Multi-worker training with Keras

View on Run in Google Colab View source on GitHub Download notebook


This tutorial demonstrates multi-worker distributed training with Keras model using tf.distribute.Strategy API, specifically tf.distribute.experimental.MultiWorkerMirroredStrategy. With the help of this strategy, a Keras model that was designed to run on single-worker can seamlessly work on multiple workers with minimal code change.

Distributed Training in TensorFlow guide is available for an overview of the distribution strategies TensorFlow supports for those interested in a deeper understanding of tf.distribute.Strategy APIs.


First, setup TensorFlow and the necessary imports.

import os
import tensorflow as tf
import numpy as np

Preparing dataset

Now, let's prepare the MNIST dataset. The MNIST dataset comprises 60,000 training examples and 10,000 test examples of the handwritten digits 0–9, formatted as 28x28-pixel monochrome images. In this example, we will take the training part of the datasets to demonstrate.

def mnist_dataset(batch_size):
  (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
  # The `x` arrays are in uint8 and have values in the range [0, 255].
  # We need to convert them to float32 with values in the range [0, 1]
  x_train = x_train / np.float32(255)
  y_train = y_train.astype(np.int64)
  train_dataset =
      (x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
  return train_dataset

Build the Keras model

Here we use tf.keras.Sequential API to build and compile a simple convolutional neural networks Keras model to train with our MNIST dataset.

def build_and_compile_cnn_model():
  model = tf.keras.Sequential([
      tf.keras.Input(shape=(28, 28)),
      tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.Dense(128, activation='relu'),
  return model

Let's first try training the model for a small number of epochs and observe the results in single worker to make sure everything works correctly. You should expect to see the loss dropping and accuracy approaching 1.0 as epoch advances.

per_worker_batch_size = 64
single_worker_dataset = mnist_dataset(per_worker_batch_size)
single_worker_model = build_and_compile_cnn_model(), epochs=3, steps_per_epoch=70)
Downloading data from
11493376/11490434 [==============================] - 0s 0us/step
Epoch 1/3
70/70 [==============================] - 0s 2ms/step - loss: 2.2701 - accuracy: 0.2451
Epoch 2/3
70/70 [==============================] - 0s 2ms/step - loss: 2.1827 - accuracy: 0.4777
Epoch 3/3
70/70 [==============================] - 0s 2ms/step - loss: 2.0865 - accuracy: 0.5955

<tensorflow.python.keras.callbacks.History at 0x7fc59381ac50>

Multi-worker Configuration

Now let's enter the world of multi-worker training. In TensorFlow, TF_CONFIG environment variable is required for training on multiple machines, each of which possibly has a different role. TF_CONFIG is a JSON string used to specify the cluster configuration on each worker that is part of the cluster.

There are two components of TF_CONFIG: cluster and task. cluster provides information about the training cluster, which is a dict consisting of different types of jobs such as worker. In multi-worker training with MultiWorkerMirroredStrategy, there is usually one worker that takes on a little more responsibility like saving checkpoint and writing summary file for TensorBoard in addition to what a regular worker does. Such worker is referred to as the chief worker, and it is customary that the worker with index 0 is appointed as the chief worker (in fact this is how tf.distribute.Strategy is implemented). task on the other hand provides information of the current task. The first component cluster is the same for all workers, and the second component task is different on each worker and specifies the type and index of that worker.

In this example, we set the task type to "worker" and the task index to 0. This means the machine that has such setting is the first worker, which will be appointed as the chief worker and do more work than other workers. Note that other machines will need to have TF_CONFIG environment variable set as well, and it should have the same cluster dict, but different task type or task index depending on what the roles of those machines are.

For illustration purposes, this tutorial shows how one may set a TF_CONFIG with 2 workers on localhost. In practice, users would create multiple workers on external IP addresses/ports, and set TF_CONFIG on each worker appropriately.

os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["localhost:12345", "localhost:23456"]
    'task': {'type': 'worker', 'index': 0}

Note that while the learning rate is fixed in this example, in general it may be necessary to adjust the learning rate based on the global batch size.

Choose the right strategy

In TensorFlow, distributed training consists of synchronous training, where the steps of training are synced across the workers and replicas, and asynchronous training, where the training steps are not strictly synced.

MultiWorkerMirroredStrategy, which is the recommended strategy for synchronous multi-worker training, will be demonstrated in this guide. To train the model, use an instance of tf.distribute.experimental.MultiWorkerMirroredStrategy. MultiWorkerMirroredStrategy creates copies of all variables in the model's layers on each device across all workers. It uses CollectiveOps, a TensorFlow op for collective communication, to aggregate gradients and keep the variables in sync. The tf.distribute.Strategy guide has more details about this strategy.

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/device:GPU:0',)
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0',), communication = CollectiveCommunication.AUTO

MultiWorkerMirroredStrategy provides multiple implementations via the CollectiveCommunication parameter. RING implements ring-based collectives using gRPC as the cross-host communication layer. NCCL uses Nvidia's NCCL to implement collectives. AUTO defers the choice to the runtime. The best choice of collective implementation depends upon the number and kind of GPUs, and the network interconnect in the cluster.

Train the model with MultiWorkerMirroredStrategy

With the integration of tf.distribute.Strategy API into tf.keras, the only change you will make to distribute the training to multi-worker is enclosing the model building and model.compile() call inside strategy.scope(). The distribution strategy's scope dictates how and where the variables are created, and in the case of MultiWorkerMirroredStrategy, the variables created are MirroredVariables, and they are replicated on each of the workers.

num_workers = 4

# Here the batch size scales up by number of workers since 
# `` expects the global batch size. Previously we used 64, 
# and now this becomes 128.
global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist_dataset(global_batch_size)

with strategy.scope():
  # Model building/compiling need to be within `strategy.scope()`.
  multi_worker_model = build_and_compile_cnn_model()

# Keras' `` trains the model with specified number of epochs and
# number of steps per epoch. Note that the numbers here are for demonstration
# purposes only and may not sufficiently produce a model with good quality., epochs=3, steps_per_epoch=70)
Epoch 1/3
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/ get_next_as_optional (from is deprecated and will be removed in a future version.
Instructions for updating:
Use `` instead.
70/70 [==============================] - 0s 3ms/step - loss: 2.2682 - accuracy: 0.2265
Epoch 2/3
70/70 [==============================] - 0s 3ms/step - loss: 2.1714 - accuracy: 0.4954
Epoch 3/3
70/70 [==============================] - 0s 3ms/step - loss: 2.0638 - accuracy: 0.6232

<tensorflow.python.keras.callbacks.History at 0x7fc5f4f062e8>

Dataset sharding and batch size

In multi-worker training with MultiWorkerMirroredStrategy, sharding the dataset is needed to ensure convergence and performance. However, note that in above code snippet, the datasets are directly passed to without needing to shard; this is because tf.distribute.Strategy API takes care of the dataset sharding automatically. It shards the dataset at the file level which may create skewed shards. In extreme cases where there is only one file, only the first shard (i.e. worker) will get training or evaluation data and as a result all workers will get errors.

If you prefer manual sharding for your training, automatic sharding can be turned off via api. Concretely,

options =
options.experimental_distribute.auto_shard_policy =
dataset_no_auto_shard = multi_worker_dataset.with_options(options)

Another thing to notice is the batch size for the datasets. In the code snippet above, we use global_batch_size = per_worker_batch_size * num_workers, which is num_workers times as large as the case it was for single worker, because the effective per worker batch size is the global batch size (the parameter passed in divided by the number of workers, and with this change we are keeping the per worker batch size same as before.


If you pass validation_data into, it will alternate between training and evaluation for each epoch. The evaluation taking validation_data is distributed across the same set of workers and the evaluation results are aggregated and available for all workers. Similar to training, the validation dataset is automatically sharded at the file level. You need to set a global batch size in the validation dataset and set validation_steps. A repeated dataset is also recommended for evaluation.

Alternatively, you can also create another task that periodically reads checkpoints and runs the evaluation. This is what Estimator does. But this is not a recommended way to perform evaluation and thus its details are omitted.


Currently model.predict doesn't work with MultiWorkerMirroredStrategy.


You now have a Keras model that is all set up to run in multiple workers with MultiWorkerMirroredStrategy. You can try the following techniques to tweak performance of multi-worker training with MultiWorkerMirroredStrategy.

  • MultiWorkerMirroredStrategy provides multiple collective communication implementations. RING implements ring-based collectives using gRPC as the cross-host communication layer. NCCL uses Nvidia's NCCL to implement collectives. AUTO defers the choice to the runtime. The best choice of collective implementation depends upon the number and kind of GPUs, and the network interconnect in the cluster. To override the automatic choice, specify a valid value to the communication parameter of MultiWorkerMirroredStrategy's constructor, e.g. communication=tf.distribute.experimental.CollectiveCommunication.NCCL.
  • Cast the variables to tf.float if possible. The official ResNet model includes an example of how this can be done.

Fault tolerance

In synchronous training, the cluster would fail if one of the workers fails and no failure-recovery mechanism exists. Using Keras with tf.distribute.Strategy comes with the advantage of fault tolerance in cases where workers die or are otherwise unstable. We do this by preserving training state in the distributed file system of your choice, such that upon restart of the instance that previously failed or preempted, the training state is recovered.

Since all the workers are kept in sync in terms of training epochs and steps, other workers would need to wait for the failed or preempted worker to restart to continue.

ModelCheckpoint callback

ModelCheckpoint callback no longer provides fault tolerance functionality, please use BackupAndRestore callback instead.

The ModelCheckpoint callback can still be used to save checkpoints. But with this, if training was interrupted or successfully finished, in order to continue training from the checkpoint, user is responsible to load the model manually. Optionally user can choose to save and restore model/weights outside ModelCheckpoint callback.

Model saving and loading

To save your model using or, the destination for saving needs to be different for each worker. On the non-chief workers, you will need to save the model to a temporary directory, and on the chief, you will need to save to the provided model directory. The temporary directories on the worker need to be unique to prevent errors resulting from multiple workers trying to write to the same location. The model saved in all the directories are identical and typically only the model saved by the chief should be referenced for restoring or serving. We recommend that you have some cleanup logic that deletes the temporary directories created by the workers once your training has completed.

The reason you need to save on the chief and workers at the same time, is because you might be aggregating variables during checkpointing which requires both the chief and workers to participate in the allreduce communication protocol. On the other hand, letting chief and workers save to the same model directory will result in errors due to contention.

With MultiWorkerMirroredStrategy, the program is run on every worker, and in order to know whether the current worker is chief, we take advantage of the cluster resolver object that has attributes task_type and task_id. task_type tells you what the current job is (e.g. 'worker'), and task_id tells you the identifier of the worker. The worker with id 0 is designated as the chief worker.

In the code snippet below, write_filepath provides the file path to write, which depends on the worker id. In the case of chief (worker with id 0), it writes to the original file path; for others, it creates a temporary directory (with id in the directory path) to write in:

model_path = '/tmp/keras-model'

def _is_chief(task_type, task_id):
  # If `task_type` is None, this may be operating as single worker, which works 
  # effectively as chief.
  return task_type is None or task_type == 'chief' or (
            task_type == 'worker' and task_id == 0)

def _get_temp_dir(dirpath, task_id):
  base_dirpath = 'workertemp_' + str(task_id)
  temp_dir = os.path.join(dirpath, base_dirpath)
  return temp_dir

def write_filepath(filepath, task_type, task_id):
  dirpath = os.path.dirname(filepath)
  base = os.path.basename(filepath)
  if not _is_chief(task_type, task_id):
    dirpath = _get_temp_dir(dirpath, task_id)
  return os.path.join(dirpath, base)

task_type, task_id = (strategy.cluster_resolver.task_type,
write_model_path = write_filepath(model_path, task_type, task_id)

With that, you're now ready to save:
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/ Model.state_updates (from is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/ Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: /tmp/keras-model/assets

As we described above, later on the model should only be loaded from the path chief saved to, so let's remove the temporary ones the non-chief workers saved:

if not _is_chief(task_type, task_id):

Now, when it's time to load, let's use convenient tf.keras.models.load_model API, and continue with further work. Here, we assume only using single worker to load and continue training, in which case you do not call tf.keras.models.load_model within another strategy.scope().

loaded_model = tf.keras.models.load_model(model_path)

# Now that we have the model restored, and can continue with the training., epochs=2, steps_per_epoch=20)
Epoch 1/2
20/20 [==============================] - 0s 2ms/step - loss: 1.9825 - accuracy: 0.1102
Epoch 2/2
20/20 [==============================] - 0s 2ms/step - loss: 1.9367 - accuracy: 0.1117

<tensorflow.python.keras.callbacks.History at 0x7fc5f4b0d8d0>

Checkpoint saving and restoring

On the other hand, checkpointing allows you to save model's weights and restore to them without having to save the whole model. Here, you'll create one tf.train.Checkpoint that tracks the model, which is managed by a tf.train.CheckpointManager so that only the latest checkpoint is preserved.

checkpoint_dir = '/tmp/ckpt'

checkpoint = tf.train.Checkpoint(model=multi_worker_model)
write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id)
checkpoint_manager = tf.train.CheckpointManager(
  checkpoint, directory=write_checkpoint_dir, max_to_keep=1)

Once the CheckpointManager is set up, you're now ready to save, and remove the checkpoints non-chief workers saved.
if not _is_chief(task_type, task_id):

Now, when you need to restore, you can find the latest checkpoint saved using convenient tf.train.latest_checkpoint function. After restoring the checkpoint, you can continue with training.

latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
checkpoint.restore(latest_checkpoint), epochs=2, steps_per_epoch=20)
Epoch 1/2
20/20 [==============================] - 0s 3ms/step - loss: 1.9841 - accuracy: 0.6561
Epoch 2/2
20/20 [==============================] - 0s 3ms/step - loss: 1.9445 - accuracy: 0.6805

<tensorflow.python.keras.callbacks.History at 0x7fc5f49d9d30>

BackupAndRestore callback

BackupAndRestore callback provides fault tolerance functionality, by backing up the model and current epoch number in a temporary checkpoint file under backup_dir argument to BackupAndRestore. This is done at the end of each epoch.

Once jobs get interrupted and restart, the callback restores the last checkpoint, and training continues from the beginning of the interrupted epoch. Any partial training already done in the unfinished epoch before interruption will be thrown away, so that it doesn't affect the final model state.

To use it, provide an instance of tf.keras.callbacks.experimental.BackupAndRestore at the call.

With MultiWorkerMirroredStrategy, if a worker gets interrupted, the whole cluster pauses until the interrupted worker is restarted. Other workers will also restart, and the interrupted worker rejoins the cluster. Then, every worker reads the checkpoint file that was previously saved and picks up its former state, thereby allowing the cluster to get back in sync. Then the training continues.

BackupAndRestore callback uses CheckpointManager to save and restore the training state, which generates a file called checkpoint that tracks existing checkpoints together with the latest one. For this reason, backup_dir should not be re-used to store other checkpoints in order to avoid name collision.

Currently, BackupAndRestore callback supports single worker with no strategy, MirroredStrategy, and multi-worker with MultiWorkerMirroredStrategy. Below are two examples for both multi-worker training and single worker training.

# Multi-worker training with MultiWorkerMirroredStrategy.

callbacks = [tf.keras.callbacks.experimental.BackupAndRestore(backup_dir='/tmp/backup')]
with strategy.scope():
  multi_worker_model = build_and_compile_cnn_model(),
Epoch 1/3
70/70 [==============================] - 0s 3ms/step - loss: 2.2837 - accuracy: 0.1836
Epoch 2/3
70/70 [==============================] - 0s 3ms/step - loss: 2.2131 - accuracy: 0.4091
Epoch 3/3
70/70 [==============================] - 0s 3ms/step - loss: 2.1310 - accuracy: 0.5485

<tensorflow.python.keras.callbacks.History at 0x7fc5f49a3080>

If you inspect the directory of backup_dir you specified in BackupAndRestore, you may notice some temporarily generated checkpoint files. Those files are needed for recovering the previously lost instances, and they will be removed by the library at the end of upon successful exiting of your training.

See also

  1. Distributed Training in TensorFlow guide provides an overview of the available distribution strategies.
  2. Official models, many of which can be configured to run multiple distribution strategies.
  3. The Performance section in the guide provides information about other strategies and tools you can use to optimize the performance of your TensorFlow models.