転移学習とファインチューニング

コレクションでコンテンツを整理 必要に応じて、コンテンツの保存と分類を行います。

TensorFlow.org で実行 Google Colabで実行 GitHubでソースを表示 ノートブックをダウンロード

セットアップ

import numpy as np
import tensorflow as tf
from tensorflow import keras
2022-08-09 04:52:49.097989: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-08-09 04:52:49.651158: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-09 04:52:49.651389: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-09 04:52:49.651401: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

はじめに

ある問題で学習した特徴量を取り入れ、それを新しい類似した問題に利用する方法を転移学習と呼びます。たとえば、アライグマの識別を学習したモデルの特徴量がある場合、それを使用してタヌキの識別を学習するモデルに取り組むことができます。

通常、転移学習はデータセットのデータが少なすぎてフルスケールモデルをゼロからトレーニングできないようなタスクで行われます。

ディープラーニングの文脈では、転移学習は次のワークフローで行われるのが最も一般的です。

  1. 以前にトレーニングされたモデルからレイヤーを取得します。
  2. 以降のトレーニングラウンドでそれらのレイヤーに含まれる情報が破損しないように凍結します。
  3. 凍結したレイヤーの上にトレーニング対象のレイヤーを新たに追加します。これらのレイヤーは古い特徴量を新しいデータセットの予測に変換することを学習します。
  4. データセットで新しいレイヤーをトレーニングします。

最後に任意でファインチューニングを実施できます。ファインチューニングでは、上記で取得したモデル全体(または一部)を解凍し、新しいデータに対して非常に低い学習率で再トレーニングします。これを実施すると、事前トレーニング済みの特徴量を徐々に新しいデータに適応させ、意味のある改善を得られることがあります。

まずは、転移学習とファインチューニングのほとんどのワークフローの基礎である、Keras の trainable API について詳しく見てみましょう。

次に、一般的なワークフローを説明します。ImageNet データセットで事前にトレーニングされたモデルを取得してそれを Kaggle の「犬と猫」分類データセットで再トレーニングしてみましょう。

これは、Deep Learning with Python および 2016 年のブログ記事「少ないデータで強力な画像分類モデルを構築する」を基にしています。

レイヤーの凍結: trainable 属性を理解する

レイヤーとモデルには 3 つの重み属性があります。

  • weights は、レイヤーのすべての重み変数のリストです。
  • trainable_weights は、トレーニング中の損失を最小限に抑えるために(勾配降下を介して)更新を意図した重みのリストです。
  • non_trainable_weights は、トレーニングを意図していない重みのリストです。通常は、フォワードパスの間にモデルによって更新されます。

例: Dense レイヤーに、トレーニング対象の重みが 2 つ(カーネルとバイアス)がある

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

一般的に、重みはすべてトレーニング対象の重みです。トレーニング対象外の重みレイヤーを持つ組み込みレイヤーは、BatchNormalization レイヤーしかありません。これはトレーニング対象外の重みを使用して、トレーニング中の入力の平均と分散を追跡します。独自のカスタムレイヤーでトレーニング対象外の重みを使用する方法については、レイヤーの新規作成ガイドをご覧ください。

例: BatchNormalization レイヤーに 、トレーニング対象の重みとトレーニング対象外の重みが 2 つずつある

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

レイヤーとモデルには、ブール属性の trainable もあり、その値を変更することができます。layer.trainableFalse に設定すると、すべてのレイヤーの重みがトレーニング対象からトレーニング対象外に移動されます。これはレイヤーの「凍結」と呼ばれるもので、凍結されたレイヤーの状態はトレーニング中に更新されません(fit() でトレーニングする場合、またはtrainable_weights に依存して勾配の更新を適用するカスタムループでトレーニングする場合)。

例: trainableFalseに設定する

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

トレーニング対象の重みがトレーニング対象外の重みになると、その値はトレーニング中に更新されなくなります。

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 2s 2s/step - loss: 0.0516

layer.trainable 属性を layer.__call__() の引数 training と混同しないようにしてください(後者は、レイヤーがフォワードパスを推論モードで実行するか、トレーニングモードで実行するかを制御します)。詳細については、Keras よくある質問をご覧ください。

trainable 属性を再帰的に設定する

モデルや、サブレイヤーのあるレイヤーで trainable = False を設定すると、すべての子レイヤーもトレーニング対象外になります。

例:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

典型的な転移学習のワークフロー

ここでは、典型的な転移学習のワークフローを Keras に実装する方法を示します。

  1. ベースモデルをインスタンス化し、それに事前トレーニング済みの重みを読み込みます。
  2. trainable = False を設定して、ベースモデルのすべてのレイヤーを凍結します。
  3. ベースモデルの 1 つ以上のレイヤーの出力上に新しいモデルを作成します。
  4. 新しいデータセットで新しいモデルをトレーニングします。

これに代わる、より軽量なワークフローとして、以下のようなものも考えられます。

  1. ベースモデルをインスタンス化し、それに事前トレーニング済みの重みを読み込みます。
  2. 新しいデータセットを実行して、ベースモデルの 1 つ以上のレイヤーの出力を記録します。特徴量抽出と呼ばれる作業です。
  3. その出力を新しい小さなモデルの入力データとして使用します。

この 2 番目のワークフローには、トレーニングのエポックごとに 1 回ではなく、データに対して 1 回だけベースモデルを実行するため、はるかに高速で安価になるというメリットがあります。

ただし、トレーニング中に新しいモデルの入力データを動的に変更することができないという問題があります。これはデータを拡張する際などに必要なことです。転移学習は通常、フルスケールでモデルを新規にトレーニングするには新しいデータセットのデータ量が少なすぎる場合に使用しますが、そのような場合、データの拡張が非常に重要になります。そこで以降では、1 番目のワークフローに焦点を当てます。

1 番目のワークフローは、Keras では以下のようになります。

まず最初に、事前トレーニング済みの重みを使用してベースモデルをインスタンス化します。

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

次に、ベースモデルを凍結します。

base_model.trainable = False

その上に新しいモデルを作成します。

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

新しいデータでモデルをトレーニングします。

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

ファインチューニング

モデルが新しいデータで収束したら、ベースモデルのすべてまたは一部を解凍して、非常に低い学習率でエンドツーエンドでモデル全体を再トレーニングすることができます。

これは任意に行える最後のステップではありますが、段階的な改善を期待することができます。ただし、すぐに過適合になる可能性もあることに注意してください。

このステップは、凍結レイヤーのあるモデルが収束するまでトレーニングされた後にのみ行うことが重要です。ランダムに初期化されたトレーニング対象レイヤーと事前にトレーニングされた特徴量を持つトレーニング対象レイヤーを混ぜると、トレーニング中に、ランダムに初期化されたレイヤーによって非常に大きな勾配の更新が発生し、事前にトレーニングされた特徴量が破損してしまうことになります。

また、この段階では学習率が非常に低いことも重要です。1 回目のトレーニングよりもはるかに大きなモデルを、非常に小さなデータセットでトレーニングするからです。その結果、大量の重みの更新を適用すると、あっという間に過適合が起きてしまう危険性があります。ここでは、事前トレーニング済みの重みを段階的に適応し直します。

ベースモデル全体のファインチューニングを実装するには、以下のようにします。

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

compile() および trainable に関する重要な注意事項

モデルで compile() を呼び出すと、そのモデルの動作が「凍結」されます。これは、モデルがコンパイルされたときの trainable 属性の値は、compile が再び呼び出されるまで、そのモデルの寿命が続く限り保持されるということです。したがって、trainable の値を変更した場合には、その内容が考慮されるように必ずモデルでもう一度 compile() を呼び出してください。

BatchNormalization レイヤーに関する重要な注意事項

多くの画像モデルには BatchNormalization レイヤーが含まれています。このレイヤーは、あらゆる点において特殊なケースです。ここにいくつかの注意点を示します。

  • BatchNormalization には、トレーニング中に更新されるトレーニング対象外の重みが 2 つ含まれています。これらは入力の平均と分散を追跡する変数です。
  • bn_layer.trainable = False を設定すると、BatchNormalization レイヤーは推論モードで実行されるため、その平均と分散の統計は更新されません。重みのトレーナビリティと推論/トレーニングモードは 2 つの直交する概念であるため、これは一般的には他のレイヤーには当てはまりませんが、BatchNormalization レイヤーの場合は、この 2 つは関連しています。
  • ファインチューニングを行うために BatchNormalization レイヤーを含むモデルを解凍する場合、ベースモデルを呼び出す際に training = False を渡して BatchNormalization レイヤーを推論モードにしておく必要があります。推論モードになっていない場合、トレーニング対象外の重みに適用された更新によって、モデルが学習したものが突然破壊されてしまいます。

このガイドの最後にあるエンドツーエンドの例で、このパターンの動作を確認することができます。

カスタムトレーニングループで転移学習とファインチューニングをする

fit() の代わりに独自の低レベルのトレーニングループを使用している場合でも、ワークフローは基本的に同じです。勾配の更新を適用する際には、model.trainable_weights のリストのみを考慮するように注意する必要があります。

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

ファインチューニングの場合も同様です。

エンドツーエンドの例: 猫と犬データセットの画像分類モデルのファインチューニング

この概念を固めるために、具体的なエンドツーエンドの転移学習とファインチューニングの例を見てみましょう。ImageNet で事前トレーニングされた Xception モデルを読み込み、Kaggleの 「犬と猫」分類データセットで使用します。

データを取得する

まず、TFDS を使用して犬と猫のデータセットを取得してみましょう。独自のデータセットをお持ちの場合は、tf.keras.preprocessing.image_dataset_from_directory ユーティリティを使用して、クラス固有のフォルダにファイル作成されたディスク上の画像集合から類似のラベル付きデータセットオブジェクトを生成することもできます。

転移学習は、非常に小さなデータセットを扱う場合に最も有用です。データセットを小さく保つために、元のトレーニングデータ(画像 25,000 枚)の 40 %をトレーニングに、10 %を検証に、10 %をテストに使用します。

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

ここにトレーニングデータセットの最初の 9 枚の画像があります。ご覧の通り、サイズはバラバラです。

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

また、ラベル 1 が「犬」、ラベル 0 が「猫」であることもわかります。

データを標準化する

生の画像には様々なサイズがあります。さらに、各ピクセルは 0 ~ 255 の 3 つの整数値(RGB レベル値)で構成されています。これは、ニューラルネットワークへの供給には適しません。次の 2 つを行う必要があります。

  • 標準化して画像サイズを固定します。 150x150 を選択します。
  • ピクセル値を -1 〜 1 に正規化します。これはモデル自体の一部としてNormalizationレイヤーを使用して行います。

一般的に、すでに処理済みのデータを使用するモデルとは対照的に、入力に生のデータを使用するモデルを開発するのは良い実践です。その理由は、モデルが前処理されたデータを期待していると、モデルをエクスポートして他の場所(ウェブブラウザやモバイルアプリ)で使用する際には、まったく同じ前処理パイプラインを常に再実装する必要が生じるからです。これはすぐに非常に面倒なことになります。だからこそ、モデルを使用する前に可能な限りの前処理を行う必要があるのです。

ここでは、(ディープニューラルネットワークは連続したデータバッチしか処理できないので)データパイプラインで画像のリサイズを行い、モデルを作成する際にモデルの一部として入力値のスケーリングを行います。

画像を 150×150 にリサイズしてみましょう。

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

さらに、データをバッチ処理して、キャッシング&プリフェッチを使用し、読み込み速度を最適化してみましょう。

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

ランダムデータ拡張を使用する

大規模な画像データセットを持っていない場合には、ランダムに水平反転や少し回転を加えるなど、ランダムでありながら現実的な変換をトレーニング画像に適用し、サンプルの多様性を人為的に導入するのが良い実践です。これによって、過適合を遅らせると同時にトレーニングデータの異なった側面にモデルを公開することができます。

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [layers.RandomFlip("horizontal"), layers.RandomRotation(0.1),]
)

さまざまなランダム変換の後、最初のバッチの最初の画像がどのように見えるかを可視化してみましょう。

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[0]))
        plt.axis("off")
2022-08-09 04:53:00.691811: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

モデルを構築する

では、先ほど説明した青写真に沿ってモデルを構築してみましょう。

注意点:

  • Rescaling レイヤーを追加して、入力値(最初の範囲は[0, 255])を[-1, 1]の範囲にスケーリングします。
  • 正則化のために、分類レイヤーの前に Dropout レイヤーを追加します。
  • ベースモデルを呼び出す際に training=False を渡して推論モードで動作するようにし、ファインチューニングを実行するためにベースモデルを解凍した後でも BatchNorm の統計が更新されないようにします。
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83683744/83683744 [==============================] - 0s 0us/step
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_5 (InputLayer)        [(None, 150, 150, 3)]     0         
                                                                 
 sequential_3 (Sequential)   (None, 150, 150, 3)       0         
                                                                 
 rescaling (Rescaling)       (None, 150, 150, 3)       0         
                                                                 
 xception (Functional)       (None, 5, 5, 2048)        20861480  
                                                                 
 global_average_pooling2d (G  (None, 2048)             0         
 lobalAveragePooling2D)                                          
                                                                 
 dropout (Dropout)           (None, 2048)              0         
                                                                 
 dense_7 (Dense)             (None, 1)                 2049      
                                                                 
=================================================================
Total params: 20,863,529
Trainable params: 2,049
Non-trainable params: 20,861,480
_________________________________________________________________

トップレイヤーをトレーニングする

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
148/291 [==============>...............] - ETA: 7s - loss: 0.2223 - binary_accuracy: 0.8986
Corrupt JPEG data: 65 extraneous bytes before marker 0xd9
262/291 [==========================>...] - ETA: 1s - loss: 0.1835 - binary_accuracy: 0.9195
Corrupt JPEG data: 239 extraneous bytes before marker 0xd9
275/291 [===========================>..] - ETA: 0s - loss: 0.1805 - binary_accuracy: 0.9212
Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9
Corrupt JPEG data: 228 extraneous bytes before marker 0xd9
291/291 [==============================] - ETA: 0s - loss: 0.1776 - binary_accuracy: 0.9225
Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9
291/291 [==============================] - 26s 65ms/step - loss: 0.1776 - binary_accuracy: 0.9225 - val_loss: 0.0864 - val_binary_accuracy: 0.9686
Epoch 2/20
291/291 [==============================] - 16s 56ms/step - loss: 0.1229 - binary_accuracy: 0.9491 - val_loss: 0.0765 - val_binary_accuracy: 0.9682
Epoch 3/20
291/291 [==============================] - 16s 56ms/step - loss: 0.1085 - binary_accuracy: 0.9537 - val_loss: 0.0781 - val_binary_accuracy: 0.9699
Epoch 4/20
291/291 [==============================] - 16s 56ms/step - loss: 0.1084 - binary_accuracy: 0.9546 - val_loss: 0.0724 - val_binary_accuracy: 0.9733
Epoch 5/20
291/291 [==============================] - 16s 56ms/step - loss: 0.0992 - binary_accuracy: 0.9582 - val_loss: 0.0726 - val_binary_accuracy: 0.9690
Epoch 6/20
291/291 [==============================] - 16s 56ms/step - loss: 0.0976 - binary_accuracy: 0.9589 - val_loss: 0.0693 - val_binary_accuracy: 0.9729
Epoch 7/20
291/291 [==============================] - 16s 56ms/step - loss: 0.0988 - binary_accuracy: 0.9591 - val_loss: 0.0805 - val_binary_accuracy: 0.9699
Epoch 8/20
291/291 [==============================] - 16s 57ms/step - loss: 0.0999 - binary_accuracy: 0.9588 - val_loss: 0.0715 - val_binary_accuracy: 0.9729
Epoch 9/20
291/291 [==============================] - 16s 56ms/step - loss: 0.0931 - binary_accuracy: 0.9629 - val_loss: 0.0717 - val_binary_accuracy: 0.9725
Epoch 10/20
291/291 [==============================] - 16s 56ms/step - loss: 0.0940 - binary_accuracy: 0.9624 - val_loss: 0.0777 - val_binary_accuracy: 0.9716
Epoch 11/20
291/291 [==============================] - 16s 56ms/step - loss: 0.0981 - binary_accuracy: 0.9602 - val_loss: 0.0745 - val_binary_accuracy: 0.9729
Epoch 12/20
291/291 [==============================] - 16s 56ms/step - loss: 0.0954 - binary_accuracy: 0.9607 - val_loss: 0.0698 - val_binary_accuracy: 0.9729
Epoch 13/20
291/291 [==============================] - 16s 56ms/step - loss: 0.0976 - binary_accuracy: 0.9593 - val_loss: 0.0679 - val_binary_accuracy: 0.9755
Epoch 14/20
291/291 [==============================] - 16s 56ms/step - loss: 0.0959 - binary_accuracy: 0.9634 - val_loss: 0.0699 - val_binary_accuracy: 0.9759
Epoch 15/20
291/291 [==============================] - 16s 56ms/step - loss: 0.0909 - binary_accuracy: 0.9625 - val_loss: 0.0682 - val_binary_accuracy: 0.9759
Epoch 16/20
291/291 [==============================] - 16s 56ms/step - loss: 0.0968 - binary_accuracy: 0.9622 - val_loss: 0.0678 - val_binary_accuracy: 0.9755
Epoch 17/20
291/291 [==============================] - 16s 56ms/step - loss: 0.0911 - binary_accuracy: 0.9622 - val_loss: 0.0728 - val_binary_accuracy: 0.9733
Epoch 18/20
291/291 [==============================] - 16s 56ms/step - loss: 0.0906 - binary_accuracy: 0.9648 - val_loss: 0.0729 - val_binary_accuracy: 0.9721
Epoch 19/20
291/291 [==============================] - 16s 56ms/step - loss: 0.0937 - binary_accuracy: 0.9622 - val_loss: 0.0708 - val_binary_accuracy: 0.9738
Epoch 20/20
291/291 [==============================] - 16s 56ms/step - loss: 0.0945 - binary_accuracy: 0.9608 - val_loss: 0.0746 - val_binary_accuracy: 0.9716
<keras.callbacks.History at 0x7fc2b0512a90>

モデル全体のファインチューニングを行う

最後に、ベースモデルを解凍して、モデル全体のエンドツーエンドを低い学習率でトレーニングしてみましょう。

重要なのは、モデル構築時の呼び出しで training=False を渡しているため、ベースモデルがトレーニング対象になっても、推論モードで動作しているということです。つまり、内側のバッチ正則化レイヤーのバッチ統計は更新されません。更新してしまうと、それまでにモデルが学習してきた表現が破壊されてしまいます。

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_5 (InputLayer)        [(None, 150, 150, 3)]     0         
                                                                 
 sequential_3 (Sequential)   (None, 150, 150, 3)       0         
                                                                 
 rescaling (Rescaling)       (None, 150, 150, 3)       0         
                                                                 
 xception (Functional)       (None, 5, 5, 2048)        20861480  
                                                                 
 global_average_pooling2d (G  (None, 2048)             0         
 lobalAveragePooling2D)                                          
                                                                 
 dropout (Dropout)           (None, 2048)              0         
                                                                 
 dense_7 (Dense)             (None, 1)                 2049      
                                                                 
=================================================================
Total params: 20,863,529
Trainable params: 20,809,001
Non-trainable params: 54,528
_________________________________________________________________
Epoch 1/10
291/291 [==============================] - 52s 159ms/step - loss: 0.0764 - binary_accuracy: 0.9702 - val_loss: 0.0486 - val_binary_accuracy: 0.9785
Epoch 2/10
291/291 [==============================] - 45s 154ms/step - loss: 0.0549 - binary_accuracy: 0.9792 - val_loss: 0.0573 - val_binary_accuracy: 0.9794
Epoch 3/10
291/291 [==============================] - 45s 154ms/step - loss: 0.0463 - binary_accuracy: 0.9816 - val_loss: 0.0446 - val_binary_accuracy: 0.9819
Epoch 4/10
291/291 [==============================] - 45s 154ms/step - loss: 0.0334 - binary_accuracy: 0.9870 - val_loss: 0.0454 - val_binary_accuracy: 0.9828
Epoch 5/10
291/291 [==============================] - 45s 154ms/step - loss: 0.0294 - binary_accuracy: 0.9884 - val_loss: 0.0528 - val_binary_accuracy: 0.9824
Epoch 6/10
291/291 [==============================] - 45s 154ms/step - loss: 0.0267 - binary_accuracy: 0.9903 - val_loss: 0.0452 - val_binary_accuracy: 0.9832
Epoch 7/10
291/291 [==============================] - 45s 154ms/step - loss: 0.0194 - binary_accuracy: 0.9929 - val_loss: 0.0434 - val_binary_accuracy: 0.9841
Epoch 8/10
291/291 [==============================] - 45s 154ms/step - loss: 0.0152 - binary_accuracy: 0.9952 - val_loss: 0.0576 - val_binary_accuracy: 0.9824
Epoch 9/10
291/291 [==============================] - 45s 154ms/step - loss: 0.0139 - binary_accuracy: 0.9956 - val_loss: 0.0480 - val_binary_accuracy: 0.9819
Epoch 10/10
291/291 [==============================] - 45s 154ms/step - loss: 0.0144 - binary_accuracy: 0.9941 - val_loss: 0.0523 - val_binary_accuracy: 0.9828
<keras.callbacks.History at 0x7fc314002490>

10エポック後、ファインチューニングによって有益な改善が得られます。