![]() |
![]() |
![]() |
![]() |
이 예제는 사용자 정의 훈련 루프(custom training loops)와 함께 tf.distribute.Strategy
를 사용하는 법을 보여드립니다. 우리는 간단한 CNN 모델을 패션 MNIST 데이터셋에 대해 훈련을 할 것입니다. 패션 MNIST 데이터셋은 60000개의 28 x 28 크기의 훈련 이미지들과 10000개의 28 x 28 크기의 테스트 이미지들을 포함하고 있습니다.
이 예제는 유연성을 높이고, 훈련을 더 잘 제어할 수 있도록 사용자 정의 훈련 루프를 사용합니다. 또한, 사용자 훈련 루프를 사용하는 것은 모델과 훈련 루프를 디버깅하기 쉬워집니다.
# Import TensorFlow
import tensorflow as tf
# Helper libraries
import numpy as np
import os
print(tf.__version__)
2.6.0
패션 MNIST 데이터셋 다운로드
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]
# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)
변수와 그래프를 분산하는 전략 만들기
tf.distribute.MirroredStrategy
전략이 어떻게 동작할까요?
- 모든 변수와 모델 그래프는 장치(replicas, 다른 문서에서는 replica가 분산 훈련에서 장치 등에 복제된 모델을 의미하는 경우가 있으나 이 문서에서는 장치 자체를 의미합니다)에 복제됩니다.
- 입력은 장치에 고르게 분배되어 들어갑니다.
- 각 장치는 주어지는 입력에 대해서 손실(loss)과 그래디언트를 계산합니다.
- 그래디언트들을 전부 더함으로써 모든 장치들 간에 그래디언트들이 동기화됩니다.
- 동기화된 후에, 동일한 업데이트가 각 장치에 있는 변수의 복사본(copies)에 동일하게 적용됩니다.
노트: 하나의 범위를 지정해서 모든 코드를 집어넣을 수 있습니다. 자, 같이 살펴보시죠!
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1
입력 파이프라인 설정하기
플랫폼에 무관한 SavedModel 형식으로 그래프와 변수들을 내보냅니다. 모델을 저장한 후에는 범위 없이 불러올 수도 있고, 범위를 지정하여 불러올 수도 있습니다.
BUFFER_SIZE = len(train_images)
BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
EPOCHS = 10
데이터세트를 만들고 배포합니다.
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)
2021-08-25 20:38:13.157244: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:695] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_UINT8 } } } attr { key: "output_shapes" value { list { shape { dim { size: 28 } dim { size: 28 } dim { size: 1 } } shape { } } } } 2021-08-25 20:38:13.199332: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:695] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_UINT8 } } } attr { key: "output_shapes" value { list { shape { dim { size: 28 } dim { size: 28 } dim { size: 1 } } shape { } } } }
모델 만들기
tf.keras.Sequential
을 사용해서 모델을 생성합니다. Model Subclassing API로도 모델 생성을 할 수 있습니다.
def create_model():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(64, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
return model
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
손실 함수 정의하기
일반적으로 1 GPU/CPU가 있는 단일 시스템에서 손실은 입력 배치의 예제 수로 나뉩니다.
그렇다면 tf.distribute.Strategy
를 사용할 때 손실을 어떻게 계산해야 할까요?
예를 들어 GPU가 4개 있고 배치 크기가 64라고 가정해 보겠습니다. 하나의 입력 배치가 전체 복제본(4개의 GPU)에 걸쳐 분배되고 각 복제본은 크기 16의 입력을 받습니다.
각 복제본의 모델은 해당 입력으로 순방향 전달을 수행하고 손실을 계산합니다. 이제 손실을 해당 입력의 예제 수로 나누는 대신(BATCH_SIZE_PER_REPLICA = 16), 손실을 GLOBAL_BATCH_SIZE(64)로 나누어야 합니다.
이렇게 하는 이유는 무엇일까요?
- 각 복제본에서 그래디언트가 계산된 후 이를 합산하여 전체 복제본에 걸쳐 동기화되기 때문에 이렇게 해야 합니다.
TensorFlow에서 이 작업을 어떻게 수행할까요?
이 튜토리얼에서와 같이 사용자 지정 훈련 루프를 작성하는 경우 예제당 손실을 합산하고 합계를 GLOBAL_BATCH_SIZE로 나누어야 합니다:
scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE)
또는tf.nn.compute_average_loss
를 사용하여 예제당 손실, 선택적 샘플 가중치, GLOBAL_BATCH_SIZE를 인수로 사용하고 조정된 손실을 반환할 수 있습니다.모델에서 정규화 손실을 사용하는 경우 복제본 수에 따라 손실 값을 확장해야 합니다.
tf.nn.scale_regularization_loss
함수를 사용하여 이를 수행할 수 있습니다.tf.reduce_mean
을 사용하는 것은 권장하지 않습니다. 이렇게 하면 손실이 실제 복제본 배치 크기별로 나눠지는데, 이는 단계별로 다를 수 있습니다.이 축소 및 크기 조정은 keras
model.compile
및model.fit
에서 자동으로 수행됩니다.tf.keras.losses
클래스를 사용하는 경우(아래 예와 같이) 손실 축소는NONE
또는SUM
중 하나로 명시적으로 지정되어야 합니다.AUTO
및SUM_OVER_BATCH_SIZE
는tf.distribute.Strategy
와 함께 사용할 때 허용되지 않습니다.AUTO
는 사용자가 어떤 축소가 필요한지 명시적으로 생각해야 하므로(분산된 경우에 이러한 축소가 올바른지 확인하기 위해) 허용되지 않습니다.SUM_OVER_BATCH_SIZE
는 현재 복제본 배치 크기별로만 나누고 복제본 수로 나누는 일은 사용자에게 맡기는 데, 이를 쉽게 놓칠 수 있다는 점 때문에 허용되지 않습니다. 따라서 사용자가 직접 명시적으로 축소를 수행할 것을 권장합니다.labels
이 다차원인 경우 각 샘플의 요소 수에 걸쳐per_example_loss
의 평균을 구합니다. 예를 들어predictions
의 형상이(batch_size, H, W, n_classes)
이고labels
이(batch_size, H, W)
인 경우, 다음과 같이per_example_loss
를 업데이트해야 합니다:per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)
주의: 손실의 형상을 확인하세요.
tf.losses
/tf.keras.losses
의 손실 함수는 일반적으로 입력의 마지막 차원에 대한 평균을 반환합니다. 손실 클래스는 이러한 함수를 래핑합니다. 손실 클래스의 인스턴스를 생성할 때reduction=Reduction.NONE
을 전달하는 것은 "추가적인 축소가 없음"을 의미합니다.[batch, W, H, n_classes]
의 예제 입력 형상을 갖는 범주형 손실의 경우,n_classes
차원이 축소됩니다.losses.mean_squared_error
또는losses.binary_crossentropy
와 같은 포인트별 손실의 경우, 더미 축을 포함시켜[batch, W, H, 1]
이[batch, W, H]
로 축소되도록 합니다. 더미 축이 없으면[batch, W, H]
가[batch, W]
로 잘못 축소됩니다.
with strategy.scope():
# Set reduction to `none` so we can do the reduction afterwards and divide by
# global batch size.
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE)
def compute_loss(labels, predictions):
per_example_loss = loss_object(labels, predictions)
return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)
손실과 정확도를 기록하기 위한 지표 정의하기
이 지표(metrics)는 테스트 손실과 훈련 정확도, 테스트 정확도를 기록합니다. .result()
를 사용해서 누적된 통계값들을 언제나 볼 수 있습니다.
with strategy.scope():
test_loss = tf.keras.metrics.Mean(name='test_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='train_accuracy')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='test_accuracy')
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
훈련 루프
# model, optimizer, and checkpoint must be created under `strategy.scope`.
with strategy.scope():
model = create_model()
optimizer = tf.keras.optimizers.Adam()
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
def train_step(inputs):
images, labels = inputs
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = compute_loss(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_accuracy.update_state(labels, predictions)
return loss
def test_step(inputs):
images, labels = inputs
predictions = model(images, training=False)
t_loss = loss_object(labels, predictions)
test_loss.update_state(t_loss)
test_accuracy.update_state(labels, predictions)
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
axis=None)
@tf.function
def distributed_test_step(dataset_inputs):
return strategy.run(test_step, args=(dataset_inputs,))
for epoch in range(EPOCHS):
# TRAIN LOOP
total_loss = 0.0
num_batches = 0
for x in train_dist_dataset:
total_loss += distributed_train_step(x)
num_batches += 1
train_loss = total_loss / num_batches
# TEST LOOP
for x in test_dist_dataset:
distributed_test_step(x)
if epoch % 2 == 0:
checkpoint.save(checkpoint_prefix)
template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
"Test Accuracy: {}")
print (template.format(epoch+1, train_loss,
train_accuracy.result()*100, test_loss.result(),
test_accuracy.result()*100))
test_loss.reset_states()
train_accuracy.reset_states()
test_accuracy.reset_states()
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/backend.py:4907: UserWarning: "`sparse_categorical_crossentropy` received `from_logits=True`, but the `output` argument was produced by a sigmoid or softmax activation and thus does not represent logits. Was this intended?" '"`sparse_categorical_crossentropy` received `from_logits=True`, but ' Epoch 1, Loss: 0.5048843026161194, Accuracy: 81.97167205810547, Test Loss: 0.37391185760498047, Test Accuracy: 86.55999755859375 Epoch 2, Loss: 0.3322712481021881, Accuracy: 87.94833374023438, Test Loss: 0.3292912542819977, Test Accuracy: 88.12999725341797 Epoch 3, Loss: 0.2867870628833771, Accuracy: 89.44000244140625, Test Loss: 0.30785882472991943, Test Accuracy: 88.84000396728516 Epoch 4, Loss: 0.2577211558818817, Accuracy: 90.625, Test Loss: 0.2837044298648834, Test Accuracy: 89.94000244140625 Epoch 5, Loss: 0.23549160361289978, Accuracy: 91.29166412353516, Test Loss: 0.2737273573875427, Test Accuracy: 90.13999938964844 Epoch 6, Loss: 0.21584177017211914, Accuracy: 92.04999542236328, Test Loss: 0.2593098282814026, Test Accuracy: 90.94000244140625 Epoch 7, Loss: 0.19817812740802765, Accuracy: 92.74166870117188, Test Loss: 0.25772857666015625, Test Accuracy: 90.75 Epoch 8, Loss: 0.18283750116825104, Accuracy: 93.17666625976562, Test Loss: 0.2609715163707733, Test Accuracy: 90.80999755859375 Epoch 9, Loss: 0.16872815787792206, Accuracy: 93.68499755859375, Test Loss: 0.2585345506668091, Test Accuracy: 90.88999938964844 Epoch 10, Loss: 0.15464162826538086, Accuracy: 94.35166931152344, Test Loss: 0.28132060170173645, Test Accuracy: 90.48999786376953
위의 예제에서 주목해야 하는 부분
- 이 예제는
train_dist_dataset
과test_dist_dataset
을for x in ...
구조를 통해서 반복합니다. - 스케일이 조정된 손실은
distributed_train_step
의 반환값입니다.tf.distribute.Strategy.reduce
호출을 사용해서 장치들 간의 스케일이 조정된 손실 값을 전부 합칩니다. 그리고 나서tf.distribute.Strategy.reduce
반환 값을 더하는 식으로 배치 간의 손실을 모읍니다. tf.keras.Metrics
는tf.distribute.Strategy.run
에 의해서 실행되는train_step
과test_step
함수 안에서 업데이트되어야 합니다.*tf.distribute.Strategy.run
은 전략의 각 로컬 복제본에서 결과를 반환하며, 이 결과를 소비하는 여러 방법이 있습니다.tf.distribute.Strategy.reduce
를 수행하여 집계된 값을 얻을 수 있습니다.tf.distribute.Strategy.experimental_local_results
를 수행하여 결과에 포함된 값 목록을 로컬 복제당 하나씩 가져올 수도 있습니다.
최신 체크포인트를 불러와서 테스트하기
tf.distribute.Strategy
를 사용해서 체크포인트가 만들어진 모델은 전략 사용 여부에 상관없이 불러올 수 있습니다.
eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='eval_accuracy')
new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
@tf.function
def eval_step(images, labels):
predictions = new_model(images, training=False)
eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
for images, labels in test_dataset:
eval_step(images, labels)
print ('전략을 사용하지 않고, 저장된 모델을 복원한 후의 정확도: {}'.format(
eval_accuracy.result()*100))
전략을 사용하지 않고, 저장된 모델을 복원한 후의 정확도: 90.88999938964844
데이터셋에 대해 반복작업을 하는 다른 방법들
반복자(iterator)를 사용하기
만약 주어진 스텝의 수에 따라서 반복하기 원하면서 전체 데이터셋을 보는 것을 원치 않는다면, iter
를 호출하여 반복자를 만들 수 있습니다. 그 다음 명시적으로 next
를 호출합니다. 또한, tf.funtion
내부 또는 외부에서 데이터셋을 반복하도록 설정 할 수 있습니다. 다음은 반복자를 사용하여 tf.function
외부에서 데이터셋을 반복하는 코드 예제입니다.
for _ in range(EPOCHS):
total_loss = 0.0
num_batches = 0
train_iter = iter(train_dist_dataset)
for _ in range(10):
total_loss += distributed_train_step(next(train_iter))
num_batches += 1
average_train_loss = total_loss / num_batches
template = ("Epoch {}, Loss: {}, Accuracy: {}")
print (template.format(epoch+1, average_train_loss, train_accuracy.result()*100))
train_accuracy.reset_states()
Epoch 10, Loss: 0.1242208480834961, Accuracy: 96.25 Epoch 10, Loss: 0.12924882769584656, Accuracy: 94.53125 Epoch 10, Loss: 0.1058281883597374, Accuracy: 96.25 Epoch 10, Loss: 0.13499769568443298, Accuracy: 95.0 Epoch 10, Loss: 0.10734955221414566, Accuracy: 96.40625 Epoch 10, Loss: 0.15909847617149353, Accuracy: 94.0625 Epoch 10, Loss: 0.13288195431232452, Accuracy: 95.15625 Epoch 10, Loss: 0.12446471303701401, Accuracy: 95.9375 Epoch 10, Loss: 0.13399545848369598, Accuracy: 94.84375 Epoch 10, Loss: 0.14339366555213928, Accuracy: 95.0
tf.function 내부에서 반복하기
전체 입력 train_dist_dataset
에 대해서 tf.function
내부에서 for x in ...
생성자를 사용함으로써 반복을 하거나, 위에서 사용했던 것처럼 반복자를 사용함으로써 반복을 할 수 있습니다. 아래의 예제에서는 tf.function
로 한 훈련의 에포크를 감싸고 그 함수에서 train_dist_dataset
를 반복하는 것을 보여 줍니다.
@tf.function
def distributed_train_epoch(dataset):
total_loss = 0.0
num_batches = 0
for x in dataset:
per_replica_losses = strategy.run(train_step, args=(x,))
total_loss += strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
num_batches += 1
return total_loss / tf.cast(num_batches, dtype=tf.float32)
for epoch in range(EPOCHS):
train_loss = distributed_train_epoch(train_dist_dataset)
template = ("Epoch {}, Loss: {}, Accuracy: {}")
print (template.format(epoch+1, train_loss, train_accuracy.result()*100))
train_accuracy.reset_states()
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py:374: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options. warnings.warn("To make it possible to preserve tf.data options across " Epoch 1, Loss: 0.1427239030599594, Accuracy: 94.64167022705078 Epoch 2, Loss: 0.13270048797130585, Accuracy: 95.09666442871094 Epoch 3, Loss: 0.12277950346469879, Accuracy: 95.52166748046875 Epoch 4, Loss: 0.11346586793661118, Accuracy: 95.75 Epoch 5, Loss: 0.10177541524171829, Accuracy: 96.23666381835938 Epoch 6, Loss: 0.09605162590742111, Accuracy: 96.413330078125 Epoch 7, Loss: 0.0875953882932663, Accuracy: 96.82666778564453 Epoch 8, Loss: 0.08049703389406204, Accuracy: 97.08833312988281 Epoch 9, Loss: 0.0770588144659996, Accuracy: 97.1500015258789 Epoch 10, Loss: 0.0700046569108963, Accuracy: 97.37999725341797
장치 간의 훈련 손실 기록하기
노트: 일반적인 규칙으로, tf.keras.Metrics
를 사용하여 샘플당 손실 값을 기록하고 장치 내부에서 값이 합쳐지는 것을 피해야 합니다.
tf.metrics.Mean
을 사용하여 여러 장치의 훈련 손실을 기록하는 것을 추천하지 않습니다. 왜냐하면 손실의 스케일을 조정하는 계산이 수행되기 때문입니다.
예를 들어, 다음과 같은 조건의 훈련을 수행한다고 합시다.
- 두개의 장치
- 두개의 샘플들이 각 장치에 의해 처리됩니다.
- 손실 값을 산출합니다: 각각의 장치에 대해 [2, 3]과 [4, 5]
- Global batch size = 4
손실의 스케일 조정을 하면, 손실 값을 더하고 전역 배치 크기로 나누어 각 장치에 대한 샘플당 손실값을 계산할 수 있습니다. 이 경우에는 (2 + 3) / 4 = 1.25
및 (4 + 5) / 4 = 2.25
입니다.
만약 tf.metrics.Mean
을 사용해서 두 개의 장치에 대해 손실값을 계산한다면, 결과값이 다릅니다. 이 예제에서는, 측정 지표의 result()
가 메서드가 호출될 때 total
이 3.50이고 count
가 2입니다. 결과값은 total/count
가 1.75가 됩니다. tf.keras.Metrics
를 이용해서 계산한 손실값이 추가적인 요인에 의해서 크기조정되며, 이 추가적인 요인은 동기화되는 장치의 개수입니다.
예제와 튜토리얼
사용자 정의 훈련루프를 포함한 분산 전략을 사용하는 몇 가지 예제가 있습니다.
- 분산 교육 가이드
MirroredStrategy
를 사용하는 DenseNet 예제.MirroredStrategy
와TPUStrategy
를 사용해서 훈련하는 BERT 예제. 이 예제는 분산 훈련 중에 어떻게 체크포인트로부터 불러오는지와 어떻게 주기적으로 체크포인트들을 생성해 내는지를 이해하기에 정말 좋습니다.keras_use_ctl flag
를 사용해서 활성화 할 수 있는 MirroredStrategy를 이용해서 훈련되는 NCF 예제MirroredStrategy
을 사용해서 훈련되는 NMT 예제.
더 많은 예제는 여기에 있습니다. Distribution strategy guide