tff.learning.reconstruction.build_training_process

Stay organized with collections Save and categorize content based on your preferences.

Builds the IterativeProcess for optimization using FedRecon.

Used in the notebooks

Used in the tutorials

Returns a tff.templates.IterativeProcess for Federated Reconstruction. On the client, computation can be divided into two stages: (1) reconstruction of local variables and (2) training of global variables.

model_fn A no-arg function that returns a tff.learning.reconstruction.Model. This method must not capture Tensorflow tensors or variables and use them. must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error.
loss_fn A no-arg function returning a tf.keras.losses.Loss to use to compute local model updates during reconstruction and post-reconstruction and evaluate the model during training. The final loss metric is the example-weighted mean loss across batches and across clients. The loss metric does not include reconstruction batches in the loss.
metrics_fn A no-arg function returning a list of tf.keras.metrics.Metrics to evaluate the model. Metrics results are computed locally as described by the metric, and are aggregated across clients as in federated_aggregate_keras_metric. If None, no metrics are applied. Metrics are not computed on reconstruction batches.
server_optimizer_fn A tff.learning.optimizers.Optimizer, or a no-arg function that returns a tf.keras.optimizers.Optimizer for applying updates to the global model on the server.
client_optimizer_fn A tff.learning.optimizers.Optimizer, or a no-arg function that returns a tf.keras.optimizers.Optimizer for local client training after reconstruction.
reconstruction_optimizer_fn A tff.learning.optimizers.Optimizer, or a no-arg function that returns a tf.keras.optimizers.Optimizer used to reconstruct the local variables, with the global ones frozen, or the first stage described above.
dataset_split_fn A tff.learning.reconstruction.DatasetSplitFn taking in a single TF dataset and producing two TF datasets. The first is iterated over during reconstruction, and the second is iterated over post-reconstruction. This can be used to preprocess datasets to e.g. iterate over them for multiple epochs or use disjoint data for reconstruction and post-reconstruction. If None, split client data in half for each user, using one half for reconstruction and the other for evaluation. See tff.learning.reconstruction.build_dataset_split_fn for options.
client_weighting A value of tff.learning.ClientWeighting that specifies a built-in weighting method, or a callable that takes the local metrics of the model and returns a tensor that provides the weight in the federated average of model deltas. If None, defaults to weighting by number of examples.
broadcast_process A tff.templates.MeasuredProcess that broadcasts the model weights on the server to the clients. It must support the signature (input_values@SERVER -> output_values@CLIENT). If set to default None, the server model is broadcast to the clients using the default tff.federated_broadcast.
aggregation_factory An optional instance of tff.aggregators.WeightedAggregationFactory or tff.aggregators.UnweightedAggregationFactory determining the method of aggregation to perform. If unspecified, uses a default tff.aggregators.MeanFactory which computes a stateless mean across clients (weighted depending on client_weighting).

TypeError If broadcast_process does not have the expected signature.
TypeError If aggregation_factory does not have the expected signature.
ValueError If aggregation_factory is not a tff.aggregators.WeightedAggregationFactory or a tff.aggregators.UnweightedAggregationFactory.
ValueError If aggregation_factory is a tff.aggregators.UnweightedAggregationFactory but client_weighting is not tff.learning.ClientWeighting.UNIFORM.

A tff.templates.IterativeProcess.