View source on GitHub |
Packs min_diff_data
with the x
component of the original dataset.
model_remediation.min_diff.keras.utils.pack_min_diff_data(
original_dataset: tf.data.Dataset,
sensitive_group_dataset=None,
nonsensitive_group_dataset=None,
min_diff_dataset=None
) -> tf.data.Dataset
Arguments | |
---|---|
original_dataset
|
tf.data.Dataset that was used before applying min
diff. The output should conform to the format used in
tf.keras.Model.fit .
|
sensitive_group_dataset
|
tf.data.Dataset or valid MinDiff structure
(unnested dict) of tf.data.Dataset s containing only examples that
belong to the sensitive group.
This must be passed in if |
nonsensitive_group_dataset
|
tf.data.Dataset or valid MinDiff structure
(unnested dict) of tf.data.Dataset s containing only examples that do
not belong to the sensitive group.
This must be passed in if |
min_diff_dataset
|
tf.data.Dataset or valid MinDiff structure (unnested
dict) of tf.data.Dataset s containing only examples to be used to
calculate the min_diff_loss .
This should only be set if neither |
This function should be used to create the dataset that will be passed to
min_diff.keras.MinDiffModel
during training and, optionally, during
evaluation.
The inputs should either have both sensitive_group_dataset
and
nonsensitive_group_dataset
passed in and min_diff_dataset
left unset or
vice versa. In the case of the former, min_diff_data
will be built using
utils.build_min_diff_dataset
.
Each input dataset must output a tuple in the format used in
tf.keras.Model.fit
. Specifically the output must be a tuple of
length 1, 2 or 3 in the form (x, y, sample_weight)
.
This output will be parsed internally in the following way:
batch = ... # Batch from any one of the input datasets.
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(batch)
Every batch from the returned tf.data.Dataset
will contain one batch from
each of the input datasets. Each returned batch will be a tuple of
(packed_inputs, original_y, original_sample_weight)
matching the length of
original_dataset
batches where:
packed_inputs
: is an instance ofutils.MinDiffPackedInputs
containing:original_inputs
:x
component taken directly from theoriginal_dataset
batch.min_diff_data
: batch of data formed fromsensitive_group_dataset
andnonsensitive_group_dataset
(as described inutils.build_min_diff_dataset
) or taken directly frommin_diff_dataset
.
original_y
: is they
component taken directly from theoriginal_dataset
batch.original_sample_weight
: is thesample_weight
component taken directly from theoriginal_dataset
batch.
min_diff_data
will be used in min_diff.keras.MinDiffModel
when calculating
the min_diff_loss
. It is a tuple or structure (matching the structure of the
inputs) of (min_diff_x, min_diff_membership, min_diff_sample_weight)
.
Returns | |
---|---|
A tf.data.Dataset whose output is a tuple of (packed_inputs ,
original_y , original_sample_weight ) matching the output length
of original_dataset .
|