トレーナーTFXパイプラインコンポーネント

Trainer TFXパイプラインコンポーネントは、TensorFlowモデルをトレーニングします。

トレーナーとTensorFlow

トレーナーは、Pythonの広範な使用可能TensorFlowのトレーニングモデルのAPIを。

成分

トレーナーが取る:

  • tf。トレーニングと評価に使用される例。
  • トレーナーロジックを定義するユーザー提供のモジュールファイル。
  • いるProtobuf列車の引数とevalの引数の定義。
  • (オプション)SchemaGenパイプラインコンポーネントによって作成され、オプションで開発者によって変更されたデータスキーマ。
  • (オプション)アップストリームのTransformコンポーネントによって生成された変換グラフ。
  • (オプション)ウォームスタートなどのシナリオに使用される事前トレーニング済みモデル。
  • (オプション)ハイパーパラメータ。ユーザーモジュール関数に渡されます。チューナーとの統合の詳細は見つけることができるここに

トレーナーが発行するもの:推論/サービング用の少なくとも1つのモデル(通常はSavedModelFormat)と、オプションでeval用の別のモデル(通常はEvalSavedModel)。

私たちは、次のような代替モデル形式のサポートを提供TFLite通じモデル書き換えライブラリを。 EstimatorモデルとKerasモデルの両方を変換する方法の例については、モデル書き換えライブラリへのリンクを参照してください。

ジェネリックトレーナー

Generic Trainerを使用すると、開発者はTensorFlowモデルAPIをTrainerコンポーネントで使用できます。 TensorFlow Estimatorに加えて、開発者はKerasモデルまたはカスタムトレーニングループを使用できます。詳細については、以下を参照してください一般的なトレーナーのためのRFCを

トレーナーコンポーネントの構成

一般的なトレーナーの一般的なパイプラインDSLコードは次のようになります。

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

トレーナーに指定されているトレーニングモジュール、呼び出しmodule_fileパラメータを。代わりにtrainer_fnrun_fnあればモジュールファイルに必要とされるGenericExecutorで指定されcustom_executor_spectrainer_fnモデルの作成を担当していました。それに加えて、 run_fnによっても与えられ、所望の位置に訓練モデルトレーニング部と出力を処理する必要FnArgs

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

ここで例のモジュールファイルrun_fn

Transformコンポーネントがパイプラインで使用されていない場合、トレーナーはExampleGenから直接例を取得することに注意してください。

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

詳細はで利用可能なトレーナーのAPIリファレンス