High-performance simulations with TFF

This tutorial will describe how to setup high-performance simulations with TFF in a variety of common scenarios.

TODO: b/134543154 - Populate the content, some of the things to cover here:

  • using GPUs in a single-machine setup,
  • multi-machine setup on GCP/GKE, with and without TPUs,
  • interfacing MapReduce-like backends,
  • current limitations and when/how they will be relaxed.
View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Before we begin

First, make sure your notebook is connected to a backend that has the relevant components (including gRPC dependencies for multi-machine scenarios) compiled.

Now, let's start by loading the MNIST example from the TFF website, and declaring the Python function that will run a small experiment loop over a group of 10 clients.

pip install --quiet --upgrade tensorflow-federated
import collections
import time

import tensorflow as tf
import tensorflow_federated as tff

source, _ = tff.simulation.datasets.emnist.load_data()


def map_fn(example):
  return collections.OrderedDict(
      x=tf.reshape(example['pixels'], [-1, 784]), y=example['label']
  )


def client_data(n):
  ds = source.create_tf_dataset_for_client(source.client_ids[n])
  return ds.repeat(10).shuffle(500).batch(20).map(map_fn)


train_data = [client_data(n) for n in range(10)]
element_spec = train_data[0].element_spec


keras_model = tf.keras.models.Sequential([
    tf.keras.layers.InputLayer(input_shape=(784,)),
    tf.keras.layers.Dense(units=10, kernel_initializer='zeros'),
    tf.keras.layers.Softmax(),
])
tff_model = tff.learning.models.functional_model_from_keras(
    keras_model,
    loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(),
    input_spec=element_spec,
    metrics_constructor=collections.OrderedDict(
        accuracy=tf.keras.metrics.SparseCategoricalAccuracy
    ),
)


trainer = tff.learning.algorithms.build_weighted_fed_avg(
    tff_model,
    client_optimizer_fn=tff.learning.optimizers.build_sgdm(learning_rate=0.02),
)


def evaluate(num_rounds=10):
  state = trainer.initialize()
  for _ in range(num_rounds):
    t1 = time.time()
    result = trainer.next(state, train_data)
    state = result.state
    train_metrics = result.metrics['client_work']['train']
    t2 = time.time()
    print(
        'train metrics {m}, round time {t:.2f} seconds'.format(
            m=train_metrics, t=t2 - t1
        )
    )

Single-machine simulations

Now on by default.

evaluate()
train metrics OrderedDict([('accuracy', 0.14557613)]), round time 3.11 seconds
train metrics OrderedDict([('accuracy', 0.181893)]), round time 1.25 seconds
train metrics OrderedDict([('accuracy', 0.23374486)]), round time 1.23 seconds
train metrics OrderedDict([('accuracy', 0.26759258)]), round time 1.18 seconds
train metrics OrderedDict([('accuracy', 0.31944445)]), round time 1.14 seconds
train metrics OrderedDict([('accuracy', 0.37222221)]), round time 1.14 seconds
train metrics OrderedDict([('accuracy', 0.42685184)]), round time 1.25 seconds
train metrics OrderedDict([('accuracy', 0.4712963)]), round time 1.35 seconds
train metrics OrderedDict([('accuracy', 0.5269547)]), round time 1.31 seconds
train metrics OrderedDict([('accuracy', 0.55833334)]), round time 1.31 seconds

Multi-machine simulations on GCP/GKE, GPUs, TPUs, and beyond...

Coming very soon.