Rappels des modules complémentaires TensorFlow : barre de progression TQDM

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier

Aperçu

Ce bloc-notes montrera comment utiliser TQDMCallback dans les modules complémentaires TensorFlow.

Installer

pip install -U tensorflow-addons
!pip install -q "tqdm>=4.36.1"

import tensorflow as tf
import tensorflow_addons as tfa

from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
import tqdm

# quietly deep-reload tqdm
import sys
from IPython.lib import deepreload 

stdout = sys.stdout
sys.stdout = open('junk','w')
deepreload.reload(tqdm)
sys.stdout = stdout

tqdm.__version__
'4.62.3'

Importer et normaliser les données

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# normalize data
x_train, x_test = x_train / 255.0, x_test / 255.0

Construire un modèle CNN MNIST simple

# build the model using the Sequential API
model = Sequential()
model.add(Flatten(input_shape=(28, 28)))
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(10, activation='softmax'))

model.compile(optimizer='adam',
              loss = 'sparse_categorical_crossentropy',
              metrics=['accuracy'])

Utilisation par défaut de TQDMCallback

# initialize tqdm callback with default parameters
tqdm_callback = tfa.callbacks.TQDMProgressBar()

# train the model with tqdm_callback
# make sure to set verbose = 0 to disable
# the default progress bar.
model.fit(x_train, y_train,
          batch_size=64,
          epochs=10,
          verbose=0,
          callbacks=[tqdm_callback],
          validation_data=(x_test, y_test))
Training:   0%|           0/10 ETA: ?s,  ?epochs/s
Epoch 1/10
0/938           ETA: ?s -
Epoch 2/10
0/938           ETA: ?s -
Epoch 3/10
0/938           ETA: ?s -
Epoch 4/10
0/938           ETA: ?s -
Epoch 5/10
0/938           ETA: ?s -
Epoch 6/10
0/938           ETA: ?s -
Epoch 7/10
0/938           ETA: ?s -
Epoch 8/10
0/938           ETA: ?s -
Epoch 9/10
0/938           ETA: ?s -
Epoch 10/10
0/938           ETA: ?s -
<keras.callbacks.History at 0x7f4a8d35aed0>

Vous trouverez ci-dessous le résultat attendu lorsque vous exécutez la cellule ci-dessus Chiffre de barre de progression TQDM

# TQDMProgressBar() also works with evaluate()
model.evaluate(x_test, y_test, batch_size=64, callbacks=[tqdm_callback], verbose=0)
0/157           ETA: ?s - Evaluating
[0.06689586490392685, 0.9805999994277954]

Vous trouverez ci-dessous le résultat attendu lorsque vous exécutez la cellule ci-dessus TQDM évaluer la barre de progression