Il componente della pipeline TFX Trainer

Il componente della pipeline Trainer TFX addestra un modello TensorFlow.

Trainer e TensorFlow

Trainer fa ampio uso di Python tensorflow API per modelli di formazione.

Componente

Il formatore prende:

  • tf.Esempi utilizzati per l'addestramento e la valutazione.
  • Un file di modulo fornito dall'utente che definisce la logica del trainer.
  • Protobuf definizione di args treni e args eval.
  • (Facoltativo) Uno schema di dati creato da un componente della pipeline SchemaGen e facoltativamente modificato dallo sviluppatore.
  • (Facoltativo) grafico di trasformazione prodotto da un componente Trasforma a monte.
  • (Facoltativo) Modelli pre-addestrati utilizzati per scenari come l'avvio a caldo.
  • (Facoltativo) Iperparametri, che verranno passati alla funzione del modulo utente. Dettagli della integrazione con sintonizzatore può essere trovato qui .

Il trainer emette: almeno un modello per l'inferenza/elaborazione (tipicamente in SavedModelFormat) e facoltativamente un altro modello per eval (tipicamente un EvalSavedModel).

Forniamo supporto per i formati modello alternativo, come TFLite attraverso la riscrittura libreria di modelli . Vedere il collegamento alla libreria di riscrittura dei modelli per esempi su come convertire entrambi i modelli Estimator e Keras.

Formatore generico

Il trainer generico consente agli sviluppatori di utilizzare qualsiasi API del modello TensorFlow con il componente Trainer. Oltre a TensorFlow Estimators, gli sviluppatori possono utilizzare modelli Keras o cicli di formazione personalizzati. Per i dettagli, consultare la RFC per allenatore generici .

Configurazione del componente Trainer

Il tipico codice DSL della pipeline per il Trainer generico sarebbe simile a questo:

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Trainer richiama un modulo di formazione, che è specificato nel module_file parametro. Invece di trainer_fn , un run_fn è necessaria nel file di modulo se il GenericExecutor è specificato nel custom_executor_spec . Il trainer_fn era responsabile della creazione del modello. In aggiunta a ciò, run_fn ha anche bisogno di gestire la parte di formazione e l'uscita del modello addestrato per una posizione desiderata data dal FnArgs :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

Ecco un file di modulo di esempio con run_fn .

Nota che se il componente Transform non viene utilizzato nella pipeline, il Trainer prenderà direttamente gli esempi da ExampleGen:

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Maggiori dettagli sono disponibili nel riferimento API Trainer .