TensorFlow Addons Layers: WeightNormalization

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

This notebook will demonstrate how to use the Weight Normalization layer and how it can improve convergence.

WeightNormalization

A Simple Reparameterization to Accelerate Training of Deep Neural Networks:

Tim Salimans, Diederik P. Kingma (2016)

By reparameterizing the weights in this way we improve the conditioning of the optimization problem and we speed up convergence of stochastic gradient descent. Our reparameterization is inspired by batch normalization but does not introduce any dependencies between the examples in a minibatch. This means that our method can also be applied successfully to recurrent models such as LSTMs and to noise-sensitive applications such as deep reinforcement learning or generative models, for which batch normalization is less well suited. Although our method is much simpler, it still provides much of the speed-up of full batch normalization. In addition, the computational overhead of our method is lower, permitting more optimization steps to be taken in the same amount of time.

https://arxiv.org/abs/1602.07868



Setup

try:
  %tensorflow_version 2.x
except:
  pass

import tensorflow as tf
pip install -q  --no-deps tensorflow-addons~=0.7
import tensorflow_addons as tfa
import numpy as np
from matplotlib import pyplot as plt
# Hyper Parameters
batch_size = 32
epochs = 10
num_classes=10

Build Models

# Standard ConvNet
reg_model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(6, 5, activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(16, 5, activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(120, activation='relu'),
    tf.keras.layers.Dense(84, activation='relu'),
    tf.keras.layers.Dense(num_classes, activation='softmax'),
])
# WeightNorm ConvNet
wn_model = tf.keras.Sequential([
    tfa.layers.WeightNormalization(tf.keras.layers.Conv2D(6, 5, activation='relu')),
    tf.keras.layers.MaxPooling2D(2, 2),
    tfa.layers.WeightNormalization(tf.keras.layers.Conv2D(16, 5, activation='relu')),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Flatten(),
    tfa.layers.WeightNormalization(tf.keras.layers.Dense(120, activation='relu')),
    tfa.layers.WeightNormalization(tf.keras.layers.Dense(84, activation='relu')),
    tfa.layers.WeightNormalization(tf.keras.layers.Dense(num_classes, activation='softmax')),
])

Load Data

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Convert class vectors to binary class matrices.
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 11s 0us/step

Train Models

reg_model.compile(optimizer='adam', 
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

reg_history = reg_model.fit(x_train, y_train,
                            batch_size=batch_size,
                            epochs=epochs,
                            validation_data=(x_test, y_test),
                            shuffle=True)
Train on 50000 samples, validate on 10000 samples
Epoch 1/10
50000/50000 [==============================] - 9s 172us/sample - loss: 1.6498 - accuracy: 0.3970 - val_loss: 1.5033 - val_accuracy: 0.4654
Epoch 2/10
50000/50000 [==============================] - 6s 125us/sample - loss: 1.3483 - accuracy: 0.5171 - val_loss: 1.2956 - val_accuracy: 0.5360
Epoch 3/10
50000/50000 [==============================] - 6s 113us/sample - loss: 1.2373 - accuracy: 0.5584 - val_loss: 1.2122 - val_accuracy: 0.5657
Epoch 4/10
50000/50000 [==============================] - 6s 113us/sample - loss: 1.1605 - accuracy: 0.5867 - val_loss: 1.1952 - val_accuracy: 0.5788
Epoch 5/10
50000/50000 [==============================] - 6s 112us/sample - loss: 1.1009 - accuracy: 0.6108 - val_loss: 1.2094 - val_accuracy: 0.5742
Epoch 6/10
50000/50000 [==============================] - 6s 111us/sample - loss: 1.0499 - accuracy: 0.6299 - val_loss: 1.1338 - val_accuracy: 0.5985
Epoch 7/10
50000/50000 [==============================] - 6s 114us/sample - loss: 1.0020 - accuracy: 0.6467 - val_loss: 1.1444 - val_accuracy: 0.5962
Epoch 8/10
50000/50000 [==============================] - 6s 121us/sample - loss: 0.9642 - accuracy: 0.6598 - val_loss: 1.1048 - val_accuracy: 0.6178
Epoch 9/10
50000/50000 [==============================] - 6s 120us/sample - loss: 0.9310 - accuracy: 0.6712 - val_loss: 1.1191 - val_accuracy: 0.6110
Epoch 10/10
50000/50000 [==============================] - 6s 111us/sample - loss: 0.9020 - accuracy: 0.6822 - val_loss: 1.1114 - val_accuracy: 0.6138

wn_model.compile(optimizer='adam', 
                 loss='categorical_crossentropy',
                 metrics=['accuracy'])

wn_history = wn_model.fit(x_train, y_train,
                          batch_size=batch_size,
                          epochs=epochs,
                          validation_data=(x_test, y_test),
                          shuffle=True)
Train on 50000 samples, validate on 10000 samples
Epoch 1/10
50000/50000 [==============================] - 15s 299us/sample - loss: 1.5882 - accuracy: 0.4222 - val_loss: 1.4699 - val_accuracy: 0.4696
Epoch 2/10
50000/50000 [==============================] - 12s 235us/sample - loss: 1.3304 - accuracy: 0.5215 - val_loss: 1.3386 - val_accuracy: 0.5171
Epoch 3/10
50000/50000 [==============================] - 11s 221us/sample - loss: 1.2203 - accuracy: 0.5625 - val_loss: 1.2282 - val_accuracy: 0.5591
Epoch 4/10
50000/50000 [==============================] - 11s 214us/sample - loss: 1.1423 - accuracy: 0.5939 - val_loss: 1.1955 - val_accuracy: 0.5764
Epoch 5/10
50000/50000 [==============================] - 11s 222us/sample - loss: 1.0763 - accuracy: 0.6211 - val_loss: 1.1734 - val_accuracy: 0.5925
Epoch 6/10
50000/50000 [==============================] - 12s 237us/sample - loss: 1.0214 - accuracy: 0.6398 - val_loss: 1.1771 - val_accuracy: 0.5869
Epoch 7/10
50000/50000 [==============================] - 11s 214us/sample - loss: 0.9739 - accuracy: 0.6553 - val_loss: 1.1415 - val_accuracy: 0.6011
Epoch 8/10
50000/50000 [==============================] - 11s 215us/sample - loss: 0.9281 - accuracy: 0.6738 - val_loss: 1.1489 - val_accuracy: 0.5955
Epoch 9/10
50000/50000 [==============================] - 11s 212us/sample - loss: 0.8845 - accuracy: 0.6865 - val_loss: 1.1386 - val_accuracy: 0.6074
Epoch 10/10
50000/50000 [==============================] - 11s 212us/sample - loss: 0.8446 - accuracy: 0.7022 - val_loss: 1.1453 - val_accuracy: 0.6148

reg_accuracy = reg_history.history['accuracy']
wn_accuracy = wn_history.history['accuracy']

plt.plot(np.linspace(0, epochs,  epochs), reg_accuracy,
             color='red', label='Regular ConvNet')

plt.plot(np.linspace(0, epochs, epochs), wn_accuracy,
         color='blue', label='WeightNorm ConvNet')

plt.title('WeightNorm Accuracy Comparison')
plt.legend()
plt.grid(True)
plt.show()

png