![]() |
![]() |
![]() |
![]() |
This tutorial demonstrates how to use tf.distribute.Strategy
with custom training loops. We will train a simple CNN model on the fashion MNIST dataset. The fashion MNIST dataset contains 60000 train images of size 28 x 28 and 10000 test images of size 28 x 28.
We are using custom training loops to train our model because they give us flexibility and a greater control on training. Moreover, it is easier to debug the model and the training loop.
# Import TensorFlow
import tensorflow as tf
# Helper libraries
import numpy as np
import os
print(tf.__version__)
2.3.0
Download the fashion MNIST dataset
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]
# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz 32768/29515 [=================================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz 26427392/26421880 [==============================] - 1s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz 8192/5148 [===============================================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz 4423680/4422102 [==============================] - 1s 0us/step
Create a strategy to distribute the variables and the graph
How does tf.distribute.MirroredStrategy
strategy work?
- All the variables and the model graph is replicated on the replicas.
- Input is evenly distributed across the replicas.
- Each replica calculates the loss and gradients for the input it received.
- The gradients are synced across all the replicas by summing them.
- After the sync, the same update is made to the copies of the variables on each replica.
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1
Setup input pipeline
Export the graph and the variables to the platform-agnostic SavedModel format. After your model is saved, you can load it with or without the scope.
BUFFER_SIZE = len(train_images)
BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
EPOCHS = 10
Create the datasets and distribute them:
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)
Create the model
Create a model using tf.keras.Sequential
. You can also use the Model Subclassing API to do this.
def create_model():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(64, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
return model
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
Define the loss function
Normally, on a single machine with 1 GPU/CPU, loss is divided by the number of examples in the batch of input.
So, how should the loss be calculated when using a tf.distribute.Strategy
?
For an example, let's say you have 4 GPU's and a batch size of 64. One batch of input is distributed across the replicas (4 GPUs), each replica getting an input of size 16.
The model on each replica does a forward pass with its respective input and calculates the loss. Now, instead of dividing the loss by the number of examples in its respective input (BATCH_SIZE_PER_REPLICA = 16), the loss should be divided by the GLOBAL_BATCH_SIZE (64).
Why do this?
- This needs to be done because after the gradients are calculated on each replica, they are synced across the replicas by summing them.
How to do this in TensorFlow?
If you're writing a custom training loop, as in this tutorial, you should sum the per example losses and divide the sum by the GLOBAL_BATCH_SIZE:
scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE)
or you can usetf.nn.compute_average_loss
which takes the per example loss, optional sample weights, and GLOBAL_BATCH_SIZE as arguments and returns the scaled loss.If you are using regularization losses in your model then you need to scale the loss value by number of replicas. You can do this by using the
tf.nn.scale_regularization_loss
function.Using
tf.reduce_mean
is not recommended. Doing so divides the loss by actual per replica batch size which may vary step to step.This reduction and scaling is done automatically in keras
model.compile
andmodel.fit
If using
tf.keras.losses
classes (as in the example below), the loss reduction needs to be explicitly specified to be one ofNONE
orSUM
.AUTO
andSUM_OVER_BATCH_SIZE
are disallowed when used withtf.distribute.Strategy
.AUTO
is disallowed because the user should explicitly think about what reduction they want to make sure it is correct in the distributed case.SUM_OVER_BATCH_SIZE
is disallowed because currently it would only divide by per replica batch size, and leave the dividing by number of replicas to the user, which might be easy to miss. So instead we ask the user do the reduction themselves explicitly.If
labels
is multi-dimensional, then average theper_example_loss
across the number of elements in each sample. For example, if the shape ofpredictions
is(batch_size, H, W, n_classes)
andlabels
is(batch_size, H, W)
, you will need to updateper_example_loss
like:per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)
with strategy.scope():
# Set reduction to `none` so we can do the reduction afterwards and divide by
# global batch size.
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE)
def compute_loss(labels, predictions):
per_example_loss = loss_object(labels, predictions)
return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)
Define the metrics to track loss and accuracy
These metrics track the test loss and training and test accuracy. You can use .result()
to get the accumulated statistics at any time.
with strategy.scope():
test_loss = tf.keras.metrics.Mean(name='test_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='train_accuracy')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='test_accuracy')
Training loop
# model, optimizer, and checkpoint must be created under `strategy.scope`.
with strategy.scope():
model = create_model()
optimizer = tf.keras.optimizers.Adam()
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
def train_step(inputs):
images, labels = inputs
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = compute_loss(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_accuracy.update_state(labels, predictions)
return loss
def test_step(inputs):
images, labels = inputs
predictions = model(images, training=False)
t_loss = loss_object(labels, predictions)
test_loss.update_state(t_loss)
test_accuracy.update_state(labels, predictions)
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
axis=None)
@tf.function
def distributed_test_step(dataset_inputs):
return strategy.run(test_step, args=(dataset_inputs,))
for epoch in range(EPOCHS):
# TRAIN LOOP
total_loss = 0.0
num_batches = 0
for x in train_dist_dataset:
total_loss += distributed_train_step(x)
num_batches += 1
train_loss = total_loss / num_batches
# TEST LOOP
for x in test_dist_dataset:
distributed_test_step(x)
if epoch % 2 == 0:
checkpoint.save(checkpoint_prefix)
template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
"Test Accuracy: {}")
print (template.format(epoch+1, train_loss,
train_accuracy.result()*100, test_loss.result(),
test_accuracy.result()*100))
test_loss.reset_states()
train_accuracy.reset_states()
test_accuracy.reset_states()
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Iterator.get_next_as_optional()` instead. INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). Epoch 1, Loss: 0.50295090675354, Accuracy: 82.1116714477539, Test Loss: 0.3852590322494507, Test Accuracy: 86.5999984741211 INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). Epoch 2, Loss: 0.32958829402923584, Accuracy: 88.20333862304688, Test Loss: 0.3391425311565399, Test Accuracy: 87.6500015258789 Epoch 3, Loss: 0.2872008979320526, Accuracy: 89.57167053222656, Test Loss: 0.2974696457386017, Test Accuracy: 89.31000518798828 Epoch 4, Loss: 0.255713552236557, Accuracy: 90.58499908447266, Test Loss: 0.2988712787628174, Test Accuracy: 89.31999969482422 Epoch 5, Loss: 0.23122134804725647, Accuracy: 91.41667175292969, Test Loss: 0.27742496132850647, Test Accuracy: 89.99000549316406 Epoch 6, Loss: 0.212575763463974, Accuracy: 92.17333221435547, Test Loss: 0.2573488652706146, Test Accuracy: 90.75 Epoch 7, Loss: 0.1963273137807846, Accuracy: 92.77166748046875, Test Loss: 0.2587501108646393, Test Accuracy: 90.66000366210938 Epoch 8, Loss: 0.1779220998287201, Accuracy: 93.46666717529297, Test Loss: 0.267805814743042, Test Accuracy: 90.55999755859375 Epoch 9, Loss: 0.16410504281520844, Accuracy: 93.91333770751953, Test Loss: 0.25632956624031067, Test Accuracy: 91.00999450683594 Epoch 10, Loss: 0.14829590916633606, Accuracy: 94.47833251953125, Test Loss: 0.25820475816726685, Test Accuracy: 91.00999450683594
Things to note in the example above:
- We are iterating over the
train_dist_dataset
andtest_dist_dataset
using afor x in ...
construct. - The scaled loss is the return value of the
distributed_train_step
. This value is aggregated across replicas using thetf.distribute.Strategy.reduce
call and then across batches by summing the return value of thetf.distribute.Strategy.reduce
calls. tf.keras.Metrics
should be updated insidetrain_step
andtest_step
that gets executed bytf.distribute.Strategy.run
. *tf.distribute.Strategy.run
returns results from each local replica in the strategy, and there are multiple ways to consume this result. You can dotf.distribute.Strategy.reduce
to get an aggregated value. You can also dotf.distribute.Strategy.experimental_local_results
to get the list of values contained in the result, one per local replica.
Restore the latest checkpoint and test
A model checkpointed with a tf.distribute.Strategy
can be restored with or without a strategy.
eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='eval_accuracy')
new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
@tf.function
def eval_step(images, labels):
predictions = new_model(images, training=False)
eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
for images, labels in test_dataset:
eval_step(images, labels)
print ('Accuracy after restoring the saved model without strategy: {}'.format(
eval_accuracy.result()*100))
Accuracy after restoring the saved model without strategy: 91.00999450683594
Alternate ways of iterating over a dataset
Using iterators
If you want to iterate over a given number of steps and not through the entire dataset you can create an iterator using the iter
call and explicity call next
on the iterator. You can choose to iterate over the dataset both inside and outside the tf.function. Here is a small snippet demonstrating iteration of the dataset outside the tf.function using an iterator.
for _ in range(EPOCHS):
total_loss = 0.0
num_batches = 0
train_iter = iter(train_dist_dataset)
for _ in range(10):
total_loss += distributed_train_step(next(train_iter))
num_batches += 1
average_train_loss = total_loss / num_batches
template = ("Epoch {}, Loss: {}, Accuracy: {}")
print (template.format(epoch+1, average_train_loss, train_accuracy.result()*100))
train_accuracy.reset_states()
Epoch 10, Loss: 0.12157603353261948, Accuracy: 95.0 Epoch 10, Loss: 0.1367541253566742, Accuracy: 94.6875 Epoch 10, Loss: 0.14902949333190918, Accuracy: 93.90625 Epoch 10, Loss: 0.12149540334939957, Accuracy: 95.625 Epoch 10, Loss: 0.13160167634487152, Accuracy: 94.6875 Epoch 10, Loss: 0.13297739624977112, Accuracy: 95.3125 Epoch 10, Loss: 0.16038034856319427, Accuracy: 94.53125 Epoch 10, Loss: 0.1035340279340744, Accuracy: 96.40625 Epoch 10, Loss: 0.11846740543842316, Accuracy: 95.625 Epoch 10, Loss: 0.09006750583648682, Accuracy: 96.71875
Iterating inside a tf.function
You can also iterate over the entire input train_dist_dataset
inside a tf.function using the for x in ...
construct or by creating iterators like we did above. The example below demonstrates wrapping one epoch of training in a tf.function and iterating over train_dist_dataset
inside the function.
@tf.function
def distributed_train_epoch(dataset):
total_loss = 0.0
num_batches = 0
for x in dataset:
per_replica_losses = strategy.run(train_step, args=(x,))
total_loss += strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
num_batches += 1
return total_loss / tf.cast(num_batches, dtype=tf.float32)
for epoch in range(EPOCHS):
train_loss = distributed_train_epoch(train_dist_dataset)
template = ("Epoch {}, Loss: {}, Accuracy: {}")
print (template.format(epoch+1, train_loss, train_accuracy.result()*100))
train_accuracy.reset_states()
Epoch 1, Loss: 0.13680464029312134, Accuracy: 94.90499877929688 Epoch 2, Loss: 0.12503673136234283, Accuracy: 95.33499908447266 Epoch 3, Loss: 0.11472766101360321, Accuracy: 95.71333312988281 Epoch 4, Loss: 0.10419528931379318, Accuracy: 96.13500213623047 Epoch 5, Loss: 0.09566374123096466, Accuracy: 96.44833374023438 Epoch 6, Loss: 0.08704081922769547, Accuracy: 96.82499694824219 Epoch 7, Loss: 0.08157625794410706, Accuracy: 96.96333312988281 Epoch 8, Loss: 0.07562965154647827, Accuracy: 97.11000061035156 Epoch 9, Loss: 0.0676642507314682, Accuracy: 97.47999572753906 Epoch 10, Loss: 0.06430575996637344, Accuracy: 97.58333587646484
Tracking training loss across replicas
We do not recommend using tf.metrics.Mean
to track the training loss across different replicas, because of the loss scaling computation that is carried out.
For example, if you run a training job with the following characteristics:
- Two replicas
- Two samples are processed on each replica
- Resulting loss values: [2, 3] and [4, 5] on each replica
- Global batch size = 4
With loss scaling, you calculate the per-sample value of loss on each replica by adding the loss values, and then dividing by the global batch size. In this case: (2 + 3) / 4 = 1.25
and (4 + 5) / 4 = 2.25
.
If you use tf.metrics.Mean
to track loss across the two replicas, the result is different. In this example, you end up with a total
of 3.50 and count
of 2, which results in total
/count
= 1.75 when result()
is called on the metric. Loss calculated with tf.keras.Metrics
is scaled by an additional factor that is equal to the number of replicas in sync.
Guide and examples
Here are some examples for using distribution strategy with custom training loops:
- Distributed training guide
- DenseNet example using
MirroredStrategy
. - BERT example trained using
MirroredStrategy
andTPUStrategy
. This example is particularly helpful for understanding how to load from a checkpoint and generate periodic checkpoints during distributed training etc. - NCF example trained using
MirroredStrategy
that can be enabled using thekeras_use_ctl
flag. - NMT example trained using
MirroredStrategy
.
More examples listed in the Distribution strategy guide.
Next steps
- Try out the new
tf.distribute.Strategy
API on your models. - Visit the Performance section in the guide to learn more about other strategies and tools you can use to optimize the performance of your TensorFlow models.