Keras Tuner 简介

使用集合让一切井井有条 根据您的偏好保存内容并对其进行分类。

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

概述

Keras Tuner 是一个库,可帮助您为 TensorFlow 程序选择最佳的超参数集。为您的机器学习 (ML) 应用选择正确的超参数集,这一过程称为超参数调节超调

超参数是控制训练过程和 ML 模型拓扑的变量。这些变量在训练过程中保持不变,并会直接影响 ML 程序的性能。超参数有两种类型:

  1. 模型超参数:影响模型的选择,例如隐藏层的数量和宽度
  2. 算法超参数:影响学习算法的速度和质量,例如随机梯度下降 (SGD) 的学习率以及 k 近邻 (KNN) 分类器的近邻数

在本教程中,您将使用 Keras Tuner 对图像分类应用执行超调。

设置

import tensorflow as tf
from tensorflow import keras
2022-08-31 04:58:31.497507: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-08-31 04:58:32.213817: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-31 04:58:32.214070: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-31 04:58:32.214083: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

安装并导入 Keras Tuner。

pip install -q -U keras-tuner
import keras_tuner as kt

下载并准备数据集

在本教程中,您将使用 Keras Tuner 为某个对 Fashion MNIST 数据集内的服装图像进行分类的机器学习模型找到最佳超参数。

加载数据。

(img_train, label_train), (img_test, label_test) = keras.datasets.fashion_mnist.load_data()
# Normalize pixel values between 0 and 1
img_train = img_train.astype('float32') / 255.0
img_test = img_test.astype('float32') / 255.0

定义模型

构建用于超调的模型时,除了模型架构之外,还要定义超参数搜索空间。您为超调设置的模型称为超模型

您可以通过两种方式定义超模型:

  • 使用模型构建工具函数
  • 将 Keras Tuner API 的 HyperModel 类子类化

您还可以将两个预定义的 HyperModelHyperXceptionHyperResNet 用于计算机视觉应用。

在本教程中,您将使用模型构建工具函数来定义图像分类模型。模型构建工具函数将返回已编译的模型,并使用您以内嵌方式定义的超参数对模型进行超调。

def model_builder(hp):
  model = keras.Sequential()
  model.add(keras.layers.Flatten(input_shape=(28, 28)))

  # Tune the number of units in the first Dense layer
  # Choose an optimal value between 32-512
  hp_units = hp.Int('units', min_value=32, max_value=512, step=32)
  model.add(keras.layers.Dense(units=hp_units, activation='relu'))
  model.add(keras.layers.Dense(10))

  # Tune the learning rate for the optimizer
  # Choose an optimal value from 0.01, 0.001, or 0.0001
  hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])

  model.compile(optimizer=keras.optimizers.Adam(learning_rate=hp_learning_rate),
                loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])

  return model

实例化调节器并执行超调

实例化调节器以执行超调。Keras Tuner 提供了四种调节器:RandomSearchHyperbandBayesianOptimizationSklearn。在本教程中,您将使用 Hyperband 调节器。

要实例化 Hyperband 调节器,必须指定超模型、要优化的 objective 和要训练的最大周期数 (max_epochs)。

tuner = kt.Hyperband(model_builder,
                     objective='val_accuracy',
                     max_epochs=10,
                     factor=3,
                     directory='my_dir',
                     project_name='intro_to_kt')

Hyperband 调节算法使用自适应资源分配和早停法来快速收敛到高性能模型。该过程采用了体育竞技争冠模式的排除法。算法会将大量模型训练多个周期,并仅将性能最高的一半模型送入下一轮训练。Hyperband 通过计算 1 + logfactor(max_epochs) 并将其向上舍入到最接近的整数来确定要训练的模型的数量。

创建回调以在验证损失达到特定值后提前停止训练。

stop_early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)

运行超参数搜索。除了上面的回调外,搜索方法的参数也与 tf.keras.model.fit 所用参数相同。

tuner.search(img_train, label_train, epochs=50, validation_split=0.2, callbacks=[stop_early])

# Get the optimal hyperparameters
best_hps=tuner.get_best_hyperparameters(num_trials=1)[0]

print(f"""
The hyperparameter search is complete. The optimal number of units in the first densely-connected
layer is {best_hps.get('units')} and the optimal learning rate for the optimizer
is {best_hps.get('learning_rate')}.
""")
Trial 30 Complete [00h 00m 36s]
val_accuracy: 0.8759999871253967

Best val_accuracy So Far: 0.887333333492279
Total elapsed time: 00h 07m 28s
INFO:tensorflow:Oracle triggered exit

The hyperparameter search is complete. The optimal number of units in the first densely-connected
layer is 192 and the optimal learning rate for the optimizer
is 0.001.

训练模型

使用从搜索中获得的超参数找到训练模型的最佳周期数。

# Build the model with the optimal hyperparameters and train it on the data for 50 epochs
model = tuner.hypermodel.build(best_hps)
history = model.fit(img_train, label_train, epochs=50, validation_split=0.2)

val_acc_per_epoch = history.history['val_accuracy']
best_epoch = val_acc_per_epoch.index(max(val_acc_per_epoch)) + 1
print('Best epoch: %d' % (best_epoch,))
Epoch 1/50
1500/1500 [==============================] - 4s 2ms/step - loss: 0.5081 - accuracy: 0.8205 - val_loss: 0.3976 - val_accuracy: 0.8615
Epoch 2/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.3815 - accuracy: 0.8616 - val_loss: 0.3777 - val_accuracy: 0.8647
Epoch 3/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.3387 - accuracy: 0.8771 - val_loss: 0.3343 - val_accuracy: 0.8793
Epoch 4/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.3135 - accuracy: 0.8851 - val_loss: 0.3447 - val_accuracy: 0.8777
Epoch 5/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2946 - accuracy: 0.8910 - val_loss: 0.3187 - val_accuracy: 0.8843
Epoch 6/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2791 - accuracy: 0.8957 - val_loss: 0.3390 - val_accuracy: 0.8797
Epoch 7/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2665 - accuracy: 0.9010 - val_loss: 0.3288 - val_accuracy: 0.8785
Epoch 8/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2529 - accuracy: 0.9068 - val_loss: 0.3263 - val_accuracy: 0.8817
Epoch 9/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2421 - accuracy: 0.9096 - val_loss: 0.3267 - val_accuracy: 0.8882
Epoch 10/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2322 - accuracy: 0.9135 - val_loss: 0.3045 - val_accuracy: 0.8922
Epoch 11/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2252 - accuracy: 0.9160 - val_loss: 0.3102 - val_accuracy: 0.8940
Epoch 12/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2173 - accuracy: 0.9181 - val_loss: 0.3215 - val_accuracy: 0.8908
Epoch 13/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2087 - accuracy: 0.9217 - val_loss: 0.3235 - val_accuracy: 0.8910
Epoch 14/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2010 - accuracy: 0.9253 - val_loss: 0.3191 - val_accuracy: 0.8938
Epoch 15/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1939 - accuracy: 0.9285 - val_loss: 0.3409 - val_accuracy: 0.8881
Epoch 16/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1888 - accuracy: 0.9298 - val_loss: 0.3285 - val_accuracy: 0.8949
Epoch 17/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1843 - accuracy: 0.9301 - val_loss: 0.3573 - val_accuracy: 0.8884
Epoch 18/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1752 - accuracy: 0.9340 - val_loss: 0.3179 - val_accuracy: 0.8963
Epoch 19/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1716 - accuracy: 0.9362 - val_loss: 0.3325 - val_accuracy: 0.8928
Epoch 20/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1672 - accuracy: 0.9374 - val_loss: 0.3874 - val_accuracy: 0.8819
Epoch 21/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1619 - accuracy: 0.9392 - val_loss: 0.3625 - val_accuracy: 0.8905
Epoch 22/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1563 - accuracy: 0.9416 - val_loss: 0.3905 - val_accuracy: 0.8863
Epoch 23/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1534 - accuracy: 0.9426 - val_loss: 0.3585 - val_accuracy: 0.8957
Epoch 24/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1487 - accuracy: 0.9445 - val_loss: 0.3646 - val_accuracy: 0.8906
Epoch 25/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1448 - accuracy: 0.9464 - val_loss: 0.3592 - val_accuracy: 0.8942
Epoch 26/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1410 - accuracy: 0.9465 - val_loss: 0.3615 - val_accuracy: 0.8980
Epoch 27/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1382 - accuracy: 0.9482 - val_loss: 0.3929 - val_accuracy: 0.8906
Epoch 28/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1305 - accuracy: 0.9511 - val_loss: 0.3898 - val_accuracy: 0.8898
Epoch 29/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1285 - accuracy: 0.9516 - val_loss: 0.3907 - val_accuracy: 0.8867
Epoch 30/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1251 - accuracy: 0.9526 - val_loss: 0.4096 - val_accuracy: 0.8885
Epoch 31/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1230 - accuracy: 0.9533 - val_loss: 0.4078 - val_accuracy: 0.8908
Epoch 32/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1206 - accuracy: 0.9543 - val_loss: 0.3960 - val_accuracy: 0.8929
Epoch 33/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1180 - accuracy: 0.9552 - val_loss: 0.4026 - val_accuracy: 0.8933
Epoch 34/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1153 - accuracy: 0.9567 - val_loss: 0.4129 - val_accuracy: 0.8931
Epoch 35/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1118 - accuracy: 0.9579 - val_loss: 0.4496 - val_accuracy: 0.8833
Epoch 36/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1111 - accuracy: 0.9590 - val_loss: 0.4864 - val_accuracy: 0.8848
Epoch 37/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1049 - accuracy: 0.9608 - val_loss: 0.4736 - val_accuracy: 0.8821
Epoch 38/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1079 - accuracy: 0.9593 - val_loss: 0.4342 - val_accuracy: 0.8955
Epoch 39/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1020 - accuracy: 0.9618 - val_loss: 0.4411 - val_accuracy: 0.8942
Epoch 40/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.0997 - accuracy: 0.9624 - val_loss: 0.4438 - val_accuracy: 0.8894
Epoch 41/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.0979 - accuracy: 0.9641 - val_loss: 0.4515 - val_accuracy: 0.8945
Epoch 42/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.0997 - accuracy: 0.9630 - val_loss: 0.4590 - val_accuracy: 0.8846
Epoch 43/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.0922 - accuracy: 0.9658 - val_loss: 0.4472 - val_accuracy: 0.8958
Epoch 44/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.0928 - accuracy: 0.9654 - val_loss: 0.4950 - val_accuracy: 0.8882
Epoch 45/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.0919 - accuracy: 0.9658 - val_loss: 0.4881 - val_accuracy: 0.8903
Epoch 46/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.0885 - accuracy: 0.9665 - val_loss: 0.4729 - val_accuracy: 0.8913
Epoch 47/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.0857 - accuracy: 0.9675 - val_loss: 0.5060 - val_accuracy: 0.8899
Epoch 48/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.0874 - accuracy: 0.9675 - val_loss: 0.5096 - val_accuracy: 0.8907
Epoch 49/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.0829 - accuracy: 0.9694 - val_loss: 0.5013 - val_accuracy: 0.8891
Epoch 50/50
1500/1500 [==============================] - 3s 2ms/step - loss: 0.0829 - accuracy: 0.9682 - val_loss: 0.5330 - val_accuracy: 0.8888
Best epoch: 26

重新实例化超模型并使用上面的最佳周期数对其进行训练。

hypermodel = tuner.hypermodel.build(best_hps)

# Retrain the model
hypermodel.fit(img_train, label_train, epochs=best_epoch, validation_split=0.2)
Epoch 1/26
1500/1500 [==============================] - 4s 2ms/step - loss: 0.5091 - accuracy: 0.8206 - val_loss: 0.4011 - val_accuracy: 0.8560
Epoch 2/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.3795 - accuracy: 0.8622 - val_loss: 0.3718 - val_accuracy: 0.8668
Epoch 3/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.3415 - accuracy: 0.8752 - val_loss: 0.3526 - val_accuracy: 0.8750
Epoch 4/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.3150 - accuracy: 0.8837 - val_loss: 0.3425 - val_accuracy: 0.8751
Epoch 5/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2972 - accuracy: 0.8916 - val_loss: 0.3255 - val_accuracy: 0.8837
Epoch 6/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2799 - accuracy: 0.8971 - val_loss: 0.3634 - val_accuracy: 0.8671
Epoch 7/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2661 - accuracy: 0.9003 - val_loss: 0.3273 - val_accuracy: 0.8847
Epoch 8/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2543 - accuracy: 0.9055 - val_loss: 0.3319 - val_accuracy: 0.8790
Epoch 9/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2438 - accuracy: 0.9082 - val_loss: 0.3318 - val_accuracy: 0.8827
Epoch 10/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2336 - accuracy: 0.9120 - val_loss: 0.3279 - val_accuracy: 0.8827
Epoch 11/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2258 - accuracy: 0.9152 - val_loss: 0.3272 - val_accuracy: 0.8900
Epoch 12/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2197 - accuracy: 0.9180 - val_loss: 0.3087 - val_accuracy: 0.8938
Epoch 13/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2094 - accuracy: 0.9215 - val_loss: 0.3391 - val_accuracy: 0.8814
Epoch 14/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.2050 - accuracy: 0.9217 - val_loss: 0.3199 - val_accuracy: 0.8900
Epoch 15/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1971 - accuracy: 0.9268 - val_loss: 0.3374 - val_accuracy: 0.8919
Epoch 16/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1912 - accuracy: 0.9280 - val_loss: 0.3417 - val_accuracy: 0.8886
Epoch 17/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1837 - accuracy: 0.9305 - val_loss: 0.3579 - val_accuracy: 0.8838
Epoch 18/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1780 - accuracy: 0.9332 - val_loss: 0.3185 - val_accuracy: 0.8971
Epoch 19/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1725 - accuracy: 0.9354 - val_loss: 0.3566 - val_accuracy: 0.8856
Epoch 20/26
1500/1500 [==============================] - 4s 2ms/step - loss: 0.1699 - accuracy: 0.9368 - val_loss: 0.3564 - val_accuracy: 0.8895
Epoch 21/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1632 - accuracy: 0.9399 - val_loss: 0.3474 - val_accuracy: 0.8930
Epoch 22/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1599 - accuracy: 0.9400 - val_loss: 0.3639 - val_accuracy: 0.8882
Epoch 23/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1539 - accuracy: 0.9425 - val_loss: 0.3570 - val_accuracy: 0.8932
Epoch 24/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1507 - accuracy: 0.9433 - val_loss: 0.3596 - val_accuracy: 0.8925
Epoch 25/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1475 - accuracy: 0.9458 - val_loss: 0.3587 - val_accuracy: 0.8908
Epoch 26/26
1500/1500 [==============================] - 3s 2ms/step - loss: 0.1426 - accuracy: 0.9470 - val_loss: 0.3748 - val_accuracy: 0.8864
<keras.callbacks.History at 0x7f5e2eb57c40>

要完成本教程,请在测试数据上评估超模型。

eval_result = hypermodel.evaluate(img_test, label_test)
print("[test loss, test accuracy]:", eval_result)
313/313 [==============================] - 1s 2ms/step - loss: 0.4196 - accuracy: 0.8807
[test loss, test accuracy]: [0.41957515478134155, 0.8806999921798706]

my_dir/intro_to_kt 目录中包含了在超参数搜索期间每次试验(模型配置)运行的详细日志和检查点。如果重新运行超参数搜索,Keras Tuner 将使用这些日志中记录的现有状态来继续搜索。要停用此行为,请在实例化调节器时传递一个附加的 overwrite = True 参数。

总结

在本教程中,您学习了如何使用 Keras Tuner 调节模型的超参数。要详细了解 Keras Tuner,请查看以下其他资源:

另请查看 TensorBoard 中的 HParams Dashboard,以交互方式调节模型超参数。