The returned TFF computation broadcasts model weights from tff.SERVER to
tff.CLIENTS. Each client evaluates the personalization strategies given in
personalize_fn_dict. Evaluation metrics from at most max_num_clients
participating clients are collected to the server.
Args
model_fn
A no-arg function that returns a tff.learning.Model. This method
must not capture TensorFlow tensors or variables and use them. The model
must be constructed entirely from scratch on each invocation, returning
the same pre-constructed model each call will result in an error.
personalize_fn_dict
An OrderedDict that maps a string (representing a
strategy name) to a no-argument function that returns a tf.function.
Each tf.function represents a personalization strategy - it accepts a
tff.learning.Model (with weights already initialized to the given model
weights when users invoke the returned TFF computation), an unbatched
tf.data.Dataset for train, an unbatched tf.data.Dataset for test, and
an arbitrary context object (which is used to hold any extra information
that a personalization strategy may use), trains a personalized model, and
returns the evaluation metrics. The evaluation metrics are represented as
an OrderedDict (or a nested OrderedDict) of string metric names to
scalar tf.Tensors.
baseline_evaluate_fn
A tf.function that accepts a tff.learning.Model
(with weights already initialized to the provided model weights when users
invoke the returned TFF computation), and an unbatched tf.data.Dataset,
evaluates the model on the dataset, and returns the evaluation metrics.
The evaluation metrics are represented as an OrderedDict (or a nested
OrderedDict) of string metric names to scalar tf.Tensors. This
function is only used to compute the baseline metrics of the initial
model.
max_num_clients
A positive int specifying the maximum number of clients
to collect metrics in a round (default is 100). The clients are sampled
without replacement. For each sampled client, all the personalization
metrics from this client will be collected. If the number of participating
clients in a round is smaller than this value, then metrics from all
clients will be collected.
context_tff_type
A tff.Type of the optional context object used by the
personalization strategies defined in personalization_fn_dict. We use a
context object to hold any extra information (in addition to the training
dataset) that personalization may use. If context is used in
personalization_fn_dict, its tff.Type must be provided here.
Returns
A federated tff.Computation with the functional type signature
(<model_weights@SERVER, input@CLIENTS> -> personalization_metrics@SERVER):
Each client's input is an OrderedDict of two required keys
train_data and test_data; each key is mapped to an unbatched
tf.data.Dataset. If extra context (e.g., extra datasets) is used in
personalize_fn_dict, then client input has a third key context that
is mapped to a object whose tff.Type is provided by the
context_tff_type argument.
personazliation_metrics is an OrderedDict that maps a key
'baseline_metrics' to the evaluation metrics of the initial model
(computed by baseline_evaluate_fn), and maps keys (strategy names) in
personalize_fn_dict to the evaluation metrics of the corresponding
personalization strategies.
Raises
TypeError
If arguments are of the wrong types.
ValueError
If baseline_metrics is used as a key in personalize_fn_dict.