O Dia da Comunidade de ML é dia 9 de novembro! Junte-nos para atualização de TensorFlow, JAX, e mais Saiba mais

O componente Trainer TFX Pipeline

O componente do pipeline Trainer TFX treina um modelo TensorFlow.

Trainer e TensorFlow

O Trainer faz uso extensivo da API Python TensorFlow para modelos de treinamento.

Componente

O treinador leva:

  • tf.Exemplos usados ​​para treinamento e avaliação.
  • Um arquivo de módulo fornecido pelo usuário que define a lógica do treinador.
  • Definição de protobuf de args de trem e args de eval.
  • (Opcional) Um esquema de dados criado por um componente de pipeline SchemaGen e opcionalmente alterado pelo desenvolvedor.
  • (Opcional) gráfico de transformação produzido por um componente Transform upstream.
  • (Opcional) modelos pré-treinados usados ​​para cenários como o warmstart.
  • (Opcional) hiperparâmetros, que serão passados ​​para a função do módulo do usuário. Os detalhes da integração com o Tuner podem ser encontrados aqui .

O treinador emite: pelo menos um modelo para inferência / veiculação (normalmente em SavedModelFormat) e, opcionalmente, outro modelo para eval (normalmente um EvalSavedModel).

Oferecemos suporte para formatos de modelos alternativos, como TFLite, por meio da Biblioteca de Reescrita de Modelos . Veja o link para a Biblioteca de Reescrita de Modelos para exemplos de como converter os modelos Estimator e Keras.

Treinador Genérico

O treinador genérico permite que os desenvolvedores usem qualquer API de modelo do TensorFlow com o componente Trainer. Além do TensorFlow Estimators, os desenvolvedores podem usar modelos Keras ou loops de treinamento personalizados. Para obter detalhes, consulte o RFC para treinador genérico .

Configurando o componente Trainer

O código DSL de pipeline típico para o Trainer genérico seria assim:

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

O Trainer invoca um módulo de treinamento, que é especificado no parâmetro module_file . Em vez de trainer_fn , um run_fn é necessário no arquivo do módulo se o GenericExecutor for especificado no custom_executor_spec . O trainer_fn foi responsável pela criação do modelo. Além disso, run_fn também precisa lidar com a parte do treinamento e enviar o modelo treinado para o local desejado fornecido pelo FnArgs :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

Aqui está umarquivo de módulo de exemplo com run_fn .

Observe que, se o componente Transform não for usado no pipeline, o Trainer pegará os exemplos de ExampleGen diretamente:

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Mais detalhes estão disponíveis na referência da API Trainer .