Video classification with a 3D convolutional neural network

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial demonstrates training a 3D convolutional neural network (CNN) for video classification using the UCF101 action recognition dataset. A 3D CNN uses a three-dimensional filter to perform convolutions. The kernel is able to slide in three directions, whereas in a 2D CNN it can slide in two dimensions. The model is based on the work published in A Closer Look at Spatiotemporal Convolutions for Action Recognition by D. Tran et al. (2017). In this tutorial, you will:

  • Build an input pipeline
  • Build a 3D convolutional neural network model with residual connections using Keras functional API
  • Train the model
  • Evaluate and test the model

This video classification tutorial is the second part in a series of TensorFlow video tutorials. Here are the other three tutorials:

Setup

Begin by installing and importing some necessary libraries, including: remotezip to inspect the contents of a ZIP file, tqdm to use a progress bar, OpenCV to process video files, einops for performing more complex tensor operations, and tensorflow_docs for embedding data in a Jupyter notebook.

pip install remotezip tqdm opencv-python einops
pip install -U tensorflow keras
import tqdm
import random
import pathlib
import itertools
import collections

import cv2
import einops
import numpy as np
import remotezip as rz
import seaborn as sns
import matplotlib.pyplot as plt

import tensorflow as tf
import keras
from keras import layers

Load and preprocess video data

The hidden cell below defines helper functions to download a slice of data from the UCF-101 dataset, and load it into a tf.data.Dataset. You can learn more about the specific preprocessing steps in the Loading video data tutorial, which walks you through this code in more detail.

The FrameGenerator class at the end of the hidden block is the most important utility here. It creates an iterable object that can feed data into the TensorFlow data pipeline. Specifically, this class contains a Python generator that loads the video frames along with its encoded label. The generator (__call__) function yields the frame array produced by frames_from_video_file and a one-hot encoded vector of the label associated with the set of frames.

URL = 'https://storage.googleapis.com/thumos14_files/UCF101_videos.zip'
download_dir = pathlib.Path('./UCF101_subset/')
subset_paths = download_ufc_101_subset(URL, 
                        num_classes = 10, 
                        splits = {"train": 30, "val": 10, "test": 10},
                        download_dir = download_dir)
train :
100%|██████████| 300/300 [00:35<00:00,  8.39it/s]
val :
100%|██████████| 100/100 [00:13<00:00,  7.58it/s]
test :
100%|██████████| 100/100 [00:11<00:00,  8.72it/s]

Create the training, validation, and test sets (train_ds, val_ds, and test_ds).

Create the model

The following 3D convolutional neural network model is based off the paper A Closer Look at Spatiotemporal Convolutions for Action Recognition by D. Tran et al. (2017). The paper compares several versions of 3D ResNets. Instead of operating on a single image with dimensions (height, width), like standard ResNets, these operate on video volume (time, height, width). The most obvious approach to this problem would be replace each 2D convolution (layers.Conv2D) with a 3D convolution (layers.Conv3D).

This tutorial uses a (2 + 1)D convolution with residual connections. The (2 + 1)D convolution allows for the decomposition of the spatial and temporal dimensions, therefore creating two separate steps. An advantage of this approach is that factorizing the convolutions into spatial and temporal dimensions saves parameters.

For each output location a 3D convolution combines all the vectors from a 3D patch of the volume to create one vector in the output volume.

3D convolutions

This operation is takes time * height * width * channels inputs and produces channels outputs (assuming the number of input and output channels are the same. So a 3D convolution layer with a kernel size of (3 x 3 x 3) would need a weight-matrix with 27 * channels ** 2 entries. The reference paper found that a more effective & efficient approach was to factorize the convolution. Instead of a single 3D convolution to process the time and space dimensions, they proposed a "(2+1)D" convolution which processes the space and time dimensions separately. The figure below shows the factored spatial and temporal convolutions of a (2 + 1)D convolution.

(2+1)D convolutions

The main advantage of this approach is that it reduces the number of parameters. In the (2 + 1)D convolution the spatial convolution takes in data of the shape (1, width, height), while the temporal convolution takes in data of the shape (time, 1, 1). For example, a (2 + 1)D convolution with kernel size (3 x 3 x 3) would need weight matrices of size (9 * channels**2) + (3 * channels**2), less than half as many as the full 3D convolution. This tutorial implements (2 + 1)D ResNet18, where each convolution in the resnet is replaced by a (2+1)D convolution.

# Define the dimensions of one frame in the set of frames created
HEIGHT = 224
WIDTH = 224
class Conv2Plus1D(keras.layers.Layer):
  def __init__(self, filters, kernel_size, padding):
    """
      A sequence of convolutional layers that first apply the convolution operation over the
      spatial dimensions, and then the temporal dimension. 
    """
    super().__init__()
    self.seq = keras.Sequential([  
        # Spatial decomposition
        layers.Conv3D(filters=filters,
                      kernel_size=(1, kernel_size[1], kernel_size[2]),
                      padding=padding),
        # Temporal decomposition
        layers.Conv3D(filters=filters, 
                      kernel_size=(kernel_size[0], 1, 1),
                      padding=padding)
        ])

  def call(self, x):
    return self.seq(x)

A ResNet model is made from a sequence of residual blocks. A residual block has two branches. The main branch performs the calculation, but is difficult for gradients to flow through. The residual branch bypasses the main calculation and mostly just adds the input to the output of the main branch. Gradients flow easily through this branch. Therefore, an easy path from the loss function to any of the residual block's main branch will be present. This avoids the vanishing gradient problem.

Create the main branch of the residual block with the following class. In contrast to the standard ResNet structure this uses the custom Conv2Plus1D layer instead of layers.Conv2D.

class ResidualMain(keras.layers.Layer):
  """
    Residual block of the model with convolution, layer normalization, and the
    activation function, ReLU.
  """
  def __init__(self, filters, kernel_size):
    super().__init__()
    self.seq = keras.Sequential([
        Conv2Plus1D(filters=filters,
                    kernel_size=kernel_size,
                    padding='same'),
        layers.LayerNormalization(),
        layers.ReLU(),
        Conv2Plus1D(filters=filters, 
                    kernel_size=kernel_size,
                    padding='same'),
        layers.LayerNormalization()
    ])

  def call(self, x):
    return self.seq(x)

To add the residual branch to the main branch it needs to have the same size. The Project layer below deals with cases where the number of channels is changed on the branch. In particular, a sequence of densely-connected layer followed by normalization is added.

class Project(keras.layers.Layer):
  """
    Project certain dimensions of the tensor as the data is passed through different 
    sized filters and downsampled. 
  """
  def __init__(self, units):
    super().__init__()
    self.seq = keras.Sequential([
        layers.Dense(units),
        layers.LayerNormalization()
    ])

  def call(self, x):
    return self.seq(x)

Use add_residual_block to introduce a skip connection between the layers of the model.

def add_residual_block(input, filters, kernel_size):
  """
    Add residual blocks to the model. If the last dimensions of the input data
    and filter size does not match, project it such that last dimension matches.
  """
  out = ResidualMain(filters, 
                     kernel_size)(input)

  res = input
  # Using the Keras functional APIs, project the last dimension of the tensor to
  # match the new filter size
  if out.shape[-1] != input.shape[-1]:
    res = Project(out.shape[-1])(res)

  return layers.add([res, out])

Resizing the video is necessary to perform downsampling of the data. In particular, downsampling the video frames allow for the model to examine specific parts of frames to detect patterns that may be specific to a certain action. Through downsampling, non-essential information can be discarded. Moreoever, resizing the video will allow for dimensionality reduction and therefore faster processing through the model.

class ResizeVideo(keras.layers.Layer):
  def __init__(self, height, width):
    super().__init__()
    self.height = height
    self.width = width
    self.resizing_layer = layers.Resizing(self.height, self.width)

  def call(self, video):
    """
      Use the einops library to resize the tensor.  

      Args:
        video: Tensor representation of the video, in the form of a set of frames.

      Return:
        A downsampled size of the video according to the new height and width it should be resized to.
    """
    # b stands for batch size, t stands for time, h stands for height, 
    # w stands for width, and c stands for the number of channels.
    old_shape = einops.parse_shape(video, 'b t h w c')
    images = einops.rearrange(video, 'b t h w c -> (b t) h w c')
    images = self.resizing_layer(images)
    videos = einops.rearrange(
        images, '(b t) h w c -> b t h w c',
        t = old_shape['t'])
    return videos

Use the Keras functional API to build the residual network.

input_shape = (None, 10, HEIGHT, WIDTH, 3)
input = layers.Input(shape=(input_shape[1:]))
x = input

x = Conv2Plus1D(filters=16, kernel_size=(3, 7, 7), padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = ResizeVideo(HEIGHT // 2, WIDTH // 2)(x)

# Block 1
x = add_residual_block(x, 16, (3, 3, 3))
x = ResizeVideo(HEIGHT // 4, WIDTH // 4)(x)

# Block 2
x = add_residual_block(x, 32, (3, 3, 3))
x = ResizeVideo(HEIGHT // 8, WIDTH // 8)(x)

# Block 3
x = add_residual_block(x, 64, (3, 3, 3))
x = ResizeVideo(HEIGHT // 16, WIDTH // 16)(x)

# Block 4
x = add_residual_block(x, 128, (3, 3, 3))

x = layers.GlobalAveragePooling3D()(x)
x = layers.Flatten()(x)
x = layers.Dense(10)(x)

model = keras.Model(input, x)
frames, label = next(iter(train_ds))
model.build(frames)
# Visualize the model
keras.utils.plot_model(model, expand_nested=True, dpi=60, show_shapes=True)

png

Train the model

For this tutorial, choose the tf.keras.optimizers.Adam optimizer and the tf.keras.losses.SparseCategoricalCrossentropy loss function. Use the metrics argument to the view the accuracy of the model performance at every step.

model.compile(loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
              optimizer = keras.optimizers.Adam(learning_rate = 0.0001), 
              metrics = ['accuracy'])

Train the model for 50 epoches with the Keras Model.fit method.

history = model.fit(x = train_ds,
                    epochs = 50, 
                    validation_data = val_ds)
Epoch 1/50
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1714526701.061971    9462 service.cc:145] XLA service 0x7f1a1c20ef90 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1714526701.062021    9462 service.cc:153]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
I0000 00:00:1714526701.062026    9462 service.cc:153]   StreamExecutor device (1): Tesla T4, Compute Capability 7.5
I0000 00:00:1714526701.062030    9462 service.cc:153]   StreamExecutor device (2): Tesla T4, Compute Capability 7.5
I0000 00:00:1714526701.062033    9462 service.cc:153]   StreamExecutor device (3): Tesla T4, Compute Capability 7.5
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1714526704.712527   10027 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_28', 12 bytes spill stores, 8 bytes spill loads

I0000 00:00:1714526723.738232    9462 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_6', 48 bytes spill stores, 48 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_1', 764 bytes spill stores, 764 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion', 124 bytes spill stores, 124 bytes spill loads

I0000 00:00:1714526723.760480    9462 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
37/Unknown 70s 1s/step - accuracy: 0.0925 - loss: 2.5544
I0000 00:00:1714526764.314869   14681 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'triton_gemm_dot_28', 12 bytes spill stores, 8 bytes spill loads

I0000 00:00:1714526779.296438    9458 asm_compiler.cc:369] ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_6', 48 bytes spill stores, 48 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion_1', 764 bytes spill stores, 764 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'input_add_reduce_fusion', 124 bytes spill stores, 124 bytes spill loads
38/Unknown 88s 1s/step - accuracy: 0.0928 - loss: 2.5518
2024-05-01 01:26:19.535317: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:26:19.535382: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
/usr/lib/python3.9/contextlib.py:137: UserWarning: Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset.
  self.gen.throw(typ, value, traceback)
38/38 ━━━━━━━━━━━━━━━━━━━━ 102s 2s/step - accuracy: 0.0931 - loss: 2.5492 - val_accuracy: 0.1500 - val_loss: 2.3808
Epoch 2/50
2024-05-01 01:26:33.663230: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:26:33.663306: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 997ms/step - accuracy: 0.1363 - loss: 2.2845
2024-05-01 01:27:11.710787: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:27:11.710830: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.1369 - loss: 2.2830 - val_accuracy: 0.1300 - val_loss: 2.4066
Epoch 3/50
2024-05-01 01:27:22.398458: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:27:22.398561: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 997ms/step - accuracy: 0.2140 - loss: 2.2506
2024-05-01 01:28:00.541594: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:28:00.541639: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.2141 - loss: 2.2492 - val_accuracy: 0.2300 - val_loss: 2.4361
Epoch 4/50
2024-05-01 01:28:11.170177: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:28:11.170247: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.2208 - loss: 2.1176
2024-05-01 01:28:49.250343: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:28:49.250389: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.2211 - loss: 2.1157 - val_accuracy: 0.1600 - val_loss: 2.3365
Epoch 5/50
2024-05-01 01:28:59.741677: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:28:59.741739: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.3215 - loss: 1.8784
2024-05-01 01:29:38.011613: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:29:38.011659: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.3210 - loss: 1.8794 - val_accuracy: 0.2700 - val_loss: 1.9216
Epoch 6/50
2024-05-01 01:29:48.545782: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:29:48.545847: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.3251 - loss: 1.8603
2024-05-01 01:30:27.051691: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:30:27.051742: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.3253 - loss: 1.8602 - val_accuracy: 0.2600 - val_loss: 1.9960
Epoch 7/50
2024-05-01 01:30:37.866972: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:30:37.867038: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.2495 - loss: 1.9548
2024-05-01 01:31:16.596279: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:31:16.596328: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.2513 - loss: 1.9509 - val_accuracy: 0.1800 - val_loss: 2.3701
Epoch 8/50
2024-05-01 01:31:27.253233: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:31:27.253290: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.3715 - loss: 1.8363
2024-05-01 01:32:05.609963: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:32:05.610007: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.3709 - loss: 1.8347 - val_accuracy: 0.3000 - val_loss: 2.1785
Epoch 9/50
2024-05-01 01:32:16.342772: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:32:16.342836: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.3189 - loss: 1.7731
2024-05-01 01:32:54.506041: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:32:54.506084: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.3192 - loss: 1.7733 - val_accuracy: 0.2300 - val_loss: 1.9931
Epoch 10/50
2024-05-01 01:33:05.082039: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:33:05.082089: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.4153 - loss: 1.6298
2024-05-01 01:33:43.760214: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:33:43.760260: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.4150 - loss: 1.6299 - val_accuracy: 0.3200 - val_loss: 1.7989
Epoch 11/50
2024-05-01 01:33:54.480935: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:33:54.480981: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.4934 - loss: 1.5774
2024-05-01 01:34:32.808397: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:34:32.808440: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.4933 - loss: 1.5779 - val_accuracy: 0.3700 - val_loss: 1.8226
Epoch 12/50
2024-05-01 01:34:43.337667: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:34:43.337736: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.4012 - loss: 1.5900
2024-05-01 01:35:21.715906: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:35:21.715955: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.4005 - loss: 1.5913 - val_accuracy: 0.3100 - val_loss: 1.8618
Epoch 13/50
2024-05-01 01:35:32.314501: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:35:32.314548: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.3941 - loss: 1.6043
2024-05-01 01:36:10.851518: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:36:10.851589: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.3949 - loss: 1.6036 - val_accuracy: 0.4100 - val_loss: 1.7070
Epoch 14/50
2024-05-01 01:36:21.246493: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:36:21.246540: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 997ms/step - accuracy: 0.4531 - loss: 1.4518
2024-05-01 01:36:59.594407: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:36:59.594453: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.4532 - loss: 1.4517 - val_accuracy: 0.4500 - val_loss: 1.4516
Epoch 15/50
2024-05-01 01:37:10.225323: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:37:10.225368: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.4660 - loss: 1.3619
2024-05-01 01:37:48.586639: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:37:48.586684: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.4664 - loss: 1.3629 - val_accuracy: 0.3700 - val_loss: 1.7023
Epoch 16/50
2024-05-01 01:37:59.236730: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:37:59.236778: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.5312 - loss: 1.3279
2024-05-01 01:38:37.677649: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:38:37.677697: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.5304 - loss: 1.3284 - val_accuracy: 0.4900 - val_loss: 1.5107
Epoch 17/50
2024-05-01 01:38:48.321021: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:38:48.321070: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.4497 - loss: 1.2881
2024-05-01 01:39:26.796215: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:39:26.796261: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.4510 - loss: 1.2877 - val_accuracy: 0.3800 - val_loss: 1.7984
Epoch 18/50
2024-05-01 01:39:37.124835: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:39:37.124880: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.5500 - loss: 1.2364
2024-05-01 01:40:15.844887: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:40:15.844941: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.5501 - loss: 1.2358 - val_accuracy: 0.4600 - val_loss: 1.5495
Epoch 19/50
2024-05-01 01:40:26.381759: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:40:26.381803: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.5596 - loss: 1.1740
2024-05-01 01:41:04.819809: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:41:04.819851: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.5594 - loss: 1.1749 - val_accuracy: 0.5800 - val_loss: 1.2670
Epoch 20/50
2024-05-01 01:41:15.465625: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:41:15.465670: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.5987 - loss: 1.1106
2024-05-01 01:41:54.198335: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:41:54.198381: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.5985 - loss: 1.1117 - val_accuracy: 0.3400 - val_loss: 2.2354
Epoch 21/50
2024-05-01 01:42:04.762144: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:42:04.762188: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.6045 - loss: 1.1643
2024-05-01 01:42:43.206237: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:42:43.206282: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.6031 - loss: 1.1657 - val_accuracy: 0.4700 - val_loss: 1.5135
Epoch 22/50
2024-05-01 01:42:53.906904: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:42:53.906955: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 998ms/step - accuracy: 0.6409 - loss: 1.0018
2024-05-01 01:43:32.004066: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:43:32.004113: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.6396 - loss: 1.0041 - val_accuracy: 0.4500 - val_loss: 1.5741
Epoch 23/50
2024-05-01 01:43:42.633894: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:43:42.633944: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.5682 - loss: 1.1207
2024-05-01 01:44:20.962551: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:44:20.962597: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.5685 - loss: 1.1212 - val_accuracy: 0.5300 - val_loss: 1.2757
Epoch 24/50
2024-05-01 01:44:31.614889: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:44:31.614936: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1000ms/step - accuracy: 0.6101 - loss: 1.1038
2024-05-01 01:45:09.790791: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:45:09.790834: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.6100 - loss: 1.1037 - val_accuracy: 0.4800 - val_loss: 1.4255
Epoch 25/50
2024-05-01 01:45:20.174585: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:45:20.174631: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.6454 - loss: 0.9844
2024-05-01 01:45:58.413810: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:45:58.413855: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.6447 - loss: 0.9850 - val_accuracy: 0.6100 - val_loss: 1.1376
Epoch 26/50
2024-05-01 01:46:08.848147: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:46:08.848192: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.7205 - loss: 0.9094
2024-05-01 01:46:47.418722: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:46:47.418778: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.7195 - loss: 0.9091 - val_accuracy: 0.5600 - val_loss: 1.1900
Epoch 27/50
2024-05-01 01:46:57.750561: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:46:57.750626: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.5917 - loss: 0.9600
2024-05-01 01:47:36.069689: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:47:36.069734: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.5927 - loss: 0.9599 - val_accuracy: 0.5800 - val_loss: 1.3921
Epoch 28/50
2024-05-01 01:47:46.426764: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:47:46.426825: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.6385 - loss: 0.9558
2024-05-01 01:48:24.688769: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:48:24.688815: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.6386 - loss: 0.9555 - val_accuracy: 0.5700 - val_loss: 1.1771
Epoch 29/50
2024-05-01 01:48:35.205715: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:48:35.205756: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.6613 - loss: 0.9125
2024-05-01 01:49:13.881434: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:49:13.881492: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.6605 - loss: 0.9139 - val_accuracy: 0.5900 - val_loss: 1.1538
Epoch 30/50
2024-05-01 01:49:24.385718: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:49:24.385782: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.6138 - loss: 0.9618
2024-05-01 01:50:02.737167: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:50:02.737225: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.6147 - loss: 0.9599 - val_accuracy: 0.5900 - val_loss: 1.1690
Epoch 31/50
2024-05-01 01:50:13.225610: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:50:13.225665: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.6770 - loss: 0.8324
2024-05-01 01:50:51.610035: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:50:51.610079: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.6771 - loss: 0.8325 - val_accuracy: 0.6000 - val_loss: 1.0798
Epoch 32/50
2024-05-01 01:51:01.933931: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:51:01.933998: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.7442 - loss: 0.7864
2024-05-01 01:51:40.269265: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:51:40.269319: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.7439 - loss: 0.7868 - val_accuracy: 0.5300 - val_loss: 1.3041
Epoch 33/50
2024-05-01 01:51:50.855654: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:51:50.855699: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.7000 - loss: 0.8440
2024-05-01 01:52:29.346011: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:52:29.346063: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.7002 - loss: 0.8446 - val_accuracy: 0.5900 - val_loss: 1.0299
Epoch 34/50
2024-05-01 01:52:39.590605: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:52:39.590648: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 994ms/step - accuracy: 0.6939 - loss: 0.8452
2024-05-01 01:53:17.561607: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:53:17.561668: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 48s 1s/step - accuracy: 0.6939 - loss: 0.8452 - val_accuracy: 0.4900 - val_loss: 1.3111
Epoch 35/50
2024-05-01 01:53:27.975488: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:53:27.975533: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.7193 - loss: 0.8416
2024-05-01 01:54:06.554832: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:54:06.554873: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.7188 - loss: 0.8415 - val_accuracy: 0.6000 - val_loss: 1.1907
Epoch 36/50
2024-05-01 01:54:17.081471: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:54:17.081523: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.7316 - loss: 0.7743
2024-05-01 01:54:55.505768: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:54:55.505809: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.7311 - loss: 0.7740 - val_accuracy: 0.6300 - val_loss: 1.0519
Epoch 37/50
2024-05-01 01:55:06.141968: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:55:06.142026: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.6769 - loss: 0.8639
2024-05-01 01:55:44.341043: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:55:44.341085: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 48s 1s/step - accuracy: 0.6772 - loss: 0.8645 - val_accuracy: 0.4500 - val_loss: 1.4352
Epoch 38/50
2024-05-01 01:55:54.641586: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:55:54.641647: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.7684 - loss: 0.7403
2024-05-01 01:56:32.931421: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:56:32.931462: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.7687 - loss: 0.7394 - val_accuracy: 0.5700 - val_loss: 1.0731
Epoch 39/50
2024-05-01 01:56:43.321244: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:56:43.321308: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.7513 - loss: 0.6741
2024-05-01 01:57:21.977342: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:57:21.977385: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.7519 - loss: 0.6750 - val_accuracy: 0.6200 - val_loss: 1.1427
Epoch 40/50
2024-05-01 01:57:32.426980: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:57:32.427024: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.7980 - loss: 0.6434
2024-05-01 01:58:10.789218: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:58:10.789265: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.7966 - loss: 0.6455 - val_accuracy: 0.7100 - val_loss: 0.9882
Epoch 41/50
2024-05-01 01:58:21.393892: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:58:21.393937: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - accuracy: 0.7739 - loss: 0.6498
2024-05-01 01:58:59.493428: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:58:59.493468: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.7736 - loss: 0.6500 - val_accuracy: 0.6400 - val_loss: 1.0812
Epoch 42/50
2024-05-01 01:59:10.000232: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:59:10.000273: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 996ms/step - accuracy: 0.7882 - loss: 0.6113
2024-05-01 01:59:47.954600: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:59:47.954645: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 48s 1s/step - accuracy: 0.7876 - loss: 0.6132 - val_accuracy: 0.6500 - val_loss: 0.9893
Epoch 43/50
2024-05-01 01:59:58.264155: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 01:59:58.264208: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 989ms/step - accuracy: 0.7566 - loss: 0.7420
2024-05-01 02:00:36.047747: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 02:00:36.047794: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 48s 1s/step - accuracy: 0.7562 - loss: 0.7414 - val_accuracy: 0.6400 - val_loss: 0.9363
Epoch 44/50
2024-05-01 02:00:46.694517: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 02:00:46.694561: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 991ms/step - accuracy: 0.8485 - loss: 0.5736
2024-05-01 02:01:24.638701: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 02:01:24.638744: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 48s 1s/step - accuracy: 0.8476 - loss: 0.5748 - val_accuracy: 0.6500 - val_loss: 1.0424
Epoch 45/50
2024-05-01 02:01:35.172700: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 02:01:35.172746: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 999ms/step - accuracy: 0.7621 - loss: 0.6177
2024-05-01 02:02:13.523955: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 02:02:13.523997: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.7629 - loss: 0.6167 - val_accuracy: 0.6500 - val_loss: 0.9271
Epoch 46/50
2024-05-01 02:02:23.865125: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 02:02:23.865184: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 992ms/step - accuracy: 0.7935 - loss: 0.5451
2024-05-01 02:03:01.694907: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 02:03:01.694956: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 48s 1s/step - accuracy: 0.7934 - loss: 0.5460 - val_accuracy: 0.6700 - val_loss: 0.9752
Epoch 47/50
2024-05-01 02:03:11.908832: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 02:03:11.908875: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 998ms/step - accuracy: 0.8130 - loss: 0.5944
2024-05-01 02:03:49.848870: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 02:03:49.848917: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 48s 1s/step - accuracy: 0.8132 - loss: 0.5937 - val_accuracy: 0.6700 - val_loss: 0.9059
Epoch 48/50
2024-05-01 02:04:00.126769: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 02:04:00.126860: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 993ms/step - accuracy: 0.8228 - loss: 0.5414
2024-05-01 02:04:37.949920: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 02:04:37.949995: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 48s 1s/step - accuracy: 0.8226 - loss: 0.5420 - val_accuracy: 0.6800 - val_loss: 0.9071
Epoch 49/50
2024-05-01 02:04:48.384197: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 02:04:48.384267: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 995ms/step - accuracy: 0.8797 - loss: 0.5166
2024-05-01 02:05:26.346920: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 02:05:26.346961: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 48s 1s/step - accuracy: 0.8795 - loss: 0.5163 - val_accuracy: 0.6600 - val_loss: 1.0068
Epoch 50/50
2024-05-01 02:05:36.808373: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 02:05:36.808415: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 0s 991ms/step - accuracy: 0.8331 - loss: 0.5567
2024-05-01 02:06:14.573998: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 02:06:14.574044: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
38/38 ━━━━━━━━━━━━━━━━━━━━ 48s 1s/step - accuracy: 0.8329 - loss: 0.5575 - val_accuracy: 0.6900 - val_loss: 0.8807
2024-05-01 02:06:24.815809: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 02:06:24.815853: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]

Visualize the results

Create plots of the loss and accuracy on the training and validation sets:

def plot_history(history):
  """
    Plotting training and validation learning curves.

    Args:
      history: model history with all the metric measures
  """
  fig, (ax1, ax2) = plt.subplots(2)

  fig.set_size_inches(18.5, 10.5)

  # Plot loss
  ax1.set_title('Loss')
  ax1.plot(history.history['loss'], label = 'train')
  ax1.plot(history.history['val_loss'], label = 'test')
  ax1.set_ylabel('Loss')

  # Determine upper bound of y-axis
  max_loss = max(history.history['loss'] + history.history['val_loss'])

  ax1.set_ylim([0, np.ceil(max_loss)])
  ax1.set_xlabel('Epoch')
  ax1.legend(['Train', 'Validation']) 

  # Plot accuracy
  ax2.set_title('Accuracy')
  ax2.plot(history.history['accuracy'],  label = 'train')
  ax2.plot(history.history['val_accuracy'], label = 'test')
  ax2.set_ylabel('Accuracy')
  ax2.set_ylim([0, 1])
  ax2.set_xlabel('Epoch')
  ax2.legend(['Train', 'Validation'])

  plt.show()

plot_history(history)

png

Evaluate the model

Use Keras Model.evaluate to get the loss and accuracy on the test dataset.

model.evaluate(test_ds, return_dict=True)
13/13 ━━━━━━━━━━━━━━━━━━━━ 10s 782ms/step - accuracy: 0.6400 - loss: 0.9176
2024-05-01 02:06:35.697737: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
2024-05-01 02:06:35.697779: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
     [[{ {node IteratorGetNext} }]]
     [[IteratorGetNext/_2]]
{'accuracy': 0.6899999976158142, 'loss': 0.8081037402153015}

To visualize model performance further, use a confusion matrix. The confusion matrix allows you to assess the performance of the classification model beyond accuracy. In order to build the confusion matrix for this multi-class classification problem, get the actual values in the test set and the predicted values.

def get_actual_predicted_labels(dataset): 
  """
    Create a list of actual ground truth values and the predictions from the model.

    Args:
      dataset: An iterable data structure, such as a TensorFlow Dataset, with features and labels.

    Return:
      Ground truth and predicted values for a particular dataset.
  """
  actual = [labels for _, labels in dataset.unbatch()]
  predicted = model.predict(dataset)

  actual = tf.stack(actual, axis=0)
  predicted = tf.concat(predicted, axis=0)
  predicted = tf.argmax(predicted, axis=1)

  return actual, predicted
def plot_confusion_matrix(actual, predicted, labels, ds_type):
  cm = tf.math.confusion_matrix(actual, predicted)
  ax = sns.heatmap(cm, annot=True, fmt='g')
  sns.set(rc={'figure.figsize':(12, 12)})
  sns.set(font_scale=1.4)
  ax.set_title('Confusion matrix of action recognition for ' + ds_type)
  ax.set_xlabel('Predicted Action')
  ax.set_ylabel('Actual Action')
  plt.xticks(rotation=90)
  plt.yticks(rotation=0)
  ax.xaxis.set_ticklabels(labels)
  ax.yaxis.set_ticklabels(labels)
fg = FrameGenerator(subset_paths['train'], n_frames, training=True)
labels = list(fg.class_ids_for_name.keys())
actual, predicted = get_actual_predicted_labels(train_ds)
plot_confusion_matrix(actual, predicted, labels, 'training')
2024-05-01 02:07:04.824506: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
38/38 ━━━━━━━━━━━━━━━━━━━━ 35s 867ms/step
2024-05-01 02:07:39.707651: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence

png

actual, predicted = get_actual_predicted_labels(test_ds)
plot_confusion_matrix(actual, predicted, labels, 'test')
2024-05-01 02:07:50.372540: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
13/13 ━━━━━━━━━━━━━━━━━━━━ 11s 797ms/step
/usr/lib/python3.9/contextlib.py:137: UserWarning: Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset.
  self.gen.throw(typ, value, traceback)
2024-05-01 02:08:00.975486: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence

png

The precision and recall values for each class can also be calculated using a confusion matrix.

def calculate_classification_metrics(y_actual, y_pred, labels):
  """
    Calculate the precision and recall of a classification model using the ground truth and
    predicted values. 

    Args:
      y_actual: Ground truth labels.
      y_pred: Predicted labels.
      labels: List of classification labels.

    Return:
      Precision and recall measures.
  """
  cm = tf.math.confusion_matrix(y_actual, y_pred)
  tp = np.diag(cm) # Diagonal represents true positives
  precision = dict()
  recall = dict()
  for i in range(len(labels)):
    col = cm[:, i]
    fp = np.sum(col) - tp[i] # Sum of column minus true positive is false negative

    row = cm[i, :]
    fn = np.sum(row) - tp[i] # Sum of row minus true positive, is false negative

    precision[labels[i]] = tp[i] / (tp[i] + fp) # Precision 

    recall[labels[i]] = tp[i] / (tp[i] + fn) # Recall

  return precision, recall
precision, recall = calculate_classification_metrics(actual, predicted, labels) # Test dataset
precision
{'ApplyEyeMakeup': 0.5,
 'ApplyLipstick': 0.5,
 'Archery': 0.6,
 'BabyCrawling': 0.8,
 'BalanceBeam': 0.75,
 'BandMarching': 0.9,
 'BaseballPitch': 0.625,
 'Basketball': 0.5,
 'BasketballDunk': 0.9,
 'BenchPress': 0.875}
recall
{'ApplyEyeMakeup': 0.2,
 'ApplyLipstick': 0.7,
 'Archery': 0.9,
 'BabyCrawling': 0.4,
 'BalanceBeam': 0.9,
 'BandMarching': 0.9,
 'BaseballPitch': 1.0,
 'Basketball': 0.3,
 'BasketballDunk': 0.9,
 'BenchPress': 0.7}

Next steps

To learn more about working with video data in TensorFlow, check out the following tutorials: