model_remediation.min_diff.keras.utils.pack_min_diff_data

Packs min_diff_data with the x component of the original dataset.

original_dataset tf.data.Dataset that was used before applying min diff. The output should conform to the format used in tf.keras.Model.fit.
sensitive_group_dataset tf.data.Dataset containing only examples that belong to the sensitive group. The output should have the same structure as that of original_dataset.
nonsensitive_group_dataset tf.data.Dataset containing only examples that do not belong to the sensitive group. The output should have the same structure as that of `original_dataset.

This function should be used to create the dataset that will be passed to min_diff.keras.MinDiffModel during training and, optionally, during evaluation.

Each input dataset must output a tuple in the format used in tf.keras.Model.fit. Specifically the output must be a tuple of length 1, 2 or 3 in the form (x, y, sample_weight).

This output will be parsed internally in the following way:

batch = ...  # Batch from any one of the input datasets.
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(batch)

The tf.data.Dataset returned will output a tuple of (packed_inputs, y, sample_weight) where:

  • packed_inputs: is an instance of utils.MinDiffPackedInputs containing:

    • original_inputs: x component from the original_dataset.
    • min_diff_data: data formed from sensitive_group_dataset and nonsensitive_group_dataset as described below.
  • y: is the y component taken directly from original_dataset.

  • sample_weight: is the sample_weight component taken directly from original_dataset.

min_diff_data will be used in min_diff.keras.MinDiffModel when calculating the min_diff_loss. It is a tuple of (min_diff_x, min_diff_membership, min_diff_sample_weight) where:

  • min_diff_x: is formed by concatenating the x components of sensitive_dataset and nonsensitive_dataset.
  • min_diff_membership: is a tensor of size [min_diff_batch_size, 1] indicating which dataset each example comes from (1.0 for sensitive_group_dataset and 0.0 for nonsensitive_group_dataset).
  • min_diff_sample_weight: is formed by concatenating the sample_weight components of sensitive_datasetand nonsensitive_dataset. If both are None, then this will be set to None. If only one is None, it is replaced with a Tensor of ones of the appropriate shape.

A tf.data.Dataset whose output is a tuple of (packed_inputs, y, sample_weight).