Trainer TFXパイプラインコンポーネントは、TensorFlowモデルをトレーニングします。
トレーナーとTensorFlow
トレーナーは、Pythonの広範な使用可能TensorFlowのトレーニングモデルのAPIを。
成分
トレーナーが取る:
- tf。トレーニングと評価に使用される例。
- トレーナーロジックを定義するユーザー提供のモジュールファイル。
- いるProtobuf列車の引数とevalの引数の定義。
- (オプション)SchemaGenパイプラインコンポーネントによって作成され、オプションで開発者によって変更されたデータスキーマ。
- (オプション)アップストリームのTransformコンポーネントによって生成された変換グラフ。
- (オプション)ウォームスタートなどのシナリオに使用される事前トレーニング済みモデル。
- (オプション)ハイパーパラメータ。ユーザーモジュール関数に渡されます。チューナーとの統合の詳細は見つけることができるここに。
トレーナーが発行するもの:推論/サービング用の少なくとも1つのモデル(通常はSavedModelFormat)と、オプションでeval用の別のモデル(通常はEvalSavedModel)。
私たちは、次のような代替モデル形式のサポートを提供TFLite通じモデル書き換えライブラリを。 EstimatorモデルとKerasモデルの両方を変換する方法の例については、モデル書き換えライブラリへのリンクを参照してください。
ジェネリックトレーナー
Generic Trainerを使用すると、開発者はTensorFlowモデルAPIをTrainerコンポーネントで使用できます。 TensorFlow Estimatorに加えて、開発者はKerasモデルまたはカスタムトレーニングループを使用できます。詳細については、以下を参照してください一般的なトレーナーのためのRFCを。
トレーナーコンポーネントの構成
一般的なトレーナーの一般的なパイプラインDSLコードは次のようになります。
from tfx.components import Trainer
...
trainer = Trainer(
module_file=module_file,
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
トレーナーに指定されているトレーニングモジュール、呼び出しmodule_file
パラメータを。代わりにtrainer_fn
、 run_fn
あればモジュールファイルに必要とされるGenericExecutor
で指定されcustom_executor_spec
。 trainer_fn
モデルの作成を担当していました。それに加えて、 run_fn
によっても与えられ、所望の位置に訓練モデルトレーニング部と出力を処理する必要FnArgs 。
from tfx.components.trainer.fn_args_utils import FnArgs
def run_fn(fn_args: FnArgs) -> None:
"""Build the TF model and train it."""
model = _build_keras_model()
model.fit(...)
# Save model to fn_args.serving_model_dir.
model.save(fn_args.serving_model_dir, ...)
ここで例のモジュールファイルとrun_fn
。
Transformコンポーネントがパイプラインで使用されていない場合、トレーナーはExampleGenから直接例を取得することに注意してください。
trainer = Trainer(
module_file=module_file,
examples=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
詳細はで利用可能なトレーナーのAPIリファレンス。