Registration is open for TensorFlow Dev Summit 2020 Learn more

High-performance simulations with TFF

This tutorial will describe how to setup high-performance simulations with TFF in a variety of common scenarios.

NOTE: The mechanisms covered here are not included in the latest release, have not been tested yet, and the API may evolve. In order to follow this tutorial, you will need to build a TFF pip package from scratch from the latest sources, and install it in a Jupyter notebook with a Python 3 runtime. The new executor stack is not compatible with Python 2.

TODO(b/134543154): Populate the content, some of the things to cover here:

  • using GPUs in a single-machine setup,
  • multi-machine setup on GCP/GKE, with and without TPUs,
  • interfacing MapReduce-like backends,
  • current limitations and when/how they will be relaxed.

Before we begin

First, make sure your notebook is connected to a backend that has the relevant components (including gRPC dependencies for multi-machine scenarios) compiled.

Now, let's start by loading the MNIST example from the TFF website, and declaring the Python function that will run a small experiment loop over a group of 10 clients.

!pip install --quiet --upgrade tensorflow_federated

# NOTE: Jupyter requires a patch to asyncio.
!pip install --quiet --upgrade nest_asyncio
import nest_asyncio
import collections
import warnings
import time

import tensorflow as tf

import tensorflow_federated as tff


source, _ = tff.simulation.datasets.emnist.load_data()

def map_fn(example):
  return collections.OrderedDict([
    ('x', tf.reshape(example['pixels'], [-1])),
    ('y', example['label'])

def client_data(n):
  ds = source.create_tf_dataset_for_client(source.client_ids[n])
  return ds.repeat(10).map(map_fn).shuffle(500).batch(20)

train_data = [client_data(n) for n in range(10)]

batch = tf.nest.map_structure(lambda x: x.numpy(), iter(train_data[0]).next())

def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.Dense(10, tf.nn.softmax, kernel_initializer='zeros')
  return tff.learning.from_compiled_keras_model(model, batch)

trainer = tff.learning.build_federated_averaging_process(model_fn)

def evaluate(num_rounds=10):
  state = trainer.initialize()
  for _ in range(num_rounds):
    t1 = time.time()
    state, metrics =, train_data)
    t2 = time.time()
    print('loss {}, round time {}'.format(metrics.loss, t2 - t1))

Single-machine simulations

A simple local multi-threaded executor can be created using a new currently undocumented framework function tff.framework.create_local_executor(), and made default by calling tff.framework.set_default_executor(), as follows.


loss 2.9510040283203125, round time 49.65723657608032
loss 2.777134656906128, round time 45.5357563495636
loss 2.5103652477264404, round time 29.720882892608643
loss 2.2921206951141357, round time 30.4314706325531
loss 2.0617873668670654, round time 32.21593737602234
loss 1.9325430393218994, round time 43.6105010509491
loss 1.7762397527694702, round time 23.19011878967285
loss 1.6028356552124023, round time 25.11474061012268
loss 1.5010586977005005, round time 24.695493936538696
loss 1.4369142055511475, round time 22.34806251525879

The reference executor can be automatically installed back by calling the tff.framework.set_default_executor() function without an argument.


Multi-machine simulations on GCP/GKE, GPUs, TPUs, and beyond...

Coming very soon.