View source on GitHub |
Model that adds one or more loss component(s) to another model during training.
model_remediation.min_diff.keras.MinDiffModel(
original_model: tf.keras.Model,
loss,
loss_weight=1.0,
predictions_transform=None,
**kwargs
)
Inherits from: tf.keras.Model
Arguments | |
---|---|
original_model
|
Instance of tf.keras.Model that will be trained with the
additional min_diff_loss .
|
loss
|
dict or single element of string(s) (name of loss) or
min_diff.losses.MinDiffLoss instance(s) that will be used to calculate
the min_diff_loss (es).
|
loss_weight
|
dict of scalars or single scalar applied to the
min_diff_loss (es) before being included in training.
|
predictions_transform
|
Optional if the output of original_model is a
tf.Tensor . Function that transforms the output of original_model after
it is called on MinDiff examples. The resulting predictions tensor is
what will be passed in to the losses.MinDiffLoss (es).
|
**kwargs
|
Named parameters that will be passed directly to the base
class' __init__ function.
|
MinDiffModel
wraps the model passed in, original_model
, and adds a
component to the loss during training and optionally during evaluation.
Construction
There are two ways to construct a MinDiffModel
instance, the first is the
simplest and the most common:
1 - Directly wrap your model with MinDiffModel
. This is the simplest usage
and is most likely what you will want to use (unless your original model has
some custom implementations that need to be taken into account).
import tensorflow as tf
model = tf.keras.Sequential([...])
model = MinDiffModel(model, ...)
In this case, all methods other than the ones listed below will use the
default implementations of tf.keras.Model
.
If you are in this use case, the next section is not relevant to you and you skip to the section on usage.
2 - Subclassing MinDiffModel
to integrate custom implementations. This will
likely be needed if the original_model is itself a customized subclass of
tf.keras.Model
. If that is the case and you want to preserve the custom
implementations, you can create a new custom class that inherits first from
MinDiffModel
and second from your custom class.
import tensorflow as tf
class CustomSequential(tf.keras.Sequential):
def train_step(self, data):
print("In a custom train_step!")
super().train_step(data)
class CustomMinDiffModel(MinDiffModel, CustomSequential):
pass # No additional implementation is required.
model = CustomSequential([...])
model = CustomMinDiffModel(model, ...) # This will use the custom train_step.
If you need to customize methods defined by MinDiffModel
, then you can
create a direct subclass and override whatever is needed.
import tensorflow as tf
class CustomMinDiffModel(MinDiffModel):
def unpack_min_diff_data(self, inputs):
print("In a custom MinDiffModel method!")
super().unpack_min_diff_data(inputs)
model = tf.keras.Sequential([...])
model = CustomMinDiffModel(model, ...) # This will use the custom
# unpack_min_diff_data method.
Multiple Applications of MinDiff
It is possible to apply MinDiff multiple times within a single instance of
MinDiffModel
. To do so, you can pass in a dictionary of losses where keys
are the names of each MinDiff application and the values are the names or
instances of losses.MinDiffLoss
that will be applied for each respective
MinDiff application.
Loss weights can be set as either one value that will be used for all
applications or with a dictionary that specifies weights for individual
applications. Weights not specified will default to 1.0.
import tensorflow as tf
model = tf.keras.Sequential([...])
model = MinDiffModel(model, loss={
"application1": min_diff.losses.MMDLoss(), # Loss for first application.
"application2": min_diff.losses.MMDLoss() # Loss for second application.
},
loss_weight=2.0) # 2.0 will used as the weight for all applications.
A MinDiffModel
initialized as shown above will expect min_diff_data
to
have a structure matching that of loss
(i.e. a dictionary of inputs with
keys matching that of loss
). See MinDiffModel.compute_min_diff_loss
for
details.
Usage
Once you have created an instance of MinDiffModel
, it can be used almost
exactly the same way as the model it wraps. The main two exceptions to this
are:
- During training, the inputs must include
min_diff_data
, seeMinDiffModel.compute_min_diff_loss
for details. - Saving and loading a model can have slightly different behavior if you are
subclassing
MinDiffModel
. SeeMinDiffModel.save
andMinDiffModel.save_original_model
for details.
Optionally, inputs containing min_diff_data
can be passed in to evaluate
and predict
. For the former, this will result in the min_diff_loss
appearing in the metrics. For predict
this should have no visible effect.
Raises | |
---|---|
ValueError
|
If predictions_transform is passed in but not callable.
|
Attributes | |
---|---|
original_model
|
tf.keras.Model to be trained with the additional min_diff_loss term.
Inference and evaluation will also come from the results this model provides. |
predictions_transform
|
Function to be applied on MinDiff predictions before calculating loss.
MinDiff predictions are the output of This can be done by selecting one of the outputs or by combining them in some way.
If no
The result of applying |
Methods
call
call(
inputs, training=None, mask=None
)
Calls original_model
with optional min_diff_loss
as regularization loss.
Args | |
---|---|
inputs
|
Inputs to original_model, optionally containing min_diff_data as
described below.
|
training
|
Boolean indicating whether to run in training or inference mode.
See tf.keras.Model.call for details.
|
mask
|
Mask or list of masks as described in tf.keras.Model.call .
|
This method should be used the same way as tf.keras.Model.call
. Depending
on whether you are in train mode, inputs
may need to include
min_diff_data
(see MinDiffModel.compute_min_diff_data
for details on
what form that needs to take).
- If
training=True
:inputs
must containmin_diff_data
(see details below). - If
training=False
: includingmin_diff_data
is optional.
If present, the min_diff_loss
is added by calling self.add_loss
and will
show up in self.losses
.
model = ... # MinDiffModel.
dataset = ... # Dataset containing min_diff_data.
for batch in dataset.take(1):
model(batch, training=True)
model.losses[0] # First element(s) will be the min_diff_loss(es).
Including min_diff_data
in inputs
implies that
MinDiffModel.unpack_original_inputs
and
MinDiffModel.unpack_min_diff_data
behave as expected when called on
inputs
(see methods for details).
This condition is satisfied with the default implementations if you use
min_diff.keras.utils.pack_min_diff_data
to create the dataset that
includes min_diff_data
.
Returns | |
---|---|
A tf.Tensor or nested structure of tf.Tensor s according to the
behavior original_model . See tf.keras.Model.call for details.
|
Raises | |
---|---|
ValueError
|
If training is set to True but inputs does not include
min_diff_data .
|
compile
compile(
*args, **kwargs
)
Compile both self
and original_model
using the same parameters.
See tf.keras.Model.compile
for details.
compute_min_diff_loss
compute_min_diff_loss(
min_diff_data, training=None, mask=None
)
Computes min_diff_loss
(es) corresponding to min_diff_data
.
Arguments | |
---|---|
min_diff_data
|
Tuple of data or valid MinDiff structure of tuples as described below. |
training
|
Boolean indicating whether to run in training or inference mode.
See tf.keras.Model.call for details.
|
mask
|
Mask or list of masks as described in tf.keras.Model.call . These
will be applied when calling the original_model .
|
min_diff_data
must have a structure (or be a single element) matching that
of the loss
parameter passed in during initialization. Each element of
min_diff_data
(and loss
) corresponds to one application of MinDiff.
Like the input requirements described in tf.keras.Model.fit
, each element
of min_diff_data
must be a tuple of length 2 or 3. The tuple will be
unpacked using the standard tf.keras.utils.unpack_x_y_sample_weight
function:
min_diff_data_elem = ... # Single element from a batch of min_diff_data.
min_diff_x, min_diff_membership, min_diff_sample_weight = (
tf.keras.utils.unpack_x_y_sample_weight(min_diff_data_elem))
The components are defined as follows:
min_diff_x
: inputs tooriginal_model
to get the corresponding MinDiff predictions.min_diff_membership
: numerical [batch_size, 1]Tensor
indicating which group each example comes from (marked as0.0
or1.0
).min_diff_sample_weight
: Optional weightTensor
. The weights will be applied to the examples during themin_diff_loss
calculation.
For each application of MinDiff, the min_diff_loss
is ultimately
calculated from the MinDiff predictions which are evaluated in the
following way:
... # In compute_min_diff_loss call.
min_diff_x = ... # Single batch of MinDiff examples.
# Get predictions for MinDiff examples.
min_diff_predictions = self.original_model(min_diff_x, training=training)
# Transform the predictions if needed. By default this is the identity.
min_diff_predictions = self.predictions_transform(min_diff_predictions)
Returns | |
---|---|
Scalar (if only one) or list of min_diff_loss values calculated from
min_diff_data .
|
Raises | |
---|---|
ValueError
|
If the structure of min_diff_data does not match that of the
loss that was passed to the model during initialization.
|
ValueError
|
If the transformed min_diff_predictions is not a
tf.Tensor .
|
save
save(
*args, **kwargs
)
Exports the model as described in tf.keras.Model.save
.
For subclasses of MinDiffModel
that have not been registered as Keras
objects, this method will likely be what you want to call to continue
training your model with MinDiff after having loaded it. If you want to use
the loaded model purely for inference, you will likely want to use
MinDiffModel.save_original_model
instead.
The exception noted above for unregistered MinDiffModel
subclasses is the
only difference with tf.keras.Model.save
. To avoid these subtle
differences, we strongly recommend registering MinDiffModel
subclasses as
Keras objects. See the documentation of
tf.keras.utils.register_keras_serializable
for details.
save_original_model
save_original_model(
*args, **kwargs
)
Exports the original_model
.
Exports the original_model
. When loaded, this model will be the type of
original_model
and will no longer be able to train or evaluate with
MinDiff data.
unpack_min_diff_data
unpack_min_diff_data(
inputs
)
Extracts min_diff_data
from inputs
if present or returns None
.
Arguments | |
---|---|
inputs
|
inputs as described in MinDiffModel.call .
|
Identifies whether min_diff_data
is included in inputs
and returns
min_diff_data
if it is.
model = ... # MinDiffModel.
inputs = ... # Batch containing `min_diff_data`
min_diff_data = model.unpack_min_diff_data(inputs)
If min_diff_data
is not included, then None
is returned.
model = ... # MinDiffModel.
# Test batch without `min_diff_data` (i.e. just passing in a simple array)
print(model.unpack_min_diff_data([1, 2, 3])) # None
The default implementation is a pure wrapper around
min_diff.keras.utils.unpack_min_diff_data
. See there for implementation
details.
Returns | |
---|---|
min_diff_data to be passed to MinDiffModel.compute_min_diff_loss if
present or None otherwise.
|
unpack_original_inputs
unpack_original_inputs(
inputs
)
Extracts original_inputs from inputs
.
Arguments | |
---|---|
inputs
|
inputs as described in MinDiffModel.call .
|
Identifies whether min_diff_data
is included in inputs
. If it is, then
what is returned is the component that is only meant to be used in the call
to original_model
.
model = ... # MinDiffModel.
inputs = ... # Batch containing `min_diff_data`
# Extracts component that is only meant to be passed to `original_model`.
original_inputs = model.unpack_original_inputs(inputs)
If min_diff_data
is not included, then inputs
is returned directly.
model = ... # MinDiffModel.
# Test batch without `min_diff_data` (i.e. just passing in a simple array)
print(model.unpack_original_inputs([1, 2, 3])) # [1, 2, 3]
The default implementation is a pure wrapper around
min_diff.keras.utils.unpack_original_inputs
. See there for implementation
details.
Returns | |
---|---|
Inputs to be used in the call to original_model .
|