SIG TFX-Addonsコミュニティに参加して、TFXをさらに改善するのを手伝ってください! SIGTFXに参加-アドオン

トレーナー TFX パイプライン コンポーネント

Trainer TFX パイプライン コンポーネントは、TensorFlow モデルをトレーニングします。

トレーナーと TensorFlow

Trainer は、モデルのトレーニングに Python TensorFlow API を広範囲に使用します。

成分

トレーナーが取る:

  • tf.トレーニングと評価に使用される例。
  • トレーナー ロジックを定義するユーザー提供のモジュール ファイル。
  • train args と eval args のProtobuf定義。
  • (オプション) SchemaGen パイプライン コンポーネントによって作成され、開発者によってオプションで変更されるデータ スキーマ。
  • (オプション) 上流の Transform コンポーネントによって生成された変換グラフ。
  • (オプション) ウォームスタートなどのシナリオに使用される事前トレーニング済みモデル。
  • (オプション) ユーザーモジュール関数に渡されるハイパーパラメータ。 Tuner との統合の詳細については、こちらを参照してください

トレーナーが出力する: 少なくとも 1 つの推論/サービング (通常は SavedModelFormat 内) 用のモデル、およびオプションで eval 用の別のモデル (通常は EvalSavedModel)。

モデル書き換えライブラリを通じて、 TFLiteなどの代替モデル形式のサポートを提供します。 Estimator モデルと Keras モデルの両方を変換する方法の例については、モデル書き換えライブラリへのリンクを参照してください。

ジェネリックトレーナー

汎用トレーナーにより、開発者は Trainer コンポーネントで任意の TensorFlow モデル API を使用できます。 TensorFlow Estimator に加えて、開発者は Keras モデルまたはカスタム トレーニング ループを使用できます。詳細については、 汎用トレーナーRFCを参照してください。

トレーナー コンポーネントの構成

一般的な Trainer の典型的なパイプライン DSL コードは次のようになります。

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

トレーナーは、 module_fileパラメータで指定されたトレーニング モジュールを呼び出します。代わりにtrainer_fnrun_fnあればモジュールファイルに必要とされるGenericExecutorで指定されcustom_executor_spectrainer_fnはモデルの作成を担当しました。それに加えて、 run_fnはトレーニング部分を処理し、トレーニング済みモデルをFnArgsによって指定された目的の場所に出力する必要もあります。

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

run_fn使用したモジュール ファイルの例を次に示します

パイプラインで変換コンポーネントが使用されていない場合、トレーナーは ExampleGen から直接例を取得することに注意してください。

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

詳細については、 Trainer API リファレンスをご覧ください。