Komponen Saluran TFX Pelatih

Komponen pipeline Trainer TFX melatih model TensorFlow.

Pelatih dan TensorFlow

Trainer banyak menggunakan Python TensorFlow API untuk model pelatihan.

Komponen

Pelatih mengambil:

  • tf.Contoh yang digunakan untuk pelatihan dan evaluasi.
  • File modul yang disediakan pengguna yang mendefinisikan logika pelatih.
  • Definisi Protobuf dari argumen kereta dan argumen eval.
  • (Opsional) Skema data yang dibuat oleh komponen alur SchemaGen dan diubah secara opsional oleh pengembang.
  • (Opsional) grafik transformasi yang dihasilkan oleh komponen Transform upstream.
  • (Opsional) model terlatih yang digunakan untuk skenario seperti pemanasan.
  • (Opsional) hyperparameter, yang akan diteruskan ke fungsi modul pengguna. Detail integrasi dengan Tuner dapat ditemukan di sini .

Pelatih memancarkan: Setidaknya satu model untuk inferensi/penyajian (biasanya dalam SavedModelFormat) dan secara opsional model lain untuk eval (biasanya EvalSavedModel).

Kami menyediakan dukungan untuk format model alternatif seperti TFLite melalui Model Rewriting Library . Lihat link ke Model Rewriting Library untuk contoh cara mengonversi model Estimator dan Keras.

Pelatih Generik

Pelatih generik memungkinkan pengembang menggunakan API model TensorFlow apa pun dengan komponen Pelatih. Selain Estimator TensorFlow, pengembang dapat menggunakan model Keras atau loop pelatihan khusus. Untuk detailnya, silakan lihat RFC untuk pelatih umum .

Mengonfigurasi Komponen Pelatih

Kode DSL pipeline yang umum untuk Trainer generik akan terlihat seperti ini:

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Pelatih memanggil modul pelatihan, yang ditentukan dalam parameter module_file . Daripada trainer_fn , run_fn diperlukan dalam file modul jika GenericExecutor ditentukan dalam custom_executor_spec . trainer_fn bertanggung jawab untuk membuat model. Selain itu, run_fn juga perlu menangani bagian pelatihan dan mengeluarkan model yang dilatih ke lokasi yang diinginkan yang diberikan oleh FnArgs :

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

Berikut adalah contoh file modul dengan run_fn .

Perlu diperhatikan bahwa jika komponen Transform tidak digunakan dalam pipeline, maka Pelatih akan mengambil contoh dari ContohGen secara langsung:

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Detail selengkapnya tersedia di referensi API Pelatih .