|View source on GitHub|
Represents a model for use in TensorFlow Federated.
Used in the notebooks
|Used in the tutorials|
Model will work on a set of
tf.Variables, and each method should be
a computation that can be implemented as a
tf.function; this implies the
class should essentially be stateless from a Python perspective, as each
method will generally only be traced once (per set of arguments) to create the
corresponding TensorFlow graph functions. Thus,
Model instances should
behave as expected in both eager and graph (TF 1.0) usage.
tf.Variables may be either:
- Weights, the variables needed to make predictions with the model.
- Local variables, e.g. to accumulate aggregated metrics across calls to forward_pass.
The weights can be broken down into trainable variables (variables
that can and should be trained using gradient-based methods), and
non-trainable variables (which could include fixed pre-trained layers,
or static model data). These variables are provided via the
properties, and must be initialized by the user of the
In federated learning, model weights will generally be provided by the
server, and updates to trainable model variables will be sent back to the
server. Local variables are not transmitted, and are instead initialized
locally on the device, and then used to produce
are sent to the server.
tf.Variables should be introduced in
__init__; this could move to a
build method more inline with Keras (see
Performs federated aggregation of the
This is typically used to aggregate metrics across many clients, e.g. the body of the computation might be:
N.B. It is assumed all TensorFlow computation happens in the
An iterable of
An iterable of
An iterable of
forward_pass( batch_input, training=True ) ->
Runs the forward pass and returns results.
This method must be serializable in a
tff.tf_computation or other backend
decorator. Any pure-Python or unserializable logic will not be runnable in
the federated system.
This method should not modify any variables that are part of the model
parameters, that is, variables that influence the predictions (exceptions
being updated, rather than learned, parameters such as BatchNorm means and
variances). Rather, this is done by the training loop. However, this method
may update aggregated metrics computed across calls to
final values of such metrics can be accessed via
Uses in TFF:
- To implement model evaluation.
- To implement federated gradient descent and other non-Federated-Averaging algorithms, where we want the model to run the forward pass and update metrics, but there is no optimizer (we might only compute gradients on the returned loss).
- To implement Federated Averaging.
A nested structure that matches the structure of
metric_finalizers() -> MetricFinalizersType
OrderedDict of metric names to finalizers.
This method and the
report_local_unfinalized_metrics() method should have
the same keys (i.e., metric names). A finalizer returned by this method is a
function (typically a
tf.function decorated callable or a
tff.tf_computation decoreated TFF Computation) that takes in a metric's
unfinalized values (returned by
returns the finalized metric values.
This method and the
report_local_unfinalized_metrics() method will be used
together to build a cross-client metrics aggregator. See the documentaion of
report_local_unfinalized_metrics() for more information.
predict_on_batch( batch_input, training=True )
Returns tensors representing values aggregated over
In federated learning, the values returned by this method will typically be further aggregated across clients and made available on the server.
This method returns results from aggregating across all previous calls
forward_pass, most typically metrics like accuracy and loss. If needed,
we may add a
clear_aggregated_outputs method, which would likely just
run the initializers on the
In general, the tensors returned can be an arbitrary function of all
tf.Variables of this model, not just the
example, this could return tensors measuring the total L2 norm of the model
(which might have been updated by training).
This method may return arbitrarily shaped tensors, not just scalar metrics. For example, it could return the average feature vector or a count of how many times each feature exceed a certain magnitude.
A structure of tensors (as supported by
report_local_unfinalized_metrics() -> OrderedDict[str, Any]
OrderedDict of metric names to unfinalized values.
For a metric, its unfinalized values are given as a structure (typically a
list) of tensors representing values from aggregating over all previous
forward_pass calls. For a Keras metric, its unfinalized values are
typically the tensor values of its state variables. In general, the tensors
can be an arbitrary function of all the
tf.Variables of this model.
The metric names returned by this method should be the same as those
expected by the
metric_finalizers(); one should be able to use the
unfinalized values as input to the finalizers to get the finalized values.
tf.keras.metrics.CategoricalAccuracy as an example, its unfinalized
values can be a list of two tensors (from its state variables):
count, and the finalizer function performs a
In federated learning, this method returns the local results from clients,
which will typically be further aggregated across clients and made available
on the server. This method and the
metric_finalizers() method will be used
together to build a cross-client metrics aggregator. For example, a simple
"sum_then_finalize" aggregator will first sum the unfinalized metric values
from clients, and then call the finalizer functions at the server.
Because both of this method and the
metric_finalizers() method are defined
in a per-metric manner, users have the flexiblity to call finalizer at the
clients or at the server for different metrics. Users also have the freedom
to defined a cross-client metrics aggregator that aggregates a single metric
in multiple ways.