Inspect and debug decision forest models

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook

In this colab, you will learn how to inspect and create the structure of a model directly. We assume you are familiar with the concepts introduced in the beginner and intermediate colabs.

In this colab, you will:

  1. Train a Random Forest model and access its structure programatically.

  2. Create a Random Forest model by hand and use it as a classical model.

Setup

# Install TensorFlow Dececision Forests.
pip install tensorflow_decision_forests

# Use wurlitzer to capture training logs.
pip install wurlitzer
import tensorflow_decision_forests as tfdf

import os
import numpy as np
import pandas as pd
import tensorflow as tf
import math
import collections

try:
  from wurlitzer import sys_pipes
except:
  from colabtools.googlelog import CaptureLog as sys_pipes

from IPython.core.magic import register_line_magic
from IPython.display import Javascript

The hidden code cell limits the output height in colab.

Train a simple Random Forest

We train a Random Forest like in the beginner colab:

# Download the dataset
!wget -q https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins.csv -O /tmp/penguins.csv

# Load a dataset into a Pandas Dataframe.
dataset_df = pd.read_csv("/tmp/penguins.csv")

# Show the first three examples.
print(dataset_df.head(3))

# Convert the pandas dataframe into a tf dataset.
dataset_tf = tfdf.keras.pd_dataframe_to_tf_dataset(dataset_df, label="species")

# Train the Random Forest
model = tfdf.keras.RandomForestModel(compute_oob_variable_importances=True)
model.fit(x=dataset_tf)
species     island  bill_length_mm  bill_depth_mm  flipper_length_mm  \
0  Adelie  Torgersen            39.1           18.7              181.0   
1  Adelie  Torgersen            39.5           17.4              186.0   
2  Adelie  Torgersen            40.3           18.0              195.0   

   body_mass_g     sex  year  
0       3750.0    male  2007  
1       3800.0  female  2007  
2       3250.0  female  2007
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_decision_forests/keras/core.py:1224: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only
  (dict(dataframe.drop(label, 1)), dataframe[label].values))
6/6 [==============================] - 3s 2ms/step
[INFO kernel.cc:746] Start Yggdrasil model training
[INFO kernel.cc:747] Collect training examples
[INFO kernel.cc:392] Number of batches: 6
[INFO kernel.cc:393] Number of examples: 344
[INFO kernel.cc:769] Dataset:
Number of records: 344
Number of columns: 8

Number of columns by type:
    NUMERICAL: 5 (62.5%)
    CATEGORICAL: 3 (37.5%)

Columns:

NUMERICAL: 5 (62.5%)
    0: "bill_depth_mm" NUMERICAL num-nas:2 (0.581395%) mean:17.1512 min:13.1 max:21.5 sd:1.9719
    1: "bill_length_mm" NUMERICAL num-nas:2 (0.581395%) mean:43.9219 min:32.1 max:59.6 sd:5.4516
    2: "body_mass_g" NUMERICAL num-nas:2 (0.581395%) mean:4201.75 min:2700 max:6300 sd:800.781
    3: "flipper_length_mm" NUMERICAL num-nas:2 (0.581395%) mean:200.915 min:172 max:231 sd:14.0411
    6: "year" NUMERICAL mean:2008.03 min:2007 max:2009 sd:0.817166

CATEGORICAL: 3 (37.5%)
    4: "island" CATEGORICAL has-dict vocab-size:4 zero-ood-items most-frequent:"Biscoe" 168 (48.8372%)
    5: "sex" CATEGORICAL num-nas:11 (3.19767%) has-dict vocab-size:3 zero-ood-items most-frequent:"male" 168 (50.4505%)
    7: "__LABEL" CATEGORICAL integerized vocab-size:4 no-ood-item

Terminology:
    nas: Number of non-available (i.e. missing) values.
    ood: Out of dictionary.
    manually-defined: Attribute which type is manually defined by the user i.e. the type was not automatically inferred.
    tokenized: The attribute value is obtained through tokenization.
    has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string.
    vocab-size: Number of unique values.

[INFO kernel.cc:772] Configure learner
[INFO kernel.cc:797] Training config:
learner: "RANDOM_FOREST"
features: "bill_depth_mm"
features: "bill_length_mm"
features: "body_mass_g"
features: "flipper_length_mm"
features: "island"
features: "sex"
features: "year"
label: "__LABEL"
task: CLASSIFICATION
[yggdrasil_decision_forests.model.random_forest.proto.random_forest_config] {
  num_trees: 300
  decision_tree {
    max_depth: 16
    min_examples: 5
    in_split_min_examples_check: true
    missing_value_policy: GLOBAL_IMPUTATION
    allow_na_conditions: false
    categorical_set_greedy_forward {
      sampling: 0.1
      max_num_items: -1
      min_item_frequency: 1
    }
    growing_strategy_local {
    }
    categorical {
      cart {
      }
    }
    num_candidate_attributes_ratio: -1
    axis_aligned_split {
    }
    internal {
      sorting_strategy: PRESORTED
    }
  }
  winner_take_all_inference: true
  compute_oob_performances: true
  compute_oob_variable_importances: true
  adapt_bootstrap_size_ratio_for_maximum_training_duration: false
}

[INFO kernel.cc:800] Deployment config:
num_threads: 6

[INFO kernel.cc:837] Train model
[INFO random_forest.cc:303] Training random forest on 344 example(s) and 7 feature(s).
[INFO random_forest.cc:578] Training of tree  1/300 (tree index:0) done accuracy:0.95 logloss:1.80218
[INFO random_forest.cc:578] Training of tree  11/300 (tree index:10) done accuracy:0.96793 logloss:0.472141
[INFO random_forest.cc:578] Training of tree  21/300 (tree index:20) done accuracy:0.982558 logloss:0.258458
[INFO random_forest.cc:578] Training of tree  32/300 (tree index:31) done accuracy:0.979651 logloss:0.26222
[INFO random_forest.cc:578] Training of tree  45/300 (tree index:44) done accuracy:0.979651 logloss:0.163172
[INFO random_forest.cc:578] Training of tree  57/300 (tree index:56) done accuracy:0.976744 logloss:0.163648
[INFO random_forest.cc:578] Training of tree  69/300 (tree index:68) done accuracy:0.979651 logloss:0.164942
[INFO random_forest.cc:578] Training of tree  79/300 (tree index:78) done accuracy:0.976744 logloss:0.165327
[INFO random_forest.cc:578] Training of tree  89/300 (tree index:88) done accuracy:0.979651 logloss:0.165204
[INFO random_forest.cc:578] Training of tree  99/300 (tree index:98) done accuracy:0.976744 logloss:0.165402
[INFO random_forest.cc:578] Training of tree  110/300 (tree index:109) done accuracy:0.976744 logloss:0.164999
[INFO random_forest.cc:578] Training of tree  120/300 (tree index:119) done accuracy:0.976744 logloss:0.164786
[INFO random_forest.cc:578] Training of tree  131/300 (tree index:130) done accuracy:0.976744 logloss:0.164439
[INFO random_forest.cc:578] Training of tree  141/300 (tree index:140) done accuracy:0.976744 logloss:0.162631
[INFO random_forest.cc:578] Training of tree  153/300 (tree index:152) done accuracy:0.976744 logloss:0.163604
[INFO random_forest.cc:578] Training of tree  163/300 (tree index:162) done accuracy:0.976744 logloss:0.16199
[INFO random_forest.cc:578] Training of tree  173/300 (tree index:172) done accuracy:0.976744 logloss:0.16083
[INFO random_forest.cc:578] Training of tree  183/300 (tree index:182) done accuracy:0.976744 logloss:0.160679
[INFO random_forest.cc:578] Training of tree  193/300 (tree index:192) done accuracy:0.976744 logloss:0.160417
[INFO random_forest.cc:578] Training of tree  203/300 (tree index:202) done accuracy:0.976744 logloss:0.16071
[INFO random_forest.cc:578] Training of tree  213/300 (tree index:212) done accuracy:0.976744 logloss:0.0684506
[INFO random_forest.cc:578] Training of tree  223/300 (tree index:222) done accuracy:0.976744 logloss:0.0690158
[INFO random_forest.cc:578] Training of tree  234/300 (tree index:233) done accuracy:0.976744 logloss:0.0690293
[INFO random_forest.cc:578] Training of tree  244/300 (tree index:242) done accuracy:0.976744 logloss:0.0689818
[INFO random_forest.cc:578] Training of tree  255/300 (tree index:254) done accuracy:0.976744 logloss:0.0696632
[INFO random_forest.cc:578] Training of tree  266/300 (tree index:265) done accuracy:0.976744 logloss:0.0697246
[INFO random_forest.cc:578] Training of tree  276/300 (tree index:275) done accuracy:0.976744 logloss:0.0692516
[INFO random_forest.cc:578] Training of tree  286/300 (tree index:285) done accuracy:0.976744 logloss:0.0697065
[INFO random_forest.cc:578] Training of tree  296/300 (tree index:294) done accuracy:0.976744 logloss:0.0699951
[INFO random_forest.cc:578] Training of tree  300/300 (tree index:299) done accuracy:0.976744 logloss:0.0695934
[INFO random_forest.cc:645] Final OOB metrics: accuracy:0.976744 logloss:0.0695934
[INFO kernel.cc:856] Export model in log directory: /tmp/tmpwjaxc8md
[INFO kernel.cc:864] Save model in resources
[INFO kernel.cc:988] Loading model from path
[INFO decision_forest.cc:590] Model loaded with 300 root(s), 5484 node(s), and 7 input feature(s).
[INFO abstract_model.cc:993] Engine "RandomForestGeneric" built
[INFO kernel.cc:848] Use fast generic engine
<keras.callbacks.History at 0x7efde7d66210>

Note the compute_oob_variable_importances=True hyper-parameter in the model constructor. This opion computes the Out-of-bag (OOB) variable importance during training. This is a popular permutation variable importance for Random Forest models.

Computing the OOB Variable importance does not impact the final model, it will slow the training on large datasets.

Check the model summary:

%set_cell_height 300

model.summary()
<IPython.core.display.Javascript object>
Model: "random_forest_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
Total params: 1
Trainable params: 0
Non-trainable params: 1
_________________________________________________________________
Type: "RANDOM_FOREST"
Task: CLASSIFICATION
Label: "__LABEL"

Input Features (7):
    bill_depth_mm
    bill_length_mm
    body_mass_g
    flipper_length_mm
    island
    sex
    year

No weights

Variable Importance: MEAN_DECREASE_IN_ACCURACY:

    1.    "bill_length_mm"  0.151163 ################
    2.            "island"  0.031977 ###
    3. "flipper_length_mm"  0.005814 
    4.     "bill_depth_mm"  0.000000 
    5.       "body_mass_g"  0.000000 
    6.               "sex"  0.000000 
    7.              "year"  0.000000 

Variable Importance: MEAN_DECREASE_IN_AP_1_VS_OTHERS:

    1.    "bill_length_mm"  0.084029 ################
    2. "flipper_length_mm"  0.014254 ##
    3.     "bill_depth_mm"  0.008158 #
    4.            "island"  0.004025 
    5.              "year"  0.001334 
    6.               "sex"  0.000688 
    7.       "body_mass_g" -0.000286 

Variable Importance: MEAN_DECREASE_IN_AP_2_VS_OTHERS:

    1.    "bill_length_mm"  0.473179 ################
    2.            "island"  0.020809 
    3.     "bill_depth_mm"  0.008128 
    4.       "body_mass_g"  0.004428 
    5.              "year"  0.001919 
    6.               "sex"  0.001662 
    7. "flipper_length_mm"  0.000122 

Variable Importance: MEAN_DECREASE_IN_AP_3_VS_OTHERS:

    1.            "island"  0.002957 ################
    2.    "bill_length_mm"  0.001680 #########
    3.     "bill_depth_mm"  0.000370 ##
    4.               "sex"  0.000000 
    5.              "year"  0.000000 
    6.       "body_mass_g" -0.000065 
    7. "flipper_length_mm" -0.000126 

Variable Importance: MEAN_DECREASE_IN_AUC_1_VS_OTHERS:

    1.    "bill_length_mm"  0.069730 ################
    2. "flipper_length_mm"  0.008532 #
    3.            "island"  0.004489 #
    4.     "bill_depth_mm"  0.003255 
    5.              "year"  0.002210 
    6.               "sex"  0.000463 
    7.       "body_mass_g" -0.000051 

Variable Importance: MEAN_DECREASE_IN_AUC_2_VS_OTHERS:

    1.    "bill_length_mm"  0.099158 ################
    2.     "bill_depth_mm"  0.011456 #
    3.            "island"  0.011322 #
    4.              "year"  0.006980 #
    5.       "body_mass_g"  0.002957 
    6.               "sex"  0.000719 
    7. "flipper_length_mm"  0.000613 

Variable Importance: MEAN_DECREASE_IN_AUC_3_VS_OTHERS:

    1.            "island"  0.002896 ################
    2.    "bill_length_mm"  0.001393 ########
    3.     "bill_depth_mm"  0.000220 #
    4.               "sex"  0.000000 
    5.              "year"  0.000000 
    6.       "body_mass_g" -0.000037 
    7. "flipper_length_mm" -0.000110 

Variable Importance: MEAN_DECREASE_IN_PRAUC_1_VS_OTHERS:

    1.    "bill_length_mm"  0.083762 ################
    2. "flipper_length_mm"  0.014094 ##
    3.     "bill_depth_mm"  0.008042 #
    4.              "year"  0.004573 
    5.            "island"  0.004017 
    6.               "sex"  0.000705 
    7.       "body_mass_g" -0.000264 

Variable Importance: MEAN_DECREASE_IN_PRAUC_2_VS_OTHERS:

    1.    "bill_length_mm"  0.463712 ################
    2.            "island"  0.020673 
    3.     "bill_depth_mm"  0.008362 
    4.       "body_mass_g"  0.004522 
    5.              "year"  0.003970 
    6.               "sex"  0.001614 
    7. "flipper_length_mm"  0.000089 

Variable Importance: MEAN_DECREASE_IN_PRAUC_3_VS_OTHERS:

    1.            "island"  0.002950 ################
    2.    "bill_length_mm"  0.001676 #########
    3.     "bill_depth_mm"  0.000369 ##
    4.               "sex"  0.000000 
    5.              "year"  0.000000 
    6.       "body_mass_g" -0.000064 
    7. "flipper_length_mm" -0.000126 

Variable Importance: MEAN_MIN_DEPTH:

    1.           "__LABEL"  3.592893 ################
    2.              "year"  3.578962 ###############
    3.               "sex"  3.536878 ###############
    4.       "body_mass_g"  3.021166 ############
    5.     "bill_depth_mm"  2.497914 ########
    6.            "island"  2.195246 ######
    7.    "bill_length_mm"  1.322871 
    8. "flipper_length_mm"  1.285246 

Variable Importance: NUM_AS_ROOT:

    1. "flipper_length_mm" 164.000000 ################
    2.    "bill_length_mm" 74.000000 #######
    3.     "bill_depth_mm" 44.000000 ####
    4.            "island" 15.000000 #
    5.       "body_mass_g"  3.000000 

Variable Importance: NUM_NODES:

    1.    "bill_length_mm" 847.000000 ################
    2.     "bill_depth_mm" 495.000000 #########
    3. "flipper_length_mm" 440.000000 ########
    4.            "island" 420.000000 #######
    5.       "body_mass_g" 319.000000 #####
    6.               "sex" 51.000000 
    7.              "year" 20.000000 

Variable Importance: SUM_SCORE:

    1. "flipper_length_mm" 35884.294356 ################
    2.    "bill_length_mm" 35451.100839 ###############
    3.            "island" 18086.699162 ########
    4.     "bill_depth_mm" 11832.748897 #####
    5.       "body_mass_g" 2905.187669 #
    6.               "sex" 395.149033 
    7.              "year" 44.817843 



Winner take all: true
Out-of-bag evaluation: accuracy:0.976744 logloss:0.0695934
Number of trees: 300
Total number of nodes: 5484

Number of nodes by tree:
Count: 300 Average: 18.28 StdDev: 3.36079
Min: 11 Max: 31 Ignored: 0
----------------------------------------------
[ 11, 12)  5   1.67%   1.67% #
[ 12, 13)  0   0.00%   1.67%
[ 13, 14) 20   6.67%   8.33% ###
[ 14, 15)  0   0.00%   8.33%
[ 15, 16) 52  17.33%  25.67% #######
[ 16, 17)  0   0.00%  25.67%
[ 17, 18) 72  24.00%  49.67% ##########
[ 18, 19)  0   0.00%  49.67%
[ 19, 20) 62  20.67%  70.33% #########
[ 20, 21)  0   0.00%  70.33%
[ 21, 22) 55  18.33%  88.67% ########
[ 22, 23)  0   0.00%  88.67%
[ 23, 24) 20   6.67%  95.33% ###
[ 24, 25)  0   0.00%  95.33%
[ 25, 26)  6   2.00%  97.33% #
[ 26, 27)  0   0.00%  97.33%
[ 27, 28)  6   2.00%  99.33% #
[ 28, 29)  0   0.00%  99.33%
[ 29, 30)  1   0.33%  99.67%
[ 30, 31]  1   0.33% 100.00%

Depth by leafs:
Count: 2892 Average: 3.64281 StdDev: 1.06562
Min: 1 Max: 6 Ignored: 0
----------------------------------------------
[ 1, 2)   1   0.03%   0.03%
[ 2, 3) 479  16.56%  16.60% #####
[ 3, 4) 812  28.08%  44.67% #########
[ 4, 5) 948  32.78%  77.46% ##########
[ 5, 6) 564  19.50%  96.96% ######
[ 6, 6]  88   3.04% 100.00% #

Number of training obs by leaf:
Count: 2892 Average: 35.6846 StdDev: 43.0402
Min: 5 Max: 151 Ignored: 0
----------------------------------------------
[   5,  12) 1601  55.36%  55.36% ##########
[  12,  19)  163   5.64%  61.00% #
[  19,  27)   74   2.56%  63.55%
[  27,  34)   77   2.66%  66.22%
[  34,  41)   72   2.49%  68.71%
[  41,  49)  112   3.87%  72.58% #
[  49,  56)   94   3.25%  75.83% #
[  56,  63)   62   2.14%  77.97%
[  63,  71)   39   1.35%  79.32%
[  71,  78)   20   0.69%  80.01%
[  78,  85)   22   0.76%  80.77%
[  85,  93)   39   1.35%  82.12%
[  93, 100)   39   1.35%  83.47%
[ 100, 107)   61   2.11%  85.58%
[ 107, 115)  122   4.22%  89.80% #
[ 115, 122)   99   3.42%  93.22% #
[ 122, 129)   73   2.52%  95.75%
[ 129, 137)   71   2.46%  98.20%
[ 137, 144)   40   1.38%  99.59%
[ 144, 151]   12   0.41% 100.00%

Attribute in nodes:
    847 : bill_length_mm [NUMERICAL]
    495 : bill_depth_mm [NUMERICAL]
    440 : flipper_length_mm [NUMERICAL]
    420 : island [CATEGORICAL]
    319 : body_mass_g [NUMERICAL]
    51 : sex [CATEGORICAL]
    20 : year [NUMERICAL]

Attribute in nodes with depth <= 0:
    164 : flipper_length_mm [NUMERICAL]
    74 : bill_length_mm [NUMERICAL]
    44 : bill_depth_mm [NUMERICAL]
    15 : island [CATEGORICAL]
    3 : body_mass_g [NUMERICAL]

Attribute in nodes with depth <= 1:
    250 : bill_length_mm [NUMERICAL]
    249 : flipper_length_mm [NUMERICAL]
    220 : island [CATEGORICAL]
    131 : bill_depth_mm [NUMERICAL]
    49 : body_mass_g [NUMERICAL]

Attribute in nodes with depth <= 2:
    464 : bill_length_mm [NUMERICAL]
    372 : island [CATEGORICAL]
    344 : flipper_length_mm [NUMERICAL]
    270 : bill_depth_mm [NUMERICAL]
    154 : body_mass_g [NUMERICAL]
    11 : sex [CATEGORICAL]
    3 : year [NUMERICAL]

Attribute in nodes with depth <= 3:
    722 : bill_length_mm [NUMERICAL]
    413 : island [CATEGORICAL]
    412 : bill_depth_mm [NUMERICAL]
    399 : flipper_length_mm [NUMERICAL]
    248 : body_mass_g [NUMERICAL]
    38 : sex [CATEGORICAL]
    12 : year [NUMERICAL]

Attribute in nodes with depth <= 5:
    847 : bill_length_mm [NUMERICAL]
    495 : bill_depth_mm [NUMERICAL]
    440 : flipper_length_mm [NUMERICAL]
    420 : island [CATEGORICAL]
    319 : body_mass_g [NUMERICAL]
    51 : sex [CATEGORICAL]
    20 : year [NUMERICAL]

Condition type in nodes:
    2121 : HigherCondition
    471 : ContainsBitmapCondition
Condition type in nodes with depth <= 0:
    285 : HigherCondition
    15 : ContainsBitmapCondition
Condition type in nodes with depth <= 1:
    679 : HigherCondition
    220 : ContainsBitmapCondition
Condition type in nodes with depth <= 2:
    1235 : HigherCondition
    383 : ContainsBitmapCondition
Condition type in nodes with depth <= 3:
    1793 : HigherCondition
    451 : ContainsBitmapCondition
Condition type in nodes with depth <= 5:
    2121 : HigherCondition
    471 : ContainsBitmapCondition
Node format: NOT_SET

Training OOB:
    trees: 1, Out-of-bag evaluation: accuracy:0.95 logloss:1.80218
    trees: 11, Out-of-bag evaluation: accuracy:0.96793 logloss:0.472141
    trees: 21, Out-of-bag evaluation: accuracy:0.982558 logloss:0.258458
    trees: 32, Out-of-bag evaluation: accuracy:0.979651 logloss:0.26222
    trees: 45, Out-of-bag evaluation: accuracy:0.979651 logloss:0.163172
    trees: 57, Out-of-bag evaluation: accuracy:0.976744 logloss:0.163648
    trees: 69, Out-of-bag evaluation: accuracy:0.979651 logloss:0.164942
    trees: 79, Out-of-bag evaluation: accuracy:0.976744 logloss:0.165327
    trees: 89, Out-of-bag evaluation: accuracy:0.979651 logloss:0.165204
    trees: 99, Out-of-bag evaluation: accuracy:0.976744 logloss:0.165402
    trees: 110, Out-of-bag evaluation: accuracy:0.976744 logloss:0.164999
    trees: 120, Out-of-bag evaluation: accuracy:0.976744 logloss:0.164786
    trees: 131, Out-of-bag evaluation: accuracy:0.976744 logloss:0.164439
    trees: 141, Out-of-bag evaluation: accuracy:0.976744 logloss:0.162631
    trees: 153, Out-of-bag evaluation: accuracy:0.976744 logloss:0.163604
    trees: 163, Out-of-bag evaluation: accuracy:0.976744 logloss:0.16199
    trees: 173, Out-of-bag evaluation: accuracy:0.976744 logloss:0.16083
    trees: 183, Out-of-bag evaluation: accuracy:0.976744 logloss:0.160679
    trees: 193, Out-of-bag evaluation: accuracy:0.976744 logloss:0.160417
    trees: 203, Out-of-bag evaluation: accuracy:0.976744 logloss:0.16071
    trees: 213, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0684506
    trees: 223, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690158
    trees: 234, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690293
    trees: 244, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0689818
    trees: 255, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0696632
    trees: 266, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0697246
    trees: 276, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0692516
    trees: 286, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0697065
    trees: 296, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0699951
    trees: 300, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0695934

Note the multiple variable importances with name MEAN_DECREASE_IN_*.

Plotting the model

Next, plot the model.

A Random Forest is a large model (this model has 300 trees and ~5k nodes; see the summary above). Therefore, only plot the first tree, and limit the nodes to depth 3.

tfdf.model_plotter.plot_model_in_colab(model, tree_idx=0, max_depth=3)

Inspect the model structure

The model structure and meta-data is available through the inspector created by make_inspector().

inspector = model.make_inspector()

For our model, the available inspector fields are:

[field for field in dir(inspector) if not field.startswith("_")]
['MODEL_NAME',
 'dataspec',
 'evaluation',
 'export_to_tensorboard',
 'extract_all_trees',
 'extract_tree',
 'features',
 'iterate_on_nodes',
 'label',
 'label_classes',
 'model_type',
 'num_trees',
 'objective',
 'specialized_header',
 'task',
 'training_logs',
 'variable_importances',
 'winner_take_all_inference']

Remember to see the API-reference or use ? for the builtin documentation.

?inspector.model_type

Some of the model meta-data:

print("Model type:", inspector.model_type())
print("Number of trees:", inspector.num_trees())
print("Objective:", inspector.objective())
print("Input features:", inspector.features())
Model type: RANDOM_FOREST
Number of trees: 300
Objective: Classification(label=__LABEL, class=None, num_classes=3)
Input features: ["bill_depth_mm" (1; #0), "bill_length_mm" (1; #1), "body_mass_g" (1; #2), "flipper_length_mm" (1; #3), "island" (4; #4), "sex" (4; #5), "year" (1; #6)]

evaluate() is the evaluation of the model computed during training. The dataset used for this evaluation depends on the algorithm. For example, it can be the validation dataset or the out-of-bag-dataset .

inspector.evaluation()
Evaluation(num_examples=344, accuracy=0.9767441860465116, loss=0.06959344029729796, rmse=None, ndcg=None, aucs=None)

The variable importances are:

print(f"Available variable importances:")
for importance in inspector.variable_importances().keys():
  print("\t", importance)
Available variable importances:
     MEAN_DECREASE_IN_PRAUC_3_VS_OTHERS
     NUM_AS_ROOT
     SUM_SCORE
     MEAN_DECREASE_IN_AUC_2_VS_OTHERS
     NUM_NODES
     MEAN_DECREASE_IN_AP_3_VS_OTHERS
     MEAN_MIN_DEPTH
     MEAN_DECREASE_IN_AUC_3_VS_OTHERS
     MEAN_DECREASE_IN_AUC_1_VS_OTHERS
     MEAN_DECREASE_IN_PRAUC_1_VS_OTHERS
     MEAN_DECREASE_IN_AP_2_VS_OTHERS
     MEAN_DECREASE_IN_ACCURACY
     MEAN_DECREASE_IN_AP_1_VS_OTHERS
     MEAN_DECREASE_IN_PRAUC_2_VS_OTHERS

Different variable importances have different semantics. For example, a feature with a mean decrease in auc of 0.05 means that removing this feature from the training dataset would reduce/hurt the AUC by 5%.

# Mean decrease in AUC of the class 1 vs the others.
inspector.variable_importances()["MEAN_DECREASE_IN_AUC_1_VS_OTHERS"]
[("bill_length_mm" (1; #1), 0.06972998903508787),
 ("flipper_length_mm" (1; #3), 0.008532072368420462),
 ("island" (4; #4), 0.004488760964911687),
 ("bill_depth_mm" (1; #0), 0.0032552083333338144),
 ("year" (1; #6), 0.0022101151315789824),
 ("sex" (4; #5), 0.0004625822368422572),
 ("body_mass_g" (1; #2), -5.1398026315596645e-05)]

Finaly, access the actual tree structure:

inspector.extract_tree(tree_idx=0)
Tree(NonLeafNode(condition=(bill_length_mm >= 43.349998474121094; miss=True), pos_child=NonLeafNode(condition=(island in ['Biscoe']; miss=True), pos_child=NonLeafNode(condition=(bill_depth_mm >= 16.349998474121094; miss=True), pos_child=NonLeafNode(condition=(bill_length_mm >= 48.19999694824219; miss=False), pos_child=LeafNode(value=ProbabilityValue([0.0, 0.0, 1.0],n=7.0)), neg_child=LeafNode(value=ProbabilityValue([0.2, 0.0, 0.8],n=5.0)), value=ProbabilityValue([0.08333333333333333, 0.0, 0.9166666666666666],n=12.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 0.0, 1.0],n=96.0)), value=ProbabilityValue([0.009259259259259259, 0.0, 0.9907407407407407],n=108.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 1.0, 0.0],n=61.0)), value=ProbabilityValue([0.005917159763313609, 0.3609467455621302, 0.6331360946745562],n=169.0)), neg_child=NonLeafNode(condition=(bill_depth_mm >= 15.949999809265137; miss=True), pos_child=NonLeafNode(condition=(bill_length_mm >= 42.349998474121094; miss=True), pos_child=NonLeafNode(condition=(flipper_length_mm >= 189.5; miss=True), pos_child=LeafNode(value=ProbabilityValue([1.0, 0.0, 0.0],n=10.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 1.0, 0.0],n=5.0)), value=ProbabilityValue([0.6666666666666666, 0.3333333333333333, 0.0],n=15.0)), neg_child=NonLeafNode(condition=(bill_depth_mm >= 16.650001525878906; miss=True), pos_child=LeafNode(value=ProbabilityValue([1.0, 0.0, 0.0],n=137.0)), neg_child=NonLeafNode(condition=(bill_length_mm >= 36.75; miss=True), pos_child=LeafNode(value=ProbabilityValue([0.8, 0.2, 0.0],n=5.0)), neg_child=LeafNode(value=ProbabilityValue([1.0, 0.0, 0.0],n=10.0)), value=ProbabilityValue([0.9333333333333333, 0.06666666666666667, 0.0],n=15.0)), value=ProbabilityValue([0.993421052631579, 0.006578947368421052, 0.0],n=152.0)), value=ProbabilityValue([0.9640718562874252, 0.03592814371257485, 0.0],n=167.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 0.0, 1.0],n=8.0)), value=ProbabilityValue([0.92, 0.03428571428571429, 0.045714285714285714],n=175.0)), value=ProbabilityValue([0.47093023255813954, 0.19476744186046513, 0.33430232558139533],n=344.0)),label_classes={self.label_classes})

Extracting a tree is not efficient. If speed is important, the model inspection can be done with the iterate_on_nodes() method instead. This method is a Depth First Pre-order traversals iterator on all the nodes of the model.

For following example computes how many times each feature is used (this is a kind of structural variable importance):

# number_of_use[F] will be the number of node using feature F in its condition.
number_of_use = collections.defaultdict(lambda: 0)

# Iterate over all the nodes in a Depth First Pre-order traversals.
for node_iter in inspector.iterate_on_nodes():

  if not isinstance(node_iter.node, tfdf.py_tree.node.NonLeafNode):
    # Skip the leaf nodes
    continue

  # Iterate over all the features used in the condition.
  # By default, models are "oblique" i.e. each node tests a single feature.
  for feature in node_iter.node.condition.features():
    number_of_use[feature] += 1

print("Number of condition nodes per features:")
for feature, count in number_of_use.items():
  print("\t", feature.name, ":", count)
Number of condition nodes per features:
     bill_length_mm : 847
     bill_depth_mm : 495
     flipper_length_mm : 440
     island : 420
     body_mass_g : 319
     sex : 51
     year : 20

Creating a model by hand

In this section you will create a small Random Forest model by hand. To make it extra easy, the model will only contain one simple tree:

3 label classes: Red, blue and green.
2 features: f1 (numerical) and f2 (string categorical)

f1>=1.5
    ├─(pos)─ f2 in ["cat","dog"]
    │         ├─(pos)─ value: [0.8, 0.1, 0.1]
    │         └─(neg)─ value: [0.1, 0.8, 0.1]
    └─(neg)─ value: [0.1, 0.1, 0.8]
# Create the model builder
builder = tfdf.builder.RandomForestBuilder(
    path="/tmp/manual_model",
    objective=tfdf.py_tree.objective.ClassificationObjective(
        label="color", classes=["red", "blue", "green"]))

Each tree is added one by one.

# So alias
Tree = tfdf.py_tree.tree.Tree
SimpleColumnSpec = tfdf.py_tree.dataspec.SimpleColumnSpec
ColumnType = tfdf.py_tree.dataspec.ColumnType
# Nodes
NonLeafNode = tfdf.py_tree.node.NonLeafNode
LeafNode = tfdf.py_tree.node.LeafNode
# Conditions
NumericalHigherThanCondition = tfdf.py_tree.condition.NumericalHigherThanCondition
CategoricalIsInCondition = tfdf.py_tree.condition.CategoricalIsInCondition
# Leaf values
ProbabilityValue = tfdf.py_tree.value.ProbabilityValue

builder.add_tree(
    Tree(
        NonLeafNode(
            condition=NumericalHigherThanCondition(
                feature=SimpleColumnSpec(name="f1", type=ColumnType.NUMERICAL),
                threshold=1.5,
                missing_evaluation=False),
            pos_child=NonLeafNode(
                condition=CategoricalIsInCondition(
                    feature=SimpleColumnSpec(name="f2",type=ColumnType.CATEGORICAL),
                    mask=["cat", "dog"],
                    missing_evaluation=False),
                pos_child=LeafNode(value=ProbabilityValue(probability=[0.8, 0.1, 0.1], num_examples=10)),
                neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.8, 0.1], num_examples=20))),
            neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.1, 0.8], num_examples=30)))))

Conclude the tree writing

builder.close()
[INFO kernel.cc:988] Loading model from path
[INFO decision_forest.cc:590] Model loaded with 1 root(s), 5 node(s), and 2 input feature(s).
[INFO kernel.cc:848] Use fast generic engine
2021-09-01 11:08:37.515994: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: /tmp/manual_model/assets
INFO:tensorflow:Assets written to: /tmp/manual_model/assets

Now you can open the model as a regular keras model, and make predictions:

manual_model = tf.keras.models.load_model("/tmp/manual_model")
[INFO kernel.cc:988] Loading model from path
[INFO decision_forest.cc:590] Model loaded with 1 root(s), 5 node(s), and 2 input feature(s).
[INFO kernel.cc:848] Use fast generic engine
examples = tf.data.Dataset.from_tensor_slices({
        "f1": [1.0, 2.0, 3.0],
        "f2": ["cat", "cat", "bird"]
    }).batch(2)

predictions = manual_model.predict(examples)

print("predictions:\n",predictions)
predictions:
 [[0.1 0.1 0.8]
 [0.8 0.1 0.1]
 [0.1 0.8 0.1]]

Access the structure:

yggdrasil_model_path = manual_model.yggdrasil_model_path_tensor().numpy().decode("utf-8")
print("yggdrasil_model_path:",yggdrasil_model_path)

inspector = tfdf.inspector.make_inspector(yggdrasil_model_path)
print("Input features:", inspector.features())
yggdrasil_model_path: /tmp/manual_model/assets/
Input features: ["f1" (1; #1), "f2" (4; #2)]

And of course, you can plot this manually constructed model:

tfdf.model_plotter.plot_model_in_colab(manual_model)