Registration is open for TensorFlow Dev Summit 2020 Learn more

Multi-worker training with Estimator

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

This tutorial demonstrates how tf.distribute.Strategy can be used for distributed multi-worker training with tf.estimator. If you write your code using tf.estimator, and you're interested in scaling beyond a single machine with high performance, this tutorial is for you.

Before getting started, please read the distribution strategy guide. The multi-GPU training tutorial is also relevant, because this tutorial uses the same model.

Setup

First, setup TensorFlow and the necessary imports.

from __future__ import absolute_import, division, print_function, unicode_literals 
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()

import os, json

Input function

This tutorial uses the MNIST dataset from TensorFlow Datasets. The code here is similar to the multi-GPU training tutorial with one key difference: when using Estimator for multi-worker training, it is necessary to shard the dataset by the number of workers to ensure model convergence. The input data is sharded by worker index, so that each worker processes 1/num_workers distinct portions of the dataset.

BUFFER_SIZE = 10000
BATCH_SIZE = 64

def input_fn(mode, input_context=None):
  datasets, info = tfds.load(name='mnist',
                                with_info=True,
                                as_supervised=True)
  mnist_dataset = (datasets['train'] if mode == tf.estimator.ModeKeys.TRAIN else
                   datasets['test'])

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255
    return image, label

  if input_context:
    mnist_dataset = mnist_dataset.shard(input_context.num_input_pipelines,
                                        input_context.input_pipeline_id)
  return mnist_dataset.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

Another reasonable approach to achieve convergence would be to shuffle the dataset with distinct seeds at each worker.

Multi-worker configuration

One of the key differences in this tutorial (compared to the multi-GPU training tutorial) is the multi-worker setup. The TF_CONFIG environment variable is the standard way to specify the cluster configuration to each worker that is part of the cluster.

There are two components of TF_CONFIG: cluster and task. cluster provides information about the entire cluster, namely the workers and parameter servers in the cluster. task provides information about the current task. In this example, the task type is worker and the task index is 0.

For illustration purposes, this tutorial shows how to set a TF_CONFIG with 2 workers on localhost. In practice, you would create multiple workers on an external IP address and port, and set TF_CONFIG on each worker appropriately, i.e. modify the task index.

os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["localhost:12345", "localhost:23456"]
    },
    'task': {'type': 'worker', 'index': 0}
})

Define the model

Write the layers, the optimizer, and the loss function for training. This tutorial defines the model with Keras layers, similar to the multi-GPU training tutorial.

LEARNING_RATE = 1e-4
def model_fn(features, labels, mode):
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])
  logits = model(features, training=False)

  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {'logits': logits}
    return tf.estimator.EstimatorSpec(labels=labels, predictions=predictions)

  optimizer = tf.compat.v1.train.GradientDescentOptimizer(
      learning_rate=LEARNING_RATE)
  loss = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(labels, logits)
  loss = tf.reduce_sum(loss) * (1. / BATCH_SIZE)
  if mode == tf.estimator.ModeKeys.EVAL:
    return tf.estimator.EstimatorSpec(mode, loss=loss)

  return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=optimizer.minimize(
          loss, tf.compat.v1.train.get_or_create_global_step()))

MultiWorkerMirroredStrategy

To train the model, use an instance of tf.distribute.experimental.MultiWorkerMirroredStrategy. MultiWorkerMirroredStrategy creates copies of all variables in the model's layers on each device across all workers. It uses CollectiveOps, a TensorFlow op for collective communication, to aggregate gradients and keep the variables in sync. The tf.distribute.Strategy guide has more details about this strategy.

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/device:GPU:0',)
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0',), communication = CollectiveCommunication.AUTO

Train and evaluate the model

Next, specify the distribution strategy in the RunConfig for the estimator, and train and evaluate by invoking tf.estimator.train_and_evaluate. This tutorial distributes only the training by specifying the strategy via train_distribute. It is also possible to distribute the evaluation via eval_distribute.

config = tf.estimator.RunConfig(train_distribute=strategy)

classifier = tf.estimator.Estimator(
    model_fn=model_fn, model_dir='/tmp/multiworker', config=config)
tf.estimator.train_and_evaluate(
    classifier,
    train_spec=tf.estimator.TrainSpec(input_fn=input_fn),
    eval_spec=tf.estimator.EvalSpec(input_fn=input_fn)
)
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/multiworker', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.collective_all_reduce_strategy.CollectiveAllReduceStrategy object at 0x7f4da0037240>, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
INFO:tensorflow:The `input_fn` accepts an `input_context` which will be given by DistributionStrategy
Downloading and preparing dataset mnist (11.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/1.0.0...

WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead set
data_dir=gs://tfds-data/datasets.


Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/1.0.0. Subsequent calls will reuse this data.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1635: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1635: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/multiworker/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/multiworker/model.ckpt.

INFO:tensorflow:loss = 2.3605726, step = 0

INFO:tensorflow:loss = 2.3605726, step = 0

INFO:tensorflow:global_step/sec: 184.665

INFO:tensorflow:global_step/sec: 184.665

INFO:tensorflow:loss = 2.3276105, step = 100 (0.543 sec)

INFO:tensorflow:loss = 2.3276105, step = 100 (0.543 sec)

INFO:tensorflow:global_step/sec: 199.093

INFO:tensorflow:global_step/sec: 199.093

INFO:tensorflow:loss = 2.28269, step = 200 (0.503 sec)

INFO:tensorflow:loss = 2.28269, step = 200 (0.503 sec)

INFO:tensorflow:global_step/sec: 196.765

INFO:tensorflow:global_step/sec: 196.765

INFO:tensorflow:loss = 2.2883375, step = 300 (0.508 sec)

INFO:tensorflow:loss = 2.2883375, step = 300 (0.508 sec)

INFO:tensorflow:global_step/sec: 196.866

INFO:tensorflow:global_step/sec: 196.866

INFO:tensorflow:loss = 2.3219037, step = 400 (0.507 sec)

INFO:tensorflow:loss = 2.3219037, step = 400 (0.507 sec)

INFO:tensorflow:global_step/sec: 192.743

INFO:tensorflow:global_step/sec: 192.743

INFO:tensorflow:loss = 2.2974806, step = 500 (0.520 sec)

INFO:tensorflow:loss = 2.2974806, step = 500 (0.520 sec)

INFO:tensorflow:global_step/sec: 197.844

INFO:tensorflow:global_step/sec: 197.844

INFO:tensorflow:loss = 2.2768953, step = 600 (0.506 sec)

INFO:tensorflow:loss = 2.2768953, step = 600 (0.506 sec)

INFO:tensorflow:global_step/sec: 192.525

INFO:tensorflow:global_step/sec: 192.525

INFO:tensorflow:loss = 2.2761128, step = 700 (0.519 sec)

INFO:tensorflow:loss = 2.2761128, step = 700 (0.519 sec)

INFO:tensorflow:global_step/sec: 214.582

INFO:tensorflow:global_step/sec: 214.582

INFO:tensorflow:loss = 2.2815404, step = 800 (0.465 sec)

INFO:tensorflow:loss = 2.2815404, step = 800 (0.465 sec)

INFO:tensorflow:global_step/sec: 712.033

INFO:tensorflow:global_step/sec: 712.033

INFO:tensorflow:loss = 2.289053, step = 900 (0.140 sec)

INFO:tensorflow:loss = 2.289053, step = 900 (0.140 sec)

INFO:tensorflow:Saving checkpoints for 938 into /tmp/multiworker/model.ckpt.

INFO:tensorflow:Saving checkpoints for 938 into /tmp/multiworker/model.ckpt.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Starting evaluation at 2020-01-14T02:27:42Z

INFO:tensorflow:Starting evaluation at 2020-01-14T02:27:42Z

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Restoring parameters from /tmp/multiworker/model.ckpt-938

INFO:tensorflow:Restoring parameters from /tmp/multiworker/model.ckpt-938

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Evaluation [10/100]

INFO:tensorflow:Evaluation [10/100]

INFO:tensorflow:Evaluation [20/100]

INFO:tensorflow:Evaluation [20/100]

INFO:tensorflow:Evaluation [30/100]

INFO:tensorflow:Evaluation [30/100]

INFO:tensorflow:Evaluation [40/100]

INFO:tensorflow:Evaluation [40/100]

INFO:tensorflow:Evaluation [50/100]

INFO:tensorflow:Evaluation [50/100]

INFO:tensorflow:Evaluation [60/100]

INFO:tensorflow:Evaluation [60/100]

INFO:tensorflow:Evaluation [70/100]

INFO:tensorflow:Evaluation [70/100]

INFO:tensorflow:Evaluation [80/100]

INFO:tensorflow:Evaluation [80/100]

INFO:tensorflow:Evaluation [90/100]

INFO:tensorflow:Evaluation [90/100]

INFO:tensorflow:Evaluation [100/100]

INFO:tensorflow:Evaluation [100/100]

INFO:tensorflow:Inference Time : 1.09692s

INFO:tensorflow:Inference Time : 1.09692s

INFO:tensorflow:Finished evaluation at 2020-01-14-02:27:43

INFO:tensorflow:Finished evaluation at 2020-01-14-02:27:43

INFO:tensorflow:Saving dict for global step 938: global_step = 938, loss = 2.2668731

INFO:tensorflow:Saving dict for global step 938: global_step = 938, loss = 2.2668731

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 938: /tmp/multiworker/model.ckpt-938

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 938: /tmp/multiworker/model.ckpt-938

INFO:tensorflow:Loss for final step: 1.123378.

INFO:tensorflow:Loss for final step: 1.123378.

({'loss': 2.2668731, 'global_step': 938}, [])

Optimize training performance

You now have a model and a multi-worker capable Estimator powered by tf.distribute.Strategy. You can try the following techniques to optimize performance of multi-worker training:

  • Increase the batch size: The batch size specified here is per-GPU. In general, the largest batch size that fits the GPU memory is advisable.
  • Cast variables: Cast the variables to tf.float if possible. The official ResNet model includes an example of how this can be done.
  • Use collective communication: MultiWorkerMirroredStrategy provides multiple collective communication implementations.

    • RING implements ring-based collectives using gRPC as the cross-host communication layer.
    • NCCL uses Nvidia's NCCL to implement collectives.
    • AUTO defers the choice to the runtime.

    The best choice of collective implementation depends upon the number and kind of GPUs, and the network interconnect in the cluster. To override the automatic choice, specify a valid value to the communication parameter of MultiWorkerMirroredStrategy's constructor, e.g. communication=tf.distribute.experimental.CollectiveCommunication.NCCL.

Other code examples

  1. End to end example for multi worker training in tensorflow/ecosystem using Kubernetes templates. This example starts with a Keras model and converts it to an Estimator using the tf.keras.estimator.model_to_estimator API.
  2. Official models, many of which can be configured to run multiple distribution strategies.