Making predictions

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook

Welcome to the Prediction Colab for TensorFlow Decision Forests (TF-DF). In this colab, you will learn about different ways to generate predictions with a previously trained TF-DF model using the Python API.

Remark: The Python API shown in this Colab is simple to use and well-suited for experimentation. However, other APIs, such as TensorFlow Serving and the C++ API are better suited for production systems as they are faster and more stable. The exhaustive list of all Serving APIs is available here.

In this colab, you will:

  1. Use the model.predict() function on a TensorFlow Dataset created with pd_dataframe_to_tf_dataset.
  2. Use the model.predict() function on a TensorFlow Dataset created manually.
  3. Use the model.predict() function on Numpy arrays.
  4. Make predictions with the CLI API.
  5. Benchmark the inference speed of a model with the CLI API.

Important remark

The dataset used for predictions should have the same feature names and types as the dataset used for training. Failing to do so, will likely raise errors.

For example, training a model with two features f1 and f2, and trying to generate predictions on a dataset without f2 will fail. Note that it is okay to set (some or all) feature values as "missing". Similarly, training a model where f2 is a numerical feature (e.g., float32), and applying this model on a dataset where f2 is a text (e.g., string) feature will fail.

While abstracted by the Keras API, a model instantiated in Python (e.g., with tfdf.keras.RandomForestModel()) and a model loaded from disk (e.g., with tf_keras.models.load_model()) can behave differently. Notably, a Python instantiated model automatically applies necessary type conversions. For example, if a float64 feature is fed to a model expecting a float32 feature, this conversion is performed implicitly. However, such a conversion is not possible for models loaded from disk. It is therefore important that the training data and the inference data always have the exact same type.

Setup

First, we install TensorFlow Dececision Forests...

pip install tensorflow_decision_forests

... , and import the libraries used in this example.

import tensorflow_decision_forests as tfdf

import os
import numpy as np
import pandas as pd
import tensorflow as tf
import math
2024-01-26 12:24:03.241595: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-26 12:24:03.241638: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-26 12:24:03.243158: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

model.predict(...) and pd_dataframe_to_tf_dataset function

TensorFlow Decision Forests implements the Keras model API. As such, TF-DF models have a predict function to make predictions. This function takes as input a TensorFlow Dataset and outputs a prediction array. The simplest way to create a TensorFlow dataset is to use Pandas and the the tfdf.keras.pd_dataframe_to_tf_dataset(...) function.

The next example shows how to create a TensorFlow dataset using pd_dataframe_to_tf_dataset.

pd_dataset = pd.DataFrame({
    "feature_1": [1,2,3],
    "feature_2": ["a", "b", "c"],
    "label": [0, 1, 0],
})

pd_dataset
tf_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(pd_dataset, label="label")

for features, label in tf_dataset:
  print("Features:",features)
  print("label:", label)
Features: {'feature_1': <tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>, 'feature_2': <tf.Tensor: shape=(3,), dtype=string, numpy=array([b'a', b'b', b'c'], dtype=object)>}
label: tf.Tensor([0 1 0], shape=(3,), dtype=int64)

Note: "pd" stands for "pandas". "tf" stands for "TensorFlow".

A TensorFlow Dataset is a function that outputs a sequence of values. Those values can be simple arrays (called Tensors) or arrays organized into a structure (for example, arrays organized in a dictionary).

The following example shows the training and inference (using predict) on a toy dataset:

# Creating a training dataset in Pandas
pd_train_dataset = pd.DataFrame({
    "feature_1": np.random.rand(1000),
    "feature_2": np.random.rand(1000),
})
pd_train_dataset["label"] = pd_train_dataset["feature_1"] > pd_train_dataset["feature_2"] 

pd_train_dataset
# Creating a serving dataset with Pandas
pd_serving_dataset = pd.DataFrame({
    "feature_1": np.random.rand(500),
    "feature_2": np.random.rand(500),
})

pd_serving_dataset

Let's convert the Pandas dataframes into TensorFlow datasets:

tf_train_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(pd_train_dataset, label="label")
tf_serving_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(pd_serving_dataset)

We can now train a model on tf_train_dataset:

model = tfdf.keras.RandomForestModel(verbose=0)
model.fit(tf_train_dataset)
[INFO 24-01-26 12:24:11.0410 UTC kernel.cc:1233] Loading model from path /tmpfs/tmp/tmpfc2o173s/model/ with prefix 9596ffa4587343b0
[INFO 24-01-26 12:24:11.0737 UTC decision_forest.cc:660] Model loaded with 300 root(s), 11030 node(s), and 2 input feature(s).
[INFO 24-01-26 12:24:11.0738 UTC abstract_model.cc:1344] Engine "RandomForestOptPred" built
[INFO 24-01-26 12:24:11.0738 UTC kernel.cc:1061] Use fast generic engine
<keras.src.callbacks.History at 0x7f2ae8180820>

And then generate predictions on tf_serving_dataset:

# Print the first 10 predictions.
model.predict(tf_serving_dataset, verbose=0)[:10]
array([[0.        ],
       [0.99999917],
       [0.6833328 ],
       [0.99999917],
       [0.00666667],
       [0.6066662 ],
       [0.99999917],
       [0.99999917],
       [0.796666  ],
       [0.99999917]], dtype=float32)

model.predict(...) and manual TF datasets

In the previous section, we showed how to create a TF dataset using the pd_dataframe_to_tf_dataset function. This option is simple but poorly suited for large datasets. Instead, TensorFlow offers several options to create a TensorFlow dataset. The next examples shows how to create a dataset using the tf.data.Dataset.from_tensor_slices() function.

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5])

for value in dataset:
  print("value:", value.numpy())
value: 1
value: 2
value: 3
value: 4
value: 5

TensorFlow models are trained with mini-batching: Instead of being fed one at a time, examples are grouped in "batches". For Neural Networks, the batch size impacts the quality of the model, and the optimal value needs to be determined by the user during training. For Decision Forests, the batch size has no impact on the model. However, for compatibility reasons, TensorFlow Decision Forests expects the dataset to be batched. Batching is done with the batch() function.

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5]).batch(2)

for value in dataset:
  print("value:", value.numpy())
value: [1 2]
value: [3 4]
value: [5]

TensorFlow Decision Forests expects the dataset to be of one of two structures:

  • features, label
  • features, label, weights

The features can be a single 2 dimensional array (where each column is a feature and each row is an example), or a dictionary of arrays.

Following is an example of a dataset compatible with TensorFlow Decision Forests:

# A dataset with a single 2d array.
tf_dataset = tf.data.Dataset.from_tensor_slices(
    ([[1,2],[3,4],[5,6]], # Features
    [0,1,0], # Label
    )).batch(2)

for features, label in tf_dataset:
  print("features:", features)
  print("label:", label)
features: tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32)
label: tf.Tensor([0 1], shape=(2,), dtype=int32)
features: tf.Tensor([[5 6]], shape=(1, 2), dtype=int32)
label: tf.Tensor([0], shape=(1,), dtype=int32)
# A dataset with a dictionary of features.
tf_dataset = tf.data.Dataset.from_tensor_slices(
    ({
    "feature_1": [1,2,3],
    "feature_2": [4,5,6],
    },
    [0,1,0], # Label
    )).batch(2)

for features, label in tf_dataset:
  print("features:", features)
  print("label:", label)
features: {'feature_1': <tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 2], dtype=int32)>, 'feature_2': <tf.Tensor: shape=(2,), dtype=int32, numpy=array([4, 5], dtype=int32)>}
label: tf.Tensor([0 1], shape=(2,), dtype=int32)
features: {'feature_1': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([3], dtype=int32)>, 'feature_2': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([6], dtype=int32)>}
label: tf.Tensor([0], shape=(1,), dtype=int32)

Let's train a model with this second option.

tf_dataset = tf.data.Dataset.from_tensor_slices(
    ({
    "feature_1": np.random.rand(100),
    "feature_2": np.random.rand(100),
    },
    np.random.rand(100) >= 0.5, # Label
    )).batch(2)

model = tfdf.keras.RandomForestModel(verbose=0)
model.fit(tf_dataset)
[INFO 24-01-26 12:24:12.1921 UTC kernel.cc:1233] Loading model from path /tmpfs/tmp/tmpffb588tv/model/ with prefix 8badff20d9284054
[INFO 24-01-26 12:24:12.2165 UTC decision_forest.cc:660] Model loaded with 300 root(s), 8312 node(s), and 2 input feature(s).
[INFO 24-01-26 12:24:12.2165 UTC kernel.cc:1061] Use fast generic engine
<keras.src.callbacks.History at 0x7f2a9473abe0>

The predict function can be used directly on the training dataset:

# The first 10 predictions.
model.predict(tf_dataset, verbose=0)[:10]
array([[0.15999992],
       [0.5966662 ],
       [0.53333294],
       [0.15666659],
       [0.15666659],
       [0.7833327 ],
       [0.54333293],
       [0.54666626],
       [0.6066662 ],
       [0.20999987]], dtype=float32)

model.predict(...) and model.predict_on_batch() on dictionaries

In some cases, the predict function can be used with an array (or dictionaries of arrays) instead of TensorFlow Dataset.

The following example uses the previously trained model with a dictionary of NumPy arrays.

# The first 10 predictions.
model.predict({
    "feature_1": np.random.rand(100),
    "feature_2": np.random.rand(100),
    }, verbose=0)[:10]
array([[0.49666628],
       [0.05      ],
       [0.53333294],
       [0.5533329 ],
       [0.09999997],
       [0.6266662 ],
       [0.64666617],
       [0.60999954],
       [0.68999946],
       [0.2566665 ]], dtype=float32)

In the previous example, the arrays are automatically batched. Alternatively, the predict_on_batch function can be used to make sure that all the examples are run in the same batch.

# The first 10 predictions.
model.predict_on_batch({
    "feature_1": np.random.rand(100),
    "feature_2": np.random.rand(100),
    })[:10]
array([[0.16333325],
       [0.1899999 ],
       [0.6966661 ],
       [0.74333274],
       [0.31999978],
       [0.61333287],
       [0.34666643],
       [0.6399995 ],
       [0.39666638],
       [0.6733328 ]], dtype=float32)

Inference with the YDF format

This example shows how to run a TF-DF model trained with the CLI API (one of the other Serving APIs). We will also use the Benchmark tool to measure the inference speed of the model.

Let's start by training and saving a model:

model = tfdf.keras.GradientBoostedTreesModel(verbose=0)
model.fit(tfdf.keras.pd_dataframe_to_tf_dataset(pd_train_dataset, label="label"))
model.save("my_model")
[WARNING 24-01-26 12:24:12.6505 UTC gradient_boosted_trees.cc:1886] "goss_alpha" set but "sampling_method" not equal to "GOSS".
[WARNING 24-01-26 12:24:12.6506 UTC gradient_boosted_trees.cc:1897] "goss_beta" set but "sampling_method" not equal to "GOSS".
[WARNING 24-01-26 12:24:12.6506 UTC gradient_boosted_trees.cc:1911] "selective_gradient_boosting_ratio" set but "sampling_method" not equal to "SELGB".
[INFO 24-01-26 12:24:13.0598 UTC kernel.cc:1233] Loading model from path /tmpfs/tmp/tmpo81qqjke/model/ with prefix 22c6b8cf911b423a
[INFO 24-01-26 12:24:13.0659 UTC quick_scorer_extended.cc:903] The binary was compiled without AVX2 support, but your CPU supports it. Enable it for faster model inference.
[INFO 24-01-26 12:24:13.0663 UTC kernel.cc:1061] Use fast generic engine
INFO:tensorflow:Assets written to: my_model/assets
INFO:tensorflow:Assets written to: my_model/assets

Let's also export the dataset to a csv file:

pd_serving_dataset.to_csv("dataset.csv")

Let's download and extract the Yggdrasil Decision Forests CLI tools.

wget https://github.com/google/yggdrasil-decision-forests/releases/download/1.0.0/cli_linux.zip
unzip cli_linux.zip
--2024-01-26 12:24:13--  https://github.com/google/yggdrasil-decision-forests/releases/download/1.0.0/cli_linux.zip
Resolving github.com (github.com)... 140.82.114.4
Connecting to github.com (github.com)|140.82.114.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/360444739/bfcd0b9d-5cbc-42a8-be0a-02131875f9a6?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240126%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240126T122413Z&X-Amz-Expires=300&X-Amz-Signature=cefb140389664f759f7cde7e84618407a19041e9a27722d5a62f277d43d49278&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=360444739&response-content-disposition=attachment%3B%20filename%3Dcli_linux.zip&response-content-type=application%2Foctet-stream [following]
--2024-01-26 12:24:13--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/360444739/bfcd0b9d-5cbc-42a8-be0a-02131875f9a6?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240126%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240126T122413Z&X-Amz-Expires=300&X-Amz-Signature=cefb140389664f759f7cde7e84618407a19041e9a27722d5a62f277d43d49278&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=360444739&response-content-disposition=attachment%3B%20filename%3Dcli_linux.zip&response-content-type=application%2Foctet-stream
Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.110.133, ...
Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 31516027 (30M) [application/octet-stream]
Saving to: ‘cli_linux.zip’

cli_linux.zip       100%[===================>]  30.06M   183MB/s    in 0.2s    

2024-01-26 12:24:14 (183 MB/s) - ‘cli_linux.zip’ saved [31516027/31516027]

Archive:  cli_linux.zip
  inflating: README                  
  inflating: cli.txt                 
  inflating: train                   
  inflating: show_model              
  inflating: show_dataspec           
  inflating: predict                 
  inflating: infer_dataspec          
  inflating: evaluate                
  inflating: convert_dataset         
  inflating: benchmark_inference     
  inflating: edit_model              
  inflating: synthetic_dataset       
  inflating: grpc_worker_main        
  inflating: LICENSE                 
  inflating: CHANGELOG.md

Finally, let's make predictions:

Remarks:

  • TensorFlow Decision Forests (TF-DF) is based on the Yggdrasil Decision Forests (YDF) library, and TF-DF model always contains a YDF model internally. When saving a TF-DF model to disk, the TF-DF model directory contains an assets sub-directory containing the YDF model. This YDF model can be used with all YDF tools. In the next example, we will use the predict and benchmark_inference tools. See the model format documentation for more details.
  • YDF tools assume that the type of the dataset is specified using a prefix, e.g. csv:. See the YDF user manual for more details.
./predict --model=my_model/assets --dataset=csv:dataset.csv --output=csv:predictions.csv
[INFO abstract_model.cc:1296] Engine "GradientBoostedTreesQuickScorerExtended" built
[INFO predict.cc:133] Run predictions with semi-fast engine

We can now look at the predictions:

pd.read_csv("predictions.csv")

The speed of inference of a model can be measured with the benchmark inference tool.

# Create the empty label column.
pd_serving_dataset["__LABEL"] = 0
pd_serving_dataset.to_csv("dataset.csv")
!./benchmark_inference \
  --model=my_model/assets \
  --dataset=csv:dataset.csv \
  --batch_size=100 \
  --warmup_runs=10 \
  --num_runs=50
[INFO benchmark_inference.cc:245] Loading model
[INFO benchmark_inference.cc:248] The model is of type: GRADIENT_BOOSTED_TREES
[INFO benchmark_inference.cc:250] Loading dataset
[INFO benchmark_inference.cc:259] Found 3 compatible fast engines.
[INFO benchmark_inference.cc:262] Running GradientBoostedTreesGeneric
[INFO decision_forest.cc:639] Model loaded with 43 root(s), 2165 node(s), and 2 input feature(s).
[INFO benchmark_inference.cc:262] Running GradientBoostedTreesQuickScorerExtended
[INFO benchmark_inference.cc:262] Running GradientBoostedTreesOptPred
[INFO decision_forest.cc:639] Model loaded with 43 root(s), 2165 node(s), and 2 input feature(s).
[INFO benchmark_inference.cc:268] Running the slow generic engine
batch_size : 100  num_runs : 50
time/example(us)  time/batch(us)  method
----------------------------------------
         0.35825          35.825  GradientBoostedTreesQuickScorerExtended [virtual interface]
          0.5065           50.65  GradientBoostedTreesOptPred [virtual interface]
          1.4692          146.93  GradientBoostedTreesGeneric [virtual interface]
          3.4152          341.52  Generic slow engine
----------------------------------------

In this benchmark, we see the inference speed for different inference engines. For example, "time/example(us) = 0.6315" (can change in different runs) indicates that the inference of one example takes 0.63 micro-seconds. That is, the model can be run ~1.6 millions of times per seconds.