Missed TensorFlow World? Check out the recap. Learn more

Multi-worker training with Keras

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

This tutorial demonstrates multi-worker distributed training with Keras model using tf.distribute.Strategy API. With the help of the strategies specifically designed for multi-worker training, a Keras model that was designed to run on single-worker can seamlessly work on multiple workers with minimal code change.

Distributed Training in TensorFlow guide is available for an overview of the distribution strategies TensorFlow supports for those interested in a deeper understanding of tf.distribute.Strategy APIs.

Setup

First, setup TensorFlow and the necessary imports.

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()

Preparing dataset

Now, let's prepare the MNIST dataset from TensorFlow Datasets. The MNIST dataset comprises 60,000 training examples and 10,000 test examples of the handwritten digits 0–9, formatted as 28x28-pixel monochrome images.

BUFFER_SIZE = 10000
BATCH_SIZE = 64

def make_datasets_unbatched():
  # Scaling MNIST data from (0, 255] to (0., 1.]
  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255
    return image, label

  datasets, info = tfds.load(name='mnist',
                            with_info=True,
                            as_supervised=True)

  return datasets['train'].map(scale).cache().shuffle(BUFFER_SIZE)

train_datasets = make_datasets_unbatched().batch(BATCH_SIZE)
Downloading and preparing dataset mnist (11.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/1.0.0...

/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning)
/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning)
/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning)
/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning)

WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.6/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.6/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/1.0.0. Subsequent calls will reuse this data.

Build the Keras model

Here we use tf.keras.Sequential API to build and compile a simple convolutional neural networks Keras model to train with our MNIST dataset.

def build_and_compile_cnn_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')
  ])
  model.compile(
      loss=tf.keras.losses.sparse_categorical_crossentropy,
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
      metrics=['accuracy'])
  return model

Let's first try training the model for a small number of epochs and observe the results in single worker to make sure everything works correctly. You should expect to see the loss dropping and accuracy approaching 1.0 as epoch advances.

single_worker_model = build_and_compile_cnn_model()
single_worker_model.fit(x=train_datasets, epochs=3)
Epoch 1/3
938/938 [==============================] - 11s 12ms/step - loss: 2.0667 - accuracy: 0.4719
Epoch 2/3
938/938 [==============================] - 2s 3ms/step - loss: 1.1673 - accuracy: 0.7724
Epoch 3/3
938/938 [==============================] - 2s 3ms/step - loss: 0.6323 - accuracy: 0.8469

<tensorflow.python.keras.callbacks.History at 0x7fc3b85f7978>

Multi-worker Configuration

Now let's enter the world of multi-worker training. In TensorFlow, TF_CONFIG environment variable is required for training on multiple machines, each of which possibly has a different role. TF_CONFIG is used to specify the cluster configuration on each worker that is part of the cluster.

There are two components of TF_CONFIG: cluster and task. cluster provides information about the training cluster, which is a dict consisting of different types of jobs such as worker. In multi-worker training, there is usually one worker that takes on a little more responsibility like saving checkpoint and writing summary file for TensorBoard in addition to what a regular worker does. Such worker is referred to as the 'chief' worker, and it is customary that the worker with index 0 is appointed as the chief worker (in fact this is how tf.distribute.Strategy is implemented). task on the other hand provides information of the current task.

In this example, we set the task type to "worker" and the task index to 0. This means the machine that has such setting is the first worker, which will be appointed as the chief worker and do more work than other workers. Note that other machines will need to have TF_CONFIG environment variable set as well, and it should have the same cluster dict, but different task type or task index depending on what the roles of those machines are.

For illustration purposes, this tutorial shows how one may set a TF_CONFIG with 2 workers on localhost. In practice, users would create multiple workers on external IP addresses/ports, and set TF_CONFIG on each worker appropriately.

os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["localhost:12345", "localhost:23456"]
    },
    'task': {'type': 'worker', 'index': 0}
})

Note that while the learning rate is fixed in this example, in general it may be necessary to adjust the learning rate based on the global batch size.

Choose the right strategy

In TensorFlow, distributed training consists of synchronous training, where the steps of training are synced across the workers and replicas, and asynchronous training, where the training steps are not strictly synced.

MultiWorkerMirroredStrategy, which is the recommended strategy for synchronous multi-worker training, will be demonstrated in this guide. To train the model, use an instance of tf.distribute.experimental.MultiWorkerMirroredStrategy. MultiWorkerMirroredStrategy creates copies of all variables in the model's layers on each device across all workers. It uses CollectiveOps, a TensorFlow op for collective communication, to aggregate gradients and keep the variables in sync. The tf.distribute.Strategy guide has more details about this strategy.

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.

WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.

INFO:tensorflow:Single-worker CollectiveAllReduceStrategy with local_devices = ('/device:GPU:0',), communication = CollectiveCommunication.AUTO

INFO:tensorflow:Single-worker CollectiveAllReduceStrategy with local_devices = ('/device:GPU:0',), communication = CollectiveCommunication.AUTO

MultiWorkerMirroredStrategy provides multiple implementations via the CollectiveCommunication parameter. RING implements ring-based collectives using gRPC as the cross-host communication layer. NCCL uses Nvidia's NCCL to implement collectives. AUTO defers the choice to the runtime. The best choice of collective implementation depends upon the number and kind of GPUs, and the network interconnect in the cluster.

Train the model with MultiWorkerMirroredStrategy

With the integration of tf.distribute.Strategy API into tf.keras, the only change you will make to distribute the training to multi-worker is enclosing the model building and model.compile() call inside strategy.scope(). The distribution strategy's scope dictates how and where the variables are created, and in the case of MultiWorkerMirroredStrategy, the variables created are MirroredVariables, and they are replicated on each of the workers.

NUM_WORKERS = 2
# Here the batch size scales up by number of workers since 
# `tf.data.Dataset.batch` expects the global batch size. Previously we used 64, 
# and now this becomes 128.
GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS
with strategy.scope():
  # Creation of dataset, and model building/compiling need to be within 
  # `strategy.scope()`.
  train_datasets = make_datasets_unbatched().batch(GLOBAL_BATCH_SIZE)
  multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(x=train_datasets, epochs=3)
Epoch 1/3
469/469 [==============================] - 9s 18ms/step - loss: 2.2174 - accuracy: 0.2142
Epoch 2/3
469/469 [==============================] - 1s 3ms/step - loss: 1.9502 - accuracy: 0.5165
Epoch 3/3
469/469 [==============================] - 1s 3ms/step - loss: 1.4742 - accuracy: 0.7417

<tensorflow.python.keras.callbacks.History at 0x7fc3a0123e48>

Dataset sharding and batch size

In multi-worker training, sharding data into multiple parts is needed to ensure convergence and performance. However, note that in above code snippet, the datasets are directly sent to model.fit() without needing to shard; this is because tf.distribute.Strategy API takes care of the dataset sharding automatically in multi-worker trainings.

If you prefer manual sharding for your training, automatic sharding can be turned off via tf.data.experimental.DistributeOptions api. Concretely,

options = tf.data.Options()
options.experimental_distribute.auto_shard = False
train_datasets_no_auto_shard = train_datasets.with_options(options)

Another thing to notice is the batch size for the datasets. In the code snippet above, we use GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS, which is NUM_WORKERS times as large as the case it was for single worker, because the effective per worker batch size is the global batch size (the parameter passed in tf.data.Dataset.batch()) divided by the number of workers, and with this change we are keeping the per worker batch size same as before.

Performance

You now have a Keras model that is all set up to run in multiple workers with MultiWorkerMirroredStrategy. You can try the following techniques to tweak performance of multi-worker training.

  • MultiWorkerMirroredStrategy provides multiple collective communication implementations. RING implements ring-based collectives using gRPC as the cross-host communication layer. NCCL uses Nvidia's NCCL to implement collectives. AUTO defers the choice to the runtime. The best choice of collective implementation depends upon the number and kind of GPUs, and the network interconnect in the cluster. To override the automatic choice, specify a valid value to the communication parameter of MultiWorkerMirroredStrategy's constructor, e.g. communication=tf.distribute.experimental.CollectiveCommunication.NCCL.
  • Cast the variables to tf.float if possible. The official ResNet model includes an example of how this can be done.

Fault tolerance

In synchronous training, the cluster would fail if one of the workers fails and no failure-recovery mechanism exists. Using Keras with tf.distribute.Strategy comes with the advantage of fault tolerance in cases where workers die or are otherwise unstable. We do this by preserving training state in the distributed file system of your choice, such that upon restart of the instance that previously failed or preempted, the training state is recovered.

Since all the workers are kept in sync in terms of training epochs and steps, other workers would need to wait for the failed or preempted worker to restart to continue.

ModelCheckpoint callback

To take advantage of fault tolerance in multi-worker training, provide an instance of tf.keras.callbacks.ModelCheckpoint at the tf.keras.Model.fit() call. The callback will store the checkpoint and training state in the directory corresponding to the filepath argument to ModelCheckpoint.

# Replace the `filepath` argument with a path in the file system
# accessible by all workers.
callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath='/tmp/keras-ckpt')]
with strategy.scope():
  multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(x=train_datasets, epochs=3, callbacks=callbacks)
Epoch 1/3
    469/Unknown - 8s 18ms/step - loss: 2.2049 - accuracy: 0.2318WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1781: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1781: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets

INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets

469/469 [==============================] - 9s 19ms/step - loss: 2.2049 - accuracy: 0.2318
Epoch 2/3
451/469 [===========================>..] - ETA: 0s - loss: 1.9195 - accuracy: 0.5715INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets

INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets

469/469 [==============================] - 2s 4ms/step - loss: 1.9113 - accuracy: 0.5767
Epoch 3/3
450/469 [===========================>..] - ETA: 0s - loss: 1.4175 - accuracy: 0.7550INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets

INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets

469/469 [==============================] - 2s 4ms/step - loss: 1.4078 - accuracy: 0.7561

<tensorflow.python.keras.callbacks.History at 0x7fc38fdfee80>

If a worker gets preempted, the whole cluster pauses until the preempted worker is restarted. Once the worker rejoins the cluster, other workers will also restart. Now, every worker reads the checkpoint file that was previously saved and picks up its former state, thereby allowing the cluster to get back in sync. Then the training continues.

If you inspect the directory containing the filepath you specified in ModelCheckpoint, you may notice some temporarily generated checkpoint files. Those files are needed for recovering the previously lost instances, and they will be removed by the library at the end of tf.keras.Model.fit() upon successful exiting of your multi-worker training.

See also

  1. Distributed Training in TensorFlow guide provides an overview of the available distribution strategies.
  2. Official ResNet50 model, which can be trained using either MirroredStrategy or MultiWorkerMirroredStrategy.