使用分布策略保存和加载模型

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本

概述

在训练期间一般需要保存和加载模型。有两组用于保存和加载 Keras 模型的 API:高级 API 和低级 API。本教程演示了在使用 tf.distribute.Strategy 时如何使用 SavedModel API。要了解 SavedModel 和序列化的相关概况,请参阅保存的模型指南Keras 模型序列化指南。让我们从一个简单的示例开始:

导入依赖项:

import tensorflow_datasets as tfds

import tensorflow as tf
tfds.disable_progress_bar()

使用 tf.distribute.Strategy 准备数据和模型:

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=['accuracy'])
    return model
2021-08-13 21:26:44.716873: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-13 21:26:44.725236: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-13 21:26:44.726147: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-13 21:26:44.727819: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-08-13 21:26:44.728427: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-13 21:26:44.729316: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-13 21:26:44.730192: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-13 21:26:45.313756: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-13 21:26:45.314823: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-13 21:26:45.315775: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-13 21:26:45.316739: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 14648 MB memory:  -> device: 0, name: Tesla V100-SXM2-16GB, pci bus id: 0000:00:05.0, compute capability: 7.0
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

训练模型:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
2021-08-13 21:26:49.072698: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:461] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
2021-08-13 21:26:49.112086: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
Epoch 1/2
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
2021-08-13 21:26:54.091005: I tensorflow/stream_executor/cuda/cuda_dnn.cc:369] Loaded cuDNN version 8100
2021-08-13 21:26:54.614254: I tensorflow/core/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
938/938 [==============================] - 10s 4ms/step - loss: 0.2064 - accuracy: 0.9403
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0656 - accuracy: 0.9811
<keras.callbacks.History at 0x7f3f100027d0>

保存和加载模型

现在,您已经有一个简单的模型可供使用,让我们了解一下如何保存/加载 API。有两组可用的 API:

Keras API

以下为使用 Keras API 保存和加载模型的示例:

keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)  # save() should be called out of strategy scope
2021-08-13 21:27:02.357743: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

恢复无 tf.distribute.Strategy 的模型:

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0489 - accuracy: 0.9851
Epoch 2/2
938/938 [==============================] - 2s 2ms/step - loss: 0.0338 - accuracy: 0.9898
<keras.callbacks.History at 0x7f3eb40dba10>

恢复模型后,您可以继续在它上面进行训练,甚至无需再次调用 compile(),因为在保存之前已经对其进行了编译。模型以 TensorFlow 的标准 SavedModel proto 格式保存。有关更多信息,请参阅 saved_model 格式指南

现在,加载模型并使用 tf.distribute.Strategy 进行训练:

another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
2021-08-13 21:27:08.056395: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:461] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
2021-08-13 21:27:08.094840: W tensorflow/core/framework/dataset.cc:679] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
Epoch 1/2
938/938 [==============================] - 8s 9ms/step - loss: 0.0473 - accuracy: 0.9855
Epoch 2/2
938/938 [==============================] - 8s 9ms/step - loss: 0.0332 - accuracy: 0.9901

如您所见, tf.distribute.Strategy 可以按预期进行加载。此处使用的策略不必与保存前所用策略相同。

tf.saved_model API

现在,让我们看一下较低级别的 API。保存模型与 Keras API 类似:

model = get_model()  # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

可以使用 tf.saved_model.load() 进行加载。但是,由于该 API 级别较低(因此用例范围更广泛),所以不会返回 Keras 模型。相反,它返回一个对象,其中包含可用于进行推断的函数。例如:

DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

加载的对象可能包含多个函数,每个函数与一个键关联。"serving_default" 是使用已保存的 Keras 模型的推断函数的默认键。要使用此函数进行推断,请运行以下代码:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[ 3.65724899e-02, -7.95967132e-03, -1.00796238e-01,
         1.37677751e-02, -4.93638702e-02,  4.77697998e-02,
        -6.90107197e-02,  4.17762101e-02,  1.82279497e-02,
         1.25376523e-01],
       [-3.31194341e-01, -2.45896727e-02, -2.00920701e-01,
         3.89861315e-02,  8.51680040e-02,  1.35855436e-01,
         2.21062213e-01, -3.63828242e-02, -5.83496913e-02,
         9.21528041e-02],
       [-1.06184594e-01,  1.10845715e-02, -1.73543483e-01,
         9.11187381e-02,  1.51753694e-01,  1.21240221e-01,
         3.00811790e-02,  1.78449392e-01, -1.60936452e-02,
         2.05711916e-01],
       [-6.10642657e-02, -4.46341224e-02, -4.64374460e-02,
        -7.20641315e-02,  1.24262378e-01,  2.28749529e-01,
        -7.41111562e-02,  2.87428051e-02, -1.72992870e-02,
         5.09614162e-02],
       [-7.52729252e-02, -3.80045697e-02, -1.17102042e-01,
        -1.13207176e-02,  4.47788946e-02,  2.12933123e-01,
        -4.22352478e-02, -2.66746189e-02,  8.03319961e-02,
         1.41919151e-01],
       [-7.11470842e-04,  1.35435864e-01, -1.84715077e-01,
        -7.49045908e-02,  4.19820733e-02,  1.93238795e-01,
        -5.12994938e-02,  1.22605346e-01,  1.08534470e-02,
         1.54421002e-01],
       [-3.10355216e-01, -2.79432982e-02, -2.06599280e-01,
        -2.87060291e-02,  1.62211787e-02,  2.01898113e-01,
         1.41179740e-01,  1.74240917e-02, -1.01091400e-01,
         7.41028711e-02],
       [-1.14341535e-01,  1.07076690e-01, -6.38952553e-02,
        -1.13967493e-01, -4.09397930e-02,  3.23887914e-01,
        -8.63047242e-02,  1.58014923e-01,  9.03190151e-02,
         4.42166477e-02],
       [-2.15394497e-01,  1.63442880e-01, -2.38296181e-01,
        -2.29089633e-02,  4.29530069e-02,  3.06127012e-01,
        -4.42468189e-02,  1.43302858e-01, -7.36406893e-02,
         6.67506084e-02],
       [-3.05610560e-02,  2.28566974e-02, -2.51703188e-02,
        -3.58148180e-02,  1.23452522e-01,  1.00282632e-01,
        -1.10969558e-01, -3.09387222e-02,  7.53785521e-02,
        -3.61409709e-02],
       [-4.59564030e-02, -1.05900988e-02,  3.31334621e-02,
        -2.15584598e-02, -5.60109168e-02,  4.51564528e-02,
         7.87367523e-02,  1.00009106e-01, -1.13775052e-01,
         2.07918603e-02],
       [-2.51858652e-01, -1.34114921e-02, -2.10219294e-01,
         6.79658726e-02,  1.45916313e-01,  1.90819025e-01,
         1.41717091e-01,  1.15573108e-01, -1.21435061e-01,
         1.40484214e-01],
       [-1.99411690e-01,  7.26427361e-02, -1.01630807e-01,
        -5.99182844e-02,  2.07168851e-02,  1.25382483e-01,
        -1.82084218e-02,  8.01223516e-02, -9.55494791e-02,
         1.15978546e-01],
       [-8.19459856e-02,  6.83340132e-02, -2.68051997e-02,
         1.81210190e-02,  3.29252332e-02,  3.08010608e-01,
        -1.94989502e-01, -3.38167772e-02,  9.03463364e-02,
         1.39072806e-01],
       [-9.68681276e-03,  5.47459722e-03,  6.05191067e-02,
        -4.00486402e-03,  2.78014168e-02,  2.89546490e-01,
        -1.69370383e-01,  5.04085794e-02,  6.11638501e-02,
         5.52373081e-02],
       [-5.86698987e-02, -6.45917654e-03, -1.94407683e-02,
         2.28475407e-02, -6.65318221e-03,  1.58965766e-01,
        -5.37033305e-02,  5.85665405e-02,  8.12815279e-02,
         1.20282196e-01],
       [-7.42891431e-03, -5.84855489e-02, -7.45271817e-02,
        -4.15958017e-02,  3.35569158e-02,  1.77989334e-01,
        -1.48767950e-02,  3.57758328e-02,  2.82963365e-03,
         1.09271519e-01],
       [-1.32584140e-01,  2.92834714e-02, -3.34548019e-02,
        -4.07676660e-02,  2.24383548e-02,  1.94459602e-01,
         1.79706365e-02,  7.77915493e-03, -5.22729307e-02,
         8.35585967e-03],
       [-6.60282001e-02,  8.80589411e-02, -5.35537340e-02,
         3.74459028e-02, -2.23972909e-02,  5.64393476e-02,
        -1.34317338e-01,  7.86396638e-02,  2.74049118e-02,
         4.99131382e-02],
       [-4.52270806e-02,  1.27429180e-02, -1.17198572e-01,
        -3.88418660e-02,  6.37637675e-02,  2.96926200e-01,
        -1.05026767e-01,  6.70263544e-02,  4.51672710e-02,
         1.56702086e-01],
       [-4.14775535e-02,  3.33369039e-02, -1.12429604e-01,
        -7.79120401e-02, -3.17105278e-02, -8.62846524e-03,
         1.54817738e-02, -1.29007250e-02,  7.20687211e-03,
         1.33686572e-01],
       [-7.41247386e-02,  4.30507697e-02, -1.49426609e-02,
        -2.93423347e-02, -3.03055719e-03,  2.71241777e-02,
         7.81110022e-03,  5.09299636e-02, -5.45268878e-02,
         3.74240279e-02],
       [-1.02496251e-01,  1.04278870e-01, -1.62839487e-01,
        -6.85771555e-02,  8.36103484e-02,  1.59424126e-01,
         8.35703872e-03,  1.00706026e-01, -3.12802941e-02,
         1.06803030e-01],
       [-2.35395804e-01,  6.46075755e-02, -1.47827327e-01,
         2.44522095e-02,  1.33527853e-02,  4.18298319e-02,
         2.32360996e-02, -1.15777627e-01, -5.17803058e-03,
         8.86921063e-02],
       [-7.59777501e-02,  8.77709389e-02, -6.93199188e-02,
         1.55621469e-02,  9.22686309e-02,  5.43488190e-02,
        -7.94870853e-02,  7.71209877e-03, -1.21777654e-01,
        -1.13232359e-02],
       [-1.87247366e-01,  7.71497861e-02, -1.31681457e-01,
         8.82048160e-02,  3.90983745e-03,  2.24681556e-01,
         9.96573120e-02,  1.01886384e-01, -9.52179357e-02,
         5.82887381e-02],
       [-5.08049875e-02,  1.09401256e-01,  3.08726635e-02,
        -4.37353104e-02, -4.21650112e-02,  8.69765952e-02,
        -7.96904415e-02, -3.72592770e-02,  4.07711528e-02,
        -7.58516043e-02],
       [-9.34456438e-02,  1.23333327e-01, -7.94956833e-02,
        -2.09217072e-02,  8.77359696e-03,  1.03266746e-01,
        -3.54990028e-02,  3.97887602e-02, -9.77787375e-03,
         1.00057386e-01],
       [-1.06727138e-01,  3.18616256e-03,  5.46025112e-03,
         3.08541209e-03, -2.62943991e-02,  1.69947743e-01,
        -4.40010875e-02,  1.70524325e-02,  1.19232640e-01,
         1.12516899e-02],
       [-1.77004635e-01, -4.45999429e-02, -1.72354937e-01,
         2.98313759e-02,  7.47403726e-02,  1.39612630e-01,
         2.09762584e-02,  1.00536957e-01, -7.88800493e-02,
         9.49023813e-02],
       [-1.84279963e-01, -1.90250576e-04, -2.46025380e-02,
        -1.92879394e-01,  8.03158209e-02,  3.84482622e-01,
        -3.49260345e-02,  1.01533920e-01, -1.60415098e-02,
         9.92035270e-02],
       [-1.64332926e-01,  1.65910259e-01, -1.53395399e-01,
         3.41798551e-02,  6.16617128e-02,  1.85333520e-01,
        -6.23546541e-04,  1.79899052e-01, -2.10485861e-01,
         7.92790949e-03],
       [-8.59231874e-02,  9.01222825e-02, -1.53029948e-01,
         1.75119564e-03,  9.62990150e-02,  1.37268126e-01,
        -2.94610709e-02,  4.35150377e-02, -5.52184582e-02,
         8.52424130e-02],
       [ 1.76177397e-02, -6.86351955e-03, -3.12671717e-03,
        -1.58546194e-01, -2.80723441e-03,  1.26830578e-01,
        -1.29562125e-01,  4.32498977e-02,  5.54641336e-03,
         3.37120518e-03],
       [-1.98745728e-01,  1.99069321e-01, -1.79387838e-01,
        -4.22908440e-02, -2.15709992e-02,  1.37084886e-01,
        -3.14375460e-02,  2.15068460e-02,  6.93278611e-02,
         9.26464722e-02],
       [-5.22071980e-02, -8.80567878e-02,  1.24042295e-02,
        -3.65715325e-02, -6.59699738e-03,  1.65754601e-01,
         3.96447256e-02,  6.81914538e-02,  2.27720141e-02,
         5.03141582e-02],
       [-1.73142940e-01,  1.13592222e-02, -2.05960512e-01,
         5.22327609e-02,  2.14335024e-02,  6.23104945e-02,
         1.46242920e-02,  2.54560933e-02, -8.61978009e-02,
         1.90168381e-01],
       [ 1.50322653e-02, -6.10070266e-02, -4.55627590e-03,
        -1.49103720e-02,  3.17678750e-02,  1.55935049e-01,
        -7.91542232e-02,  5.37255593e-02,  6.21045753e-02,
         2.34526098e-02],
       [-3.67312096e-02,  2.84136012e-02, -3.90986390e-02,
        -1.33132666e-01, -3.53388414e-02,  1.16274938e-01,
        -6.22434244e-02, -4.56262119e-02,  4.69698198e-02,
         8.47075358e-02],
       [-1.05453320e-01, -3.92754860e-02,  6.44865334e-02,
         1.02566391e-01,  8.65193903e-02,  2.03211993e-01,
         1.62507687e-02,  1.11915775e-01, -8.04723650e-02,
        -4.56886217e-02],
       [-5.92823476e-02,  8.19613338e-02, -1.38191491e-01,
        -5.44721559e-02,  1.43996879e-01,  1.63444877e-01,
        -3.32763419e-03,  6.85011521e-02, -1.03312358e-02,
         1.85680538e-01],
       [ 1.64945424e-03,  4.28140946e-02, -1.34447098e-01,
        -6.52093217e-02,  5.25315180e-02,  1.32956102e-01,
        -1.37841985e-01,  1.10974610e-01, -7.48531520e-02,
         1.49444312e-01],
       [-8.07360634e-02,  2.58304030e-02, -8.46436098e-02,
        -5.65668754e-03, -1.20470002e-02,  1.93497278e-02,
        -9.59537923e-03,  1.20501198e-01, -4.27140668e-03,
         3.15761939e-02],
       [-2.52109990e-02,  8.15307349e-03, -1.06110439e-01,
         6.37486205e-02, -1.74070857e-02, -1.80786587e-02,
        -4.86337543e-02,  1.85482465e-02, -8.21305439e-02,
         1.22130252e-01],
       [-5.84090240e-02,  6.22973964e-03, -2.09406018e-02,
        -2.64598150e-02, -1.56418812e-02,  1.24849170e-01,
        -5.05709909e-02,  3.60853411e-02,  7.49800801e-02,
         9.41742212e-02],
       [ 1.58185065e-02, -2.32586041e-02, -9.91999805e-02,
         6.16655499e-02,  6.74107224e-02,  2.83737481e-03,
        -1.21706262e-01,  1.65232532e-02, -1.58390671e-01,
        -1.17502213e-02],
       [-1.59103855e-01,  3.17674503e-02, -9.42723304e-02,
         4.63753939e-03,  1.24312490e-01,  1.63273200e-01,
        -1.74046457e-01,  3.40781063e-02, -9.84176099e-02,
         5.50186262e-03],
       [-3.01449597e-01, -7.51234293e-02, -2.72169918e-01,
         6.88522309e-02,  1.52688116e-01,  1.37633294e-01,
         1.43656343e-01,  1.35859102e-02, -1.08857602e-02,
         1.21236861e-01],
       [-4.16583866e-02,  6.51417002e-02,  6.78985938e-02,
        -1.20711625e-01,  7.75855035e-02,  2.70425886e-01,
        -1.66790605e-01,  3.90531234e-02,  1.59135997e-01,
         6.28883690e-02],
       [ 4.39315103e-02, -7.48416409e-04, -1.12921514e-01,
        -3.06121893e-02,  1.40782505e-01,  1.31040841e-01,
        -1.79994240e-01, -3.43914479e-02,  9.90083292e-02,
         8.64681676e-02],
       [-9.96745527e-02,  6.76480234e-02, -2.36267045e-01,
         3.73096503e-02,  1.25138015e-01,  5.15070707e-02,
        -3.04409862e-03,  3.68063152e-03, -6.00188971e-02,
         1.75181970e-01],
       [-1.26282215e-01,  3.07353660e-02, -1.59110919e-01,
        -3.77192721e-02,  9.93793011e-02,  1.60358116e-01,
        -9.68125165e-02,  3.26254256e-02, -9.63076949e-04,
         1.66689843e-01],
       [-4.66132164e-02,  8.98431391e-02,  5.20853028e-02,
        -9.53413174e-03, -8.38485360e-02,  9.43394229e-02,
        -1.67803213e-01,  5.21478951e-02,  2.35922188e-02,
         4.80654985e-02],
       [-8.26627612e-02, -1.12764500e-02, -7.09109604e-02,
        -1.44124866e-01, -2.93930210e-02,  1.07271858e-01,
         1.19905937e-02, -1.74359661e-02,  5.30138016e-02,
         2.06060320e-01],
       [-1.49849579e-01, -1.23635128e-01, -3.78015190e-02,
        -2.26156309e-01, -1.21895932e-02,  3.49349082e-01,
        -2.96786427e-02,  7.56466091e-02,  8.69159847e-02,
         1.51084185e-01],
       [-1.14609867e-01,  1.30442753e-01, -1.57873586e-01,
         7.65803456e-03, -2.82263197e-02, -3.09533216e-02,
        -9.51004848e-02, -1.62554532e-02,  2.33266726e-02,
         1.42399922e-01],
       [-1.16635889e-01,  1.29808515e-01, -1.14668250e-01,
        -5.25844619e-02,  1.83597878e-02,  8.23548138e-02,
        -2.57813707e-02,  1.00130260e-01, -7.20695257e-02,
         1.26909599e-01],
       [-9.05660987e-02,  9.63919014e-02,  4.38964739e-02,
        -3.48275527e-02, -3.47231030e-02,  2.14971110e-01,
        -1.32039815e-01,  9.16699767e-02,  1.08465843e-01,
         7.35922903e-02],
       [-2.01778691e-02, -1.03839390e-01,  3.22979912e-02,
         6.69929534e-02, -6.41995519e-02,  2.13030875e-01,
         3.59671935e-02,  1.89223990e-01, -3.37819010e-03,
         9.30493027e-02],
       [-8.27451572e-02,  8.02399591e-02, -9.94682461e-02,
         9.91630554e-03, -1.05077595e-01, -5.13505414e-02,
        -7.07581490e-02, -2.38152966e-02, -8.30717944e-03,
         1.66809797e-01],
       [-2.17685252e-02,  1.03263371e-01, -1.48551941e-01,
        -1.26244068e-01, -4.39270698e-02,  2.13067532e-02,
        -1.25626326e-01,  1.62264049e-01, -5.45932353e-02,
         2.11182043e-01],
       [-1.29825652e-01, -1.37500942e-01, -1.86523274e-01,
        -1.94349945e-01,  3.41879204e-03,  1.77811921e-01,
         9.89355221e-02,  4.22665998e-02,  2.67502889e-02,
         2.05411360e-01],
       [-1.51228728e-02,  7.28534758e-02,  6.98523968e-03,
        -8.13592374e-02,  1.50513485e-01,  1.56009376e-01,
        -9.92965624e-02,  2.55258288e-02, -5.03280759e-03,
         4.10573557e-02],
       [-2.23004520e-02,  4.08236533e-02,  1.43895429e-02,
        -1.38521969e-01,  1.98520124e-02,  6.18736967e-02,
        -1.20559201e-01,  1.38778061e-01, -1.51736483e-01,
         9.62266326e-03]], dtype=float32)>}
2021-08-13 21:27:26.402957: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

您还可以采用分布式方式加载和进行推断:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    another_strategy.run(inference_func,args=(batch,))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
2021-08-13 21:27:26.571690: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:461] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.

调用已恢复的函数只是基于已保存模型的前向传递(预测)。如果您想继续训练加载的函数,或者将加载的函数嵌入到更大的模型中,应如何操作? 通常的做法是将此加载对象包装到 Keras 层以实现此目的。幸运的是,TF Hub 为此提供了 hub.KerasLayer,如下所示:

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])
  model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
2021-08-13 21:27:28.485206: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:461] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/2
938/938 [==============================] - 5s 3ms/step - loss: 0.2110 - accuracy: 0.9372
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0700 - accuracy: 0.9795

如您所见,hub.KerasLayer 可将从 tf.saved_model.load() 加载回的结果封装到可用于构建其他模型的 Keras 层。这对于迁移学习非常实用。

我应使用哪种 API?

对于保存,如果您使用的是 Keras 模型,那么始终建议使用 Keras 的 model.save() API。如果您保存的不是 Keras 模型,那么您只能选择使用较低级的 API。

对于加载,使用哪种 API 取决于您要从加载的 API 中获得什么。如果您无法或不想获取 Keras 模型,请使用 tf.saved_model.load()。否则,请使用 tf.keras.models.load_model()。请注意,只有保存 Keras 模型后,才能恢复 Keras 模型。

可以混合使用 API。您可以使用 model.save 保存 Keras 模型,并使用低级 API tf.saved_model.load 加载非 Keras 模型。

model = get_model()

# Saving the model using Keras's save() API
model.save(keras_model_path) 

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

从本地设备保存/加载

要在远程运行(例如使用 Cloud TPU)的情况下从本地 I/O 设备保存和加载,则必须使用选项 experimental_io_device 将 I/O 设备设置为本地主机。

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

警告

有一种特殊情况,您的 Keras 模型没有明确定义的输入。例如,可以创建没有任何输入形状的序贯模型 (Sequential([Dense(3), ...])。子类化模型在初始化后也没有明确定义的输入。在这种情况下,在保存和加载时都应坚持使用较低级别的 API,否则会出现错误。

要检查您的模型是否具有明确定义的输入,只需检查 model.inputs 是否为 None。如果非 None,则一切正常。在 .fit.evaluate.predict 中使用模型,或调用模型 (model(inputs)) 时,输入形状将自动定义。

以下为示例:

class SubclassedModel(tf.keras.Model):

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
# my_model.save(keras_model_path)  # ERROR! 
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f3e58640c50>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f3e58640c50>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.Dense object at 0x7f3e585a4250>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.Dense object at 0x7f3e585a4250>, because it is not built.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets