TF Hub 的通用 SavedModel API

簡介

TensorFlow Hub 代管各種工作的模型。相同工作的模型最好能導入通用 API,讓模型使用者無須修改使用模型的程式碼即可輕鬆進行交換,即使模型來自不同的發布者也沒關係。

目標是要讓交換相同工作的不同模型能夠像切換字串值的超參數一樣簡單。如此一來,模型使用者便可輕鬆找到最適合自己問題的模型。

這個目錄收錄了 TF2 SavedModel 格式模型的通用 API 規格 (這會取代現已淘汰的 TF1 Hub 格式通用簽名)。

可重複使用的 SavedModel:共同基礎

Reusable SavedModel API 會定義一般慣例,規定將 SavedModel 載入 Python 程式,並在更大的 TensorFlow 模型中重複使用的方式。

基本用法:

obj = hub.load("path/to/model")  # That's tf.saved_model.load() after download.
outputs = obj(inputs, training=False)  # Invokes the tf.function obj.__call__.

如果是 Keras 使用者,hub.KerasLayer 類別會使用這個 API 將可重複使用的 SavedModel 包裝為 Keras 層 (讓 Keras 使用者無須處理細節),其中包含下列工作專屬 API 對應的輸入和輸出。

工作專屬的 API

這些 API 可透過特定的機器學習工作和資料類型慣例,修正 Reusable SavedModel API