View source on GitHub |
CounterfactualLoss abstract base class.
model_remediation.counterfactual.losses.CounterfactualLoss(
name: Optional[str] = None
)
Inherits from: tf.keras.losses.Loss
A CounterfactualLoss
instance measures the difference in prediction scores
(typically score distributions) between two groups of examples identified by
the value in the counterfactual_weights
column.
If the predictions between the two groups are indistinguishable, the loss should be 0. The greater different between the two scores are, the higher the loss.
Methods
__call__
__call__(
original: types.TensorType,
counterfactual: types.TensorType,
sample_weight: Optional[types.TensorType] = None
)
Computes Counterfactual loss.
Arguments | |
---|---|
original
|
The predictions from the original example values. shape =
[batch_size, d0, .. dN] . Tensor of type float32 or float64 .
Required.
|
counterfactual
|
The predictions from the counterfactual examples. shape =
[batch_size, d0, .. dN] . Tensor of the same type and shape as
original . Required.
|
sample_weight
|
(Optional) sample_weight acts as a coefficient for the
loss. If a scalar is provided, then the loss is simply scaled by the
given value. If sample_weight is a tensor of size [batch_size] , then
the total loss for each sample of the batch is rescaled by the
corresponding element in the sample_weight vector.
|
Returns | |
---|---|
The computed counterfactual loss. |
Raises | |
---|---|
ValueError
|
If any of the input arguments are invalid. |
TypeError
|
If any of the arguments are not of the expected type. |
InvalidArgumentError
|
If original , counterfactual or sample_weight
have incompatible shapes.
|