This tutorial will describe how to setup high-performance simulations with TFF in a variety of common scenarios.
TODO(b/134543154): Populate the content, some of the things to cover here:
- using GPUs in a single-machine setup,
- multi-machine setup on GCP/GKE, with and without TPUs,
- interfacing MapReduce-like backends,
- current limitations and when/how they will be relaxed.
![]() |
![]() |
![]() |
![]() |
Before we begin
First, make sure your notebook is connected to a backend that has the relevant components (including gRPC dependencies for multi-machine scenarios) compiled.
Now, let's start by loading the MNIST example from the TFF website, and declaring the Python function that will run a small experiment loop over a group of 10 clients.
pip install --quiet --upgrade tensorflow-federated
import collections
import time
import tensorflow as tf
import tensorflow_federated as tff
source, _ = tff.simulation.datasets.emnist.load_data()
def map_fn(example):
return collections.OrderedDict(
x=tf.reshape(example['pixels'], [-1, 784]), y=example['label'])
def client_data(n):
ds = source.create_tf_dataset_for_client(source.client_ids[n])
return ds.repeat(10).shuffle(500).batch(20).map(map_fn)
train_data = [client_data(n) for n in range(10)]
element_spec = train_data[0].element_spec
def model_fn():
model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(784,)),
tf.keras.layers.Dense(units=10, kernel_initializer='zeros'),
tf.keras.layers.Softmax(),
])
return tff.learning.models.from_keras_model(
model,
input_spec=element_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
trainer = tff.learning.algorithms.build_weighted_fed_avg(
model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.02))
def evaluate(num_rounds=10):
state = trainer.initialize()
for _ in range(num_rounds):
t1 = time.time()
result = trainer.next(state, train_data)
state = result.state
train_metrics = result.metrics['client_work']['train']
t2 = time.time()
print('train metrics {m}, round time {t:.2f} seconds'.format(
m=train_metrics, t=t2 - t1))
Single-machine simulations
Now on by default.
evaluate()
train metrics OrderedDict([('sparse_categorical_accuracy', 0.15329218), ('loss', 2.918891), ('num_examples', 9720), ('num_batches', 490)]), round time 4.64 seconds train metrics OrderedDict([('sparse_categorical_accuracy', 0.18004115), ('loss', 2.7677088), ('num_examples', 9720), ('num_batches', 490)]), round time 2.37 seconds train metrics OrderedDict([('sparse_categorical_accuracy', 0.21841563), ('loss', 2.511075), ('num_examples', 9720), ('num_batches', 490)]), round time 2.30 seconds train metrics OrderedDict([('sparse_categorical_accuracy', 0.27160493), ('loss', 2.340346), ('num_examples', 9720), ('num_batches', 490)]), round time 2.25 seconds train metrics OrderedDict([('sparse_categorical_accuracy', 0.34115225), ('loss', 2.0537064), ('num_examples', 9720), ('num_batches', 490)]), round time 2.27 seconds train metrics OrderedDict([('sparse_categorical_accuracy', 0.3745885), ('loss', 1.9158486), ('num_examples', 9720), ('num_batches', 490)]), round time 2.21 seconds train metrics OrderedDict([('sparse_categorical_accuracy', 0.41502059), ('loss', 1.7523248), ('num_examples', 9720), ('num_batches', 490)]), round time 2.19 seconds train metrics OrderedDict([('sparse_categorical_accuracy', 0.47644034), ('loss', 1.6085855), ('num_examples', 9720), ('num_batches', 490)]), round time 2.20 seconds train metrics OrderedDict([('sparse_categorical_accuracy', 0.5126543), ('loss', 1.5272282), ('num_examples', 9720), ('num_batches', 490)]), round time 2.27 seconds train metrics OrderedDict([('sparse_categorical_accuracy', 0.5576132), ('loss', 1.393721), ('num_examples', 9720), ('num_batches', 490)]), round time 2.16 seconds
Multi-machine simulations on GCP/GKE, GPUs, TPUs, and beyond...
Coming very soon.