Watch talks from the 2019 TensorFlow Dev Summit Watch now

tff.learning.from_keras_model

tff.learning.from_keras_model(
    keras_model,
    dummy_batch,
    loss,
    metrics=None,
    optimizer=None
)

Defined in learning/model_utils.py.

Builds a tff.learning.Model for an example mini batch.

Args:

  • keras_model: A tf.keras.Model object that is not compiled.
  • dummy_batch: A nested structure of values that are convertible to batched tensors with the same shapes and types as would be input to keras_model. The values of the tensors are not important and can be filled with any reasonable input value.
  • loss: A callable that takes two batched tensor parameters, y_true and y_pred, and returns the loss.
  • metrics: (Optional) a list of tf.keras.metrics.Metric objects.
  • optimizer: (Optional) a tf.keras.optimizer.Optimizer. If None, returned model cannot be used for training.

Returns:

A tff.learning.Model object.

Raises:

  • TypeError: If keras_model is not an instance of tf.keras.Model.
  • ValueError: If keras_model was compiled.