RSVP for your your local TensorFlow Everywhere event today!

在使用 TensorFlow Hub 的情况下从 TF1 迁移到 TF2

本页介绍了在将 TensorFlow 代码从 TensorFlow 1 迁移到 TensorFlow 2 时如何继续使用 TensorFlow Hub,旨在补充 TensorFlow 的常规迁移指南

对于 TF2,TF Hub 已经从旧版 hub.Module API 转为用于构建 tf.compat.v1.Graph,与 tf.contrib.v1.layers 类似。现在提供 hub.KerasLayer 与其他 Keras 层用于构建 tf.keras.Model(通常在 TF2 的新 Eager Eexecution 环境中),其底层 hub.load() 方法用于低级 TensorFlow 代码。

tensorflow_hub 库中仍包含 hub.Module API,可在 TF1 以及 TF2 的 TF1 兼容模式下使用。该 API 只能加载 TF1 Hub 格式的模型。

hub.load()hub.KerasLayer 的新 API 适用于 TensorFlow 1.15(在 Eager 和计算图模式下)以及 TensorFlow 2。这一新版 API 可以加载新的 TF2 SavedModel 资源,在模型兼容性指南所述限制下,也可以加载 TF1 Hub 格式的模型。

通常情况下,建议尽可能使用新版 API。

新版 API 摘要

hub.load() 是新的低级函数,用于从 TensorFlow Hub (或兼容的服务)加载 SavedModel。它可以包装 TF2 的 tf.saved_model.load();TensorFlow 的 SavedModel 指南介绍了您可以对结果执行的操作。

m = hub.load(handle)
outputs = m(inputs)

hub.KerasLayer 类可调用 hub.load() 并调整结果以与其他 Keras 层共同用于 Keras 中。(它甚至可以方便地包装以其他方式使用的加载 SavedModel。)

model = tf.keras.Sequential([
    hub.KerasLayer(handle),
    ...])

许多教程都展示了这些 API 的实际运行。具体请参阅:

在 Estimator 训练中使用新版 API

如果您在 Estimator 中通过参数服务器训练 TF2 SavedModel(或者在将变量置于远程设备上的 TF1 会话中),则需要在 tf.Session 的 ConfigProto 中设置 experimental.share_cluster_devices_in_session,否则您将收到错误消息,例如“Assigned device '/job:ps/replica:0/task:0/device:CPU:0' does not match any device.”

可按以下方式设置所需选项:

session_config = tf.compat.v1.ConfigProto()
session_config.experimental.share_cluster_devices_in_session = True
run_config = tf.estimator.RunConfig(..., session_config=session_config)
estimator = tf.estimator.Estimator(..., config=run_config)

自 TF2.2 起,此选项已不再为实验性选项,可以删除 .experimental 部分。

加载 TF1 Hub 格式的旧版模型

有可能会出现新的 TF2 SavedModel 尚不支持您的用例的情况,此时您需要加载 TF1 Hub 格式的旧版模型。自 tensorflow_hub 0.7 版起,您可以将 TF1 Hub 格式的旧版模型与 hub.KerasLayer API 配合使用,如下所示:

m = hub.KerasLayer(handle)
tensor_out = m(tensor_in)

此外,KerasLayer 还提供了指定 tagssignatureoutput_keysignature_outputs_as_dict 的功能,从而可以使用 TF1 Hub 格式的旧版模型和旧版 SavedModel 实现更具体的用途。

有关 TF1 Hub 格式兼容性的更多信息,请参阅模型兼容性指南

使用低级 API

旧版 TF1 Hub 格式模型可以通过 tf.saved_model.load 加载。不建议使用:

# DEPRECATED: TensorFlow 1
m = hub.Module(handle, tags={"foo", "bar"})
tensors_out_dict = m(dict(x1=..., x2=...), signature="sig", as_dict=True)

而是建议使用:

# TensorFlow 2
m = hub.load(path, tags={"foo", "bar"})
tensors_out_dict = m.signatures["sig"](x1=..., x2=...)

在以上示例中,m.signatures 是一个由签名名称键控的 TensorFlow 具体函数字典。调用此类函数会计算其所有输出,即使某些并未使用。(这与 TF1 的计算图模式的惰性评估不同。)