tff.learning.from_keras_model

Builds a tff.learning.Model from a tf.keras.Model.

Used in the notebooks

Used in the tutorials

The tff.learning.Model returned by this function uses keras_model for its forward pass and autodifferentiation steps.

Notice that since TFF couples the tf.keras.Model and loss, TFF needs a slightly different notion of "fully specified type" than pure Keras does. That is, the model M takes inputs of type x and produces predictions of type p; the loss function L takes inputs of type <p, y> and produces a scalar. Therefore in order to fully specify the type signatures for computations in which the generated tff.learning.Model will appear, TFF needs the type y in addition to the type x.

keras_model A tf.keras.Model object that is not compiled.
loss A tf.keras.losses.Loss, or a list of losses-per-output if the model has multiple outputs. If multiple outputs are present, the model will attempt to minimize the sum of all individual losses (optionally weighted using the loss_weights argument).
input_spec A structure of tf.TensorSpecs or tff.Type specifying the type of arguments the model expects. Notice this must be a compound structure of two elements, specifying both the data fed into the model (x) to generate predictions as well as the expected type of the ground truth (y). If provided as a list, it must be in the order [x, y]. If provided as a dictionary, the keys must explicitly be named 'x' and 'y'.
loss_weights (Optional) A list of Python floats used to weight the loss contribution of each model output.
metrics (Optional) a list of tf.keras.metrics.Metric objects.

A tff.learning.Model object.

TypeError If keras_model is not instance of tf.keras.Model, if keras_model has a single output and loss is not instance of tf.keras.losses.Loss, or if keras_model has multiple outputs and loss is not a list of instances of tf.keras.losses.Loss.
ValueError If keras_model was compiled, if keras_model has multiple outputs and loss is not list of equal length, if input_spec does not contain exactly two elements, or if input_spec is a dictionary and does not contain keys 'x' and 'y'.