Merken Sie den Termin vor! Google I / O kehrt vom 18. bis 20. Mai zurück Registrieren Sie sich jetzt
Diese Seite wurde von der Cloud Translation API übersetzt.
Switch to English

Transferlernen und Feinabstimmung

Ansicht auf TensorFlow.org In Google Colab ausführen Quelle auf GitHub anzeigen Notizbuch herunterladen

In diesem Tutorial lernen Sie, wie Sie Bilder von Katzen und Hunden mithilfe von Transferlernen aus einem vorab geschulten Netzwerk klassifizieren.

Ein vorab trainiertes Modell ist ein gespeichertes Netzwerk, das zuvor für einen großen Datensatz trainiert wurde, normalerweise für eine umfangreiche Bildklassifizierungsaufgabe. Sie verwenden das vorab trainierte Modell entweder unverändert oder verwenden das Transferlernen, um dieses Modell an eine bestimmte Aufgabe anzupassen.

Die Intuition hinter dem Transferlernen für die Bildklassifizierung besteht darin, dass ein Modell, wenn es auf einem großen und allgemein genug großen Datensatz trainiert wird, effektiv als generisches Modell der visuellen Welt dient. Sie können diese erlernten Feature-Maps dann nutzen, ohne von vorne beginnen zu müssen, indem Sie ein großes Modell auf einem großen Datensatz trainieren.

In diesem Notizbuch werden Sie zwei Möglichkeiten ausprobieren, um ein vorab geschultes Modell anzupassen:

  1. Merkmalsextraktion: Verwenden Sie die von einem früheren Netzwerk gelernten Darstellungen, um aussagekräftige Merkmale aus neuen Beispielen zu extrahieren. Sie fügen einfach einen neuen Klassifikator hinzu, der von Grund auf neu trainiert wird, und zwar über dem vorab trainierten Modell, damit Sie die zuvor für das Dataset erlernten Feature-Maps neu verwenden können.

    Sie müssen nicht das gesamte Modell (neu) trainieren. Das Basis-Faltungsnetzwerk enthält bereits Funktionen, die allgemein zum Klassifizieren von Bildern nützlich sind. Der letzte Klassifizierungsteil des vorab trainierten Modells ist jedoch spezifisch für die ursprüngliche Klassifizierungsaufgabe und anschließend spezifisch für die Klassen, für die das Modell trainiert wurde.

  2. Feinabstimmung: Entfrieren Sie einige der obersten Schichten einer gefrorenen Modellbasis und trainieren Sie gemeinsam die neu hinzugefügten Klassifiziererebenen und die letzten Schichten des Basismodells. Auf diese Weise können wir die Feature-Darstellungen höherer Ordnung im Basismodell "fein abstimmen", um sie für die jeweilige Aufgabe relevanter zu machen.

Sie folgen dem allgemeinen Workflow für maschinelles Lernen.

  1. Untersuchen und verstehen Sie die Daten
  2. Erstellen Sie eine Eingabepipeline, in diesem Fall mit Keras ImageDataGenerator
  3. Verfassen Sie das Modell
    • Laden Sie das vorab trainierte Basismodell (und die vorab trainierten Gewichte).
    • Stapeln Sie die Klassifizierungsebenen oben
  4. Trainiere das Modell
  5. Modell auswerten
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf

from tensorflow.keras.preprocessing import image_dataset_from_directory

Datenvorverarbeitung

Daten herunterladen

In diesem Tutorial verwenden Sie einen Datensatz mit mehreren tausend Bildern von Katzen und Hunden. Laden Sie eine Zip-Datei mit den Bildern herunter, extrahieren Sie sie und erstellentf.data.Dataset anschließend mit dem Dienstprogramm tf.keras.preprocessing.image_dataset_from_directory eintf.data.Dataset für Schulung und Validierung. Weitere Informationen zum Laden von Bildern finden Sie in diesem Tutorial .

_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

BATCH_SIZE = 32
IMG_SIZE = (160, 160)

train_dataset = image_dataset_from_directory(train_dir,
                                             shuffle=True,
                                             batch_size=BATCH_SIZE,
                                             image_size=IMG_SIZE)
Downloading data from https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip
68608000/68606236 [==============================] - 2s 0us/step
Found 2000 files belonging to 2 classes.
validation_dataset = image_dataset_from_directory(validation_dir,
                                                  shuffle=True,
                                                  batch_size=BATCH_SIZE,
                                                  image_size=IMG_SIZE)
Found 1000 files belonging to 2 classes.

Zeigen Sie die ersten neun Bilder und Beschriftungen aus dem Trainingsset:

class_names = train_dataset.class_names

plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

png

Da der ursprüngliche Datensatz keinen Testsatz enthält, erstellen Sie einen. Bestimmen Sie dazu mithilfe von tf.data.experimental.cardinality , wie viele Datenstapel im Validierungssatz tf.data.experimental.cardinality , und verschieben tf.data.experimental.cardinality dann 20% davon in einen tf.data.experimental.cardinality .

val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)
print('Number of validation batches: %d' % tf.data.experimental.cardinality(validation_dataset))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_dataset))
Number of validation batches: 26
Number of test batches: 6

Konfigurieren Sie das Dataset für die Leistung

Verwenden Sie das gepufferte Prefetching, um Bilder von der Festplatte zu laden, ohne dass E / A blockiert werden. Weitere Informationen zu dieser Methode finden Sie im Handbuch zur Datenleistung .

AUTOTUNE = tf.data.AUTOTUNE

train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)

Verwenden Sie die Datenerweiterung

Wenn Sie keinen großen Bilddatensatz haben, empfiehlt es sich, die Stichprobenvielfalt künstlich einzuführen, indem Sie zufällige, aber realistische Transformationen auf die Trainingsbilder anwenden, z. B. Rotation und horizontales Spiegeln. Dies hilft, das Modell verschiedenen Aspekten der Trainingsdaten auszusetzen und Überanpassungen zu reduzieren. Weitere Informationen zur Datenerweiterung finden Sie in diesem Lernprogramm .

data_augmentation = tf.keras.Sequential([
  tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
  tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])

Wenden wir diese Ebenen wiederholt auf dasselbe Bild an und sehen das Ergebnis.

for image, _ in train_dataset.take(1):
  plt.figure(figsize=(10, 10))
  first_image = image[0]
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    augmented_image = data_augmentation(tf.expand_dims(first_image, 0))
    plt.imshow(augmented_image[0] / 255)
    plt.axis('off')

png

Pixelwerte neu skalieren

In Kürze laden Sie tf.keras.applications.MobileNetV2 zur Verwendung als Basismodell herunter. Dieses Modell erwartet Pixelwerte in [-1,1] , aber zu diesem Zeitpunkt befinden sich die Pixelwerte in Ihren Bildern in [0-255] . Verwenden Sie zum erneuten Skalieren die im Modell enthaltene Vorverarbeitungsmethode.

preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
rescale = tf.keras.layers.experimental.preprocessing.Rescaling(1./127.5, offset= -1)

Erstellen Sie das Basismodell aus den vorab trainierten Convnets

Sie erstellen das Basismodell aus dem bei Google entwickelten MobileNet V2- Modell. Dies ist für das ImageNet-Dataset vorab trainiert, ein großes Dataset, das aus 1,4 Millionen Bildern und 1000 Klassen besteht. ImageNet ist ein Datensatz für Forschungstrainings mit einer Vielzahl von Kategorien wie jackfruit und syringe . Diese Wissensbasis hilft uns, Katzen und Hunde aus unserem spezifischen Datensatz zu klassifizieren.

Zunächst müssen Sie auswählen, welche Schicht von MobileNet V2 Sie für die Funktionsextraktion verwenden möchten. Die allerletzte Klassifizierungsebene (oben ", da die meisten Diagramme von Modellen für maschinelles Lernen von unten nach oben verlaufen) ist nicht sehr nützlich. Stattdessen befolgen Sie die gängige Praxis, um von der allerletzten Schicht vor dem Abflachen abhängig zu sein. Diese Schicht wird als "Engpassschicht" bezeichnet. Die Merkmale der Engpassschicht bleiben im Vergleich zur endgültigen / obersten Schicht allgemeiner.

Instanziieren Sie zunächst ein MobileNet V2-Modell, das mit auf ImageNet trainierten Gewichten vorinstalliert ist. Indem Sie das Argument include_top = False angeben , laden Sie ein Netzwerk, in dem die Klassifizierungsebenen oben nicht enthalten sind. Dies ist ideal für die Feature-Extraktion.

# Create the base model from the pre-trained model MobileNet V2
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step

Dieser Feature-Extraktor konvertiert jedes 160x160x3 Bild in einen 5x5x1280 Funktionsblock. Mal sehen, was es mit einem Beispielstapel von Bildern macht:

image_batch, label_batch = next(iter(train_dataset))
feature_batch = base_model(image_batch)
print(feature_batch.shape)
(32, 5, 5, 1280)

Merkmalsextraktion

In diesem Schritt frieren Sie die aus dem vorherigen Schritt erstellte Faltungsbasis ein und verwenden sie als Feature-Extraktor. Zusätzlich fügen Sie einen Klassifikator hinzu und trainieren den Klassifikator der obersten Ebene.

Frieren Sie die Faltungsbasis ein

Es ist wichtig, die Faltungsbasis einzufrieren, bevor Sie das Modell kompilieren und trainieren. Das Einfrieren (durch Setzen von layer.trainable = False) verhindert, dass die Gewichte in einer bestimmten Ebene während des Trainings aktualisiert werden. MobileNet V2 verfügt über viele Ebenen. Wenn Sie also das trainable Flag des gesamten Modells auf "Falsch" setzen, werden alle Ebenen eingefroren.

base_model.trainable = False

Wichtiger Hinweis zu BatchNormalization-Ebenen

Viele Modelle enthalten tf.keras.layers.BatchNormalization Ebenen. Diese Ebene ist ein Sonderfall, und im Zusammenhang mit der Feinabstimmung sollten Vorsichtsmaßnahmen getroffen werden, wie später in diesem Lernprogramm gezeigt wird.

Wenn Sie layer.trainable = False BatchNormalization , wird die BatchNormalization Ebene im Inferenzmodus ausgeführt und aktualisiert ihre Mittelwert- und Varianzstatistik nicht.

Wenn Sie ein Modell mit BatchNormalization-Ebenen aufheben, um eine Feinabstimmung vorzunehmen, sollten Sie die BatchNormalization-Ebenen im Inferenzmodus belassen, indem Sie beim Aufrufen des Basismodells training = False . Andernfalls zerstören die Aktualisierungen der nicht trainierbaren Gewichte das, was das Modell gelernt hat.

Einzelheiten finden Sie im Transfer-Lernhandbuch .

# Let's take a look at the base model architecture
base_model.summary()
Model: "mobilenetv2_1.00_160"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 160, 160, 3) 0                                            
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 80, 80, 32)   864         input_1[0][0]                    
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 80, 80, 32)   128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, 80, 80, 32)   0           bn_Conv1[0][0]                   
__________________________________________________________________________________________________
expanded_conv_depthwise (Depthw (None, 80, 80, 32)   288         Conv1_relu[0][0]                 
__________________________________________________________________________________________________
expanded_conv_depthwise_BN (Bat (None, 80, 80, 32)   128         expanded_conv_depthwise[0][0]    
__________________________________________________________________________________________________
expanded_conv_depthwise_relu (R (None, 80, 80, 32)   0           expanded_conv_depthwise_BN[0][0] 
__________________________________________________________________________________________________
expanded_conv_project (Conv2D)  (None, 80, 80, 16)   512         expanded_conv_depthwise_relu[0][0
__________________________________________________________________________________________________
expanded_conv_project_BN (Batch (None, 80, 80, 16)   64          expanded_conv_project[0][0]      
__________________________________________________________________________________________________
block_1_expand (Conv2D)         (None, 80, 80, 96)   1536        expanded_conv_project_BN[0][0]   
__________________________________________________________________________________________________
block_1_expand_BN (BatchNormali (None, 80, 80, 96)   384         block_1_expand[0][0]             
__________________________________________________________________________________________________
block_1_expand_relu (ReLU)      (None, 80, 80, 96)   0           block_1_expand_BN[0][0]          
__________________________________________________________________________________________________
block_1_pad (ZeroPadding2D)     (None, 81, 81, 96)   0           block_1_expand_relu[0][0]        
__________________________________________________________________________________________________
block_1_depthwise (DepthwiseCon (None, 40, 40, 96)   864         block_1_pad[0][0]                
__________________________________________________________________________________________________
block_1_depthwise_BN (BatchNorm (None, 40, 40, 96)   384         block_1_depthwise[0][0]          
__________________________________________________________________________________________________
block_1_depthwise_relu (ReLU)   (None, 40, 40, 96)   0           block_1_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_1_project (Conv2D)        (None, 40, 40, 24)   2304        block_1_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_1_project_BN (BatchNormal (None, 40, 40, 24)   96          block_1_project[0][0]            
__________________________________________________________________________________________________
block_2_expand (Conv2D)         (None, 40, 40, 144)  3456        block_1_project_BN[0][0]         
__________________________________________________________________________________________________
block_2_expand_BN (BatchNormali (None, 40, 40, 144)  576         block_2_expand[0][0]             
__________________________________________________________________________________________________
block_2_expand_relu (ReLU)      (None, 40, 40, 144)  0           block_2_expand_BN[0][0]          
__________________________________________________________________________________________________
block_2_depthwise (DepthwiseCon (None, 40, 40, 144)  1296        block_2_expand_relu[0][0]        
__________________________________________________________________________________________________
block_2_depthwise_BN (BatchNorm (None, 40, 40, 144)  576         block_2_depthwise[0][0]          
__________________________________________________________________________________________________
block_2_depthwise_relu (ReLU)   (None, 40, 40, 144)  0           block_2_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_2_project (Conv2D)        (None, 40, 40, 24)   3456        block_2_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_2_project_BN (BatchNormal (None, 40, 40, 24)   96          block_2_project[0][0]            
__________________________________________________________________________________________________
block_2_add (Add)               (None, 40, 40, 24)   0           block_1_project_BN[0][0]         
                                                                 block_2_project_BN[0][0]         
__________________________________________________________________________________________________
block_3_expand (Conv2D)         (None, 40, 40, 144)  3456        block_2_add[0][0]                
__________________________________________________________________________________________________
block_3_expand_BN (BatchNormali (None, 40, 40, 144)  576         block_3_expand[0][0]             
__________________________________________________________________________________________________
block_3_expand_relu (ReLU)      (None, 40, 40, 144)  0           block_3_expand_BN[0][0]          
__________________________________________________________________________________________________
block_3_pad (ZeroPadding2D)     (None, 41, 41, 144)  0           block_3_expand_relu[0][0]        
__________________________________________________________________________________________________
block_3_depthwise (DepthwiseCon (None, 20, 20, 144)  1296        block_3_pad[0][0]                
__________________________________________________________________________________________________
block_3_depthwise_BN (BatchNorm (None, 20, 20, 144)  576         block_3_depthwise[0][0]          
__________________________________________________________________________________________________
block_3_depthwise_relu (ReLU)   (None, 20, 20, 144)  0           block_3_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_3_project (Conv2D)        (None, 20, 20, 32)   4608        block_3_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_3_project_BN (BatchNormal (None, 20, 20, 32)   128         block_3_project[0][0]            
__________________________________________________________________________________________________
block_4_expand (Conv2D)         (None, 20, 20, 192)  6144        block_3_project_BN[0][0]         
__________________________________________________________________________________________________
block_4_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_4_expand[0][0]             
__________________________________________________________________________________________________
block_4_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_4_expand_BN[0][0]          
__________________________________________________________________________________________________
block_4_depthwise (DepthwiseCon (None, 20, 20, 192)  1728        block_4_expand_relu[0][0]        
__________________________________________________________________________________________________
block_4_depthwise_BN (BatchNorm (None, 20, 20, 192)  768         block_4_depthwise[0][0]          
__________________________________________________________________________________________________
block_4_depthwise_relu (ReLU)   (None, 20, 20, 192)  0           block_4_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_4_project (Conv2D)        (None, 20, 20, 32)   6144        block_4_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_4_project_BN (BatchNormal (None, 20, 20, 32)   128         block_4_project[0][0]            
__________________________________________________________________________________________________
block_4_add (Add)               (None, 20, 20, 32)   0           block_3_project_BN[0][0]         
                                                                 block_4_project_BN[0][0]         
__________________________________________________________________________________________________
block_5_expand (Conv2D)         (None, 20, 20, 192)  6144        block_4_add[0][0]                
__________________________________________________________________________________________________
block_5_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_5_expand[0][0]             
__________________________________________________________________________________________________
block_5_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_5_expand_BN[0][0]          
__________________________________________________________________________________________________
block_5_depthwise (DepthwiseCon (None, 20, 20, 192)  1728        block_5_expand_relu[0][0]        
__________________________________________________________________________________________________
block_5_depthwise_BN (BatchNorm (None, 20, 20, 192)  768         block_5_depthwise[0][0]          
__________________________________________________________________________________________________
block_5_depthwise_relu (ReLU)   (None, 20, 20, 192)  0           block_5_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_5_project (Conv2D)        (None, 20, 20, 32)   6144        block_5_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_5_project_BN (BatchNormal (None, 20, 20, 32)   128         block_5_project[0][0]            
__________________________________________________________________________________________________
block_5_add (Add)               (None, 20, 20, 32)   0           block_4_add[0][0]                
                                                                 block_5_project_BN[0][0]         
__________________________________________________________________________________________________
block_6_expand (Conv2D)         (None, 20, 20, 192)  6144        block_5_add[0][0]                
__________________________________________________________________________________________________
block_6_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_6_expand[0][0]             
__________________________________________________________________________________________________
block_6_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_6_expand_BN[0][0]          
__________________________________________________________________________________________________
block_6_pad (ZeroPadding2D)     (None, 21, 21, 192)  0           block_6_expand_relu[0][0]        
__________________________________________________________________________________________________
block_6_depthwise (DepthwiseCon (None, 10, 10, 192)  1728        block_6_pad[0][0]                
__________________________________________________________________________________________________
block_6_depthwise_BN (BatchNorm (None, 10, 10, 192)  768         block_6_depthwise[0][0]          
__________________________________________________________________________________________________
block_6_depthwise_relu (ReLU)   (None, 10, 10, 192)  0           block_6_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_6_project (Conv2D)        (None, 10, 10, 64)   12288       block_6_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_6_project_BN (BatchNormal (None, 10, 10, 64)   256         block_6_project[0][0]            
__________________________________________________________________________________________________
block_7_expand (Conv2D)         (None, 10, 10, 384)  24576       block_6_project_BN[0][0]         
__________________________________________________________________________________________________
block_7_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_7_expand[0][0]             
__________________________________________________________________________________________________
block_7_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_7_expand_BN[0][0]          
__________________________________________________________________________________________________
block_7_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_7_expand_relu[0][0]        
__________________________________________________________________________________________________
block_7_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_7_depthwise[0][0]          
__________________________________________________________________________________________________
block_7_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_7_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_7_project (Conv2D)        (None, 10, 10, 64)   24576       block_7_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_7_project_BN (BatchNormal (None, 10, 10, 64)   256         block_7_project[0][0]            
__________________________________________________________________________________________________
block_7_add (Add)               (None, 10, 10, 64)   0           block_6_project_BN[0][0]         
                                                                 block_7_project_BN[0][0]         
__________________________________________________________________________________________________
block_8_expand (Conv2D)         (None, 10, 10, 384)  24576       block_7_add[0][0]                
__________________________________________________________________________________________________
block_8_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_8_expand[0][0]             
__________________________________________________________________________________________________
block_8_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_8_expand_BN[0][0]          
__________________________________________________________________________________________________
block_8_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_8_expand_relu[0][0]        
__________________________________________________________________________________________________
block_8_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_8_depthwise[0][0]          
__________________________________________________________________________________________________
block_8_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_8_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_8_project (Conv2D)        (None, 10, 10, 64)   24576       block_8_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_8_project_BN (BatchNormal (None, 10, 10, 64)   256         block_8_project[0][0]            
__________________________________________________________________________________________________
block_8_add (Add)               (None, 10, 10, 64)   0           block_7_add[0][0]                
                                                                 block_8_project_BN[0][0]         
__________________________________________________________________________________________________
block_9_expand (Conv2D)         (None, 10, 10, 384)  24576       block_8_add[0][0]                
__________________________________________________________________________________________________
block_9_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_9_expand[0][0]             
__________________________________________________________________________________________________
block_9_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_9_expand_BN[0][0]          
__________________________________________________________________________________________________
block_9_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_9_expand_relu[0][0]        
__________________________________________________________________________________________________
block_9_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_9_depthwise[0][0]          
__________________________________________________________________________________________________
block_9_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_9_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_9_project (Conv2D)        (None, 10, 10, 64)   24576       block_9_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_9_project_BN (BatchNormal (None, 10, 10, 64)   256         block_9_project[0][0]            
__________________________________________________________________________________________________
block_9_add (Add)               (None, 10, 10, 64)   0           block_8_add[0][0]                
                                                                 block_9_project_BN[0][0]         
__________________________________________________________________________________________________
block_10_expand (Conv2D)        (None, 10, 10, 384)  24576       block_9_add[0][0]                
__________________________________________________________________________________________________
block_10_expand_BN (BatchNormal (None, 10, 10, 384)  1536        block_10_expand[0][0]            
__________________________________________________________________________________________________
block_10_expand_relu (ReLU)     (None, 10, 10, 384)  0           block_10_expand_BN[0][0]         
__________________________________________________________________________________________________
block_10_depthwise (DepthwiseCo (None, 10, 10, 384)  3456        block_10_expand_relu[0][0]       
__________________________________________________________________________________________________
block_10_depthwise_BN (BatchNor (None, 10, 10, 384)  1536        block_10_depthwise[0][0]         
__________________________________________________________________________________________________
block_10_depthwise_relu (ReLU)  (None, 10, 10, 384)  0           block_10_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_10_project (Conv2D)       (None, 10, 10, 96)   36864       block_10_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_10_project_BN (BatchNorma (None, 10, 10, 96)   384         block_10_project[0][0]           
__________________________________________________________________________________________________
block_11_expand (Conv2D)        (None, 10, 10, 576)  55296       block_10_project_BN[0][0]        
__________________________________________________________________________________________________
block_11_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_11_expand[0][0]            
__________________________________________________________________________________________________
block_11_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_11_expand_BN[0][0]         
__________________________________________________________________________________________________
block_11_depthwise (DepthwiseCo (None, 10, 10, 576)  5184        block_11_expand_relu[0][0]       
__________________________________________________________________________________________________
block_11_depthwise_BN (BatchNor (None, 10, 10, 576)  2304        block_11_depthwise[0][0]         
__________________________________________________________________________________________________
block_11_depthwise_relu (ReLU)  (None, 10, 10, 576)  0           block_11_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_11_project (Conv2D)       (None, 10, 10, 96)   55296       block_11_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_11_project_BN (BatchNorma (None, 10, 10, 96)   384         block_11_project[0][0]           
__________________________________________________________________________________________________
block_11_add (Add)              (None, 10, 10, 96)   0           block_10_project_BN[0][0]        
                                                                 block_11_project_BN[0][0]        
__________________________________________________________________________________________________
block_12_expand (Conv2D)        (None, 10, 10, 576)  55296       block_11_add[0][0]               
__________________________________________________________________________________________________
block_12_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_12_expand[0][0]            
__________________________________________________________________________________________________
block_12_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_12_expand_BN[0][0]         
__________________________________________________________________________________________________
block_12_depthwise (DepthwiseCo (None, 10, 10, 576)  5184        block_12_expand_relu[0][0]       
__________________________________________________________________________________________________
block_12_depthwise_BN (BatchNor (None, 10, 10, 576)  2304        block_12_depthwise[0][0]         
__________________________________________________________________________________________________
block_12_depthwise_relu (ReLU)  (None, 10, 10, 576)  0           block_12_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_12_project (Conv2D)       (None, 10, 10, 96)   55296       block_12_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_12_project_BN (BatchNorma (None, 10, 10, 96)   384         block_12_project[0][0]           
__________________________________________________________________________________________________
block_12_add (Add)              (None, 10, 10, 96)   0           block_11_add[0][0]               
                                                                 block_12_project_BN[0][0]        
__________________________________________________________________________________________________
block_13_expand (Conv2D)        (None, 10, 10, 576)  55296       block_12_add[0][0]               
__________________________________________________________________________________________________
block_13_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_13_expand[0][0]            
__________________________________________________________________________________________________
block_13_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_13_expand_BN[0][0]         
__________________________________________________________________________________________________
block_13_pad (ZeroPadding2D)    (None, 11, 11, 576)  0           block_13_expand_relu[0][0]       
__________________________________________________________________________________________________
block_13_depthwise (DepthwiseCo (None, 5, 5, 576)    5184        block_13_pad[0][0]               
__________________________________________________________________________________________________
block_13_depthwise_BN (BatchNor (None, 5, 5, 576)    2304        block_13_depthwise[0][0]         
__________________________________________________________________________________________________
block_13_depthwise_relu (ReLU)  (None, 5, 5, 576)    0           block_13_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_13_project (Conv2D)       (None, 5, 5, 160)    92160       block_13_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_13_project_BN (BatchNorma (None, 5, 5, 160)    640         block_13_project[0][0]           
__________________________________________________________________________________________________
block_14_expand (Conv2D)        (None, 5, 5, 960)    153600      block_13_project_BN[0][0]        
__________________________________________________________________________________________________
block_14_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_14_expand[0][0]            
__________________________________________________________________________________________________
block_14_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_14_expand_BN[0][0]         
__________________________________________________________________________________________________
block_14_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_14_expand_relu[0][0]       
__________________________________________________________________________________________________
block_14_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_14_depthwise[0][0]         
__________________________________________________________________________________________________
block_14_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_14_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_14_project (Conv2D)       (None, 5, 5, 160)    153600      block_14_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_14_project_BN (BatchNorma (None, 5, 5, 160)    640         block_14_project[0][0]           
__________________________________________________________________________________________________
block_14_add (Add)              (None, 5, 5, 160)    0           block_13_project_BN[0][0]        
                                                                 block_14_project_BN[0][0]        
__________________________________________________________________________________________________
block_15_expand (Conv2D)        (None, 5, 5, 960)    153600      block_14_add[0][0]               
__________________________________________________________________________________________________
block_15_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_15_expand[0][0]            
__________________________________________________________________________________________________
block_15_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_15_expand_BN[0][0]         
__________________________________________________________________________________________________
block_15_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_15_expand_relu[0][0]       
__________________________________________________________________________________________________
block_15_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_15_depthwise[0][0]         
__________________________________________________________________________________________________
block_15_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_15_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_15_project (Conv2D)       (None, 5, 5, 160)    153600      block_15_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_15_project_BN (BatchNorma (None, 5, 5, 160)    640         block_15_project[0][0]           
__________________________________________________________________________________________________
block_15_add (Add)              (None, 5, 5, 160)    0           block_14_add[0][0]               
                                                                 block_15_project_BN[0][0]        
__________________________________________________________________________________________________
block_16_expand (Conv2D)        (None, 5, 5, 960)    153600      block_15_add[0][0]               
__________________________________________________________________________________________________
block_16_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_16_expand[0][0]            
__________________________________________________________________________________________________
block_16_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_16_expand_BN[0][0]         
__________________________________________________________________________________________________
block_16_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_16_expand_relu[0][0]       
__________________________________________________________________________________________________
block_16_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_16_depthwise[0][0]         
__________________________________________________________________________________________________
block_16_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_16_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_16_project (Conv2D)       (None, 5, 5, 320)    307200      block_16_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_16_project_BN (BatchNorma (None, 5, 5, 320)    1280        block_16_project[0][0]           
__________________________________________________________________________________________________
Conv_1 (Conv2D)                 (None, 5, 5, 1280)   409600      block_16_project_BN[0][0]        
__________________________________________________________________________________________________
Conv_1_bn (BatchNormalization)  (None, 5, 5, 1280)   5120        Conv_1[0][0]                     
__________________________________________________________________________________________________
out_relu (ReLU)                 (None, 5, 5, 1280)   0           Conv_1_bn[0][0]                  
==================================================================================================
Total params: 2,257,984
Trainable params: 0
Non-trainable params: 2,257,984
__________________________________________________________________________________________________

Fügen Sie einen Klassifizierungskopf hinzu

Um Vorhersagen aus dem Block von Features zu generieren, tf.keras.layers.GlobalAveragePooling2D Sie die räumlichen 5x5 räumlichen Positionen mithilfe einer tf.keras.layers.GlobalAveragePooling2D , um die Features in einen einzelnen Vektor mit 1280 Elementen pro Bild zu konvertieren.

global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)
(32, 1280)

Wenden Sie eine tf.keras.layers.Dense an, um diese Features in eine einzelne Vorhersage pro Bild zu konvertieren. Sie benötigen hier keine Aktivierungsfunktion, da diese Vorhersage als logit oder roher Vorhersagewert behandelt wird. Positive Zahlen sagen Klasse 1 voraus, negative Zahlen sagen Klasse 0 voraus.

prediction_layer = tf.keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)
(32, 1)

Erstellen Sie ein Modell, indem Sie die Ebenen für Datenerweiterung, Neuskalierung, base_model und Feature-Extraktor mithilfe der Keras-Funktions-API verketten. Wie bereits erwähnt, verwenden Sie training = False, da unser Modell eine BatchNormalization-Ebene enthält.

inputs = tf.keras.Input(shape=(160, 160, 3))
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)

Kompilieren Sie das Modell

Kompilieren Sie das Modell, bevor Sie es trainieren. Da es zwei Klassen gibt, verwenden Sie einen binären Kreuzentropieverlust mit from_logits=True da das Modell eine lineare Ausgabe liefert.

base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
sequential (Sequential)      (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.truediv (TFOpLambda) (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.subtract (TFOpLambda (None, 160, 160, 3)       0         
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,281
Non-trainable params: 2,257,984
_________________________________________________________________

Die 2,5 Millionen Parameter in MobileNet sind eingefroren, aber es gibt 1,2 KB trainierbare Parameter in der dichten Schicht. Diese werden auf zwei tf.Variable Objekte aufgeteilt, die Gewichte und Verzerrungen.

len(model.trainable_variables)
2

Trainiere das Modell

Nach dem Training für 10 Epochen sollte der Validierungssatz eine Genauigkeit von ~ 94% aufweisen.

initial_epochs = 10

loss0, accuracy0 = model.evaluate(validation_dataset)
26/26 [==============================] - 3s 33ms/step - loss: 0.8183 - accuracy: 0.4566
print("initial loss: {:.2f}".format(loss0))
print("initial accuracy: {:.2f}".format(accuracy0))
initial loss: 0.80
initial accuracy: 0.47
history = model.fit(train_dataset,
                    epochs=initial_epochs,
                    validation_data=validation_dataset)
Epoch 1/10
63/63 [==============================] - 7s 64ms/step - loss: 0.6962 - accuracy: 0.5840 - val_loss: 0.5437 - val_accuracy: 0.6708
Epoch 2/10
63/63 [==============================] - 4s 55ms/step - loss: 0.5169 - accuracy: 0.7270 - val_loss: 0.4002 - val_accuracy: 0.8082
Epoch 3/10
63/63 [==============================] - 4s 54ms/step - loss: 0.4253 - accuracy: 0.7935 - val_loss: 0.3053 - val_accuracy: 0.8725
Epoch 4/10
63/63 [==============================] - 4s 52ms/step - loss: 0.3604 - accuracy: 0.8365 - val_loss: 0.2515 - val_accuracy: 0.9047
Epoch 5/10
63/63 [==============================] - 3s 51ms/step - loss: 0.3200 - accuracy: 0.8565 - val_loss: 0.2226 - val_accuracy: 0.9183
Epoch 6/10
63/63 [==============================] - 3s 51ms/step - loss: 0.2788 - accuracy: 0.8750 - val_loss: 0.1887 - val_accuracy: 0.9356
Epoch 7/10
63/63 [==============================] - 4s 51ms/step - loss: 0.2579 - accuracy: 0.8895 - val_loss: 0.1682 - val_accuracy: 0.9418
Epoch 8/10
63/63 [==============================] - 4s 51ms/step - loss: 0.2536 - accuracy: 0.8815 - val_loss: 0.1554 - val_accuracy: 0.9468
Epoch 9/10
63/63 [==============================] - 4s 53ms/step - loss: 0.2240 - accuracy: 0.9080 - val_loss: 0.1410 - val_accuracy: 0.9468
Epoch 10/10
63/63 [==============================] - 4s 53ms/step - loss: 0.2158 - accuracy: 0.9005 - val_loss: 0.1345 - val_accuracy: 0.9493

Lernkurven

Werfen wir einen Blick auf die Lernkurven der Trainings- und Validierungsgenauigkeit / -verluste bei Verwendung des MobileNet V2-Basismodells als Extraktor für feste Funktionen.

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

png

In geringerem Maße liegt dies auch daran, dass Trainingsmetriken den Durchschnitt für eine Epoche angeben, während Validierungsmetriken nach der Epoche ausgewertet werden, sodass Validierungsmetriken ein Modell sehen, das etwas länger trainiert hat.

Feintuning

Im Feature-Extraktionsexperiment haben Sie nur einige Ebenen auf einem MobileNet V2-Basismodell trainiert. Die Gewichte des vorab trainierten Netzwerks wurden während des Trainings nicht aktualisiert.

Eine Möglichkeit, die Leistung noch weiter zu steigern, besteht darin, die Gewichte der obersten Schichten des vorab trainierten Modells neben dem Training des von Ihnen hinzugefügten Klassifikators zu trainieren (oder zu "verfeinern"). Der Trainingsprozess erzwingt, dass die Gewichte von allgemeinen Feature-Maps auf Features abgestimmt werden, die speziell mit dem Datensatz verknüpft sind.

Außerdem sollten Sie versuchen, eine kleine Anzahl von obersten Ebenen und nicht das gesamte MobileNet-Modell zu optimieren. In den meisten Faltungsnetzwerken ist eine Schicht umso spezialisierter, je höher sie liegt. Die ersten Ebenen lernen sehr einfache und allgemeine Funktionen, die auf fast alle Arten von Bildern verallgemeinert werden können. Je höher Sie steigen, desto spezifischer werden die Funktionen für den Datensatz, für den das Modell trainiert wurde. Das Ziel der Feinabstimmung besteht darin, diese speziellen Funktionen an die Arbeit mit dem neuen Datensatz anzupassen, anstatt das generische Lernen zu überschreiben.

Entfrieren Sie die obersten Schichten des Modells

Alles, was Sie tun müssen, ist, das base_model base_model und die unteren Ebenen so einzustellen, dass sie nicht trainiert werden können. Anschließend sollten Sie das Modell neu kompilieren (damit diese Änderungen wirksam werden) und das Training fortsetzen.

base_model.trainable = True
# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model.layers))

# Fine-tune from this layer onwards
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
  layer.trainable =  False
Number of layers in the base model:  154

Kompilieren Sie das Modell

Da Sie ein viel größeres Modell trainieren und die vorab trainierten Gewichte neu anpassen möchten, ist es wichtig, in dieser Phase eine niedrigere Lernrate zu verwenden. Andernfalls könnte Ihr Modell sehr schnell überpassen.

model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate/10),
              metrics=['accuracy'])
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
sequential (Sequential)      (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.truediv (TFOpLambda) (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.subtract (TFOpLambda (None, 160, 160, 3)       0         
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,862,721
Non-trainable params: 396,544
_________________________________________________________________
len(model.trainable_variables)
56

Trainieren Sie das Modell weiter

Wenn Sie früher auf Konvergenz trainiert haben, verbessert dieser Schritt Ihre Genauigkeit um einige Prozentpunkte.

fine_tune_epochs = 10
total_epochs =  initial_epochs + fine_tune_epochs

history_fine = model.fit(train_dataset,
                         epochs=total_epochs,
                         initial_epoch=history.epoch[-1],
                         validation_data=validation_dataset)
Epoch 10/20
63/63 [==============================] - 9s 69ms/step - loss: 0.1798 - accuracy: 0.9281 - val_loss: 0.0650 - val_accuracy: 0.9752
Epoch 11/20
63/63 [==============================] - 4s 55ms/step - loss: 0.1283 - accuracy: 0.9453 - val_loss: 0.0519 - val_accuracy: 0.9827
Epoch 12/20
63/63 [==============================] - 4s 55ms/step - loss: 0.1112 - accuracy: 0.9536 - val_loss: 0.0463 - val_accuracy: 0.9777
Epoch 13/20
63/63 [==============================] - 4s 55ms/step - loss: 0.0973 - accuracy: 0.9662 - val_loss: 0.0529 - val_accuracy: 0.9790
Epoch 14/20
63/63 [==============================] - 4s 55ms/step - loss: 0.0916 - accuracy: 0.9585 - val_loss: 0.0496 - val_accuracy: 0.9790
Epoch 15/20
63/63 [==============================] - 4s 56ms/step - loss: 0.0839 - accuracy: 0.9659 - val_loss: 0.0412 - val_accuracy: 0.9827
Epoch 16/20
63/63 [==============================] - 4s 52ms/step - loss: 0.0848 - accuracy: 0.9683 - val_loss: 0.0410 - val_accuracy: 0.9827
Epoch 17/20
63/63 [==============================] - 4s 53ms/step - loss: 0.0704 - accuracy: 0.9733 - val_loss: 0.0404 - val_accuracy: 0.9827
Epoch 18/20
63/63 [==============================] - 4s 54ms/step - loss: 0.0730 - accuracy: 0.9691 - val_loss: 0.0645 - val_accuracy: 0.9777
Epoch 19/20
63/63 [==============================] - 4s 55ms/step - loss: 0.0746 - accuracy: 0.9729 - val_loss: 0.0294 - val_accuracy: 0.9839
Epoch 20/20
63/63 [==============================] - 4s 56ms/step - loss: 0.0639 - accuracy: 0.9787 - val_loss: 0.0358 - val_accuracy: 0.9827

Werfen wir einen Blick auf die Lernkurven der Trainings- und Validierungsgenauigkeit / -verluste, wenn Sie die letzten Schichten des MobileNet V2-Basismodells optimieren und den Klassifikator darüber trainieren. Der Validierungsverlust ist viel höher als der Trainingsverlust, daher kann es zu einer Überanpassung kommen.

Es kann auch zu einer Überanpassung kommen, da der neue Trainingssatz relativ klein ist und den ursprünglichen MobileNet V2-Datensätzen ähnelt.

Nach der Feinabstimmung erreicht das Modell eine Genauigkeit von fast 98% im Validierungssatz.

acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']

loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0.8, 1])
plt.plot([initial_epochs-1,initial_epochs-1],
          plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 1.0])
plt.plot([initial_epochs-1,initial_epochs-1],
         plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

png

Bewertung und Vorhersage

Schließlich können Sie die Leistung des Modells anhand neuer Daten mithilfe eines Testsatzes überprüfen.

loss, accuracy = model.evaluate(test_dataset)
print('Test accuracy :', accuracy)
6/6 [==============================] - 1s 38ms/step - loss: 0.0428 - accuracy: 0.9740
Test accuracy : 0.9739583134651184

Und jetzt können Sie mit diesem Modell vorhersagen, ob es sich bei Ihrem Haustier um eine Katze oder einen Hund handelt.

#Retrieve a batch of images from the test set
image_batch, label_batch = test_dataset.as_numpy_iterator().next()
predictions = model.predict_on_batch(image_batch).flatten()

# Apply a sigmoid since our model returns logits
predictions = tf.nn.sigmoid(predictions)
predictions = tf.where(predictions < 0.5, 0, 1)

print('Predictions:\n', predictions.numpy())
print('Labels:\n', label_batch)

plt.figure(figsize=(10, 10))
for i in range(9):
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(image_batch[i].astype("uint8"))
  plt.title(class_names[predictions[i]])
  plt.axis("off")
Predictions:
 [0 0 1 0 0 1 1 0 0 0 0 1 1 0 1 1 0 1 0 0 1 0 1 0 1 1 1 1 1 1 1 1]
Labels:
 [0 0 1 0 0 1 1 0 0 0 0 1 1 0 1 1 0 1 0 0 1 0 1 0 1 1 1 1 1 1 1 1]

png

Zusammenfassung

  • Verwenden eines vorab trainierten Modells für die Merkmalsextraktion : Wenn Sie mit einem kleinen Datensatz arbeiten, ist es üblich, Funktionen zu nutzen, die von einem Modell gelernt wurden, das auf einem größeren Datensatz in derselben Domäne trainiert wurde. Dies erfolgt durch Instanziieren des vorab trainierten Modells und Hinzufügen eines vollständig verbundenen Klassifikators. Das vorab trainierte Modell ist "eingefroren" und nur die Gewichte des Klassifikators werden während des Trainings aktualisiert. In diesem Fall hat die Faltungsbasis alle mit jedem Bild verknüpften Merkmale extrahiert, und Sie haben gerade einen Klassifizierer trainiert, der die Bildklasse für diesen Satz extrahierter Merkmale bestimmt.

  • Feinabstimmung eines vorab trainierten Modells : Um die Leistung weiter zu verbessern, möchten Sie möglicherweise die obersten Ebenen der vorab trainierten Modelle durch Feinabstimmung für den neuen Datensatz verwenden. In diesem Fall haben Sie Ihre Gewichte so angepasst, dass Ihr Modell die für den Datensatz spezifischen Funktionen auf hoher Ebene gelernt hat. Diese Technik wird normalerweise empfohlen, wenn der Trainingsdatensatz groß und dem ursprünglichen Datensatz, auf dem das vorab trainierte Modell trainiert wurde, sehr ähnlich ist.

Weitere Informationen finden Sie im Transfer-Lernhandbuch .

# MIT License
#
# Copyright (c) 2017 François Chollet                                                                                                                    # IGNORE_COPYRIGHT: cleared by OSS licensing
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.