本頁面由 Cloud Translation API 翻譯而成。
Switch to English

使用TensorFlow Hub從TF1遷移到TF2

此頁面說明瞭如何在將TensorFlow代碼從TensorFlow 1遷移到TensorFlow 2時繼續使用TensorFlow Hub。它補充了TensorFlow的常規遷移指南

對於TF2,TF Hub已從傳統的hub.Module API轉移到了構建tf.compat.v1.Graph hub.Module API,就像tf.contrib.v1.layers一樣。相反,現在有一個hub.KerasLayer與其他hub.KerasLayer層一起使用來構建tf.keras.Model (通常在TF2的新的急切執行環境中 ),以及用於底層TensorFlow代碼的基礎hub.load()方法。

tensorflow_hub庫中仍然可以使用hub.Module API,以在TF1和TF2的TF1兼容模式下使用。它只能加載TF1 Hub格式的模型。

hub.load()hub.KerasLayer的新API適用於TensorFlow 1.15(在急切和圖形模式下)和TensorFlow2。此新API可以加載新的TF2 SavedModel資產,並且具有模型中列出的限制兼容性指南 ,是TF1 Hub格式的舊模型。

通常,建議盡可能使用新的API。

新API摘要

hub.load()是新的低級函數,用於從TensorFlow Hub(或兼容的服務)加載SavedModel。它包裝了TF2的tf.saved_model.load() ; TensorFlow的SavedModel指南描述了您可以對結果執行的操作。

 m = hub.load(handle)
outputs = m(inputs)
 

hub.KerasLayer類調用hub.load()並調整結果以使其與其他Keras層一起在hub.load()中使用。 (對於以其他方式使用的已加載SavedModel,它甚至可能是一個方便的包裝器。)

 model = tf.keras.Sequential([
    hub.KerasLayer(handle),
    ...])
 

許多教程都展示了這些API的作用。特別看到

在Estimator培訓中使用新的API

如果您在Estimator中使用TF2 SavedModel進行參數服務器訓練(或者在TF1會話中將變量放置在遠程設備上),則需要在tf.Session的ConfigProto中設置experimental.share_cluster_devices_in_session ,否則您將收到錯誤消息例如“分配的設備'/ job:ps /副本:0 /任務:0 /設備:CPU:0'與任何設備都不匹配。”

可以像這樣設置必要的選項

 session_config = tf.compat.v1.ConfigProto()
session_config.experimental.share_cluster_devices_in_session = True
run_config = tf.estimator.RunConfig(..., session_config=session_config)
estimator = tf.estimator.Estimator(..., config=run_config)
 

從TF2.2開始,此選項不再是實驗性的,可以刪除.experimental部分。

加載TF1 Hub格式的舊模型

可能發生的情況是,新的TF2 SavedModel尚不適用於您的用例,您需要以TF1 Hub格式加載舊模型。從tensorflow_hub 0.7版開始,您可以將TF1 Hub格式的舊模型與hub.KerasLayer一起使用,如下所示:

 m = hub.KerasLayer(handle)
tensor_out = m(tensor_in)
 

此外, KerasLayer了指定tagssignatureoutput_keysignature_outputs_as_dict以便更特定地使用TF1 Hub格式的舊模型和舊版SavedModels。

有關TF1集線器格式兼容性的更多信息,請參見型號兼容性指南

使用較低級別的API

可以通過tf.saved_model.load加載舊版TF1 Hub格式模型。代替

 # DEPRECATED: TensorFlow 1
m = hub.Module(handle, tags={"foo", "bar"})
tensors_out_dict = m(dict(x1=..., x2=...), signature="sig", as_dict=True)
 

建議使用:

 # TensorFlow 2
m = hub.load(path, tags={"foo", "bar"})
tensors_out_dict = m.signatures["sig"](x1=..., x2=...)
 

在這些示例中, m.signatures是由簽名名稱鍵控的TensorFlow 具體功能的決定 。調用此類函數會計算其所有輸出,即使未使用也是如此。 (這與TF1的圖形模式的惰性評估不同。)