Inspect and debug decision forest models

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook

In this colab, you will learn how to inspect and create the structure of a model directly. We assume you are familiar with the concepts introduced in the beginner and intermediate colabs.

In this colab, you will:

  1. Train a Random Forest model and access its structure programmatically.

  2. Create a Random Forest model by hand and use it as a classical model.

Setup

# Install TensorFlow Decision Forests.
pip install tensorflow_decision_forests

# Use wurlitzer to show the training logs.
pip install wurlitzer
import os
# Keep using Keras 2
os.environ['TF_USE_LEGACY_KERAS'] = '1'

import tensorflow_decision_forests as tfdf

import numpy as np
import pandas as pd
import tensorflow as tf
import tf_keras
import matplotlib.pyplot as plt
import math
import collections
2026-01-12 14:08:35.487684: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1768226915.510047  150155 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1768226915.517716  150155 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1768226915.536150  150155 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768226915.536175  150155 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768226915.536178  150155 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768226915.536180  150155 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.

The hidden code cell limits the output height in colab.

Train a simple Random Forest

We train a Random Forest like in the beginner colab:

# Download the dataset
!wget -q https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins.csv -O /tmp/penguins.csv

# Load a dataset into a Pandas Dataframe.
dataset_df = pd.read_csv("/tmp/penguins.csv")

# Show the first three examples.
print(dataset_df.head(3))

# Convert the pandas dataframe into a tf dataset.
dataset_tf = tfdf.keras.pd_dataframe_to_tf_dataset(dataset_df, label="species")

# Train the Random Forest
model = tfdf.keras.RandomForestModel(compute_oob_variable_importances=True)
model.fit(x=dataset_tf)
species     island  bill_length_mm  bill_depth_mm  flipper_length_mm  \
0  Adelie  Torgersen            39.1           18.7              181.0   
1  Adelie  Torgersen            39.5           17.4              186.0   
2  Adelie  Torgersen            40.3           18.0              195.0   

   body_mass_g     sex  year  
0       3750.0    male  2007  
1       3800.0  female  2007  
2       3250.0  female  2007  
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
I0000 00:00:1768226920.022534  150155 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13638 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5
I0000 00:00:1768226920.024914  150155 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13756 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:06.0, compute capability: 7.5
I0000 00:00:1768226920.027104  150155 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 13756 MB memory:  -> device: 2, name: Tesla T4, pci bus id: 0000:00:07.0, compute capability: 7.5
I0000 00:00:1768226920.029256  150155 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 13756 MB memory:  -> device: 3, name: Tesla T4, pci bus id: 0000:00:08.0, compute capability: 7.5
WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
Use /tmpfs/tmp/tmp1133vua7 as temporary training directory
Reading training dataset...
Training dataset read in 0:00:03.771078. Found 344 examples.
Training model...
Model trained in 0:00:00.097883
Compiling model...
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1768226924.178818  150155 kernel.cc:782] Start Yggdrasil model training
I0000 00:00:1768226924.178870  150155 kernel.cc:783] Collect training examples
I0000 00:00:1768226924.178879  150155 kernel.cc:795] Dataspec guide:
column_guides {
  column_name_pattern: "^__LABEL$"
  type: CATEGORICAL
  categorial {
    min_vocab_frequency: 0
    max_vocab_count: -1
  }
}
default_column_guide {
  categorial {
    max_vocab_count: 2000
  }
  discretized_numerical {
    maximum_num_bins: 255
  }
}
ignore_columns_without_guides: false
detect_numerical_as_discretized_numerical: false

I0000 00:00:1768226924.179251  150155 kernel.cc:401] Number of batches: 1
I0000 00:00:1768226924.179269  150155 kernel.cc:402] Number of examples: 344
I0000 00:00:1768226924.179368  150155 kernel.cc:802] Training dataset:
Number of records: 344
Number of columns: 8

Number of columns by type:
    NUMERICAL: 5 (62.5%)
    CATEGORICAL: 3 (37.5%)

Columns:

NUMERICAL: 5 (62.5%)
    1: "bill_depth_mm" NUMERICAL num-nas:2 (0.581395%) mean:17.1512 min:13.1 max:21.5 sd:1.9719
    2: "bill_length_mm" NUMERICAL num-nas:2 (0.581395%) mean:43.9219 min:32.1 max:59.6 sd:5.4516
    3: "body_mass_g" NUMERICAL num-nas:2 (0.581395%) mean:4201.75 min:2700 max:6300 sd:800.781
    4: "flipper_length_mm" NUMERICAL num-nas:2 (0.581395%) mean:200.915 min:172 max:231 sd:14.0411
    7: "year" NUMERICAL mean:2008.03 min:2007 max:2009 sd:0.817166

CATEGORICAL: 3 (37.5%)
    0: "__LABEL" CATEGORICAL integerized vocab-size:4 no-ood-item
    5: "island" CATEGORICAL has-dict vocab-size:4 zero-ood-items most-frequent:"Biscoe" 168 (48.8372%)
    6: "sex" CATEGORICAL num-nas:11 (3.19767%) has-dict vocab-size:3 zero-ood-items most-frequent:"male" 168 (50.4505%)

Terminology:
    nas: Number of non-available (i.e. missing) values.
    ood: Out of dictionary.
    manually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred.
    tokenized: The attribute value is obtained through tokenization.
    has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string.
    vocab-size: Number of unique values.

I0000 00:00:1768226924.179391  150155 kernel.cc:818] Configure learner
I0000 00:00:1768226924.179572  150155 kernel.cc:831] Training config:
learner: "RANDOM_FOREST"
features: "^bill_depth_mm$"
features: "^bill_length_mm$"
features: "^body_mass_g$"
features: "^flipper_length_mm$"
features: "^island$"
features: "^sex$"
features: "^year$"
label: "^__LABEL$"
task: CLASSIFICATION
random_seed: 123456
metadata {
  framework: "TF Keras"
}
pure_serving_model: false
[yggdrasil_decision_forests.model.random_forest.proto.random_forest_config] {
  num_trees: 300
  decision_tree {
    max_depth: 16
    min_examples: 5
    in_split_min_examples_check: true
    keep_non_leaf_label_distribution: true
    num_candidate_attributes: 0
    missing_value_policy: GLOBAL_IMPUTATION
    allow_na_conditions: false
    categorical_set_greedy_forward {
      sampling: 0.1
      max_num_items: -1
      min_item_frequency: 1
    }
    growing_strategy_local {
    }
    categorical {
      cart {
      }
    }
    axis_aligned_split {
    }
    internal {
      sorting_strategy: PRESORTED
    }
    uplift {
      min_examples_in_treatment: 5
      split_score: KULLBACK_LEIBLER
    }
    numerical_vector_sequence {
      max_num_test_examples: 1000
      num_random_selected_anchors: 100
    }
  }
  winner_take_all_inference: true
  compute_oob_performances: true
  compute_oob_variable_importances: true
  num_oob_variable_importances_permutations: 1
  bootstrap_training_dataset: true
  bootstrap_size_ratio: 1
  adapt_bootstrap_size_ratio_for_maximum_training_duration: false
  sampling_with_replacement: true
}

I0000 00:00:1768226924.179932  150155 kernel.cc:834] Deployment config:
cache_path: "/tmpfs/tmp/tmp1133vua7/working_cache"
num_threads: 32
try_resume_training: true

I0000 00:00:1768226924.180042  150380 kernel.cc:895] Train model
I0000 00:00:1768226924.180148  150380 random_forest.cc:438] Training random forest on 344 example(s) and 7 feature(s).
I0000 00:00:1768226924.184195  150380 gpu.cc:93] Cannot initialize GPU: Not compiled with GPU support
I0000 00:00:1768226924.186044  150389 random_forest.cc:865] Train tree 1/300 accuracy:0.944882 logloss:1.98666 [index:2 total:0.00s tree:0.00s]
I0000 00:00:1768226924.186216  150398 random_forest.cc:865] Train tree 11/300 accuracy:0.945813 logloss:1.77896 [index:10 total:0.00s tree:0.00s]
I0000 00:00:1768226924.186533  150409 random_forest.cc:865] Train tree 24/300 accuracy:0.965986 logloss:0.76054 [index:23 total:0.00s tree:0.00s]
I0000 00:00:1768226924.186858  150414 random_forest.cc:865] Train tree 34/300 accuracy:0.965732 logloss:0.477947 [index:29 total:0.00s tree:0.00s]
I0000 00:00:1768226924.188161  150399 random_forest.cc:865] Train tree 45/300 accuracy:0.965116 logloss:0.578036 [index:44 total:0.00s tree:0.00s]
I0000 00:00:1768226924.190016  150408 random_forest.cc:865] Train tree 55/300 accuracy:0.97093 logloss:0.376297 [index:54 total:0.01s tree:0.00s]
I0000 00:00:1768226924.191278  150415 random_forest.cc:865] Train tree 65/300 accuracy:0.97093 logloss:0.378839 [index:64 total:0.01s tree:0.00s]
I0000 00:00:1768226924.193038  150393 random_forest.cc:865] Train tree 77/300 accuracy:0.973837 logloss:0.275567 [index:76 total:0.01s tree:0.00s]
I0000 00:00:1768226924.196059  150416 random_forest.cc:865] Train tree 98/300 accuracy:0.976744 logloss:0.17131 [index:97 total:0.01s tree:0.00s]
I0000 00:00:1768226924.197309  150417 random_forest.cc:865] Train tree 108/300 accuracy:0.976744 logloss:0.166666 [index:107 total:0.01s tree:0.00s]
I0000 00:00:1768226924.199854  150419 random_forest.cc:865] Train tree 125/300 accuracy:0.976744 logloss:0.0725837 [index:124 total:0.02s tree:0.00s]
I0000 00:00:1768226924.201389  150416 random_forest.cc:865] Train tree 136/300 accuracy:0.976744 logloss:0.0721071 [index:135 total:0.02s tree:0.00s]
I0000 00:00:1768226924.203526  150402 random_forest.cc:865] Train tree 150/300 accuracy:0.976744 logloss:0.0716672 [index:149 total:0.02s tree:0.00s]
I0000 00:00:1768226924.205494  150418 random_forest.cc:865] Train tree 163/300 accuracy:0.976744 logloss:0.0723785 [index:162 total:0.02s tree:0.00s]
I0000 00:00:1768226924.206989  150395 random_forest.cc:865] Train tree 173/300 accuracy:0.976744 logloss:0.0718518 [index:172 total:0.02s tree:0.00s]
I0000 00:00:1768226924.208726  150402 random_forest.cc:865] Train tree 184/300 accuracy:0.976744 logloss:0.0708714 [index:183 total:0.02s tree:0.00s]
I0000 00:00:1768226924.210282  150413 random_forest.cc:865] Train tree 194/300 accuracy:0.976744 logloss:0.0697211 [index:193 total:0.03s tree:0.00s]
I0000 00:00:1768226924.211915  150393 random_forest.cc:865] Train tree 204/300 accuracy:0.976744 logloss:0.0684952 [index:203 total:0.03s tree:0.00s]
I0000 00:00:1768226924.213450  150411 random_forest.cc:865] Train tree 215/300 accuracy:0.976744 logloss:0.0681587 [index:214 total:0.03s tree:0.00s]
I0000 00:00:1768226924.215064  150388 random_forest.cc:865] Train tree 225/300 accuracy:0.976744 logloss:0.0688663 [index:224 total:0.03s tree:0.00s]
I0000 00:00:1768226924.216633  150409 random_forest.cc:865] Train tree 237/300 accuracy:0.976744 logloss:0.068661 [index:236 total:0.03s tree:0.00s]
I0000 00:00:1768226924.218987  150408 random_forest.cc:865] Train tree 251/300 accuracy:0.976744 logloss:0.0680787 [index:250 total:0.03s tree:0.00s]
I0000 00:00:1768226924.221144  150390 random_forest.cc:865] Train tree 267/300 accuracy:0.976744 logloss:0.0685897 [index:266 total:0.04s tree:0.00s]
I0000 00:00:1768226924.223877  150394 random_forest.cc:865] Train tree 287/300 accuracy:0.976744 logloss:0.0681651 [index:286 total:0.04s tree:0.00s]
I0000 00:00:1768226924.225839  150395 random_forest.cc:865] Train tree 300/300 accuracy:0.976744 logloss:0.0676584 [index:299 total:0.04s tree:0.00s]
I0000 00:00:1768226924.231081  150380 random_forest.cc:949] Final OOB metrics: accuracy:0.976744 logloss:0.0676584
I0000 00:00:1768226924.232352  150380 feature_importance.cc:196] Running 8 features on 32 threads with 1 rounds
I0000 00:00:1768226924.236343  150380 kernel.cc:926] Export model in log directory: /tmpfs/tmp/tmp1133vua7 with prefix f016df2d741d4761
I0000 00:00:1768226924.239996  150380 kernel.cc:944] Save model in resources
I0000 00:00:1768226924.242495  150155 abstract_model.cc:921] Model self evaluation:
Number of predictions (without weights): 344
Number of predictions (with weights): 344
Task: CLASSIFICATION
Label: __LABEL

Accuracy: 0.976744  CI95[W][0.958431 0.988377]
LogLoss: : 0.0676584
ErrorRate: : 0.0232558

Default Accuracy: : 0.44186
Default LogLoss: : 1.04916
Default ErrorRate: : 0.55814

Confusion Table:
truth\prediction
     1   2    3
1  148   3    1
2    2  66    0
3    2   0  122
Total: 344


WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1768226924.267848  150155 decision_forest.cc:808] Model loaded with 300 root(s), 5080 node(s), and 7 input feature(s).
I0000 00:00:1768226924.271050  150155 abstract_model.cc:1439] Engine "RandomForestGeneric" built
Model compiled.
<tf_keras.src.callbacks.History at 0x7fed709d28b0>

Note the compute_oob_variable_importances=True hyper-parameter in the model constructor. This option computes the Out-of-bag (OOB) variable importance during training. This is a popular permutation variable importance for Random Forest models.

Computing the OOB Variable importance does not impact the final model, it will slow the training on large datasets.

Check the model summary:

%set_cell_height 300

model.summary()
<IPython.core.display.Javascript object>
Model: "random_forest_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
=================================================================
Total params: 1 (1.00 Byte)
Trainable params: 0 (0.00 Byte)
Non-trainable params: 1 (1.00 Byte)
_________________________________________________________________
Type: "RANDOM_FOREST"
Task: CLASSIFICATION
Label: "__LABEL"

Input Features (7):
    bill_depth_mm
    bill_length_mm
    body_mass_g
    flipper_length_mm
    island
    sex
    year

No weights

Variable Importance: INV_MEAN_MIN_DEPTH:

    1. "flipper_length_mm"  0.440513 ################
    2.    "bill_length_mm"  0.438028 ###############
    3.     "bill_depth_mm"  0.299751 #####
    4.            "island"  0.295079 #####
    5.       "body_mass_g"  0.256534 ##
    6.               "sex"  0.225708 
    7.              "year"  0.224020 

Variable Importance: MEAN_DECREASE_IN_ACCURACY:

    1.    "bill_length_mm"  0.151163 ################
    2.            "island"  0.008721 #
    3.     "bill_depth_mm"  0.000000 
    4.       "body_mass_g"  0.000000 
    5.               "sex"  0.000000 
    6.              "year"  0.000000 
    7. "flipper_length_mm" -0.002907 

Variable Importance: MEAN_DECREASE_IN_AP_1_VS_OTHERS:

    1.    "bill_length_mm"  0.083305 ################
    2.            "island"  0.007664 #
    3. "flipper_length_mm"  0.003400 
    4.     "bill_depth_mm"  0.002741 
    5.       "body_mass_g"  0.000722 
    6.               "sex"  0.000644 
    7.              "year"  0.000000 

Variable Importance: MEAN_DECREASE_IN_AP_2_VS_OTHERS:

    1.    "bill_length_mm"  0.508510 ################
    2.            "island"  0.023487 
    3.     "bill_depth_mm"  0.007744 
    4. "flipper_length_mm"  0.006008 
    5.       "body_mass_g"  0.003017 
    6.               "sex"  0.001537 
    7.              "year" -0.000245 

Variable Importance: MEAN_DECREASE_IN_AP_3_VS_OTHERS:

    1.            "island"  0.002192 ################
    2.    "bill_length_mm"  0.001572 ############
    3.     "bill_depth_mm"  0.000497 #######
    4.               "sex"  0.000000 ####
    5.              "year"  0.000000 ####
    6.       "body_mass_g" -0.000053 ####
    7. "flipper_length_mm" -0.000890 

Variable Importance: MEAN_DECREASE_IN_AUC_1_VS_OTHERS:

    1.    "bill_length_mm"  0.071306 ################
    2.            "island"  0.007299 #
    3. "flipper_length_mm"  0.004506 #
    4.     "bill_depth_mm"  0.002124 
    5.       "body_mass_g"  0.000548 
    6.               "sex"  0.000480 
    7.              "year"  0.000000 

Variable Importance: MEAN_DECREASE_IN_AUC_2_VS_OTHERS:

    1.    "bill_length_mm"  0.108642 ################
    2.            "island"  0.014493 ##
    3.     "bill_depth_mm"  0.007406 #
    4. "flipper_length_mm"  0.005195 
    5.       "body_mass_g"  0.001012 
    6.               "sex"  0.000480 
    7.              "year" -0.000053 

Variable Importance: MEAN_DECREASE_IN_AUC_3_VS_OTHERS:

    1.            "island"  0.002126 ################
    2.    "bill_length_mm"  0.001393 ###########
    3.     "bill_depth_mm"  0.000293 #####
    4.               "sex"  0.000000 ###
    5.              "year"  0.000000 ###
    6.       "body_mass_g" -0.000037 ###
    7. "flipper_length_mm" -0.000550 

Variable Importance: MEAN_DECREASE_IN_PRAUC_1_VS_OTHERS:

    1.    "bill_length_mm"  0.083122 ################
    2.            "island"  0.010887 ##
    3. "flipper_length_mm"  0.003425 
    4.     "bill_depth_mm"  0.002731 
    5.       "body_mass_g"  0.000719 
    6.               "sex"  0.000641 
    7.              "year"  0.000000 

Variable Importance: MEAN_DECREASE_IN_PRAUC_2_VS_OTHERS:

    1.    "bill_length_mm"  0.497611 ################
    2.            "island"  0.024045 
    3.     "bill_depth_mm"  0.007734 
    4. "flipper_length_mm"  0.006017 
    5.       "body_mass_g"  0.003000 
    6.               "sex"  0.001528 
    7.              "year" -0.000243 

Variable Importance: MEAN_DECREASE_IN_PRAUC_3_VS_OTHERS:

    1.            "island"  0.002187 ################
    2.    "bill_length_mm"  0.001568 ############
    3.     "bill_depth_mm"  0.000495 #######
    4.               "sex"  0.000000 ####
    5.              "year"  0.000000 ####
    6.       "body_mass_g" -0.000053 ####
    7. "flipper_length_mm" -0.000886 

Variable Importance: NUM_AS_ROOT:

    1. "flipper_length_mm" 157.000000 ################
    2.    "bill_length_mm" 76.000000 #######
    3.     "bill_depth_mm" 52.000000 #####
    4.            "island" 12.000000 
    5.       "body_mass_g"  3.000000 

Variable Importance: NUM_NODES:

    1.    "bill_length_mm" 778.000000 ################
    2.     "bill_depth_mm" 463.000000 #########
    3. "flipper_length_mm" 414.000000 ########
    4.            "island" 342.000000 ######
    5.       "body_mass_g" 338.000000 ######
    6.               "sex" 36.000000 
    7.              "year" 19.000000 

Variable Importance: SUM_SCORE:

    1.    "bill_length_mm" 36515.793787 ################
    2. "flipper_length_mm" 35120.434174 ###############
    3.            "island" 14669.408395 ######
    4.     "bill_depth_mm" 14515.446617 ######
    5.       "body_mass_g" 3485.330881 #
    6.               "sex" 354.201073 
    7.              "year" 49.737758 



Winner takes all: true
Out-of-bag evaluation: accuracy:0.976744 logloss:0.0676584
Number of trees: 300
Total number of nodes: 5080

Number of nodes by tree:
Count: 300 Average: 16.9333 StdDev: 3.10197
Min: 11 Max: 31 Ignored: 0
----------------------------------------------
[ 11, 12)  6   2.00%   2.00% #
[ 12, 13)  0   0.00%   2.00%
[ 13, 14) 46  15.33%  17.33% #####
[ 14, 15)  0   0.00%  17.33%
[ 15, 16) 70  23.33%  40.67% ########
[ 16, 17)  0   0.00%  40.67%
[ 17, 18) 84  28.00%  68.67% ##########
[ 18, 19)  0   0.00%  68.67%
[ 19, 20) 46  15.33%  84.00% #####
[ 20, 21)  0   0.00%  84.00%
[ 21, 22) 30  10.00%  94.00% ####
[ 22, 23)  0   0.00%  94.00%
[ 23, 24) 13   4.33%  98.33% ##
[ 24, 25)  0   0.00%  98.33%
[ 25, 26)  2   0.67%  99.00%
[ 26, 27)  0   0.00%  99.00%
[ 27, 28)  2   0.67%  99.67%
[ 28, 29)  0   0.00%  99.67%
[ 29, 30)  0   0.00%  99.67%
[ 30, 31]  1   0.33% 100.00%

Depth by leafs:
Count: 2690 Average: 3.53271 StdDev: 1.06789
Min: 2 Max: 7 Ignored: 0
----------------------------------------------
[ 2, 3) 545  20.26%  20.26% ######
[ 3, 4) 747  27.77%  48.03% ########
[ 4, 5) 888  33.01%  81.04% ##########
[ 5, 6) 444  16.51%  97.55% #####
[ 6, 7)  62   2.30%  99.85% #
[ 7, 7]   4   0.15% 100.00%

Number of training obs by leaf:
Count: 2690 Average: 38.3643 StdDev: 44.8651
Min: 5 Max: 155 Ignored: 0
----------------------------------------------
[   5,  12) 1474  54.80%  54.80% ##########
[  12,  20)  124   4.61%  59.41% #
[  20,  27)   48   1.78%  61.19%
[  27,  35)   74   2.75%  63.94% #
[  35,  42)   58   2.16%  66.10%
[  42,  50)   85   3.16%  69.26% #
[  50,  57)   96   3.57%  72.83% #
[  57,  65)   87   3.23%  76.06% #
[  65,  72)   49   1.82%  77.88%
[  72,  80)   23   0.86%  78.74%
[  80,  88)   30   1.12%  79.85%
[  88,  95)   23   0.86%  80.71%
[  95, 103)   42   1.56%  82.27%
[ 103, 110)   62   2.30%  84.57%
[ 110, 118)  115   4.28%  88.85% #
[ 118, 125)  115   4.28%  93.12% #
[ 125, 133)   98   3.64%  96.77% #
[ 133, 140)   49   1.82%  98.59%
[ 140, 148)   31   1.15%  99.74%
[ 148, 155]    7   0.26% 100.00%

Attribute in nodes:
    778 : bill_length_mm [NUMERICAL]
    463 : bill_depth_mm [NUMERICAL]
    414 : flipper_length_mm [NUMERICAL]
    342 : island [CATEGORICAL]
    338 : body_mass_g [NUMERICAL]
    36 : sex [CATEGORICAL]
    19 : year [NUMERICAL]

Attribute in nodes with depth <= 0:
    157 : flipper_length_mm [NUMERICAL]
    76 : bill_length_mm [NUMERICAL]
    52 : bill_depth_mm [NUMERICAL]
    12 : island [CATEGORICAL]
    3 : body_mass_g [NUMERICAL]

Attribute in nodes with depth <= 1:
    250 : bill_length_mm [NUMERICAL]
    244 : flipper_length_mm [NUMERICAL]
    183 : bill_depth_mm [NUMERICAL]
    170 : island [CATEGORICAL]
    53 : body_mass_g [NUMERICAL]

Attribute in nodes with depth <= 2:
    462 : bill_length_mm [NUMERICAL]
    320 : flipper_length_mm [NUMERICAL]
    310 : bill_depth_mm [NUMERICAL]
    287 : island [CATEGORICAL]
    162 : body_mass_g [NUMERICAL]
    9 : sex [CATEGORICAL]
    5 : year [NUMERICAL]

Attribute in nodes with depth <= 3:
    669 : bill_length_mm [NUMERICAL]
    410 : bill_depth_mm [NUMERICAL]
    383 : flipper_length_mm [NUMERICAL]
    328 : island [CATEGORICAL]
    286 : body_mass_g [NUMERICAL]
    32 : sex [CATEGORICAL]
    10 : year [NUMERICAL]

Attribute in nodes with depth <= 5:
    778 : bill_length_mm [NUMERICAL]
    462 : bill_depth_mm [NUMERICAL]
    413 : flipper_length_mm [NUMERICAL]
    342 : island [CATEGORICAL]
    338 : body_mass_g [NUMERICAL]
    36 : sex [CATEGORICAL]
    19 : year [NUMERICAL]

Condition type in nodes:
    2012 : HigherCondition
    378 : ContainsBitmapCondition
Condition type in nodes with depth <= 0:
    288 : HigherCondition
    12 : ContainsBitmapCondition
Condition type in nodes with depth <= 1:
    730 : HigherCondition
    170 : ContainsBitmapCondition
Condition type in nodes with depth <= 2:
    1259 : HigherCondition
    296 : ContainsBitmapCondition
Condition type in nodes with depth <= 3:
    1758 : HigherCondition
    360 : ContainsBitmapCondition
Condition type in nodes with depth <= 5:
    2010 : HigherCondition
    378 : ContainsBitmapCondition
Node format: NOT_SET

Training OOB:
    trees: 1, Out-of-bag evaluation: accuracy:0.944882 logloss:1.98666
    trees: 11, Out-of-bag evaluation: accuracy:0.945813 logloss:1.77896
    trees: 24, Out-of-bag evaluation: accuracy:0.965986 logloss:0.76054
    trees: 34, Out-of-bag evaluation: accuracy:0.965732 logloss:0.477947
    trees: 45, Out-of-bag evaluation: accuracy:0.965116 logloss:0.578036
    trees: 55, Out-of-bag evaluation: accuracy:0.97093 logloss:0.376297
    trees: 65, Out-of-bag evaluation: accuracy:0.97093 logloss:0.378839
    trees: 77, Out-of-bag evaluation: accuracy:0.973837 logloss:0.275567
    trees: 98, Out-of-bag evaluation: accuracy:0.976744 logloss:0.17131
    trees: 108, Out-of-bag evaluation: accuracy:0.976744 logloss:0.166666
    trees: 125, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0725837
    trees: 136, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0721071
    trees: 150, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0716672
    trees: 163, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0723785
    trees: 173, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0718518
    trees: 184, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0708714
    trees: 194, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0697211
    trees: 204, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0684952
    trees: 215, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0681587
    trees: 225, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0688663
    trees: 237, Out-of-bag evaluation: accuracy:0.976744 logloss:0.068661
    trees: 251, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0680787
    trees: 267, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0685897
    trees: 287, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0681651
    trees: 300, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0676584

Note the multiple variable importances with name MEAN_DECREASE_IN_*.

Plotting the model

Next, plot the model.

A Random Forest is a large model (this model has 300 trees and ~5k nodes; see the summary above). Therefore, only plot the first tree, and limit the nodes to depth 3.

tfdf.model_plotter.plot_model_in_colab(model, tree_idx=0, max_depth=3)

Inspect the model structure

The model structure and meta-data is available through the inspector created by make_inspector().

inspector = model.make_inspector()

For our model, the available inspector fields are:

[field for field in dir(inspector) if not field.startswith("_")]
['MODEL_NAME',
 'dataspec',
 'directory',
 'evaluation',
 'export_to_tensorboard',
 'extract_all_trees',
 'extract_tree',
 'features',
 'file_prefix',
 'header',
 'iterate_on_nodes',
 'label',
 'label_classes',
 'metadata',
 'model_type',
 'num_trees',
 'objective',
 'specialized_header',
 'task',
 'training_logs',
 'tuning_logs',
 'variable_importances',
 'winner_take_all_inference']

Remember to see the API-reference or use ? for the builtin documentation.

?inspector.model_type

Some of the model meta-data:

print("Model type:", inspector.model_type())
print("Number of trees:", inspector.num_trees())
print("Objective:", inspector.objective())
print("Input features:", inspector.features())
Model type: RANDOM_FOREST
Number of trees: 300
Objective: Classification(label=__LABEL, class=None, num_classes=3)
Input features: ["bill_depth_mm" (1; #1), "bill_length_mm" (1; #2), "body_mass_g" (1; #3), "flipper_length_mm" (1; #4), "island" (4; #5), "sex" (4; #6), "year" (1; #7)]

evaluate() is the evaluation of the model computed during training. The dataset used for this evaluation depends on the algorithm. For example, it can be the validation dataset or the out-of-bag-dataset .

inspector.evaluation()
Evaluation(num_examples=344, accuracy=0.9767441860465116, loss=0.06765840018471313, rmse=None, ndcg=None, aucs=None, auuc=None, qini=None)

The variable importances are:

print(f"Available variable importances:")
for importance in inspector.variable_importances().keys():
  print("\t", importance)
Available variable importances:
     MEAN_DECREASE_IN_PRAUC_3_VS_OTHERS
     MEAN_DECREASE_IN_AP_2_VS_OTHERS
     MEAN_DECREASE_IN_AP_3_VS_OTHERS
     MEAN_DECREASE_IN_PRAUC_2_VS_OTHERS
     MEAN_DECREASE_IN_AP_1_VS_OTHERS
     NUM_AS_ROOT
     MEAN_DECREASE_IN_AUC_1_VS_OTHERS
     MEAN_DECREASE_IN_AUC_3_VS_OTHERS
     INV_MEAN_MIN_DEPTH
     MEAN_DECREASE_IN_ACCURACY
     NUM_NODES
     MEAN_DECREASE_IN_AUC_2_VS_OTHERS
     SUM_SCORE
     MEAN_DECREASE_IN_PRAUC_1_VS_OTHERS

Different variable importances have different semantics. For example, a feature with a mean decrease in auc of 0.05 means that removing this feature from the training dataset would reduce/hurt the AUC by 5%.

# Mean decrease in AUC of the class 1 vs the others.
inspector.variable_importances()["MEAN_DECREASE_IN_AUC_1_VS_OTHERS"]
[("bill_length_mm" (1; #2), 0.0713061951754389),
 ("island" (4; #5), 0.007298519736842035),
 ("flipper_length_mm" (1; #4), 0.004505893640351366),
 ("bill_depth_mm" (1; #1), 0.0021244517543865804),
 ("body_mass_g" (1; #3), 0.0005482456140351033),
 ("sex" (4; #6), 0.00047971491228060437),
 ("year" (1; #7), 0.0)]

Plot the variable importances from the inspector using Matplotlib

import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))

# Mean decrease in AUC of the class 1 vs the others.
variable_importance_metric = "MEAN_DECREASE_IN_AUC_1_VS_OTHERS"
variable_importances = inspector.variable_importances()[variable_importance_metric]

# Extract the feature name and importance values.
#
# `variable_importances` is a list of <feature, importance> tuples.
feature_names = [vi[0].name for vi in variable_importances]
feature_importances = [vi[1] for vi in variable_importances]
# The feature are ordered in decreasing importance value.
feature_ranks = range(len(feature_names))

bar = plt.barh(feature_ranks, feature_importances, label=[str(x) for x in feature_ranks])
plt.yticks(feature_ranks, feature_names)
plt.gca().invert_yaxis()

# TODO: Replace with "plt.bar_label()" when available.
# Label each bar with values
for importance, patch in zip(feature_importances, bar.patches):
  plt.text(patch.get_x() + patch.get_width(), patch.get_y(), f"{importance:.4f}", va="top")

plt.xlabel(variable_importance_metric)
plt.title("Mean decrease in AUC of the class 1 vs the others")
plt.tight_layout()
plt.show()

png

Finally, access the actual tree structure:

inspector.extract_tree(tree_idx=0)
Tree(root=NonLeafNode(condition=(bill_length_mm >= 43.25; miss=True, score=0.5482327342033386), pos_child=NonLeafNode(condition=(island in ['Biscoe']; miss=True, score=0.6515106558799744), pos_child=NonLeafNode(condition=(bill_depth_mm >= 17.225584030151367; miss=False, score=0.027205035090446472), pos_child=LeafNode(value=ProbabilityValue([0.16666666666666666, 0.0, 0.8333333333333334],n=6.0), idx=7), neg_child=LeafNode(value=ProbabilityValue([0.0, 0.0, 1.0],n=104.0), idx=6), value=ProbabilityValue([0.00909090909090909, 0.0, 0.990909090909091],n=110.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 1.0, 0.0],n=61.0), idx=5), value=ProbabilityValue([0.005847953216374269, 0.3567251461988304, 0.6374269005847953],n=171.0)), neg_child=NonLeafNode(condition=(bill_depth_mm >= 15.100000381469727; miss=True, score=0.150658518075943), pos_child=NonLeafNode(condition=(flipper_length_mm >= 187.5; miss=True, score=0.036139510571956635), pos_child=LeafNode(value=ProbabilityValue([1.0, 0.0, 0.0],n=104.0), idx=4), neg_child=NonLeafNode(condition=(bill_length_mm >= 42.30000305175781; miss=True, score=0.23430533707141876), pos_child=LeafNode(value=ProbabilityValue([0.0, 1.0, 0.0],n=5.0), idx=3), neg_child=NonLeafNode(condition=(bill_length_mm >= 40.55000305175781; miss=True, score=0.043961383402347565), pos_child=LeafNode(value=ProbabilityValue([0.8, 0.2, 0.0],n=5.0), idx=2), neg_child=LeafNode(value=ProbabilityValue([1.0, 0.0, 0.0],n=53.0), idx=1), value=ProbabilityValue([0.9827586206896551, 0.017241379310344827, 0.0],n=58.0)), value=ProbabilityValue([0.9047619047619048, 0.09523809523809523, 0.0],n=63.0)), value=ProbabilityValue([0.9640718562874252, 0.03592814371257485, 0.0],n=167.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 0.0, 1.0],n=6.0), idx=0), value=ProbabilityValue([0.930635838150289, 0.03468208092485549, 0.03468208092485549],n=173.0)), value=ProbabilityValue([0.47093023255813954, 0.19476744186046513, 0.33430232558139533],n=344.0)), label_classes=None)

Extracting a tree is not efficient. If speed is important, the model inspection can be done with the iterate_on_nodes() method instead. This method is a Depth First Pre-order traversals iterator on all the nodes of the model.

For following example computes how many times each feature is used (this is a kind of structural variable importance):

# number_of_use[F] will be the number of node using feature F in its condition.
number_of_use = collections.defaultdict(lambda: 0)

# Iterate over all the nodes in a Depth First Pre-order traversals.
for node_iter in inspector.iterate_on_nodes():

  if not isinstance(node_iter.node, tfdf.py_tree.node.NonLeafNode):
    # Skip the leaf nodes
    continue

  # Iterate over all the features used in the condition.
  # By default, models are "oblique" i.e. each node tests a single feature.
  for feature in node_iter.node.condition.features():
    number_of_use[feature] += 1

print("Number of condition nodes per features:")
for feature, count in number_of_use.items():
  print("\t", feature.name, ":", count)
Number of condition nodes per features:
     bill_length_mm : 778
     bill_depth_mm : 463
     flipper_length_mm : 414
     island : 342
     body_mass_g : 338
     year : 19
     sex : 36

Creating a model by hand

In this section you will create a small Random Forest model by hand. To make it extra easy, the model will only contain one simple tree:

3 label classes: Red, blue and green.
2 features: f1 (numerical) and f2 (string categorical)

f1>=1.5
    ├─(pos) f2 in ["cat","dog"]
             ├─(pos) value: [0.8, 0.1, 0.1]
             └─(neg) value: [0.1, 0.8, 0.1]
    └─(neg) value: [0.1, 0.1, 0.8]
# Create the model builder
builder = tfdf.builder.RandomForestBuilder(
    path="/tmp/manual_model",
    objective=tfdf.py_tree.objective.ClassificationObjective(
        label="color", classes=["red", "blue", "green"]))

Each tree is added one by one.

# So alias
Tree = tfdf.py_tree.tree.Tree
SimpleColumnSpec = tfdf.py_tree.dataspec.SimpleColumnSpec
ColumnType = tfdf.py_tree.dataspec.ColumnType
# Nodes
NonLeafNode = tfdf.py_tree.node.NonLeafNode
LeafNode = tfdf.py_tree.node.LeafNode
# Conditions
NumericalHigherThanCondition = tfdf.py_tree.condition.NumericalHigherThanCondition
CategoricalIsInCondition = tfdf.py_tree.condition.CategoricalIsInCondition
# Leaf values
ProbabilityValue = tfdf.py_tree.value.ProbabilityValue

builder.add_tree(
    Tree(
        NonLeafNode(
            condition=NumericalHigherThanCondition(
                feature=SimpleColumnSpec(name="f1", type=ColumnType.NUMERICAL),
                threshold=1.5,
                missing_evaluation=False),
            pos_child=NonLeafNode(
                condition=CategoricalIsInCondition(
                    feature=SimpleColumnSpec(name="f2",type=ColumnType.CATEGORICAL),
                    mask=["cat", "dog"],
                    missing_evaluation=False),
                pos_child=LeafNode(value=ProbabilityValue(probability=[0.8, 0.1, 0.1], num_examples=10)),
                neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.8, 0.1], num_examples=20))),
            neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.1, 0.8], num_examples=30)))))

Conclude the tree writing

builder.close()
I0000 00:00:1768226925.405440  150155 decision_forest.cc:808] Model loaded with 1 root(s), 5 node(s), and 2 input feature(s).
INFO:tensorflow:Assets written to: /tmp/manual_model/assets
INFO:tensorflow:Assets written to: /tmp/manual_model/assets

Now you can open the model as a regular keras model, and make predictions:

manual_model = tf_keras.models.load_model("/tmp/manual_model")
I0000 00:00:1768226926.188144  150328 decision_forest.cc:808] Model loaded with 1 root(s), 5 node(s), and 2 input feature(s).
examples = tf.data.Dataset.from_tensor_slices({
        "f1": [1.0, 2.0, 3.0],
        "f2": ["cat", "cat", "bird"]
    }).batch(2)

predictions = manual_model.predict(examples)

print("predictions:\n",predictions)
2/2 [==============================] - 0s 3ms/step
predictions:
 [[0.1 0.1 0.8]
 [0.8 0.1 0.1]
 [0.1 0.8 0.1]]

Access the structure:

yggdrasil_model_path = manual_model.yggdrasil_model_path_tensor().numpy().decode("utf-8")
print("yggdrasil_model_path:",yggdrasil_model_path)

inspector = tfdf.inspector.make_inspector(yggdrasil_model_path)
print("Input features:", inspector.features())
yggdrasil_model_path: /tmp/manual_model/assets/
Input features: ["f1" (1; #1), "f2" (4; #2)]

And of course, you can plot this manually constructed model:

tfdf.model_plotter.plot_model_in_colab(manual_model)