Attend the Women in ML Symposium on December 7 Register now

Module: tfdf.keras

Stay organized with collections Save and categorize content based on your preferences.

Decision Forest in a Keras Model.

Usage example:

import tensorflow_decision_forests as tfdf
import pandas as pd

# Load the dataset in a Pandas dataframe.
train_df = pd.read_csv("project/train.csv")
test_df = pd.read_csv("project/test.csv")

# Convert the dataset into a TensorFlow dataset.
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label="my_label")
test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_df, label="my_label")

# Train the model.
model = tfdf.keras.RandomForestModel()
model.fit(train_ds)

# Evaluate the model on another dataset.
model.evaluate(test_ds)

# Show information about the model
model.summary()

# Export the model with the TF.SavedModel format.
model.save("/path/to/my/model")

# ...

# Load a model: it loads as a generic keras model.
loaded_model = tf.keras.models.load_model("/path/to/my/model")

Modules

core module: Core wrapper.

wrappers module: Wrapper around each learning algorithm.

Classes

class AdvancedArguments: Advanced control of the model that most users won't need to use.

class CartModel: Cart learning algorithm.

class CoreModel: Keras Model V2 wrapper around an Yggdrasil Learner and Model.

class DistributedGradientBoostedTreesModel: Distributed Gradient Boosted Trees learning algorithm.

class FeatureSemantic: Semantic (e.g.

class FeatureUsage: Semantic and hyper-parameters for a single feature.

class GradientBoostedTreesModel: Gradient Boosted Trees learning algorithm.

class RandomForestModel: Random Forest learning algorithm.

Functions

build_default_feature_signature(...): Gets an example of feature values for the default model signature.

build_default_input_model_signature(...)

get_all_models(...): Gets the lists of all the available models.

get_worker_idx_and_num_workers(...): Gets the current worker index and the total number of workers.

pd_dataframe_to_tf_dataset(...): Converts a Panda Dataframe into a TF Dataset compatible with Keras.

set_training_logs_redirection(...): Controls the redirection of training logs for display.

yggdrasil_model_to_keras_model(...): Converts an Yggdrasil model into a TensorFlow SavedModel / Keras model.

Task Instance of google.protobuf.internal.enum_type_wrapper.EnumTypeWrapper