TensorFlow Lite Model Maker による画像分類

TensorFlow.org で表示 Google Colabで実行 GitHubでソースを表示 ノートブックをダウンロード TF Hub モデルを見る

TensorFlow Lite Model Maker ライブラリは、TensorFlow ニューラルネットワークモデルを適合し、オンデバイス ML アプリケーションにこのモデルをデプロイする際の特定の入力データに変換するプロセスを単純化します。

このノートブックでは、この Model Maker を使用したエンドツーエンドの例を示し、モバイルデバイスで花を分類するために一般的に使用される画像分類モデルの適合と変換を説明します。

前提条件

この例を実行するにはまず、GitHub リポジトリ にある Model Maker パッケージなど、いくつかの必要なパッケージをインストールしてください。

pip install -q tflite-model-maker

必要なパッケージをインポートします。

import os

import numpy as np

import tensorflow as tf
assert tf.__version__.startswith('2')

from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.image_classifier import DataLoader

import matplotlib.pyplot as plt

簡単なエンドツーエンドの例

データパスの取得

この簡単なエンドツーエンドの例に使用する画像を取得しましょう。データ数が多いほどより高い精度を得ることができますが、Model Maker を使用し始めるには、数百枚の画像があれば十分です。

Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 45s 0us/step
228827136/228813984 [==============================] - 45s 0us/step

上記の image_path を自分の画像フォルダに置き換えてください。Colab にデータをアップロードする場合は、下の画像の赤く囲まれたアップロードボタンを使用できます。Zip ファイルをアップロードし、解凍してみてください。ルートファイルパスは現在のパスです。

Upload File

画像をクラウドにアップロードしない場合は、GitHub のガイドに従って、ローカルでライブラリを実行できます。

例の実行

以下に示される通り、例は 4 行で構成されています。各行は、プロセス全体の 1 ステップを表します。

ステップ 1. オンデバイス ML アプリに固有の入力データを読み込み、トレーニングデータとテストデータに分割します。

data = DataLoader.from_folder(image_path)
train_data, test_data = data.split(0.9)
INFO:tensorflow:Load image with size: 3670, num_label: 5, labels: daisy, dandelion, roses, sunflowers, tulips.

ステップ 2. TensorFlow モデルをカスタマイズします。

model = image_classifier.create(train_data)
INFO:tensorflow:Retraining the models...
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 hub_keras_layer_v1v2 (HubKe  (None, 1280)             3413024   
 rasLayerV1V2)                                                   
                                                                 
 dropout (Dropout)           (None, 1280)              0         
                                                                 
 dense (Dense)               (None, 5)                 6405      
                                                                 
=================================================================
Total params: 3,419,429
Trainable params: 6,405
Non-trainable params: 3,413,024
_________________________________________________________________
None
Epoch 1/5
103/103 [==============================] - 6s 28ms/step - loss: 0.8594 - accuracy: 0.7752
Epoch 2/5
103/103 [==============================] - 3s 29ms/step - loss: 0.6545 - accuracy: 0.8941
Epoch 3/5
103/103 [==============================] - 3s 28ms/step - loss: 0.6165 - accuracy: 0.9202
Epoch 4/5
103/103 [==============================] - 3s 28ms/step - loss: 0.5948 - accuracy: 0.9287
Epoch 5/5
103/103 [==============================] - 3s 28ms/step - loss: 0.5873 - accuracy: 0.9329

ステップ 3. モデルを評価します。

loss, accuracy = model.evaluate(test_data)
12/12 [==============================] - 1s 37ms/step - loss: 0.6266 - accuracy: 0.9074

ステップ 4. TensorFlow Lite モデルをエクスポートします。

ここでは、モデル記述の標準を提供するメタデータで TensorFlow Lite モデルをエクスポートします。ラベルファイルはメタデータに埋め込まれます。デフォルトのポストトレーニング量子化手法は、画像分類タスクの完全整数量子化です。

アップロードと同じように、左サイドバーでダウンロードできます。

model.export(export_dir='.')
2022-08-04 21:07:38.861701: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpjvt16e7h/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpjvt16e7h/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:746: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn("Statistics for quantized inputs were expected, but not "
2022-08-04 21:07:46.009436: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:357] Ignored output_format.
2022-08-04 21:07:46.009511: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:360] Ignored drop_control_dependency.
INFO:tensorflow:Label file is inside the TFLite model with metadata.
fully_quantize: 0, inference_type: 6, input_inference_type: 3, output_inference_type: 3
INFO:tensorflow:Label file is inside the TFLite model with metadata.
INFO:tensorflow:Saving labels in /tmpfs/tmp/tmpi9va8ih6/labels.txt
INFO:tensorflow:Saving labels in /tmpfs/tmp/tmpi9va8ih6/labels.txt
INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite
INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite

上記の簡単な手順を実行したら、画像分類の参照アプリのようなオンデバイスアプリケーションで、TensorFlow Lite モデルファイルとラベルファイルを使用できるようになります。

詳細なプロセス

現在のところ、EfficientNet-Lite* モデル、MobileNetV2、ResNet50 などの複数のモデルが画像分類用に事前トレーニングされたモデルとしてサポートされています。ただし、非常に柔軟性に優れているため、わずか数行のコードで、新しいトレーニング済みのモデルをこのライブラリに追加することができます。

次のウォークスルーでは、このエンドツーエンドの例の詳細を手順を追って説明します。

ステップ 1: オンデバイス ML アプリ固有の入力データを読み込む

flower データセットには、5 つのクラスに属する 3670 個の画像が含まれます。データセットのアーカイブバージョンをダウンロードして解凍してください。

データセットには次のディレクトリ構造があります。

<b>flower_photos</b>
|__ <b>daisy</b>
    |______ 100080576_f52e8ee070_n.jpg
    |______ 14167534527_781ceb1b7a_n.jpg
    |______ ...
|__ <b>dandelion</b>
    |______ 10043234166_e6dd915111_n.jpg
    |______ 1426682852_e62169221f_m.jpg
    |______ ...
|__ <b>roses</b>
    |______ 102501987_3cdb8e5394_n.jpg
    |______ 14982802401_a3dfb22afb.jpg
    |______ ...
|__ <b>sunflowers</b>
    |______ 12471791574_bb1be83df4.jpg
    |______ 15122112402_cafa41934f.jpg
    |______ ...
|__ <b>tulips</b>
    |______ 13976522214_ccec508fe7.jpg
    |______ 14487943607_651e8062a1_m.jpg
    |______ ...
image_path = tf.keras.utils.get_file(
      'flower_photos.tgz',
      'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
      extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')

DataLoader クラスを使用して、データを読み込みます。

from_folder() メソッドについては、フォルダからデータを読み込むことができます。同じクラスの画像データは同じサブディレクトリに存在し、サブフォルダ名はクラス名であることを前提とします。現在のところ、JPEG エンコード画像と PNG エンコード画像がサポートされています。

data = DataLoader.from_folder(image_path)
INFO:tensorflow:Load image with size: 3670, num_label: 5, labels: daisy, dandelion, roses, sunflowers, tulips.
INFO:tensorflow:Load image with size: 3670, num_label: 5, labels: daisy, dandelion, roses, sunflowers, tulips.

これをトレーニングデータ(80%)、検証データ(10% - オプション)、およびテストデータ(10%)に分割します。

train_data, rest_data = data.split(0.8)
validation_data, test_data = rest_data.split(0.5)

ラベル付きの 25 個の画像サンプルを表示します。

plt.figure(figsize=(10,10))
for i, (image, label) in enumerate(data.gen_dataset().unbatch().take(25)):
  plt.subplot(5,5,i+1)
  plt.xticks([])
  plt.yticks([])
  plt.grid(False)
  plt.imshow(image.numpy(), cmap=plt.cm.gray)
  plt.xlabel(data.index_to_label[label.numpy()])
plt.show()

png

ステップ 2: TensorFlow モデルをカスタマイズする

読み込んだデータをもとに、カスタム画像分類器モデルを作成します。デフォルトのモデルは EfficientNet-Lite0 です。

model = image_classifier.create(train_data, validation_data=validation_data)
INFO:tensorflow:Retraining the models...
INFO:tensorflow:Retraining the models...
Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 hub_keras_layer_v1v2_1 (Hub  (None, 1280)             3413024   
 KerasLayerV1V2)                                                 
                                                                 
 dropout_1 (Dropout)         (None, 1280)              0         
                                                                 
 dense_1 (Dense)             (None, 5)                 6405      
                                                                 
=================================================================
Total params: 3,419,429
Trainable params: 6,405
Non-trainable params: 3,413,024
_________________________________________________________________
None
Epoch 1/5
91/91 [==============================] - 6s 42ms/step - loss: 0.9027 - accuracy: 0.7541 - val_loss: 0.6885 - val_accuracy: 0.8883
Epoch 2/5
91/91 [==============================] - 3s 36ms/step - loss: 0.6651 - accuracy: 0.8887 - val_loss: 0.6585 - val_accuracy: 0.8965
Epoch 3/5
91/91 [==============================] - 3s 36ms/step - loss: 0.6244 - accuracy: 0.9114 - val_loss: 0.6430 - val_accuracy: 0.8856
Epoch 4/5
91/91 [==============================] - 3s 36ms/step - loss: 0.6022 - accuracy: 0.9289 - val_loss: 0.6297 - val_accuracy: 0.8856
Epoch 5/5
91/91 [==============================] - 3s 36ms/step - loss: 0.5861 - accuracy: 0.9327 - val_loss: 0.6235 - val_accuracy: 0.8965

モデル構造を詳しく確認します。

model.summary()
Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 hub_keras_layer_v1v2_1 (Hub  (None, 1280)             3413024   
 KerasLayerV1V2)                                                 
                                                                 
 dropout_1 (Dropout)         (None, 1280)              0         
                                                                 
 dense_1 (Dense)             (None, 5)                 6405      
                                                                 
=================================================================
Total params: 3,419,429
Trainable params: 6,405
Non-trainable params: 3,413,024
_________________________________________________________________

ステップ 3: カスタマイズ済みのモデルを評価する

モデルの結果を評価し、モデルの損失と精度を取得します。

loss, accuracy = model.evaluate(test_data)
12/12 [==============================] - 1s 27ms/step - loss: 0.6209 - accuracy: 0.9101

100 個のテスト画像で予測結果を描画できます。赤色の予測ラベルは誤った予測結果を表し、ほかは正しい結果を表します。

# A helper function that returns 'red'/'black' depending on if its two input
# parameter matches or not.
def get_label_color(val1, val2):
  if val1 == val2:
    return 'black'
  else:
    return 'red'

# Then plot 100 test images and their predicted labels.
# If a prediction result is different from the label provided label in "test"
# dataset, we will highlight it in red color.
plt.figure(figsize=(20, 20))
predicts = model.predict_top_k(test_data)
for i, (image, label) in enumerate(test_data.gen_dataset().unbatch().take(100)):
  ax = plt.subplot(10, 10, i+1)
  plt.xticks([])
  plt.yticks([])
  plt.grid(False)
  plt.imshow(image.numpy(), cmap=plt.cm.gray)

  predict_label = predicts[i][0][0]
  color = get_label_color(predict_label,
                          test_data.index_to_label[label.numpy()])
  ax.xaxis.label.set_color(color)
  plt.xlabel('Predicted: %s' % predict_label)
plt.show()

png

精度がアプリの要件を満たさない場合は、高度な使用を参照し、より大規模なモデルに変更したり、再トレーニングパラメータを調整したりといった別の手段を調べてください。

ステップ 4: TensorFlow Lite モデルにエクスポートする

トレーニングされたモデルをメタデータで TensorFlow Lite モデル形式に変換し、後でオンデバイス ML アプリケーションで使用できるようにします。ラベルファイルと語彙ファイルはメタデータに埋め込まれています。デフォルトの TFLite ファイル名は model.tflite です。

多くのオンデバイス ML アプリケーションでは、モデルサイズが重要な要因です。そのため、モデルの量子化を適用して小さくし、実行速度を高められるようにすることをお勧めします。デフォルトのポストトレーニング量子化手法は、画像分類タスクの完全整数量子化です。

model.export(export_dir='.')
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpb3z_yy8_/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpb3z_yy8_/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:746: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn("Statistics for quantized inputs were expected, but not "
2022-08-04 21:09:13.647158: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:357] Ignored output_format.
2022-08-04 21:09:13.647229: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:360] Ignored drop_control_dependency.
INFO:tensorflow:Label file is inside the TFLite model with metadata.
fully_quantize: 0, inference_type: 6, input_inference_type: 3, output_inference_type: 3
INFO:tensorflow:Label file is inside the TFLite model with metadata.
INFO:tensorflow:Saving labels in /tmpfs/tmp/tmpyb2hz81i/labels.txt
INFO:tensorflow:Saving labels in /tmpfs/tmp/tmpyb2hz81i/labels.txt
INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite
INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite

TensorFlow Lite モデルをモバイルアプリに統合する方法については、画像分類のサンプルアプリケーションとガイドをご覧ください。

このモデルは、ImageClassifier APITensorFlow Lite Task ライブラリ)を使って Android または iOS アプリに統合することができます。

次のいずれかのエクスポートフォーマットを使用できます。

デフォルトでは、メタデータとともに TensorFlow Lite モデルをエクスポートするだけです。さまざまなファイルを選択的にエクスポートすることも可能です。たとえば、ラベルファイルのみをエクスポートする場合は、次のように行います。

model.export(export_dir='.', export_format=ExportFormat.LABEL)
INFO:tensorflow:Saving labels in ./labels.txt
INFO:tensorflow:Saving labels in ./labels.txt

また、evaluate_tflite メソッドを使って tflite を評価することもできます。

model.evaluate_tflite('model.tflite', test_data)
{'accuracy': 0.9073569482288828}

高度な使用

このライブラリでは、create 関数が非常に重要な役割を果たします。この関数は、チュートリアルと同様に、トレーニング済みのモデルで転移学習を使います。

create 関数には、次のステップが含まれます。

  1. パラメータ validation_ratiotest_ratio に基づき、データをトレーニング、検証、テストのデータに分割します。validation_ratiotest_ratio のデフォルト値は、0.10.1 です。
  2. ベースモデルとして、TensorFlow Hub から Image Feature Vector をダウンロードします。デフォルトのトレーニング済みモデルは EfficientNet-Lite0 です。
  3. ヘッドレイヤーとトレーニング済みモデルの間に、dropout_rate を使用して、ドロップアウトレイヤー付きの分類器ヘッドを追加します。デフォルトのdropout_rate は TensorFlow Hub の make_image_classifier_lib のデフォルトの dropout_rate 値です。
  4. 生の入力データを前処理します。現在のところ、前処理ステップには、各画像ピクセルの値をモデルの入力スケールに正規化し、モデルの入力サイズにサイズ変更することが含まれます。EfficientNet-Lite0 の入力スケールは [0, 1]、入力画像サイズは [224, 224, 3] です。
  5. データを分類器モデルにフィードします。デフォルトでは、トレーニングエポック、バッチサイズ、学習率、運動量などのトレーニングパラメータは、TensorFlow Hub のmake_image_classifier_lib のデフォルト値です。分類器ヘッドのみがトレーニングされています。

このセクションでは、異なる画像分類モデルへの切り替えやトレーニングハイパーパラメータの変更など、いくつかの高度なトピックを説明します。

TensorFlow Lite モデルでポストトレーニング量子化をカスタマイズする

ポストトレーニング量子化は、モデルサイズと推論レイテンシを縮小できる変換テクニックです。このテクニックでは、モデル精度にほとんど影響することなく、CPU とハードウェアアクセラレータの推論速度も改善することができます。したがって、モデルを改善するために広く使われています。

Model Maker ライブラリは、モデルをエクスポートする際に、デフォルトのポストトレーニング量子化手法を適用します。ポストトレーニング量子化をカスタマイズするのであれば、Model Maker は、QuantizationConfig を使った複数のポストトレーニング量子化オプションもサポートしています。例として、float16 量子化を見てみましょう。まず、量子化構成を定義します。

config = QuantizationConfig.for_float16()

次に、その構成で TensorFlow Lite モデルをエクスポートします。

model.export(export_dir='.', tflite_filename='model_fp16.tflite', quantization_config=config)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpcvg1psqi/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpcvg1psqi/assets
INFO:tensorflow:Label file is inside the TFLite model with metadata.
2022-08-04 21:16:52.236213: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:357] Ignored output_format.
2022-08-04 21:16:52.236277: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:360] Ignored drop_control_dependency.
INFO:tensorflow:Label file is inside the TFLite model with metadata.
INFO:tensorflow:Saving labels in /tmpfs/tmp/tmplkfljo0n/labels.txt
INFO:tensorflow:Saving labels in /tmpfs/tmp/tmplkfljo0n/labels.txt
INFO:tensorflow:TensorFlow Lite model exported successfully: ./model_fp16.tflite
INFO:tensorflow:TensorFlow Lite model exported successfully: ./model_fp16.tflite

Colab では、前述のアップロード手順と同様に、左サイドバーから model_fp16.tflite というモデルをダウンロードできます。

モデルを変更する

このライブラリでサポートされているモデルに変更する

このライブラリは、EfficientNet-Lite モデル、MobileNetV2、ResNet50 をサポートします。EfficientNet-Lite は、最新の精度を達成し、エッジデバイスに適切した一群の画像分類モデルです。デフォルトのモデルは EfficientNet-Lite0 です。

このモデルを、create メソッドのパラメータ model_spec を MobileNet_v2_spec に設定することで、MobileNetV2 に切り替えることができます。

model = image_classifier.create(train_data, model_spec=model_spec.get('mobilenet_v2'), validation_data=validation_data)
INFO:tensorflow:Retraining the models...
INFO:tensorflow:Retraining the models...
Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 hub_keras_layer_v1v2_2 (Hub  (None, 1280)             2257984   
 KerasLayerV1V2)                                                 
                                                                 
 dropout_2 (Dropout)         (None, 1280)              0         
                                                                 
 dense_2 (Dense)             (None, 5)                 6405      
                                                                 
=================================================================
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________
None
Epoch 1/5
91/91 [==============================] - 7s 40ms/step - loss: 0.9210 - accuracy: 0.7545 - val_loss: 0.7319 - val_accuracy: 0.8610
Epoch 2/5
91/91 [==============================] - 3s 33ms/step - loss: 0.6888 - accuracy: 0.8788 - val_loss: 0.7115 - val_accuracy: 0.8610
Epoch 3/5
91/91 [==============================] - 3s 33ms/step - loss: 0.6501 - accuracy: 0.9018 - val_loss: 0.7060 - val_accuracy: 0.8610
Epoch 4/5
91/91 [==============================] - 3s 33ms/step - loss: 0.6225 - accuracy: 0.9217 - val_loss: 0.6986 - val_accuracy: 0.8638
Epoch 5/5
91/91 [==============================] - 3s 34ms/step - loss: 0.6071 - accuracy: 0.9272 - val_loss: 0.6885 - val_accuracy: 0.8856

新たにトレーニングした MobileNetV2 モデルを評価し、テストデータで精度と損失を確認します。

loss, accuracy = model.evaluate(test_data)
12/12 [==============================] - 1s 24ms/step - loss: 0.6786 - accuracy: 0.8965

TensorFlow Hub にモデルに変更する

さらに、画像を入力し、TensorFlow Hub 形式の特徴ベクトルを出力する他の新しいモデルに切り替えることもできます。

Inception V3 モデルを例とすると、inception_v3_spec を定義することができます。これには、image_classifier.ModelSpec のオブジェクトであり、Inception V3 モデルの仕様が含まれます。

モデル名 name、TensorFlow Hub モデルの URL uri を指定する必要があります。その間、input_image_shape のデフォルト値は [224, 224] です。これを Inception V3 モデルの [299, 299] に変更する必要があります。

inception_v3_spec = image_classifier.ModelSpec(
    uri='https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1')
inception_v3_spec.input_image_shape = [299, 299]

次に、create メソッドでパラメータ model_specinception_v3_spec に設定することで、Inception V3 モデルを再トレーニングすることができます。

残りのステップはまったく同じで、最終的にカスタマイズされた InceptionV3 TensorFlow Lite モデルを得ることができます。

独自のカスタムモデルを変更する

TensorFlow Hub にないカスタムモデルを使用する場合は、ModelSpec を作成して TensorFlow Hub にエクスポートする必要があります。

次に、上記のプロセスのように ImageModelSpec オブジェクトを定義し始めます。

トレーニングハイパーパラメータを変更する

また、モデルの精度に影響する epochsdropout_rate、および batch_size などのトレーニングハイパーパラメータも変更できます。以下は、調整できるモデルパラメータです。

  • epochs: エポックを増やすと、収束するまでより優れた精度を達成できますが、エポック数が多すぎると、トレーニングは過適合となる可能性があります。
  • dropout_rate: ドロップアウト率。過適合を回避します。
  • batch_size: 1 つのトレーニングステップに使用するサンプル数。デフォルトは None。
  • validation_data: 検証データ。None の場合は、検証をスキップします。デフォルトは None。
  • train_whole_model: true の場合、Hub モジュールは上の分類レイヤーとともにトレーニングされます。そうでない場合は、上の分類レイヤーのみがトレーニングされます。デフォルトは None です。
  • learning_rate: 基本学習率。デフォルトは None です。
  • momentum: オプティマイザに転送される Python float。use_hub_library が True の場合にのみ使用されます。デフォルトは None です。
  • shuffle: データをシャッフルするかどうかを決めるブール型。デフォルトは False です。
  • use_augmentation: 前処理にデータ拡張を行うかを決めるブール型。デフォルトは False です。
  • use_hub_library: モデルの再トレーニングに TensorFlow Hub の make_image_classifier_lib を使用するかを決めるブール型。このトレーニングパイプラインは、多数のカテゴリを持つ複雑なデータセットのパフォーマンスを改善する可能性があります。デフォルトは True です。
  • warmup_steps: 学習率に関するウォームアップスケジュールのウォームアップステップ数。None の場合、2 エポックの傍系トレーニングステップ数であるデフォルトの warmup_steps が使用されます。use_hub_library が False の場合にのみ使用されます。デフォルトは None です。
  • model_dir: オプション。モデルチェックポイントファイルの場所です。use_hub_library がFalse の場合にのみ使用されます。デフォルトは None です。

epochs など、デフォルトが None であるパラメータは、TensorFlow Hub library の make_image_classifier_lib または train_image_classifier_lib にある具体的なデフォルトパラメータを取得します。

たとえば、エポック数を増やしてトレーニングすることができます。

model = image_classifier.create(train_data, validation_data=validation_data, epochs=10)
INFO:tensorflow:Retraining the models...
INFO:tensorflow:Retraining the models...
Model: "sequential_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 hub_keras_layer_v1v2_3 (Hub  (None, 1280)             3413024   
 KerasLayerV1V2)                                                 
                                                                 
 dropout_3 (Dropout)         (None, 1280)              0         
                                                                 
 dense_3 (Dense)             (None, 5)                 6405      
                                                                 
=================================================================
Total params: 3,419,429
Trainable params: 6,405
Non-trainable params: 3,413,024
_________________________________________________________________
None
Epoch 1/10
91/91 [==============================] - 6s 42ms/step - loss: 0.8802 - accuracy: 0.7641 - val_loss: 0.6785 - val_accuracy: 0.8801
Epoch 2/10
91/91 [==============================] - 3s 36ms/step - loss: 0.6549 - accuracy: 0.8990 - val_loss: 0.6455 - val_accuracy: 0.8965
Epoch 3/10
91/91 [==============================] - 3s 36ms/step - loss: 0.6209 - accuracy: 0.9159 - val_loss: 0.6332 - val_accuracy: 0.8992
Epoch 4/10
91/91 [==============================] - 3s 36ms/step - loss: 0.6041 - accuracy: 0.9248 - val_loss: 0.6256 - val_accuracy: 0.8992
Epoch 5/10
91/91 [==============================] - 3s 36ms/step - loss: 0.5914 - accuracy: 0.9310 - val_loss: 0.6203 - val_accuracy: 0.9074
Epoch 6/10
91/91 [==============================] - 3s 36ms/step - loss: 0.5739 - accuracy: 0.9451 - val_loss: 0.6172 - val_accuracy: 0.9128
Epoch 7/10
91/91 [==============================] - 3s 36ms/step - loss: 0.5716 - accuracy: 0.9444 - val_loss: 0.6163 - val_accuracy: 0.9046
Epoch 8/10
91/91 [==============================] - 3s 36ms/step - loss: 0.5661 - accuracy: 0.9461 - val_loss: 0.6133 - val_accuracy: 0.9128
Epoch 9/10
91/91 [==============================] - 3s 36ms/step - loss: 0.5578 - accuracy: 0.9495 - val_loss: 0.6141 - val_accuracy: 0.9074
Epoch 10/10
91/91 [==============================] - 3s 36ms/step - loss: 0.5488 - accuracy: 0.9581 - val_loss: 0.6076 - val_accuracy: 0.9210

新たに再トレーニングされたモデルを 10 個のトレーニングエポックで評価します。

loss, accuracy = model.evaluate(test_data)
12/12 [==============================] - 1s 27ms/step - loss: 0.6170 - accuracy: 0.9019

その他の資料

技術的な詳細については、画像分類の例をご覧ください。詳細については、以下をご覧ください。