![]() |
![]() |
![]() |
![]() |
Overview
This notebook demonstrates how to use Moving Average Optimizer along with the Model Average Checkpoint from tensorflow addons package.
Moving Averaging
The advantage of Moving Averaging is that they are less prone to rampant loss shifts or irregular data representation in the latest batch. It gives a smooothened and a more general idea of the model training until some point.
Stochastic Averaging
Stochastic Weight Averaging converges to wider optima. By doing so, it resembles geometric ensembeling. SWA is a simple method to improve model performance when used as a wrapper around other optimizers and averaging results from different points of trajectory of the inner optimizer.
Model Average Checkpoint
callbacks.ModelCheckpoint
doesn't give you the option to save moving average weights in the middle of training, which is why Model Average Optimizers required a custom callback. Using theupdate_weights
parameter,ModelAverageCheckpoint
allows you to:
- Assign the moving average weights to the model, and save them.
- Keep the old non-averaged weights, but the saved model uses the average weights.
Setup
pip install -U tensorflow-addons
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
import os
Build Model
def create_model(opt):
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer=opt,
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
Prepare Dataset
#Load Fashion MNIST dataset
train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train
images = images/255.0
labels = labels.astype(np.int32)
fmnist_train_ds = tf.data.Dataset.from_tensor_slices((images, labels))
fmnist_train_ds = fmnist_train_ds.shuffle(5000).batch(32)
test_images, test_labels = test
We will be comparing three optimizers here:
- Unwrapped SGD
- SGD with Moving Average
- SGD with Stochastic Weight Averaging
And see how they perform with the same model.
#Optimizers
sgd = tf.keras.optimizers.SGD(0.01)
moving_avg_sgd = tfa.optimizers.MovingAverage(sgd)
stocastic_avg_sgd = tfa.optimizers.SWA(sgd)
Both MovingAverage
and StochasticAverage
optimizers use ModelAverageCheckpoint
.
#Callback
checkpoint_path = "./training/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_dir,
save_weights_only=True,
verbose=1)
avg_callback = tfa.callbacks.AverageModelCheckpoint(filepath=checkpoint_dir,
update_weights=True)
Train Model
Vanilla SGD Optimizer
#Build Model
model = create_model(sgd)
#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[cp_callback])
Epoch 1/5 1860/1875 [============================>.] - ETA: 0s - loss: 0.7621 - accuracy: 0.7457 Epoch 1: saving model to ./training 1875/1875 [==============================] - 4s 2ms/step - loss: 0.7600 - accuracy: 0.7463 Epoch 2/5 1860/1875 [============================>.] - ETA: 0s - loss: 0.5012 - accuracy: 0.8256 Epoch 2: saving model to ./training 1875/1875 [==============================] - 4s 2ms/step - loss: 0.5006 - accuracy: 0.8257 Epoch 3/5 1857/1875 [============================>.] - ETA: 0s - loss: 0.4550 - accuracy: 0.8406 Epoch 3: saving model to ./training 1875/1875 [==============================] - 4s 2ms/step - loss: 0.4550 - accuracy: 0.8404 Epoch 4/5 1870/1875 [============================>.] - ETA: 0s - loss: 0.4279 - accuracy: 0.8488 Epoch 4: saving model to ./training 1875/1875 [==============================] - 4s 2ms/step - loss: 0.4278 - accuracy: 0.8489 Epoch 5/5 1854/1875 [============================>.] - ETA: 0s - loss: 0.4089 - accuracy: 0.8558 Epoch 5: saving model to ./training 1875/1875 [==============================] - 4s 2ms/step - loss: 0.4088 - accuracy: 0.8557 <keras.callbacks.History at 0x7f1d70081fd0>
#Evalute results
model.load_weights(checkpoint_dir)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)
313/313 - 1s - loss: 101.1311 - accuracy: 0.7831 - 552ms/epoch - 2ms/step Loss : 101.13106536865234 Accuracy : 0.7831000089645386
Moving Average SGD
#Build Model
model = create_model(moving_avg_sgd)
#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[avg_callback])
Epoch 1/5 1861/1875 [============================>.] - ETA: 0s - loss: 0.7915 - accuracy: 0.7391INFO:tensorflow:Assets written to: ./training/assets 1875/1875 [==============================] - 5s 2ms/step - loss: 0.7902 - accuracy: 0.7396 Epoch 2/5 1857/1875 [============================>.] - ETA: 0s - loss: 0.5112 - accuracy: 0.8227INFO:tensorflow:Assets written to: ./training/assets 1875/1875 [==============================] - 4s 2ms/step - loss: 0.5112 - accuracy: 0.8225 Epoch 3/5 1868/1875 [============================>.] - ETA: 0s - loss: 0.4662 - accuracy: 0.8388INFO:tensorflow:Assets written to: ./training/assets 1875/1875 [==============================] - 4s 2ms/step - loss: 0.4666 - accuracy: 0.8387 Epoch 4/5 1855/1875 [============================>.] - ETA: 0s - loss: 0.4423 - accuracy: 0.8462INFO:tensorflow:Assets written to: ./training/assets 1875/1875 [==============================] - 4s 2ms/step - loss: 0.4422 - accuracy: 0.8462 Epoch 5/5 1873/1875 [============================>.] - ETA: 0s - loss: 0.4235 - accuracy: 0.8519INFO:tensorflow:Assets written to: ./training/assets 1875/1875 [==============================] - 4s 2ms/step - loss: 0.4236 - accuracy: 0.8518 <keras.callbacks.History at 0x7f1d0c598430>
#Evalute results
model.load_weights(checkpoint_dir)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)
313/313 - 1s - loss: 101.1311 - accuracy: 0.7831 - 535ms/epoch - 2ms/step Loss : 101.13106536865234 Accuracy : 0.7831000089645386
Stocastic Weight Average SGD
#Build Model
model = create_model(stocastic_avg_sgd)
#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[avg_callback])
Epoch 1/5 1866/1875 [============================>.] - ETA: 0s - loss: 0.8077 - accuracy: 0.7274INFO:tensorflow:Assets written to: ./training/assets 1875/1875 [==============================] - 6s 3ms/step - loss: 0.8064 - accuracy: 0.7278 Epoch 2/5 1875/1875 [==============================] - ETA: 0s - loss: 0.5710 - accuracy: 0.8050INFO:tensorflow:Assets written to: ./training/assets 1875/1875 [==============================] - 5s 3ms/step - loss: 0.5710 - accuracy: 0.8050 Epoch 3/5 1857/1875 [============================>.] - ETA: 0s - loss: 0.5401 - accuracy: 0.8140INFO:tensorflow:Assets written to: ./training/assets 1875/1875 [==============================] - 5s 3ms/step - loss: 0.5393 - accuracy: 0.8144 Epoch 4/5 1867/1875 [============================>.] - ETA: 0s - loss: 0.5222 - accuracy: 0.8194INFO:tensorflow:Assets written to: ./training/assets 1875/1875 [==============================] - 5s 3ms/step - loss: 0.5227 - accuracy: 0.8194 Epoch 5/5 1861/1875 [============================>.] - ETA: 0s - loss: 0.5138 - accuracy: 0.8211INFO:tensorflow:Assets written to: ./training/assets 1875/1875 [==============================] - 5s 3ms/step - loss: 0.5132 - accuracy: 0.8213 <keras.callbacks.History at 0x7f1ce85b3820>
#Evalute results
model.load_weights(checkpoint_dir)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)
313/313 - 1s - loss: 101.1311 - accuracy: 0.7831 - 563ms/epoch - 2ms/step Loss : 101.13106536865234 Accuracy : 0.7831000089645386