保存和恢复模型

在 tensorflow.google.cn 上查看 在 Google Colab 运行 在 Github 上查看源代码 下载此 notebook

模型可以在训练期间和训练完成后进行保存。这意味着模型可以从任意中断中恢复,并避免耗费比较长的时间在训练上。保存也意味着您可以共享您的模型,而其他人可以通过您的模型来重新创建工作。在发布研究模型和技术时,大多数机器学习从业者分享:

  • 用于创建模型的代码
  • 模型训练的权重 (weight) 和参数 (parameters) 。

共享数据有助于其他人了解模型的工作原理,并使用新数据自行尝试。

注意:小心不受信任的代码——Tensorflow 模型是代码。有关详细信息,请参阅 安全使用Tensorflow

选项

保存 Tensorflow 的模型有许多方法——具体取决于您使用的 API。本指南使用 tf.keras, 一个高级 API 用于在 Tensorflow 中构建和训练模型。有关其他方法的实现,请参阅 TensorFlow 保存和恢复指南或保存到 eager

配置

安装并导入

安装并导入Tensorflow和依赖项:

pip install -q pyyaml h5py  # 以 HDF5 格式保存模型所必须
import os

import tensorflow as tf
from tensorflow import keras

print(tf.version.VERSION)
2.3.0

获取示例数据集

要演示如何保存和加载权重,您将使用 MNIST 数据集. 要加快运行速度,请使用前1000个示例:

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

定义模型

首先构建一个简单的序列(sequential)模型:

# 定义一个简单的序列模型
def create_model():
  model = tf.keras.models.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10)
  ])

  model.compile(optimizer='adam',
                loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])

  return model

# 创建一个基本的模型实例
model = create_model()

# 显示模型的结构
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 512)               401920    
_________________________________________________________________
dropout (Dropout)            (None, 512)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                5130      
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

在训练期间保存模型(以 checkpoints 形式保存)

您可以使用训练好的模型而无需从头开始重新训练,或在您打断的地方开始训练,以防止训练过程没有保存。 tf.keras.callbacks.ModelCheckpoint 允许在训练的过程中结束时回调保存的模型。

Checkpoint 回调用法

创建一个只在训练期间保存权重的 tf.keras.callbacks.ModelCheckpoint 回调:

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# 创建一个保存模型权重的回调
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

# 使用新的回调训练模型
model.fit(train_images, 
          train_labels,  
          epochs=10,
          validation_data=(test_images,test_labels),
          callbacks=[cp_callback])  # 通过回调训练

# 这可能会生成与保存优化程序状态相关的警告。
# 这些警告(以及整个笔记本中的类似警告)
# 是防止过时使用,可以忽略。
Epoch 1/10
27/32 [========================>.....] - ETA: 0s - loss: 1.2342 - accuracy: 0.6470
Epoch 00001: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 9ms/step - loss: 1.1524 - accuracy: 0.6730 - val_loss: 0.7059 - val_accuracy: 0.7870
Epoch 2/10
29/32 [==========================>...] - ETA: 0s - loss: 0.4293 - accuracy: 0.8825
Epoch 00002: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.4217 - accuracy: 0.8860 - val_loss: 0.5437 - val_accuracy: 0.8330
Epoch 3/10
29/32 [==========================>...] - ETA: 0s - loss: 0.2772 - accuracy: 0.9310
Epoch 00003: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 4ms/step - loss: 0.2826 - accuracy: 0.9260 - val_loss: 0.4935 - val_accuracy: 0.8370
Epoch 4/10
30/32 [===========================>..] - ETA: 0s - loss: 0.2164 - accuracy: 0.9490
Epoch 00004: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.2146 - accuracy: 0.9490 - val_loss: 0.4454 - val_accuracy: 0.8580
Epoch 5/10
27/32 [========================>.....] - ETA: 0s - loss: 0.1527 - accuracy: 0.9734
Epoch 00005: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.1554 - accuracy: 0.9680 - val_loss: 0.4259 - val_accuracy: 0.8640
Epoch 6/10
27/32 [========================>.....] - ETA: 0s - loss: 0.1192 - accuracy: 0.9815
Epoch 00006: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 10ms/step - loss: 0.1172 - accuracy: 0.9820 - val_loss: 0.4190 - val_accuracy: 0.8690
Epoch 7/10
27/32 [========================>.....] - ETA: 0s - loss: 0.0821 - accuracy: 0.9884
Epoch 00007: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.0872 - accuracy: 0.9850 - val_loss: 0.4246 - val_accuracy: 0.8670
Epoch 8/10
27/32 [========================>.....] - ETA: 0s - loss: 0.0617 - accuracy: 0.9965
Epoch 00008: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 9ms/step - loss: 0.0638 - accuracy: 0.9970 - val_loss: 0.4135 - val_accuracy: 0.8710
Epoch 9/10
28/32 [=========================>....] - ETA: 0s - loss: 0.0498 - accuracy: 0.9944
Epoch 00009: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 9ms/step - loss: 0.0508 - accuracy: 0.9940 - val_loss: 0.4446 - val_accuracy: 0.8600
Epoch 10/10
27/32 [========================>.....] - ETA: 0s - loss: 0.0396 - accuracy: 0.9988
Epoch 00010: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.0406 - accuracy: 0.9990 - val_loss: 0.4366 - val_accuracy: 0.8690

<tensorflow.python.keras.callbacks.History at 0x7fbc7d0fb6d8>

这将创建一个 TensorFlow checkpoint 文件集合,这些文件在每个 epoch 结束时更新:

ls {checkpoint_dir}
checkpoint  cp.ckpt.data-00000-of-00001  cp.ckpt.index

创建一个新的未经训练的模型。仅恢复模型的权重时,必须具有与原始模型具有相同网络结构的模型。由于模型具有相同的结构,您可以共享权重,尽管它是模型的不同实例。 现在重建一个新的未经训练的模型,并在测试集上进行评估。未经训练的模型将在机会水平(chance levels)上执行(准确度约为10%):

# 创建一个基本模型实例
model = create_model()

# 评估模型
loss, acc = model.evaluate(test_images,  test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc))
32/32 - 0s - loss: 2.3549 - accuracy: 0.0560
Untrained model, accuracy:  5.60%

然后从 checkpoint 加载权重并重新评估:

# 加载权重
model.load_weights(checkpoint_path)

# 重新评估模型
loss,acc = model.evaluate(test_images,  test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
32/32 - 0s - loss: 0.4366 - accuracy: 0.8690
Restored model, accuracy: 86.90%

checkpoint 回调选项

回调提供了几个选项,为 checkpoint 提供唯一名称并调整 checkpoint 频率。

训练一个新模型,每五个 epochs 保存一次唯一命名的 checkpoint :

# 在文件名中包含 epoch (使用 `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# 创建一个回调,每 5 个 epochs 保存模型的权重
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path, 
    verbose=1, 
    save_weights_only=True,
    period=5)

# 创建一个新的模型实例
model = create_model()

# 使用 `checkpoint_path` 格式保存权重
model.save_weights(checkpoint_path.format(epoch=0))

# 使用新的回调训练模型
model.fit(train_images, 
          train_labels,
          epochs=50, 
          callbacks=[cp_callback],
          validation_data=(test_images,test_labels),
          verbose=0)
WARNING:tensorflow:`period` argument is deprecated. Please use `save_freq` to specify the frequency in number of batches seen.

Epoch 00005: saving model to training_2/cp-0005.ckpt

Epoch 00010: saving model to training_2/cp-0010.ckpt

Epoch 00015: saving model to training_2/cp-0015.ckpt

Epoch 00020: saving model to training_2/cp-0020.ckpt

Epoch 00025: saving model to training_2/cp-0025.ckpt

Epoch 00030: saving model to training_2/cp-0030.ckpt

Epoch 00035: saving model to training_2/cp-0035.ckpt

Epoch 00040: saving model to training_2/cp-0040.ckpt

Epoch 00045: saving model to training_2/cp-0045.ckpt

Epoch 00050: saving model to training_2/cp-0050.ckpt

<tensorflow.python.keras.callbacks.History at 0x7fbcd4c01588>

现在查看生成的 checkpoint 并选择最新的 checkpoint :

ls {checkpoint_dir}
checkpoint            cp-0025.ckpt.index
cp-0000.ckpt.data-00000-of-00001  cp-0030.ckpt.data-00000-of-00001
cp-0000.ckpt.index        cp-0030.ckpt.index
cp-0005.ckpt.data-00000-of-00001  cp-0035.ckpt.data-00000-of-00001
cp-0005.ckpt.index        cp-0035.ckpt.index
cp-0010.ckpt.data-00000-of-00001  cp-0040.ckpt.data-00000-of-00001
cp-0010.ckpt.index        cp-0040.ckpt.index
cp-0015.ckpt.data-00000-of-00001  cp-0045.ckpt.data-00000-of-00001
cp-0015.ckpt.index        cp-0045.ckpt.index
cp-0020.ckpt.data-00000-of-00001  cp-0050.ckpt.data-00000-of-00001
cp-0020.ckpt.index        cp-0050.ckpt.index
cp-0025.ckpt.data-00000-of-00001

latest = tf.train.latest_checkpoint(checkpoint_dir)
latest
'training_2/cp-0050.ckpt'

注意: 默认的 tensorflow 格式仅保存最近的5个 checkpoint 。

如果要进行测试,请重置模型并加载最新的 checkpoint :

# 创建一个新的模型实例
model = create_model()

# 加载以前保存的权重
model.load_weights(latest)

# 重新评估模型
loss, acc = model.evaluate(test_images,  test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
32/32 - 0s - loss: 0.5060 - accuracy: 0.8730
Restored model, accuracy: 87.30%

这些文件是什么?

上述代码将权重存储到 checkpoint—— 格式化文件的集合中,这些文件仅包含二进制格式的训练权重。 Checkpoints 包含:

  • 一个或多个包含模型权重的分片。
  • 索引文件,指示哪些权重存储在哪个分片中。

如果你只在一台机器上训练一个模型,你将有一个带有后缀的碎片: .data-00000-of-00001

手动保存权重

您将了解如何将权重加载到模型中。使用 Model.save_weights 方法手动保存它们同样简单。默认情况下, tf.kerassave_weights 特别使用 TensorFlow checkpoints 格式 .ckpt 扩展名和 ( 保存在 HDF5 扩展名为 .h5 保存并序列化模型 ):

# 保存权重
model.save_weights('./checkpoints/my_checkpoint')

# 创建模型实例
model = create_model()

# 恢复权重
model.load_weights('./checkpoints/my_checkpoint')

# 评估模型
loss,acc = model.evaluate(test_images,  test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
32/32 - 0s - loss: 0.5060 - accuracy: 0.8730
Restored model, accuracy: 87.30%

保存整个模型

调用 model.save 将保存模型的结构,权重和训练配置保存在单个文件/文件夹中。这可以让您导出模型,以便在不访问原始 Python 代码*的情况下使用它。因为优化器状态(optimizer-state)已经恢复,您可以从中断的位置恢复训练。

整个模型可以以两种不同的文件格式(SavedModelHDF5)进行保存。需要注意的是 TensorFlow 的 SavedModel 格式是 TF2.x. 中的默认文件格式。但是,模型仍可以以 HDF5 格式保存。下面介绍了以两种文件格式保存整个模型的更多详细信息。

保存完整模型会非常有用——您可以在 TensorFlow.js(Saved Model, HDF5)加载它们,然后在 web 浏览器中训练和运行它们,或者使用 TensorFlow Lite 将它们转换为在移动设备上运行(Saved Model, HDF5

*自定义对象(例如,子类化模型或层)在保存和加载时需要特别注意。请参阅下面的保存自定义对象部分

SavedModel 格式

SavedModel 格式是序列化模型的另一种方法。以这种格式保存的模型,可以使用 tf.keras.models.load_model 还原,并且模型与 TensorFlow Serving 兼容。SavedModel 指南详细介绍了如何提供/检查 SavedModel。以下部分说明了保存和还原模型的步骤。

# 创建并训练一个新的模型实例。
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# 将整个模型另存为 SavedModel。
!mkdir -p saved_model
model.save('saved_model/my_model') 
Epoch 1/5
32/32 [==============================] - 0s 2ms/step - loss: 1.1045 - accuracy: 0.7010
Epoch 2/5
32/32 [==============================] - 0s 2ms/step - loss: 0.4216 - accuracy: 0.8790
Epoch 3/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2697 - accuracy: 0.9310
Epoch 4/5
32/32 [==============================] - 0s 1ms/step - loss: 0.2198 - accuracy: 0.9440
Epoch 5/5
32/32 [==============================] - 0s 2ms/step - loss: 0.1498 - accuracy: 0.9710
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: saved_model/my_model/assets

SavedModel 格式是一个包含 protobuf 二进制文件和 Tensorflow 检查点(checkpoint)的目录。检查保存的模型目录:

# my_model 文件夹
!ls saved_model

# 包含一个 assets 文件夹,saved_model.pb,和变量文件夹。
!ls saved_model/my_model
my_model
assets  saved_model.pb  variables

从保存的模型重新加载新的 Keras 模型:

new_model = tf.keras.models.load_model('saved_model/my_model')

# 检查其架构
new_model.summary()
Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_10 (Dense)             (None, 512)               401920    
_________________________________________________________________
dropout_5 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_11 (Dense)             (None, 10)                5130      
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

还原的模型使用与原始模型相同的参数进行编译。 尝试使用加载的模型运行评估和预测:

# 评估还原的模型
loss, acc = new_model.evaluate(test_images,  test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100*acc))

print(new_model.predict(test_images).shape)
32/32 - 0s - loss: 0.4209 - accuracy: 0.0910
Restored model, accuracy:  9.10%
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
(1000, 10)

HDF5 格式

Keras使用 HDF5 标准提供了一种基本的保存格式。

# 创建并训练一个新的模型实例
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# 将整个模型保存为 HDF5 文件。
# '.h5' 扩展名指示应将模型保存到 HDF5。
model.save('my_model.h5') 
Epoch 1/5
32/32 [==============================] - 0s 2ms/step - loss: 1.1949 - accuracy: 0.6500
Epoch 2/5
32/32 [==============================] - 0s 2ms/step - loss: 0.4308 - accuracy: 0.8800
Epoch 3/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2764 - accuracy: 0.9330
Epoch 4/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2088 - accuracy: 0.9410
Epoch 5/5
32/32 [==============================] - 0s 2ms/step - loss: 0.1520 - accuracy: 0.9730

现在,从该文件重新创建模型:

# 重新创建完全相同的模型,包括其权重和优化程序
new_model = tf.keras.models.load_model('my_model.h5')

# 显示网络结构
new_model.summary()
Model: "sequential_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_12 (Dense)             (None, 512)               401920    
_________________________________________________________________
dropout_6 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_13 (Dense)             (None, 10)                5130      
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

检查其准确率(accuracy):

loss, acc = new_model.evaluate(test_images,  test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100*acc))
32/32 - 0s - loss: 0.4305 - accuracy: 0.0840
Restored model, accuracy:  8.40%

Keras 通过检查网络结构来保存模型。这项技术可以保存一切:

  • 权重值
  • 模型的架构
  • 模型的训练配置(您传递给编译的内容)
  • 优化器及其状态(如果有的话)(这使您可以在中断的地方重新开始训练)

Keras 无法保存 v1.x 优化器(来自 tf.compat.v1.train),因为它们与检查点不兼容。对于 v1.x 优化器,您需要在加载-失去优化器的状态后,重新编译模型。

保存自定义对象

如果使用的是 SavedModel 格式,则可以跳过此部分。HDF5 和 SavedModel 之间的主要区别在于,HDF5 使用对象配置保存模型结构,而 SavedModel 保存执行图。因此,SavedModel 能够保存自定义对象,例如子类化模型和自定义层,而无需原始代码。

要将自定义对象保存到 HDF5,必须执行以下操作:

  1. 在对象中定义一个 get_config 方法,以及可选的 from_config 类方法。
    • get_config(self) 返回重新创建对象所需的参数的 JSON 可序列化字典。
    • from_config(cls, config) 使用从 get_config 返回的 config 来创建一个新对象。默认情况下,此函数将使用 config 作为初始化 kwargs(return cls(**config))。
  2. 加载模型时,将对象传递给 custom_objects 参数。参数必须是将字符串类名称映射到 Python 类的字典。例如,tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})

有关自定义对象和 get_config 的示例,请参见从头开始编写层和模型教程。


#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.