このページは Cloud Translation API によって翻訳されました。
Switch to English

転移学習と微調整

TensorFlow.orgで見る Google Colabで実行 GitHubでソースを表示する ノートブックをダウンロード

セットアップ

import numpy as np
import tensorflow as tf
from tensorflow import keras

前書き

転移学習は、1つの問題について学習した機能を取り入れ、それらを新しい同様の問題に活用することで構成されます。たとえば、アライグマを特定することを学習したモデルの特徴は、タヌキを特定することを目的としたモデルをキックスタートするのに役立ちます。

転移学習は通常、データセットのデータが少なすぎて、フルスケールモデルを最初からトレーニングできないタスクに対して実行されます。

ディープラーニングのコンテキストでの転移学習の最も一般的な具体化は、次のワーフクロウです。

  1. 以前にトレーニングしたモデルからレイヤーを取得します。
  2. それらを凍結して、将来のトレーニングラウンド中にそれらに含まれる情報の破壊を回避します。
  3. フリーズしたレイヤーの上に新しいトレーニング可能なレイヤーを追加します。彼らは古い特徴を新しいデータセットの予測に変えることを学びます。
  4. データセットの新しいレイヤーをトレーニングします。

最後のオプションのステップは、上記で取得したモデル全体(またはその一部)のフリーズを解除し、非常に低い学習率で新しいデータで再トレーニングすることからなる微調整です。これにより、事前トレーニング済みの機能を新しいデータに段階的に適応させることで、意味のある改善を実現できる可能性があります。

最初に、ほとんどの転移学習と微調整のワークフローの基礎となるKerasのtrainable APIについて詳しく説明します。

次に、ImageNetデータセットで事前トレーニングされたモデルを取得し、Kaggleの「猫対犬」分類データセットで再トレーニングすることにより、典型的なワークフローを示します。

これは、Pythonを使用したディープラーニングと 2016年のブログの投稿「非常に少ないデータを使用して強力な画像分類モデルを構築する」から採用されました

レイヤーのフリーズ: trainable属性を理解する

レイヤーとモデルには3つのウェイト属性があります。

  • weightsは、レイヤーのすべての重み変数のリストです。
  • trainable_weightsは、トレーニング中の損失を最小限に抑えるために(勾配降下法を介して)更新することを意図したもののリストです。
  • non_trainable_weightsは、トレーニングすることを意図していないもののリストです。通常、これらはフォワードパス中にモデルによって更新されます。

例: Denseレイヤーには2つのトレーニング可能な重み(カーネルとバイアス)があります

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))

weights: 2
trainable_weights: 2
non_trainable_weights: 0

一般に、すべてのウェイトはトレーニング可能なウェイトです。トレーニング不可能な重みを持つ唯一の組み込みレイヤーは、 BatchNormalizationレイヤーです。トレーニングできない重みを使用して、トレーニング中の入力の平均と分散を追跡します。独自のカスタムレイヤーでトレーニング不可能なウェイトを使用する方法については、 新しいレイヤーを最初から作成するためのガイドをご覧ください。

例: BatchNormalizationレイヤーには2つのトレーニング可能な重みと2つのトレーニング不可能な重みがあります

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))

weights: 4
trainable_weights: 2
non_trainable_weights: 2

レイヤーとモデルには、 trainableブール属性もtrainableます。その値は変更できます。 layer.trainableFalse設定すると、すべてのレイヤーのウェイトがトレーニング可能からトレーニングlayer.trainableに移動します。これはレイヤーの「フリーズ」と呼ばれます。フリーズされたレイヤーの状態は、トレーニング中に更新されません( fit()使用してトレーニングするとき、またはtrainable_weightsに依存してグラデーションの更新を適用するカスタムループを使用してトレーニングするとき)。

例: trainableFalse設定する

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))

weights: 2
trainable_weights: 0
non_trainable_weights: 2

トレーニング可能なウェイトがトレーニング不可になると、その値はトレーニング中に更新されなくなります。

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)

1/1 [==============================] - 0s 1ms/step - loss: 0.0946

layer.trainable属性を、 layer.__call__() trainingの引数と混同しないでください(レイヤーが推論モードまたはトレーニングモードでフォワードパスを実行するかどうかを制御します)。詳細については、 Keras FAQを参照してください

trainable属性の再帰的な設定

モデルまたはサブレイヤーを持つレイヤーでtrainable = Falseを設定すると、すべての子レイヤーも同様にトレーニング不可能になります。

例:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

典型的な転移学習ワークフロー

これにより、Kerasで一般的な転移学習ワークフローを実装する方法がわかります。

  1. ベースモデルをインスタンス化し、事前トレーニング済みのウェイトをそれにロードします。
  2. trainable = False設定して、ベースモデルのすべてのレイヤーをフリーズします。
  3. ベースモデルからの1つ(または複数)のレイヤーの出力の上に新しいモデルを作成します。
  4. 新しいデータセットで新しいモデルをトレーニングします。

代替の、より軽量なワークフローは次のようになることもあります:

  1. ベースモデルをインスタンス化し、事前トレーニング済みのウェイトをそれにロードします。
  2. 新しいデータセットを実行し、ベースモデルからの1つ(または複数)のレイヤーの出力を記録します。これを特徴抽出と呼びます
  3. その出力を新しい小さなモデルの入力データとして使用します。

2番目のワークフローの主な利点は、トレーニングのエポックごとに1回ではなく、データに対して1回だけベースモデルを実行することです。したがって、はるかに高速で安価です。

ただし、この2番目のワークフローの問題は、トレーニング中に新しいモデルの入力データを動的に変更できないことです。これは、たとえばデータ拡張を行うときに必要になります。転移学習は、通常、新しいデータセットのデータが少なすぎてフルスケールモデルをゼロからトレーニングできない場合に使用されます。このようなシナリオでは、データの増強が非常に重要です。したがって、以下では、最初のワークフローに焦点を当てます。

Kerasでの最初のワークフローは次のようになります。

最初に、事前トレーニング済みの重みでベースモデルをインスタンス化します。

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

次に、ベースモデルをフリーズします。

base_model.trainable = False

上に新しいモデルを作成します。

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

新しいデータでモデルをトレーニングします。

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

微調整

モデルが新しいデータに収束したら、ベースモデルのすべてまたは一部のフリーズを解除して、モデル全体を非常に低い学習率でエンドツーエンドで再トレーニングできます。

これはオプションの最後のステップであり、段階的に改善できる可能性があります。また、急速な過剰適合につながる可能性もあります。

このステップは、フリーズされたレイヤーのあるモデルが収束するようにトレーニングされたにのみ実行することが重要です。ランダムに初期化されたトレーニング可能なレイヤーと事前トレーニングされた機能を保持するトレーニング可能なレイヤーを混在させると、ランダムに初期化されたレイヤーにより、トレーニング中に非常に大きなグラデーションが更新され、事前トレーニングされた機能が破壊されます。

また、通常は非常に小さいデータセットで、最初のトレーニングよりもはるかに大きなモデルをトレーニングするため、この段階では非常に低い学習率を使用することも重要です。その結果、大きな重みの更新を適用すると、非常に迅速に過剰適合のリスクがあります。ここでは、事前トレーニングされた重みを増分的に再適応させたいだけです。

これは、ベースモデル全体の微調整を実装する方法です。

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

compile()trainableに関する重要な注意

モデルでcompile()を呼び出すことは、そのモデルの動作を「凍結」することを意味します。これは、モデルがコンパイルされたときのtrainable属性値が、 compileが再度呼び出されるまで、そのモデルの存続期間を通じて保持される必要があることを意味します。したがって、 trainable値を変更する場合は、変更を考慮に入れるために、モデルでcompile()再度呼び出してください。

BatchNormalizationレイヤーに関する重要な注意事項

多くの画像モデルにはBatchNormalizationレイヤーが含まれてBatchNormalizationます。その層は、考えられる限りのあらゆる点で特別なケースです。覚えておくべきことがいくつかあります。

  • BatchNormalizationは、トレーニング中に更新される2つのトレーニング不可能な重みが含まれています。これらは、入力の平均と分散を追跡する変数です。
  • bn_layer.trainable = Falseを設定すると、 BatchNormalizationレイヤーは推論モードで実行され、その平均と分散の統計は更新されません。これは、 ウェイトトレーナビリティと推論/トレーニングモードが2つの直交する概念であるため、一般的に他のレイヤーには当てはまりません。ただし、 BatchNormalizationレイヤーの場合、この2つは関連しています。
  • 微調整を行うためにBatchNormalizationレイヤーを含むモデルをBatchNormalization解除するときは、ベースモデルを呼び出すときにtraining=Falseを渡して、 BatchNormalizationレイヤーを推論モードに保つ必要があります。そうしないと、トレーニング不可能な重みに適用された更新により、モデルが学習した内容が突然破壊されます。

このガイドの最後にあるエンドツーエンドの例で、このパターンの動作を確認できます。

カスタムトレーニングループによる転移学習と微調整

fit()代わりに、独自の低レベルのトレーニングループを使用している場合、ワークフローは基本的に同じままです。グラデーションの更新を適用するときは、リストmodel.trainable_weightsのみを考慮するように注意する必要があります。

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

同様に微調整用。

エンドツーエンドの例:猫と犬の画像分類モデルの微調整

データセット

これらの概念を固めるために、エンドツーエンドの具体的な転移学習と微調整の例を紹介します。 ImageNetで事前トレーニングされたXceptionモデルをロードし、Kaggleの「猫対犬」分類データセットで使用します。

データを取得する

最初に、TFDSを使用して猫対犬のデータセットをフェッチします。独自のデータセットがある場合は、 tf.keras.preprocessing.image_dataset_from_directoryユーティリティを使用して、クラス固有のフォルダーにtf.keras.preprocessing.image_dataset_from_directoryされたディスク上の一連の画像から、同様のラベル付きデータセットオブジェクトを生成することをお勧めします。

Tansfer学習は、非常に小さなデータを扱う場合に最も役立ちます。データセットを小さく保つために、元のトレーニングデータの40%(25,000画像)をトレーニングに、10%を検証に、10%をテストに使用します。

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))

Downloading and preparing dataset cats_vs_dogs/4.0.0 (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0...

Warning:absl:1738 images were corrupted and were skipped

Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0.incompleteDA6BAB/cats_vs_dogs-train.tfrecord
Dataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data.
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

これらは、トレーニングデータセットの最初の9つの画像です。ご覧のとおり、これらの画像はすべてサイズが異なります。

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

また、ラベル1は「犬」であり、ラベル0は「猫」であることがわかります。

データの標準化

生の画像にはさまざまなサイズがあります。さらに、各ピクセルは0〜255(RGBレベル値)の3つの整数値で構成されます。これは、ニューラルネットワークのフィードには適していません。 2つのことを行う必要があります。

  • 固定画像サイズに標準化します。 150x150を選択します。
  • -1から1の間のピクセル値を正規化します。これは、モデル自体の一部としてNormalizationレイヤーを使用して行います。

一般に、前処理済みのデータを使用するモデルとは対照的に、生データを入力として使用するモデルを開発することをお勧めします。その理由は、モデルが前処理されたデータを想定している場合、モデルをエクスポートして他の場所(Webブラウザー、モバイルアプリ)で使用するときは常に、まったく同じ前処理パイプラインを再実装する必要があるためです。これは非常にトリッキーになります。したがって、モデルにアクセスする前に、できる限り前処理を行う必要があります。

ここでは、データパイプラインで画像のサイズ変更を行い(ディープニューラルネットワークは連続したデータのバッチしか処理できないため)、モデルを作成するときに、モデルの一部として入力値のスケーリングを行います。

画像のサイズを150x150に変更しましょう:

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

さらに、データをバッチ処理し、キャッシュとプリフェッチを使用して読み込み速度を最適化しましょう。

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

ランダムデータ拡張の使用

大きな画像データセットがない場合は、ランダムな水平反転や小さなランダムな回転など、ランダムで現実的な変換をトレーニング画像に適用して、サンプルの多様性を人為的に導入することをお勧めします。これは、モデルをトレーニングデータのさまざまな側面に公開するのに役立ちますが、過剰適合を遅くします。

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [
        layers.experimental.preprocessing.RandomFlip("horizontal"),
        layers.experimental.preprocessing.RandomRotation(0.1),
    ]
)

さまざまなランダム変換の後、最初のバッチの最初の画像がどのように見えるかを視覚化しましょう。

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[i]))
        plt.axis("off")

png

モデルを構築する

次に、前に説明した青写真に従うモデルを作成します。

ご了承ください:

  • Normalizationレイヤーを追加して、入力値(最初は[0, 255]範囲)を[-1, 1] [0, 255]範囲にスケーリングします。
  • 正規化のために、分類レイヤーの前にDropoutレイヤーを追加します。
  • 基本モデルを呼び出すときに必ずtraining=Falseを渡して、推論モードで実行されるようにします。これにより、微調整のために基本モデルをフリーズ解除した後でも、batchnorm統計が更新されません。
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be normalized
# from (0, 255) to a range (-1., +1.), the normalization layer
# does the following, outputs = (inputs - mean) / sqrt(var)
norm_layer = keras.layers.experimental.preprocessing.Normalization()
mean = np.array([127.5] * 3)
var = mean ** 2
# Scale inputs to [-1, +1]
x = norm_layer(x)
norm_layer.set_weights([mean, var])

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83689472/83683744 [==============================] - 2s 0us/step
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
normalization (Normalization (None, 150, 150, 3)       7         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,536
Trainable params: 2,049
Non-trainable params: 20,861,487
_________________________________________________________________

最上層をトレーニングする

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)

Epoch 1/20
291/291 [==============================] - 9s 31ms/step - loss: 0.1607 - binary_accuracy: 0.9313 - val_loss: 0.0872 - val_binary_accuracy: 0.9703
Epoch 2/20
291/291 [==============================] - 8s 27ms/step - loss: 0.1181 - binary_accuracy: 0.9501 - val_loss: 0.0869 - val_binary_accuracy: 0.9690
Epoch 3/20
291/291 [==============================] - 8s 27ms/step - loss: 0.1122 - binary_accuracy: 0.9525 - val_loss: 0.0809 - val_binary_accuracy: 0.9703
Epoch 4/20
291/291 [==============================] - 8s 27ms/step - loss: 0.1065 - binary_accuracy: 0.9563 - val_loss: 0.0743 - val_binary_accuracy: 0.9733
Epoch 5/20
291/291 [==============================] - 8s 27ms/step - loss: 0.1075 - binary_accuracy: 0.9560 - val_loss: 0.0746 - val_binary_accuracy: 0.9746
Epoch 6/20
291/291 [==============================] - 8s 27ms/step - loss: 0.1031 - binary_accuracy: 0.9584 - val_loss: 0.0768 - val_binary_accuracy: 0.9708
Epoch 7/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1014 - binary_accuracy: 0.9598 - val_loss: 0.0787 - val_binary_accuracy: 0.9708
Epoch 8/20
291/291 [==============================] - 8s 27ms/step - loss: 0.1030 - binary_accuracy: 0.9567 - val_loss: 0.0732 - val_binary_accuracy: 0.9729
Epoch 9/20
291/291 [==============================] - 8s 27ms/step - loss: 0.1003 - binary_accuracy: 0.9594 - val_loss: 0.0762 - val_binary_accuracy: 0.9712
Epoch 10/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0985 - binary_accuracy: 0.9606 - val_loss: 0.0744 - val_binary_accuracy: 0.9725
Epoch 11/20
291/291 [==============================] - 8s 27ms/step - loss: 0.1014 - binary_accuracy: 0.9573 - val_loss: 0.0777 - val_binary_accuracy: 0.9703
Epoch 12/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0990 - binary_accuracy: 0.9584 - val_loss: 0.0788 - val_binary_accuracy: 0.9712
Epoch 13/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0955 - binary_accuracy: 0.9612 - val_loss: 0.0978 - val_binary_accuracy: 0.9686
Epoch 14/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0946 - binary_accuracy: 0.9592 - val_loss: 0.0795 - val_binary_accuracy: 0.9699
Epoch 15/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0947 - binary_accuracy: 0.9622 - val_loss: 0.0846 - val_binary_accuracy: 0.9686
Epoch 16/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0906 - binary_accuracy: 0.9609 - val_loss: 0.0902 - val_binary_accuracy: 0.9678
Epoch 17/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0979 - binary_accuracy: 0.9610 - val_loss: 0.0741 - val_binary_accuracy: 0.9738
Epoch 18/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0952 - binary_accuracy: 0.9629 - val_loss: 0.0764 - val_binary_accuracy: 0.9721
Epoch 19/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0946 - binary_accuracy: 0.9636 - val_loss: 0.0784 - val_binary_accuracy: 0.9699
Epoch 20/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0914 - binary_accuracy: 0.9649 - val_loss: 0.0875 - val_binary_accuracy: 0.9695

<tensorflow.python.keras.callbacks.History at 0x7fd0c1a6e3c8>

モデル全体を微調整します

最後に、基本モデルをフリーズ解除し、モデル全体を低い学習率でエンドツーエンドでトレーニングします。

重要なことに、ベースモデルはトレーニング可能になりますが、モデルの構築時に呼び出したときにtraining=Falseを渡したため、推論モードで実行されています。これは、内部のバッチ正規化レイヤーがバッチ統計を更新しないことを意味します。そうした場合、彼らはこれまでにモデルが学習した表現に大混乱をもたらすでしょう。

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
normalization (Normalization (None, 150, 150, 3)       7         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,536
Trainable params: 20,809,001
Non-trainable params: 54,535
_________________________________________________________________
Epoch 1/10
  2/291 [..............................] - ETA: 17s - loss: 0.0891 - binary_accuracy: 0.9688WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0347s vs `on_train_batch_end` time: 0.0879s). Check your callbacks.

Warning:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0347s vs `on_train_batch_end` time: 0.0879s). Check your callbacks.

291/291 [==============================] - 38s 130ms/step - loss: 0.0758 - binary_accuracy: 0.9685 - val_loss: 0.0598 - val_binary_accuracy: 0.9772
Epoch 2/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0582 - binary_accuracy: 0.9772 - val_loss: 0.0531 - val_binary_accuracy: 0.9789
Epoch 3/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0460 - binary_accuracy: 0.9818 - val_loss: 0.0511 - val_binary_accuracy: 0.9794
Epoch 4/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0358 - binary_accuracy: 0.9856 - val_loss: 0.0461 - val_binary_accuracy: 0.9819
Epoch 5/10
291/291 [==============================] - 37s 129ms/step - loss: 0.0299 - binary_accuracy: 0.9893 - val_loss: 0.0555 - val_binary_accuracy: 0.9819
Epoch 6/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0270 - binary_accuracy: 0.9894 - val_loss: 0.0511 - val_binary_accuracy: 0.9819
Epoch 7/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0206 - binary_accuracy: 0.9918 - val_loss: 0.0487 - val_binary_accuracy: 0.9832
Epoch 8/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0143 - binary_accuracy: 0.9946 - val_loss: 0.0553 - val_binary_accuracy: 0.9819
Epoch 9/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0174 - binary_accuracy: 0.9936 - val_loss: 0.0542 - val_binary_accuracy: 0.9811
Epoch 10/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0135 - binary_accuracy: 0.9951 - val_loss: 0.0521 - val_binary_accuracy: 0.9819

<tensorflow.python.keras.callbacks.History at 0x7fd0c1a6e550>

10エポックの後、微調整によってここで素晴らしい改善が得られます。