迁移学习和微调

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

设置

import numpy as np
import tensorflow as tf
from tensorflow import keras
2022-08-17 00:06:13.894189: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-08-17 00:06:14.445821: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-17 00:06:14.446069: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-17 00:06:14.446082: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

简介

迁移学习包括获取从迁移学习包括获取从一个问题中学习到的特征,然后将这些特征用于新的类似问题。例如,来自已学会识别浣熊的模型的特征可能对建立旨在识别狸猫的模型十分有用。

对于数据集中的数据太少而无法从头开始训练完整模型的任务,通常会执行迁移学习。

在深度学习情境中,迁移学习最常见的形式是以下工作流:

  1. 从之前训练的模型中获取层。
  2. 冻结这些层,以避免在后续训练轮次中破坏它们包含的任何信息。
  3. 在已冻结层的顶部添加一些新的可训练层。这些层会学习将旧特征转换为对新数据集的预测。
  4. 在您的数据集上训练新层。

最后一个可选步骤是微调,包括解冻上面获得的整个模型(或模型的一部分),然后在新数据上以极低的学习率对该模型进行重新训练。以增量方式使预训练特征适应新数据,有可能实现有意义的改进。

首先,我们将详细介绍 Keras trainable API,它是大多数迁移学习和微调工作流的基础。

随后,我们将演示一个典型工作流:先获得一个在 ImageNet 数据集上预训练的模型,然后在 Kaggle Dogs vs. Cats 分类数据集上对该模型进行重新训练。

此工作流改编自 Python 深度学习 和 2016 年的博文“使用极少的数据构建强大的图像分类模型”

冻结层:了解 trainable 特性

层和模型具有三个权重特性:

  • weights 是层的所有权重变量的列表。
  • trainable_weights 是需要进行更新(通过梯度下降)以尽可能减少训练过程中损失的权重列表。
  • non_trainable_weights 是不适合训练的权重列表。它们通常在正向传递过程中由模型更新。

示例:Dense 层具有 2 个可训练权重(内核与偏差)

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

一般而言,所有权重都是可训练权重。唯一具有不可训练权重的内置层是 BatchNormalization 层。在训练期间,它使用不可训练权重跟踪其输入的平均值和方差。要了解如何在您自己的自定义层中使用不可训练权重,请参阅从头开始编写新层的指南。

示例:BatchNormalization 层具有 2 个可训练权重和 2 个不可训练权重

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

层和模型还具有布尔特性 trainable。此特性的值可以更改。将 layer.trainable 设置为 False 会将层的所有权重从可训练移至不可训练。这一过程称为“冻结”层:已冻结层的状态在训练期间不会更新(无论是使用 fit() 进行训练,还是使用依赖于 trainable_weights 来应用梯度更新的任何自定义循环进行训练时)。

示例:将 trainable 设置为 False

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

当可训练权重变为不可训练时,它的值在训练期间不再更新。

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 1s 535ms/step - loss: 0.0700

请勿将 layer.trainable 特性与 layer.__call__() 中的 training 参数(此参数控制层是在推断模式还是训练模式下运行其前向传递)混淆。有关更多信息,请参阅 Keras常见问题解答

trainable 特性的递归设置

如果在模型或具有子层的任何层上设置 trainable = False,则所有子层也将变为不可训练。

示例:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

典型的迁移学习工作流

下面将介绍如何在 Keras 中实现典型的迁移学习工作流:

  1. 实例化一个基础模型并加载预训练权重。
  2. 通过设置 trainable = False 冻结基础模型中的所有层。
  3. 根据基础模型中一个(或多个)层的输出创建一个新模型。
  4. 在您的新数据集上训练新模型。

请注意,另一种更轻量的工作流如下:

  1. 实例化一个基础模型并加载预训练权重。
  2. 通过该模型运行新的数据集,并记录基础模型中一个(或多个)层的输出。这一过程称为特征提取。
  3. 使用该输出作为新的较小模型的输入数据。

第二种工作流有一个关键优势,即您只需在自己的数据上运行一次基础模型,而不是每个训练周期都运行一次。因此,它的速度更快,开销也更低。

但是,第二种工作流存在一个问题,即它不允许您在训练期间动态修改新模型的输入数据,在进行数据扩充时,这种修改必不可少。当新数据集的数据太少而无法从头开始训练完整模型时,任务通常会使用迁移学习,在这种情况下,数据扩充非常重要。因此,在接下来的篇幅中,我们将专注于第一种工作流。

下面是 Keras 中第一种工作流的样子:

首先,实例化一个具有预训练权重的基础模型。

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

随后,冻结该基础模型。

base_model.trainable = False

根据基础模型创建一个新模型。

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

在新数据上训练该模型。

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

微调

一旦模型在新数据上收敛,您就可以尝试解冻全部或部分基础模型,并以极低的学习率端到端地重新训练整个模型。

这是可选的最后一个步骤,可能给您带来增量式改进。不过,它也可能导致快速过拟合,请牢记这一点。

重要的是,只有在将具有冻结层的模型训练至收敛后,才能执行此步骤。如果将随机初始化的可训练层与包含预训练特征的可训练层混合使用,则随机初始化的层将在训练过程中引起非常大的梯度更新,而这将破坏您的预训练特征。

在此阶段使用极低的学习率也很重要,因为与第一轮训练相比,您正在一个通常非常小的数据集上训练一个大得多的模型。因此,如果您应用较大的权重更新,则存在很快过拟合的风险。在这里,您只需要以增量方式重新调整预训练权重。

下面是实现整个基础模型微调的方法:

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

关于 compile()trainable 的重要说明

在模型上调用 compile() 意味着“冻结”该模型的行为。这意味着编译模型时的 trainable 特性值应当在该模型的整个生命周期中保留,直到再次调用 compile。因此,如果您更改任何 trainable 值,请确保在您的模型上再次调用 compile() 以将您的变更考虑在内。

关于 BatchNormalization 层的重要说明

许多图像模型包含 BatchNormalization 层。该层在任何方面都是一个特例。下面是一些注意事项。

  • BatchNormalization 包含 2 个会在训练期间更新的不可训练权重。它们是跟踪输入的均值和方差的变量。
  • 当您设置 bn_layer.trainable = False 时,BatchNormalization 层将以推断模式运行,并且不会更新其均值和方差统计信息。一般而言,其他层的情况并非如此,因为权重可训练性和推断/训练模式是两个正交概念。但是,对于 BatchNormalization 层,两者是关联的。
  • 当您解冻包含 BatchNormalization 层的模型以进行微调时,您应当通过在调用基础模型时传递 training=False 以将 BatchNormalization 层保持在推断模式。否则,应用于不可训练权重的更新会突然破坏模型已经学习的内容。

您将在本指南结尾处的端到端示例中看到这种模式的实际运行。

使用自定义训练循环进行迁移学习和微调

如果您使用自己的低级训练循环而不是 fit(),则工作流基本保持不变。在应用梯度更新时,您应当注意只考虑清单 model.trainable_weights

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

对于微调同样如此。

端到端示例:基于 Dogs vs. Cats 数据集微调图像分类模型

为了巩固这些概念,我们先介绍一个具体的端到端迁移学习和微调示例。我们将加载在 ImageNet 上预训练的 Xception 模型,并将其用于 Kaggle Dogs vs. Cats 分类数据集。

获取数据

首先,我们使用 TFDS 来获取 Dogs vs. Cats 数据集。如果您拥有自己的数据集,则可能需要使用效用函数 tf.keras.preprocessing.image_dataset_from_directory 从磁盘上存档到类特定的文件夹中的一组图像来生成相似的有标签数据集对象。

使用非常小的数据集时,迁移学习最实用。为了使数据集保持较小状态,我们将原始训练数据(25,000 个图像)的 40% 用于训练,10% 用于验证,10% 用于测试。

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

下面是训练数据集中的前 9 个图像。如您所见,它们具有不同的大小。

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

我们还可以看到标签 1 是“狗”,标签 0 是“猫”。

标准化数据

我们的原始图像有各种大小。另外,每个像素由 0 到 255 之间的 3 个整数值(RGB 色阶值)组成。这不太适合馈送神经网络。我们需要做下面两件事:

  • 标准化为固定图像大小。我们选择 150x150。
  • 在 -1 至 1 之间归一化像素值。我们将使用 Normalization 层作为模型本身的一部分来进行此操作。

一般而言,与采用已预处理数据的模型相反,开发以原始数据作为输入的模型是一种良好的做法。原因在于,如果模型需要预处理的数据,则每次导出模型以在其他地方(在网络浏览器、移动应用中)使用时,都需要重新实现完全相同的预处理流水线。这很快就会变得非常棘手。因此,在命中模型之前,我们应当尽可能少地进行预处理。

在这里,我们将在数据流水线中进行图像大小调整(因为深度神经网络只能处理连续的数据批次),并在创建模型时将其作为模型的一部分进行输入值缩放。

我们将图像的大小调整为 150x150:

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

此外,我们对数据进行批处理并使用缓存和预提取来优化加载速度。

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

使用随机数据扩充

当您没有较大的图像数据集时,通过将随机但现实的转换(例如随机水平翻转或小幅随机旋转)应用于训练图像来人为引入样本多样性是一种良好的做法。这有助于使模型暴露于训练数据的不同方面,同时减慢过拟合的速度。

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [layers.RandomFlip("horizontal"), layers.RandomRotation(0.1),]
)

我们看一下经过各种随机转换后第一个批次的第一个图像是什么样:

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[0]))
        plt.axis("off")
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:5 out of the last 5 calls to <function pfor.<locals>.f at 0x7efc6c59ac10> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:5 out of the last 5 calls to <function pfor.<locals>.f at 0x7efc6c59ac10> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:6 out of the last 6 calls to <function pfor.<locals>.f at 0x7efc6c6854c0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:6 out of the last 6 calls to <function pfor.<locals>.f at 0x7efc6c6854c0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
2022-08-17 00:06:31.878163: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

构建模型

现在,我们来构建一个遵循我们先前解释的蓝图的模型。

注意:

  • 我们添加 Normalization 层以将输入值(最初在 [0, 255] 范围内)缩放到 [-1, 1] 范围。
  • 我们在分类层之前添加一个 Dropout 层,以进行正则化。
  • 我们确保在调用基础模型时传递 training=False,使其在推断模式下运行,这样,即使在我们解冻基础模型以进行微调后,batchnorm 统计信息也不会更新。
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83683744/83683744 [==============================] - 1s 0us/step
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_5 (InputLayer)        [(None, 150, 150, 3)]     0         
                                                                 
 sequential_3 (Sequential)   (None, 150, 150, 3)       0         
                                                                 
 rescaling (Rescaling)       (None, 150, 150, 3)       0         
                                                                 
 xception (Functional)       (None, 5, 5, 2048)        20861480  
                                                                 
 global_average_pooling2d (G  (None, 2048)             0         
 lobalAveragePooling2D)                                          
                                                                 
 dropout (Dropout)           (None, 2048)              0         
                                                                 
 dense_7 (Dense)             (None, 1)                 2049      
                                                                 
=================================================================
Total params: 20,863,529
Trainable params: 2,049
Non-trainable params: 20,861,480
_________________________________________________________________

训练顶层

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
147/291 [==============>...............] - ETA: 11s - loss: 0.1942 - binary_accuracy: 0.9075
Corrupt JPEG data: 65 extraneous bytes before marker 0xd9
262/291 [==========================>...] - ETA: 2s - loss: 0.1670 - binary_accuracy: 0.9251
Corrupt JPEG data: 239 extraneous bytes before marker 0xd9
276/291 [===========================>..] - ETA: 1s - loss: 0.1637 - binary_accuracy: 0.9269
Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9
279/291 [===========================>..] - ETA: 0s - loss: 0.1629 - binary_accuracy: 0.9273
Corrupt JPEG data: 228 extraneous bytes before marker 0xd9
291/291 [==============================] - ETA: 0s - loss: 0.1627 - binary_accuracy: 0.9277
Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9
291/291 [==============================] - 37s 102ms/step - loss: 0.1627 - binary_accuracy: 0.9277 - val_loss: 0.0817 - val_binary_accuracy: 0.9682
Epoch 2/20
291/291 [==============================] - 26s 90ms/step - loss: 0.1167 - binary_accuracy: 0.9507 - val_loss: 0.0752 - val_binary_accuracy: 0.9712
Epoch 3/20
291/291 [==============================] - 26s 90ms/step - loss: 0.1061 - binary_accuracy: 0.9568 - val_loss: 0.0753 - val_binary_accuracy: 0.9716
Epoch 4/20
291/291 [==============================] - 26s 90ms/step - loss: 0.1011 - binary_accuracy: 0.9577 - val_loss: 0.0715 - val_binary_accuracy: 0.9742
Epoch 5/20
291/291 [==============================] - 26s 90ms/step - loss: 0.1034 - binary_accuracy: 0.9589 - val_loss: 0.0782 - val_binary_accuracy: 0.9703
Epoch 6/20
291/291 [==============================] - 26s 90ms/step - loss: 0.1002 - binary_accuracy: 0.9625 - val_loss: 0.0743 - val_binary_accuracy: 0.9721
Epoch 7/20
291/291 [==============================] - 26s 90ms/step - loss: 0.0955 - binary_accuracy: 0.9601 - val_loss: 0.0714 - val_binary_accuracy: 0.9721
Epoch 8/20
291/291 [==============================] - 26s 90ms/step - loss: 0.0972 - binary_accuracy: 0.9599 - val_loss: 0.0696 - val_binary_accuracy: 0.9729
Epoch 9/20
291/291 [==============================] - 26s 90ms/step - loss: 0.0945 - binary_accuracy: 0.9605 - val_loss: 0.0728 - val_binary_accuracy: 0.9733
Epoch 10/20
291/291 [==============================] - 26s 90ms/step - loss: 0.0918 - binary_accuracy: 0.9610 - val_loss: 0.0740 - val_binary_accuracy: 0.9725
Epoch 11/20
291/291 [==============================] - 26s 90ms/step - loss: 0.0929 - binary_accuracy: 0.9616 - val_loss: 0.0703 - val_binary_accuracy: 0.9746
Epoch 12/20
291/291 [==============================] - 26s 90ms/step - loss: 0.0971 - binary_accuracy: 0.9615 - val_loss: 0.0747 - val_binary_accuracy: 0.9738
Epoch 13/20
291/291 [==============================] - 26s 90ms/step - loss: 0.0926 - binary_accuracy: 0.9630 - val_loss: 0.0725 - val_binary_accuracy: 0.9712
Epoch 14/20
291/291 [==============================] - 26s 91ms/step - loss: 0.0839 - binary_accuracy: 0.9650 - val_loss: 0.0763 - val_binary_accuracy: 0.9708
Epoch 15/20
291/291 [==============================] - 26s 90ms/step - loss: 0.0912 - binary_accuracy: 0.9632 - val_loss: 0.0745 - val_binary_accuracy: 0.9733
Epoch 16/20
291/291 [==============================] - 26s 90ms/step - loss: 0.0864 - binary_accuracy: 0.9653 - val_loss: 0.0775 - val_binary_accuracy: 0.9721
Epoch 17/20
291/291 [==============================] - 26s 89ms/step - loss: 0.0955 - binary_accuracy: 0.9597 - val_loss: 0.0749 - val_binary_accuracy: 0.9716
Epoch 18/20
291/291 [==============================] - 26s 89ms/step - loss: 0.0910 - binary_accuracy: 0.9640 - val_loss: 0.0748 - val_binary_accuracy: 0.9742
Epoch 19/20
291/291 [==============================] - 26s 90ms/step - loss: 0.0904 - binary_accuracy: 0.9629 - val_loss: 0.0749 - val_binary_accuracy: 0.9725
Epoch 20/20
291/291 [==============================] - 26s 90ms/step - loss: 0.0879 - binary_accuracy: 0.9636 - val_loss: 0.0752 - val_binary_accuracy: 0.9725
<keras.callbacks.History at 0x7efc6c5dd5e0>

对整个模型进行一轮微调

最后,我们解冻基础模型,并以较低的学习率端到端地训练整个模型。

重要的是,尽管基础模型变得可训练,但在构建模型过程中,由于我们在调用该模型时传递了 training=False,因此它仍在推断模式下运行。这意味着内部的批次归一化层不会更新其批次统计信息。如果它们更新了这些统计信息,则会破坏该模型到目前为止所学习的表示。

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_5 (InputLayer)        [(None, 150, 150, 3)]     0         
                                                                 
 sequential_3 (Sequential)   (None, 150, 150, 3)       0         
                                                                 
 rescaling (Rescaling)       (None, 150, 150, 3)       0         
                                                                 
 xception (Functional)       (None, 5, 5, 2048)        20861480  
                                                                 
 global_average_pooling2d (G  (None, 2048)             0         
 lobalAveragePooling2D)                                          
                                                                 
 dropout (Dropout)           (None, 2048)              0         
                                                                 
 dense_7 (Dense)             (None, 1)                 2049      
                                                                 
=================================================================
Total params: 20,863,529
Trainable params: 20,809,001
Non-trainable params: 54,528
_________________________________________________________________
Epoch 1/10
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting ImageProjectiveTransformV3 cause there is no registered converter for this op.
291/291 [==============================] - 62s 193ms/step - loss: 0.0752 - binary_accuracy: 0.9702 - val_loss: 0.0569 - val_binary_accuracy: 0.9772
Epoch 2/10
291/291 [==============================] - 55s 189ms/step - loss: 0.0560 - binary_accuracy: 0.9788 - val_loss: 0.0539 - val_binary_accuracy: 0.9789
Epoch 3/10
291/291 [==============================] - 55s 190ms/step - loss: 0.0436 - binary_accuracy: 0.9839 - val_loss: 0.0485 - val_binary_accuracy: 0.9807
Epoch 4/10
291/291 [==============================] - 55s 189ms/step - loss: 0.0307 - binary_accuracy: 0.9872 - val_loss: 0.0511 - val_binary_accuracy: 0.9819
Epoch 5/10
291/291 [==============================] - 55s 189ms/step - loss: 0.0280 - binary_accuracy: 0.9901 - val_loss: 0.0512 - val_binary_accuracy: 0.9815
Epoch 6/10
291/291 [==============================] - 55s 189ms/step - loss: 0.0150 - binary_accuracy: 0.9944 - val_loss: 0.0541 - val_binary_accuracy: 0.9837
Epoch 7/10
291/291 [==============================] - 55s 188ms/step - loss: 0.0144 - binary_accuracy: 0.9946 - val_loss: 0.0576 - val_binary_accuracy: 0.9815
Epoch 8/10
291/291 [==============================] - 55s 188ms/step - loss: 0.0168 - binary_accuracy: 0.9943 - val_loss: 0.0524 - val_binary_accuracy: 0.9828
Epoch 9/10
291/291 [==============================] - 55s 189ms/step - loss: 0.0130 - binary_accuracy: 0.9967 - val_loss: 0.0515 - val_binary_accuracy: 0.9837
Epoch 10/10
291/291 [==============================] - 55s 188ms/step - loss: 0.0087 - binary_accuracy: 0.9972 - val_loss: 0.0550 - val_binary_accuracy: 0.9845
<keras.callbacks.History at 0x7efc6c30fa00>

经过 10 个周期后,微调在这里为我们提供了出色的改进。