TensorFlow Addons Layers: WeightNormalization

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

This notebook will demonstrate how to use the Weight Normalization layer and how it can improve convergence.

WeightNormalization

A Simple Reparameterization to Accelerate Training of Deep Neural Networks:

Tim Salimans, Diederik P. Kingma (2016)

By reparameterizing the weights in this way we improve the conditioning of the optimization problem and we speed up convergence of stochastic gradient descent. Our reparameterization is inspired by batch normalization but does not introduce any dependencies between the examples in a minibatch. This means that our method can also be applied successfully to recurrent models such as LSTMs and to noise-sensitive applications such as deep reinforcement learning or generative models, for which batch normalization is less well suited. Although our method is much simpler, it still provides much of the speed-up of full batch normalization. In addition, the computational overhead of our method is lower, permitting more optimization steps to be taken in the same amount of time.

https://arxiv.org/abs/1602.07868



Setup

pip install -q -U tensorflow-addons
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
from matplotlib import pyplot as plt
# Hyper Parameters
batch_size = 32
epochs = 10
num_classes=10

Build Models

# Standard ConvNet
reg_model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(6, 5, activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(16, 5, activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(120, activation='relu'),
    tf.keras.layers.Dense(84, activation='relu'),
    tf.keras.layers.Dense(num_classes, activation='softmax'),
])
# WeightNorm ConvNet
wn_model = tf.keras.Sequential([
    tfa.layers.WeightNormalization(tf.keras.layers.Conv2D(6, 5, activation='relu')),
    tf.keras.layers.MaxPooling2D(2, 2),
    tfa.layers.WeightNormalization(tf.keras.layers.Conv2D(16, 5, activation='relu')),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Flatten(),
    tfa.layers.WeightNormalization(tf.keras.layers.Dense(120, activation='relu')),
    tfa.layers.WeightNormalization(tf.keras.layers.Dense(84, activation='relu')),
    tfa.layers.WeightNormalization(tf.keras.layers.Dense(num_classes, activation='softmax')),
])

Load Data

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Convert class vectors to binary class matrices.
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 11s 0us/step

Train Models

reg_model.compile(optimizer='adam', 
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

reg_history = reg_model.fit(x_train, y_train,
                            batch_size=batch_size,
                            epochs=epochs,
                            validation_data=(x_test, y_test),
                            shuffle=True)
Epoch 1/10
1563/1563 [==============================] - 5s 3ms/step - loss: 1.6401 - accuracy: 0.3979 - val_loss: 1.4453 - val_accuracy: 0.4744
Epoch 2/10
1563/1563 [==============================] - 4s 3ms/step - loss: 1.3549 - accuracy: 0.5157 - val_loss: 1.3057 - val_accuracy: 0.5389
Epoch 3/10
1563/1563 [==============================] - 5s 3ms/step - loss: 1.2301 - accuracy: 0.5635 - val_loss: 1.2365 - val_accuracy: 0.5606
Epoch 4/10
1563/1563 [==============================] - 5s 3ms/step - loss: 1.1502 - accuracy: 0.5914 - val_loss: 1.1722 - val_accuracy: 0.5833
Epoch 5/10
1563/1563 [==============================] - 5s 3ms/step - loss: 1.0849 - accuracy: 0.6188 - val_loss: 1.1436 - val_accuracy: 0.5881
Epoch 6/10
1563/1563 [==============================] - 5s 3ms/step - loss: 1.0332 - accuracy: 0.6342 - val_loss: 1.1441 - val_accuracy: 0.5914
Epoch 7/10
1563/1563 [==============================] - 5s 3ms/step - loss: 0.9902 - accuracy: 0.6500 - val_loss: 1.1391 - val_accuracy: 0.5968
Epoch 8/10
1563/1563 [==============================] - 5s 3ms/step - loss: 0.9517 - accuracy: 0.6627 - val_loss: 1.1052 - val_accuracy: 0.6149
Epoch 9/10
1563/1563 [==============================] - 5s 3ms/step - loss: 0.9160 - accuracy: 0.6760 - val_loss: 1.1244 - val_accuracy: 0.6162
Epoch 10/10
1563/1563 [==============================] - 4s 3ms/step - loss: 0.8858 - accuracy: 0.6859 - val_loss: 1.1506 - val_accuracy: 0.6041

wn_model.compile(optimizer='adam', 
                 loss='categorical_crossentropy',
                 metrics=['accuracy'])

wn_history = wn_model.fit(x_train, y_train,
                          batch_size=batch_size,
                          epochs=epochs,
                          validation_data=(x_test, y_test),
                          shuffle=True)
Epoch 1/10
1563/1563 [==============================] - 8s 5ms/step - loss: 1.6176 - accuracy: 0.4090 - val_loss: 1.4019 - val_accuracy: 0.4954
Epoch 2/10
1563/1563 [==============================] - 8s 5ms/step - loss: 1.3608 - accuracy: 0.5127 - val_loss: 1.2994 - val_accuracy: 0.5324
Epoch 3/10
1563/1563 [==============================] - 7s 5ms/step - loss: 1.2523 - accuracy: 0.5518 - val_loss: 1.2891 - val_accuracy: 0.5459
Epoch 4/10
1563/1563 [==============================] - 7s 5ms/step - loss: 1.1785 - accuracy: 0.5777 - val_loss: 1.2377 - val_accuracy: 0.5627
Epoch 5/10
1563/1563 [==============================] - 7s 5ms/step - loss: 1.1089 - accuracy: 0.6059 - val_loss: 1.1712 - val_accuracy: 0.5824
Epoch 6/10
1563/1563 [==============================] - 7s 5ms/step - loss: 1.0560 - accuracy: 0.6232 - val_loss: 1.1531 - val_accuracy: 0.5927
Epoch 7/10
1563/1563 [==============================] - 7s 5ms/step - loss: 1.0041 - accuracy: 0.6425 - val_loss: 1.1712 - val_accuracy: 0.5954
Epoch 8/10
1563/1563 [==============================] - 7s 5ms/step - loss: 0.9602 - accuracy: 0.6577 - val_loss: 1.1491 - val_accuracy: 0.6038
Epoch 9/10
1563/1563 [==============================] - 7s 5ms/step - loss: 0.9127 - accuracy: 0.6757 - val_loss: 1.1410 - val_accuracy: 0.6019
Epoch 10/10
1563/1563 [==============================] - 7s 5ms/step - loss: 0.8756 - accuracy: 0.6869 - val_loss: 1.1523 - val_accuracy: 0.6041

reg_accuracy = reg_history.history['accuracy']
wn_accuracy = wn_history.history['accuracy']

plt.plot(np.linspace(0, epochs,  epochs), reg_accuracy,
             color='red', label='Regular ConvNet')

plt.plot(np.linspace(0, epochs, epochs), wn_accuracy,
         color='blue', label='WeightNorm ConvNet')

plt.title('WeightNorm Accuracy Comparison')
plt.legend()
plt.grid(True)
plt.show()

png