Unisciti a TensorFlow a Google I/O, 11-12 maggio Registrati ora

Trasferire l'apprendimento e la messa a punto

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza l'origine su GitHub Scarica quaderno

Impostare

import numpy as np
import tensorflow as tf
from tensorflow import keras

introduzione

Apprendimento trasferimento consiste nel prendere caratteristiche apprese su un problema, e facendo leva su una nuova, problema simile. Ad esempio, le funzionalità di un modello che ha imparato a identificare i procioni possono essere utili per avviare un modello inteso a identificare i tanuki.

Il trasferimento di apprendimento viene in genere eseguito per attività in cui il set di dati ha troppo pochi dati per addestrare da zero un modello su vasta scala.

L'incarnazione più comune del transfer learning nel contesto del deep learning è il seguente flusso di lavoro:

  1. Prendi i livelli da un modello precedentemente addestrato.
  2. Bloccali, in modo da evitare di distruggere le informazioni che contengono durante i futuri round di addestramento.
  3. Aggiungi alcuni nuovi livelli allenabili sopra i livelli congelati. Impareranno a trasformare le vecchie funzionalità in previsioni su un nuovo set di dati.
  4. Addestra i nuovi livelli sul tuo set di dati.

Un ultimo, passaggio facoltativo, è messa a punto, che si compone di scongelare l'intero modello è stato ottenuto in precedenza (o parte di esso), e ri-formazione sui nuovi dati con un tasso di apprendimento molto bassa. Ciò può potenzialmente ottenere miglioramenti significativi, adattando in modo incrementale le funzionalità pre-addestrate ai nuovi dati.

In primo luogo, andremo oltre la Keras trainable API in dettaglio, che è alla base la maggior parte dei flussi di lavoro di apprendimento di trasferimento e di fine-tuning.

Quindi, dimostreremo il flusso di lavoro tipico prendendo un modello preaddestrato sul set di dati ImageNet e riqualificandolo sul set di dati di classificazione "gatti contro cani" di Kaggle.

Questo è adattato da Deep Learning con Python e il post sul blog 2016 "la costruzione di potenti modelli di classificazione dell'immagine usando molto poco dei dati" .

Strati di congelamento: comprensione del trainable attributo

Livelli e modelli hanno tre attributi di peso:

  • weights l 'elenco di tutte le variabili pesi dello strato.
  • trainable_weights è la lista di quelli che sono destinati ad essere aggiornato (via discesa del gradiente) per ridurre al minimo le perdite durante l'allenamento.
  • non_trainable_weights è la lista di quelli che non sono destinate ad essere addestrato. In genere vengono aggiornati dal modello durante il passaggio in avanti.

Esempio: il Dense strato ha 2 pesi addestrabili (kernel & bias)

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

In generale, tutti i pesi sono pesi allenabili. L'unica incorporato livello che ha pesi non addestrabili è il BatchNormalization strato. Utilizza pesi non allenabili per tenere traccia della media e della varianza dei suoi input durante l'allenamento. Per informazioni su come usare i pesi non addestrabili nei propri livelli personalizzati, consultare la guida alla scrittura di nuovi livelli da zero .

Esempio: il BatchNormalization strato ha 2 pesi addestrabili e 2 pesi non addestrabili

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

Livelli e modelli dispongono anche di un attributo booleano trainable . Il suo valore può essere modificato. Impostazione layer.trainable a False muove tutti i pesi del livello da addestrabile a non addestrabile. Questo si chiama "congelamento" dello strato: lo stato di un layer congelato non verrà aggiornato durante l'allenamento (sia quando la formazione con fit() o quando la formazione con qualsiasi ciclo personalizzato che si basa su trainable_weights per applicare gli aggiornamenti di pendenza).

Esempio: impostazione trainable a False

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

Quando un peso allenabile diventa non allenabile, il suo valore non viene più aggiornato durante l'allenamento.

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 1s 640ms/step - loss: 0.0945

Non confondere il layer.trainable attributo con l'argomento training in layer.__call__() (che controlla se lo strato deve essere eseguito il suo passaggio in avanti nel modo di deduzione o di modalità di formazione). Per ulteriori informazioni, vedere la Keras FAQ .

Impostazione ricorsiva del trainable attributo

Se si imposta trainable = False su un modello o su qualsiasi livello che ha sottolivelli, tutti i bambini gli strati diventano non addestrabile pure.

Esempio:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

Il tipico flusso di lavoro di trasferimento di apprendimento

Questo ci porta a come un tipico flusso di lavoro di trasferimento di apprendimento può essere implementato in Keras:

  1. Crea un'istanza di un modello base e carica su di esso pesi pre-addestrati.
  2. Congelare tutti i livelli del modello di base impostando trainable = False .
  3. Crea un nuovo modello sopra l'output di uno (o più) livelli dal modello di base.
  4. Addestra il tuo nuovo modello sul tuo nuovo set di dati.

Tieni presente che un flusso di lavoro alternativo e più leggero potrebbe anche essere:

  1. Crea un'istanza di un modello base e carica su di esso pesi pre-addestrati.
  2. Esegui il tuo nuovo set di dati attraverso di esso e registra l'output di uno (o più) livelli dal modello di base. Questo è chiamato estrazione di caratteristiche.
  3. Usa quell'output come dati di input per un nuovo modello più piccolo.

Un vantaggio chiave di quel secondo flusso di lavoro è che esegui il modello di base solo una volta sui tuoi dati, anziché una volta per epoca di addestramento. Quindi è molto più veloce ed economico.

Un problema con quel secondo flusso di lavoro, tuttavia, è che non consente di modificare dinamicamente i dati di input del nuovo modello durante l'addestramento, ad esempio quando si esegue l'aumento dei dati. L'apprendimento del trasferimento viene in genere utilizzato per le attività quando il nuovo set di dati ha troppo pochi dati per addestrare da zero un modello completo e in tali scenari l'aumento dei dati è molto importante. Quindi in quanto segue, ci concentreremo sul primo flusso di lavoro.

Ecco come appare il primo flusso di lavoro in Keras:

Innanzitutto, crea un'istanza di un modello base con pesi pre-addestrati.

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

Quindi, congelare il modello base.

base_model.trainable = False

Crea un nuovo modello in cima.

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

Addestra il modello su nuovi dati.

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

Ritocchi

Una volta che il tuo modello è convergente sui nuovi dati, puoi provare a sbloccare tutto o parte del modello di base e riqualificare l'intero modello end-to-end con un tasso di apprendimento molto basso.

Questo è un ultimo passaggio facoltativo che può potenzialmente darti miglioramenti incrementali. Potrebbe anche portare a un rapido sovradattamento, tienilo a mente.

E 'fondamentale per fare solo questo passo dopo il modello con layer congelati è stato addestrato per la convergenza. Se mescoli livelli addestrabili inizializzati in modo casuale con livelli addestrabili che contengono funzionalità pre-addestrate, i livelli inizializzati casualmente causeranno aggiornamenti del gradiente molto grandi durante l'allenamento, che distruggeranno le funzionalità pre-addestrate.

È inoltre fondamentale utilizzare una frequenza di apprendimento molto bassa in questa fase, poiché si esegue il training di un modello molto più ampio rispetto al primo round di training, su un set di dati che in genere è molto piccolo. Di conseguenza, si corre il rischio di un sovraadattamento molto rapidamente se si applicano aggiornamenti di peso di grandi dimensioni. Qui si desidera solo riadattare i pesi pre-addestrati in modo incrementale.

Ecco come implementare la messa a punto dell'intero modello base:

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

Nota importante sulla compile() e trainable

La chiamata compile() su un modello ha lo scopo di "congelare" il comportamento di quel modello. Ciò implica che i trainable valori degli attributi al momento il modello è compilato dovrebbero essere conservati per tutta la durata di quel modello, fino a quando compile viene chiamato di nuovo. Quindi, se si modifica qualsiasi trainable valore, assicurarsi di chiamata compile() di nuovo sul modello per le modifiche da prendere in considerazione.

Note importanti su BatchNormalization strato

Molti modelli di immagine contengono BatchNormalization strati. Quel livello è un caso speciale su ogni conteggio immaginabile. Ecco alcune cose da tenere a mente.

  • BatchNormalization contiene 2 pesi non addestrabili che vengono aggiornati durante l'allenamento. Queste sono le variabili che tracciano la media e la varianza degli input.
  • Quando si imposta bn_layer.trainable = False , l' BatchNormalization strato funzionerà in modalità deduzione, e non aggiornerà le sue statistiche medie e varianza. Questo non è il caso di altri strati in generale, come peso trainability & inferenza / modalità di formazione sono due concetti ortogonali . Ma i due sono legati nel caso del BatchNormalization strato.
  • Quando si sblocca un modello che contiene BatchNormalization strati al fine di fare la messa a punto, si dovrebbe tenere i BatchNormalization strati in modalità inferenza passando training=False quando si chiama il modello base. In caso contrario, gli aggiornamenti applicati ai pesi non addestrabili distruggono improvvisamente ciò che il modello ha appreso.

Vedrai questo modello in azione nell'esempio end-to-end alla fine di questa guida.

Trasferisci l'apprendimento e la messa a punto con un ciclo di allenamento personalizzato

Se invece di fit() , che si sta utilizzando il proprio ciclo di formazione di basso livello, i soggiorni del flusso di lavoro essenzialmente la stessa. Si dovrebbe fare attenzione a prendere solo in considerazione l'elenco model.trainable_weights quando si applica gli aggiornamenti di gradiente:

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

Allo stesso modo per la messa a punto.

Un esempio end-to-end: messa a punto di un modello di classificazione delle immagini su un set di dati gatti vs cani

Per consolidare questi concetti, ti guidiamo attraverso un concreto esempio di apprendimento e messa a punto del trasferimento end-to-end. Caricheremo il modello Xception, pre-addestrato su ImageNet, e lo utilizzeremo nel set di dati di classificazione "gatti contro cani" di Kaggle.

Ottenere i dati

Per prima cosa, recuperiamo il set di dati gatti contro cani utilizzando TFDS. Se avete il vostro set di dati, probabilmente si vorrà utilizzare l'utilità tf.keras.preprocessing.image_dataset_from_directory per generare oggetti simili dataset etichettato da un set di immagini su disco archiviati in cartelle specifiche della classe.

Il trasferimento di apprendimento è molto utile quando si lavora con set di dati molto piccoli. Per mantenere piccolo il nostro set di dati, utilizzeremo il 40% dei dati di addestramento originali (25.000 immagini) per l'addestramento, il 10% per la convalida e il 10% per i test.

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

Queste sono le prime 9 immagini nel set di dati di addestramento: come puoi vedere, sono tutte di dimensioni diverse.

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

Possiamo anche vedere che l'etichetta 1 è "cane" e l'etichetta 0 è "gatto".

Standardizzazione dei dati

Le nostre immagini grezze hanno una varietà di dimensioni. Inoltre, ogni pixel è composto da 3 valori interi compresi tra 0 e 255 (valori di livello RGB). Non è una buona idea per alimentare una rete neurale. Dobbiamo fare 2 cose:

  • Standardizza a una dimensione dell'immagine fissa. Scegliamo 150x150.
  • Valori di pixel normalizza tra -1 e 1. Lo faremo utilizzando una Normalization strato come parte del modello stesso.

In generale, è una buona pratica sviluppare modelli che prendano dati grezzi come input, al contrario di modelli che prendano dati già preelaborati. Il motivo è che, se il tuo modello prevede dati preelaborati, ogni volta che esporti il ​​tuo modello per utilizzarlo altrove (in un browser Web, in un'app mobile), dovrai reimplementare la stessa identica pipeline di preelaborazione. Questo diventa molto complicato molto rapidamente. Quindi dovremmo fare il minor numero possibile di pre-elaborazione prima di colpire il modello.

Qui, eseguiremo il ridimensionamento dell'immagine nella pipeline di dati (perché una rete neurale profonda può elaborare solo batch di dati contigui) ed eseguiremo il ridimensionamento del valore di input come parte del modello, quando lo creiamo.

Ridimensioniamo le immagini a 150x150:

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

Inoltre, eseguiamo in batch i dati e utilizziamo la memorizzazione nella cache e il prefetching per ottimizzare la velocità di caricamento.

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

Utilizzo dell'aumento casuale dei dati

Quando non si dispone di un set di dati di immagini di grandi dimensioni, è consigliabile introdurre artificialmente la diversità del campione applicando trasformazioni casuali ma realistiche alle immagini di addestramento, ad esempio capovolgimenti orizzontali casuali o piccole rotazioni casuali. Ciò consente di esporre il modello a diversi aspetti dei dati di addestramento rallentando l'overfitting.

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [layers.RandomFlip("horizontal"), layers.RandomRotation(0.1),]
)

Visualizziamo come appare la prima immagine del primo batch dopo varie trasformazioni casuali:

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[0]))
        plt.axis("off")
2021-09-01 18:45:34.772284: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

Costruisci un modello

Ora costruiamo un modello che segua il progetto che abbiamo spiegato in precedenza.

Notare che:

  • Si aggiunge una Rescaling strato di valori di ingresso di scala (inizialmente nella [0, 255] intervallo) al [-1, 1] gamma.
  • Aggiungiamo una Dropout livello prima che lo strato di classificazione, per la regolarizzazione.
  • Facciamo in modo di passare training=False quando si chiama il modello di base, in modo che venga eseguito in modalità inferenza, in modo che le statistiche batchnorm non vengono aggiornati anche dopo che abbiamo sbloccare il modello base per la messa a punto.
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83689472/83683744 [==============================] - 2s 0us/step
83697664/83683744 [==============================] - 2s 0us/step
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 2,049
Non-trainable params: 20,861,480
_________________________________________________________________

Allena lo strato superiore

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
151/291 [==============>...............] - ETA: 3s - loss: 0.1979 - binary_accuracy: 0.9096
Corrupt JPEG data: 65 extraneous bytes before marker 0xd9
268/291 [==========================>...] - ETA: 1s - loss: 0.1663 - binary_accuracy: 0.9269
Corrupt JPEG data: 239 extraneous bytes before marker 0xd9
282/291 [============================>.] - ETA: 0s - loss: 0.1628 - binary_accuracy: 0.9284
Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9
Corrupt JPEG data: 228 extraneous bytes before marker 0xd9
291/291 [==============================] - ETA: 0s - loss: 0.1620 - binary_accuracy: 0.9286
Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9
291/291 [==============================] - 29s 63ms/step - loss: 0.1620 - binary_accuracy: 0.9286 - val_loss: 0.0814 - val_binary_accuracy: 0.9686
Epoch 2/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1178 - binary_accuracy: 0.9511 - val_loss: 0.0785 - val_binary_accuracy: 0.9695
Epoch 3/20
291/291 [==============================] - 9s 30ms/step - loss: 0.1121 - binary_accuracy: 0.9536 - val_loss: 0.0748 - val_binary_accuracy: 0.9712
Epoch 4/20
291/291 [==============================] - 9s 29ms/step - loss: 0.1082 - binary_accuracy: 0.9554 - val_loss: 0.0754 - val_binary_accuracy: 0.9703
Epoch 5/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1034 - binary_accuracy: 0.9570 - val_loss: 0.0721 - val_binary_accuracy: 0.9725
Epoch 6/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0975 - binary_accuracy: 0.9602 - val_loss: 0.0748 - val_binary_accuracy: 0.9699
Epoch 7/20
291/291 [==============================] - 9s 29ms/step - loss: 0.0989 - binary_accuracy: 0.9595 - val_loss: 0.0732 - val_binary_accuracy: 0.9716
Epoch 8/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1027 - binary_accuracy: 0.9566 - val_loss: 0.0787 - val_binary_accuracy: 0.9678
Epoch 9/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0959 - binary_accuracy: 0.9614 - val_loss: 0.0734 - val_binary_accuracy: 0.9729
Epoch 10/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0995 - binary_accuracy: 0.9588 - val_loss: 0.0717 - val_binary_accuracy: 0.9721
Epoch 11/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0957 - binary_accuracy: 0.9612 - val_loss: 0.0731 - val_binary_accuracy: 0.9725
Epoch 12/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0936 - binary_accuracy: 0.9622 - val_loss: 0.0751 - val_binary_accuracy: 0.9716
Epoch 13/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0965 - binary_accuracy: 0.9610 - val_loss: 0.0821 - val_binary_accuracy: 0.9695
Epoch 14/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0939 - binary_accuracy: 0.9618 - val_loss: 0.0742 - val_binary_accuracy: 0.9712
Epoch 15/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0974 - binary_accuracy: 0.9585 - val_loss: 0.0771 - val_binary_accuracy: 0.9712
Epoch 16/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0947 - binary_accuracy: 0.9621 - val_loss: 0.0823 - val_binary_accuracy: 0.9699
Epoch 17/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0947 - binary_accuracy: 0.9625 - val_loss: 0.0718 - val_binary_accuracy: 0.9708
Epoch 18/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0928 - binary_accuracy: 0.9616 - val_loss: 0.0738 - val_binary_accuracy: 0.9716
Epoch 19/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0922 - binary_accuracy: 0.9644 - val_loss: 0.0743 - val_binary_accuracy: 0.9716
Epoch 20/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0885 - binary_accuracy: 0.9635 - val_loss: 0.0745 - val_binary_accuracy: 0.9695
<keras.callbacks.History at 0x7f849a3b2950>

Fai un giro di messa a punto dell'intero modello

Infine, sblocchiamo il modello base e addestriamo l'intero modello end-to-end con un tasso di apprendimento basso.

È importante sottolineare che, anche se il modello di base diventa addestrabile, è ancora in esecuzione in modalità di inferenza da quando abbiamo superato training=False al momento della chiamata quando abbiamo costruito il modello. Ciò significa che i livelli di normalizzazione batch all'interno non aggiorneranno le loro statistiche batch. Se lo facessero, rovinerebbero le rappresentazioni apprese finora dal modello.

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 20,809,001
Non-trainable params: 54,528
_________________________________________________________________
Epoch 1/10
291/291 [==============================] - 43s 131ms/step - loss: 0.0802 - binary_accuracy: 0.9692 - val_loss: 0.0580 - val_binary_accuracy: 0.9764
Epoch 2/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0542 - binary_accuracy: 0.9792 - val_loss: 0.0529 - val_binary_accuracy: 0.9764
Epoch 3/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0400 - binary_accuracy: 0.9832 - val_loss: 0.0510 - val_binary_accuracy: 0.9798
Epoch 4/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0313 - binary_accuracy: 0.9879 - val_loss: 0.0505 - val_binary_accuracy: 0.9819
Epoch 5/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0272 - binary_accuracy: 0.9904 - val_loss: 0.0485 - val_binary_accuracy: 0.9807
Epoch 6/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0284 - binary_accuracy: 0.9901 - val_loss: 0.0497 - val_binary_accuracy: 0.9824
Epoch 7/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0198 - binary_accuracy: 0.9937 - val_loss: 0.0530 - val_binary_accuracy: 0.9802
Epoch 8/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0173 - binary_accuracy: 0.9930 - val_loss: 0.0572 - val_binary_accuracy: 0.9819
Epoch 9/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0113 - binary_accuracy: 0.9958 - val_loss: 0.0555 - val_binary_accuracy: 0.9837
Epoch 10/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0091 - binary_accuracy: 0.9966 - val_loss: 0.0596 - val_binary_accuracy: 0.9832
<keras.callbacks.History at 0x7f83982d4cd0>

Dopo 10 epoche, la messa a punto ci porta un bel miglioramento qui.