Transfer pembelajaran dengan TensorFlow Hub

Lihat di TensorFlow.org Jalankan di Google Colab Lihat di GitHub Unduh buku catatan Lihat model TF Hub

TensorFlow Hub adalah repositori model TensorFlow yang telah dilatih sebelumnya.

Tutorial ini menunjukkan cara:

  1. Gunakan model dari TensorFlow Hub dengan tf.keras .
  2. Gunakan model klasifikasi gambar dari TensorFlow Hub.
  3. Lakukan transfer learning sederhana untuk menyempurnakan model untuk kelas gambar Anda sendiri.

Mempersiapkan

import numpy as np
import time

import PIL.Image as Image
import matplotlib.pylab as plt

import tensorflow as tf
import tensorflow_hub as hub

import datetime

%load_ext tensorboard

Pengklasifikasi ImageNet

Anda akan mulai dengan menggunakan model pengklasifikasi yang telah dilatih sebelumnya pada kumpulan data benchmark ImageNet —tidak diperlukan pelatihan awal!

Unduh pengklasifikasi

Pilih model terlatih MobileNetV2 dari TensorFlow Hub dan bungkus sebagai lapisan Keras dengan hub.KerasLayer . Semua model pengklasifikasi gambar yang kompatibel dari TensorFlow Hub akan berfungsi di sini, termasuk contoh yang diberikan dalam tarik-turun di bawah.

mobilenet_v2 ="https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4"
inception_v3 = "https://tfhub.dev/google/imagenet/inception_v3/classification/5"

classifier_model = mobilenet_v2
IMAGE_SHAPE = (224, 224)

classifier = tf.keras.Sequential([
    hub.KerasLayer(classifier_model, input_shape=IMAGE_SHAPE+(3,))
])

Jalankan pada satu gambar

Unduh satu gambar untuk mencoba model di:

grace_hopper = tf.keras.utils.get_file('image.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg')
grace_hopper = Image.open(grace_hopper).resize(IMAGE_SHAPE)
grace_hopper
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg
65536/61306 [================================] - 0s 0us/step
73728/61306 [====================================] - 0s 0us/step

png

grace_hopper = np.array(grace_hopper)/255.0
grace_hopper.shape
(224, 224, 3)

Tambahkan dimensi batch (dengan np.newaxis ) dan teruskan gambar ke model:

result = classifier.predict(grace_hopper[np.newaxis, ...])
result.shape
(1, 1001)

Hasilnya adalah vektor logit 1001 elemen, menilai probabilitas setiap kelas untuk gambar.

ID kelas atas dapat ditemukan dengan tf.math.argmax :

predicted_class = tf.math.argmax(result[0], axis=-1)
predicted_class
<tf.Tensor: shape=(), dtype=int64, numpy=653>

Decode prediksi

Ambil ID predicted_class (seperti 653 ) dan ambil label kumpulan data ImageNet untuk memecahkan kode prediksi:

labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt
16384/10484 [==============================================] - 0s 0us/step
24576/10484 [======================================================================] - 0s 0us/step
plt.imshow(grace_hopper)
plt.axis('off')
predicted_class_name = imagenet_labels[predicted_class]
_ = plt.title("Prediction: " + predicted_class_name.title())

png

Pembelajaran transfer sederhana

Tetapi bagaimana jika Anda ingin membuat pengklasifikasi khusus menggunakan dataset Anda sendiri yang memiliki kelas yang tidak disertakan dalam dataset ImageNet asli (dimana model pra-pelatihan dilatih)?

Untuk melakukannya, Anda dapat:

  1. Pilih model terlatih dari TensorFlow Hub; dan
  2. Latih kembali lapisan atas (terakhir) untuk mengenali kelas dari kumpulan data khusus Anda.

Himpunan data

Dalam contoh ini, Anda akan menggunakan kumpulan data bunga TensorFlow:

data_root = tf.keras.utils.get_file(
  'flower_photos',
  'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
   untar=True)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 7s 0us/step
228827136/228813984 [==============================] - 7s 0us/step

Pertama, muat data ini ke dalam model menggunakan data gambar dari disk dengan tf.keras.utils.image_dataset_from_directory , yang akan menghasilkan tf.data.Dataset :

batch_size = 32
img_height = 224
img_width = 224

train_ds = tf.keras.utils.image_dataset_from_directory(
  str(data_root),
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size
)

val_ds = tf.keras.utils.image_dataset_from_directory(
  str(data_root),
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size
)
Found 3670 files belonging to 5 classes.
Using 2936 files for training.
Found 3670 files belonging to 5 classes.
Using 734 files for validation.

Dataset bunga memiliki lima kelas:

class_names = np.array(train_ds.class_names)
print(class_names)
['daisy' 'dandelion' 'roses' 'sunflowers' 'tulips']

Kedua, karena konvensi TensorFlow Hub untuk model gambar mengharapkan input float dalam kisaran [0, 1] , gunakan lapisan prapemrosesan tf.keras.layers.Rescaling untuk mencapai ini.

normalization_layer = tf.keras.layers.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y)) # Where x—images, y—labels.
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y)) # Where x—images, y—labels.

Ketiga, selesaikan jalur input dengan menggunakan buffered prefetching dengan Dataset.prefetch , sehingga Anda dapat menghasilkan data dari disk tanpa masalah pemblokiran I/O.

Ini adalah beberapa metode tf.data terpenting yang harus Anda gunakan saat memuat data. Pembaca yang tertarik dapat mempelajari lebih lanjut tentang mereka, serta cara menyimpan data ke disk dan teknik lainnya, dalam kinerja yang lebih baik dengan panduan API tf.data .

AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
for image_batch, labels_batch in train_ds:
  print(image_batch.shape)
  print(labels_batch.shape)
  break
(32, 224, 224, 3)
(32,)
2022-01-26 05:06:19.465331: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Jalankan classifier pada sekumpulan gambar

Sekarang, jalankan classifier pada kumpulan gambar:

result_batch = classifier.predict(train_ds)
predicted_class_names = imagenet_labels[tf.math.argmax(result_batch, axis=-1)]
predicted_class_names
array(['daisy', 'coral fungus', 'rapeseed', ..., 'daisy', 'daisy',
       'birdhouse'], dtype='<U30')

Periksa bagaimana prediksi ini sejalan dengan gambar:

plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  plt.title(predicted_class_names[n])
  plt.axis('off')
_ = plt.suptitle("ImageNet predictions")

png

Hasilnya jauh dari sempurna, tetapi masuk akal mengingat ini bukan kelas yang dilatih modelnya (kecuali untuk "daisy").

Unduh model tanpa kepala

TensorFlow Hub juga mendistribusikan model tanpa lapisan klasifikasi teratas. Ini dapat digunakan untuk melakukan transfer learning dengan mudah.

Pilih model terlatih MobileNetV2 dari TensorFlow Hub . Semua model vektor fitur gambar yang kompatibel dari TensorFlow Hub akan berfungsi di sini, termasuk contoh dari menu tarik-turun.

mobilenet_v2 = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4"
inception_v3 = "https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4"

feature_extractor_model = mobilenet_v2

Buat ekstraktor fitur dengan membungkus model yang telah dilatih sebelumnya sebagai lapisan Keras dengan hub.KerasLayer . Gunakan argumen trainable=False untuk membekukan variabel, sehingga pelatihan hanya memodifikasi lapisan classifier baru:

feature_extractor_layer = hub.KerasLayer(
    feature_extractor_model,
    input_shape=(224, 224, 3),
    trainable=False)

Ekstraktor fitur mengembalikan vektor sepanjang 1280 untuk setiap gambar (ukuran kumpulan gambar tetap pada 32 dalam contoh ini):

feature_batch = feature_extractor_layer(image_batch)
print(feature_batch.shape)
(32, 1280)

Lampirkan kepala klasifikasi

Untuk melengkapi model, bungkus lapisan ekstraktor fitur dalam model tf.keras.Sequential dan tambahkan lapisan yang terhubung penuh untuk klasifikasi:

num_classes = len(class_names)

model = tf.keras.Sequential([
  feature_extractor_layer,
  tf.keras.layers.Dense(num_classes)
])

model.summary()
Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 keras_layer_1 (KerasLayer)  (None, 1280)              2257984   
                                                                 
 dense (Dense)               (None, 5)                 6405      
                                                                 
=================================================================
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________
predictions = model(image_batch)
predictions.shape
TensorShape([32, 5])

Latih modelnya

Gunakan Model.compile untuk mengonfigurasi proses pelatihan dan menambahkan panggilan balik tf.keras.callbacks.TensorBoard untuk membuat dan menyimpan log:

model.compile(
  optimizer=tf.keras.optimizers.Adam(),
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  metrics=['acc'])

log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir=log_dir,
    histogram_freq=1) # Enable histogram computation for every epoch.

Sekarang gunakan metode Model.fit untuk melatih model.

Untuk mempersingkat contoh ini, Anda hanya akan berlatih selama 10 epoch. Untuk memvisualisasikan kemajuan pelatihan di TensorBoard nanti, buat dan simpan log panggilan balik TensorBoard .

NUM_EPOCHS = 10

history = model.fit(train_ds,
                    validation_data=val_ds,
                    epochs=NUM_EPOCHS,
                    callbacks=tensorboard_callback)
Epoch 1/10
92/92 [==============================] - 7s 42ms/step - loss: 0.7904 - acc: 0.7210 - val_loss: 0.4592 - val_acc: 0.8515
Epoch 2/10
92/92 [==============================] - 3s 33ms/step - loss: 0.3850 - acc: 0.8713 - val_loss: 0.3694 - val_acc: 0.8787
Epoch 3/10
92/92 [==============================] - 3s 33ms/step - loss: 0.3027 - acc: 0.9057 - val_loss: 0.3367 - val_acc: 0.8856
Epoch 4/10
92/92 [==============================] - 3s 33ms/step - loss: 0.2524 - acc: 0.9237 - val_loss: 0.3210 - val_acc: 0.8869
Epoch 5/10
92/92 [==============================] - 3s 33ms/step - loss: 0.2164 - acc: 0.9373 - val_loss: 0.3124 - val_acc: 0.8896
Epoch 6/10
92/92 [==============================] - 3s 33ms/step - loss: 0.1888 - acc: 0.9469 - val_loss: 0.3070 - val_acc: 0.8937
Epoch 7/10
92/92 [==============================] - 3s 33ms/step - loss: 0.1668 - acc: 0.9550 - val_loss: 0.3032 - val_acc: 0.9005
Epoch 8/10
92/92 [==============================] - 3s 33ms/step - loss: 0.1487 - acc: 0.9619 - val_loss: 0.3004 - val_acc: 0.9005
Epoch 9/10
92/92 [==============================] - 3s 33ms/step - loss: 0.1335 - acc: 0.9687 - val_loss: 0.2981 - val_acc: 0.9019
Epoch 10/10
92/92 [==============================] - 3s 33ms/step - loss: 0.1206 - acc: 0.9748 - val_loss: 0.2964 - val_acc: 0.9046

Mulai TensorBoard untuk melihat bagaimana metrik berubah dengan setiap zaman dan untuk melacak nilai skalar lainnya:

%tensorboard --logdir logs/fit

Cek prediksinya

Dapatkan daftar nama kelas yang diurutkan dari prediksi model:

predicted_batch = model.predict(image_batch)
predicted_id = tf.math.argmax(predicted_batch, axis=-1)
predicted_label_batch = class_names[predicted_id]
print(predicted_label_batch)
['roses' 'dandelion' 'tulips' 'sunflowers' 'dandelion' 'roses' 'dandelion'
 'roses' 'tulips' 'dandelion' 'tulips' 'tulips' 'sunflowers' 'tulips'
 'dandelion' 'roses' 'daisy' 'tulips' 'dandelion' 'dandelion' 'dandelion'
 'tulips' 'sunflowers' 'roses' 'sunflowers' 'dandelion' 'tulips' 'roses'
 'roses' 'sunflowers' 'tulips' 'sunflowers']

Plot prediksi model:

plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)

for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  plt.title(predicted_label_batch[n].title())
  plt.axis('off')
_ = plt.suptitle("Model predictions")

png

Ekspor dan muat ulang model Anda

Sekarang setelah Anda melatih modelnya, ekspor sebagai SavedModel untuk digunakan kembali nanti.

t = time.time()

export_path = "/tmp/saved_models/{}".format(int(t))
model.save(export_path)

export_path
2022-01-26 05:07:03.429901: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: /tmp/saved_models/1643173621/assets
INFO:tensorflow:Assets written to: /tmp/saved_models/1643173621/assets
'/tmp/saved_models/1643173621'

Konfirmasikan bahwa Anda dapat memuat ulang SavedModel dan model dapat menampilkan hasil yang sama:

reloaded = tf.keras.models.load_model(export_path)
result_batch = model.predict(image_batch)
reloaded_result_batch = reloaded.predict(image_batch)
abs(reloaded_result_batch - result_batch).max()
0.0
reloaded_predicted_id = tf.math.argmax(reloaded_result_batch, axis=-1)
reloaded_predicted_label_batch = class_names[reloaded_predicted_id]
print(reloaded_predicted_label_batch)
['roses' 'dandelion' 'tulips' 'sunflowers' 'dandelion' 'roses' 'dandelion'
 'roses' 'tulips' 'dandelion' 'tulips' 'tulips' 'sunflowers' 'tulips'
 'dandelion' 'roses' 'daisy' 'tulips' 'dandelion' 'dandelion' 'dandelion'
 'tulips' 'sunflowers' 'roses' 'sunflowers' 'dandelion' 'tulips' 'roses'
 'roses' 'sunflowers' 'tulips' 'sunflowers']
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  plt.title(reloaded_predicted_label_batch[n].title())
  plt.axis('off')
_ = plt.suptitle("Model predictions")

png

Langkah selanjutnya

Anda dapat menggunakan SavedModel untuk memuat inferensi atau mengonversinya menjadi model TensorFlow Lite (untuk pembelajaran mesin di perangkat) atau model TensorFlow.js (untuk pembelajaran mesin dalam JavaScript).

Temukan lebih banyak tutorial untuk mempelajari cara menggunakan model terlatih dari TensorFlow Hub pada tugas gambar, teks, audio, dan video.