Thành phần đường dẫn TFX dành cho huấn luyện viên

Thành phần đường dẫn Trainer TFX huấn luyện mô hình TensorFlow.

Huấn luyện viên và TensorFlow

Huấn luyện viên sử dụng rộng rãi API Python TensorFlow cho các mô hình đào tạo.

Thành phần

Huấn luyện viên thực hiện:

  • tf.Examples được sử dụng để đào tạo và đánh giá.
  • Tệp mô-đun do người dùng cung cấp xác định logic huấn luyện.
  • Định nghĩa Protobuf của train args và eval args.
  • (Tùy chọn) Lược đồ dữ liệu được tạo bởi thành phần đường dẫn SchemaGen và được nhà phát triển thay đổi tùy ý.
  • (Tùy chọn) biểu đồ biến đổi được tạo bởi thành phần Biến đổi ngược dòng.
  • (Tùy chọn) các mô hình được đào tạo trước được sử dụng cho các tình huống như khởi động.
  • (Tùy chọn) siêu tham số, sẽ được chuyển đến chức năng mô-đun người dùng. Chi tiết về việc tích hợp với Tuner có thể được tìm thấy ở đây .

Huấn luyện viên phát ra: Ít nhất một mô hình để suy luận/cung cấp (thường là ở SavingModelFormat) và một mô hình khác tùy chọn cho eval (thường là EvalSavedModel).

Chúng tôi cung cấp hỗ trợ cho các định dạng mô hình thay thế như TFLite thông qua Thư viện viết lại mô hình . Xem liên kết đến Thư viện viết lại mô hình để biết ví dụ về cách chuyển đổi cả mô hình Ước tính và Keras.

Huấn luyện viên chung

Trình huấn luyện chung cho phép các nhà phát triển sử dụng bất kỳ API mô hình TensorFlow nào với thành phần Huấn luyện viên. Ngoài Công cụ ước tính TensorFlow, nhà phát triển có thể sử dụng mô hình Keras hoặc vòng đào tạo tùy chỉnh. Để biết chi tiết, vui lòng xem RFC dành cho huấn luyện viên chung .

Cấu hình thành phần huấn luyện viên

Mã DSL đường dẫn điển hình cho Trainer chung sẽ trông như thế này:

from tfx.components import Trainer

...

trainer = Trainer(
    module_file=module_file,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Huấn luyện viên gọi một mô-đun đào tạo được chỉ định trong tham số module_file . Thay vì trainer_fn , bắt buộc phải có run_fn trong tệp mô-đun nếu GenericExecutor được chỉ định trong custom_executor_spec . trainer_fn chịu trách nhiệm tạo mô hình. Ngoài ra, run_fn còn cần xử lý phần huấn luyện và xuất mô hình đã huấn luyện đến vị trí mong muốn do FnArgs đưa ra:

from tfx.components.trainer.fn_args_utils import FnArgs

def run_fn(fn_args: FnArgs) -> None:
  """Build the TF model and train it."""
  model = _build_keras_model()
  model.fit(...)
  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir, ...)

Đây là một tệp mô-đun ví dụ với run_fn .

Lưu ý rằng nếu thành phần Biến đổi không được sử dụng trong quy trình thì Huấn luyện viên sẽ lấy trực tiếp các ví dụ từ Ví dụ:

trainer = Trainer(
    module_file=module_file,
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(num_steps=10000),
    eval_args=trainer_pb2.EvalArgs(num_steps=5000))

Thông tin chi tiết hơn có sẵn trong tài liệu tham khảo API dành cho huấn luyện viên .