model_remediation.min_diff.losses.MMDLoss
Stay organized with collections
Save and categorize content based on your preferences.
Maximum Mean Discrepancy between predictions on two groups of examples.
Inherits From: MinDiffLoss
model_remediation.min_diff.losses.MMDLoss(
kernel='gaussian',
predictions_transform=None,
name: Optional[str] = None,
enable_summary_histogram: Optional[bool] = True
)
Arguments |
kernel
|
String (name of kernel) or losses.MinDiffKernel instance to be
applied on the predictions. Defaults to 'gaussian' and it is recommended
that this be either
'gaussian'
(min_diff.losses.GaussianKernel ) or 'laplacian'
(min_diff.losses.LaplacianKernel ).
|
predictions_transform
|
Optional transform function to be applied to the
predictions. This can be used to smooth out the distributions or limit the
range of predictions.
The choice of whether to apply a transform to the predictions is task and
data dependent. For example, for classifiers, it might make sense to apply
a tf.sigmoid transform to the predictions (if this is not done already)
so that MMD is calculated in probability space rather than on raw
predictions. In some cases, such as regression, not having any transform
is more likely to yield successful results.
|
name
|
Name used for logging and tracking. Defaults to 'mmd_loss' .
|
enable_summary_histogram
|
Optional bool indicating if tf.summary.histogram
should be included within the loss. Defaults to True.
|
The Maximum Mean Discrepancy (MMD) is a measure of the distance between the
distributions of prediction scores on two groups of examples. The metric
guarantees that the result is 0 if and only if the two distributions it is
comparing are exactly the same.
The membership
input indicates with a numerical value whether
each example is part of the sensitive group with a numerical value. This
currently only supports hard membership of 0.0
or 1.0
.
For more details, see the
paper.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2022-07-01 UTC.
[{
"type": "thumb-down",
"id": "missingTheInformationINeed",
"label":"Missing the information I need"
},{
"type": "thumb-down",
"id": "tooComplicatedTooManySteps",
"label":"Too complicated / too many steps"
},{
"type": "thumb-down",
"id": "outOfDate",
"label":"Out of date"
},{
"type": "thumb-down",
"id": "samplesCodeIssue",
"label":"Samples / code issue"
},{
"type": "thumb-down",
"id": "otherDown",
"label":"Other"
}]
[{
"type": "thumb-up",
"id": "easyToUnderstand",
"label":"Easy to understand"
},{
"type": "thumb-up",
"id": "solvedMyProblem",
"label":"Solved my problem"
},{
"type": "thumb-up",
"id": "otherUp",
"label":"Other"
}]
{"lastModified": "Last updated 2022-07-01 UTC."}
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2022-07-01 UTC."],[],[]]