View source on GitHub
|
Packs counterfactual_data with the original_input.
model_remediation.counterfactual.keras.utils.pack_counterfactual_data(
original_input: tf.data.Dataset, counterfactual_data: tf.data.Dataset
) -> tf.data.Dataset
Arguments | |
|---|---|
original_input
|
An instance of tf.data.Dataset that was used for training
the original model. The output should conform to the format used in
tf.keras.Model.fit.
|
counterfactual_data
|
An instance of tf.data.Dataset containing only
examples that will be used to calculate the counterfactual_loss. This
dataset is repeated to match the number of examples in original_input.
|
This function should be used to create an instance of
CounterfactualPackedInputs that will be passed to
counterfactual.keras.CounterfactualModel during training and, optionally,
during evaluation.
Each original_input must output a tuple in the format used in
tf.keras.Model.fit. Specifically the output must be a tuple of
length 1, 2 or 3 in the form (x, y, sample_weight).
Every batch from the returned tf.data.Dataset will contain one batch from
each of the input datasets as a CounterfactualPackedInputs. Each returned
batch will be a tuple from the original dataset and counterfactual dataset
of format ((x, y, sample_weight), (original_x, counterfactual_x,
counterfactual_sample_weight)) matching the length of original_input
batches where:
original_input: is atf.data.Datasetthat contains:x: Thexcomponent taken directly from theoriginal_inputbatch.y: Theycomponent taken directly from theoriginal_inputbatch.sample_weight: Thesample_weightcomponent taken directly from theoriginal_inputbatch.
counterfactual_data: is atf.data.Datasetthat contains:original_x: Thexcomponent taken directly from theoriginal_inputbatch.counterfactual_x: The counterfactual value fororiginal_x(as described inbuild_counterfactual_data).counterfactual_sample_weight: Batch of data formed from taken directly from thecounterfactual_sample_weightofcounterfactual_data.
The return of counterfactual_data will be an instance of
CounterfactualPackedInputs that can be used in
counterfactual.keras.CounterfactualModel when calculating the
counterfactual_loss.
Returns | |
|---|---|
A tf,data,Dataset of CounterfactualPackedInputs. Each
CounterfactualPackedInputs represents a
(original_inputs, counterfactual_data) pair where original_inputs is
a(x, y, sample_weight)tuple, andcounterfactual_datais a(original_x, counterfactual_x, counterfactual_sample_weight)` tuple.
|
View source on GitHub