転移学習と微調整

TensorFlow.orgで表示 GoogleColabで実行 GitHubでソースを表示 ノートブックをダウンロード

設定

import numpy as np
import tensorflow as tf
from tensorflow import keras

序章

転移学習は一つの問題で学習機能を取り出し、そして新しい、同様の問題にそれらを活用して構成されています。たとえば、アライグマを識別することを学んだモデルの機能は、タヌキを識別することを目的としたモデルをキックスタートするのに役立つ場合があります。

転移学習は通常、データセットのデータが少なすぎて実物大のモデルを最初からトレーニングできないタスクに対して行われます。

ディープラーニングのコンテキストでの転移学習の最も一般的な化身は、次のワークフローです。

  1. 以前にトレーニングされたモデルからレイヤーを取得します。
  2. 将来のトレーニングラウンド中に含まれる情報が破壊されないように、それらをフリーズします。
  3. 凍結したレイヤーの上に、トレーニング可能な新しいレイヤーをいくつか追加します。彼らは、古い機能を新しいデータセットの予測に変換する方法を学びます。
  4. データセットの新しいレイヤーをトレーニングします。

最後の、オプションのステップは、全体のあなたは上記で得られたモデル(またはその一部)を解凍し、非常に低い学習率を持つ新しいデータでそれを再訓練で構成されて微調整、です。これにより、事前にトレーニングされた機能を新しいデータに段階的に適応させることで、意味のある改善を実現できる可能性があります。

まず、我々はKeras以上行くtrainableほとんどの転移学習&微調整ワークフローの基礎となる詳細にAPIを、。

次に、ImageNetデータセットで事前トレーニングされたモデルを取得し、Kaggleの「猫vs犬」分類データセットで再トレーニングすることにより、一般的なワークフローを示します。

これは、から構成されているPythonのディープラーニングや2016年のブログ記事「非常に少ないデータを使用して、強力な画像分類モデルを構築します」

凍結層:理解するtrainable属性を

レイヤーとモデルには、次の3つの重み属性があります。

  • weights層のすべての重み変数のリストです。
  • trainable_weightsトレーニング中の損失を最小限にするために(勾配降下を経由して)更新されることを意図されているもののリストです。
  • non_trainable_weights訓練を受けたことを意味するものではないもののリストです。通常、これらはフォワードパス中にモデルによって更新されます。

例: Dense層2点の訓練可能な重みを有する(カーネル・バイアス)

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

一般に、すべてのウェイトはトレーニング可能なウェイトです。のみ内蔵の非トレーニング可能な重みを持っている層であるBatchNormalization層。トレーニング不可能な重みを使用して、トレーニング中の入力の平均と分散を追跡します。独自のカスタム層に非トレーニング可能な重みを使用する方法については、参照ゼロから新しいレイヤーを書くにガイドを

例: BatchNormalization層は、2つの訓練可能な重み及び2、非訓練可能な重みを有します

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

レイヤー&モデルもブール属性機能trainable 。その値は変更できます。設定layer.trainableFalse非トレーニング可能に仕込むから、すべての層の重み動きます。これは、「凍結」層と呼ばれる:凍結された層の状態は、トレーニング中に更新されることはありません(どちらかのときとトレーニングfit()またはときに依存している任意のカスタムループを備えた訓練trainable_weights勾配の更新プログラムを適用します)。

例:設定trainableするFalse

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

トレーニング可能なウェイトがトレーニング不可能になると、トレーニング中にその値が更新されなくなります。

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 1s 640ms/step - loss: 0.0945

混同しないでくださいlayer.trainable引数を持つ属性traininglayer.__call__()層が推論モードやトレーニングモードでの往路を実行する必要があるかどうかを制御し)。詳細については、 Kerasよくある質問を

再帰的な設定trainable属性

あなたが設定した場合trainable = Falseモデル上またはサブ層を持つ任意の層の上に、すべての子供層は、非トレーニング可能になります。

例:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

典型的な転送学習ワークフロー

これにより、典型的な転移学習ワークフローをKerasで実装する方法がわかります。

  1. ベースモデルをインスタンス化し、事前にトレーニングされたウェイトをモデルにロードします。
  2. 設定することで、ベースモデル内のすべてのレイヤをフリーズtrainable = False
  3. ベースモデルからの1つ(または複数)のレイヤーの出力の上に新しいモデルを作成します。
  4. 新しいデータセットで新しいモデルをトレーニングします。

別のより軽量なワークフローも次のようになる可能性があることに注意してください。

  1. ベースモデルをインスタンス化し、事前にトレーニングされたウェイトをモデルにロードします。
  2. 新しいデータセットを実行し、ベースモデルからの1つ(または複数)のレイヤーの出力を記録します。これは、特徴抽出と呼ばれています。
  3. その出力を、新しい、より小さなモデルの入力データとして使用します。

この2番目のワークフローの主な利点は、トレーニングのエポックごとに1回ではなく、データに対して1回だけ基本モデルを実行することです。そのため、はるかに高速で安価です。

ただし、この2番目のワークフローの問題は、トレーニング中に新しいモデルの入力データを動的に変更できないことです。これは、たとえばデータ拡張を行うときに必要になります。転送学習は通常、新しいデータセットのデータが少なすぎてフルスケールモデルを最初からトレーニングできない場合のタスクに使用されます。このようなシナリオでは、データ拡張が非常に重要です。したがって、以下では、最初のワークフローに焦点を当てます。

Kerasでの最初のワークフローは次のようになります。

まず、事前にトレーニングされた重みを使用して基本モデルをインスタンス化します。

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

次に、ベースモデルをフリーズします。

base_model.trainable = False

上に新しいモデルを作成します。

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

新しいデータでモデルをトレーニングします。

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

微調整

モデルが新しいデータに収束したら、基本モデルの全部または一部のフリーズを解除し、非常に低い学習率でモデル全体をエンドツーエンドで再トレーニングすることができます。

これはオプションの最後のステップであり、段階的な改善をもたらす可能性があります。また、急速な過剰適合につながる可能性もあります。この点に注意してください。

凍結された層を持つモデルが収束に訓練された後にのみ、この手順を実行することが重要です。ランダムに初期化されたトレーニング可能なレイヤーを、事前にトレーニングされた機能を保持するトレーニング可能なレイヤーと混合すると、ランダムに初期化されたレイヤーは、トレーニング中に非常に大きな勾配の更新を引き起こし、事前にトレーニングされた機能を破壊します。

また、この段階では非常に低い学習率を使用することも重要です。これは、通常は非常に小さいデータセットで、トレーニングの最初のラウンドよりもはるかに大きなモデルをトレーニングしているためです。その結果、大規模なウェイト更新を適用すると、すぐに過剰適合するリスクがあります。ここでは、事前にトレーニングされた重みを段階的に再適応させるだけです。

基本モデル全体の微調整を実装する方法は次のとおりです。

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

重要な注意compile()およびtrainable

呼び出すとcompile()モデルにすることは、そのモデルの振る舞いを「凍結」するためのものです。これは、ことを意味しtrainableになるまでのモデルがコンパイルされた時点での属性値は、そのモデルの生涯を通じて保存されるべきcompile再び呼び出されます。あなたが任意の変更された場合そのため、 trainable値を、コールに確認してcompile()変更内容を考慮されるためにお使いのモデルに再び。

重要事項BatchNormalization

多くのイメージモデルには含まれていBatchNormalization層を。その層は、考えられるすべての点で特別な場合です。覚えておくべきことがいくつかあります。

  • BatchNormalizationトレーニング中に更新されます2非トレーニング可能な重みが含まれています。これらは、入力の平均と分散を追跡する変数です。
  • あなたが設定されている場合bn_layer.trainable = FalseBatchNormalization層は、推論モードで実行され、その平均値&分散統計を更新しません。 、これは、一般的には他の層には当てはまらない重量訓練可能性および推論/トレーニングモードは、2つの直交する概念です。しかし、二人はの場合に結びついているBatchNormalization層。
  • あなたが含まれているモデルの凍結を解除するとBatchNormalization微調整を行うためにレイヤーを、あなたは維持する必要がありBatchNormalization渡すことで、推論モードでレイヤーをtraining=Falseベースモデルを呼び出すとき。そうしないと、トレーニング不可能な重みに適用された更新によって、モデルが学習した内容が突然破壊されます。

このガイドの最後にあるエンドツーエンドの例で、このパターンの動作を確認できます。

カスタムトレーニングループを使用した学習と微調整の転送

代わりの場合はfit()ワークフローの滞在は、本質的に同じ、独自の低レベルのトレーニングのループを使用しています。あなたのアカウントに、リストだけ取るように注意する必要がありますmodel.trainable_weights勾配のアップデートを適用する場合:

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

同様に微調整のために。

エンドツーエンドの例:猫と犬のデータセットでの画像分類モデルの微調整

これらの概念を固めるために、具体的なエンドツーエンドの転送学習と微調整の例を見ていきましょう。 ImageNetで事前トレーニングされたXceptionモデルをロードし、Kaggleの「猫と犬」の分類データセットで使用します。

データの取得

まず、TFDSを使用して猫と犬のデータセットを取得しましょう。あなたがあなた自身のデータセットを持っている場合は、おそらくユーティリティを使用したいと思うでしょうtf.keras.preprocessing.image_dataset_from_directoryクラス固有のフォルダに提出されたディスク上のイメージのセットから同様のラベル付きデータセットオブジェクトを生成します。

転移学習は、非常に小さなデータセットを扱うときに最も役立ちます。データセットを小さく保つために、元のトレーニングデータ(25,000枚の画像)の40%をトレーニングに、10%を検証に、10%をテストに使用します。

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

これらはトレーニングデータセットの最初の9つの画像です。ご覧のとおり、すべて異なるサイズです。

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

また、ラベル1が「犬」でラベル0が「猫」であることがわかります。

データの標準化

RAW画像にはさまざまなサイズがあります。さらに、各ピクセルは0〜255の3つの整数値(RGBレベル値)で構成されます。これは、ニューラルネットワークへの給電には最適ではありません。私たちは2つのことをする必要があります:

  • 固定画像サイズに標準化します。 150x150を選択します。
  • ノーマライズ画素値の-1と1。我々は、使用してこれをやるNormalizationモデル自体の一部として層を。

一般に、すでに前処理されたデータを取得するモデルではなく、生データを入力として取得するモデルを開発することをお勧めします。その理由は、モデルが前処理されたデータを予期している場合、モデルをエクスポートして他の場所(Webブラウザーやモバイルアプリ)で使用するときはいつでも、まったく同じ前処理パイプラインを再実装する必要があるためです。これは非常にすぐにトリッキーになります。したがって、モデルをヒットする前に、可能な限り最小限の前処理を実行する必要があります。

ここでは、データパイプラインで画像のサイズ変更を行い(ディープニューラルネットワークはデータの連続したバッチしか処理できないため)、モデルを作成するときに、モデルの一部として入力値のスケーリングを行います。

画像のサイズを150x150に変更しましょう。

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

さらに、データをバッチ処理し、キャッシュとプリフェッチを使用して読み込み速度を最適化しましょう。

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

ランダムデータ拡張の使用

大きな画像データセットがない場合は、ランダムな水平反転や小さなランダムな回転など、ランダムでありながら現実的な変換をトレーニング画像に適用して、サンプルの多様性を人為的に導入することをお勧めします。これは、過剰適合を遅くしながら、モデルをトレーニングデータのさまざまな側面に公開するのに役立ちます。

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [layers.RandomFlip("horizontal"), layers.RandomRotation(0.1),]
)

さまざまなランダム変換後の最初のバッチの最初の画像がどのように見えるかを視覚化してみましょう。

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[0]))
        plt.axis("off")
2021-09-01 18:45:34.772284: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

モデルを構築する

それでは、前に説明した青写真に従ったモデルを作成しましょう。

ご了承ください:

  • 我々は、追加Rescaling (最初にスケール入力値に層を[0, 255]の範囲) [-1, 1]範囲。
  • 私たちは、追加Dropout正則のため、分類層の前層を。
  • 私たちは、合格することを確認してくださいtraining=Falseそのbatchnorm統計は、我々は微調整のためのベースモデルの凍結を解除した後も更新されませんので、それは、推論モードで実行さそうという、基本モデルを呼び出すとき。
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83689472/83683744 [==============================] - 2s 0us/step
83697664/83683744 [==============================] - 2s 0us/step
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 2,049
Non-trainable params: 20,861,480
_________________________________________________________________

最上層をトレーニングする

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
151/291 [==============>...............] - ETA: 3s - loss: 0.1979 - binary_accuracy: 0.9096
Corrupt JPEG data: 65 extraneous bytes before marker 0xd9
268/291 [==========================>...] - ETA: 1s - loss: 0.1663 - binary_accuracy: 0.9269
Corrupt JPEG data: 239 extraneous bytes before marker 0xd9
282/291 [============================>.] - ETA: 0s - loss: 0.1628 - binary_accuracy: 0.9284
Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9
Corrupt JPEG data: 228 extraneous bytes before marker 0xd9
291/291 [==============================] - ETA: 0s - loss: 0.1620 - binary_accuracy: 0.9286
Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9
291/291 [==============================] - 29s 63ms/step - loss: 0.1620 - binary_accuracy: 0.9286 - val_loss: 0.0814 - val_binary_accuracy: 0.9686
Epoch 2/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1178 - binary_accuracy: 0.9511 - val_loss: 0.0785 - val_binary_accuracy: 0.9695
Epoch 3/20
291/291 [==============================] - 9s 30ms/step - loss: 0.1121 - binary_accuracy: 0.9536 - val_loss: 0.0748 - val_binary_accuracy: 0.9712
Epoch 4/20
291/291 [==============================] - 9s 29ms/step - loss: 0.1082 - binary_accuracy: 0.9554 - val_loss: 0.0754 - val_binary_accuracy: 0.9703
Epoch 5/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1034 - binary_accuracy: 0.9570 - val_loss: 0.0721 - val_binary_accuracy: 0.9725
Epoch 6/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0975 - binary_accuracy: 0.9602 - val_loss: 0.0748 - val_binary_accuracy: 0.9699
Epoch 7/20
291/291 [==============================] - 9s 29ms/step - loss: 0.0989 - binary_accuracy: 0.9595 - val_loss: 0.0732 - val_binary_accuracy: 0.9716
Epoch 8/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1027 - binary_accuracy: 0.9566 - val_loss: 0.0787 - val_binary_accuracy: 0.9678
Epoch 9/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0959 - binary_accuracy: 0.9614 - val_loss: 0.0734 - val_binary_accuracy: 0.9729
Epoch 10/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0995 - binary_accuracy: 0.9588 - val_loss: 0.0717 - val_binary_accuracy: 0.9721
Epoch 11/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0957 - binary_accuracy: 0.9612 - val_loss: 0.0731 - val_binary_accuracy: 0.9725
Epoch 12/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0936 - binary_accuracy: 0.9622 - val_loss: 0.0751 - val_binary_accuracy: 0.9716
Epoch 13/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0965 - binary_accuracy: 0.9610 - val_loss: 0.0821 - val_binary_accuracy: 0.9695
Epoch 14/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0939 - binary_accuracy: 0.9618 - val_loss: 0.0742 - val_binary_accuracy: 0.9712
Epoch 15/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0974 - binary_accuracy: 0.9585 - val_loss: 0.0771 - val_binary_accuracy: 0.9712
Epoch 16/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0947 - binary_accuracy: 0.9621 - val_loss: 0.0823 - val_binary_accuracy: 0.9699
Epoch 17/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0947 - binary_accuracy: 0.9625 - val_loss: 0.0718 - val_binary_accuracy: 0.9708
Epoch 18/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0928 - binary_accuracy: 0.9616 - val_loss: 0.0738 - val_binary_accuracy: 0.9716
Epoch 19/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0922 - binary_accuracy: 0.9644 - val_loss: 0.0743 - val_binary_accuracy: 0.9716
Epoch 20/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0885 - binary_accuracy: 0.9635 - val_loss: 0.0745 - val_binary_accuracy: 0.9695
<keras.callbacks.History at 0x7f849a3b2950>

モデル全体の微調整を行います

最後に、基本モデルのフリーズを解除し、低い学習率でモデル全体をエンドツーエンドでトレーニングしましょう。

ベースモデルはトレーニング可能になるが、我々が渡されたので、重要なのは、それはまだ推論モードで実行されているtraining=False我々はモデルを構築したときに、それを呼び出すとき。これは、内部のバッチ正規化レイヤーがバッチ統計を更新しないことを意味します。もしそうなら、彼らはこれまでのモデルによって学んだ表現に大混乱をもたらすでしょう。

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 20,809,001
Non-trainable params: 54,528
_________________________________________________________________
Epoch 1/10
291/291 [==============================] - 43s 131ms/step - loss: 0.0802 - binary_accuracy: 0.9692 - val_loss: 0.0580 - val_binary_accuracy: 0.9764
Epoch 2/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0542 - binary_accuracy: 0.9792 - val_loss: 0.0529 - val_binary_accuracy: 0.9764
Epoch 3/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0400 - binary_accuracy: 0.9832 - val_loss: 0.0510 - val_binary_accuracy: 0.9798
Epoch 4/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0313 - binary_accuracy: 0.9879 - val_loss: 0.0505 - val_binary_accuracy: 0.9819
Epoch 5/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0272 - binary_accuracy: 0.9904 - val_loss: 0.0485 - val_binary_accuracy: 0.9807
Epoch 6/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0284 - binary_accuracy: 0.9901 - val_loss: 0.0497 - val_binary_accuracy: 0.9824
Epoch 7/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0198 - binary_accuracy: 0.9937 - val_loss: 0.0530 - val_binary_accuracy: 0.9802
Epoch 8/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0173 - binary_accuracy: 0.9930 - val_loss: 0.0572 - val_binary_accuracy: 0.9819
Epoch 9/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0113 - binary_accuracy: 0.9958 - val_loss: 0.0555 - val_binary_accuracy: 0.9837
Epoch 10/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0091 - binary_accuracy: 0.9966 - val_loss: 0.0596 - val_binary_accuracy: 0.9832
<keras.callbacks.History at 0x7f83982d4cd0>

10エポックの後、微調整により、ここで素晴らしい改善が得られます。