Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

TF Lattice Premade Models

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

Premade Models are quick and easy ways to build TFL tf.keras.model instances for typical use cases. This guide outlines the steps needed to construct a TFL Premade Model and train/test it.

Setup

Installing TF Lattice package:

pip install -q tensorflow-lattice pydot

Importing required packages:

from __future__ import absolute_import, division, print_function, unicode_literals

try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass
import tensorflow as tf

import logging
import numpy as np
import pandas as pd
import sys
import tensorflow_lattice as tfl
from tensorflow import feature_column as fc
logging.disable(sys.maxsize)

Downloading the UCI Statlog (Heart) dataset:

csv_file = tf.keras.utils.get_file(
    'heart.csv', 'http://storage.googleapis.com/applied-dl/heart.csv')
df = pd.read_csv(csv_file)
train_size = int(len(df) * 0.8)
train_dataframe = df[:train_size]
test_dataframe = df[train_size:]
df.head()

Extract and convert features and labels to tensors:

# Features:
# - age
# - sex
# - cp        chest pain type (4 values)
# - trestbps  resting blood pressure
# - chol      serum cholestoral in mg/dl
# - fbs       fasting blood sugar > 120 mg/dl
# - restecg   resting electrocardiographic results (values 0,1,2)
# - thalach   maximum heart rate achieved
# - exang     exercise induced angina
# - oldpeak   ST depression induced by exercise relative to rest
# - slope     the slope of the peak exercise ST segment
# - ca        number of major vessels (0-3) colored by flourosopy
# - thal      3 = normal; 6 = fixed defect; 7 = reversable defect
#
# This ordering of feature names will be the exact same order that we construct
# our model to expect.
feature_names = [
    'age', 'sex', 'cp', 'chol', 'fbs', 'trestbps', 'thalach', 'restecg',
    'exang', 'oldpeak', 'slope', 'ca', 'thal'
]
feature_name_indices = {name: index for index, name in enumerate(feature_names)}
# This is the vocab list and mapping we will use for the 'thal' categorical
# feature.
thal_vocab_list = ['normal', 'fixed', 'reversible']
thal_map = {category: i for i, category in enumerate(thal_vocab_list)}
# Custom function for converting thal categories to buckets
def convert_thal_features(thal_features):
  # Note that two examples in the test set are already converted.
  return np.array([
      thal_map[feature] if feature in thal_vocab_list else feature
      for feature in thal_features
  ])


# Custom function for extracting each feature.
def extract_features(dataframe,
                     label_name='target',
                     feature_names=feature_names):
  features = []
  for feature_name in feature_names:
    if feature_name == 'thal':
      features.append(
          convert_thal_features(dataframe[feature_name].values).astype(float))
    else:
      features.append(dataframe[feature_name].values.astype(float))
  labels = dataframe[label_name].values.astype(float)
  return features, labels
train_xs, train_ys = extract_features(train_dataframe)
test_xs, test_ys = extract_features(test_dataframe)
# Let's define our label minimum and maximum.
min_label, max_label = float(np.min(train_ys)), float(np.max(train_ys))
# Our lattice models may have predictions above 1.0 due to numerical errors.
# We can subtract this small epsilon value from our output_max to make sure we
# do not predict values outside of our label bound.
numerical_error_epsilon = 1e-5

Setting the default values used for training in this guide:

LEARNING_RATE = 0.01
BATCH_SIZE = 128
NUM_EPOCHS = 500

Feature Configs

Feature calibration and per-feature configurations are set using tfl.configs.FeatureConfig. Feature configurations include monotonicity constraints, per-feature regularization (see tfl.configs.RegularizerConfig), and lattice sizes for lattice models.

Note that we must fully specify the feature config for any feature that we want our model to recognize. Otherwise the model will have no way of knowing that such a feature exists.

Compute Quatiles

Although the default setting for pwl_calibration_input_keypoints in tfl.configs.FeatureConfig is 'quatiles', for premade models we have to manually define the input keypoints. To do so, we first define our own helper function for computing quantiles.

def compute_quantiles(features,
                      num_keypoints=10,
                      clip_min=None,
                      clip_max=None,
                      missing_value=None):
  # Clip min and max if desired.
  if clip_min is not None:
    features = np.maximum(features, clip_min)
    features = np.append(features, clip_min)
  if clip_max is not None:
    features = np.minimum(features, clip_max)
    features = np.append(features, clip_max)
  # Make features unique.
  unique_features = np.unique(features)
  # Remove missing values if specified.
  if missing_value is not None:
    unique_features = np.delete(unique_features,
                                np.where(unique_features == missing_value))
  # Compute and return quantiles over unique non-missing feature values.
  return np.quantile(
      unique_features,
      np.linspace(0., 1., num=num_keypoints),
      interpolation='nearest').astype(float)

Defining Our Feature Configs

Now that we can compute our quantiles, we define a feature config for each feature that we want our model to take as input.

# Feature configs are used to specify how each feature is calibrated and used.
feature_configs = [
    tfl.configs.FeatureConfig(
        name='age',
        lattice_size=3,
        monotonicity='increasing',
        # We must set the keypoints manually.
        pwl_calibration_num_keypoints=5,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['age']],
            num_keypoints=5,
            clip_max=100),
        # Per feature regularization.
        regularizer_configs=[
            tfl.configs.RegularizerConfig(name='calib_wrinkle', l2=0.1),
        ],
    ),
    tfl.configs.FeatureConfig(
        name='sex',
        num_buckets=2,
    ),
    tfl.configs.FeatureConfig(
        name='cp',
        monotonicity='increasing',
        # Keypoints that are uniformly spaced.
        pwl_calibration_num_keypoints=4,
        pwl_calibration_input_keypoints=np.linspace(
            np.min(train_xs[feature_name_indices['cp']]),
            np.max(train_xs[feature_name_indices['cp']]),
            num=4),
    ),
    tfl.configs.FeatureConfig(
        name='chol',
        monotonicity='increasing',
        # Explicit input keypoints initialization.
        pwl_calibration_input_keypoints=[126.0, 210.0, 247.0, 286.0, 564.0],
        # Calibration can be forced to span the full output range by clamping.
        pwl_calibration_clamp_min=True,
        pwl_calibration_clamp_max=True,
        # Per feature regularization.
        regularizer_configs=[
            tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-4),
        ],
    ),
    tfl.configs.FeatureConfig(
        name='fbs',
        # Partial monotonicity: output(0) <= output(1)
        monotonicity=[(0, 1)],
        num_buckets=2,
    ),
    tfl.configs.FeatureConfig(
        name='trestbps',
        monotonicity='decreasing',
        pwl_calibration_num_keypoints=5,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['trestbps']], num_keypoints=5),
    ),
    tfl.configs.FeatureConfig(
        name='thalach',
        monotonicity='decreasing',
        pwl_calibration_num_keypoints=5,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['thalach']], num_keypoints=5),
    ),
    tfl.configs.FeatureConfig(
        name='restecg',
        # Partial monotonicity: output(0) <= output(1), output(0) <= output(2)
        monotonicity=[(0, 1), (0, 2)],
        num_buckets=3,
    ),
    tfl.configs.FeatureConfig(
        name='exang',
        # Partial monotonicity: output(0) <= output(1)
        monotonicity=[(0, 1)],
        num_buckets=2,
    ),
    tfl.configs.FeatureConfig(
        name='oldpeak',
        monotonicity='increasing',
        pwl_calibration_num_keypoints=5,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['oldpeak']], num_keypoints=5),
    ),
    tfl.configs.FeatureConfig(
        name='slope',
        # Partial monotonicity: output(0) <= output(1), output(1) <= output(2)
        monotonicity=[(0, 1), (1, 2)],
        num_buckets=3,
    ),
    tfl.configs.FeatureConfig(
        name='ca',
        monotonicity='increasing',
        pwl_calibration_num_keypoints=4,
        pwl_calibration_input_keypoints=compute_quantiles(
            train_xs[feature_name_indices['ca']], num_keypoints=4),
    ),
    tfl.configs.FeatureConfig(
        name='thal',
        # Partial monotonicity:
        # output(normal) <= output(fixed)
        # output(normal) <= output(reversible)
        monotonicity=[('normal', 'fixed'), ('normal', 'reversible')],
        num_buckets=3,
        # We must specify the vocabulary list in order to later set the
        # monotonicities since we used names and not indices.
        vocabulary_list=thal_vocab_list,
    ),
]

Next we need to make sure to properly set the monotonicities for features where we used a custom vocabulary (such as 'thal' above).

tfl.premade_lib.set_categorical_monotonicities(feature_configs)

Calibrated Linear Model

To construct a TFL premade model, first construct a model configuration from tfl.configs. A calibrated linear model is constructed using the tfl.configs.CalibratedLinearConfig. It applied piecewise-linear and categorical calibration on the input features, followed by a linear combination and an optional output piecewise-linear calibration. When using output calibration or when output bounds are specified, the linear layer will apply weighted averaging on calibrated inputs.

This example creates a calibrated linear model on the first 5 features.

# Model config defined the model structure for the estimator.
linear_model_config = tfl.configs.CalibratedLinearConfig(
    feature_configs=feature_configs[:5],
    use_bias=True,
    # We must set the output min and max to that of the label.
    output_min=min_label,
    output_max=max_label,
    output_calibration=True,
    output_initialization=[min_label, max_label],
    regularizer_configs=[
        # Regularizer for the output calibrator.
        tfl.configs.RegularizerConfig(name='output_calib_hessian', l2=1e-4),
    ])
# A CalibratedLinear premade model constructed from the given model config.
linear_model = tfl.premade.CalibratedLinear(linear_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(linear_model, show_layer_names=False, rankdir='LR')

png

Now, as with any other tf.keras.Model, we compile and fit the model to our data.

linear_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
linear_model.fit(
    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)
<tensorflow.python.keras.callbacks.History at 0x7ff3900e6048>

After training our model, we can evaluate it on our test set.

print('Test Set Evaluation...')
print(linear_model.evaluate(test_xs, test_ys))
Test Set Evaluation...
61/61 [==============================] - 0s 6ms/sample - loss: 0.4779 - auc: 0.8340
[0.47787934248564673, 0.83395994]

Calibrated Lattice Model

A calibrated lattice model is constructed using tfl.configs.CalibratedLatticeConfig. A calibrated lattice model applies piecewise-linear and categorical calibration on the input features, followed by a lattice model and an optional output piecewise-linear calibration.

This example creates a calibrated lattice model on the first 5 features.

# This is a calibrated lattice model: inputs are calibrated, then combined
# non-linearly using a lattice layer.
lattice_model_config = tfl.configs.CalibratedLatticeConfig(
    feature_configs=feature_configs[:5],
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label],
    regularizer_configs=[
        # Torsion regularizer applied to the lattice to make it more linear.
        tfl.configs.RegularizerConfig(name='torsion', l2=1e-2),
        # Globally defined calibration regularizer is applied to all features.
        tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-2),
    ])
# A CalibratedLattice premade model constructed from the given model config.
lattice_model = tfl.premade.CalibratedLattice(lattice_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(lattice_model, show_layer_names=False, rankdir='LR')

png

As before, we compile, fit, and evaluate our model.

lattice_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
lattice_model.fit(
    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)
print('Test Set Evaluation...')
print(lattice_model.evaluate(test_xs, test_ys))
Test Set Evaluation...
61/61 [==============================] - 0s 8ms/sample - loss: 0.4782 - auc_1: 0.8402
[0.4782405182963512, 0.8402256]

Calibrated Lattice Ensemble Model

When the number of features is large, you can use an ensemble model, which creates multiple smaller lattices for subsets of the features and averages their output instead of creating just a single huge lattice. Ensemble lattice models are constructed using tfl.configs.CalibratedLatticeEnsembleConfig. A calibrated lattice ensemble model applies piecewise-linear and categorical calibration on the input feature, followed by an ensemble of lattice models and an optional output piecewise-linear calibration.

Explicit Lattice Ensemble Initialization

If you already know which subsets of features you want to feed into your lattices, then you can explicitly set the lattices using feature names. This example creates a calibrated lattice ensemble model with 5 lattices and 3 features per lattice.

# This is a calibrated lattice ensemble model: inputs are calibrated, then
# combined non-linearly and averaged using multiple lattice layers.
explicit_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(
    feature_configs=feature_configs,
    lattices=[['trestbps', 'chol', 'ca'], ['fbs', 'restecg', 'thal'],
              ['fbs', 'cp', 'oldpeak'], ['exang', 'slope', 'thalach'],
              ['restecg', 'age', 'sex']],
    num_lattices=5,
    lattice_rank=3,
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label])
# A CalibratedLatticeEnsemble premade model constructed from the given
# model config.
explicit_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(
    explicit_ensemble_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(
    explicit_ensemble_model, show_layer_names=False, rankdir='LR')

png

As before, we compile, fit, and evaluate our model.

explicit_ensemble_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
explicit_ensemble_model.fit(
    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)
print('Test Set Evaluation...')
print(explicit_ensemble_model.evaluate(test_xs, test_ys))
Test Set Evaluation...
61/61 [==============================] - 0s 7ms/sample - loss: 0.4334 - auc_2: 0.8709
[0.433376672815104, 0.87092733]

Random Lattice Ensemble

If you are not sure which subsets of features to feed into your lattices, another option is to use random subsets of features for each lattice. This example creates a calibrated lattice ensemble model with 5 lattices and 3 features per lattice.

# This is a calibrated lattice ensemble model: inputs are calibrated, then
# combined non-linearly and averaged using multiple lattice layers.
random_ensemble_model_config = tfl.configs.CalibratedLatticeEnsembleConfig(
    feature_configs=feature_configs,
    lattices='random',
    num_lattices=5,
    lattice_rank=3,
    output_min=min_label,
    output_max=max_label - numerical_error_epsilon,
    output_initialization=[min_label, max_label])
# Now we must set the random lattice structure and construct the model.
tfl.premade_lib.set_random_lattice_ensemble(random_ensemble_model_config)
# A CalibratedLatticeEnsemble premade model constructed from the given
# model config.
random_ensemble_model = tfl.premade.CalibratedLatticeEnsemble(
    random_ensemble_model_config)
# Let's plot our model.
tf.keras.utils.plot_model(
    random_ensemble_model, show_layer_names=False, rankdir='LR')

png

As before, we compile, fit, and evaluate our model.

random_ensemble_model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[tf.keras.metrics.AUC()],
    optimizer=tf.keras.optimizers.Adam(LEARNING_RATE))
random_ensemble_model.fit(
    train_xs, train_ys, epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, verbose=False)
print('Test Set Evaluation...')
print(random_ensemble_model.evaluate(test_xs, test_ys))
Test Set Evaluation...
61/61 [==============================] - 0s 7ms/sample - loss: 0.3838 - auc_3: 0.8947
[0.3837592499177964, 0.8947368]