View source on GitHub
|
Build Counterfactual dataset from a list sensitive terms or custom function.
model_remediation.counterfactual.keras.utils.build_counterfactual_data(
original_input: tf.data.Dataset,
sensitive_terms_to_remove: Optional[List[str]] = None,
custom_counterfactual_function: Optional[Callable[[Any], Any]] = None
) -> tf.data.Dataset
Arguments | |
|---|---|
original_input
|
tf.data.Dataset that was used before applying
Counterfactual. The output should conform to the format used in
tf.keras.Model.fit.
|
sensitive_terms_to_remove
|
List of terms that will be removed and filtered
from within the original_input.
|
custom_counterfactual_function
|
Optional custom function to apply
to tf.data.Dataset.map to build a custom counterfactual dataset. Note
that it needs to return a dataset in the form of (original_x,
counterfactual_x, counterfactual_sample_weight) and should only include
values that have been modified. Use sensitive_terms_to_remove to filter
values that have modifying terms included.
|
This function builds a tf.data.Dataset containing only examples that will be
used when calculating counterfactual_loss. This resulting dataset
will need to be packed with the x value in original_input, a modified
version of x that will act as counterfactual_x, and a
counterfactual_sample_weight that defaults to 1.0. The resulting dataset can
be passed to pack_counterfactual_data to create an instance of
CounterfactualPackedInputs for use within
counterfactual.keras.CounterfactualModel.
original_input must output a tuple in the format used in
tf.keras.Model.fit. Specifically the output must be a tuple of
length 1, 2 or 3 in the form (x, y, sample_weight). This output will be
parsed internally in the following way:
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(batch)
Alternatively to passing a sensitive_terms_to_remove, you can create
a custom function that you can pass to the original to create a
counterfactual dataset as specficied by the users. For example, you
might want to replace a target word instead of simply removing the word.
The returned tf.data.Dataset will need to have the unchanged x
values removed. Passing sensitive_terms_to_remove in this case
acts like a filter to only include terms that have been modified.
A minimal example is given below:
simple_dataset_x = tf.constant(["Bad word", "Good word"])simple_dataset = tf.data.Dataset.from_tensor_slices((simple_dataset_x))counterfactual_data = counterfactual.keras.utils.build_counterfactual_data(original_input=simple_dataset, sensitive_terms_to_remove=['Bad'])
for original_value, counterfactual_value, _ in counterfactual_data.take(1):print("original: ", original_value)print("counterfactual: ", counterfactual_value)print("counterfactual_sample_weight: ", cf_weight)original: tf.Tensor(b'Bad word', shape=(), dtype=string)counterfactual: tf.Tensor(b' word', shape=(), dtype=string)counterfactual_sample_weight: tf.Tensor(1.0, shape=(), dtype=float32)
Returns | |
|---|---|
A tf.data.Dataset whose output is a tuple matching (original_x,
counterfactual_x, counterfactual_sample_weight).
|
Raises | |
|---|---|
ValueError
|
If both custom_counterfactual_function and
sensitive_terms_to_remove are not provided.
|
View source on GitHub