High-Performance Simulation with Kubernetes

This tutorial will describe how to set up high-performance simulation using a TFF runtime deployed on Kubernetes.

For demonstrative purposes, we'll use the TFF simulation for image classification from the tutorial, Federated Learning for Image Classification, but we'll run it against a multi-machine setup consisting of two TFF workers running in Kubernetes. We'll use the same EMNIST dataset for training, but split into two partitions, one for each TFF worker.

This tutorial refers to the following Google Cloud services,

  • GKE to create the Kubernetes cluster, but all the steps after the cluster is created can be used with any Kubernetes installation.
  • Filestore to serve the training data, but works with any storage medium that can be mounted as a Kubernetes persistent volume.
View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Launch the TFF Workers on Kubernetes

Package TFF Worker Binary

worker_service.py contains the source code for our custom TFF worker. It runs a simulation server with custom logic for loading a dataset partition and sampling from it for each round of federated learning. (To learn more, see Loading Remote Data in TFF.)

We're going to deploy our TFF worker as a containerized application on Kubernetes. Lets start by building a Docker image. Using this Dockerfile, we can package the code by running,

$ WORKER_IMAGE=tff-worker-service:latest

$ docker build --tag $WORKER_IMAGE --file "./Dockerfile" .

(Assuming worker_service.py and Dockerfile are located in your working directory.)

Then publish the image to a container repository where it can be accessed by the Kubernetes cluster we're about to create, e.g.,

$ docker push $WORKER_IMAGE

Create a Kubernetes Cluster

The following step only needs to be done once. The cluster can be re-used for future workloads.

Follow the GKE instructions to create a cluster with Filestore CSI driver enabled, e.g.,

gcloud container clusters create tff-cluster --addons=GcpFilestoreCsiDriver

The commands to interact with GCP can be run locally or in the Google Cloud Shell. We recommend the Google Cloud Shell since it doesn't require additional setup.

The rest of this tutorial assumes that the cluster is named tff-cluster, but the actual name isn't important.

Deploy the TFF Worker Application

worker_deployment.yaml declares the configuration for standing up two TFF workers, each in their own Kubernetes pod with two replicas each. We can apply this configuration to our running cluster,

kubectl apply -f worker_deployment.yaml

Once the changes have been requested, you can check the pods are ready,

kubectl get pod
NAME                                        READY   STATUS    RESTARTS   AGE
tff-workers-deployment-1-6bb8d458d5-hjl9d   1/1     Running   0          5m
tff-workers-deployment-1-6bb8d458d5-jgt4b   1/1     Running   0          5m
tff-workers-deployment-2-6cb76c6f5d-hqt88   1/1     Running   0          5m
tff-workers-deployment-2-6cb76c6f5d-xk92h   1/1     Running   0          5m

Each worker instance runs behind a load balancer with an endpoint. Look up the external IP address of the load balancers,

kubectl get service
NAME                    TYPE           CLUSTER-IP    EXTERNAL-IP     PORT(S)        AGE
tff-workers-service-1   LoadBalancer   XX.XX.X.XXX   XX.XXX.XX.XXX   80:31830/TCP   6m
tff-workers-service-2   LoadBalancer   XX.XX.X.XXX   XX.XXX.XX.XXX   80:31319/TCP   6m

You'll need it later to connect the training loop to the running workers.

Prepare Training Data

The EMNIST partitions we'll consume for training can be downloaded from TFF's public dataset repository,

gsutil cp -r gs://tff-datasets-public/emnist-partitions/2-partition

You can then upload them to each pod by copying them to a replica, e.g.,

kubectl cp emnist_part_1.sqlite tff-workers-deployment-1-6bb8d458d5-hjl9d:/root/worker/data/emnist_partition.sqlite

kubectl cp emnist_part_2.sqlite tff-workers-deployment-2-6cb76c6f5d-hqt88:/root/worker/data/emnist_partition.sqlite

Run Simulation

Now we're ready to run simulations against our cluster.

Setup TFF Environment

!pip install --quiet --upgrade tensorflow-federated
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio

Define the Training Procedure

The following defines the dataset iteration methodology, the model architecture, and the round-over-round process for federated learning. (For more detail.)

import collections
from typing import Any, Optional, List
import tensorflow as tf
import tensorflow_federated as tff

class FederatedData(tff.program.FederatedDataSource,
  """Interface for interacting with the federated training data."""

  def __init__(self, type_spec: tff.FederatedType):
    self._type_spec = type_spec
    self._capabilities = [tff.program.Capability.RANDOM_UNIFORM]

  def federated_type(self) -> tff.FederatedType:
    return self._type_spec

  def capabilities(self) -> List[tff.program.Capability]:
    return self._capabilities

  def iterator(self) -> tff.program.FederatedDataSourceIterator:
    return self

  def select(self, num_clients: Optional[int] = None) -> Any:
    data_uris = [f'uri://{i}' for i in range(num_clients)]
    return tff.framework.CreateDataDescriptor(
        arg_uris=data_uris, arg_type=self._type_spec)

input_spec = collections.OrderedDict([
    ('x', tf.TensorSpec(shape=(1, 784), dtype=tf.float32, name=None)),
    ('y', tf.TensorSpec(shape=(1, 1), dtype=tf.int32, name=None))
element_type = tff.types.StructWithPythonType(
    input_spec, container_type=collections.OrderedDict)
dataset_type = tff.types.SequenceType(element_type)

train_data_source = FederatedData(type_spec=dataset_type)
train_data_iterator = train_data_source.iterator()

def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.Dense(units=10, kernel_initializer='zeros'),
  return tff.learning.from_keras_model(

trainer = tff.learning.algorithms.build_weighted_fed_avg(
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

def train_loop(num_rounds=10, num_clients=10):
  state = trainer.initialize()
  for round in range(1, num_rounds + 1):
    train_data = train_data_iterator.select(num_clients)
    result = trainer.next(state, train_data)
    state = result.state
    train_metrics = result.metrics['client_work']['train']
    print('round {:2d}, metrics={}'.format(round, train_metrics))

Connect to TFF Workers

By default, TFF executes all computations locally. In this step we tell TFF to connect to the Kubernetes services we set up above. Be sure to copy the external IP addresses of your services here.

import grpc

ip_address_1 = '' 
ip_address_2 = '' 
port = 80

channels = [


Execute Training

round  1, metrics=OrderedDict([('sparse_categorical_accuracy', 0.10557769), ('loss', 12.475689), ('num_examples', 5020), ('num_batches', 5020)])
round  2, metrics=OrderedDict([('sparse_categorical_accuracy', 0.11940298), ('loss', 10.497084), ('num_examples', 5360), ('num_batches', 5360)])
round  3, metrics=OrderedDict([('sparse_categorical_accuracy', 0.16223507), ('loss', 7.569645), ('num_examples', 5190), ('num_batches', 5190)])
round  4, metrics=OrderedDict([('sparse_categorical_accuracy', 0.2648384), ('loss', 6.0947175), ('num_examples', 5105), ('num_batches', 5105)])
round  5, metrics=OrderedDict([('sparse_categorical_accuracy', 0.29003084), ('loss', 6.2815433), ('num_examples', 4865), ('num_batches', 4865)])
round  6, metrics=OrderedDict([('sparse_categorical_accuracy', 0.40237388), ('loss', 4.630901), ('num_examples', 5055), ('num_batches', 5055)])
round  7, metrics=OrderedDict([('sparse_categorical_accuracy', 0.4288425), ('loss', 4.2358975), ('num_examples', 5270), ('num_batches', 5270)])
round  8, metrics=OrderedDict([('sparse_categorical_accuracy', 0.46349892), ('loss', 4.3829923), ('num_examples', 4630), ('num_batches', 4630)])
round  9, metrics=OrderedDict([('sparse_categorical_accuracy', 0.492094), ('loss', 3.8121278), ('num_examples', 4680), ('num_batches', 4680)])
round 10, metrics=OrderedDict([('sparse_categorical_accuracy', 0.5872674), ('loss', 3.058461), ('num_examples', 5105), ('num_batches', 5105)])