Eine Frage haben? Verbinden Sie sich mit der Community im TensorFlow Forum Visit Forum

Transferlernen und Feinabstimmung

Auf TensorFlow.org ansehen In Google Colab ausführen Quelle auf GitHub anzeigen Notizbuch herunterladen

In diesem Tutorial erfahren Sie, wie Sie Bilder von Katzen und Hunden mithilfe von Transferlernen aus einem vortrainierten Netzwerk klassifizieren.

Ein vortrainiertes Modell ist ein gespeichertes Netzwerk, das zuvor an einem großen Datensatz trainiert wurde, typischerweise an einer groß angelegten Bildklassifizierungsaufgabe. Sie verwenden entweder das vortrainierte Modell unverändert oder verwenden Transfer-Learning, um dieses Modell an eine bestimmte Aufgabe anzupassen.

Die Intuition hinter dem Transferlernen für die Bildklassifizierung ist, dass, wenn ein Modell mit einem großen und ausreichend allgemeinen Datensatz trainiert wird, dieses Modell effektiv als generisches Modell der visuellen Welt dient. Sie können dann diese erlernten Feature-Maps nutzen, ohne bei Null anfangen zu müssen, indem Sie ein großes Modell mit einem großen Dataset trainieren.

In diesem Notebook werden Sie zwei Möglichkeiten ausprobieren, um ein vortrainiertes Modell anzupassen:

  1. Merkmalsextraktion: Verwenden Sie die Repräsentationen, die von einem früheren Netzwerk gelernt wurden, um aussagekräftige Merkmale aus neuen Stichproben zu extrahieren. Sie fügen dem vortrainierten Modell einfach einen neuen Klassifikator hinzu, der von Grund auf neu trainiert wird, damit Sie die zuvor gelernten Feature-Maps für das Dataset wiederverwenden können.

    Sie müssen nicht das gesamte Modell (neu) trainieren. Das grundlegende Faltungsnetzwerk enthält bereits Merkmale, die allgemein zum Klassifizieren von Bildern nützlich sind. Der letzte Klassifikationsteil des vortrainierten Modells ist jedoch spezifisch für die ursprüngliche Klassifikationsaufgabe und anschließend spezifisch für den Klassensatz, auf dem das Modell trainiert wurde.

  2. Feinabstimmung: Entfrosten Sie einige der oberen Schichten einer eingefrorenen Modellbasis und trainieren Sie gemeinsam sowohl die neu hinzugefügten Klassifikatorschichten als auch die letzten Schichten des Basismodells. Auf diese Weise können wir die Feature-Repräsentationen höherer Ordnung im Basismodell "feinabstimmen", um sie für die spezifische Aufgabe relevanter zu machen.

Sie folgen dem allgemeinen Workflow für maschinelles Lernen.

  1. Untersuchen und verstehen Sie die Daten
  2. Erstellen Sie eine Eingabepipeline, in diesem Fall mit Keras ImageDataGenerator
  3. Stellen Sie das Modell zusammen
    • Laden im vortrainierten Basismodell (und vortrainierten Gewichten)
    • Stapeln Sie die Klassifizierungsebenen darüber layers
  4. Trainiere das Modell
  5. Modell auswerten
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf

from tensorflow.keras.preprocessing import image_dataset_from_directory

Datenvorverarbeitung

Datendownload

In diesem Tutorial verwenden Sie einen Datensatz mit mehreren Tausend Bildern von Katzen und Hunden. Laden Sie eine ZIP-Datei mit den Bildern herunter und extrahieren Sie sie. Erstellentf.data.Dataset dann mit dem Dienstprogramm tf.keras.preprocessing.image_dataset_from_directory eintf.data.Dataset für das Training und die Validierung. In diesem Tutorial erfahren Sie mehr über das Laden von Bildern.

_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

BATCH_SIZE = 32
IMG_SIZE = (160, 160)

train_dataset = image_dataset_from_directory(train_dir,
                                             shuffle=True,
                                             batch_size=BATCH_SIZE,
                                             image_size=IMG_SIZE)
Downloading data from https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip
68608000/68606236 [==============================] - 2s 0us/step
Found 2000 files belonging to 2 classes.
validation_dataset = image_dataset_from_directory(validation_dir,
                                                  shuffle=True,
                                                  batch_size=BATCH_SIZE,
                                                  image_size=IMG_SIZE)
Found 1000 files belonging to 2 classes.

Zeigen Sie die ersten neun Bilder und Labels aus dem Trainingsset an:

class_names = train_dataset.class_names

plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

png

Da das ursprüngliche Dataset kein Testset enthält, erstellen Sie eines. Bestimmen Sie dazu mit tf.data.experimental.cardinality , wie viele tf.data.experimental.cardinality im Validierungssatz tf.data.experimental.cardinality , und verschieben Sie 20 % davon in einen tf.data.experimental.cardinality .

val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)
print('Number of validation batches: %d' % tf.data.experimental.cardinality(validation_dataset))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_dataset))
Number of validation batches: 26
Number of test batches: 6

Konfigurieren Sie das Dataset für die Leistung

Verwenden Sie gepuffertes Prefetching, um Bilder von der Festplatte zu laden, ohne dass die E/A blockiert wird. Weitere Informationen zu dieser Methode finden Sie im Leitfaden zur Datenleistung .

AUTOTUNE = tf.data.AUTOTUNE

train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)

Datenerweiterung verwenden

Wenn Sie kein großes Bild-Dataset haben, empfiehlt es sich, die Stichprobendiversität künstlich einzuführen, indem Sie zufällige, aber realistische Transformationen auf die Trainingsbilder anwenden, z. B. Drehung und horizontales Spiegeln. Dies trägt dazu bei, das Modell verschiedenen Aspekten der Trainingsdaten auszusetzen und eine Überanpassung zu reduzieren. In diesem Tutorial erfahren Sie mehr über die Datenerweiterung.

data_augmentation = tf.keras.Sequential([
  tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
  tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])

Lassen Sie uns diese Ebenen wiederholt auf dasselbe Bild anwenden und das Ergebnis sehen.

for image, _ in train_dataset.take(1):
  plt.figure(figsize=(10, 10))
  first_image = image[0]
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    augmented_image = data_augmentation(tf.expand_dims(first_image, 0))
    plt.imshow(augmented_image[0] / 255)
    plt.axis('off')

png

Pixelwerte neu skalieren

In Kürze laden Sie tf.keras.applications.MobileNetV2 zur Verwendung als Basismodell herunter. Dieses Modell erwartet Pixelwerte in [-1, 1] , aber zu diesem Zeitpunkt sind die Pixelwerte in Ihren Bildern in [0, 255] . Um sie neu zu skalieren, verwenden Sie die im Modell enthaltene Vorverarbeitungsmethode.

preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
rescale = tf.keras.layers.experimental.preprocessing.Rescaling(1./127.5, offset= -1)

Erstellen Sie das Basismodell aus den vortrainierten Convnets

Sie erstellen das Basismodell aus dem bei Google entwickelten MobileNet V2- Modell. Dies wird auf dem ImageNet-Datensatz vortrainiert, einem großen Datensatz, der aus 1,4 Millionen Bildern und 1000 Klassen besteht. ImageNet ist ein Forschungstrainingsdatensatz mit einer Vielzahl von Kategorien wie jackfruit und syringe . Diese Wissensbasis wird uns helfen, Katzen und Hunde aus unserem spezifischen Datensatz zu klassifizieren.

Zuerst müssen Sie auswählen, welche Schicht von MobileNet V2 Sie für die Funktionsextraktion verwenden möchten. Die allerletzte Klassifizierungsschicht (auf "oben", da die meisten Diagramme von Machine-Learning-Modellen von unten nach oben verlaufen) ist nicht sehr nützlich. Stattdessen folgen Sie der üblichen Praxis, sich vor dem Abflachen auf die allerletzte Schicht zu verlassen. Diese Schicht wird als "Flaschenhalsschicht" bezeichnet. Die Merkmale der Engpassschicht behalten im Vergleich zur letzten/obersten Schicht eine größere Allgemeingültigkeit.

Instanziieren Sie zunächst ein MobileNet V2-Modell, das mit auf ImageNet trainierten Gewichtungen vorinstalliert ist. Durch Angabe des Arguments include_top=False laden Sie ein Netzwerk, das die Klassifizierungs-Layer oben nicht enthält, was ideal für die Feature-Extraktion ist.

# Create the base model from the pre-trained model MobileNet V2
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step

Dieser Feature-Extraktor konvertiert jedes 160x160x3 Bild in einen 5x5x1280 Block von Features. Sehen wir uns an, was es mit einem Beispielstapel von Bildern macht:

image_batch, label_batch = next(iter(train_dataset))
feature_batch = base_model(image_batch)
print(feature_batch.shape)
(32, 5, 5, 1280)

Merkmalsextraktion

In diesem Schritt frieren Sie die im vorherigen Schritt erstellte Faltungsbasis ein und verwenden sie als Feature-Extraktor. Darüber hinaus fügen Sie darüber einen Klassifikator hinzu und trainieren den Klassifikator der obersten Ebene.

Friere die Faltungsbasis ein

Es ist wichtig, die Faltungsbasis einzufrieren, bevor Sie das Modell kompilieren und trainieren. Das Einfrieren (durch Setzen von layer.trainable = False) verhindert, dass die Gewichtungen in einem bestimmten Layer während des Trainings aktualisiert werden. MobileNet V2 hat viele Ebenen, so dass das Setzen des trainable Flags des gesamten Modells auf False alle einfriert.

base_model.trainable = False

Wichtiger Hinweis zu BatchNormalization-Layern

Viele Modelle enthalten tf.keras.layers.BatchNormalization Layer. Dieser Layer ist ein Sonderfall und im Rahmen der Feinabstimmung sollten Vorsichtsmaßnahmen getroffen werden, wie später in diesem Tutorial gezeigt.

Wenn Sie layer.trainable = False BatchNormalization , wird der BatchNormalization Layer im Inferenzmodus ausgeführt und aktualisiert seine Mittelwert- und Varianzstatistiken nicht.

Wenn Sie ein Modell freigeben, das BatchNormalization-Layer enthält, um eine Feinabstimmung durchzuführen, sollten Sie die BatchNormalization-Layer im Inferenzmodus belassen, indem Sie beim Aufrufen des Basismodells training = False . Andernfalls zerstören die Aktualisierungen, die auf die nicht trainierbaren Gewichtungen angewendet werden, das, was das Modell gelernt hat.

Weitere Informationen finden Sie im Transfer-Lernhandbuch .

# Let's take a look at the base model architecture
base_model.summary()
Model: "mobilenetv2_1.00_160"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 160, 160, 3) 0                                            
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 80, 80, 32)   864         input_1[0][0]                    
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 80, 80, 32)   128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, 80, 80, 32)   0           bn_Conv1[0][0]                   
__________________________________________________________________________________________________
expanded_conv_depthwise (Depthw (None, 80, 80, 32)   288         Conv1_relu[0][0]                 
__________________________________________________________________________________________________
expanded_conv_depthwise_BN (Bat (None, 80, 80, 32)   128         expanded_conv_depthwise[0][0]    
__________________________________________________________________________________________________
expanded_conv_depthwise_relu (R (None, 80, 80, 32)   0           expanded_conv_depthwise_BN[0][0] 
__________________________________________________________________________________________________
expanded_conv_project (Conv2D)  (None, 80, 80, 16)   512         expanded_conv_depthwise_relu[0][0
__________________________________________________________________________________________________
expanded_conv_project_BN (Batch (None, 80, 80, 16)   64          expanded_conv_project[0][0]      
__________________________________________________________________________________________________
block_1_expand (Conv2D)         (None, 80, 80, 96)   1536        expanded_conv_project_BN[0][0]   
__________________________________________________________________________________________________
block_1_expand_BN (BatchNormali (None, 80, 80, 96)   384         block_1_expand[0][0]             
__________________________________________________________________________________________________
block_1_expand_relu (ReLU)      (None, 80, 80, 96)   0           block_1_expand_BN[0][0]          
__________________________________________________________________________________________________
block_1_pad (ZeroPadding2D)     (None, 81, 81, 96)   0           block_1_expand_relu[0][0]        
__________________________________________________________________________________________________
block_1_depthwise (DepthwiseCon (None, 40, 40, 96)   864         block_1_pad[0][0]                
__________________________________________________________________________________________________
block_1_depthwise_BN (BatchNorm (None, 40, 40, 96)   384         block_1_depthwise[0][0]          
__________________________________________________________________________________________________
block_1_depthwise_relu (ReLU)   (None, 40, 40, 96)   0           block_1_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_1_project (Conv2D)        (None, 40, 40, 24)   2304        block_1_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_1_project_BN (BatchNormal (None, 40, 40, 24)   96          block_1_project[0][0]            
__________________________________________________________________________________________________
block_2_expand (Conv2D)         (None, 40, 40, 144)  3456        block_1_project_BN[0][0]         
__________________________________________________________________________________________________
block_2_expand_BN (BatchNormali (None, 40, 40, 144)  576         block_2_expand[0][0]             
__________________________________________________________________________________________________
block_2_expand_relu (ReLU)      (None, 40, 40, 144)  0           block_2_expand_BN[0][0]          
__________________________________________________________________________________________________
block_2_depthwise (DepthwiseCon (None, 40, 40, 144)  1296        block_2_expand_relu[0][0]        
__________________________________________________________________________________________________
block_2_depthwise_BN (BatchNorm (None, 40, 40, 144)  576         block_2_depthwise[0][0]          
__________________________________________________________________________________________________
block_2_depthwise_relu (ReLU)   (None, 40, 40, 144)  0           block_2_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_2_project (Conv2D)        (None, 40, 40, 24)   3456        block_2_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_2_project_BN (BatchNormal (None, 40, 40, 24)   96          block_2_project[0][0]            
__________________________________________________________________________________________________
block_2_add (Add)               (None, 40, 40, 24)   0           block_1_project_BN[0][0]         
                                                                 block_2_project_BN[0][0]         
__________________________________________________________________________________________________
block_3_expand (Conv2D)         (None, 40, 40, 144)  3456        block_2_add[0][0]                
__________________________________________________________________________________________________
block_3_expand_BN (BatchNormali (None, 40, 40, 144)  576         block_3_expand[0][0]             
__________________________________________________________________________________________________
block_3_expand_relu (ReLU)      (None, 40, 40, 144)  0           block_3_expand_BN[0][0]          
__________________________________________________________________________________________________
block_3_pad (ZeroPadding2D)     (None, 41, 41, 144)  0           block_3_expand_relu[0][0]        
__________________________________________________________________________________________________
block_3_depthwise (DepthwiseCon (None, 20, 20, 144)  1296        block_3_pad[0][0]                
__________________________________________________________________________________________________
block_3_depthwise_BN (BatchNorm (None, 20, 20, 144)  576         block_3_depthwise[0][0]          
__________________________________________________________________________________________________
block_3_depthwise_relu (ReLU)   (None, 20, 20, 144)  0           block_3_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_3_project (Conv2D)        (None, 20, 20, 32)   4608        block_3_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_3_project_BN (BatchNormal (None, 20, 20, 32)   128         block_3_project[0][0]            
__________________________________________________________________________________________________
block_4_expand (Conv2D)         (None, 20, 20, 192)  6144        block_3_project_BN[0][0]         
__________________________________________________________________________________________________
block_4_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_4_expand[0][0]             
__________________________________________________________________________________________________
block_4_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_4_expand_BN[0][0]          
__________________________________________________________________________________________________
block_4_depthwise (DepthwiseCon (None, 20, 20, 192)  1728        block_4_expand_relu[0][0]        
__________________________________________________________________________________________________
block_4_depthwise_BN (BatchNorm (None, 20, 20, 192)  768         block_4_depthwise[0][0]          
__________________________________________________________________________________________________
block_4_depthwise_relu (ReLU)   (None, 20, 20, 192)  0           block_4_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_4_project (Conv2D)        (None, 20, 20, 32)   6144        block_4_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_4_project_BN (BatchNormal (None, 20, 20, 32)   128         block_4_project[0][0]            
__________________________________________________________________________________________________
block_4_add (Add)               (None, 20, 20, 32)   0           block_3_project_BN[0][0]         
                                                                 block_4_project_BN[0][0]         
__________________________________________________________________________________________________
block_5_expand (Conv2D)         (None, 20, 20, 192)  6144        block_4_add[0][0]                
__________________________________________________________________________________________________
block_5_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_5_expand[0][0]             
__________________________________________________________________________________________________
block_5_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_5_expand_BN[0][0]          
__________________________________________________________________________________________________
block_5_depthwise (DepthwiseCon (None, 20, 20, 192)  1728        block_5_expand_relu[0][0]        
__________________________________________________________________________________________________
block_5_depthwise_BN (BatchNorm (None, 20, 20, 192)  768         block_5_depthwise[0][0]          
__________________________________________________________________________________________________
block_5_depthwise_relu (ReLU)   (None, 20, 20, 192)  0           block_5_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_5_project (Conv2D)        (None, 20, 20, 32)   6144        block_5_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_5_project_BN (BatchNormal (None, 20, 20, 32)   128         block_5_project[0][0]            
__________________________________________________________________________________________________
block_5_add (Add)               (None, 20, 20, 32)   0           block_4_add[0][0]                
                                                                 block_5_project_BN[0][0]         
__________________________________________________________________________________________________
block_6_expand (Conv2D)         (None, 20, 20, 192)  6144        block_5_add[0][0]                
__________________________________________________________________________________________________
block_6_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_6_expand[0][0]             
__________________________________________________________________________________________________
block_6_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_6_expand_BN[0][0]          
__________________________________________________________________________________________________
block_6_pad (ZeroPadding2D)     (None, 21, 21, 192)  0           block_6_expand_relu[0][0]        
__________________________________________________________________________________________________
block_6_depthwise (DepthwiseCon (None, 10, 10, 192)  1728        block_6_pad[0][0]                
__________________________________________________________________________________________________
block_6_depthwise_BN (BatchNorm (None, 10, 10, 192)  768         block_6_depthwise[0][0]          
__________________________________________________________________________________________________
block_6_depthwise_relu (ReLU)   (None, 10, 10, 192)  0           block_6_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_6_project (Conv2D)        (None, 10, 10, 64)   12288       block_6_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_6_project_BN (BatchNormal (None, 10, 10, 64)   256         block_6_project[0][0]            
__________________________________________________________________________________________________
block_7_expand (Conv2D)         (None, 10, 10, 384)  24576       block_6_project_BN[0][0]         
__________________________________________________________________________________________________
block_7_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_7_expand[0][0]             
__________________________________________________________________________________________________
block_7_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_7_expand_BN[0][0]          
__________________________________________________________________________________________________
block_7_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_7_expand_relu[0][0]        
__________________________________________________________________________________________________
block_7_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_7_depthwise[0][0]          
__________________________________________________________________________________________________
block_7_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_7_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_7_project (Conv2D)        (None, 10, 10, 64)   24576       block_7_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_7_project_BN (BatchNormal (None, 10, 10, 64)   256         block_7_project[0][0]            
__________________________________________________________________________________________________
block_7_add (Add)               (None, 10, 10, 64)   0           block_6_project_BN[0][0]         
                                                                 block_7_project_BN[0][0]         
__________________________________________________________________________________________________
block_8_expand (Conv2D)         (None, 10, 10, 384)  24576       block_7_add[0][0]                
__________________________________________________________________________________________________
block_8_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_8_expand[0][0]             
__________________________________________________________________________________________________
block_8_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_8_expand_BN[0][0]          
__________________________________________________________________________________________________
block_8_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_8_expand_relu[0][0]        
__________________________________________________________________________________________________
block_8_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_8_depthwise[0][0]          
__________________________________________________________________________________________________
block_8_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_8_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_8_project (Conv2D)        (None, 10, 10, 64)   24576       block_8_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_8_project_BN (BatchNormal (None, 10, 10, 64)   256         block_8_project[0][0]            
__________________________________________________________________________________________________
block_8_add (Add)               (None, 10, 10, 64)   0           block_7_add[0][0]                
                                                                 block_8_project_BN[0][0]         
__________________________________________________________________________________________________
block_9_expand (Conv2D)         (None, 10, 10, 384)  24576       block_8_add[0][0]                
__________________________________________________________________________________________________
block_9_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_9_expand[0][0]             
__________________________________________________________________________________________________
block_9_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_9_expand_BN[0][0]          
__________________________________________________________________________________________________
block_9_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_9_expand_relu[0][0]        
__________________________________________________________________________________________________
block_9_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_9_depthwise[0][0]          
__________________________________________________________________________________________________
block_9_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_9_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_9_project (Conv2D)        (None, 10, 10, 64)   24576       block_9_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_9_project_BN (BatchNormal (None, 10, 10, 64)   256         block_9_project[0][0]            
__________________________________________________________________________________________________
block_9_add (Add)               (None, 10, 10, 64)   0           block_8_add[0][0]                
                                                                 block_9_project_BN[0][0]         
__________________________________________________________________________________________________
block_10_expand (Conv2D)        (None, 10, 10, 384)  24576       block_9_add[0][0]                
__________________________________________________________________________________________________
block_10_expand_BN (BatchNormal (None, 10, 10, 384)  1536        block_10_expand[0][0]            
__________________________________________________________________________________________________
block_10_expand_relu (ReLU)     (None, 10, 10, 384)  0           block_10_expand_BN[0][0]         
__________________________________________________________________________________________________
block_10_depthwise (DepthwiseCo (None, 10, 10, 384)  3456        block_10_expand_relu[0][0]       
__________________________________________________________________________________________________
block_10_depthwise_BN (BatchNor (None, 10, 10, 384)  1536        block_10_depthwise[0][0]         
__________________________________________________________________________________________________
block_10_depthwise_relu (ReLU)  (None, 10, 10, 384)  0           block_10_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_10_project (Conv2D)       (None, 10, 10, 96)   36864       block_10_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_10_project_BN (BatchNorma (None, 10, 10, 96)   384         block_10_project[0][0]           
__________________________________________________________________________________________________
block_11_expand (Conv2D)        (None, 10, 10, 576)  55296       block_10_project_BN[0][0]        
__________________________________________________________________________________________________
block_11_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_11_expand[0][0]            
__________________________________________________________________________________________________
block_11_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_11_expand_BN[0][0]         
__________________________________________________________________________________________________
block_11_depthwise (DepthwiseCo (None, 10, 10, 576)  5184        block_11_expand_relu[0][0]       
__________________________________________________________________________________________________
block_11_depthwise_BN (BatchNor (None, 10, 10, 576)  2304        block_11_depthwise[0][0]         
__________________________________________________________________________________________________
block_11_depthwise_relu (ReLU)  (None, 10, 10, 576)  0           block_11_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_11_project (Conv2D)       (None, 10, 10, 96)   55296       block_11_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_11_project_BN (BatchNorma (None, 10, 10, 96)   384         block_11_project[0][0]           
__________________________________________________________________________________________________
block_11_add (Add)              (None, 10, 10, 96)   0           block_10_project_BN[0][0]        
                                                                 block_11_project_BN[0][0]        
__________________________________________________________________________________________________
block_12_expand (Conv2D)        (None, 10, 10, 576)  55296       block_11_add[0][0]               
__________________________________________________________________________________________________
block_12_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_12_expand[0][0]            
__________________________________________________________________________________________________
block_12_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_12_expand_BN[0][0]         
__________________________________________________________________________________________________
block_12_depthwise (DepthwiseCo (None, 10, 10, 576)  5184        block_12_expand_relu[0][0]       
__________________________________________________________________________________________________
block_12_depthwise_BN (BatchNor (None, 10, 10, 576)  2304        block_12_depthwise[0][0]         
__________________________________________________________________________________________________
block_12_depthwise_relu (ReLU)  (None, 10, 10, 576)  0           block_12_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_12_project (Conv2D)       (None, 10, 10, 96)   55296       block_12_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_12_project_BN (BatchNorma (None, 10, 10, 96)   384         block_12_project[0][0]           
__________________________________________________________________________________________________
block_12_add (Add)              (None, 10, 10, 96)   0           block_11_add[0][0]               
                                                                 block_12_project_BN[0][0]        
__________________________________________________________________________________________________
block_13_expand (Conv2D)        (None, 10, 10, 576)  55296       block_12_add[0][0]               
__________________________________________________________________________________________________
block_13_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_13_expand[0][0]            
__________________________________________________________________________________________________
block_13_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_13_expand_BN[0][0]         
__________________________________________________________________________________________________
block_13_pad (ZeroPadding2D)    (None, 11, 11, 576)  0           block_13_expand_relu[0][0]       
__________________________________________________________________________________________________
block_13_depthwise (DepthwiseCo (None, 5, 5, 576)    5184        block_13_pad[0][0]               
__________________________________________________________________________________________________
block_13_depthwise_BN (BatchNor (None, 5, 5, 576)    2304        block_13_depthwise[0][0]         
__________________________________________________________________________________________________
block_13_depthwise_relu (ReLU)  (None, 5, 5, 576)    0           block_13_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_13_project (Conv2D)       (None, 5, 5, 160)    92160       block_13_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_13_project_BN (BatchNorma (None, 5, 5, 160)    640         block_13_project[0][0]           
__________________________________________________________________________________________________
block_14_expand (Conv2D)        (None, 5, 5, 960)    153600      block_13_project_BN[0][0]        
__________________________________________________________________________________________________
block_14_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_14_expand[0][0]            
__________________________________________________________________________________________________
block_14_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_14_expand_BN[0][0]         
__________________________________________________________________________________________________
block_14_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_14_expand_relu[0][0]       
__________________________________________________________________________________________________
block_14_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_14_depthwise[0][0]         
__________________________________________________________________________________________________
block_14_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_14_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_14_project (Conv2D)       (None, 5, 5, 160)    153600      block_14_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_14_project_BN (BatchNorma (None, 5, 5, 160)    640         block_14_project[0][0]           
__________________________________________________________________________________________________
block_14_add (Add)              (None, 5, 5, 160)    0           block_13_project_BN[0][0]        
                                                                 block_14_project_BN[0][0]        
__________________________________________________________________________________________________
block_15_expand (Conv2D)        (None, 5, 5, 960)    153600      block_14_add[0][0]               
__________________________________________________________________________________________________
block_15_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_15_expand[0][0]            
__________________________________________________________________________________________________
block_15_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_15_expand_BN[0][0]         
__________________________________________________________________________________________________
block_15_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_15_expand_relu[0][0]       
__________________________________________________________________________________________________
block_15_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_15_depthwise[0][0]         
__________________________________________________________________________________________________
block_15_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_15_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_15_project (Conv2D)       (None, 5, 5, 160)    153600      block_15_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_15_project_BN (BatchNorma (None, 5, 5, 160)    640         block_15_project[0][0]           
__________________________________________________________________________________________________
block_15_add (Add)              (None, 5, 5, 160)    0           block_14_add[0][0]               
                                                                 block_15_project_BN[0][0]        
__________________________________________________________________________________________________
block_16_expand (Conv2D)        (None, 5, 5, 960)    153600      block_15_add[0][0]               
__________________________________________________________________________________________________
block_16_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_16_expand[0][0]            
__________________________________________________________________________________________________
block_16_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_16_expand_BN[0][0]         
__________________________________________________________________________________________________
block_16_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_16_expand_relu[0][0]       
__________________________________________________________________________________________________
block_16_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_16_depthwise[0][0]         
__________________________________________________________________________________________________
block_16_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_16_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_16_project (Conv2D)       (None, 5, 5, 320)    307200      block_16_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_16_project_BN (BatchNorma (None, 5, 5, 320)    1280        block_16_project[0][0]           
__________________________________________________________________________________________________
Conv_1 (Conv2D)                 (None, 5, 5, 1280)   409600      block_16_project_BN[0][0]        
__________________________________________________________________________________________________
Conv_1_bn (BatchNormalization)  (None, 5, 5, 1280)   5120        Conv_1[0][0]                     
__________________________________________________________________________________________________
out_relu (ReLU)                 (None, 5, 5, 1280)   0           Conv_1_bn[0][0]                  
==================================================================================================
Total params: 2,257,984
Trainable params: 0
Non-trainable params: 2,257,984
__________________________________________________________________________________________________

Klassifikationskopf hinzufügen

Um Vorhersagen aus dem Block von Features zu generieren, tf.keras.layers.GlobalAveragePooling2D über die räumlichen 5x5 räumlichen Positionen, indem Sie einen tf.keras.layers.GlobalAveragePooling2D Layer verwenden, um die Features in einen einzelnen 1280-Element-Vektor pro Bild zu konvertieren.

global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)
(32, 1280)

Wenden Sie einen tf.keras.layers.Dense Layer an, um diese Features in eine einzelne Vorhersage pro Bild umzuwandeln. Sie benötigen hier keine Aktivierungsfunktion, da diese Vorhersage als logit oder als logit behandelt wird. Positive Zahlen sagen Klasse 1 voraus, negative Zahlen sagen Klasse 0 voraus.

prediction_layer = tf.keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)
(32, 1)

Erstellen Sie ein Modell, indem Sie die Datenerweiterungs-, Neuskalierungs-, base_model- und Feature-Extraktor-Layer mit der Keras Functional API verketten. Wie bereits erwähnt, verwenden Sie training=False, da unser Modell eine BatchNormalization-Schicht enthält.

inputs = tf.keras.Input(shape=(160, 160, 3))
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)

Kompilieren Sie das Modell

Kompilieren Sie das Modell, bevor Sie es trainieren. Da es zwei Klassen gibt, verwenden Sie einen binären Kreuzentropieverlust mit from_logits=True da das Modell eine lineare Ausgabe liefert.

base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:375: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  "The `lr` argument is deprecated, use `learning_rate` instead.")
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
sequential (Sequential)      (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.truediv (TFOpLambda) (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.subtract (TFOpLambda (None, 160, 160, 3)       0         
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,281
Non-trainable params: 2,257,984
_________________________________________________________________

Die 2,5 Millionen Parameter in MobileNet sind eingefroren, aber es gibt 1,2K trainierbare Parameter in der Dichteschicht. Diese sind auf zwei tf.Variable Objekte aufgeteilt, die Gewichte und Bias.

len(model.trainable_variables)
2

Trainiere das Modell

Nach dem Training für 10 Epochen sollten Sie eine Genauigkeit von ~94 % im Validierungssatz sehen.

initial_epochs = 10

loss0, accuracy0 = model.evaluate(validation_dataset)
26/26 [==============================] - 2s 16ms/step - loss: 1.0528 - accuracy: 0.5223
print("initial loss: {:.2f}".format(loss0))
print("initial accuracy: {:.2f}".format(accuracy0))
initial loss: 1.05
initial accuracy: 0.52
history = model.fit(train_dataset,
                    epochs=initial_epochs,
                    validation_data=validation_dataset)
Epoch 1/10
63/63 [==============================] - 4s 22ms/step - loss: 0.8021 - accuracy: 0.6015 - val_loss: 0.5626 - val_accuracy: 0.7166
Epoch 2/10
63/63 [==============================] - 1s 20ms/step - loss: 0.5231 - accuracy: 0.7185 - val_loss: 0.4190 - val_accuracy: 0.8243
Epoch 3/10
63/63 [==============================] - 1s 20ms/step - loss: 0.4331 - accuracy: 0.7880 - val_loss: 0.3245 - val_accuracy: 0.8738
Epoch 4/10
63/63 [==============================] - 1s 20ms/step - loss: 0.3672 - accuracy: 0.8330 - val_loss: 0.2675 - val_accuracy: 0.9010
Epoch 5/10
63/63 [==============================] - 1s 20ms/step - loss: 0.3213 - accuracy: 0.8660 - val_loss: 0.2280 - val_accuracy: 0.9257
Epoch 6/10
63/63 [==============================] - 1s 20ms/step - loss: 0.2867 - accuracy: 0.8750 - val_loss: 0.1962 - val_accuracy: 0.9369
Epoch 7/10
63/63 [==============================] - 1s 20ms/step - loss: 0.2659 - accuracy: 0.8920 - val_loss: 0.1704 - val_accuracy: 0.9517
Epoch 8/10
63/63 [==============================] - 1s 20ms/step - loss: 0.2401 - accuracy: 0.8990 - val_loss: 0.1520 - val_accuracy: 0.9542
Epoch 9/10
63/63 [==============================] - 1s 20ms/step - loss: 0.2368 - accuracy: 0.8985 - val_loss: 0.1384 - val_accuracy: 0.9592
Epoch 10/10
63/63 [==============================] - 1s 20ms/step - loss: 0.2096 - accuracy: 0.9140 - val_loss: 0.1312 - val_accuracy: 0.9592

Lernkurven

Werfen wir einen Blick auf die Lernkurven der Trainings- und Validierungsgenauigkeit/-verlust, wenn das MobileNet V2-Basismodell als Extraktor fester Funktionen verwendet wird.

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

png

In geringerem Maße liegt dies auch daran, dass Trainingsmetriken den Durchschnitt für eine Epoche angeben, während Validierungsmetriken nach der Epoche ausgewertet werden, sodass Validierungsmetriken ein Modell sehen, das etwas länger trainiert wurde.

Feintuning

In dem Feature-Extraktionsexperiment haben Sie nur einige Layer auf einem MobileNet V2-Basismodell trainiert. Die Gewichte des vortrainierten Netzwerks wurden während des Trainings nicht aktualisiert.

Eine Möglichkeit, die Leistung noch weiter zu steigern, besteht darin, die Gewichtungen der obersten Schichten des vortrainierten Modells neben dem Training des von Ihnen hinzugefügten Klassifikators zu trainieren (oder "fein abzustimmen"). Der Trainingsprozess erzwingt, dass die Gewichtungen von generischen Feature-Maps auf Features abgestimmt werden, die speziell mit dem Dataset verknüpft sind.

Außerdem sollten Sie versuchen, eine kleine Anzahl von Top-Layer zu optimieren, anstatt das gesamte MobileNet-Modell. In den meisten Faltungsnetzwerken gilt: Je höher eine Schicht ist, desto spezialisierter ist sie. Die ersten paar Ebenen lernen sehr einfache und generische Funktionen, die sich auf fast alle Arten von Bildern verallgemeinern lassen. Je höher Sie gehen, desto spezifischer werden die Features für das Dataset, auf dem das Modell trainiert wurde. Das Ziel der Feinabstimmung besteht darin, diese speziellen Funktionen an die Arbeit mit dem neuen Datensatz anzupassen, anstatt das generische Lernen zu überschreiben.

Entfrosten Sie die oberen Schichten des Modells

Alles, was Sie tun müssen, ist, das base_model base_model und die unteren Schichten so einzustellen, dass sie nicht trainierbar sind. Anschließend sollten Sie das Modell neu kompilieren (erforderlich, damit diese Änderungen wirksam werden) und das Training wieder aufnehmen.

base_model.trainable = True
# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model.layers))

# Fine-tune from this layer onwards
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
  layer.trainable =  False
Number of layers in the base model:  154

Kompilieren Sie das Modell

Da Sie ein viel größeres Modell trainieren und die vortrainierten Gewichte neu anpassen möchten, ist es wichtig, in dieser Phase eine niedrigere Lernrate zu verwenden. Sonst könnte Ihr Modell sehr schnell overfit werden.

model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate/10),
              metrics=['accuracy'])
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
sequential (Sequential)      (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.truediv (TFOpLambda) (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.subtract (TFOpLambda (None, 160, 160, 3)       0         
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,862,721
Non-trainable params: 396,544
_________________________________________________________________
len(model.trainable_variables)
56

Trainieren Sie das Modell weiter

Wenn Sie früher auf Konvergenz trainiert haben, verbessert dieser Schritt Ihre Genauigkeit um einige Prozentpunkte.

fine_tune_epochs = 10
total_epochs =  initial_epochs + fine_tune_epochs

history_fine = model.fit(train_dataset,
                         epochs=total_epochs,
                         initial_epoch=history.epoch[-1],
                         validation_data=validation_dataset)
Epoch 10/20
63/63 [==============================] - 6s 38ms/step - loss: 0.1514 - accuracy: 0.9375 - val_loss: 0.0548 - val_accuracy: 0.9740
Epoch 11/20
63/63 [==============================] - 2s 27ms/step - loss: 0.1262 - accuracy: 0.9490 - val_loss: 0.0597 - val_accuracy: 0.9691
Epoch 12/20
63/63 [==============================] - 2s 26ms/step - loss: 0.1073 - accuracy: 0.9595 - val_loss: 0.0444 - val_accuracy: 0.9777
Epoch 13/20
63/63 [==============================] - 2s 27ms/step - loss: 0.0997 - accuracy: 0.9590 - val_loss: 0.0437 - val_accuracy: 0.9777
Epoch 14/20
63/63 [==============================] - 2s 26ms/step - loss: 0.0983 - accuracy: 0.9610 - val_loss: 0.0466 - val_accuracy: 0.9790
Epoch 15/20
63/63 [==============================] - 2s 27ms/step - loss: 0.0715 - accuracy: 0.9740 - val_loss: 0.0378 - val_accuracy: 0.9814
Epoch 16/20
63/63 [==============================] - 2s 26ms/step - loss: 0.0753 - accuracy: 0.9720 - val_loss: 0.0465 - val_accuracy: 0.9765
Epoch 17/20
63/63 [==============================] - 2s 27ms/step - loss: 0.0727 - accuracy: 0.9710 - val_loss: 0.0348 - val_accuracy: 0.9827
Epoch 18/20
63/63 [==============================] - 2s 26ms/step - loss: 0.0638 - accuracy: 0.9715 - val_loss: 0.0393 - val_accuracy: 0.9839
Epoch 19/20
63/63 [==============================] - 2s 27ms/step - loss: 0.0552 - accuracy: 0.9755 - val_loss: 0.0349 - val_accuracy: 0.9851
Epoch 20/20
63/63 [==============================] - 2s 27ms/step - loss: 0.0605 - accuracy: 0.9730 - val_loss: 0.0419 - val_accuracy: 0.9864

Werfen wir einen Blick auf die Lernkurven der Trainings- und Validierungsgenauigkeit/-verlust bei der Feinabstimmung der letzten Schichten des MobileNet V2-Basismodells und dem darauf aufbauenden Training des Klassifikators. Der Validierungsverlust ist viel höher als der Trainingsverlust, daher kann es zu einer Überanpassung kommen.

Es kann auch zu einer Überanpassung kommen, da der neue Trainingssatz relativ klein ist und den ursprünglichen MobileNet V2-Datensätzen ähnelt.

Nach der Feinabstimmung erreicht das Modell auf dem Validierungssatz eine Genauigkeit von fast 98%.

acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']

loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0.8, 1])
plt.plot([initial_epochs-1,initial_epochs-1],
          plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 1.0])
plt.plot([initial_epochs-1,initial_epochs-1],
         plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

png

Auswertung und Vorhersage

Schließlich können Sie die Leistung des Modells mit neuen Daten mit einem Testsatz überprüfen.

loss, accuracy = model.evaluate(test_dataset)
print('Test accuracy :', accuracy)
6/6 [==============================] - 0s 13ms/step - loss: 0.0438 - accuracy: 0.9792
Test accuracy : 0.9791666865348816

Und jetzt können Sie dieses Modell verwenden, um vorherzusagen, ob Ihr Haustier eine Katze oder ein Hund ist.

#Retrieve a batch of images from the test set
image_batch, label_batch = test_dataset.as_numpy_iterator().next()
predictions = model.predict_on_batch(image_batch).flatten()

# Apply a sigmoid since our model returns logits
predictions = tf.nn.sigmoid(predictions)
predictions = tf.where(predictions < 0.5, 0, 1)

print('Predictions:\n', predictions.numpy())
print('Labels:\n', label_batch)

plt.figure(figsize=(10, 10))
for i in range(9):
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(image_batch[i].astype("uint8"))
  plt.title(class_names[predictions[i]])
  plt.axis("off")
Predictions:
 [0 1 0 0 1 0 1 1 0 0 0 0 0 1 1 0 1 0 1 1 1 0 1 1 1 1 1 1 0 1 0 1]
Labels:
 [0 1 0 0 0 0 1 1 0 0 0 0 0 1 0 0 1 0 1 1 1 0 1 1 1 1 1 1 0 1 0 1]

png

Zusammenfassung

  • Verwenden eines vortrainierten Modells für die Feature-Extraktion : Beim Arbeiten mit einem kleinen Dataset ist es üblich, Features zu nutzen, die von einem Modell gelernt wurden, das mit einem größeren Dataset in derselben Domäne trainiert wurde. Dies geschieht durch Instanziieren des vortrainierten Modells und Hinzufügen eines vollständig verbundenen Klassifikators darüber. Das vortrainierte Modell wird "eingefroren" und nur die Gewichte des Klassifikators werden während des Trainings aktualisiert. In diesem Fall hat die Faltungsbasis alle mit jedem Bild verknüpften Merkmale extrahiert und Sie haben gerade einen Klassifikator trainiert, der die Bildklasse anhand dieses Satzes extrahierter Merkmale bestimmt.

  • Feinabstimmung eines vortrainierten Modells : Um die Leistung weiter zu verbessern, möchten Sie möglicherweise die obersten Schichten der vortrainierten Modelle per Feinabstimmung für das neue Dataset verwenden. In diesem Fall haben Sie Ihre Gewichtungen so abgestimmt, dass Ihr Modell für das Dataset spezifische High-Level-Features gelernt hat. Diese Technik wird normalerweise empfohlen, wenn das Trainings-Dataset groß ist und dem ursprünglichen Dataset, auf dem das vortrainierte Modell trainiert wurde, sehr ähnlich ist.

Um mehr zu erfahren, besuchen Sie den Transfer-Lernleitfaden .

# MIT License
#
# Copyright (c) 2017 François Chollet                                                                                                                    # IGNORE_COPYRIGHT: cleared by OSS licensing
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.