Join us at TensorFlow World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

tf.distribute.Strategy with training loops

View on Run in Google Colab View source on GitHub Download notebook

This tutorial demonstrates how to use tf.distribute.Strategy with custom training loops. We will train a simple CNN model on the fashion MNIST dataset. The fashion MNIST dataset contains 60000 train images of size 28 x 28 and 10000 test images of size 28 x 28.

We are using custom training loops to train our model because they give us flexibility and a greater control on training. Moreover, it is easier to debug the model and the training loop.

from __future__ import absolute_import, division, print_function, unicode_literals

# Import TensorFlow
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
import tensorflow as tf

# Helper libraries
import numpy as np
import os


Download the fashion MNIST dataset

fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)
Downloading data from
32768/29515 [=================================] - 0s 0us/step
Downloading data from
26427392/26421880 [==============================] - 1s 0us/step
Downloading data from
8192/5148 [===============================================] - 0s 0us/step
Downloading data from
4423680/4422102 [==============================] - 0s 0us/step

Create a strategy to distribute the variables and the graph

How does tf.distribute.MirroredStrategy strategy work?

  • All the variables and the model graph is replicated on the replicas.
  • Input is evenly distributed across the replicas.
  • Each replica calculates the loss and gradients for the input it received.
  • The gradients are synced across all the replicas by summing them.
  • After the sync, the same update is made to the copies of the variables on each replica.
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

Setup input pipeline

Export the graph and the variables to the platform-agnostic SavedModel format. After your model is saved, you can load it with or without the scope.

BUFFER_SIZE = len(train_images)

GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync


Create the datasets and distribute them:

train_dataset =, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) 
test_dataset =, test_labels)).batch(GLOBAL_BATCH_SIZE) 

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

Create the model

Create a model using tf.keras.Sequential. You can also use the Model Subclassing API to do this.

def create_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.Conv2D(64, 3, activation='relu'),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')

  return model
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

Define the loss function

Normally, on a single machine with 1 GPU/CPU, loss is divided by the number of examples in the batch of input.

So, how should the loss be calculated when using a tf.distribute.Strategy?

  • For an example, let's say you have 4 GPU's and a batch size of 64. One batch of input is distributed across the replicas (4 GPUs), each replica getting an input of size 16.

  • The model on each replica does a forward pass with its respective input and calculates the loss. Now, instead of dividing the loss by the number of examples in its respective input (BATCH_SIZE_PER_REPLICA = 16), the loss should be divided by the GLOBAL_BATCH_SIZE (64).

Why do this?

  • This needs to be done because after the gradients are calculated on each replica, they are synced across the replicas by summing them.

How to do this in TensorFlow? * If you're writing a custom training loop, as in this tutorial, you should sum the per example losses and divide the sum by the GLOBAL_BATCH_SIZE: scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE) or you can use tf.nn.compute_average_loss which takes the per example loss, optional sample weights, and GLOBAL_BATCH_SIZE as arguments and returns the scaled loss.

  • If you are using regularization losses in your model then you need to scale the loss value by number of replicas. You can do this by using the tf.nn.scale_regularization_loss function.

  • Using tf.reduce_mean is not recommended. Doing so divides the loss by actual per replica batch size which may vary step to step.

  • This reduction and scaling is done automatically in keras model.compile and

  • If using tf.keras.losses classes (as in the example below), the loss reduction needs to be explicitly specified to be one of NONE or SUM. AUTO and SUM_OVER_BATCH_SIZE are disallowed when used with tf.distribute.Strategy. AUTO is disallowed because the user should explicitly think about what reduction they want to make sure it is correct in the distributed case. SUM_OVER_BATCH_SIZE is disallowed because currently it would only divide by per replica batch size, and leave the dividing by number of replicas to the user, which might be easy to miss. So instead we ask the user do the reduction themselves explicitly.

with strategy.scope():
  # Set reduction to `none` so we can do the reduction afterwards and divide by
  # global batch size.
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
  # or loss_fn = tf.keras.losses.sparse_categorical_crossentropy
  def compute_loss(labels, predictions):
    per_example_loss = loss_object(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

Define the metrics to track loss and accuracy

These metrics track the test loss and training and test accuracy. You can use .result() to get the accumulated statistics at any time.

with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(

Training loop

# model and optimizer must be created under `strategy.scope`.
with strategy.scope():
  model = create_model()

  optimizer = tf.keras.optimizers.Adam()

  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
with strategy.scope():
  def train_step(inputs):
    images, labels = inputs

    with tf.GradientTape() as tape:
      predictions = model(images, training=True)
      loss = compute_loss(labels, predictions)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_accuracy.update_state(labels, predictions)
    return loss 

  def test_step(inputs):
    images, labels = inputs

    predictions = model(images, training=False)
    t_loss = loss_object(labels, predictions)

    test_accuracy.update_state(labels, predictions)
with strategy.scope():
  # `experimental_run_v2` replicates the provided computation and runs it
  # with the distributed input.
  def distributed_train_step(dataset_inputs):
    per_replica_losses = strategy.experimental_run_v2(train_step,
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
  def distributed_test_step(dataset_inputs):
    return strategy.experimental_run_v2(test_step, args=(dataset_inputs,))

  for epoch in range(EPOCHS):
    total_loss = 0.0
    num_batches = 0
    for x in train_dist_dataset:
      total_loss += distributed_train_step(x)
      num_batches += 1
    train_loss = total_loss / num_batches

    for x in test_dist_dataset:

    if epoch % 2 == 0:

    template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
                "Test Accuracy: {}")
    print (template.format(epoch+1, train_loss,
                           train_accuracy.result()*100, test_loss.result(),

Epoch 1, Loss: 0.5041059851646423, Accuracy: 81.62000274658203, Test Loss: 0.3745802938938141, Test Accuracy: 86.62999725341797
Epoch 2, Loss: 0.33168354630470276, Accuracy: 87.9566650390625, Test Loss: 0.3243101239204407, Test Accuracy: 88.43000030517578
Epoch 3, Loss: 0.2855560779571533, Accuracy: 89.45500183105469, Test Loss: 0.3009592890739441, Test Accuracy: 89.20999908447266
Epoch 4, Loss: 0.2551109194755554, Accuracy: 90.67666625976562, Test Loss: 0.28267067670822144, Test Accuracy: 89.96000671386719
Epoch 5, Loss: 0.2314356565475464, Accuracy: 91.56666564941406, Test Loss: 0.27685269713401794, Test Accuracy: 90.25
Epoch 6, Loss: 0.21108976006507874, Accuracy: 92.28666687011719, Test Loss: 0.27175289392471313, Test Accuracy: 90.41999816894531
Epoch 7, Loss: 0.19321465492248535, Accuracy: 92.99166870117188, Test Loss: 0.2708509564399719, Test Accuracy: 90.68000030517578
Epoch 8, Loss: 0.17691724002361298, Accuracy: 93.5999984741211, Test Loss: 0.27721959352493286, Test Accuracy: 90.6500015258789
Epoch 9, Loss: 0.16137677431106567, Accuracy: 94.27499389648438, Test Loss: 0.2878393232822418, Test Accuracy: 90.68000030517578
Epoch 10, Loss: 0.14655005931854248, Accuracy: 94.79833221435547, Test Loss: 0.29786813259124756, Test Accuracy: 90.72000122070312

Things to note in the example above:

Restore the latest checkpoint and test

A model checkpointed with a tf.distribute.Strategy can be restored with or without a strategy.

eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset =, test_labels)).batch(GLOBAL_BATCH_SIZE)
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)

for images, labels in test_dataset:
  eval_step(images, labels)

print ('Accuracy after restoring the saved model without strategy: {}'.format(
Accuracy after restoring the saved model without strategy: 90.68000030517578

Alternate ways of iterating over a dataset

Using iterators

If you want to iterate over a given number of steps and not through the entire dataset you can create an iterator using the iter call and explicity call next on the iterator. You can choose to iterate over the dataset both inside and outside the tf.function. Here is a small snippet demonstrating iteration of the dataset outside the tf.function using an iterator.

with strategy.scope():
  for _ in range(EPOCHS):
    total_loss = 0.0
    num_batches = 0
    train_iter = iter(train_dist_dataset)

    for _ in range(10):
      total_loss += distributed_train_step(next(train_iter))
      num_batches += 1
    average_train_loss = total_loss / num_batches

    template = ("Epoch {}, Loss: {}, Accuracy: {}")
    print (template.format(epoch+1, average_train_loss, train_accuracy.result()*100))
Epoch 10, Loss: 0.10446622222661972, Accuracy: 96.09375
Epoch 10, Loss: 0.0791325718164444, Accuracy: 97.65625
Epoch 10, Loss: 0.04644235968589783, Accuracy: 98.75
Epoch 10, Loss: 0.03152443468570709, Accuracy: 99.6875
Epoch 10, Loss: 0.023476367816329002, Accuracy: 99.6875
Epoch 10, Loss: 0.017614154145121574, Accuracy: 100.0
Epoch 10, Loss: 0.01347125880420208, Accuracy: 100.0
Epoch 10, Loss: 0.01140991784632206, Accuracy: 100.0
Epoch 10, Loss: 0.009441525675356388, Accuracy: 100.0
Epoch 10, Loss: 0.008261207491159439, Accuracy: 100.0

Iterating inside a tf.function

You can also iterate over the entire input train_dist_dataset inside a tf.function using the for x in ... construct or by creating iterators like we did above. The example below demonstrates wrapping one epoch of training in a tf.function and iterating over train_dist_dataset inside the function.

with strategy.scope():
  def distributed_train_epoch(dataset):
    total_loss = 0.0
    num_batches = 0
    for x in dataset:
      per_replica_losses = strategy.experimental_run_v2(train_step,
      total_loss += strategy.reduce(
        tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
      num_batches += 1
    return total_loss / tf.cast(num_batches, dtype=tf.float32)

  for epoch in range(EPOCHS):
    train_loss = distributed_train_epoch(train_dist_dataset)

    template = ("Epoch {}, Loss: {}, Accuracy: {}")
    print (template.format(epoch+1, train_loss, train_accuracy.result()*100))

Epoch 1, Loss: 0.13604184985160828, Accuracy: 95.12332916259766
Epoch 2, Loss: 0.13949018716812134, Accuracy: 94.71666717529297
Epoch 3, Loss: 0.12360842525959015, Accuracy: 95.36332702636719
Epoch 4, Loss: 0.11401603370904922, Accuracy: 95.77000427246094
Epoch 5, Loss: 0.10236617922782898, Accuracy: 96.1883316040039
Epoch 6, Loss: 0.0942242443561554, Accuracy: 96.54499816894531
Epoch 7, Loss: 0.08623343706130981, Accuracy: 96.79000091552734
Epoch 8, Loss: 0.08089885115623474, Accuracy: 97.02166748046875
Epoch 9, Loss: 0.07643488049507141, Accuracy: 97.12999725341797
Epoch 10, Loss: 0.06888115406036377, Accuracy: 97.44833374023438

Tracking training loss across replicas

We do not recommend using tf.metrics.Mean to track the training loss across different replicas, because of the loss scaling computation that is carried out.

For example, if you run a training job with the following characteristics: * Two replicas * Two samples are processed on each replica * Resulting loss values: [2, 3] and [4, 5] on each replica * Global batch size = 4

With loss scaling, you calculate the per-sample value of loss on each replica by adding the loss values, and then dividing by the global batch size. In this case: (2 + 3) / 4 = 1.25 and (4 + 5) / 4 = 2.25.

If you use tf.metrics.Mean to track loss across the two replicas, the result is different. In this example, you end up with a total of 3.50 and count of 2, which results in total/count = 1.75 when result() is called on the metric. Loss calculated with tf.keras.Metrics is scaled by an additional factor that is equal to the number of replicas in sync.

Examples and Tutorials

Here are some examples for using distribution strategy with custom training loops:

  1. Tutorial to train MNIST using MirroredStrategy.
  2. DenseNet example using MirroredStrategy.
  3. BERT example trained using MirroredStrategy and TPUStrategy. This example is particularly helpful for understanding how to load from a checkpoint and generate periodic checkpoints during distributed training etc.
  4. NCF example trained using MirroredStrategy that can be enabled using the keras_use_ctl flag.
  5. NMT example trained using MirroredStrategy.

More examples listed in the Distribution strategy guide

Next steps

Try out the new tf.distribute.Strategy API on your models.