TensorFlow Addons Callbacks: TQDM Progress Bar

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

This notebook will demonstrate how to use TQDMCallback in TensorFlow Addons.

Setup

try:
    # %tensorflow_version only exists in Colab.
    %tensorflow_version 2.x
except Exception:
    pass
!pip install -q --no-deps tensorflow-addons~=0.7
!pip install -q "tqdm>=4.36.1"

import tensorflow as tf
import tensorflow_addons as tfa

from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
import tqdm

# quietly deep-reload tqdm
import sys
from IPython.lib import deepreload 

stdout = sys.stdout
sys.stdout = open('junk','w')
deepreload.reload(tqdm)
sys.stdout = stdout

tqdm.__version__
'4.42.1'

Import and Normalize Data

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# normalize data
x_train, x_test = x_train / 255.0, x_test / 255.0
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step

Build Simple MNIST CNN Model

# build the model using the Sequential API
model = Sequential()
model.add(Flatten(input_shape=(28, 28)))
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(10, activation='softmax'))

model.compile(optimizer='adam',
              loss = 'sparse_categorical_crossentropy',
              metrics=['accuracy'])

Default TQDMCallback Usage

# initialize tqdm callback with default parameters
tqdm_callback = tfa.callbacks.TQDMProgressBar()

# train the model with tqdm_callback
# make sure to set verbose = 0 to disable
# the default progress bar.
model.fit(x_train, y_train,
          batch_size=64,
          epochs=10,
          verbose=0,
          callbacks=[tqdm_callback],
          validation_data=(x_test, y_test))
HBox(children=(FloatProgress(value=0.0, description='Training', layout=Layout(flex='2'), max=10.0, style=Progr…
Epoch 1/10

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=60000.0), HTML(value='')), layout=Layout(…

Epoch 2/10

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=60000.0), HTML(value='')), layout=Layout(…

Epoch 3/10

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=60000.0), HTML(value='')), layout=Layout(…

Epoch 4/10

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=60000.0), HTML(value='')), layout=Layout(…

Epoch 5/10

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=60000.0), HTML(value='')), layout=Layout(…

Epoch 6/10

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=60000.0), HTML(value='')), layout=Layout(…

Epoch 7/10

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=60000.0), HTML(value='')), layout=Layout(…

Epoch 8/10

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=60000.0), HTML(value='')), layout=Layout(…

Epoch 9/10

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=60000.0), HTML(value='')), layout=Layout(…

Epoch 10/10

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=60000.0), HTML(value='')), layout=Layout(…



<tensorflow.python.keras.callbacks.History at 0x7f5eda433080>

Below is the expected output when you run the cell above TQDM Progress Bar Figure