Trainer TFX 파이프라인 구성 요소는 TensorFlow 모델을 훈련합니다.
트레이너와 TensorFlow
트레이너는 파이썬의 광범위한 사용하게 TensorFlow의 교육 모델에 대한 API를.
요소
트레이너는 다음을 수행합니다.
- 훈련 및 평가에 사용되는 tf.Examples.
- 트레이너 로직을 정의하는 사용자 제공 모듈 파일입니다.
- Protobuf 기차 인수 및 평가 인수의 정의.
- (선택 사항) SchemaGen 파이프라인 구성 요소에 의해 생성되고 개발자가 선택적으로 변경한 데이터 스키마입니다.
- (선택 사항) 업스트림 Transform 구성 요소에서 생성된 변환 그래프.
- (선택 사항) 웜스타트와 같은 시나리오에 사용되는 사전 훈련된 모델.
- (선택 사항) 사용자 모듈 함수에 전달될 하이퍼파라미터. 튜너와 통합의 세부 사항은 찾을 수 있습니다 여기에 .
트레이너는 다음을 내보냅니다. 추론/서빙을 위한 하나 이상의 모델(일반적으로 SavedModelFormat) 및 선택적으로 eval용 다른 모델(일반적으로 EvalSavedModel).
We provide support for alternate model formats such as TFLite through the Model Rewriting Library . Estimator 및 Keras 모델을 모두 변환하는 방법에 대한 예제는 Model Rewriting Library 링크를 참조하십시오.
일반 트레이너
일반 트레이너를 사용하면 개발자가 Trainer 구성 요소와 함께 모든 TensorFlow 모델 API를 사용할 수 있습니다. TensorFlow Estimators 외에도 개발자는 Keras 모델 또는 사용자 지정 교육 루프를 사용할 수 있습니다. 자세한 내용은 참조하십시오 일반적인 트레이너에 대한 RFC를 .
트레이너 구성 요소 구성
일반 Trainer의 일반적인 파이프라인 DSL 코드는 다음과 같습니다.
from tfx.components import Trainer
...
trainer = Trainer(
module_file=module_file,
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Trainer invokes a training module, which is specified in the module_file
parameter. 대신에 trainer_fn
하는 run_fn
경우 생성 모듈 파일에 필요한 GenericExecutor
에 지정된되어 custom_executor_spec
. trainer_fn
모델을 만들기위한 책임. 그뿐만 아니라, run_fn
또한 교육 부분과 출력에 의해 주어진 원하는 위치로 훈련 모델을 처리해야 FnArgs를 :
from tfx.components.trainer.fn_args_utils import FnArgs
def run_fn(fn_args: FnArgs) -> None:
"""Build the TF model and train it."""
model = _build_keras_model()
model.fit(...)
# Save model to fn_args.serving_model_dir.
model.save(fn_args.serving_model_dir, ...)
여기이다 예 모듈 파일 과 run_fn
.
Transform 구성 요소가 파이프라인에서 사용되지 않는 경우 Trainer는 ExampleGen에서 직접 예제를 가져옵니다.
trainer = Trainer(
module_file=module_file,
examples=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
자세한 내용은에서 사용할 수있는 트레이너 API 참조 .