ส่วนประกอบไปป์ไลน์ TFX ของผู้ฝึกสอน

คอมโพเนนต์ไปป์ไลน์ Trainer TFX ฝึกโมเดล TensorFlow

เทรนเนอร์และ TensorFlow

เทรนเนอร์ทำให้การใช้งานที่กว้างขวางของงูหลาม TensorFlow API สำหรับแบบจำลองการฝึกอบรม

ส่วนประกอบ

เทรนเนอร์ใช้เวลา:

  • tf.ตัวอย่างที่ใช้ในการฝึกอบรมและประเมินผล
  • ผู้ใช้จัดเตรียมไฟล์โมดูลที่กำหนดตรรกะของผู้ฝึกสอน
  • protobuf นิยามของ args รถไฟและ args EVAL
  • (ไม่บังคับ) สคีมาข้อมูลที่สร้างโดยคอมโพเนนต์ไปป์ไลน์ SchemaGen และผู้พัฒนาสามารถเลือกแก้ไขได้
  • (ไม่บังคับ) กราฟการแปลงที่สร้างโดยองค์ประกอบการแปลงต้นน้ำ
  • (ไม่บังคับ) รุ่นก่อนการฝึกอบรมที่ใช้สำหรับสถานการณ์ต่างๆ เช่น warmstart
  • (ไม่บังคับ) ไฮเปอร์พารามิเตอร์ ซึ่งจะถูกส่งไปยังฟังก์ชันโมดูลผู้ใช้ รายละเอียดของการทำงานร่วมกับจูนเนอร์ที่สามารถพบได้ ที่นี่

Trainer ปล่อย: อย่างน้อยหนึ่งโมเดลสำหรับการอนุมาน/การแสดงผล (โดยทั่วไปใน SavedModelFormat) และอีกโมเดลหนึ่งสำหรับ eval (โดยทั่วไปคือ EvalSavedModel)

เราให้การสนับสนุนสำหรับรูปแบบทางเลือกรูปแบบเช่น TFLite ผ่าน การเขียนใหม่ห้องสมุดรุ่น ดูลิงก์ไปยัง Model Rewrite Library สำหรับตัวอย่างวิธีการแปลงทั้งโมเดล Estimator และ Keras

ผู้ฝึกสอนทั่วไป

ผู้ฝึกสอนทั่วไปช่วยให้นักพัฒนาใช้ API รุ่น TensorFlow ใดๆ กับองค์ประกอบ Trainer นอกจาก TensorFlow Estimators แล้ว นักพัฒนายังสามารถใช้โมเดล Keras หรือลูปการฝึกอบรมแบบกำหนดเองได้อีกด้วย สำหรับรายละเอียดโปรดดู RFC ฝึกทั่วไป

การกำหนดค่าส่วนประกอบเทรนเนอร์

รหัส DSL ไปป์ไลน์ทั่วไปสำหรับ Trainer ทั่วไปจะมีลักษณะดังนี้:

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

เทรนเนอร์เรียกโมดูลการฝึกอบรมซึ่งระบุไว้ใน module_file พารามิเตอร์ Instead of trainer_fn , a run_fn is required in the module file if the GenericExecutor is specified in the custom_executor_spec . The trainer_fn was responsible for creating the model. นอกจากนั้น run_fn ยังต้องการที่จะจัดการส่วนการฝึกอบรมและการส่งออกรูปแบบการฝึกอบรมไปยังสถานที่ที่ต้องการได้รับจาก FnArgs :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

นี่คือ ไฟล์โมดูลตัวอย่าง กับ run_fn

โปรดทราบว่าหากไม่ได้ใช้องค์ประกอบ Transform ในไปป์ไลน์ Trainer จะดึงตัวอย่างจาก ExampleGen โดยตรง:

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

More details are available in the Trainer API reference .