Builds the IterativeProcess for optimization using FedRecon.
tff.learning.reconstruction.build_training_process(
model_fn: tff.learning.reconstruction.Model
,
*,
loss_fn: LossFn,
metrics_fn: Optional[MetricsFn] = None,
server_optimizer_fn: OptimizerFn = functools.partial(tf.keras.optimizers.SGD, 1.0),
client_optimizer_fn: OptimizerFn = functools.partial(tf.keras.optimizers.SGD, 0.1),
reconstruction_optimizer_fn: OptimizerFn = functools.partial(tf.keras.optimizers.SGD, 0.1),
dataset_split_fn: Optional[tff.learning.reconstruction.DatasetSplitFn
] = None,
client_weighting: Optional[client_weight_lib.ClientWeightType] = None,
broadcast_process: Optional[tff.templates.MeasuredProcess
] = None,
aggregation_factory: Optional[AggregationFactory] = None
) -> tff.templates.IterativeProcess
Used in the notebooks
Returns a tff.templates.IterativeProcess
for Federated Reconstruction. On
the client, computation can be divided into two stages: (1) reconstruction of
local variables and (2) training of global variables.
Args |
model_fn
|
A no-arg function that returns a
tff.learning.reconstruction.Model . This method must not capture
Tensorflow tensors or variables and use them. must be constructed entirely
from scratch on each invocation, returning the same pre-constructed model
each call will result in an error.
|
loss_fn
|
A no-arg function returning a tf.keras.losses.Loss to use to
compute local model updates during reconstruction and post-reconstruction
and evaluate the model during training. The final loss metric is the
example-weighted mean loss across batches and across clients. The loss
metric does not include reconstruction batches in the loss.
|
metrics_fn
|
A no-arg function returning a list of tf.keras.metrics.Metric s
to evaluate the model. Metrics results are computed locally as described
by the metric, and are aggregated across clients as in
federated_aggregate_keras_metric . If None, no metrics are applied.
Metrics are not computed on reconstruction batches.
|
server_optimizer_fn
|
A tff.learning.optimizers.Optimizer , or a no-arg
function that returns a tf.keras.optimizers.Optimizer for applying
updates to the global model on the server.
|
client_optimizer_fn
|
A tff.learning.optimizers.Optimizer , or a no-arg
function that returns a tf.keras.optimizers.Optimizer for local client
training after reconstruction.
|
reconstruction_optimizer_fn
|
A tff.learning.optimizers.Optimizer , or a
no-arg function that returns a tf.keras.optimizers.Optimizer used to
reconstruct the local variables, with the global ones frozen, or the first
stage described above.
|
dataset_split_fn
|
A tff.learning.reconstruction.DatasetSplitFn taking in a
single TF dataset and producing two TF datasets. The first is iterated
over during reconstruction, and the second is iterated over
post-reconstruction. This can be used to preprocess datasets to e.g.
iterate over them for multiple epochs or use disjoint data for
reconstruction and post-reconstruction. If None, split client data in half
for each user, using one half for reconstruction and the other for
evaluation. See tff.learning.reconstruction.build_dataset_split_fn for
options.
|
client_weighting
|
A value of tff.learning.ClientWeighting that specifies a
built-in weighting method, or a callable that takes the local metrics of
the model and returns a tensor that provides the weight in the federated
average of model deltas. If None, defaults to weighting by number of
examples.
|
broadcast_process
|
A tff.templates.MeasuredProcess that broadcasts the
model weights on the server to the clients. It must support the signature
(input_values@SERVER -> output_values@CLIENT) . If set to default None,
the server model is broadcast to the clients using the default
tff.federated_broadcast .
|
aggregation_factory
|
An optional instance of
tff.aggregators.WeightedAggregationFactory or
tff.aggregators.UnweightedAggregationFactory determining the method of
aggregation to perform. If unspecified, uses a default
tff.aggregators.MeanFactory which computes a stateless mean across
clients (weighted depending on client_weighting ).
|