View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
Overview
The Keras Tuner is a library that helps you pick the optimal set of hyperparameters for your TensorFlow program. The process of selecting the right set of hyperparameters for your machine learning (ML) application is called hyperparameter tuning or hypertuning.
Hyperparameters are the variables that govern the training process and the topology of an ML model. These variables remain constant over the training process and directly impact the performance of your ML program. Hyperparameters are of two types:
- Model hyperparameters which influence model selection such as the number and width of hidden layers
- Algorithm hyperparameters which influence the speed and quality of the learning algorithm such as the learning rate for Stochastic Gradient Descent (SGD) and the number of nearest neighbors for a k Nearest Neighbors (KNN) classifier
In this tutorial, you will use the Keras Tuner to perform hypertuning for an image classification application.
Setup
import tensorflow as tf
from tensorflow import keras
2023-12-07 03:19:58.479579: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-12-07 03:19:58.479624: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-12-07 03:19:58.481203: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Install and import the Keras Tuner.
pip install -q -U keras-tuner
import keras_tuner as kt
Download and prepare the dataset
In this tutorial, you will use the Keras Tuner to find the best hyperparameters for a machine learning model that classifies images of clothing from the Fashion MNIST dataset.
Load the data.
(img_train, label_train), (img_test, label_test) = keras.datasets.fashion_mnist.load_data()
# Normalize pixel values between 0 and 1
img_train = img_train.astype('float32') / 255.0
img_test = img_test.astype('float32') / 255.0
Define the model
When you build a model for hypertuning, you also define the hyperparameter search space in addition to the model architecture. The model you set up for hypertuning is called a hypermodel.
You can define a hypermodel through two approaches:
- By using a model builder function
- By subclassing the
HyperModel
class of the Keras Tuner API
You can also use two pre-defined HyperModel classes - HyperXception and HyperResNet for computer vision applications.
In this tutorial, you use a model builder function to define the image classification model. The model builder function returns a compiled model and uses hyperparameters you define inline to hypertune the model.
def model_builder(hp):
model = keras.Sequential()
model.add(keras.layers.Flatten(input_shape=(28, 28)))
# Tune the number of units in the first Dense layer
# Choose an optimal value between 32-512
hp_units = hp.Int('units', min_value=32, max_value=512, step=32)
model.add(keras.layers.Dense(units=hp_units, activation='relu'))
model.add(keras.layers.Dense(10))
# Tune the learning rate for the optimizer
# Choose an optimal value from 0.01, 0.001, or 0.0001
hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])
model.compile(optimizer=keras.optimizers.Adam(learning_rate=hp_learning_rate),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
return model
Instantiate the tuner and perform hypertuning
Instantiate the tuner to perform the hypertuning. The Keras Tuner has four tuners available - RandomSearch
, Hyperband
, BayesianOptimization
, and Sklearn
. In this tutorial, you use the Hyperband tuner.
To instantiate the Hyperband tuner, you must specify the hypermodel, the objective
to optimize and the maximum number of epochs to train (max_epochs
).
tuner = kt.Hyperband(model_builder,
objective='val_accuracy',
max_epochs=10,
factor=3,
directory='my_dir',
project_name='intro_to_kt')
The Hyperband tuning algorithm uses adaptive resource allocation and early-stopping to quickly converge on a high-performing model. This is done using a sports championship style bracket. The algorithm trains a large number of models for a few epochs and carries forward only the top-performing half of models to the next round. Hyperband determines the number of models to train in a bracket by computing 1 + logfactor
(max_epochs
) and rounding it up to the nearest integer.
Create a callback to stop training early after reaching a certain value for the validation loss.
stop_early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)
Run the hyperparameter search. The arguments for the search method are the same as those used for tf.keras.model.fit
in addition to the callback above.
tuner.search(img_train, label_train, epochs=50, validation_split=0.2, callbacks=[stop_early])
# Get the optimal hyperparameters
best_hps=tuner.get_best_hyperparameters(num_trials=1)[0]
print(f"""
The hyperparameter search is complete. The optimal number of units in the first densely-connected
layer is {best_hps.get('units')} and the optimal learning rate for the optimizer
is {best_hps.get('learning_rate')}.
""")
Trial 30 Complete [00h 00m 41s] val_accuracy: 0.8550833463668823 Best val_accuracy So Far: 0.8900833129882812 Total elapsed time: 00h 08m 43s The hyperparameter search is complete. The optimal number of units in the first densely-connected layer is 224 and the optimal learning rate for the optimizer is 0.001.
Train the model
Find the optimal number of epochs to train the model with the hyperparameters obtained from the search.
# Build the model with the optimal hyperparameters and train it on the data for 50 epochs
model = tuner.hypermodel.build(best_hps)
history = model.fit(img_train, label_train, epochs=50, validation_split=0.2)
val_acc_per_epoch = history.history['val_accuracy']
best_epoch = val_acc_per_epoch.index(max(val_acc_per_epoch)) + 1
print('Best epoch: %d' % (best_epoch,))
Epoch 1/50 1500/1500 [==============================] - 5s 3ms/step - loss: 0.5055 - accuracy: 0.8205 - val_loss: 0.4009 - val_accuracy: 0.8582 Epoch 2/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.3772 - accuracy: 0.8628 - val_loss: 0.3637 - val_accuracy: 0.8685 Epoch 3/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.3366 - accuracy: 0.8766 - val_loss: 0.3698 - val_accuracy: 0.8662 Epoch 4/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.3110 - accuracy: 0.8858 - val_loss: 0.3599 - val_accuracy: 0.8703 Epoch 5/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.2924 - accuracy: 0.8906 - val_loss: 0.3289 - val_accuracy: 0.8818 Epoch 6/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.2768 - accuracy: 0.8958 - val_loss: 0.3491 - val_accuracy: 0.8743 Epoch 7/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.2622 - accuracy: 0.9022 - val_loss: 0.3127 - val_accuracy: 0.8866 Epoch 8/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.2512 - accuracy: 0.9067 - val_loss: 0.3378 - val_accuracy: 0.8822 Epoch 9/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.2412 - accuracy: 0.9104 - val_loss: 0.3282 - val_accuracy: 0.8848 Epoch 10/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.2294 - accuracy: 0.9143 - val_loss: 0.3398 - val_accuracy: 0.8838 Epoch 11/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2217 - accuracy: 0.9166 - val_loss: 0.3158 - val_accuracy: 0.8897 Epoch 12/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2124 - accuracy: 0.9197 - val_loss: 0.3443 - val_accuracy: 0.8858 Epoch 13/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2051 - accuracy: 0.9226 - val_loss: 0.3649 - val_accuracy: 0.8854 Epoch 14/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1986 - accuracy: 0.9254 - val_loss: 0.3195 - val_accuracy: 0.8901 Epoch 15/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1908 - accuracy: 0.9287 - val_loss: 0.3173 - val_accuracy: 0.8971 Epoch 16/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1823 - accuracy: 0.9306 - val_loss: 0.3480 - val_accuracy: 0.8911 Epoch 17/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1803 - accuracy: 0.9314 - val_loss: 0.3258 - val_accuracy: 0.8929 Epoch 18/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1721 - accuracy: 0.9370 - val_loss: 0.3331 - val_accuracy: 0.8950 Epoch 19/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1676 - accuracy: 0.9383 - val_loss: 0.3331 - val_accuracy: 0.8962 Epoch 20/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1634 - accuracy: 0.9382 - val_loss: 0.3432 - val_accuracy: 0.8932 Epoch 21/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1566 - accuracy: 0.9405 - val_loss: 0.3597 - val_accuracy: 0.8873 Epoch 22/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1538 - accuracy: 0.9412 - val_loss: 0.3446 - val_accuracy: 0.8933 Epoch 23/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1493 - accuracy: 0.9435 - val_loss: 0.3677 - val_accuracy: 0.8888 Epoch 24/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1459 - accuracy: 0.9454 - val_loss: 0.3472 - val_accuracy: 0.8961 Epoch 25/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1400 - accuracy: 0.9469 - val_loss: 0.3984 - val_accuracy: 0.8827 Epoch 26/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1374 - accuracy: 0.9484 - val_loss: 0.3767 - val_accuracy: 0.8931 Epoch 27/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1323 - accuracy: 0.9491 - val_loss: 0.3849 - val_accuracy: 0.8909 Epoch 28/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1312 - accuracy: 0.9511 - val_loss: 0.3897 - val_accuracy: 0.8903 Epoch 29/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1242 - accuracy: 0.9533 - val_loss: 0.4042 - val_accuracy: 0.8907 Epoch 30/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1238 - accuracy: 0.9533 - val_loss: 0.3784 - val_accuracy: 0.8934 Epoch 31/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1176 - accuracy: 0.9554 - val_loss: 0.4152 - val_accuracy: 0.8940 Epoch 32/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1152 - accuracy: 0.9570 - val_loss: 0.4081 - val_accuracy: 0.8886 Epoch 33/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1123 - accuracy: 0.9578 - val_loss: 0.4372 - val_accuracy: 0.8856 Epoch 34/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1120 - accuracy: 0.9582 - val_loss: 0.4068 - val_accuracy: 0.8937 Epoch 35/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1073 - accuracy: 0.9607 - val_loss: 0.4246 - val_accuracy: 0.8943 Epoch 36/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1040 - accuracy: 0.9606 - val_loss: 0.4211 - val_accuracy: 0.8934 Epoch 37/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1034 - accuracy: 0.9613 - val_loss: 0.4291 - val_accuracy: 0.8933 Epoch 38/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0991 - accuracy: 0.9627 - val_loss: 0.4504 - val_accuracy: 0.8942 Epoch 39/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0977 - accuracy: 0.9635 - val_loss: 0.4331 - val_accuracy: 0.8950 Epoch 40/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0948 - accuracy: 0.9653 - val_loss: 0.4429 - val_accuracy: 0.8944 Epoch 41/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0939 - accuracy: 0.9643 - val_loss: 0.4727 - val_accuracy: 0.8888 Epoch 42/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0937 - accuracy: 0.9650 - val_loss: 0.4521 - val_accuracy: 0.8969 Epoch 43/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0888 - accuracy: 0.9673 - val_loss: 0.4801 - val_accuracy: 0.8908 Epoch 44/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.0880 - accuracy: 0.9678 - val_loss: 0.4582 - val_accuracy: 0.8973 Epoch 45/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.0878 - accuracy: 0.9668 - val_loss: 0.5006 - val_accuracy: 0.8920 Epoch 46/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.0862 - accuracy: 0.9678 - val_loss: 0.4547 - val_accuracy: 0.8942 Epoch 47/50 1500/1500 [==============================] - 4s 2ms/step - loss: 0.0836 - accuracy: 0.9680 - val_loss: 0.5050 - val_accuracy: 0.8908 Epoch 48/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0808 - accuracy: 0.9692 - val_loss: 0.4956 - val_accuracy: 0.8954 Epoch 49/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0803 - accuracy: 0.9696 - val_loss: 0.5260 - val_accuracy: 0.8928 Epoch 50/50 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0761 - accuracy: 0.9716 - val_loss: 0.5449 - val_accuracy: 0.8914 Best epoch: 44
Re-instantiate the hypermodel and train it with the optimal number of epochs from above.
hypermodel = tuner.hypermodel.build(best_hps)
# Retrain the model
hypermodel.fit(img_train, label_train, epochs=best_epoch, validation_split=0.2)
Epoch 1/44 1500/1500 [==============================] - 5s 3ms/step - loss: 0.5087 - accuracy: 0.8195 - val_loss: 0.4183 - val_accuracy: 0.8519 Epoch 2/44 1500/1500 [==============================] - 4s 2ms/step - loss: 0.3767 - accuracy: 0.8639 - val_loss: 0.3740 - val_accuracy: 0.8653 Epoch 3/44 1500/1500 [==============================] - 4s 2ms/step - loss: 0.3355 - accuracy: 0.8771 - val_loss: 0.3642 - val_accuracy: 0.8691 Epoch 4/44 1500/1500 [==============================] - 4s 3ms/step - loss: 0.3109 - accuracy: 0.8860 - val_loss: 0.3444 - val_accuracy: 0.8782 Epoch 5/44 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2908 - accuracy: 0.8918 - val_loss: 0.3312 - val_accuracy: 0.8801 Epoch 6/44 1500/1500 [==============================] - 4s 2ms/step - loss: 0.2757 - accuracy: 0.8969 - val_loss: 0.3437 - val_accuracy: 0.8782 Epoch 7/44 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2617 - accuracy: 0.9030 - val_loss: 0.3414 - val_accuracy: 0.8788 Epoch 8/44 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2504 - accuracy: 0.9062 - val_loss: 0.3221 - val_accuracy: 0.8827 Epoch 9/44 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2389 - accuracy: 0.9105 - val_loss: 0.3210 - val_accuracy: 0.8858 Epoch 10/44 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2310 - accuracy: 0.9140 - val_loss: 0.3371 - val_accuracy: 0.8807 Epoch 11/44 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2208 - accuracy: 0.9172 - val_loss: 0.3135 - val_accuracy: 0.8898 Epoch 12/44 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2143 - accuracy: 0.9191 - val_loss: 0.3253 - val_accuracy: 0.8863 Epoch 13/44 1500/1500 [==============================] - 4s 3ms/step - loss: 0.2049 - accuracy: 0.9233 - val_loss: 0.3268 - val_accuracy: 0.8873 Epoch 14/44 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1993 - accuracy: 0.9259 - val_loss: 0.3168 - val_accuracy: 0.8919 Epoch 15/44 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1927 - accuracy: 0.9273 - val_loss: 0.3196 - val_accuracy: 0.8913 Epoch 16/44 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1832 - accuracy: 0.9316 - val_loss: 0.3353 - val_accuracy: 0.8911 Epoch 17/44 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1791 - accuracy: 0.9338 - val_loss: 0.3295 - val_accuracy: 0.8903 Epoch 18/44 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1728 - accuracy: 0.9364 - val_loss: 0.3304 - val_accuracy: 0.8913 Epoch 19/44 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1672 - accuracy: 0.9370 - val_loss: 0.3382 - val_accuracy: 0.8899 Epoch 20/44 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1605 - accuracy: 0.9400 - val_loss: 0.3623 - val_accuracy: 0.8908 Epoch 21/44 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1593 - accuracy: 0.9403 - val_loss: 0.3704 - val_accuracy: 0.8881 Epoch 22/44 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1519 - accuracy: 0.9430 - val_loss: 0.3617 - val_accuracy: 0.8940 Epoch 23/44 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1512 - accuracy: 0.9441 - val_loss: 0.3308 - val_accuracy: 0.8947 Epoch 24/44 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1439 - accuracy: 0.9460 - val_loss: 0.3594 - val_accuracy: 0.8918 Epoch 25/44 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1391 - accuracy: 0.9475 - val_loss: 0.3781 - val_accuracy: 0.8931 Epoch 26/44 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1359 - accuracy: 0.9493 - val_loss: 0.3700 - val_accuracy: 0.8934 Epoch 27/44 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1331 - accuracy: 0.9505 - val_loss: 0.3792 - val_accuracy: 0.8917 Epoch 28/44 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1299 - accuracy: 0.9519 - val_loss: 0.3837 - val_accuracy: 0.8921 Epoch 29/44 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1264 - accuracy: 0.9532 - val_loss: 0.4043 - val_accuracy: 0.8850 Epoch 30/44 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1238 - accuracy: 0.9534 - val_loss: 0.3983 - val_accuracy: 0.8917 Epoch 31/44 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1175 - accuracy: 0.9572 - val_loss: 0.4241 - val_accuracy: 0.8884 Epoch 32/44 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1151 - accuracy: 0.9563 - val_loss: 0.4080 - val_accuracy: 0.8930 Epoch 33/44 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1150 - accuracy: 0.9566 - val_loss: 0.4250 - val_accuracy: 0.8902 Epoch 34/44 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1120 - accuracy: 0.9588 - val_loss: 0.4427 - val_accuracy: 0.8895 Epoch 35/44 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1073 - accuracy: 0.9603 - val_loss: 0.4317 - val_accuracy: 0.8926 Epoch 36/44 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1057 - accuracy: 0.9610 - val_loss: 0.4138 - val_accuracy: 0.8936 Epoch 37/44 1500/1500 [==============================] - 4s 2ms/step - loss: 0.1039 - accuracy: 0.9616 - val_loss: 0.4180 - val_accuracy: 0.8917 Epoch 38/44 1500/1500 [==============================] - 4s 3ms/step - loss: 0.1007 - accuracy: 0.9625 - val_loss: 0.4698 - val_accuracy: 0.8807 Epoch 39/44 1500/1500 [==============================] - 4s 2ms/step - loss: 0.0989 - accuracy: 0.9632 - val_loss: 0.4371 - val_accuracy: 0.8920 Epoch 40/44 1500/1500 [==============================] - 4s 2ms/step - loss: 0.0954 - accuracy: 0.9644 - val_loss: 0.4582 - val_accuracy: 0.8933 Epoch 41/44 1500/1500 [==============================] - 4s 2ms/step - loss: 0.0944 - accuracy: 0.9648 - val_loss: 0.5068 - val_accuracy: 0.8855 Epoch 42/44 1500/1500 [==============================] - 4s 2ms/step - loss: 0.0909 - accuracy: 0.9663 - val_loss: 0.5006 - val_accuracy: 0.8864 Epoch 43/44 1500/1500 [==============================] - 4s 2ms/step - loss: 0.0922 - accuracy: 0.9653 - val_loss: 0.4598 - val_accuracy: 0.8942 Epoch 44/44 1500/1500 [==============================] - 4s 3ms/step - loss: 0.0860 - accuracy: 0.9673 - val_loss: 0.4797 - val_accuracy: 0.8913 <keras.src.callbacks.History at 0x7f6b341323a0>
To finish this tutorial, evaluate the hypermodel on the test data.
eval_result = hypermodel.evaluate(img_test, label_test)
print("[test loss, test accuracy]:", eval_result)
313/313 [==============================] - 1s 2ms/step - loss: 0.5223 - accuracy: 0.8872 [test loss, test accuracy]: [0.5223038792610168, 0.8871999979019165]
The my_dir/intro_to_kt
directory contains detailed logs and checkpoints for every trial (model configuration) run during the hyperparameter search. If you re-run the hyperparameter search, the Keras Tuner uses the existing state from these logs to resume the search. To disable this behavior, pass an additional overwrite=True
argument while instantiating the tuner.
Summary
In this tutorial, you learned how to use the Keras Tuner to tune hyperparameters for a model. To learn more about the Keras Tuner, check out these additional resources:
Also check out the HParams Dashboard in TensorBoard to interactively tune your model hyperparameters.