Attend the Women in ML Symposium on December 7 Register now

Module: tfdf.builder

Stay organized with collections Save and categorize content based on your preferences.

Model builder.

The model builder let the user create models by hand i.e. by defining the tree structure manually.

The available builders are:

  • RandomForestBuilder
  • CARTBuilder
  • GradientBoostedTreeBuilder

About categorical and categorical-set features with string dictionary:

Categorical and categorical-set features are tied to a dictionary of possible values. In addition, the special value "out-of-dictionary" (OOD) designate all the which are not in the dictionary. For example, the condition "a in ["x",""]" if true if the feature "a" is equal to "x" or to any value not in the dictionary.

The feature dictionaries are automatically assembled as the union of all the observed values in the tree conditions. Alternatively, dictionaries can be get/set manually with "{get,set}_dictionary()" or imported from an existing dataspec with the "import_dataspec" constructor argument.

About "file prefix": Multiple Yggdrasil decision forests models can be stored in a single directory. This is a requirement of the TensorFlow SavedModel API. To implement this logic, the files of each individual model are prefixed with a unique identifier. When loading a model from a directory path, this prefix can be provided or detected automatically. Note that the automatic detection will fail if a directory contains more than one model.

Usage:


# Create a binary classification CART model.
builder = builder_lib.CARTBuilder(
  path="/path/to/model",
  objective=py_tree.objective.ClassificationObjective(
  label="color", classes=["red", "blue"]))

# Create the tree
#  f1>=1.5
#    ├─(pos)─ [0.1, 0.9]
#    └─(neg)─ [0.8, 0.2]
#
# The component of the trees (e.g. `NonLeafNode`, `Tree`) are defined in
# `tfdf.py_tree.`.
#
builder.add_tree(
    Tree(
        NonLeafNode(
            condition=NumericalHigherThanCondition(
                feature=SimpleColumnSpec(
                    name="f1", type=py_tree.dataspec.ColumnType.NUMERICAL),
                threshold=1.5,
                missing_evaluation=False),
            pos_child=LeafNode(
                value=ProbabilityValue(probability=[0.1, 0.9])),
            neg_child=LeafNode(
                value=ProbabilityValue(probability=[0.8, 0.2])))))

# Create a second tree
#  f2 in ["x", "y"]
#    ├─(pos)─ [0.1, 0.9]
#    └─(neg)─ [0.8, 0.2]
#
builder.add_tree(
    Tree(
        NonLeafNode(
            condition=CategoricalIsInCondition(
                    feature=SimpleColumnSpec(
                        name="f2",
                        type=py_tree.dataspec.ColumnType.CATEGORICAL),
                    mask=["x", "y"],
                    missing_evaluation=False),
            pos_child=LeafNode(
                value=ProbabilityValue(probability=[0.1, 0.9])),
            neg_child=LeafNode(
                value=ProbabilityValue(probability=[0.8, 0.2])))))

# Optionally set the dictionary of the categorical feature "f2".
# If not set, all the values not seens in the model ("z" in this case) will not
# be known by the model and will be treated as OOD (out of dictionary).
#
# Defining a dictionary only has an impact if a condition is testing for the
# `<OOD>` item directly i.e. the test `f2 in ["<OOD>"]` depends on the content
# of the dictionary.
builder.set_dictionary("f2",["<OOD>", "x", "y", "z"]

builder.close()

# Load and use the model
model = tf.keras.models.load_model("/path/to/model")
predictions = model.predict(...)

Classes

class AbstractBuilder: Generic model builder.

class AbstractDecisionForestBuilder: Generic decision forest model builder.

class AdvancedArguments: Advanced control of the model building.

class CARTBuilder: CART model builder.

class Enum: Generic enumeration.

class GradientBoostedTreeBuilder: Gradient Boosted Tree model builder.

class ModelFormat: Model formats on disk.

class RandomForestBuilder: Random Forest model builder.

Functions

dataclass(...): Returns the same class as was passed in, with dunder methods

ColumnType Instance of google.protobuf.internal.enum_type_wrapper.EnumTypeWrapper
Task Instance of google.protobuf.internal.enum_type_wrapper.EnumTypeWrapper
absolute_import Instance of __future__._Feature
division Instance of __future__._Feature
print_function Instance of __future__._Feature