![]() |
![]() |
![]() |
![]() |
Overview
This notebook demonstrates how to use Moving Average Optimizer along with the Model Average Checkpoint from tensorflow addons pagkage.
Moving Averaging
The advantage of Moving Averaging is that they are less prone to rampant loss shifts or irregular data representation in the latest batch. It gives a smooothened and a more genral idea of the model training until some point.
Stocastic Averaging
Stocastic Weight Averaging converges to wider optimas. By doing so, it resembles geometric ensembeling. SWA is a simple method to improve model performance when used as a wrapper around other optimizers and averaging results from different points of trajectory of the inner optimizer.
Model Average Checkpoint
callbacks.ModelCheckpoint
doesn't give you the option to save moving average weights in the middle of training, which is why Model Average Optimizers required a custom callback. Using theupdate_weights
parameter,ModelAverageCheckpoint
allows you to:
- Assign the moving average weights to the model, and save them.
- Keep the old non-averaged weights, but the saved model uses the average weights.
Setup
pip install -q -U tensorflow-addons
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
import os
Build Model
def create_model(opt):
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer=opt,
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
Prepare Dataset
#Load Fashion MNIST dataset
train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train
images = images/255.0
labels = labels.astype(np.int32)
fmnist_train_ds = tf.data.Dataset.from_tensor_slices((images, labels))
fmnist_train_ds = fmnist_train_ds.shuffle(5000).batch(32)
test_images, test_labels = test
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz 32768/29515 [=================================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz 26427392/26421880 [==============================] - 1s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz 8192/5148 [===============================================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz 4423680/4422102 [==============================] - 0s 0us/step
We will be comparing three optimizers here:
- Unwrapped SGD
- SGD with Moving Average
- SGD with Stochastic Weight Averaging
And see how they perform with the same model.
#Optimizers
sgd = tf.keras.optimizers.SGD(0.01)
moving_avg_sgd = tfa.optimizers.MovingAverage(sgd)
stocastic_avg_sgd = tfa.optimizers.SWA(sgd)
Both MovingAverage
and StocasticAverage
optimers use ModelAverageCheckpoint
.
#Callback
checkpoint_path = "./training/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_dir,
save_weights_only=True,
verbose=1)
avg_callback = tfa.callbacks.AverageModelCheckpoint(filepath=checkpoint_dir,
update_weights=True)
Train Model
Vanilla SGD Optimizer
#Build Model
model = create_model(sgd)
#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[cp_callback])
Epoch 1/5 1875/1875 [==============================] - 5s 2ms/step - loss: 1.1010 - accuracy: 0.6480 Epoch 00001: saving model to ./training Epoch 2/5 1875/1875 [==============================] - 4s 2ms/step - loss: 0.5210 - accuracy: 0.8180 Epoch 00002: saving model to ./training Epoch 3/5 1875/1875 [==============================] - 4s 2ms/step - loss: 0.4636 - accuracy: 0.8355 Epoch 00003: saving model to ./training Epoch 4/5 1875/1875 [==============================] - 4s 2ms/step - loss: 0.4320 - accuracy: 0.8471 Epoch 00004: saving model to ./training Epoch 5/5 1875/1875 [==============================] - 4s 2ms/step - loss: 0.4098 - accuracy: 0.8558 Epoch 00005: saving model to ./training <tensorflow.python.keras.callbacks.History at 0x7f9c00068b00>
#Evalute results
model.load_weights(checkpoint_dir)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)
313/313 - 1s - loss: 82.8866 - accuracy: 0.7894 Loss : 82.88660430908203 Accuracy : 0.7893999814987183
Moving Average SGD
#Build Model
model = create_model(moving_avg_sgd)
#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[avg_callback])
Epoch 1/5 1875/1875 [==============================] - 5s 2ms/step - loss: 1.1354 - accuracy: 0.6366 INFO:tensorflow:Assets written to: ./training/assets Epoch 2/5 1875/1875 [==============================] - 5s 2ms/step - loss: 0.5312 - accuracy: 0.8170 INFO:tensorflow:Assets written to: ./training/assets Epoch 3/5 1875/1875 [==============================] - 5s 2ms/step - loss: 0.4677 - accuracy: 0.8366 INFO:tensorflow:Assets written to: ./training/assets Epoch 4/5 1875/1875 [==============================] - 5s 2ms/step - loss: 0.4407 - accuracy: 0.8447 INFO:tensorflow:Assets written to: ./training/assets Epoch 5/5 1875/1875 [==============================] - 5s 2ms/step - loss: 0.4174 - accuracy: 0.8542 INFO:tensorflow:Assets written to: ./training/assets <tensorflow.python.keras.callbacks.History at 0x7f9c0005a358>
#Evalute results
model.load_weights(checkpoint_dir)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)
313/313 - 0s - loss: 82.8866 - accuracy: 0.7894 Loss : 82.88660430908203 Accuracy : 0.7893999814987183
Stocastic Weight Average SGD
#Build Model
model = create_model(stocastic_avg_sgd)
#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[avg_callback])
Epoch 1/5 1875/1875 [==============================] - 6s 3ms/step - loss: 1.0615 - accuracy: 0.6521 INFO:tensorflow:Assets written to: ./training/assets Epoch 2/5 1875/1875 [==============================] - 5s 3ms/step - loss: 0.6078 - accuracy: 0.7925 INFO:tensorflow:Assets written to: ./training/assets Epoch 3/5 1875/1875 [==============================] - 5s 3ms/step - loss: 0.5671 - accuracy: 0.8062 INFO:tensorflow:Assets written to: ./training/assets Epoch 4/5 1875/1875 [==============================] - 5s 3ms/step - loss: 0.5468 - accuracy: 0.8115 INFO:tensorflow:Assets written to: ./training/assets Epoch 5/5 1875/1875 [==============================] - 5s 3ms/step - loss: 0.5288 - accuracy: 0.8181 INFO:tensorflow:Assets written to: ./training/assets <tensorflow.python.keras.callbacks.History at 0x7f9b7c6faa58>
#Evalute results
model.load_weights(checkpoint_dir)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)
313/313 - 0s - loss: 82.8866 - accuracy: 0.7894 Loss : 82.88660430908203 Accuracy : 0.7893999814987183