使用分布策略保存和加载模型

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本

概述

在训练期间一般需要保存和加载模型。有两组用于保存和加载 Keras 模型的 API:高级 API 和低级 API。本教程演示了在使用 tf.distribute.Strategy 时如何使用 SavedModel API。要了解 SavedModel 和序列化的相关概况,请参阅保存的模型指南Keras 模型序列化指南。让我们从一个简单的示例开始:

导入依赖项:

import tensorflow_datasets as tfds

import tensorflow as tf
tfds.disable_progress_bar()

使用 tf.distribute.Strategy 准备数据和模型:

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=['accuracy'])
    return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

训练模型:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
Epoch 1/2
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

938/938 [==============================] - 4s 5ms/step - loss: 0.1971 - accuracy: 0.9421
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0662 - accuracy: 0.9801

<tensorflow.python.keras.callbacks.History at 0x7f96501659e8>

保存和加载模型

现在,您已经有一个简单的模型可供使用,让我们了解一下如何保存/加载 API。有两组可用的 API:

Keras API

以下为使用 Keras API 保存和加载模型的示例:

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)  # save() should be called out of strategy scope
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

恢复无 tf.distribute.Strategy 的模型:

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0480 - accuracy: 0.0990
Epoch 2/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0334 - accuracy: 0.0989

<tensorflow.python.keras.callbacks.History at 0x7f96c54d0a58>

恢复模型后,您可以继续在它上面进行训练,甚至无需再次调用 compile(),因为在保存之前已经对其进行了编译。模型以 TensorFlow 的标准 SavedModel proto 格式保存。有关更多信息,请参阅 saved_model 格式指南

现在,加载模型并使用 tf.distribute.Strategy 进行训练:

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0481 - accuracy: 0.0989
Epoch 2/2
938/938 [==============================] - 9s 9ms/step - loss: 0.0329 - accuracy: 0.0990

如您所见, tf.distribute.Strategy 可以按预期进行加载。此处使用的策略不必与保存前所用策略相同。

tf.saved_model API

现在,让我们看一下较低级别的 API。保存模型与 Keras API 类似:

model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Assets written to: /tmp/tf_save/assets

可以使用 tf.saved_model.load() 进行加载。但是,由于该 API 级别较低(因此用例范围更广泛),所以不会返回 Keras 模型。相反,它返回一个对象,其中包含可用于进行推断的函数。例如:

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

加载的对象可能包含多个函数,每个函数与一个键关联。"serving_default" 是使用已保存的 Keras 模型的推断函数的默认键。要使用此函数进行推断,请运行以下代码:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[ 0.17218862,  0.07492599, -0.0548683 ,  0.03503785, -0.03743191,
        -0.05301537,  0.01267872, -0.02870197, -0.33800656,  0.17991678],
       [ 0.12937182, -0.21557797, -0.09474514,  0.39076763, -0.22147779,
        -0.1787742 ,  0.2154337 ,  0.00788027, -0.14960325,  0.43123117],
       [ 0.04755233, -0.20264567, -0.17308846,  0.19781005, -0.11123425,
        -0.4295108 ,  0.05442019,  0.01459119, -0.17129104,  0.04688327],
       [ 0.09866484,  0.01627818, -0.08671301,  0.05742932, -0.20312837,
        -0.38836166, -0.06952551,  0.05141062, -0.03084616,  0.05498504],
       [ 0.00565811, -0.04239772,  0.04898138,  0.06162139, -0.16708252,
        -0.12976539, -0.00474121,  0.05431085, -0.14715545,  0.07582194],
       [ 0.17589626,  0.19629489, -0.2076093 ,  0.02031662, -0.1619812 ,
        -0.24300966, -0.0310282 , -0.00850905, -0.18514219,  0.23665032],
       [-0.02653   , -0.17737214, -0.24494407,  0.20125583, -0.17153463,
        -0.18641792,  0.11408111,  0.01489197, -0.099539  ,  0.41159016],
       [ 0.1903163 ,  0.1697292 , -0.14116906,  0.1588785 , -0.04286646,
        -0.19863203, -0.04836996, -0.00679918, -0.14634813,  0.14979276],
       [ 0.12109621,  0.03313948, -0.1955429 ,  0.23528968, -0.12369496,
        -0.20725062,  0.06024174,  0.05078189, -0.158943  ,  0.16846842],
       [ 0.16227934,  0.06379895, -0.08847713,  0.08261362, -0.03925761,
        -0.17770812, -0.043965  ,  0.02072081, -0.07430968,  0.05749936],
       [ 0.05508922, -0.14091367, -0.1887006 ,  0.12903523, -0.13182093,
        -0.11879301,  0.20175044,  0.11686974, -0.1616871 ,  0.2226192 ],
       [ 0.18285918, -0.01880376, -0.15778637,  0.04477023, -0.22364017,
        -0.23864916, -0.06328501,  0.04380857, -0.04448643,  0.40406597],
       [ 0.04721744,  0.06619421, -0.10837474,  0.1292499 , -0.17490903,
        -0.17313394, -0.06603841,  0.15658481, -0.09657097, -0.04059617],
       [-0.04412666,  0.02258963,  0.08539917,  0.2561011 , -0.18279126,
        -0.2519745 , -0.00787598,  0.08598025, -0.21961546,  0.10189874],
       [ 0.05089861,  0.06746367, -0.13205   ,  0.09160744, -0.30171782,
        -0.25160635,  0.08317091,  0.03015741, -0.10570806,  0.28686398],
       [ 0.13625176, -0.109529  ,  0.04985618,  0.08199271, -0.24280871,
        -0.22908798,  0.17737128,  0.09937412, -0.31234092,  0.2290439 ],
       [ 0.13812706,  0.10425253,  0.0128724 ,  0.12191941, -0.09126505,
        -0.13897963, -0.17568447,  0.16489705, -0.26533198,  0.06911667],
       [ 0.16982701,  0.087276  , -0.17102191,  0.06745699, -0.06239565,
        -0.17226742, -0.02450407,  0.10939141, -0.13510445,  0.04026298],
       [-0.05762933,  0.03908077,  0.0729831 ,  0.12001946, -0.12699135,
        -0.37191632, -0.10294843,  0.1815257 , -0.10121268,  0.06880292],
       [ 0.07649058, -0.03354908, -0.06362928, -0.00831218, -0.24217641,
        -0.11137463,  0.01944396,  0.0310707 ,  0.0093919 ,  0.34353036],
       [ 0.16107717, -0.04705916, -0.14095825,  0.05297582, -0.1485554 ,
        -0.12321693,  0.07225874,  0.07695273, -0.17055047,  0.22460693],
       [ 0.02565719, -0.05495968, -0.11961621,  0.03014402, -0.1645109 ,
        -0.26333475,  0.07536604,  0.04426918, -0.12448484,  0.04142715],
       [ 0.02295595,  0.01484419, -0.28111714,  0.05291839, -0.09908111,
        -0.22002876,  0.00388122,  0.06801579, -0.03227042,  0.04201593],
       [ 0.01293404, -0.15113808, -0.05814568,  0.29754263, -0.13849238,
        -0.02268202,  0.16958144,  0.12881759, -0.13463333,  0.3364867 ],
       [ 0.19805974, -0.01798259, -0.12835501,  0.26842418, -0.04154617,
        -0.19442351, -0.08115683,  0.08586816,  0.00582654,  0.04328927],
       [ 0.09159922,  0.12617984, -0.15028486,  0.23344447, -0.06932314,
        -0.1483246 , -0.02017963,  0.03262286, -0.2800941 ,  0.18364596],
       [ 0.1528    ,  0.13280275, -0.09938447,  0.03614349, -0.1096218 ,
        -0.19335787, -0.04933339, -0.02397237, -0.13356304, -0.01165973],
       [ 0.13618907,  0.14891617, -0.16118397,  0.10435603, -0.1831438 ,
        -0.16405147, -0.14186187,  0.12581114, -0.15762964,  0.13493878],
       [ 0.05534358, -0.0916103 ,  0.0352111 ,  0.0020496 , -0.19224274,
        -0.17663556,  0.08702807, -0.08016825, -0.14833373,  0.10739949],
       [ 0.02660379, -0.04472145,  0.01165188,  0.0219909 , -0.16059823,
        -0.26817566, -0.09790543,  0.10905766, -0.01595427,  0.304615  ],
       [ 0.08248052, -0.09962849, -0.02325149,  0.04280585, -0.20835052,
        -0.2023199 , -0.0130603 ,  0.07936736,  0.0494375 ,  0.27143508],
       [ 0.00310345,  0.04583906, -0.20415008,  0.1876276 , -0.06600557,
        -0.19580218, -0.02222047,  0.07650423, -0.08899002,  0.10885157],
       [ 0.0783096 , -0.01651647, -0.09479928,  0.07058451, -0.14990349,
        -0.33366078,  0.0564964 ,  0.01118498, -0.14589244,  0.22603557],
       [ 0.04565446,  0.05590308, -0.02989801, -0.07578284, -0.09796432,
        -0.20807403, -0.00954358,  0.02622838, -0.10276475, -0.05590656],
       [ 0.07286316,  0.01376749, -0.18262148,  0.28560585, -0.18269306,
        -0.06166455,  0.12229253,  0.11880912, -0.08595768,  0.17080015],
       [ 0.12635507, -0.0836257 ,  0.03501946,  0.30507207, -0.34584454,
        -0.29186884,  0.26327768,  0.18378039, -0.09220086,  0.16707191],
       [ 0.11742169,  0.02937749, -0.16469768,  0.31997636, -0.1280521 ,
        -0.17700416,  0.05593231,  0.05017062, -0.31535   ,  0.15465745],
       [ 0.08975917,  0.01203279,  0.09783987,  0.06205256, -0.05648104,
        -0.27429107, -0.12651348,  0.09195078, -0.2890005 ,  0.08270936],
       [ 0.09477694,  0.10097383, -0.05783979,  0.11597094, -0.05375554,
        -0.04229444, -0.09689695,  0.08121311, -0.05716637,  0.09075539],
       [-0.04117738, -0.06426363, -0.0629988 ,  0.00692648, -0.30303234,
        -0.28447956, -0.01935545,  0.159902  , -0.10399745,  0.17079492],
       [-0.01080875, -0.04450692, -0.19694453,  0.15313052, -0.11790004,
        -0.21164687,  0.16064486,  0.05443045,  0.04431828,  0.18498638],
       [ 0.16398555,  0.21772492, -0.03592323,  0.15181649, -0.02455682,
        -0.28267485, -0.12445807,  0.17047536, -0.19300474, -0.01467199],
       [ 0.04904355, -0.0152067 ,  0.09667489, -0.01841408, -0.08439851,
        -0.2905228 , -0.0541675 ,  0.07489735, -0.13492545,  0.1839124 ],
       [ 0.2369909 ,  0.08534706, -0.12017098,  0.04527019, -0.05781246,
        -0.1196178 , -0.09442404,  0.01685349, -0.26979008,  0.17579612],
       [ 0.04441281, -0.09139308,  0.00063404,  0.02085789, -0.17478338,
        -0.1746104 ,  0.21254838,  0.07575508, -0.19009903,  0.26038024],
       [ 0.23913413,  0.13267268, -0.11951514,  0.13184579, -0.11442515,
        -0.1563474 , -0.13503158,  0.1639925 , -0.11313978,  0.05294855],
       [ 0.11768216,  0.12213368, -0.00641227,  0.1983034 , -0.10263431,
        -0.10918278, -0.06888436,  0.26294842, -0.1041921 ,  0.09731302],
       [ 0.16183744, -0.14602011, -0.17195675,  0.1428874 , -0.26739907,
        -0.3048862 ,  0.06860068,  0.03065268, -0.13347332,  0.4117231 ],
       [-0.02206257,  0.00734324,  0.003649  ,  0.12295016, -0.22801307,
        -0.23414296, -0.03367008,  0.11127277, -0.01726604, -0.0447302 ],
       [ 0.10106434,  0.09055474, -0.12789255,  0.1377592 , -0.05564225,
        -0.21510065, -0.09061419, -0.0219887 , -0.14411387, -0.03950592],
       [ 0.12847602, -0.09453006, -0.04503661,  0.27597424, -0.17524761,
        -0.05134012,  0.16526361,  0.08649909, -0.22461002,  0.45229536],
       [ 0.04311011,  0.09949236, -0.04975891,  0.22421105, -0.12030718,
        -0.09846736, -0.1408607 ,  0.2384947 , -0.21582088,  0.01464934],
       [-0.03788627,  0.04636163,  0.07747708,  0.0814044 , -0.12896554,
        -0.31223392, -0.0578138 ,  0.1859979 , -0.10911787,  0.15140374],
       [ 0.08929176, -0.02551255, -0.06947158,  0.25500187, -0.18166143,
        -0.1110489 ,  0.0658811 ,  0.23209906, -0.00346252,  0.27463445],
       [ 0.12721871, -0.05336493, -0.01648436,  0.23337078, -0.22428553,
        -0.17424905,  0.03487325,  0.28687072,  0.04055911,  0.30594033],
       [ 0.18656036, -0.00513786, -0.16282284,  0.02530107, -0.17092519,
        -0.24259233,  0.05227455,  0.19966123, -0.28181344,  0.14443643],
       [ 0.02111852, -0.04639132, -0.01641255,  0.20416623, -0.11734181,
        -0.08085347,  0.13685697,  0.10490854, -0.09023371,  0.32988763],
       [ 0.06382357,  0.02803485,  0.03532831,  0.07898249, -0.10290041,
        -0.2603921 , -0.03376516,  0.09166428, -0.14019875,  0.19503292],
       [ 0.15105441,  0.0064583 , -0.1603775 ,  0.16818096, -0.22179885,
        -0.36698502,  0.12694073, -0.1294238 , -0.21702135,  0.34743598],
       [ 0.11475793, -0.08016841, -0.19020993,  0.27748483, -0.13198294,
        -0.22254312,  0.19926155,  0.19124901, -0.08933976,  0.25242418],
       [ 0.09380357, -0.02989926, -0.01782445,  0.00312767, -0.02519768,
        -0.43802148, -0.00290839,  0.04753356, -0.02965541,  0.10304467],
       [ 0.20286047, -0.07675526, -0.03217752,  0.17366095, -0.13799758,
        -0.27491322,  0.00279245,  0.14233288, -0.05951798,  0.36937428],
       [ 0.01445094, -0.07265921,  0.10096341,  0.17594802, -0.17472097,
        -0.2958681 ,  0.0036519 ,  0.03119059, -0.2027646 , -0.01793122],
       [-0.02391969, -0.10441571, -0.00624696,  0.06563509, -0.14965585,
        -0.3743796 ,  0.0422266 ,  0.04684277,  0.05023851, -0.07264638]],
      dtype=float32)>}

您还可以采用分布式方式加载和进行推断:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

Warning:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `experimental_run_v2` inside a tf.function to get the best performance.

调用已恢复的函数只是基于已保存模型的前向传递(预测)。如果您想继续训练加载的函数,或者将加载的函数嵌入到更大的模型中,应如何操作? 通常的做法是将此加载对象包装到 Keras 层以实现此目的。幸运的是,TF Hub 为此提供了 hub.KerasLayer,如下所示:

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])
  model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Epoch 1/2
938/938 [==============================] - 3s 3ms/step - loss: 0.2059 - accuracy: 0.9393
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0681 - accuracy: 0.9799

如您所见,hub.KerasLayer 可将从 tf.saved_model.load() 加载回的结果包装到可供构建其他模型的 Keras 层。这对于迁移学习非常实用。

我应使用哪种 API?

对于保存,如果您使用的是 Keras 模型,那么始终建议使用 Keras 的 model.save() API。如果您所保存的不是 Keras 模型,那么您只能选择使用较低级的 API。

对于加载,使用哪种 API 取决于您要从加载的 API 中获得什么。如果您无法或不想获取 Keras 模型,请使用 tf.saved_model.load()。否则,请使用 tf.keras.models.load_model()。请注意,只有保存 Keras 模型后,才能恢复 Keras 模型。

可以混合使用 API。您可以使用 model.save 保存 Keras 模型,并使用低级 API tf.saved_model.load 加载非 Keras 模型。

model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Assets written to: /tmp/keras_save/assets

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

警告

有一种特殊情况,您的 Keras 模型没有明确定义的输入。例如,可以创建没有任何输入形状的序贯模型 (Sequential([Dense(3), ...])。子类化模型在初始化后也没有明确定义的输入。在这种情况下,在保存和加载时都应坚持使用较低级别的 API,否则会出现错误。

要检查您的模型是否具有明确定义的输入,只需检查 model.inputs 是否为 None。如果非 None,则一切正常。在 .fit.evaluate.predict 中使用模型,或调用模型 (model(inputs)) 时,输入形状将自动定义。

以下为示例:

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f96b1c92320>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f96b1c92320>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f96b1c92b70>, because it is not built.

Warning:tensorflow:Skipping full serialization of Keras layer <tensorflow.python.keras.layers.core.Dense object at 0x7f96b1c92b70>, because it is not built.

INFO:tensorflow:Assets written to: /tmp/tf_save/assets

INFO:tensorflow:Assets written to: /tmp/tf_save/assets