tff.learning.assign_weights_to_keras_model

View source on GitHub

Assigns a nested structure of TFF weights to a Keras model.

Used in the notebooks

Used in the tutorials

This function may be used to retrieve the model parameters trained by the federated averaging process for use in an existing tf.keras.models.Model, e.g.:

keras_model = tf.keras.models.Model(inputs=..., outputs=...)

def model_fn():
  return tff.learning.from_keras_model(keras_model)

fed_avg = tff.learning.build_federated_averaging_process(model_fn, ...)
state = fed_avg.initialize()
state = fed_avg.next(state, ...)
...
tff.learning.assign_weights_to_keras_model(keras_model, state.model)

keras_model A tf.keras.models.Model instance to assign weights to.
tff_weights A TFF value representing the weights of a model.

TypeError if tff_weights is not a TFF value, or keras_model is not a tf.keras.models.Model instance.