View on TensorFlow.org
|
Run in Google Colab
|
View on GitHub
|
Download notebook
|
Welcome to the Prediction Colab for TensorFlow Decision Forests (TF-DF). In this colab, you will learn about different ways to generate predictions with a previously trained TF-DF model using the Python API.
Remark: The Python API shown in this Colab is simple to use and well-suited for experimentation. However, other APIs, such as TensorFlow Serving and the C++ API are better suited for production systems as they are faster and more stable. The exhaustive list of all Serving APIs is available here.
In this colab, you will:
- Use the
model.predict()function on a TensorFlow Dataset created withpd_dataframe_to_tf_dataset. - Use the
model.predict()function on a TensorFlow Dataset created manually. - Use the
model.predict()function on Numpy arrays. - Make predictions with the CLI API.
- Benchmark the inference speed of a model with the CLI API.
Important remark
The dataset used for predictions should have the same feature names and types as the dataset used for training. Failing to do so, will likely raise errors.
For example, training a model with two features f1 and f2, and trying to generate predictions on a dataset without f2 will fail. Note that it is okay to set (some or all) feature values as "missing". Similarly, training a model where f2 is a numerical feature (e.g., float32), and applying this model on a dataset where f2 is a text (e.g., string) feature will fail.
While abstracted by the Keras API, a model instantiated in Python (e.g., with
tfdf.keras.RandomForestModel()) and a model loaded from disk (e.g., with
tf_keras.models.load_model()) can behave differently. Notably, a Python
instantiated model automatically applies necessary type conversions. For
example, if a float64 feature is fed to a model expecting a float32 feature,
this conversion is performed implicitly. However, such a conversion is not
possible for models loaded from disk. It is therefore important that the
training data and the inference data always have the exact same type.
Setup
First, we install TensorFlow Dececision Forests...
pip install tensorflow_decision_forests... , and import the libraries used in this example.
import tensorflow_decision_forests as tfdf
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import math
2026-01-12 14:11:42.620944: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered WARNING: All log messages before absl::InitializeLog() is called are written to STDERR E0000 00:00:1768227102.643324 152844 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered E0000 00:00:1768227102.650928 152844 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered W0000 00:00:1768227102.669004 152844 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1768227102.669024 152844 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1768227102.669027 152844 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1768227102.669029 152844 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
model.predict(...) and pd_dataframe_to_tf_dataset function
TensorFlow Decision Forests implements the Keras model API.
As such, TF-DF models have a predict function to make predictions. This function takes as input a TensorFlow Dataset and outputs a prediction array.
The simplest way to create a TensorFlow dataset is to use Pandas and the the tfdf.keras.pd_dataframe_to_tf_dataset(...) function.
The next example shows how to create a TensorFlow dataset using pd_dataframe_to_tf_dataset.
pd_dataset = pd.DataFrame({
"feature_1": [1,2,3],
"feature_2": ["a", "b", "c"],
"label": [0, 1, 0],
})
pd_dataset
tf_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(pd_dataset, label="label")
for features, label in tf_dataset:
print("Features:",features)
print("label:", label)
Features: {'feature_1': <tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>, 'feature_2': <tf.Tensor: shape=(3,), dtype=string, numpy=array([b'a', b'b', b'c'], dtype=object)>}
label: tf.Tensor([0 1 0], shape=(3,), dtype=int64)
I0000 00:00:1768227107.558996 152844 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13638 MB memory: -> device: 0, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5
I0000 00:00:1768227107.561439 152844 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13756 MB memory: -> device: 1, name: Tesla T4, pci bus id: 0000:00:06.0, compute capability: 7.5
I0000 00:00:1768227107.563727 152844 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 13756 MB memory: -> device: 2, name: Tesla T4, pci bus id: 0000:00:07.0, compute capability: 7.5
I0000 00:00:1768227107.566048 152844 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 13756 MB memory: -> device: 3, name: Tesla T4, pci bus id: 0000:00:08.0, compute capability: 7.5
Note: "pd" stands for "pandas". "tf" stands for "TensorFlow".
A TensorFlow Dataset is a function that outputs a sequence of values. Those values can be simple arrays (called Tensors) or arrays organized into a structure (for example, arrays organized in a dictionary).
The following example shows the training and inference (using predict) on a toy dataset:
# Creating a training dataset in Pandas
pd_train_dataset = pd.DataFrame({
"feature_1": np.random.rand(1000),
"feature_2": np.random.rand(1000),
})
pd_train_dataset["label"] = pd_train_dataset["feature_1"] > pd_train_dataset["feature_2"]
pd_train_dataset
# Creating a serving dataset with Pandas
pd_serving_dataset = pd.DataFrame({
"feature_1": np.random.rand(500),
"feature_2": np.random.rand(500),
})
pd_serving_dataset
Let's convert the Pandas dataframes into TensorFlow datasets:
tf_train_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(pd_train_dataset, label="label")
tf_serving_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(pd_serving_dataset)
We can now train a model on tf_train_dataset:
model = tfdf.keras.RandomForestModel(verbose=0)
model.fit(tf_train_dataset)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1768227111.647682 152844 kernel.cc:782] Start Yggdrasil model training
I0000 00:00:1768227111.647761 152844 kernel.cc:783] Collect training examples
I0000 00:00:1768227111.647770 152844 kernel.cc:795] Dataspec guide:
column_guides {
column_name_pattern: "^__LABEL$"
type: CATEGORICAL
categorial {
min_vocab_frequency: 0
max_vocab_count: -1
}
}
default_column_guide {
categorial {
max_vocab_count: 2000
}
discretized_numerical {
maximum_num_bins: 255
}
}
ignore_columns_without_guides: false
detect_numerical_as_discretized_numerical: false
I0000 00:00:1768227111.648149 152844 kernel.cc:401] Number of batches: 1
I0000 00:00:1768227111.648164 152844 kernel.cc:402] Number of examples: 1000
I0000 00:00:1768227111.648204 152844 kernel.cc:802] Training dataset:
Number of records: 1000
Number of columns: 3
Number of columns by type:
NUMERICAL: 2 (66.6667%)
CATEGORICAL: 1 (33.3333%)
Columns:
NUMERICAL: 2 (66.6667%)
1: "feature_1" NUMERICAL mean:0.517039 min:0.00141527 max:0.999845 sd:0.285208
2: "feature_2" NUMERICAL mean:0.502627 min:1.74395e-05 max:0.99854 sd:0.282236
CATEGORICAL: 1 (33.3333%)
0: "__LABEL" CATEGORICAL integerized vocab-size:3 no-ood-item
Terminology:
nas: Number of non-available (i.e. missing) values.
ood: Out of dictionary.
manually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred.
tokenized: The attribute value is obtained through tokenization.
has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string.
vocab-size: Number of unique values.
I0000 00:00:1768227111.648220 152844 kernel.cc:818] Configure learner
I0000 00:00:1768227111.648420 152844 kernel.cc:831] Training config:
learner: "RANDOM_FOREST"
features: "^feature_1$"
features: "^feature_2$"
label: "^__LABEL$"
task: CLASSIFICATION
random_seed: 123456
metadata {
framework: "TF Keras"
}
pure_serving_model: false
[yggdrasil_decision_forests.model.random_forest.proto.random_forest_config] {
num_trees: 300
decision_tree {
max_depth: 16
min_examples: 5
in_split_min_examples_check: true
keep_non_leaf_label_distribution: true
num_candidate_attributes: 0
missing_value_policy: GLOBAL_IMPUTATION
allow_na_conditions: false
categorical_set_greedy_forward {
sampling: 0.1
max_num_items: -1
min_item_frequency: 1
}
growing_strategy_local {
}
categorical {
cart {
}
}
axis_aligned_split {
}
internal {
sorting_strategy: PRESORTED
}
uplift {
min_examples_in_treatment: 5
split_score: KULLBACK_LEIBLER
}
numerical_vector_sequence {
max_num_test_examples: 1000
num_random_selected_anchors: 100
}
}
winner_take_all_inference: true
compute_oob_performances: true
compute_oob_variable_importances: false
num_oob_variable_importances_permutations: 1
bootstrap_training_dataset: true
bootstrap_size_ratio: 1
adapt_bootstrap_size_ratio_for_maximum_training_duration: false
sampling_with_replacement: true
}
I0000 00:00:1768227111.648761 152844 kernel.cc:834] Deployment config:
cache_path: "/tmpfs/tmp/tmpxx40f31r/working_cache"
num_threads: 32
try_resume_training: true
I0000 00:00:1768227111.648890 153098 kernel.cc:895] Train model
I0000 00:00:1768227111.648982 153098 random_forest.cc:438] Training random forest on 1000 example(s) and 2 feature(s).
I0000 00:00:1768227111.649276 153098 gpu.cc:93] Cannot initialize GPU: Not compiled with GPU support
I0000 00:00:1768227111.653642 153101 random_forest.cc:865] Train tree 1/300 accuracy:0.975477 logloss:0.883904 [index:2 total:0.00s tree:0.00s]
I0000 00:00:1768227111.654455 153101 random_forest.cc:865] Train tree 11/300 accuracy:0.972755 logloss:0.229973 [index:9 total:0.01s tree:0.00s]
I0000 00:00:1768227111.654887 153110 random_forest.cc:865] Train tree 22/300 accuracy:0.970971 logloss:0.164019 [index:21 total:0.01s tree:0.00s]
I0000 00:00:1768227111.655373 153105 random_forest.cc:865] Train tree 34/300 accuracy:0.972 logloss:0.163663 [index:35 total:0.01s tree:0.00s]
I0000 00:00:1768227111.655737 153125 random_forest.cc:865] Train tree 45/300 accuracy:0.973 logloss:0.129535 [index:40 total:0.01s tree:0.00s]
I0000 00:00:1768227111.656309 153121 random_forest.cc:865] Train tree 56/300 accuracy:0.971 logloss:0.131913 [index:55 total:0.01s tree:0.00s]
I0000 00:00:1768227111.657303 153108 random_forest.cc:865] Train tree 66/300 accuracy:0.974 logloss:0.0971358 [index:64 total:0.01s tree:0.00s]
I0000 00:00:1768227111.657886 153108 random_forest.cc:865] Train tree 76/300 accuracy:0.972 logloss:0.0634808 [index:75 total:0.01s tree:0.00s]
I0000 00:00:1768227111.658637 153110 random_forest.cc:865] Train tree 86/300 accuracy:0.977 logloss:0.0646741 [index:85 total:0.01s tree:0.00s]
I0000 00:00:1768227111.659935 153113 random_forest.cc:865] Train tree 102/300 accuracy:0.976 logloss:0.0646976 [index:102 total:0.01s tree:0.00s]
I0000 00:00:1768227111.660629 153112 random_forest.cc:865] Train tree 112/300 accuracy:0.972 logloss:0.0648093 [index:111 total:0.01s tree:0.00s]
I0000 00:00:1768227111.661623 153131 random_forest.cc:865] Train tree 126/300 accuracy:0.975 logloss:0.0643907 [index:125 total:0.01s tree:0.00s]
I0000 00:00:1768227111.662874 153132 random_forest.cc:865] Train tree 137/300 accuracy:0.974 logloss:0.063655 [index:139 total:0.01s tree:0.00s]
I0000 00:00:1768227111.663889 153129 random_forest.cc:865] Train tree 152/300 accuracy:0.976 logloss:0.0638584 [index:150 total:0.01s tree:0.00s]
I0000 00:00:1768227111.665005 153124 random_forest.cc:865] Train tree 163/300 accuracy:0.973 logloss:0.0638343 [index:163 total:0.02s tree:0.00s]
I0000 00:00:1768227111.665992 153106 random_forest.cc:865] Train tree 176/300 accuracy:0.973 logloss:0.0634206 [index:176 total:0.02s tree:0.00s]
I0000 00:00:1768227111.666795 153131 random_forest.cc:865] Train tree 186/300 accuracy:0.973 logloss:0.0630034 [index:185 total:0.02s tree:0.00s]
I0000 00:00:1768227111.667903 153102 random_forest.cc:865] Train tree 201/300 accuracy:0.973 logloss:0.0625779 [index:200 total:0.02s tree:0.00s]
I0000 00:00:1768227111.668869 153115 random_forest.cc:865] Train tree 213/300 accuracy:0.974 logloss:0.0622337 [index:211 total:0.02s tree:0.00s]
I0000 00:00:1768227111.669711 153124 random_forest.cc:865] Train tree 225/300 accuracy:0.975 logloss:0.0622276 [index:224 total:0.02s tree:0.00s]
I0000 00:00:1768227111.671399 153126 random_forest.cc:865] Train tree 246/300 accuracy:0.975 logloss:0.0621035 [index:246 total:0.02s tree:0.00s]
I0000 00:00:1768227111.672369 153128 random_forest.cc:865] Train tree 258/300 accuracy:0.975 logloss:0.061841 [index:257 total:0.02s tree:0.00s]
I0000 00:00:1768227111.673529 153130 random_forest.cc:865] Train tree 271/300 accuracy:0.973 logloss:0.0619744 [index:270 total:0.02s tree:0.00s]
I0000 00:00:1768227111.674364 153123 random_forest.cc:865] Train tree 281/300 accuracy:0.975 logloss:0.0618254 [index:280 total:0.02s tree:0.00s]
I0000 00:00:1768227111.676019 153101 random_forest.cc:865] Train tree 299/300 accuracy:0.976 logloss:0.0612395 [index:298 total:0.03s tree:0.00s]
I0000 00:00:1768227111.678348 153119 random_forest.cc:865] Train tree 300/300 accuracy:0.977 logloss:0.0610803 [index:297 total:0.03s tree:0.00s]
I0000 00:00:1768227111.678861 153098 random_forest.cc:949] Final OOB metrics: accuracy:0.977 logloss:0.0610803
I0000 00:00:1768227111.681173 153098 kernel.cc:926] Export model in log directory: /tmpfs/tmp/tmpxx40f31r with prefix e3a48a4b6d904ff0
I0000 00:00:1768227111.688609 153098 kernel.cc:944] Save model in resources
I0000 00:00:1768227111.691294 152844 abstract_model.cc:921] Model self evaluation:
Number of predictions (without weights): 1000
Number of predictions (with weights): 1000
Task: CLASSIFICATION
Label: __LABEL
Accuracy: 0.977 CI95[W][0.967571 0.984231]
LogLoss: : 0.0610803
ErrorRate: : 0.023
Default Accuracy: : 0.513
Default LogLoss: : 0.692809
Default ErrorRate: : 0.487
Confusion Table:
truth\prediction
1 2
1 476 11
2 12 501
Total: 1000
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1768227111.734091 152844 decision_forest.cc:808] Model loaded with 300 root(s), 11712 node(s), and 2 input feature(s).
I0000 00:00:1768227111.737382 152844 abstract_model.cc:1439] Engine "RandomForestOptPred" built
<tf_keras.src.callbacks.History at 0x7f1f40344400>
And then generate predictions on tf_serving_dataset:
# Print the first 10 predictions.
model.predict(tf_serving_dataset, verbose=0)[:10]
array([[0. ],
[0.99666584],
[0. ],
[0.76666605],
[0.12333328],
[0.99999917],
[0.00333333],
[0.99999917],
[0. ],
[0.99999917]], dtype=float32)
model.predict(...) and manual TF datasets
In the previous section, we showed how to create a TF dataset using the pd_dataframe_to_tf_dataset function. This option is simple but poorly suited for large datasets. Instead, TensorFlow offers several options to create a TensorFlow dataset.
The next examples shows how to create a dataset using the tf.data.Dataset.from_tensor_slices() function.
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5])
for value in dataset:
print("value:", value.numpy())
value: 1 value: 2 value: 3 value: 4 value: 5
TensorFlow models are trained with mini-batching: Instead of being fed one at a time, examples are grouped in "batches". For Neural Networks, the batch size impacts the quality of the model, and the optimal value needs to be determined by the user during training. For Decision Forests, the batch size has no impact on the model. However, for compatibility reasons, TensorFlow Decision Forests expects the dataset to be batched. Batching is done with the batch() function.
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5]).batch(2)
for value in dataset:
print("value:", value.numpy())
value: [1 2] value: [3 4] value: [5]
TensorFlow Decision Forests expects the dataset to be of one of two structures:
- features, label
- features, label, weights
The features can be a single 2 dimensional array (where each column is a feature and each row is an example), or a dictionary of arrays.
Following is an example of a dataset compatible with TensorFlow Decision Forests:
# A dataset with a single 2d array.
tf_dataset = tf.data.Dataset.from_tensor_slices(
([[1,2],[3,4],[5,6]], # Features
[0,1,0], # Label
)).batch(2)
for features, label in tf_dataset:
print("features:", features)
print("label:", label)
features: tf.Tensor( [[1 2] [3 4]], shape=(2, 2), dtype=int32) label: tf.Tensor([0 1], shape=(2,), dtype=int32) features: tf.Tensor([[5 6]], shape=(1, 2), dtype=int32) label: tf.Tensor([0], shape=(1,), dtype=int32)
# A dataset with a dictionary of features.
tf_dataset = tf.data.Dataset.from_tensor_slices(
({
"feature_1": [1,2,3],
"feature_2": [4,5,6],
},
[0,1,0], # Label
)).batch(2)
for features, label in tf_dataset:
print("features:", features)
print("label:", label)
features: {'feature_1': <tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 2], dtype=int32)>, 'feature_2': <tf.Tensor: shape=(2,), dtype=int32, numpy=array([4, 5], dtype=int32)>}
label: tf.Tensor([0 1], shape=(2,), dtype=int32)
features: {'feature_1': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([3], dtype=int32)>, 'feature_2': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([6], dtype=int32)>}
label: tf.Tensor([0], shape=(1,), dtype=int32)
Let's train a model with this second option.
tf_dataset = tf.data.Dataset.from_tensor_slices(
({
"feature_1": np.random.rand(100),
"feature_2": np.random.rand(100),
},
np.random.rand(100) >= 0.5, # Label
)).batch(2)
model = tfdf.keras.RandomForestModel(verbose=0)
model.fit(tf_dataset)
I0000 00:00:1768227112.741422 152844 kernel.cc:782] Start Yggdrasil model training
I0000 00:00:1768227112.741457 152844 kernel.cc:783] Collect training examples
I0000 00:00:1768227112.741465 152844 kernel.cc:795] Dataspec guide:
column_guides {
column_name_pattern: "^__LABEL$"
type: CATEGORICAL
categorial {
min_vocab_frequency: 0
max_vocab_count: -1
}
}
default_column_guide {
categorial {
max_vocab_count: 2000
}
discretized_numerical {
maximum_num_bins: 255
}
}
ignore_columns_without_guides: false
detect_numerical_as_discretized_numerical: false
I0000 00:00:1768227112.741517 152844 kernel.cc:401] Number of batches: 50
I0000 00:00:1768227112.741523 152844 kernel.cc:402] Number of examples: 100
I0000 00:00:1768227112.741538 152844 kernel.cc:802] Training dataset:
Number of records: 100
Number of columns: 3
Number of columns by type:
NUMERICAL: 2 (66.6667%)
CATEGORICAL: 1 (33.3333%)
Columns:
NUMERICAL: 2 (66.6667%)
1: "feature_1" NUMERICAL mean:0.492808 min:0.00508844 max:0.998128 sd:0.266833
2: "feature_2" NUMERICAL mean:0.528971 min:0.0440104 max:0.996712 sd:0.284846
CATEGORICAL: 1 (33.3333%)
0: "__LABEL" CATEGORICAL integerized vocab-size:3 no-ood-item
Terminology:
nas: Number of non-available (i.e. missing) values.
ood: Out of dictionary.
manually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred.
tokenized: The attribute value is obtained through tokenization.
has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string.
vocab-size: Number of unique values.
I0000 00:00:1768227112.741571 152844 kernel.cc:818] Configure learner
I0000 00:00:1768227112.741753 152844 kernel.cc:831] Training config:
learner: "RANDOM_FOREST"
features: "^feature_1$"
features: "^feature_2$"
label: "^__LABEL$"
task: CLASSIFICATION
random_seed: 123456
metadata {
framework: "TF Keras"
}
pure_serving_model: false
[yggdrasil_decision_forests.model.random_forest.proto.random_forest_config] {
num_trees: 300
decision_tree {
max_depth: 16
min_examples: 5
in_split_min_examples_check: true
keep_non_leaf_label_distribution: true
num_candidate_attributes: 0
missing_value_policy: GLOBAL_IMPUTATION
allow_na_conditions: false
categorical_set_greedy_forward {
sampling: 0.1
max_num_items: -1
min_item_frequency: 1
}
growing_strategy_local {
}
categorical {
cart {
}
}
axis_aligned_split {
}
internal {
sorting_strategy: PRESORTED
}
uplift {
min_examples_in_treatment: 5
split_score: KULLBACK_LEIBLER
}
numerical_vector_sequence {
max_num_test_examples: 1000
num_random_selected_anchors: 100
}
}
winner_take_all_inference: true
compute_oob_performances: true
compute_oob_variable_importances: false
num_oob_variable_importances_permutations: 1
bootstrap_training_dataset: true
bootstrap_size_ratio: 1
adapt_bootstrap_size_ratio_for_maximum_training_duration: false
sampling_with_replacement: true
}
I0000 00:00:1768227112.741812 152844 kernel.cc:834] Deployment config:
cache_path: "/tmpfs/tmp/tmpjicjg1cm/working_cache"
num_threads: 32
try_resume_training: true
I0000 00:00:1768227112.742008 153439 kernel.cc:895] Train model
I0000 00:00:1768227112.742113 153439 random_forest.cc:438] Training random forest on 100 example(s) and 2 feature(s).
I0000 00:00:1768227112.742386 153439 gpu.cc:93] Cannot initialize GPU: Not compiled with GPU support
I0000 00:00:1768227112.743620 153473 random_forest.cc:865] Train tree 1/300 accuracy:0.472222 logloss:19.023 [index:0 total:0.00s tree:0.00s]
I0000 00:00:1768227112.743795 153446 random_forest.cc:865] Train tree 12/300 accuracy:0.47 logloss:9.80066 [index:12 total:0.00s tree:0.00s]
I0000 00:00:1768227112.743883 153455 random_forest.cc:865] Train tree 22/300 accuracy:0.49 logloss:5.64613 [index:18 total:0.00s tree:0.00s]
I0000 00:00:1768227112.744065 153458 random_forest.cc:865] Train tree 34/300 accuracy:0.5 logloss:2.60446 [index:27 total:0.00s tree:0.00s]
I0000 00:00:1768227112.744231 153464 random_forest.cc:865] Train tree 45/300 accuracy:0.48 logloss:1.90152 [index:41 total:0.00s tree:0.00s]
I0000 00:00:1768227112.744329 153471 random_forest.cc:865] Train tree 56/300 accuracy:0.51 logloss:1.57244 [index:55 total:0.00s tree:0.00s]
I0000 00:00:1768227112.744580 153456 random_forest.cc:865] Train tree 67/300 accuracy:0.51 logloss:0.915122 [index:66 total:0.00s tree:0.00s]
I0000 00:00:1768227112.744794 153462 random_forest.cc:865] Train tree 77/300 accuracy:0.43 logloss:0.90219 [index:76 total:0.00s tree:0.00s]
I0000 00:00:1768227112.745035 153462 random_forest.cc:865] Train tree 87/300 accuracy:0.47 logloss:0.899716 [index:87 total:0.00s tree:0.00s]
I0000 00:00:1768227112.745232 153469 random_forest.cc:865] Train tree 97/300 accuracy:0.49 logloss:0.911851 [index:96 total:0.00s tree:0.00s]
I0000 00:00:1768227112.745525 153449 random_forest.cc:865] Train tree 107/300 accuracy:0.42 logloss:0.899317 [index:106 total:0.00s tree:0.00s]
I0000 00:00:1768227112.745735 153447 random_forest.cc:865] Train tree 117/300 accuracy:0.39 logloss:0.897874 [index:116 total:0.00s tree:0.00s]
I0000 00:00:1768227112.745980 153459 random_forest.cc:865] Train tree 127/300 accuracy:0.41 logloss:0.889305 [index:126 total:0.00s tree:0.00s]
I0000 00:00:1768227112.746185 153465 random_forest.cc:865] Train tree 137/300 accuracy:0.42 logloss:0.890286 [index:136 total:0.00s tree:0.00s]
I0000 00:00:1768227112.746450 153463 random_forest.cc:865] Train tree 148/300 accuracy:0.41 logloss:0.883531 [index:147 total:0.00s tree:0.00s]
I0000 00:00:1768227112.746697 153451 random_forest.cc:865] Train tree 158/300 accuracy:0.43 logloss:0.878797 [index:157 total:0.00s tree:0.00s]
I0000 00:00:1768227112.746920 153451 random_forest.cc:865] Train tree 168/300 accuracy:0.42 logloss:0.875887 [index:167 total:0.00s tree:0.00s]
I0000 00:00:1768227112.747175 153461 random_forest.cc:865] Train tree 178/300 accuracy:0.41 logloss:0.86972 [index:177 total:0.00s tree:0.00s]
I0000 00:00:1768227112.747384 153472 random_forest.cc:865] Train tree 188/300 accuracy:0.4 logloss:0.866629 [index:187 total:0.00s tree:0.00s]
I0000 00:00:1768227112.747626 153451 random_forest.cc:865] Train tree 199/300 accuracy:0.41 logloss:0.865727 [index:199 total:0.01s tree:0.00s]
I0000 00:00:1768227112.747844 153449 random_forest.cc:865] Train tree 209/300 accuracy:0.42 logloss:0.861723 [index:208 total:0.01s tree:0.00s]
I0000 00:00:1768227112.748075 153458 random_forest.cc:865] Train tree 221/300 accuracy:0.42 logloss:0.860255 [index:219 total:0.01s tree:0.00s]
I0000 00:00:1768227112.748375 153473 random_forest.cc:865] Train tree 232/300 accuracy:0.41 logloss:0.860808 [index:231 total:0.01s tree:0.00s]
I0000 00:00:1768227112.748683 153463 random_forest.cc:865] Train tree 242/300 accuracy:0.44 logloss:0.854936 [index:241 total:0.01s tree:0.00s]
I0000 00:00:1768227112.748912 153471 random_forest.cc:865] Train tree 253/300 accuracy:0.44 logloss:0.85598 [index:250 total:0.01s tree:0.00s]
I0000 00:00:1768227112.749185 153449 random_forest.cc:865] Train tree 264/300 accuracy:0.43 logloss:0.850913 [index:263 total:0.01s tree:0.00s]
I0000 00:00:1768227112.749431 153471 random_forest.cc:865] Train tree 274/300 accuracy:0.43 logloss:0.851879 [index:270 total:0.01s tree:0.00s]
I0000 00:00:1768227112.749671 153443 random_forest.cc:865] Train tree 284/300 accuracy:0.43 logloss:0.849053 [index:284 total:0.01s tree:0.00s]
I0000 00:00:1768227112.749877 153472 random_forest.cc:865] Train tree 294/300 accuracy:0.42 logloss:0.850496 [index:293 total:0.01s tree:0.00s]
I0000 00:00:1768227112.750060 153472 random_forest.cc:865] Train tree 300/300 accuracy:0.43 logloss:0.851824 [index:297 total:0.01s tree:0.00s]
I0000 00:00:1768227112.751344 153439 random_forest.cc:949] Final OOB metrics: accuracy:0.43 logloss:0.851824
I0000 00:00:1768227112.753060 153439 kernel.cc:926] Export model in log directory: /tmpfs/tmp/tmpjicjg1cm with prefix f27d8013e96543cd
I0000 00:00:1768227112.758159 153439 kernel.cc:944] Save model in resources
I0000 00:00:1768227112.759793 152844 abstract_model.cc:921] Model self evaluation:
Number of predictions (without weights): 100
Number of predictions (with weights): 100
Task: CLASSIFICATION
Label: __LABEL
Accuracy: 0.43 CI95[W][0.346007 0.517175]
LogLoss: : 0.851824
ErrorRate: : 0.57
Default Accuracy: : 0.51
Default LogLoss: : 0.692947
Default ErrorRate: : 0.49
Confusion Table:
truth\prediction
1 2
1 25 26
2 31 18
Total: 100
I0000 00:00:1768227112.791513 152844 decision_forest.cc:808] Model loaded with 300 root(s), 8342 node(s), and 2 input feature(s).
<tf_keras.src.callbacks.History at 0x7f1f40064d30>
The predict function can be used directly on the training dataset:
# The first 10 predictions.
model.predict(tf_dataset, verbose=0)[:10]
array([[0.84333265],
[0.9733325 ],
[0.9699992 ],
[0.5733329 ],
[0.3233331 ],
[0.23666652],
[0.41666636],
[0.17666657],
[0.2666665 ],
[0.40666637]], dtype=float32)
model.predict(...) and model.predict_on_batch() on dictionaries
In some cases, the predict function can be used with an array (or dictionaries of arrays) instead of TensorFlow Dataset.
The following example uses the previously trained model with a dictionary of NumPy arrays.
# The first 10 predictions.
model.predict({
"feature_1": np.random.rand(100),
"feature_2": np.random.rand(100),
}, verbose=0)[:10]
array([[0.08666665],
[0.4766663 ],
[0.2466665 ],
[0.47999963],
[0.25333318],
[0.6066662 ],
[0.3766664 ],
[0.08333332],
[0.21999986],
[0.07666666]], dtype=float32)
In the previous example, the arrays are automatically batched. Alternatively, the predict_on_batch function can be used to make sure that all the examples are run in the same batch.
# The first 10 predictions.
model.predict_on_batch({
"feature_1": np.random.rand(100),
"feature_2": np.random.rand(100),
})[:10]
array([[0.49666628],
[0.18333323],
[0.47999963],
[0.5199996 ],
[0.33666644],
[0.48999962],
[0.3333331 ],
[0.1899999 ],
[0.2566665 ],
[0.20999987]], dtype=float32)
Inference with the YDF format
This example shows how to run a TF-DF model trained with the CLI API (one of the other Serving APIs). We will also use the Benchmark tool to measure the inference speed of the model.
Let's start by training and saving a model:
model = tfdf.keras.GradientBoostedTreesModel(verbose=0)
model.fit(tfdf.keras.pd_dataframe_to_tf_dataset(pd_train_dataset, label="label"))
model.save("my_model")
W0000 00:00:1768227113.214357 152844 gradient_boosted_trees.cc:1873] "goss_alpha" set but "sampling_method" not equal to "GOSS".
W0000 00:00:1768227113.214387 152844 gradient_boosted_trees.cc:1883] "goss_beta" set but "sampling_method" not equal to "GOSS".
W0000 00:00:1768227113.214390 152844 gradient_boosted_trees.cc:1897] "selective_gradient_boosting_ratio" set but "sampling_method" not equal to "SELGB".
I0000 00:00:1768227113.358944 152844 kernel.cc:782] Start Yggdrasil model training
I0000 00:00:1768227113.358989 152844 kernel.cc:783] Collect training examples
I0000 00:00:1768227113.358998 152844 kernel.cc:795] Dataspec guide:
column_guides {
column_name_pattern: "^__LABEL$"
type: CATEGORICAL
categorial {
min_vocab_frequency: 0
max_vocab_count: -1
}
}
default_column_guide {
categorial {
max_vocab_count: 2000
}
discretized_numerical {
maximum_num_bins: 255
}
}
ignore_columns_without_guides: false
detect_numerical_as_discretized_numerical: false
I0000 00:00:1768227113.359054 152844 kernel.cc:401] Number of batches: 1
I0000 00:00:1768227113.359060 152844 kernel.cc:402] Number of examples: 1000
I0000 00:00:1768227113.359099 152844 kernel.cc:802] Training dataset:
Number of records: 1000
Number of columns: 3
Number of columns by type:
NUMERICAL: 2 (66.6667%)
CATEGORICAL: 1 (33.3333%)
Columns:
NUMERICAL: 2 (66.6667%)
1: "feature_1" NUMERICAL mean:0.517039 min:0.00141527 max:0.999845 sd:0.285208
2: "feature_2" NUMERICAL mean:0.502627 min:1.74395e-05 max:0.99854 sd:0.282236
CATEGORICAL: 1 (33.3333%)
0: "__LABEL" CATEGORICAL integerized vocab-size:3 no-ood-item
Terminology:
nas: Number of non-available (i.e. missing) values.
ood: Out of dictionary.
manually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred.
tokenized: The attribute value is obtained through tokenization.
has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string.
vocab-size: Number of unique values.
I0000 00:00:1768227113.359119 152844 kernel.cc:818] Configure learner
W0000 00:00:1768227113.359324 152844 gradient_boosted_trees.cc:1873] "goss_alpha" set but "sampling_method" not equal to "GOSS".
W0000 00:00:1768227113.359335 152844 gradient_boosted_trees.cc:1883] "goss_beta" set but "sampling_method" not equal to "GOSS".
W0000 00:00:1768227113.359338 152844 gradient_boosted_trees.cc:1897] "selective_gradient_boosting_ratio" set but "sampling_method" not equal to "SELGB".
I0000 00:00:1768227113.359391 152844 kernel.cc:831] Training config:
learner: "GRADIENT_BOOSTED_TREES"
features: "^feature_1$"
features: "^feature_2$"
label: "^__LABEL$"
task: CLASSIFICATION
random_seed: 123456
metadata {
framework: "TF Keras"
}
pure_serving_model: false
[yggdrasil_decision_forests.model.gradient_boosted_trees.proto.gradient_boosted_trees_config] {
num_trees: 300
decision_tree {
max_depth: 6
min_examples: 5
in_split_min_examples_check: true
keep_non_leaf_label_distribution: true
num_candidate_attributes: -1
missing_value_policy: GLOBAL_IMPUTATION
allow_na_conditions: false
categorical_set_greedy_forward {
sampling: 0.1
max_num_items: -1
min_item_frequency: 1
}
growing_strategy_local {
}
categorical {
cart {
}
}
axis_aligned_split {
}
internal {
sorting_strategy: PRESORTED
}
uplift {
min_examples_in_treatment: 5
split_score: KULLBACK_LEIBLER
}
numerical_vector_sequence {
max_num_test_examples: 1000
num_random_selected_anchors: 100
}
}
shrinkage: 0.1
loss: DEFAULT
validation_set_ratio: 0.1
validation_interval_in_trees: 1
early_stopping: VALIDATION_LOSS_INCREASE
early_stopping_num_trees_look_ahead: 30
l2_regularization: 0
lambda_loss: 1
mart {
}
adapt_subsample_for_maximum_training_duration: false
l1_regularization: 0
use_hessian_gain: false
l2_regularization_categorical: 1
xe_ndcg {
ndcg_truncation: 5
}
stochastic_gradient_boosting {
ratio: 1
}
apply_link_function: true
compute_permutation_variable_importance: false
early_stopping_initial_iteration: 10
}
I0000 00:00:1768227113.359635 152844 kernel.cc:834] Deployment config:
cache_path: "/tmpfs/tmp/tmp4417l1rz/working_cache"
num_threads: 32
try_resume_training: true
I0000 00:00:1768227113.359743 153715 kernel.cc:895] Train model
I0000 00:00:1768227113.359859 153715 gradient_boosted_trees.cc:577] Default loss set to BINOMIAL_LOG_LIKELIHOOD
I0000 00:00:1768227113.359877 153715 gradient_boosted_trees.cc:1190] Training gradient boosted tree on 1000 example(s) and 2 feature(s).
I0000 00:00:1768227113.360048 153715 gradient_boosted_trees.cc:1230] 923 examples used for training and 77 examples used for validation
I0000 00:00:1768227113.360461 153715 gpu.cc:93] Cannot initialize GPU: Not compiled with GPU support
I0000 00:00:1768227113.364931 153715 gradient_boosted_trees.cc:1632] Train tree 1/300 train-loss:1.204496 train-accuracy:0.980498 valid-loss:1.212405 valid-accuracy:0.948052 [total:0.00s iter:0.00s]
I0000 00:00:1768227113.368397 153715 gradient_boosted_trees.cc:1632] Train tree 2/300 train-loss:1.056152 train-accuracy:0.990249 valid-loss:1.072781 valid-accuracy:0.948052 [total:0.01s iter:0.00s]
I0000 00:00:1768227113.371833 153715 gradient_boosted_trees.cc:1634] Train tree 3/300 train-loss:0.932086 train-accuracy:0.988082 valid-loss:0.952990 valid-accuracy:0.961039 [total:0.01s iter:0.00s]
I0000 00:00:1768227113.585103 153715 early_stopping.cc:54] Early stop of the training because the validation loss does not decrease anymore. Best valid-loss: 0.154844
I0000 00:00:1768227113.585144 153715 gradient_boosted_trees.cc:1669] Create final snapshot of the model at iteration 68
I0000 00:00:1768227113.588020 153715 gradient_boosted_trees.cc:279] Truncates the model to 39 tree(s) i.e. 39 iteration(s).
I0000 00:00:1768227113.588240 153715 gradient_boosted_trees.cc:341] Final model num-trees:39 valid-loss:0.154844 valid-accuracy:0.961039
I0000 00:00:1768227113.588837 153715 kernel.cc:926] Export model in log directory: /tmpfs/tmp/tmp4417l1rz with prefix f5d28a164e744a13
I0000 00:00:1768227113.590380 153715 kernel.cc:944] Save model in resources
I0000 00:00:1768227113.592303 152844 abstract_model.cc:921] Model self evaluation:
Task: CLASSIFICATION
Label: __LABEL
Loss (BINOMIAL_LOG_LIKELIHOOD): 0.154844
Accuracy: 0.961039 CI95[W][0 1]
ErrorRate: : 0.0389611
Confusion Table:
truth\prediction
1 2
1 37 1
2 2 37
Total: 77
I0000 00:00:1768227113.603835 152844 quick_scorer_extended.cc:927] The binary was compiled without AVX2 support, but your CPU supports it. Enable it for faster model inference.
INFO:tensorflow:Assets written to: my_model/assets
INFO:tensorflow:Assets written to: my_model/assets
Let's also export the dataset to a csv file:
pd_serving_dataset.to_csv("dataset.csv")
Let's download and extract the Yggdrasil Decision Forests CLI tools.
wget https://github.com/google/yggdrasil-decision-forests/releases/download/1.0.0/cli_linux.zipunzip cli_linux.zip
--2026-01-12 14:11:54-- https://github.com/google/yggdrasil-decision-forests/releases/download/1.0.0/cli_linux.zip Resolving github.com (github.com)... 140.82.112.4 Connecting to github.com (github.com)|140.82.112.4|:443... connected. HTTP request sent, awaiting response... 302 Found Location: https://release-assets.githubusercontent.com/github-production-release-asset/360444739/bfcd0b9d-5cbc-42a8-be0a-02131875f9a6?sp=r&sv=2018-11-09&sr=b&spr=https&se=2026-01-12T15%3A00%3A08Z&rscd=attachment%3B+filename%3Dcli_linux.zip&rsct=application%2Foctet-stream&skoid=96c2d410-5711-43a1-aedd-ab1947aa7ab0&sktid=398a6654-997b-47e9-b12b-9515b896b4de&skt=2026-01-12T13%3A59%3A29Z&ske=2026-01-12T15%3A00%3A08Z&sks=b&skv=2018-11-09&sig=PFFh4ZoBrSUveSKyVVWStvEDAP30UyjdJrPxDW90ek4%3D&jwt=eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmVsZWFzZS1hc3NldHMuZ2l0aHVidXNlcmNvbnRlbnQuY29tIiwia2V5Ijoia2V5MSIsImV4cCI6MTc2ODIyODkxNCwibmJmIjoxNzY4MjI3MTE0LCJwYXRoIjoicmVsZWFzZWFzc2V0cHJvZHVjdGlvbi5ibG9iLmNvcmUud2luZG93cy5uZXQifQ.tkjqR6oxVIRrYVJ1Eq6d_OOG0Dh9o0ofYJ3Suy7BDq8&response-content-disposition=attachment%3B%20filename%3Dcli_linux.zip&response-content-type=application%2Foctet-stream [following] --2026-01-12 14:11:54-- https://release-assets.githubusercontent.com/github-production-release-asset/360444739/bfcd0b9d-5cbc-42a8-be0a-02131875f9a6?sp=r&sv=2018-11-09&sr=b&spr=https&se=2026-01-12T15%3A00%3A08Z&rscd=attachment%3B+filename%3Dcli_linux.zip&rsct=application%2Foctet-stream&skoid=96c2d410-5711-43a1-aedd-ab1947aa7ab0&sktid=398a6654-997b-47e9-b12b-9515b896b4de&skt=2026-01-12T13%3A59%3A29Z&ske=2026-01-12T15%3A00%3A08Z&sks=b&skv=2018-11-09&sig=PFFh4ZoBrSUveSKyVVWStvEDAP30UyjdJrPxDW90ek4%3D&jwt=eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmVsZWFzZS1hc3NldHMuZ2l0aHVidXNlcmNvbnRlbnQuY29tIiwia2V5Ijoia2V5MSIsImV4cCI6MTc2ODIyODkxNCwibmJmIjoxNzY4MjI3MTE0LCJwYXRoIjoicmVsZWFzZWFzc2V0cHJvZHVjdGlvbi5ibG9iLmNvcmUud2luZG93cy5uZXQifQ.tkjqR6oxVIRrYVJ1Eq6d_OOG0Dh9o0ofYJ3Suy7BDq8&response-content-disposition=attachment%3B%20filename%3Dcli_linux.zip&response-content-type=application%2Foctet-stream Resolving release-assets.githubusercontent.com (release-assets.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.109.133, ... Connecting to release-assets.githubusercontent.com (release-assets.githubusercontent.com)|185.199.111.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 31516027 (30M) [application/octet-stream] Saving to: ‘cli_linux.zip’ cli_linux.zip 100%[===================>] 30.06M 83.2MB/s in 0.4s 2026-01-12 14:11:54 (83.2 MB/s) - ‘cli_linux.zip’ saved [31516027/31516027] Archive: cli_linux.zip inflating: README inflating: cli.txt inflating: train inflating: show_model inflating: show_dataspec inflating: predict inflating: infer_dataspec inflating: evaluate inflating: convert_dataset inflating: benchmark_inference inflating: edit_model inflating: synthetic_dataset inflating: grpc_worker_main inflating: LICENSE inflating: CHANGELOG.md
Finally, let's make predictions:
Remarks:
- TensorFlow Decision Forests (TF-DF) is based on the Yggdrasil Decision Forests (YDF) library, and TF-DF model always contains a YDF model internally. When saving a TF-DF model to disk, the TF-DF model directory contains an
assetssub-directory containing the YDF model. This YDF model can be used with all YDF tools. In the next example, we will use thepredictandbenchmark_inferencetools. See the model format documentation for more details. - YDF tools assume that the type of the dataset is specified using a prefix, e.g.
csv:. See the YDF user manual for more details.
./predict --model=my_model/assets --dataset=csv:dataset.csv --output=csv:predictions.csv[INFO abstract_model.cc:1296] Engine "GradientBoostedTreesQuickScorerExtended" built [INFO predict.cc:133] Run predictions with semi-fast engine
We can now look at the predictions:
pd.read_csv("predictions.csv")
The speed of inference of a model can be measured with the benchmark inference tool.
# Create the empty label column.
pd_serving_dataset["__LABEL"] = 0
pd_serving_dataset.to_csv("dataset.csv")
!./benchmark_inference \
--model=my_model/assets \
--dataset=csv:dataset.csv \
--batch_size=100 \
--warmup_runs=10 \
--num_runs=50
[INFO benchmark_inference.cc:245] Loading model
[INFO benchmark_inference.cc:248] The model is of type: GRADIENT_BOOSTED_TREES
[INFO benchmark_inference.cc:250] Loading dataset
[INFO benchmark_inference.cc:259] Found 3 compatible fast engines.
[INFO benchmark_inference.cc:262] Running GradientBoostedTreesGeneric
[INFO decision_forest.cc:639] Model loaded with 39 root(s), 2179 node(s), and 2 input feature(s).
[INFO benchmark_inference.cc:262] Running GradientBoostedTreesQuickScorerExtended
[INFO benchmark_inference.cc:262] Running GradientBoostedTreesOptPred
[INFO decision_forest.cc:639] Model loaded with 39 root(s), 2179 node(s), and 2 input feature(s).
[INFO benchmark_inference.cc:268] Running the slow generic engine
batch_size : 100 num_runs : 50
time/example(us) time/batch(us) method
----------------------------------------
0.34825 34.825 GradientBoostedTreesQuickScorerExtended [virtual interface]
0.5015 50.15 GradientBoostedTreesOptPred [virtual interface]
1.3113 131.12 GradientBoostedTreesGeneric [virtual interface]
3.2043 320.43 Generic slow engine
----------------------------------------
In this benchmark, we see the inference speed for different inference engines. For example, "time/example(us) = 0.6315" (can change in different runs) indicates that the inference of one example takes 0.63 micro-seconds. That is, the model can be run ~1.6 millions of times per seconds.
View on TensorFlow.org
Run in Google Colab
View on GitHub
Download notebook