Eine Frage haben? Verbinden Sie sich mit der Community im TensorFlow Forum Visit Forum

Transferlernen und Feinabstimmung

Ansicht auf TensorFlow.org In Google Colab ausführen Quelle auf GitHub anzeigen Notizbuch herunterladen

Einrichten

import numpy as np
import tensorflow as tf
from tensorflow import keras

Einführung

Transferlernen besteht darin, die für ein Problem erlernten Funktionen zu nutzen und sie für ein neues, ähnliches Problem zu nutzen. Zum Beispiel können Merkmale eines Modells, das gelernt hat, Waschbären zu identifizieren, nützlich sein, um ein Modell zur Identifizierung von Tanukis zu starten.

Das Transferlernen wird normalerweise für Aufgaben durchgeführt, bei denen Ihr Datensatz zu wenig Daten enthält, um ein vollständiges Modell von Grund auf neu zu trainieren.

Die häufigste Inkarnation des Transferlernens im Kontext des Tiefenlernens ist der folgende Workflow:

  1. Nehmen Sie Schichten von einem zuvor trainierten Modell.
  2. Frieren Sie sie ein, um zu vermeiden, dass die in zukünftigen Trainingsrunden enthaltenen Informationen zerstört werden.
  3. Fügen Sie einige neue, trainierbare Schichten über die gefrorenen Schichten hinzu. Sie lernen, die alten Funktionen in Vorhersagen für einen neuen Datensatz umzuwandeln.
  4. Trainieren Sie die neuen Ebenen in Ihrem Datensatz.

Ein letzter optionaler Schritt ist die Feinabstimmung , bei der das gesamte oben erhaltene Modell (oder ein Teil davon) freigegeben und mit einer sehr geringen Lernrate auf die neuen Daten umgeschult wird. Dies kann möglicherweise zu bedeutenden Verbesserungen führen, indem die vorab trainierten Funktionen schrittweise an die neuen Daten angepasst werden.

Zunächst werden wir die trainable Keras-API im Detail behandeln, die den meisten Transfer-Lern- und Feinabstimmungs-Workflows zugrunde liegt.

Anschließend demonstrieren wir den typischen Workflow, indem wir ein im ImageNet-Dataset vorab trainiertes Modell und im Kaggle-Klassifizierungsdatensatz "Katzen gegen Hunde" neu trainieren.

Dies wurde aus Deep Learning mit Python und dem Blog-Beitrag von 2016 "Erstellen leistungsfähiger Bildklassifizierungsmodelle mit sehr wenig Daten" übernommen .

Einfrieren von Schichten: Verständnis des trainable Attributs

Ebenen und Modelle haben drei Gewichtsattribute:

  • weights ist die Liste aller Gewichtungsvariablen der Ebene.
  • trainable_weights ist die Liste derjenigen, die aktualisiert werden sollen (über Gradientenabstieg), um den Verlust während des Trainings zu minimieren.
  • non_trainable_weights ist die Liste derjenigen, die nicht trainiert werden sollen. In der Regel werden sie vom Modell während des Vorwärtsdurchlaufs aktualisiert.

Beispiel: Die Dense Schicht hat 2 trainierbare Gewichte (Kernel & Bias)

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

Im Allgemeinen sind alle Gewichte trainierbare Gewichte. Die einzige integrierte Ebene mit nicht trainierbaren Gewichten ist die BatchNormalization Ebene. Es verwendet nicht trainierbare Gewichte, um den Mittelwert und die Varianz seiner Eingaben während des Trainings zu verfolgen. Informationen zum Verwenden nicht trainierbarer Gewichte in Ihren eigenen benutzerdefinierten Ebenen finden Sie in der Anleitung zum Schreiben neuer Ebenen von Grund auf neu .

Beispiel: Die BatchNormalization Schicht hat 2 trainierbare Gewichte und 2 nicht trainierbare Gewichte

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

Ebenen und Modelle verfügen außerdem über ein boolesches Attribut, das trainable . Sein Wert kann geändert werden. Wenn Sie layer.trainable auf False alle Gewichte der Ebene von trainierbar auf nicht trainierbar verschoben. Dies wird als "Einfrieren" der Ebene bezeichnet: Der Status einer eingefrorenen Ebene wird während des Trainings nicht aktualisiert (entweder beim Training mit fit() oder beim Training mit einer benutzerdefinierten Schleife, die auf trainable_weights angewiesen ist, um Gradientenaktualisierungen anzuwenden).

Beispiel: Einstellung trainable auf False

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

Wenn ein trainierbares Gewicht nicht mehr trainierbar ist, wird sein Wert während des Trainings nicht mehr aktualisiert.

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 1s 664ms/step - loss: 0.1025

Verwechseln Sie das Attribut layer.trainable mit dem Argument training in layer.__call__() (das steuert, ob der Layer seinen Vorwärtsdurchlauf im Inferenzmodus oder im Trainingsmodus ausführen soll). Weitere Informationen finden Sie in den Keras-FAQ .

Rekursive Einstellung des trainable Attributs

Wenn Sie trainable = False für ein Modell oder eine Ebene mit Unterebenen festlegen, werden alle untergeordneten Ebenen ebenfalls nicht trainierbar.

Beispiel:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

Der typische Workflow für Transfer-Learning

Dies führt uns dazu, wie ein typischer Transfer-Learning-Workflow in Keras implementiert werden kann:

  1. Instanziieren Sie ein Basismodell und laden Sie vorab trainierte Gewichte hinein.
  2. Frieren Sie alle Ebenen im Basismodell ein, indem Sie trainable = False .
  3. Erstellen Sie ein neues Modell über der Ausgabe einer (oder mehrerer) Ebenen aus dem Basismodell.
  4. Trainieren Sie Ihr neues Modell mit Ihrem neuen Datensatz.

Beachten Sie, dass ein alternativer, leichterer Workflow auch Folgendes sein könnte:

  1. Instanziieren Sie ein Basismodell und laden Sie vorab trainierte Gewichte hinein.
  2. Führen Sie Ihr neues Dataset durch und zeichnen Sie die Ausgabe einer (oder mehrerer) Ebenen aus dem Basismodell auf. Dies wird als Merkmalsextraktion bezeichnet .
  3. Verwenden Sie diese Ausgabe als Eingabedaten für ein neues, kleineres Modell.

Ein wesentlicher Vorteil dieses zweiten Workflows besteht darin, dass Sie das Basismodell nur einmal für Ihre Daten ausführen und nicht einmal pro Trainingsepoche. Es ist also viel schneller und billiger.

Ein Problem bei diesem zweiten Workflow besteht jedoch darin, dass Sie die Eingabedaten Ihres neuen Modells während des Trainings nicht dynamisch ändern können, was beispielsweise bei der Datenerweiterung erforderlich ist. Transferlernen wird normalerweise für Aufgaben verwendet, bei denen Ihr neuer Datensatz zu wenig Daten enthält, um ein vollständiges Modell von Grund auf neu zu trainieren. In solchen Szenarien ist die Datenerweiterung sehr wichtig. Im Folgenden konzentrieren wir uns auf den ersten Workflow.

So sieht der erste Workflow in Keras aus:

Instanziieren Sie zunächst ein Basismodell mit vorab trainierten Gewichten.

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

Frieren Sie dann das Basismodell ein.

base_model.trainable = False

Erstellen Sie oben ein neues Modell.

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

Trainieren Sie das Modell mit neuen Daten.

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

Feintuning

Sobald Ihr Modell mit den neuen Daten konvergiert hat, können Sie versuchen, das Basismodell ganz oder teilweise freizugeben und das gesamte Modell mit einer sehr geringen Lernrate durchgängig zu trainieren.

Dies ist ein optionaler letzter Schritt, der möglicherweise zu inkrementellen Verbesserungen führen kann. Dies kann möglicherweise auch zu einer schnellen Überanpassung führen - denken Sie daran.

Es ist wichtig, diesen Schritt erst auszuführen, nachdem das Modell mit gefrorenen Schichten auf Konvergenz trainiert wurde. Wenn Sie zufällig initialisierte trainierbare Ebenen mit trainierbaren Ebenen mischen, die vorab trainierte Funktionen enthalten, verursachen die zufällig initialisierten Ebenen während des Trainings sehr große Gradientenaktualisierungen, die Ihre vorab trainierten Funktionen zerstören.

Es ist auch wichtig, in dieser Phase eine sehr niedrige Lernrate zu verwenden, da Sie ein viel größeres Modell als in der ersten Trainingsrunde mit einem Datensatz trainieren, der normalerweise sehr klein ist. Infolgedessen besteht die Gefahr, dass Sie sehr schnell überanpassen, wenn Sie große Gewichtsaktualisierungen vornehmen. Hier möchten Sie die vorab trainierten Gewichte nur inkrementell neu anpassen.

So implementieren Sie die Feinabstimmung des gesamten Basismodells:

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

Wichtiger Hinweis zu compile() und trainable

Das Aufrufen von compile() für ein Modell soll das Verhalten dieses Modells "einfrieren". Dies bedeutet, dass die trainable Attributwerte zum Zeitpunkt der Kompilierung des Modells während der gesamten Lebensdauer dieses Modells erhalten bleiben sollten, bis die compile erneut aufgerufen wird. Wenn Sie einen trainable Wert ändern, rufen Sie compile() in Ihrem Modell erneut auf, damit Ihre Änderungen berücksichtigt werden.

Wichtige Hinweise zur BatchNormalization Ebene

Viele BatchNormalization enthalten BatchNormalization Ebenen. Diese Schicht ist in jeder erdenklichen Hinsicht ein Sonderfall. Hier sind einige Dinge zu beachten.

  • BatchNormalization enthält 2 nicht trainierbare Gewichte, die während des Trainings aktualisiert werden. Dies sind die Variablen, die den Mittelwert und die Varianz der Eingaben verfolgen.
  • Wenn Sie bn_layer.trainable = False BatchNormalization , wird die BatchNormalization Ebene im Inferenzmodus ausgeführt und aktualisiert ihre Mittelwert- und Varianzstatistik nicht. Dies ist bei anderen Schichten im Allgemeinen nicht der Fall, da Gewichtstrainings- und Inferenz- / Trainingsmodi zwei orthogonale Konzepte sind . Bei der BatchNormalization Schicht sind die beiden jedoch miteinander verbunden.
  • Wenn Sie ein Modell, das BatchNormalization Ebenen enthält, BatchNormalization , um eine Feinabstimmung BatchNormalization , sollten Sie die BatchNormalization Ebenen im Inferenzmodus BatchNormalization , indem Sie beim Aufrufen des Basismodells training=False . Andernfalls zerstören die Aktualisierungen der nicht trainierbaren Gewichte plötzlich das, was das Modell gelernt hat.

Sie werden dieses Muster im End-to-End-Beispiel am Ende dieses Handbuchs in Aktion sehen.

Übertragen Sie Lernen und Feinabstimmung mit einer benutzerdefinierten Trainingsschleife

Wenn Sie anstelle von fit() eine eigene Low-Level-Trainingsschleife verwenden, bleibt der Workflow im Wesentlichen gleich. Sie sollten darauf achten, nur die Liste model.trainable_weights zu berücksichtigen, wenn Sie Verlaufsaktualisierungen anwenden:

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

Ebenso zur Feinabstimmung.

Ein End-to-End-Beispiel: Feinabstimmung eines Bildklassifizierungsmodells für einen Datensatz zwischen Katzen und Hunden

Um diese Konzepte zu festigen, führen wir Sie durch ein konkretes Beispiel für ein umfassendes Transferlernen und eine Feinabstimmung. Wir werden das in ImageNet vorab trainierte Xception-Modell laden und es im Kaggle-Klassifizierungsdatensatz "Katzen gegen Hunde" verwenden.

Daten abrufen

Lassen Sie uns zunächst den Datensatz Katzen gegen Hunde mit TFDS abrufen. Wenn Sie über ein eigenes Dataset verfügen, möchten Sie wahrscheinlich das Dienstprogramm tf.keras.preprocessing.image_dataset_from_directory , um ähnlich beschriftete Dataset-Objekte aus einer Reihe von Bildern auf der Festplatte zu generieren, die in klassenspezifischen Ordnern abgelegt sind.

Transferlernen ist am nützlichsten, wenn Sie mit sehr kleinen Datensätzen arbeiten. Um unseren Datensatz klein zu halten, verwenden wir 40% der ursprünglichen Trainingsdaten (25.000 Bilder) für das Training, 10% für die Validierung und 10% für Tests.

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

Dies sind die ersten 9 Bilder im Trainingsdatensatz - wie Sie sehen können, sind sie alle unterschiedlich groß.

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

Wir können auch sehen, dass Etikett 1 "Hund" und Etikett 0 "Katze" ist.

Standardisierung der Daten

Unsere Rohbilder haben verschiedene Größen. Zusätzlich besteht jedes Pixel aus 3 ganzzahligen Werten zwischen 0 und 255 (RGB-Pegelwerte). Dies ist nicht besonders geeignet, um ein neuronales Netzwerk zu versorgen. Wir müssen zwei Dinge tun:

  • Standardisieren Sie auf eine feste Bildgröße. Wir wählen 150x150.
  • Normalisieren Sie Pixelwerte zwischen -1 und 1. Wir verwenden dazu eine Normalization als Teil des Modells.

Im Allgemeinen empfiehlt es sich, Modelle zu entwickeln, die Rohdaten als Eingabe verwenden, im Gegensatz zu Modellen, die bereits vorverarbeitete Daten verwenden. Der Grund dafür ist, dass Sie, wenn Ihr Modell vorverarbeitete Daten erwartet, jedes Mal, wenn Sie Ihr Modell exportieren, um es an anderer Stelle zu verwenden (in einem Webbrowser, in einer mobilen App), genau dieselbe Vorverarbeitungspipeline neu implementieren müssen. Dies wird sehr schnell sehr schwierig. Wir sollten also die geringstmögliche Vorverarbeitung durchführen, bevor wir das Modell treffen.

Hier werden wir die Bildgröße in der Datenpipeline ändern (da ein tiefes neuronales Netzwerk nur zusammenhängende Datenstapel verarbeiten kann) und die Skalierung der Eingabewerte als Teil des Modells durchführen, wenn wir es erstellen.

Lassen Sie uns die Größe der Bilder auf 150 x 150 ändern:

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

Lassen Sie uns außerdem die Daten stapeln und Caching & Prefetching verwenden, um die Ladegeschwindigkeit zu optimieren.

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

Zufällige Datenerweiterung verwenden

Wenn Sie keinen großen Bilddatensatz haben, empfiehlt es sich, die Probendiversität künstlich einzuführen, indem Sie zufällige, aber realistische Transformationen auf die Trainingsbilder anwenden, z. B. zufälliges horizontales Spiegeln oder kleine zufällige Rotationen. Dies hilft, das Modell verschiedenen Aspekten der Trainingsdaten auszusetzen und gleichzeitig die Überanpassung zu verlangsamen.

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [
        layers.experimental.preprocessing.RandomFlip("horizontal"),
        layers.experimental.preprocessing.RandomRotation(0.1),
    ]
)

Lassen Sie uns visualisieren, wie das erste Bild des ersten Stapels nach verschiedenen zufälligen Transformationen aussieht:

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[0]))
        plt.axis("off")

png

Ein Modell bauen

Lassen Sie uns nun ein Modell erstellen, das dem zuvor erläuterten Entwurf folgt.

Beachten Sie, dass:

  • Wir fügen eine Normalization , um Eingabewerte (anfänglich im Bereich [0, 255] ) auf den Bereich [-1, 1] zu skalieren.
  • Dropout Regularisierung fügen wir vor der Klassifizierungsebene eine Dropout Ebene hinzu.
  • Wir stellen sicher, dass training=False beim Aufrufen des Basismodells bestanden wird, damit es im Inferenzmodus ausgeführt wird, damit die Batchnorm-Statistiken auch dann nicht aktualisiert werden, wenn wir das Basismodell zur Feinabstimmung freigeben.
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be normalized
# from (0, 255) to a range (-1., +1.), the normalization layer
# does the following, outputs = (inputs - mean) / sqrt(var)
norm_layer = keras.layers.experimental.preprocessing.Normalization()
mean = np.array([127.5] * 3)
var = mean ** 2
# Scale inputs to [-1, +1]
x = norm_layer(x)
norm_layer.set_weights([mean, var])

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83689472/83683744 [==============================] - 1s 0us/step
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
normalization (Normalization (None, 150, 150, 3)       7         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,536
Trainable params: 2,049
Non-trainable params: 20,861,487
_________________________________________________________________

Trainiere die oberste Schicht

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
291/291 [==============================] - 20s 49ms/step - loss: 0.2226 - binary_accuracy: 0.8972 - val_loss: 0.0805 - val_binary_accuracy: 0.9703
Epoch 2/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1246 - binary_accuracy: 0.9464 - val_loss: 0.0757 - val_binary_accuracy: 0.9712
Epoch 3/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1153 - binary_accuracy: 0.9480 - val_loss: 0.0724 - val_binary_accuracy: 0.9733
Epoch 4/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1055 - binary_accuracy: 0.9575 - val_loss: 0.0753 - val_binary_accuracy: 0.9721
Epoch 5/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1026 - binary_accuracy: 0.9589 - val_loss: 0.0750 - val_binary_accuracy: 0.9703
Epoch 6/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1022 - binary_accuracy: 0.9587 - val_loss: 0.0723 - val_binary_accuracy: 0.9716
Epoch 7/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1009 - binary_accuracy: 0.9570 - val_loss: 0.0731 - val_binary_accuracy: 0.9708
Epoch 8/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0947 - binary_accuracy: 0.9576 - val_loss: 0.0726 - val_binary_accuracy: 0.9716
Epoch 9/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0872 - binary_accuracy: 0.9624 - val_loss: 0.0720 - val_binary_accuracy: 0.9712
Epoch 10/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0892 - binary_accuracy: 0.9622 - val_loss: 0.0711 - val_binary_accuracy: 0.9716
Epoch 11/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0987 - binary_accuracy: 0.9608 - val_loss: 0.0752 - val_binary_accuracy: 0.9712
Epoch 12/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0962 - binary_accuracy: 0.9595 - val_loss: 0.0715 - val_binary_accuracy: 0.9738
Epoch 13/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0972 - binary_accuracy: 0.9606 - val_loss: 0.0700 - val_binary_accuracy: 0.9725
Epoch 14/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1019 - binary_accuracy: 0.9568 - val_loss: 0.0779 - val_binary_accuracy: 0.9690
Epoch 15/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0929 - binary_accuracy: 0.9614 - val_loss: 0.0700 - val_binary_accuracy: 0.9729
Epoch 16/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0937 - binary_accuracy: 0.9610 - val_loss: 0.0698 - val_binary_accuracy: 0.9742
Epoch 17/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0945 - binary_accuracy: 0.9613 - val_loss: 0.0671 - val_binary_accuracy: 0.9759
Epoch 18/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0868 - binary_accuracy: 0.9612 - val_loss: 0.0692 - val_binary_accuracy: 0.9738
Epoch 19/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0871 - binary_accuracy: 0.9647 - val_loss: 0.0691 - val_binary_accuracy: 0.9746
Epoch 20/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0922 - binary_accuracy: 0.9603 - val_loss: 0.0721 - val_binary_accuracy: 0.9738
<tensorflow.python.keras.callbacks.History at 0x7fb73f231860>

Führen Sie eine Feinabstimmung des gesamten Modells durch

Lassen Sie uns abschließend das Basismodell auftauen und das gesamte Modell mit einer geringen Lernrate durchgängig trainieren.

Obwohl das Basismodell trainierbar wird, wird es immer noch im Inferenzmodus ausgeführt, da wir beim Aufrufen des Modells beim Aufrufen von training=False . Dies bedeutet, dass die darin enthaltenen Batch-Normalisierungsebenen ihre Batch-Statistiken nicht aktualisieren. Wenn sie dies tun würden, würden sie die Darstellungen, die das Modell bisher gelernt hat, zerstören.

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
normalization (Normalization (None, 150, 150, 3)       7         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,536
Trainable params: 20,809,001
Non-trainable params: 54,535
_________________________________________________________________
Epoch 1/10
291/291 [==============================] - 43s 133ms/step - loss: 0.0814 - binary_accuracy: 0.9677 - val_loss: 0.0527 - val_binary_accuracy: 0.9776
Epoch 2/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0544 - binary_accuracy: 0.9796 - val_loss: 0.0537 - val_binary_accuracy: 0.9776
Epoch 3/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0481 - binary_accuracy: 0.9822 - val_loss: 0.0471 - val_binary_accuracy: 0.9789
Epoch 4/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0324 - binary_accuracy: 0.9871 - val_loss: 0.0551 - val_binary_accuracy: 0.9807
Epoch 5/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0298 - binary_accuracy: 0.9899 - val_loss: 0.0447 - val_binary_accuracy: 0.9807
Epoch 6/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0262 - binary_accuracy: 0.9901 - val_loss: 0.0469 - val_binary_accuracy: 0.9824
Epoch 7/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0242 - binary_accuracy: 0.9918 - val_loss: 0.0539 - val_binary_accuracy: 0.9798
Epoch 8/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0153 - binary_accuracy: 0.9935 - val_loss: 0.0644 - val_binary_accuracy: 0.9794
Epoch 9/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0175 - binary_accuracy: 0.9934 - val_loss: 0.0496 - val_binary_accuracy: 0.9819
Epoch 10/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0171 - binary_accuracy: 0.9936 - val_loss: 0.0496 - val_binary_accuracy: 0.9828
<tensorflow.python.keras.callbacks.History at 0x7fb74f74f940>

Nach 10 Epochen bringt uns die Feinabstimmung hier eine schöne Verbesserung.