Keras 및 MultiWorkerMirroredStrategy를 사용한 사용자 정의 훈련 루프

TensorFlow.org에서 보기 Google Colab에서 실행하기 GitHub에서 소스 보기 노트북 다운로드하기

개요

이 튜토리얼에서는 tf.distribute.Strategy API를 사용하여 Keras 모델 및 사용자 정의 훈련 루프로 다중 작업자 분산 훈련을 수행하는 방법을 보여줍니다. 훈련 루프는 tf.distribute.MultiWorkerMirroredStrategy를 통해 배포되므로 단일 작업자에서 실행되도록 설계된 tf.keras 모델은 최소한의 코드 변경만으로 여러 작업자에서 원활하게 작동할 수 있습니다. 사용자 정의 훈련 루프는 훈련에 대한 유연성과 더 많은 통제력을 제공하는 동시에 모델을 디버그하기 더 쉽게 만들어줍니다. 기본 훈련 루프, 처음부터 훈련 루프 작성하기사용자 정의 훈련에 대해 자세히 알아보세요.

tf.keras.Model.fitMultiWorkerMirroredStrategy과 함께 사용하는 방법을 찾고 있다면 이 튜토리얼을 대신 참조하세요.

tf.distribute.Strategy API를 심층적으로 이해하는 데 관심이 있는 분들은 TensorFlow로 분산 훈련하기 가이드에서 TensorFlow가 제공하는 분산 훈련 전략들을 훑어보실 수 있습니다.

설정

먼저, 몇 가지 필요한 패키지를 가져옵니다.

import json
import os
import sys

TensorFlow를 가져오기 전에 환경을 일부 변경합니다.

  • 모든 GPU를 비활성화합니다. 그러면 모든 작업자가 동일한 GPU를 사용하려고 하여 발생하는 오류가 방지됩니다. 실제 애플리케이션에서는 각 작업자가 다른 시스템에 있습니다.
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  • 'TF_CONFIG' 환경 변수를 재설정합니다(나중에 이에 대해 자세히 볼 수 있음).
os.environ.pop('TF_CONFIG', None)
  • 현재 디렉터리가 Python의 경로에 있는지 확인합니다. 그렇게 되어 있으면 나중에 %%writefile이 작성한 파일을 노트북에서 가져올 수 있습니다.
if '.' not in sys.path:
  sys.path.insert(0, '.')

이제 TensorFlow를 가져옵니다.

import tensorflow as tf
2022-12-15 02:05:55.039887: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-15 02:05:55.040002: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-15 02:05:55.040014: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

데이터세트 및 모델 정의

다음으로, 간단한 모델 및 데이터세트 설정으로 mnist.py 파일을 만듭니다. 이 Python 파일은 이 튜토리얼의 작업자 프로세스에서 사용됩니다.

%%writefile mnist.py

import os
import tensorflow as tf
import numpy as np

def mnist_dataset(batch_size):
  (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
  # The `x` arrays are in uint8 and have values in the range [0, 255].
  # You need to convert them to float32 with values in the range [0, 1]
  x_train = x_train / np.float32(255)
  y_train = y_train.astype(np.int64)
  train_dataset = tf.data.Dataset.from_tensor_slices(
      (x_train, y_train)).shuffle(60000)
  return train_dataset

def dataset_fn(global_batch_size, input_context):
  batch_size = input_context.get_per_replica_batch_size(global_batch_size)
  dataset = mnist_dataset(batch_size)
  dataset = dataset.shard(input_context.num_input_pipelines,
                          input_context.input_pipeline_id)
  dataset = dataset.batch(batch_size)
  return dataset

def build_cnn_model():
  return tf.keras.Sequential([
      tf.keras.Input(shape=(28, 28)),
      tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dense(10)
  ])
Writing mnist.py

다중 작업자 구성

이제 다중 작업자 훈련의 세계로 들어가 보겠습니다. TensorFlow에서 'TF_CONFIG' 환경 변수는 여러 시스템에 대한 훈련에 필요합니다. 시스템마다 역할이 다를 수 있습니다. 아래에 사용된 'TF_CONFIG' 변수는 클러스터의 일부인 각 작업자에 대한 클러스터 구성을 지정하는 JSON 문자열입니다. 이것은 cluster_resolver.TFConfigClusterResolver를 사용하여 클러스터를 지정하는 기본 방법이지만, distribute.cluster_resolver 모듈에서 사용할 수 있는 다른 옵션들도 있습니다. 분산 훈련 가이드에서 'TF_CONFIG' 변수 설정에 대해 자세히 알아보세요.

클러스터 설명하기

다음은 구성의 예입니다.

tf_config = {
    'cluster': {
        'worker': ['localhost:12345', 'localhost:23456']
    },
    'task': {'type': 'worker', 'index': 0}
}

tf_config는 단순히 Python의 지역 변수입니다. 훈련 구성에 사용하려면 JSON으로 직렬화하고 'TF_CONFIG' 환경 변수에 배치합니다. 다음은 JSON 문자열로 직렬화된 동일한 'TF_CONFIG'입니다.

json.dumps(tf_config)
'{"cluster": {"worker": ["localhost:12345", "localhost:23456"]}, "task": {"type": "worker", "index": 0} }'

'TF_CONFIG'에는 'cluster''task'라는 두 가지 구성 요소가 있습니다.

  • 'cluster'는 모든 작업자에게 동일하며 'worker'와 같은 다양한 유형의 작업으로 구성된 사전인 훈련 클러스터에 대한 정보를 제공합니다. MultiWorkerMirroredStrategy를 사용한 다중 작업자 훈련에는 일반적으로 일반 'worker'가 수행하는 작업 외에 TensorBoard에 대한 체크포인트를 저장하고 요약 파일을 작성하는 등 약간 더 많은 일을 처리하는 'worker'가 하나 있습니다. 이러한 작업자를 'chief' 작업자라고 하며, 'index'가 0인 'worker'를 메인 worker로 지정하는 것이 관례입니다.

  • 'task'는 현재 작업에 대한 정보를 제공하며 작업자마다 다릅니다. 이를 통해 해당 작업자의 'type''index'가 지정됩니다.

이 예에서는 'type' 작업을 'worker'로 설정하고 'index' 작업을 0으로 설정합니다. 이 시스템은 첫 번째 작업자이며 메인 작업자로 지정되어 다른 작업자보다 더 많은 일을 처리하게 됩니다. 다른 시스템에도 'TF_CONFIG' 환경 변수가 설정되어 있어야 하며, 동일한 'cluster' 사전이 있어야 하지만 해당 시스템의 역할에 따라 다른 작업 'type' 또는 작업 'index'가 있어야 합니다.

설명을 위해 이 튜토리얼에서는 'localhost'에 두 작업자가 있는 'TF_CONFIG'를 설정하는 방법을 보여줍니다. 실제로 사용자는 외부 IP 주소/포트에 여러 작업자를 만들고 각 작업자에 'TF_CONFIG'를 적절하게 설정합니다.

이 예제에서는 두 작업자를 사용합니다. 첫 번째 작업자의 'TF_CONFIG'는 위와 같습니다. 두 번째 작업자의 경우 tf_config['task']['index']=1을 설정합니다.

노트북의 환경 변수 및 하위 프로세스

하위 프로세스는 상위 요소로부터 환경 변수를 상속합니다. 따라서 이 Jupyter Notebook 프로세스에서 다음과 같이 환경 변수를 설정하는 경우:

os.environ['GREETINGS'] = 'Hello TensorFlow!'

하위 프로세스에서 환경 변수에 액세스할 수 있습니다.

echo ${GREETINGS}
Hello TensorFlow!

다음 섹션에서는 이를 사용하여 작업자 하위 프로세스에 'TF_CONFIG'를 전달합니다. 이런 식으로 작업을 시작하지는 않겠지만 이 튜토리얼의 목적인 최소 다중 작업자 예제를 보여주는 데는 충분합니다.

MultiWorkerMirroredStrategy

모델을 훈련하기 전에 먼저 tf.distribute.MultiWorkerMirroredStrategy의 인스턴스를 만듭니다.

strategy = tf.distribute.MultiWorkerMirroredStrategy()
INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:CPU:0',), communication = CommunicationImplementation.AUTO
2022-12-15 02:05:56.272689: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

참고: 'TF_CONFIG'가 구문 분석되고 TensorFlow의 GRPC 서버는 tf.distribute.MultiWorkerMirroredStrategy를 호출할 때 시작됩니다. 따라서 tf.distribute.Strategy를 인스턴스화하기 전에 'TF_CONFIG' 환경 변수를 설정해야 합니다. 이 튜토리얼의 예제에서는 시간을 절약하기 위해 이를 보여주지 않으므로 서버를 시작할 필요가 없습니다. 이 튜토리얼의 마지막 섹션에 전체 예제가 나와 있습니다.

tf.distribute.Strategy.scope를 사용하여 모델을 빌드할 때 사용해야 하는 전략을 지정합니다. 이를 통해 전략에서 변수 배치와 같은 사항을 제어할 수 있습니다. 모든 작업자에 걸쳐 각 장치의 모델 레이어에 있는 모든 변수의 복사본을 생성합니다.

import mnist
with strategy.scope():
  # Model building needs to be within `strategy.scope()`.
  multi_worker_model = mnist.build_cnn_model()

작업자 간에 데이터 자동 샤딩하기

다중 작업자 훈련에서는 수렴과 재현성을 보장하기 위해 데이터세트 샤딩이 필요합니다. 샤딩은 각 작업자에게 전체 데이터세트의 일부를 전달하는 것을 의미합니다. 이는 단일 작업자에 대한 훈련과 유사한 경험을 만드는 데 도움이 됩니다. 아래 예에서는 tf.distribute의 기본 자동 샤딩 정책을 이용하고 있습니다. tf.data.experimental.AutoShardPolicytf.data.experimental.DistributeOptions를 설정하여 이를 사용자 정의할 수도 있습니다. 자세한 내용은 분산 입력 튜토리얼샤딩 섹션을 참조하세요.

per_worker_batch_size = 64
num_workers = len(tf_config['cluster']['worker'])
global_batch_size = per_worker_batch_size * num_workers

with strategy.scope():
  multi_worker_dataset = strategy.distribute_datasets_from_function(
      lambda input_context: mnist.dataset_fn(global_batch_size, input_context))

사용자 정의 훈련 루프 정의 및 모델 훈련하기

옵티마이저 지정:

with strategy.scope():
  # The creation of optimizer and train_accuracy needs to be in
  # `strategy.scope()` as well, since they create variables.
  optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)
  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')

tf.function을 사용하여 훈련 단계를 정의합니다.

@tf.function
def train_step(iterator):
  """Training step function."""

  def step_fn(inputs):
    """Per-Replica step function."""
    x, y = inputs
    with tf.GradientTape() as tape:
      predictions = multi_worker_model(x, training=True)
      per_batch_loss = tf.keras.losses.SparseCategoricalCrossentropy(
          from_logits=True,
          reduction=tf.keras.losses.Reduction.NONE)(y, predictions)
      loss = tf.nn.compute_average_loss(
          per_batch_loss, global_batch_size=global_batch_size)

    grads = tape.gradient(loss, multi_worker_model.trainable_variables)
    optimizer.apply_gradients(
        zip(grads, multi_worker_model.trainable_variables))
    train_accuracy.update_state(y, predictions)
    return loss

  per_replica_losses = strategy.run(step_fn, args=(next(iterator),))
  return strategy.reduce(
      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

체크포인트 저장 및 복원

사용자 정의 훈련 루프를 작성할 때 Keras 콜백에 의존하는 대신 수동으로 체크포인트 저장을 처리해야 합니다. MultiWorkerMirroredStrategy의 경우 체크포인트 또는 전체 모델을 저장하려면 모든 작업자가 참여해야 합니다. 메인 작업자만 저장하려고 하면 교착 상태가 발생할 수 있기 때문입니다. 또한 작업자는 서로 덮어쓰지 않도록 다른 경로에 작성해야 합니다. 다음은 디렉터리를 구성하는 방법의 예입니다.

from multiprocessing import util
checkpoint_dir = os.path.join(util.get_temp_dir(), 'ckpt')

def _is_chief(task_type, task_id, cluster_spec):
  return (task_type is None
          or task_type == 'chief'
          or (task_type == 'worker'
              and task_id == 0
              and "chief" not in cluster_spec.as_dict()))

def _get_temp_dir(dirpath, task_id):
  base_dirpath = 'workertemp_' + str(task_id)
  temp_dir = os.path.join(dirpath, base_dirpath)
  tf.io.gfile.makedirs(temp_dir)
  return temp_dir

def write_filepath(filepath, task_type, task_id, cluster_spec):
  dirpath = os.path.dirname(filepath)
  base = os.path.basename(filepath)
  if not _is_chief(task_type, task_id, cluster_spec):
    dirpath = _get_temp_dir(dirpath, task_id)
  return os.path.join(dirpath, base)

tf.train.CheckpointManager에 의해 관리되는 모델을 추적하는 하나의 tf.train.Checkpoint를 생성하여 최신 체크포인트만 보존되도록 합니다.

epoch = tf.Variable(
    initial_value=tf.constant(0, dtype=tf.dtypes.int64), name='epoch')
step_in_epoch = tf.Variable(
    initial_value=tf.constant(0, dtype=tf.dtypes.int64),
    name='step_in_epoch')
task_type, task_id = (strategy.cluster_resolver.task_type,
                      strategy.cluster_resolver.task_id)
# Normally, you don't need to manually instantiate a `ClusterSpec`, but in this 
# illustrative example you did not set `'TF_CONFIG'` before initializing the
# strategy. Check out the next section for "real-world" usage.
cluster_spec = tf.train.ClusterSpec(tf_config['cluster'])

checkpoint = tf.train.Checkpoint(
    model=multi_worker_model, epoch=epoch, step_in_epoch=step_in_epoch)

write_checkpoint_dir = write_filepath(checkpoint_dir, task_type, task_id,
                                      cluster_spec)
checkpoint_manager = tf.train.CheckpointManager(
    checkpoint, directory=write_checkpoint_dir, max_to_keep=1)

이제 체크포인트를 복원해야 할 때 편리한 tf.train.latest_checkpoint 함수를 사용하거나 tf.train.CheckpointManager.restore_or_initialize를 호출하여 저장된 최신 체크포인트를 찾을 수 있습니다.

latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
  checkpoint.restore(latest_checkpoint)

체크포인트를 복원한 후, 사용자 정의 훈련 루프의 훈련을 계속할 수 있습니다.

num_epochs = 3
num_steps_per_epoch = 70

while epoch.numpy() < num_epochs:
  iterator = iter(multi_worker_dataset)
  total_loss = 0.0
  num_batches = 0

  while step_in_epoch.numpy() < num_steps_per_epoch:
    total_loss += train_step(iterator)
    num_batches += 1
    step_in_epoch.assign_add(1)

  train_loss = total_loss / num_batches
  print('Epoch: %d, accuracy: %f, train_loss: %f.'
                %(epoch.numpy(), train_accuracy.result(), train_loss))

  train_accuracy.reset_states()

  # Once the `CheckpointManager` is set up, you're now ready to save, and remove
  # the checkpoints non-chief workers saved.
  checkpoint_manager.save()
  if not _is_chief(task_type, task_id, cluster_spec):
    tf.io.gfile.rmtree(write_checkpoint_dir)

  epoch.assign_add(1)
  step_in_epoch.assign(0)
2022-12-15 02:05:57.289580: W tensorflow/core/framework/dataset.cc:769] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
Epoch: 0, accuracy: 0.819531, train_loss: 0.572494.
Epoch: 1, accuracy: 0.929129, train_loss: 0.238400.
Epoch: 2, accuracy: 0.956250, train_loss: 0.157569.

전체 코드 한 눈에 보기

지금까지 논의된 모든 절차를 요약하면 다음과 같습니다.

  1. 작업자 프로세스를 만듭니다.
  2. 작업자 프로세스에 'TF_CONFIG'를 전달합니다.
  3. 각 작업 프로세스에서 훈련 코드가 포함된 아래 스크립트를 실행하도록 합니다.

File: main.py

Writing main.py

현재 디렉터리에는 이제 두 Python 파일이 모두 포함됩니다.

ls *.py
main.py
mnist.py

따라서 'TF_CONFIG'를 JSON 직렬화하고 환경 변수에 추가합니다.

os.environ['TF_CONFIG'] = json.dumps(tf_config)

이제 main.py를 실행하고 'TF_CONFIG'를 사용하는 작업자 프로세스를 시작할 수 있습니다.

# first kill any previous runs
%killbgscripts
All background processes were killed.
python main.py &> job_0.log

위의 명령에서 몇 가지 주의할 사항이 있습니다.

  1. 일부 bash 명령을 실행하기 위해 노트북 "매직"%%bash가 사용됩니다.
  2. --bg 플래그를 사용하여 백그라운드에서 bash 프로세스를 실행합니다. 이 작업자는 종료되지 않기 때문입니다. 시작하기 전에 모든 작업자를 기다립니다.

백그라운드 작업자 프로세스는 이 노트북에 출력을 인쇄하지 않습니다. &>는 출력을 파일로 리디렉션하여 발생한 상황을 검사할 수 있도록 합니다.

프로세스가 시작될 때까지 몇 초 동안 기다립니다.

import time
time.sleep(20)

이제 지금까지 작업자의 로그 파일에 대한 출력을 확인합니다.

cat job_0.log
2022-12-15 02:06:04.951967: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-15 02:06:04.952071: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-15 02:06:04.952083: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2022-12-15 02:06:06.050491: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

로그 파일의 마지막 줄은 다음과 같아야 합니다: Started server with target: grpc://localhost:12345. 이제 첫 번째 작업자가 준비되었으며 다른 모든 작업자가 계속 진행할 준비가 되기를 기다립니다.

두 번째 작업자 프로세스가 선택하도록 tf_config를 업데이트합니다.

tf_config['task']['index'] = 1
os.environ['TF_CONFIG'] = json.dumps(tf_config)

이제 두 번째 작업자를 시작합니다. 모든 작업자가 활성 상태이므로 훈련이 시작됩니다(따라서 이 프로세스를 백그라운드에 둘 필요가 없음).

python main.py > /dev/null 2>&1

첫 번째 작업자가 작성한 로그를 다시 확인하면 작업자가 해당 모델 훈련에 참여했음을 알 수 있습니다.

cat job_0.log
2022-12-15 02:06:04.951967: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-15 02:06:04.952071: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-15 02:06:04.952083: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2022-12-15 02:06:06.050491: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2022-12-15 02:06:27.412697: W tensorflow/core/framework/dataset.cc:769] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
Epoch: 0, accuracy: 0.823996, train_loss: 0.581225.
Epoch: 1, accuracy: 0.925558, train_loss: 0.249244.
Epoch: 2, accuracy: 0.946317, train_loss: 0.186680.
# Delete the `'TF_CONFIG'`, and kill any background tasks so they don't affect the next section.
os.environ.pop('TF_CONFIG', None)
%killbgscripts
All background processes were killed.

심층적 다중 작업자 훈련

이 튜토리얼에서는 다중 작업자 설정의 사용자 정의 훈련 루프 워크플로를 보여주었습니다. 다른 주제에 대한 자세한 설명은 사용자 정의 훈련 루프에 적용할 수 있는 Keras를 사용한 다중 작업자 훈련(tf.keras.Model.fit) 튜토리얼에서 확인할 수 있습니다.

자세히 알아보기

  1. TensorFlow에서 분산 훈련하기 가이드는 사용 가능한 분산 전략을 간략히 소개합니다.
  2. 많은 공식 모델은 다양한 분산 전략을 실행하도록 설정할 수 있습니다.
  3. tf.function 가이드의 성능 섹션에서 TensorFlow 모델 성능 최적화를 위해 사용할 수 있는 다른 전략과 도구를 소개합니다.