Join us at TensorFlow World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

tff.learning.framework.build_model_delta_optimizer_process

View source

Constructs tff.utils.IterativeProcess for Federated Averaging or SGD.

tff.learning.framework.build_model_delta_optimizer_process(
    model_fn,
    model_to_client_delta_fn,
    server_optimizer_fn,
    stateful_delta_aggregate_fn=build_stateless_mean(),
    stateful_model_broadcast_fn=build_stateless_broadcaster()
)

This provides the TFF orchestration logic connecting the common server logic which applies aggregated model deltas to the server model with a ClientDeltaFn that specifies how weight_deltas are computed on device.

Args:

  • model_fn: A no-arg function that returns a tff.learning.Model.
  • model_to_client_delta_fn: A function from a model_fn to a ClientDeltaFn.
  • server_optimizer_fn: A no-arg function that returns a tf.Optimizer. The apply_gradients method of this optimizer is used to apply client updates to the server model.
  • stateful_delta_aggregate_fn: A tff.utils.StatefulAggregateFn where the next_fn performs a federated aggregation and upates state. That is, it has TFF type (state@SERVER, value@CLIENTS, weights@CLIENTS) -> (state@SERVER, aggregate@SERVER), where the value type is tff.learning.framework.ModelWeights.trainable corresponding to the object returned by model_fn.
  • stateful_model_broadcast_fn: A tff.utils.StatefulBroadcastFn where the next_fn performs a federated broadcast and upates state. That is, it has TFF type (state@SERVER, value@SERVER) -> (state@SERVER, value@CLIENTS), where the value type is tff.learning.framework.ModelWeights corresponding to the object returned by model_fn.

Returns:

A tff.utils.IterativeProcess.