![]() |
![]() |
![]() |
![]() |
This tutorial demonstrates how to use tf.distribute.Strategy
—a TensorFlow API that provides an abstraction for distributing your training across multiple processing units (GPUs, multiple machines, or TPUs)—with custom training loops. In this example, you will train a simple convolutional neural network on the Fashion MNIST dataset containing 70,000 images of size 28 x 28.
Custom training loops provide flexibility and a greater control on training. They also make it easier to debug the model and the training loop.
# Import TensorFlow
import tensorflow as tf
# Helper libraries
import numpy as np
import os
print(tf.__version__)
2022-12-14 03:58:36.707326: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 03:58:36.707416: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 03:58:36.707426: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly. 2.11.0
Download the Fashion MNIST dataset
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
# Add a dimension to the array -> new shape == (28, 28, 1)
# This is done because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]
# Scale the images to the [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz 29515/29515 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz 26421880/26421880 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz 5148/5148 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz 4422102/4422102 [==============================] - 0s 0us/step
Create a strategy to distribute the variables and the graph
How does tf.distribute.MirroredStrategy
strategy work?
- All the variables and the model graph are replicated across the replicas.
- Input is evenly distributed across the replicas.
- Each replica calculates the loss and gradients for the input it received.
- The gradients are synced across all the replicas by summing them.
- After the sync, the same update is made to the copies of the variables on each replica.
# If the list of devices is not specified in
# `tf.distribute.MirroredStrategy` constructor, they will be auto-detected.
strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 4
Setup input pipeline
BUFFER_SIZE = len(train_images)
BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
EPOCHS = 10
Create the datasets and distribute them:
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)
2022-12-14 03:58:43.807083: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_UINT8 } } } attr { key: "_cardinality" value { i: 60000 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:0" } } attr { key: "output_shapes" value { list { shape { dim { size: 28 } dim { size: 28 } dim { size: 1 } } shape { } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_UINT8 } } } } } 2022-12-14 03:58:43.864909: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_UINT8 } } } attr { key: "_cardinality" value { i: 10000 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:3" } } attr { key: "output_shapes" value { list { shape { dim { size: 28 } dim { size: 28 } dim { size: 1 } } shape { } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_UINT8 } } } } }
Create the model
Create a model using tf.keras.Sequential
. You can also use the Model Subclassing API or the functional API to do this.
def create_model():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(64, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
return model
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
Define the loss function
Normally, on a single machine with a single GPU/CPU, the loss function is divided by the number of examples in the batch of input.
So, how should the loss be calculated when using a tf.distribute.Strategy
?
For an example, let's say you have 4 GPUs and a batch size of 64. One batch of input is distributed across the replicas (4 GPUs), and each replica gets an input of size 16.
The model on each replica does a forward pass with its respective input and calculates the loss. Now, instead of dividing the loss by the number of examples in its respective input (
BATCH_SIZE_PER_REPLICA
= 16), the loss should be divided by theGLOBAL_BATCH_SIZE
(64).
Why do this?
- This needs to be done because after the gradients are calculated on each replica, they are synced across the replicas by summing them.
How to do this in TensorFlow?
Loss reduction and scaling is done automatically in Keras
Model.compile
andModel.fit
If you're writing a custom training loop, as in this tutorial, you should sum the per example losses and divide the sum by the
GLOBAL_BATCH_SIZE
:scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE)
or you can usetf.nn.compute_average_loss
which takes the per example loss, optional sample weights, andGLOBAL_BATCH_SIZE
as arguments and returns the scaled loss.If you're writing a custom training loop for a model with a non-empty list of
Model.losses
(e.g., weight regularizers), you should sum them up and divide the sum by the number of replicas. You can do this by using thetf.nn.scale_regularization_loss
function.Be careful about batches that are shorter than the
GLOBAL_BATCH_SIZE
, if your training data allows them: Dividing the prediction loss byGLOBAL_BATCH_SIZE
(instead of usingtf.reduce_mean
over the actual batch size) avoids overweighting examples from short batches. However, this does not apply to regularization losses.If using
tf.keras.losses
classes (as in the example below), the loss reduction needs to be explicitly specified to be one ofNONE
orSUM
.AUTO
andSUM_OVER_BATCH_SIZE
are disallowed when used withtf.distribute.Strategy
.AUTO
is disallowed because the user should explicitly think about what reduction they want to make sure it is correct in the distributed case.SUM_OVER_BATCH_SIZE
is disallowed because currently it would only divide by per replica batch size, and leave the dividing by number of replicas to the user, which might be easy to miss. So, instead, you need to do the reduction yourself explicitly.If
labels
is multi-dimensional, then average theper_example_loss
across the number of elements in each sample. For example, if the shape ofpredictions
is(batch_size, H, W, n_classes)
andlabels
is(batch_size, H, W)
, you will need to updateper_example_loss
like:per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)
with strategy.scope():
# Set reduction to `NONE` so you can do the reduction afterwards and divide by
# global batch size.
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE)
def compute_loss(labels, predictions, model_losses):
per_example_loss = loss_object(labels, predictions)
loss = tf.nn.compute_average_loss(per_example_loss,
global_batch_size=GLOBAL_BATCH_SIZE)
if model_losses:
loss += tf.nn.scale_regularization_loss(tf.add_n(model_losses))
return loss
Define the metrics to track loss and accuracy
These metrics track the test loss and training and test accuracy. You can use .result()
to get the accumulated statistics at any time.
with strategy.scope():
test_loss = tf.keras.metrics.Mean(name='test_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='train_accuracy')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='test_accuracy')
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Training loop
# A model, an optimizer, and a checkpoint must be created under `strategy.scope`.
with strategy.scope():
model = create_model()
optimizer = tf.keras.optimizers.Adam()
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
def train_step(inputs):
images, labels = inputs
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = compute_loss(labels, predictions, model.losses)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_accuracy.update_state(labels, predictions)
return loss
def test_step(inputs):
images, labels = inputs
predictions = model(images, training=False)
t_loss = loss_object(labels, predictions)
test_loss.update_state(t_loss)
test_accuracy.update_state(labels, predictions)
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
axis=None)
@tf.function
def distributed_test_step(dataset_inputs):
return strategy.run(test_step, args=(dataset_inputs,))
for epoch in range(EPOCHS):
# TRAIN LOOP
total_loss = 0.0
num_batches = 0
for x in train_dist_dataset:
total_loss += distributed_train_step(x)
num_batches += 1
train_loss = total_loss / num_batches
# TEST LOOP
for x in test_dist_dataset:
distributed_test_step(x)
if epoch % 2 == 0:
checkpoint.save(checkpoint_prefix)
template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
"Test Accuracy: {}")
print(template.format(epoch + 1, train_loss,
train_accuracy.result() * 100, test_loss.result(),
test_accuracy.result() * 100))
test_loss.reset_states()
train_accuracy.reset_states()
test_accuracy.reset_states()
INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1 Epoch 1, Loss: 0.6628355383872986, Accuracy: 76.33332824707031, Test Loss: 0.4561820924282074, Test Accuracy: 82.81999969482422 Epoch 2, Loss: 0.39365142583847046, Accuracy: 85.75666809082031, Test Loss: 0.4020243287086487, Test Accuracy: 85.23999786376953 Epoch 3, Loss: 0.3401503264904022, Accuracy: 87.65333557128906, Test Loss: 0.35901474952697754, Test Accuracy: 86.7699966430664 Epoch 4, Loss: 0.31314024329185486, Accuracy: 88.69166564941406, Test Loss: 0.33639273047447205, Test Accuracy: 87.66999816894531 Epoch 5, Loss: 0.2936164140701294, Accuracy: 89.21833038330078, Test Loss: 0.3192780017852783, Test Accuracy: 88.05999755859375 Epoch 6, Loss: 0.2765035033226013, Accuracy: 89.8933334350586, Test Loss: 0.29890647530555725, Test Accuracy: 89.3499984741211 Epoch 7, Loss: 0.26154735684394836, Accuracy: 90.46833801269531, Test Loss: 0.28919824957847595, Test Accuracy: 89.5 Epoch 8, Loss: 0.24878241121768951, Accuracy: 90.86833953857422, Test Loss: 0.2841128408908844, Test Accuracy: 89.76000213623047 Epoch 9, Loss: 0.2387547492980957, Accuracy: 91.1816635131836, Test Loss: 0.27830588817596436, Test Accuracy: 89.70000457763672 Epoch 10, Loss: 0.22312773764133453, Accuracy: 91.86499786376953, Test Loss: 0.27480950951576233, Test Accuracy: 90.26000213623047
Things to note in the example above:
- Iterate over the
train_dist_dataset
andtest_dist_dataset
using afor x in ...
construct. - The scaled loss is the return value of the
distributed_train_step
. This value is aggregated across replicas using thetf.distribute.Strategy.reduce
call and then across batches by summing the return value of thetf.distribute.Strategy.reduce
calls. tf.keras.Metrics
should be updated insidetrain_step
andtest_step
that gets executed bytf.distribute.Strategy.run
.tf.distribute.Strategy.run
returns results from each local replica in the strategy, and there are multiple ways to consume this result. You can dotf.distribute.Strategy.reduce
to get an aggregated value. You can also dotf.distribute.Strategy.experimental_local_results
to get the list of values contained in the result, one per local replica.
Restore the latest checkpoint and test
A model checkpointed with a tf.distribute.Strategy
can be restored with or without a strategy.
eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='eval_accuracy')
new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
@tf.function
def eval_step(images, labels):
predictions = new_model(images, training=False)
eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
for images, labels in test_dataset:
eval_step(images, labels)
print('Accuracy after restoring the saved model without strategy: {}'.format(
eval_accuracy.result() * 100))
Accuracy after restoring the saved model without strategy: 89.70000457763672
Alternate ways of iterating over a dataset
Using iterators
If you want to iterate over a given number of steps and not through the entire dataset, you can create an iterator using the iter
call and explicitly call next
on the iterator. You can choose to iterate over the dataset both inside and outside the tf.function
. Here is a small snippet demonstrating iteration of the dataset outside the tf.function
using an iterator.
for _ in range(EPOCHS):
total_loss = 0.0
num_batches = 0
train_iter = iter(train_dist_dataset)
for _ in range(10):
total_loss += distributed_train_step(next(train_iter))
num_batches += 1
average_train_loss = total_loss / num_batches
template = ("Epoch {}, Loss: {}, Accuracy: {}")
print(template.format(epoch + 1, average_train_loss, train_accuracy.result() * 100))
train_accuracy.reset_states()
Epoch 10, Loss: 0.20553798973560333, Accuracy: 92.5 Epoch 10, Loss: 0.19978073239326477, Accuracy: 92.8125 Epoch 10, Loss: 0.2043170928955078, Accuracy: 92.5 Epoch 10, Loss: 0.20364265143871307, Accuracy: 92.4609375 Epoch 10, Loss: 0.20627591013908386, Accuracy: 92.4609375 Epoch 10, Loss: 0.21071460843086243, Accuracy: 92.1484375 Epoch 10, Loss: 0.2239733189344406, Accuracy: 92.1484375 Epoch 10, Loss: 0.21312539279460907, Accuracy: 92.6953125 Epoch 10, Loss: 0.21038644015789032, Accuracy: 92.03125 Epoch 10, Loss: 0.21483521163463593, Accuracy: 92.0703125
Iterating inside a tf.function
You can also iterate over the entire input train_dist_dataset
inside a tf.function
using the for x in ...
construct or by creating iterators like you did above. The example below demonstrates wrapping one epoch of training with a @tf.function
decorator and iterating over train_dist_dataset
inside the function.
@tf.function
def distributed_train_epoch(dataset):
total_loss = 0.0
num_batches = 0
for x in dataset:
per_replica_losses = strategy.run(train_step, args=(x,))
total_loss += strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
num_batches += 1
return total_loss / tf.cast(num_batches, dtype=tf.float32)
for epoch in range(EPOCHS):
train_loss = distributed_train_epoch(train_dist_dataset)
template = ("Epoch {}, Loss: {}, Accuracy: {}")
print(template.format(epoch + 1, train_loss, train_accuracy.result() * 100))
train_accuracy.reset_states()
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py:461: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options. warnings.warn("To make it possible to preserve tf.data options across " INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1 Epoch 1, Loss: 0.21304954588413239, Accuracy: 92.1866683959961 Epoch 2, Loss: 0.20434944331645966, Accuracy: 92.57833099365234 Epoch 3, Loss: 0.193551167845726, Accuracy: 92.84666442871094 Epoch 4, Loss: 0.18342040479183197, Accuracy: 93.31666564941406 Epoch 5, Loss: 0.17504367232322693, Accuracy: 93.53333282470703 Epoch 6, Loss: 0.16769950091838837, Accuracy: 93.788330078125 Epoch 7, Loss: 0.16176080703735352, Accuracy: 94.05999755859375 Epoch 8, Loss: 0.15285256505012512, Accuracy: 94.41000366210938 Epoch 9, Loss: 0.14692439138889313, Accuracy: 94.70333099365234 Epoch 10, Loss: 0.1429813951253891, Accuracy: 94.81666564941406
Tracking training loss across replicas
Because of the loss scaling computation that is carried out, it's not recommended to use tf.keras.metrics.Mean
to track the training loss across different replicas.
For example, if you run a training job with the following characteristics:
- Two replicas
- Two samples are processed on each replica
- Resulting loss values: [2, 3] and [4, 5] on each replica
- Global batch size = 4
With loss scaling, you calculate the per-sample value of loss on each replica by adding the loss values, and then dividing by the global batch size. In this case: (2 + 3) / 4 = 1.25
and (4 + 5) / 4 = 2.25
.
If you use tf.keras.metrics.Mean
to track loss across the two replicas, the result is different. In this example, you end up with a total
of 3.50 and count
of 2, which results in total
/count
= 1.75 when result()
is called on the metric. Loss calculated with tf.keras.Metrics
is scaled by an additional factor that is equal to the number of replicas in sync.
Guide and examples
Here are some examples for using distribution strategy with custom training loops:
- Distributed training guide
- DenseNet example using
MirroredStrategy
. - BERT example trained using
MirroredStrategy
andTPUStrategy
. This example is particularly helpful for understanding how to load from a checkpoint and generate periodic checkpoints during distributed training etc. - NCF example trained using
MirroredStrategy
that can be enabled using thekeras_use_ctl
flag. - NMT example trained using
MirroredStrategy
.
You can find more examples listed under Examples and tutorials in the Distribution strategy guide.
Next steps
- Try out the new
tf.distribute.Strategy
API on your models. - Visit the Better performance with
tf.function
and TensorFlow Profiler guides to learn more about tools to optimize the performance of your TensorFlow models. - Check out the Distributed training in TensorFlow guide, which provides an overview of the available distribution strategies.