Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings


View source on GitHub

Assigns a nested structure of TFF weights to a Keras model.

    keras_model, tff_weights

Used in the notebooks

Used in the tutorials

This function may be used to retrieve the model parameters trained by the federated averaging process for use in an existing tf.keras.models.Model, e.g.:

keras_model = tf.keras.models.Model(inputs=..., outputs=...)

def model_fn():
  return tff.learning.from_keras_model(keras_model)

fed_avg = tff.learning.build_federated_averaging_process(model_fn, ...)
state = fed_avg.initialize()
state = fed_avg.next(state, ...)
tff.learning.assign_weights_to_keras_model(keras_model, state.model)


  • keras_model: A tf.keras.models.Model instance to assign weights to.
  • tff_weights: A TFF value representing the weights of a model.