Watch talks from the 2019 TensorFlow Dev Summit Watch now

Distributed Training in TensorFlow

View on Run in Google Colab View source on GitHub


The tf.distribute.Strategy API is an easy way to distribute your training across multiple devices/machines. Our goal is to allow users to use existing models and training code with minimal changes to enable distributed training.

Currently, in core TensorFlow, we support tf.distribute.MirroredStrategy. This does in-graph replication with synchronous training on many GPUs on one machine. Essentially, we create copies of all variables in the model's layers on each device. We then use all-reduce to combine gradients across the devices before applying them to the variables to keep them in sync.

Many other strategies are available in TensorFlow 1.12+ contrib and will soon be available in core TensorFlow. You can find more information about them in the contrib README. You can also read the public design review for updating the tf.distribute.Strategy API as part of the move from contrib to core TF.

Example with Keras API

Let's see how to scale to multiple GPUs on one machine using MirroredStrategy with tf.keras.

We will take a very simple model consisting of a single layer. First, we will import Tensorflow.

!pip install -q tf-nightly
from __future__ import absolute_import, division, print_function

import tensorflow as tf

To distribute a Keras model on multiple GPUs using MirroredStrategy, we first instantiate a MirroredStrategy object.

strategy = tf.distribute.MirroredStrategy()
WARNING: Logging before flag parsing goes to stderr.
W0305 20:03:27.688184 139919445513984] Not all devices in <a href="./../api_docs/python/tf/distribute/Strategy"><code>tf.distribute.Strategy</code></a> are visible to TensorFlow.

We then create and compile the Keras model in strategy.scope.

with strategy.scope():
  inputs = tf.keras.layers.Input(shape=(1,))
  predictions = tf.keras.layers.Dense(1)(inputs)
  model = tf.keras.models.Model(inputs=inputs, outputs=predictions)
W0305 20:03:27.721799 139919445513984] From /usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/ calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor

Let's also define a simple input dataset for training this model.

train_dataset =[1.], [1.])).repeat(10000).batch(10)

To train the model we call Keras fit API using the input dataset that we created earlier, same as how we would in a non-distributed case., epochs=5, steps_per_epoch=10)
W0305 20:03:27.918351 139919445513984] Expected a shuffled dataset but input dataset `x` is not shuffled. Please invoke `shuffle()` on input dataset.

Epoch 1/5
10/10 [==============================] - 0s 8ms/step - loss: 0.0815
Epoch 2/5
10/10 [==============================] - 0s 1ms/step - loss: 1.4211e-15
Epoch 3/5
10/10 [==============================] - 0s 1ms/step - loss: 0.0000e+00
Epoch 4/5
10/10 [==============================] - 0s 1ms/step - loss: 0.0000e+00
Epoch 5/5
10/10 [==============================] - 0s 1ms/step - loss: 0.0000e+00

<tensorflow.python.keras.callbacks.History at 0x7f4054693eb8>

Similarly, we can also call evaluate and predict as before using appropriate datasets.

eval_dataset =[1.], [1.])).repeat(100).batch(10)
model.evaluate(eval_dataset, steps=10)
10/10 [==============================] - 0s 6ms/step - loss: 0.0000e+00

predict_dataset =[1.])).repeat(10).batch(2)
model.predict(predict_dataset, steps=5)
       [1.]], dtype=float32)

That's all you need to train your model with Keras on multiple GPUs with MirroredStrategy. It will take care of splitting up the input dataset, replicating layers and variables on each device, and combining and applying gradients.

The model and input code does not have to change because we have changed the underlying components of TensorFlow (such as optimizer, batch norm and summaries) to become strategy-aware. That means those components know how to combine their state across devices. Further, saving and checkpointing works seamlessly, so you can save with one or no distribute strategy and resume with another.

Example with Estimator API

You can also use tf.distribute.Strategy API with Estimator. Let's see a simple example of it's usage with MirroredStrategy.

We will use the LinearRegressor premade estimator as an example. To use MirroredStrategy with Estimator, all we need to do is:

  • Create an instance of the MirroredStrategy class.
  • Pass it to the RunConfig parameter of the custom or premade Estimator.
strategy = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(
    train_distribute=strategy, eval_distribute=strategy)
regressor = tf.estimator.LinearRegressor(
W0305 20:03:29.054260 139919445513984] Not all devices in <a href="./../api_docs/python/tf/distribute/Strategy"><code>tf.distribute.Strategy</code></a> are visible to TensorFlow.
W0305 20:03:29.057693 139919445513984] Using temporary folder as model directory: /tmp/tmpjc86jvft

We will define a simple input function to feed data for training this model.

def input_fn():
  return{"feats":[1.]}, [1.])).repeat(10000).batch(10)

Then we can call train on the regressor instance to train the model.

regressor.train(input_fn=input_fn, steps=10)
W0305 20:03:29.186089 139898593982208] Partitioned variables are disabled when using current tf.distribute.Strategy.
W0305 20:03:29.205616 139898593982208] From /usr/local/lib/python3.5/dist-packages/tensorflow/python/feature_column/ to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use <a href="./../api_docs/python/tf/dtypes/cast"><code>tf.cast</code></a> instead.

<tensorflow_estimator.python.estimator.canned.linear.LinearRegressor at 0x7f40543bf860>

And we can evaluate to evaluate the trained model.

regressor.evaluate(input_fn=input_fn, steps=10)
W0305 20:03:30.710714 139919445513984] From /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/ checkpoint_exists (from is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.

{'average_loss': 7.979226e+16,
 'global_step': 10,
 'label/mean': 1.0,
 'loss': 7.979226e+17,
 'prediction/mean': -282475260.0}

That's it! This change will now configure estimator to run on all GPUs on your machine.

Customization and Performance Tips

Above, we showed the easiest way to use MirroredStrategy. There are few things you can customize in practice:

  • You can specify a list of specific GPUs (using param devices), in case you don't want auto detection.
  • You can specify various parameters for all reduce with the cross_device_ops param, such as the all reduce algorithm to use, and gradient repacking.

We've tried to make it such that you get the best performance for your existing model without having to specify any additional options. We also recommend you follow the tips from Input Pipeline Performance Guide. Specifically, we found using map_and_batch and dataset.prefetch in the input function gives a solid boost in performance. When using dataset.prefetch, use to let it detect optimal buffer size.


This API is still in progress there are a lot of improvements forthcoming:

  • Summaries are only computed in the first replica in MirroredStrategy.
  • PartitionedVariables are not supported yet.
  • Performance improvements, especially w.r.t. input handling, eager execution etc.

What's next?

tf.distribute.Strategy is actively under development and we will be adding more examples and tutorials in the near future. Please give it a try, we welcome your feedback via issues on GitHub.