今日のローカルTensorFlowEverywhereイベントの出欠確認!

Estimator を使ったマルチワーカートレーニング

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

概要

注意: tf.distribute API で Estimator を使用する際、tf.distribute で Keras を使用することを推奨します。Keras を使ったマルチワーカートレーニングをご覧ください。tf.distribute.Strategy を使った Estimator トレーニングのサポートは制限されています。

このチュートリアルでは、tf.estimator を使った分散型マルチワーカートレーニングに tf.distribute.Strategy を使用する方法を実演しています。tf.estimator を使って独自のコードを記述しており、高性能の単一の機械を超えるスケーリングに関心がある場合は、このチュートリアルをご利用ください。

始める前に、分散ストラテジーガイドをお読みください。マルチ GPU トレーニングのチュートリアルも関連しています。このチュートリアルでは同じモデルが使用されています。

セットアップ

最初に、TensorFlow と必要なインポートをセットアップします。

import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()

import os, json

入力関数

このチュートリアルでは、TensorFlow Datasets の MNIST データセットを使用しています。このコードはマルチ GPU トレーニングのチュートリアルのコードに似ていますが、大きな違いが 1 つあります。マルチワーカートレーニングに Estimator を使用する際は、モデルのコンバージェンスを可能にできるよう、ワーカーの数でデータセットをシャーディングする必要があります。入力データは、ワーカーインデックスでシャーディングされるため、各ワーカーは、データセットの各 1/num_workers の部分を処理します。

BUFFER_SIZE = 10000
BATCH_SIZE = 64

def input_fn(mode, input_context=None):
  datasets, info = tfds.load(name='mnist',
                                with_info=True,
                                as_supervised=True)
  mnist_dataset = (datasets['train'] if mode == tf.estimator.ModeKeys.TRAIN else
                   datasets['test'])

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255
    return image, label

  if input_context:
    mnist_dataset = mnist_dataset.shard(input_context.num_input_pipelines,
                                        input_context.input_pipeline_id)
  return mnist_dataset.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

コンバージェンスを達成するためのもう 1 つの合理的なアプローチとして、各ワーカーで異なるシードを使ってデータベースをシャッフルする方法があります。

マルチワーカー構成

このチュートリアルの主な違いの 1 つに(マルチ GPU トレーニングのチュートリアル と比較)、マルチワーカーのセットアップがあります。TF_CONFIG 環境変数は、クラスタの一部である各ワーカーにクラスタ構成を指定する標準的な方法です。

TF_CONFIG には、clustertask の 2 つのコンポーネントがあります。cluster は、クラスタのワーカーとパラメータサーバーを含むクラスタ全体に関する情報を提供するのに対し、task は、現在のタスクに関する情報を提供します。最初のコンポーネント cluster は、クラスタ内のすべてのワーカーとパラメータサーバーで同一であり、2 つ目のコンポーネント task は、各ワーカーとパラメータサーバー間で異なり、それぞれに typeindex を指定します。この例では、タスクの typeworker で、タスクの index0 です。

説明の目的により、このチュートリアルでは、localhost 上に 2 つのワーカーを持つ TF_CONFIG の設定方法を示しています。実践として、外部 IP アドレスとポートに複数のワーカーを作成し、各ワーカーに適切に TF_CONFIG を設定します(タスクの index を変更します)。

警告: 次のコードを Colab で実行しないでください。 TensorFlow のランタイムは、指定された IP アドレスとポートに gRPC サーバーを作成しようとしますが、失敗する可能性があります。

os.environ['TF_CONFIG'] = json.dumps({     'cluster': {         'worker': ["localhost:12345", "localhost:23456"]     },     'task': {'type': 'worker', 'index': 0} })

モデルを定義する

トレーニング用にレイヤー、オプティマイザ、および損失関数を記述します。このチュートリアルでは、マルチ GPU トレーニングのチュートリアルと同様に、Keras レイヤーを使ったモデルを定義しています。

LEARNING_RATE = 1e-4
def model_fn(features, labels, mode):
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])
  logits = model(features, training=False)

  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {'logits': logits}
    return tf.estimator.EstimatorSpec(labels=labels, predictions=predictions)

  optimizer = tf.compat.v1.train.GradientDescentOptimizer(
      learning_rate=LEARNING_RATE)
  loss = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(labels, logits)
  loss = tf.reduce_sum(loss) * (1. / BATCH_SIZE)
  if mode == tf.estimator.ModeKeys.EVAL:
    return tf.estimator.EstimatorSpec(mode, loss=loss)

  return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=optimizer.minimize(
          loss, tf.compat.v1.train.get_or_create_global_step()))

注意: この例の学習速度は固定されていますが、一般的に、グローバルバッチサイズに基づいて学習速度を調整する必要があります。

MultiWorkerMirroredStrategy

モデルをトレーニングするために、 tf.distribute.experimental.MultiWorkerMirroredStrategy のインスタンスを作成します。 MultiWorkerMirroredStrategy は、すべてのワーカーの各装置にあるモデルのレイヤーにすべての変数のコピーを作成します。集合通信に使用する TensorFlow 演算子 CollectiveOps を使用して勾配を集め、変数の同期を維持します。このストラテジーの詳細は、tf.distribute.Strategy ガイドで説明されています。

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
WARNING:tensorflow:From <ipython-input-1-f1f424df316e>:1: _CollectiveAllReduceStrategyExperimental.__init__ (from tensorflow.python.distribute.collective_all_reduce_strategy) is deprecated and will be removed in a future version.
Instructions for updating:
use distribute.MultiWorkerMirroredStrategy instead
INFO:tensorflow:Using MirroredStrategy with devices ('/device:GPU:0',)
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0',), communication = CommunicationImplementation.AUTO

モデルをトレーニングして評価する

次に、分散ストラテジーを Estimator の RunConfig に指定し、tf.estimator.train_and_evaluate を呼び出してトレーニングと評価を行います。このチュートリアルでは、train_distribute 経由でストラテジーを指定してトレーニングのみを分散しています。eval_distribute を使って評価を分散することもできます。

config = tf.estimator.RunConfig(train_distribute=strategy)

classifier = tf.estimator.Estimator(
    model_fn=model_fn, model_dir='/tmp/multiworker', config=config)
tf.estimator.train_and_evaluate(
    classifier,
    train_spec=tf.estimator.TrainSpec(input_fn=input_fn),
    eval_spec=tf.estimator.EvalSpec(input_fn=input_fn)
)
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/multiworker', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.collective_all_reduce_strategy._CollectiveAllReduceStrategyExperimental object at 0x7f1b1c1a1a90>, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
INFO:tensorflow:The `input_fn` accepts an `input_context` which will be given by DistributionStrategy
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into /tmp/multiworker/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/multiworker/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 2.2981951, step = 0

INFO:tensorflow:loss = 2.2981951, step = 0

INFO:tensorflow:global_step/sec: 199.399

INFO:tensorflow:global_step/sec: 199.399

INFO:tensorflow:loss = 2.2770095, step = 100 (0.504 sec)

INFO:tensorflow:loss = 2.2770095, step = 100 (0.504 sec)

INFO:tensorflow:global_step/sec: 214.822

INFO:tensorflow:global_step/sec: 214.822

INFO:tensorflow:loss = 2.2760954, step = 200 (0.466 sec)

INFO:tensorflow:loss = 2.2760954, step = 200 (0.466 sec)

INFO:tensorflow:global_step/sec: 215.119

INFO:tensorflow:global_step/sec: 215.119

INFO:tensorflow:loss = 2.2578058, step = 300 (0.465 sec)

INFO:tensorflow:loss = 2.2578058, step = 300 (0.465 sec)

INFO:tensorflow:global_step/sec: 218.735

INFO:tensorflow:global_step/sec: 218.735

INFO:tensorflow:loss = 2.2710721, step = 400 (0.457 sec)

INFO:tensorflow:loss = 2.2710721, step = 400 (0.457 sec)

INFO:tensorflow:global_step/sec: 223.869

INFO:tensorflow:global_step/sec: 223.869

INFO:tensorflow:loss = 2.264247, step = 500 (0.446 sec)

INFO:tensorflow:loss = 2.264247, step = 500 (0.446 sec)

INFO:tensorflow:global_step/sec: 225.988

INFO:tensorflow:global_step/sec: 225.988

INFO:tensorflow:loss = 2.257546, step = 600 (0.442 sec)

INFO:tensorflow:loss = 2.257546, step = 600 (0.442 sec)

INFO:tensorflow:global_step/sec: 219.498

INFO:tensorflow:global_step/sec: 219.498

INFO:tensorflow:loss = 2.2384143, step = 700 (0.456 sec)

INFO:tensorflow:loss = 2.2384143, step = 700 (0.456 sec)

INFO:tensorflow:global_step/sec: 240.676

INFO:tensorflow:global_step/sec: 240.676

INFO:tensorflow:loss = 2.2514377, step = 800 (0.415 sec)

INFO:tensorflow:loss = 2.2514377, step = 800 (0.415 sec)

INFO:tensorflow:global_step/sec: 583.017

INFO:tensorflow:global_step/sec: 583.017

INFO:tensorflow:loss = 2.2368863, step = 900 (0.171 sec)

INFO:tensorflow:loss = 2.2368863, step = 900 (0.171 sec)

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 938...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 938...

INFO:tensorflow:Saving checkpoints for 938 into /tmp/multiworker/model.ckpt.

INFO:tensorflow:Saving checkpoints for 938 into /tmp/multiworker/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 938...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 938...

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Starting evaluation at 2021-02-12T23:03:24Z

INFO:tensorflow:Starting evaluation at 2021-02-12T23:03:24Z

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Restoring parameters from /tmp/multiworker/model.ckpt-938

INFO:tensorflow:Restoring parameters from /tmp/multiworker/model.ckpt-938

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Evaluation [10/100]

INFO:tensorflow:Evaluation [10/100]

INFO:tensorflow:Evaluation [20/100]

INFO:tensorflow:Evaluation [20/100]

INFO:tensorflow:Evaluation [30/100]

INFO:tensorflow:Evaluation [30/100]

INFO:tensorflow:Evaluation [40/100]

INFO:tensorflow:Evaluation [40/100]

INFO:tensorflow:Evaluation [50/100]

INFO:tensorflow:Evaluation [50/100]

INFO:tensorflow:Evaluation [60/100]

INFO:tensorflow:Evaluation [60/100]

INFO:tensorflow:Evaluation [70/100]

INFO:tensorflow:Evaluation [70/100]

INFO:tensorflow:Evaluation [80/100]

INFO:tensorflow:Evaluation [80/100]

INFO:tensorflow:Evaluation [90/100]

INFO:tensorflow:Evaluation [90/100]

INFO:tensorflow:Evaluation [100/100]

INFO:tensorflow:Evaluation [100/100]

INFO:tensorflow:Inference Time : 1.24596s

INFO:tensorflow:Inference Time : 1.24596s

INFO:tensorflow:Finished evaluation at 2021-02-12-23:03:26

INFO:tensorflow:Finished evaluation at 2021-02-12-23:03:26

INFO:tensorflow:Saving dict for global step 938: global_step = 938, loss = 2.2316883

INFO:tensorflow:Saving dict for global step 938: global_step = 938, loss = 2.2316883

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 938: /tmp/multiworker/model.ckpt-938

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 938: /tmp/multiworker/model.ckpt-938

INFO:tensorflow:Loss for final step: 1.1226506.

INFO:tensorflow:Loss for final step: 1.1226506.

({'loss': 2.2316883, 'global_step': 938}, [])

トレーニングのパフォーマンスを最適化する

tf.distribute.Strategy により、モデルとマルチワーカー対応 Estimator の準備が整いました。次のテクニックに従って、マルチワーカートレーニングのパフォーマンスを最適化することができます。

  • バッチサイズの増加: ここで指定されるバッチサイズは、GPU 単位のサイズです。一般的に、GPU メモリに収まる最大バッチサイズの指定が推奨されます。

  • 変数のキャスト: 可能であれば、tf.float に変数をキャストしてください。公式の ResNet モデルには、どのようにしてこれを行うかのが示されています。

  • 集合通信の使用: MultiWorkerMirroredStrategy は、複数の集合通信実装を提供しています。

    • RING は、クロスホスト通信レイヤーとして、gRPC を使用したリング状の集合体を実装します。
    • NCCL は、Nvidia の NCCL を使用して集合体を実装します。
    • AUTO は、選択をランタイムに持ち越します。

    最適な集合体実装の選択肢は、GPU 数と種類によって異なり、ネットワークはクラスタ内で相互接続します。自動選択をオーバーライドするには、MultiWorkerMirroredStrategy コンストラクタの communication パラメータに、 communication=tf.distribute.experimental.CollectiveCommunication.NCCL のように有効な値を指定します。

ガイドのパフォーマンスのセクションに目を通し、独自の TensorFlow モデルのパフォーマンス最適化に使用できるほかのストラテジーやツールについてさらに詳しく学習しましょう。

その他のコード例

  1. Kubernetes テンプレートを使った tensorflow/ecosystem でマルチワーカートレーニングを行うためのエンドツーエンドの例。この例は最初に Keras モデルを使用し、それを tf.keras.estimator.model_to_estimator API を使って Estimator に変換します。
  2. 公式モデル。この多くは、複数の分散ストラテジーで実行するように構成できます。