Loading Remote Data in TFF

In a production environment, the raw data for a federated computation is typically distributed across machines and requires special preprocessing and loading before it's usable.

This tutorial describes how to load data stored in those remote locations with TFF's DataBackend and DataExecutor interfaces. But to keep the example simple, the dataset will exist entirely in memory and we'll smiulate the fetching as if the dataset was partitioned over a network.

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Before we start

Before we start, please run the following to make sure that your environment is correctly setup. If you don't see a greeting, please refer to the Installation guide for instructions.

Set up open-source environment

Import packages

Preparing the input data

Let's begin by loading TFF's federated version of the EMNIST dataset from the built-in repository:

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

We'll construct a preprocessing function to transform the raw examples in the EMNIST dataset from 28x28 images into 784-element arrays. Additionally, the function will shuffle the individual examples, and rename the features from pixels and label, to x and y for use with Keras. We also throw in a repeat over the data set to run several epochs.


def preprocess(dataset):

  def map_fn(element):
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).map(map_fn)

Let's verify this works:

example_dataset = emnist_train.create_tf_dataset_for_client(
preprocessed_example_dataset = preprocess(example_dataset)

We'll use the EMNIST dataset to train a model by loading and preprocessing individual clients (emulating distinct partitions) through an implementation of DataBackend.

Defining a DataBackend

We need an instance of DataBackend to instruct TFF workers (these are the processes that run the client-side of federated computations) how to load and tranform the local data stored in remote locations. A DataBackend is a programmatic construct that resolves symbolic references, represented as application-specific URIs, to materialized payloads that downstream TFF operations can process. Specifically, a DataBackend object is wrapped by a DataExecutor, which queries the object when the TFF runtime encounters an operation that fetches the data.

In this example, an Id to a client is encoded in a URI, which is parsed by our DataBackend definition to retrieve the corresponding client data, convert it to tf.Dataset, and then apply our preprocess function.

class TestDataBackend(tff.framework.DataBackend):

  async def materialize(self, data, type_spec):
    client_id = int(data.uri[-1])
    client_dataset = emnist_train.create_tf_dataset_for_client(
    return preprocess(client_dataset)

Plugging the DataBackend into the ExecutionContext

TFF computations are invoked by an ExecutionContext and in order for data URIs defined in TFF computations to be understood at runtime, a custom context must be defined that includes a pointer to the DataBackend we just created, so URIs can be properly resolved.

The DataBackend works in tandem with DataExecutor to supply the executor with operable data that the executor can relay to requesting executors in order to complete a TFF computation.

def ex_fn(
    device: tf.config.LogicalDevice) -> tff.framework.DataExecutor:
  return tff.framework.DataExecutor(
factory = tff.framework.local_executor_factory(leaf_executor_fn=ex_fn)
ctx = tff.framework.ExecutionContext(executor_fn=factory)

Training the model

Now we are ready to train a model in a federated fashion. Lets define a Keras model along with training hyperparameters:

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),

def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(

We can pass this TFF-wrapped definition of our model to a Federated Averaging algorithm by invoking the helper function tff.learning.algorithms.build_weighted_fed_avg, as follows:

iterative_process = tff.learning.algorithms.build_weighted_fed_avg(
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

state = iterative_process.initialize()

The initialize computation returns the initial state of the Federated Averaging process.

To run a round of training, we need to construct a sample of data by organizing a sample of URI references as follows:

element_type = tff.types.StructWithPythonType(
dataset_type = tff.types.SequenceType(element_type)

data_uris = [f'uri://{i}' for i in range(5)]
data_handle = tff.framework.CreateDataDescriptor(arg_uris=data_uris, arg_type=dataset_type)

Now we can round a round of training:

result = iterative_process.next(state, data_handle)
state = result.state
metrics = result.metrics
print('round 1, metrics={}'.format(metrics))
round 1, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.11625), ('loss', 12.682652), ('num_examples', 2400), ('num_batches', 2400)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])

And we can run a few more rounds:

for round_num in range(2, NUM_ROUNDS):
  result = iterative_process.next(state, data_handle)
  state = result.state
  metrics = result.metrics
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.12375), ('loss', 10.2836895), ('num_examples', 2400), ('num_batches', 2400)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  3, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.17916666), ('loss', 7.733705), ('num_examples', 2400), ('num_batches', 2400)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  4, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.29458332), ('loss', 5.6188993), ('num_examples', 2400), ('num_batches', 2400)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  5, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.38541666), ('loss', 4.4057455), ('num_examples', 2400), ('num_batches', 2400)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  6, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.47041667), ('loss', 3.512454), ('num_examples', 2400), ('num_batches', 2400)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  7, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.535), ('loss', 3.0268242), ('num_examples', 2400), ('num_batches', 2400)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  8, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.5729167), ('loss', 2.7468147), ('num_examples', 2400), ('num_batches', 2400)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round  9, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.62416667), ('loss', 2.3982067), ('num_examples', 2400), ('num_batches', 2400)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
round 10, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.62333333), ('loss', 2.3998983), ('num_examples', 2400), ('num_batches', 2400)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])


This concludes the tutorial. We encourage you to explore the other tutorials we've developed to learn about the many other features of the TFF framework.