Hari Komunitas ML adalah 9 November! Bergabung dengan kami untuk update dari TensorFlow, JAX, dan lebih Pelajari lebih lanjut

Transfer pembelajaran dan penyempurnaan

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Dalam tutorial ini, Anda akan belajar bagaimana mengklasifikasikan gambar kucing dan anjing dengan menggunakan pembelajaran transfer dari jaringan yang telah dilatih sebelumnya.

Model pra-pelatihan adalah jaringan tersimpan yang sebelumnya dilatih pada kumpulan data besar, biasanya pada tugas klasifikasi gambar skala besar. Anda dapat menggunakan model yang telah dilatih sebelumnya atau menggunakan pembelajaran transfer untuk menyesuaikan model ini dengan tugas yang diberikan.

Intuisi di balik pembelajaran transfer untuk klasifikasi gambar adalah bahwa jika model dilatih pada kumpulan data yang cukup besar dan umum, model ini akan secara efektif berfungsi sebagai model umum dunia visual. Anda kemudian dapat memanfaatkan peta fitur yang dipelajari ini tanpa harus memulai dari awal dengan melatih model besar pada kumpulan data yang besar.

Di notebook ini, Anda akan mencoba dua cara untuk menyesuaikan model yang telah dilatih sebelumnya:

  1. Ekstraksi Fitur: Gunakan representasi yang dipelajari oleh jaringan sebelumnya untuk mengekstrak fitur yang berarti dari sampel baru. Anda cukup menambahkan pengklasifikasi baru, yang akan dilatih dari awal, di atas model yang telah dilatih sebelumnya sehingga Anda dapat menggunakan kembali peta fitur yang dipelajari sebelumnya untuk kumpulan data.

    Anda tidak perlu (kembali) melatih seluruh model. Jaringan konvolusi dasar sudah berisi fitur yang secara umum berguna untuk mengklasifikasikan gambar. Namun, bagian klasifikasi terakhir dari model pra-pelatihan khusus untuk tugas klasifikasi asli, dan selanjutnya khusus untuk himpunan kelas tempat model dilatih.

  2. Fine-Tuning: Mencairkan beberapa lapisan atas dari basis model yang dibekukan dan bersama-sama melatih lapisan pengklasifikasi yang baru ditambahkan dan lapisan terakhir dari model dasar. Ini memungkinkan kita untuk "menyempurnakan" representasi fitur tingkat tinggi dalam model dasar agar lebih relevan untuk tugas tertentu.

Anda akan mengikuti alur kerja machine learning umum.

  1. Pelajari dan pahami datanya
  2. Bangun saluran input, dalam hal ini menggunakan Keras ImageDataGenerator
  3. Buat modelnya
    • Muat dalam model dasar yang telah dilatih sebelumnya (dan bobot yang telah dilatih sebelumnya)
    • Tumpuk lapisan klasifikasi di atas
  4. Latih modelnya
  5. Evaluasi model
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf

from tensorflow.keras.preprocessing import image_dataset_from_directory

Pra-pemrosesan data

Unduhan data

Dalam tutorial ini, Anda akan menggunakan kumpulan data yang berisi beberapa ribu gambar kucing dan anjing. Download dan ekstrak file zip yang berisi gambar, kemudian membuat tf.data.Dataset untuk pelatihan dan validasi menggunakan tf.keras.preprocessing.image_dataset_from_directory utilitas. Anda dapat mempelajari lebih lanjut tentang memuat gambar dalam ini tutorial .

_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

BATCH_SIZE = 32
IMG_SIZE = (160, 160)

train_dataset = image_dataset_from_directory(train_dir,
                                             shuffle=True,
                                             batch_size=BATCH_SIZE,
                                             image_size=IMG_SIZE)
Downloading data from https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip
68608000/68606236 [==============================] - 1s 0us/step
68616192/68606236 [==============================] - 1s 0us/step
Found 2000 files belonging to 2 classes.
validation_dataset = image_dataset_from_directory(validation_dir,
                                                  shuffle=True,
                                                  batch_size=BATCH_SIZE,
                                                  image_size=IMG_SIZE)
Found 1000 files belonging to 2 classes.

Tampilkan sembilan gambar dan label pertama dari set pelatihan:

class_names = train_dataset.class_names

plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

png

Karena set data asli tidak berisi set pengujian, Anda akan membuatnya. Untuk melakukannya, menentukan berapa banyak batch data yang tersedia pada set validasi menggunakan tf.data.experimental.cardinality , kemudian pindah 20% dari mereka untuk satu set tes.

val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)
print('Number of validation batches: %d' % tf.data.experimental.cardinality(validation_dataset))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_dataset))
Number of validation batches: 26
Number of test batches: 6

Konfigurasikan kumpulan data untuk kinerja

Gunakan buffered prefetching untuk memuat gambar dari disk tanpa I/O menjadi pemblokiran. Untuk mempelajari lebih lanjut tentang metode ini melihat data kinerja panduan.

AUTOTUNE = tf.data.AUTOTUNE

train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)

Gunakan augmentasi data

Bila Anda tidak memiliki kumpulan data gambar yang besar, merupakan praktik yang baik untuk memperkenalkan keragaman sampel secara artifisial dengan menerapkan transformasi acak, namun realistis, ke gambar pelatihan, seperti rotasi dan pembalikan horizontal. Ini membantu mengekspos model untuk aspek yang berbeda dari data pelatihan dan mengurangi overfitting . Anda dapat mempelajari lebih lanjut tentang augmentation data dalam tutorial .

data_augmentation = tf.keras.Sequential([
  tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
  tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])

Mari kita berulang kali menerapkan lapisan ini ke gambar yang sama dan melihat hasilnya.

for image, _ in train_dataset.take(1):
  plt.figure(figsize=(10, 10))
  first_image = image[0]
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    augmented_image = data_augmentation(tf.expand_dims(first_image, 0))
    plt.imshow(augmented_image[0] / 255)
    plt.axis('off')

png

Skala ulang nilai piksel

Dalam beberapa saat, Anda akan men-download tf.keras.applications.MobileNetV2 untuk digunakan sebagai model dasar Anda. Model ini mengharapkan nilai-nilai pixel di [-1, 1] , tapi pada saat ini, nilai-nilai pixel dalam gambar Anda berada di [0, 255] . Untuk mengubah skalanya, gunakan metode prapemrosesan yang disertakan dengan model.

preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
rescale = tf.keras.layers.experimental.preprocessing.Rescaling(1./127.5, offset= -1)

Buat model dasar dari konvnet yang telah dilatih sebelumnya

Anda akan menciptakan model dasar dari model MobileNet V2 dikembangkan di Google. Ini telah dilatih sebelumnya pada kumpulan data ImageNet, kumpulan data besar yang terdiri dari 1,4 juta gambar dan 1000 kelas. ImageNet adalah pelatihan dataset penelitian dengan berbagai macam kategori seperti jackfruit dan syringe . Basis pengetahuan ini akan membantu kami mengklasifikasikan kucing dan anjing dari kumpulan data spesifik kami.

Pertama, Anda harus memilih lapisan MobileNet V2 mana yang akan Anda gunakan untuk ekstraksi fitur. Lapisan klasifikasi terakhir (di "atas", karena sebagian besar diagram model pembelajaran mesin bergerak dari bawah ke atas) tidak terlalu berguna. Sebagai gantinya, Anda akan mengikuti praktik umum untuk bergantung pada lapisan terakhir sebelum operasi perataan. Lapisan ini disebut "lapisan bottleneck". Fitur lapisan bottleneck mempertahankan lebih umum dibandingkan dengan lapisan akhir/atas.

Pertama, buat instance model MobileNet V2 yang telah dimuat sebelumnya dengan bobot yang dilatih di ImageNet. Dengan menentukan include_top = argumen False, Anda memuat jaringan yang tidak termasuk lapisan klasifikasi di atas, yang sangat ideal untuk ekstraksi fitur.

# Create the base model from the pre-trained model MobileNet V2
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step
9420800/9406464 [==============================] - 0s 0us/step

Ekstraktor fitur ini mengkonversi setiap 160x160x3 gambar ke dalam 5x5x1280 blok fitur. Mari kita lihat apa yang dilakukannya pada kumpulan contoh gambar:

image_batch, label_batch = next(iter(train_dataset))
feature_batch = base_model(image_batch)
print(feature_batch.shape)
(32, 5, 5, 1280)

Ekstraksi fitur

Pada langkah ini, Anda akan membekukan basis konvolusi yang dibuat dari langkah sebelumnya dan digunakan sebagai ekstraktor fitur. Selain itu, Anda menambahkan pengklasifikasi di atasnya dan melatih pengklasifikasi tingkat atas.

Bekukan basis konvolusi

Penting untuk membekukan basis konvolusi sebelum Anda mengompilasi dan melatih model. Pembekuan (dengan menyetel layer.trainable = False) mencegah pembobotan pada lapisan tertentu agar tidak diperbarui selama pelatihan. MobileNet V2 memiliki banyak lapisan, sehingga pengaturan seluruh model trainable bendera ke False akan membekukan semua dari mereka.

base_model.trainable = False

Catatan penting tentang lapisan BatchNormalization

Banyak model mengandung tf.keras.layers.BatchNormalization lapisan. Lapisan ini adalah kasus khusus dan tindakan pencegahan harus diambil dalam konteks fine-tuning, seperti yang ditunjukkan nanti dalam tutorial ini.

Ketika Anda menetapkan layer.trainable = False , yang BatchNormalization lapisan akan berjalan dalam mode inferensi, dan tidak akan memperbarui mean dan varians statistiknya.

Ketika Anda mencairkan model yang berisi lapisan BatchNormalization untuk melakukan fine-tuning, Anda harus menjaga lapisan BatchNormalization dalam mode inferensi dengan melewati training = False saat memanggil model dasar. Jika tidak, pembaruan yang diterapkan pada bobot yang tidak dapat dilatih akan menghancurkan apa yang telah dipelajari model.

Untuk lebih jelasnya, lihat panduan belajar transfer .

# Let's take a look at the base model architecture
base_model.summary()
Model: "mobilenetv2_1.00_160"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 160, 160, 3) 0                                            
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 80, 80, 32)   864         input_1[0][0]                    
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 80, 80, 32)   128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, 80, 80, 32)   0           bn_Conv1[0][0]                   
__________________________________________________________________________________________________
expanded_conv_depthwise (Depthw (None, 80, 80, 32)   288         Conv1_relu[0][0]                 
__________________________________________________________________________________________________
expanded_conv_depthwise_BN (Bat (None, 80, 80, 32)   128         expanded_conv_depthwise[0][0]    
__________________________________________________________________________________________________
expanded_conv_depthwise_relu (R (None, 80, 80, 32)   0           expanded_conv_depthwise_BN[0][0] 
__________________________________________________________________________________________________
expanded_conv_project (Conv2D)  (None, 80, 80, 16)   512         expanded_conv_depthwise_relu[0][0
__________________________________________________________________________________________________
expanded_conv_project_BN (Batch (None, 80, 80, 16)   64          expanded_conv_project[0][0]      
__________________________________________________________________________________________________
block_1_expand (Conv2D)         (None, 80, 80, 96)   1536        expanded_conv_project_BN[0][0]   
__________________________________________________________________________________________________
block_1_expand_BN (BatchNormali (None, 80, 80, 96)   384         block_1_expand[0][0]             
__________________________________________________________________________________________________
block_1_expand_relu (ReLU)      (None, 80, 80, 96)   0           block_1_expand_BN[0][0]          
__________________________________________________________________________________________________
block_1_pad (ZeroPadding2D)     (None, 81, 81, 96)   0           block_1_expand_relu[0][0]        
__________________________________________________________________________________________________
block_1_depthwise (DepthwiseCon (None, 40, 40, 96)   864         block_1_pad[0][0]                
__________________________________________________________________________________________________
block_1_depthwise_BN (BatchNorm (None, 40, 40, 96)   384         block_1_depthwise[0][0]          
__________________________________________________________________________________________________
block_1_depthwise_relu (ReLU)   (None, 40, 40, 96)   0           block_1_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_1_project (Conv2D)        (None, 40, 40, 24)   2304        block_1_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_1_project_BN (BatchNormal (None, 40, 40, 24)   96          block_1_project[0][0]            
__________________________________________________________________________________________________
block_2_expand (Conv2D)         (None, 40, 40, 144)  3456        block_1_project_BN[0][0]         
__________________________________________________________________________________________________
block_2_expand_BN (BatchNormali (None, 40, 40, 144)  576         block_2_expand[0][0]             
__________________________________________________________________________________________________
block_2_expand_relu (ReLU)      (None, 40, 40, 144)  0           block_2_expand_BN[0][0]          
__________________________________________________________________________________________________
block_2_depthwise (DepthwiseCon (None, 40, 40, 144)  1296        block_2_expand_relu[0][0]        
__________________________________________________________________________________________________
block_2_depthwise_BN (BatchNorm (None, 40, 40, 144)  576         block_2_depthwise[0][0]          
__________________________________________________________________________________________________
block_2_depthwise_relu (ReLU)   (None, 40, 40, 144)  0           block_2_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_2_project (Conv2D)        (None, 40, 40, 24)   3456        block_2_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_2_project_BN (BatchNormal (None, 40, 40, 24)   96          block_2_project[0][0]            
__________________________________________________________________________________________________
block_2_add (Add)               (None, 40, 40, 24)   0           block_1_project_BN[0][0]         
                                                                 block_2_project_BN[0][0]         
__________________________________________________________________________________________________
block_3_expand (Conv2D)         (None, 40, 40, 144)  3456        block_2_add[0][0]                
__________________________________________________________________________________________________
block_3_expand_BN (BatchNormali (None, 40, 40, 144)  576         block_3_expand[0][0]             
__________________________________________________________________________________________________
block_3_expand_relu (ReLU)      (None, 40, 40, 144)  0           block_3_expand_BN[0][0]          
__________________________________________________________________________________________________
block_3_pad (ZeroPadding2D)     (None, 41, 41, 144)  0           block_3_expand_relu[0][0]        
__________________________________________________________________________________________________
block_3_depthwise (DepthwiseCon (None, 20, 20, 144)  1296        block_3_pad[0][0]                
__________________________________________________________________________________________________
block_3_depthwise_BN (BatchNorm (None, 20, 20, 144)  576         block_3_depthwise[0][0]          
__________________________________________________________________________________________________
block_3_depthwise_relu (ReLU)   (None, 20, 20, 144)  0           block_3_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_3_project (Conv2D)        (None, 20, 20, 32)   4608        block_3_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_3_project_BN (BatchNormal (None, 20, 20, 32)   128         block_3_project[0][0]            
__________________________________________________________________________________________________
block_4_expand (Conv2D)         (None, 20, 20, 192)  6144        block_3_project_BN[0][0]         
__________________________________________________________________________________________________
block_4_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_4_expand[0][0]             
__________________________________________________________________________________________________
block_4_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_4_expand_BN[0][0]          
__________________________________________________________________________________________________
block_4_depthwise (DepthwiseCon (None, 20, 20, 192)  1728        block_4_expand_relu[0][0]        
__________________________________________________________________________________________________
block_4_depthwise_BN (BatchNorm (None, 20, 20, 192)  768         block_4_depthwise[0][0]          
__________________________________________________________________________________________________
block_4_depthwise_relu (ReLU)   (None, 20, 20, 192)  0           block_4_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_4_project (Conv2D)        (None, 20, 20, 32)   6144        block_4_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_4_project_BN (BatchNormal (None, 20, 20, 32)   128         block_4_project[0][0]            
__________________________________________________________________________________________________
block_4_add (Add)               (None, 20, 20, 32)   0           block_3_project_BN[0][0]         
                                                                 block_4_project_BN[0][0]         
__________________________________________________________________________________________________
block_5_expand (Conv2D)         (None, 20, 20, 192)  6144        block_4_add[0][0]                
__________________________________________________________________________________________________
block_5_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_5_expand[0][0]             
__________________________________________________________________________________________________
block_5_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_5_expand_BN[0][0]          
__________________________________________________________________________________________________
block_5_depthwise (DepthwiseCon (None, 20, 20, 192)  1728        block_5_expand_relu[0][0]        
__________________________________________________________________________________________________
block_5_depthwise_BN (BatchNorm (None, 20, 20, 192)  768         block_5_depthwise[0][0]          
__________________________________________________________________________________________________
block_5_depthwise_relu (ReLU)   (None, 20, 20, 192)  0           block_5_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_5_project (Conv2D)        (None, 20, 20, 32)   6144        block_5_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_5_project_BN (BatchNormal (None, 20, 20, 32)   128         block_5_project[0][0]            
__________________________________________________________________________________________________
block_5_add (Add)               (None, 20, 20, 32)   0           block_4_add[0][0]                
                                                                 block_5_project_BN[0][0]         
__________________________________________________________________________________________________
block_6_expand (Conv2D)         (None, 20, 20, 192)  6144        block_5_add[0][0]                
__________________________________________________________________________________________________
block_6_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_6_expand[0][0]             
__________________________________________________________________________________________________
block_6_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_6_expand_BN[0][0]          
__________________________________________________________________________________________________
block_6_pad (ZeroPadding2D)     (None, 21, 21, 192)  0           block_6_expand_relu[0][0]        
__________________________________________________________________________________________________
block_6_depthwise (DepthwiseCon (None, 10, 10, 192)  1728        block_6_pad[0][0]                
__________________________________________________________________________________________________
block_6_depthwise_BN (BatchNorm (None, 10, 10, 192)  768         block_6_depthwise[0][0]          
__________________________________________________________________________________________________
block_6_depthwise_relu (ReLU)   (None, 10, 10, 192)  0           block_6_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_6_project (Conv2D)        (None, 10, 10, 64)   12288       block_6_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_6_project_BN (BatchNormal (None, 10, 10, 64)   256         block_6_project[0][0]            
__________________________________________________________________________________________________
block_7_expand (Conv2D)         (None, 10, 10, 384)  24576       block_6_project_BN[0][0]         
__________________________________________________________________________________________________
block_7_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_7_expand[0][0]             
__________________________________________________________________________________________________
block_7_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_7_expand_BN[0][0]          
__________________________________________________________________________________________________
block_7_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_7_expand_relu[0][0]        
__________________________________________________________________________________________________
block_7_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_7_depthwise[0][0]          
__________________________________________________________________________________________________
block_7_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_7_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_7_project (Conv2D)        (None, 10, 10, 64)   24576       block_7_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_7_project_BN (BatchNormal (None, 10, 10, 64)   256         block_7_project[0][0]            
__________________________________________________________________________________________________
block_7_add (Add)               (None, 10, 10, 64)   0           block_6_project_BN[0][0]         
                                                                 block_7_project_BN[0][0]         
__________________________________________________________________________________________________
block_8_expand (Conv2D)         (None, 10, 10, 384)  24576       block_7_add[0][0]                
__________________________________________________________________________________________________
block_8_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_8_expand[0][0]             
__________________________________________________________________________________________________
block_8_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_8_expand_BN[0][0]          
__________________________________________________________________________________________________
block_8_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_8_expand_relu[0][0]        
__________________________________________________________________________________________________
block_8_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_8_depthwise[0][0]          
__________________________________________________________________________________________________
block_8_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_8_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_8_project (Conv2D)        (None, 10, 10, 64)   24576       block_8_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_8_project_BN (BatchNormal (None, 10, 10, 64)   256         block_8_project[0][0]            
__________________________________________________________________________________________________
block_8_add (Add)               (None, 10, 10, 64)   0           block_7_add[0][0]                
                                                                 block_8_project_BN[0][0]         
__________________________________________________________________________________________________
block_9_expand (Conv2D)         (None, 10, 10, 384)  24576       block_8_add[0][0]                
__________________________________________________________________________________________________
block_9_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_9_expand[0][0]             
__________________________________________________________________________________________________
block_9_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_9_expand_BN[0][0]          
__________________________________________________________________________________________________
block_9_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_9_expand_relu[0][0]        
__________________________________________________________________________________________________
block_9_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_9_depthwise[0][0]          
__________________________________________________________________________________________________
block_9_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_9_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_9_project (Conv2D)        (None, 10, 10, 64)   24576       block_9_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_9_project_BN (BatchNormal (None, 10, 10, 64)   256         block_9_project[0][0]            
__________________________________________________________________________________________________
block_9_add (Add)               (None, 10, 10, 64)   0           block_8_add[0][0]                
                                                                 block_9_project_BN[0][0]         
__________________________________________________________________________________________________
block_10_expand (Conv2D)        (None, 10, 10, 384)  24576       block_9_add[0][0]                
__________________________________________________________________________________________________
block_10_expand_BN (BatchNormal (None, 10, 10, 384)  1536        block_10_expand[0][0]            
__________________________________________________________________________________________________
block_10_expand_relu (ReLU)     (None, 10, 10, 384)  0           block_10_expand_BN[0][0]         
__________________________________________________________________________________________________
block_10_depthwise (DepthwiseCo (None, 10, 10, 384)  3456        block_10_expand_relu[0][0]       
__________________________________________________________________________________________________
block_10_depthwise_BN (BatchNor (None, 10, 10, 384)  1536        block_10_depthwise[0][0]         
__________________________________________________________________________________________________
block_10_depthwise_relu (ReLU)  (None, 10, 10, 384)  0           block_10_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_10_project (Conv2D)       (None, 10, 10, 96)   36864       block_10_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_10_project_BN (BatchNorma (None, 10, 10, 96)   384         block_10_project[0][0]           
__________________________________________________________________________________________________
block_11_expand (Conv2D)        (None, 10, 10, 576)  55296       block_10_project_BN[0][0]        
__________________________________________________________________________________________________
block_11_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_11_expand[0][0]            
__________________________________________________________________________________________________
block_11_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_11_expand_BN[0][0]         
__________________________________________________________________________________________________
block_11_depthwise (DepthwiseCo (None, 10, 10, 576)  5184        block_11_expand_relu[0][0]       
__________________________________________________________________________________________________
block_11_depthwise_BN (BatchNor (None, 10, 10, 576)  2304        block_11_depthwise[0][0]         
__________________________________________________________________________________________________
block_11_depthwise_relu (ReLU)  (None, 10, 10, 576)  0           block_11_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_11_project (Conv2D)       (None, 10, 10, 96)   55296       block_11_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_11_project_BN (BatchNorma (None, 10, 10, 96)   384         block_11_project[0][0]           
__________________________________________________________________________________________________
block_11_add (Add)              (None, 10, 10, 96)   0           block_10_project_BN[0][0]        
                                                                 block_11_project_BN[0][0]        
__________________________________________________________________________________________________
block_12_expand (Conv2D)        (None, 10, 10, 576)  55296       block_11_add[0][0]               
__________________________________________________________________________________________________
block_12_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_12_expand[0][0]            
__________________________________________________________________________________________________
block_12_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_12_expand_BN[0][0]         
__________________________________________________________________________________________________
block_12_depthwise (DepthwiseCo (None, 10, 10, 576)  5184        block_12_expand_relu[0][0]       
__________________________________________________________________________________________________
block_12_depthwise_BN (BatchNor (None, 10, 10, 576)  2304        block_12_depthwise[0][0]         
__________________________________________________________________________________________________
block_12_depthwise_relu (ReLU)  (None, 10, 10, 576)  0           block_12_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_12_project (Conv2D)       (None, 10, 10, 96)   55296       block_12_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_12_project_BN (BatchNorma (None, 10, 10, 96)   384         block_12_project[0][0]           
__________________________________________________________________________________________________
block_12_add (Add)              (None, 10, 10, 96)   0           block_11_add[0][0]               
                                                                 block_12_project_BN[0][0]        
__________________________________________________________________________________________________
block_13_expand (Conv2D)        (None, 10, 10, 576)  55296       block_12_add[0][0]               
__________________________________________________________________________________________________
block_13_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_13_expand[0][0]            
__________________________________________________________________________________________________
block_13_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_13_expand_BN[0][0]         
__________________________________________________________________________________________________
block_13_pad (ZeroPadding2D)    (None, 11, 11, 576)  0           block_13_expand_relu[0][0]       
__________________________________________________________________________________________________
block_13_depthwise (DepthwiseCo (None, 5, 5, 576)    5184        block_13_pad[0][0]               
__________________________________________________________________________________________________
block_13_depthwise_BN (BatchNor (None, 5, 5, 576)    2304        block_13_depthwise[0][0]         
__________________________________________________________________________________________________
block_13_depthwise_relu (ReLU)  (None, 5, 5, 576)    0           block_13_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_13_project (Conv2D)       (None, 5, 5, 160)    92160       block_13_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_13_project_BN (BatchNorma (None, 5, 5, 160)    640         block_13_project[0][0]           
__________________________________________________________________________________________________
block_14_expand (Conv2D)        (None, 5, 5, 960)    153600      block_13_project_BN[0][0]        
__________________________________________________________________________________________________
block_14_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_14_expand[0][0]            
__________________________________________________________________________________________________
block_14_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_14_expand_BN[0][0]         
__________________________________________________________________________________________________
block_14_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_14_expand_relu[0][0]       
__________________________________________________________________________________________________
block_14_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_14_depthwise[0][0]         
__________________________________________________________________________________________________
block_14_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_14_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_14_project (Conv2D)       (None, 5, 5, 160)    153600      block_14_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_14_project_BN (BatchNorma (None, 5, 5, 160)    640         block_14_project[0][0]           
__________________________________________________________________________________________________
block_14_add (Add)              (None, 5, 5, 160)    0           block_13_project_BN[0][0]        
                                                                 block_14_project_BN[0][0]        
__________________________________________________________________________________________________
block_15_expand (Conv2D)        (None, 5, 5, 960)    153600      block_14_add[0][0]               
__________________________________________________________________________________________________
block_15_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_15_expand[0][0]            
__________________________________________________________________________________________________
block_15_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_15_expand_BN[0][0]         
__________________________________________________________________________________________________
block_15_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_15_expand_relu[0][0]       
__________________________________________________________________________________________________
block_15_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_15_depthwise[0][0]         
__________________________________________________________________________________________________
block_15_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_15_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_15_project (Conv2D)       (None, 5, 5, 160)    153600      block_15_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_15_project_BN (BatchNorma (None, 5, 5, 160)    640         block_15_project[0][0]           
__________________________________________________________________________________________________
block_15_add (Add)              (None, 5, 5, 160)    0           block_14_add[0][0]               
                                                                 block_15_project_BN[0][0]        
__________________________________________________________________________________________________
block_16_expand (Conv2D)        (None, 5, 5, 960)    153600      block_15_add[0][0]               
__________________________________________________________________________________________________
block_16_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_16_expand[0][0]            
__________________________________________________________________________________________________
block_16_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_16_expand_BN[0][0]         
__________________________________________________________________________________________________
block_16_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_16_expand_relu[0][0]       
__________________________________________________________________________________________________
block_16_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_16_depthwise[0][0]         
__________________________________________________________________________________________________
block_16_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_16_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_16_project (Conv2D)       (None, 5, 5, 320)    307200      block_16_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_16_project_BN (BatchNorma (None, 5, 5, 320)    1280        block_16_project[0][0]           
__________________________________________________________________________________________________
Conv_1 (Conv2D)                 (None, 5, 5, 1280)   409600      block_16_project_BN[0][0]        
__________________________________________________________________________________________________
Conv_1_bn (BatchNormalization)  (None, 5, 5, 1280)   5120        Conv_1[0][0]                     
__________________________________________________________________________________________________
out_relu (ReLU)                 (None, 5, 5, 1280)   0           Conv_1_bn[0][0]                  
==================================================================================================
Total params: 2,257,984
Trainable params: 0
Non-trainable params: 2,257,984
__________________________________________________________________________________________________

Tambahkan kepala klasifikasi

Untuk menghasilkan prediksi dari blok fitur, rata-rata selama spasial 5x5 lokasi spasial, menggunakan tf.keras.layers.GlobalAveragePooling2D lapisan untuk mengubah fitur untuk vektor 1280-elemen tunggal per gambar.

global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)
(32, 1280)

Oleskan tf.keras.layers.Dense lapisan untuk mengkonversi fitur ini menjadi prediksi tunggal per gambar. Anda tidak perlu fungsi aktivasi di sini karena prediksi ini akan diperlakukan sebagai logit , atau nilai prediksi mentah. Angka positif memprediksi kelas 1, angka negatif memprediksi kelas 0.

prediction_layer = tf.keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)
(32, 1)

Membangun model dengan chaining bersama augmentasi data, rescaling, base_model dan fitur ekstraktor lapisan menggunakan Keras Fungsional API . Seperti disebutkan sebelumnya, penggunaan training=False sebagai model kami mengandung BatchNormalization lapisan.

inputs = tf.keras.Input(shape=(160, 160, 3))
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)

Kompilasi modelnya

Kompilasi model sebelum melatihnya. Karena ada dua kelas, menggunakan tf.keras.losses.BinaryCrossentropy kerugian dengan from_logits=True sejak model memberikan output linear.

base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
sequential (Sequential)      (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.truediv (TFOpLambda) (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.subtract (TFOpLambda (None, 160, 160, 3)       0         
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,281
Non-trainable params: 2,257,984
_________________________________________________________________

Parameter 2.5M di MobileNet yang beku, tetapi ada 1.2K parameter dilatih di lapisan padat. Ini dibagi antara dua tf.Variable benda, bobot dan bias.

len(model.trainable_variables)
2

Latih modelnya

Setelah pelatihan selama 10 epoch, Anda akan melihat akurasi ~94% pada set validasi.

initial_epochs = 10

loss0, accuracy0 = model.evaluate(validation_dataset)
26/26 [==============================] - 2s 18ms/step - loss: 0.8367 - accuracy: 0.4851
print("initial loss: {:.2f}".format(loss0))
print("initial accuracy: {:.2f}".format(accuracy0))
initial loss: 0.84
initial accuracy: 0.49
history = model.fit(train_dataset,
                    epochs=initial_epochs,
                    validation_data=validation_dataset)
Epoch 1/10
63/63 [==============================] - 4s 26ms/step - loss: 0.7017 - accuracy: 0.5825 - val_loss: 0.5713 - val_accuracy: 0.6968
Epoch 2/10
63/63 [==============================] - 2s 24ms/step - loss: 0.5313 - accuracy: 0.7140 - val_loss: 0.4127 - val_accuracy: 0.8057
Epoch 3/10
63/63 [==============================] - 2s 24ms/step - loss: 0.4209 - accuracy: 0.7980 - val_loss: 0.3240 - val_accuracy: 0.8478
Epoch 4/10
63/63 [==============================] - 2s 24ms/step - loss: 0.3611 - accuracy: 0.8210 - val_loss: 0.2657 - val_accuracy: 0.8874
Epoch 5/10
63/63 [==============================] - 2s 23ms/step - loss: 0.3088 - accuracy: 0.8605 - val_loss: 0.2233 - val_accuracy: 0.9171
Epoch 6/10
63/63 [==============================] - 2s 23ms/step - loss: 0.2744 - accuracy: 0.8835 - val_loss: 0.1952 - val_accuracy: 0.9307
Epoch 7/10
63/63 [==============================] - 2s 23ms/step - loss: 0.2481 - accuracy: 0.8930 - val_loss: 0.1746 - val_accuracy: 0.9443
Epoch 8/10
63/63 [==============================] - 1s 23ms/step - loss: 0.2341 - accuracy: 0.9050 - val_loss: 0.1580 - val_accuracy: 0.9468
Epoch 9/10
63/63 [==============================] - 2s 23ms/step - loss: 0.2212 - accuracy: 0.9070 - val_loss: 0.1490 - val_accuracy: 0.9493
Epoch 10/10
63/63 [==============================] - 2s 23ms/step - loss: 0.1947 - accuracy: 0.9205 - val_loss: 0.1347 - val_accuracy: 0.9530

kurva belajar

Mari kita lihat kurva pembelajaran dari pelatihan dan akurasi/kerugian validasi saat menggunakan model dasar MobileNet V2 sebagai ekstraktor fitur tetap.

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

png

Pada tingkat yang lebih rendah, ini juga karena metrik pelatihan melaporkan rata-rata untuk suatu epoch, sedangkan metrik validasi dievaluasi setelah epoch, jadi metrik validasi melihat model yang telah dilatih sedikit lebih lama.

Mencari setelan

Dalam percobaan ekstraksi fitur, Anda hanya melatih beberapa lapisan di atas model dasar MobileNet V2. Bobot dari jaringan pra-dilatih tidak diperbarui selama pelatihan.

Salah satu cara untuk meningkatkan kinerja lebih jauh adalah dengan melatih (atau "menyetel") bobot lapisan atas model pra-latihan di samping pelatihan pengklasifikasi yang Anda tambahkan. Proses pelatihan akan memaksa bobot untuk disetel dari peta fitur generik ke fitur yang terkait secara khusus dengan kumpulan data.

Selain itu, Anda harus mencoba menyempurnakan sejumlah kecil lapisan atas daripada keseluruhan model MobileNet. Di sebagian besar jaringan konvolusi, semakin tinggi lapisan, semakin terspesialisasi. Beberapa lapisan pertama mempelajari fitur yang sangat sederhana dan umum yang digeneralisasi ke hampir semua jenis gambar. Saat Anda naik lebih tinggi, fitur-fiturnya semakin spesifik untuk kumpulan data tempat model dilatih. Tujuan fine-tuning adalah untuk mengadaptasi fitur-fitur khusus ini untuk bekerja dengan dataset baru, daripada menimpa pembelajaran generik.

Un-freeze lapisan atas model

Yang perlu Anda lakukan adalah mencairkan base_model dan mengatur bagian bawah lapisan menjadi un-dilatih. Kemudian, Anda harus mengkompilasi ulang model (diperlukan agar perubahan ini diterapkan), dan melanjutkan pelatihan.

base_model.trainable = True
# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model.layers))

# Fine-tune from this layer onwards
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
  layer.trainable =  False
Number of layers in the base model:  154

Kompilasi modelnya

Saat Anda melatih model yang jauh lebih besar dan ingin menyesuaikan kembali bobot yang telah dilatih sebelumnya, penting untuk menggunakan tingkat pembelajaran yang lebih rendah pada tahap ini. Jika tidak, model Anda bisa overfit dengan sangat cepat.

model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate/10),
              metrics=['accuracy'])
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py:356: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  "The `lr` argument is deprecated, use `learning_rate` instead.")
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
sequential (Sequential)      (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.truediv (TFOpLambda) (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.subtract (TFOpLambda (None, 160, 160, 3)       0         
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,862,721
Non-trainable params: 396,544
_________________________________________________________________
len(model.trainable_variables)
56

Lanjutkan melatih model

Jika Anda dilatih untuk konvergensi sebelumnya, langkah ini akan meningkatkan akurasi Anda dengan beberapa poin persentase.

fine_tune_epochs = 10
total_epochs =  initial_epochs + fine_tune_epochs

history_fine = model.fit(train_dataset,
                         epochs=total_epochs,
                         initial_epoch=history.epoch[-1],
                         validation_data=validation_dataset)
Epoch 10/20
63/63 [==============================] - 8s 45ms/step - loss: 0.1384 - accuracy: 0.9440 - val_loss: 0.0653 - val_accuracy: 0.9839
Epoch 11/20
63/63 [==============================] - 2s 32ms/step - loss: 0.1087 - accuracy: 0.9520 - val_loss: 0.0447 - val_accuracy: 0.9827
Epoch 12/20
63/63 [==============================] - 2s 32ms/step - loss: 0.1177 - accuracy: 0.9510 - val_loss: 0.0453 - val_accuracy: 0.9839
Epoch 13/20
63/63 [==============================] - 2s 32ms/step - loss: 0.0955 - accuracy: 0.9620 - val_loss: 0.0547 - val_accuracy: 0.9802
Epoch 14/20
63/63 [==============================] - 2s 32ms/step - loss: 0.0838 - accuracy: 0.9700 - val_loss: 0.0326 - val_accuracy: 0.9864
Epoch 15/20
63/63 [==============================] - 2s 32ms/step - loss: 0.0828 - accuracy: 0.9655 - val_loss: 0.0403 - val_accuracy: 0.9827
Epoch 16/20
63/63 [==============================] - 2s 32ms/step - loss: 0.0807 - accuracy: 0.9670 - val_loss: 0.0487 - val_accuracy: 0.9790
Epoch 17/20
63/63 [==============================] - 2s 32ms/step - loss: 0.0696 - accuracy: 0.9730 - val_loss: 0.0406 - val_accuracy: 0.9814
Epoch 18/20
63/63 [==============================] - 2s 32ms/step - loss: 0.0616 - accuracy: 0.9735 - val_loss: 0.0387 - val_accuracy: 0.9851
Epoch 19/20
63/63 [==============================] - 2s 32ms/step - loss: 0.0630 - accuracy: 0.9725 - val_loss: 0.0342 - val_accuracy: 0.9839
Epoch 20/20
63/63 [==============================] - 2s 32ms/step - loss: 0.0599 - accuracy: 0.9750 - val_loss: 0.0355 - val_accuracy: 0.9851

Mari kita lihat kurva pembelajaran dari pelatihan dan akurasi/kerugian validasi saat menyempurnakan beberapa lapisan terakhir model dasar MobileNet V2 dan melatih pengklasifikasi di atasnya. Kehilangan validasi jauh lebih tinggi daripada kerugian pelatihan, jadi Anda mungkin mendapatkan beberapa overfitting.

Anda juga mungkin mendapatkan beberapa overfitting karena set pelatihan baru relatif kecil dan mirip dengan set data MobileNet V2 asli.

Setelah fine tuning model hampir mencapai akurasi 98% pada set validasi.

acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']

loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0.8, 1])
plt.plot([initial_epochs-1,initial_epochs-1],
          plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 1.0])
plt.plot([initial_epochs-1,initial_epochs-1],
         plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

png

Evaluasi dan prediksi

Akhirnya Anda dapat memverifikasi kinerja model pada data baru menggunakan test set.

loss, accuracy = model.evaluate(test_dataset)
print('Test accuracy :', accuracy)
6/6 [==============================] - 0s 16ms/step - loss: 0.0431 - accuracy: 0.9740
Test accuracy : 0.9739583134651184

Dan sekarang Anda siap menggunakan model ini untuk memprediksi apakah hewan peliharaan Anda adalah kucing atau anjing.

# Retrieve a batch of images from the test set
image_batch, label_batch = test_dataset.as_numpy_iterator().next()
predictions = model.predict_on_batch(image_batch).flatten()

# Apply a sigmoid since our model returns logits
predictions = tf.nn.sigmoid(predictions)
predictions = tf.where(predictions < 0.5, 0, 1)

print('Predictions:\n', predictions.numpy())
print('Labels:\n', label_batch)

plt.figure(figsize=(10, 10))
for i in range(9):
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(image_batch[i].astype("uint8"))
  plt.title(class_names[predictions[i]])
  plt.axis("off")
Predictions:
 [0 1 0 1 0 1 0 0 1 0 0 0 0 0 0 0 1 1 0 0 1 0 1 0 1 1 1 0 1 0 1 0]
Labels:
 [0 1 0 1 0 1 0 0 1 0 0 0 0 0 0 0 1 1 0 0 1 0 1 0 1 1 1 0 1 0 1 0]

png

Ringkasan

  • Menggunakan model pra-dilatih untuk ekstraksi fitur: Ketika bekerja dengan dataset kecil, itu adalah praktek umum untuk memanfaatkan fitur dipelajari oleh model dilatih pada dataset yang lebih besar dalam domain yang sama. Ini dilakukan dengan membuat instance model yang telah dilatih sebelumnya dan menambahkan classifier yang terhubung penuh di atasnya. Model pra-pelatihan "dibekukan" dan hanya bobot pengklasifikasi yang diperbarui selama pelatihan. Dalam hal ini, basis konvolusi mengekstrak semua fitur yang terkait dengan setiap gambar dan Anda baru saja melatih pengklasifikasi yang menentukan kelas gambar berdasarkan kumpulan fitur yang diekstraksi.

  • Fine-tuning pra-model terlatih: Untuk lebih meningkatkan kinerja, salah satu mungkin ingin repurpose lapisan top-level model pra-dilatih untuk dataset baru melalui fine-tuning. Dalam hal ini, Anda menyetel bobot Anda sedemikian rupa sehingga model Anda mempelajari fitur tingkat tinggi khusus untuk kumpulan data. Teknik ini biasanya direkomendasikan ketika dataset pelatihan berukuran besar dan sangat mirip dengan dataset asli yang digunakan untuk melatih model pra-pelatihan.

Untuk mempelajari lebih lanjut, kunjungi panduan belajar transfer .

# MIT License
#
# Copyright (c) 2017 François Chollet                                                                                                                    # IGNORE_COPYRIGHT: cleared by OSS licensing
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.