![]() |
![]() |
![]() |
![]() |
Overview
This tutorial demonstrates multi-worker distributed training with Keras model using tf.distribute.Strategy
API. With the help of the strategies specifically designed for multi-worker training, a Keras model that was designed to run on single-worker can seamlessly work on multiple workers with minimal code change.
Distributed Training in TensorFlow guide is available for an overview of the distribution strategies TensorFlow supports for those interested in a deeper understanding of tf.distribute.Strategy
APIs.
Setup
First, setup TensorFlow and the necessary imports.
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()
Preparing dataset
Now, let's prepare the MNIST dataset from TensorFlow Datasets. The MNIST dataset comprises 60,000 training examples and 10,000 test examples of the handwritten digits 0–9, formatted as 28x28-pixel monochrome images.
BUFFER_SIZE = 10000
BATCH_SIZE = 64
def make_datasets_unbatched():
# Scaling MNIST data from (0, 255] to (0., 1.]
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
datasets, info = tfds.load(name='mnist',
with_info=True,
as_supervised=True)
return datasets['train'].map(scale).cache().shuffle(BUFFER_SIZE)
train_datasets = make_datasets_unbatched().batch(BATCH_SIZE)
Build the Keras model
Here we use tf.keras.Sequential
API to build and compile a simple convolutional neural networks Keras model to train with our MNIST dataset.
def build_and_compile_cnn_model():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(
loss=tf.keras.losses.sparse_categorical_crossentropy,
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
metrics=['accuracy'])
return model
Let's first try training the model for a small number of epochs and observe the results in single worker to make sure everything works correctly. You should expect to see the loss dropping and accuracy approaching 1.0 as epoch advances.
single_worker_model = build_and_compile_cnn_model()
single_worker_model.fit(x=train_datasets, epochs=3)
Epoch 1/3 938/938 [==============================] - 14s 15ms/step - loss: 2.1411 - accuracy: 0.3374 Epoch 2/3 938/938 [==============================] - 10s 10ms/step - loss: 1.4281 - accuracy: 0.7325 Epoch 3/3 938/938 [==============================] - 10s 10ms/step - loss: 0.7333 - accuracy: 0.8328 <tensorflow.python.keras.callbacks.History at 0x7f661023d160>
Multi-worker Configuration
Now let's enter the world of multi-worker training. In TensorFlow, TF_CONFIG
environment variable is required for training on multiple machines, each of which possibly has a different role. TF_CONFIG
is used to specify the cluster configuration on each worker that is part of the cluster.
There are two components of TF_CONFIG
: cluster
and task
. cluster
provides information about the training cluster, which is a dict consisting of different types of jobs such as worker
. In multi-worker training, there is usually one worker
that takes on a little more responsibility like saving checkpoint and writing summary file for TensorBoard in addition to what a regular worker
does. Such worker is referred to as the 'chief' worker, and it is customary that the worker
with index
0 is appointed as the chief worker
(in fact this is how tf.distribute.Strategy
is implemented). task
on the other hand provides information of the current task.
In this example, we set the task type
to "worker"
and the task index
to 0
. This means the machine that has such setting is the first worker, which will be appointed as the chief worker and do more work than other workers. Note that other machines will need to have TF_CONFIG
environment variable set as well, and it should have the same cluster
dict, but different task type
or task index
depending on what the roles of those machines are.
For illustration purposes, this tutorial shows how one may set a TF_CONFIG
with 2 workers on localhost
. In practice, users would create multiple workers on external IP addresses/ports, and set TF_CONFIG
on each worker appropriately.
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': ["localhost:12345", "localhost:23456"]
},
'task': {'type': 'worker', 'index': 0}
})
Note that while the learning rate is fixed in this example, in general it may be necessary to adjust the learning rate based on the global batch size.
Choose the right strategy
In TensorFlow, distributed training consists of synchronous training, where the steps of training are synced across the workers and replicas, and asynchronous training, where the training steps are not strictly synced.
MultiWorkerMirroredStrategy
, which is the recommended strategy for synchronous multi-worker training, will be demonstrated in this guide.
To train the model, use an instance of tf.distribute.experimental.MultiWorkerMirroredStrategy
. MultiWorkerMirroredStrategy
creates copies of all variables in the model's layers on each device across all workers. It uses CollectiveOps
, a TensorFlow op for collective communication, to aggregate gradients and keep the variables in sync. The tf.distribute.Strategy
guide has more details about this strategy.
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled. WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled. INFO:tensorflow:Using MirroredStrategy with devices ('/device:CPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/device:CPU:0',) INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:CPU:0',), communication = CollectiveCommunication.AUTO INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:CPU:0',), communication = CollectiveCommunication.AUTO
MultiWorkerMirroredStrategy
provides multiple implementations via the CollectiveCommunication
parameter. RING
implements ring-based collectives using gRPC as the cross-host communication layer. NCCL
uses Nvidia's NCCL to implement collectives. AUTO
defers the choice to the runtime. The best choice of collective implementation depends upon the number and kind of GPUs, and the network interconnect in the cluster.
Train the model with MultiWorkerMirroredStrategy
With the integration of tf.distribute.Strategy
API into tf.keras
, the only change you will make to distribute the training to multi-worker is enclosing the model building and model.compile()
call inside strategy.scope()
. The distribution strategy's scope dictates how and where the variables are created, and in the case of MultiWorkerMirroredStrategy
, the variables created are MirroredVariable
s, and they are replicated on each of the workers.
NUM_WORKERS = 2
# Here the batch size scales up by number of workers since
# `tf.data.Dataset.batch` expects the global batch size. Previously we used 64,
# and now this becomes 128.
GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS
with strategy.scope():
# Creation of dataset, and model building/compiling need to be within
# `strategy.scope()`.
train_datasets = make_datasets_unbatched().batch(GLOBAL_BATCH_SIZE)
multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(x=train_datasets, epochs=3)
Epoch 1/3 469/469 [==============================] - 14s 29ms/step - loss: 2.2063 - accuracy: 0.2959 Epoch 2/3 469/469 [==============================] - 8s 16ms/step - loss: 1.9151 - accuracy: 0.5978 Epoch 3/3 469/469 [==============================] - 8s 16ms/step - loss: 1.4234 - accuracy: 0.7405 <tensorflow.python.keras.callbacks.History at 0x7f65dc760470>
Dataset sharding and batch size
In multi-worker training, sharding data into multiple parts is needed to ensure convergence and performance. However, note that in above code snippet, the datasets are directly sent to model.fit()
without needing to shard; this is because tf.distribute.Strategy
API takes care of the dataset sharding automatically in multi-worker trainings.
If you prefer manual sharding for your training, automatic sharding can be turned off via tf.data.experimental.DistributeOptions
api. Concretely,
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
train_datasets_no_auto_shard = train_datasets.with_options(options)
Another thing to notice is the batch size for the datasets
. In the code snippet above, we use GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS
, which is NUM_WORKERS
times as large as the case it was for single worker, because the effective per worker batch size is the global batch size (the parameter passed in tf.data.Dataset.batch()
) divided by the number of workers, and with this change we are keeping the per worker batch size same as before.
Performance
You now have a Keras model that is all set up to run in multiple workers with MultiWorkerMirroredStrategy
. You can try the following techniques to tweak performance of multi-worker training.
MultiWorkerMirroredStrategy
provides multiple collective communication implementations.RING
implements ring-based collectives using gRPC as the cross-host communication layer.NCCL
uses Nvidia's NCCL to implement collectives.AUTO
defers the choice to the runtime. The best choice of collective implementation depends upon the number and kind of GPUs, and the network interconnect in the cluster. To override the automatic choice, specify a valid value to thecommunication
parameter ofMultiWorkerMirroredStrategy
's constructor, e.g.communication=tf.distribute.experimental.CollectiveCommunication.NCCL
.- Cast the variables to
tf.float
if possible. The official ResNet model includes an example of how this can be done.
Fault tolerance
In synchronous training, the cluster would fail if one of the workers fails and no failure-recovery mechanism exists. Using Keras with tf.distribute.Strategy
comes with the advantage of fault tolerance in cases where workers die or are otherwise unstable. We do this by preserving training state in the distributed file system of your choice, such that upon restart of the instance that previously failed or preempted, the training state is recovered.
Since all the workers are kept in sync in terms of training epochs and steps, other workers would need to wait for the failed or preempted worker to restart to continue.
ModelCheckpoint callback
To take advantage of fault tolerance in multi-worker training, provide an instance of tf.keras.callbacks.ModelCheckpoint
at the tf.keras.Model.fit()
call. The callback will store the checkpoint and training state in the directory corresponding to the filepath
argument to ModelCheckpoint
.
# Replace the `filepath` argument with a path in the file system
# accessible by all workers.
callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath='/tmp/keras-ckpt')]
with strategy.scope():
multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(x=train_datasets, epochs=3, callbacks=callbacks)
Epoch 1/3 469/Unknown - 13s 29ms/step - loss: 2.2116 - accuracy: 0.2893WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1788: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version. Instructions for updating: If using Keras pass *_constraint arguments to layers. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1788: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version. Instructions for updating: If using Keras pass *_constraint arguments to layers. INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets 469/469 [==============================] - 14s 30ms/step - loss: 2.2116 - accuracy: 0.2893 Epoch 2/3 465/469 [============================>.] - ETA: 0s - loss: 1.9749 - accuracy: 0.4954INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets 469/469 [==============================] - 8s 17ms/step - loss: 1.9734 - accuracy: 0.4960 Epoch 3/3 467/469 [============================>.] - ETA: 0s - loss: 1.5652 - accuracy: 0.6418INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets 469/469 [==============================] - 8s 17ms/step - loss: 1.5639 - accuracy: 0.6423 <tensorflow.python.keras.callbacks.History at 0x7f65dc6170f0>
If a worker gets preempted, the whole cluster pauses until the preempted worker is restarted. Once the worker rejoins the cluster, other workers will also restart. Now, every worker reads the checkpoint file that was previously saved and picks up its former state, thereby allowing the cluster to get back in sync. Then the training continues.
If you inspect the directory containing the filepath
you specified in ModelCheckpoint
, you may notice some temporarily generated checkpoint files. Those files are needed for recovering the previously lost instances, and they will be removed by the library at the end of tf.keras.Model.fit()
upon successful exiting of your multi-worker training.
See also
- Distributed Training in TensorFlow guide provides an overview of the available distribution strategies.
- Official models, many of which can be configured to run multiple distribution strategies.