ParameterServerStrategy를 사용한 매개변수 서버 훈련

TensorFlow.org에서 보기 Google Colab에서 실행하기 GitHub에서 소스 보기 노트북 다운로드

개요

매개변수 서버 훈련은 여러 머신에서 모델 훈련을 확장하는 일반적인 데이터 병렬 방법입니다.

매개변수 서버 훈련 클러스터는 작업자매개변수 서버로 구성됩니다. 변수는 매개변수 서버에서 생성되며 각 단계에서 작업자가 읽고 업데이트합니다. 기본적으로 작업자는 이러한 변수를 서로 동기화하지 않고 독립적으로 읽고 업데이트합니다. 이 때문에 때로 매개변수 서버 형태의 훈련을 비동기 훈련이라고 부릅니다.

TensorFlow 2에서 매개변수 서버 훈련은 tf.distribute.ParameterServerStrategy 클래스를 기반으로 하는데, 이 클래스는 훈련 단계를 최대 수천 개의 작업자(매개변수 서버와 함께)로 확장되는 클러스터에 배포합니다.

지원되는 훈련 방법

지원되는 훈련 방법에는 크게 두 가지가 있습니다.

작업 및 태스크가 있는 클러스터

선택한 API(Model.fit 또는 사용자 정의 훈련 루프)에 관계없이 TensorFlow 2의 분산 훈련에는 여러 'jobs'이 있는 'cluster'가 포함되며 각 작업에는 하나 이상의 'tasks'이 있을 수 있습니다.

매개변수 서버 훈련을 사용할 때 다음이 있는 것이 좋습니다.

  • 하나의 코디네이터 작업(작업 이름이 chief)
  • 여러 작업자 작업(작업 이름 worker)
  • 다중 매개변수 서버 작업(작업 이름 ps)

코디네이터는 리소스를 생성하고, 훈련 작업을 전달하고, 체크포인트를 작성하고, 작업 실패를 처리합니다. 작업자매개변수 서버는 코디네이터의 요청에 수신 대기하는 tf.distribute.Server 인스턴스를 실행합니다.

Model.fit API를 사용한 매개변수 서버 훈련

Model.fit API를 사용한 매개변수 서버 훈련을 위해서는 코디네이터가 tf.distribute.ParameterServerStrategy 객체를 사용해야 합니다. 전략이 없거나 다른 전략이 있는 Model.fit 사용과 유사하게 워크플로에는 모델 생성과 컴파일, 콜백 준비, 및 Model.fit 호출이 포함됩니다.

사용자 정의 훈련 루프를 사용한 매개변수 서버 훈련

사용자 정의 훈련 루프에서 tf.distribute.coordinator.ClusterCoordinator 클래스는 코디네이터에 사용되는 핵심 구성 요소입니다.

ClusterCoordinator 객체가 제공하는 가장 중요한 API는 schedule입니다.

  • schedule API는 tf.function을 대기열에 넣고 미래와 같은 RemoteValue를 즉시 반환합니다.
  • 대기 중인 함수는 백그라운드 스레드의 원격 작업자에게 전달되고 해당 RemoteValue는 비동기식으로 채워집니다.
  • schedule에는 작업자 할당이 필요하지 않으므로 전달된 tf.function은 사용 가능한 모든 작업자에서 실행할 수 있습니다.
  • 실행된 작업자가 완료되기 전에 사용할 수 없게 되면 사용 가능한 다른 작업자에서 함수가 재시도됩니다.
  • 이것과 함께 함수 실행이 원자적이지 않다는 사실 때문에 단일 함수 호출이 두 번 이상 실행될 수 있습니다.

원격 함수를 전달하는 외에도 ClusterCoordinator는 모든 작업자에 대한 데이터세트를 생성하고 작업자가 장애에서 복구될 때 이러한 데이터세트를 재구축하는 도움을 줍니다.

튜토리얼 설정

이 튜토리얼은 Model.fit 및 사용자 정의 훈련 루프 경로로 분기되며 필요에 맞는 경로를 선택할 수 있습니다. "X를 이용한 훈련" 이외의 섹션은 두 경로 모두에 적용할 수 있습니다.

pip install portpicker

2022-12-15 02:11:31.990663: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-15 02:11:31.990770: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-15 02:11:31.990781: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

클러스터 설정

위에서 언급했듯이 매개변수 서버 훈련 클러스터에는 훈련 프로그램을 실행하는 코디네이터 작업, TensorFlow 서버 tf.distribute.Server를 실행하는 하나 이상의 작업자 및 매개변수 서버 작업, 그리고 사이드카 평가를 실행하는 추가 평가 작업이 필요합니다(아래 사이드카 평가 섹션 참조). 설정 요구 사항은 다음과 같습니다.

  • 코디네이터 작업은 평가자를 제외한 다른 모든 TensorFlow 서버의 주소와 포트를 알아야 합니다.
  • 작업자와 매개변수 서버는 수신 대기해야 하는 포트를 알아야 합니다. 단순하게 하기 위해 일반적으로 이러한 작업에서 TensorFlow 서버를 생성할 때 전체 클러스터 정보를 전달할 수 있습니다.
  • 평가자 작업은 훈련 클러스터의 설정을 알 필요가 없습니다. 알고 있다면 훈련 클러스터에 연결을 시도해서는 안 됩니다.
  • 작업자 및 매개변수 서버는 각각 "worker""ps"의 작업 유형을 가져야 합니다. 코디네이터는 레거시 문제로 인해 작업 유형으로 "chief"를 사용해야 합니다.

이 튜토리얼에서는 전체 매개변수 서버 훈련을 Colab에서 실행할 수 있도록 in-process 클러스터를 생성합니다. 이후 섹션에서 실제 클러스터를 설정하는 방법을 배웁니다.

In-process 클러스터

먼저 몇 개의 TensorFlow 서버를 만들고 나중에 연결할 것입니다. 이것은 이 튜토리얼에서 시연을 하기 위한 것이며 실제 훈련에서 서버는 "worker""ps" 머신에서 시작됩니다.

def create_in_process_cluster(num_workers, num_ps):
  """Creates and starts local servers and returns the cluster_resolver."""
  worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
  ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]

  cluster_dict = {}
  cluster_dict["worker"] = ["localhost:%s" % port for port in worker_ports]
  if num_ps > 0:
    cluster_dict["ps"] = ["localhost:%s" % port for port in ps_ports]

  cluster_spec = tf.train.ClusterSpec(cluster_dict)

  # Workers need some inter_ops threads to work properly.
  worker_config = tf.compat.v1.ConfigProto()
  if multiprocessing.cpu_count() < num_workers + 1:
    worker_config.inter_op_parallelism_threads = num_workers + 1

  for i in range(num_workers):
    tf.distribute.Server(
        cluster_spec,
        job_name="worker",
        task_index=i,
        config=worker_config,
        protocol="grpc")

  for i in range(num_ps):
    tf.distribute.Server(
        cluster_spec,
        job_name="ps",
        task_index=i,
        protocol="grpc")

  cluster_resolver = tf.distribute.cluster_resolver.SimpleClusterResolver(
      cluster_spec, rpc_layer="grpc")
  return cluster_resolver

# Set the environment variable to allow reporting worker and ps failure to the
# coordinator. This is a workaround and won't be necessary in the future.
os.environ["GRPC_FAIL_FAST"] = "use_caller"

NUM_WORKERS = 3
NUM_PS = 2
cluster_resolver = create_in_process_cluster(NUM_WORKERS, NUM_PS)

In-process 클러스터 설정은 여기에서와 같은 단위 테스트에 자주 사용됩니다.

로컬 테스트를 위한 또 다른 옵션은 로컬 머신에서 프로세스를 시작하는 것입니다. 이 접근 방식의 예를 보려면 Keras를 사용한 다중 작업자 훈련을 확인하세요.

ParameterServerStrategy 인스턴스화

훈련 코드를 살펴보기 전에 tf.distribute.ParameterServerStrategy 객체를 인스턴스화하겠습니다. 이 작업은 Model.fit 또는 사용자 정의 훈련 루프를 진행하는지 여부에 관계없이 필요합니다. variable_partitioner 인수는 변수 샤딩 섹션에서 설명합니다.

variable_partitioner = (
    tf.distribute.experimental.partitioners.MinSizePartitioner(
        min_shard_bytes=(256 << 10),
        max_shards=NUM_PS))

strategy = tf.distribute.ParameterServerStrategy(
    cluster_resolver,
    variable_partitioner=variable_partitioner)
INFO:tensorflow:`tf.distribute.experimental.ParameterServerStrategy` is initialized with cluster_spec: ClusterSpec({'ps': ['localhost:37305', 'localhost:42421'], 'worker': ['localhost:37173', 'localhost:43011', 'localhost:37121']})
INFO:tensorflow:ParameterServerStrategyV2 is now connecting to cluster with cluster_spec: ClusterSpec({'ps': ['localhost:37305', 'localhost:42421'], 'worker': ['localhost:37173', 'localhost:43011', 'localhost:37121']})
INFO:tensorflow:ParameterServerStrategy (CentralStorageStrategy if you are using a single machine) with compute_devices = ['/job:chief/replica:0/task:0/device:GPU:0', '/job:chief/replica:0/task:0/device:GPU:1', '/job:chief/replica:0/task:0/device:GPU:2', '/job:chief/replica:0/task:0/device:GPU:3'], variable_device = '/device:CPU:0'
INFO:tensorflow:Number of GPUs on workers: 4

훈련에 GPU를 사용하려면 각 작업자에 표시되는 GPU를 할당합니다. ParameterServerStrategy는 각 작업자에서 사용 가능한 모든 GPU를 사용하며 모든 작업자가 동일한 수의 GPU를 사용할 수 있어야 한다는 제한이 있습니다.

가변 샤딩

변수 샤딩은 변수를 샤드라고 하는 여러 개의 작은 변수로 나누는 것을 말합니다. 변수 샤딩은 이러한 샤드에 액세스할 때 네트워크 부하를 분산하는 데 유용할 수 있습니다. 예를 들어 단일 머신의 메모리에 맞지 않을 수 있는 매우 큰 임베딩을 사용할 때 일반 변수의 계산과 저장을 여러 매개변수 서버에 분산하는 데도 유용합니다.

변수 샤딩을 사용하려면 ParameterServerStrategy 객체를 생성할 때 variable_partitioner를 전달할 수 있습니다. variable_partitioner는 변수가 생성될 때마다 호출되며 변수의 각 차원을 따라 샤드 수를 반환할 것으로 예상됩니다. tf.distribute.experimental.partitioners.MinSizePartitioner와 같은 몇 가지 기본 variable_partitioner가 제공됩니다. tf.distribute.experimental.partitioners.MinSizePartitioner와 같은 크기 기반 파티셔너를 사용하여 모델 훈련 속도에 부정적인 영향을 줄 수 있는 작은 변수의 분할을 방지하는 것이 좋습니다.

variable_partitioner가 전달되고 Strategy.scope 바로 아래에 변수를 생성하면 변수는 샤드 목록에 대한 액세스를 제공하는 variables 속성이 있는 컨테이너 유형이 됩니다. 대부분의 경우 이 컨테이너는 모든 샤드를 연결하여 자동으로 텐서로 변환됩니다. 결과적으로 일반 변수로 사용할 수 있습니다. 반면에 tf.nn.embedding_lookup과 같은 일부 TensorFlow 메서드는 이 컨테이너 유형에 대한 효율적인 구현을 제공하며 이러한 메서드에서는 자동 연결이 방지됩니다.

자세한 내용은 tf.distribute.ParameterServerStrategy의 API 문서를 참조하세요.

Model.fit으로 훈련하기

Keras는 재정의 가능한 train_step의 유연성, 그리고 TensorBoard에 대한 체크포인트 저장 또는 요약 저장과 같은 기능을 제공하는 콜백과 함께 막후에서 훈련 루프를 처리하는 Model.fit을 통해 사용하기 쉬운 훈련 API를 제공합니다. Model.fit을 사용하면 간단히 전략 객체만 교체하여 다른 전략에서 동일한 훈련 코드를 사용할 수 있습니다.

입력 데이터

Model.fit을 포함한 tf.distribute.ParameterServerStrategytf.data.Dataset, tf.distribute.DistributedDataset 또는 tf.keras.utils.experimental.DatasetCreator 형식으로 입력 데이터를 받을 수 있으며 Dataset은 사용 편리성을 위해 권장되는 옵션입니다. 그러나 Dataset을 사용하여 메모리 문제가 발생하면 호출 가능한 dataset_fn 인수와 함께 DatasetCreator를 사용해야 할 수도 있습니다(자세한 내용은 tf.keras.utils.experimental.DatasetCreator API 설명서 참조).

데이터세트를 tf.data.Dataset으로 변환하는 경우 아래 코드 예제와 같이 Dataset.shuffleDataset.repeat를 사용해야 합니다.

  • 매개변수 서버 훈련이 포함된 Model.fit은 다르게 섞인 경우를 제외하고 각 작업자가 동일한 데이터세트를 받는다고 가정합니다. 따라서 Dataset.shuffle을 호출하여 데이터에 대해 더 균일한 반복을 보장할 수 있습니다.
  • 작업자는 동기화하지 않기 때문에 서로 다른 시간에 데이터세트 처리를 완료할 수 있습니다. 따라서 매개변수 서버 훈련으로 epoch를 정의하는 가장 쉬운 방법은 Dataset.repeat(인수 없이 호출될 때 데이터세트를 무한히 반복함)를 사용하고 Model.fit 호출에서 steps_per_epoch 인수를 지정하는 것입니다.

shufflerepeat에 대한 자세한 내용은 tf.data 가이드의 "훈련 워크플로" 섹션을 참조하세요.

global_batch_size = 64

x = tf.random.uniform((10, 10))
y = tf.random.uniform((10,))

dataset = tf.data.Dataset.from_tensor_slices((x, y)).shuffle(10).repeat()
dataset = dataset.batch(global_batch_size)
dataset = dataset.prefetch(2)

대신 tf.keras.utils.experimental.DatasetCreator를 사용하여 데이터세트를 생성하면 각 작업자 머신의 입력 장치(일반적으로 CPU)에서 dataset_fn의 코드가 호출됩니다.

모델 구성 및 컴파일

이제 데모 목적의 간단한 tf.keras.models.Sequential 모델인 tf.keras.Model을 만든 다음 옵티마이저와 같은 구성 요소 및 steps_per_execution과 같은 기타 매개변수를 도입하기 위한 Model.compile 호출이 이루어집니다.

with strategy.scope():
  model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])

  model.compile(tf.keras.optimizers.legacy.SGD(), loss="mse", steps_per_execution=10)

콜백 및 훈련

실제 훈련을 위해 Keras Model.fit을 호출하기 전에 다음과 같은 일반적인 작업에 필요한 콜백을 준비합니다.

  • tf.keras.callbacks.ModelCheckpoint: 매 epoch 후와 같이 특정 빈도로 모델을 저장합니다.
  • tf.keras.callbacks.BackupAndRestore: 클러스터에 사용 불가능한 상황(중단 또는 선점 등)이 발생하는 경우 모델과 현재 epoch 번호를 백업하여 내결함성을 제공합니다. 그런 다음 작업 실패 후 다시 시작할 때 훈련 상태를 복원하고 중단된 epoch의 시작 부분부터 훈련을 계속할 수 있습니다.
  • tf.keras.callbacks.TensorBoard: 요약 파일에 TensorBoard 도구에서 시각화할 수 있는 모델 로그를 주기적으로 작성합니다.

참고: 성능 고려 사항으로 인해 사용자 정의 콜백은 ParameterServerStrategy와 함께 사용될 때 일괄 처리 수준 콜백을 재정의할 수 없습니다. 사용자 정의 콜백을 수정하여 epoch 수준 호출이 되도록 하고 steps_per_epoch를 적절한 값으로 조정합니다. 또한 steps_per_epochParameterServerStrategy와 함께 사용할 때 Model.fit에 대한 필수 인수입니다.

working_dir = "/tmp/my_working_dir"
log_dir = os.path.join(working_dir, "log")
ckpt_filepath = os.path.join(working_dir, "ckpt")
backup_dir = os.path.join(working_dir, "backup")

callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir=log_dir),
    tf.keras.callbacks.ModelCheckpoint(filepath=ckpt_filepath),
    tf.keras.callbacks.BackupAndRestore(backup_dir=backup_dir),
]

model.fit(dataset, epochs=5, steps_per_epoch=20, callbacks=callbacks)
Epoch 1/5
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py:461: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options.
  warnings.warn("To make it possible to preserve tf.data options across "
2022-12-15 02:11:37.794389: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_FLOAT
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 10
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:0"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 10
        }
      }
      shape {
      }
    }
  }
}
attr {
  key: "replicate_on_split"
  value {
    b: false
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
    }
  }
}

2022-12-15 02:11:37.794487: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_FLOAT
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 10
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:0"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 10
        }
      }
      shape {
      }
    }
  }
}
attr {
  key: "replicate_on_split"
  value {
    b: false
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
    }
  }
}

2022-12-15 02:11:37.810542: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_FLOAT
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 10
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:0"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 10
        }
      }
      shape {
      }
    }
  }
}
attr {
  key: "replicate_on_split"
  value {
    b: false
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
    }
  }
}
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Waiting for all global closures to be finished.
INFO:tensorflow:Assets written to: /tmp/my_working_dir/ckpt/assets
20/20 - 8s - loss: 0.4865 - 8s/epoch - 419ms/step
Epoch 2/5
INFO:tensorflow:Waiting for all global closures to be finished.
INFO:tensorflow:Assets written to: /tmp/my_working_dir/ckpt/assets
20/20 - 1s - loss: 0.4072 - 946ms/epoch - 47ms/step
Epoch 3/5
INFO:tensorflow:Waiting for all global closures to be finished.
WARNING:tensorflow:5 out of the last 5 calls to <function MultiDeviceSaver.save.<locals>.tf_function_save at 0x7efefc19c280> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
INFO:tensorflow:Assets written to: /tmp/my_working_dir/ckpt/assets
WARNING:tensorflow:6 out of the last 6 calls to <function MultiDeviceSaver.save.<locals>.tf_function_save at 0x7eff28420670> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
20/20 - 1s - loss: 0.3492 - 532ms/epoch - 27ms/step
Epoch 4/5
INFO:tensorflow:Waiting for all global closures to be finished.
INFO:tensorflow:Assets written to: /tmp/my_working_dir/ckpt/assets
20/20 - 1s - loss: 0.3043 - 543ms/epoch - 27ms/step
Epoch 5/5
INFO:tensorflow:Waiting for all global closures to be finished.
INFO:tensorflow:Assets written to: /tmp/my_working_dir/ckpt/assets
20/20 - 1s - loss: 0.2698 - 543ms/epoch - 27ms/step
<keras.callbacks.History at 0x7eff63bb1b80>

ClusterCoordinator로 직접 사용(선택 사항)

Model.fit 훈련 경로를 선택하더라도 tf.distribute.coordinator.ClusterCoordinator 객체를 선택적으로 인스턴스화하여 작업자에서 실행하려는 다른 함수를 예약할 수 있습니다. 자세한 내용과 예제는 사용자 정의 훈련 루프를 사용한 훈련 섹션을 참조하세요.

사용자 정의 훈련 루프를 사용한 훈련

tf.distribute.Strategy와 함께 사용자 정의 훈련 루프를 사용하면 훈련 루프를 정의하는 큰 유연성이 생깁니다. 위에서 정의한 ParameterServerStrategy(strategy)로 tf.distribute.coordinator.ClusterCoordinator를 사용하여 훈련 단계 실행을 원격 작업자에게 전달합니다.

그런 다음 다른 tf.distribute.Strategy를 사용하여 훈련 루프에서 수행한 것처럼 모델을 생성하고, 데이터세트를 정의하고, 단계 함수를 정의합니다. tf.distribute.Strategy를 사용한 사용자 정의 훈련 튜토리얼에서 자세한 내용을 찾을 수 있습니다.

효율적인 데이터세트 미리 가져오기를 위해 아래 원격 작업자에게 훈련 단계 전달하기 섹션에 언급된 권장 분산 데이터세트 생성 API를 사용하세요. 또한 작업자에게 할당된 GPU를 최대한 활용하려면 worker_fn 내에서 Strategy.run을 호출해야 합니다. 나머지 단계는 GPU가 있거나 없는 훈련에 대해 동일합니다.

다음 단계에서 이러한 구성 요소를 만들어 보겠습니다.

데이터 설정하기

먼저 데이터세트를 생성하는 함수를 작성합니다.

Keras 전처리 레이어 또는 Tensorflow Transform 레이어를 사용하여 데이터를 전처리하려면 다른 Keras 레이어의 경우와 마찬가지로 dataset_fn 외부Strategy.scope 아래에 이러한 레이어를 생성합니다. 이는 dataset_fntf.function으로 래핑된 다음 각 작업자에서 실행되어 데이터 파이프라인을 생성하기 때문입니다.

위의 절차를 따르지 않으면 레이어를 생성할 때 tf.function에서 코디네이터로 해제되는 Tensorflow 상태가 생성될 수 있습니다. 따라서 작업자에서 여기에 액세스하면 코디네이터와 작업자 간에 반복적인 RPC 호출이 발생하고 상당한 속도 저하가 발생합니다.

Strategy.scope 아래에 레이어를 배치하면 대신 모든 작업자에서 레이어가 생성됩니다. 그러면 tf.data.Dataset.map을 통해 dataset_fn 내에 변환을 적용합니다. 분산 입력을 사용한 데이터 전처리에 대한 자세한 내용은 분산 입력 튜토리얼의 데이터 전처리를 참조하세요.

feature_vocab = [
    "avenger", "ironman", "batman", "hulk", "spiderman", "kingkong", "wonder_woman"
]
label_vocab = ["yes", "no"]

with strategy.scope():
  feature_lookup_layer = tf.keras.layers.StringLookup(
      vocabulary=feature_vocab,
      mask_token=None)
  label_lookup_layer = tf.keras.layers.StringLookup(
      vocabulary=label_vocab,
      num_oov_indices=0,
      mask_token=None)

  raw_feature_input = tf.keras.layers.Input(
      shape=(3,),
      dtype=tf.string,
      name="feature")
  feature_id_input = feature_lookup_layer(raw_feature_input)
  feature_preprocess_stage = tf.keras.Model(
      {"features": raw_feature_input},
      feature_id_input)

  raw_label_input = tf.keras.layers.Input(
      shape=(1,),
      dtype=tf.string,
      name="label")
  label_id_input = label_lookup_layer(raw_label_input)

  label_preprocess_stage = tf.keras.Model(
      {"label": raw_label_input},
      label_id_input)
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/numpy/core/numeric.py:2468: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison
  return bool(asarray(a1 == a2).all())

데이터세트에서 장난감 예제 생성:

def feature_and_label_gen(num_examples=200):
  examples = {"features": [], "label": []}
  for _ in range(num_examples):
    features = random.sample(feature_vocab, 3)
    label = ["yes"] if "avenger" in features else ["no"]
    examples["features"].append(features)
    examples["label"].append(label)
  return examples

examples = feature_and_label_gen()

그런 다음 dataset_fn에 래핑된 훈련 데이터세트를 생성합니다.

def dataset_fn(_):
  raw_dataset = tf.data.Dataset.from_tensor_slices(examples)

  train_dataset = raw_dataset.map(
      lambda x: (
          {"features": feature_preprocess_stage(x["features"])},
          label_preprocess_stage(x["label"])
      )).shuffle(200).batch(32).repeat()
  return train_dataset

모델 구축하기

다음으로 모델 및 기타 객체를 만듭니다. Strategy.scope 아래에 모든 변수를 생성해야 합니다.

# These variables created under the `Strategy.scope` will be placed on parameter
# servers in a round-robin fashion.
with strategy.scope():
  # Create the model. The input needs to be compatible with Keras processing layers.
  model_input = tf.keras.layers.Input(
      shape=(3,), dtype=tf.int64, name="model_input")

  emb_layer = tf.keras.layers.Embedding(
      input_dim=len(feature_lookup_layer.get_vocabulary()), output_dim=16384)
  emb_output = tf.reduce_mean(emb_layer(model_input), axis=1)
  dense_output = tf.keras.layers.Dense(units=1, activation="sigmoid")(emb_output)
  model = tf.keras.Model({"features": model_input}, dense_output)

  optimizer = tf.keras.optimizers.legacy.RMSprop(learning_rate=0.1)
  accuracy = tf.keras.metrics.Accuracy()

FixedShardsPartitioner의 사용으로 모든 변수가 두 개의 샤드로 분할되고 각 샤드가 다른 매개변수 서버에 할당되었는지 확인하겠습니다.

assert len(emb_layer.weights) == 2
assert emb_layer.weights[0].shape == (4, 16384)
assert emb_layer.weights[1].shape == (4, 16384)

print(emb_layer.weights[0].device)
print(emb_layer.weights[1].device)
/job:ps/replica:0/task:1/device:CPU:0
/job:ps/replica:0/task:0/device:CPU:0

훈련 단계 정의하기

셋째, tf.function에 래핑된 훈련 단계를 생성합니다.

@tf.function
def step_fn(iterator):

  def replica_fn(batch_data, labels):
    with tf.GradientTape() as tape:
      pred = model(batch_data, training=True)
      per_example_loss = tf.keras.losses.BinaryCrossentropy(
          reduction=tf.keras.losses.Reduction.NONE)(labels, pred)
      loss = tf.nn.compute_average_loss(per_example_loss)
      gradients = tape.gradient(loss, model.trainable_variables)

    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    actual_pred = tf.cast(tf.greater(pred, 0.5), tf.int64)
    accuracy.update_state(labels, actual_pred)
    return loss

  batch_data, labels = next(iterator)
  losses = strategy.run(replica_fn, args=(batch_data, labels))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, losses, axis=None)

위의 훈련 단계 함수에서 step_fn에서 Strategy.runStrategy.reduce를 호출하면 작업자당 여러 GPU를 지원할 수 있습니다. 작업자에게 GPU가 할당된 경우 Strategy.run은 데이터세트를 여러 복제본에 배포합니다.

원격 작업자에게 교육 단계 전달하기

모든 계산이 ParameterServerStrategy에 의해 정의된 후 tf.distribute.coordinator.ClusterCoordinator 클래스를 사용하여 리소스를 생성하고 훈련 단계를 원격 작업자에게 배포합니다.

먼저 ClusterCoordinator 객체를 만들고 전략 객체를 전달해 보겠습니다.

coordinator = tf.distribute.coordinator.ClusterCoordinator(strategy)

그런 다음 모든 작업자에게 데이터세트를 복제하는 ClusterCoordinator.create_per_worker_dataset API를 사용하여 작업자별 데이터세트와 반복자를 만듭니다. GPU로 미리 가져오기가 원활하고 효율적으로 수행되도록 아래 per_worker_dataset_fn에서 dataset_fnstrategy.distribute_datasets_from_function으로 래핑하는 것이 좋습니다.

@tf.function
def per_worker_dataset_fn():
  return strategy.distribute_datasets_from_function(dataset_fn)

per_worker_dataset = coordinator.create_per_worker_dataset(per_worker_dataset_fn)
per_worker_iterator = iter(per_worker_dataset)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:tensorflow:Model was constructed with shape (None, 3) for input KerasTensor(type_spec=TensorSpec(shape=(None, 3), dtype=tf.string, name='feature'), name='feature', description="created by layer 'feature'"), but it was called on an input with incompatible shape (3,).

마지막 단계는 ClusterCoordinator.schedule을 사용하여 원격 작업자에게 계산을 배포하는 것입니다.

  • schedule 메서드는 tf.function을 대기열에 넣고 미래와 같은 RemoteValue를 즉시 반환합니다. 대기열에 놓인 함수는 백그라운드 스레드의 원격 작업자에게 전달되고 RemoteValue는 비동기식으로 채워집니다.
  • join 메서드(ClusterCoordinator.join)는 예약된 모든 함수가 실행될 때까지 대기하는 데 사용할 수 있습니다.
num_epochs = 4
steps_per_epoch = 5
for i in range(num_epochs):
  accuracy.reset_states()
  for _ in range(steps_per_epoch):
    coordinator.schedule(step_fn, args=(per_worker_iterator,))
  # Wait at epoch boundaries.
  coordinator.join()
  print("Finished epoch %d, accuracy is %f." % (i, accuracy.result().numpy()))
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Reduce to /device:CPU:0 then broadcast to ('/replica:0/device:CPU:0',).
INFO:tensorflow:Waiting for all global closures to be finished.
Finished epoch 0, accuracy is 0.357143.
INFO:tensorflow:Waiting for all global closures to be finished.
Finished epoch 1, accuracy is 1.000000.
INFO:tensorflow:Waiting for all global closures to be finished.
Finished epoch 2, accuracy is 1.000000.
INFO:tensorflow:Waiting for all global closures to be finished.
Finished epoch 3, accuracy is 1.000000.

다음은 RemoteValue의 결과를 가져오는 방법입니다.

loss = coordinator.schedule(step_fn, args=(per_worker_iterator,))
print("Final loss is %f" % loss.fetch())
Final loss is 0.000030

또는 모든 단계를 시작하고 완료를 기다리는 동안 작업을 수행할 수 있습니다.

for _ in range(total_steps):
  coordinator.schedule(step_fn, args=(per_worker_iterator,))
while not coordinator.done():
  time.sleep(10)
  # Do something like logging metrics or writing checkpoints.

이 특정 예제에 대한 전체 훈련 및 제공 워크플로는 이 테스트를 확인하세요.

데이터세트 생성에 대한 추가 사항

위 코드의 데이터세트는 ClusterCoordinator.create_per_worker_dataset API를 사용하여 생성됩니다. 작업자당 하나의 데이터세트를 생성하고 컨테이너 객체를 반환합니다. 여기서 iter 메서드를 호출하여 작업자별 반복자를 만들 수 있습니다. 작업자별 반복자는 작업자당 하나의 반복자를 포함하고 작업자의 해당 슬라이스는 특정 작업자에서 함수가 실행되기 전에 ClusterCoordinator.schedule 메서드에 전달된 함수의 입력 인수에서 대체됩니다.

ClusterCoordinator.schedule 메서드는 작업자가 동등하다고 가정하므로 서로 다른 작업자의 데이터세트가 동일하다고 가정합니다(다르게 섞일 수 있다는 점은 제외). 이 때문에 데이터세트를 반복하고, 데이터세트에서 OutOfRangeError를 수신하는 대신 유한한 수의 단계를 예약하는 것이 좋습니다.

또 다른 중요한 점은 tf.data 데이터세트가 작업 경계를 넘어 암시적 직렬화 및 역직렬화를 지원하지 않는다는 것입니다. 따라서 ClusterCoordinator.create_per_worker_dataset에 전달된 함수 내부에 전체 데이터세트를 생성하는 것이 중요합니다. create_per_worker_dataset API는 tf.data.Dataset 또는 tf.distribute.DistributedDataset을 직접 입력으로 사용할 수도 있습니다.

평가

tf.distribute.ParameterServerStrategy 훈련으로 평가를 수행하는 두 가지 주요 접근 방식은 인라인 평가와 사이드카 평가입니다. 각각은 아래와 같이 장단점이 있습니다. 특별한 선호도가 없다면 인라인 평가 방법을 권장합니다.

인라인 평가

이 방법에서 코디네이터는 훈련과 평가를 번갈아 가며 수행하므로 인라인 평가라고 합니다.

인라인 평가에는 다음과 같은 몇 가지 이점이 있습니다.

  • 단일 작업으로 보유할 수 없는 대규모 평가 모델 및 평가 데이터세트를 지원할 수 있습니다.
  • 평가 결과는 예를 들어 훈련을 조기에 중단할지 여부와 같이 다음 epoch를 훈련하기 위한 결정을 내리는 데 사용할 수 있습니다.

인라인 평가를 구현하는 방법에는 직접 평가와 분산 평가의 두 가지가 있습니다.

  • 직접 평가: 소규모 모델 및 평가 데이터세트의 경우 코디네이터는 코디네이터의 평가 데이터세트를 사용하여 분산 모델에서 직접 평가를 실행할 수 있습니다.
eval_dataset = tf.data.Dataset.from_tensor_slices(
    feature_and_label_gen(num_examples=16)).map(
          lambda x: (
              {"features": feature_preprocess_stage(x["features"])},
              label_preprocess_stage(x["label"])
          )).batch(8)

eval_accuracy = tf.keras.metrics.Accuracy()

for batch_data, labels in eval_dataset:
  pred = model(batch_data, training=False)
  actual_pred = tf.cast(tf.greater(pred, 0.5), tf.int64)
  eval_accuracy.update_state(labels, actual_pred)

print("Evaluation accuracy: %f" % eval_accuracy.result())
WARNING:tensorflow:Model was constructed with shape (None, 3) for input KerasTensor(type_spec=TensorSpec(shape=(None, 3), dtype=tf.string, name='feature'), name='feature', description="created by layer 'feature'"), but it was called on an input with incompatible shape (3,).
Evaluation accuracy: 1.000000
  • 분산 평가: 코디네이터에서 직접 실행할 수 없는 대규모 모델 또는 데이터세트의 경우, 코디네이터 작업으로 ClusterCoordinator.schedule/ClusterCoordinator.join 메서드를 통해 작업자에게 평가 작업을 배포할 수 있습니다.
with strategy.scope():
  # Define the eval metric on parameter servers.
  eval_accuracy = tf.keras.metrics.Accuracy()

@tf.function
def eval_step(iterator):
  def replica_fn(batch_data, labels):
    pred = model(batch_data, training=False)
    actual_pred = tf.cast(tf.greater(pred, 0.5), tf.int64)
    eval_accuracy.update_state(labels, actual_pred)
  batch_data, labels = next(iterator)
  strategy.run(replica_fn, args=(batch_data, labels))

def eval_dataset_fn():
  return tf.data.Dataset.from_tensor_slices(
      feature_and_label_gen(num_examples=16)).map(
          lambda x: (
              {"features": feature_preprocess_stage(x["features"])},
              label_preprocess_stage(x["label"])
          )).shuffle(16).repeat().batch(8)

per_worker_eval_dataset = coordinator.create_per_worker_dataset(eval_dataset_fn)
per_worker_eval_iterator = iter(per_worker_eval_dataset)

eval_steps_per_epoch = 2
for _ in range(eval_steps_per_epoch):
  coordinator.schedule(eval_step, args=(per_worker_eval_iterator,))
coordinator.join()
print("Evaluation accuracy: %f" % eval_accuracy.result())
WARNING:tensorflow:Model was constructed with shape (None, 3) for input KerasTensor(type_spec=TensorSpec(shape=(None, 3), dtype=tf.string, name='feature'), name='feature', description="created by layer 'feature'"), but it was called on an input with incompatible shape (3,).
WARNING:tensorflow:4 GPUs are allocated per worker. Please use DistributedDataset by calling strategy.experimental_distribute_dataset or strategy.distribute_datasets_from_function to make best use of GPU resources
WARNING:tensorflow:4 GPUs are allocated per worker. Please use DistributedDataset by calling strategy.experimental_distribute_dataset or strategy.distribute_datasets_from_function to make best use of GPU resources
WARNING:tensorflow:4 GPUs are allocated per worker. Please use DistributedDataset by calling strategy.experimental_distribute_dataset or strategy.distribute_datasets_from_function to make best use of GPU resources
INFO:tensorflow:Waiting for all global closures to be finished.
Evaluation accuracy: 1.000000

참고: tf.distribute.coordinator.ClusterCoordinatorschedulejoin 메서드는 방문 보장 또는 정확히 한 번의 의미 체계를 지원하지 않습니다. 즉, 데이터세트의 모든 평가 예제가 정확히 한 번 평가된다는 보장은 없습니다. 일부는 방문되지 않을 수 있고 일부는 여러 번 평가될 수 있습니다. tf.data 서비스 API는 ParameterServerStrategy를 사용할 때 평가를 위해 정확히 한 번 방문을 제공하는 데 사용할 수 있습니다(tf.data.experimental.service API 문서의 동적 샤딩 섹션 참조).

사이드카 평가

tf.distribute.ParameterServerStrategy 훈련에서 평가 루프를 정의하고 실행하는 또 다른 방법은 체크포인트를 반복적으로 읽고 최신 체크포인트에서 평가를 실행하는 전용 평가자 작업을 생성하는 사이드카 평가입니다(체크포인트에 대한 자세한 내용은 이 가이드 참조). 수석 및 작업자 작업은 평가에 시간을 들이지 않으므로 고정된 반복 횟수에 대해 전체 훈련 시간은 다른 평가 방법을 사용하는 것보다 짧습니다. 그러나 평가를 트리거하려면 추가적인 평가자 작업과 주기적인 체크포인트 절차가 필요합니다.

사이드카 평가를 위한 평가 루프를 작성할 때 두 가지 옵션이 있습니다.

  1. tf.keras.utils.SidecarEvaluator API를 사용합니다.
  2. 사용자 정의 평가 루프를 만듭니다.

옵션 1에 대한 자세한 내용은 tf.keras.utils.SidecarEvaluator API 문서를 참조하세요.

사이드카 평가는 단일 작업에서만 지원됩니다. 이것은 다음을 의미합니다.

  • 각 예제는 확실하게 한 번만 평가됩니다. 평가자가 선점되거나 다시 시작되는 경우 최근 체크포인트에서 평가 루프를 다시 시작하고 다시 시작하기 전에 이루어진 평가 진행 부분은 폐기됩니다.

  • 그러나 단일 작업에 대해 평가를 실행하면 전체 평가에 시간이 오래 걸릴 수 있습니다.

  • 모델의 크기가 너무 커서 평가자의 메모리에 맞지 않는 경우 단일 사이드카 평가가 적용되지 않습니다.

또 다른 주의 사항은 tf.keras.utils.SidecarEvaluator 구현과 아래의 사용자 정의 평가 루프가 항상 사용 가능한 최신 체크포인트를 선택하고 평가 epoch 동안 훈련 클러스터에서 여러 체크포인트를 생성할 수 있기 때문에 일부 체크포인트를 건너뛸 수 있다는 것입니다. 모든 체크포인트를 평가하는 사용자 정의 평가 루프를 작성할 수 있지만 이 튜토리얼에서는 다루지 않습니다. 반면에 체크포인트가 평가를 실행하는 데 걸리는 시간보다 덜 자주 생성되면 유휴 상태로 있을 수 있습니다.

사용자 정의 평가 루프는 평가할 체크포인트를 선택하거나 평가와 함께 실행할 추가 논리를 제공하는 등 세부 사항에 대한 더 많은 통제력을 제공합니다. 다음은 가능한 사용자 정의 사이드카 평가 루프입니다.

checkpoint_dir = ...
eval_model = ...
eval_data = ...
checkpoint = tf.train.Checkpoint(model=eval_model)

for latest_checkpoint in tf.train.checkpoints_iterator(
    checkpoint_dir):
  try:
    checkpoint.restore(latest_checkpoint).expect_partial()
  except (tf.errors.OpError,) as e:
    # checkpoint may be deleted by training when it is about to read it.
    continue

  # Optionally add callbacks to write summaries.
  eval_model.evaluate(eval_data)

  # Evaluation finishes when it has evaluated the last epoch.
  if latest_checkpoint.endswith('-{}'.format(train_epochs)):
    break

실제 상황에서 클러스터

참고: 이 섹션은 이 페이지의 튜토리얼 코드를 실행하는 데 필요하지 않습니다.

실제 프로덕션 환경에서는 서로 다른 머신의 서로 다른 프로세스에서 모든 작업을 실행합니다. 각 작업에 대한 클러스터 정보를 구성하는 가장 간단한 방법은 "TF_CONFIG" 환경 변수를 설정하고 tf.distribute.cluster_resolver.TFConfigClusterResolver를 사용하여 "TF_CONFIG"를 구문 분석하는 것입니다.

"TF_CONFIG" 환경 변수에 대한 일반적인 설명은 분산 훈련 가이드의 "TF_CONFIG 환경 변수 설정"을 참조하세요.

Kubernetes 또는 기타 구성 템플릿을 사용하여 훈련 작업을 시작하는 경우 이러한 템플릿이 이미 “TF_CONFIG"를 설정했을 가능성이 높습니다.

"TF_CONFIG" 환경 변수 설정하기

3개의 작업자와 2개의 매개변수 서버가 있다고 가정합니다. 그러면 작업자 1의 "TF_CONFIG"는 다음과 같을 수 있습니다.

os.environ["TF_CONFIG"] = json.dumps({
    "cluster": {
        "worker": ["host1:port", "host2:port", "host3:port"],
        "ps": ["host4:port", "host5:port"],
        "chief": ["host6:port"]
    },
    "task": {"type": "worker", "index": 1}
})

평가자의 "TF_CONFIG"는 다음과 같을 수 있습니다.

os.environ["TF_CONFIG"] = json.dumps({
    "cluster": {
        "evaluator": ["host7:port"]
    },
    "task": {"type": "evaluator", "index": 0}
})

위의 평가자에 대한 "TF_CONFIG" 문자열 중 "cluster" 부분은 선택 사항입니다.

모든 작업에 동일한 바이너리를 사용하는 경우

단일 바이너리를 사용하여 이러한 모든 작업을 실행하려면 처음에 프로그램이 여러 역할로 분기되도록 해야 합니다.

cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
if cluster_resolver.task_type in ("worker", "ps"):
  # Start a TensorFlow server and wait.
elif cluster_resolver.task_type == "evaluator":
  # Run sidecar evaluation
else:
  # Run the coordinator.

TensorFlow 서버를 시작하고 대기하는 다음 코드는 "worker""ps" 역할에 유용합니다.

# Set the environment variable to allow reporting worker and ps failure to the
# coordinator. This is a workaround and won't be necessary in the future.
os.environ["GRPC_FAIL_FAST"] = "use_caller"

server = tf.distribute.Server(
    cluster_resolver.cluster_spec(),
    job_name=cluster_resolver.task_type,
    task_index=cluster_resolver.task_id,
    protocol=cluster_resolver.rpc_layer or "grpc",
    start=True)
server.join()

작업 오류 처리하기

작업자 오류

tf.distribute.coordinator.ClusterCoordinator 사용자 정의 훈련 루프와 Model.fit 접근 방식은 모두 작업자 오류에 대한 내결함성을 기본적으로 제공합니다. 작업자 복구 시 ClusterCoordinator는 작업자에 대한 데이터세트 재생성을 호출합니다.

매개변수 서버 또는 코디네이터 오류

그러나 코디네이터는 매개변수 서버 오류를 발견하면 즉시 UnavailableError 또는 AbortedError를 발생시킵니다. 이 경우 코디네이터를 다시 시작할 수 있습니다. 코디네이터 자체도 사용할 수 없게 될 수 있습니다. 따라서 훈련 진행 상황을 잃지 않기 위해 특정 도구를 사용하는 것이 좋습니다.

  • Model.fit의 경우 진행률 저장 및 복원을 자동으로 처리하는 BackupAndRestore 콜백을 사용해야 합니다. 예제는 위의 콜백 및 훈련 섹션을 참조하세요.

  • 사용자 정의 훈련 루프의 경우 훈련이 시작되기 전에 주기적으로 모델 변수를 체크포인트하고 체크포인트에서 모델 변수를 로드해야 합니다(있는 경우). 옵티마이저가 체크포인트된 경우 훈련 진행 상황은 optimizer.iterations에서 대략적으로 추론할 수 있습니다.

checkpoint_manager = tf.train.CheckpointManager(
    tf.train.Checkpoint(model=model, optimizer=optimizer),
    checkpoint_dir,
    max_to_keep=3)
if checkpoint_manager.latest_checkpoint:
  checkpoint = checkpoint_manager.checkpoint
  checkpoint.restore(
      checkpoint_manager.latest_checkpoint).assert_existing_objects_matched()

global_steps = int(optimizer.iterations.numpy())
starting_epoch = global_steps // steps_per_epoch

for _ in range(starting_epoch, num_epochs):
  for _ in range(steps_per_epoch):
    coordinator.schedule(step_fn, args=(per_worker_iterator,))
  coordinator.join()
  checkpoint_manager.save()

RemoteValue 가져오기

함수가 성공적으로 실행되면 RemoteValue 가져오기의 성공이 보장됩니다. 현재는 함수가 실행된 후 반환 값이 즉시 코디네이터에 복사되기 때문입니다. 복사하는 동안 작업자 오류가 발생하면 사용 가능한 다른 작업자에서 함수가 재시도됩니다. 따라서 성능을 최적화하려면 반환 값 없이 함수를 예약할 수 있습니다.

오류 보고

코디네이터가 매개변수 서버의 UnavailableError와 같은 오류 또는 tf.debugging.check_numericsInvalidArgument와 같은 기타 애플리케이션 오류를 확인하면 오류를 발생시키기 전에 보류 중 및 대기열에 있는 모든 함수를 취소합니다. 해당 RemoteValue를 가져오면 CancelledError가 발생합니다.

오류가 발생한 후 코디네이터는 동일한 오류 또는 취소된 함수의 오류를 발생시키지 않습니다.

성능 향상

tf.distribute.ParameterServerStrategytf.distribute.coordinator.ClusterCoordinator로 훈련할 때 성능 문제가 발생할 수 있는 몇 가지 가능한 이유가 있습니다.

한 가지 일반적인 이유는 매개변수 서버의 로드 균형이 맞지 않고 로드가 심한 일부 매개변수 서버가 용량에 도달했기 때문입니다. 또한 여러 근본 원인이 있을 수 있습니다. 다음과 같은 몇 가지 간단한 방법으로 이 문제를 완화할 수 있습니다.

  1. ParameterServerStrategy를 구성할 때 variable_partitioner를 지정하여 큰 모델 변수를 분할합니다.
  2. 가능하면 모든 매개변수 서버에 필요한 핫스팟 변수를 단일 단계로 생성하지 않습니다. 예를 들어, 옵티마이저에서 일정한 학습률 또는 하위 클래스 tf.keras.optimizers.schedules.LearningRateSchedule을 사용합니다. 학습률이 특정 매개변수 서버에 배치되고 각 단계에서 다른 모든 매개변수 서버에서 요청하는 변수가 되는 것이 기본 동작이기 때문입니다.
  3. Keras 전처리 레이어에 전달하기 전에 큰 어휘를 섞습니다.

성능 문제의 또 다른 가능한 이유는 코디네이터에 있습니다. schedule/join의 구현은 Python 기반이므로 스레딩 오버헤드가 있을 수 있습니다. 또한 코디네이터와 작업자 간의 대기 시간이 클 수 있습니다. 이러한 경우라면 다음과 같이 할 수 있습니다.

  • Model.fit의 경우 Model.compile에 제공된 steps_per_execution 인수를 1보다 큰 값으로 설정할 수 있습니다.

  • 사용자 정의 훈련 루프의 경우 여러 단계를 단일 tf.function으로 묶을 수 있습니다.

steps_per_invocation = 10

@tf.function
def step_fn(iterator):
  for _ in range(steps_per_invocation):
    features, labels = next(iterator)
    def replica_fn(features, labels):
      ...

    strategy.run(replica_fn, args=(features, labels))

라이브러리가 더욱 최적화됨에 따라 앞으로 대부분의 사용자가 수동으로 단계를 묶을 필요가 없게 되기를 바랍니다.

또한 성능 향상을 위한 약간의 요령은 위의 작업 오류 처리하기 섹션에서 설명한 대로 반환 값 없이 함수를 예약하는 것입니다.

알려진 제한 사항

알려진 대부분의 제한 사항은 이미 위 섹션에서 다루었습니다. 이 섹션에서는 요약을 제공합니다.

ParameterServerStrategy 일반

  • 내결함성이 제대로 작동하려면 코디네이터를 포함한 모든 작업에 os.environment["grpc_fail_fast"]="use_caller"가 필요합니다.
  • 동기 매개변수 서버 훈련은 지원되지 않습니다.
  • 최적의 성능을 얻으려면 일반적으로 여러 단계를 단일 함수로 압축해야 합니다.
  • 샤딩된 변수를 포함하는 tf.saved_model.load를 통해 saved_model을 로드하는 것은 지원되지 않습니다. TensorFlow Serving을 사용하여 이러한 stored_model을 로드하는 것은 작동할 것으로 예상됩니다(자세한 내용은 제공 튜토리얼 참조).
  • 코디네이터 작업을 다시 시작하지 않고 매개변수 서버 오류에서 복구하는 것은 지원되지 않습니다.
  • tf.keras.layers.IntegerLookup, tf.keras.layers.StringLookuptf.keras.layers.TextVectorization과 같은 일부 Keras 전처리 레이어에서 일반적으로 사용되는 tf.lookup.StaticHashTable의 생성은 Strategy.scope 아래에 배치해야 합니다. 그렇지 않으면 리소스가 코디네이터에 배치되고 작업자에서 코디네이터로의 조회 RPC가 성능에 영향을 미칩니다.

Model.fit 특이 사항

  • steps_per_epoch 인수는 Model.fit에 필요합니다. Epoch에서 적절한 간격을 제공하는 값을 선택할 수 있습니다.
  • ParameterServerStrategy는 성능상의 이유로 배치 수준 호출이 있는 사용자 정의 콜백을 지원하지 않습니다. 이러한 호출을 적절하게 선택된 steps_per_epoch를 이용해 epoch 수준 호출로 변환하여 steps_per_epoch 단계 수마다 호출되도록 해야 합니다. 내장 콜백은 영향을 받지 않습니다(해당 배치 수준 호출이 성능을 발휘하도록 수정되었음). ParameterServerStrategy에 대한 배치 수준 호출을 지원할 계획에 있습니다.
  • 같은 이유로, 다른 전략과 달리 진행률 표시줄과 메트릭은 epoch 경계에서만 기록됩니다.
  • run_eagerly는 지원되지 않습니다.

사용자 정의 훈련 루프 특이 사항