View source on GitHub |
Build Counterfactual dataset from a list sensitive terms or custom function.
model_remediation.counterfactual.keras.utils.build_counterfactual_data(
original_input: tf.data.Dataset,
sensitive_terms_to_remove: Optional[List[str]] = None,
custom_counterfactual_function: Optional[Callable[[Any], Any]] = None
) -> tf.data.Dataset
Arguments | |
---|---|
original_input
|
tf.data.Dataset that was used before applying
Counterfactual. The output should conform to the format used in
tf.keras.Model.fit .
|
sensitive_terms_to_remove
|
List of terms that will be removed and filtered
from within the original_input .
|
custom_counterfactual_function
|
Optional custom function to apply
to tf.data.Dataset.map to build a custom counterfactual dataset. Note
that it needs to return a dataset in the form of (original_x,
counterfactual_x, counterfactual_sample_weight) and should only include
values that have been modified. Use sensitive_terms_to_remove to filter
values that have modifying terms included.
|
This function builds a tf.data.Dataset
containing only examples that will be
used when calculating counterfactual_loss
. This resulting dataset
will need to be packed with the x
value in original_input
, a modified
version of x
that will act as counterfactual_x
, and a
counterfactual_sample_weight
that defaults to 1.0. The resulting dataset can
be passed to pack_counterfactual_data
to create an instance of
CounterfactualPackedInputs
for use within
counterfactual.keras.CounterfactualModel
.
original_input
must output a tuple in the format used in
tf.keras.Model.fit
. Specifically the output must be a tuple of
length 1, 2 or 3 in the form (x, y, sample_weight)
. This output will be
parsed internally in the following way:
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(batch)
Alternatively to passing a sensitive_terms_to_remove
, you can create
a custom function that you can pass to the original to create a
counterfactual dataset as specficied by the users. For example, you
might want to replace a target word instead of simply removing the word.
The returned tf.data.Dataset
will need to have the unchanged x
values removed. Passing sensitive_terms_to_remove
in this case
acts like a filter to only include terms that have been modified.
A minimal example is given below:
simple_dataset_x = tf.constant(["Bad word", "Good word"])
simple_dataset = tf.data.Dataset.from_tensor_slices((simple_dataset_x))
counterfactual_data = counterfactual.keras.utils.build_counterfactual_data(
original_input=simple_dataset, sensitive_terms_to_remove=['Bad'])
for original_value, counterfactual_value, _ in counterfactual_data.take(1):
print("original: ", original_value)
print("counterfactual: ", counterfactual_value)
print("counterfactual_sample_weight: ", cf_weight)
original: tf.Tensor(b'Bad word', shape=(), dtype=string)
counterfactual: tf.Tensor(b' word', shape=(), dtype=string)
counterfactual_sample_weight: tf.Tensor(1.0, shape=(), dtype=float32)
Returns | |
---|---|
A tf.data.Dataset whose output is a tuple matching (original_x,
counterfactual_x, counterfactual_sample_weight) .
|
Raises | |
---|---|
ValueError
|
If both custom_counterfactual_function and
sensitive_terms_to_remove are not provided.
|