Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tff.learning.assign_weights_to_keras_model

View source on GitHub

Assigns a nested structure of TFF weights to a Keras model.

tff.learning.assign_weights_to_keras_model(
    keras_model, tff_weights
)

Used in the notebooks

Used in the tutorials

This function may be used to retrieve the model parameters trained by the federated averaging process for use in an existing tf.keras.models.Model, e.g.:

keras_model = tf.keras.models.Model(inputs=..., outputs=...)

def model_fn():
  return tff.learning.from_keras_model(keras_model)

fed_avg = tff.learning.build_federated_averaging_process(model_fn, ...)
state = fed_avg.initialize()
state = fed_avg.next(state, ...)
...
tff.learning.assign_weights_to_keras_model(keras_model, state.model)

Args:

  • keras_model: A tf.keras.models.Model instance to assign weights to.
  • tff_weights: A TFF value representing the weights of a model.

Raises: