Watch talks from the 2019 TensorFlow Dev Summit Watch now

Distributed training in TensorFlow

View on TensorFlow.org Run in Google Colab View source on GitHub

Overview

The tf.distribute.Strategy API provides an abstraction for distributing your training across multiple processing units. The goal is to allow users to enable distributed training using existing models and training code, with minimal changes.

This tutorial uses the tf.distribute.MirroredStrategy, which does in-graph replication with synchronous training on many GPUs on one machine. Essentially, it copies all of the model's variables to each processor. Then, it uses all-reduce to combine the gradients from all processors and applies the combined value to all copies of the model.

MirroredStategy is one of several distribution strategy available in TensorFlow core. You can read about more strategies at distribution strategy guide.

Keras API

This example uses the tf.keras API to build the model and training loop. For custom training loops, see this tutorial.

Import Dependencies

from __future__ import absolute_import, division, print_function
# Import TensorFlow
!pip install -q tensorflow==2.0.0-alpha0 
import tensorflow_datasets as tfds
import tensorflow as tf

import os

Download the dataset

Download the MNIST dataset and load it from TensorFlow Datasets. This returns a dataset in tf.data format.

Setting with_info to True includes the metadata for the entire dataset, which is being saved here to ds_info. Among other things, this metadata object includes the number of train and test examples.

datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/1 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/2 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/3 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]
Downloading / extracting dataset mnist (11.06 MiB) to /root/tensorflow_datasets/mnist/1.0.0...

Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Size...:   0%|          | 0/9 [00:00<?, ? MiB/s]

Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  3.83 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  3.83 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Extraction completed...:   0%|          | 0/1 [00:00<?, ? file/s]

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  3.83 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  3.83 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  4.42 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  4.42 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Extraction completed...:  50%|█████     | 1/2 [00:00<00:00,  3.59 file/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  4.42 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Extraction completed...: 100%|██████████| 2/2 [00:00<00:00,  4.13 file/s]
Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  4.42 url/s]
Dl Size...:  10%|█         | 1/10 [00:00<00:06,  1.43 MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  4.42 url/s]
Dl Size...:  20%|██        | 2/10 [00:00<00:05,  1.43 MiB/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:00<00:00,  3.76 url/s]
Dl Size...:  20%|██        | 2/10 [00:00<00:05,  1.43 MiB/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:00<00:00,  3.76 url/s]
Dl Size...:  20%|██        | 2/10 [00:00<00:05,  1.43 MiB/s]

Extraction completed...:  67%|██████▋   | 2/3 [00:00<00:00,  4.13 file/s]
Dl Completed...:  75%|███████▌  | 3/4 [00:00<00:00,  3.76 url/s]
Dl Size...:  30%|███       | 3/10 [00:00<00:03,  1.98 MiB/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:00<00:00,  3.76 url/s]
Dl Size...:  40%|████      | 4/10 [00:00<00:03,  1.98 MiB/s]

Extraction completed...:  67%|██████▋   | 2/3 [00:00<00:00,  4.13 file/s]
Dl Completed...:  75%|███████▌  | 3/4 [00:00<00:00,  3.76 url/s]
Dl Size...:  50%|█████     | 5/10 [00:00<00:01,  2.70 MiB/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:00<00:00,  3.76 url/s]
Dl Size...:  60%|██████    | 6/10 [00:00<00:01,  2.70 MiB/s]

Extraction completed...:  67%|██████▋   | 2/3 [00:00<00:00,  4.13 file/s]
Dl Completed...:  75%|███████▌  | 3/4 [00:01<00:00,  3.76 url/s]
Dl Size...:  70%|███████   | 7/10 [00:01<00:00,  3.64 MiB/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:01<00:00,  3.76 url/s]
Dl Size...:  80%|████████  | 8/10 [00:01<00:00,  3.64 MiB/s]

Extraction completed...:  67%|██████▋   | 2/3 [00:01<00:00,  4.13 file/s]
Dl Completed...:  75%|███████▌  | 3/4 [00:01<00:00,  3.76 url/s]
Dl Size...:  90%|█████████ | 9/10 [00:01<00:00,  4.82 MiB/s]

Extraction completed...:  67%|██████▋   | 2/3 [00:01<00:00,  4.13 file/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:01<00:00,  3.76 url/s]
Dl Size...:  90%|█████████ | 9/10 [00:01<00:00,  4.82 MiB/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:01<00:00,  3.76 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  4.82 MiB/s]

Dl Completed...: 100%|██████████| 4/4 [00:01<00:00,  3.13 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  4.82 MiB/s]

Dl Completed...: 100%|██████████| 4/4 [00:01<00:00,  3.13 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  4.82 MiB/s]

Extraction completed...:  75%|███████▌  | 3/4 [00:01<00:00,  2.57 file/s]

Dl Completed...: 100%|██████████| 4/4 [00:01<00:00,  3.13 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  4.82 MiB/s]

Extraction completed...: 100%|██████████| 4/4 [00:01<00:00,  2.27 file/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  5.73 MiB/s]
1 examples [00:00,  8.42 examples/s]




60000 examples [00:17, 3399.41 examples/s]
Shuffling...:   0%|          | 0/10 [00:00<?, ? shard/s]WARNING: Logging before flag parsing goes to stderr.
W0307 18:20:37.943666 140066264610560 deprecation.py:323] From /usr/local/lib/python3.5/dist-packages/tensorflow_datasets/core/file_format_adapter.py:249: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 249584.20 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 116622.37 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 248080.91 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...:  20%|██        | 2/10 [00:00<00:00, 11.18 shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 195497.63 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 127939.48 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 170939.09 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...:  40%|████      | 4/10 [00:00<00:00, 11.03 shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 246801.19 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 116299.54 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 208018.12 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...:  60%|██████    | 6/10 [00:00<00:00, 11.00 shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 228634.72 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 120247.24 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 173193.10 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...:  80%|████████  | 8/10 [00:00<00:00, 10.97 shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 255472.44 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 121654.54 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 247790.23 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...: 100%|██████████| 10/10 [00:00<00:00, 11.24 shard/s]
10000 examples [00:02, 3371.07 examples/s]
Shuffling...:   0%|          | 0/1 [00:00<?, ? shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 10000 examples [00:00, 318266.28 examples/s]
Writing...:   0%|          | 0/10000 [00:00<?, ? examples/s]
Shuffling...: 100%|██████████| 1/1 [00:00<00:00,  8.18 shard/s]

Define Distribution Strategy

Create a MirroredStrategy object. This will handle distribution, and provides a context manager (tf.distribute.MirroredStrategy.scope) to build your model inside.

strategy = tf.distribute.MirroredStrategy()
W0307 18:20:42.523524 140066264610560 cross_device_ops.py:1111] Not all devices in `tf.distribute.Strategy` are visible to TensorFlow.
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

Setup Input pipeline

If a model is trained on multiple GPUs, the batch size should be increased accordingly so as to make effective use of the extra computing power. Moreover, the learning rate should be tuned accordingly.

# You can also do ds_info.splits.total_num_examples to get the total 
# number of examples in the dataset.

num_train_examples = ds_info.splits['train'].num_examples
num_test_examples = ds_info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

Pixel values, which are 0-255, have to be normalized to the 0-1 range. Define this scale in a function.

def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255
  
  return image, label

Apply this function to the training and test data, shuffle the training data, and batch it for training.

train_dataset = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

Create the model

Create and compile the Keras model in the context of strategy.scope.

with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')
  ])
  
  model.compile(loss='sparse_categorical_crossentropy',
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])

Define the callbacks.

The callbacks used here are:

  • Tensorboard: This callback writes a log for Tensorboard which allows you to visualize the graphs.
  • Model Checkpoint: This callback saves the model after every epoch.
  • Learning Rate Scheduler: Using this callback, you can schedule the learning rate to change after every epoch/batch.

For illustrative purposes, add a print callback to display the learning rate in the notebook.

# Define the checkpoint directory to store the checkpoints

checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5
# Callback for printing the LR at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print ('\nLearning rate for epoch {} is {}'.format(epoch + 1, 
                                                       model.optimizer.lr.numpy()))
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix, 
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

Train and evaluate

Now, train the model in the usual way, calling fit on the model and passing in the dataset created at the beginning of the tutorial. This step is the same whether you are distributing the training or not.

model.fit(train_dataset, epochs=10, callbacks=callbacks)
W0307 18:20:44.069839 140066264610560 distributed_training_utils.py:182] Your input callback is not one of the predefined Callbacks that supports DistributionStrategy. You might encounter an error if you access one of the model's attributes as part of the callback since these attributes are not set. You can access each of the individual distributed models using the `_grouped_model` attribute of your original model.
W0307 18:20:44.071910 140066264610560 distributed_training_utils.py:182] Your input callback is not one of the predefined Callbacks that supports DistributionStrategy. You might encounter an error if you access one of the model's attributes as part of the callback since these attributes are not set. You can access each of the individual distributed models using the `_grouped_model` attribute of your original model.
W0307 18:20:44.073302 140066264610560 distributed_training_utils.py:182] Your input callback is not one of the predefined Callbacks that supports DistributionStrategy. You might encounter an error if you access one of the model's attributes as part of the callback since these attributes are not set. You can access each of the individual distributed models using the `_grouped_model` attribute of your original model.
W0307 18:20:44.074583 140066264610560 distributed_training_utils.py:182] Your input callback is not one of the predefined Callbacks that supports DistributionStrategy. You might encounter an error if you access one of the model's attributes as part of the callback since these attributes are not set. You can access each of the individual distributed models using the `_grouped_model` attribute of your original model.

Epoch 1/10
    938/Unknown - 10s 10ms/step - loss: 0.2095 - accuracy: 0.9394
Learning rate for epoch 1 is 0.0010000000474974513
938/938 [==============================] - 10s 10ms/step - loss: 0.2095 - accuracy: 0.9394
Epoch 2/10
933/938 [============================>.] - ETA: 0s - loss: 0.0673 - accuracy: 0.9804
Learning rate for epoch 2 is 0.0010000000474974513
938/938 [==============================] - 9s 9ms/step - loss: 0.0673 - accuracy: 0.9804
Epoch 3/10
934/938 [============================>.] - ETA: 0s - loss: 0.0483 - accuracy: 0.9855
Learning rate for epoch 3 is 0.0010000000474974513
938/938 [==============================] - 8s 9ms/step - loss: 0.0484 - accuracy: 0.9855
Epoch 4/10
933/938 [============================>.] - ETA: 0s - loss: 0.0278 - accuracy: 0.9922
Learning rate for epoch 4 is 9.999999747378752e-05
938/938 [==============================] - 8s 9ms/step - loss: 0.0277 - accuracy: 0.9923
Epoch 5/10
934/938 [============================>.] - ETA: 0s - loss: 0.0240 - accuracy: 0.9933
Learning rate for epoch 5 is 9.999999747378752e-05
938/938 [==============================] - 8s 9ms/step - loss: 0.0240 - accuracy: 0.9933
Epoch 6/10
934/938 [============================>.] - ETA: 0s - loss: 0.0221 - accuracy: 0.9940
Learning rate for epoch 6 is 9.999999747378752e-05
938/938 [==============================] - 8s 9ms/step - loss: 0.0221 - accuracy: 0.9940
Epoch 7/10
937/938 [============================>.] - ETA: 0s - loss: 0.0203 - accuracy: 0.9949
Learning rate for epoch 7 is 9.999999747378752e-05
938/938 [==============================] - 8s 9ms/step - loss: 0.0203 - accuracy: 0.9949
Epoch 8/10
934/938 [============================>.] - ETA: 0s - loss: 0.0178 - accuracy: 0.9958
Learning rate for epoch 8 is 9.999999747378752e-06
938/938 [==============================] - 9s 9ms/step - loss: 0.0178 - accuracy: 0.9958
Epoch 9/10
936/938 [============================>.] - ETA: 0s - loss: 0.0174 - accuracy: 0.9960
Learning rate for epoch 9 is 9.999999747378752e-06
938/938 [==============================] - 9s 9ms/step - loss: 0.0174 - accuracy: 0.9959
Epoch 10/10
937/938 [============================>.] - ETA: 0s - loss: 0.0172 - accuracy: 0.9960
Learning rate for epoch 10 is 9.999999747378752e-06
938/938 [==============================] - 9s 9ms/step - loss: 0.0172 - accuracy: 0.9961

<tensorflow.python.keras.callbacks.History at 0x7f6348d2f358>

As you can see below, the checkpoints are getting saved.

# check the checkpoint directory
!ls {checkpoint_dir}
checkpoint           ckpt_5.data-00000-of-00001
ckpt_1.data-00000-of-00001   ckpt_5.index
ckpt_1.index             ckpt_6.data-00000-of-00001
ckpt_10.data-00000-of-00001  ckpt_6.index
ckpt_10.index            ckpt_7.data-00000-of-00001
ckpt_2.data-00000-of-00001   ckpt_7.index
ckpt_2.index             ckpt_8.data-00000-of-00001
ckpt_3.data-00000-of-00001   ckpt_8.index
ckpt_3.index             ckpt_9.data-00000-of-00001
ckpt_4.data-00000-of-00001   ckpt_9.index
ckpt_4.index

To see how the model perform, load the latest checkpoint and call evaluate on the test data.

Call evaluate as before using appropriate datasets.

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)
print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
    157/Unknown - 1s 10ms/step - loss: 0.0395 - accuracy: 0.9870Eval loss: 0.0394766087422094, Eval Accuracy: 0.9869999885559082

To see the output, you can download and view the TensorBoard logs at the terminal.

$ tensorboard --logdir=path/to/log-directory
!ls -sh ./logs
total 12K
4.0K plugins  4.0K train  4.0K validation

Export to SavedModel

If you want to export the graph and the variables, SavedModel is the best way of doing this. The model can be loaded back with or without the scope. Moreover, SavedModel is platform agnostic.

path = 'saved_model/'
tf.keras.experimental.export_saved_model(model, path)
W0307 18:22:49.946761 140066264610560 deprecation.py:323] From /usr/local/lib/python3.5/dist-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:253: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
W0307 18:22:49.948541 140066264610560 tf_logging.py:161] Export includes no default signature!
W0307 18:22:50.497394 140066264610560 tf_logging.py:161] Export includes no default signature!

Load the model without strategy.scope.

unreplicated_model = tf.keras.experimental.load_from_saved_model(path)

unreplicated_model.compile(
    loss='sparse_categorical_crossentropy', 
    optimizer=tf.keras.optimizers.Adam(), 
    metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)
print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
    157/Unknown - 1s 9ms/step - loss: 0.0395 - accuracy: 0.9870Eval loss: 0.0394766087422094, Eval Accuracy: 0.9869999885559082

Load the model with strategy.scope.

with strategy.scope():
  replicated_model = tf.keras.experimental.load_from_saved_model(path)
  replicated_model.compile(loss='sparse_categorical_crossentropy',
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
    157/Unknown - 1s 9ms/step - loss: 0.0395 - accuracy: 0.9870Eval loss: 0.0394766087422094, Eval Accuracy: 0.9869999885559082

What's next?

Read the distribution strategy guide.

Try the Distributed Training with Custom Training Loops tutorial.