A simple container for a stateful aggregation function.
A typical (though trivial) example would be:
stateless_federated_mean = tff.utils.StatefulAggregateFn( initialize_fn=lambda: (), # The state is an empty tuple. next_fn=lambda state, value, weight=None: ( state, tff.federated_mean(value, weight=weight)))
__init__( initialize_fn, next_fn )
Creates the StatefulFn.
initialize_fn: A no-arg function that returns a Python container which can be converted to a
tff.Value, placed on the
tff.SERVER, and passed as the first argument of
__call__. This may be called in vanilla TensorFlow code, typically wrapped as a
tff.tf_compuatation, as part of the initialization of a larger state object.
next_fn: A function matching the signature of
__call__, see below.
__call__( state, value, weight=None )
Performs an aggregate of value@CLIENTS, with optional weight@CLIENTS.
This is a function intended to (only) be invoked in the context of a
It shold be compatible with the TFF type signature
(state@SERVER, value@CLIENTS, weight@CLIENTS) -> (state@SERVER, aggregate@SERVER).
tff.Valueplaced on the
tff.Valueto be aggregated, placed on the
weight: An optional
tff.Valuefor weighting values, placed on the
A tuple of
(state@SERVER, aggregate@SERVER) where * state: The updated state. * aggregate:
The result of the aggregation of
value weighted by `weight.
Returns the initial state.