Estimator を使ったマルチワーカートレーニング

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

警告: 新しいコードには Estimators は推奨されません。Estimators は v1.Session スタイルのコードを実行しますが、これは正しく記述するのはより難しく、特に TF 2 コードと組み合わせると予期しない動作をする可能性があります。Estimators は、互換性保証の対象となりますが、セキュリティの脆弱性以外の修正は行われません。詳細については、移行ガイドを参照してください。

概要

注意: tf.distribute API で Estimator を使用することもできますが、tf.distributeで Keras を使用することを推奨します。Keras を使ったマルチワーカートレーニングをご覧ください。tf.distribute.Strategy を使った Estimator のトレーニングのサポートは制限されています。

このチュートリアルでは、tf.estimatorを使った分散型マルチワーカートレーニングに tf.distribute.Strategy を使用する方法を実演します。tf.estimator を使って独自のコードを記述し、高性能のマシン一台以上のスケーリングを希望する場合は、このチュートリアルをご利用ください。

始める前に、分散ストラテジーガイドをお読みください。マルチ GPU トレーニングのチュートリアルも関連しています。このチュートリアルでは同じモデルが使用されています。

セットアップ

最初に、TensorFlow と必要なインポートをセットアップします。

import tensorflow_datasets as tfds
import tensorflow as tf

import os, json
2024-01-11 18:09:27.183559: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 18:09:27.183611: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 18:09:27.185190: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

注意: TF2.4 以降、マルチワーカーミラーリングストラテジーは、eager を有効にして実行すると (デフォルト)、Estimator で失敗します。TF2.4 のエラーは TypeError: cannot pickle '_thread.lock' object です。詳細については、課題 #46556 を参照してください。回避策は、eager の実行を無効にすることです。

tf.compat.v1.disable_eager_execution()

入力関数

このチュートリアルでは、TensorFlow Datasets の MNIST データセットを使用しています。このコードはマルチ GPU トレーニングのチュートリアルのコードに似ていますが、大きな違いが 1 つあります。マルチワーカートレーニングに Estimator を使用する際は、モデルのコンバージェンスを可能にするように、ワーカーの数でデータセットをシャーディングする必要があります。入力データは、ワーカーインデックスでシャーディングされるため、各ワーカーは、データセットの各 1/num_workers の部分を処理します。

BUFFER_SIZE = 10000
BATCH_SIZE = 64

def input_fn(mode, input_context=None):
  datasets, info = tfds.load(name='mnist',
                                with_info=True,
                                as_supervised=True)
  mnist_dataset = (datasets['train'] if mode == tf.estimator.ModeKeys.TRAIN else
                   datasets['test'])

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255
    return image, label

  if input_context:
    mnist_dataset = mnist_dataset.shard(input_context.num_input_pipelines,
                                        input_context.input_pipeline_id)
  return mnist_dataset.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

コンバージェンスを達成するためのもう 1 つの合理的なアプローチとして、各ワーカーで異なるシードを使ってデータベースをシャッフルする方法があります。

マルチワーカー構成

このチュートリアルの主な違いの 1 つに (マルチ GPU トレーニングのチュートリアル と比較)、マルチワーカーのセットアップがあります。TF_CONFIG 環境変数は、クラスタの一部である各ワーカーにクラスタ構成を指定する標準的な方法です。

TF_CONFIG には、clustertask の 2 つのコンポーネントがあります。cluster は、クラスタのワーカーとパラメータサーバーを含むクラスタ全体に関する情報を提供するのに対し、task は、現在のタスクに関する情報を提供します。最初のコンポーネント cluster は、クラスタ内のすべてのワーカーとパラメータサーバーで同一であり、2 つ目のコンポーネント task は、各ワーカーとパラメータサーバー間で異なり、それぞれに typeindex を指定します。この例では、タスクの typeworker で、タスクの index0 です。

例示を目的として、このチュートリアルでは、localhost 上に 2 つのワーカーを持つ TF_CONFIG の設定方法を示しています。実際には、外部 IP アドレスとポートに複数のワーカーを作成し、各ワーカーに適切に TF_CONFIG を設定します (タスクの index を変更します)。

警告: Colab で次のコードを実行しないでください。TensorFlow のランタイムは、指定された IP アドレスとポートに gRPC サーバーを作成しようとしますが、失敗する可能性があります。1 台のマシンで複数のワーカーをテスト実行する方法の例については、このチュートリアルの keras バージョンを参照してください。

os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["localhost:12345", "localhost:23456"]
    },
    'task': {'type': 'worker', 'index': 0}
})

モデルを定義する

トレーニング用にレイヤー、オプティマイザ、および損失関数を記述します。このチュートリアルでは、マルチ GPU トレーニングのチュートリアルと同様に、Keras レイヤーを使ったモデルを定義しています。

LEARNING_RATE = 1e-4
def model_fn(features, labels, mode):
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])
  logits = model(features, training=False)

  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {'logits': logits}
    return tf.estimator.EstimatorSpec(labels=labels, predictions=predictions)

  optimizer = tf.compat.v1.train.GradientDescentOptimizer(
      learning_rate=LEARNING_RATE)
  loss = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(labels, logits)
  loss = tf.reduce_sum(loss) * (1. / BATCH_SIZE)
  if mode == tf.estimator.ModeKeys.EVAL:
    return tf.estimator.EstimatorSpec(mode, loss=loss)

  return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=optimizer.minimize(
          loss, tf.compat.v1.train.get_or_create_global_step()))

注意: この例の学習速度は固定されていますが、一般的に、グローバルバッチサイズに基づいて学習速度を調整する必要があります。

MultiWorkerMirroredStrategy

モデルをトレーニングするために、tf.distribute.experimental.MultiWorkerMirroredStrategy のインスタンスを使用します。MultiWorkerMirroredStrategy は、すべてのワーカーの各デバイスにあるモデルのレイヤーにすべての変数のコピーを作成します。集合通信に使用する TensorFlow 演算子 CollectiveOps を使用して勾配を集め、変数の同期を維持します。このストラテジーの詳細は、tf.distribute.Strategy ガイドで説明されています。

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_41770/349189047.py:1: _CollectiveAllReduceStrategyExperimental.__init__ (from tensorflow.python.distribute.collective_all_reduce_strategy) is deprecated and will be removed in a future version.
Instructions for updating:
use distribute.MultiWorkerMirroredStrategy instead
INFO:tensorflow:Using MirroredStrategy with devices ('/device:GPU:0', '/device:GPU:1', '/device:GPU:2', '/device:GPU:3')
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0', '/device:GPU:1', '/device:GPU:2', '/device:GPU:3'), communication = CommunicationImplementation.AUTO

モデルをトレーニングして評価する

次に、分散ストラテジーを Estimator の RunConfig に指定し、tf.estimator.train_and_evaluate を呼び出してトレーニングと評価を行います。このチュートリアルでは、train_distribute 経由でストラテジーを指定してトレーニングのみを分散しています。eval_distribute を使って評価を分散することもできます。

config = tf.estimator.RunConfig(train_distribute=strategy)

classifier = tf.estimator.Estimator(
    model_fn=model_fn, model_dir='/tmp/multiworker', config=config)
tf.estimator.train_and_evaluate(
    classifier,
    train_spec=tf.estimator.TrainSpec(input_fn=input_fn),
    eval_spec=tf.estimator.EvalSpec(input_fn=input_fn)
)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_41770/2557501124.py:1: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:Not using Distribute Coordinator.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_41770/2557501124.py:3: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/multiworker', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.collective_all_reduce_strategy._CollectiveAllReduceStrategyExperimental object at 0x7ff408226c70>, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_41770/2557501124.py:7: TrainSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_41770/2557501124.py:8: EvalSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_41770/2557501124.py:5: train_and_evaluate (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1246: StrategyBase.configure (from tensorflow.python.distribute.distribute_lib) is deprecated and will be removed in a future version.
Instructions for updating:
use `update_config_proto` instead.
INFO:tensorflow:The `input_fn` accepts an `input_context` which will be given by DistributionStrategy
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py:462: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options.
  warnings.warn("To make it possible to preserve tf.data options across "
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py:459: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py:459: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 1 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.AUTO, num_packs = 1
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.v1.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.v1.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/multiworker/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/multiworker/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
2024-01-11 18:09:37.808408: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} }
    .  Registered:  device='CPU'

2024-01-11 18:09:37.809724: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} }
    .  Registered:  device='CPU'

2024-01-11 18:09:37.816927: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} }
    .  Registered:  device='CPU'

2024-01-11 18:09:37.817487: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} }
    .  Registered:  device='CPU'
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
2024-01-11 18:09:37.839310: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} }
    .  Registered:  device='CPU'

2024-01-11 18:09:37.839826: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} }
    .  Registered:  device='CPU'

2024-01-11 18:09:37.846436: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} }
    .  Registered:  device='CPU'

2024-01-11 18:09:37.846962: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} }
    .  Registered:  device='CPU'

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:loss = 2.3079295, step = 0
INFO:tensorflow:loss = 2.3079295, step = 0
INFO:tensorflow:global_step/sec: 165.101
INFO:tensorflow:global_step/sec: 165.101
INFO:tensorflow:loss = 2.2927566, step = 100 (0.608 sec)
INFO:tensorflow:loss = 2.2927566, step = 100 (0.608 sec)
INFO:tensorflow:global_step/sec: 223.16
INFO:tensorflow:global_step/sec: 223.16
INFO:tensorflow:loss = 2.3051863, step = 200 (0.447 sec)
INFO:tensorflow:loss = 2.3051863, step = 200 (0.447 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 234...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 234...
INFO:tensorflow:Saving checkpoints for 234 into /tmp/multiworker/model.ckpt.
INFO:tensorflow:Saving checkpoints for 234 into /tmp/multiworker/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 234...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 234...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2024-01-11T18:09:42
INFO:tensorflow:Starting evaluation at 2024-01-11T18:09:42
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/multiworker/model.ckpt-234
INFO:tensorflow:Restoring parameters from /tmp/multiworker/model.ckpt-234
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Evaluation [80/100]
INFO:tensorflow:Evaluation [80/100]
INFO:tensorflow:Evaluation [90/100]
INFO:tensorflow:Evaluation [90/100]
INFO:tensorflow:Evaluation [100/100]
INFO:tensorflow:Evaluation [100/100]
INFO:tensorflow:Inference Time : 1.34798s
INFO:tensorflow:Inference Time : 1.34798s
INFO:tensorflow:Finished evaluation at 2024-01-11-18:09:43
INFO:tensorflow:Finished evaluation at 2024-01-11-18:09:43
INFO:tensorflow:Saving dict for global step 234: global_step = 234, loss = 2.2969956
INFO:tensorflow:Saving dict for global step 234: global_step = 234, loss = 2.2969956
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 234: /tmp/multiworker/model.ckpt-234
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 234: /tmp/multiworker/model.ckpt-234
INFO:tensorflow:Loss for final step: 2.2864075.
INFO:tensorflow:Loss for final step: 2.2864075.
({'loss': 2.2969956, 'global_step': 234}, [])

トレーニングのパフォーマンスを最適化する

tf.distribute.Strategy により、モデルとマルチワーカー対応 Estimator の準備が整いました。次のテクニックに従って、マルチワーカートレーニングのパフォーマンスを最適化することができます。

  • バッチサイズの増加: ここで指定されるバッチサイズは、GPU 単位のサイズです。一般的に、GPU メモリに収まる最大バッチサイズの指定が推奨されます。

  • 変数のキャスト: 可能であれば、tf.float に変数をキャストしてください。公式の ResNet モデルには、どのようにしてこれを行うかのが示されています。

  • 集合通信の使用: MultiWorkerMirroredStrategy は、複数の集合通信実装を提供しています。

    • RING は、クロスホスト通信レイヤーとして、gRPC を使用したリング状の集合体を実装します。
    • NCCL は、Nvidia の NCCL を使用して集合体を実装します。
    • AUTO は、選択をランタイムに持ち越します。

    最適な集合体実装の選択肢は、GPU 数と種類、および、クラスタ内の相互接続ネットワークにより異なります。自動選択をオーバーライドするには、MultiWorkerMirroredStrategy コンストラクタの communication パラメータに、 communication=tf.distribute.experimental.CollectiveCommunication.NCCL のように有効な値を指定します。

ガイドのパフォーマンスのセクションに目を通し、独自の TensorFlow モデルのパフォーマンス最適化に使用できるほかのストラテジーやツールについてさらに詳しく学習しましょう。

その他のコード例

  1. Kubernetes テンプレートを使った tensorflow/ecosystem でマルチワーカートレーニングを行うためのエンドツーエンドの例。この例は最初に Keras モデルを使用し、それを tf.keras.estimator.model_to_estimator API を使って Estimator に変換します。
  2. 公式モデル。この多くは、複数の分散ストラテジーで実行するように構成できます。