View source on GitHub 
Model that adds a loss component to another model during training.
model_remediation.min_diff.keras.MinDiffModel(
original_model: tf.keras.Model,
loss,
loss_weight: complex = 1.0,
predictions_transform=None,
**kwargs
)
Inherits from: tf.keras.Model
Arguments  

original_model

Instance of tf.keras.Model that will be trained with the
additional min_diff_loss .

loss

String (name of loss) or min_diff.losses.MinDiffLoss instance that
will be used to calculate the min_diff_loss .

loss_weight

Scalar applied to the min_diff_loss before being included
in training.

predictions_transform

Optional if the output of original_model is a
tf.Tensor . Function that transforms the output of original_model after
it is called on MinDiff examples. The resulting predictions tensor is
what will be passed in to the losses.MinDiffLoss .

**kwargs

Named parameters that will be passed directly to the base
class' __init__ function.

MinDiffModel
wraps the model passed in, original_model
, and adds a
component to the loss during training and optionally during evaluation.
Construction
There are two ways to construct a MinDiffModel
instance, the first is the
simplest and the most common:
1  Directly wrap your model with MinDiffModel
. This is the simplest usage
and is most likely what you will want to use (unless your original model has
some custom implementations that need to be taken into account).
import tensorflow as tf
model = tf.keras.Sequential([...])
model = MinDiffModel(model, ...)
In this case, all methods other than the ones listed below will use the
default implementations of tf.keras.Model
.
If you are in this use case, the next section is not relevant to you and you skip to the section on usage.
2  Subclassing MinDiffModel
to integrate custom implementations. This will
likely be needed if the original_model is itself a customized subclass of
tf.keras.Model
. If that is the case and you want to preserve the custom
implementations, you can create a new custom class that inherits first from
MinDiffModel
and second from your custom class.
import tensorflow as tf
class CustomSequential(tf.keras.Sequential):
def train_step(self, data):
print("In a custom train_step!")
super().train_step(data)
class CustomMinDiffModel(MinDiffModel, CustomSequential):
pass # No additional implementation is required.
model = CustomSequential([...])
model = CustomMinDiffModel(model, ...) # This will use the custom train_step.
If you need to customize methods defined by MinDiffModel
, then you can
create a direct subclass and override whatever is needed.
import tensorflow as tf
class CustomMinDiffModel(MinDiffModel):
def unpack_min_diff_data(self, inputs):
print("In a custom MinDiffModel method!")
super().unpack_min_diff_data(inputs)
model = tf.keras.Sequential([...])
model = CustomMinDiffModel(model, ...) # This will use the custom
# unpack_min_diff_data method.
Usage
Once you have created an instance of MinDiffModel
, it can be used almost
exactly the same way as the model it wraps. The main two exceptions to this
are:
 During training, the inputs must include
min_diff_data
, seeMinDiffModel.compute_min_diff_loss
for details.  Saving and loading a model has slightly different behavior. See
MinDiffModel.save
andMinDiffModel.save_original_model
for details.
Optionally, inputs containing min_diff_data
can be passed in to evaluate
and predict
. For the former, this will result in the min_diff_loss
appearing in the metrics. For predict
this should have no visible effect.
Raises  

ValueError

If predictions_transform is passed in but not callable.

Attributes  

original_model

tf.keras.Model to be trained with the additional min_diff_loss term.
Inference and evaluation will also come from the results this model provides. 
predictions_transform

Function to be applied on MinDiff predictions before calculating loss.
MinDiff predictions are the output of This can be done by selecting one of the outputs or by combining them in some way.
If no
The result of applying 
Methods
call
call(
inputs, training=None, mask=None
)
Calls original_model
with optional min_diff_loss
as regularization loss.
Args  

inputs

Inputs to original_model, optionally containing min_diff_data as
described below.

training

Boolean indicating whether to run in training or inference mode.
See tf.keras.Model.call for details.

mask

Mask or list of masks as described in tf.keras.Model.call .

This method should be used the same way as tf.keras.Model.call
. Depending
on whether you are in train mode, inputs
may need to include
min_diff_data
(see MinDiffModel.compute_min_diff_data
for details on
what form that needs to take).
 If
training=True
:inputs
must containmin_diff_data
(see details below).  If
training=False
: includingmin_diff_data
is optional.
If present, the min_diff_loss
is added by calling self.add_loss
and will
show up in self.losses
.
model = ... # MinDiffModel.
dataset = ... # Dataset containing min_diff_data.
for batch in dataset.take(1):
model(batch, training=True)
model.losses[0] # First element will be the min_diff_loss.
Inlcuding min_diff_data
in inputs
implies that
MinDiffModel.unpack_original_inputs
and
MinDiffModel.unpack_min_diff_data
behave as expected when called on
inputs
(see methods for details).
This condition is satisfied with the default implementations if you use
min_diff.keras.utils.pack_min_diff_data
to create the dataset that
includes min_diff_data
.
Returns  

A tf.Tensor or nested structure of tf.Tensor s according to the
behavior original_model . See tf.keras.Model.call for details.

Raises  

ValueError

If training is set to True but inputs does not include
min_diff_data .

compile
compile(
*args, **kwargs
)
Compile both self
and original_model
using the same parameters.
See tf.keras.Model.compile
for details.
compute_min_diff_loss
compute_min_diff_loss(
min_diff_data, training=None, mask=None
)
Computes and returns the min_diff_loss
corresponding to min_diff_data
.
Arguments  

min_diff_data

Tuple of length 2 or 3 as described below. 
training

Boolean indicating whether to run in training or inference mode.
See tf.keras.Model.call for details.

mask

Mask or list of masks as described in tf.keras.Model.call .

Like the input requirements described in tf.keras.Model.fit
,
min_diff_data
must be a tuple of length 2 or 3. The tuple will be unpacked
using the standard tf.keras.utils.unpack_x_y_sample_weight
function:
min_diff_data = ... # Single batch of min_diff_data.
min_diff_x, min_diff_membership, min_diff_sample_weight = (
tf.keras.utils.unpack_x_y_sample_weight(min_diff_data))
The components are defined as follows:
min_diff_x
: inputs tooriginal_model
to get the corresponding MinDiff predictions.min_diff_membership
: numerical [batch_size, 1]Tensor
indicating which group each example comes from (marked as0.0
or1.0
).min_diff_sample_weight
: Optional weightTensor
. The weights will be applied to the examples during themin_diff_loss
calculation.
The min_diff_loss
is ultimately calculated from the MinDiff
predictions which are evaluated in the following way:
... # In compute_min_diff_loss call.
min_diff_x = ... # Single batch of MinDiff examples.
# Get predictions for MinDiff examples.
min_diff_predictions = self.original_model(min_diff_x, training=training)
# Transform the predictions if needed. By default this is the identity.
min_diff_predictions = self.predictions_transform(min_diff_predictions)
Returns  

min_diff_loss calculated from min_diff_data .

Raises  

ValueError

If the transformed min_diff_predictions is not a
tf.Tensor .

save
save(
*args, **kwargs
)
Exports the model as described in tf.keras.Model.save
.
You may want to use this if you want to continue training your model with
MinDiff after having loaded it. If you want to use the loaded model purely
for inference, you will likely want to use
MinDiffModel.save_original_model
instead.
Other than the exception noted above, this method has the same behavior as
tf.keras.Model.save
.
save_original_model
save_original_model(
*args, **kwargs
)
Exports the original_model
for inference without min_diff_data
.
Saving the original_model
allows you to load a model and run
tf.keras.Model.evaluate
or tf.keras.Model.predict
without requiring
min_diff_data
to be included.
This is most likely what you will want to use if you want to save your model
for inference only. Most cases will need to use this method instead of
MinDiffModel.save
.
unpack_min_diff_data
unpack_min_diff_data(
inputs
)
Extracts min_diff_data
from inputs
if present or returns None
.
Arguments  

inputs

inputs as described in MinDiffModel.call .

Identifies whether min_diff_data
is included in inputs
and returns
min_diff_data
if it is.
model = ... # MinDiffModel.
inputs = ... # Batch containing `min_diff_data`
min_diff_data = model.unpack_min_diff_data(inputs)
If min_diff_data
is not included, then None
is returned.
model = ... # MinDiffModel.
# Test batch without `min_diff_data` (i.e. just passing in a simple array)
print(model.unpack_min_diff_data([1, 2, 3])) # None
The default implementation is a pure wrapper around
min_diff.keras.utils.unpack_min_diff_data
. See there for implementation
details.
Returns  

min_diff_data to be passed to MinDiffModel.compute_min_diff_loss if
present or None otherwise.

unpack_original_inputs
unpack_original_inputs(
inputs
)
Extracts original_inputs from inputs
.
Arguments  

inputs

inputs as described in MinDiffModel.call .

Identifies whether min_diff_data
is included in inputs
. If it is, then
what is returned is the component that is only meant to be used in the call
to original_model
.
model = ... # MinDiffModel.
inputs = ... # Batch containing `min_diff_data`
# Extracts component that is only meant to be passed to `original_model`.
original_inputs = model.unpack_original_inputs(inputs)
If min_diff_data
is not included, then inputs
is returned directly.
model = ... # MinDiffModel.
# Test batch without `min_diff_data` (i.e. just passing in a simple array)
print(model.unpack_original_inputs([1, 2, 3])) # [1, 2, 3]
The default implementation is a pure wrapper around
min_diff.keras.utils.unpack_original_inputs
. See there for implementation
details.
Returns  

Inputs to be used in the call to original_model .
