{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "headers" }, "source": [ "Project: /federated/_project.yaml\n", "Book: /federated/_book.yaml\n", "\n", "\n", "\n", "\n", "\n", "\n", "{% comment %}\n", "The source of truth file can be found [here]: http://google3/third_party/tensorflow_federated/g3doc\n", "{% endcomment %}" ] }, { "cell_type": "markdown", "metadata": { "id": "metadata" }, "source": [ "
\n", " View on TensorFlow.org\n", " | \n", "\n", " Run in Google Colab\n", " | \n", "\n", " View source on GitHub\n", " | \n", "\n", " Download notebook\n", " | \n", "
tff.learning
- a set of\n",
"higher-level interfaces that can be used to perform common types of federated\n",
"learning tasks, such as federated training, against user-supplied models\n",
"implemented in TensorFlow.\n",
"\n",
"This tutorial, and the Federated Learning API, are intended primarily for users\n",
"who want to plug their own TensorFlow models into TFF, treating the latter\n",
"mostly as a black box. For a more in-depth understanding of TFF and how to\n",
"implement your own federated learning algorithms, see the tutorials on the FC Core API - [Custom Federated Algorithms Part 1](custom_federated_algorithms_1.ipynb) and [Part 2](custom_federated_algorithms_2.ipynb).\n",
"\n",
"For more on tff.learning
, continue with the\n",
"[Federated Learning for Text Generation](federated_learning_for_text_generation.ipynb),\n",
"tutorial which in addition to covering recurrent models, also demonstrates loading a\n",
"pre-trained serialized Keras model for refinement with federated learning\n",
"combined with evaluation using Keras."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MnUwFbCAKB2r"
},
"source": [
"## Before we start\n",
"\n",
"Before we start, please run the following to make sure that your environment is\n",
"correctly setup. If you don't see a greeting, please refer to the\n",
"[Installation](../install.md) guide for instructions. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZrGitA_KnRO0"
},
"outputs": [],
"source": [
"#@test {\"skip\": true}\n",
"\n",
"!pip install --quiet --upgrade tensorflow-federated"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QLyJIaLlERJ8",
"outputId": "f15a58ae-63fb-44c8-a50e-45051608ec88"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Fetching TensorBoard MPM version 'live'... done.\n"
]
}
],
"source": [
"%load_ext tensorboard"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8BKyHkMxKHfV",
"outputId": "d396cd17-6b11-4ecc-a026-16140a5f6929"
},
"outputs": [
{
"data": {
"text/plain": [
"b'Hello, World!'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import collections\n",
"\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"import tensorflow_federated as tff\n",
"\n",
"np.random.seed(0)\n",
"\n",
"tff.federated_computation(lambda: 'Hello, World!')()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5Cyy2AWbLMKj"
},
"source": [
"## Preparing the input data\n",
"\n",
"Let's start with the data. Federated learning requires a federated data set,\n",
"i.e., a collection of data from multiple users. Federated data is typically\n",
"non-[i.i.d.](https://en.wikipedia.org/wiki/Independent_and_identically_distributed_random_variables),\n",
"which poses a unique set of challenges.\n",
"\n",
"In order to facilitate experimentation, we seeded the TFF repository with a few\n",
"datasets, including a federated version of MNIST that contains a version of the [original NIST dataset](https://www.nist.gov/srd/nist-special-database-19) that has been re-processed using [Leaf](https://github.com/TalwalkarLab/leaf) so that the data is keyed by the original writer of the digits. Since each writer has a unique style, this dataset exhibits the kind of non-i.i.d. behavior expected of federated datasets.\n",
"\n",
"Here's how we can load it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NayDhCX6SjwE"
},
"outputs": [],
"source": [
"emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yeX8BKgPfeFw"
},
"source": [
"The data sets returned by `load_data()` are instances of\n",
"`tff.simulation.ClientData`, an interface that allows you to enumerate the set\n",
"of users, to construct a tf.data.Dataset
that represents the data of a\n",
"particular user, and to query the structure of individual elements. Here's how\n",
"you can use this interface to explore the content of the data set. Keep in mind\n",
"that while this interface allows you to iterate over clients ids, this is only a\n",
"feature of the simulation data. As you will see shortly, client identities are\n",
"not used by the federated learning framework - their only purpose is to allow\n",
"you to select subsets of the data for simulations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kN4-U5nJgKig",
"outputId": "0bfaf626-89f6-4000-da7e-f2ad1e64f996"
},
"outputs": [
{
"data": {
"text/plain": [
"3383"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(emnist_train.client_ids)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZyCzIrSegT62",
"outputId": "e2f5753a-e212-48a0-d7c2-a7e9d43ef7e9"
},
"outputs": [
{
"data": {
"text/plain": [
"OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"emnist_train.element_type_structure"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EsvSXGEMgd9G",
"outputId": "6f8241ca-602c-49d5-b206-a1cf8a4b13bf"
},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"example_dataset = emnist_train.create_tf_dataset_for_client(\n",
" emnist_train.client_ids[0])\n",
"\n",
"example_element = next(iter(example_dataset))\n",
"\n",
"example_element['label'].numpy()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OmLV0nfMg98V"
},
"outputs": [],
"source": [
"from matplotlib import pyplot as plt\n",
"plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')\n",
"plt.grid(False)\n",
"_ = plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GGnxdUp8Cj5h"
},
"source": [
"### Exploring heterogeneity in federated data\n",
"\n",
"Federated data is typically non-[i.i.d.](https://en.wikipedia.org/wiki/Independent_and_identically_distributed_random_variables), users typically have different distributions of data depending on usage patterns. Some clients may have fewer training examples on device, suffering from data paucity locally, while some clients will have more than enough training examples. Let's explore this concept of data heterogeneity typical of a federated system with the EMNIST data we have available. It's important to note that this deep analysis of a client's data is only available to us because this is a simulation environment where all the data is available to us locally. In a real production federated environment you would not be able to inspect a single client's data."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "77mx33vXFrqd"
},
"source": [
"First, let's grab a sampling of one client's data to get a feel for the examples on one simulated device. Because the dataset we're using has been keyed by unique writer, the data of one client represents the handwriting of one person for a sample of the digits 0 through 9, simulating the unique \"usage pattern\" of one user."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PfRva0fsFfSX"
},
"outputs": [],
"source": [
"## Example MNIST digits for one client\n",
"figure = plt.figure(figsize=(20, 4))\n",
"j = 0\n",
"\n",
"for example in example_dataset.take(40):\n",
" plt.subplot(4, 10, j+1)\n",
" plt.imshow(example['pixels'].numpy(), cmap='gray', aspect='equal')\n",
" plt.axis('off')\n",
" j += 1"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c6wB6PggHO3g"
},
"source": [
"Now let's visualize the number of examples on each client for each MNIST digit label. In the federated environment, the number of examples on each client can vary quite a bit, depending on user behavior."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vrjtRk5kICeN"
},
"outputs": [],
"source": [
"# Number of examples per layer for a sample of clients\n",
"f = plt.figure(figsize=(12, 7))\n",
"f.suptitle('Label Counts for a Sample of Clients')\n",
"for i in range(6):\n",
" client_dataset = emnist_train.create_tf_dataset_for_client(\n",
" emnist_train.client_ids[i])\n",
" plot_data = collections.defaultdict(list)\n",
" for example in client_dataset:\n",
" # Append counts individually per label to make plots\n",
" # more colorful instead of one color per plot.\n",
" label = example['label'].numpy()\n",
" plot_data[label].append(label)\n",
" plt.subplot(2, 3, i+1)\n",
" plt.title('Client {}'.format(i))\n",
" for j in range(10):\n",
" plt.hist(\n",
" plot_data[j],\n",
" density=False,\n",
" bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "B9vBNGd2I4Kn"
},
"source": [
"Now let's visualize the mean image per client for each MNIST label. This code will produce the mean of each pixel value for all of the user's examples for one label. We'll see that one client's mean image for a digit will look different than another client's mean image for the same digit, due to each person's unique handwriting style. We can muse about how each local training round will nudge the model in a different direction on each client, as we're learning from that user's own unique data in that local round. Later in the tutorial we'll see how we can take each update to the model from all the clients and aggregate them together into our new global model, that has learned from each of our client's own unique data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qfkNoBCTJ5Pl"
},
"outputs": [],
"source": [
"# Each client has different mean images, meaning each client will be nudging\n",
"# the model in their own directions locally.\n",
"\n",
"for i in range(5):\n",
" client_dataset = emnist_train.create_tf_dataset_for_client(\n",
" emnist_train.client_ids[i])\n",
" plot_data = collections.defaultdict(list)\n",
" for example in client_dataset:\n",
" plot_data[example['label'].numpy()].append(example['pixels'].numpy())\n",
" f = plt.figure(i, figsize=(12, 5))\n",
" f.suptitle(\"Client #{}'s Mean Image Per Label\".format(i))\n",
" for j in range(10):\n",
" mean_img = np.mean(plot_data[j], 0)\n",
" plt.subplot(2, 5, j+1)\n",
" plt.imshow(mean_img.reshape((28, 28)))\n",
" plt.axis('off')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HpBrx5Jn7X5E"
},
"source": [
"User data can be noisy and unreliably labeled. For example, looking at Client #2's data above, we can see that for label 2, it is possible that there may have been some mislabeled examples creating a noisier mean image."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "U0pwnQZUKea2"
},
"source": [
"### Preprocessing the input data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lMd01egqy9we"
},
"source": [
"Since the data is already a tf.data.Dataset
, preprocessing can be accomplished using Dataset transformations. Here, we flatten the `28x28` images\n",
"into `784`-element arrays, shuffle the individual examples, organize them into batches, and rename the features\n",
"from `pixels` and `label` to `x` and `y` for use with Keras. We also throw in a\n",
"`repeat` over the data set to run several epochs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cyG_BMraSuu_"
},
"outputs": [],
"source": [
"NUM_CLIENTS = 10\n",
"NUM_EPOCHS = 5\n",
"BATCH_SIZE = 20\n",
"SHUFFLE_BUFFER = 100\n",
"PREFETCH_BUFFER = 10\n",
"\n",
"def preprocess(dataset):\n",
"\n",
" def batch_format_fn(element):\n",
" \"\"\"Flatten a batch `pixels` and return the features as an `OrderedDict`.\"\"\"\n",
" return collections.OrderedDict(\n",
" x=tf.reshape(element['pixels'], [-1, 784]),\n",
" y=tf.reshape(element['label'], [-1, 1]))\n",
"\n",
" return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).batch(\n",
" BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m9LXykN_jlJw"
},
"source": [
"Let's verify this worked."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VChB7LMQjkYz",
"outputId": "18e2de32-b2f9-4ed3-f37b-11428c067f01"
},
"outputs": [
{
"data": {
"text/plain": [
"OrderedDict([('x', array([[1., 1., 1., ..., 1., 1., 1.],\n",
" [1., 1., 1., ..., 1., 1., 1.],\n",
" [1., 1., 1., ..., 1., 1., 1.],\n",
" ...,\n",
" [1., 1., 1., ..., 1., 1., 1.],\n",
" [1., 1., 1., ..., 1., 1., 1.],\n",
" [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)), ('y', array([[2],\n",
" [1],\n",
" [5],\n",
" [7],\n",
" [1],\n",
" [7],\n",
" [7],\n",
" [1],\n",
" [4],\n",
" [7],\n",
" [4],\n",
" [2],\n",
" [2],\n",
" [5],\n",
" [4],\n",
" [1],\n",
" [1],\n",
" [0],\n",
" [0],\n",
" [9]], dtype=int32))])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessed_example_dataset = preprocess(example_dataset)\n",
"\n",
"sample_batch = tf.nest.map_structure(lambda x: x.numpy(),\n",
" next(iter(preprocessed_example_dataset)))\n",
"\n",
"sample_batch"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JGsMvRQt9Agl"
},
"source": [
"We have almost all the building blocks in place to construct federated data\n",
"sets.\n",
"\n",
"One of the ways to feed federated data to TFF in a simulation is simply as a\n",
"Python list, with each element of the list holding the data of an individual\n",
"user, whether as a list or as a tf.data.Dataset
. Since we already have\n",
"an interface that provides the latter, let's use it.\n",
"\n",
"Here's a simple helper function that will construct a list of datasets from the\n",
"given set of users as an input to a round of training or evaluation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_PHMvHAI9xVc"
},
"outputs": [],
"source": [
"def make_federated_data(client_data, client_ids):\n",
" return [\n",
" preprocess(client_data.create_tf_dataset_for_client(x))\n",
" for x in client_ids\n",
" ]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0M9PfjOtAVqw"
},
"source": [
"Now, how do we choose clients?\n",
"\n",
"In a typical federated training scenario, we are dealing with potentially a very\n",
"large population of user devices, only a fraction of which may be available for\n",
"training at a given point in time. This is the case, for example, when the\n",
"client devices are mobile phones that participate in training only when plugged\n",
"into a power source, off a metered network, and otherwise idle.\n",
"\n",
"Of course, we are in a simulation environment, and all the data is locally\n",
"available. Typically then, when running simulations, we would simply sample a\n",
"random subset of the clients to be involved in each round of training, generally\n",
"different in each round.\n",
"\n",
"That said, as you can find out by studying the paper on the\n",
"[Federated Averaging](https://arxiv.org/abs/1602.05629) algorithm, achieving convergence in a system with randomly sampled\n",
"subsets of clients in each round can take a while, and it would be impractical\n",
"to have to run hundreds of rounds in this interactive tutorial.\n",
"\n",
"What we'll do instead is sample the set of clients once, and\n",
"reuse the same set across rounds to speed up convergence (intentionally\n",
"over-fitting to these few user's data). We leave it as an exercise for the\n",
"reader to modify this tutorial to simulate random sampling - it is fairly easy to\n",
"do (once you do, keep in mind that getting the model to converge may take a\n",
"while)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GZ6NYHxB8xer",
"outputId": "da185cd5-1b8b-4dda-8ea0-ccf38abb07ba"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of client datasets: 10\n",
"First dataset: <_PrefetchDataset element_spec=OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))])>\n"
]
}
],
"source": [
"sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]\n",
"\n",
"federated_train_data = make_federated_data(emnist_train, sample_clients)\n",
"\n",
"print(f'Number of client datasets: {len(federated_train_data)}')\n",
"print(f'First dataset: {federated_train_data[0]}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HOxq4tbi9m8-"
},
"source": [
"## Creating a model with Keras\n",
"\n",
"If you are using Keras, you likely already have code that constructs a Keras\n",
"model. Here's an example of a simple model that will suffice for our needs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LYCsJGJFWbqt"
},
"outputs": [],
"source": [
"def create_keras_model():\n",
" return tf.keras.models.Sequential([\n",
" tf.keras.layers.InputLayer(input_shape=(784,)),\n",
" tf.keras.layers.Dense(10, kernel_initializer='zeros'),\n",
" tf.keras.layers.Softmax(),\n",
" ])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NHdraKFH4OU2"
},
"source": [
"**Note:** we do not compile the model yet. The loss, metrics, and optimizers are introduced later.\n",
"\n",
"In order to use any model with TFF, it needs to be wrapped in an instance of the\n",
"tff.learning.models.VariableModel
interface, which exposes methods to stamp the model's\n",
"forward pass, metadata properties, etc., similarly to Keras, but also introduces\n",
"additional elements, such as ways to control the process of computing federated\n",
"metrics. Let's not worry about this for now; if you have a Keras model like the\n",
"one we've just defined above, you can have TFF wrap it for you by invoking\n",
"tff.learning.models.from_keras_model
, passing the model and a sample data batch as\n",
"arguments, as shown below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Q3ynrxd53HzY"
},
"outputs": [],
"source": [
"def model_fn():\n",
" # We _must_ create a new model here, and _not_ capture it from an external\n",
" # scope. TFF will call this within different graph contexts.\n",
" keras_model = create_keras_model()\n",
" return tff.learning.models.from_keras_model(\n",
" keras_model,\n",
" input_spec=preprocessed_example_dataset.element_spec,\n",
" loss=tf.keras.losses.SparseCategoricalCrossentropy(),\n",
" metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XJ5E3O18_JZ6"
},
"source": [
"## Training the model on federated data\n",
"\n",
"Now that we have a model wrapped as tff.learning.models.VariableModel
for use with TFF, we\n",
"can let TFF construct a Federated Averaging algorithm by invoking the helper\n",
"function tff.learning.algorithms.build_weighted_fed_avg
, as follows.\n",
"\n",
"Keep in mind that the argument needs to be a constructor (such as `model_fn`\n",
"above), not an already-constructed instance, so that the construction of your\n",
"model can happen in a context controlled by TFF (if you're curious about the\n",
"reasons for this, we encourage you to read the follow-up tutorial on\n",
"[custom algorithms](custom_federated_algorithms_1.ipynb)).\n",
"\n",
"One critical note on the Federated Averaging algorithm below, there are **2**\n",
"optimizers: a _client_optimizer_ and a _server_optimizer_. The\n",
"_client_optimizer_ is only used to compute local model updates on each client.\n",
"The _server_optimizer_ applies the averaged update to the global model at the\n",
"server. In particular, this means that the choice of optimizer and learning rate\n",
"used may need to be different than the ones you have used to train the model on\n",
"a standard i.i.d. dataset. We recommend starting with regular SGD, possibly with\n",
"a smaller learning rate than usual. The learning rate we use has not been\n",
"carefully tuned, feel free to experiment."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sk6mjOfycX5N"
},
"outputs": [],
"source": [
"training_process = tff.learning.algorithms.build_weighted_fed_avg(\n",
" model_fn,\n",
" client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),\n",
" server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f8FpvN2n67sm"
},
"source": [
"What just happened? TFF has constructed a pair of *federated computations* and\n",
"packaged them into a tff.templates.IterativeProcess
in which these computations\n",
"are available as a pair of properties `initialize` and `next`.\n",
"\n",
"In a nutshell, *federated computations* are programs in TFF's internal language\n",
"that can express various federated algorithms (you can find more about this in\n",
"the [custom algorithms](custom_federated_algorithms_1.ipynb) tutorial). In this\n",
"case, the two computations generated and packed into `iterative_process`\n",
"implement [Federated Averaging](https://arxiv.org/abs/1602.05629).\n",
"\n",
"It is a goal of TFF to define computations in a way that they could be executed\n",
"in real federated learning settings, but currently only local execution\n",
"simulation runtime is implemented. To execute a computation in a simulator, you\n",
"simply invoke it like a Python function. This default interpreted environment is\n",
"not designed for high performance, but it will suffice for this tutorial; we\n",
"expect to provide higher-performance simulation runtimes to facilitate\n",
"larger-scale research in future releases.\n",
"\n",
"Let's start with the `initialize` computation. As is the case for all federated\n",
"computations, you can think of it as a function. The computation takes no\n",
"arguments, and returns one result - the representation of the state of the\n",
"Federated Averaging process on the server. While we don't want to dive into the\n",
"details of TFF, it may be instructive to see what this state looks like. You can\n",
"visualize it as follows."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Z4pcfWsUBp_5",
"outputId": "2f4e706a-55fb-4662-f317-04312d2a992d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"( -> <\n",
" global_model_weights=<\n",
" trainable=<\n",
" float32[784,10],\n",
" float32[10]\n",
" >,\n",
" non_trainable=<>\n",
" >,\n",
" distributor=<>,\n",
" client_work=<>,\n",
" aggregator=<\n",
" value_sum_process=<>,\n",
" weight_sum_process=<>\n",
" >,\n",
" finalizer=<\n",
" int64,\n",
" float32[784,10],\n",
" float32[10]\n",
" >\n",
">@SERVER)\n"
]
}
],
"source": [
"print(training_process.initialize.type_signature.formatted_representation())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "v1gbHQ_7BiyT"
},
"source": [
"While the above type signature may at first seem a bit cryptic, you can\n",
"recognize that the server state consists of a `global_model_weights` (the initial model parameters for MNIST that will be distributed to all devices), some empty parameters (like `distributor`, which governs the server-to-client communication) and a `finalizer` component. This last one governs the logic that the server uses to update its model at the end of a round, and contains an integer representing how many rounds of FedAvg have occurred.\n",
"\n",
"Let's invoke the `initialize` computation to construct the server state."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6cagCWlZmcch"
},
"outputs": [],
"source": [
"train_state = training_process.initialize()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TjjxTx9e_rMd"
},
"source": [
"The second of the pair of federated computations, `next`, represents a single\n",
"round of Federated Averaging, which consists of pushing the server state\n",
"(including the model parameters) to the clients, on-device training on their\n",
"local data, collecting and averaging model updates, and producing a new updated\n",
"model at the server.\n",
"\n",
"Conceptually, you can think of `next` as having a functional type signature that\n",
"looks as follows.\n",
"\n",
"```\n",
"SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS\n",
"```\n",
"\n",
"In particular, one should think about `next()` not as being a function that runs on a server, but rather being a declarative functional representation of the entire decentralized computation - some of the inputs are provided by the server (`SERVER_STATE`), but each participating device contributes its own local dataset.\n",
"\n",
"Let's run a single round of training and visualize the results. We can use the\n",
"federated data we've already generated above for a sample of users."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "F3M_W9dDE6Tm",
"outputId": "e3e64677-474b-4f08-9ef5-cb563e0bfbc9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"round 1, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.12345679), ('loss', 3.1193733), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])\n"
]
}
],
"source": [
"result = training_process.next(train_state, federated_train_data)\n",
"train_state = result.state\n",
"train_metrics = result.metrics\n",
"print('round 1, metrics={}'.format(train_metrics))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UmhReXt9G4A5"
},
"source": [
"Let's run a few more rounds. As noted earlier, typically at this point you would\n",
"pick a subset of your simulation data from a new randomly selected sample of\n",
"users for each round in order to simulate a realistic deployment in which users\n",
"continuously come and go, but in this interactive notebook, for the sake of\n",
"demonstration we'll just reuse the same users, so that the system converges\n",
"quickly."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qrJkQuCRJP9C",
"outputId": "c554c25e-ac05-4332-b127-45fffbff9882"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"round 2, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.14012346), ('loss', 2.9851403), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])\n",
"round 3, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.1590535), ('loss', 2.8617127), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])\n",
"round 4, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.17860082), ('loss', 2.7401376), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])\n",
"round 5, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.20102881), ('loss', 2.6186547), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])\n",
"round 6, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.22345679), ('loss', 2.5006158), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])\n",
"round 7, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.24794239), ('loss', 2.3858356), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])\n",
"round 8, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.27160493), ('loss', 2.2757034), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])\n",
"round 9, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.2958848), ('loss', 2.17098), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])\n",
"round 10, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.3251029), ('loss', 2.072707), ('num_examples', 4860), ('num_batches', 248)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])\n"
]
}
],
"source": [
"NUM_ROUNDS = 11\n",
"for round_num in range(2, NUM_ROUNDS):\n",
" result = training_process.next(train_state, federated_train_data)\n",
" train_state = result.state\n",
" train_metrics = result.metrics\n",
" print('round {:2d}, metrics={}'.format(round_num, train_metrics))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "joHYzn9jcs0Y"
},
"source": [
"Training loss is decreasing after each round of federated training, indicating\n",
"the model is converging. There are some important caveats with these training\n",
"metrics, however, see the section on *Evaluation* later in this tutorial."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ruSHJl1IjhNf"
},
"source": [
"## Displaying model metrics in TensorBoard\n",
"Next, let's visualize the metrics from these federated computations using Tensorboard.\n",
"\n",
"Let's start by creating the directory and the corresponding summary writer to write the metrics to."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "E3QUBK41lWDW"
},
"outputs": [],
"source": [
"#@test {\"skip\": true}\n",
"logdir = \"/tmp/logs/scalars/training/\"\n",
"try:\n",
" tf.io.gfile.rmtree(logdir) # delete any previous results\n",
"except tf.errors.NotFoundError as e:\n",
" pass # Ignore if the directory didn't previously exist.\n",
"summary_writer = tf.summary.create_file_writer(logdir)\n",
"train_state = training_process.initialize()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "w-2aGxUlzS_J"
},
"source": [
"Plot the relevant scalar metrics with the same summary writer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JZtr4_8lzN-V"
},
"outputs": [],
"source": [
"#@test {\"skip\": true}\n",
"with summary_writer.as_default():\n",
" for round_num in range(1, NUM_ROUNDS):\n",
" result = training_process.next(train_state, federated_train_data)\n",
" train_state = result.state\n",
" train_metrics = result.metrics\n",
" for name, value in train_metrics['client_work']['train'].items():\n",
" tf.summary.scalar(name, value, step=round_num)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iUouyAHG0Mk8"
},
"source": [
"Start TensorBoard with the root log directory specified above. It can take a few seconds for the data to load."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "urYYcmA9089p"
},
"outputs": [],
"source": [
"#@test {\"skip\": true}\n",
"!ls {logdir}\n",
"%tensorboard --logdir {logdir} --port=0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZMcV15W7b1wG"
},
"outputs": [],
"source": [
"#@test {\"skip\": true}\n",
"# Uncomment and run this cell to clean your directory of old output for\n",
"# future graphs from this directory. We don't run it by default so that if \n",
"# you do a \"Runtime > Run all\" you don't lose your results.\n",
"\n",
"# !rm -R /tmp/logs/scalars/*"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jejrFEVP1EDs"
},
"source": [
"In order to view evaluation metrics the same way, you can create a separate eval folder, like \"logs/scalars/eval\", to write to TensorBoard."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T4hneAcb-F2l"
},
"source": [
"## Customizing the model implementation\n",
"\n",
"Keras is the [recommended high-level model API for TensorFlow](https://medium.com/tensorflow/standardizing-on-keras-guidance-on-high-level-apis-in-tensorflow-2-0-bad2b04c819a), and we encourage using Keras models (via \n",
"tff.learning.models.from_keras_model
) in TFF whenever possible.\n",
"\n",
"However, tff.learning
provides a lower-level model interface, tff.learning.models.VariableModel
, that exposes the minimal functionality necessary for using a model for federated learning. Directly implementing this interface (possibly still using building blocks like tf.keras.layers
) allows for maximum customization without modifying the internals of the federated learning algorithms.\n",
"\n",
"So let's do it all over again from scratch.\n",
"\n",
"### Defining model variables, forward pass, and metrics\n",
"\n",
"The first step is to identify the TensorFlow variables we're going to work with.\n",
"In order to make the following code more legible, let's define a data structure\n",
"to represent the entire set. This will include variables such as `weights` and\n",
"`bias` that we will train, as well as variables that will hold various\n",
"cumulative statistics and counters we will update during training, such as\n",
"`loss_sum`, `accuracy_sum`, and `num_examples`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uqRD72WQC4u1"
},
"outputs": [],
"source": [
"MnistVariables = collections.namedtuple(\n",
" 'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nkJfDcY5oXii"
},
"source": [
"Here's a method that creates the variables. For the sake of simplicity, we\n",
"represent all statistics as tf.float32
, as that will eliminate the need for\n",
"type conversions at a later stage. Wrapping variable initializers as lambdas is\n",
"a requirement imposed by\n",
"[resource variables](https://www.tensorflow.org/api_docs/python/tf/enable_resource_variables)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "H3GQHLNqCfMU"
},
"outputs": [],
"source": [
"def create_mnist_variables():\n",
" return MnistVariables(\n",
" weights=tf.Variable(\n",
" lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),\n",
" name='weights',\n",
" trainable=True),\n",
" bias=tf.Variable(\n",
" lambda: tf.zeros(dtype=tf.float32, shape=(10)),\n",
" name='bias',\n",
" trainable=True),\n",
" num_examples=tf.Variable(0.0, name='num_examples', trainable=False),\n",
" loss_sum=tf.Variable(0.0, name='loss_sum', trainable=False),\n",
" accuracy_sum=tf.Variable(0.0, name='accuracy_sum', trainable=False))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SrdnR0fAre-Q"
},
"source": [
"With the variables for model parameters and cumulative statistics in place, we\n",
"can now define the forward pass method that computes loss, emits predictions,\n",
"and updates the cumulative statistics for a single batch of input data, as\n",
"follows."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZYSRAl-KCvC7"
},
"outputs": [],
"source": [
"def predict_on_batch(variables, x):\n",
" return tf.nn.softmax(tf.matmul(x, variables.weights) + variables.bias)\n",
"\n",
"def mnist_forward_pass(variables, batch):\n",
" y = predict_on_batch(variables, batch['x'])\n",
" predictions = tf.cast(tf.argmax(y, 1), tf.int32)\n",
"\n",
" flat_labels = tf.reshape(batch['y'], [-1])\n",
" loss = -tf.reduce_mean(\n",
" tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))\n",
" accuracy = tf.reduce_mean(\n",
" tf.cast(tf.equal(predictions, flat_labels), tf.float32))\n",
"\n",
" num_examples = tf.cast(tf.size(batch['y']), tf.float32)\n",
"\n",
" variables.num_examples.assign_add(num_examples)\n",
" variables.loss_sum.assign_add(loss * num_examples)\n",
" variables.accuracy_sum.assign_add(accuracy * num_examples)\n",
"\n",
" return loss, predictions"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-gm-yx2Mr_bl"
},
"source": [
"Next, we define two functions that are related to local metrics, again using TensorFlow.\n",
"\n",
"The first function `get_local_unfinalized_metrics` returns the unfinalized metric values (in addition to model updates, which are handled automatically) that are eligible to be aggregated to the server in a federated learning or evaluation process."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RkAZXhjGEekp"
},
"outputs": [],
"source": [
"def get_local_unfinalized_metrics(variables):\n",
" return collections.OrderedDict(\n",
" num_examples=[variables.num_examples],\n",
" loss=[variables.loss_sum, variables.num_examples],\n",
" accuracy=[variables.accuracy_sum, variables.num_examples])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "p-yS2g9nJQwe"
},
"source": [
"The second function `get_metric_finalizers` returns an `OrderedDict` of tf.function
s with the same keys (i.e., metric names) as `get_local_unfinalized_metrics`. Each tf.function
takes in the metric's unfinalized values and computes the finalized metric."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "y0f_Hc4sJTo6"
},
"outputs": [],
"source": [
"def get_metric_finalizers():\n",
" return collections.OrderedDict(\n",
" num_examples=tf.function(func=lambda x: x[0]),\n",
" loss=tf.function(func=lambda x: x[0] / x[1]),\n",
" accuracy=tf.function(func=lambda x: x[0] / x[1]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tqnmjV3zJaeC"
},
"source": [
"How the local unfinalized metrics returned by `get_local_unfinalized_metrics` are aggregated across clients are specified by the `metrics_aggregator` parameter when defining the federated learning or evaluation processes. For example, in the [`tff.learning.algorithms.build_weighted_fed_avg`](https://www.tensorflow.org/federated/api_docs/python/tff/learning/algorithms/build_weighted_fed_avg) API (shown in the next section), the default value for `metrics_aggregator` is [`tff.learning.metrics.sum_then_finalize`](https://www.tensorflow.org/federated/api_docs/python/tff/learning/metrics/sum_then_finalize), which first sums the unfinalized metrics from `CLIENTS`, and then applies the metric finalizers at `SERVER`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7MXGAuQRvmcp"
},
"source": [
"### Constructing an instance of tff.learning.models.VariableModel
\n",
"\n",
"With all of the above in place, we are ready to construct a model representation\n",
"for use with TFF similar to one that's generated for you when you let TFF ingest\n",
"a Keras model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "blQGiTQFS9_r"
},
"outputs": [],
"source": [
"import collections\n",
"from collections.abc import Callable\n",
"\n",
"class MnistModel(tff.learning.models.VariableModel):\n",
"\n",
" def __init__(self):\n",
" self._variables = create_mnist_variables()\n",
"\n",
" @property\n",
" def trainable_variables(self):\n",
" return [self._variables.weights, self._variables.bias]\n",
"\n",
" @property\n",
" def non_trainable_variables(self):\n",
" return []\n",
"\n",
" @property\n",
" def local_variables(self):\n",
" return [\n",
" self._variables.num_examples, self._variables.loss_sum,\n",
" self._variables.accuracy_sum\n",
" ]\n",
"\n",
" @property\n",
" def input_spec(self):\n",
" return collections.OrderedDict(\n",
" x=tf.TensorSpec([None, 784], tf.float32),\n",
" y=tf.TensorSpec([None, 1], tf.int32))\n",
"\n",
" @tf.function\n",
" def predict_on_batch(self, x, training=True):\n",
" del training\n",
" return predict_on_batch(self._variables, x)\n",
" \n",
" @tf.function\n",
" def forward_pass(self, batch, training=True):\n",
" del training\n",
" loss, predictions = mnist_forward_pass(self._variables, batch)\n",
" num_exmaples = tf.shape(batch['x'])[0]\n",
" return tff.learning.models.BatchOutput(\n",
" loss=loss, predictions=predictions, num_examples=num_exmaples)\n",
"\n",
" @tf.function\n",
" def report_local_unfinalized_metrics(\n",
" self) -> collections.OrderedDict[str, list[tf.Tensor]]:\n",
" \"\"\"Creates an `OrderedDict` of metric names to unfinalized values.\"\"\"\n",
" return get_local_unfinalized_metrics(self._variables)\n",
"\n",
" def metric_finalizers(\n",
" self) -> collections.OrderedDict[str, Callable[[list[tf.Tensor]], tf.Tensor]]:\n",
" \"\"\"Creates an `OrderedDict` of metric names to finalizers.\"\"\"\n",
" return get_metric_finalizers()\n",
"\n",
" @tf.function\n",
" def reset_metrics(self):\n",
" \"\"\"Resets metrics variables to initial value.\"\"\"\n",
" for var in self.local_variables:\n",
" var.assign(tf.zeros_like(var))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sMN1AszMwLHL"
},
"source": [
"As you can see, the abstract methods and properties defined by\n",
"tff.learning.models.VariableModel
corresponds to the code snippets in the preceding section\n",
"that introduced the variables and defined the loss and statistics.\n",
"\n",
"Here are a few points worth highlighting:\n",
"\n",
"* All state that your model will use must be captured as TensorFlow variables,\n",
" as TFF does not use Python at runtime (remember your code should be written\n",
" such that it can be deployed to mobile devices; see the\n",
" [custom algorithms](custom_federated_algorithms_1.ipynb) tutorial for a more\n",
" in-depth commentary on the reasons).\n",
"* Your model should describe what form of data it accepts (`input_spec`), as\n",
" in general, TFF is a strongly-typed environment and wants to determine type\n",
" signatures for all components. Declaring the format of your model's input is\n",
" an essential part of it.\n",
"* Although technically not required, we recommend wrapping all TensorFlow\n",
" logic (forward pass, metric calculations, etc.) as tf.function
s,\n",
" as this helps ensure the TensorFlow can be serialized, and removes the need\n",
" for explicit control dependencies."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9DVhXk2Bu-GU"
},
"source": [
"The above is sufficient for evaluation and algorithms like Federated SGD.\n",
"However, for Federated Averaging, we need to specify how the model should train\n",
"locally on each batch. We will specify a local optimizer when building the Federated Averaging algorithm."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hVBugKP3yw03"
},
"source": [
"### Simulating federated training with the new model\n",
"\n",
"With all the above in place, the remainder of the process looks like what we've\n",
"seen already - just replace the model constructor with the constructor of our\n",
"new model class, and use the two federated computations in the iterative process\n",
"you created to cycle through training rounds."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FK3c8_leS9_t"
},
"outputs": [],
"source": [
"training_process = tff.learning.algorithms.build_weighted_fed_avg(\n",
" MnistModel,\n",
" client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Jv_LiggwS9_u"
},
"outputs": [],
"source": [
"train_state = training_process.initialize()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PtOLElmzDPxs",
"outputId": "67b45f5b-fc20-48ce-93f1-41ab1cf19386"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"round 1, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 3.119374), ('accuracy', 0.12345679)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])\n"
]
}
],
"source": [
"result = training_process.next(train_state, federated_train_data)\n",
"train_state = result.state\n",
"metrics = result.metrics\n",
"print('round 1, metrics={}'.format(metrics))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gFkv0yJEGhue",
"outputId": "c4101590-4b0e-4b04-d0cd-ccfe7a6aa9b9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"round 2, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.98514), ('accuracy', 0.14012346)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])\n",
"round 3, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.8617127), ('accuracy', 0.1590535)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])\n",
"round 4, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.740137), ('accuracy', 0.17860082)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])\n",
"round 5, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.6186547), ('accuracy', 0.20102881)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])\n",
"round 6, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.5006158), ('accuracy', 0.22345679)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])\n",
"round 7, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.3858361), ('accuracy', 0.24794239)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])\n",
"round 8, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.275704), ('accuracy', 0.27160493)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])\n",
"round 9, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.1709805), ('accuracy', 0.2958848)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])\n",
"round 10, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.0727067), ('accuracy', 0.3251029)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])\n"
]
}
],
"source": [
"for round_num in range(2, 11):\n",
" result = training_process.next(train_state, federated_train_data)\n",
" train_state = result.state\n",
" metrics = result.metrics\n",
" print('round {:2d}, metrics={}'.format(round_num, metrics))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Iswqa2Uj7phq"
},
"source": [
"To see these metrics within TensorBoard, refer to the steps listed above in \"Displaying model metrics in TensorBoard\"."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m7lz59lMJ0kj"
},
"source": [
"## Evaluation\n",
"\n",
"All of our experiments so far presented only federated training metrics - the\n",
"average metrics over all batches of data trained across all clients in the\n",
"round. This introduces the normal concerns about overfitting, especially since\n",
"we used the same set of clients on each round for simplicity, but there is an\n",
"additional notion of overfitting in training metrics specific to the Federated\n",
"Averaging algorithm. This is easiest to see if we imagine each client had a\n",
"single batch of data, and we train on that batch for many iterations (epochs).\n",
"In this case, the local model will quickly exactly fit to that one batch, and so\n",
"the local accuracy metric we average will approach 1.0. Thus, these training\n",
"metrics can be taken as a sign that training is progressing, but not much more.\n",
"\n",
"To perform evaluation on federated data, you can construct another *federated\n",
"computation* designed for just this purpose, using the\n",
"tff.learning.build_federated_evaluation
function, and passing in your model\n",
"constructor as an argument. Note that unlike with Federated Averaging, where\n",
"we've used `MnistTrainableModel`, it suffices to pass the `MnistModel`.\n",
"Evaluation doesn't perform gradient descent, and there's no need to construct\n",
"optimizers.\n",
"\n",
"For experimentation and research, when a centralized test dataset is available,\n",
"[Federated Learning for Text Generation](federated_learning_for_text_generation.ipynb)\n",
"demonstrates another evaluation option: taking the trained weights from\n",
"federated learning, applying them to a standard Keras model, and then simply\n",
"calling tf.keras.models.Model.evaluate()
on a centralized dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nRiXyqnXM2VO"
},
"outputs": [],
"source": [
"evaluation_process = tff.learning.algorithms.build_fed_eval(MnistModel)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uwfINGoNQEuV"
},
"source": [
"You can inspect the abstract type signature of the evaluation function as follows."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3q5ueoO0NDNb",
"outputId": "df33bf7a-1a11-4183-e3cf-adcf96cd641b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(<\n",
" state=<\n",
" global_model_weights=<\n",
" trainable=<\n",
" float32[784,10],\n",
" float32[10]\n",
" >,\n",
" non_trainable=<>\n",
" >,\n",
" distributor=<>,\n",
" client_work=<\n",
" <>,\n",
" <\n",
" num_examples=<\n",
" float32\n",
" >,\n",
" loss=<\n",
" float32,\n",
" float32\n",
" >,\n",
" accuracy=<\n",
" float32,\n",
" float32\n",
" >\n",
" >\n",
" >,\n",
" aggregator=<\n",
" value_sum_process=<>,\n",
" weight_sum_process=<>\n",
" >,\n",
" finalizer=<>\n",
" >@SERVER,\n",
" client_data={<\n",
" x=float32[?,784],\n",
" y=int32[?,1]\n",
" >*}@CLIENTS\n",
"> -> <\n",
" state=<\n",
" global_model_weights=<\n",
" trainable=<\n",
" float32[784,10],\n",
" float32[10]\n",
" >,\n",
" non_trainable=<>\n",
" >,\n",
" distributor=<>,\n",
" client_work=<\n",
" <>,\n",
" <\n",
" num_examples=<\n",
" float32\n",
" >,\n",
" loss=<\n",
" float32,\n",
" float32\n",
" >,\n",
" accuracy=<\n",
" float32,\n",
" float32\n",
" >\n",
" >\n",
" >,\n",
" aggregator=<\n",
" value_sum_process=<>,\n",
" weight_sum_process=<>\n",
" >,\n",
" finalizer=<>\n",
" >@SERVER,\n",
" metrics=<\n",
" distributor=<>,\n",
" client_work=<\n",
" eval=<\n",
" current_round_metrics=<\n",
" num_examples=float32,\n",
" loss=float32,\n",
" accuracy=float32\n",
" >,\n",
" total_rounds_metrics=<\n",
" num_examples=float32,\n",
" loss=float32,\n",
" accuracy=float32\n",
" >\n",
" >\n",
" >,\n",
" aggregator=<\n",
" mean_value=<>,\n",
" mean_weight=<>\n",
" >,\n",
" finalizer=<>\n",
" >@SERVER\n",
">)\n"
]
}
],
"source": [
"print(evaluation_process.next.type_signature.formatted_representation())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XA3v7f2SQs6q"
},
"source": [
"Be aware that evaluation process is a `tff.lenaring.templates.LearningProcess` object. The object has an `initialize` method that will create the state, but this will contain an untrained model at first. Using the `set_model_weights` method, one must insert the weights from the training state to be evaluated."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OX4Sk_uyOaYa"
},
"outputs": [],
"source": [
"evaluation_state = evaluation_process.initialize()\n",
"model_weights = training_process.get_model_weights(train_state)\n",
"evaluation_state = evaluation_process.set_model_weights(evaluation_state, model_weights)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "F5H66kcZRMBB"
},
"source": [
"Now with the evaluation state containing the model weights to be evaluated, we can compute evaluation metrics using evaluation datasets by calling the `next` method on the process, just like in training.\n",
"\n",
"This will again return a `tff.learning.templates.LearingProcessOutput` instance."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kT53YGdkRccR"
},
"outputs": [],
"source": [
"evaluation_output = evaluation_process.next(evaluation_state, federated_train_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UeEsdwJgRGMW"
},
"source": [
"Here's what we get. Note the numbers look marginally better than what was\n",
"reported by the last round of training above. By convention, the training\n",
"metrics reported by the iterative training process generally reflect the\n",
"performance of the model at the beginning of the training round, so the\n",
"evaluation metrics will always be one step ahead."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zwCy1IPxOfiT",
"outputId": "356887df-7531-4ea4-d0ca-b57d7dd34a41"
},
"outputs": [
{
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"\"OrderedDict([('distributor', ()), ('client_work', OrderedDict([('eval', OrderedDict([('current_round_metrics', OrderedDict([('num_examples', 4860.0), ('loss', 1.6654209), ('accuracy', 0.3621399)])), ('total_rounds_metrics', OrderedDict([('num_examples', 4860.0), ('loss', 1.6654209), ('accuracy', 0.3621399)]))]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])\""
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"str(evaluation_output.metrics)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SpfgdNDoRjPy"
},
"source": [
"Now, let's compile a test sample of federated data and rerun evaluation on the\n",
"test data. The data will come from the same sample of real users, but from a\n",
"distinct held-out data set."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "in8vProVNc04",
"outputId": "2d1d1487-f567-4658-e848-8ae7e63c6a72"
},
"outputs": [
{
"data": {
"text/plain": [
"(10,\n",
" <_PrefetchDataset element_spec=OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))])>)"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"federated_test_data = make_federated_data(emnist_test, sample_clients)\n",
"\n",
"len(federated_test_data), federated_test_data[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ty-ZwfE0NJfV"
},
"outputs": [],
"source": [
"evaluation_output = evaluation_process.next(evaluation_state, federated_test_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "e5fGtIJYNqYH",
"outputId": "8739dabe-4fcc-47d7-9da1-f2869eac6cbb"
},
"outputs": [
{
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"\"OrderedDict([('distributor', ()), ('client_work', OrderedDict([('eval', OrderedDict([('current_round_metrics', OrderedDict([('num_examples', 580.0), ('loss', 1.7750846), ('accuracy', 0.33620688)])), ('total_rounds_metrics', OrderedDict([('num_examples', 580.0), ('loss', 1.7750846), ('accuracy', 0.33620688)]))]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])\""
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"str(evaluation_output.metrics)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "67vYxrDWzRcj"
},
"source": [
"This concludes the tutorial. We encourage you to play with the\n",
"parameters (e.g., batch sizes, number of users, epochs, learning rates, etc.), to modify the code above to simulate training on random samples of users in\n",
"each round, and to explore the other tutorials we've developed."
]
}
],
"metadata": {
"colab": {
"name": "federated_learning_for_image_classification.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}