Join us at TensorFlow World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

High-performance simulations with TFF

This tutorial will describe how to setup high-performance simulations with TFF in a variety of common scenarios.

NOTE: The mechanisms covered here are not included in the latest release, have not been tested yet, and the API may evolve. In order to follow this tutorial, you will need to build a TFF pip package from scratch from the latest sources, and install it in a Jupyter notebook with a Python 3 runtime. The new executor stack is not compatible with Python 2.

TODO(b/134543154): Populate the content, some of the things to cover here: - using GPUs in a single-machine setup, - multi-machine setup on GCP/GKE, with and without TPUs, - interfacing MapReduce-like backends, - current limitations and when/how they will be relaxed.

Before we begin

First, make sure your notebook is connected to a backend that has the relevant components (including gRPC dependencies for multi-machine scenarios) compiled.

Now, if you are running this notebook in Jupyter, you may need to take an extra step to work around the limitations of Jypter with asyncio by installing the nest_asyncio package.

!pip install -q --upgrade nest_asyncio
import nest_asyncio

Now, let's start by loading the MNIST example from the TFF website, and declaring the Python function that will run a small experiment loop. In order to use all data and make sure there is enough work for each round, we partition data from all users into 10 groups and assign one group per simulated client.

!pip install -q --upgrade tensorflow_federated
!pip install -q --upgrade tf-nightly
import collections
import warnings
import time

import tensorflow as tf

import tensorflow_federated as tff



source, _ = tff.simulation.datasets.emnist.load_data()

def client_data(n):
  ids_per_client = int(len(source.client_ids) / NUM_CLIENTS)
  start = ids_per_client * n
  limit = ids_per_client * (n + 1)
  combined_dataset = None
  while start < limit:
    dataset = source.create_tf_dataset_for_client(source.client_ids[start]).map(
        lambda e: {
            'x': tf.reshape(e['pixels'], [-1]),
            'y': e['label'],
    start = start + 1
    if combined_dataset is not None:
      combined_dataset = combined_dataset.concatenate(dataset)
      combined_dataset = dataset
  return combined_dataset.repeat(NUM_EPOCHS).batch(BATCH_SIZE)

train_data = [client_data(n) for n in range(NUM_CLIENTS)]

batch = tf.nest.map_structure(lambda x: x.numpy(), iter(train_data[0]).next())

def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.Dense(10, tf.nn.softmax, kernel_initializer='zeros')
  return tff.learning.from_compiled_keras_model(model, batch)

trainer = tff.learning.build_federated_averaging_process(model_fn)

def evaluate(num_rounds=10):
  state = trainer.initialize()
  for _ in range(num_rounds):
    t1 = time.time()
    state, metrics =, train_data)
    t2 = time.time()
    print('loss {}, round time {}'.format(metrics.loss, t2 - t1))

Single-machine simulations

A simple local multi-threaded executor can be created using a new currently undocumented framework function tff.framework.create_local_executor(), and made default by calling tff.framework.set_default_executor(), as follows. Note that the elocal executor currently requires the number of clients to be specified at setup time. We'll relax this restriction in the near future.


loss 3.55805, round time 15.59630298614502
loss 1.20457, round time 14.756427764892578
loss 0.968885, round time 14.836261749267578
loss 0.876231, round time 15.162195682525635
loss 0.822056, round time 14.794559955596924
loss 0.782698, round time 14.633676767349243
loss 0.754549, round time 14.806968688964844
loss 0.735254, round time 14.6411874294281
loss 0.718297, round time 14.8141508102417
loss 0.703123, round time 14.478189468383789

Now, for comparison, let's run the same training code using the reference executor. The reference executor can be automatically installed back by calling the tff.framework.set_default_executor() function without an argument.


loss 3.55805, round time 273.98993945121765
loss 1.20457, round time 275.3648579120636
loss 0.968885, round time 275.70571970939636
loss 0.876231, round time 276.54428720474243
loss 0.822056, round time 277.55314350128174
loss 0.782698, round time 273.0547134876251
loss 0.754549, round time 276.6848702430725
loss 0.735254, round time 277.1024343967438
loss 0.718297, round time 273.02305459976196
loss 0.703123, round time 275.2886736392975

Multi-machine simulations on GCP/GKE, GPUs, TPUs, and beyond...

Coming very soon.