View source on GitHub

Builds a tff.learning.Model for a given input type.

Used in the notebooks

Used in the tutorials

from_keras_model validates its arguments, normalizes them as appropriate and instantiates a tff.learning.Model backed by keras_model for the forward pass and autodifferentiation steps. This function needs three pieces of information in order to accomplish this goal: a tf.keras.Model to use for its forward pass; a loss function (or group of loss functions) loss; and a way to infer the TFF type signatures for the tff.Computations in which this model will appear, the input_spec.

Notice that since TFF couples the tf.keras.Model and loss, TFF needs a slightly different notion of "fully specified type" than pure Keras does. That is, the model M takes inputs of type x and produces predictions of type p; the loss function L takes inputs of type <p, y> and produces a scalar. Therefore in order to fully specify the type signatures for computations in which the generated tff.learning.Model will appear, TFF needs the type y in addition to the type x.

keras_model A tf.keras.Model object that is not compiled.
loss A callable that takes two batched tensor parameters, y_true and y_pred, and returns the loss. If the model has multiple outputs, you can use a different loss on each output by passing a dictionary or a list of losses. The loss value that will be minimized by the model will then be the sum of all individual losses, each weighted by loss_weights.
input_spec A value convertible to tff.Type specifying the type of arguments the model expects. Notice this must be a compound structure of two elements, specifying both the data fed into the model to generate predictions, as its first element, as well as the expected type of the ground truth as its second.
loss_weights (Optional) a list or dictionary specifying scalar coefficients (Python floats) to weight the loss contributions of different model outputs. The loss value that will be minimized by the model will then be the weighted sum of all individual losses, weighted by the loss_weights coefficients. If a list, it is expected to have a 1:1 mapping to the model's outputs. If a tensor, it is expected to map output names (strings) to scalar coefficients.
metrics (Optional) a list of tf.keras.metrics.Metric objects.

A tff.learning.Model object.

TypeError If keras_model is not an instance of tf.keras.Model.
ValueError If keras_model was compiled, or , or input_spec does not contain two elements.
KeyError If loss is a dict and does not have the same keys as keras_model.outputs.