tf.distribute.Strategy with Training Loops

View on TensorFlow.org Run in Google Colab View source on GitHub

This tutorial demonstrates how to use tf.distribute.Strategy with custom training loops. We will train a simple CNN model on the fashion MNIST dataset. The fashion MNIST dataset contains 60000 train images of size 28 x 28 and 10000 test images of size 28 x 28.

We are using custom training loops to train our model because they give us flexibility and a greater control on training. Moreover, it is easier to debug the model and the training loop.

from __future__ import absolute_import, division, print_function, unicode_literals

# Import TensorFlow
!pip install -q tensorflow==2.0.0-alpha0
import tensorflow as tf

# Helper libraries
import numpy as np
import os

print(tf.__version__)
2.0.0-alpha0

Download the fashion mnist dataset

fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step

Create a strategy to distribute the variables and the graph

How does tf.distribute.MirroredStrategy strategy work?

  • All the variables and the model graph is replicated on the replicas.
  • Input is evenly distributed across the replicas.
  • Each replica calculates the loss and gradients for the input it received.
  • The gradients are synced across all the replicas by summing them.
  • After the sync, the same update is made to the copies of the variables on each replica.
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()
WARNING: Logging before flag parsing goes to stderr.
W0508 21:10:42.999111 140054595954432 cross_device_ops.py:1111] Not all devices in `tf.distribute.Strategy` are visible to TensorFlow.
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

Setup input pipeline

If a model is trained on multiple GPUs, the batch size should be increased accordingly so as to make effective use of the extra computing power. Moreover, the learning rate should be tuned accordingly.

BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10
train_steps_per_epoch = len(train_images) // GLOBAL_BATCH_SIZE
test_steps_per_epoch = len(test_images) // GLOBAL_BATCH_SIZE

Create the iterators inside a strategy.scope:

with strategy.scope():

  train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) 
  train_iterator = strategy.make_dataset_iterator(train_dataset)
  
  test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE) 
  test_iterator = strategy.make_dataset_iterator(test_dataset)

Model Creation

Create a model using tf.keras.Sequential. You can also use the Model Subclassing API to do this.

def create_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Conv2D(64, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')
    ])

  return model
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

Define the loss function

Normally, on a single machine with 1 GPU/CPU, loss is divided by the number of examples in the batch of input.

So, how should the loss be calculated when using a tf.distribute.Strategy?

  • For an example, let's say you have 4 GPU's and a batch size of 64. One batch of input is distributed across the replicas (4 GPUs), each replica getting an input of size 16.

  • The model on each replica does a forward pass with its respective input and calculates the loss. Now, instead of dividing the loss by the number of examples in its respective input (BATCH_SIZE_PER_REPLICA = 16), the loss should be divided by the GLOBAL_BATCH_SIZE (64).

Why do this?

  • This needs to be done because after the gradients are calculated on each replica, they are synced across the replicas by summing them.

How to do this in TensorFlow? * If you're writing a custom training loop, as in this tutorial, you should sum the per example losses and divide the sum by the GLOBAL_BATCH_SIZE: scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE)

  • Using tf.reduce_mean is not recommended. Doing so divides the loss by actual per replica batch size which may vary step to step.

  • This reduction and scaling is done automatically in keras model.compile and model.fit

  • If using tf.keras.losses classes (as in the example below), the loss reduction needs to be explicitly specified to be one of NONE or SUM. AUTO and SUM_OVER_BATCH_SIZE are disallowed when used with tf.distribute.Strategy. AUTO is disallowed because the user should explicitly think about what reduction they want to make sure it is correct in the distributed case. SUM_OVER_BATCH_SIZE is disallowed because currently it would only divide by per replica batch size, and leave the dividing by number of replicas to the user, which might be easy to miss. So instead we ask the user do the reduction themselves explicitly.

with strategy.scope():
  # Set reduction to `none` so we can do the reduction afterwards and divide by
  # global batch size.
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
      reduction=tf.keras.losses.Reduction.NONE)
  def compute_loss(labels, predictions):
    per_example_loss = loss_object(labels, predictions)
    return tf.reduce_sum(per_example_loss) * (1. / GLOBAL_BATCH_SIZE)

Define the metrics to track loss and accuracy

These metrics track the loss and the accuracy. You can use .result() to get the accumulated statistics at any time.

with strategy.scope():
  train_loss = tf.keras.metrics.Mean(name='train_loss')
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='test_accuracy')

Training loop

# model and optimizer must be created under `strategy.scope`.
with strategy.scope():
  model = create_model()

  optimizer = tf.keras.optimizers.Adam()

  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
with strategy.scope():
  # Train step
  def train_step(inputs):
    images, labels = inputs

    with tf.GradientTape() as tape:
      predictions = model(images, training=True)
      loss = compute_loss(labels, predictions)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss(loss)
    train_accuracy(labels, predictions)

  # Test step
  def test_step(inputs):
    images, labels = inputs

    predictions = model(images, training=False)
    t_loss = loss_object(labels, predictions)

    test_loss(t_loss)
    test_accuracy(labels, predictions)
with strategy.scope():
  # `experimental_run` replicates the provided computation and runs it
  # with the distributed input.

  @tf.function
  def distributed_train():
    return strategy.experimental_run(train_step, train_iterator)

  @tf.function
  def distributed_test():
    return strategy.experimental_run(test_step, test_iterator)

  for epoch in range(EPOCHS):
    # Note: This code is expected to change in the near future.

    # TRAIN LOOP
    # Initialize the iterator
    train_iterator.initialize()
    for _ in range(train_steps_per_epoch):
      distributed_train()

    # TEST LOOP
    test_iterator.initialize()
    for _ in range(test_steps_per_epoch):
      distributed_test()

    if epoch % 2 == 0:
      checkpoint.save(checkpoint_prefix)

    template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
                "Test Accuracy: {}")
    print (template.format(epoch+1, train_loss.result(),
                           train_accuracy.result()*100, test_loss.result(),
                           test_accuracy.result()*100))

    train_loss.reset_states()
    test_loss.reset_states()
    train_accuracy.reset_states()
    test_accuracy.reset_states()
Epoch 1, Loss: 0.5082187652587891, Accuracy: 81.4751205444336, Test Loss: 0.3758997917175293, Test Accuracy: 86.59855651855469
Epoch 2, Loss: 0.3360634744167328, Accuracy: 87.86519622802734, Test Loss: 0.3365756869316101, Test Accuracy: 87.80047607421875
Epoch 3, Loss: 0.2890585958957672, Accuracy: 89.42436218261719, Test Loss: 0.3115336298942566, Test Accuracy: 88.70191955566406
Epoch 4, Loss: 0.25779587030410767, Accuracy: 90.59664916992188, Test Loss: 0.27641329169273376, Test Accuracy: 90.05409240722656
Epoch 5, Loss: 0.2319592386484146, Accuracy: 91.5104751586914, Test Loss: 0.2744285762310028, Test Accuracy: 90.01402282714844
Epoch 6, Loss: 0.2102610170841217, Accuracy: 92.38260650634766, Test Loss: 0.2595132291316986, Test Accuracy: 90.27444458007812
Epoch 7, Loss: 0.1924610137939453, Accuracy: 92.91121673583984, Test Loss: 0.2634430527687073, Test Accuracy: 90.36457824707031
Epoch 8, Loss: 0.1769772171974182, Accuracy: 93.5132064819336, Test Loss: 0.24882595241069794, Test Accuracy: 91.025634765625
Epoch 9, Loss: 0.16251613199710846, Accuracy: 94.0634994506836, Test Loss: 0.2549446225166321, Test Accuracy: 90.96554565429688
Epoch 10, Loss: 0.1479620635509491, Accuracy: 94.51707458496094, Test Loss: 0.2599133551120758, Test Accuracy: 90.99559020996094

Restore the latest checkpoint and test

A model checkpointed with a tf.distribute.Strategy can be restored with or without a strategy.

eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='eval_accuracy')

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
@tf.function
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for images, labels in test_dataset:
  eval_step(images, labels)

print ('Accuracy after restoring the saved model without strategy: {}'.format(
    eval_accuracy.result()*100))
W0508 21:12:18.141699 140054595954432 tf_logging.py:161] Entity <method-wrapper '__call__' of weakref object at 0x7f60881f2f98> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
W0508 21:12:18.148729 140054595954432 tf_logging.py:161] Entity <method-wrapper '__call__' of weakref object at 0x7f60881a93b8> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
W0508 21:12:18.153354 140054595954432 tf_logging.py:161] Entity <method-wrapper '__call__' of weakref object at 0x7f60881ec048> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
W0508 21:12:18.244294 140054595954432 tf_logging.py:161] Entity <method-wrapper '__call__' of weakref object at 0x7f60881ec098> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
W0508 21:12:18.248667 140054595954432 tf_logging.py:161] Entity <method-wrapper '__call__' of weakref object at 0x7f60881ec0e8> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
W0508 21:12:18.252809 140054595954432 tf_logging.py:161] Entity <method-wrapper '__call__' of weakref object at 0x7f60881ec138> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
W0508 21:12:18.260956 140054595954432 tf_logging.py:161] Entity <method-wrapper '__call__' of weakref object at 0x7f60881ec188> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
W0508 21:12:18.269250 140054595954432 tf_logging.py:161] Entity <method-wrapper '__call__' of weakref object at 0x7f60881ec1d8> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.

WARNING: Entity <method-wrapper '__call__' of weakref object at 0x7f60881f2f98> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
WARNING: Entity <method-wrapper '__call__' of weakref object at 0x7f60881a93b8> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
WARNING: Entity <method-wrapper '__call__' of weakref object at 0x7f60881ec048> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
WARNING: Entity <method-wrapper '__call__' of weakref object at 0x7f60881ec098> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
WARNING: Entity <method-wrapper '__call__' of weakref object at 0x7f60881ec0e8> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
WARNING: Entity <method-wrapper '__call__' of weakref object at 0x7f60881ec138> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
WARNING: Entity <method-wrapper '__call__' of weakref object at 0x7f60881ec188> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
WARNING: Entity <method-wrapper '__call__' of weakref object at 0x7f60881ec1d8> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
Accuracy after restoring the saved model without strategy: 90.97000122070312

Next Steps

Try out the new tf.distribute.Strategy API on your models.