![]() |
Packs min_diff_data
with the x
component of the original dataset.
model_remediation.min_diff.keras.utils.pack_min_diff_data(
original_dataset: tf.data.Dataset,
sensitive_group_dataset: tf.data.Dataset,
nonsensitive_group_dataset: tf.data.Dataset
) -> tf.data.Dataset
Arguments | |
---|---|
original_dataset
|
tf.data.Dataset that was used before applying min
diff. The output should conform to the format used in
tf.keras.Model.fit .
|
sensitive_group_dataset
|
tf.data.Dataset containing only examples that
belong to the sensitive group. The output should have the same structure
as that of original_dataset .
|
nonsensitive_group_dataset
|
tf.data.Dataset containing only examples that
do not belong to the sensitive group. The output should have the same
structure as that of `original_dataset.
|
This function should be used to create the dataset that will be passed to
min_diff.keras.MinDiffModel
during training and, optionally, during
evaluation.
Each input dataset must output a tuple in the format used in
tf.keras.Model.fit
. Specifically the output must be a tuple of
length 1, 2 or 3 in the form (x, y, sample_weight)
.
This output will be parsed internally in the following way:
batch = ... # Batch from any one of the input datasets.
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(batch)
The tf.data.Dataset
returned will output a tuple of (packed_inputs, y,
sample_weight)
where:
packed_inputs
: is an instance ofutils.MinDiffPackedInputs
containing:original_inputs
:x
component from theoriginal_dataset
.min_diff_data
: data formed fromsensitive_group_dataset
andnonsensitive_group_dataset
as described below.
y
: is they
component taken directly fromoriginal_dataset
.sample_weight
: is thesample_weight
component taken directly fromoriginal_dataset
.
min_diff_data
will be used in min_diff.keras.MinDiffModel
when calculating
the min_diff_loss
. It is a tuple of (min_diff_x, min_diff_membership,
min_diff_sample_weight)
where:
min_diff_x
: is formed by concatenating thex
components ofsensitive_dataset
andnonsensitive_dataset
.min_diff_membership
: is a tensor of size[min_diff_batch_size, 1]
indicating which dataset each example comes from (1.0
forsensitive_group_dataset
and0.0
fornonsensitive_group_dataset
).min_diff_sample_weight
: is formed by concatenating thesample_weight
components ofsensitive_dataset
andnonsensitive_dataset
. If both areNone
, then this will be set toNone
. If only one isNone
, it is replaced with aTensor
of ones of the appropriate shape.
Returns | |
---|---|
A tf.data.Dataset whose output is a tuple of (packed_inputs , y ,
sample_weight ).
|