Have a question? Connect with the community at the TensorFlow Forum Visit Forum

Model Remediation Case Study

In this notebook, we’ll train a text classifier to identify written content that could be considered toxic or harmful, and apply MinDiff to remediate some fairness concerns. In our workflow, we will:

  1. Evaluate our baseline model’s performance on text containing references to sensitive groups.
  2. Improve performance on any underperforming groups by training with MinDiff.
  3. Evaluate the new model’s performance on our chosen metric.

Our purpose is to demonstrate usage of the MinDiff technique with a very minimal workflow, not to lay out a principled approach to fairness in machine learning. As such, our evaluation will only focus on one sensitive category and a single metric. We also don’t address potential shortcomings in the dataset, nor tune our configurations. In a production setting, you would want to approach each of these with rigor. For more information on evaluating for fairness, see this guide.


We begin by installing Fairness Indicators and TensorFlow Model Remediation.


Import all necessary components, including MinDiff and Fairness Indicators for evaluation.


We use a utility function to download the preprocessed data and prepare the labels to match the model’s output shape. The function also downloads the data as TFRecords to make later evaluation quicker. Alternatively, you may convert the Pandas DataFrame into TFRecords with any available utility conversion function.

# We use a helper utility to preprocessed data for convenience and speed.
data_train, data_validate, validate_tfrecord_file, labels_train, labels_validate = min_diff_keras_utils.download_and_process_civil_comments_data()
Downloading data from https://storage.googleapis.com/civil_comments_dataset/train_df_processed.csv
345702400/345699197 [==============================] - 8s 0us/step
Downloading data from https://storage.googleapis.com/civil_comments_dataset/validate_df_processed.csv
229974016/229970098 [==============================] - 5s 0us/step
Downloading data from https://storage.googleapis.com/civil_comments_dataset/validate_tf_processed.tfrecord
324943872/324941336 [==============================] - 9s 0us/step

We define a few useful constants. We will train the model on the ’comment_text’ feature, with our target label as ’toxicity’. Note that the batch size here is chosen arbitrarily, but in a production setting you would need to tune it for best performance.

TEXT_FEATURE = 'comment_text'
LABEL = 'toxicity'

Set random seeds. (Note that this does not fully stabilize results.)


Define and train the baseline model

To reduce runtime, we use a pretrained model by default. It is a simple Keras sequential model with an initial embedding and convolution layers, outputting a toxicity prediction. If you prefer, you can change this and train from scratch using our utility function to create the model. (Note that since your environment is likely different from ours, you would need to customize the tuning and evaluation thresholds.)

use_pretrained_model = True

if use_pretrained_model:
  URL = 'https://storage.googleapis.com/civil_comments_model/baseline_model.zip'
  BASE_PATH = tempfile.mkdtemp()
  ZIP_PATH = os.path.join(BASE_PATH, 'baseline_model.zip')
  MODEL_PATH = os.path.join(BASE_PATH, 'tmp/baseline_model')

  r = requests.get(URL, allow_redirects=True)
  open(ZIP_PATH, 'wb').write(r.content)

  with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
  baseline_model = tf.keras.models.load_model(
      MODEL_PATH, custom_objects={'KerasLayer' : hub.KerasLayer})
  optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
  loss = tf.keras.losses.BinaryCrossentropy()

  baseline_model = min_diff_keras_utils.create_keras_sequential_model()

  baseline_model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])


We save the model in order to evaluate using Fairness Indicators.

base_dir = tempfile.mkdtemp(prefix='saved_models')
baseline_model_location = os.path.join(base_dir, 'model_export_baseline')
baseline_model.save(baseline_model_location, save_format='tf')
INFO:tensorflow:Assets written to: /tmp/saved_models867b8d74/model_export_baseline/assets
INFO:tensorflow:Assets written to: /tmp/saved_models867b8d74/model_export_baseline/assets

Next we run Fairness Indicators. As a reminder, we’re just going to perform sliced evaluation for comments referencing one category, religious groups. In a production environment, we recommend taking a thoughtful approach to determining which categories and metrics to evaluate across.

To compute model performance, the utility function makes a few convenient choices for metrics, slices, and classifier thresholds.

# We use a helper utility to hide the evaluation logic for readability.
base_dir = tempfile.mkdtemp(prefix='eval')
eval_dir = os.path.join(base_dir, 'tfma_eval_result')
eval_result = fi_util.get_eval_results(
    baseline_model_location, eval_dir, validate_tfrecord_file)
WARNING:absl:Tensorflow version (2.5.0) found. Note that TFMA support for TF 2.0 is currently in beta
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.
WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_model_analysis/writers/metrics_plots_and_validations_writer.py:113: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_model_analysis/writers/metrics_plots_and_validations_writer.py:113: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 

Render Evaluation Results

FairnessIndicatorViewer(slicingMetrics=[{'sliceValue': 'Overall', 'slice': 'Overall', 'metrics': {'accuracy': …

Let’s look at the evaluation results. Try selecting the metric false positive rate (FPR) with threshold 0.450. We can see that the model does not perform as well for some religious groups as for others, displaying a much higher FPR. Note the wide confidence intervals on some groups because they have too few examples. This makes it difficult to say with certainty that there is a significant difference in performance for these slices. We may want to collect more examples to address this issue. We can, however, attempt to apply MinDiff for the two groups that we are confident are underperforming.

We’ve chosen to focus on FPR, because a higher FPR means that comments referencing these identity groups are more likely to be incorrectly flagged as toxic than other comments. This could lead to inequitable outcomes for users engaging in dialogue about religion, but note that disparities in other metrics can lead to other types of harm.

Define and Train the MinDiff Model

Now, we’ll try to improve the FPR for underperforming religious groups. We’ll attempt to do so using MinDiff, a remediation technique that seeks to balance error rates across slices of your data by penalizing disparities in performance during training. When we apply MinDiff, model performance may degrade slightly on other slices. As such, our goals with MinDiff will be:

  • Improved performance for underperforming groups
  • Limited degradation for other groups and overall performance

Prepare your data

To use MinDiff, we create two additional data splits:

  • A split for non-toxic examples referencing minority groups: In our case, this will include comments with references to our underperforming identity terms. We don’t include some of the groups because there are too few examples, leading to higher uncertainty with wide confidence interval ranges.
  • A split for non-toxic examples referencing the majority group.

It’s important to have sufficient examples belonging to the underperforming classes. Based on your model architecture, data distribution, and MinDiff configuration, the amount of data needed can vary significantly. In past applications, we have seen MinDiff work well with 5,000 examples in each data split.

In our case, the groups in the minority splits have example quantities of 9,688 and 3,906. Note the class imbalances in the dataset; in practice, this could be cause for concern, but we won’t seek to address them in this notebook since our intention is just to demonstrate MinDiff.

We select only negative examples for these groups, so that MinDiff can optimize on getting these examples right. It may seem counterintuitive to carve out sets of ground truth negative examples if we’re primarily concerned with disparities in false positive rate, but remember that a false positive prediction is a ground truth negative example that’s incorrectly classified as positive, which is the issue we’re trying to address.

Create MinDiff DataFrames

# Create masks for the sensitive and nonsensitive groups
minority_mask = data_train.religion.apply(
    lambda x: any(religion in x for religion in ('jewish', 'muslim')))
majority_mask = data_train.religion.apply(lambda x: x == "['christian']")

# Select nontoxic examples, so MinDiff will be able to reduce sensitive FP rate.
true_negative_mask = data_train['toxicity'] == 0

data_train_main = copy.copy(data_train)
data_train_sensitive = data_train[minority_mask & true_negative_mask]
data_train_nonsensitive = data_train[majority_mask & true_negative_mask]

We also need to convert our Pandas DataFrames into Tensorflow Datasets for MinDiff input. Note that unlike the Keras model API for Pandas DataFrames, using Datasets means that we need to provide the model’s input features and labels together in one Dataset. Here we provide the 'comment_text' as an input feature and reshape the label to match the model's expected output.

We batch the Dataset at this stage, too, since MinDiff requires batched Datasets. Note that we tune the batch size selection the same way it is tuned for the baseline model, taking into account training speed and hardware considerations while balancing with model performance. Here we have chosen the same batch size for all three datasets but this is not a requirement, although it’s good practice to have the two MinDiff batch sizes be equivalent.

Create MinDiff Datasets

# Convert the pandas DataFrames to Datasets.
dataset_train_main = tf.data.Dataset.from_tensor_slices(
     data_train_main.pop(LABEL).values.reshape(-1,1) * 1.0)).batch(BATCH_SIZE)
dataset_train_sensitive = tf.data.Dataset.from_tensor_slices(
     data_train_sensitive.pop(LABEL).values.reshape(-1,1) * 1.0)).batch(BATCH_SIZE)
dataset_train_nonsensitive = tf.data.Dataset.from_tensor_slices(
     data_train_nonsensitive.pop(LABEL).values.reshape(-1,1) * 1.0)).batch(BATCH_SIZE)

Train and evaluate the model

To train with MinDiff, simply take the original model and wrap it in a MinDiffModel with a corresponding loss and loss_weight. We are using 1.5 as the default loss_weight, but this is a parameter that needs to be tuned for your use case, since it depends on your model and product requirements. You can experiment with changing the value to see how it impacts the model, noting that increasing it pushes the performance of the minority and majority groups closer together but may come with more pronounced tradeoffs.

Then we compile the model normally (using the regular non-MinDiff loss) and fit to train.

Train MinDiffModel

use_pretrained_model = True

base_dir = tempfile.mkdtemp(prefix='saved_models')
min_diff_model_location = os.path.join(base_dir, 'model_export_min_diff')

if use_pretrained_model:
  BASE_MIN_DIFF_PATH = tempfile.mkdtemp()
  MIN_DIFF_URL = 'https://storage.googleapis.com/civil_comments_model/min_diff_model.zip'
  ZIP_PATH = os.path.join(BASE_PATH, 'min_diff_model.zip')
  MIN_DIFF_MODEL_PATH = os.path.join(BASE_MIN_DIFF_PATH, 'tmp/min_diff_model')
  DIRPATH = '/tmp/min_diff_model'

  r = requests.get(MIN_DIFF_URL, allow_redirects=True)
  open(ZIP_PATH, 'wb').write(r.content)

  with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
  min_diff_model = tf.keras.models.load_model(
      MIN_DIFF_MODEL_PATH, custom_objects={'KerasLayer' : hub.KerasLayer})

  min_diff_model.save(min_diff_model_location, save_format='tf')

  min_diff_weight = 1.5

  # Create the dataset that will be passed to the MinDiffModel during training.
  dataset = md.keras.utils.input_utils.pack_min_diff_data(
      dataset_train_main, dataset_train_sensitive, dataset_train_nonsensitive)

  # Create the original model.
  original_model = min_diff_keras_utils.create_keras_sequential_model()

  # Wrap the original model in a MinDiffModel, passing in one of the MinDiff
  # losses and using the set loss_weight.
  min_diff_loss = md.losses.MMDLoss()
  min_diff_model = md.keras.MinDiffModel(original_model,

  # Compile the model normally after wrapping the original model.  Note that
  # this means we use the baseline's model's loss here.
  optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
  loss = tf.keras.losses.BinaryCrossentropy()
  min_diff_model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])

  min_diff_model.fit(dataset, epochs=20)

  min_diff_model.save_original_model(min_diff_model_location, save_format='tf')
INFO:tensorflow:Assets written to: /tmp/saved_modelsb3zkcos_/model_export_min_diff/assets
INFO:tensorflow:Assets written to: /tmp/saved_modelsb3zkcos_/model_export_min_diff/assets

Next we evaluate the results.

min_diff_eval_subdir = os.path.join(base_dir, 'tfma_eval_result')
min_diff_eval_result = fi_util.get_eval_results(
WARNING:absl:Tensorflow version (2.5.0) found. Note that TFMA support for TF 2.0 is currently in beta

To ensure we evaluate a new model correctly, we need to select a threshold the same way that we would the baseline model. In a production setting, this would mean ensuring that evaluation metrics meet launch standards. In our case, we will pick the threshold that results in a similar overall FPR to the baseline model. This threshold may be different from the one you selected for the baseline model. Try selecting false positive rate with threshold 0.400. (Note that the subgroups with very low quantity examples have very wide confidence range intervals and don’t have predictable results.)

FairnessIndicatorViewer(slicingMetrics=[{'sliceValue': 'Overall', 'slice': 'Overall', 'metrics': {'accuracy': …

Reviewing these results, you may notice that the FPRs for our target groups have improved. The gap between our lowest performing group and the majority group has improved from .024 to .006. Given the improvements we’ve observed and the continued strong performance for the majority group, we’ve satisfied both of our goals. Depending on the product, further improvements may be necessary, but this approach has gotten our model one step closer to performing equitably for all users.