tff.learning.reconstruction.from_keras_model

Stay organized with collections Save and categorize content based on your preferences.

Builds a tff.learning.reconstruction.Model from a tf.keras.Model.

Used in the notebooks

Used in the tutorials

The tff.learning.reconstruction.Model returned by this function uses keras_model for its forward pass and autodifferentiation steps. During reconstruction, variables in local_layers are initialized and trained. Post-reconstruction, variables in global_layers are trained and aggregated on the server. All variables must be partitioned between global and local layers, without overlap.

keras_model A tf.keras.Model object that is not compiled.
global_layers Iterable of global layers to be aggregated across users. All trainable and non-trainable model variables that can be aggregated on the server should be included in these layers.
local_layers Iterable of local layers not shared with the server. All trainable and non-trainable model variables that should not be aggregated on the server should be included in these layers.
input_spec A structure of tf.TensorSpecs specifying the type of arguments the model expects. Notice this must be a compound structure of two elements, specifying both the data fed into the model to generate predictions, as its first element, as well as the expected type of the ground truth as its second.

A tff.learning.reconstruction.Model object.

TypeError If keras_model is not an instance of tf.keras.Model.
ValueError If keras_model was compiled.