|View on TensorFlow.org||Run in Google Colab||View source on GitHub||Download notebook|
Before you start
Before you start, please run the following to make sure that your environment is correctly setup. If you don't see a greeting, please refer to the Installation guide for instructions.
pip install --quiet --upgrade tensorflow-federated
from collections.abc import Callable import tensorflow as tf import tensorflow_federated as tff
Composing Learning Algorithms
The Building Your Own Federated Learning Algorithm Tutorial used TFF's federated core to directly implement a version of the Federated Averaging (FedAvg) algorithm.
In this tutorial, you will use federated learning components in TFF's API to build federated learning algorithms in a modular manner, without having to re-implement everything from scratch.
For the purposes of this tutorial, you will implement a variant of FedAvg that employs gradient clipping through local training.
Learning Algorithm Building Blocks
At a high level, many learning algorithms can be separated into 4 separate components, referred to as building blocks. These are as follows:
- Distributor (ie. server-to-client communication)
- Client work (ie. local client computation)
- Aggregator (ie. client-to-server communication)
- Finalizer (ie. server computation using aggregated client outputs)
While the Building Your Own Federated Learning Algorithm Tutorial implemented all of these building blocks from scratch, this is often unnecessary. Instead, you can re-use building blocks from similar algorithms.
In this case, to implement FedAvg with gradient clipping, you only need to modify the client work building block. The remaining blocks can be identical to what is used in "vanilla" FedAvg.
Implementing the Client Work
First, let's write TF logic that does local model training with gradient clipping. For simplicity, gradients will be clipped have norm at most 1.
@tf.function def client_update(model: tff.learning.models.VariableModel, dataset: tf.data.Dataset, server_weights: tff.learning.models.ModelWeights, client_optimizer: tf.keras.optimizers.Optimizer): """Performs training (using the server model weights) on the client's dataset.""" # Initialize the client model with the current server weights. client_weights = tff.learning.models.ModelWeights.from_model(model) tf.nest.map_structure(lambda x, y: x.assign(y), client_weights, server_weights) # Use the client_optimizer to update the local model. # Keep track of the number of examples as well. num_examples = 0.0 for batch in dataset: with tf.GradientTape() as tape: # Compute a forward pass on the batch of data outputs = model.forward_pass(batch) num_examples += tf.cast(outputs.num_examples, tf.float32) # Compute the corresponding gradient grads = tape.gradient(outputs.loss, client_weights.trainable) # Compute the gradient norm and clip gradient_norm = tf.linalg.global_norm(grads) if gradient_norm > 1: grads = tf.nest.map_structure(lambda x: x/gradient_norm, grads) grads_and_vars = zip(grads, client_weights.trainable) # Apply the gradient using a client optimizer. client_optimizer.apply_gradients(grads_and_vars) # Compute the difference between the server weights and the client weights client_update = tf.nest.map_structure(tf.subtract, client_weights.trainable, server_weights.trainable) return tff.learning.templates.ClientResult( update=client_update, update_weight=num_examples)
There are a few important points about the code above. First, it keeps track of the number of examples seen, as this will constitute the weight of the client update (when computing an average across clients).
Creating a ClientWorkProcess
While the TF logic above will do local training with clipping, it still needs to be wrapped in TFF code in order to create the necessary building block.
Specifically, the 4 building blocks are represented as a
tff.templates.MeasuredProcess. This means that all 4 blocks have both an
next function used to instantiate and run the computation.
This allows each building block to keep track of its own state (stored at the server) as needed to perform its operations. While it will not be used in this tutorial, it can be used for things like tracking how many iterations have occurred, or keeping track of optimizer states.
Client work TF logic should generally be wrapped as a
tff.learning.templates.ClientWorkProcess, which codifies the expected types going into and out of the client's local training. It can be parameterized by a model and optimizer, as below.
def build_gradient_clipping_client_work( model_fn: Callable[, tff.learning.models.VariableModel], optimizer_fn: Callable[, tf.keras.optimizers.Optimizer], ) -> tff.learning.templates.ClientWorkProcess: """Creates a client work process that uses gradient clipping.""" with tf.Graph().as_default(): # Wrap model construction in a graph to avoid polluting the global context # with variables created for this model. model = model_fn() data_type = tff.SequenceType(model.input_spec) model_weights_type = tff.learning.models.weights_type_from_model(model) @tff.federated_computation def initialize_fn(): return tff.federated_value((), tff.SERVER) @tff.tf_computation(model_weights_type, data_type) def client_update_computation(model_weights, dataset): model = model_fn() optimizer = optimizer_fn() return client_update(model, dataset, model_weights, optimizer) @tff.federated_computation( initialize_fn.type_signature.result, tff.type_at_clients(model_weights_type), tff.type_at_clients(data_type) ) def next_fn(state, model_weights, client_dataset): client_result = tff.federated_map( client_update_computation, (model_weights, client_dataset)) # Return empty measurements, though a more complete algorithm might # measure something here. measurements = tff.federated_value((), tff.SERVER) return tff.templates.MeasuredProcessOutput(state, client_result, measurements) return tff.learning.templates.ClientWorkProcess( initialize_fn, next_fn)
Composing a Learning Algorithm
Let's put the client work above into a full-fledged algorithm. First, let's set up our data and model.
Preparing the input data
Load and preprocess the EMNIST dataset included in TFF. For more details, see the image classification tutorial.
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
In order to feed the dataset into our model, the data is flattened and converted into tuples of the form
Let's select a small number of clients, and apply the preprocessing above to their datasets.
NUM_CLIENTS = 10 BATCH_SIZE = 20 def preprocess(dataset): def batch_format_fn(element): """Flatten a batch of EMNIST data and return a (features, label) tuple.""" return (tf.reshape(element['pixels'], [-1, 784]), tf.reshape(element['label'], [-1, 1])) return dataset.batch(BATCH_SIZE).map(batch_format_fn) client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS] federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x)) for x in client_ids ]
Preparing the model
This uses the same model as in the image classification tutorial. This model (implemented via
tf.keras) has a single hidden layer, followed by a softmax layer. In order to use this model in TFF, Keras model is wrapped as a
tff.learning.models.VariableModel. This allows us to perform the model's forward pass within TFF, and extract model outputs. For more details, also see the image classification tutorial.
def create_keras_model(): initializer = tf.keras.initializers.GlorotNormal(seed=0) return tf.keras.models.Sequential([ tf.keras.layers.Input(shape=(784,)), tf.keras.layers.Dense(10, kernel_initializer=initializer), tf.keras.layers.Softmax(), ]) def model_fn(): keras_model = create_keras_model() return tff.learning.models.from_keras_model( keras_model, input_spec=federated_train_data.element_spec, loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
Preparing the optimizers
Just as in
tff.learning.algorithms.build_weighted_fed_avg, there are two optimizers here: A client optimizer, and a server optimizer. For simplicity, the optimizers will be SGD with different learning rates.
client_optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate=0.01) server_optimizer_fn = lambda: tf.keras.optimizers.SGD(learning_rate=1.0)
Defining the building blocks
Now that the client work building block, data, model, and optimizers are set up, it remains to create building blocks for the distributor, the aggregator, and the finalizer. This can be done just by borrowing some defaults available in TFF and that are used by FedAvg.
@tff.tf_computation() def initial_model_weights_fn(): return tff.learning.models.ModelWeights.from_model(model_fn()) model_weights_type = initial_model_weights_fn.type_signature.result distributor = tff.learning.templates.build_broadcast_process(model_weights_type) client_work = build_gradient_clipping_client_work(model_fn, client_optimizer_fn) # TFF aggregators use a factory pattern, which create an aggregator # based on the output type of the client work. This also uses a float (the number # of examples) to govern the weight in the average being computed.) aggregator_factory = tff.aggregators.MeanFactory() aggregator = aggregator_factory.create(model_weights_type.trainable, tff.TensorType(tf.float32)) finalizer = tff.learning.templates.build_apply_optimizer_finalizer( server_optimizer_fn, model_weights_type)
Composing the building blocks
Finally, you can use a built-in composer in TFF for putting the building blocks together. This one is a relatively simple composer, which takes the 4 building blocks above and wires their types together.
fed_avg_with_clipping = tff.learning.templates.compose_learning_process( initial_model_weights_fn, distributor, client_work, aggregator, finalizer )
Running the algorithm
Now that the algorithm is done, let's run it. First, initialize the algorithm. The state of this algorithm has a component for each building block, along with one for the global model weights.
state = fed_avg_with_clipping.initialize() state.client_work
As expected, the client work has an empty state (remember the client work code above!). However, other building blocks may have non-empty state. For example, the finalizer keeps track of how many iterations have occurred. Since
next has not been run yet, it has a state of
Now run a training round.
learning_process_output = fed_avg_with_clipping.next(state, federated_train_data)
The output of this (
tff.learning.templates.LearningProcessOutput) has both a
.metrics output. Let's look at both.
Clearly, the finalizer state has incremented by one, as one round of
.next has been run.
OrderedDict([('distributor', ()), ('client_work', ()), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
While the metrics are empty, for more complex and practical algorithms they'll generally be full of useful information.
By using the building block/composers framework above, you can create entirely new learning algorithms, without having to re-do everything from scratch. However, this is only the starting point. This framework makes it much easier to express algorithms as simple modifications of FedAvg. For more algorithms, see
tff.learning.algorithms, which contains algorithms such as FedProx and FedAvg with client learning rate scheduling. These APIs can even aid implementations of entirely new algorithms, such as federated k-means clustering.