Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

TensorFlow Addons Callbacks: TQDM Progress Bar

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

This notebook will demonstrate how to use TQDMCallback in TensorFlow Addons.

Setup

!pip install -q --no-deps tensorflow-addons~=0.7
!pip install -q "tqdm>=4.36.1"

import tensorflow as tf
import tensorflow_addons as tfa

from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
import tqdm

# quietly deep-reload tqdm
import sys
from IPython.lib import deepreload 

stdout = sys.stdout
sys.stdout = open('junk','w')
deepreload.reload(tqdm)
sys.stdout = stdout

tqdm.__version__
'4.42.1'

Import and Normalize Data

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# normalize data
x_train, x_test = x_train / 255.0, x_test / 255.0
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step

Build Simple MNIST CNN Model

# build the model using the Sequential API
model = Sequential()
model.add(Flatten(input_shape=(28, 28)))
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(10, activation='softmax'))

model.compile(optimizer='adam',
              loss = 'sparse_categorical_crossentropy',
              metrics=['accuracy'])

Default TQDMCallback Usage

# initialize tqdm callback with default parameters
tqdm_callback = tfa.callbacks.TQDMProgressBar()

# train the model with tqdm_callback
# make sure to set verbose = 0 to disable
# the default progress bar.
model.fit(x_train, y_train,
          batch_size=64,
          epochs=10,
          verbose=0,
          callbacks=[tqdm_callback],
          validation_data=(x_test, y_test))
HBox(children=(FloatProgress(value=0.0, description='Training', layout=Layout(flex='2'), max=10.0, style=Progr…
Epoch 1/10

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=60000.0), HTML(value='')), layout=Layout(…

Epoch 2/10

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=60000.0), HTML(value='')), layout=Layout(…

Epoch 3/10

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=60000.0), HTML(value='')), layout=Layout(…

Epoch 4/10

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=60000.0), HTML(value='')), layout=Layout(…

Epoch 5/10

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=60000.0), HTML(value='')), layout=Layout(…

Epoch 6/10

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=60000.0), HTML(value='')), layout=Layout(…

Epoch 7/10

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=60000.0), HTML(value='')), layout=Layout(…

Epoch 8/10

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=60000.0), HTML(value='')), layout=Layout(…

Epoch 9/10

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=60000.0), HTML(value='')), layout=Layout(…

Epoch 10/10

HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=60000.0), HTML(value='')), layout=Layout(…



<tensorflow.python.keras.callbacks.History at 0x7f5eda433080>

Below is the expected output when you run the cell above TQDM Progress Bar Figure