جزء Trainer TFX Pipeline

جزء خط لوله Trainer TFX یک مدل TensorFlow را آموزش می دهد.

ترینر و تنسورفلو

ترینر به طور گسترده از Python TensorFlow API برای مدل های آموزشی استفاده می کند.

جزء

مربی می گیرد:

  • tf. مثال هایی که برای آموزش و ارزیابی استفاده می شوند.
  • یک فایل ماژول ارائه شده توسط کاربر که منطق مربی را تعریف می کند.
  • تعریف Protobuf از قطار ارگ و ارگ eval.
  • (اختیاری) یک طرح داده ایجاد شده توسط یک جزء خط لوله SchemaGen و به صورت اختیاری توسط توسعه دهنده تغییر می کند.
  • (اختیاری) تبدیل گراف تولید شده توسط یک جزء Transform بالادست.
  • (اختیاری) مدل های از پیش آموزش دیده مورد استفاده برای سناریوهایی مانند شروع گرما.
  • هایپرپارامترهای (اختیاری)، که به تابع ماژول کاربر ارسال می شود. جزئیات ادغام با Tuner را می توانید در اینجا پیدا کنید.

Trainer منتشر می کند: حداقل یک مدل برای استنتاج/خدمت (معمولا در SavedModelFormat) و به صورت اختیاری مدل دیگری برای eval (معمولا EvalSavedModel).

ما از قالب‌های مدل جایگزین مانند TFLite از طریق کتابخانه بازنویسی مدل پشتیبانی می‌کنیم. برای نمونه‌هایی از نحوه تبدیل هر دو مدل تخمین‌گر و کراس، پیوند کتابخانه بازنویسی مدل را ببینید.

مربی عمومی

Generic trainer توسعه دهندگان را قادر می سازد تا از هر API مدل TensorFlow با جزء Trainer استفاده کنند. علاوه بر برآوردگرهای TensorFlow، توسعه‌دهندگان می‌توانند از مدل‌های Keras یا حلقه‌های آموزشی سفارشی استفاده کنند. برای جزئیات، لطفاً RFC برای مربی عمومی را ببینید.

پیکربندی کامپوننت Trainer

کد DSL خط لوله معمولی برای Trainer عمومی به شکل زیر است:

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer یک ماژول آموزشی را فراخوانی می کند که در پارامتر module_file مشخص شده است. اگر GenericExecutor در custom_executor_spec مشخص شده باشد، به جای trainer_fn ، یک run_fn در فایل ماژول مورد نیاز است. trainer_fn مسئول ایجاد مدل بود. علاوه بر آن، run_fn همچنین باید بخش آموزشی را مدیریت کند و مدل آموزش‌دیده شده را در محل مورد نظر ارائه شده توسط FnArgs خروجی دهد:

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

در اینجا یک نمونه فایل ماژول با run_fn آمده است.

توجه داشته باشید که اگر کامپوننت Transform در خط لوله استفاده نشود، Trainer مثال‌ها را مستقیماً از ExampleGen می‌گیرد:

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

جزئیات بیشتر در مرجع Trainer API موجود است.