Attend the Women in ML Symposium on December 7 Register now

Proximities and Prototypes with Random Forests

Stay organized with collections Save and categorize content based on your preferences.

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook

Introduction

Leo Breiman, the author of the random forest learning algorithm, proposed a method to measure the proximity (also known as similarity) between two examples using a pre-trained Random Forest (RF) model. He qualifies this method as "[...] one of the most useful tools in random forests.". In this Notebook, we implement this method and show how to use it to interpret models.

This notebook is implemented using the TensorFlow Decision Forests library. This document is easier to understand if you are familiar with the content of the Beginner colab.

Proximities

A proximity (or a similarity) between two examples is a number indicating how "close" those two examples are. Following is an example of similarity in between the 3 examples \(\{e_1, e_2, e_3\}\):

\[ \mathrm{proxy}(e_1, e_2) = 0.1 \\ \mathrm{proxy}(e_2, e_3) = 9.6 \\ \mathrm{proxy}(e_3, e_1) = 4.1 \\ \]

For convenience, the proximity between examples is represented in matrix form:

\(e_1\) \(e_2\) \(e_3\)
\(e_1\) \(\mathrm{proxy}(e_1, e_1)\) \(\mathrm{proxy}(e_1, e_2)\) \(\mathrm{proxy}(e_1, e_3)\)
\(e_2\) \(\mathrm{proxy}(e_2, e_1)\) \(\mathrm{proxy}(e_2, e_2)\) \(\mathrm{proxy}(e_2, e_3)\)
\(e_3\) \(\mathrm{proxy}(e_3, e_1)\) \(\mathrm{proxy}(e_3, e_2)\) \(\mathrm{proxy}(e_3, e_3)\)

Proximities are used in multiple data analysis techniques, including clustering, dimensionality reductions or nearest neighbor analysis. For this reason, it is a great tool for models and predictions interpretation.

Unfortunately, measuring the proximity between two tabular examples is not straightforward as different columns might describe different quantities. For example, try to define the proximity in between the following examples.

species weight num_legs age sex
cat 2 kg 4 2 y male
dog 6 kg 4 12 y female
spider 5 g 8 3 weeks female

To define the similarity between two rows in the table above, you need to specify how much a difference in weight compares to a difference in the number of legs, or in ages. In addition, relations might be non-linear or be conditional on other columns. For example, dogs live longer than spiders, so maybe, a one year difference for a spider should not count the same one year of age for a dog.

Instead of manually defining those relations, Breiman's proximity turns a random forest model (which we know how to train on a tabular dataset), into a proximity metric.

Proximities with random forests

A random forest is a collection of decision trees. The prediction of the random the aggregation of the predictions of the individual trees. The prediction of a decision tree is computed by routing an example from the root to forest is one of the leaves according to node conditions. The leaf reached by the example \(i\) in the tree \(t\) is called its active leaf and noted \(\mathrm{leaf}(i,t)\)

Breiman defines the proximity between two examples as the ratio of shared active leafs between those two examples. Formally, the proximity between example \(i\) and example \(j\) is:

\[ \mathrm{prox}(i,j) = \mathrm{prox}(j,i) = \frac{1}{|\mathrm{Trees}|} \sum_{t \in \mathrm{Trees} } \left[ \mathrm{leaf}(i,t) = \mathrm{leaf}(j,t) \right] \]

with \(\mathrm{leaf}(j,t)\) the index of the active leaf for the example \(j\) in the tree \(t\).

Informally, if two examples are often routed to the same leaves (i.e. the two examples have the same active leaves), those examples are similar.

Let's implement this proximity function and use it in some examples.

Setup

# Install TensorFlow Dececision Forests and the dependencies used in this colab.
pip install tensorflow_decision_forests plotly scikit-learn wurlitzer -U -qq
import tensorflow_decision_forests as tfdf

import matplotlib.colors as mcolors
import math
import os
import numpy as np
import pandas as pd
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from plotly.offline import iplot
import plotly.graph_objs as go
2022-11-23 12:09:15.027555: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-11-23 12:09:15.027652: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-11-23 12:09:15.027662: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

Train a Random Forest model

The method relies on a pre-trained random forest model. First, we train a random forest model with TensorFlow Decision Forests library on the Adult binary classification dataset. The Adult dataset is well suited for this example as it contains columns that don't have a natural way to be compared.

# Download a copy of the adult dataset.
wget -q https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset/adult_train.csv -O /tmp/adult_train.csv
wget -q https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset/adult_test.csv -O /tmp/adult_test.csv
# Load the dataset in memory
train_df = pd.read_csv("/tmp/adult_train.csv")
test_df = pd.read_csv("/tmp/adult_test.csv")

# , and convert it into a TensorFlow dataset.
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label="income")
test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_df, label="income")

Following are the first five examples of the training dataset. Notice that different columns represent different quantities. For example, how would you compare the distance between relationship and age?

# Print the first 5 examples.
train_df.head()

A Random Forest is trained as follows:

# Train a Random Forest
model = tfdf.keras.RandomForestModel(num_trees=1000)
model.fit(train_ds)
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus.
Use /tmpfs/tmp/tmpxpsa62ou as temporary training directory
Reading training dataset...
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
Training dataset read in 0:00:03.666538. Found 22792 examples.
Training model...
[INFO 2022-11-23T12:09:31.144007853+00:00 kernel.cc:1175] Loading model from path /tmpfs/tmp/tmpxpsa62ou/model/ with prefix 4d8ebcbe59a2451c
[INFO 2022-11-23T12:09:35.665899507+00:00 abstract_model.cc:1306] Engine "RandomForestGeneric" built
[INFO 2022-11-23T12:09:35.665947684+00:00 kernel.cc:1021] Use fast generic engine
Model trained in 0:00:10.480653
Compiling model...
WARNING:tensorflow:AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7fe1ef059280> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:tensorflow:AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7fe1ef059280> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function simple_ml_inference_op_with_handle at 0x7fe1ef059280> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
Model compiled.
<keras.callbacks.History at 0x7fe299ac5700>

The performance of the Random Forest model is:

model_inspector = model.make_inspector()
out_of_bag_accuracy = model_inspector.evaluation().accuracy
print(f"Out-of-bag accuracy: {out_of_bag_accuracy:.4f}")
Out-of-bag accuracy: 0.8653

This is an expected accuracy value for Random Forest models on this dataset. It indicates that the model is correctly trained.

We can also measure the accuracy of the model on the test datasets:

# The test accuracy is measured on the test datasets.
model.compile(["accuracy"])
test_accuracy = model.evaluate(test_ds, return_dict=True, verbose=0)["accuracy"]
print(f"Test accuracy: {test_accuracy:.4f}")
Test accuracy: 0.8663

Proximities

First, we inspect the number of trees in the model and the number of examples in the test datasets.

print("The model contains", model_inspector.num_trees(), "trees.")
print("The test dataset contains", test_df.shape[0], "examples.")
The model contains 1000 trees.
The test dataset contains 9769 examples.

The method predict_get_leaves() returns the index of the active leaf for each example and each tree.

leaves = model.predict_get_leaves(test_ds)
print("The leaf indices:\n", leaves)
[INFO 2022-11-23T12:09:41.06379579+00:00 kernel.cc:1175] Loading model from path /tmpfs/tmp/tmpxpsa62ou/model/ with prefix 4d8ebcbe59a2451c
WARNING:tensorflow:AutoGraph could not transform <function simple_ml_inference_leaf_index_op_with_handle at 0x7fe1ef0594c0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
[INFO 2022-11-23T12:09:44.649709738+00:00 kernel.cc:1027] Use slow generic engine
WARNING:tensorflow:AutoGraph could not transform <function simple_ml_inference_leaf_index_op_with_handle at 0x7fe1ef0594c0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function simple_ml_inference_leaf_index_op_with_handle at 0x7fe1ef0594c0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
The leaf indices:
 [[498 193 142 ... 457 221 198]
 [399 466 423 ... 288 420 444]
 [639 651 562 ... 608 636 625]
 ...
 [149 296 258 ... 153 310 316]
 [481 186 131 ... 432 192 153]
 [  9   0  28 ...   4   1  42]]
print("The predicted leaves have shape", leaves.shape,
      "(we expect [num_examples, num_trees]")
The predicted leaves have shape (9769, 1000) (we expect [num_examples, num_trees]

Here, leaves[i,j] is the index of the active leaf of the i-th example in the j-th tree.

Next, we implement the \(\mathrm{prox}\) equation define earlier.

def compute_proximity(leaves, step_size=100):
  """Computes the proximity between each pair of examples.

  Args:
    leaves: A matrix of shape [num_example, num_tree] where the value [i,j] is
      the index of the leaf reached by example "i" in the tree "j".
    step_size: Size of the block of examples for the computation of the
      proximity. Does not impact the results.

  Returns:
    The example pair-wise proximity matrix of shape [n,n] with "n" the number of
    examples.
  """

  example_idx = 0
  num_examples = leaves.shape[0]
  t_leaves = np.transpose(leaves)
  proximities = []

  # Instead of computing the proximity in between all the examples at the same
  # time, we compute the similarity in blocks of "step_size" examples. This
  # makes the code more efficient with the the numpy broadcast.
  while example_idx < num_examples:
    end_idx = min(example_idx + step_size, num_examples)
    proximities.append(
        np.mean(
            leaves[..., np.newaxis] == t_leaves[:,
                                                example_idx:end_idx][np.newaxis,
                                                                     ...],
            axis=1))
    example_idx = end_idx
  return np.concatenate(proximities, axis=1)


proximity = compute_proximity(leaves)
print("The shape of proximity is", proximity.shape)
The shape of proximity is (9769, 9769)

Here, proximity[i,j] is the proximity in between the example i and j.

The proximity matrix:

proximity
array([[1.   , 0.   , 0.   , ..., 0.   , 0.053, 0.   ],
       [0.   , 1.   , 0.   , ..., 0.002, 0.   , 0.   ],
       [0.   , 0.   , 1.   , ..., 0.   , 0.   , 0.   ],
       ...,
       [0.   , 0.002, 0.   , ..., 1.   , 0.   , 0.   ],
       [0.053, 0.   , 0.   , ..., 0.   , 1.   , 0.   ],
       [0.   , 0.   , 0.   , ..., 0.   , 0.   , 1.   ]])

The proximity matrix has several interesting properties, notably, it is symmetrical, positive, and the diagonal elements are all 1.

Projection

Our first use of the proximity is to project the examples on the two dimensional plane.

If \(\mathrm{prox} \in [0,1]\) is a proximity, \(1 - \mathrm{prox}\) is a distance between examples. Breiman proposes to compute the inner products of those distances, and to plot the eigenvalues. See details here.

Instead, we will use the t-SNE which is a more modern way to visualize high-dimensional data.

distance = 1 - proximity

t_sne = TSNE(
    # Number of dimensions to display. 3d is also possible.
    n_components=2,
    # Control the shape of the projection. Higher values create more
    # distinct but also more collapsed clusters. Can be in 5-50.
    perplexity=20,
    metric="precomputed",
    init="random",
    verbose=1,
    square_distances=True,
    learning_rate="auto").fit_transform(distance)
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/sklearn/manifold/_t_sne.py:830: FutureWarning:

The parameter `square_distances` has not effect and will be removed in version 1.3.
[t-SNE] Computing 61 nearest neighbors...
[t-SNE] Indexed 9769 samples in 0.206s...
[t-SNE] Computed neighbors for 9769 samples in 1.225s...
[t-SNE] Computed conditional probabilities for sample 1000 / 9769
[t-SNE] Computed conditional probabilities for sample 2000 / 9769
[t-SNE] Computed conditional probabilities for sample 3000 / 9769
[t-SNE] Computed conditional probabilities for sample 4000 / 9769
[t-SNE] Computed conditional probabilities for sample 5000 / 9769
[t-SNE] Computed conditional probabilities for sample 6000 / 9769
[t-SNE] Computed conditional probabilities for sample 7000 / 9769
[t-SNE] Computed conditional probabilities for sample 8000 / 9769
[t-SNE] Computed conditional probabilities for sample 9000 / 9769
[t-SNE] Computed conditional probabilities for sample 9769 / 9769
[t-SNE] Mean sigma: 0.188051
[t-SNE] KL divergence after 250 iterations with early exaggeration: 76.190346
[t-SNE] KL divergence after 1000 iterations: 1.117140

The next plot shows a two-dimensional projection of the test example features. The color of the points represent the label values. Note that the label values were not available to the model.

fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.grid(False)

# Color the points according to the label value.
colors = (test_df["income"] == ">50K").map(lambda x: ["orange", "green"][x])
ax.scatter(
    t_sne[:, 0], t_sne[:, 1], c=colors, linewidths=0.5, marker="x", s=20)
<matplotlib.collections.PathCollection at 0x7fe220c37df0>

png

Observations:

  • There are clusters of points with similar colors. Those are examples that are easy for the model to classify.
  • There are multiple clusters with the same color. Those multiple clusters show examples with the same label, but for "different reasons" according to the model.
  • Clusters with mixed colors contain examples where the model performs poorly. In the part above, we evaluated the model test accuracy to ~86%. Those are likely those examples.

The previous plot is a static image. Let's turn it into an interactive plot and inspect the individual examples.

# docs_infra: no_execute

# Note: Run the colab (click the "Run in Google Colab" link at the top) to see
# the interactive plot.

def interactive_plot(dataset, projections):

  def label_fn(row):
    """HTML printer over each example."""
    return "<br>".join([f"<b>{k}:</b> {v}" for k, v in row.items()])

  labels = list(dataset.apply(label_fn, axis=1).values)
  iplot({
      "data": [
          go.Scatter(
              x=projections[:, 0],
              y=projections[:, 1],
              text=labels,
              mode="markers",
              marker={
                  "color": colors,
                  "size": 3,
              })
      ],
      "layout": go.Layout(width=600, height=600, template="simple_white")
  })


interactive_plot(test_df, t_sne)

Instructions: Put the mouse pointer over some examples, and try to make sense of them. Compare them to their neighbors.

Not seeing the interactive plot?: Run the colab with this link to see the interactive plot.

Instead of coloring the examples according to the label values, we can color the examples according to each feature values:

# Number of columns and rows in the multi-plot.
num_plot_cols = 5
num_plot_rows = math.ceil(test_df.shape[1] / num_plot_cols)

# Color palette for the categorical features.
palette = list(mcolors.TABLEAU_COLORS.values())

# Create the plot
plot_size_in = 3.5
fig, axs = plt.subplots(
    num_plot_rows,
    num_plot_cols,
    figsize=(num_plot_cols * plot_size_in, num_plot_rows * plot_size_in))

# Hide the borders.
for row in axs:
  for ax in row:
    ax.set_axis_off()

for col_idx, col_name in enumerate(test_df):
  ax = axs[col_idx // num_plot_cols, col_idx % num_plot_cols]

  colors = test_df[col_name]
  if colors.dtypes in [str, object]:
    # Use the color palette on categorical features.
    unique_values = list(colors.unique())
    colors = colors.map(
        lambda x: palette[unique_values.index(x) % len(palette)])

  ax.set_title(col_name)
  ax.scatter(t_sne[:, 0], t_sne[:, 1], c=colors.values, linewidths=0.5,
      marker="x", s=5)

png

Prototypes

Trying to make sense of an example by looking at all its neighbors is not always efficient. Instead, we could "group" similar examples to make this task easier. This is the underlying idea behind prototypes.

Prototypes are examples, not necessarily in the original dataset, that are representative of large trends in the dataset. Looking at prototypes is a solution to understand a dataset. For more details, see the chapter 8.7 of Interpretable Machine Learning by Molnar.

Prototypes can be computed in different ways, for example using a clustering algorithm. Instead, Breiman proposed a specific solution based on a simple iterative algorithm. The algorithm is as follow:

  1. Select the example surrounded with the highest number of neighbors with the same class among its k nearest neighbors.
  2. Create a prototype example using the median feature values of the selected example and its k neighbors.
  3. Remove those k+1 examples
  4. Repeat

Informally, prototypes are centers of clusters in the plots we created above.

Let's implement this algorithm and look at some prototypes.

First the method that selects the example in step 1.

def select_example(labels, distance_matrix, k):
  """Selects the example with the highest number of neighbors with the same class.

  Usage example:
    n = 5
    select_example(
        np.random.randint(0,2, size=n),
        np.random.uniform(size=(n,n)),
        2)

  Returns:
    The list of neighbors for the selected example. Includes the selected
    example.
  """

  partition = np.argpartition(distance_matrix, k)[:,:k]
  same_label = np.mean(np.equal(labels[partition], np.expand_dims(labels, axis=1)), axis=1)
  selected_example = np.argmax(same_label)
  return partition[selected_example, :]
def extract_prototype_examples(labels, distance_matrix, k, num_prototypes):
  """Extracts a list of examples in each prototype.

  Usage example:
    n = 50
    print(extract_prototype_examples(
        labels=np.random.randint(0, 2, size=n),
        distance_matrix=np.random.uniform(size=(n, n)),
        k=2,
        num_prototypes=3))

  Returns:
    An array where E[i][j] is the index of the j-th examples of the i-th
    prototype. 
  """

  example_idxs = np.arange(len(labels))
  prototypes = []
  examples_per_prototype = []

  for iter in range(num_prototypes):
    print(f"Iter #{iter}")
    # Select the example
    neighbors = select_example(labels, distance_matrix, k)

    # Index of the examples in the prototype
    examples_per_prototype.append(list(example_idxs[neighbors]))

    # Remove the selected examples
    example_idxs = np.delete(example_idxs, neighbors)
    labels = np.delete(labels, neighbors)
    distance_matrix = np.delete(distance_matrix, neighbors, axis=0)
    distance_matrix = np.delete(distance_matrix, neighbors, axis=1)

  return examples_per_prototype

Using the methods above, let's extract the examples for 10 prototypes.

examples_per_prototype = extract_prototype_examples(test_df["income"].values, distance, k=20, num_prototypes=10)
print(f"Found examples for {len(examples_per_prototype)} prototypes.")
Iter #0
Iter #1
Iter #2
Iter #3
Iter #4
Iter #5
Iter #6
Iter #7
Iter #8
Iter #9
Found examples for 10 prototypes.

For each of those prototypes, we want to display the statistics of the feature values. In this example, we will look at the quartiles of the numerical features, and the most frequent values for the categorical features.

def build_prototype(dataset):
  """Exacts the feature statistics of a prototype.

  For numerical features, returns the quantiles.
  For categorical features, returns the most frequent value.

  Usage example:
    n = 50
    print(build_prototype(
        pd.DataFrame({
            "f1": np.random.uniform(size=n),
            "f2": np.random.uniform(size=n),
            "f3": [f"v_{x}" for x in np.random.randint(0, 2, size=n)],
            "label": np.random.randint(0, 2, size=n)
        })))

  Return:
    A prototype as a dictionary of strings.
  """

  prototype = {}
  for col in dataset.columns:
    col_values = dataset[col]
    if col_values.dtypes in [str, object]:
      # A categorical feature.

      # Remove the missing values
      col_values = [x for x in col_values if isinstance(x,str) or not math.isnan(x)]

      # Frequency of each possible value.
      frequency_item, frequency_count = np.unique(col_values, return_counts=True)
      top_item_idx = np.argmax(frequency_count)
      top_item_probability = frequency_count[top_item_idx] / np.sum(frequency_count)

      # Print the most common item.
      prototype[col] = f"{frequency_item[top_item_idx]} ({100*top_item_probability:.0f}%)"

    else:
      # A numerical feature.
      quartiles = np.nanquantile(col_values.values, [0.25, 0.5, 0.75])
      # Print the 3 quantiles.
      prototype[col] = f"{quartiles[0]} {quartiles[1]} {quartiles[2]}"
  return prototype

Now, let's look at our prototypes.

# Extract the statistics of each prototype.
prototypes = []
for examples in examples_per_prototype:
  # Prorotype statistics.
  prototypes.append(build_prototype(test_df.iloc[examples, :]))
prototypes = pd.DataFrame(prototypes)

prototypes

Try to make sense of the prototypes.

Let's extract and plot the mean 2d t-SNE projection of the elements in those prototypes.

# Extract the projection of each prototype.
prototypes_projection = []
for examples in examples_per_prototype:
  # t-SNE for each prototype.
  prototypes_projection.append(np.mean(t_sne[examples,:],axis=0))
prototypes_projection = np.stack(prototypes_projection)
# Plot the mean 2d t-SNE projection of the elements in the prototypes.

fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.grid(False)

# Color the points according to the label value.
colors = (test_df["income"] == ">50K").map(lambda x: ["orange", "green"][x])
ax.scatter(
    t_sne[:, 0], t_sne[:, 1], c=colors, linewidths=0.5, marker="x", s=20)

# Add the prototype indices.
for i in range(prototypes_projection.shape[0]):
  ax.text(prototypes_projection[i, 0],
          prototypes_projection[i, 1],
          f"{i}",
          fontdict={"size":18},
          c="red")

png

We see that the 10 prototypes cover around half of the domain. Clusters of examples without a prototype would be best explained with more prototypes.

In the example above, we extracted the prototypes automatically. However, we can also build prototypes around specific examples.

Let's create the prototype around the example #0.

example_idx = 0
k = 20
neighbors = np.argpartition(distance[example_idx, :], k)[:k]

print(f"The example #{example_idx} is:")
print("===============================")
print(test_df.iloc[example_idx, :])
print("")
print(f"The prototype around the example #{example_idx} is:")
print("============================================")
print(pd.Series(build_prototype(test_df.iloc[neighbors, :])))
The example #0 is:
===============================
age                          39
workclass             State-gov
fnlwgt                    77516
education             Bachelors
education_num                13
marital_status    Never-married
occupation         Adm-clerical
relationship      Not-in-family
race                      White
sex                        Male
capital_gain               2174
capital_loss                  0
hours_per_week               40
native_country    United-States
income                    <=50K
Name: 0, dtype: object

The prototype around the example #0 is:
============================================
age                          36.0 39.0 41.0
workclass                     Private (50%)
fnlwgt            72314.0 115188.5 138797.0
education                   Bachelors (95%)
education_num                13.0 13.0 13.0
marital_status          Never-married (65%)
occupation               Adm-clerical (70%)
relationship            Not-in-family (75%)
race                            White (95%)
sex                              Male (65%)
capital_gain                    0.0 0.0 0.0
capital_loss                    0.0 0.0 0.0
hours_per_week              38.75 40.0 40.0
native_country         United-States (100%)
income                         <=50K (100%)
dtype: object