Module: tfdf

User entry point for the TensorFlow Decision Forest API.

Basic usage:

# Imports
import tensorflow_decision_forests as tfdf
import pandas as pd
from wurlitzer import sys_pipes

# Load a dataset into a Pandas Dataframe.
dataset_df = pd.read_csv("/tmp/penguins.csv")

# Display the first 3 examples.
dataset_df.head(3)

# Convert the Pandas dataframe to a tf dataset
tf_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(dataset_df,label="species")

model = tfdf.keras.RandomForestModel()
with sys_pipes():
  model.fit(tf_dataset)
# Note: The `sys_pipes` part is to display logs during training.

# Evaluate model.
model.compile(metrics=["accuracy"])
model.evaluate(...test_dataset...)

# Save model.
model.save("/tmp/my_saved_model")

# ...

# Load a model: it loads as a generic keras model.
loaded_model = tf_keras.models.load_model("/tmp/my_saved_model")

Modules

builder module: Model builder.

check_version module: Check that version of TensorFlow is compatible with TF-DF.

inspector module: Model inspector.

keras module: Decision Forest in a Keras Model.

model_plotter module: Plotting of decision forest models.

py_tree module: Decision trees stored as python objects.

tuner module: Specification of the parameters of a tuner.

version '1.9.0'
compatible_tf_versions ['2.16.1']