Model that adds a loss component to another model during training.

Inherits from: tf.keras.Model

original_model Instance of tf.keras.Model that will be trained with the additional min_diff_loss.
loss String (name of loss) or min_diff.losses.MinDiffLoss instance that will be used to calculate the min_diff_loss.
loss_weight Scalar applied to the min_diff_loss before being included in training.
predictions_transform Optional if the output of original_model is a tf.Tensor. Function that transforms the output of original_model after it is called on MinDiff examples. The resulting predictions tensor is what will be passed in to the losses.MinDiffLoss.
**kwargs Named parameters that will be passed directly to the base class' __init__ function.

MinDiffModel wraps the model passed in, original_model, and adds a component to the loss during training and optionally during evaluation.


There are two ways to construct a MinDiffModel instance, the first is the simplest and the most common:

1 - Directly wrap your model with MinDiffModel. This is the simplest usage and is most likely what you will want to use (unless your original model has some custom implementations that need to be taken into account).

import tensorflow as tf

model = tf.keras.Sequential([...])

model = MinDiffModel(model, ...)

In this case, all methods other than the ones listed below will use the default implementations of tf.keras.Model.

If you are in this use case, the next section is not relevant to you and you skip to the section on usage.

2 - Subclassing MinDiffModel to integrate custom implementations. This will likely be needed if the original_model is itself a customized subclass of tf.keras.Model. If that is the case and you want to preserve the custom implementations, you can create a new custom class that inherits first from MinDiffModel and second from your custom class.

import tensorflow as tf

class CustomSequential(tf.keras.Sequential):

  def train_step(self, data):
    print("In a custom train_step!")

class CustomMinDiffModel(MinDiffModel, CustomSequential):
  pass  # No additional implementation is required.

model = CustomSequential([...])

model = CustomMinDiffModel(model, ...)  # This will use the custom train_step.

If you need to customize methods defined by MinDiffModel, then you can create a direct subclass and override whatever is needed.

import tensorflow as tf

class CustomMinDiffModel(MinDiffModel):

  def unpack_min_diff_data(self, inputs):
    print("In a custom MinDiffModel method!")

model = tf.keras.Sequential([...])

model = CustomMinDiffModel(model, ...)  # This will use the custom
                                        # unpack_min_diff_data method.


Once you have created an instance of MinDiffModel, it can be used almost exactly the same way as the model it wraps. The main two exceptions to this are:

Optionally, inputs containing min_diff_data can be passed in to evaluate and predict. For the former, this will result in the min_diff_loss appearing in the metrics. For predict this should have no visible effect.

ValueError If predictions_transform is passed in but not callable.

original_model tf.keras.Model to be trained with the additional min_diff_loss term.

Inference and evaluation will also come from the results this model provides.

predictions_transform Function to be applied on MinDiff predictions before calculating loss.

MinDiff predictions are the output of original_model on the MinDiff examples (see compute_min_diff_loss for details). These might not initially be a tf.Tensor, for example if the model is multi-output. If this is the case, the predictions need to be converted into a tf.Tensor.

This can be done by selecting one of the outputs or by combining them in some way.

# Pick out a specific output to use for MinDiff.
transform = lambda predictions: predictions["output2"]

model = MinDiffModel(..., predictions_transform=transform)

# test data imitating multi_output predictions
test_predictions = {
"output1": [1, 2, 3],
"output2": [4, 5, 6],
model.predictions_transform(test_predictions)  # [4, 5 ,6]

If no predictions_transform parameter is passed in (or None is used), then it will default to the identity.

model = MinDiffModel(..., predictions_transform=None)

model.predictions_transform([1, 2, 3])  # [1, 2, 3]

The result of applying predictions_transform on the MinDiff predictions must be a tf.Tensor. The min_diff_loss will be calculated on these results.



View source

Calls original_model with optional min_diff_loss as regularization loss.

inputs Inputs to original_model, optionally containing min_diff_data as described below.
training Boolean indicating whether to run in training or inference mode. See for details.
mask Mask or list of masks as described in

This method should be used the same way as Depending on whether you are in train mode, inputs may need to include min_diff_data (see MinDiffModel.compute_min_diff_data for details on what form that needs to take).

  • If training=True: inputs must contain min_diff_data (see details below).
  • If training=False: including min_diff_data is optional.

If present, the min_diff_loss is added by calling self.add_loss and will show up in self.losses.

model = ...  # MinDiffModel.

dataset = ...  # Dataset containing min_diff_data.

for batch in dataset.take(1):
  model(batch, training=True)

model.losses[0]  # First element will be the min_diff_loss.

Inlcuding min_diff_data in inputs implies that MinDiffModel.unpack_original_inputs and MinDiffModel.unpack_min_diff_data behave as expected when called on inputs (see methods for details).

This condition is satisfied with the default implementations if you use min_diff.keras.utils.pack_min_diff_data to create the dataset that includes min_diff_data.

A tf.Tensor or nested structure of tf.Tensors according to the behavior original_model. See for details.

ValueError If training is set to True but inputs does not include min_diff_data.


View source

Compile both self and original_model using the same parameters.

See tf.keras.Model.compile for details.


View source

Computes and returns the min_diff_loss corresponding to min_diff_data.

min_diff_data Tuple of length 2 or 3 as described below.
training Boolean indicating whether to run in training or inference mode. See for details.
mask Mask or list of masks as described in

Like the input requirements described in, min_diff_data must be a tuple of length 2 or 3. The tuple will be unpacked using the standard tf.keras.utils.unpack_x_y_sample_weight function:

min_diff_data = ...  # Single batch of min_diff_data.

min_diff_x, min_diff_membership, min_diff_sample_weight = (

The components are defined as follows:

  • min_diff_x: inputs to original_model to get the corresponding MinDiff predictions.
  • min_diff_membership: numerical [batch_size, 1] Tensor indicating which group each example comes from (marked as 0.0 or 1.0).
  • min_diff_sample_weight: Optional weight Tensor. The weights will be applied to the examples during the min_diff_loss calculation.

The min_diff_loss is ultimately calculated from the MinDiff predictions which are evaluated in the following way:

...  # In compute_min_diff_loss call.

min_diff_x = ...  # Single batch of MinDiff examples.

# Get predictions for MinDiff examples.
min_diff_predictions = self.original_model(min_diff_x, training=training)
# Transform the predictions if needed. By default this is the identity.
min_diff_predictions = self.predictions_transform(min_diff_predictions)

min_diff_loss calculated from min_diff_data.

ValueError If the transformed min_diff_predictions is not a tf.Tensor.


View source

Exports the model as described in

You may want to use this if you want to continue training your model with MinDiff after having loaded it. If you want to use the loaded model purely for inference, you will likely want to use MinDiffModel.save_original_model instead.

Other than the exception noted above, this method has the same behavior as


View source

Exports the original_model for inference without min_diff_data.

Saving the original_model allows you to load a model and run tf.keras.Model.evaluate or tf.keras.Model.predict without requiring min_diff_data to be included.

This is most likely what you will want to use if you want to save your model for inference only. Most cases will need to use this method instead of


View source

Extracts min_diff_data from inputs if present or returns None.

inputs inputs as described in

Identifies whether min_diff_data is included in inputs and returns min_diff_data if it is.

model = ...  # MinDiffModel.

inputs = ...  # Batch containing `min_diff_data`

min_diff_data = model.unpack_min_diff_data(inputs)

If min_diff_data is not included, then None is returned.

model = ...  # MinDiffModel.

# Test batch without `min_diff_data` (i.e. just passing in a simple array)
print(model.unpack_min_diff_data([1, 2, 3]))  # None

The default implementation is a pure wrapper around min_diff.keras.utils.unpack_min_diff_data. See there for implementation details.

min_diff_data to be passed to MinDiffModel.compute_min_diff_loss if present or None otherwise.


View source

Extracts original_inputs from inputs.

inputs inputs as described in

Identifies whether min_diff_data is included in inputs. If it is, then what is returned is the component that is only meant to be used in the call to original_model.

model = ...  # MinDiffModel.

inputs = ...  # Batch containing `min_diff_data`

# Extracts component that is only meant to be passed to `original_model`.
original_inputs = model.unpack_original_inputs(inputs)

If min_diff_data is not included, then inputs is returned directly.

model = ...  # MinDiffModel.

# Test batch without `min_diff_data` (i.e. just passing in a simple array)
print(model.unpack_original_inputs([1, 2, 3]))  # [1, 2, 3]

The default implementation is a pure wrapper around min_diff.keras.utils.unpack_original_inputs. See there for implementation details.

Inputs to be used in the call to original_model.