![]() |
![]() |
![]() |
![]() |
This guide demonstrates how to migrate the single-worker multiple-GPU workflows from TensorFlow 1 to TensorFlow 2.
To perform synchronous training across multiple GPUs on one machine:
- In TensorFlow 1, you use the
tf.estimator.Estimator
APIs withtf.distribute.MirroredStrategy
. - In TensorFlow 2, you can use Keras Model.fit or a custom training loop with
tf.distribute.MirroredStrategy
. Learn more in the Distributed training with TensorFlow guide.
Setup
Start with imports and a simple dataset for demonstration purposes:
import tensorflow as tf
import tensorflow.compat.v1 as tf1
2022-12-14 03:45:01.796383: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 03:45:01.796474: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 03:45:01.796483: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]
eval_labels = [[0.8], [0.9], [1.]]
TensorFlow 1: Single-worker distributed training with tf.estimator.Estimator
This example demonstrates the TensorFlow 1 canonical workflow of single-worker multiple-GPU training. You need to set the distribution strategy (tf.distribute.MirroredStrategy
) through the config
parameter of the tf.estimator.Estimator
:
def _input_fn():
return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)
def _eval_input_fn():
return tf1.data.Dataset.from_tensor_slices(
(eval_features, eval_labels)).batch(1)
def _model_fn(features, labels, mode):
logits = tf1.layers.Dense(1)(features)
loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
optimizer = tf1.train.AdagradOptimizer(0.05)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
strategy = tf1.distribute.MirroredStrategy()
config = tf1.estimator.RunConfig(
train_distribute=strategy, eval_distribute=strategy)
estimator = tf1.estimator.Estimator(model_fn=_model_fn, config=config)
train_spec = tf1.estimator.TrainSpec(input_fn=_input_fn)
eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)
tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') INFO:tensorflow:Initializing RunConfig with distribution strategies. INFO:tensorflow:Not using Distribute Coordinator. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpck6hpnzc INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpck6hpnzc', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategyV1 object at 0x7f78f00fc520>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategyV1 object at 0x7f78f00fc520>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None} INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Running training and evaluation locally (non-distributed). INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1244: StrategyBase.configure (from tensorflow.python.distribute.distribute_lib) is deprecated and will be removed in a future version. Instructions for updating: use `update_config_proto` instead. /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py:461: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options. warnings.warn("To make it possible to preserve tf.data options across " INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:batch_all_reduce: 2 all-reduces with algorithm = nccl, num_packs = 1 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089 INFO:tensorflow:Create CheckpointSaverHook. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.v1.input_lib) is deprecated and will be removed in a future version. Instructions for updating: Use the iterator's `initializer` property instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpck6hpnzc/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... 2022-12-14 03:45:08.836664: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} } . Registered: device='CPU' 2022-12-14 03:45:08.837978: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} } . Registered: device='CPU' 2022-12-14 03:45:08.850093: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} } . Registered: device='CPU' 2022-12-14 03:45:08.850568: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} } . Registered: device='CPU' 2022-12-14 03:45:08.860478: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} } . Registered: device='CPU' 2022-12-14 03:45:08.860959: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} } . Registered: device='CPU' 2022-12-14 03:45:08.868287: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} } . Registered: device='CPU' 2022-12-14 03:45:08.868767: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} } . Registered: device='CPU' INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:batch_all_reduce: 2 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Starting evaluation at 2022-12-14T03:45:10 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpck6hpnzc/model.ckpt-0 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Inference Time : 0.46764s INFO:tensorflow:Finished evaluation at 2022-12-14-03:45:11 INFO:tensorflow:Saving dict for global step 0: global_step = 0, loss = 0.00045113903 2022-12-14 03:45:10.896895: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} } . Registered: device='CPU' 2022-12-14 03:45:10.898253: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} } . Registered: device='CPU' 2022-12-14 03:45:10.901212: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} } . Registered: device='CPU' 2022-12-14 03:45:10.901715: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} } . Registered: device='CPU' 2022-12-14 03:45:10.916974: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} } . Registered: device='CPU' 2022-12-14 03:45:10.917460: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} } . Registered: device='CPU' 2022-12-14 03:45:10.926336: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} } . Registered: device='CPU' 2022-12-14 03:45:10.926867: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} } . Registered: device='CPU' INFO:tensorflow:Saving 'checkpoint_path' summary for global step 0: /tmpfs/tmp/tmpck6hpnzc/model.ckpt-0 INFO:tensorflow:Loss for final step: None. ({'loss': 0.00045113903, 'global_step': 0}, [])
TensorFlow 2: Single-worker training with Keras
When migrating to TensorFlow 2, you can use the Keras APIs with tf.distribute.MirroredStrategy
.
If you use the tf.keras
APIs for model building and Keras Model.fit
for training, the main difference is instantiating the Keras model, an optimizer, and metrics in the context of Strategy.scope
, instead of defining a config
for tf.estimator.Estimator
.
If you need to use a custom training loop, check out the Using tf.distribute.Strategy with custom training loops guide.
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
eval_dataset = tf.data.Dataset.from_tensor_slices(
(eval_features, eval_labels)).batch(1)
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
model.compile(optimizer=optimizer, loss='mse')
model.fit(dataset)
model.evaluate(eval_dataset, return_dict=True)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). 2022-12-14 03:45:11.304365: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_cardinality" value { i: 3 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\025TensorSliceDataset:24" } } attr { key: "output_shapes" value { list { shape { dim { size: 2 } } shape { dim { size: 1 } } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } } } INFO:tensorflow:batch_all_reduce: 2 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 2 all-reduces with algorithm = nccl, num_packs = 1 3/3 [==============================] - 4s 8ms/step - loss: 4.7602 2022-12-14 03:45:15.677855: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_FLOAT } } } attr { key: "_cardinality" value { i: 3 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\025TensorSliceDataset:26" } } attr { key: "output_shapes" value { list { shape { dim { size: 2 } } shape { dim { size: 1 } } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } } } } 3/3 [==============================] - 2s 6ms/step - loss: 12.6974 {'loss': 12.697421073913574}
Next steps
To learn more about distributed training with tf.distribute.MirroredStrategy
in TensorFlow 2, check out the following documentation:
- The Distributed training on one machine with Keras tutorial
- The Distributed training on one machine with a custom training loop tutorial
- The Distributed training with TensorFlow guide
- The Using multiple GPUs guide
- The Optimize the performance on the multi-GPU single host (with the TensorFlow Profiler) guide