Creating a custom Counterfactual Logit Pairing Dataset

Stay organized with collections Save and categorize content based on your preferences.

Applying Counterfactual Logit Pairing (CLP) to evaluate and improve the fairness of your model requires a counterfactual dataset. You create a counterfactual dataset by duplicating your existing dataset and changing the new dataset to add, remove, or modify identity terminology. This tutorial explains the approach and techniques for creating a counterfactual dataset for your existing text dataset.

You use your counterfactual dataset with the CLP technique by creating a new data object, CounterfactualPackedInputs, that contains the original_input and counterfactual_data, and looks like the following:

CounterfactualPackedInputs looks like the following:

CounterfactualPackedInputs(
  original_input=(x, y, sample_weight),
  counterfactual_data=(original_x, counterfactual_x,
                       counterfactual_sample_weight)
)

The original_input should be the original dataset that is used to train your Keras model. counterfactual_data should be a tf.data.Dataset with the original x value, the corresponding counterfactual_x value, and the counterfactual_sample_weight. The counterfactual_x value is nearly identical to the original value but with one or more of the attributes removed or replaced. This dataset is used to pair the loss function between the original value and the counterfactual value with the goal of assuring that the model’s prediction doesn’t change when the sensitive attribute is different. original_input and counterfactual_data need to be the same shape. You can duplicate values from counterfactual_data so that it’s the same number of elements as original_input.

Properties of counterfactual_data:

  • All original_x values need to have references to an identity group
  • Each counterfactual_x value is identical to the original value, but with one or more of the attributes removed or replaced
  • Have the same shape as original input (you can duplicate values so that they’re the same shape)

counterfactual_data does not need to:

  • Have overlap with data within original input
  • Have ground truth labels

Here’s an example of what a counterfactual_data would look like if you remove the term "gay".

original_x: “I am a gay man”
counterfactual_x: “I am a man” 
counterfactual_sample_weight”: 1

If you have a text classifier, you can use build_counterfactual_data to help create a counterfactual dataset. For all other data types, you need to provide a counterfactual dataset directly.

Setup

You'll begin by installing TensorFlow Model Remediation.

pip install --upgrade tensorflow-model-remediation
import tensorflow as tf
from tensorflow_model_remediation import counterfactual

Create a simple Dataset

For demonstrative purposes, we’ll create counterfactual data from the original input using build_counterfactual_dataset. Note that you can also construct counterfactual data from unlabeled data (as opposed to constructing it from original input). You will create a simple dataset with one sentence: “i am a gay man” which will serve as the original_input.

Build a Counterfactual Dataset

As this is a text classifier, you can create the counterfactual dataset with build_counterfactual_data in two ways:

  1. Remove terms: Use build_counterfactual_data to pass a list of words that will be removed from the dataset via tf.strings.regex_replace.
  2. Replace terms: Create a custom function to pass to build_counterfactual_data. This might include using more specific regex functions to replace words within your original dataset or to support non-text features

build_counterfactual_dataset takes in original_input and either removes or replaces terms depending on what optional parameters you pass. In most cases removing terms (option 1) should be sufficient to run CLP, however passing a custom function (option 2) is available for more precise control on the counterfactual values.

Option 1: List of Words to Remove

Pass in a list of gender-related terms to remove withbuild_counterfactual_data.

When using simple regex to create the counterfactual dataset, keep in mind that this may augment words that shouldn’t be changed. It is good practice to check that the changes made to the counterfactual_x value make sense in the context of the orginal_x value. Additionally, build_counterfactual_dataset will return only the values including a counterfactual instance. This could result in a different shape dataset from orginal_input, but it will be resized when passed to pack_counterfactual_data.

simple_dataset_x = tf.constant(
    ["I am a gay man" + str(i) for i in range(10)] +
    ["I am a man" + str(i) for i in range(10)])
print("Length of starting values: " + str(len(simple_dataset_x)))

simple_dataset = tf.data.Dataset.from_tensor_slices(
            (simple_dataset_x, None, None))

counterfactual_data = counterfactual.keras.utils.build_counterfactual_data(
    original_input=simple_dataset,
    sensitive_terms_to_remove=['gay'])

# Inspect the content of the TF Counterfactual Dataset
for original_value, counterfactual_value, _ in counterfactual_data.take(1):
  print("original: ", original_value)
  print("counterfactual: ", counterfactual_value)
print("Length of dataset after build_counterfactual_data: " +
      str(len(list(counterfactual_data))))
Length of starting values: 20
original:  tf.Tensor(b'I am a gay man0', shape=(), dtype=string)
counterfactual:  tf.Tensor(b'I am a  man0', shape=(), dtype=string)
Length of dataset after build_counterfactual_data: 10

Option 2: Custom Function

For more flexibility around ways of modifying your original dataset, you can instead pass a custom function to build_counterfactual_data.

In the example, you can consider replacing identity terms that reference men with those that reference women. This can be done by writing a function to replace a dictionary of words.

Note that the only limitation on the custom function is that it must be a callable to accept and return a tuple in the format used in Model.fit and should remove values that do not include any changes, which can be done by passing the terms to sensitive_terms_to_remove.

words_to_replace = {"man": "woman"}
print("Length of starting values: " + str(len(simple_dataset_x)))

def replace_words(original_batch):
  original_x, _, original_sample_weight = (
      tf.keras.utils.unpack_x_y_sample_weight(original_batch))
  for word in words_to_replace:
    counterfactual_x = tf.strings.regex_replace(
        original_x, f'{word}', words_to_replace[word])
  return tf.keras.utils.pack_x_y_sample_weight(
      original_x, counterfactual_x, sample_weight=original_sample_weight)

counterfactual_data = counterfactual.keras.utils.build_counterfactual_data(
    original_input=simple_dataset,
    sensitive_terms_to_remove=['gay'],
    custom_counterfactual_function=replace_words)

# Inspect the content of the TF Counterfactual Dataset
for original_value, counterfactual_value in counterfactual_data.take(1):
  print("original: ", original_value)
  print("counterfactual: ", counterfactual_value)
print("Length of dataset after build_counterfactual_data: " +
      str(len(list(counterfactual_data))))
Length of starting values: 20
original:  tf.Tensor(b'I am a gay man0', shape=(), dtype=string)
counterfactual:  tf.Tensor(b'I am a gay man0', shape=(), dtype=string)
Length of dataset after build_counterfactual_data: 10

To learn more, please see the API documents for build_counterfactual_data.