El componente de canalización Trainer TFX entrena un modelo de TensorFlow.
Entrenador y TensorFlow
Entrenador hace un amplio uso de la Python TensorFlow API para los modelos de formación.
Componente
El entrenador toma:
- tf.Ejemplos utilizados para entrenamiento y evaluación.
- Un archivo de módulo proporcionado por el usuario que define la lógica del entrenador.
- Protobuf definición de args args tren y eval.
- (Opcional) Un esquema de datos creado por un componente de canalización de SchemaGen y modificado opcionalmente por el desarrollador.
- (Opcional) gráfico de transformación producido por un componente Transform ascendente.
- (Opcional) modelos preentrenados utilizados para escenarios como arranque en caliente.
- (Opcional) hiperparámetros, que se pasarán a la función del módulo de usuario. Los detalles de la integración con el sintonizador se pueden encontrar aquí .
El entrenador emite: al menos un modelo para inferencia/servicio (normalmente en formato de modelo guardado) y, opcionalmente, otro modelo para evaluación (normalmente, un modelo guardado de evaluación).
Proporcionamos soporte para formatos alternativos como modelo TFLite a través de la Biblioteca Modelo reescritura . Consulte el enlace a la biblioteca de reescritura de modelos para ver ejemplos de cómo convertir modelos de Estimator y Keras.
Entrenador genérico
El entrenador genérico permite a los desarrolladores usar cualquier modelo de API de TensorFlow con el componente Entrenador. Además de TensorFlow Estimators, los desarrolladores pueden usar modelos Keras o ciclos de entrenamiento personalizados. Para más detalles, ver el RFC genérico para el entrenador .
Configuración del componente de entrenador
El código DSL de canalización típico para el Entrenador genérico se vería así:
from tfx.components import Trainer
...
trainer = Trainer(
module_file=module_file,
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Entrenador invoca un módulo de formación, que se especifica en el module_file
parámetro. En lugar de trainer_fn
, un run_fn
es necesaria en el archivo de módulo si el GenericExecutor
se especifica en el custom_executor_spec
. El trainer_fn
fue responsable de crear el modelo. Además de eso, run_fn
también tiene que manejar la parte de entrenamiento y la salida del modelo entrenado a un lugar deseado dada por FnArgs :
from tfx.components.trainer.fn_args_utils import FnArgs
def run_fn(fn_args: FnArgs) -> None:
"""Build the TF model and train it."""
model = _build_keras_model()
model.fit(...)
# Save model to fn_args.serving_model_dir.
model.save(fn_args.serving_model_dir, ...)
He aquí un archivo de módulo ejemplo con run_fn
.
Tenga en cuenta que si el componente Transform no se usa en la canalización, el Entrenador tomaría los ejemplos de ExampleGen directamente:
trainer = Trainer(
module_file=module_file,
examples=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
Más detalles están disponibles en la referencia de la API entrenador .