![]() |
Client TensorFlow logic for Federated Averaging.
Inherits From: ClientDeltaFn
tff.learning.ClientFedAvg(
model: tff.learning.Model
,
optimizer: tf.keras.optimizers.Optimizer,
client_weight_fn: Optional[Callable[[Any], tf.Tensor]] = None,
use_experimental_simulation_loop: bool = False
)
Args | |
---|---|
model
|
A tff.learning.Model instance.
|
optimizer
|
A tf.keras.Optimizer instance.
|
client_weight_fn
|
an optional callable that takes the output of
model.report_local_outputs and returns a tensor that provides the
weight in the federated average of model deltas. If not provided, the
default is the total number of examples processed on device.
|
use_experimental_simulation_loop
|
Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. |
Attributes | |
---|---|
variables
|
Returns all the variables of this object.
Note this only includes variables that are part of the state of this object, and not the model variables themselves. |
Methods
__call__
@tf.function
__call__( dataset, initial_weights )