Unisciti alla comunità SIG TFX-Addons e contribuisci a rendere TFX ancora migliore!

Il componente della pipeline TFX Trainer

Il componente della pipeline Trainer TFX addestra un modello TensorFlow.

Trainer e TensorFlow

Trainer fa ampio uso dell'API Python TensorFlow per i modelli di addestramento.

Componente

Il formatore prende:

  • tf.Esempi utilizzati per l'addestramento e la valutazione.
  • Un file di modulo fornito dall'utente che definisce la logica del trainer.
  • Definizione protobuf di args train e eval args.
  • (Facoltativo) Uno schema di dati creato da un componente della pipeline SchemaGen e facoltativamente modificato dallo sviluppatore.
  • (Facoltativo) grafico di trasformazione prodotto da un componente Trasforma a monte.
  • (Facoltativo) Modelli pre-addestrati utilizzati per scenari come l'avvio a caldo.
  • (Facoltativo) Iperparametri, che verranno passati alla funzione del modulo utente. I dettagli dell'integrazione con Tuner possono essere trovati qui .

Il trainer emette: almeno un modello per l'inferenza/elaborazione (tipicamente in SavedModelFormat) e facoltativamente un altro modello per eval (tipicamente un EvalSavedModel).

Forniamo supporto per formati di modello alternativi come TFLite attraverso la libreria di riscrittura dei modelli . Vedere il collegamento alla libreria di riscrittura dei modelli per esempi su come convertire entrambi i modelli Estimator e Keras.

Formatore generico

Il trainer generico consente agli sviluppatori di utilizzare qualsiasi API del modello TensorFlow con il componente Trainer. Oltre a TensorFlow Estimators, gli sviluppatori possono utilizzare modelli Keras o cicli di formazione personalizzati. Per i dettagli, consultare l' RFC per il trainer generico .

Configurazione del componente Trainer

Il tipico codice DSL della pipeline per il Trainer generico sarebbe simile a questo:

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer richiama un modulo di addestramento, che è specificato nel parametro module_file . Invece di trainer_fn , è richiesto un run_fn nel file del modulo se GenericExecutor è specificato in custom_executor_spec . Il trainer_fn era responsabile della creazione del modello. Oltre a ciò, run_fn deve anche gestire la parte di addestramento e restituire il modello addestrato nella posizione desiderata fornita da FnArgs :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

Ecco unfile di modulo di esempio con run_fn .

Nota che se il componente Transform non viene utilizzato nella pipeline, il Trainer prenderà direttamente gli esempi da ExampleGen:

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Maggiori dettagli sono disponibili nel riferimento API Trainer .