Model Averaging

Stay organized with collections Save and categorize content based on your preferences.

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

This notebook demonstrates how to use Moving Average Optimizer along with the Model Average Checkpoint from tensorflow addons package.

Moving Averaging

The advantage of Moving Averaging is that they are less prone to rampant loss shifts or irregular data representation in the latest batch. It gives a smooothened and a more general idea of the model training until some point.

Stochastic Averaging

Stochastic Weight Averaging converges to wider optima. By doing so, it resembles geometric ensembeling. SWA is a simple method to improve model performance when used as a wrapper around other optimizers and averaging results from different points of trajectory of the inner optimizer.

Model Average Checkpoint

callbacks.ModelCheckpoint doesn't give you the option to save moving average weights in the middle of training, which is why Model Average Optimizers required a custom callback. Using the update_weights parameter, ModelAverageCheckpoint allows you to:

  1. Assign the moving average weights to the model, and save them.
  2. Keep the old non-averaged weights, but the saved model uses the average weights.

Setup

pip install -U tensorflow-addons
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
import os

Build Model

def create_model(opt):
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(),                         
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])

    model.compile(optimizer=opt,
                    loss='sparse_categorical_crossentropy',
                    metrics=['accuracy'])

    return model

Prepare Dataset

#Load Fashion MNIST dataset
train, test = tf.keras.datasets.fashion_mnist.load_data()

images, labels = train
images = images/255.0
labels = labels.astype(np.int32)

fmnist_train_ds = tf.data.Dataset.from_tensor_slices((images, labels))
fmnist_train_ds = fmnist_train_ds.shuffle(5000).batch(32)

test_images, test_labels = test

We will be comparing three optimizers here:

  • Unwrapped SGD
  • SGD with Moving Average
  • SGD with Stochastic Weight Averaging

And see how they perform with the same model.

#Optimizers 
sgd = tf.keras.optimizers.SGD(0.01)
moving_avg_sgd = tfa.optimizers.MovingAverage(sgd)
stocastic_avg_sgd = tfa.optimizers.SWA(sgd)

Both MovingAverage and StochasticAverage optimizers use ModelAverageCheckpoint.

#Callback 
checkpoint_path = "./training/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_dir,
                                                 save_weights_only=True,
                                                 verbose=1)
avg_callback = tfa.callbacks.AverageModelCheckpoint(filepath=checkpoint_dir, 
                                                    update_weights=True)

Train Model

Vanilla SGD Optimizer

#Build Model
model = create_model(sgd)

#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[cp_callback])
Epoch 1/5
1860/1875 [============================>.] - ETA: 0s - loss: 0.7621 - accuracy: 0.7457
Epoch 1: saving model to ./training
1875/1875 [==============================] - 4s 2ms/step - loss: 0.7600 - accuracy: 0.7463
Epoch 2/5
1860/1875 [============================>.] - ETA: 0s - loss: 0.5012 - accuracy: 0.8256
Epoch 2: saving model to ./training
1875/1875 [==============================] - 4s 2ms/step - loss: 0.5006 - accuracy: 0.8257
Epoch 3/5
1857/1875 [============================>.] - ETA: 0s - loss: 0.4550 - accuracy: 0.8406
Epoch 3: saving model to ./training
1875/1875 [==============================] - 4s 2ms/step - loss: 0.4550 - accuracy: 0.8404
Epoch 4/5
1870/1875 [============================>.] - ETA: 0s - loss: 0.4279 - accuracy: 0.8488
Epoch 4: saving model to ./training
1875/1875 [==============================] - 4s 2ms/step - loss: 0.4278 - accuracy: 0.8489
Epoch 5/5
1854/1875 [============================>.] - ETA: 0s - loss: 0.4089 - accuracy: 0.8558
Epoch 5: saving model to ./training
1875/1875 [==============================] - 4s 2ms/step - loss: 0.4088 - accuracy: 0.8557
<keras.callbacks.History at 0x7f1d70081fd0>
#Evalute results
model.load_weights(checkpoint_dir)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)
313/313 - 1s - loss: 101.1311 - accuracy: 0.7831 - 552ms/epoch - 2ms/step
Loss : 101.13106536865234
Accuracy : 0.7831000089645386

Moving Average SGD

#Build Model
model = create_model(moving_avg_sgd)

#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[avg_callback])
Epoch 1/5
1861/1875 [============================>.] - ETA: 0s - loss: 0.7915 - accuracy: 0.7391INFO:tensorflow:Assets written to: ./training/assets
1875/1875 [==============================] - 5s 2ms/step - loss: 0.7902 - accuracy: 0.7396
Epoch 2/5
1857/1875 [============================>.] - ETA: 0s - loss: 0.5112 - accuracy: 0.8227INFO:tensorflow:Assets written to: ./training/assets
1875/1875 [==============================] - 4s 2ms/step - loss: 0.5112 - accuracy: 0.8225
Epoch 3/5
1868/1875 [============================>.] - ETA: 0s - loss: 0.4662 - accuracy: 0.8388INFO:tensorflow:Assets written to: ./training/assets
1875/1875 [==============================] - 4s 2ms/step - loss: 0.4666 - accuracy: 0.8387
Epoch 4/5
1855/1875 [============================>.] - ETA: 0s - loss: 0.4423 - accuracy: 0.8462INFO:tensorflow:Assets written to: ./training/assets
1875/1875 [==============================] - 4s 2ms/step - loss: 0.4422 - accuracy: 0.8462
Epoch 5/5
1873/1875 [============================>.] - ETA: 0s - loss: 0.4235 - accuracy: 0.8519INFO:tensorflow:Assets written to: ./training/assets
1875/1875 [==============================] - 4s 2ms/step - loss: 0.4236 - accuracy: 0.8518
<keras.callbacks.History at 0x7f1d0c598430>
#Evalute results
model.load_weights(checkpoint_dir)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)
313/313 - 1s - loss: 101.1311 - accuracy: 0.7831 - 535ms/epoch - 2ms/step
Loss : 101.13106536865234
Accuracy : 0.7831000089645386

Stocastic Weight Average SGD

#Build Model
model = create_model(stocastic_avg_sgd)

#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[avg_callback])
Epoch 1/5
1866/1875 [============================>.] - ETA: 0s - loss: 0.8077 - accuracy: 0.7274INFO:tensorflow:Assets written to: ./training/assets
1875/1875 [==============================] - 6s 3ms/step - loss: 0.8064 - accuracy: 0.7278
Epoch 2/5
1875/1875 [==============================] - ETA: 0s - loss: 0.5710 - accuracy: 0.8050INFO:tensorflow:Assets written to: ./training/assets
1875/1875 [==============================] - 5s 3ms/step - loss: 0.5710 - accuracy: 0.8050
Epoch 3/5
1857/1875 [============================>.] - ETA: 0s - loss: 0.5401 - accuracy: 0.8140INFO:tensorflow:Assets written to: ./training/assets
1875/1875 [==============================] - 5s 3ms/step - loss: 0.5393 - accuracy: 0.8144
Epoch 4/5
1867/1875 [============================>.] - ETA: 0s - loss: 0.5222 - accuracy: 0.8194INFO:tensorflow:Assets written to: ./training/assets
1875/1875 [==============================] - 5s 3ms/step - loss: 0.5227 - accuracy: 0.8194
Epoch 5/5
1861/1875 [============================>.] - ETA: 0s - loss: 0.5138 - accuracy: 0.8211INFO:tensorflow:Assets written to: ./training/assets
1875/1875 [==============================] - 5s 3ms/step - loss: 0.5132 - accuracy: 0.8213
<keras.callbacks.History at 0x7f1ce85b3820>
#Evalute results
model.load_weights(checkpoint_dir)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)
313/313 - 1s - loss: 101.1311 - accuracy: 0.7831 - 563ms/epoch - 2ms/step
Loss : 101.13106536865234
Accuracy : 0.7831000089645386