Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings


View source on GitHub

Builds a tff.learning.Model for a given input type.

    keras_model, loss, input_spec=None, loss_weights=None, metrics=None,

Used in the notebooks

Used in the tutorials

from_keras_model validates its arguments, normalizes them as appropriate and instantiates a tff.learning.Model backed by keras_model for the forward pass and autodifferentiation steps. This function needs three pieces of information in order to accomplish this goal: a tf.keras.Model to use for its forward pass; a loss function (or group of loss functions) loss; and a way to infer the TFF type signatures for the tff.Computations in which this model will appear, the input_spec.

Notice that since TFF couples the tf.keras.Model and loss, TFF needs a slightly different notion of "fully specified type" than pure Keras does. That is, the model M takes inputs of type x and produces predictions of type p; the loss function L takes inputs of type <p, y> and produces a scalar. Therefore in order to fully specify the type signatures for computations in which the generated tff.learning.Model will appear, TFF needs the type y in addition to the type x. Currently, TFF allows two methods of specifying this input spec, via an input_spec and via a dummy_batch. input_spec is strictly more general, and TFF is in the process of deprecating dummy_batch as an argument to from_keras_model.


  • keras_model: A tf.keras.Model object that is not compiled.
  • loss: A callable that takes two batched tensor parameters, y_true and y_pred, and returns the loss. If the model has multiple outputs, you can use a different loss on each output by passing a dictionary or a list of losses. The loss value that will be minimized by the model will then be the sum of all individual losses, each weighted by loss_weights.
  • input_spec: (Optional) a value convertible to tff.Type specifying the type of arguments the model expects. Notice this must be a compound structure of two elements, specifying both the data fed into the model to generate predictions, as its first element, as well as the expected type of the ground truth as its second. This argument will become required when we remove dummy_batch; currently, exactly one of these two must be specified.
  • loss_weights: (Optional) a list or dictionary specifying scalar coefficients (Python floats) to weight the loss contributions of different model outputs. The loss value that will be minimized by the model will then be the weighted sum of all individual losses, weighted by the loss_weights coefficients. If a list, it is expected to have a 1:1 mapping to the model's outputs. If a tensor, it is expected to map output names (strings) to scalar coefficients.
  • metrics: (Optional) a list of tf.keras.metrics.Metric objects.
  • dummy_batch: (Optional, deprecated) a nested structure of values that are convertible to batched tensors with the same shapes and types as would be input to keras_model. The values of the tensors are not important and can be filled with any reasonable input value.


A tff.learning.Model object.


  • TypeError: If keras_model is not an instance of tf.keras.Model.
  • ValueError: If keras_model was compiled.
  • KeyError: If loss is a dict and does not have the same keys as keras_model.outputs.