TensorFlow Constrained Optimization Example Using CelebA Dataset

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook

This notebook demonstrates an easy way to create and optimize constrained problems using the TFCO library. This method can be useful in improving models when we find that they’re not performing equally well across different slices of our data, which we can identify using Fairness Indicators. The second of Google’s AI principles states that our technology should avoid creating or reinforcing unfair bias, and we believe this technique can help improve model fairness in some situations. In particular, this notebook will:

  • Train a simple, unconstrained neural network model to detect a person's smile in images using tf.keras and the large-scale CelebFaces Attributes (CelebA) dataset.
  • Evaluate model performance against a commonly used fairness metric across age groups, using Fairness Indicators.
  • Set up a simple constrained optimization problem to achieve fairer performance across age groups.
  • Retrain the now constrained model and evaluate performance again, ensuring that our chosen fairness metric has improved.

Last updated: 3/11 Feb 2020

Installation

This notebook was created in Colaboratory, connected to the Python 3 Google Compute Engine backend. If you wish to host this notebook in a different environment, then you should not experience any major issues provided you include all the required packages in the cells below.

Note that the very first time you run the pip installs, you may be asked to restart the runtime because of preinstalled out of date packages. Once you do so, the correct packages will be used.

Pip installs

Note that depending on when you run the cell below, you may receive a warning about the default version of TensorFlow in Colab switching to TensorFlow 2.X soon. You can safely ignore that warning as this notebook was designed to be compatible with TensorFlow 1.X and 2.X.

Import Modules

2022-10-06 09:34:10.676061: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-10-06 09:34:11.489093: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-10-06 09:34:11.489323: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-10-06 09:34:11.489335: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

Additionally, we add a few imports that are specific to Fairness Indicators which we will use to evaluate and visualize the model's performance.

Although TFCO is compatible with eager and graph execution, this notebook assumes that eager execution is enabled by default as it is in TensorFlow 2.x. To ensure that nothing breaks, eager execution will be enabled in the cell below.

Enable Eager Execution and Print Versions

Eager execution enabled by default.
TensorFlow 2.10.0
TFMA 0.41.0
TFDS 4.7.0
FI 0.41.0

CelebA Dataset

CelebA is a large-scale face attributes dataset with more than 200,000 celebrity images, each with 40 attribute annotations (such as hair type, fashion accessories, facial features, etc.) and 5 landmark locations (eyes, mouth and nose positions). For more details take a look at the paper. With the permission of the owners, we have stored this dataset on Google Cloud Storage and mostly access it via TensorFlow Datasets(tfds).

In this notebook:

  • Our model will attempt to classify whether the subject of the image is smiling, as represented by the "Smiling" attribute*.
  • Images will be resized from 218x178 to 28x28 to reduce the execution time and memory when training.
  • Our model's performance will be evaluated across age groups, using the binary "Young" attribute. We will call this "age group" in this notebook.

* While there is little information available about the labeling methodology for this dataset, we will assume that the "Smiling" attribute was determined by a pleased, kind, or amused expression on the subject's face. For the purpose of this case study, we will take these labels as ground truth.

gcs_base_dir = "gs://celeb_a_dataset/"
celeb_a_builder = tfds.builder("celeb_a", data_dir=gcs_base_dir, version='2.0.0')

celeb_a_builder.download_and_prepare()

num_test_shards_dict = {'0.3.0': 4, '2.0.0': 2} # Used because we download the test dataset separately
version = str(celeb_a_builder.info.version)
print('Celeb_A dataset version: %s' % version)
Celeb_A dataset version: 2.0.0

Test dataset helper functions

Caveats

Before moving forward, there are several considerations to keep in mind in using CelebA:

  • Although in principle this notebook could use any dataset of face images, CelebA was chosen because it contains public domain images of public figures.
  • All of the attribute annotations in CelebA are operationalized as binary categories. For example, the "Young" attribute (as determined by the dataset labelers) is denoted as either present or absent in the image.
  • CelebA's categorizations do not reflect real human diversity of attributes.
  • For the purposes of this notebook, the feature containing the "Young" attribute is referred to as "age group", where the presence of the "Young" attribute in an image is labeled as a member of the "Young" age group and the absence of the "Young" attribute is labeled as a member of the "Not Young" age group. These are assumptions made as this information is not mentioned in the original paper.
  • As such, performance in the models trained in this notebook is tied to the ways the attributes have been operationalized and annotated by the authors of CelebA.
  • This model should not be used for commercial purposes as that would violate CelebA's non-commercial research agreement.

Setting Up Input Functions

The subsequent cells will help streamline the input pipeline as well as visualize performance.

First we define some data-related variables and define a requisite preprocessing function.

Define Variables

Define Preprocessing Functions

Then, we build out the data functions we need in the rest of the colab.

# Train data returning either 2 or 3 elements (the third element being the group)
def celeb_a_train_data_wo_group(batch_size):
  celeb_a_train_data = celeb_a_builder.as_dataset(split='train').shuffle(1024).repeat().batch(batch_size).map(preprocess_input_dict)
  return celeb_a_train_data.map(get_image_and_label)
def celeb_a_train_data_w_group(batch_size):
  celeb_a_train_data = celeb_a_builder.as_dataset(split='train').shuffle(1024).repeat().batch(batch_size).map(preprocess_input_dict)
  return celeb_a_train_data.map(get_image_label_and_group)

# Test data for the overall evaluation
celeb_a_test_data = celeb_a_builder.as_dataset(split='test').batch(1).map(preprocess_input_dict).map(get_image_label_and_group)
# Copy test data locally to be able to read it into tfma
copy_test_files_to_local()

Build a simple DNN Model

Because this notebook focuses on TFCO, we will assemble a simple, unconstrained tf.keras.Sequential model.

We may be able to greatly improve model performance by adding some complexity (e.g., more densely-connected layers, exploring different activation functions, increasing image size), but that may distract from the goal of demonstrating how easy it is to apply the TFCO library when working with Keras. For that reason, the model will be kept simple — but feel encouraged to explore this space.

def create_model():
  # For this notebook, accuracy will be used to evaluate performance.
  METRICS = [
    tf.keras.metrics.BinaryAccuracy(name='accuracy')
  ]

  # The model consists of:
  # 1. An input layer that represents the 28x28x3 image flatten.
  # 2. A fully connected layer with 64 units activated by a ReLU function.
  # 3. A single-unit readout layer to output real-scores instead of probabilities.
  model = keras.Sequential([
      keras.layers.Flatten(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), name='image'),
      keras.layers.Dense(64, activation='relu'),
      keras.layers.Dense(1, activation=None)
  ])

  # TFCO by default uses hinge loss — and that will also be used in the model.
  model.compile(
      optimizer=tf.keras.optimizers.Adam(0.001),
      loss='hinge',
      metrics=METRICS)
  return model

We also define a function to set seeds to ensure reproducible results. Note that this colab is meant as an educational tool and does not have the stability of a finely tuned production pipeline. Running without setting a seed may lead to varied results.

def set_seeds():
  np.random.seed(121212)
  tf.compat.v1.set_random_seed(212121)

Fairness Indicators Helper Functions

Before training our model, we define a number of helper functions that will allow us to evaluate the model's performance via Fairness Indicators.

First, we create a helper function to save our model once we train it.

def save_model(model, subdir):
  base_dir = tempfile.mkdtemp(prefix='saved_models')
  model_location = os.path.join(base_dir, subdir)
  model.save(model_location, save_format='tf')
  return model_location

Next, we define functions used to preprocess the data in order to correctly pass it through to TFMA.

Data Preprocessing functions for

Finally, we define a function that evaluates the results in TFMA.

def get_eval_results(model_location, eval_subdir):
  base_dir = tempfile.mkdtemp(prefix='saved_eval_results')
  tfma_eval_result_path = os.path.join(base_dir, eval_subdir)

  eval_config_pbtxt = """
        model_specs {
          label_key: "%s"
        }
        metrics_specs {
          metrics {
            class_name: "FairnessIndicators"
            config: '{ "thresholds": [0.22, 0.5, 0.75] }'
          }
          metrics {
            class_name: "ExampleCount"
          }
        }
        slicing_specs {}
        slicing_specs { feature_keys: "%s" }
        options {
          compute_confidence_intervals { value: False }
          disabled_outputs{values: "analysis"}
        }
      """ % (LABEL_KEY, GROUP_KEY)

  eval_config = text_format.Parse(eval_config_pbtxt, tfma.EvalConfig())

  eval_shared_model = tfma.default_eval_shared_model(
        eval_saved_model_path=model_location, tags=[tf.saved_model.SERVING])

  schema_pbtxt = """
        tensor_representation_group {
          key: ""
          value {
            tensor_representation {
              key: "%s"
              value {
                dense_tensor {
                  column_name: "%s"
                  shape {
                    dim { size: 28 }
                    dim { size: 28 }
                    dim { size: 3 }
                  }
                }
              }
            }
          }
        }
        feature {
          name: "%s"
          type: FLOAT
        }
        feature {
          name: "%s"
          type: FLOAT
        }
        feature {
          name: "%s"
          type: BYTES
        }
        """ % (IMAGE_KEY, IMAGE_KEY, IMAGE_KEY, LABEL_KEY, GROUP_KEY)
  schema = text_format.Parse(schema_pbtxt, schema_pb2.Schema())
  coder = tf_example_record.TFExampleBeamRecord(
      physical_format='inmem', schema=schema,
      raw_record_column_name=tfma.ARROW_INPUT_COLUMN)
  tensor_adapter_config = tensor_adapter.TensorAdapterConfig(
    arrow_schema=coder.ArrowSchema(),
    tensor_representations=coder.TensorRepresentations())
  # Run the fairness evaluation.
  with beam.Pipeline() as pipeline:
    _ = (
          tfds_as_pcollection(pipeline, 'celeb_a', 'test')
          | 'ExamplesToRecordBatch' >> coder.BeamSource()
          | 'ExtractEvaluateAndWriteResults' >>
          tfma.ExtractEvaluateAndWriteResults(
              eval_config=eval_config,
              eval_shared_model=eval_shared_model,
              output_path=tfma_eval_result_path,
              tensor_adapter_config=tensor_adapter_config)
    )
  return tfma.load_eval_result(output_path=tfma_eval_result_path)

Train & Evaluate Unconstrained Model

With the model now defined and the input pipeline in place, we’re now ready to train our model. To cut back on the amount of execution time and memory, we will train the model by slicing the data into small batches with only a few repeated iterations.

Note that running this notebook in TensorFlow < 2.0.0 may result in a deprecation warning for np.where. Safely ignore this warning as TensorFlow addresses this in 2.X by using tf.where in place of np.where.

BATCH_SIZE = 32

# Set seeds to get reproducible results
set_seeds()

model_unconstrained = create_model()
model_unconstrained.fit(celeb_a_train_data_wo_group(BATCH_SIZE), epochs=5, steps_per_epoch=1000)
Epoch 1/5
1000/1000 [==============================] - 27s 3ms/step - loss: 0.5210 - accuracy: 0.7593
Epoch 2/5
1000/1000 [==============================] - 3s 3ms/step - loss: 0.3573 - accuracy: 0.8387
Epoch 3/5
1000/1000 [==============================] - 3s 3ms/step - loss: 0.3460 - accuracy: 0.8460
Epoch 4/5
1000/1000 [==============================] - 3s 3ms/step - loss: 0.3432 - accuracy: 0.8450
Epoch 5/5
1000/1000 [==============================] - 3s 3ms/step - loss: 0.3228 - accuracy: 0.8559
<keras.callbacks.History at 0x7fb504787dc0>

Evaluating the model on the test data should result in a final accuracy score of just over 85%. Not bad for a simple model with no fine tuning.

print('Overall Results, Unconstrained')
celeb_a_test_data = celeb_a_builder.as_dataset(split='test').batch(1).map(preprocess_input_dict).map(get_image_label_and_group)
results = model_unconstrained.evaluate(celeb_a_test_data)
Overall Results, Unconstrained
WARNING:tensorflow:`evaluate()` received a value for `sample_weight`, but `weighted_metrics` were not provided.  Did you mean to pass metrics to `weighted_metrics` in `compile()`?  If this is intentional you can pass `weighted_metrics=[]` to `compile()` in order to silence this warning.
WARNING:tensorflow:`evaluate()` received a value for `sample_weight`, but `weighted_metrics` were not provided.  Did you mean to pass metrics to `weighted_metrics` in `compile()`?  If this is intentional you can pass `weighted_metrics=[]` to `compile()` in order to silence this warning.
19962/19962 [==============================] - 37s 2ms/step - loss: 0.2097 - accuracy: 0.8643

However, performance evaluated across age groups may reveal some shortcomings.

To explore this further, we evaluate the model with Fairness Indicators (via TFMA). In particular, we are interested in seeing whether there is a significant gap in performance between "Young" and "Not Young" categories when evaluated on false positive rate.

A false positive error occurs when the model incorrectly predicts the positive class. In this context, a false positive outcome occurs when the ground truth is an image of a celebrity 'Not Smiling' and the model predicts 'Smiling'. By extension, the false positive rate, which is used in the visualization above, is a measure of accuracy for a test. While this is a relatively mundane error to make in this context, false positive errors can sometimes cause more problematic behaviors. For instance, a false positive error in a spam classifier could cause a user to miss an important email.

model_location = save_model(model_unconstrained, 'model_export_unconstrained')
eval_results_unconstrained = get_eval_results(model_location, 'eval_results_unconstrained')
INFO:tensorflow:Assets written to: /tmpfs/tmp/saved_modelsr9qp0f4m/model_export_unconstrained/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/saved_modelsr9qp0f4m/model_export_unconstrained/assets
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.
WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.
WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.8/site-packages/tensorflow_model_analysis/writers/metrics_plots_and_validations_writer.py:110: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`
WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.8/site-packages/tensorflow_model_analysis/writers/metrics_plots_and_validations_writer.py:110: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

As mentioned above, we are concentrating on the false positive rate. The current version of Fairness Indicators (0.1.2) selects false negative rate by default. After running the line below, deselect false_negative_rate and select false_positive_rate to look at the metric we are interested in.

tfma.addons.fairness.view.widget_view.render_fairness_indicator(eval_results_unconstrained)
FairnessIndicatorViewer(slicingMetrics=[{'sliceValue': 'Young', 'slice': 'Young:Young', 'metrics': {'example_c…

As the results show above, we do see a disproportionate gap between "Young" and "Not Young" categories.

This is where TFCO can help by constraining the false positive rate to be within a more acceptable criterion.

Constrained Model Set Up

As documented in TFCO's library, there are several helpers that will make it easier to constrain the problem:

  1. tfco.rate_context() – This is what will be used in constructing a constraint for each age group category.
  2. tfco.RateMinimizationProblem()– The rate expression to be minimized here will be the false positive rate subject to age group. In other words, performance now will be evaluated based on the difference between the false positive rates of the age group and that of the overall dataset. For this demonstration, a false positive rate of less than or equal to 5% will be set as the constraint.
  3. tfco.ProxyLagrangianOptimizerV2() – This is the helper that will actually solve the rate constraint problem.

The cell below will call on these helpers to set up model training with the fairness constraint.

# The batch size is needed to create the input, labels and group tensors.
# These tensors are initialized with all 0's. They will eventually be assigned
# the batch content to them. A large batch size is chosen so that there are
# enough number of "Young" and "Not Young" examples in each batch.
set_seeds()
model_constrained = create_model()
BATCH_SIZE = 32

# Create input tensor.
input_tensor = tf.Variable(
    np.zeros((BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3), dtype="float32"),
    name="input")

# Create labels and group tensors (assuming both labels and groups are binary).
labels_tensor = tf.Variable(
    np.zeros(BATCH_SIZE, dtype="float32"), name="labels")
groups_tensor = tf.Variable(
    np.zeros(BATCH_SIZE, dtype="float32"), name="groups")

# Create a function that returns the applied 'model' to the input tensor
# and generates constrained predictions.
def predictions():
  return model_constrained(input_tensor)

# Create overall context and subsetted context.
# The subsetted context contains subset of examples where group attribute < 1
# (i.e. the subset of "Not Young" celebrity images).
# "groups_tensor < 1" is used instead of "groups_tensor == 0" as the former
# would be a comparison on the tensor value, while the latter would be a
# comparison on the Tensor object.
context = tfco.rate_context(predictions, labels=lambda:labels_tensor)
context_subset = context.subset(lambda:groups_tensor < 1)

# Setup list of constraints.
# In this notebook, the constraint will just be: FPR to less or equal to 5%.
constraints = [tfco.false_positive_rate(context_subset) <= 0.05]

# Setup rate minimization problem: minimize overall error rate s.t. constraints.
problem = tfco.RateMinimizationProblem(tfco.error_rate(context), constraints)

# Create constrained optimizer and obtain train_op.
# Separate optimizers are specified for the objective and constraints
optimizer = tfco.ProxyLagrangianOptimizerV2(
      optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
      constraint_optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
      num_constraints=problem.num_constraints)

# A list of all trainable variables is also needed to use TFCO.
var_list = (model_constrained.trainable_weights + list(problem.trainable_variables) +
            optimizer.trainable_variables())

The model is now set up and ready to be trained with the false positive rate constraint across age group.

Now, because the last iteration of the constrained model may not necessarily be the best performing model in terms of the defined constraint, the TFCO library comes equipped with tfco.find_best_candidate_index() that can help choose the best iterate out of the ones found after each epoch. Think of tfco.find_best_candidate_index() as an added heuristic that ranks each of the outcomes based on accuracy and fairness constraint (in this case, false positive rate across age group) separately with respect to the training data. That way, it can search for a better trade-off between overall accuracy and the fairness constraint.

The following cells will start the training with constraints while also finding the best performing model per iteration.

# Obtain train set batches.

NUM_ITERATIONS = 100  # Number of training iterations.
SKIP_ITERATIONS = 10  # Print training stats once in this many iterations.

# Create temp directory for saving snapshots of models.
temp_directory = tempfile.mktemp()
os.mkdir(temp_directory)

# List of objective and constraints across iterations.
objective_list = []
violations_list = []

# Training iterations.
iteration_count = 0
for (image, label, group) in celeb_a_train_data_w_group(BATCH_SIZE):
  # Assign current batch to input, labels and groups tensors.
  input_tensor.assign(image)
  labels_tensor.assign(label)
  groups_tensor.assign(group)

  # Run gradient update.
  optimizer.minimize(problem, var_list=var_list)

  # Record objective and violations.
  objective = problem.objective()
  violations = problem.constraints()

  sys.stdout.write(
      "\r Iteration %d: Hinge Loss = %.3f, Max. Constraint Violation = %.3f"
      % (iteration_count + 1, objective, max(violations)))

  # Snapshot model once in SKIP_ITERATIONS iterations.
  if iteration_count % SKIP_ITERATIONS == 0:
    objective_list.append(objective)
    violations_list.append(violations)

    # Save snapshot of model weights.
    model_constrained.save_weights(
        temp_directory + "/celeb_a_constrained_" +
        str(iteration_count / SKIP_ITERATIONS) + ".h5")

  iteration_count += 1
  if iteration_count >= NUM_ITERATIONS:
    break

# Choose best model from recorded iterates and load that model.
best_index = tfco.find_best_candidate_index(
    np.array(objective_list), np.array(violations_list))

model_constrained.load_weights(
    temp_directory + "/celeb_a_constrained_" + str(best_index) + ".0.h5")

# Remove temp directory.
os.system("rm -r " + temp_directory)
Iteration 100: Hinge Loss = 0.485, Max. Constraint Violation = 0.612
0

After having applied the constraint, we evaluate the results once again using Fairness Indicators.

model_location = save_model(model_constrained, 'model_export_constrained')
eval_result_constrained = get_eval_results(model_location, 'eval_results_constrained')
INFO:tensorflow:Assets written to: /tmpfs/tmp/saved_modelso5ldfuga/model_export_constrained/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/saved_modelso5ldfuga/model_export_constrained/assets

As with the previous time we used Fairness Indicators, deselect false_negative_rate and select false_positive_rate to look at the metric we are interested in.

Note that to fairly compare the two versions of our model, it is important to use thresholds that set the overall false positive rate to be roughly equal. This ensures that we are looking at actual change as opposed to just a shift in the model equivalent to simply moving the threshold boundary. In our case, comparing the unconstrained model at 0.5 and the constrained model at 0.22 provides a fair comparison for the models.

eval_results_dict = {
    'constrained': eval_result_constrained,
    'unconstrained': eval_results_unconstrained,
}
tfma.addons.fairness.view.widget_view.render_fairness_indicator(multi_eval_results=eval_results_dict)
FairnessIndicatorViewer(evalName='constrained', evalNameCompare='unconstrained', slicingMetrics=[{'sliceValue'…

With TFCO's ability to express a more complex requirement as a rate constraint, we helped this model achieve a more desirable outcome with little impact to the overall performance. There is, of course, still room for improvement, but at least TFCO was able to find a model that gets close to satisfying the constraint and reduces the disparity between the groups as much as possible.