Treten Sie der SIG TFX-Addons-Community bei und helfen Sie mit, TFX noch besser zu machen! SIG TFX-Addons beitreten

SavedModel Warmup

Einführung

Die TensorFlow-Laufzeit verfügt über Komponenten, die träge initialisiert werden. Dies kann zu einer hohen Latenz für die ersten Anforderungen führen, die nach dem Laden an ein Modell gesendet werden. Diese Latenz kann mehrere Größenordnungen höher sein als die einer einzelnen Inferenzanforderung.

Um die Auswirkungen der verzögerten Initialisierung auf die Anforderungslatenz zu verringern, können Sie die Initialisierung der Subsysteme und Komponenten zur Modellladezeit auslösen, indem Sie zusammen mit dem SavedModel einen Beispielsatz von Inferenzanforderungen bereitstellen. Dieser Vorgang wird als "Aufwärmen" des Modells bezeichnet.

Verwendung

SavedModel Warmup wird für Regress, Classify, MultiInference und Predict unterstützt. Um das Aufwärmen des Modells beim Laden auszulösen, fügen Sie eine Aufwärmdatendatei unter dem Unterordner assets.extra des SavedModel-Verzeichnisses hinzu.

Voraussetzungen, damit das Aufwärmen des Modells ordnungsgemäß funktioniert:

  • Name der Aufwärmdatei: 'tf_serving_warmup_requests'
  • Dateispeicherort: assets.extra /
  • Dateiformat: TFRecord mit jedem Datensatz als PredictionLog .
  • Anzahl der Aufwärmdatensätze <= 1000.
  • Die Aufwärmdaten müssen repräsentativ für die beim Serving verwendeten Inferenzanforderungen sein.

Beispiel eines Code-Snippets, das Aufwärmdaten erzeugt:

import tensorflow as tf
from tensorflow_serving.apis import classification_pb2
from tensorflow_serving.apis import inference_pb2
from tensorflow_serving.apis import model_pb2
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_log_pb2
from tensorflow_serving.apis import regression_pb2

def main():
    with tf.io.TFRecordWriter("tf_serving_warmup_requests") as writer:
        # replace <request> with one of:
        # predict_pb2.PredictRequest(..)
        # classification_pb2.ClassificationRequest(..)
        # regression_pb2.RegressionRequest(..)
        # inference_pb2.MultiInferenceRequest(..)
        log = prediction_log_pb2.PredictionLog(
            predict_log=prediction_log_pb2.PredictLog(request=<request>))
        writer.write(log.SerializeToString())

if __name__ == "__main__":
    main()