Google I/O returns May 18-20! Reserve space and build your schedule Register now

Sending Different Data To Particular Clients With tff.federated_select

tff.federated_select" />
View on Run in Google Colab View source on GitHub Download notebook

This tutorial demonstrates how to implement custom federated algorithms in TFF that require sending different data to different clients. You may already be familiar with tff.federated_broadcast which sends a single server-placed value to all clients. This tutorial focuses on cases where different parts of a server-based value are sent to different clients. This may be useful for dividing up parts of a model across different clients in order to avoid sending the whole model to any single client.

Let's get started by importing both tensorflow and tensorflow_federated.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
import tensorflow as tf
import tensorflow_federated as tff

Sending Different Values Based On Client Data

Consider the case where we have some server-placed list from which we want to send a few elements to each client based on some client-placed data. For example, a list of strings on the server, and on the clients, a comma-separated list of indices to download. We can implement that as follows:

list_of_strings_type = tff.TensorType(tf.string, [None])
# We only ever send exactly two values to each client. The number of keys per
# client must be a fixed number across all clients.
number_of_keys_per_client = 2
keys_type = tff.TensorType(tf.int32, [number_of_keys_per_client])
get_size = tff.tf_computation(lambda x: tf.size(x))
select_fn = tff.tf_computation(lambda val, index: tf.gather(val, index))
client_data_type = tf.string

# A function from our client data to the indices of the values we'd like to
# select from the server.
def keys_for_client(client_string):
  # We assume our client data is a single string consisting of exactly three
  # comma-separated integers indicating which values to grab from the server.
  split = tf.strings.split([client_string], sep=',')[0]
  return tf.strings.to_number([split[0], split[1]], tf.int32)

def concatenate(values):
  def reduce_fn(acc, item):
    return tf.cond(tf.math.equal(acc, ''),
                   lambda: item,
                   lambda: tf.strings.join([acc, item], ','))
  return values.reduce('', reduce_fn)

@tff.federated_computation(tff.type_at_server(list_of_strings_type), tff.type_at_clients(client_data_type))
def broadcast_based_on_client_data(list_of_strings_at_server, client_data):
  keys_at_clients = tff.federated_map(keys_for_client, client_data)
  max_key = tff.federated_map(get_size, list_of_strings_at_server)
  values_at_clients = tff.federated_select(keys_at_clients, max_key, list_of_strings_at_server, select_fn)
  value_at_clients = tff.federated_map(concatenate, values_at_clients)
  return value_at_clients

Then we can simulate our computation by providing the server-placed list of strings as well as string data for each client:

client_data = ['0,1', '1,2', '2,0']
broadcast_based_on_client_data(['a', 'b', 'c'], client_data)
[<tf.Tensor: shape=(), dtype=string, numpy=b'a,b'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'b,c'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'c,a'>]

Sending A Randomized Element To Each Client

Alternatively, it may be useful to send a random portion of the server data to each client. We can implement that by first generating a random key on each client and then following a similar selection process to the one used above:

@tff.check_returns_type(tff.TensorType(tf.int32, [1]))
def get_random_key(max_key):
  return tf.random.uniform(shape=[1], minval=0, maxval=max_key, dtype=tf.int32)

list_of_strings_type = tff.TensorType(tf.string, [None])
get_size = tff.tf_computation(lambda x: tf.size(x))
select_fn = tff.tf_computation(lambda val, index: tf.gather(val, index))

def get_last_element(sequence):
  return sequence.reduce('', lambda _initial_state, val: val)

def broadcast_random_element(list_of_strings_at_server):
  max_key_at_server = tff.federated_map(get_size, list_of_strings_at_server)
  max_key_at_clients = tff.federated_broadcast(max_key_at_server)
  key_at_clients = tff.federated_map(get_random_key, max_key_at_clients)
  random_string_sequence_at_clients = tff.federated_select(
      key_at_clients, max_key_at_server, list_of_strings_at_server, select_fn)
  # Even though we only passed in a single key, `federated_select` returns a
  # sequence for each client. We only care about the last (and only) element.
  random_string_at_clients = tff.federated_map(get_last_element, random_string_sequence_at_clients)
  return random_string_at_clients

Since our broadcast_random_element function doesn't take in any client-placed data, we have to configure the TFF Simulation Runtime with a default number of clients to use:


Then we can simulate the selection. We can change default_num_clients above and the list of strings below to generate different results, or simply re-run the computation to generate different random outputs.

broadcast_random_element(tf.convert_to_tensor(['foo', 'bar', 'baz']))