TF 2.0 is out! Get hands-on practice at TF World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

TensorFlow Addons Layers: WeightNormalization

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

This notebook will demonstrate how to use the Weight Normalization layer and how it can improve convergence.

WeightNormalization

A Simple Reparameterization to Accelerate Training of Deep Neural Networks:

Tim Salimans, Diederik P. Kingma (2016)

By reparameterizing the weights in this way we improve the conditioning of the optimization problem and we speed up convergence of stochastic gradient descent. Our reparameterization is inspired by batch normalization but does not introduce any dependencies between the examples in a minibatch. This means that our method can also be applied successfully to recurrent models such as LSTMs and to noise-sensitive applications such as deep reinforcement learning or generative models, for which batch normalization is less well suited. Although our method is much simpler, it still provides much of the speed-up of full batch normalization. In addition, the computational overhead of our method is lower, permitting more optimization steps to be taken in the same amount of time.

https://arxiv.org/abs/1602.07868



Setup

from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
!pip install -q  --no-deps tensorflow-addons~=0.6
import tensorflow_addons as tfa
import numpy as np
from matplotlib import pyplot as plt
# Hyper Parameters
batch_size = 32
epochs = 10
num_classes=10

Build Models

# Standard ConvNet
reg_model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(6, 5, activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(16, 5, activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(120, activation='relu'),
    tf.keras.layers.Dense(84, activation='relu'),
    tf.keras.layers.Dense(num_classes, activation='softmax'),
])
# WeightNorm ConvNet
wn_model = tf.keras.Sequential([
    tfa.layers.WeightNormalization(tf.keras.layers.Conv2D(6, 5, activation='relu')),
    tf.keras.layers.MaxPooling2D(2, 2),
    tfa.layers.WeightNormalization(tf.keras.layers.Conv2D(16, 5, activation='relu')),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Flatten(),
    tfa.layers.WeightNormalization(tf.keras.layers.Dense(120, activation='relu')),
    tfa.layers.WeightNormalization(tf.keras.layers.Dense(84, activation='relu')),
    tfa.layers.WeightNormalization(tf.keras.layers.Dense(num_classes, activation='softmax')),
])

Load Data

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Convert class vectors to binary class matrices.
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 11s 0us/step

Train Models

reg_model.compile(optimizer='adam', 
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

reg_history = reg_model.fit(x_train, y_train,
                            batch_size=batch_size,
                            epochs=epochs,
                            validation_data=(x_test, y_test),
                            shuffle=True)
Train on 50000 samples, validate on 10000 samples
Epoch 1/10
50000/50000 [==============================] - 9s 179us/sample - loss: 1.6500 - accuracy: 0.3975 - val_loss: 1.4637 - val_accuracy: 0.4771
Epoch 2/10
50000/50000 [==============================] - 6s 117us/sample - loss: 1.4010 - accuracy: 0.4943 - val_loss: 1.3549 - val_accuracy: 0.5148
Epoch 3/10
50000/50000 [==============================] - 6s 123us/sample - loss: 1.2747 - accuracy: 0.5428 - val_loss: 1.2682 - val_accuracy: 0.5464
Epoch 4/10
50000/50000 [==============================] - 7s 139us/sample - loss: 1.1879 - accuracy: 0.5759 - val_loss: 1.2014 - val_accuracy: 0.5729
Epoch 5/10
50000/50000 [==============================] - 6s 120us/sample - loss: 1.1257 - accuracy: 0.6001 - val_loss: 1.1686 - val_accuracy: 0.5922
Epoch 6/10
50000/50000 [==============================] - 6s 119us/sample - loss: 1.0697 - accuracy: 0.6210 - val_loss: 1.1347 - val_accuracy: 0.6025
Epoch 7/10
50000/50000 [==============================] - 6s 120us/sample - loss: 1.0278 - accuracy: 0.6335 - val_loss: 1.1663 - val_accuracy: 0.5902
Epoch 8/10
50000/50000 [==============================] - 6s 119us/sample - loss: 0.9841 - accuracy: 0.6502 - val_loss: 1.1561 - val_accuracy: 0.5952
Epoch 9/10
50000/50000 [==============================] - 6s 120us/sample - loss: 0.9488 - accuracy: 0.6623 - val_loss: 1.1646 - val_accuracy: 0.6040
Epoch 10/10
50000/50000 [==============================] - 6s 124us/sample - loss: 0.9162 - accuracy: 0.6734 - val_loss: 1.1418 - val_accuracy: 0.6050
wn_model.compile(optimizer='adam', 
                 loss='categorical_crossentropy',
                 metrics=['accuracy'])

wn_history = wn_model.fit(x_train, y_train,
                          batch_size=batch_size,
                          epochs=epochs,
                          validation_data=(x_test, y_test),
                          shuffle=True)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_addons/layers/wrappers.py:84: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.
Train on 50000 samples, validate on 10000 samples
Epoch 1/10
50000/50000 [==============================] - 15s 292us/sample - loss: 1.6051 - accuracy: 0.4168 - val_loss: 1.3976 - val_accuracy: 0.4992
Epoch 2/10
50000/50000 [==============================] - 11s 225us/sample - loss: 1.3392 - accuracy: 0.5202 - val_loss: 1.2541 - val_accuracy: 0.5550
Epoch 3/10
50000/50000 [==============================] - 11s 223us/sample - loss: 1.2187 - accuracy: 0.5668 - val_loss: 1.2583 - val_accuracy: 0.5502
Epoch 4/10
50000/50000 [==============================] - 11s 224us/sample - loss: 1.1398 - accuracy: 0.5977 - val_loss: 1.1849 - val_accuracy: 0.5850
Epoch 5/10
50000/50000 [==============================] - 11s 225us/sample - loss: 1.0731 - accuracy: 0.6203 - val_loss: 1.1568 - val_accuracy: 0.5931
Epoch 6/10
50000/50000 [==============================] - 11s 214us/sample - loss: 1.0216 - accuracy: 0.6380 - val_loss: 1.1560 - val_accuracy: 0.5972
Epoch 7/10
50000/50000 [==============================] - 10s 198us/sample - loss: 0.9723 - accuracy: 0.6572 - val_loss: 1.1362 - val_accuracy: 0.6076
Epoch 8/10
50000/50000 [==============================] - 10s 198us/sample - loss: 0.9292 - accuracy: 0.6710 - val_loss: 1.1173 - val_accuracy: 0.6165
Epoch 9/10
50000/50000 [==============================] - 10s 199us/sample - loss: 0.8874 - accuracy: 0.6856 - val_loss: 1.1088 - val_accuracy: 0.6156
Epoch 10/10
50000/50000 [==============================] - 10s 199us/sample - loss: 0.8563 - accuracy: 0.6981 - val_loss: 1.1266 - val_accuracy: 0.6202
reg_accuracy = reg_history.history['accuracy']
wn_accuracy = wn_history.history['accuracy']

plt.plot(np.linspace(0, epochs,  epochs), reg_accuracy,
             color='red', label='Regular ConvNet')

plt.plot(np.linspace(0, epochs, epochs), wn_accuracy,
         color='blue', label='WeightNorm ConvNet')

plt.title('WeightNorm Accuracy Comparison')
plt.legend()
plt.grid(True)
plt.show()

png