La journée communautaire ML est le 9 novembre ! Rejoignez - nous pour les mises à jour de tensorflow, JAX et plus En savoir plus

Apprentissage par transfert et mise au point

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier

Dans ce didacticiel, vous apprendrez à classer des images de chats et de chiens en utilisant l'apprentissage par transfert à partir d'un réseau pré-entraîné.

Un modèle pré-entraîné est un réseau enregistré qui a été préalablement formé sur un grand ensemble de données, généralement sur une tâche de classification d'images à grande échelle. Vous utilisez soit le modèle pré-entraîné tel quel, soit l'apprentissage par transfert pour personnaliser ce modèle pour une tâche donnée.

L'intuition derrière l'apprentissage par transfert pour la classification d'images est que si un modèle est formé sur un ensemble de données suffisamment grand et général, ce modèle servira efficacement de modèle générique du monde visuel. Vous pouvez ensuite tirer parti de ces cartes de caractéristiques apprises sans avoir à repartir de zéro en entraînant un grand modèle sur un grand jeu de données.

Dans ce cahier, vous allez essayer deux manières de personnaliser un modèle pré-entraîné :

  1. Extraction de caractéristiques : utilisez les représentations apprises par un réseau précédent pour extraire des caractéristiques significatives à partir de nouveaux échantillons. Vous ajoutez simplement un nouveau classificateur, qui sera entraîné à partir de zéro, au-dessus du modèle pré-entraîné afin que vous puissiez réutiliser les cartes de caractéristiques apprises précédemment pour le jeu de données.

    Vous n'avez pas besoin de (re)former le modèle entier. Le réseau convolutif de base contient déjà des fonctionnalités génériquement utiles pour classer les images. Cependant, la partie finale de classification du modèle pré-entraîné est spécifique à la tâche de classification d'origine, et par la suite spécifique à l'ensemble de classes sur lesquelles le modèle a été entraîné.

  2. Ajustement : Dégelez quelques-unes des couches supérieures d'une base de modèle gelée et entraînez conjointement à la fois les couches de classificateur nouvellement ajoutées et les dernières couches du modèle de base. Cela nous permet d'« affiner » les représentations des caractéristiques d'ordre supérieur dans le modèle de base afin de les rendre plus pertinentes pour la tâche spécifique.

Vous suivrez le workflow général de machine learning.

  1. Examiner et comprendre les données
  2. Construire un pipeline d'entrée, dans ce cas en utilisant Keras ImageDataGenerator
  3. Composer le modèle
    • Charge dans le modèle de base pré-entraîné (et les poids pré-entraînés)
    • Empilez les couches de classification sur le dessus
  4. Former le modèle
  5. Évaluer le modèle
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf

from tensorflow.keras.preprocessing import image_dataset_from_directory

Prétraitement des données

Téléchargement de données

Dans ce tutoriel, vous allez utiliser un jeu de données contenant plusieurs milliers d'images de chats et de chiens. Télécharger et extraire un fichier zip contenant les images, puis créer une tf.data.Dataset de formation et de validation à l' aide du tf.keras.preprocessing.image_dataset_from_directory utilitaire. Vous pouvez en savoir plus sur le chargement des images dans ce tutoriel .

_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

BATCH_SIZE = 32
IMG_SIZE = (160, 160)

train_dataset = image_dataset_from_directory(train_dir,
                                             shuffle=True,
                                             batch_size=BATCH_SIZE,
                                             image_size=IMG_SIZE)
Downloading data from https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip
68608000/68606236 [==============================] - 1s 0us/step
68616192/68606236 [==============================] - 1s 0us/step
Found 2000 files belonging to 2 classes.
validation_dataset = image_dataset_from_directory(validation_dir,
                                                  shuffle=True,
                                                  batch_size=BATCH_SIZE,
                                                  image_size=IMG_SIZE)
Found 1000 files belonging to 2 classes.

Affichez les neuf premières images et étiquettes de l'ensemble d'entraînement :

class_names = train_dataset.class_names

plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

png

Comme l'ensemble de données d'origine ne contient pas d'ensemble de test, vous en créerez un. Pour ce faire, de déterminer le nombre de lots de données sont disponibles dans l'ensemble de validation à l' aide tf.data.experimental.cardinality , puis déplacez 20% d'entre eux à un ensemble de test.

val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)
print('Number of validation batches: %d' % tf.data.experimental.cardinality(validation_dataset))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_dataset))
Number of validation batches: 26
Number of test batches: 6

Configurer l'ensemble de données pour les performances

Utilisez la prélecture tamponnée pour charger des images à partir du disque sans que les E/S ne deviennent bloquantes. Pour en savoir plus sur cette méthode voir la performance des données guide.

AUTOTUNE = tf.data.AUTOTUNE

train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)

Utiliser l'augmentation des données

Lorsque vous ne disposez pas d'un grand ensemble de données d'images, il est recommandé d'introduire artificiellement une diversité d'échantillons en appliquant des transformations aléatoires, mais réalistes, aux images d'entraînement, telles que la rotation et le retournement horizontal. Cela permet d' exposer le modèle à différents aspects des données de formation et de réduire overfitting . Vous pouvez en apprendre davantage sur l' augmentation des données dans ce tutoriel .

data_augmentation = tf.keras.Sequential([
  tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
  tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])

Appliquons à plusieurs reprises ces calques sur la même image et voyons le résultat.

for image, _ in train_dataset.take(1):
  plt.figure(figsize=(10, 10))
  first_image = image[0]
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    augmented_image = data_augmentation(tf.expand_dims(first_image, 0))
    plt.imshow(augmented_image[0] / 255)
    plt.axis('off')

png

Redimensionner les valeurs des pixels

Dans un moment, vous téléchargerez tf.keras.applications.MobileNetV2 pour être utilisé comme modèle de base. Ce modèle prévoit des valeurs de pixel dans [-1, 1] , mais à ce stade, les valeurs de pixels de vos images sont dans [0, 255] . Pour les redimensionner, utilisez la méthode de prétraitement incluse avec le modèle.

preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
rescale = tf.keras.layers.experimental.preprocessing.Rescaling(1./127.5, offset= -1)

Créer le modèle de base à partir des convnets pré-entraînés

Vous allez créer le modèle de base du modèle MobileNet V2 développé à Google. Ceci est pré-entraîné sur l'ensemble de données ImageNet, un grand ensemble de données composé de 1,4 million d'images et de 1 000 classes. IMAGEnet est un ensemble de données de formation de recherche avec une grande variété de catégories comme jackfruit et syringe . Cette base de connaissances nous aidera à classer les chats et les chiens à partir de notre ensemble de données spécifique.

Tout d'abord, vous devez choisir la couche de MobileNet V2 que vous utiliserez pour l'extraction d'entités. La toute dernière couche de classification (en haut, car la plupart des diagrammes de modèles d'apprentissage automatique vont de bas en haut) n'est pas très utile. Au lieu de cela, vous suivrez la pratique courante de dépendre de la toute dernière couche avant l'opération d'aplatissement. Cette couche est appelée « couche de goulot d'étranglement ». Les caractéristiques de la couche de goulot d'étranglement conservent une plus grande généralité par rapport à la couche finale/supérieure.

Tout d'abord, instanciez un modèle MobileNet V2 préchargé avec des poids entraînés sur ImageNet. En spécifiant l'include_top = faux argument, vous chargez un réseau qui ne comprend pas les couches de classification au sommet, ce qui est idéal pour l' extraction de caractéristiques.

# Create the base model from the pre-trained model MobileNet V2
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step
9420800/9406464 [==============================] - 0s 0us/step

Cet extracteur de fonctionnalité convertit chaque 160x160x3 image en un 5x5x1280 bloc de caractéristiques. Voyons ce que cela fait d'un exemple de lot d'images :

image_batch, label_batch = next(iter(train_dataset))
feature_batch = base_model(image_batch)
print(feature_batch.shape)
(32, 5, 5, 1280)

Extraction de caractéristiques

Dans cette étape, vous allez geler la base convolutive créée à partir de l'étape précédente et l'utiliser comme extracteur de caractéristiques. De plus, vous ajoutez un classificateur par-dessus et entraînez le classificateur de niveau supérieur.

Geler la base convolutive

Il est important de geler la base convolutive avant de compiler et d'entraîner le modèle. Le gel (en définissant layer.trainable = False) empêche les poids dans une couche donnée d'être mis à jour pendant l'entraînement. MobileNet V2 a plusieurs couches, donc définir de l' ensemble du modèle trainable drapeau False va geler tous.

base_model.trainable = False

Remarque importante sur les couches de normalisation par lots

De nombreux modèles contiennent tf.keras.layers.BatchNormalization couches. Cette couche est un cas particulier et des précautions doivent être prises dans le cadre d'un réglage fin, comme indiqué plus loin dans ce tutoriel.

Lorsque vous définissez layer.trainable = False , la BatchNormalization couche fonctionnera en mode d'inférence, et pas mettre à jour ses statistiques moyenne et la variance.

Lorsque vous dégeler un modèle qui contient des couches BatchNormalization afin de faire le réglage fin, vous devez garder les couches de BatchNormalization en mode d'inférence en passant la training = False lorsque vous appelez le modèle de base. Sinon, les mises à jour appliquées aux poids non entraînables détruiront ce que le modèle a appris.

Pour plus de détails, consultez le guide d'apprentissage de transfert .

# Let's take a look at the base model architecture
base_model.summary()
Model: "mobilenetv2_1.00_160"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 160, 160, 3) 0                                            
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 80, 80, 32)   864         input_1[0][0]                    
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 80, 80, 32)   128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, 80, 80, 32)   0           bn_Conv1[0][0]                   
__________________________________________________________________________________________________
expanded_conv_depthwise (Depthw (None, 80, 80, 32)   288         Conv1_relu[0][0]                 
__________________________________________________________________________________________________
expanded_conv_depthwise_BN (Bat (None, 80, 80, 32)   128         expanded_conv_depthwise[0][0]    
__________________________________________________________________________________________________
expanded_conv_depthwise_relu (R (None, 80, 80, 32)   0           expanded_conv_depthwise_BN[0][0] 
__________________________________________________________________________________________________
expanded_conv_project (Conv2D)  (None, 80, 80, 16)   512         expanded_conv_depthwise_relu[0][0
__________________________________________________________________________________________________
expanded_conv_project_BN (Batch (None, 80, 80, 16)   64          expanded_conv_project[0][0]      
__________________________________________________________________________________________________
block_1_expand (Conv2D)         (None, 80, 80, 96)   1536        expanded_conv_project_BN[0][0]   
__________________________________________________________________________________________________
block_1_expand_BN (BatchNormali (None, 80, 80, 96)   384         block_1_expand[0][0]             
__________________________________________________________________________________________________
block_1_expand_relu (ReLU)      (None, 80, 80, 96)   0           block_1_expand_BN[0][0]          
__________________________________________________________________________________________________
block_1_pad (ZeroPadding2D)     (None, 81, 81, 96)   0           block_1_expand_relu[0][0]        
__________________________________________________________________________________________________
block_1_depthwise (DepthwiseCon (None, 40, 40, 96)   864         block_1_pad[0][0]                
__________________________________________________________________________________________________
block_1_depthwise_BN (BatchNorm (None, 40, 40, 96)   384         block_1_depthwise[0][0]          
__________________________________________________________________________________________________
block_1_depthwise_relu (ReLU)   (None, 40, 40, 96)   0           block_1_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_1_project (Conv2D)        (None, 40, 40, 24)   2304        block_1_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_1_project_BN (BatchNormal (None, 40, 40, 24)   96          block_1_project[0][0]            
__________________________________________________________________________________________________
block_2_expand (Conv2D)         (None, 40, 40, 144)  3456        block_1_project_BN[0][0]         
__________________________________________________________________________________________________
block_2_expand_BN (BatchNormali (None, 40, 40, 144)  576         block_2_expand[0][0]             
__________________________________________________________________________________________________
block_2_expand_relu (ReLU)      (None, 40, 40, 144)  0           block_2_expand_BN[0][0]          
__________________________________________________________________________________________________
block_2_depthwise (DepthwiseCon (None, 40, 40, 144)  1296        block_2_expand_relu[0][0]        
__________________________________________________________________________________________________
block_2_depthwise_BN (BatchNorm (None, 40, 40, 144)  576         block_2_depthwise[0][0]          
__________________________________________________________________________________________________
block_2_depthwise_relu (ReLU)   (None, 40, 40, 144)  0           block_2_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_2_project (Conv2D)        (None, 40, 40, 24)   3456        block_2_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_2_project_BN (BatchNormal (None, 40, 40, 24)   96          block_2_project[0][0]            
__________________________________________________________________________________________________
block_2_add (Add)               (None, 40, 40, 24)   0           block_1_project_BN[0][0]         
                                                                 block_2_project_BN[0][0]         
__________________________________________________________________________________________________
block_3_expand (Conv2D)         (None, 40, 40, 144)  3456        block_2_add[0][0]                
__________________________________________________________________________________________________
block_3_expand_BN (BatchNormali (None, 40, 40, 144)  576         block_3_expand[0][0]             
__________________________________________________________________________________________________
block_3_expand_relu (ReLU)      (None, 40, 40, 144)  0           block_3_expand_BN[0][0]          
__________________________________________________________________________________________________
block_3_pad (ZeroPadding2D)     (None, 41, 41, 144)  0           block_3_expand_relu[0][0]        
__________________________________________________________________________________________________
block_3_depthwise (DepthwiseCon (None, 20, 20, 144)  1296        block_3_pad[0][0]                
__________________________________________________________________________________________________
block_3_depthwise_BN (BatchNorm (None, 20, 20, 144)  576         block_3_depthwise[0][0]          
__________________________________________________________________________________________________
block_3_depthwise_relu (ReLU)   (None, 20, 20, 144)  0           block_3_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_3_project (Conv2D)        (None, 20, 20, 32)   4608        block_3_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_3_project_BN (BatchNormal (None, 20, 20, 32)   128         block_3_project[0][0]            
__________________________________________________________________________________________________
block_4_expand (Conv2D)         (None, 20, 20, 192)  6144        block_3_project_BN[0][0]         
__________________________________________________________________________________________________
block_4_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_4_expand[0][0]             
__________________________________________________________________________________________________
block_4_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_4_expand_BN[0][0]          
__________________________________________________________________________________________________
block_4_depthwise (DepthwiseCon (None, 20, 20, 192)  1728        block_4_expand_relu[0][0]        
__________________________________________________________________________________________________
block_4_depthwise_BN (BatchNorm (None, 20, 20, 192)  768         block_4_depthwise[0][0]          
__________________________________________________________________________________________________
block_4_depthwise_relu (ReLU)   (None, 20, 20, 192)  0           block_4_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_4_project (Conv2D)        (None, 20, 20, 32)   6144        block_4_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_4_project_BN (BatchNormal (None, 20, 20, 32)   128         block_4_project[0][0]            
__________________________________________________________________________________________________
block_4_add (Add)               (None, 20, 20, 32)   0           block_3_project_BN[0][0]         
                                                                 block_4_project_BN[0][0]         
__________________________________________________________________________________________________
block_5_expand (Conv2D)         (None, 20, 20, 192)  6144        block_4_add[0][0]                
__________________________________________________________________________________________________
block_5_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_5_expand[0][0]             
__________________________________________________________________________________________________
block_5_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_5_expand_BN[0][0]          
__________________________________________________________________________________________________
block_5_depthwise (DepthwiseCon (None, 20, 20, 192)  1728        block_5_expand_relu[0][0]        
__________________________________________________________________________________________________
block_5_depthwise_BN (BatchNorm (None, 20, 20, 192)  768         block_5_depthwise[0][0]          
__________________________________________________________________________________________________
block_5_depthwise_relu (ReLU)   (None, 20, 20, 192)  0           block_5_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_5_project (Conv2D)        (None, 20, 20, 32)   6144        block_5_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_5_project_BN (BatchNormal (None, 20, 20, 32)   128         block_5_project[0][0]            
__________________________________________________________________________________________________
block_5_add (Add)               (None, 20, 20, 32)   0           block_4_add[0][0]                
                                                                 block_5_project_BN[0][0]         
__________________________________________________________________________________________________
block_6_expand (Conv2D)         (None, 20, 20, 192)  6144        block_5_add[0][0]                
__________________________________________________________________________________________________
block_6_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_6_expand[0][0]             
__________________________________________________________________________________________________
block_6_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_6_expand_BN[0][0]          
__________________________________________________________________________________________________
block_6_pad (ZeroPadding2D)     (None, 21, 21, 192)  0           block_6_expand_relu[0][0]        
__________________________________________________________________________________________________
block_6_depthwise (DepthwiseCon (None, 10, 10, 192)  1728        block_6_pad[0][0]                
__________________________________________________________________________________________________
block_6_depthwise_BN (BatchNorm (None, 10, 10, 192)  768         block_6_depthwise[0][0]          
__________________________________________________________________________________________________
block_6_depthwise_relu (ReLU)   (None, 10, 10, 192)  0           block_6_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_6_project (Conv2D)        (None, 10, 10, 64)   12288       block_6_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_6_project_BN (BatchNormal (None, 10, 10, 64)   256         block_6_project[0][0]            
__________________________________________________________________________________________________
block_7_expand (Conv2D)         (None, 10, 10, 384)  24576       block_6_project_BN[0][0]         
__________________________________________________________________________________________________
block_7_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_7_expand[0][0]             
__________________________________________________________________________________________________
block_7_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_7_expand_BN[0][0]          
__________________________________________________________________________________________________
block_7_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_7_expand_relu[0][0]        
__________________________________________________________________________________________________
block_7_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_7_depthwise[0][0]          
__________________________________________________________________________________________________
block_7_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_7_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_7_project (Conv2D)        (None, 10, 10, 64)   24576       block_7_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_7_project_BN (BatchNormal (None, 10, 10, 64)   256         block_7_project[0][0]            
__________________________________________________________________________________________________
block_7_add (Add)               (None, 10, 10, 64)   0           block_6_project_BN[0][0]         
                                                                 block_7_project_BN[0][0]         
__________________________________________________________________________________________________
block_8_expand (Conv2D)         (None, 10, 10, 384)  24576       block_7_add[0][0]                
__________________________________________________________________________________________________
block_8_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_8_expand[0][0]             
__________________________________________________________________________________________________
block_8_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_8_expand_BN[0][0]          
__________________________________________________________________________________________________
block_8_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_8_expand_relu[0][0]        
__________________________________________________________________________________________________
block_8_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_8_depthwise[0][0]          
__________________________________________________________________________________________________
block_8_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_8_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_8_project (Conv2D)        (None, 10, 10, 64)   24576       block_8_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_8_project_BN (BatchNormal (None, 10, 10, 64)   256         block_8_project[0][0]            
__________________________________________________________________________________________________
block_8_add (Add)               (None, 10, 10, 64)   0           block_7_add[0][0]                
                                                                 block_8_project_BN[0][0]         
__________________________________________________________________________________________________
block_9_expand (Conv2D)         (None, 10, 10, 384)  24576       block_8_add[0][0]                
__________________________________________________________________________________________________
block_9_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_9_expand[0][0]             
__________________________________________________________________________________________________
block_9_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_9_expand_BN[0][0]          
__________________________________________________________________________________________________
block_9_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_9_expand_relu[0][0]        
__________________________________________________________________________________________________
block_9_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_9_depthwise[0][0]          
__________________________________________________________________________________________________
block_9_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_9_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_9_project (Conv2D)        (None, 10, 10, 64)   24576       block_9_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_9_project_BN (BatchNormal (None, 10, 10, 64)   256         block_9_project[0][0]            
__________________________________________________________________________________________________
block_9_add (Add)               (None, 10, 10, 64)   0           block_8_add[0][0]                
                                                                 block_9_project_BN[0][0]         
__________________________________________________________________________________________________
block_10_expand (Conv2D)        (None, 10, 10, 384)  24576       block_9_add[0][0]                
__________________________________________________________________________________________________
block_10_expand_BN (BatchNormal (None, 10, 10, 384)  1536        block_10_expand[0][0]            
__________________________________________________________________________________________________
block_10_expand_relu (ReLU)     (None, 10, 10, 384)  0           block_10_expand_BN[0][0]         
__________________________________________________________________________________________________
block_10_depthwise (DepthwiseCo (None, 10, 10, 384)  3456        block_10_expand_relu[0][0]       
__________________________________________________________________________________________________
block_10_depthwise_BN (BatchNor (None, 10, 10, 384)  1536        block_10_depthwise[0][0]         
__________________________________________________________________________________________________
block_10_depthwise_relu (ReLU)  (None, 10, 10, 384)  0           block_10_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_10_project (Conv2D)       (None, 10, 10, 96)   36864       block_10_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_10_project_BN (BatchNorma (None, 10, 10, 96)   384         block_10_project[0][0]           
__________________________________________________________________________________________________
block_11_expand (Conv2D)        (None, 10, 10, 576)  55296       block_10_project_BN[0][0]        
__________________________________________________________________________________________________
block_11_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_11_expand[0][0]            
__________________________________________________________________________________________________
block_11_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_11_expand_BN[0][0]         
__________________________________________________________________________________________________
block_11_depthwise (DepthwiseCo (None, 10, 10, 576)  5184        block_11_expand_relu[0][0]       
__________________________________________________________________________________________________
block_11_depthwise_BN (BatchNor (None, 10, 10, 576)  2304        block_11_depthwise[0][0]         
__________________________________________________________________________________________________
block_11_depthwise_relu (ReLU)  (None, 10, 10, 576)  0           block_11_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_11_project (Conv2D)       (None, 10, 10, 96)   55296       block_11_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_11_project_BN (BatchNorma (None, 10, 10, 96)   384         block_11_project[0][0]           
__________________________________________________________________________________________________
block_11_add (Add)              (None, 10, 10, 96)   0           block_10_project_BN[0][0]        
                                                                 block_11_project_BN[0][0]        
__________________________________________________________________________________________________
block_12_expand (Conv2D)        (None, 10, 10, 576)  55296       block_11_add[0][0]               
__________________________________________________________________________________________________
block_12_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_12_expand[0][0]            
__________________________________________________________________________________________________
block_12_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_12_expand_BN[0][0]         
__________________________________________________________________________________________________
block_12_depthwise (DepthwiseCo (None, 10, 10, 576)  5184        block_12_expand_relu[0][0]       
__________________________________________________________________________________________________
block_12_depthwise_BN (BatchNor (None, 10, 10, 576)  2304        block_12_depthwise[0][0]         
__________________________________________________________________________________________________
block_12_depthwise_relu (ReLU)  (None, 10, 10, 576)  0           block_12_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_12_project (Conv2D)       (None, 10, 10, 96)   55296       block_12_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_12_project_BN (BatchNorma (None, 10, 10, 96)   384         block_12_project[0][0]           
__________________________________________________________________________________________________
block_12_add (Add)              (None, 10, 10, 96)   0           block_11_add[0][0]               
                                                                 block_12_project_BN[0][0]        
__________________________________________________________________________________________________
block_13_expand (Conv2D)        (None, 10, 10, 576)  55296       block_12_add[0][0]               
__________________________________________________________________________________________________
block_13_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_13_expand[0][0]            
__________________________________________________________________________________________________
block_13_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_13_expand_BN[0][0]         
__________________________________________________________________________________________________
block_13_pad (ZeroPadding2D)    (None, 11, 11, 576)  0           block_13_expand_relu[0][0]       
__________________________________________________________________________________________________
block_13_depthwise (DepthwiseCo (None, 5, 5, 576)    5184        block_13_pad[0][0]               
__________________________________________________________________________________________________
block_13_depthwise_BN (BatchNor (None, 5, 5, 576)    2304        block_13_depthwise[0][0]         
__________________________________________________________________________________________________
block_13_depthwise_relu (ReLU)  (None, 5, 5, 576)    0           block_13_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_13_project (Conv2D)       (None, 5, 5, 160)    92160       block_13_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_13_project_BN (BatchNorma (None, 5, 5, 160)    640         block_13_project[0][0]           
__________________________________________________________________________________________________
block_14_expand (Conv2D)        (None, 5, 5, 960)    153600      block_13_project_BN[0][0]        
__________________________________________________________________________________________________
block_14_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_14_expand[0][0]            
__________________________________________________________________________________________________
block_14_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_14_expand_BN[0][0]         
__________________________________________________________________________________________________
block_14_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_14_expand_relu[0][0]       
__________________________________________________________________________________________________
block_14_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_14_depthwise[0][0]         
__________________________________________________________________________________________________
block_14_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_14_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_14_project (Conv2D)       (None, 5, 5, 160)    153600      block_14_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_14_project_BN (BatchNorma (None, 5, 5, 160)    640         block_14_project[0][0]           
__________________________________________________________________________________________________
block_14_add (Add)              (None, 5, 5, 160)    0           block_13_project_BN[0][0]        
                                                                 block_14_project_BN[0][0]        
__________________________________________________________________________________________________
block_15_expand (Conv2D)        (None, 5, 5, 960)    153600      block_14_add[0][0]               
__________________________________________________________________________________________________
block_15_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_15_expand[0][0]            
__________________________________________________________________________________________________
block_15_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_15_expand_BN[0][0]         
__________________________________________________________________________________________________
block_15_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_15_expand_relu[0][0]       
__________________________________________________________________________________________________
block_15_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_15_depthwise[0][0]         
__________________________________________________________________________________________________
block_15_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_15_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_15_project (Conv2D)       (None, 5, 5, 160)    153600      block_15_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_15_project_BN (BatchNorma (None, 5, 5, 160)    640         block_15_project[0][0]           
__________________________________________________________________________________________________
block_15_add (Add)              (None, 5, 5, 160)    0           block_14_add[0][0]               
                                                                 block_15_project_BN[0][0]        
__________________________________________________________________________________________________
block_16_expand (Conv2D)        (None, 5, 5, 960)    153600      block_15_add[0][0]               
__________________________________________________________________________________________________
block_16_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_16_expand[0][0]            
__________________________________________________________________________________________________
block_16_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_16_expand_BN[0][0]         
__________________________________________________________________________________________________
block_16_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_16_expand_relu[0][0]       
__________________________________________________________________________________________________
block_16_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_16_depthwise[0][0]         
__________________________________________________________________________________________________
block_16_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_16_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_16_project (Conv2D)       (None, 5, 5, 320)    307200      block_16_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_16_project_BN (BatchNorma (None, 5, 5, 320)    1280        block_16_project[0][0]           
__________________________________________________________________________________________________
Conv_1 (Conv2D)                 (None, 5, 5, 1280)   409600      block_16_project_BN[0][0]        
__________________________________________________________________________________________________
Conv_1_bn (BatchNormalization)  (None, 5, 5, 1280)   5120        Conv_1[0][0]                     
__________________________________________________________________________________________________
out_relu (ReLU)                 (None, 5, 5, 1280)   0           Conv_1_bn[0][0]                  
==================================================================================================
Total params: 2,257,984
Trainable params: 0
Non-trainable params: 2,257,984
__________________________________________________________________________________________________

Ajouter une tête de classement

Pour générer des prédictions à partir du bloc de caractéristiques, en moyenne au cours des spatiales 5x5 emplacements dans l' espace, en utilisant un tf.keras.layers.GlobalAveragePooling2D couche pour convertir les caractéristiques d'un seul vecteur 1280-élément par image.

global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)
(32, 1280)

Appliquer une tf.keras.layers.Dense couche pour convertir ces caractéristiques en une seule prédiction par image. Vous n'avez pas besoin d' une fonction d'activation ici parce que cette prédiction sera traitée comme un logit , ou une valeur de prédiction brute. Les nombres positifs prédisent la classe 1, les nombres négatifs prédisent la classe 0.

prediction_layer = tf.keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)
(32, 1)

Construire un modèle en enchaînant l'augmentation des données, redimensionnant, base_model et fonctionnalité des couches en utilisant l'extracteur Keras API fonctionnelle . Comme mentionné précédemment, l' utilisation de training=False comme notre modèle contient une BatchNormalization couche.

inputs = tf.keras.Input(shape=(160, 160, 3))
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)

Compiler le modèle

Compilez le modèle avant de l'entraîner. Comme il y a deux classes, utilisez la tf.keras.losses.BinaryCrossentropy perte avec from_logits=True puisque le modèle fournit une sortie linéaire.

base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
sequential (Sequential)      (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.truediv (TFOpLambda) (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.subtract (TFOpLambda (None, 160, 160, 3)       0         
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,281
Non-trainable params: 2,257,984
_________________________________________________________________

Les paramètres de 2,5M dans MobileNet sont congelés, mais il y a des paramètres 1.2K trainable dans la couche dense. Ceux - ci sont divisés entre deux tf.Variable objets, les poids et les biais.

len(model.trainable_variables)
2

Former le modèle

Après une formation sur 10 époques, vous devriez voir une précision d'environ 94 % sur l'ensemble de validation.

initial_epochs = 10

loss0, accuracy0 = model.evaluate(validation_dataset)
26/26 [==============================] - 2s 18ms/step - loss: 0.8367 - accuracy: 0.4851
print("initial loss: {:.2f}".format(loss0))
print("initial accuracy: {:.2f}".format(accuracy0))
initial loss: 0.84
initial accuracy: 0.49
history = model.fit(train_dataset,
                    epochs=initial_epochs,
                    validation_data=validation_dataset)
Epoch 1/10
63/63 [==============================] - 4s 26ms/step - loss: 0.7017 - accuracy: 0.5825 - val_loss: 0.5713 - val_accuracy: 0.6968
Epoch 2/10
63/63 [==============================] - 2s 24ms/step - loss: 0.5313 - accuracy: 0.7140 - val_loss: 0.4127 - val_accuracy: 0.8057
Epoch 3/10
63/63 [==============================] - 2s 24ms/step - loss: 0.4209 - accuracy: 0.7980 - val_loss: 0.3240 - val_accuracy: 0.8478
Epoch 4/10
63/63 [==============================] - 2s 24ms/step - loss: 0.3611 - accuracy: 0.8210 - val_loss: 0.2657 - val_accuracy: 0.8874
Epoch 5/10
63/63 [==============================] - 2s 23ms/step - loss: 0.3088 - accuracy: 0.8605 - val_loss: 0.2233 - val_accuracy: 0.9171
Epoch 6/10
63/63 [==============================] - 2s 23ms/step - loss: 0.2744 - accuracy: 0.8835 - val_loss: 0.1952 - val_accuracy: 0.9307
Epoch 7/10
63/63 [==============================] - 2s 23ms/step - loss: 0.2481 - accuracy: 0.8930 - val_loss: 0.1746 - val_accuracy: 0.9443
Epoch 8/10
63/63 [==============================] - 1s 23ms/step - loss: 0.2341 - accuracy: 0.9050 - val_loss: 0.1580 - val_accuracy: 0.9468
Epoch 9/10
63/63 [==============================] - 2s 23ms/step - loss: 0.2212 - accuracy: 0.9070 - val_loss: 0.1490 - val_accuracy: 0.9493
Epoch 10/10
63/63 [==============================] - 2s 23ms/step - loss: 0.1947 - accuracy: 0.9205 - val_loss: 0.1347 - val_accuracy: 0.9530

Courbes d'apprentissage

Examinons les courbes d'apprentissage de la précision/perte d'entraînement et de validation lors de l'utilisation du modèle de base MobileNet V2 comme extracteur de caractéristiques fixes.

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

png

Dans une moindre mesure, c'est aussi parce que les métriques d'entraînement rapportent la moyenne pour une époque, tandis que les métriques de validation sont évaluées après l'époque, donc les métriques de validation voient un modèle qui s'est entraîné un peu plus longtemps.

Réglage fin

Dans l'expérience d'extraction de caractéristiques, vous n'étiez que quelques couches au-dessus d'un modèle de base MobileNet V2. Les poids du réseau pré-formés ont été pas mis à jour lors de la formation.

Une façon d'augmenter encore les performances consiste à entraîner (ou "affiner") les poids des couches supérieures du modèle pré-entraîné parallèlement à l'entraînement du classificateur que vous avez ajouté. Le processus de formation forcera les poids à être ajustés à partir de cartes de caractéristiques génériques vers des caractéristiques associées spécifiquement à l'ensemble de données.

En outre, vous devriez essayer d'affiner un petit nombre de couches supérieures plutôt que l'ensemble du modèle MobileNet. Dans la plupart des réseaux convolutifs, plus une couche est élevée, plus elle est spécialisée. Les premières couches apprennent des fonctionnalités très simples et génériques qui se généralisent à presque tous les types d'images. Au fur et à mesure que vous montez, les fonctionnalités sont de plus en plus spécifiques au jeu de données sur lequel le modèle a été formé. L'objectif du réglage fin est d'adapter ces fonctionnalités spécialisées pour qu'elles fonctionnent avec le nouvel ensemble de données, plutôt que d'écraser l'apprentissage générique.

Dégeler les couches supérieures du modèle

Tout ce que vous devez faire est de dégeler le base_model et définir les couches inférieures pour être non trainable. Ensuite, vous devez recompiler le modèle (nécessaire pour que ces modifications prennent effet) et reprendre l'entraînement.

base_model.trainable = True
# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model.layers))

# Fine-tune from this layer onwards
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
  layer.trainable =  False
Number of layers in the base model:  154

Compiler le modèle

Comme vous entraînez un modèle beaucoup plus grand et que vous souhaitez réadapter les poids pré-entraînés, il est important d'utiliser un taux d'apprentissage inférieur à ce stade. Sinon, votre modèle pourrait surdimensionner très rapidement.

model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate/10),
              metrics=['accuracy'])
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py:356: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  "The `lr` argument is deprecated, use `learning_rate` instead.")
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
sequential (Sequential)      (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.truediv (TFOpLambda) (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.subtract (TFOpLambda (None, 160, 160, 3)       0         
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,862,721
Non-trainable params: 396,544
_________________________________________________________________
len(model.trainable_variables)
56

Continuer la formation du modèle

Si vous vous êtes entraîné à la convergence plus tôt, cette étape améliorera votre précision de quelques points de pourcentage.

fine_tune_epochs = 10
total_epochs =  initial_epochs + fine_tune_epochs

history_fine = model.fit(train_dataset,
                         epochs=total_epochs,
                         initial_epoch=history.epoch[-1],
                         validation_data=validation_dataset)
Epoch 10/20
63/63 [==============================] - 8s 45ms/step - loss: 0.1384 - accuracy: 0.9440 - val_loss: 0.0653 - val_accuracy: 0.9839
Epoch 11/20
63/63 [==============================] - 2s 32ms/step - loss: 0.1087 - accuracy: 0.9520 - val_loss: 0.0447 - val_accuracy: 0.9827
Epoch 12/20
63/63 [==============================] - 2s 32ms/step - loss: 0.1177 - accuracy: 0.9510 - val_loss: 0.0453 - val_accuracy: 0.9839
Epoch 13/20
63/63 [==============================] - 2s 32ms/step - loss: 0.0955 - accuracy: 0.9620 - val_loss: 0.0547 - val_accuracy: 0.9802
Epoch 14/20
63/63 [==============================] - 2s 32ms/step - loss: 0.0838 - accuracy: 0.9700 - val_loss: 0.0326 - val_accuracy: 0.9864
Epoch 15/20
63/63 [==============================] - 2s 32ms/step - loss: 0.0828 - accuracy: 0.9655 - val_loss: 0.0403 - val_accuracy: 0.9827
Epoch 16/20
63/63 [==============================] - 2s 32ms/step - loss: 0.0807 - accuracy: 0.9670 - val_loss: 0.0487 - val_accuracy: 0.9790
Epoch 17/20
63/63 [==============================] - 2s 32ms/step - loss: 0.0696 - accuracy: 0.9730 - val_loss: 0.0406 - val_accuracy: 0.9814
Epoch 18/20
63/63 [==============================] - 2s 32ms/step - loss: 0.0616 - accuracy: 0.9735 - val_loss: 0.0387 - val_accuracy: 0.9851
Epoch 19/20
63/63 [==============================] - 2s 32ms/step - loss: 0.0630 - accuracy: 0.9725 - val_loss: 0.0342 - val_accuracy: 0.9839
Epoch 20/20
63/63 [==============================] - 2s 32ms/step - loss: 0.0599 - accuracy: 0.9750 - val_loss: 0.0355 - val_accuracy: 0.9851

Examinons les courbes d'apprentissage de la précision/perte de l'entraînement et de la validation lors du réglage fin des dernières couches du modèle de base MobileNet V2 et de l'entraînement du classificateur par-dessus. La perte de validation est beaucoup plus élevée que la perte d'entraînement, vous pouvez donc avoir un surapprentissage.

Vous pouvez également obtenir un surapprentissage car le nouvel ensemble d'apprentissage est relativement petit et similaire aux ensembles de données MobileNet V2 d'origine.

Après un réglage fin, le modèle atteint près de 98% de précision sur l'ensemble de validation.

acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']

loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0.8, 1])
plt.plot([initial_epochs-1,initial_epochs-1],
          plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 1.0])
plt.plot([initial_epochs-1,initial_epochs-1],
         plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

png

Évaluation et prédiction

Enfin, vous pouvez vérifier les performances du modèle sur de nouvelles données à l'aide d'un ensemble de tests.

loss, accuracy = model.evaluate(test_dataset)
print('Test accuracy :', accuracy)
6/6 [==============================] - 0s 16ms/step - loss: 0.0431 - accuracy: 0.9740
Test accuracy : 0.9739583134651184

Et maintenant, vous êtes prêt à utiliser ce modèle pour prédire si votre animal de compagnie est un chat ou un chien.

# Retrieve a batch of images from the test set
image_batch, label_batch = test_dataset.as_numpy_iterator().next()
predictions = model.predict_on_batch(image_batch).flatten()

# Apply a sigmoid since our model returns logits
predictions = tf.nn.sigmoid(predictions)
predictions = tf.where(predictions < 0.5, 0, 1)

print('Predictions:\n', predictions.numpy())
print('Labels:\n', label_batch)

plt.figure(figsize=(10, 10))
for i in range(9):
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(image_batch[i].astype("uint8"))
  plt.title(class_names[predictions[i]])
  plt.axis("off")
Predictions:
 [0 1 0 1 0 1 0 0 1 0 0 0 0 0 0 0 1 1 0 0 1 0 1 0 1 1 1 0 1 0 1 0]
Labels:
 [0 1 0 1 0 1 0 0 1 0 0 0 0 0 0 0 1 1 0 0 1 0 1 0 1 1 1 0 1 0 1 0]

png

Sommaire

  • En utilisant un modèle de pré-formation pour l' extraction de caractéristiques: Lorsque vous travaillez avec un petit jeu de données, il est une pratique courante de tirer parti des fonctionnalités acquises par un modèle formé sur un ensemble de données plus vaste dans le même domaine. Cela se fait en instanciant le modèle pré-entraîné et en ajoutant un classificateur entièrement connecté au-dessus. Le modèle pré-entraîné est « figé » et seuls les poids du classificateur sont mis à jour pendant l'entraînement. Dans ce cas, la base convolutive a extrait toutes les caractéristiques associées à chaque image et vous venez d'entraîner un classificateur qui détermine la classe d'images en fonction de cet ensemble de caractéristiques extraites.

  • Affinant un modèle de pré-formation: Pour améliorer encore les performances, on peut vouloir reformater les couches de niveau supérieur des modèles pré-formés au nouvel ensemble de données via réglage fin. Dans ce cas, vous avez ajusté vos pondérations de sorte que votre modèle apprenne des caractéristiques de haut niveau spécifiques à l'ensemble de données. Cette technique est généralement recommandée lorsque l'ensemble de données d'apprentissage est volumineux et très similaire à l'ensemble de données d'origine sur lequel le modèle pré-entraîné a été entraîné.

Pour en savoir plus, visitez le guide d'apprentissage de transfert .

# MIT License
#
# Copyright (c) 2017 François Chollet                                                                                                                    # IGNORE_COPYRIGHT: cleared by OSS licensing
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.