Apprentissage par transfert et mise au point

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier

Installer

import numpy as np
import tensorflow as tf
from tensorflow import keras

introduction

L' apprentissage de transfert consiste à prendre sur un éléments appris problème, et les appuyant sur un nouveau problème similaire. Par exemple, les caractéristiques d'un modèle qui a appris à identifier les ratons laveurs peuvent être utiles pour lancer un modèle destiné à identifier les tanukis.

L'apprentissage par transfert est généralement effectué pour les tâches pour lesquelles votre ensemble de données contient trop peu de données pour former un modèle à grande échelle à partir de zéro.

L'incarnation la plus courante de l'apprentissage par transfert dans le contexte de l'apprentissage en profondeur est le workflow suivant :

  1. Prenez des calques à partir d'un modèle préalablement formé.
  2. Congelez-les, afin d'éviter de détruire les informations qu'ils contiennent lors des futurs tours d'entraînement.
  3. Ajoutez de nouvelles couches pouvant être entraînées au-dessus des couches gelées. Ils apprendront à transformer les anciennes caractéristiques en prédictions sur un nouvel ensemble de données.
  4. Entraînez les nouvelles couches sur votre jeu de données.

Une dernière, étape facultative, est bien mise au point, qui consiste à dégeler l'ensemble du modèle que vous avez obtenu ci - dessus (ou une partie de celui - ci), et re-formation sur les nouvelles données avec un taux d'apprentissage très faible. Cela peut potentiellement apporter des améliorations significatives, en adaptant progressivement les fonctionnalités pré-entraînées aux nouvelles données.

Tout d' abord, nous irons sur la Keras trainable API en détail, ce qui sous - tend la plupart apprentissage de transfert et des flux de travail de réglage fin.

Ensuite, nous démontrerons le flux de travail typique en prenant un modèle pré-entraîné sur le jeu de données ImageNet et en le réformant sur le jeu de données de classification Kaggle « cats vs dogs ».

Ceci est adapté de profond apprentissage avec Python et 2016 blog « la construction de modèles de classification des images puissantes en utilisant très peu de données » .

Couches de congélation: comprendre le trainable attribut

Les calques et les modèles ont trois attributs de poids :

  • weights est la liste de toutes les variables de poids de la couche.
  • trainable_weights est la liste de ceux qui sont destinés à être mis à jour (via une descente de gradient) afin de minimiser la perte lors de la formation.
  • non_trainable_weights est la liste de ceux qui ne sont pas destinés à être formés. Généralement, ils sont mis à jour par le modèle lors de la passe avant.

Exemple: la Dense couche a 2 poids entraînables (kernel et polarisation)

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

En général, tous les poids sont des poids pouvant être entraînés. La seule intégrée dans la couche qui a des poids non entraînables est le BatchNormalization couche. Il utilise des poids non entraînables pour suivre la moyenne et la variance de ses entrées pendant l'entraînement. Pour apprendre à utiliser des poids non trainable dans vos propres couches personnalisées, consultez le guide pour l' écriture de nouvelles couches à partir de zéro .

Exemple: la BatchNormalization couche a 2 poids et 2 entraînables poids non entraînables

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

Les couches et les modèles disposent également d' un attribut booléen trainable . Sa valeur peut être modifiée. Réglage layer.trainable de False mouvements , tous les poids de la couche de trainable à la non-trainable. On appelle cela le « gel » de la couche: l'état d'une couche gelée ne sera pas mis à jour lors de la formation (soit lors de la formation avec fit() ou lors de l' entraînement avec une boucle personnalisée qui repose sur trainable_weights pour appliquer les mises à jour de gradient).

Exemple: réglage de trainable à False

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

Lorsqu'un poids entraînable devient non entraînable, sa valeur n'est plus mise à jour pendant l'entraînement.

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 1s 640ms/step - loss: 0.0945

Ne confondez pas l' layer.trainable attribut avec l'argument de training en layer.__call__() (qui contrôle si la couche doit exécuter son passage vers l' avant en mode d'inférence ou le mode de formation). Pour plus d' informations, consultez le Keras FAQ .

Réglage récursive de l' trainable attribut

Si vous définissez trainable = False sur un modèle ou sur une couche qui a des sous - couches, tous les enfants deviennent des couches non trainable ainsi.

Exemple:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

Le flux de travail typique de l'apprentissage par transfert

Cela nous amène à la manière dont un workflow d'apprentissage par transfert typique peut être mis en œuvre dans Keras :

  1. Instanciez un modèle de base et chargez-y des poids pré-entraînés.
  2. Congeler toutes les couches du modèle de base en définissant trainable = False .
  3. Créez un nouveau modèle au-dessus de la sortie d'une (ou plusieurs) couches du modèle de base.
  4. Entraînez votre nouveau modèle sur votre nouvel ensemble de données.

Notez qu'un workflow alternatif plus léger pourrait également être :

  1. Instanciez un modèle de base et chargez-y des poids pré-entraînés.
  2. Exécutez votre nouveau jeu de données et enregistrez la sortie d'une (ou plusieurs) couches du modèle de base. Ceci est appelé extraction de caractéristiques.
  3. Utilisez cette sortie comme données d'entrée pour un nouveau modèle plus petit.

Un avantage clé de ce deuxième workflow est que vous n'exécutez le modèle de base qu'une seule fois sur vos données, plutôt qu'une fois par période d'entraînement. C'est donc beaucoup plus rapide et moins cher.

Un problème avec ce deuxième workflow, cependant, est qu'il ne vous permet pas de modifier dynamiquement les données d'entrée de votre nouveau modèle pendant la formation, ce qui est nécessaire lors de l'augmentation des données, par exemple. L'apprentissage par transfert est généralement utilisé pour les tâches lorsque votre nouvel ensemble de données contient trop peu de données pour former un modèle à grande échelle à partir de zéro, et dans de tels scénarios, l'augmentation des données est très importante. Dans ce qui suit, nous allons donc nous concentrer sur le premier workflow.

Voici à quoi ressemble le premier workflow dans Keras :

Tout d'abord, instanciez un modèle de base avec des poids pré-entraînés.

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

Ensuite, gelez le modèle de base.

base_model.trainable = False

Créez un nouveau modèle sur le dessus.

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

Entraînez le modèle sur de nouvelles données.

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

Réglage fin

Une fois que votre modèle a convergé sur les nouvelles données, vous pouvez essayer de dégeler tout ou partie du modèle de base et recycler l'ensemble du modèle de bout en bout avec un taux d'apprentissage très faible.

Il s'agit d'une dernière étape facultative qui peut potentiellement vous apporter des améliorations incrémentielles. Cela pourrait également potentiellement conduire à un surapprentissage rapide - gardez cela à l'esprit.

Il est essentiel de faire que cette étape après le modèle avec des couches congelées a été formé à la convergence. Si vous mélangez des couches d'apprentissage initialisées de manière aléatoire avec des couches d'apprentissage contenant des entités pré-entraînées, les couches initialisées de manière aléatoire entraîneront des mises à jour de gradient très importantes pendant l'entraînement, ce qui détruira vos entités pré-entraînées.

Il est également essentiel d'utiliser un taux d'apprentissage très faible à ce stade, car vous entraînez un modèle beaucoup plus grand que lors du premier cycle d'entraînement, sur un ensemble de données généralement très petit. Par conséquent, vous risquez un surapprentissage très rapidement si vous appliquez des mises à jour de poids importantes. Ici, vous souhaitez uniquement réadapter les poids pré-entraînés de manière incrémentielle.

Voici comment implémenter le réglage fin de l'ensemble du modèle de base :

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

Remarque importante sur la compile() et trainable

L' appel de la compile() sur un modèle est destiné à « geler » le comportement de ce modèle. Cela implique que les trainable valeurs d'attribut au moment où le modèle est compilé doivent être conservés tout au long de la durée de vie de ce modèle, jusqu'à ce que la compile est appelé à nouveau. Ainsi, Si vous modifiez trainable valeur, veillez à l' appel compile() nouveau votre modèle vos modifications à tenir compte.

Remarques importantes sur BatchNormalization couche

De nombreux modèles d'image contiennent BatchNormalization couches. Cette couche est un cas particulier à tous les égards imaginables. Voici quelques éléments à garder à l'esprit.

  • BatchNormalization contient 2 poids non trainable qui mis à jour pendant la formation. Ce sont les variables qui suivent la moyenne et la variance des entrées.
  • Lorsque vous définissez bn_layer.trainable = False , la BatchNormalization couche fonctionnera en mode d'inférence, et ne sera pas mettre à jour ses statistiques moyennes et la variance. Ce n'est pas le cas pour d' autres couches en général, comme le poids et l' éducabilité inférence / modes de formation sont deux concepts orthogonaux . Mais les deux sont liés dans le cas de la BatchNormalization couche.
  • Lorsque vous dégeler un modèle qui contient BatchNormalization couches afin de faire le réglage fin, vous devez garder les BatchNormalization couches en mode d'inférence en passant la training=False lorsque vous appelez le modèle de base. Sinon, les mises à jour appliquées aux poids non entraînables détruiront soudainement ce que le modèle a appris.

Vous verrez ce modèle en action dans l'exemple de bout en bout à la fin de ce guide.

Transférer l'apprentissage et affiner avec une boucle de formation personnalisée

Si au lieu d' fit() , vous utilisez votre propre boucle de formation de bas niveau, les séjours de flux de travail essentiellement les mêmes. Vous devez être prudent de ne prendre en compte la liste model.trainable_weights lors de l' application des mises à jour de gradient:

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

De même pour le réglage fin.

Un exemple de bout en bout : affiner un modèle de classification d'images sur un ensemble de données chats vs chiens

Pour consolider ces concepts, passons en revue un exemple concret d'apprentissage et de mise au point par transfert de bout en bout. Nous allons charger le modèle Xception, pré-entraîné sur ImageNet, et l'utiliser sur le jeu de données de classification Kaggle "cats vs. dogs".

Obtenir les données

Tout d'abord, récupérons l'ensemble de données chats contre chiens à l'aide de TFDS. Si vous avez votre propre ensemble de données, vous voudrez probablement utiliser l'utilitaire tf.keras.preprocessing.image_dataset_from_directory pour générer ensemble de données similaires étiquetés objets à partir d' une série d'images sur le disque déposés dans des dossiers spécifiques à chaque classe.

L'apprentissage par transfert est particulièrement utile lorsque vous travaillez avec de très petits ensembles de données. Pour garder notre ensemble de données petit, nous utiliserons 40 % des données d'entraînement d'origine (25 000 images) pour l'entraînement, 10 % pour la validation et 10 % pour les tests.

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

Ce sont les 9 premières images de l'ensemble de données d'entraînement. Comme vous pouvez le voir, elles sont toutes de tailles différentes.

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

Nous pouvons également voir que l'étiquette 1 est "chien" et l'étiquette 0 est "chat".

Standardiser les données

Nos images brutes ont une variété de tailles. De plus, chaque pixel est constitué de 3 valeurs entières comprises entre 0 et 255 (valeurs de niveau RVB). Ce n'est pas une bonne solution pour alimenter un réseau de neurones. Nous devons faire 2 choses :

  • Standardiser à une taille d'image fixe. Nous choisissons 150x150.
  • Valeurs de pixels Normaliser entre -1 et 1. Nous allons le faire en utilisant une Normalization couche dans le cadre du modèle lui - même.

En général, il est recommandé de développer des modèles qui prennent des données brutes en entrée, par opposition aux modèles qui prennent des données déjà prétraitées. La raison en est que, si votre modèle attend des données prétraitées, chaque fois que vous exportez votre modèle pour l'utiliser ailleurs (dans un navigateur Web, dans une application mobile), vous devrez réimplémenter exactement le même pipeline de prétraitement. Cela devient très délicat très rapidement. Nous devons donc effectuer le moins de prétraitements possible avant de toucher le modèle.

Ici, nous allons redimensionner l'image dans le pipeline de données (car un réseau de neurones profonds ne peut traiter que des lots de données contigus) et nous allons faire la mise à l'échelle de la valeur d'entrée dans le cadre du modèle, lorsque nous le créons.

Redimensionnons les images à 150x150 :

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

En outre, regroupons les données et utilisons la mise en cache et la prélecture pour optimiser la vitesse de chargement.

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

Utilisation de l'augmentation de données aléatoire

Lorsque vous ne disposez pas d'un grand ensemble de données d'images, il est recommandé d'introduire artificiellement une diversité d'échantillons en appliquant des transformations aléatoires mais réalistes aux images d'apprentissage, telles que le retournement horizontal aléatoire ou de petites rotations aléatoires. Cela permet d'exposer le modèle à différents aspects des données d'entraînement tout en ralentissant le surapprentissage.

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [layers.RandomFlip("horizontal"), layers.RandomRotation(0.1),]
)

Visualisons à quoi ressemble la première image du premier lot après diverses transformations aléatoires :

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[0]))
        plt.axis("off")
2021-09-01 18:45:34.772284: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

Construire un modèle

Maintenant, construisons un modèle qui suit le plan que nous avons expliqué plus tôt.

Noter que:

  • Nous ajoutons un Rescaling couche de valeurs d'entrée à grande échelle ( à l' origine dans le [0, 255] gamme) à la [-1, 1] gamme.
  • Nous ajoutons une Dropout couche avant que la couche de classification, de régularisation.
  • Nous nous assurons de passer la training=False lorsque vous appelez le modèle de base, pour qu'il fonctionne en mode d'inférence, de sorte que les statistiques ne batchnorm pas mis à jour même après que nous dégeler le modèle de base pour réglage fin.
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83689472/83683744 [==============================] - 2s 0us/step
83697664/83683744 [==============================] - 2s 0us/step
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 2,049
Non-trainable params: 20,861,480
_________________________________________________________________

Former la couche supérieure

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
151/291 [==============>...............] - ETA: 3s - loss: 0.1979 - binary_accuracy: 0.9096
Corrupt JPEG data: 65 extraneous bytes before marker 0xd9
268/291 [==========================>...] - ETA: 1s - loss: 0.1663 - binary_accuracy: 0.9269
Corrupt JPEG data: 239 extraneous bytes before marker 0xd9
282/291 [============================>.] - ETA: 0s - loss: 0.1628 - binary_accuracy: 0.9284
Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9
Corrupt JPEG data: 228 extraneous bytes before marker 0xd9
291/291 [==============================] - ETA: 0s - loss: 0.1620 - binary_accuracy: 0.9286
Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9
291/291 [==============================] - 29s 63ms/step - loss: 0.1620 - binary_accuracy: 0.9286 - val_loss: 0.0814 - val_binary_accuracy: 0.9686
Epoch 2/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1178 - binary_accuracy: 0.9511 - val_loss: 0.0785 - val_binary_accuracy: 0.9695
Epoch 3/20
291/291 [==============================] - 9s 30ms/step - loss: 0.1121 - binary_accuracy: 0.9536 - val_loss: 0.0748 - val_binary_accuracy: 0.9712
Epoch 4/20
291/291 [==============================] - 9s 29ms/step - loss: 0.1082 - binary_accuracy: 0.9554 - val_loss: 0.0754 - val_binary_accuracy: 0.9703
Epoch 5/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1034 - binary_accuracy: 0.9570 - val_loss: 0.0721 - val_binary_accuracy: 0.9725
Epoch 6/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0975 - binary_accuracy: 0.9602 - val_loss: 0.0748 - val_binary_accuracy: 0.9699
Epoch 7/20
291/291 [==============================] - 9s 29ms/step - loss: 0.0989 - binary_accuracy: 0.9595 - val_loss: 0.0732 - val_binary_accuracy: 0.9716
Epoch 8/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1027 - binary_accuracy: 0.9566 - val_loss: 0.0787 - val_binary_accuracy: 0.9678
Epoch 9/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0959 - binary_accuracy: 0.9614 - val_loss: 0.0734 - val_binary_accuracy: 0.9729
Epoch 10/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0995 - binary_accuracy: 0.9588 - val_loss: 0.0717 - val_binary_accuracy: 0.9721
Epoch 11/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0957 - binary_accuracy: 0.9612 - val_loss: 0.0731 - val_binary_accuracy: 0.9725
Epoch 12/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0936 - binary_accuracy: 0.9622 - val_loss: 0.0751 - val_binary_accuracy: 0.9716
Epoch 13/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0965 - binary_accuracy: 0.9610 - val_loss: 0.0821 - val_binary_accuracy: 0.9695
Epoch 14/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0939 - binary_accuracy: 0.9618 - val_loss: 0.0742 - val_binary_accuracy: 0.9712
Epoch 15/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0974 - binary_accuracy: 0.9585 - val_loss: 0.0771 - val_binary_accuracy: 0.9712
Epoch 16/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0947 - binary_accuracy: 0.9621 - val_loss: 0.0823 - val_binary_accuracy: 0.9699
Epoch 17/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0947 - binary_accuracy: 0.9625 - val_loss: 0.0718 - val_binary_accuracy: 0.9708
Epoch 18/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0928 - binary_accuracy: 0.9616 - val_loss: 0.0738 - val_binary_accuracy: 0.9716
Epoch 19/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0922 - binary_accuracy: 0.9644 - val_loss: 0.0743 - val_binary_accuracy: 0.9716
Epoch 20/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0885 - binary_accuracy: 0.9635 - val_loss: 0.0745 - val_binary_accuracy: 0.9695
<keras.callbacks.History at 0x7f849a3b2950>

Faire un tour de réglage fin de l'ensemble du modèle

Enfin, libérons le modèle de base et formons l'ensemble du modèle de bout en bout avec un faible taux d'apprentissage.

Il est important, bien que le modèle de base devient facile à former, il est toujours en cours d' exécution en mode d'inférence puisque nous avons passé la training=False lorsque vous appelez quand nous avons construit le modèle. Cela signifie que les couches de normalisation de lot à l'intérieur ne mettront pas à jour leurs statistiques de lot. S'ils le faisaient, ils feraient des ravages sur les représentations apprises par le modèle jusqu'à présent.

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 20,809,001
Non-trainable params: 54,528
_________________________________________________________________
Epoch 1/10
291/291 [==============================] - 43s 131ms/step - loss: 0.0802 - binary_accuracy: 0.9692 - val_loss: 0.0580 - val_binary_accuracy: 0.9764
Epoch 2/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0542 - binary_accuracy: 0.9792 - val_loss: 0.0529 - val_binary_accuracy: 0.9764
Epoch 3/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0400 - binary_accuracy: 0.9832 - val_loss: 0.0510 - val_binary_accuracy: 0.9798
Epoch 4/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0313 - binary_accuracy: 0.9879 - val_loss: 0.0505 - val_binary_accuracy: 0.9819
Epoch 5/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0272 - binary_accuracy: 0.9904 - val_loss: 0.0485 - val_binary_accuracy: 0.9807
Epoch 6/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0284 - binary_accuracy: 0.9901 - val_loss: 0.0497 - val_binary_accuracy: 0.9824
Epoch 7/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0198 - binary_accuracy: 0.9937 - val_loss: 0.0530 - val_binary_accuracy: 0.9802
Epoch 8/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0173 - binary_accuracy: 0.9930 - val_loss: 0.0572 - val_binary_accuracy: 0.9819
Epoch 9/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0113 - binary_accuracy: 0.9958 - val_loss: 0.0555 - val_binary_accuracy: 0.9837
Epoch 10/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0091 - binary_accuracy: 0.9966 - val_loss: 0.0596 - val_binary_accuracy: 0.9832
<keras.callbacks.History at 0x7f83982d4cd0>

Après 10 époques, le réglage fin nous apporte ici une belle amélioration.