tff.learning.models.ReconstructionModel

Represents a reconstruction model for use in Tensorflow Federated.

Used in the notebooks

Used in the tutorials

tff.learning.models.ReconstructionModels are used to train models that reconstruct a set of their variables on device, never sharing those variables with the server.

Each tff.learning.models.ReconstructionModel will work on a set of tf.Variables, and each method should be a computation that can be implemented as a tf.function; this implies the class should essentially be stateless from a Python perspective, as each method will generally only be traced once (per set of arguments) to create the corresponding TensorFlow graph functions. Thus, tff.learning.models.ReconstructionModel instances should behave as expected in both eager and graph (TF 1.0) usage.

In general, tf.Variables may be either:

  • Weights, the variables needed to make predictions with the model.
  • Local variables, e.g. to accumulate aggregated metrics across calls to forward_pass.

The weights can be broken down into:

  • Global variables: Variables that are allowed to be aggregated on the server.
  • Local variables: Variables that cannot leave the device.

Furthermore, both of these types of variables can be:

  • Trainable variables: These can and should be trained using gradient-based methods.
  • Non-trainable variables: Could include fixed pre-trained layers or static model data.

These variables are provided via:

  • global_trainable_variables
  • global_non_trainable_variables
  • local_trainable_variables
  • local_non_trainable_variables

properties, and must be initialized by the user of the tff.learning.models.ReconstructionModel.

While training a reconstruction model, global trainable variables will generally be provided by the server. Local trainable variables will then be reconstructed locally. Updates to the global trainable variables will be sent back to the server. Local variables are not transmitted.

All tf.Variables should be introduced in __init__; this could move to a build method more inline with Keras (see https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) in the future.

global_non_trainable_variables An iterable of tf.Variable objects, see class comment for details.
global_trainable_variables An iterable of tf.Variable objects, see class comment for details.
input_spec The type specification of the batch_input parameter for forward_pass.

A nested structure of tf.TensorSpec objects, that matches the structure of arguments that will be passed as the batch_input argument of forward_pass. The tensors must include a batch dimension as the first dimension, but the batch dimension may be undefined.

local_non_trainable_variables An iterable of tf.Variable objects, see class comment for details.
local_trainable_variables An iterable of tf.Variable objects, see class comment for details.

Methods

build_dataset_split_fn

View source

Builds a ReconstructionDatasetSplitFn for training/evaluation.

The returned ReconstructionDatasetSplitFn parameterizes training and evaluation computations and enables reconstruction for multiple local epochs, multiple epochs of post-reconstruction training, limiting the number of steps for both stages, and splitting client datasets into disjoint halves for each stage.

Note that the returned function is used during both training and evaluation: during training, "post-reconstruction" refers to training of global variables, and during evaluation, it refers to calculation of metrics using reconstructed local variables and fixed global variables.

Args
recon_epochs The integer number of iterations over the dataset to make during reconstruction.
recon_steps_max If not None, the integer maximum number of steps (batches) to iterate through during reconstruction. This maximum number of steps is across all reconstruction iterations, i.e. it is applied after recon_epochs. If None, this has no effect.
post_recon_epochs The integer constant number of iterations to make over client data after reconstruction.
post_recon_steps_max If not None, the integer maximum number of steps (batches) to iterate through after reconstruction. This maximum number of steps is across all post-reconstruction iterations, i.e. it is applied after post_recon_epochs. If None, this has no effect.
split_dataset If True, splits client_dataset in half for each user, using even-indexed entries in reconstruction and odd-indexed entries after reconstruction. If False, client_dataset is used for both reconstruction and post-reconstruction, with the above arguments applied. If True, splitting requires that mupltiple iterations through the dataset yield the same ordering. For example if client_dataset.shuffle(reshuffle_each_iteration=True) has been called, then the split datasets may have overlap. If True, note that the dataset should have more than one batch for reasonable results, since the splitting does not occur within batches.

Returns
A SplitDatasetFn.

forward_pass

View source

Runs the forward pass and returns results.

This method should not modify any variables that are part of the model parameters, that is, variables that influence the predictions. Rather, this is done by the training loop.

Args
batch_input A nested structure that matches the structure of Model.input_spec and each tensor in batch_input satisfies tf.TensorSpec.is_compatible_with() for the corresponding tf.TensorSpec in Model.input_spec.
training If True, run the training forward pass, otherwise, run in evaluation mode. The semantics are generally the same as the training argument to keras.Model.__call__; this might e.g. influence how dropout or batch normalization is handled.

Returns
A ReconstructionBatchOutput object.

from_keras_model_and_layers

View source

Builds a tff.learning.models.ReconstructionModel from a tf.keras.Model.

The tff.learning.models.ReconstructionModel returned by this function uses keras_model for its forward pass and autodifferentiation steps. During reconstruction, variables in local_layers are initialized and trained. Post-reconstruction, variables in global_layers are trained and aggregated on the server. All variables must be partitioned between global and local layers, without overlap.

Args
keras_model A tf.keras.Model object that is not compiled.
global_layers Iterable of global layers to be aggregated across users. All trainable and non-trainable model variables that can be aggregated on the server should be included in these layers.
local_layers Iterable of local layers not shared with the server. All trainable and non-trainable model variables that should not be aggregated on the server should be included in these layers.
input_spec A structure of tf.TensorSpecs specifying the type of arguments the model expects. Notice this must be a compound structure of two elements, specifying both the data fed into the model to generate predictions, as its first element, as well as the expected type of the ground truth as its second.

Returns
A tff.learning.models.ReconstructionModel object.

Raises
TypeError If keras_model is not an instance of tf.keras.Model.
ValueError If keras_model was compiled, if input_spec has unexpected structure (e.g., has more than two elements), if global_layers or local_layers contains layers that are not in keras_model, or if global_layers and local_layers are not disjoint in their variables.

from_keras_model_and_variables

View source

Builds a tff.learning.models.ReconstructionModel from a tf.keras.Model.

The tff.learning.models.ReconstructionModel returned by this function uses keras_model for its forward pass and autodifferentiation steps. During reconstruction, variables in local_trainable_variables are initialized and trained, and variables in local_non_trainable_variables are initialized. Post-reconstruction, variables in global_trainable_variables are trained and aggregated on the server. All keras_model variables must be partitioned between global_trainable_variables, global_non_trainable_variables, local_trainable_variables, and local_non_trainable_variables, without overlap.

Args
keras_model A tf.keras.Model object that is not compiled.
global_trainable_variables The trainable variables to associate with the post-reconstruction phase.
global_non_trainable_variables The non-trainable variables to associate with the post-reconstruction phase.
local_trainable_variables The trainable variables to associate with the reconstruction phase.
local_non_trainable_variables The non-trainable variables to associate with the reconstruction phase.
input_spec A structure of tf.TensorSpecs specifying the type of arguments the model expects. Notice this must be a compound structure of two elements, specifying both the data fed into the model to generate predictions, as its first element, as well as the expected type of the ground truth as its second.

Returns
A tff.learning.models.ReconstructionModel object.

Raises
TypeError If keras_model is not an instance of tf.keras.Model.
ValueError If keras_model was compiled, if keras_model is not already built, if input_spec has unexpected structure (e.g., has more than two elements), if global_layers or local_layers contains layers that are not in keras_model, or if global_layers and local_layers are not disjoint in their variables.

get_global_variables

View source

Gets global variables from model as ModelWeights.

get_local_variables

View source

Gets local variables from a Model as ModelWeights.

has_only_global_variables

View source

Returns True if the model has no local variables.

read_metric_variables

View source

Reads values from Keras metric variables.

simple_dataset_split_fn

View source

A ReconstructionDatasetSplitFn that returns the original client data.

Both the reconstruction data and post-reconstruction data will result from iterating over the same tf.data.Dataset. Note that depending on any preprocessing steps applied to client tf.data.Datasets, this may not produce exactly the same data in the same order for both reconstruction and post-reconstruction. For example, if client_dataset.shuffle(reshuffle_each_iteration=True) was applied, post-reconstruction data will be in a different order than reconstruction data.

Args
client_dataset tf.data.Dataset representing client data.

Returns
A tuple of two tf.data.Datasets, the first to be used for reconstruction, the second to be used for post-reconstruction.