คอมโพเนนต์ไปป์ไลน์ Trainer TFX ฝึกโมเดล TensorFlow
เทรนเนอร์และ TensorFlow
เทรนเนอร์ทำให้การใช้งานที่กว้างขวางของงูหลาม TensorFlow API สำหรับแบบจำลองการฝึกอบรม
ส่วนประกอบ
เทรนเนอร์ใช้เวลา:
- tf.ตัวอย่างที่ใช้ในการฝึกอบรมและประเมินผล
- ผู้ใช้จัดเตรียมไฟล์โมดูลที่กำหนดตรรกะของผู้ฝึกสอน
- protobuf นิยามของ args รถไฟและ args EVAL
- (ไม่บังคับ) สคีมาข้อมูลที่สร้างโดยคอมโพเนนต์ไปป์ไลน์ SchemaGen และผู้พัฒนาสามารถเลือกแก้ไขได้
- (ไม่บังคับ) กราฟการแปลงที่สร้างโดยองค์ประกอบการแปลงต้นน้ำ
- (ไม่บังคับ) รุ่นก่อนการฝึกอบรมที่ใช้สำหรับสถานการณ์ต่างๆ เช่น warmstart
- (ไม่บังคับ) ไฮเปอร์พารามิเตอร์ ซึ่งจะถูกส่งไปยังฟังก์ชันโมดูลผู้ใช้ รายละเอียดของการทำงานร่วมกับจูนเนอร์ที่สามารถพบได้ ที่นี่
Trainer ปล่อย: อย่างน้อยหนึ่งโมเดลสำหรับการอนุมาน/การแสดงผล (โดยทั่วไปใน SavedModelFormat) และอีกโมเดลหนึ่งสำหรับ eval (โดยทั่วไปคือ EvalSavedModel)
เราให้การสนับสนุนสำหรับรูปแบบทางเลือกรูปแบบเช่น TFLite ผ่าน การเขียนใหม่ห้องสมุดรุ่น ดูลิงก์ไปยัง Model Rewrite Library สำหรับตัวอย่างวิธีการแปลงทั้งโมเดล Estimator และ Keras
ผู้ฝึกสอนทั่วไป
ผู้ฝึกสอนทั่วไปช่วยให้นักพัฒนาใช้ API รุ่น TensorFlow ใดๆ กับองค์ประกอบ Trainer นอกจาก TensorFlow Estimators แล้ว นักพัฒนายังสามารถใช้โมเดล Keras หรือลูปการฝึกอบรมแบบกำหนดเองได้อีกด้วย สำหรับรายละเอียดโปรดดู RFC ฝึกทั่วไป
การกำหนดค่าส่วนประกอบเทรนเนอร์
รหัส DSL ไปป์ไลน์ทั่วไปสำหรับ Trainer ทั่วไปจะมีลักษณะดังนี้:
from tfx.components import Trainer
...
trainer = Trainer(
module_file=module_file,
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
เทรนเนอร์เรียกโมดูลการฝึกอบรมซึ่งระบุไว้ใน module_file
พารามิเตอร์ Instead of trainer_fn
, a run_fn
is required in the module file if the GenericExecutor
is specified in the custom_executor_spec
. The trainer_fn
was responsible for creating the model. นอกจากนั้น run_fn
ยังต้องการที่จะจัดการส่วนการฝึกอบรมและการส่งออกรูปแบบการฝึกอบรมไปยังสถานที่ที่ต้องการได้รับจาก FnArgs :
from tfx.components.trainer.fn_args_utils import FnArgs
def run_fn(fn_args: FnArgs) -> None:
"""Build the TF model and train it."""
model = _build_keras_model()
model.fit(...)
# Save model to fn_args.serving_model_dir.
model.save(fn_args.serving_model_dir, ...)
นี่คือ ไฟล์โมดูลตัวอย่าง กับ run_fn
โปรดทราบว่าหากไม่ได้ใช้องค์ประกอบ Transform ในไปป์ไลน์ Trainer จะดึงตัวอย่างจาก ExampleGen โดยตรง:
trainer = Trainer(
module_file=module_file,
examples=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
More details are available in the Trainer API reference .