คำถามเกี่ยวกับ TFX? เข้าร่วมกับเราที่ Google I / O!

ส่วนประกอบท่อเทรนเนอร์ TFX

ส่วนประกอบไปป์ไลน์ Trainer TFX ฝึกโมเดล TensorFlow

เทรนเนอร์และ TensorFlow

เทรนเนอร์ใช้ประโยชน์จาก Python TensorFlow API อย่างกว้างขวางสำหรับโมเดลการฝึกอบรม

ส่วนประกอบ

เทรนเนอร์ใช้เวลา:

  • tf ตัวอย่างที่ใช้สำหรับการฝึกอบรมและการประเมิน
  • ผู้ใช้จัดเตรียมไฟล์โมดูลที่กำหนดตรรกะของเทรนเนอร์
  • คำจำกัดความของ Protobuf ของ args train และ eval args
  • (ไม่บังคับ) สคีมาข้อมูลที่สร้างโดยคอมโพเนนต์ไปป์ไลน์ของ SchemaGen และผู้พัฒนาอาจเปลี่ยนแปลงได้
  • (ไม่บังคับ) กราฟการแปลงที่สร้างโดยส่วนประกอบการแปลงต้นน้ำ
  • (ไม่บังคับ) แบบจำลองที่ได้รับการฝึกฝนมาแล้วซึ่งใช้สำหรับสถานการณ์เช่นการเริ่มต้นในช่วงวอร์มอัพ
  • (ไม่บังคับ) ไฮเปอร์พารามิเตอร์ซึ่งจะส่งผ่านไปยังฟังก์ชันโมดูลผู้ใช้ สามารถดูรายละเอียดการทำงานร่วมกับ Tuner ได้ ที่นี่

เทรนเนอร์ส่งเสียง: อย่างน้อยหนึ่งโมเดลสำหรับการอนุมาน / การแสดงผล (โดยทั่วไปใน SavedModelFormat) และอีกโมเดลหนึ่งสำหรับ eval (โดยทั่วไปคือ EvalSavedModel)

เราให้การสนับสนุนรูปแบบโมเดลทางเลือกเช่น TFLite ผ่าน ไลบรารีการเขียนแบบจำลอง ดูลิงก์ไปยัง Model Rewriting Library สำหรับตัวอย่างวิธีการแปลงทั้งแบบจำลอง Estimator และ Keras

เทรนเนอร์ทั่วไป

เทรนเนอร์ทั่วไปช่วยให้นักพัฒนาสามารถใช้ TensorFlow model API กับส่วนประกอบ Trainer ได้ นอกจาก TensorFlow Estimators แล้วนักพัฒนายังสามารถใช้โมเดล Keras หรือลูปการฝึกอบรมที่กำหนดเองได้ สำหรับรายละเอียดโปรดดู RFC สำหรับเทรนเนอร์ทั่วไป

การกำหนดคอนฟิกส่วนประกอบเทรนเนอร์

รหัส DSL ไปป์ไลน์ทั่วไปสำหรับ Trainer ทั่วไปจะมีลักษณะดังนี้:

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

เทรนเนอร์เรียกใช้โมดูลการฝึกอบรมซึ่งระบุไว้ในพารามิเตอร์ module_file แทนที่จะ trainer_fn เป็น run_fn เป็นสิ่งจำเป็นในไฟล์โมดูลถ้า GenericExecutor ระบุไว้ใน custom_executor_spec trainer_fn เป็นผู้รับผิดชอบในการสร้างแบบจำลอง นอกจากนั้น run_fn ยังต้องจัดการส่วนการฝึกอบรมและส่งออกโมเดลที่ฝึกแล้วไปยังตำแหน่งที่ต้องการโดย FnArgs :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

นี่คือ ไฟล์โมดูลตัวอย่างที่ มี run_fn

โปรดทราบว่าหากไม่ได้ใช้ส่วนประกอบ Transform ในไปป์ไลน์ Trainer จะนำตัวอย่างจาก ExampleGen โดยตรง:

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

ผู้ฝึกสอนตามเครื่องมือประมาณการ (เลิกใช้งานแล้ว)

หากต้องการเรียนรู้เกี่ยวกับการใช้แบบจำลองตาม เครื่องมือประมาณการ กับ TFX และ Trainer โปรดดู การออกแบบโค้ดการสร้างแบบจำลอง TensorFlow ด้วย tf.Estimator สำหรับ TFX

การกำหนดคอนฟิก Trainer Component เพื่อใช้ Estimator based Executor

รหัส Python DSL ไปป์ไลน์ทั่วไปมีลักษณะดังนี้:

from tfx.components import Trainer
from tfx.components.trainer.executor import Executor
from tfx.dsl.components.base import executor_spec

...

trainer = Trainer(
      custom_executor_spec=executor_spec.ExecutorClassSpec(Executor),
      module_file=module_file,
      examples=transform.outputs['transformed_examples'],
      schema=infer_schema.outputs['schema'],
      base_model=latest_model_resolver.outputs['latest_model'],
      transform_graph=transform.outputs['transform_graph'],
      train_args=trainer_pb2.TrainArgs(num_steps=10000),
      eval_args=trainer_pb2.EvalArgs(num_steps=5000))

เทรนเนอร์เรียกใช้โมดูลการฝึกอบรมซึ่งระบุไว้ในพารามิเตอร์ module_file โมดูลการฝึกอบรมทั่วไปมีลักษณะดังนี้:

# TFX will call this function
def trainer_fn(trainer_fn_args, schema):
  """Build the estimator using the high level API.

  Args:
    trainer_fn_args: Holds args used to train the model as name/value pairs.
    schema: Holds the schema of the training examples.

  Returns:
    A dict of the following:

      - estimator: The estimator that will be used for training and eval.
      - train_spec: Spec for training.
      - eval_spec: Spec for eval.
      - eval_input_receiver_fn: Input function for eval.
  """
  # Number of nodes in the first layer of the DNN
  first_dnn_layer_size = 100
  num_dnn_layers = 4
  dnn_decay_factor = 0.7

  train_batch_size = 40
  eval_batch_size = 40

  tf_transform_output = tft.TFTransformOutput(trainer_fn_args.transform_output)

  train_input_fn = lambda: _input_fn(  # pylint: disable=g-long-lambda
      trainer_fn_args.train_files,
      tf_transform_output,
      batch_size=train_batch_size)

  eval_input_fn = lambda: _input_fn(  # pylint: disable=g-long-lambda
      trainer_fn_args.eval_files,
      tf_transform_output,
      batch_size=eval_batch_size)

  train_spec = tf.estimator.TrainSpec(  # pylint: disable=g-long-lambda
      train_input_fn,
      max_steps=trainer_fn_args.train_steps)

  serving_receiver_fn = lambda: _example_serving_receiver_fn(  # pylint: disable=g-long-lambda
      tf_transform_output, schema)

  exporter = tf.estimator.FinalExporter('chicago-taxi', serving_receiver_fn)
  eval_spec = tf.estimator.EvalSpec(
      eval_input_fn,
      steps=trainer_fn_args.eval_steps,
      exporters=[exporter],
      name='chicago-taxi-eval')

  run_config = tf.estimator.RunConfig(
      save_checkpoints_steps=999, keep_checkpoint_max=1)

  run_config = run_config.replace(model_dir=trainer_fn_args.serving_model_dir)
  warm_start_from = trainer_fn_args.base_model[
      0] if trainer_fn_args.base_model else None

  estimator = _build_estimator(
      # Construct layers sizes with exponetial decay
      hidden_units=[
          max(2, int(first_dnn_layer_size * dnn_decay_factor**i))
          for i in range(num_dnn_layers)
      ],
      config=run_config,
      warm_start_from=warm_start_from)

  # Create an input receiver for TFMA processing
  receiver_fn = lambda: _eval_input_receiver_fn(  # pylint: disable=g-long-lambda
      tf_transform_output, schema)

  return {
      'estimator': estimator,
      'train_spec': train_spec,
      'eval_spec': eval_spec,
      'eval_input_receiver_fn': receiver_fn
  }