Ispezionare ed eseguire il debug dei modelli della foresta decisionale

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza su GitHub Scarica taccuino

In questa collaborazione imparerai come ispezionare e creare direttamente la struttura di un modello. Diamo per scontato si ha familiarità con i concetti introdotti nel principianti e intermedi colabs.

In questa collaborazione potrai:

  1. Addestra un modello Random Forest e accedi alla sua struttura in modo programmatico.

  2. Crea a mano un modello Random Forest e usalo come un modello classico.

Impostare

# Install TensorFlow Dececision Forests.
pip install tensorflow_decision_forests

# Use wurlitzer to capture training logs.
pip install wurlitzer
import tensorflow_decision_forests as tfdf

import os
import numpy as np
import pandas as pd
import tensorflow as tf
import math
import collections

try:
  from wurlitzer import sys_pipes
except:
  from colabtools.googlelog import CaptureLog as sys_pipes

from IPython.core.magic import register_line_magic
from IPython.display import Javascript
WARNING:root:Failure to load the custom c++ tensorflow ops. This error is likely caused the version of TensorFlow and TensorFlow Decision Forests are not compatible.
WARNING:root:TF Parameter Server distributed training not available.

La cella del codice nascosto limita l'altezza dell'output in colab.

Allena una semplice foresta casuale

Formiamo una foresta a caso come nel CoLab principianti :

# Download the dataset
!wget -q https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins.csv -O /tmp/penguins.csv

# Load a dataset into a Pandas Dataframe.
dataset_df = pd.read_csv("/tmp/penguins.csv")

# Show the first three examples.
print(dataset_df.head(3))

# Convert the pandas dataframe into a tf dataset.
dataset_tf = tfdf.keras.pd_dataframe_to_tf_dataset(dataset_df, label="species")

# Train the Random Forest
model = tfdf.keras.RandomForestModel(compute_oob_variable_importances=True)
model.fit(x=dataset_tf)
species     island  bill_length_mm  bill_depth_mm  flipper_length_mm  \
0  Adelie  Torgersen            39.1           18.7              181.0   
1  Adelie  Torgersen            39.5           17.4              186.0   
2  Adelie  Torgersen            40.3           18.0              195.0   

   body_mass_g     sex  year  
0       3750.0    male  2007  
1       3800.0  female  2007  
2       3250.0  female  2007
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_decision_forests/keras/core.py:1612: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only
  features_dataframe = dataframe.drop(label, 1)
6/6 [==============================] - 4s 17ms/step
[INFO kernel.cc:736] Start Yggdrasil model training
[INFO kernel.cc:737] Collect training examples
[INFO kernel.cc:392] Number of batches: 6
[INFO kernel.cc:393] Number of examples: 344
[INFO kernel.cc:759] Dataset:
Number of records: 344
Number of columns: 8

Number of columns by type:
    NUMERICAL: 5 (62.5%)
    CATEGORICAL: 3 (37.5%)

Columns:

NUMERICAL: 5 (62.5%)
    0: "bill_depth_mm" NUMERICAL num-nas:2 (0.581395%) mean:17.1512 min:13.1 max:21.5 sd:1.9719
    1: "bill_length_mm" NUMERICAL num-nas:2 (0.581395%) mean:43.9219 min:32.1 max:59.6 sd:5.4516
    2: "body_mass_g" NUMERICAL num-nas:2 (0.581395%) mean:4201.75 min:2700 max:6300 sd:800.781
    3: "flipper_length_mm" NUMERICAL num-nas:2 (0.581395%) mean:200.915 min:172 max:231 sd:14.0411
    6: "year" NUMERICAL mean:2008.03 min:2007 max:2009 sd:0.817166

CATEGORICAL: 3 (37.5%)
    4: "island" CATEGORICAL has-dict vocab-size:4 zero-ood-items most-frequent:"Biscoe" 168 (48.8372%)
    5: "sex" CATEGORICAL num-nas:11 (3.19767%) has-dict vocab-size:3 zero-ood-items most-frequent:"male" 168 (50.4505%)
    7: "__LABEL" CATEGORICAL integerized vocab-size:4 no-ood-item

Terminology:
    nas: Number of non-available (i.e. missing) values.
    ood: Out of dictionary.
    manually-defined: Attribute which type is manually defined by the user i.e. the type was not automatically inferred.
    tokenized: The attribute value is obtained through tokenization.
    has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string.
    vocab-size: Number of unique values.

[INFO kernel.cc:762] Configure learner
[INFO kernel.cc:787] Training config:
learner: "RANDOM_FOREST"
features: "bill_depth_mm"
features: "bill_length_mm"
features: "body_mass_g"
features: "flipper_length_mm"
features: "island"
features: "sex"
features: "year"
label: "__LABEL"
task: CLASSIFICATION
[yggdrasil_decision_forests.model.random_forest.proto.random_forest_config] {
  num_trees: 300
  decision_tree {
    max_depth: 16
    min_examples: 5
    in_split_min_examples_check: true
    missing_value_policy: GLOBAL_IMPUTATION
    allow_na_conditions: false
    categorical_set_greedy_forward {
      sampling: 0.1
      max_num_items: -1
      min_item_frequency: 1
    }
    growing_strategy_local {
    }
    categorical {
      cart {
      }
    }
    num_candidate_attributes_ratio: -1
    axis_aligned_split {
    }
    internal {
      sorting_strategy: PRESORTED
    }
  }
  winner_take_all_inference: true
  compute_oob_performances: true
  compute_oob_variable_importances: true
  adapt_bootstrap_size_ratio_for_maximum_training_duration: false
}

[INFO kernel.cc:790] Deployment config:
num_threads: 6

[INFO kernel.cc:817] Train model
[INFO random_forest.cc:315] Training random forest on 344 example(s) and 7 feature(s).
[INFO random_forest.cc:628] Training of tree  1/300 (tree index:0) done accuracy:0.964286 logloss:1.28727
[INFO random_forest.cc:628] Training of tree  11/300 (tree index:10) done accuracy:0.956268 logloss:0.584301
[INFO random_forest.cc:628] Training of tree  22/300 (tree index:21) done accuracy:0.965116 logloss:0.378823
[INFO random_forest.cc:628] Training of tree  35/300 (tree index:34) done accuracy:0.968023 logloss:0.178185
[INFO random_forest.cc:628] Training of tree  46/300 (tree index:45) done accuracy:0.973837 logloss:0.170304
[INFO random_forest.cc:628] Training of tree  58/300 (tree index:57) done accuracy:0.973837 logloss:0.171223
[INFO random_forest.cc:628] Training of tree  70/300 (tree index:69) done accuracy:0.979651 logloss:0.169564
[INFO random_forest.cc:628] Training of tree  83/300 (tree index:82) done accuracy:0.976744 logloss:0.17074
[INFO random_forest.cc:628] Training of tree  96/300 (tree index:95) done accuracy:0.976744 logloss:0.0736925
[INFO random_forest.cc:628] Training of tree  106/300 (tree index:105) done accuracy:0.976744 logloss:0.0748649
[INFO random_forest.cc:628] Training of tree  117/300 (tree index:116) done accuracy:0.976744 logloss:0.074671
[INFO random_forest.cc:628] Training of tree  130/300 (tree index:129) done accuracy:0.976744 logloss:0.0736275
[INFO random_forest.cc:628] Training of tree  140/300 (tree index:139) done accuracy:0.976744 logloss:0.0727718
[INFO random_forest.cc:628] Training of tree  152/300 (tree index:151) done accuracy:0.976744 logloss:0.0715068
[INFO random_forest.cc:628] Training of tree  162/300 (tree index:161) done accuracy:0.976744 logloss:0.0708994
[INFO random_forest.cc:628] Training of tree  173/300 (tree index:172) done accuracy:0.976744 logloss:0.069447
[INFO random_forest.cc:628] Training of tree  184/300 (tree index:183) done accuracy:0.976744 logloss:0.0695926
[INFO random_forest.cc:628] Training of tree  195/300 (tree index:194) done accuracy:0.976744 logloss:0.0690138
[INFO random_forest.cc:628] Training of tree  205/300 (tree index:204) done accuracy:0.976744 logloss:0.0694597
[INFO random_forest.cc:628] Training of tree  217/300 (tree index:216) done accuracy:0.976744 logloss:0.068122
[INFO random_forest.cc:628] Training of tree  229/300 (tree index:228) done accuracy:0.976744 logloss:0.0687641
[INFO random_forest.cc:628] Training of tree  239/300 (tree index:238) done accuracy:0.976744 logloss:0.067988
[INFO random_forest.cc:628] Training of tree  250/300 (tree index:249) done accuracy:0.976744 logloss:0.0690187
[INFO random_forest.cc:628] Training of tree  260/300 (tree index:259) done accuracy:0.976744 logloss:0.0690134
[INFO random_forest.cc:628] Training of tree  270/300 (tree index:269) done accuracy:0.976744 logloss:0.0689877
[INFO random_forest.cc:628] Training of tree  280/300 (tree index:279) done accuracy:0.976744 logloss:0.0689845
[INFO random_forest.cc:628] Training of tree  290/300 (tree index:288) done accuracy:0.976744 logloss:0.0690742
[INFO random_forest.cc:628] Training of tree  300/300 (tree index:299) done accuracy:0.976744 logloss:0.068949
[INFO random_forest.cc:696] Final OOB metrics: accuracy:0.976744 logloss:0.068949
[INFO kernel.cc:828] Export model in log directory: /tmp/tmpoqki9pfl
[INFO kernel.cc:836] Save model in resources
[INFO kernel.cc:988] Loading model from path
[INFO decision_forest.cc:590] Model loaded with 300 root(s), 5080 node(s), and 7 input feature(s).
[INFO abstract_model.cc:993] Engine "RandomForestGeneric" built
[INFO kernel.cc:848] Use fast generic engine
<keras.callbacks.History at 0x7f09eaa9cb90>

Si noti la compute_oob_variable_importances=True iper-parametri nel costruttore modello. Questa opzione calcola l'importanza della variabile Out-of-bag (OOB) durante l'allenamento. Questo è un popolare importanza variabile di permutazione per i modelli foresta casuale.

Il calcolo dell'importanza della variabile OOB non influisce sul modello finale, rallenta l'addestramento su set di dati di grandi dimensioni.

Controlla il riepilogo del modello:

%set_cell_height 300

model.summary()
<IPython.core.display.Javascript object>
Model: "random_forest_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
=================================================================
Total params: 1
Trainable params: 0
Non-trainable params: 1
_________________________________________________________________
Type: "RANDOM_FOREST"
Task: CLASSIFICATION
Label: "__LABEL"

Input Features (7):
    bill_depth_mm
    bill_length_mm
    body_mass_g
    flipper_length_mm
    island
    sex
    year

No weights

Variable Importance: MEAN_DECREASE_IN_ACCURACY:

    1.    "bill_length_mm"  0.151163 ################
    2.            "island"  0.008721 #
    3.     "bill_depth_mm"  0.000000 
    4.       "body_mass_g"  0.000000 
    5.               "sex"  0.000000 
    6.              "year"  0.000000 
    7. "flipper_length_mm" -0.002907 

Variable Importance: MEAN_DECREASE_IN_AP_1_VS_OTHERS:

    1.    "bill_length_mm"  0.083305 ################
    2.            "island"  0.007664 #
    3. "flipper_length_mm"  0.003400 
    4.     "bill_depth_mm"  0.002741 
    5.       "body_mass_g"  0.000722 
    6.               "sex"  0.000644 
    7.              "year"  0.000000 

Variable Importance: MEAN_DECREASE_IN_AP_2_VS_OTHERS:

    1.    "bill_length_mm"  0.508510 ################
    2.            "island"  0.023487 
    3.     "bill_depth_mm"  0.007744 
    4. "flipper_length_mm"  0.006008 
    5.       "body_mass_g"  0.003017 
    6.               "sex"  0.001537 
    7.              "year" -0.000245 

Variable Importance: MEAN_DECREASE_IN_AP_3_VS_OTHERS:

    1.            "island"  0.002192 ################
    2.    "bill_length_mm"  0.001572 ############
    3.     "bill_depth_mm"  0.000497 #######
    4.               "sex"  0.000000 ####
    5.              "year"  0.000000 ####
    6.       "body_mass_g" -0.000053 ####
    7. "flipper_length_mm" -0.000890 

Variable Importance: MEAN_DECREASE_IN_AUC_1_VS_OTHERS:

    1.    "bill_length_mm"  0.071306 ################
    2.            "island"  0.007299 #
    3. "flipper_length_mm"  0.004506 #
    4.     "bill_depth_mm"  0.002124 
    5.       "body_mass_g"  0.000548 
    6.               "sex"  0.000480 
    7.              "year"  0.000000 

Variable Importance: MEAN_DECREASE_IN_AUC_2_VS_OTHERS:

    1.    "bill_length_mm"  0.108642 ################
    2.            "island"  0.014493 ##
    3.     "bill_depth_mm"  0.007406 #
    4. "flipper_length_mm"  0.005195 
    5.       "body_mass_g"  0.001012 
    6.               "sex"  0.000480 
    7.              "year" -0.000053 

Variable Importance: MEAN_DECREASE_IN_AUC_3_VS_OTHERS:

    1.            "island"  0.002126 ################
    2.    "bill_length_mm"  0.001393 ###########
    3.     "bill_depth_mm"  0.000293 #####
    4.               "sex"  0.000000 ###
    5.              "year"  0.000000 ###
    6.       "body_mass_g" -0.000037 ###
    7. "flipper_length_mm" -0.000550 

Variable Importance: MEAN_DECREASE_IN_PRAUC_1_VS_OTHERS:

    1.    "bill_length_mm"  0.083122 ################
    2.            "island"  0.010887 ##
    3. "flipper_length_mm"  0.003425 
    4.     "bill_depth_mm"  0.002731 
    5.       "body_mass_g"  0.000719 
    6.               "sex"  0.000641 
    7.              "year"  0.000000 

Variable Importance: MEAN_DECREASE_IN_PRAUC_2_VS_OTHERS:

    1.    "bill_length_mm"  0.497611 ################
    2.            "island"  0.024045 
    3.     "bill_depth_mm"  0.007734 
    4. "flipper_length_mm"  0.006017 
    5.       "body_mass_g"  0.003000 
    6.               "sex"  0.001528 
    7.              "year" -0.000243 

Variable Importance: MEAN_DECREASE_IN_PRAUC_3_VS_OTHERS:

    1.            "island"  0.002187 ################
    2.    "bill_length_mm"  0.001568 ############
    3.     "bill_depth_mm"  0.000495 #######
    4.               "sex"  0.000000 ####
    5.              "year"  0.000000 ####
    6.       "body_mass_g" -0.000053 ####
    7. "flipper_length_mm" -0.000886 

Variable Importance: MEAN_MIN_DEPTH:

    1.           "__LABEL"  3.479602 ################
    2.              "year"  3.463891 ###############
    3.               "sex"  3.430498 ###############
    4.       "body_mass_g"  2.898112 ###########
    5.            "island"  2.388925 ########
    6.     "bill_depth_mm"  2.336100 #######
    7.    "bill_length_mm"  1.282960 
    8. "flipper_length_mm"  1.270079 

Variable Importance: NUM_AS_ROOT:

    1. "flipper_length_mm" 157.000000 ################
    2.    "bill_length_mm" 76.000000 #######
    3.     "bill_depth_mm" 52.000000 #####
    4.            "island" 12.000000 
    5.       "body_mass_g"  3.000000 

Variable Importance: NUM_NODES:

    1.    "bill_length_mm" 778.000000 ################
    2.     "bill_depth_mm" 463.000000 #########
    3. "flipper_length_mm" 414.000000 ########
    4.            "island" 342.000000 ######
    5.       "body_mass_g" 338.000000 ######
    6.               "sex" 36.000000 
    7.              "year" 19.000000 

Variable Importance: SUM_SCORE:

    1.    "bill_length_mm" 36515.793787 ################
    2. "flipper_length_mm" 35120.434174 ###############
    3.            "island" 14669.408395 ######
    4.     "bill_depth_mm" 14515.446617 ######
    5.       "body_mass_g" 3485.330881 #
    6.               "sex" 354.201073 
    7.              "year" 49.737758 



Winner take all: true
Out-of-bag evaluation: accuracy:0.976744 logloss:0.068949
Number of trees: 300
Total number of nodes: 5080

Number of nodes by tree:
Count: 300 Average: 16.9333 StdDev: 3.10197
Min: 11 Max: 31 Ignored: 0
----------------------------------------------
[ 11, 12)  6   2.00%   2.00% #
[ 12, 13)  0   0.00%   2.00%
[ 13, 14) 46  15.33%  17.33% #####
[ 14, 15)  0   0.00%  17.33%
[ 15, 16) 70  23.33%  40.67% ########
[ 16, 17)  0   0.00%  40.67%
[ 17, 18) 84  28.00%  68.67% ##########
[ 18, 19)  0   0.00%  68.67%
[ 19, 20) 46  15.33%  84.00% #####
[ 20, 21)  0   0.00%  84.00%
[ 21, 22) 30  10.00%  94.00% ####
[ 22, 23)  0   0.00%  94.00%
[ 23, 24) 13   4.33%  98.33% ##
[ 24, 25)  0   0.00%  98.33%
[ 25, 26)  2   0.67%  99.00%
[ 26, 27)  0   0.00%  99.00%
[ 27, 28)  2   0.67%  99.67%
[ 28, 29)  0   0.00%  99.67%
[ 29, 30)  0   0.00%  99.67%
[ 30, 31]  1   0.33% 100.00%

Depth by leafs:
Count: 2690 Average: 3.53271 StdDev: 1.06789
Min: 2 Max: 7 Ignored: 0
----------------------------------------------
[ 2, 3) 545  20.26%  20.26% ######
[ 3, 4) 747  27.77%  48.03% ########
[ 4, 5) 888  33.01%  81.04% ##########
[ 5, 6) 444  16.51%  97.55% #####
[ 6, 7)  62   2.30%  99.85% #
[ 7, 7]   4   0.15% 100.00%

Number of training obs by leaf:
Count: 2690 Average: 38.3643 StdDev: 44.8651
Min: 5 Max: 155 Ignored: 0
----------------------------------------------
[   5,  12) 1474  54.80%  54.80% ##########
[  12,  20)  124   4.61%  59.41% #
[  20,  27)   48   1.78%  61.19%
[  27,  35)   74   2.75%  63.94% #
[  35,  42)   58   2.16%  66.10%
[  42,  50)   85   3.16%  69.26% #
[  50,  57)   96   3.57%  72.83% #
[  57,  65)   87   3.23%  76.06% #
[  65,  72)   49   1.82%  77.88%
[  72,  80)   23   0.86%  78.74%
[  80,  88)   30   1.12%  79.85%
[  88,  95)   23   0.86%  80.71%
[  95, 103)   42   1.56%  82.27%
[ 103, 110)   62   2.30%  84.57%
[ 110, 118)  115   4.28%  88.85% #
[ 118, 125)  115   4.28%  93.12% #
[ 125, 133)   98   3.64%  96.77% #
[ 133, 140)   49   1.82%  98.59%
[ 140, 148)   31   1.15%  99.74%
[ 148, 155]    7   0.26% 100.00%

Attribute in nodes:
    778 : bill_length_mm [NUMERICAL]
    463 : bill_depth_mm [NUMERICAL]
    414 : flipper_length_mm [NUMERICAL]
    342 : island [CATEGORICAL]
    338 : body_mass_g [NUMERICAL]
    36 : sex [CATEGORICAL]
    19 : year [NUMERICAL]

Attribute in nodes with depth <= 0:
    157 : flipper_length_mm [NUMERICAL]
    76 : bill_length_mm [NUMERICAL]
    52 : bill_depth_mm [NUMERICAL]
    12 : island [CATEGORICAL]
    3 : body_mass_g [NUMERICAL]

Attribute in nodes with depth <= 1:
    250 : bill_length_mm [NUMERICAL]
    244 : flipper_length_mm [NUMERICAL]
    183 : bill_depth_mm [NUMERICAL]
    170 : island [CATEGORICAL]
    53 : body_mass_g [NUMERICAL]

Attribute in nodes with depth <= 2:
    462 : bill_length_mm [NUMERICAL]
    320 : flipper_length_mm [NUMERICAL]
    310 : bill_depth_mm [NUMERICAL]
    287 : island [CATEGORICAL]
    162 : body_mass_g [NUMERICAL]
    9 : sex [CATEGORICAL]
    5 : year [NUMERICAL]

Attribute in nodes with depth <= 3:
    669 : bill_length_mm [NUMERICAL]
    410 : bill_depth_mm [NUMERICAL]
    383 : flipper_length_mm [NUMERICAL]
    328 : island [CATEGORICAL]
    286 : body_mass_g [NUMERICAL]
    32 : sex [CATEGORICAL]
    10 : year [NUMERICAL]

Attribute in nodes with depth <= 5:
    778 : bill_length_mm [NUMERICAL]
    462 : bill_depth_mm [NUMERICAL]
    413 : flipper_length_mm [NUMERICAL]
    342 : island [CATEGORICAL]
    338 : body_mass_g [NUMERICAL]
    36 : sex [CATEGORICAL]
    19 : year [NUMERICAL]

Condition type in nodes:
    2012 : HigherCondition
    378 : ContainsBitmapCondition
Condition type in nodes with depth <= 0:
    288 : HigherCondition
    12 : ContainsBitmapCondition
Condition type in nodes with depth <= 1:
    730 : HigherCondition
    170 : ContainsBitmapCondition
Condition type in nodes with depth <= 2:
    1259 : HigherCondition
    296 : ContainsBitmapCondition
Condition type in nodes with depth <= 3:
    1758 : HigherCondition
    360 : ContainsBitmapCondition
Condition type in nodes with depth <= 5:
    2010 : HigherCondition
    378 : ContainsBitmapCondition
Node format: NOT_SET

Training OOB:
    trees: 1, Out-of-bag evaluation: accuracy:0.964286 logloss:1.28727
    trees: 11, Out-of-bag evaluation: accuracy:0.956268 logloss:0.584301
    trees: 22, Out-of-bag evaluation: accuracy:0.965116 logloss:0.378823
    trees: 35, Out-of-bag evaluation: accuracy:0.968023 logloss:0.178185
    trees: 46, Out-of-bag evaluation: accuracy:0.973837 logloss:0.170304
    trees: 58, Out-of-bag evaluation: accuracy:0.973837 logloss:0.171223
    trees: 70, Out-of-bag evaluation: accuracy:0.979651 logloss:0.169564
    trees: 83, Out-of-bag evaluation: accuracy:0.976744 logloss:0.17074
    trees: 96, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0736925
    trees: 106, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0748649
    trees: 117, Out-of-bag evaluation: accuracy:0.976744 logloss:0.074671
    trees: 130, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0736275
    trees: 140, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0727718
    trees: 152, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0715068
    trees: 162, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0708994
    trees: 173, Out-of-bag evaluation: accuracy:0.976744 logloss:0.069447
    trees: 184, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0695926
    trees: 195, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690138
    trees: 205, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0694597
    trees: 217, Out-of-bag evaluation: accuracy:0.976744 logloss:0.068122
    trees: 229, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0687641
    trees: 239, Out-of-bag evaluation: accuracy:0.976744 logloss:0.067988
    trees: 250, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690187
    trees: 260, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690134
    trees: 270, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0689877
    trees: 280, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0689845
    trees: 290, Out-of-bag evaluation: accuracy:0.976744 logloss:0.0690742
    trees: 300, Out-of-bag evaluation: accuracy:0.976744 logloss:0.068949

Si notino le più importanze variabili con nome MEAN_DECREASE_IN_* .

Tracciare il modello

Quindi, traccia il modello.

Una foresta casuale è un modello di grandi dimensioni (questo modello ha 300 alberi e ~5k nodi; vedere il riepilogo sopra). Pertanto, traccia solo il primo albero e limita i nodi alla profondità 3.

tfdf.model_plotter.plot_model_in_colab(model, tree_idx=0, max_depth=3)

Ispezionare la struttura del modello

La struttura del modello e meta-dati sono disponibili nella finestra di ispezione creato da make_inspector() .

inspector = model.make_inspector()

Per il nostro modello, i campi dell'ispettore disponibili sono:

[field for field in dir(inspector) if not field.startswith("_")]
['MODEL_NAME',
 'dataspec',
 'evaluation',
 'export_to_tensorboard',
 'extract_all_trees',
 'extract_tree',
 'features',
 'iterate_on_nodes',
 'label',
 'label_classes',
 'model_type',
 'num_trees',
 'objective',
 'specialized_header',
 'task',
 'training_logs',
 'variable_importances',
 'winner_take_all_inference']

Ricordatevi di vedere l'API-di riferimento o l'uso ? per la documentazione interna.

?inspector.model_type

Alcuni dei metadati del modello:

print("Model type:", inspector.model_type())
print("Number of trees:", inspector.num_trees())
print("Objective:", inspector.objective())
print("Input features:", inspector.features())
Model type: RANDOM_FOREST
Number of trees: 300
Objective: Classification(label=__LABEL, class=None, num_classes=3)
Input features: ["bill_depth_mm" (1; #0), "bill_length_mm" (1; #1), "body_mass_g" (1; #2), "flipper_length_mm" (1; #3), "island" (4; #4), "sex" (4; #5), "year" (1; #6)]

evaluate() è la valutazione del modello calcolato durante l'allenamento. Il set di dati utilizzato per questa valutazione dipende dall'algoritmo. Ad esempio, può essere il dataset di convalida o il dataset out-of-bag .

inspector.evaluation()
Evaluation(num_examples=344, accuracy=0.9767441860465116, loss=0.06894904488784283, rmse=None, ndcg=None, aucs=None)

Le variabili di importanza sono:

print(f"Available variable importances:")
for importance in inspector.variable_importances().keys():
  print("\t", importance)
Available variable importances:
     MEAN_DECREASE_IN_AUC_3_VS_OTHERS
     NUM_AS_ROOT
     MEAN_DECREASE_IN_AUC_2_VS_OTHERS
     MEAN_DECREASE_IN_AP_2_VS_OTHERS
     MEAN_DECREASE_IN_ACCURACY
     SUM_SCORE
     MEAN_DECREASE_IN_PRAUC_2_VS_OTHERS
     MEAN_DECREASE_IN_PRAUC_3_VS_OTHERS
     MEAN_DECREASE_IN_AP_3_VS_OTHERS
     MEAN_DECREASE_IN_AUC_1_VS_OTHERS
     MEAN_MIN_DEPTH
     MEAN_DECREASE_IN_PRAUC_1_VS_OTHERS
     NUM_NODES
     MEAN_DECREASE_IN_AP_1_VS_OTHERS

Diverse importanza delle variabili hanno semantiche diverse. Ad esempio, una caratteristica con una riduzione media dell'AUC di 0.05 mezzi che la rimozione di questa funzionalità dal set di dati di formazione ridurrebbe / male l'AUC del 5%.

# Mean decrease in AUC of the class 1 vs the others.
inspector.variable_importances()["MEAN_DECREASE_IN_AUC_1_VS_OTHERS"]
[("bill_length_mm" (1; #1), 0.0713061951754389),
 ("island" (4; #4), 0.007298519736842035),
 ("flipper_length_mm" (1; #3), 0.004505893640351366),
 ("bill_depth_mm" (1; #0), 0.0021244517543865804),
 ("body_mass_g" (1; #2), 0.0005482456140351033),
 ("sex" (4; #5), 0.00047971491228060437),
 ("year" (1; #6), 0.0)]

Infine, accedi alla struttura ad albero attuale:

inspector.extract_tree(tree_idx=0)
Tree(NonLeafNode(condition=(bill_length_mm >= 43.25; miss=True), pos_child=NonLeafNode(condition=(island in ['Biscoe']; miss=True), pos_child=NonLeafNode(condition=(bill_depth_mm >= 17.225584030151367; miss=False), pos_child=LeafNode(value=ProbabilityValue([0.16666666666666666, 0.0, 0.8333333333333334],n=6.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 0.0, 1.0],n=104.0)), value=ProbabilityValue([0.00909090909090909, 0.0, 0.990909090909091],n=110.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 1.0, 0.0],n=61.0)), value=ProbabilityValue([0.005847953216374269, 0.3567251461988304, 0.6374269005847953],n=171.0)), neg_child=NonLeafNode(condition=(bill_depth_mm >= 15.100000381469727; miss=True), pos_child=NonLeafNode(condition=(flipper_length_mm >= 187.5; miss=True), pos_child=LeafNode(value=ProbabilityValue([1.0, 0.0, 0.0],n=104.0)), neg_child=NonLeafNode(condition=(bill_length_mm >= 42.30000305175781; miss=True), pos_child=LeafNode(value=ProbabilityValue([0.0, 1.0, 0.0],n=5.0)), neg_child=NonLeafNode(condition=(bill_length_mm >= 40.55000305175781; miss=True), pos_child=LeafNode(value=ProbabilityValue([0.8, 0.2, 0.0],n=5.0)), neg_child=LeafNode(value=ProbabilityValue([1.0, 0.0, 0.0],n=53.0)), value=ProbabilityValue([0.9827586206896551, 0.017241379310344827, 0.0],n=58.0)), value=ProbabilityValue([0.9047619047619048, 0.09523809523809523, 0.0],n=63.0)), value=ProbabilityValue([0.9640718562874252, 0.03592814371257485, 0.0],n=167.0)), neg_child=LeafNode(value=ProbabilityValue([0.0, 0.0, 1.0],n=6.0)), value=ProbabilityValue([0.930635838150289, 0.03468208092485549, 0.03468208092485549],n=173.0)), value=ProbabilityValue([0.47093023255813954, 0.19476744186046513, 0.33430232558139533],n=344.0)),label_classes={self.label_classes})

L'estrazione di un albero non è efficiente. Se la velocità è importante, l'ispezione modello può essere fatto con i iterate_on_nodes() invece il metodo. Questo metodo è un iteratore di attraversamenti di Depth First Pre-order su tutti i nodi del modello.

L'esempio seguente calcola quante volte viene utilizzata ciascuna caratteristica (questo è un tipo di importanza della variabile strutturale):

# number_of_use[F] will be the number of node using feature F in its condition.
number_of_use = collections.defaultdict(lambda: 0)

# Iterate over all the nodes in a Depth First Pre-order traversals.
for node_iter in inspector.iterate_on_nodes():

  if not isinstance(node_iter.node, tfdf.py_tree.node.NonLeafNode):
    # Skip the leaf nodes
    continue

  # Iterate over all the features used in the condition.
  # By default, models are "oblique" i.e. each node tests a single feature.
  for feature in node_iter.node.condition.features():
    number_of_use[feature] += 1

print("Number of condition nodes per features:")
for feature, count in number_of_use.items():
  print("\t", feature.name, ":", count)
Number of condition nodes per features:
     bill_length_mm : 778
     bill_depth_mm : 463
     flipper_length_mm : 414
     island : 342
     body_mass_g : 338
     year : 19
     sex : 36

Creare un modello a mano

In questa sezione creerai a mano un piccolo modello Random Forest. Per renderlo ancora più semplice, il modello conterrà solo un semplice albero:

3 label classes: Red, blue and green.
2 features: f1 (numerical) and f2 (string categorical)

f1>=1.5
    ├─(pos)─ f2 in ["cat","dog"]
    │         ├─(pos)─ value: [0.8, 0.1, 0.1]
    │         └─(neg)─ value: [0.1, 0.8, 0.1]
    └─(neg)─ value: [0.1, 0.1, 0.8]
# Create the model builder
builder = tfdf.builder.RandomForestBuilder(
    path="/tmp/manual_model",
    objective=tfdf.py_tree.objective.ClassificationObjective(
        label="color", classes=["red", "blue", "green"]))

Ogni albero viene aggiunto uno per uno.

# So alias
Tree = tfdf.py_tree.tree.Tree
SimpleColumnSpec = tfdf.py_tree.dataspec.SimpleColumnSpec
ColumnType = tfdf.py_tree.dataspec.ColumnType
# Nodes
NonLeafNode = tfdf.py_tree.node.NonLeafNode
LeafNode = tfdf.py_tree.node.LeafNode
# Conditions
NumericalHigherThanCondition = tfdf.py_tree.condition.NumericalHigherThanCondition
CategoricalIsInCondition = tfdf.py_tree.condition.CategoricalIsInCondition
# Leaf values
ProbabilityValue = tfdf.py_tree.value.ProbabilityValue

builder.add_tree(
    Tree(
        NonLeafNode(
            condition=NumericalHigherThanCondition(
                feature=SimpleColumnSpec(name="f1", type=ColumnType.NUMERICAL),
                threshold=1.5,
                missing_evaluation=False),
            pos_child=NonLeafNode(
                condition=CategoricalIsInCondition(
                    feature=SimpleColumnSpec(name="f2",type=ColumnType.CATEGORICAL),
                    mask=["cat", "dog"],
                    missing_evaluation=False),
                pos_child=LeafNode(value=ProbabilityValue(probability=[0.8, 0.1, 0.1], num_examples=10)),
                neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.8, 0.1], num_examples=20))),
            neg_child=LeafNode(value=ProbabilityValue(probability=[0.1, 0.1, 0.8], num_examples=30)))))

Concludi l'albero scrivendo

builder.close()
[INFO kernel.cc:988] Loading model from path
[INFO decision_forest.cc:590] Model loaded with 1 root(s), 5 node(s), and 2 input feature(s).
[INFO kernel.cc:848] Use fast generic engine
2021-11-08 12:19:14.555155: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
INFO:tensorflow:Assets written to: /tmp/manual_model/assets
INFO:tensorflow:Assets written to: /tmp/manual_model/assets

Ora puoi aprire il modello come un normale modello Keras e fare previsioni:

manual_model = tf.keras.models.load_model("/tmp/manual_model")
[INFO kernel.cc:988] Loading model from path
[INFO decision_forest.cc:590] Model loaded with 1 root(s), 5 node(s), and 2 input feature(s).
[INFO kernel.cc:848] Use fast generic engine
examples = tf.data.Dataset.from_tensor_slices({
        "f1": [1.0, 2.0, 3.0],
        "f2": ["cat", "cat", "bird"]
    }).batch(2)

predictions = manual_model.predict(examples)

print("predictions:\n",predictions)
predictions:
 [[0.1 0.1 0.8]
 [0.8 0.1 0.1]
 [0.1 0.8 0.1]]

Accedi alla struttura:

yggdrasil_model_path = manual_model.yggdrasil_model_path_tensor().numpy().decode("utf-8")
print("yggdrasil_model_path:",yggdrasil_model_path)

inspector = tfdf.inspector.make_inspector(yggdrasil_model_path)
print("Input features:", inspector.features())
yggdrasil_model_path: /tmp/manual_model/assets/
Input features: ["f1" (1; #1), "f2" (4; #2)]

E, naturalmente, puoi tracciare questo modello costruito manualmente:

tfdf.model_plotter.plot_model_in_colab(manual_model)