Missed TensorFlow World? Check out the recap. Learn more

tff.utils.StatefulAggregateFn

View source on GitHub

Class StatefulAggregateFn

A simple container for a stateful aggregation function.

A typical (though trivial) example would be:

stateless_federated_mean = tff.utils.StatefulAggregateFn(
    initialize_fn=lambda: (),  # The state is an empty tuple.
    next_fn=lambda state, value, weight=None: (
        state, tff.federated_mean(value, weight=weight)))

__init__

View source

__init__(
    initialize_fn,
    next_fn
)

Creates the StatefulFn.

Args:

  • initialize_fn: A no-arg function that returns a Python container which can be converted to a tff.Value, placed on the tff.SERVER, and passed as the first argument of __call__. This may be called in vanilla TensorFlow code, typically wrapped as a tff.tf_computation, as part of the initialization of a larger state object.
  • next_fn: A function matching the signature of __call__, see below.

Methods

__call__

View source

__call__(
    state,
    value,
    weight=None
)

Performs an aggregate of value@CLIENTS, producing value@SERVER.

The aggregation is optionally parameterized by weight@CLIENTS.

This is a function intended to (only) be invoked in the context of a tff.federated_computation. It should be compatible with the TFF type signature.

(state@SERVER, value@CLIENTS, weight@CLIENTS) ->
     (state@SERVER, aggregate@SERVER).

Args:

Returns:

A tuple of tff.Values (state@SERVER, aggregate@SERVER), where

  • state: The updated state.
  • aggregate: The result of the aggregation of value weighted by weight.

initialize

View source

initialize()

Returns the initial state.