Registration is open for TensorFlow Dev Summit 2020 Learn more

TensorFlow Addons Callbacks: TimeStopping

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

This notebook will demonstrate how to use TimeStopping Callback in TensorFlow Addons.

Setup

!pip install -q --no-deps tensorflow-addons~=0.6
import tensorflow as tf
import tensorflow_addons as tfa

import tensorflow.keras as keras
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten

Import and Normalize Data

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# normalize data
x_train, x_test = x_train / 255.0, x_test / 255.0

Build Simple MNIST CNN Model

# build the model using the Sequential API
model = Sequential()
model.add(Flatten(input_shape=(28, 28)))
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(10, activation='softmax'))

model.compile(optimizer='adam',
              loss = 'sparse_categorical_crossentropy',
              metrics=['accuracy'])

Simple TimeStopping Usage

# initialize TimeStopping callback 
time_stopping_callback = tfa.callbacks.TimeStopping(seconds=5, verbose=1)

# train the model with tqdm_callback
# make sure to set verbose = 0 to disable
# the default progress bar.
model.fit(x_train, y_train,
          batch_size=64,
          epochs=100,
          callbacks=[time_stopping_callback],
          validation_data=(x_test, y_test))
Train on 60000 samples, validate on 10000 samples
Epoch 1/100
60000/60000 [==============================] - 2s 28us/sample - loss: 0.3357 - accuracy: 0.9033 - val_loss: 0.1606 - val_accuracy: 0.9533
Epoch 2/100
60000/60000 [==============================] - 1s 23us/sample - loss: 0.1606 - accuracy: 0.9525 - val_loss: 0.1104 - val_accuracy: 0.9669
Epoch 3/100
60000/60000 [==============================] - 1s 24us/sample - loss: 0.1185 - accuracy: 0.9645 - val_loss: 0.0949 - val_accuracy: 0.9704
Epoch 4/100
60000/60000 [==============================] - 1s 25us/sample - loss: 0.0954 - accuracy: 0.9713 - val_loss: 0.0854 - val_accuracy: 0.9740
Timed stopping at epoch 4 after training for 0:00:05

<tensorflow.python.keras.callbacks.History at 0x110af0ef0>