このページは Cloud Translation API によって翻訳されました。
Switch to English

Kerasによる分散トレーニング

TensorFlow.orgで見る Google Colabで実行 GitHubでソースを表示する ノートブックをダウンロード

概観

tf.distribute.Strategy APIは、トレーニングを複数の処理ユニットに分散するための抽象化を提供します。目標は、ユーザーが既存のモデルとトレーニングコードを使用して、最小限の変更で分散トレーニングを有効にできるようにすることです。

このチュートリアルでは、 tf.distribute.MirroredStrategy使用しtf.distribute.MirroredStrategy 。これは、1台のマシン上の多くのGPUで同期トレーニングを行うグラフ内レプリケーションを実行します。基本的に、モデルのすべての変数を各プロセッサにコピーします。次に、 all-reduceを使用してすべてのプロセッサからの勾配を結合し、結合した値をモデルのすべてのコピーに適用します。

MirroredStrategyは、TensorFlowコアで利用可能ないくつかの配布戦略の1つです。詳細については、 配布戦略ガイドをご覧ください

Keras API

この例では、 tf.keras APIを使用してモデルとトレーニングループを構築します。カスタムトレーニングループについては、トレーニングループ付きtf.distribute.Strategyチュートリアルご覧ください。

依存関係をインポートする

 # Import TensorFlow and TensorFlow Datasets

import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()

import os
 
 print(tf.__version__)
 
2.2.0

データセットをダウンロードする

MNISTデータセットをダウンロードし、TensorFlow Datasetsからロードします。これは、データセットをtf.data形式でtf.dataます。

with_infoTrueに設定すると、データセット全体のメタデータが含まれ、ここでinfoに保存されます。特に、このメタデータオブジェクトには、トレーニングとテストの例の数が含まれています。

 datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)

mnist_train, mnist_test = datasets['train'], datasets['test']
 
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1...

Warning:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.


Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.

流通戦略を定義する

MirroredStrategyオブジェクトを作成します。これは配布を処理し、内部でモデルを構築するためのコンテキストマネージャー( tf.distribute.MirroredStrategy.scope )を提供します。

 strategy = tf.distribute.MirroredStrategy()
 
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

 print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
 
Number of devices: 1

入力パイプラインのセットアップ

複数のGPUを使用してモデルをトレーニングする場合、バッチサイズを増やすことにより、追加のコンピューティング能力を効果的に使用できます。一般に、GPUメモリに適合する最大のバッチサイズを使用し、それに応じて学習率を調整します。

 # You can also do info.splits.total_num_examples to get the total
# number of examples in the dataset.

num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
 

ピクセル値は0〜255で、0〜1の範囲に正規化する必要があります 。このスケールを関数で定義します。

 def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label
 

この関数をトレーニングデータとテストデータに適用し、トレーニングデータをシャッフルし、 バッチ処理してトレーニングします。パフォーマンスを向上させるために、トレーニングデータのメモリ内キャッシュも保持していることに注意してください。

 train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
 

モデルを作成する

strategy.scopeのコンテキストでKerasモデルを作成してコンパイルします。

 with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])
 

コールバックを定義する

ここで使用されるコールバックは次のとおりです。

  • TensorBoard :このコールバックは、グラフを視覚化できるTensorBoardのログを書き込みます。
  • モデルチェックポイント :このコールバックは、エポックごとにモデルを保存します。
  • 学習率スケジューラ :このコールバックを使用して、すべてのエポック/バッチの後で変化するように学習率をスケジュールできます。

説明のために、ノートブックに学習率を表示する印刷コールバックを追加します。

 # Define the checkpoint directory to store the checkpoints

checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
 
 # Function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5
 
 # Callback for printing the LR at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
                                                      model.optimizer.lr.numpy()))
 
 callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]
 

トレーニングと評価

次に、通常の方法でモデルをトレーニングし、モデルのfitを呼び出して、チュートリアルの最初に作成したデータセットを渡します。このステップは、トレーニングを配布するかどうかにかかわらず同じです。

 model.fit(train_dataset, epochs=12, callbacks=callbacks)
 
Epoch 1/12
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

936/938 [============================>.] - ETA: 0s - accuracy: 0.9422 - loss: 0.2016
Learning rate for epoch 1 is 0.0010000000474974513
938/938 [==============================] - 5s 5ms/step - accuracy: 0.9422 - loss: 0.2015 - lr: 0.0010
Epoch 2/12
936/938 [============================>.] - ETA: 0s - accuracy: 0.9807 - loss: 0.0662
Learning rate for epoch 2 is 0.0010000000474974513
938/938 [==============================] - 3s 4ms/step - accuracy: 0.9807 - loss: 0.0662 - lr: 0.0010
Epoch 3/12
933/938 [============================>.] - ETA: 0s - accuracy: 0.9863 - loss: 0.0464
Learning rate for epoch 3 is 0.0010000000474974513
938/938 [==============================] - 3s 3ms/step - accuracy: 0.9863 - loss: 0.0464 - lr: 0.0010
Epoch 4/12
933/938 [============================>.] - ETA: 0s - accuracy: 0.9933 - loss: 0.0252
Learning rate for epoch 4 is 9.999999747378752e-05
938/938 [==============================] - 3s 4ms/step - accuracy: 0.9933 - loss: 0.0252 - lr: 1.0000e-04
Epoch 5/12
932/938 [============================>.] - ETA: 0s - accuracy: 0.9946 - loss: 0.0220
Learning rate for epoch 5 is 9.999999747378752e-05
938/938 [==============================] - 3s 3ms/step - accuracy: 0.9945 - loss: 0.0220 - lr: 1.0000e-04
Epoch 6/12
929/938 [============================>.] - ETA: 0s - accuracy: 0.9951 - loss: 0.0200
Learning rate for epoch 6 is 9.999999747378752e-05
938/938 [==============================] - 3s 3ms/step - accuracy: 0.9951 - loss: 0.0201 - lr: 1.0000e-04
Epoch 7/12
928/938 [============================>.] - ETA: 0s - accuracy: 0.9955 - loss: 0.0186
Learning rate for epoch 7 is 9.999999747378752e-05
938/938 [==============================] - 3s 3ms/step - accuracy: 0.9955 - loss: 0.0186 - lr: 1.0000e-04
Epoch 8/12
934/938 [============================>.] - ETA: 0s - accuracy: 0.9965 - loss: 0.0161
Learning rate for epoch 8 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - accuracy: 0.9965 - loss: 0.0161 - lr: 1.0000e-05
Epoch 9/12
932/938 [============================>.] - ETA: 0s - accuracy: 0.9965 - loss: 0.0157
Learning rate for epoch 9 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - accuracy: 0.9965 - loss: 0.0156 - lr: 1.0000e-05
Epoch 10/12
934/938 [============================>.] - ETA: 0s - accuracy: 0.9966 - loss: 0.0155
Learning rate for epoch 10 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - accuracy: 0.9966 - loss: 0.0154 - lr: 1.0000e-05
Epoch 11/12
934/938 [============================>.] - ETA: 0s - accuracy: 0.9967 - loss: 0.0153
Learning rate for epoch 11 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - accuracy: 0.9967 - loss: 0.0153 - lr: 1.0000e-05
Epoch 12/12
924/938 [============================>.] - ETA: 0s - accuracy: 0.9967 - loss: 0.0152
Learning rate for epoch 12 is 9.999999747378752e-06
938/938 [==============================] - 3s 3ms/step - accuracy: 0.9967 - loss: 0.0151 - lr: 1.0000e-05

<tensorflow.python.keras.callbacks.History at 0x7fc7cc1ce5f8>

以下に示すように、チェックポイントが保存されています。

 # check the checkpoint directory
!ls {checkpoint_dir}
 
checkpoint           ckpt_4.data-00000-of-00002
ckpt_1.data-00000-of-00002   ckpt_4.data-00001-of-00002
ckpt_1.data-00001-of-00002   ckpt_4.index
ckpt_1.index             ckpt_5.data-00000-of-00002
ckpt_10.data-00000-of-00002  ckpt_5.data-00001-of-00002
ckpt_10.data-00001-of-00002  ckpt_5.index
ckpt_10.index            ckpt_6.data-00000-of-00002
ckpt_11.data-00000-of-00002  ckpt_6.data-00001-of-00002
ckpt_11.data-00001-of-00002  ckpt_6.index
ckpt_11.index            ckpt_7.data-00000-of-00002
ckpt_12.data-00000-of-00002  ckpt_7.data-00001-of-00002
ckpt_12.data-00001-of-00002  ckpt_7.index
ckpt_12.index            ckpt_8.data-00000-of-00002
ckpt_2.data-00000-of-00002   ckpt_8.data-00001-of-00002
ckpt_2.data-00001-of-00002   ckpt_8.index
ckpt_2.index             ckpt_9.data-00000-of-00002
ckpt_3.data-00000-of-00002   ckpt_9.data-00001-of-00002
ckpt_3.data-00001-of-00002   ckpt_9.index
ckpt_3.index

モデルのパフォーマンスを確認するには、最新のチェックポイントを読み込み、テストデータに対してevaluateを呼び出します。

適切なデータセットを使用する前と同じようにevaluateを呼び出します。

 model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
 
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

157/157 [==============================] - 1s 7ms/step - accuracy: 0.9861 - loss: 0.0393
Eval loss: 0.039307601749897, Eval Accuracy: 0.9861000180244446

出力を確認するには、ターミナルでTensorBoardログをダウンロードして表示します。

 $ tensorboard --logdir=path/to/log-directory
 
ls -sh ./logs
total 4.0K
4.0K train

SavedModelにエクスポート

グラフと変数をプラットフォームに依存しないSavedModel形式にエクスポートします。モデルが保存された後、スコープの有無にかかわらずそれをロードできます。

 path = 'saved_model/'
 
 model.save(path, save_format='tf')
 
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

INFO:tensorflow:Assets written to: saved_model/assets

INFO:tensorflow:Assets written to: saved_model/assets

strategy.scopeなしでモデルをロードします。

 unreplicated_model = tf.keras.models.load_model(path)

unreplicated_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
 
157/157 [==============================] - 1s 5ms/step - loss: 0.0393 - accuracy: 0.9861
Eval loss: 0.039307601749897, Eval Accuracy: 0.9861000180244446

strategy.scopeてモデルをロードします。

 with strategy.scope():
  replicated_model = tf.keras.models.load_model(path)
  replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
 
157/157 [==============================] - 1s 6ms/step - accuracy: 0.9861 - loss: 0.0393
Eval loss: 0.039307601749897, Eval Accuracy: 0.9861000180244446

例とチュートリアル

keras fit / compileで配布戦略を使用する例をいくつか示します:

  1. tf.distribute.MirroredStrategyを使用してトレーニングされたTransformerの
  2. tf.distribute.MirroredStrategyを使用してトレーニングされたNCFの例。

配布戦略ガイドに記載されているその他の例

次のステップ