Distributed training with Keras

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

The tf.distribute.Strategy API provides an abstraction for distributing your training across multiple processing units. It allows you to carry out distributed training using existing models and training code with minimal changes.

This tutorial demonstrates how to use the tf.distribute.MirroredStrategy to perform in-graph replication with synchronous training on many GPUs on one machine. The strategy essentially copies all of the model's variables to each processor. Then, it uses all-reduce to combine the gradients from all processors, and applies the combined value to all copies of the model.

You will use the tf.keras APIs to build the model and Model.fit for training it. (To learn about distributed training with a custom training loop and the MirroredStrategy, check out this tutorial.)

MirroredStrategy trains your model on multiple GPUs on a single machine. For synchronous training on many GPUs on multiple workers, use the tf.distribute.MultiWorkerMirroredStrategy with the Keras Model.fit or a custom training loop. For other options, refer to the Distributed training guide.

To learn about various other strategies, there is the Distributed training with TensorFlow guide.

Setup

import tensorflow_datasets as tfds
import tensorflow as tf

import os

# Load the TensorBoard notebook extension.
%load_ext tensorboard
2023-12-07 02:42:10.684424: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-07 02:42:10.684469: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-07 02:42:10.685994: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
print(tf.__version__)
2.15.0

Download the dataset

Load the MNIST dataset from TensorFlow Datasets. This returns a dataset in the tf.data format.

Setting the with_info argument to True includes the metadata for the entire dataset, which is being saved here to info. Among other things, this metadata object includes the number of train and test examples.

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)

mnist_train, mnist_test = datasets['train'], datasets['test']

Define the distribution strategy

Create a MirroredStrategy object. This will handle distribution and provide a context manager (MirroredStrategy.scope) to build your model inside.

strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 4

Set up the input pipeline

When training a model with multiple GPUs, you can use the extra computing power effectively by increasing the batch size. In general, use the largest batch size that fits the GPU memory and tune the learning rate accordingly.

# You can also do info.splits.total_num_examples to get the total
# number of examples in the dataset.

num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

Define a function that normalizes the image pixel values from the [0, 255] range to the [0, 1] range (feature scaling):

def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

Apply this scale function to the training and test data, and then use the tf.data.Dataset APIs to shuffle the training data (Dataset.shuffle), and batch it (Dataset.batch). Notice that you are also keeping an in-memory cache of the training data to improve performance (Dataset.cache).

train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

Create the model and instantiate the optimizer

Within the context of Strategy.scope, create and compile the model using the Keras API:

with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
                metrics=['accuracy'])

For this toy example with the MNIST dataset, you will be using the Adam optimizer's default learning rate of 0.001.

For larger datasets, the key benefit of distributed training is to learn more in each training step, because each step processes more training data in parallel, which allows for a larger learning rate (within the limits of the model and dataset).

Define the callbacks

Define the following Keras Callbacks:

For illustrative purposes, add a custom callback called PrintLR to display the learning rate in the notebook.

# Define the checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
# Define the name of the checkpoint files.
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Define a function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5
# Define a callback for printing the learning rate at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print('\nLearning rate for epoch {} is {}'.format(        epoch + 1, model.optimizer.lr.numpy()))
# Put all the callbacks together.
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

Train and evaluate

Now, train the model in the usual way by calling Keras Model.fit on the model and passing in the dataset created at the beginning of the tutorial. This step is the same whether you are distributing the training or not.

EPOCHS = 12

model.fit(train_dataset, epochs=EPOCHS, callbacks=callbacks)
2023-12-07 02:42:17.017489: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/12
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1701916944.003594   37459 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
1/235 [..............................] - ETA: 26:24 - loss: 2.2859 - accuracy: 0.1914WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0080s vs `on_train_batch_end` time: 0.0138s). Check your callbacks.
WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0080s vs `on_train_batch_end` time: 0.0138s). Check your callbacks.
235/235 [==============================] - ETA: 0s - loss: 0.3395 - accuracy: 0.9071INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Learning rate for epoch 1 is 0.0010000000474974513
235/235 [==============================] - 9s 9ms/step - loss: 0.3395 - accuracy: 0.9071 - lr: 0.0010
Epoch 2/12
235/235 [==============================] - ETA: 0s - loss: 0.0976 - accuracy: 0.9720
Learning rate for epoch 2 is 0.0010000000474974513
235/235 [==============================] - 2s 7ms/step - loss: 0.0976 - accuracy: 0.9720 - lr: 0.0010
Epoch 3/12
232/235 [============================>.] - ETA: 0s - loss: 0.0671 - accuracy: 0.9807
Learning rate for epoch 3 is 0.0010000000474974513
235/235 [==============================] - 2s 7ms/step - loss: 0.0668 - accuracy: 0.9808 - lr: 0.0010
Epoch 4/12
229/235 [============================>.] - ETA: 0s - loss: 0.0468 - accuracy: 0.9874
Learning rate for epoch 4 is 9.999999747378752e-05
235/235 [==============================] - 2s 7ms/step - loss: 0.0467 - accuracy: 0.9874 - lr: 1.0000e-04
Epoch 5/12
235/235 [==============================] - ETA: 0s - loss: 0.0443 - accuracy: 0.9879
Learning rate for epoch 5 is 9.999999747378752e-05
235/235 [==============================] - 2s 7ms/step - loss: 0.0443 - accuracy: 0.9879 - lr: 1.0000e-04
Epoch 6/12
232/235 [============================>.] - ETA: 0s - loss: 0.0422 - accuracy: 0.9886
Learning rate for epoch 6 is 9.999999747378752e-05
235/235 [==============================] - 2s 7ms/step - loss: 0.0428 - accuracy: 0.9885 - lr: 1.0000e-04
Epoch 7/12
231/235 [============================>.] - ETA: 0s - loss: 0.0414 - accuracy: 0.9885
Learning rate for epoch 7 is 9.999999747378752e-05
235/235 [==============================] - 2s 7ms/step - loss: 0.0413 - accuracy: 0.9884 - lr: 1.0000e-04
Epoch 8/12
230/235 [============================>.] - ETA: 0s - loss: 0.0394 - accuracy: 0.9895
Learning rate for epoch 8 is 9.999999747378752e-06
235/235 [==============================] - 2s 7ms/step - loss: 0.0392 - accuracy: 0.9895 - lr: 1.0000e-05
Epoch 9/12
233/235 [============================>.] - ETA: 0s - loss: 0.0391 - accuracy: 0.9895
Learning rate for epoch 9 is 9.999999747378752e-06
235/235 [==============================] - 2s 7ms/step - loss: 0.0390 - accuracy: 0.9895 - lr: 1.0000e-05
Epoch 10/12
234/235 [============================>.] - ETA: 0s - loss: 0.0388 - accuracy: 0.9895
Learning rate for epoch 10 is 9.999999747378752e-06
235/235 [==============================] - 2s 7ms/step - loss: 0.0388 - accuracy: 0.9895 - lr: 1.0000e-05
Epoch 11/12
230/235 [============================>.] - ETA: 0s - loss: 0.0386 - accuracy: 0.9897
Learning rate for epoch 11 is 9.999999747378752e-06
235/235 [==============================] - 2s 7ms/step - loss: 0.0387 - accuracy: 0.9896 - lr: 1.0000e-05
Epoch 12/12
230/235 [============================>.] - ETA: 0s - loss: 0.0386 - accuracy: 0.9896
Learning rate for epoch 12 is 9.999999747378752e-06
235/235 [==============================] - 2s 7ms/step - loss: 0.0385 - accuracy: 0.9897 - lr: 1.0000e-05
<keras.src.callbacks.History at 0x7f52740d3340>

Check for saved checkpoints:

# Check the checkpoint directory.
ls {checkpoint_dir}
checkpoint           ckpt_4.data-00000-of-00001
ckpt_1.data-00000-of-00001   ckpt_4.index
ckpt_1.index             ckpt_5.data-00000-of-00001
ckpt_10.data-00000-of-00001  ckpt_5.index
ckpt_10.index            ckpt_6.data-00000-of-00001
ckpt_11.data-00000-of-00001  ckpt_6.index
ckpt_11.index            ckpt_7.data-00000-of-00001
ckpt_12.data-00000-of-00001  ckpt_7.index
ckpt_12.index            ckpt_8.data-00000-of-00001
ckpt_2.data-00000-of-00001   ckpt_8.index
ckpt_2.index             ckpt_9.data-00000-of-00001
ckpt_3.data-00000-of-00001   ckpt_9.index
ckpt_3.index

To check how well the model performs, load the latest checkpoint and call Model.evaluate on the test data:

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)

print('Eval loss: {}, Eval accuracy: {}'.format(eval_loss, eval_acc))
2023-12-07 02:42:49.914159: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
40/40 [==============================] - 2s 8ms/step - loss: 0.0474 - accuracy: 0.9837
Eval loss: 0.04736212641000748, Eval accuracy: 0.9836999773979187

To visualize the output, launch TensorBoard and view the logs:

%tensorboard --logdir=logs

ls -sh ./logs
total 4.0K
4.0K train

Save the model

Save the model to a .keras zip archive using Model.save. After your model is saved, you can load it with or without the Strategy.scope.

path = 'my_model.keras'
model.save(path)

Now, load the model without Strategy.scope:

unreplicated_model = tf.keras.models.load_model(path)

unreplicated_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
40/40 [==============================] - 0s 4ms/step - loss: 0.0474 - accuracy: 0.9837
Eval loss: 0.047362130135297775, Eval Accuracy: 0.9836999773979187

Load the model with Strategy.scope:

with strategy.scope():
  replicated_model = tf.keras.models.load_model(path)
  replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
2023-12-07 02:42:53.484749: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
40/40 [==============================] - 2s 5ms/step - loss: 0.0474 - accuracy: 0.9837
Eval loss: 0.04736212641000748, Eval Accuracy: 0.9836999773979187

Additional resources

More examples that use different distribution strategies with the Keras Model.fit API:

  1. The Solve GLUE tasks using BERT on TPU tutorial uses tf.distribute.MirroredStrategy for training on GPUs and tf.distribute.TPUStrategy on TPUs.
  2. The Save and load a model using a distribution strategy tutorial demonstates how to use the SavedModel APIs with tf.distribute.Strategy.
  3. The official TensorFlow models can be configured to run multiple distribution strategies.

To learn more about TensorFlow distribution strategies:

  1. The Custom training with tf.distribute.Strategy tutorial shows how to use the tf.distribute.MirroredStrategy for single-worker training with a custom training loop.
  2. The Multi-worker training with Keras tutorial shows how to use the MultiWorkerMirroredStrategy with Model.fit.
  3. The Custom training loop with Keras and MultiWorkerMirroredStrategy tutorial shows how to use the MultiWorkerMirroredStrategy with Keras and a custom training loop.
  4. The Distributed training in TensorFlow guide provides an overview of the available distribution strategies.
  5. The Better performance with tf.function guide provides information about other strategies and tools, such as the TensorFlow Profiler you can use to optimize the performance of your TensorFlow models.