![]() |
![]() |
![]() |
![]() |
Overview
This tutorial demonstrates how tf.distribute.Strategy
can be used for distributed multi-worker training with tf.estimator
. If you write your code using tf.estimator
, and you're interested in scaling beyond a single machine with high performance, this tutorial is for you.
Before getting started, please read the distribution strategy guide. The multi-GPU training tutorial is also relevant, because this tutorial uses the same model.
Setup
First, setup TensorFlow and the necessary imports.
import tensorflow_datasets as tfds
import tensorflow as tf
import os, json
2023-09-01 04:34:06.244961: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-09-01 04:34:06.245008: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-09-01 04:34:06.245040: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/distributions/distribution.py:259: ReparameterizationType.__init__ (from tensorflow.python.ops.distributions.distribution) is deprecated and will be removed after 2019-01-01. Instructions for updating: The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/distributions/bernoulli.py:165: RegisterKL.__init__ (from tensorflow.python.ops.distributions.kullback_leibler) is deprecated and will be removed after 2019-01-01. Instructions for updating: The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.
tf.compat.v1.disable_eager_execution()
Input function
This tutorial uses the MNIST dataset from TensorFlow Datasets. The code here is similar to the multi-GPU training tutorial with one key difference: when using Estimator for multi-worker training, it is necessary to shard the dataset by the number of workers to ensure model convergence. The input data is sharded by worker index, so that each worker processes 1/num_workers
distinct portions of the dataset.
BUFFER_SIZE = 10000
BATCH_SIZE = 64
def input_fn(mode, input_context=None):
datasets, info = tfds.load(name='mnist',
with_info=True,
as_supervised=True)
mnist_dataset = (datasets['train'] if mode == tf.estimator.ModeKeys.TRAIN else
datasets['test'])
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
if input_context:
mnist_dataset = mnist_dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
return mnist_dataset.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
Another reasonable approach to achieve convergence would be to shuffle the dataset with distinct seeds at each worker.
Multi-worker configuration
One of the key differences in this tutorial (compared to the multi-GPU training tutorial) is the multi-worker setup. The TF_CONFIG
environment variable is the standard way to specify the cluster configuration to each worker that is part of the cluster.
There are two components of TF_CONFIG
: cluster
and task
. cluster
provides information about the entire cluster, namely the workers and parameter servers in the cluster. task
provides information about the current task. The first component cluster
is the same for all workers and parameter servers in the cluster, and the second component task
is different on each worker and parameter server and specifies its own type
and index
. In this example, the task type
is worker
and the task index
is 0
.
For illustration purposes, this tutorial shows how to set a TF_CONFIG
with 2 workers on localhost
. In practice, you would create multiple workers on an external IP address and port, and set TF_CONFIG
on each worker appropriately, i.e. modify the task index
.
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': ["localhost:12345", "localhost:23456"]
},
'task': {'type': 'worker', 'index': 0}
})
Define the model
Write the layers, the optimizer, and the loss function for training. This tutorial defines the model with Keras layers, similar to the multi-GPU training tutorial.
LEARNING_RATE = 1e-4
def model_fn(features, labels, mode):
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
logits = model(features, training=False)
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = {'logits': logits}
return tf.estimator.EstimatorSpec(labels=labels, predictions=predictions)
optimizer = tf.compat.v1.train.GradientDescentOptimizer(
learning_rate=LEARNING_RATE)
loss = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(labels, logits)
loss = tf.reduce_sum(loss) * (1. / BATCH_SIZE)
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(mode, loss=loss)
return tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
train_op=optimizer.minimize(
loss, tf.compat.v1.train.get_or_create_global_step()))
MultiWorkerMirroredStrategy
To train the model, use an instance of tf.distribute.experimental.MultiWorkerMirroredStrategy
. MultiWorkerMirroredStrategy
creates copies of all variables in the model's layers on each device across all workers. It uses CollectiveOps
, a TensorFlow op for collective communication, to aggregate gradients and keep the variables in sync. The tf.distribute.Strategy
guide has more details about this strategy.
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_415982/349189047.py:1: _CollectiveAllReduceStrategyExperimental.__init__ (from tensorflow.python.distribute.collective_all_reduce_strategy) is deprecated and will be removed in a future version. Instructions for updating: use distribute.MultiWorkerMirroredStrategy instead INFO:tensorflow:Using MirroredStrategy with devices ('/device:GPU:0', '/device:GPU:1', '/device:GPU:2', '/device:GPU:3') INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0', '/device:GPU:1', '/device:GPU:2', '/device:GPU:3'), communication = CommunicationImplementation.AUTO
Train and evaluate the model
Next, specify the distribution strategy in the RunConfig
for the estimator, and train and evaluate by invoking tf.estimator.train_and_evaluate
. This tutorial distributes only the training by specifying the strategy via train_distribute
. It is also possible to distribute the evaluation via eval_distribute
.
config = tf.estimator.RunConfig(train_distribute=strategy)
classifier = tf.estimator.Estimator(
model_fn=model_fn, model_dir='/tmp/multiworker', config=config)
tf.estimator.train_and_evaluate(
classifier,
train_spec=tf.estimator.TrainSpec(input_fn=input_fn),
eval_spec=tf.estimator.EvalSpec(input_fn=input_fn)
)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_415982/2557501124.py:1: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Initializing RunConfig with distribution strategies. INFO:tensorflow:Not using Distribute Coordinator. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_415982/2557501124.py:3: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Using config: {'_model_dir': '/tmp/multiworker', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.collective_all_reduce_strategy._CollectiveAllReduceStrategyExperimental object at 0x7efcb00977f0>, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None} WARNING:tensorflow:From /tmpfs/tmp/ipykernel_415982/2557501124.py:7: TrainSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_415982/2557501124.py:8: EvalSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_415982/2557501124.py:5: train_and_evaluate (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Running training and evaluation locally (non-distributed). INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1246: StrategyBase.configure (from tensorflow.python.distribute.distribute_lib) is deprecated and will be removed in a future version. Instructions for updating: use `update_config_proto` instead. INFO:tensorflow:The `input_fn` accepts an `input_context` which will be given by DistributionStrategy /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py:458: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options. warnings.warn("To make it possible to preserve tf.data options across " INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py:459: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py:459: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.v1.input_lib) is deprecated and will be removed in a future version. Instructions for updating: Use the iterator's `initializer` property instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.v1.input_lib) is deprecated and will be removed in a future version. Instructions for updating: Use the iterator's `initializer` property instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/multiworker/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/multiworker/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... 2023-09-01 04:34:15.742656: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} } . Registered: device='CPU' 2023-09-01 04:34:15.743930: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} } . Registered: device='CPU' 2023-09-01 04:34:15.751213: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} } . Registered: device='CPU' 2023-09-01 04:34:15.751699: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} } . Registered: device='CPU' 2023-09-01 04:34:15.772204: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} } . Registered: device='CPU' 2023-09-01 04:34:15.772680: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} } . Registered: device='CPU' WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. 2023-09-01 04:34:15.779191: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} } . Registered: device='CPU' 2023-09-01 04:34:15.779702: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} } . Registered: device='CPU' WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:loss = 2.3046074, step = 0 INFO:tensorflow:loss = 2.3046074, step = 0 INFO:tensorflow:global_step/sec: 164.069 INFO:tensorflow:global_step/sec: 164.069 INFO:tensorflow:loss = 2.3018882, step = 100 (0.612 sec) INFO:tensorflow:loss = 2.3018882, step = 100 (0.612 sec) INFO:tensorflow:global_step/sec: 224.868 INFO:tensorflow:global_step/sec: 224.868 INFO:tensorflow:loss = 2.3000152, step = 200 (0.444 sec) INFO:tensorflow:loss = 2.3000152, step = 200 (0.444 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 234... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 234... INFO:tensorflow:Saving checkpoints for 234 into /tmp/multiworker/model.ckpt. INFO:tensorflow:Saving checkpoints for 234 into /tmp/multiworker/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 234... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 234... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2023-09-01T04:34:19 INFO:tensorflow:Starting evaluation at 2023-09-01T04:34:19 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/multiworker/model.ckpt-234 INFO:tensorflow:Restoring parameters from /tmp/multiworker/model.ckpt-234 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [10/100] INFO:tensorflow:Evaluation [10/100] INFO:tensorflow:Evaluation [20/100] INFO:tensorflow:Evaluation [20/100] INFO:tensorflow:Evaluation [30/100] INFO:tensorflow:Evaluation [30/100] INFO:tensorflow:Evaluation [40/100] INFO:tensorflow:Evaluation [40/100] INFO:tensorflow:Evaluation [50/100] INFO:tensorflow:Evaluation [50/100] INFO:tensorflow:Evaluation [60/100] INFO:tensorflow:Evaluation [60/100] INFO:tensorflow:Evaluation [70/100] INFO:tensorflow:Evaluation [70/100] INFO:tensorflow:Evaluation [80/100] INFO:tensorflow:Evaluation [80/100] INFO:tensorflow:Evaluation [90/100] INFO:tensorflow:Evaluation [90/100] INFO:tensorflow:Evaluation [100/100] INFO:tensorflow:Evaluation [100/100] INFO:tensorflow:Inference Time : 1.26623s INFO:tensorflow:Inference Time : 1.26623s INFO:tensorflow:Finished evaluation at 2023-09-01-04:34:21 INFO:tensorflow:Finished evaluation at 2023-09-01-04:34:21 INFO:tensorflow:Saving dict for global step 234: global_step = 234, loss = 2.302158 INFO:tensorflow:Saving dict for global step 234: global_step = 234, loss = 2.302158 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 234: /tmp/multiworker/model.ckpt-234 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 234: /tmp/multiworker/model.ckpt-234 INFO:tensorflow:Loss for final step: 2.3058825. INFO:tensorflow:Loss for final step: 2.3058825. ({'loss': 2.302158, 'global_step': 234}, [])
Optimize training performance
You now have a model and a multi-worker capable Estimator powered by tf.distribute.Strategy
. You can try the following techniques to optimize performance of multi-worker training:
- Increase the batch size: The batch size specified here is per-GPU. In general, the largest batch size that fits the GPU memory is advisable.
- Cast variables: Cast the variables to
tf.float
if possible. The official ResNet model includes an example of how this can be done. Use collective communication:
MultiWorkerMirroredStrategy
provides multiple collective communication implementations.RING
implements ring-based collectives using gRPC as the cross-host communication layer.NCCL
uses Nvidia's NCCL to implement collectives.AUTO
defers the choice to the runtime.
The best choice of collective implementation depends upon the number and kind of GPUs, and the network interconnect in the cluster. To override the automatic choice, specify a valid value to the
communication
parameter ofMultiWorkerMirroredStrategy
's constructor, e.g.communication=tf.distribute.experimental.CollectiveCommunication.NCCL
.
Visit the Performance section in the guide to learn more about other strategies and tools you can use to optimize the performance of your TensorFlow models.
Other code examples
- End to end example for multi worker training in tensorflow/ecosystem using Kubernetes templates. This example starts with a Keras model and converts it to an Estimator using the
tf.keras.estimator.model_to_estimator
API. - Official models, many of which can be configured to run multiple distribution strategies.