tf.distribute.Strategy を使用したカスタムトレーニング

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

このチュートリアルでは、複数の処理ユニット(GPU、複数のマシン、または TPU)にトレーニングを分散するための抽象化を提供する tf.distribute.Strategy という TensorFlow API をカスタムトレーニングループで使用する方法を説明します。この例では、70,000 個の 28 x 28 のサイズの画像を含む Fashion MNIST データセットで、単純な畳み込みニューラルネットワークをトレーニングします。

カスタムトレーニングループを使用すると、より優れた制御によってトレーニングを柔軟に実行できます。また、モデルとトレーニングループのデバックもより簡単に行えるようになります。

# Import TensorFlow
import tensorflow as tf

# Helper libraries
import numpy as np
import os

print(tf.__version__)
2024-01-11 18:11:25.309748: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 18:11:25.309791: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 18:11:25.311321: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2.15.0

Fashion MNIST データセットをダウンロードする

fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Add a dimension to the array -> new shape == (28, 28, 1)
# This is done because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Scale the images to the [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
29515/29515 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26421880/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
5148/5148 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4422102/4422102 [==============================] - 0s 0us/step

変数とグラフを分散させるストラテジーを作成する

tf.distribute.MirroredStrategyストラテジーはどのように機能するのでしょう?

  • すべての変数とモデルグラフはレプリカ上に複製されます。
  • 入力はレプリカ全体に均等に分散されます。
  • 各レプリカは受け取った入力の損失と勾配を計算します。
  • 勾配は加算して全てのレプリカ間で同期されます。
  • 同期後、各レプリカ上の変数のコピーにも同じ更新が行われます。

注意: 下のコードはすべて 1 つのスコープ内に入れることができます。説明しやすいように、この例では複数のコードセルに分割しています。

# If the list of devices is not specified in
# `tf.distribute.MirroredStrategy` constructor, they will be auto-detected.
strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 4

入力パイプラインをセットアップする

BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10

データセットを作成して、それを分散します。

train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

モデルを作成する

tf.keras.Sequential を使用してモデルを作成します。これには、Model Subclassing APIfunctional API も使用できます。

def create_model():
  regularizer = tf.keras.regularizers.L2(1e-5)
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3,
                             activation='relu',
                             kernel_regularizer=regularizer),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Conv2D(64, 3,
                             activation='relu',
                             kernel_regularizer=regularizer),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64,
                            activation='relu',
                            kernel_regularizer=regularizer),
      tf.keras.layers.Dense(10, kernel_regularizer=regularizer)
    ])

  return model
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

損失関数を定義する

損失関数は以下の 2 つの部分で構成されていることを思い出しましょう。

  • 予測損失は、モデルの予測が、トレーニングサンプルのバッチに対するトレーニングラベルからどれくらい外れているかを測定します。ラベル付きのサンプルごとに計算されてから、平均値を計算してバッチ全体で縮小されます。
  • オプションの 正則化損失の項を予測損失に追加して、モデルがトレーニングデータを過学習しないように誘導します。一般的には L2 正則化が使用されます。これは、サンプルの数に関係なく、すべてのモデルの重みの二乗和の小さな固定倍数を追加します。上記のモデルは L2 正則化を使用して、以下のトレーニングループでの処理を示しています。

単一の GPU/CPU を使った単一のマシンでのトレーニングでは、次のように動作します。

  • 予測損失は、バッチのサンプルごとに計算され、バッチ全体で加算され、バッチサイズで除算されます。
  • 正則化損失は、予測損失に追加されます。
  • 合計損失の勾配は各モデルの重みに関して計算され、オプティマイザが、対応する勾配から各モデルの重みを更新します。

tf.distribute.Strategy では、入力バッチはレプリカ間で分割されます。たとえば、GPU が 4 つあり、それぞれにモデルのレプリカが 1 つあるとします。1 つのバッチの 256 の入力サンプルは 4 つのレプリカで均等に分散されるため、各レプリカのバッチサイズは 64 となります。したがって、256 = 4*64、または一般に GLOBAL_BATCH_SIZE = num_replicas_in_sync * BATCH_SIZE_PER_REPLICA があることになります。

各レプリカは、それが得るトレーニングサンプルから損失を計算し、各モデルの重みに関する損失の勾配を計算します。オプティマイザは、これらの勾配をレプリカ全体で加算してから、レプリカごとにモデルの重みのコピーを更新します。

では、tf.distribute.Strategy を使用する場合、どのように損失を計算すればよいのでしょうか。

  • 各レプリカは、それに分散されたすべてのサンプルの予測損失を計算し、結果を加算して、num_replicas_in_sync * BATCH_SIZE_PER_REPLICA または GLOBAL_BATCH_SIZE で除算します。
  • 各レプリカは正則化損失を計算し、それを num_replicas_in_sync で除算します。

非分散型トレーニングに桑ベルト、すべてのレプリカ単位の損失項は 1/num_replicas_in_sync の計数でスケールダウンされます。一方、すべての損失項、または勾配は、オプティマイザが適用する前にレプリカの数で加算されます。実際、各レプリカのオプティマイザは、GLOBAL_BATCH_SIZE による非分散型計算が行われなかったかのようにして、同じ勾配を使用します。これは、分散型と非分散型の Keras Model.fit の動作と同じです。より大きなグローバルバッチサイズによって学習率のスケールアップが可能になるかについて、Keras による分散型トレーニングをご覧ください。

TensorFlow では次のようにします。

  • この縮小とスケーリングは、Keras Model.compileModel.fit で自動的に行われます。

  • If you're writing a custom training loop, as in this tutorial, you should sum the per example losses and divide the sum by the GLOBAL_BATCH_SIZE: scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE) or you can use tf.nn.compute_average_loss which takes the per example loss, optional sample weights, and GLOBAL_BATCH_SIZE as arguments and returns the scaled loss.

  • tf.keras.losses クラスを使用すると(以下の例)、損失の縮小を NONE または SUM のいずれかになるように明示的に指定する必要があります。デフォルトの AUTOSUM_OVER_BATCH_SIZEModel.fit の外では使用できません。

    • AUTO は、分散型のケースで正しくなるようにどの縮小を使用するかを明示的に考える必要があるため、使用できません。
    • SUM_OVER_BATCH_SIZE は、現在、レプリカごとのバッチサイズでのみ除算し、レプリカ数による除算をユーザーが処理しなければならないようになっていますが、見逃す可能性があるため、使用できなくなっています。したがって、ユーザー自身が縮小を明示的に行う必要があります。
  • 空でない Model.losses リストのカスタムトレーニングループを書いている場合は(重みレギュラライザなど)、加算して、レプリカ数で除算する必要があります。これは、tf.nn.scale_regularization_loss 関数を使って行えます。モデルコード自体は、レプリカの数を認識していません。

ただし、モデルは、Layer.add_loss(...)Layer(activity_regularizer=...) などの Keras API によって入力に依存する正則化損失を定義できます。Layer.add_loss(...) の場合、モデリングコードが加算されたサンプルごとの項を tf.math.reduce_mean() などを使ってレプリカ単位(!) のバッチサイズで除算します。

with strategy.scope():
  # Set reduction to `NONE` so you can do the reduction yourself.
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True,
      reduction=tf.keras.losses.Reduction.NONE)
  def compute_loss(labels, predictions, model_losses):
    per_example_loss = loss_object(labels, predictions)
    loss = tf.nn.compute_average_loss(per_example_loss)
    if model_losses:
      loss += tf.nn.scale_regularization_loss(tf.add_n(model_losses))
    return loss

特殊ケース

高度なユーザーは、以下の特殊ケースについても考慮することをお勧めします。

  • GLOBAL_BATCH_SIZE よりも短い入力バッチが原因で、いくつかの場所で好ましくない例外が発生します。実際には、Dataset.repeat().batch() を使用してエポックの境界をまたぐバッチを許可し、データセットの終了ではなくステップ数でおおよそのエポック数を定義することで、例外を回避することがよくあります。または、Dataset.batch(drop_remainder=True) は、エポックの表記を維持しながら、最後の数個のサンプルを除外します。

説明のために、この例ではより困難なルートを選択し、短いバッチを許可するため、トレーニングエポックごとに各トレーニング サンプルが 1 回だけ含まれます。

どのデノミネーターを tf.nn.compute_average_loss() で使用すればよいでしょうか。

* By default, in the example code above and equivalently in `Keras.fit()`, the sum of prediction losses is divided by `num_replicas_in_sync` times the actual batch size seen on the replica (with empty batches silently ignored). This preserves the balance between the prediction loss on the one hand and the regularization losses on the other hand. It is particularly appropriate for models that use input-dependent regularization losses. Plain L2 regularization just superimposes weight decay onto the gradients of the prediction loss and is less in need of such a balance.
* In practice, many custom training loops pass as a constant Python value into `tf.nn.compute_average_loss(..., global_batch_size=GLOBAL_BATCH_SIZE)` to use it as the denominator. This preserves the relative weighting of training examples between batches. Without it, the smaller denominator in short batches effectively upweights the examples in those. (Before TensorFlow 2.13, this was also needed to avoid NaNs in case some replica received an actual batch size of zero.)

上記で説明するように、いずれのオプションも、短いバッチが回避されるのであれば同等です。

  • 多次元 labels では、各サンプルの予測数全体で per_example_loss を平均化する必要があります。形状が (batch_size, H, W, n_classes)predictions と形状が (batch_size, H, W)labels を持つ入力画像のすべてのピクセルに対する分類タスクがあるとした場合、per_example_loss は次のようにして更新する必要があります: per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)

注意:損失の形状を確認してくださいtf.losses/tf.keras.lossesの損失関数は、通常、入力の最後の次元の平均を返します。損失クラスはこれらの関数をラップします。 損失クラスのインスタンスを作成するときにreduction=Reduction.NONEを渡すことは、「追加の縮小がない」ことを意味します。[batch, W, H, n_classes]の入力形状の例を使用したカテゴリ損失の場合、n_classes次元が縮小されます。losses.mean_squared_errorまたはlosses.binary_crossentropyのような点ごとの損失の場合、ダミー軸を用いて、[batch, W, H, 1][batch, W, H]に縮小します。ダミー軸がないと、[batch, W, H]は誤って[batch, W]に縮小されます。

損失と精度を追跡するメトリクスを定義する

これらのメトリクスは、テストの損失、トレーニング、テストの精度を追跡します。.result()を使用して、いつでも累積統計を取得できます。

with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='test_accuracy')

トレーニングループ

# A model, an optimizer, and a checkpoint must be created under `strategy.scope`.
with strategy.scope():
  model = create_model()

  optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
def train_step(inputs):
  images, labels = inputs

  with tf.GradientTape() as tape:
    predictions = model(images, training=True)
    loss = compute_loss(labels, predictions, model.losses)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_accuracy.update_state(labels, predictions)
  return loss

def test_step(inputs):
  images, labels = inputs

  predictions = model(images, training=False)
  t_loss = loss_object(labels, predictions)

  test_loss.update_state(t_loss)
  test_accuracy.update_state(labels, predictions)
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

@tf.function
def distributed_test_step(dataset_inputs):
  return strategy.run(test_step, args=(dataset_inputs,))

for epoch in range(EPOCHS):
  # TRAIN LOOP
  total_loss = 0.0
  num_batches = 0
  for x in train_dist_dataset:
    total_loss += distributed_train_step(x)
    num_batches += 1
  train_loss = total_loss / num_batches

  # TEST LOOP
  for x in test_dist_dataset:
    distributed_test_step(x)

  if epoch % 2 == 0:
    checkpoint.save(checkpoint_prefix)

  template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
              "Test Accuracy: {}")
  print(template.format(epoch + 1, train_loss,
                         train_accuracy.result() * 100, test_loss.result(),
                         test_accuracy.result() * 100))

  test_loss.reset_states()
  train_accuracy.reset_states()
  test_accuracy.reset_states()
INFO:tensorflow:Collective all_reduce tensors: 8 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Collective all_reduce tensors: 8 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1704996698.665771   44862 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
INFO:tensorflow:Collective all_reduce tensors: 8 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 1, Loss: 0.6599556803703308, Accuracy: 77.09000396728516, Test Loss: 0.4591159522533417, Test Accuracy: 83.42000579833984
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 2, Loss: 0.4002712368965149, Accuracy: 85.9866714477539, Test Loss: 0.4018000364303589, Test Accuracy: 85.7699966430664
Epoch 3, Loss: 0.34986022114753723, Accuracy: 87.6066665649414, Test Loss: 0.3515324294567108, Test Accuracy: 87.4800033569336
Epoch 4, Loss: 0.32334816455841064, Accuracy: 88.48666381835938, Test Loss: 0.32882973551750183, Test Accuracy: 88.13999938964844
Epoch 5, Loss: 0.30147460103034973, Accuracy: 89.28166961669922, Test Loss: 0.31937628984451294, Test Accuracy: 88.80000305175781
Epoch 6, Loss: 0.28456664085388184, Accuracy: 89.92833709716797, Test Loss: 0.30238109827041626, Test Accuracy: 89.41000366210938
Epoch 7, Loss: 0.268964558839798, Accuracy: 90.45833587646484, Test Loss: 0.30147919058799744, Test Accuracy: 89.45999908447266
Epoch 8, Loss: 0.2548156976699829, Accuracy: 90.91500091552734, Test Loss: 0.29818081855773926, Test Accuracy: 89.13999938964844
Epoch 9, Loss: 0.24469958245754242, Accuracy: 91.3550033569336, Test Loss: 0.28745248913764954, Test Accuracy: 89.8800048828125
Epoch 10, Loss: 0.23541080951690674, Accuracy: 91.68000030517578, Test Loss: 0.2748521864414215, Test Accuracy: 89.94000244140625

上記の例における注意点

  • for x in ... コンストラクトを使用して、train_dist_datasettest_dist_dataset をイテレーションします。
  • スケーリングされた損失は distributed_train_step の戻り値です。この値は tf.distribute.Strategy.reduce 呼び出しを使用してレプリカ間で集約され、次に tf.distribute.Strategy.reduce 呼び出しの戻り値を加算してバッチ間で集約されます。
  • tf.keras.Metrics は、tf.distribute.Strategy.run によって実行される train_step および test_step 内で更新する必要があります。
  • tf.distribute.Strategy.run はストラテジー内の各ローカルレプリカの結果を返し、この結果の使用方法は多様です。reduce で、集約された値を取得することができます。また、tf.distribute.Strategy.experimental_local_results を実行して、ローカルレプリカごとに 1 つ、結果に含まれる値のリストを取得することもできます。

最新のチェックポイントを復元してテストする

tf.distribute.Strategyでチェックポイントされたモデルは、ストラテジーの有無に関わらず復元することができます。

eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='eval_accuracy')

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
@tf.function
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for images, labels in test_dataset:
  eval_step(images, labels)

print('Accuracy after restoring the saved model without strategy: {}'.format(
    eval_accuracy.result() * 100))
Accuracy after restoring the saved model without strategy: 89.8800048828125

データセットのイテレーションの代替方法

イテレータを使用する

データセット全体ではなく、任意のステップ数のイテレーションを行う場合は、iter 呼び出しを使用してイテレータを作成し、そのイテレータ上で next を明示的に呼び出すことができます。tf.function の内側と外側の両方でデータセットのイテレーションを選択することができます。ここでは、イテレータを使用し tf.function の外側のデータセットのイテレーションを実行する小さなスニペットを示します。

for _ in range(EPOCHS):
  total_loss = 0.0
  num_batches = 0
  train_iter = iter(train_dist_dataset)

  for _ in range(10):
    total_loss += distributed_train_step(next(train_iter))
    num_batches += 1
  average_train_loss = total_loss / num_batches

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print(template.format(epoch + 1, average_train_loss, train_accuracy.result() * 100))
  train_accuracy.reset_states()
Epoch 10, Loss: 0.22530433535575867, Accuracy: 92.0703125
Epoch 10, Loss: 0.21155035495758057, Accuracy: 92.421875
Epoch 10, Loss: 0.23270802199840546, Accuracy: 91.09375
Epoch 10, Loss: 0.2111983597278595, Accuracy: 92.421875
Epoch 10, Loss: 0.2315395325422287, Accuracy: 91.7578125
Epoch 10, Loss: 0.22891399264335632, Accuracy: 91.640625
Epoch 10, Loss: 0.23187729716300964, Accuracy: 91.4453125
Epoch 10, Loss: 0.23954670131206512, Accuracy: 91.4453125
Epoch 10, Loss: 0.21727390587329865, Accuracy: 92.5390625
Epoch 10, Loss: 0.2208312749862671, Accuracy: 92.3046875

tf.function 内でイテレーションする

for x in ... コンストラクトを使用して、または上記で行ったようにイテレータを作成して、tf.function 内で train_dist_dataset の入力全体をイテレートすることもできます。以下の例では、1 エポックのトレーニングを @tf.function デコレータでラップし、関数内で train_dist_dataset をイテレーションする方法を示します。

@tf.function
def distributed_train_epoch(dataset):
  total_loss = 0.0
  num_batches = 0
  for x in dataset:
    per_replica_losses = strategy.run(train_step, args=(x,))
    total_loss += strategy.reduce(
      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
    num_batches += 1
  return total_loss / tf.cast(num_batches, dtype=tf.float32)

for epoch in range(EPOCHS):
  train_loss = distributed_train_epoch(train_dist_dataset)

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print(template.format(epoch + 1, train_loss, train_accuracy.result() * 100))

  train_accuracy.reset_states()
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py:462: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options.
  warnings.warn("To make it possible to preserve tf.data options across "
INFO:tensorflow:Collective all_reduce tensors: 8 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
Epoch 1, Loss: 0.22321248054504395, Accuracy: 92.12667083740234
Epoch 2, Loss: 0.21352693438529968, Accuracy: 92.40499877929688
Epoch 3, Loss: 0.2031208723783493, Accuracy: 92.74666595458984
Epoch 4, Loss: 0.19914129376411438, Accuracy: 92.96666717529297
Epoch 5, Loss: 0.18742477893829346, Accuracy: 93.40333557128906
Epoch 6, Loss: 0.18182916939258575, Accuracy: 93.59166717529297
Epoch 7, Loss: 0.17676156759262085, Accuracy: 93.77666473388672
Epoch 8, Loss: 0.16894836723804474, Accuracy: 94.06333923339844
Epoch 9, Loss: 0.1639356017112732, Accuracy: 94.2683334350586
Epoch 10, Loss: 0.1561119258403778, Accuracy: 94.56999969482422

レプリカ間でトレーニング損失を追跡する

注意: 一般的なルールとして、サンプルごとの値の追跡にはtf.keras.Metricsを使用し、レプリカ内で集約された値を避ける必要があります。

損失スケーリングの計算が実行されるため、レプリカ間でトレーニング損失を追跡するために tf.keras.metrics.Mean を使用することは推奨されません。

例えば、次のような特徴を持つトレーニングジョブを実行するとします。

  • レプリカ 2 つ
  • 各レプリカで 2 つのサンプルを処理
  • 結果の損失値 : 各レプリカで [2, 3] および [4, 5]
  • グローバルバッチサイズ = 4

損失スケーリングで損失値を加算して各レプリカのサンプルごとの損失の値を計算し、さらにグローバルバッチサイズで除算します。この場合は、(2 + 3) / 4 = 1.25および(4 + 5) / 4 = 2.25となります。

tf.keras.metrics.Mean を使用して 2 つのレプリカ間の損失を追跡すると、異なる結果が得られます。この例では、total は 3.50、count は 2 となるため、メトリックで result() が呼び出されると、total/count = 1.75 となります。tf.keras.Metrics で計算された損失は、同期するレプリカの数に等しい追加の係数によってスケーリングされます。

ガイドと例

カスタムトレーニングループを用いた分散ストラテジーの使用例をここに幾つか示します。

  1. 分散型トレーニングガイド
  2. MirroredStrategyを使用した DenseNet の例。
  3. MirroredStrategyTPUStrategyを使用してトレーニングされた BERT の例。この例は、分散トレーニングなどの間にチェックポイントから読み込む方法と、定期的にチェックポイントを生成する方法を理解するのに特に有用です。
  4. MirroredStrategy を使用してトレーニングされ、keras_use_ctl フラグを使用した有効化が可能な、NCF の例。
  5. MirroredStrategyを使用してトレーニングされた、NMT の例。

その他の例は、分散型ストラテジーガイドの「例とチュートリアル」に記載されています。

次のステップ