このページは Cloud Translation API によって翻訳されました。
Switch to English

画像分類子の再トレーニング

TensorFlow.orgで見る Google Colabで実行 GitHubでソースを表示 ノートブックをダウンロード

前書き

画像分類モデルには数百万のパラメータがあります。最初からそれらをトレーニングするには、ラベル付きのトレーニングデータと大量のコンピューティング能力が必要です。転移学習は、関連するタスクですでにトレーニングされたモデルの一部を取り、それを新しいモデルで再利用することにより、この多くを短縮する手法です。

このColabは、TensorFlow Hubの事前トレーニングされたTF2 SavedModelを使用して、画像の特徴を抽出し、はるかに大規模でより一般的なImageNetデータセットでトレーニングすることにより、5種類の花を分類するKerasモデルを構築する方法を示します。オプションで、新しく追加された分類子とともに、特徴抽出子をトレーニング(「微調整」)できます。

代わりにツールをお探しですか?

これはTensorFlowコーディングチュートリアルです。 TensorFlowまたはTF Liteモデルを構築するだけのツールが必要な場合は、PIPパッケージtensorflow-hub[make_image_classifier]によってインストールされるmake_image_classifierコマンドラインツール、またはこの TF Lite colabをご覧ください。

セットアップ

 import itertools
import os

import matplotlib.pylab as plt
import numpy as np

import tensorflow as tf
import tensorflow_hub as hub

print("TF version:", tf.__version__)
print("Hub version:", hub.__version__)
print("GPU is", "available" if tf.test.is_gpu_available() else "NOT AVAILABLE")
 
TF version: 2.2.0
Hub version: 0.8.0
WARNING:tensorflow:From <ipython-input-2-0831fa394ed3>:12: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.
GPU is available

使用するTF2 SavedModelモジュールを選択します

まず第一に、使用のhttps:// tfhub.dev / / imagenet / mobilenet_v2_100_224 / feature_vector / 4グーグル。同じURLをコードで使用してSavedModelを識別したり、ブラウザーでそのURLを使用してドキュメントを表示したりできます。 (TF1ハブ形式のモデルはここでは機能しないことに注意してください。)

 module_selection = ("mobilenet_v2_100_224", 224) 
handle_base, pixels = module_selection
MODULE_HANDLE ="https://  tfhub.dev  /google/imagenet/{}/feature_vector/4".format(handle_base)
IMAGE_SIZE = (pixels, pixels)
print("Using {} with input size {}".format(MODULE_HANDLE, IMAGE_SIZE))

BATCH_SIZE = 32 
 
Using https://  tfhub.dev  /google/imagenet/mobilenet_v2_100_224/feature_vector/4 with input size (224, 224)

花のデータセットを設定する

入力は、選択したモジュールに合わせて適切にサイズ変更されます。データセットの拡張(つまり、読み取られるたびに画像がランダムに歪む)により、トレーニングが向上します。微調整するとき。

 data_dir = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
 
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 4s 0us/step

 datagen_kwargs = dict(rescale=1./255, validation_split=.20)
dataflow_kwargs = dict(target_size=IMAGE_SIZE, batch_size=BATCH_SIZE,
                   interpolation="bilinear")

valid_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    **datagen_kwargs)
valid_generator = valid_datagen.flow_from_directory(
    data_dir, subset="validation", shuffle=False, **dataflow_kwargs)

do_data_augmentation = False 
if do_data_augmentation:
  train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
      rotation_range=40,
      horizontal_flip=True,
      width_shift_range=0.2, height_shift_range=0.2,
      shear_range=0.2, zoom_range=0.2,
      **datagen_kwargs)
else:
  train_datagen = valid_datagen
train_generator = train_datagen.flow_from_directory(
    data_dir, subset="training", shuffle=True, **dataflow_kwargs)
 
Found 731 images belonging to 5 classes.
Found 2939 images belonging to 5 classes.

モデルの定義

必要なのは、ハブモジュールを使用してfeature_extractor_layer上に線形分類子を配置することだけです。

速度を上げるために、最初はトレーニング不可能なfeature_extractor_layerから始めますが、精度を高めるために微調整を有効にすることもできます。

 do_fine_tuning = False 
 
 print("Building model with", MODULE_HANDLE)
model = tf.keras.Sequential([
    # Explicitly define the input shape so the model can be properly
    # loaded by the TFLiteConverter
    tf.keras.layers.InputLayer(input_shape=IMAGE_SIZE + (3,)),
    hub.KerasLayer(MODULE_HANDLE, trainable=do_fine_tuning),
    tf.keras.layers.Dropout(rate=0.2),
    tf.keras.layers.Dense(train_generator.num_classes,
                          kernel_regularizer=tf.keras.regularizers.l2(0.0001))
])
model.build((None,)+IMAGE_SIZE+(3,))
model.summary()
 
Building model with https://  tfhub.dev  /google/imagenet/mobilenet_v2_100_224/feature_vector/4
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
keras_layer (KerasLayer)     (None, 1280)              2257984   
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 5)                 6405      
=================================================================
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________

モデルのトレーニング

 model.compile(
  optimizer=tf.keras.optimizers.SGD(lr=0.005, momentum=0.9), 
  loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1),
  metrics=['accuracy'])
 
 steps_per_epoch = train_generator.samples // train_generator.batch_size
validation_steps = valid_generator.samples // valid_generator.batch_size
hist = model.fit(
    train_generator,
    epochs=5, steps_per_epoch=steps_per_epoch,
    validation_data=valid_generator,
    validation_steps=validation_steps).history
 
Epoch 1/5
91/91 [==============================] - 15s 165ms/step - loss: 0.9333 - accuracy: 0.7544 - val_loss: 0.7685 - val_accuracy: 0.8239
Epoch 2/5
91/91 [==============================] - 14s 158ms/step - loss: 0.7039 - accuracy: 0.8720 - val_loss: 0.7094 - val_accuracy: 0.8537
Epoch 3/5
91/91 [==============================] - 14s 157ms/step - loss: 0.6443 - accuracy: 0.9023 - val_loss: 0.7079 - val_accuracy: 0.8565
Epoch 4/5
91/91 [==============================] - 14s 158ms/step - loss: 0.6212 - accuracy: 0.9168 - val_loss: 0.7030 - val_accuracy: 0.8537
Epoch 5/5
91/91 [==============================] - 14s 156ms/step - loss: 0.6051 - accuracy: 0.9278 - val_loss: 0.6684 - val_accuracy: 0.8864

 plt.figure()
plt.ylabel("Loss (training and validation)")
plt.xlabel("Training Steps")
plt.ylim([0,2])
plt.plot(hist["loss"])
plt.plot(hist["val_loss"])

plt.figure()
plt.ylabel("Accuracy (training and validation)")
plt.xlabel("Training Steps")
plt.ylim([0,1])
plt.plot(hist["accuracy"])
plt.plot(hist["val_accuracy"])
 
[<matplotlib.lines.Line2D at 0x7f9737e58fd0>]

png

png

最後に、トレーニング済みモデルを保存して、次のようにTF ServingまたはTF Lite(モバイル)にデプロイできます。

 saved_model_path = "/tmp/saved_flowers_model"
tf.saved_model.save(model, saved_model_path)
 
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

INFO:tensorflow:Assets written to: /tmp/saved_flowers_model/assets

INFO:tensorflow:Assets written to: /tmp/saved_flowers_model/assets

オプション:TensorFlow Liteへのデプロイ

TensorFlow Liteを使用すると、TensorFlowモデルをモバイルデバイスやIoTデバイスにデプロイできます。以下のコードは、トレーニング済みモデルをTF Liteに変換し、 TensorFlowモデル最適化ツールキットからトレーニング後のツールを適用する方法を示しています。最後に、TF Liteインタープリターで実行して、結果の品質を調べます

  • 最適化せずに変換すると、以前と同じ結果が得られます(丸め誤差まで)。
  • データなしで最適化して変換すると、モデルの重みが8ビットに量子化されますが、推論では、ニューラルネットワークのアクティブ化に浮動小数点計算が使用されます。これにより、モデルサイズがほぼ4分の1になり、モバイルデバイスのCPUレイテンシが向上します。
  • さらに、小さな参照データセットが量子化範囲を調整するために提供されている場合は、ニューラルネットワークのアクティブ化の計算を8ビット整数に量子化することもできます。モバイルデバイスでは、これにより推論がさらに加速され、EdgeTPUなどのアクセラレータで実行できるようになります。
 
# TODO(b/156102192)
optimize_lite_model = False  

num_calibration_examples = 60  
representative_dataset = None
if optimize_lite_model and num_calibration_examples:
  # Use a bounded number of training examples without labels for calibration.
  # TFLiteConverter expects a list of input tensors, each with batch size 1.
  representative_dataset = lambda: itertools.islice(
      ([image[None, ...]] for batch, _ in train_generator for image in batch),
      num_calibration_examples)

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path)
if optimize_lite_model:
  converter.optimizations = [tf.lite.Optimize.DEFAULT]
  if representative_dataset:  # This is optional, see above.
    converter.representative_dataset = representative_dataset
lite_model_content = converter.convert()

with open("/tmp/lite_flowers_model", "wb") as f:
  f.write(lite_model_content)
print("Wrote %sTFLite model of %d bytes." %
      ("optimized " if optimize_lite_model else "", len(lite_model_content)))
 
 interpreter = tf.lite.Interpreter(model_content=lite_model_content)
# This little helper wraps the TF Lite interpreter as a numpy-to-numpy function.
def lite_model(images):
  interpreter.allocate_tensors()
  interpreter.set_tensor(interpreter.get_input_details()[0]['index'], images)
  interpreter.invoke()
  return interpreter.get_tensor(interpreter.get_output_details()[0]['index'])
 
 
num_eval_examples = 50  
eval_dataset = ((image, label)  # TFLite expects batch size 1.
                for batch in train_generator
                for (image, label) in zip(*batch))
count = 0
count_lite_tf_agree = 0
count_lite_correct = 0
for image, label in eval_dataset:
  probs_lite = lite_model(image[None, ...])[0]
  probs_tf = model(image[None, ...]).numpy()[0]
  y_lite = np.argmax(probs_lite)
  y_tf = np.argmax(probs_tf)
  y_true = np.argmax(label)
  count +=1
  if y_lite == y_tf: count_lite_tf_agree += 1
  if y_lite == y_true: count_lite_correct += 1
  if count >= num_eval_examples: break
print("TF Lite model agrees with original model on %d of %d examples (%g%%)." %
      (count_lite_tf_agree, count, 100.0 * count_lite_tf_agree / count))
print("TF Lite model is accurate on %d of %d examples (%g%%)." %
      (count_lite_correct, count, 100.0 * count_lite_correct / count))