Keras Tuner の基礎

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

概要

Keras Tuner は、TensorFlow プログラム向けに最適なハイパーパラメータを選択するためのライブラリです。ユーザーの機械学習(ML)アプリケーションに適切なハイパーパラメータを選択するためのプロセスは、ハイパーパラメータチューニングまたはハイパーチューニングと呼ばれます。

ハイパーパラメータは、ML のトレーニングプロセスとトポロジーを管理する変数です。これらの変数はトレーニングプロセス中、一貫して定数を維持し、ML プログラムのパフォーマンスに直接影響を与えます。ハイパーパラメータには、以下の 2 種類があります。

  1. モデルハイパーパラメータ: 非表示レイヤーの数と幅などのモデルの選択に影響します。
  2. アルゴリズムハイパーパラメータ: 確率的勾配降下法 (SGD) の学習率や k 最近傍 (KNN) 分類器の最近傍の数など、学習アルゴリズムの速度と質に影響します。

このチュートリアルでは、Keras Tuner を使用して、画像分類アプリケーションのハイパーチューニングを実施します。

セットアップ

import tensorflow as tf
from tensorflow import keras
2022-12-15 00:15:50.634115: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-15 00:15:50.634212: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-15 00:15:50.634221: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

Keras Tuner をインストールしてインポートします。

pip install -q -U keras-tuner
import keras_tuner as kt

データセットをダウンロードして準備する

このチュートリアルでは、Keras Tuner を使用して、Fashion MNIST データセットの服飾の画像を分類する学習モデル向けに最適なハイパーパラメータを見つけます。

データを読み込みます。

(img_train, label_train), (img_test, label_test) = keras.datasets.fashion_mnist.load_data()
# Normalize pixel values between 0 and 1
img_train = img_train.astype('float32') / 255.0
img_test = img_test.astype('float32') / 255.0

モデルを定義する

ハイパーチューニングを行うモデルを構築する際、モデルアーキテクチャのほかにハイパーパラメータ検索空間も定義します。ハイパーチューニング用にセットアップするモデルをハイパーモデルと呼びます。

ハイパーモデルの定義は、以下の 2 つの方法で行います。

  • モデルビルダー関数を使用する
  • Keras Tuner API の HyperModel クラスをサブクラス化する

また、コンピュータビジョンアプリケーション用の HyperXceptionHyperResNet という 2 つの事前定義済みの HyperModel クラスも使用します。

このチュートリアルでは、モデルビルダー関数を使用して、画像分類モデルを定義します。モデルビルダー関数は、コンパイル済みのモデルを返し、インラインで定義するハイパーパラメータを使用してモデルをハイパーチューニングします。

def model_builder(hp):
  model = keras.Sequential()
  model.add(keras.layers.Flatten(input_shape=(28, 28)))

  # Tune the number of units in the first Dense layer
  # Choose an optimal value between 32-512
  hp_units = hp.Int('units', min_value=32, max_value=512, step=32)
  model.add(keras.layers.Dense(units=hp_units, activation='relu'))
  model.add(keras.layers.Dense(10))

  # Tune the learning rate for the optimizer
  # Choose an optimal value from 0.01, 0.001, or 0.0001
  hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])

  model.compile(optimizer=keras.optimizers.Adam(learning_rate=hp_learning_rate),
                loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])

  return model

チューナーをインスタンス化してハイパーチューニングを実行する

チューナーをインスタンス化して、ハイパーチューニングを実行します。Keras Tuner には、RandomSearchHyperbandBayesianOptimization、および Sklearn チューナーがあります。このチュートリアルでは、Hyperband チューナーを使用します。

Hyperband チューナーをインスタンス化するには、ハイパーモデル、最適化する objective、およびトレーニングするエポックの最大数 (max_epochs) を指定する必要があります。

tuner = kt.Hyperband(model_builder,
                     objective='val_accuracy',
                     max_epochs=10,
                     factor=3,
                     directory='my_dir',
                     project_name='intro_to_kt')

Hyperband チューニングアルゴリズムは、適応型リソース割り当てと早期停止を使用して、高パフォーマンスモデルに素早く収束させます。これは、トーナメント式のツリーを使用して行われます。アルゴリズムは、数回のエポックで大量のモデルをトレーニングし、性能の高い上位半数のモデル次のラウンドに持ち越します。Hyperband は、1 + logfactor(max_epochs) を計算し、直近の整数に繰り上げて、トーナメントでトレーニングするモデル数を決定します。

検証損失の特定の値に達した後、トレーニングを早期に停止するためのコールバックを作成します。

stop_early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)

ハイパーパラメータ検索を実行します。検索メソッドの引数は、上記のコールバックのほか、tf.keras.model.fit に使用される引数と同じです。

tuner.search(img_train, label_train, epochs=50, validation_split=0.2, callbacks=[stop_early])

# Get the optimal hyperparameters
best_hps=tuner.get_best_hyperparameters(num_trials=1)[0]

print(f"""
The hyperparameter search is complete. The optimal number of units in the first densely-connected
layer is {best_hps.get('units')} and the optimal learning rate for the optimizer
is {best_hps.get('learning_rate')}.
""")
Trial 30 Complete [00h 00m 41s]
val_accuracy: 0.890999972820282

Best val_accuracy So Far: 0.890999972820282
Total elapsed time: 00h 08m 39s
INFO:tensorflow:Oracle triggered exit

The hyperparameter search is complete. The optimal number of units in the first densely-connected
layer is 352 and the optimal learning rate for the optimizer
is 0.001.

モデルをトレーニングする

検索から取得したハイパーパラメータを使用してモデルをトレーニングするための最適なエポック数を見つけます。

# Build the model with the optimal hyperparameters and train it on the data for 50 epochs
model = tuner.hypermodel.build(best_hps)
history = model.fit(img_train, label_train, epochs=50, validation_split=0.2)

val_acc_per_epoch = history.history['val_accuracy']
best_epoch = val_acc_per_epoch.index(max(val_acc_per_epoch)) + 1
print('Best epoch: %d' % (best_epoch,))
Epoch 1/50
1500/1500 [==============================] - 5s 3ms/step - loss: 0.5018 - accuracy: 0.8220 - val_loss: 0.4314 - val_accuracy: 0.8442
Epoch 2/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.3681 - accuracy: 0.8672 - val_loss: 0.3792 - val_accuracy: 0.8604
Epoch 3/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.3345 - accuracy: 0.8762 - val_loss: 0.3690 - val_accuracy: 0.8682
Epoch 4/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.3045 - accuracy: 0.8886 - val_loss: 0.3394 - val_accuracy: 0.8755
Epoch 5/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.2870 - accuracy: 0.8938 - val_loss: 0.3368 - val_accuracy: 0.8751
Epoch 6/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2708 - accuracy: 0.8995 - val_loss: 0.3317 - val_accuracy: 0.8802
Epoch 7/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2595 - accuracy: 0.9030 - val_loss: 0.3100 - val_accuracy: 0.8894
Epoch 8/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2473 - accuracy: 0.9072 - val_loss: 0.3224 - val_accuracy: 0.8866
Epoch 9/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2368 - accuracy: 0.9112 - val_loss: 0.3311 - val_accuracy: 0.8837
Epoch 10/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2274 - accuracy: 0.9151 - val_loss: 0.3317 - val_accuracy: 0.8868
Epoch 11/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2152 - accuracy: 0.9193 - val_loss: 0.3350 - val_accuracy: 0.8848
Epoch 12/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2102 - accuracy: 0.9216 - val_loss: 0.3139 - val_accuracy: 0.8938
Epoch 13/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2020 - accuracy: 0.9237 - val_loss: 0.3175 - val_accuracy: 0.8906
Epoch 14/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1934 - accuracy: 0.9282 - val_loss: 0.3257 - val_accuracy: 0.8903
Epoch 15/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1865 - accuracy: 0.9296 - val_loss: 0.3399 - val_accuracy: 0.8907
Epoch 16/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1822 - accuracy: 0.9313 - val_loss: 0.3378 - val_accuracy: 0.8916
Epoch 17/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1762 - accuracy: 0.9336 - val_loss: 0.3275 - val_accuracy: 0.8923
Epoch 18/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1699 - accuracy: 0.9355 - val_loss: 0.3450 - val_accuracy: 0.8942
Epoch 19/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1622 - accuracy: 0.9396 - val_loss: 0.3701 - val_accuracy: 0.8880
Epoch 20/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1596 - accuracy: 0.9393 - val_loss: 0.3491 - val_accuracy: 0.8923
Epoch 21/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1547 - accuracy: 0.9415 - val_loss: 0.3982 - val_accuracy: 0.8850
Epoch 22/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1509 - accuracy: 0.9432 - val_loss: 0.3607 - val_accuracy: 0.8922
Epoch 23/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1436 - accuracy: 0.9462 - val_loss: 0.3635 - val_accuracy: 0.8947
Epoch 24/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1412 - accuracy: 0.9473 - val_loss: 0.3740 - val_accuracy: 0.8951
Epoch 25/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1359 - accuracy: 0.9488 - val_loss: 0.3809 - val_accuracy: 0.8940
Epoch 26/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1334 - accuracy: 0.9483 - val_loss: 0.3876 - val_accuracy: 0.8988
Epoch 27/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1297 - accuracy: 0.9513 - val_loss: 0.3957 - val_accuracy: 0.8943
Epoch 28/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1254 - accuracy: 0.9528 - val_loss: 0.3954 - val_accuracy: 0.8913
Epoch 29/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1205 - accuracy: 0.9541 - val_loss: 0.4045 - val_accuracy: 0.8942
Epoch 30/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1184 - accuracy: 0.9563 - val_loss: 0.3903 - val_accuracy: 0.8959
Epoch 31/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1167 - accuracy: 0.9560 - val_loss: 0.4016 - val_accuracy: 0.8961
Epoch 32/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1125 - accuracy: 0.9576 - val_loss: 0.4034 - val_accuracy: 0.8955
Epoch 33/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1117 - accuracy: 0.9581 - val_loss: 0.4231 - val_accuracy: 0.8965
Epoch 34/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1069 - accuracy: 0.9593 - val_loss: 0.4221 - val_accuracy: 0.8924
Epoch 35/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1023 - accuracy: 0.9623 - val_loss: 0.4215 - val_accuracy: 0.8954
Epoch 36/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1009 - accuracy: 0.9631 - val_loss: 0.4290 - val_accuracy: 0.8962
Epoch 37/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.0993 - accuracy: 0.9622 - val_loss: 0.4494 - val_accuracy: 0.8949
Epoch 38/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.0974 - accuracy: 0.9630 - val_loss: 0.4511 - val_accuracy: 0.8953
Epoch 39/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.0939 - accuracy: 0.9634 - val_loss: 0.4881 - val_accuracy: 0.8907
Epoch 40/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.0928 - accuracy: 0.9651 - val_loss: 0.4662 - val_accuracy: 0.8952
Epoch 41/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.0926 - accuracy: 0.9651 - val_loss: 0.4695 - val_accuracy: 0.8967
Epoch 42/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.0911 - accuracy: 0.9666 - val_loss: 0.4913 - val_accuracy: 0.8891
Epoch 43/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.0849 - accuracy: 0.9684 - val_loss: 0.5057 - val_accuracy: 0.8872
Epoch 44/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.0839 - accuracy: 0.9688 - val_loss: 0.4992 - val_accuracy: 0.8902
Epoch 45/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.0849 - accuracy: 0.9676 - val_loss: 0.4998 - val_accuracy: 0.8946
Epoch 46/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.0808 - accuracy: 0.9690 - val_loss: 0.5265 - val_accuracy: 0.8912
Epoch 47/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.0825 - accuracy: 0.9688 - val_loss: 0.4989 - val_accuracy: 0.8940
Epoch 48/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.0803 - accuracy: 0.9700 - val_loss: 0.4896 - val_accuracy: 0.8959
Epoch 49/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.0757 - accuracy: 0.9716 - val_loss: 0.5493 - val_accuracy: 0.8899
Epoch 50/50
1500/1500 [==============================] - 4s 3ms/step - loss: 0.0735 - accuracy: 0.9723 - val_loss: 0.5423 - val_accuracy: 0.8962
Best epoch: 26

ハイパーモデルを再インスタンス化し、前述の最適なエポック数でトレーニングします。

hypermodel = tuner.hypermodel.build(best_hps)

# Retrain the model
hypermodel.fit(img_train, label_train, epochs=best_epoch, validation_split=0.2)
Epoch 1/26
1500/1500 [==============================] - 5s 3ms/step - loss: 0.5010 - accuracy: 0.8224 - val_loss: 0.3890 - val_accuracy: 0.8578
Epoch 2/26
1500/1500 [==============================] - 4s 2ms/step - loss: 0.3708 - accuracy: 0.8662 - val_loss: 0.3712 - val_accuracy: 0.8637
Epoch 3/26
1500/1500 [==============================] - 4s 2ms/step - loss: 0.3325 - accuracy: 0.8777 - val_loss: 0.3565 - val_accuracy: 0.8731
Epoch 4/26
1500/1500 [==============================] - 4s 2ms/step - loss: 0.3035 - accuracy: 0.8870 - val_loss: 0.3420 - val_accuracy: 0.8759
Epoch 5/26
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2905 - accuracy: 0.8917 - val_loss: 0.3238 - val_accuracy: 0.8834
Epoch 6/26
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2715 - accuracy: 0.8988 - val_loss: 0.3311 - val_accuracy: 0.8796
Epoch 7/26
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2583 - accuracy: 0.9033 - val_loss: 0.3331 - val_accuracy: 0.8832
Epoch 8/26
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2472 - accuracy: 0.9072 - val_loss: 0.3168 - val_accuracy: 0.8866
Epoch 9/26
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2366 - accuracy: 0.9129 - val_loss: 0.3397 - val_accuracy: 0.8769
Epoch 10/26
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2283 - accuracy: 0.9151 - val_loss: 0.3723 - val_accuracy: 0.8683
Epoch 11/26
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2183 - accuracy: 0.9185 - val_loss: 0.3304 - val_accuracy: 0.8872
Epoch 12/26
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2125 - accuracy: 0.9201 - val_loss: 0.3233 - val_accuracy: 0.8887
Epoch 13/26
1500/1500 [==============================] - 4s 2ms/step - loss: 0.2039 - accuracy: 0.9230 - val_loss: 0.3346 - val_accuracy: 0.8832
Epoch 14/26
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1944 - accuracy: 0.9272 - val_loss: 0.3237 - val_accuracy: 0.8935
Epoch 15/26
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1888 - accuracy: 0.9284 - val_loss: 0.3399 - val_accuracy: 0.8865
Epoch 16/26
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1842 - accuracy: 0.9302 - val_loss: 0.3117 - val_accuracy: 0.8956
Epoch 17/26
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1757 - accuracy: 0.9333 - val_loss: 0.3336 - val_accuracy: 0.8899
Epoch 18/26
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1706 - accuracy: 0.9360 - val_loss: 0.3282 - val_accuracy: 0.8925
Epoch 19/26
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1642 - accuracy: 0.9383 - val_loss: 0.3294 - val_accuracy: 0.8893
Epoch 20/26
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1602 - accuracy: 0.9397 - val_loss: 0.3415 - val_accuracy: 0.8928
Epoch 21/26
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1532 - accuracy: 0.9416 - val_loss: 0.3773 - val_accuracy: 0.8888
Epoch 22/26
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1506 - accuracy: 0.9433 - val_loss: 0.3477 - val_accuracy: 0.8957
Epoch 23/26
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1443 - accuracy: 0.9461 - val_loss: 0.3533 - val_accuracy: 0.8939
Epoch 24/26
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1421 - accuracy: 0.9463 - val_loss: 0.3553 - val_accuracy: 0.8953
Epoch 25/26
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1376 - accuracy: 0.9491 - val_loss: 0.3485 - val_accuracy: 0.8948
Epoch 26/26
1500/1500 [==============================] - 4s 3ms/step - loss: 0.1328 - accuracy: 0.9506 - val_loss: 0.3717 - val_accuracy: 0.8923
<keras.callbacks.History at 0x7ff8dfe57580>

このチュートリアルを終了するには、テストデータでハイパーモデルを評価します。

eval_result = hypermodel.evaluate(img_test, label_test)
print("[test loss, test accuracy]:", eval_result)
313/313 [==============================] - 1s 2ms/step - loss: 0.4146 - accuracy: 0.8872
[test loss, test accuracy]: [0.4146318733692169, 0.8871999979019165]

my_dir/intro_to_kt ディレクトリには、ハイパーパラメータ検索中に実行された各トライアル(モデル構成)の詳細なログとチェックポイントが含まれます。ハイパーパラメータ検索を再実行する場合、Keras Tuner は、これらのログの既存の状態を使用して、検索を再開します。この動作を無効にするには、チューナーをインスタンス化する際に、overwrite = True 引数を追加で渡してください。

まとめ

このチュートリアルでは、Keras Tuner の使用して、モデルのハイパーパラメータを調整する方法を学習しました。Keras Tuner の調査委については、以下のその他のリソースをご覧ください。

また、モデルのハイパーパラメータを対話式で調整できる、TensorBoard の HParams Dashboard もご覧ください。