TensorFlow Addons Callbacks: TQDM Progress Bar

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

This notebook will demonstrate how to use TQDMCallback in TensorFlow Addons.

Setup

pip install -U tensorflow-addons
!pip install -q "tqdm>=4.36.1"

import tensorflow as tf
import tensorflow_addons as tfa

from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
import tqdm

# quietly deep-reload tqdm
import sys
from IPython.lib import deepreload 

stdout = sys.stdout
sys.stdout = open('junk','w')
deepreload.reload(tqdm)
sys.stdout = stdout

tqdm.__version__
'4.64.0'

Import and Normalize Data

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# normalize data
x_train, x_test = x_train / 255.0, x_test / 255.0
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11490434/11490434 [==============================] - 0s 0us/step

Build Simple MNIST CNN Model

# build the model using the Sequential API
model = Sequential()
model.add(Flatten(input_shape=(28, 28)))
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(10, activation='softmax'))

model.compile(optimizer='adam',
              loss = 'sparse_categorical_crossentropy',
              metrics=['accuracy'])

Default TQDMCallback Usage

# initialize tqdm callback with default parameters
tqdm_callback = tfa.callbacks.TQDMProgressBar()

# train the model with tqdm_callback
# make sure to set verbose = 0 to disable
# the default progress bar.
model.fit(x_train, y_train,
          batch_size=64,
          epochs=10,
          verbose=0,
          callbacks=[tqdm_callback],
          validation_data=(x_test, y_test))
Training:   0%|           0/10 ETA: ?s,  ?epochs/s
Epoch 1/10
0/938           ETA: ?s -
Epoch 2/10
0/938           ETA: ?s -
Epoch 3/10
0/938           ETA: ?s -
Epoch 4/10
0/938           ETA: ?s -
Epoch 5/10
0/938           ETA: ?s -
Epoch 6/10
0/938           ETA: ?s -
Epoch 7/10
0/938           ETA: ?s -
Epoch 8/10
0/938           ETA: ?s -
Epoch 9/10
0/938           ETA: ?s -
Epoch 10/10
0/938           ETA: ?s -
<keras.callbacks.History at 0x7f49161ea7f0>

Below is the expected output when you run the cell above TQDM Progress Bar Figure

# TQDMProgressBar() also works with evaluate()
model.evaluate(x_test, y_test, batch_size=64, callbacks=[tqdm_callback], verbose=0)
0/157           ETA: ?s - Evaluating
[0.07728561758995056, 0.9760000109672546]

Below is the expected output when you run the cell above TQDM Evaluate Progress Bar Figure