Halaman ini diterjemahkan oleh Cloud Translation API.
Switch to English

Mentransfer pembelajaran dan fine-tuning

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Mendirikan

import numpy as np
import tensorflow as tf
from tensorflow import keras

pengantar

Pembelajaran transfer terdiri dari mengambil fitur yang dipelajari pada satu masalah, dan memanfaatkannya pada masalah baru yang serupa. Misalnya, fitur dari model yang telah belajar mengidentifikasi racoons mungkin berguna untuk memulai model yang dimaksudkan untuk mengidentifikasi tanukis.

Pembelajaran transfer biasanya dilakukan untuk tugas-tugas yang set data Anda memiliki terlalu sedikit data untuk melatih model skala penuh dari awal.

Inkarnasi pembelajaran transfer yang paling umum dalam konteks pembelajaran mendalam adalah alur kerja berikut:

  1. Ambil lapisan dari model yang dilatih sebelumnya.
  2. Bekukan mereka, untuk menghindari penghancuran informasi apa pun yang dikandungnya selama putaran pelatihan di masa depan.
  3. Tambahkan beberapa lapisan baru yang bisa dilatih di atas lapisan beku. Mereka akan belajar mengubah fitur lama menjadi prediksi pada kumpulan data baru.
  4. Latih layer baru pada dataset Anda.

Langkah terakhir opsional, adalah fine-tuning , yang terdiri dari pemblokiran seluruh model yang Anda peroleh di atas (atau sebagian darinya), dan melatihnya kembali pada data baru dengan kecepatan pembelajaran yang sangat rendah. Ini berpotensi mencapai peningkatan yang berarti, dengan secara bertahap mengadaptasi fitur yang telah dilatih sebelumnya ke data baru.

Pertama, kita akan membahas API yang trainable Keras secara detail, yang mendasari sebagian besar alur kerja pemelajaran transfer & penyempurnaan.

Kemudian, kami akan mendemonstrasikan alur kerja umum dengan mengambil model yang dilatih sebelumnya pada kumpulan data ImageNet, dan melatihnya kembali di kumpulan data klasifikasi "kucing vs anjing" Kaggle.

Ini diadaptasi dari Deep Learning dengan Python dan entri blog 2016 "membangun model klasifikasi gambar yang kuat menggunakan data yang sangat sedikit" .

Lapisan beku: memahami atribut yang trainable

Lapisan & model memiliki tiga atribut bobot:

  • weights adalah daftar semua variabel bobot pada lapisan.
  • trainable_weights adalah daftar yang dimaksudkan untuk diperbarui (melalui penurunan gradien) untuk meminimalkan kerugian selama pelatihan.
  • non_trainable_weights adalah daftar yang tidak dimaksudkan untuk dilatih. Biasanya mereka diperbarui oleh model selama forward pass.

Contoh: lapisan Dense memiliki 2 bobot yang dapat dilatih (kernel & bias)

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

Secara umum, semua beban adalah anak timbangan yang bisa dilatih. Satu-satunya lapisan BatchNormalization yang memiliki bobot yang tidak dapat dilatih adalah lapisan BatchNormalization . Ini menggunakan bobot yang tidak bisa dilatih untuk melacak mean dan varians inputnya selama pelatihan. Untuk mempelajari cara menggunakan bobot yang tidak bisa dilatih di lapisan kustom Anda sendiri, lihat panduan untuk menulis lapisan baru dari awal .

Contoh: lapisan BatchNormalization memiliki 2 bobot yang dapat dilatih dan 2 bobot yang tidak dapat dilatih

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

Lapisan & model juga dilengkapi atribut boolean yang trainable . Nilainya bisa diubah. Menyetel layer.trainable ke False memindahkan semua bobot lapisan dari bisa dilatih menjadi tidak bisa dilatih. Ini disebut "membekukan" lapisan: status lapisan beku tidak akan diperbarui selama pelatihan (baik saat melatih dengan fit() atau saat melatih dengan loop khusus apa pun yang mengandalkan trainable_weights untuk menerapkan pembaruan gradien).

Contoh: menyetel trainable ke False

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

Saat beban yang dapat dilatih menjadi tidak dapat dilatih, nilainya tidak lagi diperbarui selama pelatihan.

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 0s 1ms/step - loss: 0.1275

Jangan mengacaukan atribut layer.trainable dengan training argumen di layer.__call__() (yang mengontrol apakah lapisan harus menjalankan layer.trainable dalam mode inferensi atau mode pelatihan). Untuk informasi selengkapnya, lihat FAQ Keras .

Pengaturan rekursif dari atribut trainable

Jika Anda menyetel trainable = False pada model atau lapisan apa pun yang memiliki sub-lapisan, semua lapisan turunan juga menjadi tidak dapat dilatih.

Contoh:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

Alur kerja transfer-belajar yang khas

Ini mengarahkan kita ke bagaimana alur kerja pembelajaran transfer biasa dapat diterapkan di Keras:

  1. Membuat instance model dasar dan memuat anak timbangan terlatih ke dalamnya.
  2. Bekukan semua lapisan dalam model dasar dengan menyetel trainable = False .
  3. Buat model baru di atas output dari satu (atau beberapa) lapisan dari model dasar.
  4. Latih model baru Anda pada kumpulan data baru Anda.

Perhatikan bahwa alternatif, alur kerja yang lebih ringan juga bisa berupa:

  1. Membuat instance model dasar dan memuat anak timbangan terlatih ke dalamnya.
  2. Jalankan dataset baru Anda melaluinya dan catat keluaran dari satu (atau beberapa) lapisan dari model dasar. Ini disebut ekstraksi fitur .
  3. Gunakan keluaran tersebut sebagai data masukan untuk model baru yang lebih kecil.

Keuntungan utama dari alur kerja kedua tersebut adalah Anda hanya menjalankan model dasar sekali pada data Anda, bukan sekali per periode pelatihan. Jadi jauh lebih cepat & lebih murah.

Masalah dengan alur kerja kedua tersebut, bagaimanapun, adalah bahwa hal itu tidak memungkinkan Anda untuk secara dinamis mengubah data masukan model baru Anda selama pelatihan, yang diperlukan saat melakukan augmentasi data, misalnya. Pembelajaran transfer biasanya digunakan untuk tugas-tugas ketika kumpulan data baru Anda memiliki terlalu sedikit data untuk melatih model skala penuh dari awal, dan dalam skenario seperti itu, augmentasi data sangat penting. Jadi berikut ini, kami akan fokus pada alur kerja pertama.

Inilah tampilan alur kerja pertama di Keras:

Pertama, buat contoh model dasar dengan bobot terlatih.

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

Kemudian, bekukan model dasarnya.

base_model.trainable = False

Buat model baru di atas.

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

Latih model pada data baru.

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

Mencari setelan

Setelah model Anda berkumpul di data baru, Anda dapat mencoba mencairkan semua atau sebagian dari model dasar dan melatih kembali seluruh model secara menyeluruh dengan kecepatan pembelajaran yang sangat rendah.

Ini adalah langkah terakhir opsional yang berpotensi memberi Anda peningkatan bertahap. Ini juga berpotensi menyebabkan overfitting cepat - ingatlah itu.

Sangat penting untuk melakukan langkah ini hanya setelah model dengan lapisan beku telah dilatih untuk konvergensi. Jika Anda mencampur lapisan yang dapat dilatih secara acak dengan lapisan yang dapat dilatih yang memiliki fitur terlatih, lapisan yang diinisialisasi secara acak akan menyebabkan pembaruan gradien yang sangat besar selama pelatihan, yang akan menghancurkan fitur terlatih Anda.

Penting juga untuk menggunakan kecepatan pemelajaran yang sangat rendah pada tahap ini, karena Anda melatih model yang jauh lebih besar daripada di putaran pertama pelatihan, pada kumpulan data yang biasanya sangat kecil. Akibatnya, Anda berisiko mengalami overfitting dengan sangat cepat jika menerapkan pembaruan bobot yang besar. Di sini, Anda hanya ingin menyesuaikan kembali bobot yang telah dilatih sebelumnya secara bertahap.

Ini adalah cara mengimplementasikan fine-tuning seluruh model dasar:

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

Catatan penting tentang compile() dan trainable

Memanggil compile() pada model dimaksudkan untuk "membekukan" perilaku model tersebut. Ini menyiratkan bahwa nilai atribut yang trainable pada saat model dikompilasi harus dipertahankan selama masa pakai model tersebut, hingga compile dipanggil lagi. Oleh karena itu, jika Anda mengubah nilai yang trainable , pastikan untuk memanggil compile() lagi pada model Anda agar perubahan Anda dipertimbangkan.

Catatan penting tentang lapisan BatchNormalization

Banyak model gambar berisi lapisan BatchNormalization . Lapisan itu adalah kasus khusus pada setiap hitungan yang bisa dibayangkan. Berikut beberapa hal yang perlu diperhatikan.

  • BatchNormalization berisi 2 beban yang tidak bisa dilatih yang diperbarui selama pelatihan. Ini adalah variabel yang melacak mean dan varians dari input.
  • Saat Anda menyetel bn_layer.trainable = False , lapisan BatchNormalization akan berjalan dalam mode inferensi, dan tidak akan memperbarui statistik mean & variansnya. Ini tidak berlaku untuk lapisan lain secara umum, karena pelatihan beban & mode inferensi / pelatihan adalah dua konsep ortogonal . Tetapi keduanya terikat dalam kasus lapisan BatchNormalization .
  • Saat Anda mencairkan model yang berisi lapisan BatchNormalization untuk melakukan BatchNormalization , Anda harus mempertahankan lapisan BatchNormalization dalam mode inferensi dengan meneruskan training=False saat memanggil model dasar. Jika tidak, pembaruan yang diterapkan pada bobot yang tidak dapat dilatih tiba-tiba akan menghancurkan apa yang telah dipelajari model.

Anda akan melihat pola ini beraksi di contoh ujung ke ujung di akhir panduan ini.

Mentransfer pembelajaran & menyempurnakan dengan loop pelatihan kustom

Jika alih-alih fit() , Anda menggunakan loop pelatihan tingkat rendah Anda sendiri, alur kerja pada dasarnya tetap sama. Anda harus berhati-hati untuk hanya memperhitungkan model model.trainable_weights saat menerapkan pembaruan gradien:

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))
.dll

Begitu juga untuk fine-tuning.

Contoh ujung ke ujung: menyempurnakan model klasifikasi gambar pada kucing vs. anjing

Himpunan data

Untuk memperkuat konsep-konsep ini, mari memandu Anda melalui pembelajaran transfer ujung-ke-ujung yang konkret & contoh fine-tuning. Kami akan memuat model Xception, yang telah dilatih sebelumnya di ImageNet, dan menggunakannya pada set data klasifikasi "kucing vs. anjing" Kaggle.

Mendapatkan datanya

Pertama, mari ambil kumpulan data kucing vs anjing menggunakan TFDS. Jika Anda memiliki kumpulan data Anda sendiri, Anda mungkin ingin menggunakan utilitas tf.keras.preprocessing.image_dataset_from_directory untuk menghasilkan objek kumpulan data berlabel serupa dari sekumpulan gambar pada disk yang dimasukkan ke dalam folder khusus kelas.

Pembelajaran transfer paling berguna saat bekerja dengan kumpulan data yang sangat kecil. Untuk menjaga kumpulan data kami kecil, kami akan menggunakan 40% dari data pelatihan asli (25.000 gambar) untuk pelatihan, 10% untuk validasi, dan 10% untuk pengujian.

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Downloading and preparing dataset cats_vs_dogs/4.0.0 (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0...

Warning:absl:1738 images were corrupted and were skipped

Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0.incompleteIL7NQA/cats_vs_dogs-train.tfrecord
Dataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data.
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

Ini adalah 9 gambar pertama dalam set data pelatihan - seperti yang Anda lihat, ukurannya berbeda-beda.

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

Kita juga dapat melihat bahwa label 1 adalah "anjing" dan label 0 adalah "kucing".

Standarisasi data

Gambar mentah kami memiliki berbagai ukuran. Selain itu, setiap piksel terdiri dari 3 nilai integer antara 0 dan 255 (nilai level RGB). Ini tidak cocok untuk memberi makan jaringan saraf. Kita perlu melakukan 2 hal:

  • Standarisasi ke ukuran gambar tetap. Kami memilih 150x150.
  • Normalisasi nilai piksel antara -1 dan 1. Kita akan melakukan ini menggunakan lapisan Normalization sebagai bagian dari model itu sendiri.

Secara umum, merupakan praktik yang baik untuk mengembangkan model yang menggunakan data mentah sebagai masukan, bukan model yang mengambil data yang sudah diproses sebelumnya. Alasannya adalah, jika model Anda mengharapkan data praproses, setiap kali Anda mengekspor model untuk digunakan di tempat lain (di browser web, di aplikasi seluler), Anda harus menerapkan ulang pipeline praproses yang sama persis. Ini menjadi sangat rumit dengan sangat cepat. Jadi, kita harus melakukan preprocessing sesedikit mungkin sebelum mencapai model.

Di sini, kami akan melakukan pengubahan ukuran gambar di pipeline data (karena deep neural network hanya dapat memproses kumpulan data yang berdekatan), dan kami akan melakukan penskalaan nilai input sebagai bagian dari model, saat kami membuatnya.

Mari ubah ukuran gambar menjadi 150x150:

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

Selain itu, mari kita batch data dan gunakan caching & prefetching untuk mengoptimalkan kecepatan pemuatan.

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

Menggunakan augmentasi data acak

Jika Anda tidak memiliki kumpulan data gambar yang besar, praktik yang baik adalah memperkenalkan keragaman sampel secara artifisial dengan menerapkan transformasi acak namun realistis ke gambar pelatihan, seperti membalik horizontal acak atau rotasi acak kecil. Ini membantu memaparkan model ke berbagai aspek data pelatihan sambil memperlambat overfitting.

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [
        layers.experimental.preprocessing.RandomFlip("horizontal"),
        layers.experimental.preprocessing.RandomRotation(0.1),
    ]
)

Mari kita visualisasikan seperti apa gambar pertama dari kumpulan pertama setelah berbagai transformasi acak:

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[i]))
        plt.axis("off")

png

Bangun model

Sekarang mari kita buat model yang mengikuti cetak biru yang telah kita jelaskan sebelumnya.

Perhatikan bahwa:

  • Kami menambahkan lapisan Normalization untuk menskalakan nilai input (awalnya dalam rentang [0, 255] ) ke rentang [-1, 1] .
  • Kami menambahkan lapisan Dropout sebelum lapisan klasifikasi, untuk regularisasi.
  • Kami memastikan untuk meneruskan training=False saat memanggil model dasar, sehingga model tersebut berjalan dalam mode inferensi, sehingga statistik batchnorm tidak diperbarui bahkan setelah kami mencairkan model dasar untuk penyempurnaan.
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be normalized
# from (0, 255) to a range (-1., +1.), the normalization layer
# does the following, outputs = (inputs - mean) / sqrt(var)
norm_layer = keras.layers.experimental.preprocessing.Normalization()
mean = np.array([127.5] * 3)
var = mean ** 2
# Scale inputs to [-1, +1]
x = norm_layer(x)
norm_layer.set_weights([mean, var])

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83689472/83683744 [==============================] - 2s 0us/step
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
normalization (Normalization (None, 150, 150, 3)       7         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,536
Trainable params: 2,049
Non-trainable params: 20,861,487
_________________________________________________________________

Latih lapisan atas

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
291/291 [==============================] - 9s 32ms/step - loss: 0.1758 - binary_accuracy: 0.9226 - val_loss: 0.0897 - val_binary_accuracy: 0.9660
Epoch 2/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1211 - binary_accuracy: 0.9497 - val_loss: 0.0870 - val_binary_accuracy: 0.9686
Epoch 3/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1166 - binary_accuracy: 0.9503 - val_loss: 0.0814 - val_binary_accuracy: 0.9712
Epoch 4/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1125 - binary_accuracy: 0.9534 - val_loss: 0.0825 - val_binary_accuracy: 0.9695
Epoch 5/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1073 - binary_accuracy: 0.9569 - val_loss: 0.0763 - val_binary_accuracy: 0.9703
Epoch 6/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1041 - binary_accuracy: 0.9573 - val_loss: 0.0812 - val_binary_accuracy: 0.9686
Epoch 7/20
291/291 [==============================] - 8s 27ms/step - loss: 0.1023 - binary_accuracy: 0.9567 - val_loss: 0.0820 - val_binary_accuracy: 0.9669
Epoch 8/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1005 - binary_accuracy: 0.9597 - val_loss: 0.0779 - val_binary_accuracy: 0.9695
Epoch 9/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1019 - binary_accuracy: 0.9580 - val_loss: 0.0813 - val_binary_accuracy: 0.9699
Epoch 10/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0940 - binary_accuracy: 0.9651 - val_loss: 0.0762 - val_binary_accuracy: 0.9729
Epoch 11/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0974 - binary_accuracy: 0.9613 - val_loss: 0.0752 - val_binary_accuracy: 0.9725
Epoch 12/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0965 - binary_accuracy: 0.9591 - val_loss: 0.0760 - val_binary_accuracy: 0.9721
Epoch 13/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0962 - binary_accuracy: 0.9598 - val_loss: 0.0785 - val_binary_accuracy: 0.9712
Epoch 14/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0966 - binary_accuracy: 0.9616 - val_loss: 0.0831 - val_binary_accuracy: 0.9699
Epoch 15/20
291/291 [==============================] - 8s 27ms/step - loss: 0.1000 - binary_accuracy: 0.9574 - val_loss: 0.0741 - val_binary_accuracy: 0.9725
Epoch 16/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0940 - binary_accuracy: 0.9628 - val_loss: 0.0781 - val_binary_accuracy: 0.9686
Epoch 17/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0915 - binary_accuracy: 0.9634 - val_loss: 0.0843 - val_binary_accuracy: 0.9678
Epoch 18/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0937 - binary_accuracy: 0.9620 - val_loss: 0.0829 - val_binary_accuracy: 0.9669
Epoch 19/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0988 - binary_accuracy: 0.9601 - val_loss: 0.0862 - val_binary_accuracy: 0.9686
Epoch 20/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0928 - binary_accuracy: 0.9644 - val_loss: 0.0798 - val_binary_accuracy: 0.9703

<tensorflow.python.keras.callbacks.History at 0x7f6104f04518>

Lakukan satu putaran fine-tuning untuk seluruh model

Terakhir, mari kita cabut model dasar dan latih seluruh model secara menyeluruh dengan kecepatan pembelajaran yang rendah.

Yang penting, meskipun model dasar menjadi dapat dilatih, model tersebut tetap berjalan dalam mode inferensi karena kita lulus training=False saat memanggilnya saat kita membuat model. Ini berarti bahwa lapisan normalisasi batch di dalamnya tidak akan memperbarui statistik batch-nya. Jika mereka melakukannya, mereka akan merusak representasi yang dipelajari oleh model sejauh ini.

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
normalization (Normalization (None, 150, 150, 3)       7         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,536
Trainable params: 20,809,001
Non-trainable params: 54,535
_________________________________________________________________
Epoch 1/10
  2/291 [..............................] - ETA: 17s - loss: 0.1439 - binary_accuracy: 0.9219WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0329s vs `on_train_batch_end` time: 0.0905s). Check your callbacks.

Warning:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0329s vs `on_train_batch_end` time: 0.0905s). Check your callbacks.

291/291 [==============================] - 38s 132ms/step - loss: 0.0786 - binary_accuracy: 0.9706 - val_loss: 0.0631 - val_binary_accuracy: 0.9772
Epoch 2/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0553 - binary_accuracy: 0.9790 - val_loss: 0.0537 - val_binary_accuracy: 0.9781
Epoch 3/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0442 - binary_accuracy: 0.9829 - val_loss: 0.0532 - val_binary_accuracy: 0.9819
Epoch 4/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0369 - binary_accuracy: 0.9858 - val_loss: 0.0460 - val_binary_accuracy: 0.9832
Epoch 5/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0335 - binary_accuracy: 0.9870 - val_loss: 0.0561 - val_binary_accuracy: 0.9794
Epoch 6/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0253 - binary_accuracy: 0.9910 - val_loss: 0.0559 - val_binary_accuracy: 0.9819
Epoch 7/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0232 - binary_accuracy: 0.9920 - val_loss: 0.0432 - val_binary_accuracy: 0.9845
Epoch 8/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0185 - binary_accuracy: 0.9930 - val_loss: 0.0396 - val_binary_accuracy: 0.9854
Epoch 9/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0147 - binary_accuracy: 0.9948 - val_loss: 0.0439 - val_binary_accuracy: 0.9832
Epoch 10/10
291/291 [==============================] - 37s 129ms/step - loss: 0.0117 - binary_accuracy: 0.9954 - val_loss: 0.0538 - val_binary_accuracy: 0.9819

<tensorflow.python.keras.callbacks.History at 0x7f611c26e438>

Setelah 10 epoch, fine-tuning memberi kita peningkatan yang bagus di sini.