Google I/O is a wrap! Catch up on TensorFlow sessions View sessions

Composing Learning Algorithms

View on Run in Google Colab View source on GitHub Download notebook

Before you start

Before you start, please run the following to make sure that your environment is correctly setup. If you don't see a greeting, please refer to the Installation guide for instructions.

!pip install --quiet --upgrade tensorflow-federated
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
from typing import Callable

import tensorflow as tf
import tensorflow_federated as tff

Composing Learning Algorithms

The Building Your Own Federated Learning Algorithm Tutorial used TFF's federated core to directly implement a version of the Federated Averaging (FedAvg) algorithm.

In this tutorial, you will use federated learning components in TFF's API to build federated learning algorithms in a modular manner, without having to re-implement everything from scratch.

For the purposes of this tutorial, you will implement a variant of FedAvg that employs gradient clipping through local training.

Learning Algorithm Building Blocks

At a high level, many learning algorithms can be separated into 4 separate components, referred to as building blocks. These are as follows:

  1. Distributor (ie. server-to-client communication)
  2. Client work (ie. local client computation)
  3. Aggregator (ie. client-to-server communication)
  4. Finalizer (ie. server computation using aggregated client outputs)

While the Building Your Own Federated Learning Algorithm Tutorial implemented all of these building blocks from scratch, this is often unnecessary. Instead, you can re-use building blocks from similar algorithms.

In this case, to implement FedAvg with gradient clipping, you only need to modify the client work building block. The remaining blocks can be identical to what is used in "vanilla" FedAvg.

Implementing the Client Work

First, let's write TF logic that does local model training with gradient clipping. For simplicity, gradients will be clipped have norm at most 1.

TF Logic

def client_update(model: tff.learning.Model,
                  server_weights: tff.learning.ModelWeights,
                  client_optimizer: tf.keras.optimizers.Optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = tff.learning.ModelWeights.from_model(model)
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)

  # Use the client_optimizer to update the local model.
  # Keep track of the number of examples as well.
  num_examples = 0.0
  for batch in dataset:
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data
      outputs = model.forward_pass(batch)
      num_examples += tf.cast(outputs.num_examples, tf.float32)

    # Compute the corresponding gradient
    grads = tape.gradient(outputs.loss, client_weights.trainable)

    # Compute the gradient norm and clip
    gradient_norm = tf.linalg.global_norm(grads)
    if gradient_norm > 1:
      grads = tf.nest.map_structure(lambda x: x/gradient_norm, grads)

    grads_and_vars = zip(grads, client_weights.trainable)

    # Apply the gradient using a client optimizer.

  # Compute the difference between the server weights and the client weights
  client_update = tf.nest.map_structure(tf.subtract,

  return tff.learning.templates.ClientResult(
      update=client_update, update_weight=num_examples)

There are a few important points about the code above. First, it keeps track of the number of examples seen, as this will constitute the weight of the client update (when computing an average across clients).

Second, it uses tff.learning.templates.ClientResult to package the output. This return type is used to standardize client work building blocks in tff.learning.

Creating a ClientWorkProcess

While the TF logic above will do local training with clipping, it still needs to be wrapped in TFF code in order to create the necessary building block.

Specifically, the 4 building blocks are represented as a tff.templates.MeasuredProcess. This means that all 4 blocks have both an initialize and next function used to instantiate and run the computation.

This allows each building block to keep track of its own state (stored at the server) as needed to perform its operations. While it will not be used in this tutorial, it can be used for things like tracking how many iterations have occurred, or keeping track of optimizer states.

Client work TF logic should generally be wrapped as a tff.learning.templates.ClientWorkProcess, which codifies the expected types going into and out of the client's local training. It can be parameterized by a model and optimizer, as below.

def build_gradient_clipping_client_work(
    model_fn: Callable[[], tff.learning.Model],
    optimizer_fn: Callable[[], tf.keras.optimizers.Optimizer],
) -> tff.learning.templates.ClientWorkProcess:
  """Creates a client work process that uses gradient clipping."""

  with tf.Graph().as_default():
    # Wrap model construction in a graph to avoid polluting the global context
    # with variables created for this model.
    model = model_fn()
  data_type = tff.SequenceType(model.input_spec)
  model_weights_type = tff.learning.framework.weights_type_from_model(model)

  def initialize_fn():
    return tff.federated_value((), tff.SERVER)

  @tff.tf_computation(model_weights_type, data_type)
  def client_update_computation(model_weights, dataset):
    model = model_fn()
    optimizer = optimizer_fn()
    return client_update(model, dataset, model_weights, optimizer)

  def next_fn(state, model_weights, client_dataset):
    client_result = tff.federated_map(
        client_update_computation, (model_weights, client_dataset))
    # Return empty measurements, though a more complete algorithm might
    # measure something here.
    measurements = tff.federated_value((), tff.SERVER)
    return tff.templates.MeasuredProcessOutput(state, client_result,
  return tff.learning.templates.ClientWorkProcess(
      initialize_fn, next_fn)

Composing a Learning Algorithm

Let's put the client work above into a full-fledged algorithm. First, let's set up our data and model.

Preparing the input data

Load and preprocess the EMNIST dataset included in TFF. For more details, see the image classification tutorial.

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

In order to feed the dataset into our model, the data is flattened and converted into tuples of the form (flattened_image_vector, label).

Let's select a small number of clients, and apply the preprocessing above to their datasets.


def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch of EMNIST data and return a (features, label) tuple."""
    return (tf.reshape(element['pixels'], [-1, 784]), 
            tf.reshape(element['label'], [-1, 1]))

  return dataset.batch(BATCH_SIZE).map(batch_format_fn)

client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
  for x in client_ids

Preparing the model

This uses the same model as in the image classification tutorial. This model (implemented via tf.keras) has a single hidden layer, followed by a softmax layer. In order to use this model in TFF, Keras model is wrapped as a tff.learning.Model. This allows us to perform the model's forward pass within TFF, and extract model outputs. For more details, also see the image classification tutorial.

def create_keras_model():
  initializer = tf.keras.initializers.GlorotNormal(seed=0)
  return tf.keras.models.Sequential([
      tf.keras.layers.Dense(10, kernel_initializer=initializer),

def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(

Preparing the optimizers

Just as in tff.learning.build_federated_averaging_process, there are two optimizers here: A client optimizer, and a server optimizer. For simplicity, the optimizers will be SGD with different learning rates.

client_optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate=0.01)
server_optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate=1.0)

Defining the building blocks

Now that the client work building block, data, model, and optimizers are set up, it remains to create building blocks for the distributor, the aggregator, and the finalizer. This can be done just by borrowing some defaults available in TFF and that are used by FedAvg.

def initial_model_weights_fn():
  return tff.learning.ModelWeights.from_model(model_fn())

model_weights_type = initial_model_weights_fn.type_signature.result

distributor = tff.learning.templates.build_broadcast_process(model_weights_type)
client_work = build_gradient_clipping_client_work(model_fn, client_optimizer_fn)

# TFF aggregators use a factory pattern, which create an aggregator
# based on the output type of the client work. This also uses a float (the number
# of examples) to govern the weight in the average being computed.)
aggregator_factory = tff.aggregators.MeanFactory()
aggregator = aggregator_factory.create(model_weights_type.trainable,
finalizer = tff.learning.templates.build_apply_optimizer_finalizer(
    server_optimizer_fn, model_weights_type)

Composing the building blocks

Finally, you can use a built-in composer in TFF for putting the building blocks together. This one is a relatively simple composer, which takes the 4 building blocks above and wires their types together.

fed_avg_with_clipping = tff.learning.templates.compose_learning_process(

Running the algorithm

Now that the algorithm is done, let's run it. First, initialize the algorithm. The state of this algorithm has a component for each building block, along with one for the global model weights.

state = fed_avg_with_clipping.initialize()


As expected, the client work has an empty state (remember the client work code above!). However, other building blocks may have non-empty state. For example, the finalizer keeps track of how many iterations have occurred. Since next has not been run yet, it has a state of 0.


Now run a training round.

learning_process_output =, federated_train_data)

The output of this (tff.learning.templates.LearningProcessOutput) has both a .state and .metrics output. Let's look at both.


Clearly, the finalizer state has incremented by one, as one round of .next has been run.

OrderedDict([('distributor', ()),
             ('client_work', ()),
              OrderedDict([('mean_value', ()), ('mean_weight', ())])),
             ('finalizer', ())])

While the metrics are empty, for more complex and practical algorithms they'll generally be full of useful information.


By using the building block/composers framework above, you can create entirely new learning algorithms, without having to re-do everything from scratch. However, this is only the starting point. This framework makes it much easier to express algorithms as simple modifications of FedAvg. For more algorithms, see tff.learning.algorithms, which contains algorithms such as FedProx and FedAvg with client learning rate scheduling. These APIs can even aid implementations of entirely new algorithms, such as federated k-means clustering.