Model Averaging

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

This notebook demonstrates how to use Moving Average Optimizer along with the Model Average Checkpoint from tensorflow addons pagkage.

Moving Averaging

The advantage of Moving Averaging is that they are less prone to rampant loss shifts or irregular data representation in the latest batch. It gives a smooothened and a more genral idea of the model training until some point.

Stocastic Averaging

Stocastic Weight Averaging converges to wider optimas. By doing so, it resembles geometric ensembeling. SWA is a simple method to improve model performance when used as a wrapper around other optimizers and averaging results from different points of trajectory of the inner optimizer.

Model Average Checkpoint

callbacks.ModelCheckpoint doesn't give you the option to save moving average weights in the middle of traning, which is why Model Average Optimizers required a custom callback. Using the update_weights parameter, ModelAverageCheckpoint allows you to:

  1. Assign the moving average weights to the model, and save them.
  2. Keep the old non-averaged weights, but the saved model uses the average weights.

Setup

pip install -q -U tensorflow-addons
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
import os

Build Model

def create_model(opt):
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(),                         
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])

    model.compile(optimizer=opt,
                    loss='sparse_categorical_crossentropy',
                    metrics=['accuracy'])

    return model

Prepare Dataset

#Load Fashion MNIST dataset
train, test = tf.keras.datasets.fashion_mnist.load_data()

images, labels = train
images = images/255.0
labels = labels.astype(np.int32)

fmnist_train_ds = tf.data.Dataset.from_tensor_slices((images, labels))
fmnist_train_ds = fmnist_train_ds.shuffle(5000).batch(32)

test_images, test_labels = test
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step

We will be comparing three optimizers here:

  • Unwrapped SGD
  • SGD with Moving Average
  • SGD with Stochastic Weight Averaging

And see how they perform with the same model.

#Optimizers 
sgd = tf.keras.optimizers.SGD(0.01)
moving_avg_sgd = tfa.optimizers.MovingAverage(sgd)
stocastic_avg_sgd = tfa.optimizers.SWA(sgd)

Both MovingAverage and StocasticAverage optimers use ModelAverageCheckpoint.

#Callback 
checkpoint_path = "./training/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_dir,
                                                 save_weights_only=True,
                                                 verbose=1)
avg_callback = tfa.callbacks.AverageModelCheckpoint(filepath=checkpoint_dir, 
                                                    update_weights=True)

Train Model

Vanilla SGD Optimizer

#Build Model
model = create_model(sgd)

#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[cp_callback])
Epoch 1/5
WARNING:tensorflow:Layer flatten is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because its dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

1848/1875 [============================>.] - ETA: 0s - loss: 0.7658 - accuracy: 0.7509
Epoch 00001: saving model to ./training
1875/1875 [==============================] - 3s 2ms/step - loss: 0.7624 - accuracy: 0.7518
Epoch 2/5
1869/1875 [============================>.] - ETA: 0s - loss: 0.4991 - accuracy: 0.8263
Epoch 00002: saving model to ./training
1875/1875 [==============================] - 3s 2ms/step - loss: 0.4991 - accuracy: 0.8263
Epoch 3/5
1867/1875 [============================>.] - ETA: 0s - loss: 0.4535 - accuracy: 0.8411
Epoch 00003: saving model to ./training
1875/1875 [==============================] - 3s 2ms/step - loss: 0.4531 - accuracy: 0.8413
Epoch 4/5
1870/1875 [============================>.] - ETA: 0s - loss: 0.4259 - accuracy: 0.8508
Epoch 00004: saving model to ./training
1875/1875 [==============================] - 3s 2ms/step - loss: 0.4261 - accuracy: 0.8506
Epoch 5/5
1851/1875 [============================>.] - ETA: 0s - loss: 0.4063 - accuracy: 0.8584
Epoch 00005: saving model to ./training
1875/1875 [==============================] - 3s 2ms/step - loss: 0.4063 - accuracy: 0.8584

<tensorflow.python.keras.callbacks.History at 0x7f18ec6f9518>
#Evalute results
model.load_weights(checkpoint_dir)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)
313/313 - 0s - loss: 73.2310 - accuracy: 0.8149
Loss : 73.23104095458984
Accuracy : 0.8148999810218811

Moving Average SGD

#Build Model
model = create_model(moving_avg_sgd)

#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[avg_callback])
Epoch 1/5
WARNING:tensorflow:Layer flatten_1 is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because its dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

1872/1875 [============================>.] - ETA: 0s - loss: 0.7884 - accuracy: 0.7307WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: ./training/assets
1875/1875 [==============================] - 5s 2ms/step - loss: 0.7880 - accuracy: 0.7308
Epoch 2/5
1875/1875 [==============================] - ETA: 0s - loss: 0.5053 - accuracy: 0.8230INFO:tensorflow:Assets written to: ./training/assets
1875/1875 [==============================] - 4s 2ms/step - loss: 0.5053 - accuracy: 0.8230
Epoch 3/5
1856/1875 [============================>.] - ETA: 0s - loss: 0.4555 - accuracy: 0.8392INFO:tensorflow:Assets written to: ./training/assets
1875/1875 [==============================] - 4s 2ms/step - loss: 0.4548 - accuracy: 0.8394
Epoch 4/5
1865/1875 [============================>.] - ETA: 0s - loss: 0.4241 - accuracy: 0.8495INFO:tensorflow:Assets written to: ./training/assets
1875/1875 [==============================] - 4s 2ms/step - loss: 0.4246 - accuracy: 0.8493
Epoch 5/5
1854/1875 [============================>.] - ETA: 0s - loss: 0.4052 - accuracy: 0.8579INFO:tensorflow:Assets written to: ./training/assets
1875/1875 [==============================] - 4s 2ms/step - loss: 0.4055 - accuracy: 0.8577

<tensorflow.python.keras.callbacks.History at 0x7f18ec42b128>
#Evalute results
model.load_weights(checkpoint_dir)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)
313/313 - 0s - loss: 73.2310 - accuracy: 0.8149
Loss : 73.23104095458984
Accuracy : 0.8148999810218811

Stocastic Weight Average SGD

#Build Model
model = create_model(stocastic_avg_sgd)

#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[avg_callback])
Epoch 1/5
WARNING:tensorflow:Layer flatten_2 is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because its dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

1871/1875 [============================>.] - ETA: 0s - loss: 0.7918 - accuracy: 0.7348INFO:tensorflow:Assets written to: ./training/assets
1875/1875 [==============================] - 5s 3ms/step - loss: 0.7915 - accuracy: 0.7348
Epoch 2/5
1864/1875 [============================>.] - ETA: 0s - loss: 0.5665 - accuracy: 0.8050INFO:tensorflow:Assets written to: ./training/assets
1875/1875 [==============================] - 5s 3ms/step - loss: 0.5667 - accuracy: 0.8050
Epoch 3/5
1859/1875 [============================>.] - ETA: 0s - loss: 0.5370 - accuracy: 0.8138INFO:tensorflow:Assets written to: ./training/assets
1875/1875 [==============================] - 5s 3ms/step - loss: 0.5363 - accuracy: 0.8140
Epoch 4/5
1861/1875 [============================>.] - ETA: 0s - loss: 0.5225 - accuracy: 0.8199INFO:tensorflow:Assets written to: ./training/assets
1875/1875 [==============================] - 5s 3ms/step - loss: 0.5228 - accuracy: 0.8199
Epoch 5/5
1861/1875 [============================>.] - ETA: 0s - loss: 0.5139 - accuracy: 0.8218INFO:tensorflow:Assets written to: ./training/assets
1875/1875 [==============================] - 5s 3ms/step - loss: 0.5138 - accuracy: 0.8220

<tensorflow.python.keras.callbacks.History at 0x7f18ec247c50>
#Evalute results
model.load_weights(checkpoint_dir)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)
313/313 - 0s - loss: 73.2310 - accuracy: 0.8149
Loss : 73.23104095458984
Accuracy : 0.8148999810218811