![]() |
![]() |
![]() |
![]() |
Welcome to the Prediction Colab for TensorFlow Decision Forests (TF-DF). In this colab, you will learn about different ways to generate predictions with a previously trained TF-DF model using the Python API.
Remark: The Python API shown in this Colab is simple to use and well-suited for experimentation. However, other APIs, such as TensorFlow Serving and the C++ API are better suited for production systems as they are faster and more stable. The exhaustive list of all Serving APIs is available here.
In this colab, you will:
- Use the
model.predict()
function on a TensorFlow Dataset created withpd_dataframe_to_tf_dataset
. - Use the
model.predict()
function on a TensorFlow Dataset created manually. - Use the
model.predict()
function on Numpy arrays. - Make predictions with the CLI API.
- Benchmark the inference speed of a model with the CLI API.
Important remark
The dataset used for predictions should have the same feature names and types as the dataset used for training. Failing to do so, will likely raise errors.
For example, training a model with two features f1
and f2
, and trying to generate predictions on a dataset without f2
will fail. Note that it is okay to set (some or all) feature values as "missing". Similarly, training a model where f2
is a numerical feature (e.g., float32), and applying this model on a dataset where f2
is a text (e.g., string) feature will fail.
While abstracted by the Keras API, a model instantiated in Python (e.g., with
tfdf.keras.RandomForestModel()
) and a model loaded from disk (e.g., with
tf.keras.models.load_model()
) can behave differently. Notably, a Python
instantiated model automatically applies necessary type conversions. For
example, if a float64
feature is fed to a model expecting a float32
feature,
this conversion is performed implicitly. However, such a conversion is not
possible for models loaded from disk. It is therefore important that the
training data and the inference data always have the exact same type.
Setup
First, we install TensorFlow Dececision Forests...
pip install tensorflow_decision_forests
... , and import the libraries used in this example.
import tensorflow_decision_forests as tfdf
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import math
2022-12-14 12:06:51.603857: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 12:06:51.603946: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 12:06:51.603955: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
model.predict(...)
and pd_dataframe_to_tf_dataset
function
TensorFlow Decision Forests implements the Keras model API.
As such, TF-DF models have a predict
function to make predictions. This function takes as input a TensorFlow Dataset and outputs a prediction array.
The simplest way to create a TensorFlow dataset is to use Pandas and the the tfdf.keras.pd_dataframe_to_tf_dataset(...)
function.
The next example shows how to create a TensorFlow dataset using pd_dataframe_to_tf_dataset
.
pd_dataset = pd.DataFrame({
"feature_1": [1,2,3],
"feature_2": ["a", "b", "c"],
"label": [0, 1, 0],
})
pd_dataset
tf_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(pd_dataset, label="label")
for features, label in tf_dataset:
print("Features:",features)
print("label:", label)
Features: {'feature_1': <tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>, 'feature_2': <tf.Tensor: shape=(3,), dtype=string, numpy=array([b'a', b'b', b'c'], dtype=object)>} label: tf.Tensor([0 1 0], shape=(3,), dtype=int64)
Note: "pd" stands for "pandas". "tf" stands for "TensorFlow".
A TensorFlow Dataset is a function that outputs a sequence of values. Those values can be simple arrays (called Tensors) or arrays organized into a structure (for example, arrays organized in a dictionary).
The following example shows the training and inference (using predict
) on a toy dataset:
# Creating a training dataset in Pandas
pd_train_dataset = pd.DataFrame({
"feature_1": np.random.rand(1000),
"feature_2": np.random.rand(1000),
})
pd_train_dataset["label"] = pd_train_dataset["feature_1"] > pd_train_dataset["feature_2"]
pd_train_dataset
# Creating a serving dataset with Pandas
pd_serving_dataset = pd.DataFrame({
"feature_1": np.random.rand(500),
"feature_2": np.random.rand(500),
})
pd_serving_dataset
Let's convert the Pandas dataframes into TensorFlow datasets:
tf_train_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(pd_train_dataset, label="label")
tf_serving_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(pd_serving_dataset)
We can now train a model on tf_train_dataset
:
model = tfdf.keras.RandomForestModel(verbose=0)
model.fit(tf_train_dataset)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089 [INFO 2022-12-14T12:06:58.981628493+00:00 kernel.cc:1175] Loading model from path /tmpfs/tmp/tmp0b3hukdi/model/ with prefix 0234a68d9d6c49ee [INFO 2022-12-14T12:06:59.017961685+00:00 abstract_model.cc:1306] Engine "RandomForestOptPred" built [INFO 2022-12-14T12:06:59.017993244+00:00 kernel.cc:1021] Use fast generic engine WARNING:tensorflow:AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7f76793294c0> and will run it as-is. Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: could not get source code To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert WARNING: AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7f76793294c0> and will run it as-is. Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: could not get source code To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert <keras.callbacks.History at 0x7f76701969d0>
And then generate predictions on tf_serving_dataset
:
# Print the first 10 predictions.
model.predict(tf_serving_dataset, verbose=0)[:10]
array([[0. ], [0.99999917], [0. ], [0.29666647], [0.99999917], [0. ], [0.99999917], [0.99999917], [0.99999917], [0. ]], dtype=float32)
model.predict(...)
and manual TF datasets
In the previous section, we showed how to create a TF dataset using the pd_dataframe_to_tf_dataset
function. This option is simple but poorly suited for large datasets. Instead, TensorFlow offers several options to create a TensorFlow dataset.
The next examples shows how to create a dataset using the tf.data.Dataset.from_tensor_slices()
function.
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5])
for value in dataset:
print("value:", value.numpy())
value: 1 value: 2 value: 3 value: 4 value: 5
TensorFlow models are trained with mini-batching: Instead of being fed one at a time, examples are grouped in "batches". For Neural Networks, the batch size impacts the quality of the model, and the optimal value needs to be determined by the user during training. For Decision Forests, the batch size has no impact on the model. However, for compatibility reasons, TensorFlow Decision Forests expects the dataset to be batched. Batching is done with the batch()
function.
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5]).batch(2)
for value in dataset:
print("value:", value.numpy())
value: [1 2] value: [3 4] value: [5]
TensorFlow Decision Forests expects the dataset to be of one of two structures:
- features, label
- features, label, weights
The features can be a single 2 dimensional array (where each column is a feature and each row is an example), or a dictionary of arrays.
Following is an example of a dataset compatible with TensorFlow Decision Forests:
# A dataset with a single 2d array.
tf_dataset = tf.data.Dataset.from_tensor_slices(
([[1,2],[3,4],[5,6]], # Features
[0,1,0], # Label
)).batch(2)
for features, label in tf_dataset:
print("features:", features)
print("label:", label)
features: tf.Tensor( [[1 2] [3 4]], shape=(2, 2), dtype=int32) label: tf.Tensor([0 1], shape=(2,), dtype=int32) features: tf.Tensor([[5 6]], shape=(1, 2), dtype=int32) label: tf.Tensor([0], shape=(1,), dtype=int32)
# A dataset with a dictionary of features.
tf_dataset = tf.data.Dataset.from_tensor_slices(
({
"feature_1": [1,2,3],
"feature_2": [4,5,6],
},
[0,1,0], # Label
)).batch(2)
for features, label in tf_dataset:
print("features:", features)
print("label:", label)
features: {'feature_1': <tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 2], dtype=int32)>, 'feature_2': <tf.Tensor: shape=(2,), dtype=int32, numpy=array([4, 5], dtype=int32)>} label: tf.Tensor([0 1], shape=(2,), dtype=int32) features: {'feature_1': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([3], dtype=int32)>, 'feature_2': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([6], dtype=int32)>} label: tf.Tensor([0], shape=(1,), dtype=int32)
Let's train a model with this second option.
tf_dataset = tf.data.Dataset.from_tensor_slices(
({
"feature_1": np.random.rand(100),
"feature_2": np.random.rand(100),
},
np.random.rand(100) >= 0.5, # Label
)).batch(2)
model = tfdf.keras.RandomForestModel(verbose=0)
model.fit(tf_dataset)
[INFO 2022-12-14T12:07:00.416575763+00:00 kernel.cc:1175] Loading model from path /tmpfs/tmp/tmpvzrrxxmw/model/ with prefix 0bc6f955d2d1456e [INFO 2022-12-14T12:07:00.440516186+00:00 kernel.cc:1021] Use fast generic engine <keras.callbacks.History at 0x7f75f016e220>
The predict
function can be used directly on the training dataset:
# The first 10 predictions.
model.predict(tf_dataset, verbose=0)[:10]
array([[0.43666634], [0.58999956], [0.42999968], [0.73333275], [0.75666606], [0.20666654], [0.67666614], [0.66666615], [0.82333267], [0.3999997 ]], dtype=float32)
model.predict(...)
and model.predict_on_batch()
on dictionaries
In some cases, the predict
function can be used with an array (or dictionaries of arrays) instead of TensorFlow Dataset.
The following example uses the previously trained model with a dictionary of NumPy arrays.
# The first 10 predictions.
model.predict({
"feature_1": np.random.rand(100),
"feature_2": np.random.rand(100),
}, verbose=0)[:10]
array([[0.6533328 ], [0.5399996 ], [0.2133332 ], [0.22999986], [0.16333325], [0.18333323], [0.3766664 ], [0.5066663 ], [0.20333321], [0.8633326 ]], dtype=float32)
In the previous example, the arrays are automatically batched. Alternatively, the predict_on_batch
function can be used to make sure that all the examples are run in the same batch.
# The first 10 predictions.
model.predict_on_batch({
"feature_1": np.random.rand(100),
"feature_2": np.random.rand(100),
})[:10]
array([[0.54666626], [0.21666653], [0.18333323], [0.5299996 ], [0.5499996 ], [0.12666662], [0.6299995 ], [0.06000001], [0.33999977], [0.08999998]], dtype=float32)
Inference with the YDF format
This example shows how to run a TF-DF model trained with the CLI API (one of the other Serving APIs). We will also use the Benchmark tool to measure the inference speed of the model.
Let's start by training and saving a model:
model = tfdf.keras.GradientBoostedTreesModel(verbose=0)
model.fit(tfdf.keras.pd_dataframe_to_tf_dataset(pd_train_dataset, label="label"))
model.save("my_model")
2022-12-14 12:07:00.950798: W external/ydf/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc:1765] Subsample hyperparameter given but sampling method does not match. 2022-12-14 12:07:00.950839: W external/ydf/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc:1778] GOSS alpha hyperparameter given but GOSS is disabled. 2022-12-14 12:07:00.950846: W external/ydf/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc:1787] GOSS beta hyperparameter given but GOSS is disabled. 2022-12-14 12:07:00.950852: W external/ydf/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc:1799] SelGB ratio hyperparameter given but SelGB is disabled. [INFO 2022-12-14T12:07:01.160357659+00:00 kernel.cc:1175] Loading model from path /tmpfs/tmp/tmpo37712qo/model/ with prefix 391746915b7842cb [INFO 2022-12-14T12:07:01.164736847+00:00 kernel.cc:1021] Use fast generic engine WARNING:absl:Found untraced functions such as call_get_leaves, _update_step_xla while saving (showing 2 of 2). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: my_model/assets INFO:tensorflow:Assets written to: my_model/assets
Let's also export the dataset to a csv file:
pd_serving_dataset.to_csv("dataset.csv")
Let's download and extract the Yggdrasil Decision Forests CLI tools.
wget https://github.com/google/yggdrasil-decision-forests/releases/download/1.0.0/cli_linux.zip
unzip cli_linux.zip
--2022-12-14 12:07:01-- https://github.com/google/yggdrasil-decision-forests/releases/download/1.0.0/cli_linux.zip Resolving github.com (github.com)... 140.82.114.3 Connecting to github.com (github.com)|140.82.114.3|:443... connected. HTTP request sent, awaiting response... 302 Found Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/360444739/bfcd0b9d-5cbc-42a8-be0a-02131875f9a6?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20221214%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20221214T120701Z&X-Amz-Expires=300&X-Amz-Signature=94e7b8fd2c219cbe6305222b34f566360eb9fea8ea35e8303519f09b04744b93&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=360444739&response-content-disposition=attachment%3B%20filename%3Dcli_linux.zip&response-content-type=application%2Foctet-stream [following] --2022-12-14 12:07:01-- https://objects.githubusercontent.com/github-production-release-asset-2e65be/360444739/bfcd0b9d-5cbc-42a8-be0a-02131875f9a6?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20221214%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20221214T120701Z&X-Amz-Expires=300&X-Amz-Signature=94e7b8fd2c219cbe6305222b34f566360eb9fea8ea35e8303519f09b04744b93&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=360444739&response-content-disposition=attachment%3B%20filename%3Dcli_linux.zip&response-content-type=application%2Foctet-stream Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.109.133, ... Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.111.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 31516027 (30M) [application/octet-stream] Saving to: ‘cli_linux.zip’ cli_linux.zip 100%[===================>] 30.06M 38.2MB/s in 0.8s 2022-12-14 12:07:03 (38.2 MB/s) - ‘cli_linux.zip’ saved [31516027/31516027] Archive: cli_linux.zip inflating: README inflating: cli.txt inflating: train inflating: show_model inflating: show_dataspec inflating: predict inflating: infer_dataspec inflating: evaluate inflating: convert_dataset inflating: benchmark_inference inflating: edit_model inflating: synthetic_dataset inflating: grpc_worker_main inflating: LICENSE inflating: CHANGELOG.md
Finally, let's make predictions:
Remarks:
- TensorFlow Decision Forests (TF-DF) is based on the Yggdrasil Decision Forests (YDF) library, and TF-DF model always contains a YDF model internally. When saving a TF-DF model to disk, the TF-DF model directory contains an
assets
sub-directory containing the YDF model. This YDF model can be used with all YDF tools. In the next example, we will use thepredict
andbenchmark_inference
tools. See the model format documentation for more details. - YDF tools assume that the type of the dataset is specified using a prefix, e.g.
csv:
. See the YDF user manual for more details.
./predict --model=my_model/assets --dataset=csv:dataset.csv --output=csv:predictions.csv
[INFO abstract_model.cc:1296] Engine "GradientBoostedTreesQuickScorerExtended" built [INFO predict.cc:133] Run predictions with semi-fast engine
We can now look at the predictions:
pd.read_csv("predictions.csv")
The speed of inference of a model can be measured with the benchmark inference tool.
# Create the empty label column.
pd_serving_dataset["__LABEL"] = 0
pd_serving_dataset.to_csv("dataset.csv")
!./benchmark_inference \
--model=my_model/assets \
--dataset=csv:dataset.csv \
--batch_size=100 \
--warmup_runs=10 \
--num_runs=50
[INFO benchmark_inference.cc:245] Loading model [INFO benchmark_inference.cc:248] The model is of type: GRADIENT_BOOSTED_TREES [INFO benchmark_inference.cc:250] Loading dataset [INFO benchmark_inference.cc:259] Found 3 compatible fast engines. [INFO benchmark_inference.cc:262] Running GradientBoostedTreesGeneric [INFO decision_forest.cc:639] Model loaded with 27 root(s), 1471 node(s), and 2 input feature(s). [INFO benchmark_inference.cc:262] Running GradientBoostedTreesQuickScorerExtended [INFO benchmark_inference.cc:262] Running GradientBoostedTreesOptPred [INFO decision_forest.cc:639] Model loaded with 27 root(s), 1471 node(s), and 2 input feature(s). [INFO benchmark_inference.cc:268] Running the slow generic engine batch_size : 100 num_runs : 50 time/example(us) time/batch(us) method ---------------------------------------- 0.22425 22.425 GradientBoostedTreesOptPred [virtual interface] 0.2465 24.65 GradientBoostedTreesQuickScorerExtended [virtual interface] 0.6875 68.75 GradientBoostedTreesGeneric [virtual interface] 1.825 182.5 Generic slow engine ----------------------------------------
In this benchmark, we see the inference speed for different inference engines. For example, "time/example(us) = 0.6315" (can change in different runs) indicates that the inference of one example takes 0.63 micro-seconds. That is, the model can be run ~1.6 millions of times per seconds.