บันทึกวันที่! Google I / O ส่งคืนวันที่ 18-20 พฤษภาคม ลงทะเบียนเลย
หน้านี้ได้รับการแปลโดย Cloud Translation API
Switch to English

ถ่ายทอดการเรียนรู้และการปรับแต่ง

ดูใน TensorFlow.org เรียกใช้ใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดสมุดบันทึก

ในบทช่วยสอนนี้คุณจะได้เรียนรู้วิธีการจัดประเภทรูปภาพของแมวและสุนัขโดยใช้การถ่ายทอดการเรียนรู้จากเครือข่ายที่ได้รับการฝึกฝนมาก่อน

แบบจำลองที่ได้รับการฝึกอบรมมาก่อนคือเครือข่ายที่บันทึกไว้ซึ่งได้รับการฝึกอบรมก่อนหน้านี้เกี่ยวกับชุดข้อมูลขนาดใหญ่โดยทั่วไปจะใช้กับงานการจัดประเภทรูปภาพขนาดใหญ่ คุณใช้โมเดลที่กำหนดไว้ล่วงหน้าตามที่เป็นอยู่หรือใช้การเรียนรู้การถ่ายโอนเพื่อปรับแต่งโมเดลนี้ตามงานที่กำหนด

สัญชาตญาณที่อยู่เบื้องหลังการเรียนรู้การถ่ายโอนสำหรับการจำแนกภาพคือหากโมเดลได้รับการฝึกฝนบนชุดข้อมูลที่มีขนาดใหญ่และทั่วไปเพียงพอโมเดลนี้จะทำหน้าที่เป็นแบบจำลองทั่วไปของโลกภาพได้อย่างมีประสิทธิภาพ จากนั้นคุณสามารถใช้ประโยชน์จากแผนที่คุณลักษณะที่เรียนรู้เหล่านี้ได้โดยไม่ต้องเริ่มต้นใหม่ตั้งแต่ต้นด้วยการฝึกโมเดลขนาดใหญ่บนชุดข้อมูลขนาดใหญ่

ในสมุดบันทึกนี้คุณจะลองปรับแต่งโมเดลที่กำหนดไว้ล่วงหน้าได้สองวิธี:

  1. การแยกคุณลักษณะ: ใช้การแสดงที่เรียนรู้โดยเครือข่ายก่อนหน้านี้เพื่อดึงคุณลักษณะที่มีความหมายจากตัวอย่างใหม่ คุณเพียงแค่เพิ่มตัวแยกประเภทใหม่ซึ่งจะได้รับการฝึกฝนตั้งแต่เริ่มต้นที่ด้านบนของแบบจำลองที่กำหนดไว้ล่วงหน้าเพื่อให้คุณสามารถนำคุณลักษณะแผนที่ที่เรียนรู้มาก่อนหน้านี้สำหรับชุดข้อมูลกลับมาใช้

    คุณไม่จำเป็นต้องฝึกโมเดลทั้งหมดอีกครั้ง เครือข่าย Convolutional พื้นฐานมีคุณสมบัติที่เป็นประโยชน์โดยทั่วไปสำหรับการจำแนกรูปภาพ อย่างไรก็ตามส่วนสุดท้ายการจำแนกประเภทของแบบจำลองที่กำหนดไว้ล่วงหน้านั้นเฉพาะสำหรับงานการจำแนกประเภทดั้งเดิมและต่อมาจะเจาะจงไปยังชุดของชั้นเรียนที่โมเดลนั้นได้รับการฝึกฝน

  2. การปรับแต่งแบบละเอียด: คลายการตรึงชั้นบนสุดสองสามชั้นของฐานแบบจำลองที่แช่แข็งและร่วมกันฝึกทั้งชั้นลักษณนามที่เพิ่มเข้ามาใหม่และชั้นสุดท้ายของโมเดลฐาน สิ่งนี้ช่วยให้เราสามารถ "ปรับแต่ง" การนำเสนอคุณลักษณะลำดับที่สูงขึ้นในแบบจำลองพื้นฐานเพื่อให้มีความเกี่ยวข้องมากขึ้นสำหรับงานนั้น ๆ

คุณจะทำตามขั้นตอนการทำงานของแมชชีนเลิร์นนิงทั่วไป

  1. ตรวจสอบและทำความเข้าใจข้อมูล
  2. สร้างท่อส่งข้อมูลในกรณีนี้โดยใช้ Keras ImageDataGenerator
  3. ประกอบโมเดล
    • โหลดในแบบจำลองพื้นฐานที่กำหนดไว้ล่วงหน้า (และน้ำหนักที่กำหนดไว้ล่วงหน้า)
    • ซ้อนชั้นการจัดหมวดหมู่ไว้ด้านบน
  4. ฝึกโมเดล
  5. ประเมินแบบจำลอง
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf

from tensorflow.keras.preprocessing import image_dataset_from_directory

การประมวลผลข้อมูลล่วงหน้า

ดาวน์โหลดข้อมูล

ในบทช่วยสอนนี้คุณจะใช้ชุดข้อมูลที่มีรูปภาพแมวและสุนัขหลายพันรูป ดาวน์โหลดและแตกไฟล์ zip ที่มีอิมเมจจากนั้นสร้างtf.data.Dataset สำหรับการฝึกอบรมและการตรวจสอบความถูกต้องโดยใช้ยูทิลิตี้ tf.keras.preprocessing.image_dataset_from_directory คุณสามารถเรียนรู้เพิ่มเติมเกี่ยวกับการโหลดภาพได้ใน บทช่วยสอน นี้

_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

BATCH_SIZE = 32
IMG_SIZE = (160, 160)

train_dataset = image_dataset_from_directory(train_dir,
                                             shuffle=True,
                                             batch_size=BATCH_SIZE,
                                             image_size=IMG_SIZE)
Downloading data from https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip
68608000/68606236 [==============================] - 2s 0us/step
Found 2000 files belonging to 2 classes.
validation_dataset = image_dataset_from_directory(validation_dir,
                                                  shuffle=True,
                                                  batch_size=BATCH_SIZE,
                                                  image_size=IMG_SIZE)
Found 1000 files belonging to 2 classes.

แสดงภาพและป้ายกำกับเก้าภาพแรกจากชุดการฝึก:

class_names = train_dataset.class_names

plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

png

เนื่องจากชุดข้อมูลเดิมไม่มีชุดทดสอบคุณจึงต้องสร้างขึ้นมา ในการทำเช่นนั้นให้กำหนดจำนวนชุดข้อมูลที่มีอยู่ในชุดการตรวจสอบความถูกต้องโดยใช้ tf.data.experimental.cardinality จากนั้นย้ายข้อมูล 20% ไปยังชุดทดสอบ

val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)
print('Number of validation batches: %d' % tf.data.experimental.cardinality(validation_dataset))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_dataset))
Number of validation batches: 26
Number of test batches: 6

กำหนดค่าชุดข้อมูลเพื่อประสิทธิภาพ

ใช้การดึงข้อมูลล่วงหน้าที่บัฟเฟอร์เพื่อโหลดอิมเมจจากดิสก์โดยไม่ต้องให้ I / O ปิดกั้น หากต้องการเรียนรู้เพิ่มเติมเกี่ยวกับวิธีนี้โปรดดูคู่มือ ประสิทธิภาพของข้อมูล

AUTOTUNE = tf.data.AUTOTUNE

train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)

ใช้การเพิ่มข้อมูล

เมื่อคุณไม่มีชุดข้อมูลรูปภาพขนาดใหญ่เป็นแนวทางปฏิบัติที่ดีในการนำเสนอความหลากหลายของตัวอย่างโดยใช้การเปลี่ยนแปลงแบบสุ่ม แต่เหมือนจริงกับภาพการฝึกอบรมเช่นการหมุนและการพลิกแนวนอน สิ่งนี้ช่วยให้โมเดลเปิดเผยข้อมูลการฝึกอบรมในแง่มุมต่างๆและลด การฟิตติ้ง มาก เกินไป คุณสามารถเรียนรู้เพิ่มเติมเกี่ยวกับการเพิ่มข้อมูลได้ใน บทช่วยสอน นี้

data_augmentation = tf.keras.Sequential([
  tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
  tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])

ลองใช้เลเยอร์เหล่านี้ซ้ำ ๆ กับภาพเดียวกันและดูผลลัพธ์

for image, _ in train_dataset.take(1):
  plt.figure(figsize=(10, 10))
  first_image = image[0]
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    augmented_image = data_augmentation(tf.expand_dims(first_image, 0))
    plt.imshow(augmented_image[0] / 255)
    plt.axis('off')

png

ปรับขนาดค่าพิกเซลใหม่

อีกสักครู่คุณจะดาวน์โหลด tf.keras.applications.MobileNetV2 เพื่อใช้เป็นโมเดลพื้นฐานของคุณ รุ่นนี้คาดว่าจะมีค่าพิกเซลเป็น [-1,1] แต่ ณ จุดนี้ค่าพิกเซลในภาพของคุณจะอยู่ที่ [0-255] หากต้องการลดขนาดให้ใช้วิธีการประมวลผลล่วงหน้าที่มาพร้อมกับโมเดล

preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
rescale = tf.keras.layers.experimental.preprocessing.Rescaling(1./127.5, offset= -1)

สร้างแบบจำลองพื้นฐานจาก Convnets ที่ผ่านการฝึกอบรมมาแล้ว

คุณจะสร้างโมเดลพื้นฐานจากโมเดล MobileNet V2 ที่ พัฒนาโดย Google สิ่งนี้ได้รับการฝึกอบรมล่วงหน้าเกี่ยวกับชุดข้อมูล ImageNet ซึ่งเป็นชุดข้อมูลขนาดใหญ่ที่ประกอบด้วยรูปภาพ 1.4 ล้านรายการและคลาส 1,000 รายการ ImageNet เป็นชุดข้อมูลการฝึกอบรมการวิจัยที่มีหมวดหมู่มากมายเช่น jackfruit และ syringe ฐานความรู้นี้จะช่วยให้เราจำแนกแมวและสุนัขจากชุดข้อมูลเฉพาะของเรา

ขั้นแรกคุณต้องเลือกเลเยอร์ของ MobileNet V2 ที่คุณจะใช้ในการแยกคุณสมบัติ เลเยอร์การจัดหมวดหมู่สุดท้าย (ที่ "ด้านบน" เนื่องจากไดอะแกรมของโมเดลแมชชีนเลิร์นนิงส่วนใหญ่เปลี่ยนจากล่างขึ้นบน) จึงไม่มีประโยชน์มากนัก แต่คุณจะปฏิบัติตามแนวทางปฏิบัติทั่วไปเพื่อขึ้นอยู่กับชั้นสุดท้ายก่อนการดำเนินการแบน ชั้นนี้เรียกว่า "ชั้นคอขวด" คุณสมบัติของชั้นคอขวดยังคงมีลักษณะทั่วไปมากกว่าเมื่อเทียบกับชั้นสุดท้าย / ชั้นบนสุด

ขั้นแรกให้สร้างโมเดล MobileNet V2 ที่โหลดไว้ล่วงหน้าด้วยน้ำหนักที่ฝึกบน ImageNet ด้วยการระบุอาร์กิวเมนต์ include_top = False คุณจะโหลดเครือข่ายที่ไม่มีเลเยอร์การจัดหมวดหมู่ที่ด้านบนซึ่งเหมาะอย่างยิ่งสำหรับการแยกคุณลักษณะ

# Create the base model from the pre-trained model MobileNet V2
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step

ตัวแยกคุณสมบัตินี้จะแปลงรูปภาพ 160x160x3 แต่ละภาพให้เป็นบล็อก 5x5x1280 มาดูกันว่ามันทำอะไรกับชุดรูปภาพตัวอย่าง:

image_batch, label_batch = next(iter(train_dataset))
feature_batch = base_model(image_batch)
print(feature_batch.shape)
(32, 5, 5, 1280)

การแยกคุณลักษณะ

ในขั้นตอนนี้คุณจะตรึงฐาน Convolutional ที่สร้างขึ้นจากขั้นตอนก่อนหน้าและใช้เป็นตัวแยกคุณสมบัติ นอกจากนี้คุณยังเพิ่มลักษณนามที่ด้านบนและฝึกลักษณนามระดับบนสุด

ตรึงฐาน Convolutional

เป็นสิ่งสำคัญที่จะต้องตรึงฐาน Convolutional ก่อนที่คุณจะคอมไพล์และฝึกโมเดล การแช่แข็ง (โดยการตั้งค่า layer.trainable = False) ป้องกันไม่ให้มีการอัปเดตน้ำหนักในเลเยอร์ที่กำหนดระหว่างการฝึก MobileNet V2 มีหลายเลเยอร์ดังนั้นการตั้งค่าแฟ trainable ของโมเดลทั้งหมดเป็น False จะหยุดการทำงานทั้งหมด

base_model.trainable = False

หมายเหตุสำคัญเกี่ยวกับเลเยอร์ BatchNormalization

หลายรุ่นมีเลเยอร์ tf.keras.layers.BatchNormalization เลเยอร์นี้เป็นกรณีพิเศษและควรใช้ความระมัดระวังในบริบทของการปรับแต่งตามที่แสดงในบทช่วยสอนนี้

เมื่อคุณตั้งค่า layer.trainable = False เลเยอร์ BatchNormalization จะทำงานในโหมดอนุมานและจะไม่อัปเดตค่าเฉลี่ยและสถิติความแปรปรวน

เมื่อคุณยกเลิกการตรึงโมเดลที่มีเลเยอร์ BatchNormalization เพื่อทำการปรับแต่งคุณควรเก็บเลเยอร์ BatchNormalization ไว้ในโหมดอนุมานโดยผ่าน training = False เมื่อเรียกโมเดลพื้นฐาน มิฉะนั้นการอัปเดตที่ใช้กับน้ำหนักที่ไม่สามารถฝึกได้จะทำลายสิ่งที่โมเดลได้เรียนรู้

สำหรับรายละเอียดโปรดดูคู่มือการ ถ่ายโอนการเรียนรู้

# Let's take a look at the base model architecture
base_model.summary()
Model: "mobilenetv2_1.00_160"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 160, 160, 3) 0                                            
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 80, 80, 32)   864         input_1[0][0]                    
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 80, 80, 32)   128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, 80, 80, 32)   0           bn_Conv1[0][0]                   
__________________________________________________________________________________________________
expanded_conv_depthwise (Depthw (None, 80, 80, 32)   288         Conv1_relu[0][0]                 
__________________________________________________________________________________________________
expanded_conv_depthwise_BN (Bat (None, 80, 80, 32)   128         expanded_conv_depthwise[0][0]    
__________________________________________________________________________________________________
expanded_conv_depthwise_relu (R (None, 80, 80, 32)   0           expanded_conv_depthwise_BN[0][0] 
__________________________________________________________________________________________________
expanded_conv_project (Conv2D)  (None, 80, 80, 16)   512         expanded_conv_depthwise_relu[0][0
__________________________________________________________________________________________________
expanded_conv_project_BN (Batch (None, 80, 80, 16)   64          expanded_conv_project[0][0]      
__________________________________________________________________________________________________
block_1_expand (Conv2D)         (None, 80, 80, 96)   1536        expanded_conv_project_BN[0][0]   
__________________________________________________________________________________________________
block_1_expand_BN (BatchNormali (None, 80, 80, 96)   384         block_1_expand[0][0]             
__________________________________________________________________________________________________
block_1_expand_relu (ReLU)      (None, 80, 80, 96)   0           block_1_expand_BN[0][0]          
__________________________________________________________________________________________________
block_1_pad (ZeroPadding2D)     (None, 81, 81, 96)   0           block_1_expand_relu[0][0]        
__________________________________________________________________________________________________
block_1_depthwise (DepthwiseCon (None, 40, 40, 96)   864         block_1_pad[0][0]                
__________________________________________________________________________________________________
block_1_depthwise_BN (BatchNorm (None, 40, 40, 96)   384         block_1_depthwise[0][0]          
__________________________________________________________________________________________________
block_1_depthwise_relu (ReLU)   (None, 40, 40, 96)   0           block_1_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_1_project (Conv2D)        (None, 40, 40, 24)   2304        block_1_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_1_project_BN (BatchNormal (None, 40, 40, 24)   96          block_1_project[0][0]            
__________________________________________________________________________________________________
block_2_expand (Conv2D)         (None, 40, 40, 144)  3456        block_1_project_BN[0][0]         
__________________________________________________________________________________________________
block_2_expand_BN (BatchNormali (None, 40, 40, 144)  576         block_2_expand[0][0]             
__________________________________________________________________________________________________
block_2_expand_relu (ReLU)      (None, 40, 40, 144)  0           block_2_expand_BN[0][0]          
__________________________________________________________________________________________________
block_2_depthwise (DepthwiseCon (None, 40, 40, 144)  1296        block_2_expand_relu[0][0]        
__________________________________________________________________________________________________
block_2_depthwise_BN (BatchNorm (None, 40, 40, 144)  576         block_2_depthwise[0][0]          
__________________________________________________________________________________________________
block_2_depthwise_relu (ReLU)   (None, 40, 40, 144)  0           block_2_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_2_project (Conv2D)        (None, 40, 40, 24)   3456        block_2_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_2_project_BN (BatchNormal (None, 40, 40, 24)   96          block_2_project[0][0]            
__________________________________________________________________________________________________
block_2_add (Add)               (None, 40, 40, 24)   0           block_1_project_BN[0][0]         
                                                                 block_2_project_BN[0][0]         
__________________________________________________________________________________________________
block_3_expand (Conv2D)         (None, 40, 40, 144)  3456        block_2_add[0][0]                
__________________________________________________________________________________________________
block_3_expand_BN (BatchNormali (None, 40, 40, 144)  576         block_3_expand[0][0]             
__________________________________________________________________________________________________
block_3_expand_relu (ReLU)      (None, 40, 40, 144)  0           block_3_expand_BN[0][0]          
__________________________________________________________________________________________________
block_3_pad (ZeroPadding2D)     (None, 41, 41, 144)  0           block_3_expand_relu[0][0]        
__________________________________________________________________________________________________
block_3_depthwise (DepthwiseCon (None, 20, 20, 144)  1296        block_3_pad[0][0]                
__________________________________________________________________________________________________
block_3_depthwise_BN (BatchNorm (None, 20, 20, 144)  576         block_3_depthwise[0][0]          
__________________________________________________________________________________________________
block_3_depthwise_relu (ReLU)   (None, 20, 20, 144)  0           block_3_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_3_project (Conv2D)        (None, 20, 20, 32)   4608        block_3_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_3_project_BN (BatchNormal (None, 20, 20, 32)   128         block_3_project[0][0]            
__________________________________________________________________________________________________
block_4_expand (Conv2D)         (None, 20, 20, 192)  6144        block_3_project_BN[0][0]         
__________________________________________________________________________________________________
block_4_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_4_expand[0][0]             
__________________________________________________________________________________________________
block_4_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_4_expand_BN[0][0]          
__________________________________________________________________________________________________
block_4_depthwise (DepthwiseCon (None, 20, 20, 192)  1728        block_4_expand_relu[0][0]        
__________________________________________________________________________________________________
block_4_depthwise_BN (BatchNorm (None, 20, 20, 192)  768         block_4_depthwise[0][0]          
__________________________________________________________________________________________________
block_4_depthwise_relu (ReLU)   (None, 20, 20, 192)  0           block_4_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_4_project (Conv2D)        (None, 20, 20, 32)   6144        block_4_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_4_project_BN (BatchNormal (None, 20, 20, 32)   128         block_4_project[0][0]            
__________________________________________________________________________________________________
block_4_add (Add)               (None, 20, 20, 32)   0           block_3_project_BN[0][0]         
                                                                 block_4_project_BN[0][0]         
__________________________________________________________________________________________________
block_5_expand (Conv2D)         (None, 20, 20, 192)  6144        block_4_add[0][0]                
__________________________________________________________________________________________________
block_5_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_5_expand[0][0]             
__________________________________________________________________________________________________
block_5_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_5_expand_BN[0][0]          
__________________________________________________________________________________________________
block_5_depthwise (DepthwiseCon (None, 20, 20, 192)  1728        block_5_expand_relu[0][0]        
__________________________________________________________________________________________________
block_5_depthwise_BN (BatchNorm (None, 20, 20, 192)  768         block_5_depthwise[0][0]          
__________________________________________________________________________________________________
block_5_depthwise_relu (ReLU)   (None, 20, 20, 192)  0           block_5_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_5_project (Conv2D)        (None, 20, 20, 32)   6144        block_5_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_5_project_BN (BatchNormal (None, 20, 20, 32)   128         block_5_project[0][0]            
__________________________________________________________________________________________________
block_5_add (Add)               (None, 20, 20, 32)   0           block_4_add[0][0]                
                                                                 block_5_project_BN[0][0]         
__________________________________________________________________________________________________
block_6_expand (Conv2D)         (None, 20, 20, 192)  6144        block_5_add[0][0]                
__________________________________________________________________________________________________
block_6_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_6_expand[0][0]             
__________________________________________________________________________________________________
block_6_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_6_expand_BN[0][0]          
__________________________________________________________________________________________________
block_6_pad (ZeroPadding2D)     (None, 21, 21, 192)  0           block_6_expand_relu[0][0]        
__________________________________________________________________________________________________
block_6_depthwise (DepthwiseCon (None, 10, 10, 192)  1728        block_6_pad[0][0]                
__________________________________________________________________________________________________
block_6_depthwise_BN (BatchNorm (None, 10, 10, 192)  768         block_6_depthwise[0][0]          
__________________________________________________________________________________________________
block_6_depthwise_relu (ReLU)   (None, 10, 10, 192)  0           block_6_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_6_project (Conv2D)        (None, 10, 10, 64)   12288       block_6_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_6_project_BN (BatchNormal (None, 10, 10, 64)   256         block_6_project[0][0]            
__________________________________________________________________________________________________
block_7_expand (Conv2D)         (None, 10, 10, 384)  24576       block_6_project_BN[0][0]         
__________________________________________________________________________________________________
block_7_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_7_expand[0][0]             
__________________________________________________________________________________________________
block_7_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_7_expand_BN[0][0]          
__________________________________________________________________________________________________
block_7_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_7_expand_relu[0][0]        
__________________________________________________________________________________________________
block_7_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_7_depthwise[0][0]          
__________________________________________________________________________________________________
block_7_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_7_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_7_project (Conv2D)        (None, 10, 10, 64)   24576       block_7_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_7_project_BN (BatchNormal (None, 10, 10, 64)   256         block_7_project[0][0]            
__________________________________________________________________________________________________
block_7_add (Add)               (None, 10, 10, 64)   0           block_6_project_BN[0][0]         
                                                                 block_7_project_BN[0][0]         
__________________________________________________________________________________________________
block_8_expand (Conv2D)         (None, 10, 10, 384)  24576       block_7_add[0][0]                
__________________________________________________________________________________________________
block_8_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_8_expand[0][0]             
__________________________________________________________________________________________________
block_8_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_8_expand_BN[0][0]          
__________________________________________________________________________________________________
block_8_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_8_expand_relu[0][0]        
__________________________________________________________________________________________________
block_8_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_8_depthwise[0][0]          
__________________________________________________________________________________________________
block_8_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_8_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_8_project (Conv2D)        (None, 10, 10, 64)   24576       block_8_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_8_project_BN (BatchNormal (None, 10, 10, 64)   256         block_8_project[0][0]            
__________________________________________________________________________________________________
block_8_add (Add)               (None, 10, 10, 64)   0           block_7_add[0][0]                
                                                                 block_8_project_BN[0][0]         
__________________________________________________________________________________________________
block_9_expand (Conv2D)         (None, 10, 10, 384)  24576       block_8_add[0][0]                
__________________________________________________________________________________________________
block_9_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_9_expand[0][0]             
__________________________________________________________________________________________________
block_9_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_9_expand_BN[0][0]          
__________________________________________________________________________________________________
block_9_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_9_expand_relu[0][0]        
__________________________________________________________________________________________________
block_9_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_9_depthwise[0][0]          
__________________________________________________________________________________________________
block_9_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_9_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_9_project (Conv2D)        (None, 10, 10, 64)   24576       block_9_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_9_project_BN (BatchNormal (None, 10, 10, 64)   256         block_9_project[0][0]            
__________________________________________________________________________________________________
block_9_add (Add)               (None, 10, 10, 64)   0           block_8_add[0][0]                
                                                                 block_9_project_BN[0][0]         
__________________________________________________________________________________________________
block_10_expand (Conv2D)        (None, 10, 10, 384)  24576       block_9_add[0][0]                
__________________________________________________________________________________________________
block_10_expand_BN (BatchNormal (None, 10, 10, 384)  1536        block_10_expand[0][0]            
__________________________________________________________________________________________________
block_10_expand_relu (ReLU)     (None, 10, 10, 384)  0           block_10_expand_BN[0][0]         
__________________________________________________________________________________________________
block_10_depthwise (DepthwiseCo (None, 10, 10, 384)  3456        block_10_expand_relu[0][0]       
__________________________________________________________________________________________________
block_10_depthwise_BN (BatchNor (None, 10, 10, 384)  1536        block_10_depthwise[0][0]         
__________________________________________________________________________________________________
block_10_depthwise_relu (ReLU)  (None, 10, 10, 384)  0           block_10_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_10_project (Conv2D)       (None, 10, 10, 96)   36864       block_10_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_10_project_BN (BatchNorma (None, 10, 10, 96)   384         block_10_project[0][0]           
__________________________________________________________________________________________________
block_11_expand (Conv2D)        (None, 10, 10, 576)  55296       block_10_project_BN[0][0]        
__________________________________________________________________________________________________
block_11_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_11_expand[0][0]            
__________________________________________________________________________________________________
block_11_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_11_expand_BN[0][0]         
__________________________________________________________________________________________________
block_11_depthwise (DepthwiseCo (None, 10, 10, 576)  5184        block_11_expand_relu[0][0]       
__________________________________________________________________________________________________
block_11_depthwise_BN (BatchNor (None, 10, 10, 576)  2304        block_11_depthwise[0][0]         
__________________________________________________________________________________________________
block_11_depthwise_relu (ReLU)  (None, 10, 10, 576)  0           block_11_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_11_project (Conv2D)       (None, 10, 10, 96)   55296       block_11_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_11_project_BN (BatchNorma (None, 10, 10, 96)   384         block_11_project[0][0]           
__________________________________________________________________________________________________
block_11_add (Add)              (None, 10, 10, 96)   0           block_10_project_BN[0][0]        
                                                                 block_11_project_BN[0][0]        
__________________________________________________________________________________________________
block_12_expand (Conv2D)        (None, 10, 10, 576)  55296       block_11_add[0][0]               
__________________________________________________________________________________________________
block_12_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_12_expand[0][0]            
__________________________________________________________________________________________________
block_12_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_12_expand_BN[0][0]         
__________________________________________________________________________________________________
block_12_depthwise (DepthwiseCo (None, 10, 10, 576)  5184        block_12_expand_relu[0][0]       
__________________________________________________________________________________________________
block_12_depthwise_BN (BatchNor (None, 10, 10, 576)  2304        block_12_depthwise[0][0]         
__________________________________________________________________________________________________
block_12_depthwise_relu (ReLU)  (None, 10, 10, 576)  0           block_12_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_12_project (Conv2D)       (None, 10, 10, 96)   55296       block_12_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_12_project_BN (BatchNorma (None, 10, 10, 96)   384         block_12_project[0][0]           
__________________________________________________________________________________________________
block_12_add (Add)              (None, 10, 10, 96)   0           block_11_add[0][0]               
                                                                 block_12_project_BN[0][0]        
__________________________________________________________________________________________________
block_13_expand (Conv2D)        (None, 10, 10, 576)  55296       block_12_add[0][0]               
__________________________________________________________________________________________________
block_13_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_13_expand[0][0]            
__________________________________________________________________________________________________
block_13_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_13_expand_BN[0][0]         
__________________________________________________________________________________________________
block_13_pad (ZeroPadding2D)    (None, 11, 11, 576)  0           block_13_expand_relu[0][0]       
__________________________________________________________________________________________________
block_13_depthwise (DepthwiseCo (None, 5, 5, 576)    5184        block_13_pad[0][0]               
__________________________________________________________________________________________________
block_13_depthwise_BN (BatchNor (None, 5, 5, 576)    2304        block_13_depthwise[0][0]         
__________________________________________________________________________________________________
block_13_depthwise_relu (ReLU)  (None, 5, 5, 576)    0           block_13_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_13_project (Conv2D)       (None, 5, 5, 160)    92160       block_13_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_13_project_BN (BatchNorma (None, 5, 5, 160)    640         block_13_project[0][0]           
__________________________________________________________________________________________________
block_14_expand (Conv2D)        (None, 5, 5, 960)    153600      block_13_project_BN[0][0]        
__________________________________________________________________________________________________
block_14_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_14_expand[0][0]            
__________________________________________________________________________________________________
block_14_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_14_expand_BN[0][0]         
__________________________________________________________________________________________________
block_14_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_14_expand_relu[0][0]       
__________________________________________________________________________________________________
block_14_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_14_depthwise[0][0]         
__________________________________________________________________________________________________
block_14_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_14_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_14_project (Conv2D)       (None, 5, 5, 160)    153600      block_14_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_14_project_BN (BatchNorma (None, 5, 5, 160)    640         block_14_project[0][0]           
__________________________________________________________________________________________________
block_14_add (Add)              (None, 5, 5, 160)    0           block_13_project_BN[0][0]        
                                                                 block_14_project_BN[0][0]        
__________________________________________________________________________________________________
block_15_expand (Conv2D)        (None, 5, 5, 960)    153600      block_14_add[0][0]               
__________________________________________________________________________________________________
block_15_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_15_expand[0][0]            
__________________________________________________________________________________________________
block_15_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_15_expand_BN[0][0]         
__________________________________________________________________________________________________
block_15_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_15_expand_relu[0][0]       
__________________________________________________________________________________________________
block_15_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_15_depthwise[0][0]         
__________________________________________________________________________________________________
block_15_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_15_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_15_project (Conv2D)       (None, 5, 5, 160)    153600      block_15_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_15_project_BN (BatchNorma (None, 5, 5, 160)    640         block_15_project[0][0]           
__________________________________________________________________________________________________
block_15_add (Add)              (None, 5, 5, 160)    0           block_14_add[0][0]               
                                                                 block_15_project_BN[0][0]        
__________________________________________________________________________________________________
block_16_expand (Conv2D)        (None, 5, 5, 960)    153600      block_15_add[0][0]               
__________________________________________________________________________________________________
block_16_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_16_expand[0][0]            
__________________________________________________________________________________________________
block_16_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_16_expand_BN[0][0]         
__________________________________________________________________________________________________
block_16_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_16_expand_relu[0][0]       
__________________________________________________________________________________________________
block_16_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_16_depthwise[0][0]         
__________________________________________________________________________________________________
block_16_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_16_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_16_project (Conv2D)       (None, 5, 5, 320)    307200      block_16_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_16_project_BN (BatchNorma (None, 5, 5, 320)    1280        block_16_project[0][0]           
__________________________________________________________________________________________________
Conv_1 (Conv2D)                 (None, 5, 5, 1280)   409600      block_16_project_BN[0][0]        
__________________________________________________________________________________________________
Conv_1_bn (BatchNormalization)  (None, 5, 5, 1280)   5120        Conv_1[0][0]                     
__________________________________________________________________________________________________
out_relu (ReLU)                 (None, 5, 5, 1280)   0           Conv_1_bn[0][0]                  
==================================================================================================
Total params: 2,257,984
Trainable params: 0
Non-trainable params: 2,257,984
__________________________________________________________________________________________________

เพิ่มหัวการจัดหมวดหมู่

ในการสร้างการคาดการณ์จากบล็อกของคุณสมบัติให้เฉลี่ยบนตำแหน่งเชิงพื้นที่ 5x5 เชิงพื้นที่โดยใช้เลเยอร์ tf.keras.layers.GlobalAveragePooling2D เพื่อแปลงคุณสมบัติเป็นเวกเตอร์ 1280 องค์ประกอบเดียวต่อภาพ

global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)
(32, 1280)

ใช้เลเยอร์ tf.keras.layers.Dense เพื่อแปลงคุณสมบัติเหล่านี้เป็นการคาดคะเนเดียวต่อภาพ คุณไม่จำเป็นต้องมีฟังก์ชั่นการเปิดใช้งานที่นี่เนื่องจากการคาดการณ์นี้จะถือว่าเป็น logit หรือค่าการคาดการณ์ดิบ ตัวเลขบวกทำนายชั้น 1 เลขลบทำนายชั้น 0

prediction_layer = tf.keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)
(32, 1)

สร้างแบบจำลองโดยเชื่อมโยงการเพิ่มข้อมูลการปรับขนาดการปรับขนาดฐานและตัวแยกคุณลักษณะโดยใช้ Keras Functional API ตามที่กล่าวไว้ก่อนหน้านี้ให้ใช้ training = False เนื่องจากโมเดลของเรามีเลเยอร์ BatchNormalization

inputs = tf.keras.Input(shape=(160, 160, 3))
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)

รวบรวมโมเดล

รวบรวมแบบจำลองก่อนการฝึกอบรม เนื่องจากมีสองคลาสให้ใช้การสูญเสียข้ามเอนโทรปีแบบไบนารีกับ from_logits=True เนื่องจากโมเดลมีเอาต์พุตเชิงเส้น

base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
sequential (Sequential)      (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.truediv (TFOpLambda) (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.subtract (TFOpLambda (None, 160, 160, 3)       0         
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,281
Non-trainable params: 2,257,984
_________________________________________________________________

พารามิเตอร์ 2.5M ใน MobileNet ถูกแช่แข็ง แต่มี 1.2K พารามิเตอร์สุวินัยในชั้นหนาแน่น สิ่งเหล่านี้ถูกแบ่งระหว่างสอง tf.Variable วัตถุที่เปลี่ยนแปลงได้น้ำหนักและอคติ

len(model.trainable_variables)
2

ฝึกโมเดล

หลังจากการฝึกอบรมเป็นเวลา 10 ยุคคุณจะเห็นความแม่นยำ ~ 94% ในชุดการตรวจสอบความถูกต้อง

initial_epochs = 10

loss0, accuracy0 = model.evaluate(validation_dataset)
26/26 [==============================] - 3s 33ms/step - loss: 0.8183 - accuracy: 0.4566
print("initial loss: {:.2f}".format(loss0))
print("initial accuracy: {:.2f}".format(accuracy0))
initial loss: 0.80
initial accuracy: 0.47
history = model.fit(train_dataset,
                    epochs=initial_epochs,
                    validation_data=validation_dataset)
Epoch 1/10
63/63 [==============================] - 7s 64ms/step - loss: 0.6962 - accuracy: 0.5840 - val_loss: 0.5437 - val_accuracy: 0.6708
Epoch 2/10
63/63 [==============================] - 4s 55ms/step - loss: 0.5169 - accuracy: 0.7270 - val_loss: 0.4002 - val_accuracy: 0.8082
Epoch 3/10
63/63 [==============================] - 4s 54ms/step - loss: 0.4253 - accuracy: 0.7935 - val_loss: 0.3053 - val_accuracy: 0.8725
Epoch 4/10
63/63 [==============================] - 4s 52ms/step - loss: 0.3604 - accuracy: 0.8365 - val_loss: 0.2515 - val_accuracy: 0.9047
Epoch 5/10
63/63 [==============================] - 3s 51ms/step - loss: 0.3200 - accuracy: 0.8565 - val_loss: 0.2226 - val_accuracy: 0.9183
Epoch 6/10
63/63 [==============================] - 3s 51ms/step - loss: 0.2788 - accuracy: 0.8750 - val_loss: 0.1887 - val_accuracy: 0.9356
Epoch 7/10
63/63 [==============================] - 4s 51ms/step - loss: 0.2579 - accuracy: 0.8895 - val_loss: 0.1682 - val_accuracy: 0.9418
Epoch 8/10
63/63 [==============================] - 4s 51ms/step - loss: 0.2536 - accuracy: 0.8815 - val_loss: 0.1554 - val_accuracy: 0.9468
Epoch 9/10
63/63 [==============================] - 4s 53ms/step - loss: 0.2240 - accuracy: 0.9080 - val_loss: 0.1410 - val_accuracy: 0.9468
Epoch 10/10
63/63 [==============================] - 4s 53ms/step - loss: 0.2158 - accuracy: 0.9005 - val_loss: 0.1345 - val_accuracy: 0.9493

การเรียนรู้เส้นโค้ง

มาดูเส้นโค้งการเรียนรู้ของการฝึกอบรมและความถูกต้อง / การสูญเสียในการตรวจสอบความถูกต้องเมื่อใช้โมเดลพื้นฐาน MobileNet V2 เป็นตัวแยกคุณสมบัติคงที่

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

png

ในระดับที่น้อยกว่านั้นยังเป็นเพราะเมตริกการฝึกอบรมรายงานค่าเฉลี่ยสำหรับยุคหนึ่งในขณะที่เมตริกการตรวจสอบความถูกต้องจะได้รับการประเมินหลังยุคดังนั้นเมตริกการตรวจสอบความถูกต้องจะเห็นโมเดลที่ผ่านการฝึกอบรมมานานกว่าเล็กน้อย

ปรับจูน

ในการทดสอบการแยกคุณลักษณะคุณได้ฝึกอบรมเพียงไม่กี่ชั้นจากโมเดลพื้นฐาน MobileNet V2 น้ำหนักของเครือข่ายก่อนการฝึกอบรม ไม่ได้ รับการปรับปรุงในระหว่างการฝึกอบรม

วิธีหนึ่งในการเพิ่มประสิทธิภาพให้ดียิ่งขึ้นไปอีกคือการฝึก (หรือ "ปรับแต่ง") น้ำหนักของชั้นบนสุดของแบบจำลองที่ได้รับการฝึกฝนมาแล้วควบคู่ไปกับการฝึกลักษณนามที่คุณเพิ่มเข้าไป กระบวนการฝึกอบรมจะบังคับให้ปรับน้ำหนักจากแผนที่คุณลักษณะทั่วไปไปจนถึงคุณลักษณะที่เกี่ยวข้องกับชุดข้อมูลโดยเฉพาะ

นอกจากนี้คุณควรพยายามปรับแต่งเลเยอร์บนสุดจำนวนเล็กน้อยแทนที่จะเป็นโมเดล MobileNet ทั้งหมด ในเครือข่าย Convolutional ส่วนใหญ่ยิ่งเลเยอร์สูงขึ้นเท่าไหร่ก็ยิ่งมีความเชี่ยวชาญมากขึ้นเท่านั้น สองสามชั้นแรกจะเรียนรู้คุณสมบัติที่เรียบง่ายและทั่วไปซึ่งแสดงถึงภาพเกือบทุกประเภท เมื่อคุณก้าวไปสูงขึ้นคุณลักษณะต่างๆจะมีความเฉพาะเจาะจงมากขึ้นสำหรับชุดข้อมูลที่โมเดลได้รับการฝึกฝน เป้าหมายของการปรับแต่งอย่างละเอียดคือการปรับคุณสมบัติพิเศษเหล่านี้ให้ทำงานร่วมกับชุดข้อมูลใหม่แทนที่จะเขียนทับการเรียนรู้ทั่วไป

ยกเลิกการตรึงเลเยอร์บนสุดของโมเดล

สิ่งที่คุณต้องทำคือยกเลิกการตรึง base_model และตั้งค่าเลเยอร์ด้านล่างให้ไม่สามารถฝึกได้ จากนั้นคุณควรคอมไพล์โมเดลใหม่ (จำเป็นเพื่อให้การเปลี่ยนแปลงเหล่านี้มีผล) และเริ่มการฝึกอบรมต่อ

base_model.trainable = True
# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model.layers))

# Fine-tune from this layer onwards
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
  layer.trainable =  False
Number of layers in the base model:  154

รวบรวมโมเดล

ในขณะที่คุณกำลังฝึกโมเดลที่ใหญ่กว่ามากและต้องการอ่านน้ำหนักที่กำหนดไว้ล่วงหน้าสิ่งสำคัญคือต้องใช้อัตราการเรียนรู้ที่ต่ำกว่าในขั้นตอนนี้ มิฉะนั้นแบบจำลองของคุณอาจเกินได้อย่างรวดเร็ว

model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate/10),
              metrics=['accuracy'])
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
sequential (Sequential)      (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.truediv (TFOpLambda) (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.subtract (TFOpLambda (None, 160, 160, 3)       0         
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,862,721
Non-trainable params: 396,544
_________________________________________________________________
len(model.trainable_variables)
56

ฝึกโมเดลต่อไป

หากคุณฝึกการลู่เข้ามาก่อนหน้านี้ขั้นตอนนี้จะช่วยเพิ่มความแม่นยำของคุณโดยไม่กี่เปอร์เซ็นต์

fine_tune_epochs = 10
total_epochs =  initial_epochs + fine_tune_epochs

history_fine = model.fit(train_dataset,
                         epochs=total_epochs,
                         initial_epoch=history.epoch[-1],
                         validation_data=validation_dataset)
Epoch 10/20
63/63 [==============================] - 9s 69ms/step - loss: 0.1798 - accuracy: 0.9281 - val_loss: 0.0650 - val_accuracy: 0.9752
Epoch 11/20
63/63 [==============================] - 4s 55ms/step - loss: 0.1283 - accuracy: 0.9453 - val_loss: 0.0519 - val_accuracy: 0.9827
Epoch 12/20
63/63 [==============================] - 4s 55ms/step - loss: 0.1112 - accuracy: 0.9536 - val_loss: 0.0463 - val_accuracy: 0.9777
Epoch 13/20
63/63 [==============================] - 4s 55ms/step - loss: 0.0973 - accuracy: 0.9662 - val_loss: 0.0529 - val_accuracy: 0.9790
Epoch 14/20
63/63 [==============================] - 4s 55ms/step - loss: 0.0916 - accuracy: 0.9585 - val_loss: 0.0496 - val_accuracy: 0.9790
Epoch 15/20
63/63 [==============================] - 4s 56ms/step - loss: 0.0839 - accuracy: 0.9659 - val_loss: 0.0412 - val_accuracy: 0.9827
Epoch 16/20
63/63 [==============================] - 4s 52ms/step - loss: 0.0848 - accuracy: 0.9683 - val_loss: 0.0410 - val_accuracy: 0.9827
Epoch 17/20
63/63 [==============================] - 4s 53ms/step - loss: 0.0704 - accuracy: 0.9733 - val_loss: 0.0404 - val_accuracy: 0.9827
Epoch 18/20
63/63 [==============================] - 4s 54ms/step - loss: 0.0730 - accuracy: 0.9691 - val_loss: 0.0645 - val_accuracy: 0.9777
Epoch 19/20
63/63 [==============================] - 4s 55ms/step - loss: 0.0746 - accuracy: 0.9729 - val_loss: 0.0294 - val_accuracy: 0.9839
Epoch 20/20
63/63 [==============================] - 4s 56ms/step - loss: 0.0639 - accuracy: 0.9787 - val_loss: 0.0358 - val_accuracy: 0.9827

มาดูเส้นโค้งการเรียนรู้ของการฝึกอบรมและความถูกต้อง / การสูญเสียในการตรวจสอบความถูกต้องเมื่อปรับแต่งสองสามชั้นสุดท้ายของโมเดลพื้นฐาน MobileNet V2 และฝึกตัวแยกประเภทที่ด้านบน การสูญเสียการตรวจสอบความถูกต้องสูงกว่าการสูญเสียการฝึกอบรมมากดังนั้นคุณอาจได้รับการติดตั้งมากเกินไป

นอกจากนี้คุณอาจได้รับการติดตั้งมากเกินไปเนื่องจากชุดการฝึกอบรมใหม่มีขนาดค่อนข้างเล็กและคล้ายกับชุดข้อมูล MobileNet V2 เดิม

หลังจากปรับโมเดลอย่างละเอียดแล้วความแม่นยำเกือบถึง 98% ในชุดการตรวจสอบความถูกต้อง

acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']

loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0.8, 1])
plt.plot([initial_epochs-1,initial_epochs-1],
          plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 1.0])
plt.plot([initial_epochs-1,initial_epochs-1],
         plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

png

การประเมินผลและการทำนาย

ขั้นสุดท้ายคุณสามารถตรวจสอบประสิทธิภาพของโมเดลกับข้อมูลใหม่โดยใช้ชุดทดสอบ

loss, accuracy = model.evaluate(test_dataset)
print('Test accuracy :', accuracy)
6/6 [==============================] - 1s 38ms/step - loss: 0.0428 - accuracy: 0.9740
Test accuracy : 0.9739583134651184

และตอนนี้คุณพร้อมที่จะใช้แบบจำลองนี้เพื่อทำนายว่าสัตว์เลี้ยงของคุณเป็นแมวหรือสุนัข

#Retrieve a batch of images from the test set
image_batch, label_batch = test_dataset.as_numpy_iterator().next()
predictions = model.predict_on_batch(image_batch).flatten()

# Apply a sigmoid since our model returns logits
predictions = tf.nn.sigmoid(predictions)
predictions = tf.where(predictions < 0.5, 0, 1)

print('Predictions:\n', predictions.numpy())
print('Labels:\n', label_batch)

plt.figure(figsize=(10, 10))
for i in range(9):
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(image_batch[i].astype("uint8"))
  plt.title(class_names[predictions[i]])
  plt.axis("off")
Predictions:
 [0 0 1 0 0 1 1 0 0 0 0 1 1 0 1 1 0 1 0 0 1 0 1 0 1 1 1 1 1 1 1 1]
Labels:
 [0 0 1 0 0 1 1 0 0 0 0 1 1 0 1 1 0 1 0 0 1 0 1 0 1 1 1 1 1 1 1 1]

png

สรุป

  • การใช้แบบจำลองที่ได้รับการฝึกฝนมาก่อนสำหรับการแยกคุณลักษณะ : เมื่อทำงานกับชุดข้อมูลขนาดเล็กถือเป็นแนวทางปฏิบัติทั่วไปในการใช้ประโยชน์จากคุณลักษณะที่เรียนรู้โดยโมเดลที่ได้รับการฝึกฝนบนชุดข้อมูลขนาดใหญ่ในโดเมนเดียวกัน สิ่งนี้ทำได้โดยการสร้างอินสแตนซ์โมเดลที่ฝึกไว้ล่วงหน้าและเพิ่มลักษณนามที่เชื่อมต่อเต็มรูปแบบไว้ด้านบน แบบจำลองที่ผ่านการฝึกอบรมแล้วคือ "แช่แข็ง" และจะมีการอัปเดตเฉพาะน้ำหนักของลักษณนามระหว่างการฝึกเท่านั้น ในกรณีนี้ฐาน Convolutional ดึงคุณลักษณะทั้งหมดที่เกี่ยวข้องกับแต่ละภาพและคุณเพิ่งฝึกลักษณนามที่กำหนดคลาสของรูปภาพที่กำหนดคุณสมบัติที่แยกออกมานั้น

  • การปรับโมเดลที่ผ่านการฝึกอบรมมาแล้วอย่างละเอียด : เพื่อปรับปรุงประสิทธิภาพให้ดียิ่งขึ้นเราอาจต้องการนำเลเยอร์ระดับบนสุดของโมเดลที่ผ่านการฝึกอบรมมาใช้กับชุดข้อมูลใหม่ผ่านการปรับแต่งอย่างละเอียด ในกรณีนี้คุณได้ปรับน้ำหนักของคุณเพื่อให้โมเดลของคุณเรียนรู้คุณลักษณะระดับสูงเฉพาะสำหรับชุดข้อมูล โดยปกติจะแนะนำให้ใช้เทคนิคนี้เมื่อชุดข้อมูลการฝึกอบรมมีขนาดใหญ่และคล้ายกับชุดข้อมูลดั้งเดิมที่ได้รับการฝึกอบรมมาก่อน

หากต้องการเรียนรู้เพิ่มเติมโปรดไปที่ คู่มือการถ่ายโอนการเรียนรู้

# MIT License
#
# Copyright (c) 2017 François Chollet                                                                                                                    # IGNORE_COPYRIGHT: cleared by OSS licensing
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.