Google I/O to frajda! Nadrobić zaległości w sesjach TensorFlow Zobacz sesje

Transfer uczenia się i dostrajania

Zobacz na TensorFlow.org Uruchom w Google Colab Wyświetl źródło na GitHub Pobierz notatnik

Ustawiać

import numpy as np
import tensorflow as tf
from tensorflow import keras

Wstęp

Learning transferu polega na pobraniu funkcje wyuczone na jeden problem, i wykorzystanie ich w nowym, podobnym problemem. Na przykład cechy modelu, który nauczył się rozpoznawać szopy pracze, mogą być przydatne do uruchomienia modelu przeznaczonego do identyfikacji tanuki.

Nauka transferu jest zwykle wykonywana w przypadku zadań, w których zestaw danych zawiera zbyt mało danych, aby wytrenować model w pełnej skali od podstaw.

Najczęstszym wcieleniem transfer learning w kontekście deep learningu jest następujący przepływ pracy:

  1. Pobierz warstwy z wcześniej wytrenowanego modelu.
  2. Zamroź je, aby podczas przyszłych rund treningowych nie zniszczyć jakichkolwiek zawartych w nich informacji.
  3. Dodaj kilka nowych, nadających się do trenowania warstw na wierzchu zamrożonych warstw. Nauczą się przekształcać stare funkcje w prognozy na nowym zbiorze danych.
  4. Trenuj nowe warstwy w swoim zbiorze danych.

Ostatnim, opcjonalny krok, to dostrajanie, który składa się z odmrożenie cały model uzyskany powyżej (lub jego część) i przekwalifikowanie go na nowych danych z bardzo małą szybkością uczenia się. Może to potencjalnie osiągnąć znaczące ulepszenia, stopniowo dostosowując wstępnie wytrenowane funkcje do nowych danych.

Najpierw pojedziemy nad Keras trainable API w szczegółach, które leży u podstaw większości uczenia Transfer & dostrajających przepływów pracy.

Następnie zademonstrujemy typowy przepływ pracy, biorąc model wstępnie przeszkolony w zestawie danych ImageNet i przeszkolając go ponownie w zestawie danych klasyfikacji „koty kontra psy” Kaggle.

To jest adaptacją głębokie nauki z Python i 2016 blogu „budowanie potężnych modeli klasyfikacyjnych obraz za pomocą bardzo mało danych” .

Zamrażanie warstw: Zrozumienie trainable atrybut

Warstwy i modele mają trzy atrybuty wagi:

  • weights lista wszystkich wag zmiennych warstwy.
  • trainable_weights lista tych, które mają być aktualizowane (poprzez zejście gradientu), aby zminimalizować straty podczas treningu.
  • non_trainable_weights lista tych, które nie mają być przeszkoleni. Zazwyczaj są one aktualizowane przez model podczas przejścia do przodu.

Przykład: Dense warstwa ma 2 nadającego masy (jądro i polaryzacji)

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

Ogólnie rzecz biorąc, wszystkie ciężary są ciężarami, które można trenować. Jedyny wbudowany w warstwę, która ma zakaz wyszkolić ciężarów jest BatchNormalization warstwa. Wykorzystuje ciężary, których nie można trenować, aby śledzić średnią i wariancję danych wejściowych podczas treningu. Aby dowiedzieć się, jak używać non-wyszkolić ciężarów we własnych niestandardowych warstw, zobacz przewodnik pisanie nowych warstw od zera .

Przykład: BatchNormalization warstwa ma 2 nadającego wagi i 2 nie nadającego wag

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

Warstwy i modele są również wyposażone w logiczną atrybutu trainable . Jego wartość można zmienić. Ustawianie layer.trainable do False ruchów wszystkie ciężary warstwa jest z wyszkolić do nieprzestrzegania wyszkolić. Nazywa się to „zamrożenie” warstwa: stan zamarzniętej warstwy nie będą aktualizowane w trakcie szkolenia (zarówno podczas szkolenia z fit() lub gdy szkolenie z dowolnej niestandardowej pętli, która opiera się na trainable_weights zastosować aktualizacje gradient).

Przykład: ustawiania trainable do False

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

Kiedy waga możliwa do trenowania staje się niemożliwa do wytrenowania, jej wartość nie jest już aktualizowana podczas treningu.

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 1s 640ms/step - loss: 0.0945

Nie należy mylić layer.trainable atrybut z argumentem training w layer.__call__() (która określa, czy warstwa powinna prowadzić swoje podaniu w trybie wnioskowania lub trybu treningowego). Aby uzyskać więcej informacji, zobacz Keras nas .

Rekurencyjne ustawienie trainable atrybutu

Jeśli ustawisz trainable = False na modelu lub na dowolnej warstwy, która ma podwarstwy, wszystkie dzieci warstwy stać non-wyszkolić również.

Przykład:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

Typowy przepływ pracy typu transfer-learning

To prowadzi nas do tego, jak typowy przepływ uczenia się transferowego można wdrożyć w Keras:

  1. Utwórz wystąpienie modelu podstawowego i załaduj do niego wstępnie wytrenowane wagi.
  2. Zamrozić wszystkie warstwy w modelu bazowym poprzez ustawienie trainable = False .
  3. Utwórz nowy model na podstawie danych wyjściowych jednej (lub kilku) warstw z modelu podstawowego.
  4. Wytrenuj nowy model na nowym zbiorze danych.

Zwróć uwagę, że alternatywnym, lżejszym przepływem pracy może być również:

  1. Utwórz wystąpienie modelu podstawowego i załaduj do niego wstępnie wytrenowane wagi.
  2. Przeprowadź przez niego nowy zestaw danych i zapisz dane wyjściowe jednej (lub kilku) warstw z modelu podstawowego. Jest to tak zwana funkcja ekstrakcji.
  3. Użyj tych danych wyjściowych jako danych wejściowych dla nowego, mniejszego modelu.

Kluczową zaletą tego drugiego przepływu pracy jest to, że model podstawowy jest uruchamiany tylko raz na danych, a nie raz na epokę uczenia. Więc jest o wiele szybciej i taniej.

Problem z tym drugim przepływem pracy polega jednak na tym, że nie pozwala on na dynamiczną modyfikację danych wejściowych nowego modelu podczas uczenia, co jest wymagane na przykład podczas rozszerzania danych. Uczenie się przenoszenia jest zwykle używane w przypadku zadań, w których nowy zestaw danych zawiera zbyt mało danych, aby można było wytrenować model w pełnej skali od podstaw, a w takich scenariuszach bardzo ważne jest rozszerzanie danych. W dalszej części skupimy się na pierwszym przepływie pracy.

Oto jak wygląda pierwszy przepływ pracy w Keras:

Najpierw utwórz wystąpienie modelu podstawowego ze wstępnie wytrenowanymi wagami.

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

Następnie zamroź model podstawowy.

base_model.trainable = False

Utwórz nowy model na górze.

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

Trenuj model na nowych danych.

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

Strojenie

Gdy model osiągnie zbieżność na nowych danych, możesz spróbować odblokować całość lub część modelu podstawowego i przeszkolić cały model od początku do końca z bardzo niskim współczynnikiem uczenia.

Jest to opcjonalny ostatni krok, który może potencjalnie zapewnić stopniową poprawę. Może to również potencjalnie prowadzić do szybkiego overfittingu – miej to na uwadze.

Bardzo ważne jest, aby tylko zrobić ten krok po model z zamrożonych warstw został przeszkolony do konwergencji. Jeśli zmieszasz losowo inicjowane warstwy możliwe do trenowania z warstwami możliwymi do trenowania, które zawierają wstępnie wytrenowane funkcje, losowo zainicjowane warstwy spowodują bardzo duże aktualizacje gradientu podczas treningu, co zniszczy wstępnie wytrenowane funkcje.

Bardzo ważne jest również użycie bardzo niskiego współczynnika uczenia się na tym etapie, ponieważ trenujesz znacznie większy model niż w pierwszej rundzie uczenia, na zestawie danych, który jest zwykle bardzo mały. W rezultacie istnieje ryzyko bardzo szybkiego przeciążenia, jeśli zastosujesz duże aktualizacje wagi. Tutaj chcesz tylko dostosować wstępnie wytrenowane wagi w sposób przyrostowy.

Oto jak zaimplementować dostrojenie całego modelu podstawowego:

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

Ważna uwaga o compile() i trainable

Wywoływanie compile() na modelu rozumie się „zamrożenia” zachowanie tego modelu. Oznacza to, że trainable wartości atrybutów w czasie model jest kompilowany powinny być zachowane przez cały okres użytkowania tego modelu, aż do compile nazywa się ponownie. Stąd, jeśli zmienić dowolny trainable wartość, upewnij się, aby zadzwonić do compile() ponownie w modelu na zmiany mają być brane pod uwagę.

Ważne informacje o BatchNormalization warstwie

Wiele modeli graficznych zawierają BatchNormalization warstw. Ta warstwa jest szczególnym przypadkiem pod każdym możliwym względem. Oto kilka rzeczy, o których należy pamiętać.

  • BatchNormalization zawiera 2 non-wyszkolić ciężary, które aktualizowane w czasie treningu. Są to zmienne śledzące średnią i wariancję danych wejściowych.
  • Po ustawieniu bn_layer.trainable = False The BatchNormalization warstwa będzie działać w trybie wnioskowania i nie aktualizuje swoich średnich i wariancji statystyk. To nie jest sprawa dla innych warstw Ogólnie, jak waga trainability & wnioskowania tryby szkolenia / są dwie prostopadłe koncepcje . Ale dwa są związane w przypadku BatchNormalization warstwy.
  • Kiedy odmrozić model, który zawiera BatchNormalization warstw w tym celu dostrajania, należy zachować BatchNormalization warstwy w trybie wnioskowania o przejściu training=False podczas wywoływania modelu bazowego. W przeciwnym razie aktualizacje zastosowane do wag, których nie można wyszkolić, nagle zniszczą to, czego nauczył się model.

Zobaczysz ten wzorzec w akcji w kompletnym przykładzie na końcu tego przewodnika.

Przenieś naukę i dostrajanie za pomocą niestandardowej pętli treningowej

Jeśli zamiast fit() , używasz własną pętlę szkoleniowy niskim poziomie, pobyty workflow w zasadzie takie same. Należy uważać, aby wziąć pod uwagę tylko listy model.trainable_weights podczas stosowania aktualizacji gradient:

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

Podobnie do dostrajania.

Kompleksowy przykład: dopracowanie modelu klasyfikacji obrazów w zestawie danych koty i psy

Aby utrwalić te koncepcje, przeprowadźmy Cię przez konkretny przykład kompleksowego uczenia się i dostrajania. Załadujemy model Xception, wstępnie wytrenowany w ImageNet, i użyjemy go w zestawie danych klasyfikacji „koty kontra psy” Kaggle.

Uzyskiwanie danych

Najpierw pobierzmy zestaw danych koty kontra psy za pomocą TFDS. Jeśli masz swój własny zestaw danych, prawdopodobnie będziesz chciał użyć narzędzia tf.keras.preprocessing.image_dataset_from_directory generować podobne obiekty oznaczone zestaw danych ze zbioru obrazów na dysku złożone w foldery klasy specyficzne.

Uczenie się transferu jest najbardziej przydatne podczas pracy z bardzo małymi zestawami danych. Aby nasz zestaw danych był niewielki, użyjemy 40% oryginalnych danych treningowych (25 000 obrazów) do trenowania, 10% do walidacji i 10% do testowania.

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

Oto pierwsze 9 obrazów w treningowym zbiorze danych — jak widać, wszystkie mają różne rozmiary.

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

Widzimy również, że etykieta 1 to „pies”, a etykieta 0 to „kot”.

Standaryzacja danych

Nasze surowe obrazy mają różne rozmiary. Ponadto każdy piksel składa się z 3 wartości całkowitych z zakresu od 0 do 255 (wartości na poziomie RGB). Nie jest to idealne rozwiązanie do zasilania sieci neuronowej. Musimy zrobić 2 rzeczy:

  • Standaryzuj do stałego rozmiaru obrazu. Wybieramy 150x150.
  • Znormalizować wartości pikseli pomiędzy -1 a 1. Będziemy to robić za pomocą Normalization warstwę jako część samego modelu.

Ogólnie rzecz biorąc, dobrą praktyką jest tworzenie modeli, które pobierają nieprzetworzone dane jako dane wejściowe, w przeciwieństwie do modeli, które przyjmują już wstępnie przetworzone dane. Powodem jest to, że jeśli model oczekuje wstępnie przetworzonych danych, za każdym razem, gdy eksportujesz model, aby użyć go w innym miejscu (w przeglądarce internetowej, w aplikacji mobilnej), będziesz musiał ponownie zaimplementować dokładnie ten sam potok przetwarzania wstępnego. To bardzo szybko staje się trudne. Powinniśmy więc wykonać najmniejszą możliwą ilość wstępnego przetwarzania przed uderzeniem w model.

Tutaj dokonamy zmiany rozmiaru obrazu w potoku danych (ponieważ głęboka sieć neuronowa może przetwarzać tylko ciągłe partie danych) i wykonamy skalowanie wartości wejściowej jako część modelu podczas jego tworzenia.

Zmieńmy rozmiar obrazów na 150x150:

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

Poza tym zbierzmy dane i użyjmy buforowania i pobierania z wyprzedzeniem, aby zoptymalizować prędkość ładowania.

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

Korzystanie z losowego zwiększania danych

Jeśli nie masz dużego zestawu danych obrazu, dobrą praktyką jest sztuczne wprowadzenie różnorodności próbek przez zastosowanie losowych, ale realistycznych przekształceń do obrazów szkoleniowych, takich jak losowe odwracanie w poziomie lub małe losowe obroty. Pomaga to wystawić model na różne aspekty danych uczących, jednocześnie spowalniając nadmierne dopasowanie.

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [layers.RandomFlip("horizontal"), layers.RandomRotation(0.1),]
)

Wyobraźmy sobie, jak wygląda pierwszy obraz pierwszej partii po różnych losowych przekształceniach:

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[0]))
        plt.axis("off")
2021-09-01 18:45:34.772284: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

Zbudować model

Teraz zbudujmy model zgodny z planem, który wyjaśniliśmy wcześniej.

Zwróć uwagę, że:

  • Dodajmy do Rescaling warstwy do wartości wejściowych skalę (początkowo w [0, 255] zakresu) w [-1, 1] zakresu.
  • Dodamy Dropout warstwy przed nałożeniem warstwy klasyfikacji, dla uregulowania.
  • Dbamy o to, aby przejść training=False podczas wywoływania modelu bazowego, tak, że działa w trybie wnioskowania, dzięki czemu statystyki batchnorm nie aktualizowane nawet po odmrozić model bazowy dostrajających.
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83689472/83683744 [==============================] - 2s 0us/step
83697664/83683744 [==============================] - 2s 0us/step
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 2,049
Non-trainable params: 20,861,480
_________________________________________________________________

Trenuj górną warstwę

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
151/291 [==============>...............] - ETA: 3s - loss: 0.1979 - binary_accuracy: 0.9096
Corrupt JPEG data: 65 extraneous bytes before marker 0xd9
268/291 [==========================>...] - ETA: 1s - loss: 0.1663 - binary_accuracy: 0.9269
Corrupt JPEG data: 239 extraneous bytes before marker 0xd9
282/291 [============================>.] - ETA: 0s - loss: 0.1628 - binary_accuracy: 0.9284
Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9
Corrupt JPEG data: 228 extraneous bytes before marker 0xd9
291/291 [==============================] - ETA: 0s - loss: 0.1620 - binary_accuracy: 0.9286
Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9
291/291 [==============================] - 29s 63ms/step - loss: 0.1620 - binary_accuracy: 0.9286 - val_loss: 0.0814 - val_binary_accuracy: 0.9686
Epoch 2/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1178 - binary_accuracy: 0.9511 - val_loss: 0.0785 - val_binary_accuracy: 0.9695
Epoch 3/20
291/291 [==============================] - 9s 30ms/step - loss: 0.1121 - binary_accuracy: 0.9536 - val_loss: 0.0748 - val_binary_accuracy: 0.9712
Epoch 4/20
291/291 [==============================] - 9s 29ms/step - loss: 0.1082 - binary_accuracy: 0.9554 - val_loss: 0.0754 - val_binary_accuracy: 0.9703
Epoch 5/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1034 - binary_accuracy: 0.9570 - val_loss: 0.0721 - val_binary_accuracy: 0.9725
Epoch 6/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0975 - binary_accuracy: 0.9602 - val_loss: 0.0748 - val_binary_accuracy: 0.9699
Epoch 7/20
291/291 [==============================] - 9s 29ms/step - loss: 0.0989 - binary_accuracy: 0.9595 - val_loss: 0.0732 - val_binary_accuracy: 0.9716
Epoch 8/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1027 - binary_accuracy: 0.9566 - val_loss: 0.0787 - val_binary_accuracy: 0.9678
Epoch 9/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0959 - binary_accuracy: 0.9614 - val_loss: 0.0734 - val_binary_accuracy: 0.9729
Epoch 10/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0995 - binary_accuracy: 0.9588 - val_loss: 0.0717 - val_binary_accuracy: 0.9721
Epoch 11/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0957 - binary_accuracy: 0.9612 - val_loss: 0.0731 - val_binary_accuracy: 0.9725
Epoch 12/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0936 - binary_accuracy: 0.9622 - val_loss: 0.0751 - val_binary_accuracy: 0.9716
Epoch 13/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0965 - binary_accuracy: 0.9610 - val_loss: 0.0821 - val_binary_accuracy: 0.9695
Epoch 14/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0939 - binary_accuracy: 0.9618 - val_loss: 0.0742 - val_binary_accuracy: 0.9712
Epoch 15/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0974 - binary_accuracy: 0.9585 - val_loss: 0.0771 - val_binary_accuracy: 0.9712
Epoch 16/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0947 - binary_accuracy: 0.9621 - val_loss: 0.0823 - val_binary_accuracy: 0.9699
Epoch 17/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0947 - binary_accuracy: 0.9625 - val_loss: 0.0718 - val_binary_accuracy: 0.9708
Epoch 18/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0928 - binary_accuracy: 0.9616 - val_loss: 0.0738 - val_binary_accuracy: 0.9716
Epoch 19/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0922 - binary_accuracy: 0.9644 - val_loss: 0.0743 - val_binary_accuracy: 0.9716
Epoch 20/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0885 - binary_accuracy: 0.9635 - val_loss: 0.0745 - val_binary_accuracy: 0.9695
<keras.callbacks.History at 0x7f849a3b2950>

Wykonaj rundę dostrajania całego modelu

Na koniec odblokujmy model podstawowy i wytrenujmy cały model od początku do końca z niskim współczynnikiem uczenia się.

Co ważne, mimo że model podstawowy staje się wyszkolić, to nadal działa w trybie wnioskowania odkąd przeszedł training=False , gdy dzwoni, gdy zbudowaliśmy model. Oznacza to, że znajdujące się wewnątrz warstwy normalizacji wsadowej nie będą aktualizować swoich statystyk wsadowych. Gdyby to zrobili, zniszczyliby reprezentacje poznane przez model do tej pory.

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 20,809,001
Non-trainable params: 54,528
_________________________________________________________________
Epoch 1/10
291/291 [==============================] - 43s 131ms/step - loss: 0.0802 - binary_accuracy: 0.9692 - val_loss: 0.0580 - val_binary_accuracy: 0.9764
Epoch 2/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0542 - binary_accuracy: 0.9792 - val_loss: 0.0529 - val_binary_accuracy: 0.9764
Epoch 3/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0400 - binary_accuracy: 0.9832 - val_loss: 0.0510 - val_binary_accuracy: 0.9798
Epoch 4/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0313 - binary_accuracy: 0.9879 - val_loss: 0.0505 - val_binary_accuracy: 0.9819
Epoch 5/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0272 - binary_accuracy: 0.9904 - val_loss: 0.0485 - val_binary_accuracy: 0.9807
Epoch 6/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0284 - binary_accuracy: 0.9901 - val_loss: 0.0497 - val_binary_accuracy: 0.9824
Epoch 7/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0198 - binary_accuracy: 0.9937 - val_loss: 0.0530 - val_binary_accuracy: 0.9802
Epoch 8/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0173 - binary_accuracy: 0.9930 - val_loss: 0.0572 - val_binary_accuracy: 0.9819
Epoch 9/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0113 - binary_accuracy: 0.9958 - val_loss: 0.0555 - val_binary_accuracy: 0.9837
Epoch 10/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0091 - binary_accuracy: 0.9966 - val_loss: 0.0596 - val_binary_accuracy: 0.9832
<keras.callbacks.History at 0x7f83982d4cd0>

Po 10 epokach dostrajanie daje nam tutaj niezłą poprawę.