Help protect the Great Barrier Reef with TensorFlow on Kaggle Join Challenge


Builds a tff.learning.Model from a tf.keras.Model.

Used in the notebooks

Used in the tutorials

The tff.learning.Model returned by this function uses keras_model for its forward pass and autodifferentiation steps.

Notice that since TFF couples the tf.keras.Model and loss, TFF needs a slightly different notion of "fully specified type" than pure Keras does. That is, the model M takes inputs of type x and produces predictions of type p; the loss function L takes inputs of type <p, y> (where y is the ground truth label type) and produces a scalar. Therefore in order to fully specify the type signatures for computations in which the generated tff.learning.Model will appear, TFF needs the type y in addition to the type x.

keras_model A tf.keras.Model object that is not compiled.
loss A single tf.keras.losses.Loss or a list of losses-per-output. If a single loss is provided, then all model output (as well as all prediction information) is passed to the loss; this includes situations of multiple model outputs and/or predictions. If multiple losses are provided as a list, then each loss is expected to correspond to a model output; the model will attempt to minimize the sum of all individual losses (optionally weighted using the loss_weights argument).
input_spec A structure of tf.TensorSpecs or tff.Type specifying the type of arguments the model expects. If input_spec is a tff.Type, its leaf nodes must be TensorTypes. Note that input_spec must be a compound structure of two elements, specifying both the data fed into the model (x) to generate predictions as well as the expected type of the ground truth (y). If provided as a list, it must be in the order [x, y]. If provided as a dictionary, the keys must explicitly be named 'x' and 'y'.
loss_weights (Optional) A list of Python floats used to weight the loss contribution of each model output (when providing a list of losses for the loss argument).
metrics (Optional) a list of tf.keras.metrics.Metric objects.

A tff.learning.Model object.

TypeError If keras_model is not an instance of tf.keras.Model, if loss is not an instance of tf.keras.losses.Loss nor a list of instances of tf.keras.losses.Loss, if input_spec is a tff.Type but the leaf nodes are not tff.TensorTypes, if loss_weight is provided but is not a list of floats, or if metrics is provided but is not a list of instances of tf.keras.metrics.Metric.
ValueError If keras_model was compiled, if loss is a list of unequal length to the number of outputs of keras_model, if loss_weights is specified but loss is not a list, if input_spec does not contain exactly two elements, or if input_spec is a dictionary and does not contain keys 'x' and 'y'.