SIG TFX-Addonsコミュニティに参加して、TFXをさらに改善してください!
このページは Cloud Translation API によって翻訳されました。
Switch to English

トレーナーTFXパイプラインコンポーネント

Trainer TFXパイプラインコンポーネントは、TensorFlowモデルをトレーニングします。

トレーナーとTensorFlow

トレーナーは、モデルのトレーニングにPython TensorFlowAPIを幅広く利用しています。

成分

トレーナーが取る:

  • tf。トレーニングと評価に使用される例。
  • トレーナーロジックを定義するユーザー提供のモジュールファイル。
  • trainargsとevalargsのProtobuf定義。
  • (オプション)SchemaGenパイプラインコンポーネントによって作成され、オプションで開発者によって変更されたデータスキーマ。
  • (オプション)アップストリームのTransformコンポーネントによって生成された変換グラフ。
  • (オプション)ウォームスタートなどのシナリオに使用される事前トレーニング済みモデル。
  • (オプション)ハイパーパラメータ。ユーザーモジュール関数に渡されます。チューナーとの統合の詳細については、こちらをご覧ください

トレーナーが発行するもの:推論/提供用の少なくとも1つのモデル(通常はSavedModelFormat)と、オプションでeval用の別のモデル(通常はEvalSavedModel)。

モデル書き換えライブラリを通じて、 TFLiteなどの代替モデル形式のサポートを提供します。 EstimatorモデルとKerasモデルの両方を変換する方法の例については、モデル書き換えライブラリへのリンクを参照してください。

ジェネリックトレーナー

Generic Trainerを使用すると、開発者はTensorFlowモデルAPIをTrainerコンポーネントで使用できます。 TensorFlow Estimatorに加えて、開発者はKerasモデルまたはカスタムトレーニングループを使用できます。詳細については、 ジェネリックトレーナーRFCを参照してください。

トレーナーコンポーネントの構成

一般的なトレーナーの一般的なパイプラインDSLコードは次のようになります。

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

トレーナーは、 module_fileパラメーターで指定されたトレーニングモジュールを呼び出します。代わりにtrainer_fnrun_fnあればモジュールファイルに必要とされるGenericExecutorで指定されcustom_executor_spectrainer_fnはモデルの作成を担当しました。それに加えて、 run_fnはトレーニング部分を処理し、トレーニングされたモデルをFnArgsによって指定された目的の場所に出力する必要もあります。

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

run_fnrun_fnモジュールファイルの例を次に示します

Transformコンポーネントがパイプラインで使用されていない場合、トレーナーはExampleGenから直接例を取得することに注意してください。

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Estimatorベースのトレーナー(非推奨)

TFXおよびTrainerでのEstimatorベースのモデルの使用については、TFX用のtf.Estimatorを使用したTensorFlowモデリングコードの設計を参照してください。

Estimatorベースのエグゼキューターを使用するためのトレーナーコンポーネントの構成

典型的なパイプラインPythonDSLコードは次のようになります。

from tfx.components import Trainer
from tfx.components.trainer.executor import Executor
from tfx.dsl.components.base import executor_spec

...

trainer = Trainer(
      custom_executor_spec=executor_spec.ExecutorClassSpec(Executor),
      module_file=module_file,
      examples=transform.outputs['transformed_examples'],
      schema=infer_schema.outputs['schema'],
      base_model=latest_model_resolver.outputs['latest_model'],
      transform_graph=transform.outputs['transform_graph'],
      train_args=trainer_pb2.TrainArgs(num_steps=10000),
      eval_args=trainer_pb2.EvalArgs(num_steps=5000))

トレーナーは、 module_fileパラメーターで指定されたトレーニングモジュールを呼び出します。一般的なトレーニングモジュールは次のようになります。

# TFX will call this function
def trainer_fn(trainer_fn_args, schema):
  """Build the estimator using the high level API.

  Args:
    trainer_fn_args: Holds args used to train the model as name/value pairs.
    schema: Holds the schema of the training examples.

  Returns:
    A dict of the following:

      - estimator: The estimator that will be used for training and eval.
      - train_spec: Spec for training.
      - eval_spec: Spec for eval.
      - eval_input_receiver_fn: Input function for eval.
  """
  # Number of nodes in the first layer of the DNN
  first_dnn_layer_size = 100
  num_dnn_layers = 4
  dnn_decay_factor = 0.7

  train_batch_size = 40
  eval_batch_size = 40

  tf_transform_output = tft.TFTransformOutput(trainer_fn_args.transform_output)

  train_input_fn = lambda: _input_fn(  # pylint: disable=g-long-lambda
      trainer_fn_args.train_files,
      tf_transform_output,
      batch_size=train_batch_size)

  eval_input_fn = lambda: _input_fn(  # pylint: disable=g-long-lambda
      trainer_fn_args.eval_files,
      tf_transform_output,
      batch_size=eval_batch_size)

  train_spec = tf.estimator.TrainSpec(  # pylint: disable=g-long-lambda
      train_input_fn,
      max_steps=trainer_fn_args.train_steps)

  serving_receiver_fn = lambda: _example_serving_receiver_fn(  # pylint: disable=g-long-lambda
      tf_transform_output, schema)

  exporter = tf.estimator.FinalExporter('chicago-taxi', serving_receiver_fn)
  eval_spec = tf.estimator.EvalSpec(
      eval_input_fn,
      steps=trainer_fn_args.eval_steps,
      exporters=[exporter],
      name='chicago-taxi-eval')

  run_config = tf.estimator.RunConfig(
      save_checkpoints_steps=999, keep_checkpoint_max=1)

  run_config = run_config.replace(model_dir=trainer_fn_args.serving_model_dir)
  warm_start_from = trainer_fn_args.base_model[
      0] if trainer_fn_args.base_model else None

  estimator = _build_estimator(
      # Construct layers sizes with exponetial decay
      hidden_units=[
          max(2, int(first_dnn_layer_size * dnn_decay_factor**i))
          for i in range(num_dnn_layers)
      ],
      config=run_config,
      warm_start_from=warm_start_from)

  # Create an input receiver for TFMA processing
  receiver_fn = lambda: _eval_input_receiver_fn(  # pylint: disable=g-long-lambda
      tf_transform_output, schema)

  return {
      'estimator': estimator,
      'train_spec': train_spec,
      'eval_spec': eval_spec,
      'eval_input_receiver_fn': receiver_fn
  }