O componente do pipeline Trainer TFX treina um modelo TensorFlow.
Trainer e TensorFlow
O Trainer faz uso extensivo da API Python TensorFlow para modelos de treinamento.
Componente
O treinador leva:
- tf.Exemplos usados para treinamento e avaliação.
- Um arquivo de módulo fornecido pelo usuário que define a lógica do treinador.
- Definição de protobuf de args de trem e args de eval.
- (Opcional) Um esquema de dados criado por um componente de pipeline SchemaGen e opcionalmente alterado pelo desenvolvedor.
- (Opcional) gráfico de transformação produzido por um componente Transform upstream.
- (Opcional) modelos pré-treinados usados para cenários como o warmstart.
- (Opcional) hiperparâmetros, que serão passados para a função do módulo do usuário. Os detalhes da integração com o Tuner podem ser encontrados aqui .
O treinador emite: pelo menos um modelo para inferência / veiculação (normalmente em SavedModelFormat) e, opcionalmente, outro modelo para eval (normalmente um EvalSavedModel).
Oferecemos suporte para formatos de modelo alternativos, como TFLite, por meio da Biblioteca de Reescrita de Modelos . Veja o link para a Biblioteca de Reescrita de Modelos para exemplos de como converter os modelos Estimator e Keras.
Treinador baseado em estimador
Para saber como usar um modelo baseado em Estimator com TFX e Trainer, consulte Projetando o código de modelagem do TensorFlow com tf.Estimator para TFX .
Configurando um componente do instrutor
O código DSL do pipeline Python típico se parece com este:
from tfx.components import Trainer
...
trainer = Trainer(
module_file=module_file,
examples=transform.outputs['transformed_examples'],
schema=infer_schema.outputs['schema'],
base_model=latest_model_resolver.outputs['latest_model'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
O Trainer invoca um módulo de treinamento, que é especificado no parâmetro module_file
. Um módulo de treinamento típico se parece com este:
# TFX will call this function
def trainer_fn(trainer_fn_args, schema):
"""Build the estimator using the high level API.
Args:
trainer_fn_args: Holds args used to train the model as name/value pairs.
schema: Holds the schema of the training examples.
Returns:
A dict of the following:
- estimator: The estimator that will be used for training and eval.
- train_spec: Spec for training.
- eval_spec: Spec for eval.
- eval_input_receiver_fn: Input function for eval.
"""
# Number of nodes in the first layer of the DNN
first_dnn_layer_size = 100
num_dnn_layers = 4
dnn_decay_factor = 0.7
train_batch_size = 40
eval_batch_size = 40
tf_transform_output = tft.TFTransformOutput(trainer_fn_args.transform_output)
train_input_fn = lambda: _input_fn( # pylint: disable=g-long-lambda
trainer_fn_args.train_files,
tf_transform_output,
batch_size=train_batch_size)
eval_input_fn = lambda: _input_fn( # pylint: disable=g-long-lambda
trainer_fn_args.eval_files,
tf_transform_output,
batch_size=eval_batch_size)
train_spec = tf.estimator.TrainSpec( # pylint: disable=g-long-lambda
train_input_fn,
max_steps=trainer_fn_args.train_steps)
serving_receiver_fn = lambda: _example_serving_receiver_fn( # pylint: disable=g-long-lambda
tf_transform_output, schema)
exporter = tf.estimator.FinalExporter('chicago-taxi', serving_receiver_fn)
eval_spec = tf.estimator.EvalSpec(
eval_input_fn,
steps=trainer_fn_args.eval_steps,
exporters=[exporter],
name='chicago-taxi-eval')
run_config = tf.estimator.RunConfig(
save_checkpoints_steps=999, keep_checkpoint_max=1)
run_config = run_config.replace(model_dir=trainer_fn_args.serving_model_dir)
warm_start_from = trainer_fn_args.base_model[
0] if trainer_fn_args.base_model else None
estimator = _build_estimator(
# Construct layers sizes with exponetial decay
hidden_units=[
max(2, int(first_dnn_layer_size * dnn_decay_factor**i))
for i in range(num_dnn_layers)
],
config=run_config,
warm_start_from=warm_start_from)
# Create an input receiver for TFMA processing
receiver_fn = lambda: _eval_input_receiver_fn( # pylint: disable=g-long-lambda
tf_transform_output, schema)
return {
'estimator': estimator,
'train_spec': train_spec,
'eval_spec': eval_spec,
'eval_input_receiver_fn': receiver_fn
}
Treinador Genérico
O treinador genérico permite que os desenvolvedores usem qualquer API de modelo do TensorFlow com o componente Trainer. Além dos TensorFlow Estimators, os desenvolvedores podem usar modelos Keras ou loops de treinamento personalizados. Para obter detalhes, consulte o RFC para treinador genérico .
Configurando o componente Trainer para usar o GenericExecutor
O código DSL de pipeline típico para o Trainer genérico seria assim:
from tfx.components import Trainer
from tfx.dsl.components.base import executor_spec
from tfx.components.trainer.executor import GenericExecutor
...
trainer = Trainer(
module_file=module_file,
custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
O Trainer invoca um módulo de treinamento, que é especificado no parâmetro module_file
. Em vez de trainer_fn
, um run_fn
é necessário no arquivo do módulo se o GenericExecutor
for especificado no custom_executor_spec
. O trainer_fn
foi responsável pela criação do modelo. Além disso, run_fn
também precisa lidar com a parte de treinamento e enviar o modelo treinado para um local desejado fornecido pelo FnArgs :
from tfx.components.trainer.fn_args_utils import FnArgs
def run_fn(fn_args: FnArgs) -> None:
"""Build the TF model and train it."""
model = _build_keras_model()
model.fit(...)
# Save model to fn_args.serving_model_dir.
model.save(fn_args.serving_model_dir, ...)
Aqui está um arquivo de módulo de exemplo com run_fn
.
Observe que, se o componente Transform não for usado no pipeline, o Trainer pegará os exemplos de ExampleGen diretamente:
trainer = Trainer(
module_file=module_file,
custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
examples=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))