model_remediation.counterfactual.keras.utils.build_counterfactual_data

Build Counterfactual dataset from a list sensitive terms or custom function.

original_input tf.data.Dataset that was used before applying Counterfactual. The output should conform to the format used in tf.keras.Model.fit.
sensitive_terms_to_remove List of terms that will be removed and filtered from within the original_input.
custom_counterfactual_function Optional custom function to apply to tf.data.Dataset.map to build a custom counterfactual dataset. Note that it needs to return a dataset in the form of (original_x, counterfactual_x, counterfactual_sample_weight) and should only include values that have been modified. Use sensitive_terms_to_remove to filter values that have modifying terms included.

This function builds a tf.data.Dataset containing only examples that will be used when calculating counterfactual_loss. This resulting dataset will need to be packed with the x value in original_input, a modified version of x that will act as counterfactual_x, and a counterfactual_sample_weight that defaults to 1.0. The resulting dataset can be passed to pack_counterfactual_data to create an instance of CounterfactualPackedInputs for use within counterfactual.keras.CounterfactualModel.

original_input must output a tuple in the format used in tf.keras.Model.fit. Specifically the output must be a tuple of length 1, 2 or 3 in the form (x, y, sample_weight). This output will be parsed internally in the following way:

x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(batch)

Alternatively to passing a sensitive_terms_to_remove, you can create a custom function that you can pass to the original to create a counterfactual dataset as specficied by the users. For example, you might want to replace a target word instead of simply removing the word. The returned tf.data.Dataset will need to have the unchanged x values removed. Passing sensitive_terms_to_remove in this case acts like a filter to only include terms that have been modified.

A minimal example is given below:

simple_dataset_x = tf.constant(["Bad word", "Good word"])
simple_dataset = tf.data.Dataset.from_tensor_slices((simple_dataset_x))
counterfactual_data = counterfactual.keras.utils.build_counterfactual_data(
  original_input=simple_dataset, sensitive_terms_to_remove=['Bad'])
for original_value, counterfactual_value, _ in counterfactual_data.take(1):
  print("original: ", original_value)
  print("counterfactual: ", counterfactual_value)
  print("counterfactual_sample_weight: ", cf_weight)
original:  tf.Tensor(b'Bad word', shape=(), dtype=string)
counterfactual:  tf.Tensor(b' word', shape=(), dtype=string)
counterfactual_sample_weight:  tf.Tensor(1.0, shape=(), dtype=float32)

A tf.data.Dataset whose output is a tuple matching (original_x, counterfactual_x, counterfactual_sample_weight).

ValueError If both custom_counterfactual_function and sensitive_terms_to_remove are not provided.