転移学習と微調整

TensorFlow.orgで表示 GoogleColabで実行 GitHubでソースを表示ノートブックをダウンロードする

セットアップ

import numpy as np
import tensorflow as tf
from tensorflow import keras

前書き

転移学習は、1つの問題で学習した機能を取得し、それらを新しい同様の問題で活用することで構成されます。たとえば、タヌキを識別することを学んだモデルの機能は、タヌキを識別することを目的としたモデルをキックスタートするのに役立つ場合があります。

転移学習は通常、データセットのデータが少なすぎて実物大のモデルを最初からトレーニングできないタスクに対して行われます。

ディープラーニングのコンテキストでの転移学習の最も一般的な化身は、次のワークフローです。

  1. 以前にトレーニングされたモデルからレイヤーを取得します。
  2. 将来のトレーニングラウンド中に含まれる情報が破壊されないように、それらをフリーズします。
  3. 凍結したレイヤーの上に、トレーニング可能な新しいレイヤーをいくつか追加します。彼らは、古い機能を新しいデータセットの予測に変換する方法を学びます。
  4. データセットの新しいレイヤーをトレーニングします。

最後のオプションのステップは微調整です。これは、上記で取得したモデル全体(またはその一部)のフリーズを解除し、非常に低い学習率で新しいデータでモデルを再トレーニングすることで構成されます。これは、事前にトレーニングされた機能を新しいデータに段階的に適応させることにより、意味のある改善を実現できる可能性があります。

最初に、Kerasのtrainable APIについて詳しく説明します。これは、ほとんどの転送学習と微調整のワークフローの基礎になっています。

次に、ImageNetデータセットで事前トレーニングされたモデルを取得し、Kaggleの「猫vs犬」分類データセットで再トレーニングすることにより、一般的なワークフローを示します。

これは、Pythonを使用したディープラーニングと2016年のブログ投稿「非常に少ないデータを使用して強力な画像分類モデルを構築する」から採用されています

凍結層: trainable属性を理解する

レイヤーとモデルには、次の3つの重み属性があります。

  • weightsは、レイヤーのすべての重み変数のリストです。
  • trainable_weightsは、トレーニング中の損失を最小限に抑えるために(勾配降下法を介して)更新されることを意図したもののリストです。
  • non_trainable_weightsは、トレーニングされることを意図していないもののリストです。通常、これらはフォワードパス中にモデルによって更新されます。

例: Denseレイヤーには2つのトレーニング可能なウェイト(カーネルとバイアス)があります

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

一般に、すべてのウェイトはトレーニング可能なウェイトです。トレーニング不可能な重みを持つ唯一の組み込みレイヤーは、 BatchNormalizationレイヤーです。トレーニング中の入力の平均と分散を追跡するために、トレーニング不可能な重みを使用します。独自のカスタムレイヤーでトレーニング不可能なウェイトを使用する方法については、新しいレイヤーを最初から作成するためのガイドを参照してください

例: BatchNormalizationレイヤーには、2つのトレーニング可能なウェイトと2つのトレーニング不可能なウェイトがあります

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

レイヤーとモデルには、 trainableブール属性もあります。その値は変更できます。 layer.trainableFalse設定すると、すべてのレイヤーの重みが訓練可能から訓練不可能に移動します。これはレイヤーの「フリーズ」と呼ばれます。フリーズされたレイヤーの状態は、トレーニング中に更新されません( fit()使用してトレーニングする場合、またはtrainable_weightsに依存して勾配更新を適用するカスタムループを使用してtrainable_weightsする場合)。

例: trainableFalse設定

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

訓練可能なウェイトが訓練不可能になると、その値は訓練中に更新されなくなります。

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 1s 664ms/step - loss: 0.1025

layer.trainable属性をlayer.__call__()の引数trainingと混同しないでください(レイヤーが推論モードとトレーニングモードのどちらでフォワードパスを実行するかを制御します)。詳細については、 KerasFAQを参照してください

trainable属性の再帰的設定

モデルまたはサブレイヤーを持つレイヤーでtrainable = Falseを設定すると、すべての子レイヤーもトレーニング不能になります。

例:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

典型的な転送学習ワークフロー

これにより、典型的な転移学習ワークフローをKerasで実装する方法がわかります。

  1. ベースモデルをインスタンス化し、事前にトレーニングされたウェイトをモデルにロードします。
  2. trainable = False設定して、ベースモデルのすべてのレイヤーをフリーズします。
  3. ベースモデルからの1つ(または複数)のレイヤーの出力の上に新しいモデルを作成します。
  4. 新しいデータセットで新しいモデルをトレーニングします。

別のより軽量なワークフローも次のようになる可能性があることに注意してください。

  1. ベースモデルをインスタンス化し、事前にトレーニングされたウェイトをモデルにロードします。
  2. 新しいデータセットを実行し、ベースモデルからの1つ(または複数)のレイヤーの出力を記録します。これは特徴抽出と呼ばれます
  3. その出力を、新しい、より小さなモデルの入力データとして使用します。

この2番目のワークフローの主な利点は、トレーニングのエポックごとに1回ではなく、データに対して1回だけ基本モデルを実行することです。そのため、はるかに高速で安価です。

ただし、この2番目のワークフローの問題は、トレーニング中に新しいモデルの入力データを動的に変更できないことです。これは、たとえばデータ拡張を行うときに必要になります。転移学習は通常、新しいデータセットのデータが少なすぎてフルスケールモデルを最初からトレーニングできない場合のタスクに使用されます。このようなシナリオでは、データ拡張が非常に重要です。したがって、以下では、最初のワークフローに焦点を当てます。

Kerasでの最初のワークフローは次のようになります。

まず、事前にトレーニングされた重みを使用してベースモデルをインスタンス化します。

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

次に、ベースモデルをフリーズします。

base_model.trainable = False

上に新しいモデルを作成します。

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

新しいデータでモデルをトレーニングします。

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

微調整

モデルが新しいデータに収束したら、基本モデルの全部または一部のフリーズを解除し、非常に低い学習率でモデル全体をエンドツーエンドで再トレーニングすることができます。

これはオプションの最後のステップであり、段階的な改善をもたらす可能性があります。また、急速な過剰適合につながる可能性もあります。そのことに注意してください。

このステップは、フリーズされたレイヤーを持つモデルが収束するようにトレーニングされた後でのみ実行することが重要です。ランダムに初期化されたトレーニング可能なレイヤーを、事前にトレーニングされた機能を保持するトレーニング可能なレイヤーと混合すると、ランダムに初期化されたレイヤーは、トレーニング中に非常に大きな勾配の更新を引き起こし、事前にトレーニングされた機能を破壊します。

また、この段階では非常に低い学習率を使用することも重要です。これは、通常は非常に小さいデータセットで、トレーニングの最初のラウンドよりもはるかに大きなモデルをトレーニングしているためです。その結果、大規模なウェイト更新を適用すると、すぐに過剰適合するリスクがあります。ここでは、事前にトレーニングされた重みを段階的に再適応させるだけです。

基本モデル全体の微調整を実装する方法は次のとおりです。

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

compile()trainableに関する重要な注意

モデルでcompile()を呼び出すことは、そのモデルの動作を「フリーズ」することを意味します。これは、モデルがコンパイルされたときのtrainable属性値は、 compileが再度呼び出されるまで、そのモデルの存続期間を通じて保持される必要があることを意味します。したがって、 trainable値を変更する場合は、変更を考慮に入れるために、モデルでcompile()再度呼び出すようにしてください。

BatchNormalizationレイヤーに関する重要な注意事項

多くの画像モデルには、 BatchNormalizationレイヤーが含まれてBatchNormalizationます。その層は、考えられるすべての点で特別な場合です。ここに覚えておくべきいくつかの事柄があります。

  • BatchNormalizationは、トレーニング中に更新される2つのトレーニング不可能な重みが含まれています。これらは、入力の平均と分散を追跡する変数です。
  • bn_layer.trainable = Falseに設定すると、 BatchNormalizationレイヤーは推論モードで実行され、平均と分散の統計は更新されません。 重みのトレーニング可能性と推論/トレーニングモードは2つの直交する概念であるため、これは一般に他のレイヤーには当てはまりません。ただし、 BatchNormalizationレイヤーの場合、この2つは結びついています。
  • 微調整を行うためにBatchNormalizationレイヤーを含むモデルのBatchNormalizationを解除する場合、ベースモデルを呼び出すときにtraining=Falseを渡して、 BatchNormalizationレイヤーを推論モードに保つ必要があります。そうしないと、訓練不可能な重みに適用された更新によって、モデルが学習した内容が突然破壊されます。

このガイドの最後にあるエンドツーエンドの例で、このパターンの動作を確認できます。

カスタムトレーニングループを使用した学習と微調整の転送

fit()代わりに、独自の低レベルのトレーニングループを使用している場合、ワークフローは基本的に同じままです。グラデーションの更新を適用するときは、リストmodel.trainable_weightsのみを考慮するように注意する必要があります。

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

同様に微調整のために。

エンドツーエンドの例:猫と犬のデータセットでの画像分類モデルの微調整

これらの概念を固めるために、具体的なエンドツーエンドの転送学習と微調整の例を見ていきましょう。 ImageNetで事前トレーニングされたXceptionモデルをロードし、Kaggleの「猫と犬」の分類データセットで使用します。

データの取得

まず、TFDSを使用して猫と犬のデータセットを取得しましょう。独自のデータセットがある場合は、ユーティリティtf.keras.preprocessing.image_dataset_from_directoryを使用して、クラス固有のフォルダーにtf.keras.preprocessing.image_dataset_from_directoryされたディスク上の一連の画像から同様のラベル付きデータセットオブジェクトを生成することをお勧めします。

転移学習は、非常に小さなデータセットを扱うときに最も役立ちます。データセットを小さく保つために、元のトレーニングデータ(25,000枚の画像)の40%をトレーニングに、10%を検証に、10%をテストに使用します。

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

これらはトレーニングデータセットの最初の9つの画像です。ご覧のとおり、すべて異なるサイズです。

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

また、ラベル1が「犬」でラベル0が「猫」であることがわかります。

データの標準化

RAW画像にはさまざまなサイズがあります。さらに、各ピクセルは0〜255の3つの整数値(RGBレベル値)で構成されます。これは、ニューラルネットワークへの給電には最適ではありません。私たちは2つのことをする必要があります:

  • 固定画像サイズに標準化します。 150x150を選択します。
  • -1から1までのピクセル値を正規化します。これは、モデル自体の一部としてNormalizationレイヤーを使用して行います。

一般に、すでに前処理されたデータを取得するモデルではなく、生データを入力として取得するモデルを開発することをお勧めします。その理由は、モデルが前処理されたデータを予期している場合、モデルをエクスポートして他の場所(Webブラウザーやモバイルアプリ)で使用するときはいつでも、まったく同じ前処理パイプラインを再実装する必要があるためです。これは非常にすぐにトリッキーになります。したがって、モデルに到達する前に、可能な限り最小限の前処理を実行する必要があります。

ここでは、データパイプラインで画像のサイズ変更を行い(ディープニューラルネットワークはデータの連続したバッチしか処理できないため)、モデルを作成するときに、モデルの一部として入力値のスケーリングを行います。

画像のサイズを150x150に変更しましょう。

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

さらに、データをバッチ処理し、キャッシュとプリフェッチを使用して読み込み速度を最適化しましょう。

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

ランダムデータ拡張の使用

大きな画像データセットがない場合は、ランダムな水平反転や小さなランダムな回転など、ランダムでありながら現実的な変換をトレーニング画像に適用して、サンプルの多様性を人為的に導入することをお勧めします。これは、過剰適合を遅くしながら、モデルをトレーニングデータのさまざまな側面に公開するのに役立ちます。

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [
        layers.experimental.preprocessing.RandomFlip("horizontal"),
        layers.experimental.preprocessing.RandomRotation(0.1),
    ]
)

さまざまなランダム変換後の最初のバッチの最初の画像がどのように見えるかを視覚化してみましょう。

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[0]))
        plt.axis("off")

png

モデルを構築する

それでは、前に説明した青写真に従ったモデルを作成しましょう。

ご了承ください:

  • Normalizationレイヤーを追加して、入力値(最初は[0, 255]範囲)を[-1, 1] [0, 255]範囲にスケーリングします。
  • 正則化のために、分類レイヤーの前にDropoutレイヤーを追加します。
  • ベースモデルを呼び出すときにtraining=Falseを渡すようにして、推論モードで実行するようにします。これにより、微調整のためにベースモデルをフリーズ解除した後でも、batchnorm統計が更新されません。
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be normalized
# from (0, 255) to a range (-1., +1.), the normalization layer
# does the following, outputs = (inputs - mean) / sqrt(var)
norm_layer = keras.layers.experimental.preprocessing.Normalization()
mean = np.array([127.5] * 3)
var = mean ** 2
# Scale inputs to [-1, +1]
x = norm_layer(x)
norm_layer.set_weights([mean, var])

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83689472/83683744 [==============================] - 1s 0us/step
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
normalization (Normalization (None, 150, 150, 3)       7         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,536
Trainable params: 2,049
Non-trainable params: 20,861,487
_________________________________________________________________

最上層をトレーニングする

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
291/291 [==============================] - 20s 49ms/step - loss: 0.2226 - binary_accuracy: 0.8972 - val_loss: 0.0805 - val_binary_accuracy: 0.9703
Epoch 2/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1246 - binary_accuracy: 0.9464 - val_loss: 0.0757 - val_binary_accuracy: 0.9712
Epoch 3/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1153 - binary_accuracy: 0.9480 - val_loss: 0.0724 - val_binary_accuracy: 0.9733
Epoch 4/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1055 - binary_accuracy: 0.9575 - val_loss: 0.0753 - val_binary_accuracy: 0.9721
Epoch 5/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1026 - binary_accuracy: 0.9589 - val_loss: 0.0750 - val_binary_accuracy: 0.9703
Epoch 6/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1022 - binary_accuracy: 0.9587 - val_loss: 0.0723 - val_binary_accuracy: 0.9716
Epoch 7/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1009 - binary_accuracy: 0.9570 - val_loss: 0.0731 - val_binary_accuracy: 0.9708
Epoch 8/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0947 - binary_accuracy: 0.9576 - val_loss: 0.0726 - val_binary_accuracy: 0.9716
Epoch 9/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0872 - binary_accuracy: 0.9624 - val_loss: 0.0720 - val_binary_accuracy: 0.9712
Epoch 10/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0892 - binary_accuracy: 0.9622 - val_loss: 0.0711 - val_binary_accuracy: 0.9716
Epoch 11/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0987 - binary_accuracy: 0.9608 - val_loss: 0.0752 - val_binary_accuracy: 0.9712
Epoch 12/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0962 - binary_accuracy: 0.9595 - val_loss: 0.0715 - val_binary_accuracy: 0.9738
Epoch 13/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0972 - binary_accuracy: 0.9606 - val_loss: 0.0700 - val_binary_accuracy: 0.9725
Epoch 14/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1019 - binary_accuracy: 0.9568 - val_loss: 0.0779 - val_binary_accuracy: 0.9690
Epoch 15/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0929 - binary_accuracy: 0.9614 - val_loss: 0.0700 - val_binary_accuracy: 0.9729
Epoch 16/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0937 - binary_accuracy: 0.9610 - val_loss: 0.0698 - val_binary_accuracy: 0.9742
Epoch 17/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0945 - binary_accuracy: 0.9613 - val_loss: 0.0671 - val_binary_accuracy: 0.9759
Epoch 18/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0868 - binary_accuracy: 0.9612 - val_loss: 0.0692 - val_binary_accuracy: 0.9738
Epoch 19/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0871 - binary_accuracy: 0.9647 - val_loss: 0.0691 - val_binary_accuracy: 0.9746
Epoch 20/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0922 - binary_accuracy: 0.9603 - val_loss: 0.0721 - val_binary_accuracy: 0.9738
<tensorflow.python.keras.callbacks.History at 0x7fb73f231860>

モデル全体の微調整を行います

最後に、基本モデルのフリーズを解除し、低い学習率でモデル全体をエンドツーエンドでトレーニングしましょう。

重要なのは、基本モデルはトレーニング可能になりますが、モデルの作成時に呼び出すときにtraining=Falseを渡したため、推論モードで実行されていることです。これは、内部のバッチ正規化レイヤーがバッチ統計を更新しないことを意味します。もしそうなら、彼らはこれまでのモデルによって学んだ表現に大混乱をもたらすでしょう。

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
normalization (Normalization (None, 150, 150, 3)       7         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,536
Trainable params: 20,809,001
Non-trainable params: 54,535
_________________________________________________________________
Epoch 1/10
291/291 [==============================] - 43s 133ms/step - loss: 0.0814 - binary_accuracy: 0.9677 - val_loss: 0.0527 - val_binary_accuracy: 0.9776
Epoch 2/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0544 - binary_accuracy: 0.9796 - val_loss: 0.0537 - val_binary_accuracy: 0.9776
Epoch 3/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0481 - binary_accuracy: 0.9822 - val_loss: 0.0471 - val_binary_accuracy: 0.9789
Epoch 4/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0324 - binary_accuracy: 0.9871 - val_loss: 0.0551 - val_binary_accuracy: 0.9807
Epoch 5/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0298 - binary_accuracy: 0.9899 - val_loss: 0.0447 - val_binary_accuracy: 0.9807
Epoch 6/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0262 - binary_accuracy: 0.9901 - val_loss: 0.0469 - val_binary_accuracy: 0.9824
Epoch 7/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0242 - binary_accuracy: 0.9918 - val_loss: 0.0539 - val_binary_accuracy: 0.9798
Epoch 8/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0153 - binary_accuracy: 0.9935 - val_loss: 0.0644 - val_binary_accuracy: 0.9794
Epoch 9/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0175 - binary_accuracy: 0.9934 - val_loss: 0.0496 - val_binary_accuracy: 0.9819
Epoch 10/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0171 - binary_accuracy: 0.9936 - val_loss: 0.0496 - val_binary_accuracy: 0.9828
<tensorflow.python.keras.callbacks.History at 0x7fb74f74f940>

10エポックの後、微調整により、ここで素晴らしい改善が得られます。