SIG TFX-Addons 커뮤니티에 가입하고 TFX를 더욱 향상시키는 데 도움을 주세요! SIG TFX 애드온 가입

Trainer TFX 파이프 라인 구성 요소

Trainer TFX 파이프 라인 구성 요소는 TensorFlow 모델을 학습시킵니다.

트레이너와 TensorFlow

Trainer는 모델 학습을 위해 Python TensorFlow API를 광범위하게 사용합니다.

구성 요소

트레이너는 다음을 수행합니다.

  • tf. 훈련 및 평가에 사용되는 예제.
  • 트레이너 로직을 정의하는 사용자 제공 모듈 파일입니다.
  • train args 및 eval args의 Protobuf 정의.
  • (선택 사항) SchemaGen 파이프 라인 구성 요소에 의해 생성되고 개발자가 선택적으로 변경하는 데이터 스키마입니다.
  • (선택 사항) 업스트림 변환 구성 요소에서 생성 한 변환 그래프.
  • (선택 사항) 웜 스타트와 같은 시나리오에 사용되는 사전 학습 된 모델.
  • (선택 사항) 사용자 모듈 함수에 전달 될 하이퍼 파라미터. Tuner와의 통합에 대한 자세한 내용은 여기 에서 찾을 수 있습니다 .

트레이너가 내 보냅니다 : 추론 / 제공을위한 하나 이상의 모델 (일반적으로 SavedModelFormat에 있음) 및 선택적으로 평가를위한 다른 모델 (일반적으로 EvalSavedModel).

모델 재 작성 라이브러리를 통해 TFLite 와 같은 대체 모델 형식에 대한 지원을 제공합니다. Estimator 및 Keras 모델을 모두 변환하는 방법에 대한 예제는 모델 재 작성 라이브러리 링크를 참조하십시오.

일반 트레이너

Generic trainer를 통해 개발자는 Trainer 구성 요소와 함께 모든 TensorFlow 모델 API를 사용할 수 있습니다. TensorFlow Estimator 외에도 개발자는 Keras 모델 또는 맞춤 학습 루프를 사용할 수 있습니다. 자세한 내용 은 일반 트레이너를위한 RFC 를 참조하세요.

Trainer 구성 요소 구성

일반 Trainer의 일반적인 파이프 라인 DSL 코드는 다음과 같습니다.

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer는 module_file 매개 변수에 지정된 교육 모듈을 호출합니다. GenericExecutorcustom_executor_spec 지정된 경우 trainer_fn 대신 모듈 파일에 run_fn 이 필요합니다. trainer_fn 은 모델 생성을 담당했습니다. 또한 run_fn 은 훈련 부분을 처리하고 훈련 된 모델을 FnArgs에서 지정한 원하는 위치로 출력해야합니다.

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

다음은 run_fn 사용한예제 모듈 파일 입니다.

Transform 구성 요소가 파이프 라인에서 사용되지 않는 경우 Trainer는 ExampleGen에서 직접 예제를 가져옵니다.

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

자세한 내용은 Trainer API 참조 에서 확인할 수 있습니다.