Pairwise cosine loss between the original and counterfactual.
Inherits From: CounterfactualLoss
model_remediation.counterfactual.losses.PairwiseCosineLoss(
name: Optional[str] = None
)
Arguments |
name
|
Name used for logging and tracking. Defaults to
'pairwise_cosine_loss' .
|
Methods
__call__
View source
__call__(
original: types.TensorType,
counterfactual: types.TensorType,
sample_weight: Optional[types.TensorType] = None
)
Computes Counterfactual loss.
Arguments |
original
|
The predictions from the original example values. shape =
[batch_size, d0, .. dN] . Tensor of type float32 or float64 .
Required.
|
counterfactual
|
The predictions from the counterfactual examples. shape =
[batch_size, d0, .. dN] . Tensor of the same type and shape as
original . Required.
|
sample_weight
|
(Optional) sample_weight acts as a coefficient for the
loss. If a scalar is provided, then the loss is simply scaled by the
given value. If sample_weight is a tensor of size [batch_size] , then
the total loss for each sample of the batch is rescaled by the
corresponding element in the sample_weight vector.
|
Returns |
The computed counterfactual loss.
|
Raises |
ValueError
|
If any of the input arguments are invalid.
|
TypeError
|
If any of the arguments are not of the expected type.
|
InvalidArgumentError
|
If original , counterfactual or sample_weight
have incompatible shapes.
|