به انجمن SIG TFX-Addons بپیوندید و به پیشرفت TFX کمک کنید!
این صفحه به‌وسیله ‏Cloud Translation API‏ ترجمه شده است.
Switch to English

جزer خط لوله Trainer TFX

جز component خط لوله Trainer TFX یک مدل TensorFlow را آموزش می دهد.

مربی و TensorFlow

Trainer از Python TensorFlow API به طور گسترده برای مدل های آموزشی استفاده می کند.

مولفه

مربی:

  • tf. نمونه هایی که برای آموزش و ارزیابی استفاده می شود.
  • یک کاربر پرونده ماژولی را ارائه می دهد که منطق مربی را تعریف می کند.
  • تعریف Protobuf از قوس قطار و قوس eval.
  • (اختیاری) طرحواره داده ای که توسط یک جز component خط لوله SchemaGen ایجاد شده و به طور اختیاری توسط توسعه دهنده تغییر می یابد.
  • نمودار تبدیل (اختیاری) تولید شده توسط یک جز an بالادست تبدیل.
  • (اختیاری) مدلهای از قبل آموزش دیده که برای سناریوهایی مانند شروع گرم استفاده می شود.
  • ابر پارامترهای (اختیاری) که به عملکرد ماژول کاربر منتقل می شوند. جزئیات ادغام با تیونر را می توانید در اینجا پیدا کنید .

Trainer emits: حداقل یک مدل برای استنتاج / خدمت (معمولاً در SavedModelFormat) و به صورت اختیاری مدل دیگری برای eval (معمولاً یک EvalSavedModel).

ما از طریق کتابخانه بازنویسی مدل از قالب های جایگزین مدل مانند TFLite پشتیبانی می کنیم. برای مثالهایی در مورد نحوه تبدیل هر دو مدل برآوردگر و Keras ، به پیوند به کتابخانه بازنویسی مدل مراجعه کنید.

مربی عمومی

مربی عمومی برنامه نویسان را قادر می سازد تا از هر API مدل TensorFlow با م Trainلفه Trainer استفاده کنند. علاوه بر TensorFlow Estimators ، توسعه دهندگان می توانند از مدل های Keras یا حلقه های آموزش سفارشی استفاده کنند. برای جزئیات بیشتر ، لطفاً به مربی عمومی در RFC مراجعه کنید .

پیکربندی م Trainلفه مربی

کد DSL خط لوله معمولی برای مربی عمومی به صورت زیر است:

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer یک ماژول آموزشی را فراخوانی می کند ، که در پارامتر module_file مشخص شده است. به جای trainer_fn ، یک run_fn در فایل واحد مورد نیاز است اگر GenericExecutor در مشخص custom_executor_spec . trainer_fn مسئول ایجاد مدل بود. علاوه بر این ، run_fn همچنین باید قسمت آموزش را مدیریت کند و مدل آموزش دیده را به مکان دلخواه FnArgs منتقل کند :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

در اینجا یک نمونه پرونده ماژول با run_fn .

توجه داشته باشید که اگر از م Transلفه Transform در خط لوله استفاده نشده باشد ، مربی نمونه ها را از مثالGen به طور مستقیم می گیرد:

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

مربی مبتنی بر برآورد (منسوخ شده)

برای کسب اطلاعات در مورد استفاده از مدل مبتنی بر برآورد با TFX و Trainer ، به طراحی کد مدل سازی TensorFlow با tf.Estimator برای TFX مراجعه کنید .

پیکربندی م Trainلفه مربی برای استفاده از مجری مبتنی برآورد

کد خط لوله معمولی Python DSL به این شکل است:

from tfx.components import Trainer
from tfx.components.trainer.executor import Executor
from tfx.dsl.components.base import executor_spec

...

trainer = Trainer(
      custom_executor_spec=executor_spec.ExecutorClassSpec(Executor),
      module_file=module_file,
      examples=transform.outputs['transformed_examples'],
      schema=infer_schema.outputs['schema'],
      base_model=latest_model_resolver.outputs['latest_model'],
      transform_graph=transform.outputs['transform_graph'],
      train_args=trainer_pb2.TrainArgs(num_steps=10000),
      eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer یک ماژول آموزشی را فراخوانی می کند ، که در پارامتر module_file مشخص شده است. یک ماژول آموزشی معمولی به این شکل است:

# TFX will call this function
def trainer_fn(trainer_fn_args, schema):
  """Build the estimator using the high level API.

  Args:
    trainer_fn_args: Holds args used to train the model as name/value pairs.
    schema: Holds the schema of the training examples.

  Returns:
    A dict of the following:

      - estimator: The estimator that will be used for training and eval.
      - train_spec: Spec for training.
      - eval_spec: Spec for eval.
      - eval_input_receiver_fn: Input function for eval.
  """
  # Number of nodes in the first layer of the DNN
  first_dnn_layer_size = 100
  num_dnn_layers = 4
  dnn_decay_factor = 0.7

  train_batch_size = 40
  eval_batch_size = 40

  tf_transform_output = tft.TFTransformOutput(trainer_fn_args.transform_output)

  train_input_fn = lambda: _input_fn(  # pylint: disable=g-long-lambda
      trainer_fn_args.train_files,
      tf_transform_output,
      batch_size=train_batch_size)

  eval_input_fn = lambda: _input_fn(  # pylint: disable=g-long-lambda
      trainer_fn_args.eval_files,
      tf_transform_output,
      batch_size=eval_batch_size)

  train_spec = tf.estimator.TrainSpec(  # pylint: disable=g-long-lambda
      train_input_fn,
      max_steps=trainer_fn_args.train_steps)

  serving_receiver_fn = lambda: _example_serving_receiver_fn(  # pylint: disable=g-long-lambda
      tf_transform_output, schema)

  exporter = tf.estimator.FinalExporter('chicago-taxi', serving_receiver_fn)
  eval_spec = tf.estimator.EvalSpec(
      eval_input_fn,
      steps=trainer_fn_args.eval_steps,
      exporters=[exporter],
      name='chicago-taxi-eval')

  run_config = tf.estimator.RunConfig(
      save_checkpoints_steps=999, keep_checkpoint_max=1)

  run_config = run_config.replace(model_dir=trainer_fn_args.serving_model_dir)
  warm_start_from = trainer_fn_args.base_model[
      0] if trainer_fn_args.base_model else None

  estimator = _build_estimator(
      # Construct layers sizes with exponetial decay
      hidden_units=[
          max(2, int(first_dnn_layer_size * dnn_decay_factor**i))
          for i in range(num_dnn_layers)
      ],
      config=run_config,
      warm_start_from=warm_start_from)

  # Create an input receiver for TFMA processing
  receiver_fn = lambda: _eval_input_receiver_fn(  # pylint: disable=g-long-lambda
      tf_transform_output, schema)

  return {
      'estimator': estimator,
      'train_spec': train_spec,
      'eval_spec': eval_spec,
      'eval_input_receiver_fn': receiver_fn
  }