트레이너 TFX 파이프라인 구성요소

Trainer TFX 파이프라인 구성요소는 TensorFlow 모델을 학습시킵니다.

트레이너와 TensorFlow

Trainer는 모델 학습을 위해 Python TensorFlow API를 광범위하게 사용합니다.

요소

트레이너는 다음을 수행합니다.

  • tf.훈련 및 평가에 사용되는 예제입니다.
  • 트레이너 로직을 정의하는 사용자 제공 모듈 파일입니다.
  • train args 및 eval args의 Protobuf 정의입니다.
  • (선택 사항) SchemaGen 파이프라인 구성 요소에 의해 생성되고 선택적으로 개발자에 의해 변경되는 데이터 스키마입니다.
  • (선택 사항) 업스트림 변환 구성 요소에서 생성된 변환 그래프입니다.
  • (선택 사항) 웜스타트와 같은 시나리오에 사용되는 사전 학습된 모델입니다.
  • (선택 사항) 사용자 모듈 함수에 전달되는 하이퍼파라미터. Tuner와의 통합에 대한 자세한 내용은 여기에서 확인할 수 있습니다.

Trainer는 추론/제공을 위한 하나 이상의 모델(일반적으로 SavedModelFormat에 있음)과 선택적으로 평가를 위한 다른 모델(일반적으로 EvalSavedModel)을 내보냅니다.

우리는 모델 재작성 라이브러리를 통해 TFLite 와 같은 대체 모델 형식을 지원합니다. Estimator와 Keras 모델을 모두 변환하는 방법에 대한 예는 Model Rewriting Library 링크를 참조하세요.

일반 트레이너

일반 트레이너를 사용하면 개발자가 Trainer 구성요소와 함께 모든 TensorFlow 모델 API를 사용할 수 있습니다. TensorFlow Estimator 외에도 개발자는 Keras 모델 또는 사용자 정의 훈련 루프를 사용할 수 있습니다. 자세한 내용은 일반 트레이너에 대한 RFC를 참조하세요.

트레이너 구성요소 구성

일반 Trainer의 일반적인 파이프라인 DSL 코드는 다음과 같습니다.

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer는 module_file 매개변수에 지정된 교육 모듈을 호출합니다. custom_executor_specGenericExecutor 지정된 경우 trainer_fn 대신 run_fn 이 모듈 파일에 필요합니다. trainer_fn 이 모델 생성을 담당했습니다. 그 외에도 run_fn 훈련 부분을 처리하고 훈련된 모델을 FnArgs 에서 제공하는 원하는 위치로 출력해야 합니다.

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

다음은 run_fn 이 포함된 예제 모듈 파일 입니다.

Transform 구성 요소가 파이프라인에서 사용되지 않으면 Trainer는 exampleGen에서 직접 예제를 가져옵니다.

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

자세한 내용은 Trainer API 참조 에서 확인할 수 있습니다.