Ta strona została przetłumaczona przez Cloud Translation API.
Switch to English

Przenieś naukę i dostrajanie

Zobacz na TensorFlow.org Uruchom w Google Colab Wyświetl źródło na GitHub Pobierz notatnik

Ustawiać

import numpy as np
import tensorflow as tf
from tensorflow import keras

Wprowadzenie

Uczenie się transferowe polega na przejęciu cech poznanych w ramach jednego problemu i wykorzystaniu ich w nowym, podobnym problemie. Na przykład cechy modelu, który nauczył się identyfikować szopy, mogą być przydatne do uruchomienia modelu służącego do identyfikacji tanuki.

Uczenie się transferu jest zwykle wykonywane w przypadku zadań, w których zestaw danych zawiera zbyt mało danych, aby wytrenować model w pełnej skali od podstaw.

Najczęstszym wcieleniem uczenia się transferowego w kontekście uczenia głębokiego jest następujący przepływ pracy:

  1. Weź warstwy z wcześniej wytrenowanego modelu.
  2. Zamrozić je, aby uniknąć zniszczenia jakichkolwiek zawartych w nich informacji podczas kolejnych rund szkoleniowych.
  3. Dodaj nowe, możliwe do nauczenia warstwy na wierzchu zamrożonych warstw. Nauczą się przekształcać stare funkcje w przewidywania dotyczące nowego zbioru danych.
  4. Trenuj nowe warstwy w swoim zbiorze danych.

Ostatnim, opcjonalnym krokiem jest dostrojenie , które polega na odmrożeniu całego modelu, który uzyskałeś powyżej (lub jego części) i ponownym trenowaniu go na nowych danych z bardzo niskim współczynnikiem uczenia się. Może to potencjalnie przynieść znaczące ulepszenia poprzez stopniowe dostosowywanie wstępnie wytrenowanych funkcji do nowych danych.

Najpierw omówimy szczegółowo interfejs API Keras, który można trainable , który leży u podstaw większości procesów uczenia się i dostosowywania transferu.

Następnie zademonstrujemy typowy przepływ pracy, pobierając model wstępnie wytrenowany w zbiorze danych ImageNet i przekwalifikowując go w zestawie danych klasyfikacji Kaggle „koty kontra psy”.

Jest to zaczerpnięte z Deep Learning with Python i post na blogu z 2016 r. „Tworzenie potężnych modeli klasyfikacji obrazów przy użyciu bardzo małej ilości danych” .

Warstwy zamrażania: zrozumienie atrybutu trainable do trainable

Warstwy i modele mają trzy atrybuty wagi:

  • weights to lista wszystkich zmiennych wagi warstwy.
  • trainable_weights to lista tych, które mają zostać zaktualizowane (poprzez spadek gradientu), aby zminimalizować straty podczas treningu.
  • non_trainable_weights to lista tych, które nie mają być trenowane. Zazwyczaj są one aktualizowane przez model podczas przejścia do przodu.

Przykład: warstwa Dense ma 2 możliwe do trenowania wagi (jądro i odchylenie)

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

Ogólnie wszystkie ciężary są ciężarkami, które można trenować. Jedyną wbudowaną warstwą, która ma wagi, których nie można BatchNormalization warstwa BatchNormalization . Wykorzystuje ciężary, których nie można trenować, aby śledzić średnią i wariancję danych wejściowych podczas treningu. Aby dowiedzieć się, jak używać ciężarków, których nie można wytrenować we własnych warstwach niestandardowych, zapoznaj się z przewodnikiem dotyczącym pisania nowych warstw od podstaw .

Przykład: warstwa BatchNormalization ma 2 ciężary, które można BatchNormalization , i 2 ciężary, których nie można wytrenować

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

Warstwy i modele mają również atrybut logiczny, który można trainable . Jego wartość można zmienić. Ustawienie parametru layer.trainable na False powoduje przesunięcie wszystkich obciążeń warstwy z layer.trainable do layer.trainable na layer.trainable . Nazywa się to „zamrażaniem” warstwy: stan zamrożonej warstwy nie będzie aktualizowany podczas treningu (ani podczas treningu z fit() ani podczas treningu z dowolną niestandardową pętlą, która polega na trainable_weights do stosowania aktualizacji gradientu).

Przykład: ustawienie trainable na False

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

Kiedy ciężar, który można trenować, staje się niemożliwy do wytrenowania, jego wartość nie jest już aktualizowana podczas treningu.

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 0s 1ms/step - loss: 0.1275

Nie należy mylić layer.trainable atrybut z argumentem training w layer.__call__() (która określa, czy warstwa powinna prowadzić swoje podaniu w trybie wnioskowania lub trybu treningowego). Aby uzyskać więcej informacji, zobacz FAQ Keras .

Rekurencyjne ustawienie atrybutu trainable do trainable

Jeśli ustawisz trainable = False na modelu lub na dowolnej warstwie, która ma podwarstwy, wszystkie warstwy potomne również staną się niemożliwe do trenowania.

Przykład:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

Typowy proces uczenia się transferowego

To prowadzi nas do tego, jak typowy przepływ pracy uczenia transferowego można zaimplementować w Keras:

  1. Utwórz wystąpienie modelu podstawowego i załaduj do niego wstępnie wytrenowane wagi.
  2. Zablokuj wszystkie warstwy w modelu podstawowym, ustawiając trainable = False .
  3. Utwórz nowy model na podstawie wyniku jednej (lub kilku) warstw z modelu podstawowego.
  4. Wytrenuj nowy model w nowym zbiorze danych.

Zwróć uwagę, że alternatywnym, lżejszym przepływem pracy może być również:

  1. Utwórz wystąpienie modelu podstawowego i załaduj do niego wstępnie wytrenowane wagi.
  2. Przeprowadź przez niego nowy zestaw danych i zapisz wynik jednej (lub kilku) warstw z modelu podstawowego. Nazywa się to wyodrębnianiem cech .
  3. Użyj tych danych wyjściowych jako danych wejściowych dla nowego, mniejszego modelu.

Główną zaletą tego drugiego przepływu pracy jest to, że model podstawowy jest uruchamiany tylko raz na danych, a nie raz na okres uczenia. Jest to więc dużo szybsze i tańsze.

Problem z tym drugim przepływem pracy polega jednak na tym, że nie pozwala on na dynamiczne modyfikowanie danych wejściowych nowego modelu podczas uczenia, co jest wymagane na przykład podczas rozszerzania danych. Uczenie się z transferu jest zwykle używane w przypadku zadań, w których nowy zestaw danych ma zbyt mało danych, aby wytrenować model w pełnej skali od podstaw, a w takich scenariuszach rozszerzanie danych jest bardzo ważne. W dalszej części skupimy się na pierwszym przepływie pracy.

Oto jak wygląda pierwszy przepływ pracy w Keras:

Najpierw utwórz wystąpienie modelu podstawowego z wstępnie wytrenowanymi wagami.

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

Następnie zamroź model podstawowy.

base_model.trainable = False

Utwórz nowy model na górze.

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

Wytrenuj model na nowych danych.

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

Strojenie

Gdy model osiągnie zbieżność z nowymi danymi, możesz spróbować odblokować cały model podstawowy lub jego część i ponownie przeszkolić cały model od końca do końca z bardzo niskim współczynnikiem uczenia się.

Jest to opcjonalny ostatni krok, który może potencjalnie przynieść stopniowe ulepszenia. Może to również potencjalnie prowadzić do szybkiego overfittingu - miej to na uwadze.

Ważne jest, aby wykonać ten krok dopiero po przeszkoleniu modelu z zamrożonymi warstwami w celu uzyskania konwergencji. Jeśli zmieszasz losowo zainicjowane warstwy, które można trenować z warstwami, które można trenować, które zawierają wstępnie wyuczone funkcje, losowo zainicjowane warstwy spowodują bardzo duże aktualizacje gradientu podczas treningu, co zniszczy wstępnie wyuczone funkcje.

Na tym etapie niezwykle ważne jest również użycie bardzo niskiego wskaźnika uczenia się, ponieważ trenujesz znacznie większy model niż w pierwszej rundzie szkolenia, na zbiorze danych, który jest zwykle bardzo mały. W rezultacie istnieje ryzyko bardzo szybkiego nadmiernego dopasowania, jeśli zastosujesz duże aktualizacje wagi. Tutaj chcesz tylko ponownie dostosować wstępnie wytrenowane odważniki w sposób przyrostowy.

Oto jak zaimplementować dostrajanie całego modelu podstawowego:

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

Ważna uwaga o compile() i trainable

Wywołanie compile() na modelu ma na celu „zamrożenie” zachowania tego modelu. Oznacza to, że możliwe do trainable wartości atrybutów w momencie kompilowania modelu powinny być zachowane przez cały okres istnienia tego modelu, aż do ponownego wywołania compile . W związku z tym, jeśli zmienisz jakąkolwiek trainable wartość, upewnij się, że ponownie wywołałeś compile() w swoim modelu, aby zmiany zostały uwzględnione.

Ważne uwagi o warstwie BatchNormalization

Wiele modeli obrazów zawiera warstwy BatchNormalization . Ta warstwa jest szczególnym przypadkiem pod każdym możliwym względem. Oto kilka rzeczy, o których należy pamiętać.

  • BatchNormalization zawiera 2 ciężary, których nie można trenować, które są aktualizowane podczas treningu. Są to zmienne śledzące średnią i wariancję danych wejściowych.
  • Jeśli ustawisz bn_layer.trainable = False , warstwa BatchNormalization będzie działać w trybie wnioskowania i nie będzie aktualizować statystyk średniej i wariancji. Nie dotyczy to ogólnie innych warstw, ponieważ trening siłowy i tryby wnioskowania / treningu to dwie ortogonalne koncepcje . Ale te dwa są powiązane w przypadku warstwy BatchNormalization .
  • Po odblokowaniu modelu, który zawiera warstwy BatchNormalization w celu dostrojenia, należy utrzymywać warstwy BatchNormalization w trybie wnioskowania, przekazując training=False podczas wywoływania modelu podstawowego. W przeciwnym razie aktualizacje zastosowane do ciężarów, których nie można trenować, nagle zniszczą to, czego nauczył się model.

Zobaczysz ten wzorzec w akcji w kompleksowym przykładzie na końcu tego przewodnika.

Przenieś naukę i dostrajanie dzięki niestandardowej pętli treningowej

Jeśli zamiast fit() używasz własnej niskopoziomowej pętli treningowej, przepływ pracy pozostaje zasadniczo taki sam. Należy uważać, aby brać pod uwagę tylko model listy. model.trainable_weights podczas stosowania aktualizacji gradientu:

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

Podobnie do dostrajania.

Kompleksowy przykład: dopracowanie modelu klasyfikacji obrazu na przykładzie kotów i psów

zbiór danych

Aby utrwalić te koncepcje, przeprowadzimy Cię przez konkretny, kompleksowy przykład uczenia się i dostrajania transferu. Załadujemy model Xception, wstępnie wyszkolony w ImageNet i użyjemy go w zbiorze danych klasyfikacji Kaggle „koty kontra psy”.

Pobieranie danych

Najpierw pobierzmy zestaw danych koty kontra psy za pomocą TFDS. Jeśli masz własny zestaw danych, prawdopodobnie będziesz chciał użyć narzędzia tf.keras.preprocessing.image_dataset_from_directory do wygenerowania podobnie oznaczonych obiektów zestawu danych z zestawu obrazów na dysku umieszczonych w folderach dla określonej klasy.

Transfer uczenia się jest najbardziej przydatny podczas pracy z bardzo małymi zbiorami danych. Aby nasz zbiór danych był mały, wykorzystamy 40% oryginalnych danych szkoleniowych (25 000 obrazów) do uczenia, 10% do walidacji i 10% do testowania.

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Downloading and preparing dataset cats_vs_dogs/4.0.0 (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0...

Warning:absl:1738 images were corrupted and were skipped

Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0.incompleteIL7NQA/cats_vs_dogs-train.tfrecord
Dataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data.
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

To jest pierwszych 9 obrazów w zestawie danych treningowych - jak widać, wszystkie mają różne rozmiary.

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

Widzimy również, że etykieta 1 to „pies”, a etykieta 0 to „kot”.

Standaryzacja danych

Nasze surowe obrazy mają różne rozmiary. Ponadto każdy piksel składa się z 3 wartości całkowitych od 0 do 255 (wartości poziomu RGB). To nie jest dobre rozwiązanie do zasilania sieci neuronowej. Musimy zrobić 2 rzeczy:

  • Standaryzuj do stałego rozmiaru obrazu. Wybieramy 150x150.
  • Normalizuj wartości pikseli w zakresie od -1 do 1. Zrobimy to za pomocą warstwy Normalization jako części samego modelu.

Ogólnie rzecz biorąc, dobrą praktyką jest opracowywanie modeli, które przyjmują surowe dane jako dane wejściowe, w przeciwieństwie do modeli, które wykorzystują już wstępnie przetworzone dane. Powodem jest to, że jeśli twój model oczekuje wstępnie przetworzonych danych, za każdym razem, gdy eksportujesz swój model, aby użyć go w innym miejscu (w przeglądarce internetowej, w aplikacji mobilnej), będziesz musiał ponownie zaimplementować dokładnie ten sam potok wstępnego przetwarzania. To staje się bardzo trudne, bardzo szybko. Dlatego powinniśmy wykonać jak najmniejszą ilość wstępnego przetwarzania przed uderzeniem w model.

Tutaj dokonamy zmiany rozmiaru obrazu w potoku danych (ponieważ głęboka sieć neuronowa może przetwarzać tylko ciągłe partie danych), a skalowanie wartości wejściowych wykonamy jako część modelu, kiedy go tworzymy.

Zmieńmy rozmiar obrazów na 150x150:

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

Poza tym zbierzmy dane i użyjmy buforowania i wstępnego pobierania, aby zoptymalizować prędkość ładowania.

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

Korzystanie z losowego rozszerzania danych

Jeśli nie masz dużego zestawu danych obrazu, dobrą praktyką jest sztuczne wprowadzanie różnorodności próbek przez zastosowanie losowych, ale realistycznych przekształceń do obrazów szkoleniowych, takich jak losowe przerzucanie w poziomie lub małe losowe obroty. Pomaga to uwidocznić model w różnych aspektach danych szkoleniowych, jednocześnie spowalniając nadmierne dopasowanie.

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [
        layers.experimental.preprocessing.RandomFlip("horizontal"),
        layers.experimental.preprocessing.RandomRotation(0.1),
    ]
)

Wizualizujmy, jak wygląda pierwszy obraz pierwszej partii po różnych losowych przekształceniach:

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[i]))
        plt.axis("off")

png

Zbudować model

Zbudujmy teraz model zgodny z planem, który wyjaśniliśmy wcześniej.

Zauważ, że:

  • Dodajemy warstwę Normalization , aby skalować wartości wejściowe (początkowo z zakresu [0, 255] ) do zakresu [-1, 1] .
  • W celu uregulowania dodajemy warstwę Dropout przed warstwą klasyfikacji.
  • Upewniamy się, że podczas wywoływania modelu podstawowego przekazujemy training=False , aby działał on w trybie wnioskowania, aby statystyki wsadowe nie były aktualizowane nawet po odblokowaniu modelu podstawowego w celu dostrojenia.
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be normalized
# from (0, 255) to a range (-1., +1.), the normalization layer
# does the following, outputs = (inputs - mean) / sqrt(var)
norm_layer = keras.layers.experimental.preprocessing.Normalization()
mean = np.array([127.5] * 3)
var = mean ** 2
# Scale inputs to [-1, +1]
x = norm_layer(x)
norm_layer.set_weights([mean, var])

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83689472/83683744 [==============================] - 2s 0us/step
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
normalization (Normalization (None, 150, 150, 3)       7         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,536
Trainable params: 2,049
Non-trainable params: 20,861,487
_________________________________________________________________

Trenuj górną warstwę

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
291/291 [==============================] - 9s 32ms/step - loss: 0.1758 - binary_accuracy: 0.9226 - val_loss: 0.0897 - val_binary_accuracy: 0.9660
Epoch 2/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1211 - binary_accuracy: 0.9497 - val_loss: 0.0870 - val_binary_accuracy: 0.9686
Epoch 3/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1166 - binary_accuracy: 0.9503 - val_loss: 0.0814 - val_binary_accuracy: 0.9712
Epoch 4/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1125 - binary_accuracy: 0.9534 - val_loss: 0.0825 - val_binary_accuracy: 0.9695
Epoch 5/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1073 - binary_accuracy: 0.9569 - val_loss: 0.0763 - val_binary_accuracy: 0.9703
Epoch 6/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1041 - binary_accuracy: 0.9573 - val_loss: 0.0812 - val_binary_accuracy: 0.9686
Epoch 7/20
291/291 [==============================] - 8s 27ms/step - loss: 0.1023 - binary_accuracy: 0.9567 - val_loss: 0.0820 - val_binary_accuracy: 0.9669
Epoch 8/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1005 - binary_accuracy: 0.9597 - val_loss: 0.0779 - val_binary_accuracy: 0.9695
Epoch 9/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1019 - binary_accuracy: 0.9580 - val_loss: 0.0813 - val_binary_accuracy: 0.9699
Epoch 10/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0940 - binary_accuracy: 0.9651 - val_loss: 0.0762 - val_binary_accuracy: 0.9729
Epoch 11/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0974 - binary_accuracy: 0.9613 - val_loss: 0.0752 - val_binary_accuracy: 0.9725
Epoch 12/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0965 - binary_accuracy: 0.9591 - val_loss: 0.0760 - val_binary_accuracy: 0.9721
Epoch 13/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0962 - binary_accuracy: 0.9598 - val_loss: 0.0785 - val_binary_accuracy: 0.9712
Epoch 14/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0966 - binary_accuracy: 0.9616 - val_loss: 0.0831 - val_binary_accuracy: 0.9699
Epoch 15/20
291/291 [==============================] - 8s 27ms/step - loss: 0.1000 - binary_accuracy: 0.9574 - val_loss: 0.0741 - val_binary_accuracy: 0.9725
Epoch 16/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0940 - binary_accuracy: 0.9628 - val_loss: 0.0781 - val_binary_accuracy: 0.9686
Epoch 17/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0915 - binary_accuracy: 0.9634 - val_loss: 0.0843 - val_binary_accuracy: 0.9678
Epoch 18/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0937 - binary_accuracy: 0.9620 - val_loss: 0.0829 - val_binary_accuracy: 0.9669
Epoch 19/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0988 - binary_accuracy: 0.9601 - val_loss: 0.0862 - val_binary_accuracy: 0.9686
Epoch 20/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0928 - binary_accuracy: 0.9644 - val_loss: 0.0798 - val_binary_accuracy: 0.9703

<tensorflow.python.keras.callbacks.History at 0x7f6104f04518>

Wykonaj rundę, aby dostroić cały model

Na koniec odblokujmy model podstawowy i przećwiczmy cały model od końca do końca z niskim współczynnikiem uczenia się.

Co ważne, chociaż model podstawowy można training=False , nadal działa w trybie wnioskowania, ponieważ przeszliśmy training=False podczas wywoływania go podczas budowania modelu. Oznacza to, że warstwy normalizacji wsadowej wewnątrz nie będą aktualizować statystyk wsadowych. Gdyby to zrobili, zrujnowaliby wyobrażenia, których nauczył się dotychczas model.

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
normalization (Normalization (None, 150, 150, 3)       7         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,536
Trainable params: 20,809,001
Non-trainable params: 54,535
_________________________________________________________________
Epoch 1/10
  2/291 [..............................] - ETA: 17s - loss: 0.1439 - binary_accuracy: 0.9219WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0329s vs `on_train_batch_end` time: 0.0905s). Check your callbacks.

Warning:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0329s vs `on_train_batch_end` time: 0.0905s). Check your callbacks.

291/291 [==============================] - 38s 132ms/step - loss: 0.0786 - binary_accuracy: 0.9706 - val_loss: 0.0631 - val_binary_accuracy: 0.9772
Epoch 2/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0553 - binary_accuracy: 0.9790 - val_loss: 0.0537 - val_binary_accuracy: 0.9781
Epoch 3/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0442 - binary_accuracy: 0.9829 - val_loss: 0.0532 - val_binary_accuracy: 0.9819
Epoch 4/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0369 - binary_accuracy: 0.9858 - val_loss: 0.0460 - val_binary_accuracy: 0.9832
Epoch 5/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0335 - binary_accuracy: 0.9870 - val_loss: 0.0561 - val_binary_accuracy: 0.9794
Epoch 6/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0253 - binary_accuracy: 0.9910 - val_loss: 0.0559 - val_binary_accuracy: 0.9819
Epoch 7/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0232 - binary_accuracy: 0.9920 - val_loss: 0.0432 - val_binary_accuracy: 0.9845
Epoch 8/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0185 - binary_accuracy: 0.9930 - val_loss: 0.0396 - val_binary_accuracy: 0.9854
Epoch 9/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0147 - binary_accuracy: 0.9948 - val_loss: 0.0439 - val_binary_accuracy: 0.9832
Epoch 10/10
291/291 [==============================] - 37s 129ms/step - loss: 0.0117 - binary_accuracy: 0.9954 - val_loss: 0.0538 - val_binary_accuracy: 0.9819

<tensorflow.python.keras.callbacks.History at 0x7f611c26e438>

Po 10 epokach dostrojenie daje nam tutaj niezłą poprawę.