tf.distribute.Strategy を使用したカスタムトレーニング

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

このチュートリアルでは、tf.distribute.Strategyをカスタムトレーニングループで使用する方法を示します。Fashion MNIST データセットで単純な CNN モデルをトレーニングします。Fashion MNIST データセットには、サイズ 28×28 のトレーニング画像 60,000 枚とサイズ 28×28 のテスト画像 10,000 枚が含まれています。

モデルのトレーニングにカスタムトレーニングループを使用するのは、柔軟性があり、トレーニングを容易に制御できるからです。それに加えて、モデルとトレーニングループのデバッグも容易になります。

# Import TensorFlow
import tensorflow as tf

# Helper libraries
import numpy as np
import os

print(tf.__version__)
2.4.1

Fashion MNIST データセットをダウンロードする

fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step

変数とグラフを分散させるストラテジーを作成する

tf.distribute.MirroredStrategyストラテジーはどのように機能するのでしょう?

  • 全ての変数とモデルグラフはレプリカ上に複製されます。
  • 入力はレプリカ全体に均等に分散されます。
  • 各レプリカは受け取った入力の損失と勾配を計算します。
  • 勾配は加算して全てのレプリカ間で同期されます。
  • 同期後、各レプリカ上の変数のコピーにも同じ更新が行われます。

注意: 下記のコードは全て 1 つのスコープ内に入れることができます。説明しやすいように、ここでは幾つかのコードセルに分割しています。

# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

入力パイプラインをセットアップする

グラフと変数をプラットフォームに依存しない SavedModel 形式にエクスポートします。モデルが保存された後、スコープの有無に関わらずそれを読み込むことができます。

BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10

データセットを作成して、それを配布します。

train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) 
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE) 

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

モデルを作成する

tf.keras.Sequentialを使用してモデルを作成します。これは Model Subclassing API を使用しても作成できます。

def create_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Conv2D(64, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
    ])

  return model
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

損失関数を定義する

通常、GPU/CPU を 1 つ搭載した単一のマシンでは、損失は入力バッチ内の例の数で除算されます。

では、tf.distribute.Strategyを使用する場合、どのように損失を計算すればよいのでしょうか?

  • 例えば、4 つの GPU と 64 のバッチサイズがあるとします。1 つの入力バッチは(4 つの GPU の)レプリカに分散されるので、各レプリカはサイズ 16 の入力を取得します。

  • 各レプリカのモデルは、それぞれの入力でフォワードパスを実行し、損失を計算します。ここでは、損失をそれぞれの入力の例の数(BATCH_SIZE_PER_REPLICA = 16)で除算するのではなく、損失を GLOBAL_BATCH_SIZE (64) で除算する必要があります。

なぜそうするのでしょう?

  • 勾配を各レプリカで計算した後にそれらを加算してレプリカ間で同期するため、これを行う必要があります。

TensorFlow でこれを行うには?

  • このチュートリアルにもあるように、カスタムトレーニングループを書く場合は、例ごとの損失を加算し、その合計を GLOBAL_BATCH_SIZE: scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE)で除算する必要があります。または、tf.nn.compute_average_lossを使用することも可能です。これは例ごとの損失、オプションのサンプルの重み、そしてGLOBAL_BATCH_SIZE を引数として取り、スケーリングされた損失を返します。

  • モデルで正則化損失を使用している場合は、損失値をレプリカの数でスケーリングする必要があります。これを行うには、tf.nn.scale_regularization_loss関数を使用します。

  • tf.reduce_meanの使用は推奨されません。これを使用すると、損失がレプリカごとの実際のバッチサイズで除算され、ステップごとに変化する場合があります。

  • この縮小とスケーリングは、kerasmodel.compilemodel.fitで自動的に行われます。

  • 以下の例のようにtf.keras.lossesクラスを使用する場合、損失削減はNONEまたはSUMのいずれかになるよう、明示的に指定する必要があります。AUTOおよびSUM_OVER_BATCH_SIZEtf.distribute.Strategyとの併用は許可されません。AUTOは、分散型のケースでユーザーがどの削減を正しいと確認するか明示的に考える必要があるため、許可されていません。SUM_OVER_BATCH_SIZEは、現時点ではレプリカのバッチサイズのみで除算され、レプリカの数に基づく除算はユーザーに任されており見落としがちなため、許可されていません。そのため、その代わりにユーザー自身が明示的に削減を行うようにお願いしています。

  • もしlabelsが多次元である場合は、各サンプルの要素数全体でper_example_lossを平均化します。例えば、predictionsの形状が(batch_size, H, W, n_classes)で、labels(batch_size, H, W)の場合、per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)のようにper_example_lossを更新する必要があります。

with strategy.scope():
  # Set reduction to `none` so we can do the reduction afterwards and divide by
  # global batch size.
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True,
      reduction=tf.keras.losses.Reduction.NONE)
  def compute_loss(labels, predictions):
    per_example_loss = loss_object(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

損失と精度を追跡するメトリクスを定義する

これらのメトリクスは、テストの損失、トレーニング、テストの精度を追跡します。.result()を使用して、いつでも累積統計を取得できます。

with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='test_accuracy')
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

トレーニングループ

# model, optimizer, and checkpoint must be created under `strategy.scope`.
with strategy.scope():
  model = create_model()

  optimizer = tf.keras.optimizers.Adam()

  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
def train_step(inputs):
  images, labels = inputs

  with tf.GradientTape() as tape:
    predictions = model(images, training=True)
    loss = compute_loss(labels, predictions)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_accuracy.update_state(labels, predictions)
  return loss 

def test_step(inputs):
  images, labels = inputs

  predictions = model(images, training=False)
  t_loss = loss_object(labels, predictions)

  test_loss.update_state(t_loss)
  test_accuracy.update_state(labels, predictions)
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

@tf.function
def distributed_test_step(dataset_inputs):
  return strategy.run(test_step, args=(dataset_inputs,))

for epoch in range(EPOCHS):
  # TRAIN LOOP
  total_loss = 0.0
  num_batches = 0
  for x in train_dist_dataset:
    total_loss += distributed_train_step(x)
    num_batches += 1
  train_loss = total_loss / num_batches

  # TEST LOOP
  for x in test_dist_dataset:
    distributed_test_step(x)

  if epoch % 2 == 0:
    checkpoint.save(checkpoint_prefix)

  template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
              "Test Accuracy: {}")
  print (template.format(epoch+1, train_loss,
                         train_accuracy.result()*100, test_loss.result(),
                         test_accuracy.result()*100))

  test_loss.reset_states()
  train_accuracy.reset_states()
  test_accuracy.reset_states()
Epoch 1, Loss: 0.5086057186126709, Accuracy: 81.62666320800781, Test Loss: 0.3796250820159912, Test Accuracy: 86.69000244140625
Epoch 2, Loss: 0.32958245277404785, Accuracy: 88.27333068847656, Test Loss: 0.32727015018463135, Test Accuracy: 88.44000244140625
Epoch 3, Loss: 0.2872839868068695, Accuracy: 89.56000518798828, Test Loss: 0.29971158504486084, Test Accuracy: 89.20999908447266
Epoch 4, Loss: 0.25545117259025574, Accuracy: 90.6483383178711, Test Loss: 0.29015499353408813, Test Accuracy: 89.72000122070312
Epoch 5, Loss: 0.23445327579975128, Accuracy: 91.50999450683594, Test Loss: 0.27410945296287537, Test Accuracy: 90.16999816894531
Epoch 6, Loss: 0.21617336571216583, Accuracy: 92.07999420166016, Test Loss: 0.32005295157432556, Test Accuracy: 87.61000061035156
Epoch 7, Loss: 0.1997239589691162, Accuracy: 92.61666870117188, Test Loss: 0.25581517815589905, Test Accuracy: 90.66999816894531
Epoch 8, Loss: 0.18616704642772675, Accuracy: 93.07833099365234, Test Loss: 0.24862809479236603, Test Accuracy: 90.95999908447266
Epoch 9, Loss: 0.17132586240768433, Accuracy: 93.7266616821289, Test Loss: 0.2547084391117096, Test Accuracy: 90.83999633789062
Epoch 10, Loss: 0.15842965245246887, Accuracy: 94.12000274658203, Test Loss: 0.2527276575565338, Test Accuracy: 91.05999755859375

上記の例における注意点:

  • for文(for x in ...)を使用して、train_dist_datasettest_dist_datasetに対してイテレーションしています。
  • スケーリングされた損失は distributed_train_stepの戻り値です。この値はtf.distribute.Strategy.reduce呼び出しを使用してレプリカ間で集約され、次にtf.distribute.Strategy.reduce呼び出しの戻り値を加算してバッチ間で集約されます。
  • tf.keras.Metricsは、tf.distribute.Strategy.runによって実行されるtrain_steptest_step内で更新される必要があります。*tf.distribute.Strategy.runはストラテジー内の各ローカルレプリカの結果を返し、この結果の消費方法は多様です。tf.distribute.Strategy.reduceを実行して、集約された値を取得することができます。また、tf.distribute.Strategy.experimental_local_resultsを実行して、ローカルレプリカごとに 1 つ、結果に含まれる値のリストを取得することもできます。

最新のチェックポイントを復元してテストする

tf.distribute.Strategyでチェックポイントされたモデルは、ストラテジーの有無に関わらず復元することができます。

eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='eval_accuracy')

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
@tf.function
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for images, labels in test_dataset:
  eval_step(images, labels)

print ('Accuracy after restoring the saved model without strategy: {}'.format(
    eval_accuracy.result()*100))
Accuracy after restoring the saved model without strategy: 90.83999633789062

データセットのイテレーションの代替方法

イテレータを使用する

データセット全体ではなく、任意のステップ数のイテレーションを行いたい場合は、iter呼び出しを使用してイテレータを作成し、そのイテレータ上でnextを明示的に呼び出すことができます。tf.function の内側と外側の両方でデータセットのイテレーションを選択することができます。ここでは、イテレータを使用し tf.function 外側のデータセットのイテレーションを実行する小さなスニペットを示します。

for _ in range(EPOCHS):
  total_loss = 0.0
  num_batches = 0
  train_iter = iter(train_dist_dataset)

  for _ in range(10):
    total_loss += distributed_train_step(next(train_iter))
    num_batches += 1
  average_train_loss = total_loss / num_batches

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print (template.format(epoch+1, average_train_loss, train_accuracy.result()*100))
  train_accuracy.reset_states()
Epoch 10, Loss: 0.15624000132083893, Accuracy: 94.0625
Epoch 10, Loss: 0.16303464770317078, Accuracy: 93.4375
Epoch 10, Loss: 0.15605169534683228, Accuracy: 93.75
Epoch 10, Loss: 0.1290276050567627, Accuracy: 94.53125
Epoch 10, Loss: 0.13734309375286102, Accuracy: 94.53125
Epoch 10, Loss: 0.12921670079231262, Accuracy: 96.40625
Epoch 10, Loss: 0.16179904341697693, Accuracy: 94.0625
Epoch 10, Loss: 0.11322519928216934, Accuracy: 96.09375
Epoch 10, Loss: 0.17364104092121124, Accuracy: 93.125
Epoch 10, Loss: 0.1457490175962448, Accuracy: 94.84375

tf.function 内でイテレーションする

tf.function の内側で for 文(for x in ...)を使用して、あるいは上記で行ったようにイテレータを作成して、入力train_dist_dataset全体をイテレーションすることもできます。次の例では、トレーニングの 1 つのエポックを tf.function でラップし、関数内でtrain_dist_datasetをイテレーションする方法を示します。

@tf.function
def distributed_train_epoch(dataset):
  total_loss = 0.0
  num_batches = 0
  for x in dataset:
    per_replica_losses = strategy.run(train_step, args=(x,))
    total_loss += strategy.reduce(
      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
    num_batches += 1
  return total_loss / tf.cast(num_batches, dtype=tf.float32)

for epoch in range(EPOCHS):
  train_loss = distributed_train_epoch(train_dist_dataset)

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print (template.format(epoch+1, train_loss, train_accuracy.result()*100))

  train_accuracy.reset_states()
Epoch 1, Loss: 0.14632149040699005, Accuracy: 94.64666748046875
Epoch 2, Loss: 0.1365729421377182, Accuracy: 94.99666595458984
Epoch 3, Loss: 0.12348174303770065, Accuracy: 95.45166778564453
Epoch 4, Loss: 0.116549551486969, Accuracy: 95.65499877929688
Epoch 5, Loss: 0.10783953219652176, Accuracy: 95.99666595458984
Epoch 6, Loss: 0.09882526099681854, Accuracy: 96.32333374023438
Epoch 7, Loss: 0.08941978216171265, Accuracy: 96.6816635131836
Epoch 8, Loss: 0.08444732427597046, Accuracy: 96.8949966430664
Epoch 9, Loss: 0.07905271649360657, Accuracy: 97.02833557128906
Epoch 10, Loss: 0.07040875405073166, Accuracy: 97.4183349609375

レプリカ間でトレーニング損失を追跡する

注意: 一般的なルールとして、サンプルごとの値の追跡にはtf.keras.Metricsを使用し、レプリカ内で集約された値を避ける必要があります。

tf.metrics.Meanを使用すると損失スケーリングが計算されるため、異なるレプリカ間のトレーニング損失の追跡には推奨できません

例えば、次のような特徴を持つトレーニングジョブを実行するとします。

  • レプリカ 2 つ
  • 各レプリカで 2 つのサンプルを処理
  • 結果の損失値 : 各レプリカで [2, 3] および [4, 5]
  • グローバルバッチサイズ = 4

損失スケーリングで損失値を加算して各レプリカのサンプルごとの損失の値を計算し、さらにグローバルバッチサイズで除算します。この場合は、(2 + 3) / 4 = 1.25および(4 + 5) / 4 = 2.25となります。

tf.metrics.Meanを使用して 2 つのレプリカ間の損失を追跡する場合、結果は異なります。この例では、totalは 3.50、countは 2 となり、メトリック上でresult()が呼び出された場合の結果はtotal/count = 1.75 となります。tf.keras.Metricsで計算された損失は、同期するレプリカの数に等しい追加の係数によってスケーリングされます。

ガイドと例

カスタムトレーニングループを用いた分散ストラテジーの使用例をここに幾つか示します。

  1. 分散型トレーニングガイド
  2. MirroredStrategyを使用した DenseNet の例。
  3. MirroredStrategyTPUStrategyを使用してトレーニングされた BERT の例。この例は、分散トレーニングなどの間にチェックポイントから読み込む方法と、定期的にチェックポイントを生成する方法を理解するのに特に有用です。
  4. MirroredStrategyを使用してトレーニングされ、keras_use_ctlフラグを使用した有効化が可能な、NCFの例。
  5. MirroredStrategyを使用してトレーニングされた、NMT の例。

分散ストラテジーガイドには他の例も記載されています。

次のステップ