Watch talks from the 2019 TensorFlow Dev Summit Watch now

tf.distribute.Strategy with Training Loops

View on TensorFlow.org Run in Google Colab View source on GitHub

This tutorial demonstrates how to use tf.distribute.Strategy with custom training loops. We will train a simple CNN model on the fashion MNIST dataset. The fashion MNIST dataset contains 60000 train images of size 28 x 28 and 10000 test images of size 28 x 28.

We are using custom training loops to train our model because they give us flexibility and a greater control on training. Moreover, it is easier to debug the model and the training loop.

from __future__ import absolute_import, division, print_function

# Import TensorFlow
!pip install -q tensorflow==2.0.0-alpha0
import tensorflow as tf

# Helper libraries
import numpy as np
import os

print(tf.__version__)
2.0.0-alpha0

Download the fashion mnist dataset

fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional 
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step

Create a strategy to distribute the variables and the graph

How does tf.distribute.MirroredStrategy strategy work?

  • All the variables and the model graph is replicated on the replicas.
  • Input is evenly distributed across the replicas.
  • Each replica calculates the loss and gradients for the input it received.
  • The gradients are synced across all the replicas by summing them.
  • After the sync, the same update is made to the copies of the variables on each replica.
# If the list of devices is not specified in the 
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()
WARNING: Logging before flag parsing goes to stderr.
W0307 18:27:09.510527 140515314800384 cross_device_ops.py:1111] Not all devices in `tf.distribute.Strategy` are visible to TensorFlow.
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

Setup input pipeline

If a model is trained on multiple GPUs, the batch size should be increased accordingly so as to make effective use of the extra computing power. Moreover, the learning rate should be tuned accordingly.

BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10
train_steps_per_epoch = len(train_images) // BATCH_SIZE
test_steps_per_epoch = len(test_images) // BATCH_SIZE

strategy.experimental_make_numpy_iterator creates an iterator that evenly distributes the data across all the replicas.

This is more efficient than using tf.data.Dataset.from_tensor_slices directly since it avoids recording the training data as a constant in the graph.

If you are not using strategy.experimental_make_numpy_iterator, then create the iterator inside a strategy.scope like this:

train_dataset = tf.data.Dataset.from_tensor_slices( (train_images, train_labels)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) train_iterator = strategy.make_dataset_iterator(train_dataset)

with strategy.scope():
  train_iterator = strategy.experimental_make_numpy_iterator(
      (train_images, train_labels), BATCH_SIZE, shuffle=BUFFER_SIZE)

  test_iterator = strategy.experimental_make_numpy_iterator(
      (test_images, test_labels), BATCH_SIZE, shuffle=None)

Model Creation

Create a model using tf.keras.Sequential. You can also use the Model Subclassing API to do this.

def create_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Conv2D(64, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')
    ])

  return model
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

Define the loss function

Normally, on a single machine with 1 GPU/CPU, loss is divided by the number of examples in the batch of input.

So, how is the loss calculated when using a tf.distribute.Strategy?

  • For an example, let's say you have 4 GPU's and a batch size of 64. One batch of input is distributed across the replicas (4 GPUs), each replica getting an input of size 16.

  • The model on each replica does a forward pass with its respective input and calculates the loss. Now, instead of dividing the loss by the number of examples in its respective input (16), the loss is divided by the global input size (64).

Why is this done?

  • This is done because after the gradients are calculated on each replica, they are synced across the replicas by summing them.

How to handle this in TensorFlow?

  • tf.keras.losses handle this automatically.

  • If you distribute a custom loss function, don't implement it using tf.reduce_mean (which divides by the local batch size), divide the sum by the global batch size: scale_loss = tf.reduce_sum(loss) * (1. / global_batch_size)

with strategy.scope():
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy()

Define the metrics to track loss and accuracy

These metrics track the loss and the accuracy. You can use .result() to get the accumulated statistics at any time.

with strategy.scope():
  train_loss = tf.keras.metrics.Mean(name='train_loss')
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='test_accuracy')

Training loop

# model and optimizer must be created under `strategy.scope`.
with strategy.scope():
  model = create_model()

  optimizer = tf.keras.optimizers.Adam()
  
  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
with strategy.scope():
  # Train step
  def train_step(inputs):
    images, labels = inputs

    with tf.GradientTape() as tape:
      predictions = model(images, training=True)
      loss = loss_object(labels, predictions)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss(loss)
    train_accuracy(labels, predictions)

  # Test step
  def test_step(inputs):
    images, labels = inputs

    predictions = model(images, training=False)
    t_loss = loss_object(labels, predictions)

    test_loss(t_loss)
    test_accuracy(labels, predictions)
with strategy.scope():
  # `experimental_run` replicates the provided computation and runs it 
  # with the distributed input.
  
  @tf.function
  def distributed_train():
    return strategy.experimental_run(train_step, train_iterator)
  
  @tf.function
  def distributed_test():
    return strategy.experimental_run(test_step, test_iterator)
    
  for epoch in range(EPOCHS):
    # Note: This code is expected to change in the near future.
    
    # TRAIN LOOP
    # Initialize the iterator
    train_iterator.initialize()
    for _ in range(train_steps_per_epoch):
      distributed_train()

    # TEST LOOP
    test_iterator.initialize()
    for _ in range(test_steps_per_epoch):
      distributed_test()
    
    if epoch % 2 == 0:
      checkpoint.save(checkpoint_prefix)

    template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
                "Test Accuracy: {}")
    print (template.format(epoch+1, train_loss.result(), 
                           train_accuracy.result()*100, test_loss.result(), 
                           test_accuracy.result()*100))
    
    train_loss.reset_states()
    test_loss.reset_states()
    train_accuracy.reset_states()
    test_accuracy.reset_states()
Epoch 1, Loss: 0.507440447807312, Accuracy: 81.790283203125, Test Loss: 0.3900431990623474, Test Accuracy: 85.78726196289062
Epoch 2, Loss: 0.3392734229564667, Accuracy: 87.82183837890625, Test Loss: 0.32898980379104614, Test Accuracy: 88.37139892578125
Epoch 3, Loss: 0.29349881410598755, Accuracy: 89.33931732177734, Test Loss: 0.3083241283893585, Test Accuracy: 89.03245544433594
Epoch 4, Loss: 0.2620709538459778, Accuracy: 90.48326110839844, Test Loss: 0.30253908038139343, Test Accuracy: 89.02243041992188
Epoch 5, Loss: 0.2398786097764969, Accuracy: 91.29369354248047, Test Loss: 0.2742038369178772, Test Accuracy: 89.91386413574219
Epoch 6, Loss: 0.2202199250459671, Accuracy: 91.83898162841797, Test Loss: 0.2755086123943329, Test Accuracy: 90.16426086425781
Epoch 7, Loss: 0.203200101852417, Accuracy: 92.42762756347656, Test Loss: 0.2782629728317261, Test Accuracy: 90.12419891357422
Epoch 8, Loss: 0.18854723870754242, Accuracy: 92.9345703125, Test Loss: 0.30404072999954224, Test Accuracy: 89.443115234375
Epoch 9, Loss: 0.17325672507286072, Accuracy: 93.5082015991211, Test Loss: 0.2459276169538498, Test Accuracy: 91.37620544433594
Epoch 10, Loss: 0.15984027087688446, Accuracy: 94.00846862792969, Test Loss: 0.2482466846704483, Test Accuracy: 91.43629455566406

Restore the latest checkpoint and test

A model checkpointed with a tf.distribute.Strategy can be restored with or without a strategy.

eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='eval_accuracy')

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(BATCH_SIZE)
@tf.function
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for images, labels in test_dataset:
  eval_step(images, labels)

print ('Accuracy after restoring the saved model without strategy: {}'.format(
    eval_accuracy.result()*100))
W0307 18:28:49.161143 140515314800384 tf_logging.py:161] Entity <method-wrapper '__call__' of weakref object at 0x7fcbcc3d8bd8> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
W0307 18:28:49.167026 140515314800384 tf_logging.py:161] Entity <method-wrapper '__call__' of weakref object at 0x7fcbcc4072c8> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
W0307 18:28:49.171065 140515314800384 tf_logging.py:161] Entity <method-wrapper '__call__' of weakref object at 0x7fcbcc3d4048> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
W0307 18:28:49.175788 140515314800384 tf_logging.py:161] Entity <method-wrapper '__call__' of weakref object at 0x7fcbcc3d4098> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
W0307 18:28:49.180326 140515314800384 tf_logging.py:161] Entity <method-wrapper '__call__' of weakref object at 0x7fcbcc3d40e8> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
W0307 18:28:49.184488 140515314800384 tf_logging.py:161] Entity <method-wrapper '__call__' of weakref object at 0x7fcbcc3d43b8> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
W0307 18:28:49.188759 140515314800384 tf_logging.py:161] Entity <method-wrapper '__call__' of weakref object at 0x7fcbcc3d4408> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
W0307 18:28:49.193886 140515314800384 tf_logging.py:161] Entity <method-wrapper '__call__' of weakref object at 0x7fcbcc30f778> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.

WARNING: Entity <method-wrapper '__call__' of weakref object at 0x7fcbcc3d8bd8> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
WARNING: Entity <method-wrapper '__call__' of weakref object at 0x7fcbcc4072c8> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
WARNING: Entity <method-wrapper '__call__' of weakref object at 0x7fcbcc3d4048> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
WARNING: Entity <method-wrapper '__call__' of weakref object at 0x7fcbcc3d4098> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
WARNING: Entity <method-wrapper '__call__' of weakref object at 0x7fcbcc3d40e8> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
WARNING: Entity <method-wrapper '__call__' of weakref object at 0x7fcbcc3d43b8> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
WARNING: Entity <method-wrapper '__call__' of weakref object at 0x7fcbcc3d4408> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
WARNING: Entity <method-wrapper '__call__' of weakref object at 0x7fcbcc30f778> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
Accuracy after restoring the saved model without strategy: 91.37999725341797

Next Steps

Try out the new tf.distribute.Strategy API on your models.