Watch talks from the 2019 TensorFlow Dev Summit Watch now

Gradient Boosted Trees: Model understanding

View on Run in Google Colab View source on GitHub

For an end-to-end walkthrough of training a Gradient Boosting model check out the boosted trees tutorial. In this tutorial you will:

  • Learn how to interpret a Boosted Tree model both locally and globally
  • Gain intution for how a Boosted Trees model fits a dataset

How to interpret Boosted Trees models both locally and globally

Local interpretability refers to an understanding of a model’s predictions at the individual example level, while global interpretability refers to an understanding of the model as a whole. Such techniques can help machine learning (ML) practitioners detect bias and bugs during the model development stage

For local interpretability, you will learn how to create and visualize per-instance contributions. To distinguish this from feature importances, we refer to these values as directional feature contributions (DFCs).

For global interpretability you will retrieve and visualize gain-based feature importances, permutation feature importances and also show aggregated DFCs.

Load the titanic dataset

You will be using the titanic dataset, where the (rather morbid) goal is to predict passenger survival, given characteristics such as gender, age, class, etc.

!pip install -q tf-nightly  # Requires tf 1.13
Requirement already satisfied: tf-nightly in /usr/local/lib/python3.6/dist-packages (1.13.0.dev20190223)
Requirement already satisfied: tf-estimator-nightly in /usr/local/lib/python3.6/dist-packages (from tf-nightly) (1.14.0.dev2019022301)
Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly) (1.1.0)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.6/dist-packages (from tf-nightly) (0.33.1)
Requirement already satisfied: protobuf>=3.6.1 in /usr/local/lib/python3.6/dist-packages (from tf-nightly) (3.6.1)
Requirement already satisfied: astor>=0.6.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly) (0.7.1)
Requirement already satisfied: numpy<2.0,>=1.14.5 in /usr/local/lib/python3.6/dist-packages (from tf-nightly) (1.14.6)
Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tf-nightly) (1.15.0)
Requirement already satisfied: keras-preprocessing>=1.0.5 in /usr/local/lib/python3.6/dist-packages (from tf-nightly) (1.0.9)
Requirement already satisfied: tb-nightly<1.14.0a0,>=1.13.0a0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly) (1.13.0a20190223)
Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly) (1.11.0)
Requirement already satisfied: gast>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly) (0.2.2)
Requirement already satisfied: absl-py>=0.1.6 in /usr/local/lib/python3.6/dist-packages (from tf-nightly) (0.7.0)
Requirement already satisfied: google-pasta>=0.1.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly) (0.1.4)
Requirement already satisfied: keras-applications>=1.0.6 in /usr/local/lib/python3.6/dist-packages (from tf-nightly) (1.0.7)
Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.6.1->tf-nightly) (40.8.0)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<1.14.0a0,>=1.13.0a0->tf-nightly) (3.0.1)
Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<1.14.0a0,>=1.13.0a0->tf-nightly) (0.14.1)
Requirement already satisfied: h5py in /usr/local/lib/python3.6/dist-packages (from keras-applications>=1.0.6->tf-nightly) (2.8.0)
from __future__ import absolute_import, division, print_function

import numpy as np
import pandas as pd
import tensorflow as tf


# Load dataset.
dftrain = pd.read_csv('')
dfeval = pd.read_csv('')
y_train = dftrain.pop('survived')
y_eval = dfeval.pop('survived')

For a description of the features, please review the prior tutorial.

Create feature columns, input_fn, and the train the estimator

Preprocess the data

Create the feature columns, using the original numeric columns as is and one-hot-encoding categorical variables.

fc = tf.feature_column
CATEGORICAL_COLUMNS = ['sex', 'n_siblings_spouses', 'parch', 'class', 'deck', 
                       'embark_town', 'alone']
NUMERIC_COLUMNS = ['age', 'fare']
def one_hot_cat_column(feature_name, vocab):
  return fc.indicator_column(
feature_columns = []
for feature_name in CATEGORICAL_COLUMNS:
  # Need to one-hot encode categorical features.
  vocabulary = dftrain[feature_name].unique()
  feature_columns.append(one_hot_cat_column(feature_name, vocabulary))
for feature_name in NUMERIC_COLUMNS:

Build the input pipeline

Create the input functions using the from_tensor_slices method in the API to read in data directly from Pandas.

# Use entire batch since this is such a small dataset.
NUM_EXAMPLES = len(y_train)

def make_input_fn(X, y, n_epochs=None, shuffle=True):
  def input_fn():
    dataset ='list'), y))
    if shuffle:
      dataset = dataset.shuffle(NUM_EXAMPLES)
    # For training, cycle thru dataset as many times as need (n_epochs=None).    
    dataset = (dataset
    return dataset
  return input_fn

# Training and evaluation input functions.
train_input_fn = make_input_fn(dftrain, y_train)
eval_input_fn = make_input_fn(dfeval, y_eval, shuffle=False, n_epochs=1)

Train the model

params = {
  'n_trees': 50,
  'max_depth': 3,
  'n_batches_per_layer': 1,
  # You must enable center_bias = True to get DFCs. This will force the model to 
  # make an initial prediction before using any features (e.g. use the mean of 
  # the training labels for regression or log odds for classification when
  # using cross entropy loss).
  'center_bias': True

est = tf.estimator.BoostedTreesClassifier(feature_columns, **params)
est.train(train_input_fn, max_steps=100)
results = est.evaluate(eval_input_fn)
accuracy 0.803030
accuracy_baseline 0.625000
auc 0.866850
auc_precision_recall 0.841105
average_loss 0.425117
global_step 100.000000
label/mean 0.375000
loss 0.425117
precision 0.747368
prediction/mean 0.388147
recall 0.717172

For performance reasons, when your data fits in memory, we recommend use the boosted_trees_classifier_train_in_memory function. However if training time is not of a concern or if you have a very large dataset and want to do distributed training, use the tf.estimator.BoostedTrees API shown above.

When using this method, you should not batch your input data, as the method operates on the entire dataset.

in_memory_params = dict(params)
del in_memory_params['n_batches_per_layer']
# In-memory input_fn does not use batching.
def make_inmemory_train_input_fn(X, y):
  def input_fn():
    return dict(X), y
  return input_fn
train_input_fn = make_inmemory_train_input_fn(dftrain, y_train)

# Train the model.
est = tf.contrib.estimator.boosted_trees_classifier_train_in_memory(

WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
If you depend on functionality not listed there, please file an issue.

{'accuracy': 0.8068182, 'accuracy_baseline': 0.625, 'auc': 0.8668504, 'auc_precision_recall': 0.8509292, 'average_loss': 0.4199287, 'label/mean': 0.375, 'loss': 0.4199287, 'precision': 0.7553192, 'prediction/mean': 0.3870664, 'recall': 0.7171717, 'global_step': 153}

Model interpretation and plotting

import matplotlib.pyplot as plt
import seaborn as sns
sns_colors = sns.color_palette('colorblind')

Local interpretability

Next you will output the directional feature contributions (DFCs) to explain individual predictions using the approach outlined in Palczewska et al and by Saabas in Interpreting Random Forests (this method is also available in scikit-learn for Random Forests in the treeinterpreter package). The DFCs are generated with:

pred_dicts = list(est.experimental_predict_with_explanations(pred_input_fn))

(Note: The method is named experimental as we may modify the API before dropping the experimental prefix.)

pred_dicts = list(est.experimental_predict_with_explanations(eval_input_fn))
# Create DFC Pandas dataframe.
labels = y_eval.values
probs = pd.Series([pred['probabilities'][1] for pred in pred_dicts])
df_dfc = pd.DataFrame([pred['dfc'] for pred in pred_dicts])
count mean std min 25% 50% 75% max
age 264.0 -0.025428 0.096347 -0.146072 -0.078276 -0.057353 0.001774 0.496446
sex 264.0 0.007022 0.110539 -0.124210 -0.074910 -0.073353 0.139825 0.183282
class 264.0 0.015867 0.093520 -0.078707 -0.045622 -0.044700 0.033924 0.251914
deck 264.0 -0.015923 0.033929 -0.076938 -0.042056 -0.029379 0.003550 0.200937
embark_town 264.0 -0.007180 0.027693 -0.052807 -0.014748 -0.013727 -0.003101 0.068115
fare 264.0 0.021740 0.086823 -0.262387 -0.028537 -0.007287 0.057555 0.230759
n_siblings_spouses 264.0 0.002836 0.027726 -0.157888 0.002401 0.004710 0.006721 0.116872
parch 264.0 0.000294 0.007607 -0.066935 0.000315 0.000395 0.000725 0.024951
alone 264.0 0.000279 0.003937 -0.006541 0.000000 0.000000 0.000000 0.017987

A nice property of DFCs is that the sum of the contributions + the bias is equal to the prediction for a given example.

# Sum of DFCs + bias == probabality.
bias = pred_dicts[0]['bias']
dfc_prob = df_dfc.sum(axis=1) + bias

Plot DFCs for an individual passenger.

# Plot results.
ID = 182
example = df_dfc.iloc[ID]  # Choose ith example from evaluation set.
TOP_N = 8  # View top 8 features.
sorted_ix = example.abs().sort_values()[-TOP_N:].index
ax = example[sorted_ix].plot(kind='barh', color=sns_colors[3])
ax.grid(False, axis='y')

ax.set_title('Feature contributions for example {}\n pred: {:1.2f}; label: {}'.format(ID, probs[ID], labels[ID]))
ax.set_xlabel('Contribution to predicted probability');


The larger magnitude contributions have a larger impact on the model's prediction. Negative contributions indicate the feature value for this given example reduced the model's prediction, while positive values contribute an increase in the prediction.

Improved plotting

Let's make the plot nice by color coding based on the contributions' directionality and add the feature values on figure.

# Boilerplate code for plotting :)
def _get_color(value):
    """To make positive DFCs plot green, negative DFCs plot red."""
    green, red = sns.color_palette()[2:4]
    if value >= 0: return green
    return red

def _add_feature_values(feature_values, ax):
    """Display feature's values on left of plot."""
    x_coord = ax.get_xlim()[0]
    OFFSET = 0.15
    for y_coord, (feat_name, feat_val) in enumerate(feature_values.items()):
        t = plt.text(x_coord, y_coord - OFFSET, '{}'.format(feat_val), size=12)
        t.set_bbox(dict(facecolor='white', alpha=0.5))
    from matplotlib.font_manager import FontProperties
    font = FontProperties()
    t = plt.text(x_coord, y_coord + 1 - OFFSET, 'feature\nvalue',
    fontproperties=font, size=12)
def plot_example(example):
  TOP_N = 8 # View top 8 features.
  sorted_ix = example.abs().sort_values()[-TOP_N:].index  # Sort by magnitude.
  example = example[sorted_ix]
  colors =
  ax = example.to_frame().plot(kind='barh',
  ax.grid(False, axis='y')
  ax.set_yticklabels(ax.get_yticklabels(), size=14)

  # Add feature values.
  _add_feature_values(dfeval.iloc[ID][sorted_ix], ax)
  return ax

Plot example.

example = df_dfc.iloc[ID]  # Choose IDth example from evaluation set.
ax = plot_example(example)
ax.set_title('Feature contributions for example {}\n pred: {:1.2f}; label: {}'.format(ID, probs[ID], labels[ID]))
ax.set_xlabel('Contribution to predicted probability', size=14);


You can also plot the example's DFCs compare with the entire distribution using a voilin plot.

# Boilerplate plotting code.
def dist_violin_plot(df_dfc, ID):
  # Initialize plot.
  fig, ax = plt.subplots(1, 1, figsize=(10, 6))
  # Create example dataframe.
  TOP_N = 8  # View top 8 features.
  example = df_dfc.iloc[ID]
  ix = example.abs().sort_values()[-TOP_N:].index
  example = example[ix]
  example_df = example.to_frame(name='dfc')
  # Add contributions of entire distribution.
  parts=ax.violinplot([df_dfc[w] for w in ix],
  face_color = sns_colors[0]
  alpha = 0.15
  for pc in parts['bodies']:
  # Add feature values.
  _add_feature_values(dfeval.iloc[ID][sorted_ix], ax)

  # Add local contributions.
              label='contributions for example')
  # Legend
  # Proxy plot, to show violinplot dist on legend.
  ax.plot([0,0], [1,1], label='eval set contributions\ndistributions',
          color=face_color, alpha=alpha, linewidth=10)
  legend = ax.legend(loc='lower right', shadow=True, fontsize='x-large',
  # Format plot.
  ax.grid(False, axis='y')
  ax.set_xlabel('Contribution to predicted probability', size=14)

Plot this example.

dist_violin_plot(df_dfc, ID)
plt.title('Feature contributions for example {}\n pred: {:1.2f}; label: {}'.format(ID, probs[ID], labels[ID]));


Finally, third-party tools, such as LIME and shap, can also help understand individual predictions for a model.

Global feature importances

Additionally, you might want to understand the model as a whole, rather than studying individual predictions. Below, you will compute and use:

  1. Gain-based feature importances using est.experimental_feature_importances
  2. Permutation importances
  3. Aggregate DFCs using est.experimental_predict_with_explanations

Gain-based feature importances measure the loss change when splitting on a particular feature, while permutation feature importances are computed by evaluating model performance on the evaluation set by shuffling each feature one-by-one and attributing the change in model performance to the shuffled feature.

In general, permutation feature importance are preferred to gain-based feature importance, though both methods can be unreliable in situations where potential predictor variables vary in their scale of measurement or their number of categories and when features are correlated (source). Check out this article for an in-depth overview and great discussion on different feature importance types.

1. Gain-based feature importances

Gain-based feature importances are built into the TensorFlow Boosted Trees estimators using est.experimental_feature_importances.

importances = est.experimental_feature_importances(normalize=True)
df_imp = pd.Series(importances)

# Visualize importances.
N = 8
ax = (df_imp.iloc[0:N][::-1]
          title='Gain feature importances',
          figsize=(10, 6)))
ax.grid(False, axis='y')


2. Average absolute DFCs

You can also average the absolute values of DFCs to understand impact at a global level.

# Plot.
dfc_mean = df_dfc.abs().mean()
N = 8
sorted_ix = dfc_mean.abs().sort_values()[-N:].index  # Average and sort by absolute.
ax = dfc_mean[sorted_ix].plot(kind='barh',
                       title='Mean |directional feature contributions|',
                       figsize=(10, 6))
ax.grid(False, axis='y')


You can also see how DFCs vary as a feature value varies.

FEATURE = 'fare'
feature = pd.Series(df_dfc[FEATURE].values, index=dfeval[FEATURE].values).sort_index()
ax = sns.regplot(feature.index.values, feature.values, lowess=True);
ax.set_xlim(0, 100);


3. Permutation feature importance

def permutation_importances(est, X_eval, y_eval, metric, features):
    """Column by column, shuffle values and observe effect on eval set.
    A similar approach can be done during training. See "Drop-column importance"
    in the above article."""
    baseline = metric(est, X_eval, y_eval)
    imp = []
    for col in features:
        save = X_eval[col].copy()
        X_eval[col] = np.random.permutation(X_eval[col])
        m = metric(est, X_eval, y_eval)
        X_eval[col] = save
        imp.append(baseline - m)
    return np.array(imp)

def accuracy_metric(est, X, y):
    """TensorFlow estimator accuracy."""
    eval_input_fn = make_input_fn(X,
    return est.evaluate(input_fn=eval_input_fn)['accuracy']
importances = permutation_importances(est, dfeval, y_eval, accuracy_metric,
df_imp = pd.Series(importances, index=features)

sorted_ix = df_imp.abs().sort_values().index
ax = df_imp[sorted_ix][-5:].plot(kind='barh', color=sns_colors[2], figsize=(10, 6))
ax.grid(False, axis='y')
ax.set_title('Permutation feature importance');


Visualizing model fitting

Lets first simulate/create training data using the following formula:

\[z=x* e^{-x^2 - y^2}\]

Where \(z\) is the dependent variable you are trying to predict and \(x\) and \(y\) are the features.

from numpy.random import uniform, seed
from matplotlib.mlab import griddata

# Create fake data
npts = 5000
x = uniform(-2, 2, npts)
y = uniform(-2, 2, npts)
z = x*np.exp(-x**2 - y**2)
# Prep data for training.
df = pd.DataFrame({'x': x, 'y': y, 'z': z})

xi = np.linspace(-2.0, 2.0, 200),
yi = np.linspace(-2.1, 2.1, 210),
xi,yi = np.meshgrid(xi, yi);

df_predict = pd.DataFrame({
    'x' : xi.flatten(),
    'y' : yi.flatten(),
predict_shape = xi.shape
def plot_contour(x, y, z, **kwargs):
  # Grid the data.
  plt.figure(figsize=(10, 8))
  # Contour the gridded data, plotting dots at the nonuniform data points.
  CS = plt.contour(x, y, z, 15, linewidths=0.5, colors='k')
  CS = plt.contourf(x, y, z, 15,
                    vmax=abs(zi).max(), vmin=-abs(zi).max(), cmap='RdBu_r')
  plt.colorbar()  # Draw colorbar.
  # Plot data points.
  plt.xlim(-2, 2)
  plt.ylim(-2, 2)

You can visualize the function. Redder colors correspond to larger function values.

zi = griddata(x, y, z, xi, yi, interp='linear')
plot_contour(xi, yi, zi)
plt.scatter(df.x, df.y, marker='.')
plt.title('Contour on training data');
/usr/local/lib/python3.6/dist-packages/ MatplotlibDeprecationWarning: The griddata function was deprecated in Matplotlib 2.2 and will be removed in 3.1. Use scipy.interpolate.griddata instead.
  """Entry point for launching an IPython kernel.


fc = [tf.feature_column.numeric_column('x'),
def predict(est):
  """Predictions from a given estimator."""
  predict_input_fn = lambda:
  preds = np.array([p['predictions'][0] for p in est.predict(predict_input_fn)])
  return preds.reshape(predict_shape)

First let's try to fit a linear model to the data.

train_input_fn = make_input_fn(df, df.z)
est = tf.estimator.LinearRegressor(fc)
est.train(train_input_fn, max_steps=500);
plot_contour(xi, yi, predict(est))


It's not a very good fit. Next let's try to fit a GBDT model to it and try to understand how the model fits the function.

def create_bt_est(n_trees):
  return tf.estimator.BoostedTreesRegressor(fc,
N_TREES = [1,2,3,4,10,20,50,100]
for n in N_TREES:
  est = create_bt_est(n)
  est.train(train_input_fn, max_steps=500)
  plot_contour(xi, yi, predict(est))
  plt.text(-1.8, 2.1, '# trees: {}'.format(n), color='w', backgroundcolor='black', size=20);









As you increase the number of trees, the model's predictions better approximates the underlying function.


In this tutorial you learned how to interpret Boosted Trees models using directional feature contributions and feature importance techniques. These techniques provide insight into how the features impact a model's predictions. Finally, you also gained intution for how a Boosted Tree model fits a complex function by viewing the decision surface for several models.