tf.distribute.Strategy を使用したカスタムトレーニング

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

このチュートリアルでは、tf.distribute.Strategyをカスタムトレーニングループで使用する方法を示します。Fashion MNIST データセットで単純な CNN モデルをトレーニングします。Fashion MNIST データセットには、サイズ 28×28 のトレーニング画像 60,000 枚とサイズ 28×28 のテスト画像 10,000 枚が含まれています。

モデルのトレーニングにカスタムトレーニングループを使用するのは、柔軟性があり、トレーニングを容易に制御できるからです。それに加えて、モデルとトレーニングループのデバッグも容易になります。

# Import TensorFlow
import tensorflow as tf

# Helper libraries
import numpy as np
import os

print(tf.__version__)
2022-08-08 21:42:38.094554: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-08-08 21:42:38.856546: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-08 21:42:38.856806: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-08 21:42:38.856818: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2.10.0-rc0

Fashion MNIST データセットをダウンロードする

fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
29515/29515 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26421880/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
5148/5148 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4422102/4422102 [==============================] - 0s 0us/step

変数とグラフを分散させるストラテジーを作成する

tf.distribute.MirroredStrategyストラテジーはどのように機能するのでしょう?

  • 全ての変数とモデルグラフはレプリカ上に複製されます。
  • 入力はレプリカ全体に均等に分散されます。
  • 各レプリカは受け取った入力の損失と勾配を計算します。
  • 勾配は加算して全てのレプリカ間で同期されます。
  • 同期後、各レプリカ上の変数のコピーにも同じ更新が行われます。

注意: 下記のコードは全て 1 つのスコープ内に入れることができます。説明しやすいように、ここでは幾つかのコードセルに分割しています。

# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 4

入力パイプラインをセットアップする

グラフと変数をプラットフォームに依存しない SavedModel 形式にエクスポートします。モデルが保存された後、スコープの有無に関わらずそれを読み込むことができます。

BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10

データセットを作成して、それを配布します。

train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) 
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE) 

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)
2022-08-08 21:42:45.954268: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_UINT8
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 60000
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:0"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 28
        }
        dim {
          size: 28
        }
        dim {
          size: 1
        }
      }
      shape {
      }
    }
  }
}
attr {
  key: "replicate_on_split"
  value {
    b: false
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_UINT8
        }
      }
    }
  }
}

2022-08-08 21:42:46.007276: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_UINT8
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 10000
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:3"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 28
        }
        dim {
          size: 28
        }
        dim {
          size: 1
        }
      }
      shape {
      }
    }
  }
}
attr {
  key: "replicate_on_split"
  value {
    b: false
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_UINT8
        }
      }
    }
  }
}

モデルを作成する

tf.keras.Sequentialを使用してモデルを作成します。これは Model Subclassing API を使用しても作成できます。

def create_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Conv2D(64, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
    ])

  return model
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

損失関数を定義する

通常、GPU/CPU を 1 つ搭載した単一のマシンでは、損失は入力バッチ内の例の数で除算されます。

では、tf.distribute.Strategy を使用する場合、どのように損失を計算すればよいのでしょうか。

  • 例えば、4 つの GPU と 64 のバッチサイズがあるとします。1 つの入力バッチは(4 つの GPU の)レプリカに分散されるので、各レプリカはサイズ 16 の入力を取得します。

  • 各レプリカのモデルは、それぞれの入力でフォワードパスを実行し、損失を計算します。ここでは、損失をそれぞれの入力の例の数(BATCH_SIZE_PER_REPLICA = 16)で除算するのではなく、損失を GLOBAL_BATCH_SIZE (64) で除算する必要があります。

なぜそうするのでしょう?

  • 勾配を各レプリカで計算した後にそれらを加算してレプリカ間で同期するためです。

TensorFlow では次のようにします。

  • このチュートリアルにもあるように、カスタムトレーニングループを書く場合は、サンプルごとの損失を加算し、その合計を GLOBAL_BATCH_SIZE: scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE) で除算する必要があります。または、tf.nn.compute_average_loss を使用することも可能です。これはサンプルごとの損失、オプションのサンプルの重み、そしてGLOBAL_BATCH_SIZE を引数として取り、スケーリングされた損失を返します。

  • モデルで正則化損失を使用している場合は、損失値をレプリカの数でスケーリングする必要があります。これを行うには、tf.nn.scale_regularization_loss 関数を使用します。

  • tf.reduce_mean の使用は推奨されません。これを使用すると、損失がレプリカごとの実際のバッチサイズで除算され、ステップごとに変化する場合があります。

  • この縮小とスケーリングは、kerasmodel.compilemodel.fit で自動的に行われます。

  • 以下の例のように tf.keras.losses を使用する場合、損失削減は NONE または SUM のいずれかになるよう、明示的に指定する必要があります。AUTO および SUM_OVER_BATCH_SIZEtf.distribute.Strategy との併用は許可されません。AUTO は、分散型のケースでユーザーがどの削減を正しいと確認したいか明示的に考える必要があるため、許可されていません。SUM_OVER_BATCH_SIZE は、現時点ではレプリカのバッチサイズのみで除算され、レプリカの数に基づく除算はユーザーに任されており見落としがちなため、許可されていません。そのため、その代わりにユーザー自身が明示的に削減を行うようにお願いしています。

  • もし labels が多次元である場合は、各サンプルの要素数全体で per_example_loss を平均化します。例えば、predictions の形状が (batch_size, H, W, n_classes) で、labels(batch_size, H, W) の場合、per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32) のように per_example_loss を更新する必要があります。

    注意:損失の形状を確認してくださいtf.losses/tf.keras.lossesの損失関数は、通常、入力の最後の次元の平均を返します。損失クラスはこれらの関数をラップします。 損失クラスのインスタンスを作成するときにreduction=Reduction.NONEを渡すことは、「追加の縮小がない」ことを意味します。[batch, W, H, n_classes]の入力形状の例を使用したカテゴリ損失の場合、n_classes次元が縮小されます。losses.mean_squared_errorまたはlosses.binary_crossentropyのような点ごとの損失の場合、ダミー軸を用いて、[batch, W, H, 1][batch, W, H]に縮小します。ダミー軸がないと、[batch, W, H]は誤って[batch, W]に縮小されます。

with strategy.scope():
  # Set reduction to `none` so we can do the reduction afterwards and divide by
  # global batch size.
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True,
      reduction=tf.keras.losses.Reduction.NONE)
  def compute_loss(labels, predictions):
    per_example_loss = loss_object(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

損失と精度を追跡するメトリクスを定義する

これらのメトリクスは、テストの損失、トレーニング、テストの精度を追跡します。.result()を使用して、いつでも累積統計を取得できます。

with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='test_accuracy')
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

トレーニングループ

# model, optimizer, and checkpoint must be created under `strategy.scope`.
with strategy.scope():
  model = create_model()

  optimizer = tf.keras.optimizers.Adam()

  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
def train_step(inputs):
  images, labels = inputs

  with tf.GradientTape() as tape:
    predictions = model(images, training=True)
    loss = compute_loss(labels, predictions)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_accuracy.update_state(labels, predictions)
  return loss 

def test_step(inputs):
  images, labels = inputs

  predictions = model(images, training=False)
  t_loss = loss_object(labels, predictions)

  test_loss.update_state(t_loss)
  test_accuracy.update_state(labels, predictions)
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

@tf.function
def distributed_test_step(dataset_inputs):
  return strategy.run(test_step, args=(dataset_inputs,))

for epoch in range(EPOCHS):
  # TRAIN LOOP
  total_loss = 0.0
  num_batches = 0
  for x in train_dist_dataset:
    total_loss += distributed_train_step(x)
    num_batches += 1
  train_loss = total_loss / num_batches

  # TEST LOOP
  for x in test_dist_dataset:
    distributed_test_step(x)

  if epoch % 2 == 0:
    checkpoint.save(checkpoint_prefix)

  template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
              "Test Accuracy: {}")
  print (template.format(epoch+1, train_loss,
                         train_accuracy.result()*100, test_loss.result(),
                         test_accuracy.result()*100))

  test_loss.reset_states()
  train_accuracy.reset_states()
  test_accuracy.reset_states()
INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1
Epoch 1, Loss: 0.6534212231636047, Accuracy: 76.38333129882812, Test Loss: 0.5108022689819336, Test Accuracy: 80.36000061035156
Epoch 2, Loss: 0.40542471408843994, Accuracy: 85.3133316040039, Test Loss: 0.3908401131629944, Test Accuracy: 85.95999908447266
Epoch 3, Loss: 0.3536377251148224, Accuracy: 87.32333374023438, Test Loss: 0.37182164192199707, Test Accuracy: 86.6199951171875
Epoch 4, Loss: 0.31891700625419617, Accuracy: 88.51000213623047, Test Loss: 0.33639442920684814, Test Accuracy: 87.95999908447266
Epoch 5, Loss: 0.30060118436813354, Accuracy: 89.13999938964844, Test Loss: 0.3324870467185974, Test Accuracy: 88.2300033569336
Epoch 6, Loss: 0.2808188796043396, Accuracy: 89.8949966430664, Test Loss: 0.3083915412425995, Test Accuracy: 88.9000015258789
Epoch 7, Loss: 0.26761382818222046, Accuracy: 90.30166625976562, Test Loss: 0.29076263308525085, Test Accuracy: 89.70999908447266
Epoch 8, Loss: 0.2525147497653961, Accuracy: 90.8699951171875, Test Loss: 0.30839911103248596, Test Accuracy: 88.75
Epoch 9, Loss: 0.23873022198677063, Accuracy: 91.38500213623047, Test Loss: 0.28844621777534485, Test Accuracy: 89.36000061035156
Epoch 10, Loss: 0.22912032902240753, Accuracy: 91.6866683959961, Test Loss: 0.2780952453613281, Test Accuracy: 89.96000671386719

上記の例における注意点:

  • for文(for x in ...)を使用して、train_dist_datasettest_dist_datasetに対してイテレーションしています。
  • スケーリングされた損失は distributed_train_step の戻り値です。この値は tf.distribute.Strategy.reduce 呼び出しを使用してレプリカ間で集約され、次に tf.distribute.Strategy.reduce 呼び出しの戻り値を加算してバッチ間で集約されます。
  • tf.keras.Metrics は、train_steptf.distribute.Strategy.run によって実行される test_step 内で更新される必要があります。*tf.distribute.Strategy.run はストラテジー内の各ローカルレプリカの結果を返し、この結果の使用方法は多様です。tf.distribute.Strategy.reduce を実行して、集約された値を取得することができます。また、tf.distribute.Strategy.experimental_local_results を実行して、ローカルレプリカごとに 1 つ、結果に含まれる値のリストを取得することもできます。

最新のチェックポイントを復元してテストする

tf.distribute.Strategyでチェックポイントされたモデルは、ストラテジーの有無に関わらず復元することができます。

eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='eval_accuracy')

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
@tf.function
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for images, labels in test_dataset:
  eval_step(images, labels)

print ('Accuracy after restoring the saved model without strategy: {}'.format(
    eval_accuracy.result()*100))
Accuracy after restoring the saved model without strategy: 89.36000061035156

データセットのイテレーションの代替方法

イテレータを使用する

データセット全体ではなく、任意のステップ数のイテレーションを行いたい場合は、iter 呼び出しを使用してイテレータを作成し、そのイテレータ上で next を明示的に呼び出すことができます。tf.function の内側と外側の両方でデータセットのイテレーションを選択することができます。ここでは、イテレータを使用し tf.function の外側のデータセットのイテレーションを実行する小さなスニペットを示します。

for _ in range(EPOCHS):
  total_loss = 0.0
  num_batches = 0
  train_iter = iter(train_dist_dataset)

  for _ in range(10):
    total_loss += distributed_train_step(next(train_iter))
    num_batches += 1
  average_train_loss = total_loss / num_batches

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print (template.format(epoch+1, average_train_loss, train_accuracy.result()*100))
  train_accuracy.reset_states()
Epoch 10, Loss: 0.24030835926532745, Accuracy: 90.9765625
Epoch 10, Loss: 0.22981488704681396, Accuracy: 91.3671875
Epoch 10, Loss: 0.23154428601264954, Accuracy: 91.8359375
Epoch 10, Loss: 0.19440846145153046, Accuracy: 93.203125
Epoch 10, Loss: 0.1993107795715332, Accuracy: 93.0859375
Epoch 10, Loss: 0.21297237277030945, Accuracy: 92.1484375
Epoch 10, Loss: 0.22526805102825165, Accuracy: 91.875
Epoch 10, Loss: 0.2222786247730255, Accuracy: 92.3828125
Epoch 10, Loss: 0.21100978553295135, Accuracy: 92.03125
Epoch 10, Loss: 0.21175627410411835, Accuracy: 92.421875

tf.function 内でイテレーションする

tf.function の内側で for 文(for x in ...)を使用して、あるいは上記で行ったようにイテレータを作成して、入力train_dist_dataset全体をイテレーションすることもできます。次の例では、トレーニングの 1 つのエポックを tf.function でラップし、関数内でtrain_dist_datasetをイテレーションする方法を示します。

@tf.function
def distributed_train_epoch(dataset):
  total_loss = 0.0
  num_batches = 0
  for x in dataset:
    per_replica_losses = strategy.run(train_step, args=(x,))
    total_loss += strategy.reduce(
      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
    num_batches += 1
  return total_loss / tf.cast(num_batches, dtype=tf.float32)

for epoch in range(EPOCHS):
  train_loss = distributed_train_epoch(train_dist_dataset)

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print (template.format(epoch+1, train_loss, train_accuracy.result()*100))

  train_accuracy.reset_states()
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py:461: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options.
  warnings.warn("To make it possible to preserve tf.data options across "
INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1
Epoch 1, Loss: 0.21878942847251892, Accuracy: 92.05333709716797
Epoch 2, Loss: 0.20524615049362183, Accuracy: 92.46666717529297
Epoch 3, Loss: 0.19754621386528015, Accuracy: 92.84666442871094
Epoch 4, Loss: 0.1915532350540161, Accuracy: 92.94000244140625
Epoch 5, Loss: 0.182777538895607, Accuracy: 93.26000213623047
Epoch 6, Loss: 0.17301584780216217, Accuracy: 93.67666625976562
Epoch 7, Loss: 0.16668958961963654, Accuracy: 93.94166564941406
Epoch 8, Loss: 0.16198676824569702, Accuracy: 93.9800033569336
Epoch 9, Loss: 0.1535317748785019, Accuracy: 94.38499450683594
Epoch 10, Loss: 0.14288969337940216, Accuracy: 94.76499938964844

レプリカ間でトレーニング損失を追跡する

注意: 一般的なルールとして、サンプルごとの値の追跡にはtf.keras.Metricsを使用し、レプリカ内で集約された値を避ける必要があります。

tf.metrics.Meanを使用すると損失スケーリングが計算されるため、異なるレプリカ間のトレーニング損失の追跡には推奨できません

例えば、次のような特徴を持つトレーニングジョブを実行するとします。

  • レプリカ 2 つ
  • 各レプリカで 2 つのサンプルを処理
  • 結果の損失値 : 各レプリカで [2, 3] および [4, 5]
  • グローバルバッチサイズ = 4

損失スケーリングで損失値を加算して各レプリカのサンプルごとの損失の値を計算し、さらにグローバルバッチサイズで除算します。この場合は、(2 + 3) / 4 = 1.25および(4 + 5) / 4 = 2.25となります。

tf.metrics.Meanを使用して 2 つのレプリカ間の損失を追跡する場合、結果は異なります。この例では、totalは 3.50、countは 2 となり、メトリック上でresult()が呼び出された場合の結果はtotal/count = 1.75 となります。tf.keras.Metricsで計算された損失は、同期するレプリカの数に等しい追加の係数によってスケーリングされます。

ガイドと例

カスタムトレーニングループを用いた分散ストラテジーの使用例をここに幾つか示します。

  1. 分散型トレーニングガイド
  2. MirroredStrategyを使用した DenseNet の例。
  3. MirroredStrategyTPUStrategyを使用してトレーニングされた BERT の例。この例は、分散トレーニングなどの間にチェックポイントから読み込む方法と、定期的にチェックポイントを生成する方法を理解するのに特に有用です。
  4. MirroredStrategy を使用してトレーニングされ、keras_use_ctl フラグを使用した有効化が可能な、NCF の例。
  5. MirroredStrategyを使用してトレーニングされた、NMT の例。

分散ストラテジーガイドには他の例も記載されています。

次のステップ