TF Lattice Custom Estimators

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

You can use custom estimators to create arbitrarily monotonic models using TFL layers. This guide outlines the steps needed to create such estimators.

Setup

Installing TF Lattice package:

pip install tensorflow-lattice

Importing required packages:

import tensorflow as tf

import logging
import numpy as np
import pandas as pd
import sys
import tensorflow_lattice as tfl
from tensorflow import feature_column as fc

from tensorflow_estimator.python.estimator.canned import optimizers
from tensorflow_estimator.python.estimator.head import binary_class_head
logging.disable(sys.maxsize)
2023-08-23 11:33:26.737232: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-08-23 11:33:26.737278: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-08-23 11:33:26.737310: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/distributions/distribution.py:259: ReparameterizationType.__init__ (from tensorflow.python.ops.distributions.distribution) is deprecated and will be removed after 2019-01-01.
Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/distributions/bernoulli.py:165: RegisterKL.__init__ (from tensorflow.python.ops.distributions.kullback_leibler) is deprecated and will be removed after 2019-01-01.
Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.

Downloading the UCI Statlog (Heart) dataset:

csv_file = tf.keras.utils.get_file(
    'heart.csv', 'http://storage.googleapis.com/download.tensorflow.org/data/heart.csv')
df = pd.read_csv(csv_file)
target = df.pop('target')
train_size = int(len(df) * 0.8)
train_x = df[:train_size]
train_y = target[:train_size]
test_x = df[train_size:]
test_y = target[train_size:]
df.head()

Setting the default values used for training in this guide:

LEARNING_RATE = 0.1
BATCH_SIZE = 128
NUM_EPOCHS = 1000

Feature Columns

As for any other TF estimator, data needs to be passed to the estimator, which is typically via an input_fn and parsed using FeatureColumns.

# Feature columns.
# - age
# - sex
# - ca        number of major vessels (0-3) colored by flourosopy
# - thal      3 = normal; 6 = fixed defect; 7 = reversable defect
feature_columns = [
    fc.numeric_column('age', default_value=-1),
    fc.categorical_column_with_vocabulary_list('sex', [0, 1]),
    fc.numeric_column('ca'),
    fc.categorical_column_with_vocabulary_list(
        'thal', ['normal', 'fixed', 'reversible']),
]

Note that categorical features do not need to be wrapped by a dense feature column, since tfl.laysers.CategoricalCalibration layer can directly consume category indices.

Creating input_fn

As for any other estimator, you can use input_fn to feed data to the model for training and evaluation.

train_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(
    x=train_x,
    y=train_y,
    shuffle=True,
    batch_size=BATCH_SIZE,
    num_epochs=NUM_EPOCHS,
    num_threads=1)

test_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(
    x=test_x,
    y=test_y,
    shuffle=False,
    batch_size=BATCH_SIZE,
    num_epochs=1,
    num_threads=1)

Creating model_fn

There are several ways to create a custom estimator. Here we will construct a model_fn that calls a Keras model on the parsed input tensors. To parse the input features, you can use tf.feature_column.input_layer, tf.keras.layers.DenseFeatures, or tfl.estimators.transform_features. If you use the latter, you will not need to wrap categorical features with dense feature columns, and the resulting tensors will not be concatenated, which makes it easier to use the features in the calibration layers.

To construct a model, you can mix and match TFL layers or any other Keras layers. Here we create a calibrated lattice Keras model out of TFL layers and impose several monotonicity constraints. We then use the Keras model to create the custom estimator.

def model_fn(features, labels, mode, config):
  """model_fn for the custom estimator."""
  del config
  input_tensors = tfl.estimators.transform_features(features, feature_columns)
  inputs = {
      key: tf.keras.layers.Input(shape=(1,), name=key) for key in input_tensors
  }

  lattice_sizes = [3, 2, 2, 2]
  lattice_monotonicities = ['increasing', 'none', 'increasing', 'increasing']
  lattice_input = tf.keras.layers.Concatenate(axis=1)([
      tfl.layers.PWLCalibration(
          input_keypoints=np.linspace(10, 100, num=8, dtype=np.float32),
          # The output range of the calibrator should be the input range of
          # the following lattice dimension.
          output_min=0.0,
          output_max=lattice_sizes[0] - 1.0,
          monotonicity='increasing',
      )(inputs['age']),
      tfl.layers.CategoricalCalibration(
          # Number of categories including any missing/default category.
          num_buckets=2,
          output_min=0.0,
          output_max=lattice_sizes[1] - 1.0,
      )(inputs['sex']),
      tfl.layers.PWLCalibration(
          input_keypoints=[0.0, 1.0, 2.0, 3.0],
          output_min=0.0,
          output_max=lattice_sizes[0] - 1.0,
          # You can specify TFL regularizers as tuple
          # ('regularizer name', l1, l2).
          kernel_regularizer=('hessian', 0.0, 1e-4),
          monotonicity='increasing',
      )(inputs['ca']),
      tfl.layers.CategoricalCalibration(
          num_buckets=3,
          output_min=0.0,
          output_max=lattice_sizes[1] - 1.0,
          # Categorical monotonicity can be partial order.
          # (i, j) indicates that we must have output(i) <= output(j).
          # Make sure to set the lattice monotonicity to 'increasing' for this
          # dimension.
          monotonicities=[(0, 1), (0, 2)],
      )(inputs['thal']),
  ])
  output = tfl.layers.Lattice(
      lattice_sizes=lattice_sizes, monotonicities=lattice_monotonicities)(
          lattice_input)

  training = (mode == tf.estimator.ModeKeys.TRAIN)
  model = tf.keras.Model(inputs=inputs, outputs=output)
  logits = model(input_tensors, training=training)

  if training:
    optimizer = optimizers.get_optimizer_instance_v2('Adagrad', LEARNING_RATE)
  else:
    optimizer = None

  head = binary_class_head.BinaryClassHead()
  return head.create_estimator_spec(
      features=features,
      mode=mode,
      labels=labels,
      optimizer=optimizer,
      logits=logits,
      trainable_variables=model.trainable_variables,
      update_ops=model.updates)

Training and Estimator

Using the model_fn we can create and train the estimator.

estimator = tf.estimator.Estimator(model_fn=model_fn)
estimator.train(input_fn=train_input_fn)
results = estimator.evaluate(input_fn=test_input_fn)
print('AUC: {}'.format(results['auc']))
2023-08-23 11:33:29.458817: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:268] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
AUC: 0.7857142686843872