Il componente della pipeline Trainer TFX addestra un modello TensorFlow.
Trainer e TensorFlow
Trainer fa ampio uso di Python tensorflow API per modelli di formazione.
Componente
Il formatore prende:
- tf.Esempi utilizzati per l'addestramento e la valutazione.
- Un file di modulo fornito dall'utente che definisce la logica del trainer.
- Protobuf definizione di args treni e args eval.
- (Facoltativo) Uno schema di dati creato da un componente della pipeline SchemaGen e facoltativamente modificato dallo sviluppatore.
- (Facoltativo) grafico di trasformazione prodotto da un componente Trasforma a monte.
- (Facoltativo) Modelli pre-addestrati utilizzati per scenari come l'avvio a caldo.
- (Facoltativo) Iperparametri, che verranno passati alla funzione del modulo utente. Dettagli della integrazione con sintonizzatore può essere trovato qui .
Il trainer emette: almeno un modello per l'inferenza/elaborazione (tipicamente in SavedModelFormat) e facoltativamente un altro modello per eval (tipicamente un EvalSavedModel).
Forniamo supporto per i formati modello alternativo, come TFLite attraverso la riscrittura libreria di modelli . Vedere il collegamento alla libreria di riscrittura dei modelli per esempi su come convertire entrambi i modelli Estimator e Keras.
Formatore generico
Il trainer generico consente agli sviluppatori di utilizzare qualsiasi API del modello TensorFlow con il componente Trainer. Oltre a TensorFlow Estimators, gli sviluppatori possono utilizzare modelli Keras o cicli di formazione personalizzati. Per i dettagli, consultare la RFC per allenatore generici .
Configurazione del componente Trainer
Il tipico codice DSL della pipeline per il Trainer generico sarebbe simile a questo:
from tfx.components import Trainer
...
trainer = Trainer(
module_file=module_file,
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Trainer richiama un modulo di formazione, che è specificato nel module_file
parametro. Invece di trainer_fn
, un run_fn
è necessaria nel file di modulo se il GenericExecutor
è specificato nel custom_executor_spec
. Il trainer_fn
era responsabile della creazione del modello. In aggiunta a ciò, run_fn
ha anche bisogno di gestire la parte di formazione e l'uscita del modello addestrato per una posizione desiderata data dal FnArgs :
from tfx.components.trainer.fn_args_utils import FnArgs
def run_fn(fn_args: FnArgs) -> None:
"""Build the TF model and train it."""
model = _build_keras_model()
model.fit(...)
# Save model to fn_args.serving_model_dir.
model.save(fn_args.serving_model_dir, ...)
Ecco un file di modulo di esempio con run_fn
.
Nota che se il componente Transform non viene utilizzato nella pipeline, il Trainer prenderà direttamente gli esempi da ExampleGen:
trainer = Trainer(
module_file=module_file,
examples=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Maggiori dettagli sono disponibili nel riferimento API Trainer .