케라스를 사용한 분산 훈련

TensorFlow.org에서 보기 구글 코랩(Colab)에서 실행하기 깃허브(GitHub) 소스 보기

개요

tf.distribute.Strategy API는 훈련을 여러 처리 장치들로 분산시키는 것을 추상화한 것입니다. 기존의 모델이나 훈련 코드를 조금만 바꾸어 분산 훈련을 할 수 있게 하는 것이 분산 전략 API의 목표입니다.

이 튜토리얼에서는 tf.distribute.MirroredStrategy를 사용합니다. 이 전략은 동기화된 훈련 방식을 활용하여 한 장비에 있는 여러 개의 GPU로 그래프 내 복제를 수행합니다. 다시 말하자면, 모델의 모든 변수를 각 프로세서에 복사합니다. 그리고 각 프로세서의 그래디언트(gradient)를 올 리듀스(all-reduce)를 사용하여 모읍니다. 그다음 모아서 계산한 값을 각 프로세서의 모델 복사본에 적용합니다.

MirroredStategy는 텐서플로에서 기본으로 제공하는 몇 가지 분산 전략 중 하나입니다. 다른 전략들에 대해서는 분산 전략 가이드를 참고하십시오.

케라스 API

이 예는 모델과 훈련 루프를 만들기 위해 tf.keras API를 사용합니다. 직접 훈련 코드를 작성하는 방법은 사용자 정의 훈련 루프로 분산 훈련하기 튜토리얼을 참고하십시오.

필요한 패키지 가져오기

from __future__ import absolute_import, division, print_function, unicode_literals

# 텐서플로와 텐서플로 데이터셋 패키지 가져오기
!pip install -q tensorflow-gpu==2.0.0-rc1
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()

import os

데이터셋 다운로드

MNIST 데이터셋을 TensorFlow Datasets에서 다운로드받은 후 불러옵니다. 이 함수는 tf.data 형식을 반환합니다.

with_infoTrue로 설정하면 전체 데이터에 대한 메타 정보도 함께 불러옵니다. 이 정보는 info 변수에 저장됩니다. 여기에는 훈련과 테스트 샘플 수를 비롯한 여러가지 정보들이 들어있습니다.

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)

mnist_train, mnist_test = datasets['train'], datasets['test']

분산 전략 정의하기

분산과 관련된 처리를 하는 MirroredStrategy 객체를 만듭니다. 이 객체가 컨텍스트 관리자(tf.distribute.MirroredStrategy.scope)도 제공하는데, 이 안에서 모델을 만들어야 합니다.

strategy = tf.distribute.MirroredStrategy()
print('장치의 수: {}'.format(strategy.num_replicas_in_sync))
장치의 수: 1

입력 파이프라인 구성하기

다중 GPU로 모델을 훈련할 때는 배치 크기를 늘려야 컴퓨팅 자원을 효과적으로 사용할 수 있습니다. 기본적으로는 GPU 메모리에 맞추어 가능한 가장 큰 배치 크기를 사용하십시오. 이에 맞게 학습률도 조정해야 합니다.

# 데이터셋 내 샘플의 수는 info.splits.total_num_examples 로도
# 얻을 수 있습니다.

num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

픽셀의 값은 0~255 사이이므로 0-1 범위로 정규화해야 합니다. 정규화 함수를 정의합니다.

def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

이 함수를 훈련과 테스트 데이터에 적용합니다. 훈련 데이터 순서를 섞고, 훈련을 위해 배치로 묶습니다.

train_dataset = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

모델 만들기

strategy.scope 컨텍스트 안에서 케라스 모델을 만들고 컴파일합니다.

with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')
  ])

  model.compile(loss='sparse_categorical_crossentropy',
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

콜백 정의하기

여기서 사용하는 콜백은 다음과 같습니다.

  • 텐서보드(TensorBoard): 이 콜백은 텐서보드용 로그를 남겨서, 텐서보드에서 그래프를 그릴 수 있게 해줍니다.
  • 모델 체크포인트(Checkpoint): 이 콜백은 매 에포크(epoch)가 끝난 후 모델을 저장합니다.
  • 학습률 스케줄러: 이 콜백을 사용하면 매 에포크 혹은 배치가 끝난 후 학습률을 바꿀 수 있습니다.

콜백을 추가하는 방법을 보여드리기 위하여 노트북에 학습률을 표시하는 콜백도 추가하겠습니다.

# 체크포인트를 저장할 체크포인트 디렉터리를 지정합니다.
checkpoint_dir = './training_checkpoints'
# 체크포인트 파일의 이름
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# 학습률을 점점 줄이기 위한 함수
# 필요한 함수를 직접 정의하여 사용할 수 있습니다.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5
# 에포크가 끝날 때마다 학습률을 출력하는 콜백.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print('\n에포크 {}의 학습률은 {}입니다.'.format(epoch + 1,
                                                      model.optimizer.lr.numpy()))
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

훈련과 평가

이제 평소처럼 모델을 학습합시다. 모델의 fit 함수를 호출하고 튜토리얼의 시작 부분에서 만든 데이터셋을 넘깁니다. 이 단계는 분산 훈련 여부와 상관없이 동일합니다.

model.fit(train_dataset, epochs=12, callbacks=callbacks)
Epoch 1/12
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

    938/Unknown - 12s 13ms/step - loss: 0.2026 - accuracy: 0.9410
에포크 1의 학습률은 0.0010000000474974513입니다.
938/938 [==============================] - 12s 13ms/step - loss: 0.2026 - accuracy: 0.9410
Epoch 2/12
937/938 [============================>.] - ETA: 0s - loss: 0.0698 - accuracy: 0.9790
에포크 2의 학습률은 0.0010000000474974513입니다.
938/938 [==============================] - 7s 8ms/step - loss: 0.0698 - accuracy: 0.9790
Epoch 3/12
923/938 [============================>.] - ETA: 0s - loss: 0.0485 - accuracy: 0.9854
에포크 3의 학습률은 0.0010000000474974513입니다.
938/938 [==============================] - 7s 8ms/step - loss: 0.0485 - accuracy: 0.9854
Epoch 4/12
936/938 [============================>.] - ETA: 0s - loss: 0.0267 - accuracy: 0.9928
에포크 4의 학습률은 9.999999747378752e-05입니다.
938/938 [==============================] - 8s 8ms/step - loss: 0.0267 - accuracy: 0.9928
Epoch 5/12
927/938 [============================>.] - ETA: 0s - loss: 0.0236 - accuracy: 0.9935
에포크 5의 학습률은 9.999999747378752e-05입니다.
938/938 [==============================] - 7s 8ms/step - loss: 0.0235 - accuracy: 0.9936
Epoch 6/12
935/938 [============================>.] - ETA: 0s - loss: 0.0216 - accuracy: 0.9946
에포크 6의 학습률은 9.999999747378752e-05입니다.
938/938 [==============================] - 7s 8ms/step - loss: 0.0216 - accuracy: 0.9946
Epoch 7/12
921/938 [============================>.] - ETA: 0s - loss: 0.0198 - accuracy: 0.9951
에포크 7의 학습률은 9.999999747378752e-05입니다.
938/938 [==============================] - 7s 8ms/step - loss: 0.0200 - accuracy: 0.9951
Epoch 8/12
932/938 [============================>.] - ETA: 0s - loss: 0.0173 - accuracy: 0.9960
에포크 8의 학습률은 9.999999747378752e-06입니다.
938/938 [==============================] - 7s 8ms/step - loss: 0.0172 - accuracy: 0.9960
Epoch 9/12
928/938 [============================>.] - ETA: 0s - loss: 0.0170 - accuracy: 0.9961
에포크 9의 학습률은 9.999999747378752e-06입니다.
938/938 [==============================] - 7s 8ms/step - loss: 0.0170 - accuracy: 0.9962
Epoch 10/12
921/938 [============================>.] - ETA: 0s - loss: 0.0168 - accuracy: 0.9963
에포크 10의 학습률은 9.999999747378752e-06입니다.
938/938 [==============================] - 7s 8ms/step - loss: 0.0168 - accuracy: 0.9962
Epoch 11/12
921/938 [============================>.] - ETA: 0s - loss: 0.0167 - accuracy: 0.9962
에포크 11의 학습률은 9.999999747378752e-06입니다.
938/938 [==============================] - 7s 8ms/step - loss: 0.0166 - accuracy: 0.9962
Epoch 12/12
934/938 [============================>.] - ETA: 0s - loss: 0.0164 - accuracy: 0.9962
에포크 12의 학습률은 9.999999747378752e-06입니다.
938/938 [==============================] - 7s 8ms/step - loss: 0.0164 - accuracy: 0.9962

<tensorflow.python.keras.callbacks.History at 0x7f0ae002d748>

아래에서 볼 수 있듯이 체크포인트가 저장되고 있습니다.

# 체크포인트 디렉터리 확인하기
!ls {checkpoint_dir}
checkpoint           ckpt_4.data-00000-of-00002
ckpt_10.data-00000-of-00002  ckpt_4.data-00001-of-00002
ckpt_10.data-00001-of-00002  ckpt_4.index
ckpt_10.index            ckpt_5.data-00000-of-00002
ckpt_11.data-00000-of-00002  ckpt_5.data-00001-of-00002
ckpt_11.data-00001-of-00002  ckpt_5.index
ckpt_11.index            ckpt_6.data-00000-of-00002
ckpt_12.data-00000-of-00002  ckpt_6.data-00001-of-00002
ckpt_12.data-00001-of-00002  ckpt_6.index
ckpt_12.index            ckpt_7.data-00000-of-00002
ckpt_1.data-00000-of-00002   ckpt_7.data-00001-of-00002
ckpt_1.data-00001-of-00002   ckpt_7.index
ckpt_1.index             ckpt_8.data-00000-of-00002
ckpt_2.data-00000-of-00002   ckpt_8.data-00001-of-00002
ckpt_2.data-00001-of-00002   ckpt_8.index
ckpt_2.index             ckpt_9.data-00000-of-00002
ckpt_3.data-00000-of-00002   ckpt_9.data-00001-of-00002
ckpt_3.data-00001-of-00002   ckpt_9.index
ckpt_3.index

모델의 성능이 어떤지 확인하기 위하여, 가장 최근 체크포인트를 불러온 후 테스트 데이터에 대하여 evaluate를 호출합니다.

평소와 마찬가지로 적절한 데이터셋과 함께 evaluate를 호출하면 됩니다.

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)

print('평가 손실: {}, 평가 정확도: {}'.format(eval_loss, eval_acc))
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

157/157 [==============================] - 3s 18ms/step - loss: 0.0420 - accuracy: 0.9854
평가 손실: 0.042024929147617074, 평가 정확도: 0.9854000210762024

텐서보드 로그를 다운로드받은 후 터미널에서 다음과 같이 텐서보드를 실행하여 훈련 결과를 확인할 수 있습니다.

$ tensorboard --logdir=path/to/log-directory
!ls -sh ./logs
total 4.0K
4.0K train

SavedModel로 내보내기

플랫폼에 무관한 SavedModel 형식으로 그래프와 변수들을 내보냅니다. 모델을 내보낸 후에는, 전략 범위(scope) 없이 불러올 수도 있고, 전략 범위와 함께 불러올 수도 있습니다.

path = 'saved_model/'
tf.keras.experimental.export_saved_model(model, path)
WARNING:tensorflow:From <ipython-input-19-7f22af6799f5>:1: export_saved_model (from tensorflow.python.keras.saving.saved_model_experimental) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `model.save(..., save_format="tf")` or `tf.keras.models.save_model(..., save_format="tf")`.

WARNING:tensorflow:From <ipython-input-19-7f22af6799f5>:1: export_saved_model (from tensorflow.python.keras.saving.saved_model_experimental) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `model.save(..., save_format="tf")` or `tf.keras.models.save_model(..., save_format="tf")`.

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/saved_model/signature_def_utils_impl.py:253: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/saved_model/signature_def_utils_impl.py:253: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:Signatures INCLUDED in export for Train: ['train']

INFO:tensorflow:Signatures INCLUDED in export for Train: ['train']

INFO:tensorflow:Signatures INCLUDED in export for Predict: None

INFO:tensorflow:Signatures INCLUDED in export for Predict: None

INFO:tensorflow:Signatures INCLUDED in export for Regress: None

INFO:tensorflow:Signatures INCLUDED in export for Regress: None

INFO:tensorflow:Signatures INCLUDED in export for Classify: None

INFO:tensorflow:Signatures INCLUDED in export for Classify: None

WARNING:tensorflow:Export includes no default signature!

WARNING:tensorflow:Export includes no default signature!

INFO:tensorflow:No assets to save.

INFO:tensorflow:No assets to save.

INFO:tensorflow:No assets to write.

INFO:tensorflow:No assets to write.

INFO:tensorflow:Signatures INCLUDED in export for Eval: ['eval']

INFO:tensorflow:Signatures INCLUDED in export for Eval: ['eval']

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Predict: None

INFO:tensorflow:Signatures INCLUDED in export for Predict: None

INFO:tensorflow:Signatures INCLUDED in export for Regress: None

INFO:tensorflow:Signatures INCLUDED in export for Regress: None

INFO:tensorflow:Signatures INCLUDED in export for Classify: None

INFO:tensorflow:Signatures INCLUDED in export for Classify: None

WARNING:tensorflow:Export includes no default signature!

WARNING:tensorflow:Export includes no default signature!

INFO:tensorflow:No assets to save.

INFO:tensorflow:No assets to save.

INFO:tensorflow:No assets to write.

INFO:tensorflow:No assets to write.

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default']

INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default']

INFO:tensorflow:Signatures INCLUDED in export for Regress: None

INFO:tensorflow:Signatures INCLUDED in export for Regress: None

INFO:tensorflow:Signatures INCLUDED in export for Classify: None

INFO:tensorflow:Signatures INCLUDED in export for Classify: None

INFO:tensorflow:No assets to save.

INFO:tensorflow:No assets to save.

INFO:tensorflow:No assets to write.

INFO:tensorflow:No assets to write.

INFO:tensorflow:SavedModel written to: saved_model/saved_model.pb

INFO:tensorflow:SavedModel written to: saved_model/saved_model.pb

strategy.scope 없이 모델 불러오기.

unreplicated_model = tf.keras.experimental.load_from_saved_model(path)

unreplicated_model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)

print('평가 손실: {}, 평가 정확도: {}'.format(eval_loss, eval_acc))
WARNING:tensorflow:From <ipython-input-20-2f23d81b2b21>:1: load_from_saved_model (from tensorflow.python.keras.saving.saved_model_experimental) is deprecated and will be removed in a future version.
Instructions for updating:
The experimental save and load functions have been  deprecated. Please switch to `tf.keras.models.load_model`.

WARNING:tensorflow:From <ipython-input-20-2f23d81b2b21>:1: load_from_saved_model (from tensorflow.python.keras.saving.saved_model_experimental) is deprecated and will be removed in a future version.
Instructions for updating:
The experimental save and load functions have been  deprecated. Please switch to `tf.keras.models.load_model`.

157/157 [==============================] - 2s 10ms/step - loss: 0.0420 - accuracy: 0.9854
평가 손실: 0.042024929147617074, 평가 정확도: 0.9854000210762024

strategy.scope와 함께 모델 불러오기.

with strategy.scope():
  replicated_model = tf.keras.experimental.load_from_saved_model(path)
  replicated_model.compile(loss='sparse_categorical_crossentropy',
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print ('평가 손실: {}, 평가 정확도: {}'.format(eval_loss, eval_acc))
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

157/157 [==============================] - 2s 12ms/step - loss: 0.0420 - accuracy: 0.9854
평가 손실: 0.042024929147617074, 평가 정확도: 0.9854000210762024

예제와 튜토리얼

케라스 적합/컴파일과 함께 분산 전략을 쓰는 예제들이 더 있습니다.

  1. tf.distribute.MirroredStrategy를 사용하여 학습한 Transformer 예제.
  2. tf.distribute.MirroredStrategy를 사용하여 학습한 NCF 예제.

분산 전략 가이드에 더 많은 예제 목록이 있습니다.

다음 단계