Dołącz do społeczności SIG TFX-Addons i pomóż ulepszyć TFX! Dołącz do dodatków SIG TFX

Składnik potoku Trainer TFX

Składnik potoku Trainer TFX szkoli model TensorFlow.

Trener i TensorFlow

Trainer w szerokim zakresie wykorzystuje interfejs API Python TensorFlow do trenowania modeli.

Składnik

Trener bierze:

  • tf.Przykłady wykorzystywane do szkolenia i oceny.
  • Dostarczony przez użytkownika plik modułu, który definiuje logikę trenera.
  • Protobufowa definicja argumentów pociągu i argumentów eval.
  • (Opcjonalnie) Schemat danych utworzony przez składnik potoku SchemaGen i opcjonalnie zmieniony przez dewelopera.
  • (Opcjonalne) wykres transformacji utworzony przez składnik transformacji poprzedzający.
  • (Opcjonalnie) wstępnie wytrenowane modele używane w scenariuszach, takich jak warmstart.
  • (Opcjonalnie) hiperparametry, które zostaną przekazane do funkcji modułu użytkownika. Szczegóły integracji z Tunerem znajdziesz tutaj .

Trainer emituje: co najmniej jeden model do wnioskowania/serwowania (zazwyczaj w SavedModelFormat) i opcjonalnie inny model do eval (zazwyczaj EvalSavedModel).

Zapewniamy obsługę alternatywnych formatów modeli, takich jak TFLite, za pośrednictwem Biblioteki przepisywania modeli . Zobacz link do Biblioteki przepisywania modeli, aby zapoznać się z przykładami konwersji modeli Estimator i Keras.

Trener ogólny

Trener ogólny umożliwia programistom korzystanie z dowolnego interfejsu API modelu TensorFlow z komponentem Trainer. Oprócz narzędzi TensorFlow Estimators programiści mogą korzystać z modeli Keras lub niestandardowych pętli szkoleniowych. Szczegółowe informacje można znaleźć w dokumencie RFC dotyczącym ogólnego trenera .

Konfiguracja komponentu trenera

Typowy kod DSL potoku dla ogólnego Trainera wyglądałby tak:

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer wywołuje moduł szkoleniowy, który jest określony w parametrze module_file . Zamiast trainer_fn , w pliku modułu wymagany jest run_fn jeśli GenericExecutor jest określony w custom_executor_spec . trainer_fn był odpowiedzialny za stworzenie modelu. Oprócz tego run_fn musi również obsłużyć część szkoleniową i wyprowadzić wyszkolony model do żądanej lokalizacji podanej przez FnArgs :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

Otoprzykładowy plik modułu z run_fn .

Zwróć uwagę, że jeśli komponent Transform nie jest używany w potoku, to Trainer weźmie przykłady bezpośrednio z ExampleGen:

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Więcej szczegółów można znaleźć w instrukcji Trainer API .