Đã lưuModel Warmup

Sử dụng bộ sưu tập để sắp xếp ngăn nắp các trang Lưu và phân loại nội dung dựa trên lựa chọn ưu tiên của bạn.

Giới thiệu

Thời gian chạy TensorFlow có các thành phần được khởi tạo một cách lười biếng, có thể gây ra độ trễ cao cho / s yêu cầu đầu tiên được gửi đến một mô hình sau khi nó được tải. Độ trễ này có thể cao hơn vài bậc so với độ trễ của một yêu cầu suy luận đơn lẻ.

Để giảm tác động của việc khởi tạo lười biếng đối với độ trễ của yêu cầu, bạn có thể kích hoạt quá trình khởi tạo hệ thống con và thành phần tại thời điểm tải mô hình bằng cách cung cấp một tập hợp mẫu các yêu cầu suy luận cùng với SavedModel. Quá trình này được gọi là "hâm nóng" mô hình.

Cách sử dụng

SavedModel Warmup được hỗ trợ cho Hồi quy, Phân loại, Đa tham khảo và Dự đoán. Để kích hoạt khởi động mô hình tại thời điểm tải, hãy đính kèm tệp dữ liệu khởi động trong thư mục con asset.extra của thư mục SavedModel.

Yêu cầu để khởi động mô hình hoạt động chính xác:

  • Tên tệp khởi động: 'tf_serving_warmup_requests'
  • Vị trí tệp: tài sản.extra /
  • Định dạng file: TFRecord với mỗi bản ghi là một PredictionLog .
  • Số lượng bản ghi khởi động <= 1000.
  • Dữ liệu khởi động phải đại diện cho các yêu cầu suy luận được sử dụng khi phân phát.

Đoạn mã mẫu tạo ra dữ liệu khởi động:

import tensorflow as tf
from tensorflow_serving.apis import classification_pb2
from tensorflow_serving.apis import inference_pb2
from tensorflow_serving.apis import model_pb2
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_log_pb2
from tensorflow_serving.apis import regression_pb2

def main():
    with tf.io.TFRecordWriter("tf_serving_warmup_requests") as writer:
        # replace <request> with one of:
        # predict_pb2.PredictRequest(..)
        # classification_pb2.ClassificationRequest(..)
        # regression_pb2.RegressionRequest(..)
        # inference_pb2.MultiInferenceRequest(..)
        log = prediction_log_pb2.PredictionLog(
            predict_log=prediction_log_pb2.PredictLog(request=<request>))
        writer.write(log.SerializeToString())

if __name__ == "__main__":
    main()