|View source on GitHub|
A simple container for a stateful aggregation function.
A typical (though trivial) example would be:
stateless_federated_mean = tff.utils.StatefulAggregateFn( initialize_fn=lambda: (), # The state is an empty tuple. next_fn=lambda state, value, weight=None: ( state, tff.federated_mean(value, weight=weight)))
__init__( initialize_fn, next_fn )
Creates the StatefulFn.
initialize_fn: A no-arg function that returns a Python container which can be converted to a
tff.Value, placed on the
tff.SERVER, and passed as the first argument of
__call__. This may be called in vanilla TensorFlow code, typically wrapped as a
tff.tf_computation, as part of the initialization of a larger state object.
next_fn: A function matching the signature of
__call__, see below.
__call__( state, value, weight=None )
Performs an aggregate of
The aggregation is optionally parameterized by
This is a function intended to (only) be invoked in the context
tff.federated_computation. It should be compatible with the
TFF type signature.
(state@SERVER, value@CLIENTS, weight@CLIENTS) -> (state@SERVER, aggregate@SERVER).
tff.Valueplaced on the
tff.Valueto be aggregated, placed on the
weight: An optional
values, placed on the
A tuple of
(state@SERVER, aggregate@SERVER), where
state: The updated state.
aggregate: The result of the aggregation of
Returns the initial state.