เข้าร่วม Women in ML Symposium ในวันที่ 7 ธันวาคม ลงทะเบียนตอนนี้

ถ่ายทอดการเรียนรู้และการปรับแต่ง

จัดทุกอย่างให้เป็นระเบียบอยู่เสมอด้วยคอลเล็กชัน บันทึกและจัดหมวดหมู่เนื้อหาตามค่ากำหนดของคุณ

ดูบน TensorFlow.org ทำงานใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดโน๊ตบุ๊ค

ติดตั้ง

import numpy as np
import tensorflow as tf
from tensorflow import keras

บทนำ

การเรียนรู้การถ่ายโอนประกอบด้วยการเรียนรู้คุณสมบัติในการแก้ปัญหาอย่างใดอย่างหนึ่งและใช้ประโยชน์จากพวกเขาในใหม่ปัญหาที่คล้ายกัน ตัวอย่างเช่น คุณลักษณะจากโมเดลที่เรียนรู้ที่จะระบุแรคคูนอาจมีประโยชน์ในการเริ่มต้นโมเดลเพื่อระบุแทนนุกิ

โดยปกติการถ่ายโอนการเรียนรู้จะทำได้สำหรับงานที่ชุดข้อมูลของคุณมีข้อมูลน้อยเกินไปที่จะฝึกแบบจำลองเต็มรูปแบบตั้งแต่เริ่มต้น

การแปลงโฉมการเรียนรู้โดยทั่วไปในบริบทของการเรียนรู้เชิงลึกคือขั้นตอนการทำงานต่อไปนี้:

  1. นำเลเยอร์จากแบบจำลองที่ได้รับการฝึกมาก่อนหน้านี้
  2. ตรึงพวกเขาไว้ เพื่อหลีกเลี่ยงการทำลายข้อมูลใด ๆ ที่มีอยู่ในรอบการฝึกอบรมในอนาคต
  3. เพิ่มเลเยอร์ใหม่ที่สามารถฝึกได้บนเลเยอร์ที่แช่แข็ง พวกเขาจะได้เรียนรู้การเปลี่ยนคุณสมบัติเก่าเป็นการคาดคะเนในชุดข้อมูลใหม่
  4. ฝึกเลเยอร์ใหม่ในชุดข้อมูลของคุณ

ที่ผ่านมาขั้นตอนที่เลือกคือปรับจูนซึ่งประกอบด้วย unfreezing รูปแบบทั้งคุณได้รับข้างต้น (หรือส่วนหนึ่งของมัน) และการฝึกอบรมอีกครั้งบนข้อมูลใหม่ที่มีอัตราการเรียนรู้ที่ต่ำมาก สิ่งนี้สามารถบรรลุการปรับปรุงที่มีความหมาย โดยการปรับคุณลักษณะที่ได้รับการฝึกมาล่วงหน้าให้เข้ากับข้อมูลใหม่ทีละน้อย

ครั้งแรกที่เราจะไปกว่า Keras trainable API ในรายละเอียดซึ่งรองรับการถ่ายโอนการเรียนรู้มากที่สุดและปรับจูนเวิร์กโฟลว์

จากนั้น เราจะสาธิตเวิร์กโฟลว์ทั่วไปโดยใช้โมเดลที่ได้รับการฝึกอบรมล่วงหน้าบนชุดข้อมูล ImageNet และฝึกใหม่ในชุดข้อมูลการจัดหมวดหมู่ "cats vs dogs" ของ Kaggle

นี้ถูกดัดแปลงมาจาก การเรียนรู้ลึกกับงูหลาม และบล็อกโพสต์ 2016 "ที่มีประสิทธิภาพการสร้างแบบจำลองการจัดหมวดหมู่ภาพโดยใช้ข้อมูลน้อยมาก"

ชั้นแช่แข็ง: การทำความเข้าใจ trainable แอตทริบิวต์

เลเยอร์และโมเดลมีคุณสมบัติด้านน้ำหนักสามอย่าง:

  • weights คือรายการน้ำหนักตัวแปรของชั้นที่
  • trainable_weights เป็นรายชื่อของผู้ที่มีความหมายที่จะได้รับการปรับปรุง (ผ่านการไล่ระดับสีโคตร) เพื่อลดการสูญเสียในระหว่างการฝึก
  • non_trainable_weights เป็นรายชื่อของผู้ที่ไม่ได้หมายถึงการได้รับการอบรม โดยทั่วไปแล้วโมเดลจะอัปเดตในระหว่างการส่งต่อ

ตัวอย่าง: Dense ชั้นมี 2 น้ำหนักสุวินัย (เคอร์เนลและอคติ)

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

โดยทั่วไป ตุ้มน้ำหนักทั้งหมดเป็นตุ้มน้ำหนักที่ฝึกได้ ในตัวชั้นที่มีน้ำหนักไม่ใช่สุวินัยเพียง แต่เป็น BatchNormalization ชั้น ใช้ตุ้มน้ำหนักที่ไม่สามารถฝึกได้เพื่อติดตามค่าเฉลี่ยและความแปรปรวนของอินพุตระหว่างการฝึก เรียนรู้วิธีการใช้น้ำหนักที่ไม่ใช่สุวินัยในชั้นของคุณเองให้ดูที่ คู่มือการเขียนชั้นใหม่จากรอยขีดข่วน

ตัวอย่าง: BatchNormalization ชั้นมี 2 น้ำหนักสุวินัยและ 2 น้ำหนักไม่ใช่สุวินัย

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

ชั้นและรุ่นที่มีคุณลักษณะแบบบูล trainable ค่าของมันสามารถเปลี่ยนแปลงได้ การตั้งค่า layer.trainable การ False ย้ายน้ำหนักทุกชั้นจากสุวินัยไม่ใช่สุวินัย นี้เรียกว่า "แช่แข็ง" ชั้น: รัฐของชั้นแช่แข็งจะไม่ได้รับการปรับปรุงในระหว่างการฝึกอบรม (ทั้งเมื่อการฝึกอบรมกับ fit() หรือเมื่อการฝึกอบรมกับห่วงกำหนดเองใด ๆ ที่อาศัย trainable_weights ใช้การปรับปรุงการไล่ระดับสี)

ตัวอย่าง: การตั้งค่า trainable การ False

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

เมื่อตุ้มน้ำหนักที่ฝึกได้นั้นไม่สามารถฝึกได้ ค่าของตุ้มน้ำหนักจะไม่ได้รับการอัปเดตอีกต่อไปในระหว่างการฝึก

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 1s 640ms/step - loss: 0.0945

อย่าสับสน layer.trainable แอตทริบิวต์ที่มีการโต้แย้ง training ใน layer.__call__() (ซึ่งควบคุมว่าชั้นควรจะทำงานผ่านไปข้างหน้าในโหมดการอนุมานหรือโหมดการฝึกอบรม) สำหรับข้อมูลเพิ่มเติมโปรดดูที่ Keras คำถามที่พบบ่อย

การตั้งค่าซ้ำของ trainable แอตทริบิวต์

ถ้าคุณตั้งค่า trainable = False ในรูปแบบหรือบนชั้นใด ๆ ที่มี sublayers เด็กทุกชั้นไม่เป็นสุวินัยเช่นกัน

ตัวอย่าง:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

เวิร์กโฟลว์การโอนย้ายการเรียนรู้ทั่วไป

สิ่งนี้นำเราไปสู่วิธีการปรับใช้เวิร์กโฟลว์การเรียนรู้การถ่ายโอนทั่วไปใน Keras:

  1. สร้างโมเดลพื้นฐานและโหลดตุ้มน้ำหนักที่ฝึกไว้ล่วงหน้าลงไป
  2. ตรึงทุกชั้นในรูปแบบฐานโดยการตั้งค่า trainable = False
  3. สร้างโมเดลใหม่ที่ด้านบนของเอาต์พุตของหนึ่ง (หรือหลายเลเยอร์) จากโมเดลพื้นฐาน
  4. ฝึกโมเดลใหม่ของคุณกับชุดข้อมูลใหม่ของคุณ

โปรดทราบว่าทางเลือกอื่น เวิร์กโฟลว์ที่เบากว่าอาจเป็น:

  1. สร้างโมเดลพื้นฐานและโหลดตุ้มน้ำหนักที่ฝึกไว้ล่วงหน้าลงไป
  2. รันชุดข้อมูลใหม่ของคุณผ่านมันและบันทึกเอาต์พุตของหนึ่ง (หรือหลายเลเยอร์) จากโมเดลพื้นฐาน นี้เรียกว่าการสกัดคุณลักษณะ
  3. ใช้เอาต์พุตนั้นเป็นข้อมูลอินพุตสำหรับโมเดลใหม่ที่เล็กกว่า

ข้อได้เปรียบที่สำคัญของเวิร์กโฟลว์ที่สองนั้นคือ คุณเรียกใช้แบบจำลองพื้นฐานเพียงครั้งเดียวกับข้อมูลของคุณ แทนที่จะเรียกใช้หนึ่งครั้งต่อยุคของการฝึก เร็วกว่าและถูกกว่ามาก

อย่างไรก็ตาม ปัญหาของเวิร์กโฟลว์ที่สองนั้นก็คือ ไม่อนุญาตให้คุณแก้ไขข้อมูลอินพุตของโมเดลใหม่ของคุณแบบไดนามิกในระหว่างการฝึกอบรม ซึ่งจำเป็นเมื่อทำการเสริมข้อมูล เป็นต้น โดยทั่วไปแล้วการถ่ายโอนการเรียนรู้จะใช้สำหรับงานเมื่อชุดข้อมูลใหม่ของคุณมีข้อมูลน้อยเกินไปที่จะฝึกโมเดลแบบเต็มสเกลตั้งแต่เริ่มต้น และในสถานการณ์ดังกล่าว การเพิ่มข้อมูลมีความสำคัญมาก ดังนั้นในสิ่งต่อไปนี้ เราจะเน้นที่เวิร์กโฟลว์แรก

นี่คือลักษณะของเวิร์กโฟลว์แรกใน Keras:

ขั้นแรก สร้างตัวอย่างโมเดลพื้นฐานด้วยตุ้มน้ำหนักที่ฝึกไว้ล่วงหน้า

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

จากนั้นตรึงโมเดลพื้นฐาน

base_model.trainable = False

สร้างรูปแบบใหม่ด้านบน

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

ฝึกโมเดลเกี่ยวกับข้อมูลใหม่

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

ปรับจูน

เมื่อโมเดลของคุณผสานเข้ากับข้อมูลใหม่แล้ว คุณสามารถลองยกเลิกการตรึงโมเดลพื้นฐานทั้งหมดหรือบางส่วน และฝึกโมเดลทั้งหมดแบบ end-to-end ด้วยอัตราการเรียนรู้ที่ต่ำมาก

นี่เป็นขั้นตอนสุดท้ายที่เป็นทางเลือกที่อาจทำให้คุณมีการปรับปรุงเพิ่มขึ้นได้ นอกจากนี้ยังอาจนำไปสู่การใส่มากเกินไปอย่างรวดเร็ว - พึงระลึกไว้เสมอว่า

มันเป็นสิ่งสำคัญที่จะเพียง แต่ทำตามขั้นตอนนี้หลังจากที่รุ่นที่มีชั้นแช่แข็งได้รับการฝึกอบรมเพื่อการบรรจบกัน หากคุณผสมเลเยอร์ที่ฝึกได้แบบสุ่มที่เริ่มต้นแบบสุ่มกับเลเยอร์ที่ฝึกได้ซึ่งมีคุณลักษณะที่ได้รับการฝึกไว้ล่วงหน้า เลเยอร์ที่เริ่มต้นแบบสุ่มจะทำให้เกิดการอัปเดตการไล่ระดับสีที่มีขนาดใหญ่มากระหว่างการฝึก ซึ่งจะทำลายคุณลักษณะที่ได้รับการฝึกล่วงหน้าของคุณ

สิ่งสำคัญคือต้องใช้อัตราการเรียนรู้ที่ต่ำมากในขั้นตอนนี้ เนื่องจากคุณกำลังฝึกโมเดลที่ใหญ่กว่าในการฝึกรอบแรก ในชุดข้อมูลที่ปกติแล้วมีขนาดเล็กมาก ด้วยเหตุนี้ คุณจึงมีความเสี่ยงที่จะฟิตมากเกินไปอย่างรวดเร็วหากคุณใช้การอัปเดตน้ำหนักจำนวนมาก ที่นี่ คุณต้องการอ่านเฉพาะตุ้มน้ำหนักที่ฝึกไว้ล่วงหน้าด้วยวิธีที่เพิ่มขึ้นเท่านั้น

นี่คือวิธีการปรับใช้การปรับแต่งแบบจำลองพื้นฐานทั้งหมด:

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

หมายเหตุสำคัญเกี่ยวกับการ compile() และ trainable

โทร compile() ในรูปแบบที่จะหมายถึงการ "แช่แข็ง" พฤติกรรมของรูปแบบว่า ซึ่งหมายความว่า trainable ค่าแอตทริบิวต์ในเวลาแบบจะรวบรวมควรจะเก็บรักษาตลอดอายุการใช้งานของรูปแบบนั้นจนกระทั่ง compile ถูกเรียกอีกครั้ง ดังนั้นหากคุณเปลี่ยน trainable ค่าให้แน่ใจว่าจะโทร compile() อีกครั้งในรูปแบบของคุณสำหรับการเปลี่ยนแปลงของคุณจะถูกนำเข้าบัญชี

หมายเหตุสำคัญเกี่ยวกับ BatchNormalization ชั้น

รูปแบบของภาพจำนวนมากประกอบด้วย BatchNormalization ชั้น เลเยอร์นั้นเป็นกรณีพิเศษในทุกจำนวนเท่าที่จะจินตนาการได้ นี่คือสิ่งที่ควรคำนึงถึง

  • BatchNormalization มี 2 น้ำหนักไม่ใช่สุวินัยที่ได้รับการปรับปรุงในระหว่างการฝึกอบรม เหล่านี้เป็นตัวแปรที่ติดตามค่าเฉลี่ยและความแปรปรวนของอินพุต
  • เมื่อคุณตั้ง bn_layer.trainable = False ที่ BatchNormalization ชั้นจะทำงานในโหมดการอนุมานและจะไม่ปรับปรุงสถิติค่าเฉลี่ยและความแปรปรวนของ กรณีนี้ไม่ได้สำหรับชั้นอื่น ๆ โดยทั่วไปเป็น trainability น้ำหนักและการอนุมาน / โหมดการฝึกอบรมเป็นสองแนวคิดมุมฉาก แต่ทั้งสองจะเชื่อมโยงในกรณีของ BatchNormalization ชั้น
  • เมื่อคุณยกเลิกการตรึงรูปแบบที่มี BatchNormalization ชั้นในเพื่อที่จะทำปรับจูน, คุณควรเก็บ BatchNormalization ชั้นในโหมดการอนุมานโดยผ่าน training=False เมื่อโทรฐานแบบจำลอง มิฉะนั้น การอัปเดตที่ใช้กับตุ้มน้ำหนักที่ไม่สามารถฝึกได้จะทำลายสิ่งที่โมเดลได้เรียนรู้ในทันที

คุณจะเห็นรูปแบบการทำงานนี้ในตัวอย่างตั้งแต่ต้นจนจบที่ส่วนท้ายของคู่มือนี้

ถ่ายทอดการเรียนรู้และการปรับแต่งด้วยลูปการฝึกแบบกำหนดเอง

ถ้าแทน fit() , คุณกำลังใช้ระดับต่ำห่วงการฝึกอบรมของคุณเองการเข้าพักขั้นตอนการทำงานหลักเดียวกัน คุณควรจะระมัดระวังในการใช้เวลาเพียงเข้าบัญชีรายการ model.trainable_weights เมื่อใช้การอัปเดตการไล่ระดับสี:

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

ในทำนองเดียวกันสำหรับการปรับจูน

ตัวอย่างแบบ end-to-end: การปรับแต่งโมเดลการจัดประเภทรูปภาพบนชุดข้อมูล cat vs. dog อย่างละเอียด

ในการทำให้แนวคิดเหล่านี้แข็งแกร่งขึ้น เรามาแนะนำคุณผ่านการเรียนรู้การถ่ายโอนจากต้นทางถึงปลายทางที่เป็นรูปธรรมและตัวอย่างการปรับแต่งอย่างละเอียด เราจะโหลดโมเดล Xception ซึ่งได้รับการฝึกอบรมล่วงหน้าบน ImageNet และใช้กับชุดข้อมูลการจัดหมวดหมู่ "cats vs. dogs" ของ Kaggle

การรับข้อมูล

ขั้นแรก ให้ดึงชุดข้อมูลแมวกับสุนัขโดยใช้ TFDS หากคุณมีชุดข้อมูลของคุณเองคุณอาจจะต้องการที่จะใช้ยูทิลิตี้ tf.keras.preprocessing.image_dataset_from_directory ในการสร้างที่คล้ายกันที่มีป้ายกำกับวัตถุชุดจากชุดของภาพบนดิสก์ที่ยื่นลงในโฟลเดอร์ระดับที่เฉพาะเจาะจง

การถ่ายโอนการเรียนรู้มีประโยชน์มากที่สุดเมื่อทำงานกับชุดข้อมูลขนาดเล็กมาก เพื่อให้ชุดข้อมูลของเรามีขนาดเล็ก เราจะใช้ข้อมูลการฝึกอบรมดั้งเดิม 40% (25,000 ภาพ) สำหรับการฝึกอบรม 10% สำหรับการตรวจสอบ และ 10% สำหรับการทดสอบ

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

นี่คือ 9 ภาพแรกในชุดข้อมูลการฝึก อย่างที่คุณเห็น พวกมันมีขนาดต่างกันทั้งหมด

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

นอกจากนี้เรายังสามารถเห็นได้ว่าป้ายกำกับ 1 คือ "สุนัข" และป้ายกำกับ 0 คือ "แมว"

การกำหนดมาตรฐานข้อมูล

ภาพดิบของเรามีหลายขนาด นอกจากนี้ แต่ละพิกเซลยังประกอบด้วยค่าจำนวนเต็ม 3 ค่าระหว่าง 0 ถึง 255 (ค่าระดับ RGB) วิธีนี้ไม่เหมาะกับการป้อนโครงข่ายประสาทเทียม เราต้องทำ 2 สิ่ง:

  • กำหนดมาตรฐานให้มีขนาดภาพคงที่ เราเลือก 150x150
  • ค่าพิกเซลปกติระหว่าง -1 และ 1 เราจะทำเช่นนี้โดยใช้ Normalization ชั้นเป็นส่วนหนึ่งของรูปแบบของตัวเอง

โดยทั่วไป แนวทางปฏิบัติที่ดีในการพัฒนาโมเดลที่รับข้อมูลดิบเป็นอินพุต ตรงข้ามกับโมเดลที่ใช้ข้อมูลที่ประมวลผลล่วงหน้าแล้ว เหตุผลก็คือ หากโมเดลของคุณคาดหวังข้อมูลที่ประมวลผลล่วงหน้า ทุกครั้งที่คุณส่งออกโมเดลของคุณไปใช้ที่อื่น (ในเว็บเบราว์เซอร์ ในแอปบนอุปกรณ์เคลื่อนที่) คุณจะต้องนำไปป์ไลน์การประมวลผลล่วงหน้าเดิมกลับมาใช้ใหม่ สิ่งนี้จะยุ่งยากมากอย่างรวดเร็ว ดังนั้นเราควรประมวลผลล่วงหน้าให้น้อยที่สุดก่อนที่จะกดโมเดล

ในที่นี้ เราจะทำการปรับขนาดรูปภาพในไปป์ไลน์ข้อมูล (เนื่องจากโครงข่ายประสาทเทียมเชิงลึกสามารถประมวลผลเฉพาะกลุ่มข้อมูลที่ต่อเนื่องกัน) และเราจะทำการปรับขนาดค่าอินพุตซึ่งเป็นส่วนหนึ่งของโมเดล เมื่อเราสร้างมันขึ้นมา

มาปรับขนาดภาพเป็น 150x150:

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

นอกจากนี้ มาแบทช์ข้อมูลและใช้การแคชและการดึงข้อมูลล่วงหน้าเพื่อปรับความเร็วในการโหลดให้เหมาะสมที่สุด

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

การใช้การเพิ่มข้อมูลแบบสุ่ม

เมื่อคุณไม่มีชุดข้อมูลรูปภาพขนาดใหญ่ เป็นการดีที่จะแนะนำความหลากหลายของตัวอย่างโดยใช้การแปลงแบบสุ่มแต่เหมือนจริงกับรูปภาพการฝึก เช่น การพลิกแนวนอนแบบสุ่มหรือการหมุนแบบสุ่มเล็กๆ ซึ่งจะช่วยให้โมเดลมองเห็นข้อมูลการฝึกในแง่มุมต่างๆ ในขณะที่ลดความเร็วลง

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [layers.RandomFlip("horizontal"), layers.RandomRotation(0.1),]
)

ลองนึกภาพว่าภาพแรกของชุดแรกมีลักษณะอย่างไรหลังจากการแปลงแบบสุ่มต่างๆ:

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[0]))
        plt.axis("off")
2021-09-01 18:45:34.772284: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

สร้างแบบจำลอง

ตอนนี้ มาสร้างแบบจำลองที่เป็นไปตามพิมพ์เขียวที่เราได้อธิบายไว้ก่อนหน้านี้

โปรดทราบว่า:

  • เราเพิ่ม Rescaling ชั้นค่าที่ป้อนเข้าขนาด (ครั้งแรกใน [0, 255] ช่วง) เพื่อ [-1, 1] ช่วง
  • เราเพิ่ม Dropout ชั้นก่อนที่จะจำแนกชั้นสำหรับกู
  • เราแน่ใจว่าจะผ่าน training=False เมื่อโทรฐานแบบจำลองเพื่อที่จะทำงานในโหมดการอนุมานเพื่อให้สถิติ batchnorm ไม่ได้รับการปรับปรุงแม้หลังจากที่เรายกเลิกการตรึงฐานแบบจำลองสำหรับการปรับจูน
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83689472/83683744 [==============================] - 2s 0us/step
83697664/83683744 [==============================] - 2s 0us/step
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 2,049
Non-trainable params: 20,861,480
_________________________________________________________________

ฝึกชั้นบนสุด

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
151/291 [==============>...............] - ETA: 3s - loss: 0.1979 - binary_accuracy: 0.9096
Corrupt JPEG data: 65 extraneous bytes before marker 0xd9
268/291 [==========================>...] - ETA: 1s - loss: 0.1663 - binary_accuracy: 0.9269
Corrupt JPEG data: 239 extraneous bytes before marker 0xd9
282/291 [============================>.] - ETA: 0s - loss: 0.1628 - binary_accuracy: 0.9284
Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9
Corrupt JPEG data: 228 extraneous bytes before marker 0xd9
291/291 [==============================] - ETA: 0s - loss: 0.1620 - binary_accuracy: 0.9286
Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9
291/291 [==============================] - 29s 63ms/step - loss: 0.1620 - binary_accuracy: 0.9286 - val_loss: 0.0814 - val_binary_accuracy: 0.9686
Epoch 2/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1178 - binary_accuracy: 0.9511 - val_loss: 0.0785 - val_binary_accuracy: 0.9695
Epoch 3/20
291/291 [==============================] - 9s 30ms/step - loss: 0.1121 - binary_accuracy: 0.9536 - val_loss: 0.0748 - val_binary_accuracy: 0.9712
Epoch 4/20
291/291 [==============================] - 9s 29ms/step - loss: 0.1082 - binary_accuracy: 0.9554 - val_loss: 0.0754 - val_binary_accuracy: 0.9703
Epoch 5/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1034 - binary_accuracy: 0.9570 - val_loss: 0.0721 - val_binary_accuracy: 0.9725
Epoch 6/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0975 - binary_accuracy: 0.9602 - val_loss: 0.0748 - val_binary_accuracy: 0.9699
Epoch 7/20
291/291 [==============================] - 9s 29ms/step - loss: 0.0989 - binary_accuracy: 0.9595 - val_loss: 0.0732 - val_binary_accuracy: 0.9716
Epoch 8/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1027 - binary_accuracy: 0.9566 - val_loss: 0.0787 - val_binary_accuracy: 0.9678
Epoch 9/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0959 - binary_accuracy: 0.9614 - val_loss: 0.0734 - val_binary_accuracy: 0.9729
Epoch 10/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0995 - binary_accuracy: 0.9588 - val_loss: 0.0717 - val_binary_accuracy: 0.9721
Epoch 11/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0957 - binary_accuracy: 0.9612 - val_loss: 0.0731 - val_binary_accuracy: 0.9725
Epoch 12/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0936 - binary_accuracy: 0.9622 - val_loss: 0.0751 - val_binary_accuracy: 0.9716
Epoch 13/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0965 - binary_accuracy: 0.9610 - val_loss: 0.0821 - val_binary_accuracy: 0.9695
Epoch 14/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0939 - binary_accuracy: 0.9618 - val_loss: 0.0742 - val_binary_accuracy: 0.9712
Epoch 15/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0974 - binary_accuracy: 0.9585 - val_loss: 0.0771 - val_binary_accuracy: 0.9712
Epoch 16/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0947 - binary_accuracy: 0.9621 - val_loss: 0.0823 - val_binary_accuracy: 0.9699
Epoch 17/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0947 - binary_accuracy: 0.9625 - val_loss: 0.0718 - val_binary_accuracy: 0.9708
Epoch 18/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0928 - binary_accuracy: 0.9616 - val_loss: 0.0738 - val_binary_accuracy: 0.9716
Epoch 19/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0922 - binary_accuracy: 0.9644 - val_loss: 0.0743 - val_binary_accuracy: 0.9716
Epoch 20/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0885 - binary_accuracy: 0.9635 - val_loss: 0.0745 - val_binary_accuracy: 0.9695
<keras.callbacks.History at 0x7f849a3b2950>

ทำการจูนอย่างละเอียดทั้งรุ่น

สุดท้าย เรามาเลิกตรึงโมเดลพื้นฐานและฝึกโมเดลทั้งหมดตั้งแต่ต้นจนจบด้วยอัตราการเรียนรู้ที่ต่ำ

ที่สำคัญแม้ว่าฐานแบบจำลองจะกลายเป็นสุวินัยจะยังคงทำงานในโหมดการอนุมานตั้งแต่เราผ่าน training=False เมื่อเรียกมันว่าเมื่อเราสร้างแบบจำลอง ซึ่งหมายความว่าเลเยอร์การทำให้เป็นมาตรฐานของแบทช์ภายในจะไม่อัปเดตสถิติของแบทช์ ถ้าเป็นเช่นนั้น พวกเขาจะทำลายความหายนะให้กับตัวแทนที่เรียนรู้โดยแบบจำลองจนถึงตอนนี้

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 20,809,001
Non-trainable params: 54,528
_________________________________________________________________
Epoch 1/10
291/291 [==============================] - 43s 131ms/step - loss: 0.0802 - binary_accuracy: 0.9692 - val_loss: 0.0580 - val_binary_accuracy: 0.9764
Epoch 2/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0542 - binary_accuracy: 0.9792 - val_loss: 0.0529 - val_binary_accuracy: 0.9764
Epoch 3/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0400 - binary_accuracy: 0.9832 - val_loss: 0.0510 - val_binary_accuracy: 0.9798
Epoch 4/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0313 - binary_accuracy: 0.9879 - val_loss: 0.0505 - val_binary_accuracy: 0.9819
Epoch 5/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0272 - binary_accuracy: 0.9904 - val_loss: 0.0485 - val_binary_accuracy: 0.9807
Epoch 6/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0284 - binary_accuracy: 0.9901 - val_loss: 0.0497 - val_binary_accuracy: 0.9824
Epoch 7/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0198 - binary_accuracy: 0.9937 - val_loss: 0.0530 - val_binary_accuracy: 0.9802
Epoch 8/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0173 - binary_accuracy: 0.9930 - val_loss: 0.0572 - val_binary_accuracy: 0.9819
Epoch 9/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0113 - binary_accuracy: 0.9958 - val_loss: 0.0555 - val_binary_accuracy: 0.9837
Epoch 10/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0091 - binary_accuracy: 0.9966 - val_loss: 0.0596 - val_binary_accuracy: 0.9832
<keras.callbacks.History at 0x7f83982d4cd0>

หลังจากผ่านไป 10 ยุค การปรับแต่งอย่างละเอียดทำให้เราได้รับการปรับปรุงที่ดีที่นี่