Random noise generation in TFF

This tutorial will discuss the recommended best practices for random noise generation in TFF. Random noise generation is an important component of many privacy protection techniques in federated learning algorithms, e.g., differential privacy.

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Before we begin

First, let us make sure the notebook is connected to a backend that has the relevant components compiled.

pip install --quiet --upgrade tensorflow-federated
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

Run the following "Hello World" example to make sure the TFF environment is correctly setup. If it doesn't work, please refer to the Installation guide for instructions.

@tff.federated_computation
def hello_world():
  return 'Hello, World!'

hello_world()
b'Hello, World!'

Random noise on clients

The need for noise on clients generally falls into two cases: identical noise and i.i.d. noise.

  • For identical noise, the recommended pattern is to maintain a seed on the server, broadcast it to clients, and use the tf.random.stateless functions to generate noise.
  • For i.i.d. noise, use a tf.random.Generator initialized on the client with from_non_deterministic_state, in keeping with TF's recommendation to avoid the tf.random.<distribution> functions.

Client behavior is different from server (doesn't suffer from the pitfalls discussed later) because each client will build their own computation graph and initialize their own default seed.

Identical noise on clients

# Set to use 10 clients.
tff.backends.native.set_sync_local_cpp_execution_context(default_num_clients=10)

@tff.tf_computation
def noise_from_seed(seed):
  return tf.random.stateless_normal((), seed=seed)

seed_type_at_server = tff.FederatedType(tff.to_type((np.int64, [2])), tff.SERVER)

@tff.federated_computation(seed_type_at_server)
def get_random_min_and_max_deterministic(seed):
  # Broadcast seed to all clients.
  seed_on_clients = tff.federated_broadcast(seed)

  # Clients generate noise from seed deterministicly.
  noise_on_clients = tff.federated_map(noise_from_seed, seed_on_clients)

  # Aggregate and return the min and max of the values generated on clients.
  min = tff.federated_min(noise_on_clients)
  max = tff.federated_max(noise_on_clients)
  return min, max

seed = tf.constant([1, 1], dtype=tf.int64)
min, max = get_random_min_and_max_deterministic(seed)
assert min == max
print(f'Seed: {seed.numpy()}. All clients sampled value {min:8.3f}.')

seed += 1
min, max = get_random_min_and_max_deterministic(seed)
assert min == max
print(f'Seed: {seed.numpy()}. All clients sampled value {min:8.3f}.')
Seed: [1 1]. All clients sampled value    1.665.
Seed: [2 2]. All clients sampled value   -0.219.

Independent noise on clients

@tff.tf_computation
def nondeterministic_noise():
  gen = tf.random.Generator.from_non_deterministic_state()
  return gen.normal(())

@tff.federated_computation
def get_random_min_and_max_nondeterministic():
  noise_on_clients = tff.federated_eval(nondeterministic_noise, tff.CLIENTS)
  min = tff.federated_min(noise_on_clients)
  max = tff.federated_max(noise_on_clients)
  return min, max

min, max = get_random_min_and_max_nondeterministic()
assert min != max
print(f'Values differ across clients. {min:8.3f},{max:8.3f}.')

new_min, new_max = get_random_min_and_max_nondeterministic()
assert new_min != new_max
assert new_min != min and new_max != max
print(f'Values differ across rounds.  {new_min:8.3f},{new_max:8.3f}.')
Values differ across clients.   -1.490,   1.172.
Values differ across rounds.    -1.358,   1.208.

Model initializer on clients

def _keras_model():
  inputs = tf.keras.Input(shape=(1,))
  outputs = tf.keras.layers.Dense(1)(inputs)
  return tf.keras.Model(inputs=inputs, outputs=outputs)

@tff.tf_computation
def tff_return_model_init():
  model = _keras_model()
  # return the initialized single weight value of the dense layer
  return tf.reshape(
      tff.learning.models.ModelWeights.from_model(model).trainable[0], [-1])[0]

@tff.federated_computation
def get_random_min_and_max_nondeterministic():
  noise_on_clients = tff.federated_eval(tff_return_model_init, tff.CLIENTS)
  min = tff.federated_min(noise_on_clients)
  max = tff.federated_max(noise_on_clients)
  return min, max

min, max = get_random_min_and_max_nondeterministic()
assert min != max
print(f'Values differ across clients. {min:8.3f},{max:8.3f}.')

new_min, new_max = get_random_min_and_max_nondeterministic()
assert new_min != new_max
assert new_min != min and new_max != max
print(f'Values differ across rounds.  {new_min:8.3f},{new_max:8.3f}.')
Values differ across clients.   -1.022,   1.567.
Values differ across rounds.    -1.675,   1.550.

Random noise on the server

Discouraged usage: directly using tf.random.normal

TF1.x like APIs tf.random.normal for random noise generation are strongly discouraged in TF2 according to the random noise generation tutorial in TF. Surprising behavior may happen when these APIs are used together with tf.function and tf.random.set_seed. For example, the following code will generate the same value with each call. This surprising behavior is expected for TF, and explanation can be found in the documentation of tf.random.set_seed.

tf.random.set_seed(1)

@tf.function
def return_one_noise(_):
  return tf.random.normal([])

n1=return_one_noise(1)
n2=return_one_noise(2) 
assert n1 == n2
print(n1.numpy(), n2.numpy())
0.3052047 0.3052047

In TFF, things are slightly different. If we wrap the noise generation as tff.tf_computation instead of tf.function, non-deterministic random noise will be generated. However, if we run this code snippet multiple times, different set of (n1, n2) will be generated each time. There is no easy way to set a global random seed for TFF.

tf.random.set_seed(1)

@tff.tf_computation
def return_one_noise(_):
  return tf.random.normal([])

n1=return_one_noise(1)
n2=return_one_noise(2) 
assert n1 != n2
print(n1, n2)
0.11990704 1.9185987

Moreover, deterministic noise can be generated in TFF without explicitly setting a seed. The function return_two_noise in the following code snippet returns two identical noise values. This is expected behavior because TFF will build computation graph in advance before execution. However, this suggests users have to pay attention on the usage of tf.random.normal in TFF.

Usage with care: tf.random.Generator

We can use tf.random.Generator as suggested in the TF tutorial.

@tff.tf_computation
def tff_return_one_noise(i):
  g=tf.random.Generator.from_seed(i)
  @tf.function
  def tf_return_one_noise():
    return g.normal([])
  return tf_return_one_noise()

@tff.federated_computation
def return_two_noise():
  return (tff_return_one_noise(1), tff_return_one_noise(2))

n1, n2 = return_two_noise() 
assert n1 != n2
print(n1, n2)
0.3052047 -0.38260335

However, users may have to be careful on its usage

In general, TFF prefers functional operations and we will showcase the usage of tf.random.stateless_* functions in the following sections.

In TFF for federated learning, we often work with nested structures instead of scalars and the previous code snippet can be naturally extended to nested structures.

@tff.tf_computation
def tff_return_one_noise(i):
  g=tf.random.Generator.from_seed(i)
  weights = [
         tf.ones([2, 2], dtype=tf.float32),
         tf.constant([2], dtype=tf.float32)
     ]
  @tf.function
  def tf_return_one_noise():
    return tf.nest.map_structure(lambda x: g.normal(tf.shape(x)), weights)
  return tf_return_one_noise()

@tff.federated_computation
def return_two_noise():
  return (tff_return_one_noise(1), tff_return_one_noise(2))

n1, n2 = return_two_noise() 
assert n1[1] != n2[1]
print('n1', n1)
print('n2', n2)
n1 [array([[0.3052047 , 0.5671378 ],
       [0.41852272, 0.2326421 ]], dtype=float32), array([1.1675092], dtype=float32)]
n2 [array([[-0.38260335, -0.4780486 ],
       [-0.5187485 , -1.8471988 ]], dtype=float32), array([-0.77835274], dtype=float32)]

A general recommendation in TFF is to use the functional tf.random.stateless_* functions for random noise generation. These functions take seed (a Tensor with shape [2] or a tuple of two scalar tensors) as an explicit input argument to generate random noise. We first define a helper class to maintain the seed as pseudo state. The helper RandomSeedGenerator has functional operators in a state-in-state-out fashion. It is reasonable to use a counter as pseudo state for tf.random.stateless_* as these functions scramble the seed before using it to make noises generated by correlated seeds statistically uncorrelated.

def timestamp_seed():
  # tf.timestamp returns microseconds as decimal places, thus scaling by 1e6.
  return tf.cast(tf.timestamp() * 1e6, tf.int64)

class RandomSeedGenerator():

  def initialize(self, seed=None):
    if seed is None:
      return tf.stack([timestamp_seed(), 0])
    else:
      return tf.constant(self.seed, dtype=tf.int64, shape=(2,))

  def next(self, state):
    return state + tf.constant([0, 1], tf.int64)

  def structure_next(self, state, nest_structure):
    "Returns seed in nested structure and the next state seed."
    flat_structure = tf.nest.flatten(nest_structure)
    flat_seeds = [state + tf.constant([0, i], tf.int64) for
                  i in range(len(flat_structure))]
    nest_seeds = tf.nest.pack_sequence_as(nest_structure, flat_seeds)
    return nest_seeds, flat_seeds[-1] + tf.constant([0, 1], tf.int64)

Now let us use the helper class and tf.random.stateless_normal to generate (nested structure of) random noise in TFF. The following code snippet looks a lot like a TFF iterative process, see simple_fedavg as an example of expressing federated learning algorithm as TFF iterative process. The pseudo seed state here for random noise generation is tf.Tensor that can be easily transported in TFF and TF functions.

@tff.tf_computation
def tff_return_one_noise(seed_state):
  g=RandomSeedGenerator()
  weights = [
         tf.ones([2, 2], dtype=tf.float32),
         tf.constant([2], dtype=tf.float32)
     ]
  @tf.function
  def tf_return_one_noise():
    nest_seeds, updated_state = g.structure_next(seed_state, weights)
    nest_noise = tf.nest.map_structure(lambda x,s: tf.random.stateless_normal(
        shape=tf.shape(x), seed=s), weights, nest_seeds)
    return nest_noise, updated_state
  return tf_return_one_noise()

@tff.tf_computation
def tff_init_state():
  g=RandomSeedGenerator()
  return g.initialize()

@tff.federated_computation
def return_two_noise():
  seed_state = tff_init_state()
  n1, seed_state = tff_return_one_noise(seed_state)
  n2, seed_state = tff_return_one_noise(seed_state)
  return (n1, n2)

n1, n2 = return_two_noise() 
assert n1[1] != n2[1]
print('n1', n1)
print('n2', n2)
n1 [array([[ 0.86828816,  0.8535084 ],
       [ 1.0053564 , -0.42096713]], dtype=float32), array([0.18048067], dtype=float32)]
n2 [array([[-1.1973879 , -0.2974589 ],
       [ 1.8309833 ,  0.17024393]], dtype=float32), array([0.68991095], dtype=float32)]