Missed TensorFlow World? Check out the recap. Learn more

tff.learning.from_keras_model

View source on GitHub

Builds a tff.learning.Model for an example mini batch.

tff.learning.from_keras_model(
    keras_model,
    dummy_batch,
    loss,
    loss_weights=None,
    metrics=None,
    optimizer=None
)

Args:

  • keras_model: A tf.keras.Model object that is not compiled.
  • dummy_batch: A nested structure of values that are convertible to batched tensors with the same shapes and types as would be input to keras_model. The values of the tensors are not important and can be filled with any reasonable input value.
  • loss: A callable that takes two batched tensor parameters, y_true and y_pred, and returns the loss. If the model has multiple outputs, you can use a different loss on each output by passing a dictionary or a list of losses. The loss value that will be minimized by the model will then be the sum of all individual losses, each weighted by loss_weights.
  • loss_weights: (Optional) a list or dictionary specifying scalar coefficients (Python floats) to weight the loss contributions of different model outputs. The loss value that will be minimized by the model will then be the weighted sum of all individual losses, weighted by the loss_weights coefficients. If a list, it is expected to have a 1:1 mapping to the model's outputs. If a tensor, it is expected to map output names (strings) to scalar coefficients.
  • metrics: (Optional) a list of tf.keras.metrics.Metric objects.
  • optimizer: (Optional) a tf.keras.optimizer.Optimizer. If None, returned model cannot be used for training.

Returns:

A tff.learning.Model object.

Raises:

  • TypeError: If keras_model is not an instance of tf.keras.Model.
  • ValueError: If keras_model was compiled.
  • KeyError: If loss is a dict and does not have the same keys as keras_model.outputs.