画像分類器を再トレーニングする

TensorFlow.org で表示 Google Colabで実行 GitHub でソースを表示 ノートブックをダウンロード TF Hub モデルを参照

はじめに

画像分類モデルには数百個のパラメータがあります。モデルをゼロからトレーニングするには、ラベル付きの多数のトレーニングデータと膨大なトレーニング性能が必要となります。転移学習とは、関連するタスクでトレーニングされたモデルの一部を取り出して新しいモデルで再利用することで、学習の大部分を省略するテクニックを指します。

この Colab では、より大規模で一般的な ImageNet データセットでトレーニングされた、TensorFlow Hub のトレーニング済み TF2 SavedModel を使用して画像特徴量を抽出することで、5 種類の花を分類する Keras モデルの構築方法を実演します。オプションとして、特徴量抽出器を新たに追加される分類器とともにトレーニング(「ファインチューニング」)することができます。

代替ツールをお探しですか?

これは、TensorFlow のコーディングチュートリアルです。TensorFlow または TF Lite モデルを構築するだけのツールをお探しの方は、PIP パッケージ tensorflow-hub[make_image_classifier] によってインストールされる make_image_classifier コマンドラインツール、またはこちらの TF Lite Colab をご覧ください。

セットアップ

import itertools
import os

import matplotlib.pylab as plt
import numpy as np

import tensorflow as tf
import tensorflow_hub as hub

print("TF version:", tf.__version__)
print("Hub version:", hub.__version__)
print("GPU is", "available" if tf.config.list_physical_devices('GPU') else "NOT AVAILABLE")
2024-01-11 19:59:24.086086: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 19:59:24.086129: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 19:59:24.087797: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
TF version: 2.15.0
Hub version: 0.15.0
GPU is available

使用する TF2 SavedModel モジュールを選択する

手始めに、https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4 を使用します。同じ URL を、SavedModel を識別するコードに使用できます。またブラウザで使用すれば、そのドキュメントを表示することができます。(ここでは TF1 Hub 形式のモデルは機能しないことに注意してください。)

画像特徴量ベクトルを生成するその他の TF2 モデルは、こちらをご覧ください。

試すことのできるモデルはたくさんあります。下のセルから別のモデルを選択し、ノートブックの指示に従ってください。

Selected model: efficientnetv2-xl-21k : https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_xl/feature_vector/2
Input size (512, 512)

Flowers データセットをセットアップする

入力は、選択されたモジュールに合わせてサイズ変更されます。データセットを拡張することで(読み取られるたびに画像をランダムに歪みを加える)、特にファインチューニング時のトレーニングが改善されます。

data_dir = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228813984/228813984 [==============================] - 1s 0us/step

Found 3670 files belonging to 5 classes.
Using 2936 files for training.
Found 3670 files belonging to 5 classes.
Using 734 files for validation.

モデルを定義する

Hub モジュールを使用して、線形分類器を feature_extractor_layer の上に配置するだけで定義できます。

高速化するため、トレーニング不可能な feature_extractor_layer から始めますが、ファインチューニングを実施して精度を高めることもできます。

do_fine_tuning = False
print("Building model with", model_handle)
model = tf.keras.Sequential([
    # Explicitly define the input shape so the model can be properly
    # loaded by the TFLiteConverter
    tf.keras.layers.InputLayer(input_shape=IMAGE_SIZE + (3,)),
    hub.KerasLayer(model_handle, trainable=do_fine_tuning),
    tf.keras.layers.Dropout(rate=0.2),
    tf.keras.layers.Dense(len(class_names),
                          kernel_regularizer=tf.keras.regularizers.l2(0.0001))
])
model.build((None,)+IMAGE_SIZE+(3,))
model.summary()
Building model with https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_xl/feature_vector/2
Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 keras_layer (KerasLayer)    (None, 1280)              207615832 
                                                                 
 dropout (Dropout)           (None, 1280)              0         
                                                                 
 dense (Dense)               (None, 5)                 6405      
                                                                 
=================================================================
Total params: 207622237 (792.02 MB)
Trainable params: 6405 (25.02 KB)
Non-trainable params: 207615832 (791.99 MB)
_________________________________________________________________

モデルをトレーニングする

model.compile(
  optimizer=tf.keras.optimizers.SGD(learning_rate=0.005, momentum=0.9), 
  loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1),
  metrics=['accuracy'])
steps_per_epoch = train_size // BATCH_SIZE
validation_steps = valid_size // BATCH_SIZE
hist = model.fit(
    train_ds,
    epochs=5, steps_per_epoch=steps_per_epoch,
    validation_data=val_ds,
    validation_steps=validation_steps).history
Epoch 1/5
  1/183 [..............................] - ETA: 1:52:17 - loss: 2.6761 - accuracy: 0.1875
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1705003243.156430  162420 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
183/183 [==============================] - 323s 2s/step - loss: 0.8854 - accuracy: 0.8921 - val_loss: 0.7000 - val_accuracy: 0.9347
Epoch 2/5
183/183 [==============================] - 290s 2s/step - loss: 0.6304 - accuracy: 0.9541 - val_loss: 0.5608 - val_accuracy: 0.9667
Epoch 3/5
183/183 [==============================] - 286s 2s/step - loss: 0.5825 - accuracy: 0.9603 - val_loss: 0.5395 - val_accuracy: 0.9653
Epoch 4/5
183/183 [==============================] - 286s 2s/step - loss: 0.5705 - accuracy: 0.9661 - val_loss: 0.5450 - val_accuracy: 0.9639
Epoch 5/5
183/183 [==============================] - 285s 2s/step - loss: 0.5294 - accuracy: 0.9753 - val_loss: 0.5530 - val_accuracy: 0.9639
plt.figure()
plt.ylabel("Loss (training and validation)")
plt.xlabel("Training Steps")
plt.ylim([0,2])
plt.plot(hist["loss"])
plt.plot(hist["val_loss"])

plt.figure()
plt.ylabel("Accuracy (training and validation)")
plt.xlabel("Training Steps")
plt.ylim([0,1])
plt.plot(hist["accuracy"])
plt.plot(hist["val_accuracy"])
[<matplotlib.lines.Line2D at 0x7f214a750d90>]

png

png

検証データの画像でモデルが機能するか試してみましょう。

x, y = next(iter(val_ds))
image = x[0, :, :, :]
true_index = np.argmax(y[0])
plt.imshow(image)
plt.axis('off')
plt.show()

# Expand the validation image to (1, 224, 224, 3) before predicting the label
prediction_scores = model.predict(np.expand_dims(image, axis=0))
predicted_index = np.argmax(prediction_scores)
print("True label: " + class_names[true_index])
print("Predicted label: " + class_names[predicted_index])

png

1/1 [==============================] - 6s 6s/step
True label: sunflowers
Predicted label: sunflowers

最後に次のようにして、トレーニングされたモデルを、TF Serving または TF Lite(モバイル)用に保存することができます。

saved_model_path = f"/tmp/saved_flowers_model_{model_name}"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/saved_flowers_model_efficientnetv2-xl-21k/assets
INFO:tensorflow:Assets written to: /tmp/saved_flowers_model_efficientnetv2-xl-21k/assets

オプション: TensorFlow Lite にデプロイする

TensorFlow Lite では、TensorFlow モデルをモバイルおよび IoT デバイスにデプロイすることができます。以下のコードには、トレーニングされたモデルを TF Lite に変換して、TensorFlow Model Optimization Toolkit のポストトレーニングツールを適用する方法が示されています。最後に、結果の質を調べるために、変換したモデルを TF Lite Interpreter で実行しています。

  • 最適化せずに変換すると、前と同じ結果が得られます(丸め誤差まで)。
  • データなしで最適化して変換すると、モデルの重みを 8 ビットに量子化しますが、それでもニューラルネットワークアクティベーションの推論では浮動小数点数計算が使用されます。これにより、モデルのサイズが約 4 倍に縮小されるため、モバイルデバイスの CPU レイテンシが改善されます。
  • 最上部の、ニューラルネットワークアクティベーションの計算は、量子化の範囲を調整するために小規模な参照データセットが提供される場合、8 ビット整数に量子化されます。モバイルデバイスでは、これにより推論がさらに高速化されるため、EdgeTPU などのアクセラレータで実行することが可能となります。

Optimization settings

2024-01-11 20:25:57.979361: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2024-01-11 20:25:57.979412: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
Summary on the non-converted ops:
---------------------------------

 * Accepted dialects: tfl, builtin, func
 * Non-Converted Ops: 879, Total Ops 2180, % non-converted = 40.32 %
 * 879 ARITH ops

- arith.constant:  879 occurrences  (f32: 878, i32: 1)



  (f32: 94)
  (f32: 358)
  (f32: 80)
  (f32: 1)
  (f32: 342)
  (f32: 81)
  (f32: 342)
Wrote TFLite model of 826217852 bytes.
interpreter = tf.lite.Interpreter(model_content=lite_model_content)
# This little helper wraps the TFLite Interpreter as a numpy-to-numpy function.
def lite_model(images):
  interpreter.allocate_tensors()
  interpreter.set_tensor(interpreter.get_input_details()[0]['index'], images)
  interpreter.invoke()
  return interpreter.get_tensor(interpreter.get_output_details()[0]['index'])
num_eval_examples = 50 
eval_dataset = ((image, label)  # TFLite expects batch size 1.
                for batch in train_ds
                for (image, label) in zip(*batch))
count = 0
count_lite_tf_agree = 0
count_lite_correct = 0
for image, label in eval_dataset:
  probs_lite = lite_model(image[None, ...])[0]
  probs_tf = model(image[None, ...]).numpy()[0]
  y_lite = np.argmax(probs_lite)
  y_tf = np.argmax(probs_tf)
  y_true = np.argmax(label)
  count +=1
  if y_lite == y_tf: count_lite_tf_agree += 1
  if y_lite == y_true: count_lite_correct += 1
  if count >= num_eval_examples: break
print("TFLite model agrees with original model on %d of %d examples (%g%%)." %
      (count_lite_tf_agree, count, 100.0 * count_lite_tf_agree / count))
print("TFLite model is accurate on %d of %d examples (%g%%)." %
      (count_lite_correct, count, 100.0 * count_lite_correct / count))
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
TFLite model agrees with original model on 50 of 50 examples (100%).
TFLite model is accurate on 50 of 50 examples (100%).