![]() |
Represents a reconstruction model for use in Tensorflow Federated.
tff.learning.reconstruction.Model
s are used to train models that reconstruct
a set of their variables on device, never sharing those variables with the
server.
Each tff.learning.reconstruction.Model
will work on a set of tf.Variables
,
and each method should be a computation that can be implemented as a
tf.function
; this implies the class should essentially be stateless from a
Python perspective, as each method will generally only be traced once (per set
of arguments) to create the corresponding TensorFlow graph functions. Thus,
tff.learning.reconstruction.Model
instances should behave as expected in
both eager and graph (TF 1.0) usage.
In general, tf.Variables
may be either:
- Weights, the variables needed to make predictions with the model.
- Local variables, e.g. to accumulate aggregated metrics across calls to forward_pass.
The weights can be broken down into:
- Global variables: Variables that are allowed to be aggregated on the server.
- Local variables: Variables that cannot leave the device.
Furthermore, both of these types of variables can be:
- Trainable variables: These can and should be trained using gradient-based methods.
- Non-trainable variables: Could include fixed pre-trained layers or static model data.
These variables are provided via:
global_trainable_variables
global_non_trainable_variables
local_trainable_variables
local_non_trainable_variables
properties, and must be initialized by the user of the
tff.learning.reconstruction.Model
.
While training a reconstruction model, global trainable variables will generally be provided by the server. Local trainable variables will then be reconstructed locally. Updates to the global trainable variables will be sent back to the server. Local variables are not transmitted.
All tf.Variables
should be introduced in __init__
; this could move to a
build
method more inline with Keras (see
https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) in
the future.
Attributes | |
---|---|
global_non_trainable_variables
|
An iterable of tf.Variable objects, see class comment for details.
|
global_trainable_variables
|
An iterable of tf.Variable objects, see class comment for details.
|
input_spec
|
The type specification of the batch_input parameter for forward_pass .
A nested structure of |
local_non_trainable_variables
|
An iterable of tf.Variable objects, see class comment for details.
|
local_trainable_variables
|
An iterable of tf.Variable objects, see class comment for details.
|
Methods
forward_pass
@abc.abstractmethod
forward_pass( batch_input, training=True )
Runs the forward pass and returns results.
This method should not modify any variables that are part of the model parameters, that is, variables that influence the predictions. Rather, this is done by the training loop.
Args | |
---|---|
batch_input
|
A nested structure that matches the structure of
Model.input_spec and each tensor in batch_input satisfies
tf.TensorSpec.is_compatible_with() for the corresponding
tf.TensorSpec in Model.input_spec .
|
training
|
If True , run the training forward pass, otherwise, run in
evaluation mode. The semantics are generally the same as the training
argument to keras.Model.__call__ ; this might e.g. influence how
dropout or batch normalization is handled.
|
Returns | |
---|---|
A BatchOutput object.
|