tff.learning.build_personalization_eval

View source on GitHub

Builds the TFF computation for evaluating personalization strategies.

The returned TFF computation broadcasts model weights from tff.SERVER to tff.CLIENTS. Each client evaluates the personalization strategies given in personalize_fn_dict. Evaluation metrics from at most max_num_samples participating clients are collected to the server.

model_fn A no-arg function that returns a tff.learning.Model. This method must not capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error.
personalize_fn_dict An OrderedDict that maps a string (representing a strategy name) to a no-argument function that returns a tf.function. Each tf.function represents a personalization strategy: it accepts a tff.learning.Model (with weights already initialized to the given model weights when users invoke the returned TFF computation), an unbatched tf.data.Dataset for train, an unbatched tf.data.Dataset for test, and an arbitrary context object (which is used to hold any extra information that a personalization strategy may use), trains a personalized model, and returns the evaluation metrics. The evaluation metrics are represented as an OrderedDict (or a nested OrderedDict) of string metric names to scalar tf.Tensors.
baseline_evaluate_fn A tf.function that accepts a tff.learning.Model (with weights already initialized to the provided model weights when users invoke the returned TFF computation), and an unbatched tf.data.Dataset, evaluates the model on the dataset, and returns the evaluation metrics. The evaluation metrics are represented as an OrderedDict (or a nested OrderedDict) of string metric names to scalar tf.Tensors. This function is only used to compute the baseline metrics of the initial model.
max_num_samples A positive int specifying the maximum number of metric samples to collect in a round. Each sample contains the personalization metrics from a single client. If the number of participating clients in a round is smaller than this value, all clients' metrics are collected.
context_tff_type A tff.Type of the optional context object used by the personalization strategies defined in personalization_fn_dict. We use a context object to hold any extra information (in addition to the training dataset) that personalization may use. If context is used in personalization_fn_dict, its tff.Type must be provided here.

A federated tff.Computation with the functional type signature (<model_weights@SERVER, input@CLIENTS> -> personalization_metrics@SERVER):

  • model_weights is a tff.learning.ModelWeights.
  • Each client's input is an OrderedDict of two required keys train_data and test_data; each key is mapped to an unbatched tf.data.Dataset. If extra context (e.g., extra datasets) is used in personalize_fn_dict, then client input has a third key context that is mapped to a object whose tff.Type is provided by the context_tff_type argument.
  • personazliation_metrics is an OrderedDict that maps a key 'baseline_metrics' to the evaluation metrics of the initial model (computed by baseline_evaluate_fn), and maps keys (strategy names) in personalize_fn_dict to the evaluation metrics of the corresponding personalization strategies.
  • Note: only metrics from at most max_num_samples participating clients (sampled without replacement) are collected to the SERVER. All collected metrics are stored in a single OrderedDict (personalization_metrics shown above), where each metric is mapped to a list of scalars (each scalar comes from one client). Metric values at the same position, e.g., metric_1[i], metric_2[i]..., all come from the same client.

TypeError If arguments are of the wrong types.
ValueError If baseline_metrics is used as a key in personalize_fn_dict.
ValueError If max_num_samples is not positive.