Custom training with tf.distribute.Strategy

View on Run in Google Colab View source on GitHub Download notebook

This tutorial demonstrates how to use tf.distribute.Strategy—a TensorFlow API that provides an abstraction for distributing your training across multiple processing units (GPUs, multiple machines, or TPUs)—with custom training loops. In this example, you will train a simple convolutional neural network on the Fashion MNIST dataset containing 70,000 images of size 28 x 28.

Custom training loops provide flexibility and a greater control on training. They also make it easier to debug the model and the training loop.

# Import TensorFlow
import tensorflow as tf

# Helper libraries
import numpy as np
import os

2023-12-07 02:50:09.865434: E external/local_xla/xla/stream_executor/cuda/] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-07 02:50:09.865487: E external/local_xla/xla/stream_executor/cuda/] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-07 02:50:09.866997: E external/local_xla/xla/stream_executor/cuda/] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

Download the Fashion MNIST dataset

fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Add a dimension to the array -> new shape == (28, 28, 1)
# This is done because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Scale the images to the [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)
Downloading data from
29515/29515 [==============================] - 0s 0us/step
Downloading data from
26421880/26421880 [==============================] - 0s 0us/step
Downloading data from
5148/5148 [==============================] - 0s 0us/step
Downloading data from
4422102/4422102 [==============================] - 0s 0us/step

Create a strategy to distribute the variables and the graph

How does tf.distribute.MirroredStrategy strategy work?

  • All the variables and the model graph are replicated across the replicas.
  • Input is evenly distributed across the replicas.
  • Each replica calculates the loss and gradients for the input it received.
  • The gradients are synced across all the replicas by summing them.
  • After the sync, the same update is made to the copies of the variables on each replica.
# If the list of devices is not specified in
# `tf.distribute.MirroredStrategy` constructor, they will be auto-detected.
strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 4

Setup input pipeline

BUFFER_SIZE = len(train_images)

GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync


Create the datasets and distribute them:

train_dataset =, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE)
test_dataset =, test_labels)).batch(GLOBAL_BATCH_SIZE)

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

Create the model

Create a model using tf.keras.Sequential. You can also use the Model Subclassing API or the functional API to do this.

def create_model():
  regularizer = tf.keras.regularizers.L2(1e-5)
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3,
      tf.keras.layers.Conv2D(64, 3,
      tf.keras.layers.Dense(10, kernel_regularizer=regularizer)

  return model
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

Define the loss function

Recall that the loss function consists of one or two parts:

  • The prediction loss measures how far off the model's predictions are from the training labels for a batch of training examples. It is computed for each labeled example and then reduced across the batch by computing the average value.
  • Optionally, regularization loss terms can be added to the prediction loss, to steer the model away from overfitting the training data. A common choice is L2 regularization, which adds a small fixed multiple of the sum of squares of all model weights, independent of the number of examples. The model above uses L2 regularization to demonstrate its handling in the training loop below.

For training on a single machine with a single GPU/CPU, this works as follows:

  • The prediction loss is computed for each example in the batch, summed across the batch, and then divided by the batch size.
  • The regularization loss is added to the prediction loss.
  • The gradient of the total loss is computed w.r.t. each model weight, and the optimizer updates each model weight from the corresponding gradient.

With tf.distribute.Strategy, the input batch is split between replicas. For example, let's say you have 4 GPUs, each with one replica of the model. One batch of 256 input examples is distributed evenly across the 4 replicas, so each replica gets a batch of size 64: We have 256 = 4*64, or generally GLOBAL_BATCH_SIZE = num_replicas_in_sync * BATCH_SIZE_PER_REPLICA.

Each replica computes the loss from the training examples it gets and computes the gradients of the loss w.r.t. each model weight. The optimizer takes care that these gradients are summed up across replicas before using them to update the copies of the model weights on each replica.

So, how should the loss be calculated when using a tf.distribute.Strategy?

  • Each replica computes the prediction loss for all examples distributed to it, sums up the results and divides them by num_replicas_in_sync * BATCH_SIZE_PER_REPLICA, or equivently, GLOBAL_BATCH_SIZE.
  • Each replica compues the regularization loss(es) and divides them by num_replicas_in_sync.

Compared to non-distributed training, all per-replica loss terms are scaled down by a factor of 1/num_replicas_in_sync. On the other hand, all loss terms -- or rather, their gradients -- are summed across that number of replicas before the optimizer applies them. In effect, the optimizer on each replica uses the same gradients as if a non-distributed computation with GLOBAL_BATCH_SIZE had happened. This is consistent with the distributed and undistributed behavior of Keras See the Distributed training with Keras tutorial on how a larger gloabl batch size enables to scale up the learning rate.

How to do this in TensorFlow?

  • Loss reduction and scaling is done automatically in Keras Model.compile and

  • If you're writing a custom training loop, as in this tutorial, you should sum the per-example losses and divide the sum by the global batch size using tf.nn.compute_average_loss, which takes the per-example losses and optional sample weights as arguments and returns the scaled loss.

  • If using tf.keras.losses classes (as in the example below), the loss reduction needs to be explicitly specified to be one of NONE or SUM. The default AUTO and SUM_OVER_BATCH_SIZE are disallowed outside

    • AUTO is disallowed because the user should explicitly think about what reduction they want to make sure it is correct in the distributed case.
    • SUM_OVER_BATCH_SIZE is disallowed because currently it would only divide by per replica batch size, and leave the dividing by number of replicas to the user, which might be easy to miss. So, instead, you need to do the reduction yourself explicitly.
  • If you're writing a custom training loop for a model with a non-empty list of Model.losses (e.g., weight regularizers), you should sum them up and divide the sum by the number of replicas. You can do this by using the tf.nn.scale_regularization_loss function. The model code itself remains unaware of the number of replicas.

    However, models can define input-dependent regularization losses with Keras APIs such as Layer.add_loss(...) and Layer(activity_regularizer=...). For Layer.add_loss(...), it falls on the modeling code to perform the division of the summed per-example terms by the per-replica(!) batch size, e.g., by using tf.math.reduce_mean().

with strategy.scope():
  # Set reduction to `NONE` so you can do the reduction yourself.
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
  def compute_loss(labels, predictions, model_losses):
    per_example_loss = loss_object(labels, predictions)
    loss = tf.nn.compute_average_loss(per_example_loss)
    if model_losses:
      loss += tf.nn.scale_regularization_loss(tf.add_n(model_losses))
    return loss

Special cases

Advanced users should also consider the following special cases.

  • Input batches shorter than GLOBAL_BATCH_SIZE create unpleasant corner cases in several places. In practice, it often works best to avoid them by allowing batches to span epoch boundaries using Dataset.repeat().batch() and defining approximate epochs by step counts, not dataset ends. Alternatively, Dataset.batch(drop_remainder=True) maintains the notion of epoch but drops the last few examples.

    For illustration, this example goes the harder route and allows short batches, so that each training epoch contains each training example exactly once.

    Which denominator should be used by tf.nn.compute_average_loss()?

    • By default, in the example code above and equivalently in, the sum of prediction losses is divided by num_replicas_in_sync times the actual batch size seen on the replica (with empty batches silently ignored). This preserves the balance between the prediction loss on the one hand and the regularization losses on the other hand. It is particularly appropriate for models that use input-dependent regularization losses. Plain L2 regularization just superimposes weight decay onto the gradients of the prediction loss and is less in need of such a balance.
    • In practice, many custom training loops pass as a constant Python value into tf.nn.compute_average_loss(..., global_batch_size=GLOBAL_BATCH_SIZE) to use it as the denominator. This preserves the relative weighting of training examples between batches. Without it, the smaller denominator in short batches effectively upweights the examples in those. (Before TensorFlow 2.13, this was also needed to avoid NaNs in case some replica received an actual batch size of zero.)

    Both options are equivalent if short batches are avoided, as suggested above.

  • Multi-dimensional labels require you to average the per_example_loss across the number of predictions in each example. Consider a classification task for all pixels of an input image, with predictions of shape (batch_size, H, W, n_classes) and labels of shape (batch_size, H, W). You will need to update per_example_loss like: per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)

Define the metrics to track loss and accuracy

These metrics track the test loss and training and test accuracy. You can use .result() to get the accumulated statistics at any time.

with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(

Training loop

# A model, an optimizer, and a checkpoint must be created under `strategy.scope`.
with strategy.scope():
  model = create_model()

  optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
def train_step(inputs):
  images, labels = inputs

  with tf.GradientTape() as tape:
    predictions = model(images, training=True)
    loss = compute_loss(labels, predictions, model.losses)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_accuracy.update_state(labels, predictions)
  return loss

def test_step(inputs):
  images, labels = inputs

  predictions = model(images, training=False)
  t_loss = loss_object(labels, predictions)

  test_accuracy.update_state(labels, predictions)
# `run` replicates the provided computation and runs it
# with the distributed input.
def distributed_train_step(dataset_inputs):
  per_replica_losses =, args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,

def distributed_test_step(dataset_inputs):
  return, args=(dataset_inputs,))

for epoch in range(EPOCHS):
  total_loss = 0.0
  num_batches = 0
  for x in train_dist_dataset:
    total_loss += distributed_train_step(x)
    num_batches += 1
  train_loss = total_loss / num_batches

  for x in test_dist_dataset:

  if epoch % 2 == 0:

  template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
              "Test Accuracy: {}")
  print(template.format(epoch + 1, train_loss,
                         train_accuracy.result() * 100, test_loss.result(),
                         test_accuracy.result() * 100))

INFO:tensorflow:Collective all_reduce tensors: 8 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Collective all_reduce tensors: 8 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1701917423.382863   49757 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
INFO:tensorflow:Collective all_reduce tensors: 8 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 1, Loss: 0.6486169099807739, Accuracy: 76.63666534423828, Test Loss: 0.4479253888130188, Test Accuracy: 83.91999816894531
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 2, Loss: 0.39812561869621277, Accuracy: 85.82833099365234, Test Loss: 0.38462167978286743, Test Accuracy: 86.44000244140625
Epoch 3, Loss: 0.3495272696018219, Accuracy: 87.51666259765625, Test Loss: 0.3450538218021393, Test Accuracy: 87.70999908447266
Epoch 4, Loss: 0.32059356570243835, Accuracy: 88.5566635131836, Test Loss: 0.3286792039871216, Test Accuracy: 88.5
Epoch 5, Loss: 0.3009909689426422, Accuracy: 89.29000091552734, Test Loss: 0.3245093822479248, Test Accuracy: 88.0999984741211
Epoch 6, Loss: 0.28188374638557434, Accuracy: 89.9383316040039, Test Loss: 0.30204612016677856, Test Accuracy: 89.1500015258789
Epoch 7, Loss: 0.26679542660713196, Accuracy: 90.52999877929688, Test Loss: 0.28874218463897705, Test Accuracy: 89.67000579833984
Epoch 8, Loss: 0.25406745076179504, Accuracy: 90.9183349609375, Test Loss: 0.279201865196228, Test Accuracy: 89.96000671386719
Epoch 9, Loss: 0.24111337959766388, Accuracy: 91.44166564941406, Test Loss: 0.27962371706962585, Test Accuracy: 90.2300033569336
Epoch 10, Loss: 0.23263928294181824, Accuracy: 91.6866683959961, Test Loss: 0.28959372639656067, Test Accuracy: 89.4000015258789

Things to note in the example above

Restore the latest checkpoint and test

A model checkpointed with a tf.distribute.Strategy can be restored with or without a strategy.

eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset =, test_labels)).batch(GLOBAL_BATCH_SIZE)
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)

for images, labels in test_dataset:
  eval_step(images, labels)

print('Accuracy after restoring the saved model without strategy: {}'.format(
    eval_accuracy.result() * 100))
Accuracy after restoring the saved model without strategy: 90.2300033569336

Alternate ways of iterating over a dataset

Using iterators

If you want to iterate over a given number of steps and not through the entire dataset, you can create an iterator using the iter call and explicitly call next on the iterator. You can choose to iterate over the dataset both inside and outside the tf.function. Here is a small snippet demonstrating iteration of the dataset outside the tf.function using an iterator.

for _ in range(EPOCHS):
  total_loss = 0.0
  num_batches = 0
  train_iter = iter(train_dist_dataset)

  for _ in range(10):
    total_loss += distributed_train_step(next(train_iter))
    num_batches += 1
  average_train_loss = total_loss / num_batches

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print(template.format(epoch + 1, average_train_loss, train_accuracy.result() * 100))
Epoch 10, Loss: 0.24379579722881317, Accuracy: 90.9375
Epoch 10, Loss: 0.22639629244804382, Accuracy: 91.4453125
Epoch 10, Loss: 0.23340332508087158, Accuracy: 91.328125
Epoch 10, Loss: 0.21611782908439636, Accuracy: 92.3046875
Epoch 10, Loss: 0.21506652235984802, Accuracy: 93.0078125
Epoch 10, Loss: 0.20497587323188782, Accuracy: 92.5390625
Epoch 10, Loss: 0.20436997711658478, Accuracy: 93.203125
Epoch 10, Loss: 0.20990486443042755, Accuracy: 92.3828125
Epoch 10, Loss: 0.2148672640323639, Accuracy: 91.953125
Epoch 10, Loss: 0.22497150301933289, Accuracy: 91.953125

Iterating inside a tf.function

You can also iterate over the entire input train_dist_dataset inside a tf.function using the for x in ... construct or by creating iterators like you did above. The example below demonstrates wrapping one epoch of training with a @tf.function decorator and iterating over train_dist_dataset inside the function.

def distributed_train_epoch(dataset):
  total_loss = 0.0
  num_batches = 0
  for x in dataset:
    per_replica_losses =, args=(x,))
    total_loss += strategy.reduce(
      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
    num_batches += 1
  return total_loss / tf.cast(num_batches, dtype=tf.float32)

for epoch in range(EPOCHS):
  train_loss = distributed_train_epoch(train_dist_dataset)

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print(template.format(epoch + 1, train_loss, train_accuracy.result() * 100))

/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/ops/ UserWarning: To make it possible to preserve options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options.
  warnings.warn("To make it possible to preserve options across "
INFO:tensorflow:Collective all_reduce tensors: 8 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
Epoch 1, Loss: 0.21882569789886475, Accuracy: 92.2249984741211
Epoch 2, Loss: 0.20959770679473877, Accuracy: 92.61166381835938
Epoch 3, Loss: 0.2009064108133316, Accuracy: 92.90666198730469
Epoch 4, Loss: 0.19423046708106995, Accuracy: 93.19000244140625
Epoch 5, Loss: 0.18592748045921326, Accuracy: 93.44000244140625
Epoch 6, Loss: 0.18024654686450958, Accuracy: 93.788330078125
Epoch 7, Loss: 0.17136560380458832, Accuracy: 94.11833190917969
Epoch 8, Loss: 0.1671149581670761, Accuracy: 94.1933364868164
Epoch 9, Loss: 0.16002388298511505, Accuracy: 94.3949966430664
Epoch 10, Loss: 0.15144744515419006, Accuracy: 94.75

Tracking training loss across replicas

Because of the loss scaling computation that is carried out, it's not recommended to use tf.keras.metrics.Mean to track the training loss across different replicas.

For example, if you run a training job with the following characteristics:

  • Two replicas
  • Two samples are processed on each replica
  • Resulting loss values: [2, 3] and [4, 5] on each replica
  • Global batch size = 4

With loss scaling, you calculate the per-sample value of loss on each replica by adding the loss values, and then dividing by the global batch size. In this case: (2 + 3) / 4 = 1.25 and (4 + 5) / 4 = 2.25.

If you use tf.keras.metrics.Mean to track loss across the two replicas, the result is different. In this example, you end up with a total of 3.50 and count of 2, which results in total/count = 1.75 when result() is called on the metric. Loss calculated with tf.keras.Metrics is scaled by an additional factor that is equal to the number of replicas in sync.

Guide and examples

Here are some examples for using distribution strategy with custom training loops:

  1. Distributed training guide
  2. DenseNet example using MirroredStrategy.
  3. BERT example trained using MirroredStrategy and TPUStrategy. This example is particularly helpful for understanding how to load from a checkpoint and generate periodic checkpoints during distributed training etc.
  4. NCF example trained using MirroredStrategy that can be enabled using the keras_use_ctl flag.
  5. NMT example trained using MirroredStrategy.

You can find more examples listed under Examples and tutorials in the Distribution strategy guide.

Next steps