質問があります? TensorFlowフォーラム訪問フォーラムでコミュニティとつながる

TensorFlowLiteモデルメーカーによる画像分類

TensorFlow.orgで表示 GoogleColabで実行 GitHubでソースを表示ノートブックをダウンロードTFハブモデルを参照してください

TensorFlow Lite Model Makerライブラリは、デバイス上のMLアプリケーションにこのモデルをデプロイするときに、TensorFlowニューラルネットワークモデルを特定の入力データに適合させて変換するプロセスを簡素化します。

このノートブックは、このModel Makerライブラリを利用して、モバイルデバイスで花を分類するために一般的に使用される画像分類モデルの適応と変換を説明するエンドツーエンドの例を示しています。

前提条件

この例を実行するには、最初に、GitHub リポジトリにあるModel Makerパッケージを含む、いくつかの必要なパッケージをインストールする必要があります。

pip install -q tflite-model-maker

必要なパッケージをインポートします。

import os

import numpy as np

import tensorflow as tf
assert tf.__version__.startswith('2')

from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.image_classifier import DataLoader

import matplotlib.pyplot as plt
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/numba/core/errors.py:154: UserWarning: Insufficiently recent colorama version found. Numba requires colorama >= 0.3.9
  warnings.warn(msg)

単純なエンドツーエンドの例

データパスを取得する

この単純なエンドツーエンドの例で再生するいくつかの画像を取得しましょう。何百もの画像はModelMakerにとって良いスタートですが、より多くのデータがより良い精度を達成する可能性があります。

Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 2s 0us/step

image_pathを独自の画像フォルダに置き換えることができます。 colabへのデータのアップロードについては、下の画像に赤い長方形で示されている左側のサイドバーにアップロードボタンがあります。 zipファイルをアップロードして解凍してみてください。ルートファイルのパスは現在のパスです。

ファイルをアップロードする

画像をクラウドにアップロードしたくない場合は、GitHubの ガイドに従ってローカルでライブラリを実行してみてください。

例を実行する

この例は、以下に示すように4行のコードで構成されており、各コードはプロセス全体の1つのステップを表しています。

手順1.デバイス上のMLアプリに固有の入力データを読み込みます。それをトレーニングデータとテストデータに分割します。

data = DataLoader.from_folder(image_path)
train_data, test_data = data.split(0.9)
INFO:tensorflow:Load image with size: 3670, num_label: 5, labels: daisy, dandelion, roses, sunflowers, tulips.

手順2.TensorFlowモデルをカスタマイズします。

model = image_classifier.create(train_data)
INFO:tensorflow:Retraining the models...
WARNING:tensorflow:Please add `keras.layers.InputLayer` instead of `keras.Input` to Sequential model. `keras.Input` is intended to be used by Functional model.
WARNING:tensorflow:Please add `keras.layers.InputLayer` instead of `keras.Input` to Sequential model. `keras.Input` is intended to be used by Functional model.
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
hub_keras_layer_v1v2 (HubKer (None, 1280)              3413024   
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 5)                 6405      
=================================================================
Total params: 3,419,429
Trainable params: 6,405
Non-trainable params: 3,413,024
_________________________________________________________________
None
Epoch 1/5
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:375: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  "The `lr` argument is deprecated, use `learning_rate` instead.")
103/103 [==============================] - 12s 31ms/step - loss: 0.8568 - accuracy: 0.7797
Epoch 2/5
103/103 [==============================] - 4s 35ms/step - loss: 0.6462 - accuracy: 0.9020
Epoch 3/5
103/103 [==============================] - 4s 35ms/step - loss: 0.6186 - accuracy: 0.9190
Epoch 4/5
103/103 [==============================] - 3s 32ms/step - loss: 0.5984 - accuracy: 0.9217
Epoch 5/5
103/103 [==============================] - 4s 36ms/step - loss: 0.5853 - accuracy: 0.9336

ステップ3.モデルを評価します。

loss, accuracy = model.evaluate(test_data)
12/12 [==============================] - 2s 39ms/step - loss: 0.6424 - accuracy: 0.8828

ステップ4.TensorFlowLiteモデルにエクスポートします。

ここでは、モデル記述の標準を提供するメタデータを使用してTensorFlowLiteモデルをエクスポートします。ラベルファイルはメタデータに埋め込まれています。デフォルトのトレーニング後の量子化手法は、画像分類タスクの完全な整数量子化です。

アップロード部分と同じ左側のサイドバーからダウンロードして、自分で使用することができます。

model.export(export_dir='.')
INFO:tensorflow:Assets written to: /tmp/tmpagqtnzk2/assets
INFO:tensorflow:Assets written to: /tmp/tmpagqtnzk2/assets
WARNING:absl:For model inputs containing unsupported operations which cannot be quantized, the `inference_input_type` attribute will default to the original type.
INFO:tensorflow:Label file is inside the TFLite model with metadata.
INFO:tensorflow:Label file is inside the TFLite model with metadata.
INFO:tensorflow:Saving labels in /tmp/tmprk11gbxt/labels.txt
INFO:tensorflow:Saving labels in /tmp/tmprk11gbxt/labels.txt
INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite
INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite

これらの簡単な4つの手順の後、画像分類参照アプリなどのオンデバイスアプリケーションでTensorFlowLiteモデルファイルをさらに使用できます。

詳細なプロセス

現在、画像分類用の事前トレーニング済みモデルとして、EfficientNet-Lite *モデル、MobileNetV2、ResNet50などのいくつかのモデルをサポートしています。ただし、数行のコードでこのライブラリに新しい事前トレーニング済みモデルを追加することは非常に柔軟です。

以下では、このエンドツーエンドの例を段階的に説明して、詳細を示します。

ステップ1:デバイス上のMLアプリに固有の入力データを読み込む

花のデータセットには、5つのクラスに属する3670枚の画像が含まれています。データセットのアーカイブバージョンをダウンロードして、解凍します。

データセットのディレクトリ構造は次のとおりです。

flower_photos
|__ daisy
    |______ 100080576_f52e8ee070_n.jpg
    |______ 14167534527_781ceb1b7a_n.jpg
    |______ ...
|__ dandelion
    |______ 10043234166_e6dd915111_n.jpg
    |______ 1426682852_e62169221f_m.jpg
    |______ ...
|__ roses
    |______ 102501987_3cdb8e5394_n.jpg
    |______ 14982802401_a3dfb22afb.jpg
    |______ ...
|__ sunflowers
    |______ 12471791574_bb1be83df4.jpg
    |______ 15122112402_cafa41934f.jpg
    |______ ...
|__ tulips
    |______ 13976522214_ccec508fe7.jpg
    |______ 14487943607_651e8062a1_m.jpg
    |______ ...
image_path = tf.keras.utils.get_file(
      'flower_photos.tgz',
      'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
      extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')

DataLoaderクラスを使用してデータをロードします。

from_folder()メソッドに関しては、フォルダーからデータをロードできます。同じクラスの画像データが同じサブディレクトリにあり、サブフォルダ名がクラス名であると想定しています。現在、JPEGエンコード画像とPNGエンコード画像がサポートされています。

data = DataLoader.from_folder(image_path)
INFO:tensorflow:Load image with size: 3670, num_label: 5, labels: daisy, dandelion, roses, sunflowers, tulips.
INFO:tensorflow:Load image with size: 3670, num_label: 5, labels: daisy, dandelion, roses, sunflowers, tulips.

トレーニングデータ(80%)、検証データ(10%、オプション)、およびテストデータ(10%)に分割します。

train_data, rest_data = data.split(0.8)
validation_data, test_data = rest_data.split(0.5)

25枚の画像の例をラベル付きで表示します。

plt.figure(figsize=(10,10))
for i, (image, label) in enumerate(data.gen_dataset().unbatch().take(25)):
  plt.subplot(5,5,i+1)
  plt.xticks([])
  plt.yticks([])
  plt.grid(False)
  plt.imshow(image.numpy(), cmap=plt.cm.gray)
  plt.xlabel(data.index_to_label[label.numpy()])
plt.show()

png

ステップ2:TensorFlowモデルをカスタマイズする

ロードされたデータに基づいてカスタム画像分類モデルを作成します。デフォルトのモデルはEfficientNet-Lite0です。

model = image_classifier.create(train_data, validation_data=validation_data)
INFO:tensorflow:Retraining the models...
INFO:tensorflow:Retraining the models...
WARNING:tensorflow:Please add `keras.layers.InputLayer` instead of `keras.Input` to Sequential model. `keras.Input` is intended to be used by Functional model.
WARNING:tensorflow:Please add `keras.layers.InputLayer` instead of `keras.Input` to Sequential model. `keras.Input` is intended to be used by Functional model.
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
hub_keras_layer_v1v2_1 (HubK (None, 1280)              3413024   
_________________________________________________________________
dropout_1 (Dropout)          (None, 1280)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 5)                 6405      
=================================================================
Total params: 3,419,429
Trainable params: 6,405
Non-trainable params: 3,413,024
_________________________________________________________________
None
Epoch 1/5
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:375: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  "The `lr` argument is deprecated, use `learning_rate` instead.")
91/91 [==============================] - 6s 49ms/step - loss: 0.8756 - accuracy: 0.7689 - val_loss: 0.6990 - val_accuracy: 0.8551
Epoch 2/5
91/91 [==============================] - 4s 45ms/step - loss: 0.6632 - accuracy: 0.8856 - val_loss: 0.6706 - val_accuracy: 0.8835
Epoch 3/5
91/91 [==============================] - 4s 45ms/step - loss: 0.6257 - accuracy: 0.9090 - val_loss: 0.6627 - val_accuracy: 0.8920
Epoch 4/5
91/91 [==============================] - 4s 46ms/step - loss: 0.6064 - accuracy: 0.9214 - val_loss: 0.6562 - val_accuracy: 0.8920
Epoch 5/5
91/91 [==============================] - 4s 47ms/step - loss: 0.5894 - accuracy: 0.9303 - val_loss: 0.6510 - val_accuracy: 0.8892

詳細なモデル構造をご覧ください。

model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
hub_keras_layer_v1v2_1 (HubK (None, 1280)              3413024   
_________________________________________________________________
dropout_1 (Dropout)          (None, 1280)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 5)                 6405      
=================================================================
Total params: 3,419,429
Trainable params: 6,405
Non-trainable params: 3,413,024
_________________________________________________________________

ステップ3:カスタマイズされたモデルを評価する

モデルの結果を評価し、モデルの損失と精度を取得します。

loss, accuracy = model.evaluate(test_data)
12/12 [==============================] - 1s 24ms/step - loss: 0.5999 - accuracy: 0.9292

予測結果を100枚のテスト画像にプロットできました。赤い色の予測ラベルは間違った予測結果ですが、他のラベルは正しいです。

# A helper function that returns 'red'/'black' depending on if its two input
# parameter matches or not.
def get_label_color(val1, val2):
  if val1 == val2:
    return 'black'
  else:
    return 'red'

# Then plot 100 test images and their predicted labels.
# If a prediction result is different from the label provided label in "test"
# dataset, we will highlight it in red color.
plt.figure(figsize=(20, 20))
predicts = model.predict_top_k(test_data)
for i, (image, label) in enumerate(test_data.gen_dataset().unbatch().take(100)):
  ax = plt.subplot(10, 10, i+1)
  plt.xticks([])
  plt.yticks([])
  plt.grid(False)
  plt.imshow(image.numpy(), cmap=plt.cm.gray)

  predict_label = predicts[i][0][0]
  color = get_label_color(predict_label,
                          test_data.index_to_label[label.numpy()])
  ax.xaxis.label.set_color(color)
  plt.xlabel('Predicted: %s' % predict_label)
plt.show()

png

精度がアプリの要件を満たしていない場合は、高度な使用法を参照して、より大きなモデルへの変更、再トレーニングパラメーターの調整などの代替案を検討できます。

ステップ4:TensorFlowLiteモデルにエクスポートする

トレーニング済みモデルをメタデータを含むTensorFlowLiteモデル形式に変換して、後でデバイス上のMLアプリケーションで使用できるようにします。ラベルファイルと語彙ファイルはメタデータに埋め込まれています。デフォルトのTFLiteファイル名はmodel.tfliteです。

多くのオンデバイスMLアプリケーションでは、モデルサイズが重要な要素です。したがって、モデルを小さくし、実行速度を上げる可能性があるように、クオンタイズを適用することをお勧めします。デフォルトのトレーニング後の量子化手法は、画像分類タスクの完全な整数量子化です。

model.export(export_dir='.')
INFO:tensorflow:Assets written to: /tmp/tmp2s_q22rj/assets
INFO:tensorflow:Assets written to: /tmp/tmp2s_q22rj/assets
WARNING:absl:For model inputs containing unsupported operations which cannot be quantized, the `inference_input_type` attribute will default to the original type.
INFO:tensorflow:Label file is inside the TFLite model with metadata.
INFO:tensorflow:Label file is inside the TFLite model with metadata.
INFO:tensorflow:Saving labels in /tmp/tmpo2h6s5wc/labels.txt
INFO:tensorflow:Saving labels in /tmp/tmpo2h6s5wc/labels.txt
INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite
INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite

TensorFlow Liteモデルをモバイルアプリに統合する方法の詳細については、サンプルアプリケーションと画像分類のガイドを参照してください。

このモデルは、使用してAndroidやiOSアプリに統合することができImageClassifierのAPITensorFlow Liteのタスク・ライブラリーを

許可されるエクスポート形式は、次の1つまたはリストのいずれかです。

デフォルトでは、メタデータを使用してTensorFlowLiteモデルをエクスポートするだけです。異なるファイルを選択的にエクスポートすることもできます。たとえば、次のようにラベルファイルのみをエクスポートします。

model.export(export_dir='.', export_format=ExportFormat.LABEL)
INFO:tensorflow:Saving labels in ./labels.txt
INFO:tensorflow:Saving labels in ./labels.txt

また、 evaluate_tfliteメソッドを使用してtfliteモデルを評価することもできます。

model.evaluate_tflite('model.tflite', test_data)
{'accuracy': 0.9318801089918256}

高度な使用法

create関数は、このライブラリの重要な部分です。チュートリアルと同様に、事前にトレーニングされたモデルを使用した転移学習を使用します。

create関数には、次の手順が含まれています。

  1. パラメータvalidation_ratioおよびtest_ratioに従って、データをトレーニング、検証、テストデータに分割します。デフォルトの値validation_ratiotest_ratioある0.10.1
  2. TensorFlowHubからベースモデルとして画像特徴ベクトルをダウンロードします。デフォルトの事前トレーニング済みモデルはEfficientNet-Lite0です。
  3. ヘッドレイヤーと事前トレーニング済みモデルの間にdropout_rateを持つドロップアウトレイヤーを持つ分類器ヘッドを追加します。デフォルトのdropout_rateは、 TensorFlowHubによるmake_image_classifier_libからのデフォルトのdropout_rate値です。
  4. 生の入力データを前処理します。現在、各画像ピクセルの値をモデル入力スケールに正規化し、モデル入力サイズにサイズ変更するなどの前処理ステップ。 EfficientNet-Lite0には、入力スケール[0, 1]と入力画像サイズ[224, 224, 3] 224、224、3 [224, 224, 3]ます。
  5. データを分類器モデルにフィードします。デフォルトでは、トレーニングエポック、バッチサイズ、学習率、勢いなどのトレーニングパラメータは、 TensorFlowHubによるmake_image_classifier_libからのデフォルト値です。分類器のヘッドのみがトレーニングされます。

このセクションでは、別の画像分類モデルへの切り替え、トレーニングハイパーパラメータの変更など、いくつかの高度なトピックについて説明します。

TensorFLowLiteモデルでトレーニング後の量子化をカスタマイズする

トレーニング後の量子化は、モデルの精度を少し低下させながら、CPUとハードウェアアクセラレータの推論速度を向上させながら、モデルのサイズと推論の待ち時間を短縮できる変換手法です。したがって、モデルを最適化するために広く使用されています。

モデルメーカーライブラリは、モデルをエクスポートするときに、デフォルトのトレーニング後の量子化手法を適用します。トレーニング後の量子化をカスタマイズする場合、Model Makerは、 QuantizationConfigを使用した複数のトレーニング後の量子化オプションもサポートしています。たとえば、float16量子化を取り上げましょう。まず、量子化構成を定義します。

config = QuantizationConfig.for_float16()

次に、そのような構成でTensorFlowLiteモデルをエクスポートします。

model.export(export_dir='.', tflite_filename='model_fp16.tflite', quantization_config=config)
INFO:tensorflow:Assets written to: /tmp/tmpalsh2f4f/assets
INFO:tensorflow:Assets written to: /tmp/tmpalsh2f4f/assets
INFO:tensorflow:Label file is inside the TFLite model with metadata.
INFO:tensorflow:Label file is inside the TFLite model with metadata.
INFO:tensorflow:Saving labels in /tmp/tmpjwujjf3x/labels.txt
INFO:tensorflow:Saving labels in /tmp/tmpjwujjf3x/labels.txt
INFO:tensorflow:TensorFlow Lite model exported successfully: ./model_fp16.tflite
INFO:tensorflow:TensorFlow Lite model exported successfully: ./model_fp16.tflite

Colabでは、上記のアップロード部分と同じように、左側のサイドバーからmodel_fp16.tfliteという名前のモデルをダウンロードできます。

モデルを変更する

このライブラリでサポートされているモデルに変更します。

このライブラリは、現在、EfficientNet-Liteモデル、MobileNetV2、ResNet50をサポートしています。 EfficientNet-Liteは、最先端の精度を実現し、エッジデバイスに適した画像分類モデルのファミリーです。デフォルトのモデルはEfficientNet-Lite0です。

createメソッドでパラメータmodel_specをMobileNetV2モデル仕様に設定するだけで、モデルをMobileNetV2に切り替えることができcreate

model = image_classifier.create(train_data, model_spec=model_spec.get('mobilenet_v2'), validation_data=validation_data)
INFO:tensorflow:Retraining the models...
INFO:tensorflow:Retraining the models...
WARNING:tensorflow:Please add `keras.layers.InputLayer` instead of `keras.Input` to Sequential model. `keras.Input` is intended to be used by Functional model.
WARNING:tensorflow:Please add `keras.layers.InputLayer` instead of `keras.Input` to Sequential model. `keras.Input` is intended to be used by Functional model.
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
hub_keras_layer_v1v2_2 (HubK (None, 1280)              2257984   
_________________________________________________________________
dropout_2 (Dropout)          (None, 1280)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 5)                 6405      
=================================================================
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________
None
Epoch 1/5
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:375: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  "The `lr` argument is deprecated, use `learning_rate` instead.")
91/91 [==============================] - 7s 50ms/step - loss: 0.9188 - accuracy: 0.7565 - val_loss: 0.7582 - val_accuracy: 0.8409
Epoch 2/5
91/91 [==============================] - 4s 45ms/step - loss: 0.6916 - accuracy: 0.8712 - val_loss: 0.7522 - val_accuracy: 0.8608
Epoch 3/5
91/91 [==============================] - 4s 48ms/step - loss: 0.6505 - accuracy: 0.9062 - val_loss: 0.7361 - val_accuracy: 0.8580
Epoch 4/5
91/91 [==============================] - 4s 46ms/step - loss: 0.6184 - accuracy: 0.9210 - val_loss: 0.7362 - val_accuracy: 0.8636
Epoch 5/5
91/91 [==============================] - 4s 47ms/step - loss: 0.6006 - accuracy: 0.9313 - val_loss: 0.7257 - val_accuracy: 0.8636

新しく再トレーニングされたMobileNetV2モデルを評価して、テストデータの精度と損失を確認します。

loss, accuracy = model.evaluate(test_data)
12/12 [==============================] - 1s 32ms/step - loss: 0.6459 - accuracy: 0.9019

TensorFlowハブのモデルに変更します

さらに、画像を入力し、TensorFlowHub形式で特徴ベクトルを出力する他の新しいモデルに切り替えることもできます。

例としてInceptionV3モデルとして、 image_classifier.ModelSpecのオブジェクトであり、InceptionV3モデルの仕様を含むinception_v3_specを定義できます。

モデル名name 、TensorFlowHubモデルuri URLを指定する必要があります。一方、 input_image_shapeのデフォルト値は[224, 224] input_image_shape [224, 224]です。 Inception V3モデルでは[299, 299]に変更する必要があります。

inception_v3_spec = image_classifier.ModelSpec(
    uri='https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1')
inception_v3_spec.input_image_shape = [299, 299]

次に、 createメソッドでパラメーターmodel_specinception_v3_specに設定することにより、InceptionV3モデルを再トレーニングできます。

残りの手順はまったく同じで、最終的にカスタマイズされたInceptionV3 TensorFlowLiteモデルを取得できます。

独自のカスタムモデルを変更する

我々はTensorFlowハブにないカスタムモデルを使用したい場合は、我々は作成してエクスポートする必要がありModelSpecをTensorFlowハブに。

次に、上記のプロセスのようにModelSpecオブジェクトの定義を開始します。

トレーニングハイパーパラメータを変更する

モデルの精度に影響を与える可能性のあるepochsdropout_ratebatch_sizeなどのトレーニングハイパーパラメータを変更することもできます。調整できるモデルパラメータは次のとおりです。

  • epochs :エポックが多いほど、収束するまで精度が向上する可能性がありますが、エポックが多すぎると、過剰適合につながる可能性があります。
  • dropout_rate :ドロップアウトのレート。過剰適合を避けます。デフォルトではなし。
  • batch_size :1つのトレーニングステップで使用するサンプルの数。デフォルトではなし。
  • validation_data :検証データ。 Noneの場合、検証プロセスをスキップします。デフォルトではなし。
  • train_whole_model :trueの場合、ハブモジュールは最上位の分類レイヤーとともにトレーニングされます。それ以外の場合は、最上位の分類レイヤーのみをトレーニングします。デフォルトではなし。
  • learning_rate :基本学習率。デフォルトではなし。
  • momentum :オプティマイザーに転送されるPythonフロート。 use_hub_libraryがTrueの場合にのみ使用されます。デフォルトではなし。
  • shuffle :ブール値。データをシャッフルする必要があるかどうか。デフォルトではFalseです。
  • use_augmentation :ブール値。前処理にデータ拡張を使用します。デフォルトではFalseです。
  • use_hub_library :ブール値。テンソルフローハブのmake_image_classifier_libを使用して、モデルを再トレーニングします。このトレーニングパイプラインは、多くのカテゴリを持つ複雑なデータセットのパフォーマンスを向上させる可能性があります。デフォルトではTrue。
  • warmup_steps :学習率のウォームアップスケジュールのウォームアップステップ数。 Noneの場合、デフォルトのwarmup_stepsが使用されます。これは、2つのエポックでのトレーニングステップの合計です。 use_hub_libraryがFalseの場合にのみ使用されます。デフォルトではなし。
  • model_dir :オプションで、モデルチェックポイントファイルの場所。 use_hub_libraryがFalseの場合にのみ使用されます。デフォルトではなし。

以下のように、デフォルトではNoneですパラメータepochsでの具体的なデフォルトパラメータを取得しますmake_image_classifier_lib TensorFlowハブライブラリやからtrain_image_classifier_libを

たとえば、より多くのエポックでトレーニングすることができます。

model = image_classifier.create(train_data, validation_data=validation_data, epochs=10)
INFO:tensorflow:Retraining the models...
INFO:tensorflow:Retraining the models...
WARNING:tensorflow:Please add `keras.layers.InputLayer` instead of `keras.Input` to Sequential model. `keras.Input` is intended to be used by Functional model.
WARNING:tensorflow:Please add `keras.layers.InputLayer` instead of `keras.Input` to Sequential model. `keras.Input` is intended to be used by Functional model.
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
hub_keras_layer_v1v2_3 (HubK (None, 1280)              3413024   
_________________________________________________________________
dropout_3 (Dropout)          (None, 1280)              0         
_________________________________________________________________
dense_3 (Dense)              (None, 5)                 6405      
=================================================================
Total params: 3,419,429
Trainable params: 6,405
Non-trainable params: 3,413,024
_________________________________________________________________
None
Epoch 1/10
91/91 [==============================] - 6s 50ms/step - loss: 0.8710 - accuracy: 0.7682 - val_loss: 0.6951 - val_accuracy: 0.8722
Epoch 2/10
91/91 [==============================] - 4s 46ms/step - loss: 0.6585 - accuracy: 0.8925 - val_loss: 0.6643 - val_accuracy: 0.8807
Epoch 3/10
91/91 [==============================] - 4s 48ms/step - loss: 0.6215 - accuracy: 0.9131 - val_loss: 0.6533 - val_accuracy: 0.8835
Epoch 4/10
91/91 [==============================] - 4s 45ms/step - loss: 0.6021 - accuracy: 0.9186 - val_loss: 0.6467 - val_accuracy: 0.8864
Epoch 5/10
91/91 [==============================] - 4s 47ms/step - loss: 0.5890 - accuracy: 0.9320 - val_loss: 0.6461 - val_accuracy: 0.8835
Epoch 6/10
91/91 [==============================] - 4s 42ms/step - loss: 0.5758 - accuracy: 0.9399 - val_loss: 0.6411 - val_accuracy: 0.8864
Epoch 7/10
91/91 [==============================] - 4s 42ms/step - loss: 0.5676 - accuracy: 0.9475 - val_loss: 0.6367 - val_accuracy: 0.8949
Epoch 8/10
91/91 [==============================] - 4s 42ms/step - loss: 0.5614 - accuracy: 0.9485 - val_loss: 0.6369 - val_accuracy: 0.8920
Epoch 9/10
91/91 [==============================] - 4s 42ms/step - loss: 0.5536 - accuracy: 0.9536 - val_loss: 0.6397 - val_accuracy: 0.8920
Epoch 10/10
91/91 [==============================] - 4s 42ms/step - loss: 0.5489 - accuracy: 0.9550 - val_loss: 0.6385 - val_accuracy: 0.8977

10のトレーニングエポックで新しく再トレーニングされたモデルを評価します。

loss, accuracy = model.evaluate(test_data)
12/12 [==============================] - 1s 24ms/step - loss: 0.5883 - accuracy: 0.9428

続きを読む

画像分類の例を読んで、技術的な詳細を学ぶことができます。詳細については、以下を参照してください。