Bu sayfa, Cloud Translation API ile çevrilmiştir.
Switch to English

Trainer TFX Ardışık Düzeni Bileşeni

Trainer TFX işlem hattı bileşeni bir TensorFlow modelini eğitir.

Eğitmen ve TensorFlow

Trainer, eğitim modelleri için Python TensorFlow API'sinden kapsamlı bir şekilde yararlanır.

Bileşen

Eğitmen şunları alır:

  • tf.Eğitim ve değerlendirme için kullanılan örnekler.
  • Eğitmen mantığını tanımlayan, kullanıcı tarafından sağlanan bir modül dosyası.
  • Bir SchemaGen işlem hattı bileşeni tarafından oluşturulan ve isteğe bağlı olarak geliştirici tarafından değiştirilen bir veri şeması.
  • Tren parametrelerinin protobuf tanımı ve argümanlarının değerlendirilmesi.
  • (İsteğe bağlı) yukarı akış Transform bileşeni tarafından üretilen dönüşüm grafiği.
  • (İsteğe bağlı) sıcak başlangıç ​​gibi senaryolar için kullanılan önceden eğitilmiş modeller.
  • (İsteğe bağlı) kullanıcı modülü işlevine aktarılacak olan hiperparametreler. Tuner ile entegrasyonun ayrıntıları burada bulunabilir.

Eğitmen yayar: Çıkarım / sunum için en az bir model (tipik olarak SavedModelFormat'ta) ve isteğe bağlı olarak değerlendirme için başka bir model (tipik olarak bir EvalSavedModel).

Model Yeniden Yazma Kitaplığı aracılığıyla TFLite gibi alternatif model biçimleri için destek sağlıyoruz. Hem Estimator hem de Keras modellerinin nasıl dönüştürüleceğine ilişkin örnekler için Model Yeniden Yazma Kitaplığı bağlantısına bakın.

Tahminciye dayalı Eğitmen

TFX ve Trainer ile Estimator tabanlı bir model kullanma hakkında bilgi edinmek için, bkz. Tf.Estimator for TFX ile TensorFlow modelleme kodu tasarlama .

Bir Trainer Bileşenini Yapılandırma

Tipik ardışık düzen Python DSL kodu şuna benzer:

from tfx.components import Trainer

...

trainer = Trainer(
      module_file=module_file,
      examples=transform.outputs['transformed_examples'],
      schema=infer_schema.outputs['schema'],
      base_model=latest_model_resolver.outputs['latest_model'],
      transform_graph=transform.outputs['transform_graph'],
      train_args=trainer_pb2.TrainArgs(num_steps=10000),
      eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Eğitmen, module_file parametresinde belirtilen bir eğitim modülünü çağırır. Tipik bir eğitim modülü şuna benzer:

# TFX will call this function
def trainer_fn(trainer_fn_args, schema):
  """Build the estimator using the high level API.

  Args:
    trainer_fn_args: Holds args used to train the model as name/value pairs.
    schema: Holds the schema of the training examples.

  Returns:
    A dict of the following:

      - estimator: The estimator that will be used for training and eval.
      - train_spec: Spec for training.
      - eval_spec: Spec for eval.
      - eval_input_receiver_fn: Input function for eval.
  """
  # Number of nodes in the first layer of the DNN
  first_dnn_layer_size = 100
  num_dnn_layers = 4
  dnn_decay_factor = 0.7

  train_batch_size = 40
  eval_batch_size = 40

  tf_transform_output = tft.TFTransformOutput(trainer_fn_args.transform_output)

  train_input_fn = lambda: _input_fn(  # pylint: disable=g-long-lambda
      trainer_fn_args.train_files,
      tf_transform_output,
      batch_size=train_batch_size)

  eval_input_fn = lambda: _input_fn(  # pylint: disable=g-long-lambda
      trainer_fn_args.eval_files,
      tf_transform_output,
      batch_size=eval_batch_size)

  train_spec = tf.estimator.TrainSpec(  # pylint: disable=g-long-lambda
      train_input_fn,
      max_steps=trainer_fn_args.train_steps)

  serving_receiver_fn = lambda: _example_serving_receiver_fn(  # pylint: disable=g-long-lambda
      tf_transform_output, schema)

  exporter = tf.estimator.FinalExporter('chicago-taxi', serving_receiver_fn)
  eval_spec = tf.estimator.EvalSpec(
      eval_input_fn,
      steps=trainer_fn_args.eval_steps,
      exporters=[exporter],
      name='chicago-taxi-eval')

  run_config = tf.estimator.RunConfig(
      save_checkpoints_steps=999, keep_checkpoint_max=1)

  run_config = run_config.replace(model_dir=trainer_fn_args.serving_model_dir)
  warm_start_from = trainer_fn_args.base_model[
      0] if trainer_fn_args.base_model else None

  estimator = _build_estimator(
      # Construct layers sizes with exponetial decay
      hidden_units=[
          max(2, int(first_dnn_layer_size * dnn_decay_factor**i))
          for i in range(num_dnn_layers)
      ],
      config=run_config,
      warm_start_from=warm_start_from)

  # Create an input receiver for TFMA processing
  receiver_fn = lambda: _eval_input_receiver_fn(  # pylint: disable=g-long-lambda
      tf_transform_output, schema)

  return {
      'estimator': estimator,
      'train_spec': train_spec,
      'eval_spec': eval_spec,
      'eval_input_receiver_fn': receiver_fn
  }

Genel Eğitmen

Genel eğitici, geliştiricilerin herhangi bir TensorFlow model API'sini Trainer bileşeniyle kullanmasına olanak tanır. TensorFlow Estimators'a ek olarak, geliştiriciler Keras modellerini veya özel eğitim döngülerini kullanabilir. Ayrıntılar için lütfen genel eğitmen için RFC'ye bakın.

GenericExecutor'u kullanmak için Trainer Bileşenini Yapılandırma

Genel Eğitmen için tipik iletişim hattı DSL kodu şöyle görünür:

from tfx.components import Trainer
from tfx.components.base import executor_spec
from tfx.components.trainer.executor import GenericExecutor

...

trainer = Trainer(
    module_file=module_file,
    custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Eğitmen, module_file parametresinde belirtilen bir eğitim modülünü çağırır. trainer_fn yerine, eğer GenericExecutor custom_executor_spec belirtilmişse modül dosyasında bir run_fn gereklidir.

Dönüşüm bileşeni ardışık düzen içinde kullanılmazsa, Eğitmen doğrudan ExampleGen'deki örnekleri alır:

trainer = Trainer(
    module_file=module_file,
    custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

İşte run_fn ile örnek bir modül dosyası .