New! Use Simple ML for Sheets to apply machine learning to the data in your Google Sheets Read More

Automated hyper-parameter tuning

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook

Welcome to the Automated hyper-parameter tuning tutorial. In this colab, you will learn how to improve your models using automated hyper-parameter tuning with TensorFlow Decision Forests.

More precicely we will:

  1. Train a model without hyper-parameter tuning. This model will be used to measure the quality improvement of hyper-parameter tuning.
  2. Train a model with hyper-parameter tuning using TF-DF's tuner. The hyper-parameters to optimize will be defined manually.
  3. Train another model with hyper-parameter tuning using TF-DF's tuner. But this time, the hyper-parameters to optimize will be set automatically. This is the recommanded first approach to try when using hyper-parameter tuning.
  4. Finally, we will train a model with hyper-parameter tuning using Keras's tuner.

Introduction

A learning algorithm trains a machine learning model on a training dataset. The parameters of a learning algorithm–called "hyper-parameters"–control how the model is trained and impact its quality. Therefore, finding the best hyper-parameters is an important stage of modeling.

Some hyper-parameters are simple to configure. For example, increasing the number of trees (num_trees) in a random forest increases the quality of the model until a plateau. Therefore, setting the largest value compatible with the serving constraints (more trees means a larger model) is a valid rule of thumb. However, other hyper-parameters have a more complex interaction with the model and cannot be chosen with such a simple rule. For example, increasing the maximum tree depth (max_depth) of a gradient boosted tree model can both increase or decrease the quality of the model. Furthermore, hyper-parameters can interact between each others, and the optimal value of a hyper-parameter cannot be found in isolation.

There are three main approaches to select the hyper-parameter values:

  1. The default approach: Learning algorithms come with default values. While not ideal in all cases, those values produce reasonable results in most situations. This approach is recommended as the first approach to use in any modeling. This page lists the default value of TF Decision Forests.

  2. The template hyper-parameter approach: In addition to the default values, TF Decision Forests also exposes the hyper-parameter templates. Those are benchmark-tuned hyper-parameter values with excellent performance but high training cost (e.g. hyperparameter_template="benchmark_rank1").

  3. The manual tuning approach: You can manually test different hyper-parameter values and select the one that performs best. The advanced users guide give some advice.

  4. The automated tuning approach: A tuning algorithm can be used to find automatically the best hyper-parameter values. This approach gives often the best results and does not require expertise. The main downside of this approach is the time it takes for large datasets.

In this colab, we shows the default and automated tuning approaches with the TensorFlow Decision Forests library.

Hyper-parameter tuning algorithms

Automated tuning algorithms work by generating and evaluating a large number of hyper-parameter values. Each of those iterations is called a "trial". The evaluation of a trial is expensive as it requires to train a new model each time. At the end of the tuning, the hyper-parameter with the best evaluation is used.

Tuning algorithm are configured as follow:

The search space

The search space is the list of hyper-parameters to optimize and the values they can take. For example, the maximum depth of a tree could be optimized for values in between 1 and 32. Exploring more hyper-parameters and more possible values often leads to better models but also takes more time. The hyper-parameters listed in the user manual are the most impactful ones to tune. The other hyper-parameters are listed in the documentation.

When the possible value of one hyper-parameter depends on the value of another hyper-parameter, the search space is said to be conditional.

The number of trials

The number of trials defines how many models will be trained and evaluated. Larger number of trials generally leads to better models, but takes more time.

The optimizer

The optimizer selects the next hyper-parameter to evaluate the past trial evaluations. The simplest and often reasonable optimizer is the one that selects the hyper-parameter at random.

The objective / trial score

The objective is the metric optimized by the tuner. Often, this metric is a measure of quality (e.g. accuracy, log loss) of the model evaluated on a validation dataset.

Train-valid-test

The validation dataset should be different from the training datasets: If the training and validation datasets are the same, the selected hyper-parameters will be irrelevant. The validation dataset should also be different from the testing dataset (also called holdout dataset): Because hyper-parameter tuning is a form of training, if the testing and validation datasets are the same, you are effectively training on the test dataset. In this case, you might overfit on your test dataset without a way to measure it.

Cross-validation

In the case of a small dataset, for example a dataset with less than 100k examples, hyper-parameter tuning can be coupled with cross-validation: Instead of being evaluated from a single training-test round, the objective/trial score is evaluated as the average of the metric over multiple cross-validation rounds.

Similarly as to the train-valid-and-test datasets, the cross-validation used to evaluate the objective/score during hyper-parameter tuning should be different from the cross-validation used to evaluate the quality of the model.

Out-of-bag evaluation

Some models, like Random Forests, can be evaluated on the training datasets using the "out-of-bag evaluation" method. While not as accurate as cross-validation, the "out-of-bag evaluation" is much faster than cross-validation and does not require a separate validation datasets.

In tensorflow decision forests

In TF-DF, the model "self" evaluation is always a fair way to evaluate a model. For example, an out-of-bag evaluation is used for Random Forest models while a validation dataset is used for Gradient Boosted models.

Hyper-parameter tuning with TF Decision Forests

TF-DF supports automatic hyper-parameter tuning with minimal configuration. In the next example, we will train and compare two models: One trained with default hyper-parameters, and one trained with hyper-parameter tuning.

Setup

# Install TensorFlow Dececision Forests
pip install tensorflow_decision_forests -U -qq

Install Wurlitzer. Wurlitzer is required to show the detailed training logs in colabs (with verbose=2).

pip install wurlitzer -U -qq

Import the necessary libraries.

import tensorflow_decision_forests as tfdf
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow as tf
import numpy as np
2023-02-01 12:08:22.560223: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-02-01 12:08:22.560311: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-02-01 12:08:22.560320: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

The hidden code cell limits the output height in colab.

Define "set_cell_height".

Training a model without Automated hyper-parameter tuning

We will train a model on the Adult dataset available on the UCI. Let's download the dataset.

# Download a copy of the adult dataset.
wget -q https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset/adult_train.csv -O /tmp/adult_train.csv
wget -q https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset/adult_test.csv -O /tmp/adult_test.csv

Split the dataset into a training and a testing dataset.

# Load the dataset in memory
train_df = pd.read_csv("/tmp/adult_train.csv")
test_df = pd.read_csv("/tmp/adult_test.csv")

# , and convert it into a TensorFlow dataset.
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label="income")
test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_df, label="income")

First, we train and evaluate the quality of a Gradient Boosted Trees model trained with the default hyper-parameters.

%%time
# Train a model with default hyper-parameters
model = tfdf.keras.GradientBoostedTreesModel()
model.fit(train_ds)
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
Use /tmpfs/tmp/tmpuk7pbkq8 as temporary training directory
Reading training dataset...
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
Training dataset read in 0:00:03.662891. Found 22792 examples.
Training model...
2023-02-01 12:08:32.432401: W external/ydf/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc:1790] "goss_alpha" set but "sampling_method" not equal to "GOSS".
2023-02-01 12:08:32.432437: W external/ydf/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc:1800] "goss_beta" set but "sampling_method" not equal to "GOSS".
2023-02-01 12:08:32.432444: W external/ydf/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc:1814] "selective_gradient_boosting_ratio" set but "sampling_method" not equal to "SELGB".
Model trained in 0:00:03.111172
Compiling model...
[INFO 2023-02-01T12:08:35.50196752+00:00 kernel.cc:1214] Loading model from path /tmpfs/tmp/tmpuk7pbkq8/model/ with prefix 951b90795cc74103
[INFO 2023-02-01T12:08:35.524444264+00:00 abstract_model.cc:1311] Engine "GradientBoostedTreesQuickScorerExtended" built
[INFO 2023-02-01T12:08:35.524475121+00:00 kernel.cc:1046] Use fast generic engine
WARNING:tensorflow:AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7fc976a0cdc0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:tensorflow:AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7fc976a0cdc0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7fc976a0cdc0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
Model compiled.
CPU times: user 11.9 s, sys: 1.19 s, total: 13.1 s
Wall time: 8.25 s
<keras.callbacks.History at 0x7fca5c56db80>
# Evaluate the model
model.compile(["accuracy"])
test_accuracy = model.evaluate(test_ds, return_dict=True, verbose=0)["accuracy"]
print(f"Test accuracy without hyper-parameter tuning: {test_accuracy:.4f}")
Test accuracy without hyper-parameter tuning: 0.8744

The default hyper-parameters of the model are available with the learner_params function. The definition of those parameters is available in the documentation.

print("Default hyper-parameters of the model:\n", model.learner_params)
Default hyper-parameters of the model:
 {'adapt_subsample_for_maximum_training_duration': False, 'allow_na_conditions': False, 'apply_link_function': True, 'categorical_algorithm': 'CART', 'categorical_set_split_greedy_sampling': 0.1, 'categorical_set_split_max_num_items': -1, 'categorical_set_split_min_item_frequency': 1, 'compute_permutation_variable_importance': False, 'dart_dropout': 0.01, 'early_stopping': 'LOSS_INCREASE', 'early_stopping_initial_iteration': 10, 'early_stopping_num_trees_look_ahead': 30, 'focal_loss_alpha': 0.5, 'focal_loss_gamma': 2.0, 'forest_extraction': 'MART', 'goss_alpha': 0.2, 'goss_beta': 0.1, 'growing_strategy': 'LOCAL', 'honest': False, 'honest_fixed_separation': False, 'honest_ratio_leaf_examples': 0.5, 'in_split_min_examples_check': True, 'keep_non_leaf_label_distribution': True, 'l1_regularization': 0.0, 'l2_categorical_regularization': 1.0, 'l2_regularization': 0.0, 'lambda_loss': 1.0, 'loss': 'DEFAULT', 'max_depth': 6, 'max_num_nodes': None, 'maximum_model_size_in_memory_in_bytes': -1.0, 'maximum_training_duration_seconds': -1.0, 'min_examples': 5, 'missing_value_policy': 'GLOBAL_IMPUTATION', 'num_candidate_attributes': -1, 'num_candidate_attributes_ratio': -1.0, 'num_trees': 300, 'pure_serving_model': False, 'random_seed': 123456, 'sampling_method': 'RANDOM', 'selective_gradient_boosting_ratio': 0.01, 'shrinkage': 0.1, 'sorting_strategy': 'PRESORT', 'sparse_oblique_normalization': None, 'sparse_oblique_num_projections_exponent': None, 'sparse_oblique_projection_density_factor': None, 'sparse_oblique_weights': None, 'split_axis': 'AXIS_ALIGNED', 'subsample': 1.0, 'uplift_min_examples_in_treatment': 5, 'uplift_split_score': 'KULLBACK_LEIBLER', 'use_hessian_gain': False, 'validation_interval_in_trees': 1, 'validation_ratio': 0.1}

Training a model with automated hyper-parameter tuning and manual definition of the hyper-parameters

Hyper-parameter tuning is enabled by specifying the tuner constructor argument of the model. The tuner object contains all the configuration of the tuner (search space, optimizer, trial and objective).

# Configure the tuner.

# Create a Random Search tuner with 50 trials.
tuner = tfdf.tuner.RandomSearch(num_trials=50)

# Define the search space.
#
# Adding more parameters generaly improve the quality of the model, but make
# the tuning last longer.

tuner.choice("min_examples", [2, 5, 7, 10])
tuner.choice("categorical_algorithm", ["CART", "RANDOM"])

# Some hyper-parameters are only valid for specific values of other
# hyper-parameters. For example, the "max_depth" parameter is mostly useful when
# "growing_strategy=LOCAL" while "max_num_nodes" is better suited when
# "growing_strategy=BEST_FIRST_GLOBAL".

local_search_space = tuner.choice("growing_strategy", ["LOCAL"])
local_search_space.choice("max_depth", [3, 4, 5, 6, 8])

# merge=True indicates that the parameter (here "growing_strategy") is already
# defined, and that new values are added to it.
global_search_space = tuner.choice("growing_strategy", ["BEST_FIRST_GLOBAL"], merge=True)
global_search_space.choice("max_num_nodes", [16, 32, 64, 128, 256])

tuner.choice("use_hessian_gain", [True, False])
tuner.choice("shrinkage", [0.02, 0.05, 0.10, 0.15])
tuner.choice("num_candidate_attributes_ratio", [0.2, 0.5, 0.9, 1.0])

# Uncomment some (or all) of the following hyper-parameters to increase the
# quality of the search. The number of trial should be increased accordingly.

# tuner.choice("split_axis", ["AXIS_ALIGNED"])
# oblique_space = tuner.choice("split_axis", ["SPARSE_OBLIQUE"], merge=True)
# oblique_space.choice("sparse_oblique_normalization",
#                      ["NONE", "STANDARD_DEVIATION", "MIN_MAX"])
# oblique_space.choice("sparse_oblique_weights", ["BINARY", "CONTINUOUS"])
# oblique_space.choice("sparse_oblique_num_projections_exponent", [1.0, 1.5])
<tensorflow_decision_forests.component.tuner.tuner.SearchSpace at 0x7fc9702bd5e0>
%%time
%set_cell_height 300

# Tune the model. Notice the `tuner=tuner`.
tuned_model = tfdf.keras.GradientBoostedTreesModel(tuner=tuner)
tuned_model.fit(train_ds, verbose=2)

# The `num_threads` model constructor argument (not specified in the example
# above) controls how many trials are run in parallel (one per thread). If
# `num_threads` is not specified (like in the example above), one thread is
# allocated for each available CPU core.
#
# If the training is interrupted (for example, by pressing on the "stop" button
# on the top-left of the colab cell), the best model so-far will be returned.

# In the training logs, you can see lines such as `[10/50] Score: -0.45 / -0.40
# HParams: ...`. This indicates that 10 of the 50 trials have been completed.
# And that the last trial returned a score of "-0.45" and that the best trial so
# far has a score of "-0.40". In this example, the model is optimized by
# logloss. Since scores are maximized and log loss should be minimized, the
# score is effectively minus the log loss.
<IPython.core.display.Javascript object>
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
Use /tmpfs/tmp/tmpgjhxwzpg as temporary training directory
Reading training dataset...
Training tensor examples:
Features: {'age': <tf.Tensor 'data:0' shape=(None,) dtype=int64>, 'workclass': <tf.Tensor 'data_13:0' shape=(None,) dtype=string>, 'fnlwgt': <tf.Tensor 'data_5:0' shape=(None,) dtype=int64>, 'education': <tf.Tensor 'data_3:0' shape=(None,) dtype=string>, 'education_num': <tf.Tensor 'data_4:0' shape=(None,) dtype=int64>, 'marital_status': <tf.Tensor 'data_7:0' shape=(None,) dtype=string>, 'occupation': <tf.Tensor 'data_9:0' shape=(None,) dtype=string>, 'relationship': <tf.Tensor 'data_11:0' shape=(None,) dtype=string>, 'race': <tf.Tensor 'data_10:0' shape=(None,) dtype=string>, 'sex': <tf.Tensor 'data_12:0' shape=(None,) dtype=string>, 'capital_gain': <tf.Tensor 'data_1:0' shape=(None,) dtype=int64>, 'capital_loss': <tf.Tensor 'data_2:0' shape=(None,) dtype=int64>, 'hours_per_week': <tf.Tensor 'data_6:0' shape=(None,) dtype=int64>, 'native_country': <tf.Tensor 'data_8:0' shape=(None,) dtype=string>}
Label: Tensor("data_14:0", shape=(None,), dtype=int64)
Weights: None
Normalized tensor features:
 {'age': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast:0' shape=(None,) dtype=float32>), 'workclass': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_13:0' shape=(None,) dtype=string>), 'fnlwgt': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast_1:0' shape=(None,) dtype=float32>), 'education': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_3:0' shape=(None,) dtype=string>), 'education_num': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast_2:0' shape=(None,) dtype=float32>), 'marital_status': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_7:0' shape=(None,) dtype=string>), 'occupation': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_9:0' shape=(None,) dtype=string>), 'relationship': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_11:0' shape=(None,) dtype=string>), 'race': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_10:0' shape=(None,) dtype=string>), 'sex': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_12:0' shape=(None,) dtype=string>), 'capital_gain': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast_3:0' shape=(None,) dtype=float32>), 'capital_loss': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast_4:0' shape=(None,) dtype=float32>), 'hours_per_week': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast_5:0' shape=(None,) dtype=float32>), 'native_country': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_8:0' shape=(None,) dtype=string>)}
Training dataset read in 0:00:00.385471. Found 22792 examples.
Training model...
Standard output detected as not visible to the user e.g. running in a notebook. Creating a training log redirection. If training gets stuck, try calling tfdf.keras.set_training_logs_redirection(False).
[INFO 2023-02-01T12:08:37.779610966+00:00 kernel.cc:756] Start Yggdrasil model training
[INFO 2023-02-01T12:08:37.779679473+00:00 kernel.cc:757] Collect training examples
[INFO 2023-02-01T12:08:37.77976402+00:00 kernel.cc:388] Number of batches: 23
[INFO 2023-02-01T12:08:37.779779517+00:00 kernel.cc:389] Number of examples: 22792
[INFO 2023-02-01T12:08:37.787405763+00:00 data_spec_inference.cc:303] 1 item(s) have been pruned (i.e. they are considered out of dictionary) for the column native_country (40 item(s) left) because min_value_count=5 and max_number_of_unique_values=2000
[INFO 2023-02-01T12:08:37.787447447+00:00 data_spec_inference.cc:303] 1 item(s) have been pruned (i.e. they are considered out of dictionary) for the column occupation (13 item(s) left) because min_value_count=5 and max_number_of_unique_values=2000
[INFO 2023-02-01T12:08:37.787485818+00:00 data_spec_inference.cc:303] 1 item(s) have been pruned (i.e. they are considered out of dictionary) for the column workclass (7 item(s) left) because min_value_count=5 and max_number_of_unique_values=2000
[INFO 2023-02-01T12:08:37.792400377+00:00 kernel.cc:774] Training dataset:
Number of records: 22792
Number of columns: 15

Number of columns by type:
    CATEGORICAL: 9 (60%)
    NUMERICAL: 6 (40%)

Columns:

CATEGORICAL: 9 (60%)
    0: "__LABEL" CATEGORICAL integerized vocab-size:3 no-ood-item
    4: "education" CATEGORICAL has-dict vocab-size:17 zero-ood-items most-frequent:"HS-grad" 7340 (32.2043%)
    8: "marital_status" CATEGORICAL has-dict vocab-size:8 zero-ood-items most-frequent:"Married-civ-spouse" 10431 (45.7661%)
    9: "native_country" CATEGORICAL num-nas:407 (1.78571%) has-dict vocab-size:41 num-oods:1 (0.00446728%) most-frequent:"United-States" 20436 (91.2933%)
    10: "occupation" CATEGORICAL num-nas:1260 (5.52826%) has-dict vocab-size:14 num-oods:1 (0.00464425%) most-frequent:"Prof-specialty" 2870 (13.329%)
    11: "race" CATEGORICAL has-dict vocab-size:6 zero-ood-items most-frequent:"White" 19467 (85.4115%)
    12: "relationship" CATEGORICAL has-dict vocab-size:7 zero-ood-items most-frequent:"Husband" 9191 (40.3256%)
    13: "sex" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"Male" 15165 (66.5365%)
    14: "workclass" CATEGORICAL num-nas:1257 (5.51509%) has-dict vocab-size:8 num-oods:1 (0.0046436%) most-frequent:"Private" 15879 (73.7358%)

NUMERICAL: 6 (40%)
    1: "age" NUMERICAL mean:38.6153 min:17 max:90 sd:13.661
    2: "capital_gain" NUMERICAL mean:1081.9 min:0 max:99999 sd:7509.48
    3: "capital_loss" NUMERICAL mean:87.2806 min:0 max:4356 sd:403.01
    5: "education_num" NUMERICAL mean:10.0927 min:1 max:16 sd:2.56427
    6: "fnlwgt" NUMERICAL mean:189879 min:12285 max:1.4847e+06 sd:106423
    7: "hours_per_week" NUMERICAL mean:40.3955 min:1 max:99 sd:12.249

Terminology:
    nas: Number of non-available (i.e. missing) values.
    ood: Out of dictionary.
    manually-defined: Attribute which type is manually defined by the user i.e. the type was not automatically inferred.
    tokenized: The attribute value is obtained through tokenization.
    has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string.
    vocab-size: Number of unique values.

[INFO 2023-02-01T12:08:37.792452287+00:00 kernel.cc:790] Configure learner
2023-02-01 12:08:37.792731: W external/ydf/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc:1790] "goss_alpha" set but "sampling_method" not equal to "GOSS".
2023-02-01 12:08:37.792757: W external/ydf/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc:1800] "goss_beta" set but "sampling_method" not equal to "GOSS".
2023-02-01 12:08:37.792764: W external/ydf/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc:1814] "selective_gradient_boosting_ratio" set but "sampling_method" not equal to "SELGB".
[INFO 2023-02-01T12:08:37.792816689+00:00 kernel.cc:804] Training config:
learner: "HYPERPARAMETER_OPTIMIZER"
features: "^age$"
features: "^capital_gain$"
features: "^capital_loss$"
features: "^education$"
features: "^education_num$"
features: "^fnlwgt$"
features: "^hours_per_week$"
features: "^marital_status$"
features: "^native_country$"
features: "^occupation$"
features: "^race$"
features: "^relationship$"
features: "^sex$"
features: "^workclass$"
label: "^__LABEL$"
task: CLASSIFICATION
metadata {
  framework: "TF Keras"
}
[yggdrasil_decision_forests.model.hyperparameters_optimizer_v2.proto.hyperparameters_optimizer_config] {
  base_learner {
    learner: "GRADIENT_BOOSTED_TREES"
    features: "^age$"
    features: "^capital_gain$"
    features: "^capital_loss$"
    features: "^education$"
    features: "^education_num$"
    features: "^fnlwgt$"
    features: "^hours_per_week$"
    features: "^marital_status$"
    features: "^native_country$"
    features: "^occupation$"
    features: "^race$"
    features: "^relationship$"
    features: "^sex$"
    features: "^workclass$"
    label: "^__LABEL$"
    task: CLASSIFICATION
    random_seed: 123456
    pure_serving_model: false
    [yggdrasil_decision_forests.model.gradient_boosted_trees.proto.gradient_boosted_trees_config] {
      num_trees: 300
      decision_tree {
        max_depth: 6
        min_examples: 5
        in_split_min_examples_check: true
        keep_non_leaf_label_distribution: true
        num_candidate_attributes: -1
        missing_value_policy: GLOBAL_IMPUTATION
        allow_na_conditions: false
        categorical_set_greedy_forward {
          sampling: 0.1
          max_num_items: -1
          min_item_frequency: 1
        }
        growing_strategy_local {
        }
        categorical {
          cart {
          }
        }
        axis_aligned_split {
        }
        internal {
          sorting_strategy: PRESORTED
        }
        uplift {
          min_examples_in_treatment: 5
          split_score: KULLBACK_LEIBLER
        }
      }
      shrinkage: 0.1
      loss: DEFAULT
      validation_set_ratio: 0.1
      validation_interval_in_trees: 1
      early_stopping: VALIDATION_LOSS_INCREASE
      early_stopping_num_trees_look_ahead: 30
      l2_regularization: 0
      lambda_loss: 1
      mart {
      }
      adapt_subsample_for_maximum_training_duration: false
      l1_regularization: 0
      use_hessian_gain: false
      l2_regularization_categorical: 1
      stochastic_gradient_boosting {
        ratio: 1
      }
      apply_link_function: true
      compute_permutation_variable_importance: false
      binary_focal_loss_options {
        misprediction_exponent: 2
        positive_sample_coefficient: 0.5
      }
      early_stopping_initial_iteration: 10
    }
  }
  optimizer {
    optimizer_key: "RANDOM"
    [yggdrasil_decision_forests.model.hyperparameters_optimizer_v2.proto.random] {
      num_trials: 50
    }
  }
  search_space {
    fields {
      name: "min_examples"
      discrete_candidates {
        possible_values {
          integer: 2
        }
        possible_values {
          integer: 5
        }
        possible_values {
          integer: 7
        }
        possible_values {
          integer: 10
        }
      }
    }
    fields {
      name: "categorical_algorithm"
      discrete_candidates {
        possible_values {
          categorical: "CART"
        }
        possible_values {
          categorical: "RANDOM"
        }
      }
    }
    fields {
      name: "growing_strategy"
      discrete_candidates {
        possible_values {
          categorical: "LOCAL"
        }
        possible_values {
          categorical: "BEST_FIRST_GLOBAL"
        }
      }
      children {
        name: "max_depth"
        discrete_candidates {
          possible_values {
            integer: 3
          }
          possible_values {
            integer: 4
          }
          possible_values {
            integer: 5
          }
          possible_values {
            integer: 6
          }
          possible_values {
            integer: 8
          }
        }
        parent_discrete_values {
          possible_values {
            categorical: "LOCAL"
          }
        }
      }
      children {
        name: "max_num_nodes"
        discrete_candidates {
          possible_values {
            integer: 16
          }
          possible_values {
            integer: 32
          }
          possible_values {
            integer: 64
          }
          possible_values {
            integer: 128
          }
          possible_values {
            integer: 256
          }
        }
        parent_discrete_values {
          possible_values {
            categorical: "BEST_FIRST_GLOBAL"
          }
        }
      }
    }
    fields {
      name: "use_hessian_gain"
      discrete_candidates {
        possible_values {
          categorical: "true"
        }
        possible_values {
          categorical: "false"
        }
      }
    }
    fields {
      name: "shrinkage"
      discrete_candidates {
        possible_values {
          real: 0.02
        }
        possible_values {
          real: 0.05
        }
        possible_values {
          real: 0.1
        }
        possible_values {
          real: 0.15
        }
      }
    }
    fields {
      name: "num_candidate_attributes_ratio"
      discrete_candidates {
        possible_values {
          real: 0.2
        }
        possible_values {
          real: 0.5
        }
        possible_values {
          real: 0.9
        }
        possible_values {
          real: 1
        }
      }
    }
  }
  base_learner_deployment {
    num_threads: 1
  }
}

[INFO 2023-02-01T12:08:37.793220481+00:00 kernel.cc:807] Deployment config:
cache_path: "/tmpfs/tmp/tmpgjhxwzpg/working_cache"
num_threads: 32
try_resume_training: true

[INFO 2023-02-01T12:08:37.793376204+00:00 kernel.cc:868] Train model
[INFO 2023-02-01T12:08:39.851533704+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.583674
[INFO 2023-02-01T12:08:41.585165969+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.588227
[INFO 2023-02-01T12:08:42.092825385+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.569129
[INFO 2023-02-01T12:08:42.557544087+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.573016
[INFO 2023-02-01T12:08:42.602274077+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.575957
[INFO 2023-02-01T12:08:42.613803966+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.571689
[INFO 2023-02-01T12:08:43.519389753+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.575672
[INFO 2023-02-01T12:08:44.213780982+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.578549
[INFO 2023-02-01T12:08:44.759667214+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.573061
[INFO 2023-02-01T12:08:45.231593728+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.577748
[INFO 2023-02-01T12:08:45.479992055+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.578216
[INFO 2023-02-01T12:08:45.799623005+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.578669
[INFO 2023-02-01T12:08:45.964682066+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.586668
[INFO 2023-02-01T12:08:46.034461092+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.576506
[INFO 2023-02-01T12:08:47.555410961+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.578696
[INFO 2023-02-01T12:08:47.593686674+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.579464
[INFO 2023-02-01T12:08:47.838775628+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.574801
[INFO 2023-02-01T12:08:48.454877928+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.571936
[INFO 2023-02-01T12:08:48.711038125+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.587919
[INFO 2023-02-01T12:08:48.712049402+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.575719
[INFO 2023-02-01T12:08:48.946675629+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.578764
[INFO 2023-02-01T12:08:49.053943911+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.57629
[INFO 2023-02-01T12:08:49.315889807+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.581114
[INFO 2023-02-01T12:08:49.994744725+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.568287
[INFO 2023-02-01T12:08:52.573819212+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.570064
[INFO 2023-02-01T12:08:54.567070436+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.572158
[INFO 2023-02-01T12:08:58.961698295+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.573524
[INFO 2023-02-01T12:08:58.98892809+00:00 kernel.cc:905] Export model in log directory: /tmpfs/tmp/tmpgjhxwzpg with prefix 29a08c8a3f274273
[INFO 2023-02-01T12:08:58.995928198+00:00 kernel.cc:923] Save model in resources
[INFO 2023-02-01T12:08:58.998730119+00:00 abstract_model.cc:849] Model self evaluation:
Task: CLASSIFICATION
Label: __LABEL
Loss (BINOMIAL_LOG_LIKELIHOOD): 0.568287

Accuracy: 0.873395  CI95[W][0 1]
ErrorRate: : 0.126605


Confusion Table:
truth\prediction
   0     1    2
0  0     0    0
1  0  1570   94
2  0   192  403
Total: 2259

One vs other classes:

[INFO 2023-02-01T12:08:59.018306354+00:00 kernel.cc:1214] Loading model from path /tmpfs/tmp/tmpgjhxwzpg/model/ with prefix 29a08c8a3f274273
[INFO 2023-02-01T12:08:59.049560135+00:00 abstract_model.cc:1311] Engine "GradientBoostedTreesQuickScorerExtended" built
[INFO 2023-02-01T12:08:59.0495917+00:00 kernel.cc:1046] Use fast generic engine
Model trained in 0:00:21.277256
Compiling model...
Model compiled.
CPU times: user 6min 57s, sys: 332 ms, total: 6min 57s
Wall time: 21.9 s
<keras.callbacks.History at 0x7fca7c6ca100>
# Evaluate the model
tuned_model.compile(["accuracy"])
tuned_test_accuracy = tuned_model.evaluate(test_ds, return_dict=True, verbose=0)["accuracy"]
print(f"Test accuracy with the TF-DF hyper-parameter tuner: {tuned_test_accuracy:.4f}")
Test accuracy with the TF-DF hyper-parameter tuner: 0.8722

The hyper-parameters and objective scores of the trials are available in the model inspector. The score value is always maximized. In this example, the score is the negative log loss on the validation dataset (selected automatically).

# Display the tuning logs.
tuning_logs = tuned_model.make_inspector().tuning_logs()
tuning_logs.head()

The single rows with best=True is the one used in the final model.

# Best hyper-parameters.
tuning_logs[tuning_logs.best].iloc[0]
score                                     -0.568287
evaluation_time                           12.203865
best                                           True
min_examples                                      5
categorical_algorithm                          CART
growing_strategy                  BEST_FIRST_GLOBAL
max_depth                                       NaN
use_hessian_gain                               true
shrinkage                                       0.1
num_candidate_attributes_ratio                  0.9
max_num_nodes                                 256.0
Name: 34, dtype: object

Next, we plot the evaluation of the best score during the tuning.

plt.figure(figsize=(10, 5))
plt.plot(tuning_logs["score"], label="current trial")
plt.plot(tuning_logs["score"].cummax(), label="best trial")
plt.xlabel("Tuning step")
plt.ylabel("Tuning score")
plt.legend()
plt.show()

png

As before, hyper-parameter tuning is enabled by specifying the tuner constructor argument of the model. Set use_predefined_hps=True to automatically configure the search space for the hyper-parameters.

%%time
%set_cell_height 300

# Create a Random Search tuner with 50 trials and automatic hp configuration.
tuner = tfdf.tuner.RandomSearch(num_trials=50, use_predefined_hps=True)

# Define and train the model.
tuned_model = tfdf.keras.GradientBoostedTreesModel(tuner=tuner)
tuned_model.fit(train_ds, verbose=2)
<IPython.core.display.Javascript object>
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
Use /tmpfs/tmp/tmp7w1wg5fw as temporary training directory
Reading training dataset...
Training tensor examples:
Features: {'age': <tf.Tensor 'data:0' shape=(None,) dtype=int64>, 'workclass': <tf.Tensor 'data_13:0' shape=(None,) dtype=string>, 'fnlwgt': <tf.Tensor 'data_5:0' shape=(None,) dtype=int64>, 'education': <tf.Tensor 'data_3:0' shape=(None,) dtype=string>, 'education_num': <tf.Tensor 'data_4:0' shape=(None,) dtype=int64>, 'marital_status': <tf.Tensor 'data_7:0' shape=(None,) dtype=string>, 'occupation': <tf.Tensor 'data_9:0' shape=(None,) dtype=string>, 'relationship': <tf.Tensor 'data_11:0' shape=(None,) dtype=string>, 'race': <tf.Tensor 'data_10:0' shape=(None,) dtype=string>, 'sex': <tf.Tensor 'data_12:0' shape=(None,) dtype=string>, 'capital_gain': <tf.Tensor 'data_1:0' shape=(None,) dtype=int64>, 'capital_loss': <tf.Tensor 'data_2:0' shape=(None,) dtype=int64>, 'hours_per_week': <tf.Tensor 'data_6:0' shape=(None,) dtype=int64>, 'native_country': <tf.Tensor 'data_8:0' shape=(None,) dtype=string>}
Label: Tensor("data_14:0", shape=(None,), dtype=int64)
Weights: None
Normalized tensor features:
 {'age': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast:0' shape=(None,) dtype=float32>), 'workclass': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_13:0' shape=(None,) dtype=string>), 'fnlwgt': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast_1:0' shape=(None,) dtype=float32>), 'education': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_3:0' shape=(None,) dtype=string>), 'education_num': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast_2:0' shape=(None,) dtype=float32>), 'marital_status': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_7:0' shape=(None,) dtype=string>), 'occupation': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_9:0' shape=(None,) dtype=string>), 'relationship': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_11:0' shape=(None,) dtype=string>), 'race': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_10:0' shape=(None,) dtype=string>), 'sex': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_12:0' shape=(None,) dtype=string>), 'capital_gain': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast_3:0' shape=(None,) dtype=float32>), 'capital_loss': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast_4:0' shape=(None,) dtype=float32>), 'hours_per_week': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast_5:0' shape=(None,) dtype=float32>), 'native_country': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_8:0' shape=(None,) dtype=string>)}
Training dataset read in 0:00:00.369788. Found 22792 examples.
Training model...
[INFO 2023-02-01T12:09:00.101003542+00:00 kernel.cc:756] Start Yggdrasil model training
[INFO 2023-02-01T12:09:00.101032435+00:00 kernel.cc:757] Collect training examples
[INFO 2023-02-01T12:09:00.101119151+00:00 kernel.cc:388] Number of batches: 23
[INFO 2023-02-01T12:09:00.101135203+00:00 kernel.cc:389] Number of examples: 22792
[INFO 2023-02-01T12:09:00.108672367+00:00 data_spec_inference.cc:303] 1 item(s) have been pruned (i.e. they are considered out of dictionary) for the column native_country (40 item(s) left) because min_value_count=5 and max_number_of_unique_values=2000
[INFO 2023-02-01T12:09:00.108716013+00:00 data_spec_inference.cc:303] 1 item(s) have been pruned (i.e. they are considered out of dictionary) for the column occupation (13 item(s) left) because min_value_count=5 and max_number_of_unique_values=2000
[INFO 2023-02-01T12:09:00.10875597+00:00 data_spec_inference.cc:303] 1 item(s) have been pruned (i.e. they are considered out of dictionary) for the column workclass (7 item(s) left) because min_value_count=5 and max_number_of_unique_values=2000
[INFO 2023-02-01T12:09:00.113672478+00:00 kernel.cc:774] Training dataset:
Number of records: 22792
Number of columns: 15

Number of columns by type:
    CATEGORICAL: 9 (60%)
    NUMERICAL: 6 (40%)

Columns:

CATEGORICAL: 9 (60%)
    0: "__LABEL" CATEGORICAL integerized vocab-size:3 no-ood-item
    4: "education" CATEGORICAL has-dict vocab-size:17 zero-ood-items most-frequent:"HS-grad" 7340 (32.2043%)
    8: "marital_status" CATEGORICAL has-dict vocab-size:8 zero-ood-items most-frequent:"Married-civ-spouse" 10431 (45.7661%)
    9: "native_country" CATEGORICAL num-nas:407 (1.78571%) has-dict vocab-size:41 num-oods:1 (0.00446728%) most-frequent:"United-States" 20436 (91.2933%)
    10: "occupation" CATEGORICAL num-nas:1260 (5.52826%) has-dict vocab-size:14 num-oods:1 (0.00464425%) most-frequent:"Prof-specialty" 2870 (13.329%)
    11: "race" CATEGORICAL has-dict vocab-size:6 zero-ood-items most-frequent:"White" 19467 (85.4115%)
    12: "relationship" CATEGORICAL has-dict vocab-size:7 zero-ood-items most-frequent:"Husband" 9191 (40.3256%)
    13: "sex" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"Male" 15165 (66.5365%)
    14: "workclass" CATEGORICAL num-nas:1257 (5.51509%) has-dict vocab-size:8 num-oods:1 (0.0046436%) most-frequent:"Private" 15879 (73.7358%)

NUMERICAL: 6 (40%)
    1: "age" NUMERICAL mean:38.6153 min:17 max:90 sd:13.661
    2: "capital_gain" NUMERICAL mean:1081.9 min:0 max:99999 sd:7509.48
    3: "capital_loss" NUMERICAL mean:87.2806 min:0 max:4356 sd:403.01
    5: "education_num" NUMERICAL mean:10.0927 min:1 max:16 sd:2.56427
    6: "fnlwgt" NUMERICAL mean:189879 min:12285 max:1.4847e+06 sd:106423
    7: "hours_per_week" NUMERICAL mean:40.3955 min:1 max:99 sd:12.249

Terminology:
    nas: Number of non-available (i.e. missing) values.
    ood: Out of dictionary.
    manually-defined: Attribute which type is manually defined by the user i.e. the type was not automatically inferred.
    tokenized: The attribute value is obtained through tokenization.
    has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string.
    vocab-size: Number of unique values.

[INFO 2023-02-01T12:09:00.113731136+00:00 kernel.cc:790] Configure learner
2023-02-01 12:09:00.114002: W external/ydf/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc:1790] "goss_alpha" set but "sampling_method" not equal to "GOSS".
2023-02-01 12:09:00.114032: W external/ydf/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc:1800] "goss_beta" set but "sampling_method" not equal to "GOSS".
2023-02-01 12:09:00.114039: W external/ydf/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc:1814] "selective_gradient_boosting_ratio" set but "sampling_method" not equal to "SELGB".
[INFO 2023-02-01T12:09:00.114092836+00:00 kernel.cc:804] Training config:
learner: "HYPERPARAMETER_OPTIMIZER"
features: "^age$"
features: "^capital_gain$"
features: "^capital_loss$"
features: "^education$"
features: "^education_num$"
features: "^fnlwgt$"
features: "^hours_per_week$"
features: "^marital_status$"
features: "^native_country$"
features: "^occupation$"
features: "^race$"
features: "^relationship$"
features: "^sex$"
features: "^workclass$"
label: "^__LABEL$"
task: CLASSIFICATION
metadata {
  framework: "TF Keras"
}
[yggdrasil_decision_forests.model.hyperparameters_optimizer_v2.proto.hyperparameters_optimizer_config] {
  base_learner {
    learner: "GRADIENT_BOOSTED_TREES"
    features: "^age$"
    features: "^capital_gain$"
    features: "^capital_loss$"
    features: "^education$"
    features: "^education_num$"
    features: "^fnlwgt$"
    features: "^hours_per_week$"
    features: "^marital_status$"
    features: "^native_country$"
    features: "^occupation$"
    features: "^race$"
    features: "^relationship$"
    features: "^sex$"
    features: "^workclass$"
    label: "^__LABEL$"
    task: CLASSIFICATION
    random_seed: 123456
    pure_serving_model: false
    [yggdrasil_decision_forests.model.gradient_boosted_trees.proto.gradient_boosted_trees_config] {
      num_trees: 300
      decision_tree {
        max_depth: 6
        min_examples: 5
        in_split_min_examples_check: true
        keep_non_leaf_label_distribution: true
        num_candidate_attributes: -1
        missing_value_policy: GLOBAL_IMPUTATION
        allow_na_conditions: false
        categorical_set_greedy_forward {
          sampling: 0.1
          max_num_items: -1
          min_item_frequency: 1
        }
        growing_strategy_local {
        }
        categorical {
          cart {
          }
        }
        axis_aligned_split {
        }
        internal {
          sorting_strategy: PRESORTED
        }
        uplift {
          min_examples_in_treatment: 5
          split_score: KULLBACK_LEIBLER
        }
      }
      shrinkage: 0.1
      loss: DEFAULT
      validation_set_ratio: 0.1
      validation_interval_in_trees: 1
      early_stopping: VALIDATION_LOSS_INCREASE
      early_stopping_num_trees_look_ahead: 30
      l2_regularization: 0
      lambda_loss: 1
      mart {
      }
      adapt_subsample_for_maximum_training_duration: false
      l1_regularization: 0
      use_hessian_gain: false
      l2_regularization_categorical: 1
      stochastic_gradient_boosting {
        ratio: 1
      }
      apply_link_function: true
      compute_permutation_variable_importance: false
      binary_focal_loss_options {
        misprediction_exponent: 2
        positive_sample_coefficient: 0.5
      }
      early_stopping_initial_iteration: 10
    }
  }
  optimizer {
    optimizer_key: "RANDOM"
    [yggdrasil_decision_forests.model.hyperparameters_optimizer_v2.proto.random] {
      num_trials: 50
    }
  }
  base_learner_deployment {
    num_threads: 1
  }
  predefined_search_space {
  }
}

[INFO 2023-02-01T12:09:00.114222565+00:00 kernel.cc:807] Deployment config:
cache_path: "/tmpfs/tmp/tmp7w1wg5fw/working_cache"
num_threads: 32
try_resume_training: true

[INFO 2023-02-01T12:09:00.114451643+00:00 kernel.cc:868] Train model
[INFO 2023-02-01T12:09:27.673901013+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.617247
[INFO 2023-02-01T12:09:32.801753903+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.587795
[INFO 2023-02-01T12:09:38.108731583+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.579144
[INFO 2023-02-01T12:09:40.718985441+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.604322
[INFO 2023-02-01T12:09:48.519083808+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.601315
[INFO 2023-02-01T12:09:50.955101001+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.609047
[INFO 2023-02-01T12:09:51.305896716+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.576761
[INFO 2023-02-01T12:09:51.866282407+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.585247
[INFO 2023-02-01T12:09:52.639168731+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.574235
[INFO 2023-02-01T12:09:52.904800172+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.575853
[INFO 2023-02-01T12:09:57.049231376+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.634505
[INFO 2023-02-01T12:10:07.732286395+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.599006
[INFO 2023-02-01T12:10:13.226015779+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.57462
[INFO 2023-02-01T12:10:17.27092083+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.585657
[INFO 2023-02-01T12:10:17.867169081+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.575868
[INFO 2023-02-01T12:10:25.791623077+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.57937
[INFO 2023-02-01T12:10:36.666482972+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.593538
[INFO 2023-02-01T12:10:37.469933095+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.642896
[INFO 2023-02-01T12:10:37.708165541+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.580436
[INFO 2023-02-01T12:10:39.081284721+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.582612
[INFO 2023-02-01T12:10:39.402097572+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.623461
[INFO 2023-02-01T12:10:58.224513829+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.622613
[INFO 2023-02-01T12:11:03.082495092+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.636921
[INFO 2023-02-01T12:11:21.550245546+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.575538
[INFO 2023-02-01T12:11:37.764394969+00:00 kernel.cc:905] Export model in log directory: /tmpfs/tmp/tmp7w1wg5fw with prefix 5ba965b1f2a24c22
[INFO 2023-02-01T12:11:37.79468272+00:00 kernel.cc:923] Save model in resources
[INFO 2023-02-01T12:11:37.79904483+00:00 abstract_model.cc:849] Model self evaluation:
Task: CLASSIFICATION
Label: __LABEL
Loss (BINOMIAL_LOG_LIKELIHOOD): 0.573336

Accuracy: 0.867198  CI95[W][0 1]
ErrorRate: : 0.132802


Confusion Table:
truth\prediction
   0     1    2
0  0     0    0
1  0  1578   86
2  0   214  381
Total: 2259

One vs other classes:

[INFO 2023-02-01T12:11:37.823530826+00:00 kernel.cc:1214] Loading model from path /tmpfs/tmp/tmp7w1wg5fw/model/ with prefix 5ba965b1f2a24c22
[INFO 2023-02-01T12:11:37.981384124+00:00 decision_forest.cc:661] Model loaded with 284 root(s), 48262 node(s), and 14 input feature(s).
[INFO 2023-02-01T12:11:37.981421401+00:00 abstract_model.cc:1311] Engine "GradientBoostedTreesGeneric" built
[INFO 2023-02-01T12:11:37.981443494+00:00 kernel.cc:1046] Use fast generic engine
Model trained in 0:02:37.891659
Compiling model...
Model compiled.
CPU times: user 57min, sys: 1 s, total: 57min 1s
Wall time: 2min 38s
<keras.callbacks.History at 0x7fc9a37aea30>
# Evaluate the model
tuned_model.compile(["accuracy"])
tuned_test_accuracy = tuned_model.evaluate(test_ds, return_dict=True, verbose=0)["accuracy"]
print(f"Test accuracy with the TF-DF hyper-parameter tuner: {tuned_test_accuracy:.4f}")
Test accuracy with the TF-DF hyper-parameter tuner: 0.8741

Same as before, display the tuning logs.

# Display the tuning logs.
tuning_logs = tuned_model.make_inspector().tuning_logs()
tuning_logs.head()

Same as before, shows the best hyper-parameters.

# Best hyper-parameters.
tuning_logs[tuning_logs.best].iloc[0]
score                                            -0.573336
evaluation_time                                  124.74388
best                                                  True
split_axis                                  SPARSE_OBLIQUE
sparse_oblique_projection_density_factor               4.0
sparse_oblique_normalization                          NONE
sparse_oblique_weights                          CONTINUOUS
categorical_algorithm                                 CART
growing_strategy                                     LOCAL
max_num_nodes                                          NaN
sampling_method                                     RANDOM
subsample                                              0.9
shrinkage                                             0.02
min_examples                                             5
use_hessian_gain                                      true
num_candidate_attributes_ratio                         0.2
max_depth                                              8.0
Name: 41, dtype: object

Finally, plots the evolution of the quality of the model during tuning:

plt.figure(figsize=(10, 5))
plt.plot(tuning_logs["score"], label="current trial")
plt.plot(tuning_logs["score"].cummax(), label="best trial")
plt.xlabel("Tuning step")
plt.ylabel("Tuning score")
plt.legend()
plt.show()

png

Training a model with Keras Tuner (Alternative approach)

TensorFlow Decision Forests is based on the Keras framework, and it is compatible with the Keras tuner.

Currently, the TF-DF Tuner and the Keras Tuner are complementary.

TF-DF Tuner

  • Automatic configuration of the objective.
  • Automatic extraction of validation dataset (if needed).
  • Support model self evaluation (e.g. out-of-bag evaluation).
  • Distributed hyper-parameter tuning.
  • Shared dataset access in between the trials: The tensorflow dataset is read only once, speeding-up tuning significantly on small datasets.

Keras Tuner

  • Support tuning of the pre-processing parameters.
  • Support hyper-band optimizer.
  • Support custom objectives.

Let's tune a TF-DF model using the Keras tuner.

# Install the Keras tuner
!pip install keras-tuner -U -qq
import keras_tuner as kt
%%time

def build_model(hp):
  """Creates a model."""

  model = tfdf.keras.GradientBoostedTreesModel(
      min_examples=hp.Choice("min_examples", [2, 5, 7, 10]),
      categorical_algorithm=hp.Choice("categorical_algorithm", ["CART", "RANDOM"]),
      max_depth=hp.Choice("max_depth", [4, 5, 6, 7]),
      # The keras tuner convert automaticall boolean parameters to integers.
      use_hessian_gain=bool(hp.Choice("use_hessian_gain", [True, False])),
      shrinkage=hp.Choice("shrinkage", [0.02, 0.05, 0.10, 0.15]),
      num_candidate_attributes_ratio=hp.Choice("num_candidate_attributes_ratio", [0.2, 0.5, 0.9, 1.0]),
  )

  # Optimize the model accuracy as computed on the validation dataset.
  model.compile(metrics=["accuracy"])
  return model

keras_tuner = kt.RandomSearch(
    build_model,
    objective="val_accuracy",
    max_trials=50,
    overwrite=True,
    directory="/tmp/keras_tuning")

# Important: The tuning should not be done on the test dataset.

# Extract a validation dataset from the training dataset. The new training
# dataset is called the "sub-training-dataset".

def split_dataset(dataset, test_ratio=0.30):
  """Splits a panda dataframe in two."""
  test_indices = np.random.rand(len(dataset)) < test_ratio
  return dataset[~test_indices], dataset[test_indices]

sub_train_df, sub_valid_df = split_dataset(train_df)
sub_train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(sub_train_df, label="income")
sub_valid_ds = tfdf.keras.pd_dataframe_to_tf_dataset(sub_valid_df, label="income")

# Tune the model
keras_tuner.search(sub_train_ds, validation_data=sub_valid_ds)
Trial 50 Complete [00h 00m 05s]
val_accuracy: 0.8727272748947144

Best val_accuracy So Far: 0.8779433965682983
Total elapsed time: 00h 04m 28s
INFO:tensorflow:Oracle triggered exit
INFO:tensorflow:Oracle triggered exit
CPU times: user 7min 52s, sys: 1min 22s, total: 9min 14s
Wall time: 4min 28s

The best hyper-parameter are available with get_best_hyperparameters:

# Tune the model
best_hyper_parameters = keras_tuner.get_best_hyperparameters()[0].values
print("Best hyper-parameters:", keras_tuner.get_best_hyperparameters()[0].values)
Best hyper-parameters: {'min_examples': 7, 'categorical_algorithm': 'CART', 'max_depth': 7, 'use_hessian_gain': 0, 'shrinkage': 0.15, 'num_candidate_attributes_ratio': 0.9}

The model should be re-trained with the best hyper-parameters:

%set_cell_height 300
# Train the model
# The keras tuner convert automaticall boolean parameters to integers.
best_hyper_parameters["use_hessian_gain"] = bool(best_hyper_parameters["use_hessian_gain"])
best_model = tfdf.keras.GradientBoostedTreesModel(**best_hyper_parameters)
best_model.fit(train_ds, verbose=2)
<IPython.core.display.Javascript object>
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
Use /tmpfs/tmp/tmpau6vzbt0 as temporary training directory
Reading training dataset...
Training tensor examples:
Features: {'age': <tf.Tensor 'data:0' shape=(None,) dtype=int64>, 'workclass': <tf.Tensor 'data_13:0' shape=(None,) dtype=string>, 'fnlwgt': <tf.Tensor 'data_5:0' shape=(None,) dtype=int64>, 'education': <tf.Tensor 'data_3:0' shape=(None,) dtype=string>, 'education_num': <tf.Tensor 'data_4:0' shape=(None,) dtype=int64>, 'marital_status': <tf.Tensor 'data_7:0' shape=(None,) dtype=string>, 'occupation': <tf.Tensor 'data_9:0' shape=(None,) dtype=string>, 'relationship': <tf.Tensor 'data_11:0' shape=(None,) dtype=string>, 'race': <tf.Tensor 'data_10:0' shape=(None,) dtype=string>, 'sex': <tf.Tensor 'data_12:0' shape=(None,) dtype=string>, 'capital_gain': <tf.Tensor 'data_1:0' shape=(None,) dtype=int64>, 'capital_loss': <tf.Tensor 'data_2:0' shape=(None,) dtype=int64>, 'hours_per_week': <tf.Tensor 'data_6:0' shape=(None,) dtype=int64>, 'native_country': <tf.Tensor 'data_8:0' shape=(None,) dtype=string>}
Label: Tensor("data_14:0", shape=(None,), dtype=int64)
Weights: None
Normalized tensor features:
 {'age': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast:0' shape=(None,) dtype=float32>), 'workclass': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_13:0' shape=(None,) dtype=string>), 'fnlwgt': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast_1:0' shape=(None,) dtype=float32>), 'education': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_3:0' shape=(None,) dtype=string>), 'education_num': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast_2:0' shape=(None,) dtype=float32>), 'marital_status': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_7:0' shape=(None,) dtype=string>), 'occupation': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_9:0' shape=(None,) dtype=string>), 'relationship': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_11:0' shape=(None,) dtype=string>), 'race': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_10:0' shape=(None,) dtype=string>), 'sex': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_12:0' shape=(None,) dtype=string>), 'capital_gain': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast_3:0' shape=(None,) dtype=float32>), 'capital_loss': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast_4:0' shape=(None,) dtype=float32>), 'hours_per_week': SemanticTensor(semantic=<Semantic.NUMERICAL: 1>, tensor=<tf.Tensor 'Cast_5:0' shape=(None,) dtype=float32>), 'native_country': SemanticTensor(semantic=<Semantic.CATEGORICAL: 2>, tensor=<tf.Tensor 'data_8:0' shape=(None,) dtype=string>)}
Training dataset read in 0:00:00.372057. Found 22792 examples.
Training model...
[INFO 2023-02-01T12:16:10.151802063+00:00 kernel.cc:756] Start Yggdrasil model training
[INFO 2023-02-01T12:16:10.15183364+00:00 kernel.cc:757] Collect training examples
[INFO 2023-02-01T12:16:10.151922389+00:00 kernel.cc:388] Number of batches: 23
[INFO 2023-02-01T12:16:10.151938239+00:00 kernel.cc:389] Number of examples: 22792
[INFO 2023-02-01T12:16:10.159359793+00:00 data_spec_inference.cc:303] 1 item(s) have been pruned (i.e. they are considered out of dictionary) for the column native_country (40 item(s) left) because min_value_count=5 and max_number_of_unique_values=2000
[INFO 2023-02-01T12:16:10.15940551+00:00 data_spec_inference.cc:303] 1 item(s) have been pruned (i.e. they are considered out of dictionary) for the column occupation (13 item(s) left) because min_value_count=5 and max_number_of_unique_values=2000
[INFO 2023-02-01T12:16:10.159448183+00:00 data_spec_inference.cc:303] 1 item(s) have been pruned (i.e. they are considered out of dictionary) for the column workclass (7 item(s) left) because min_value_count=5 and max_number_of_unique_values=2000
[INFO 2023-02-01T12:16:10.16406424+00:00 kernel.cc:774] Training dataset:
Number of records: 22792
Number of columns: 15

Number of columns by type:
    CATEGORICAL: 9 (60%)
    NUMERICAL: 6 (40%)

Columns:

CATEGORICAL: 9 (60%)
    0: "__LABEL" CATEGORICAL integerized vocab-size:3 no-ood-item
    4: "education" CATEGORICAL has-dict vocab-size:17 zero-ood-items most-frequent:"HS-grad" 7340 (32.2043%)
    8: "marital_status" CATEGORICAL has-dict vocab-size:8 zero-ood-items most-frequent:"Married-civ-spouse" 10431 (45.7661%)
    9: "native_country" CATEGORICAL num-nas:407 (1.78571%) has-dict vocab-size:41 num-oods:1 (0.00446728%) most-frequent:"United-States" 20436 (91.2933%)
    10: "occupation" CATEGORICAL num-nas:1260 (5.52826%) has-dict vocab-size:14 num-oods:1 (0.00464425%) most-frequent:"Prof-specialty" 2870 (13.329%)
    11: "race" CATEGORICAL has-dict vocab-size:6 zero-ood-items most-frequent:"White" 19467 (85.4115%)
    12: "relationship" CATEGORICAL has-dict vocab-size:7 zero-ood-items most-frequent:"Husband" 9191 (40.3256%)
    13: "sex" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"Male" 15165 (66.5365%)
    14: "workclass" CATEGORICAL num-nas:1257 (5.51509%) has-dict vocab-size:8 num-oods:1 (0.0046436%) most-frequent:"Private" 15879 (73.7358%)

NUMERICAL: 6 (40%)
    1: "age" NUMERICAL mean:38.6153 min:17 max:90 sd:13.661
    2: "capital_gain" NUMERICAL mean:1081.9 min:0 max:99999 sd:7509.48
    3: "capital_loss" NUMERICAL mean:87.2806 min:0 max:4356 sd:403.01
    5: "education_num" NUMERICAL mean:10.0927 min:1 max:16 sd:2.56427
    6: "fnlwgt" NUMERICAL mean:189879 min:12285 max:1.4847e+06 sd:106423
    7: "hours_per_week" NUMERICAL mean:40.3955 min:1 max:99 sd:12.249

Terminology:
    nas: Number of non-available (i.e. missing) values.
    ood: Out of dictionary.
    manually-defined: Attribute which type is manually defined by the user i.e. the type was not automatically inferred.
    tokenized: The attribute value is obtained through tokenization.
    has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string.
    vocab-size: Number of unique values.

[INFO 2023-02-01T12:16:10.164120332+00:00 kernel.cc:790] Configure learner
2023-02-01 12:16:10.164338: W external/ydf/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc:1790] "goss_alpha" set but "sampling_method" not equal to "GOSS".
2023-02-01 12:16:10.164367: W external/ydf/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc:1800] "goss_beta" set but "sampling_method" not equal to "GOSS".
2023-02-01 12:16:10.164374: W external/ydf/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc:1814] "selective_gradient_boosting_ratio" set but "sampling_method" not equal to "SELGB".
[INFO 2023-02-01T12:16:10.164413528+00:00 kernel.cc:804] Training config:
learner: "GRADIENT_BOOSTED_TREES"
features: "^age$"
features: "^capital_gain$"
features: "^capital_loss$"
features: "^education$"
features: "^education_num$"
features: "^fnlwgt$"
features: "^hours_per_week$"
features: "^marital_status$"
features: "^native_country$"
features: "^occupation$"
features: "^race$"
features: "^relationship$"
features: "^sex$"
features: "^workclass$"
label: "^__LABEL$"
task: CLASSIFICATION
random_seed: 123456
metadata {
  framework: "TF Keras"
}
pure_serving_model: false
[yggdrasil_decision_forests.model.gradient_boosted_trees.proto.gradient_boosted_trees_config] {
  num_trees: 300
  decision_tree {
    max_depth: 7
    min_examples: 7
    in_split_min_examples_check: true
    keep_non_leaf_label_distribution: true
    missing_value_policy: GLOBAL_IMPUTATION
    allow_na_conditions: false
    categorical_set_greedy_forward {
      sampling: 0.1
      max_num_items: -1
      min_item_frequency: 1
    }
    growing_strategy_local {
    }
    categorical {
      cart {
      }
    }
    num_candidate_attributes_ratio: 0.9
    axis_aligned_split {
    }
    internal {
      sorting_strategy: PRESORTED
    }
    uplift {
      min_examples_in_treatment: 5
      split_score: KULLBACK_LEIBLER
    }
  }
  shrinkage: 0.15
  loss: DEFAULT
  validation_set_ratio: 0.1
  validation_interval_in_trees: 1
  early_stopping: VALIDATION_LOSS_INCREASE
  early_stopping_num_trees_look_ahead: 30
  l2_regularization: 0
  lambda_loss: 1
  mart {
  }
  adapt_subsample_for_maximum_training_duration: false
  l1_regularization: 0
  use_hessian_gain: false
  l2_regularization_categorical: 1
  stochastic_gradient_boosting {
    ratio: 1
  }
  apply_link_function: true
  compute_permutation_variable_importance: false
  binary_focal_loss_options {
    misprediction_exponent: 2
    positive_sample_coefficient: 0.5
  }
  early_stopping_initial_iteration: 10
}

[INFO 2023-02-01T12:16:10.164528213+00:00 kernel.cc:807] Deployment config:
cache_path: "/tmpfs/tmp/tmpau6vzbt0/working_cache"
num_threads: 32
try_resume_training: true

[INFO 2023-02-01T12:16:10.164745793+00:00 kernel.cc:868] Train model
[INFO 2023-02-01T12:16:12.125174473+00:00 early_stopping.cc:53] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.580282
[INFO 2023-02-01T12:16:12.132791707+00:00 kernel.cc:905] Export model in log directory: /tmpfs/tmp/tmpau6vzbt0 with prefix d9f51fa1e24f4e5a
[INFO 2023-02-01T12:16:12.136270135+00:00 kernel.cc:923] Save model in resources
[INFO 2023-02-01T12:16:12.138808984+00:00 abstract_model.cc:849] Model self evaluation:
Task: CLASSIFICATION
Label: __LABEL
Loss (BINOMIAL_LOG_LIKELIHOOD): 0.580282

Accuracy: 0.872067  CI95[W][0 1]
ErrorRate: : 0.127933


Confusion Table:
truth\prediction
   0     1    2
0  0     0    0
1  0  1572   92
2  0   197  398
Total: 2259

One vs other classes:

[INFO 2023-02-01T12:16:12.156739509+00:00 kernel.cc:1214] Loading model from path /tmpfs/tmp/tmpau6vzbt0/model/ with prefix d9f51fa1e24f4e5a
[INFO 2023-02-01T12:16:12.172480351+00:00 abstract_model.cc:1311] Engine "GradientBoostedTreesQuickScorerExtended" built
[INFO 2023-02-01T12:16:12.172513629+00:00 kernel.cc:1046] Use fast generic engine
Model trained in 0:00:02.027041
Compiling model...
Model compiled.
<keras.callbacks.History at 0x7fc9acac0130>

We can then evaluate the tuned model:

# Evaluate the model
best_model.compile(["accuracy"])
tuned_test_accuracy = best_model.evaluate(test_ds, return_dict=True, verbose=0)["accuracy"]
print(f"Test accuracy with the Keras Tuner: {tuned_test_accuracy:.4f}")
Test accuracy with the Keras Tuner: 0.8747