Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

High-performance simulations with TFF

This tutorial will describe how to setup high-performance simulations with TFF in a variety of common scenarios.

TODO(b/134543154): Populate the content, some of the things to cover here:

  • using GPUs in a single-machine setup,
  • multi-machine setup on GCP/GKE, with and without TPUs,
  • interfacing MapReduce-like backends,
  • current limitations and when/how they will be relaxed.

Before we begin

First, make sure your notebook is connected to a backend that has the relevant components (including gRPC dependencies for multi-machine scenarios) compiled.

Now, let's start by loading the MNIST example from the TFF website, and declaring the Python function that will run a small experiment loop over a group of 10 clients.

pip install --quiet --upgrade tensorflow_federated
/bin/sh: pip: command not found

import collections
import time

import tensorflow as tf

import tensorflow_federated as tff

source, _ = tff.simulation.datasets.emnist.load_data()

def map_fn(example):
  return collections.OrderedDict(
      x=tf.reshape(example['pixels'], [-1, 784]), y=example['label'])

def client_data(n):
  ds = source.create_tf_dataset_for_client(source.client_ids[n])
  return ds.repeat(10).shuffle(500).batch(20).map(map_fn)

train_data = [client_data(n) for n in range(10)]
element_spec = train_data[0].element_spec

def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.Dense(units=10, kernel_initializer='zeros'),
  return tff.learning.from_keras_model(

trainer = tff.learning.build_federated_averaging_process(
    model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.02))

def evaluate(num_rounds=10):
  state = trainer.initialize()
  for _ in range(num_rounds):
    t1 = time.time()
    state, metrics = trainer.next(state, train_data)
    t2 = time.time()
    print('loss {}, round time {}'.format(metrics.loss, t2 - t1))

Single-machine simulations

Now on by default.

loss 3.0367836952209473, round time 4.970079183578491
loss 2.778421401977539, round time 3.4929888248443604
loss 2.521284341812134, round time 4.029532432556152
loss 2.3498423099517822, round time 3.4987425804138184
loss 2.0624916553497314, round time 3.5738046169281006
loss 1.9093912839889526, round time 3.041914463043213
loss 1.7627369165420532, round time 3.6436498165130615
loss 1.5839917659759521, round time 3.193682909011841
loss 1.5063327550888062, round time 3.22552227973938
loss 1.4204730987548828, round time 3.399146795272827

Multi-machine simulations on GCP/GKE, GPUs, TPUs, and beyond...

Coming very soon.