![]() |
![]() |
![]() |
![]() |
このチュートリアルでは、tf.distribute.Strategy
をカスタムトレーニングループで使用する方法を示します。Fashion MNIST データセットで単純な CNN モデルをトレーニングします。Fashion MNIST データセットには、サイズ 28×28 のトレーニング画像 60,000 枚とサイズ 28×28 のテスト画像 10,000 枚が含まれています。
モデルのトレーニングにカスタムトレーニングループを使用するのは、柔軟性があり、トレーニングを容易に制御できるからです。それに加えて、モデルとトレーニングループのデバッグも容易になります。
# Import TensorFlow
import tensorflow as tf
# Helper libraries
import numpy as np
import os
print(tf.__version__)
2022-08-08 21:42:38.094554: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2022-08-08 21:42:38.856546: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory 2022-08-08 21:42:38.856806: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory 2022-08-08 21:42:38.856818: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly. 2.10.0-rc0
Fashion MNIST データセットをダウンロードする
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]
# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz 29515/29515 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz 26421880/26421880 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz 5148/5148 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz 4422102/4422102 [==============================] - 0s 0us/step
変数とグラフを分散させるストラテジーを作成する
tf.distribute.MirroredStrategy
ストラテジーはどのように機能するのでしょう?
- 全ての変数とモデルグラフはレプリカ上に複製されます。
- 入力はレプリカ全体に均等に分散されます。
- 各レプリカは受け取った入力の損失と勾配を計算します。
- 勾配は加算して全てのレプリカ間で同期されます。
- 同期後、各レプリカ上の変数のコピーにも同じ更新が行われます。
注意: 下記のコードは全て 1 つのスコープ内に入れることができます。説明しやすいように、ここでは幾つかのコードセルに分割しています。
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 4
入力パイプラインをセットアップする
グラフと変数をプラットフォームに依存しない SavedModel 形式にエクスポートします。モデルが保存された後、スコープの有無に関わらずそれを読み込むことができます。
BUFFER_SIZE = len(train_images)
BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
EPOCHS = 10
データセットを作成して、それを配布します。
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)
2022-08-08 21:42:45.954268: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_UINT8 } } } attr { key: "_cardinality" value { i: 60000 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:0" } } attr { key: "output_shapes" value { list { shape { dim { size: 28 } dim { size: 28 } dim { size: 1 } } shape { } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_UINT8 } } } } } 2022-08-08 21:42:46.007276: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_UINT8 } } } attr { key: "_cardinality" value { i: 10000 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:3" } } attr { key: "output_shapes" value { list { shape { dim { size: 28 } dim { size: 28 } dim { size: 1 } } shape { } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_UINT8 } } } } }
モデルを作成する
tf.keras.Sequential
を使用してモデルを作成します。これは Model Subclassing API を使用しても作成できます。
def create_model():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(64, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
return model
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
損失関数を定義する
通常、GPU/CPU を 1 つ搭載した単一のマシンでは、損失は入力バッチ内の例の数で除算されます。
では、tf.distribute.Strategy
を使用する場合、どのように損失を計算すればよいのでしょうか。
例えば、4 つの GPU と 64 のバッチサイズがあるとします。1 つの入力バッチは(4 つの GPU の)レプリカに分散されるので、各レプリカはサイズ 16 の入力を取得します。
各レプリカのモデルは、それぞれの入力でフォワードパスを実行し、損失を計算します。ここでは、損失をそれぞれの入力の例の数(BATCH_SIZE_PER_REPLICA = 16)で除算するのではなく、損失を GLOBAL_BATCH_SIZE (64) で除算する必要があります。
なぜそうするのでしょう?
- 勾配を各レプリカで計算した後にそれらを加算してレプリカ間で同期するためです。
TensorFlow では次のようにします。
このチュートリアルにもあるように、カスタムトレーニングループを書く場合は、サンプルごとの損失を加算し、その合計を GLOBAL_BATCH_SIZE:
scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE)
で除算する必要があります。または、tf.nn.compute_average_loss
を使用することも可能です。これはサンプルごとの損失、オプションのサンプルの重み、そしてGLOBAL_BATCH_SIZE を引数として取り、スケーリングされた損失を返します。モデルで正則化損失を使用している場合は、損失値をレプリカの数でスケーリングする必要があります。これを行うには、
tf.nn.scale_regularization_loss
関数を使用します。tf.reduce_mean
の使用は推奨されません。これを使用すると、損失がレプリカごとの実際のバッチサイズで除算され、ステップごとに変化する場合があります。この縮小とスケーリングは、keras
model.compile
とmodel.fit
で自動的に行われます。以下の例のように
tf.keras.losses
を使用する場合、損失削減はNONE
またはSUM
のいずれかになるよう、明示的に指定する必要があります。AUTO
およびSUM_OVER_BATCH_SIZE
のtf.distribute.Strategy
との併用は許可されません。AUTO
は、分散型のケースでユーザーがどの削減を正しいと確認したいか明示的に考える必要があるため、許可されていません。SUM_OVER_BATCH_SIZE
は、現時点ではレプリカのバッチサイズのみで除算され、レプリカの数に基づく除算はユーザーに任されており見落としがちなため、許可されていません。そのため、その代わりにユーザー自身が明示的に削減を行うようにお願いしています。もし
labels
が多次元である場合は、各サンプルの要素数全体でper_example_loss
を平均化します。例えば、predictions
の形状が(batch_size, H, W, n_classes)
で、labels
が(batch_size, H, W)
の場合、per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)
のようにper_example_loss
を更新する必要があります。注意:損失の形状を確認してください。
tf.losses
/tf.keras.losses
の損失関数は、通常、入力の最後の次元の平均を返します。損失クラスはこれらの関数をラップします。 損失クラスのインスタンスを作成するときにreduction=Reduction.NONE
を渡すことは、「追加の縮小がない」ことを意味します。[batch, W, H, n_classes]
の入力形状の例を使用したカテゴリ損失の場合、n_classes
次元が縮小されます。losses.mean_squared_error
またはlosses.binary_crossentropy
のような点ごとの損失の場合、ダミー軸を用いて、[batch, W, H, 1]
を[batch, W, H]
に縮小します。ダミー軸がないと、[batch, W, H]
は誤って[batch, W]
に縮小されます。
with strategy.scope():
# Set reduction to `none` so we can do the reduction afterwards and divide by
# global batch size.
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE)
def compute_loss(labels, predictions):
per_example_loss = loss_object(labels, predictions)
return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)
損失と精度を追跡するメトリクスを定義する
これらのメトリクスは、テストの損失、トレーニング、テストの精度を追跡します。.result()
を使用して、いつでも累積統計を取得できます。
with strategy.scope():
test_loss = tf.keras.metrics.Mean(name='test_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='train_accuracy')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='test_accuracy')
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
トレーニングループ
# model, optimizer, and checkpoint must be created under `strategy.scope`.
with strategy.scope():
model = create_model()
optimizer = tf.keras.optimizers.Adam()
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
def train_step(inputs):
images, labels = inputs
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = compute_loss(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_accuracy.update_state(labels, predictions)
return loss
def test_step(inputs):
images, labels = inputs
predictions = model(images, training=False)
t_loss = loss_object(labels, predictions)
test_loss.update_state(t_loss)
test_accuracy.update_state(labels, predictions)
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
axis=None)
@tf.function
def distributed_test_step(dataset_inputs):
return strategy.run(test_step, args=(dataset_inputs,))
for epoch in range(EPOCHS):
# TRAIN LOOP
total_loss = 0.0
num_batches = 0
for x in train_dist_dataset:
total_loss += distributed_train_step(x)
num_batches += 1
train_loss = total_loss / num_batches
# TEST LOOP
for x in test_dist_dataset:
distributed_test_step(x)
if epoch % 2 == 0:
checkpoint.save(checkpoint_prefix)
template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
"Test Accuracy: {}")
print (template.format(epoch+1, train_loss,
train_accuracy.result()*100, test_loss.result(),
test_accuracy.result()*100))
test_loss.reset_states()
train_accuracy.reset_states()
test_accuracy.reset_states()
INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1 Epoch 1, Loss: 0.6534212231636047, Accuracy: 76.38333129882812, Test Loss: 0.5108022689819336, Test Accuracy: 80.36000061035156 Epoch 2, Loss: 0.40542471408843994, Accuracy: 85.3133316040039, Test Loss: 0.3908401131629944, Test Accuracy: 85.95999908447266 Epoch 3, Loss: 0.3536377251148224, Accuracy: 87.32333374023438, Test Loss: 0.37182164192199707, Test Accuracy: 86.6199951171875 Epoch 4, Loss: 0.31891700625419617, Accuracy: 88.51000213623047, Test Loss: 0.33639442920684814, Test Accuracy: 87.95999908447266 Epoch 5, Loss: 0.30060118436813354, Accuracy: 89.13999938964844, Test Loss: 0.3324870467185974, Test Accuracy: 88.2300033569336 Epoch 6, Loss: 0.2808188796043396, Accuracy: 89.8949966430664, Test Loss: 0.3083915412425995, Test Accuracy: 88.9000015258789 Epoch 7, Loss: 0.26761382818222046, Accuracy: 90.30166625976562, Test Loss: 0.29076263308525085, Test Accuracy: 89.70999908447266 Epoch 8, Loss: 0.2525147497653961, Accuracy: 90.8699951171875, Test Loss: 0.30839911103248596, Test Accuracy: 88.75 Epoch 9, Loss: 0.23873022198677063, Accuracy: 91.38500213623047, Test Loss: 0.28844621777534485, Test Accuracy: 89.36000061035156 Epoch 10, Loss: 0.22912032902240753, Accuracy: 91.6866683959961, Test Loss: 0.2780952453613281, Test Accuracy: 89.96000671386719
上記の例における注意点:
- for文(
for x in ...
)を使用して、train_dist_dataset
とtest_dist_dataset
に対してイテレーションしています。 - スケーリングされた損失は
distributed_train_step
の戻り値です。この値はtf.distribute.Strategy.reduce
呼び出しを使用してレプリカ間で集約され、次にtf.distribute.Strategy.reduce
呼び出しの戻り値を加算してバッチ間で集約されます。 tf.keras.Metrics
は、train_step
とtf.distribute.Strategy.run
によって実行されるtest_step
内で更新される必要があります。*tf.distribute.Strategy.run
はストラテジー内の各ローカルレプリカの結果を返し、この結果の使用方法は多様です。tf.distribute.Strategy.reduce
を実行して、集約された値を取得することができます。また、tf.distribute.Strategy.experimental_local_results
を実行して、ローカルレプリカごとに 1 つ、結果に含まれる値のリストを取得することもできます。
最新のチェックポイントを復元してテストする
tf.distribute.Strategy
でチェックポイントされたモデルは、ストラテジーの有無に関わらず復元することができます。
eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='eval_accuracy')
new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
@tf.function
def eval_step(images, labels):
predictions = new_model(images, training=False)
eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
for images, labels in test_dataset:
eval_step(images, labels)
print ('Accuracy after restoring the saved model without strategy: {}'.format(
eval_accuracy.result()*100))
Accuracy after restoring the saved model without strategy: 89.36000061035156
データセットのイテレーションの代替方法
イテレータを使用する
データセット全体ではなく、任意のステップ数のイテレーションを行いたい場合は、iter
呼び出しを使用してイテレータを作成し、そのイテレータ上で next
を明示的に呼び出すことができます。tf.function の内側と外側の両方でデータセットのイテレーションを選択することができます。ここでは、イテレータを使用し tf.function の外側のデータセットのイテレーションを実行する小さなスニペットを示します。
for _ in range(EPOCHS):
total_loss = 0.0
num_batches = 0
train_iter = iter(train_dist_dataset)
for _ in range(10):
total_loss += distributed_train_step(next(train_iter))
num_batches += 1
average_train_loss = total_loss / num_batches
template = ("Epoch {}, Loss: {}, Accuracy: {}")
print (template.format(epoch+1, average_train_loss, train_accuracy.result()*100))
train_accuracy.reset_states()
Epoch 10, Loss: 0.24030835926532745, Accuracy: 90.9765625 Epoch 10, Loss: 0.22981488704681396, Accuracy: 91.3671875 Epoch 10, Loss: 0.23154428601264954, Accuracy: 91.8359375 Epoch 10, Loss: 0.19440846145153046, Accuracy: 93.203125 Epoch 10, Loss: 0.1993107795715332, Accuracy: 93.0859375 Epoch 10, Loss: 0.21297237277030945, Accuracy: 92.1484375 Epoch 10, Loss: 0.22526805102825165, Accuracy: 91.875 Epoch 10, Loss: 0.2222786247730255, Accuracy: 92.3828125 Epoch 10, Loss: 0.21100978553295135, Accuracy: 92.03125 Epoch 10, Loss: 0.21175627410411835, Accuracy: 92.421875
tf.function 内でイテレーションする
tf.function の内側で for 文(for x in ...
)を使用して、あるいは上記で行ったようにイテレータを作成して、入力train_dist_dataset
全体をイテレーションすることもできます。次の例では、トレーニングの 1 つのエポックを tf.function でラップし、関数内でtrain_dist_dataset
をイテレーションする方法を示します。
@tf.function
def distributed_train_epoch(dataset):
total_loss = 0.0
num_batches = 0
for x in dataset:
per_replica_losses = strategy.run(train_step, args=(x,))
total_loss += strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
num_batches += 1
return total_loss / tf.cast(num_batches, dtype=tf.float32)
for epoch in range(EPOCHS):
train_loss = distributed_train_epoch(train_dist_dataset)
template = ("Epoch {}, Loss: {}, Accuracy: {}")
print (template.format(epoch+1, train_loss, train_accuracy.result()*100))
train_accuracy.reset_states()
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py:461: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options. warnings.warn("To make it possible to preserve tf.data options across " INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1 Epoch 1, Loss: 0.21878942847251892, Accuracy: 92.05333709716797 Epoch 2, Loss: 0.20524615049362183, Accuracy: 92.46666717529297 Epoch 3, Loss: 0.19754621386528015, Accuracy: 92.84666442871094 Epoch 4, Loss: 0.1915532350540161, Accuracy: 92.94000244140625 Epoch 5, Loss: 0.182777538895607, Accuracy: 93.26000213623047 Epoch 6, Loss: 0.17301584780216217, Accuracy: 93.67666625976562 Epoch 7, Loss: 0.16668958961963654, Accuracy: 93.94166564941406 Epoch 8, Loss: 0.16198676824569702, Accuracy: 93.9800033569336 Epoch 9, Loss: 0.1535317748785019, Accuracy: 94.38499450683594 Epoch 10, Loss: 0.14288969337940216, Accuracy: 94.76499938964844
レプリカ間でトレーニング損失を追跡する
注意: 一般的なルールとして、サンプルごとの値の追跡にはtf.keras.Metrics
を使用し、レプリカ内で集約された値を避ける必要があります。
tf.metrics.Mean
を使用すると損失スケーリングが計算されるため、異なるレプリカ間のトレーニング損失の追跡には推奨できません。
例えば、次のような特徴を持つトレーニングジョブを実行するとします。
- レプリカ 2 つ
- 各レプリカで 2 つのサンプルを処理
- 結果の損失値 : 各レプリカで [2, 3] および [4, 5]
- グローバルバッチサイズ = 4
損失スケーリングで損失値を加算して各レプリカのサンプルごとの損失の値を計算し、さらにグローバルバッチサイズで除算します。この場合は、(2 + 3) / 4 = 1.25
および(4 + 5) / 4 = 2.25
となります。
tf.metrics.Mean
を使用して 2 つのレプリカ間の損失を追跡する場合、結果は異なります。この例では、total
は 3.50、count
は 2 となり、メトリック上でresult()
が呼び出された場合の結果はtotal
/count
= 1.75 となります。tf.keras.Metrics
で計算された損失は、同期するレプリカの数に等しい追加の係数によってスケーリングされます。
ガイドと例
カスタムトレーニングループを用いた分散ストラテジーの使用例をここに幾つか示します。
- 分散型トレーニングガイド
MirroredStrategy
を使用した DenseNet の例。MirroredStrategy
とTPUStrategy
を使用してトレーニングされた BERT の例。この例は、分散トレーニングなどの間にチェックポイントから読み込む方法と、定期的にチェックポイントを生成する方法を理解するのに特に有用です。MirroredStrategy
を使用してトレーニングされ、keras_use_ctl
フラグを使用した有効化が可能な、NCF の例。MirroredStrategy
を使用してトレーニングされた、NMT の例。
分散ストラテジーガイドには他の例も記載されています。
次のステップ
- 新しい
tf.distribute.Strategy
API を独自のモデルで試してみましょう。 - 他のストラテジーや独自の TensorFlow モデルのパフォーマンス最適化に使用できるツールについての詳細は、ガイドのパフォーマンスのセクションをご覧ください。