Save the date! Google I/O returns May 18-20 Register now

Experimental support for JAX in TFF

View on Run in Google Colab View on GitHub Download notebook

In addition to being a part of the TensorFlow ecosystem, TFF aims to enable interoperability with other frontend and backend ML frameworks. At the moment, support for other ML frameworks is still in the incubation phase, and the APIs and the functionality supported may change (largely as a function of demand from the users of TFF). This tutorial describes how to use TFF with JAX as an alternative ML frontend, and the XLA compiler as an alternative backend. The examples shown here are based on an entirely native JAX/XLA stack, end-to-end. The possibility of mixing code across frameworks (e.g., JAX with TensorFlow) will be discussed in one of the future tutorials.

As always, we welcome your contributions. If support for JAX/XLA or the ability to interoperate with other ML frameworks is important for you, please consider helping us evolve these capabilities towards parity with the remainder of TFF.

Before we begin

Please consult the main body of TFF documentation for how to configure your environment. Depending on where you are running this tutorial, you may want to uncomment and run some or all of the code below.

# !pip install --quiet --upgrade tensorflow-federated-nightly
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()

This tutorial also assumes you have reviewed TFF's primary TensorFlow tutorials, and that you are familiar with the core TFF concepts. If you have not done this yet, please consider reviewing at least one of them.

JAX computations

Support for JAX in TFF is designed to be symmetric with the manner in which TFF interoperates with TensorFlow, starting with imports:

import jax
import numpy as np
import tensorflow_federated as tff

Also, just like with TensorFlow, the foundation for expressing any TFF code is the logic that runs locally. You can express this logic in JAX, as shown below, using the @tff.experimental.jax_computation wrapper. It behaves similarly to the @tff.tf_computation that by now your are familiar with. Let's start with something simple, e.g., a computation that adds two integers:

@tff.experimental.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
  return jax.numpy.add(x, y)

You can use the JAX computation defined above just like you would normally use a TFF computation. For example, you can check its type signature, as follows:

'(<x=int32,y=int32> -> int32)'

Note that we used np.int32 to define the type of arguments. TFF does not distinguish between Numpy types (such as np.int32) and TensorFlow type (such as tf.int32). From TFF's perspective, they're just ways to refer to the same thing.

Now, remember that TFF is not Python (and if this doesn't ring a bell, please review some of our earlier tutorials, e.g., on custom algorithms). You can use the @tff.experimental.jax_computation wrapper with any JAX code that can be traced and serialized, i.e., with code that you would normally annotate with @jax.jit expected to be compiled into XLA (but you don't need to actually use the @jax.jit annotation to embed your JAX code in TFF).

Indeed, under the hood, TFF instantly compiles JAX computations to XLA. You can check this for yourself by manually extracting and printing the serialized XLA code from add_numbers, as follows:

comp_pb = tff.framework.serialize_computation(add_numbers)
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
HloModule xla_computation_add_numbers.7

ENTRY xla_computation_add_numbers.7 {
  constant.4 = pred[] constant(false)
  parameter.1 = (s32[], s32[]) parameter(0)
  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0
  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1
  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)
  ROOT tuple.6 = (s32[]) tuple(add.5)

Think of representation of JAX computations as XLA code as being the functional equivalent of tf.GraphDef for computations expressed in TensorFlow. It is portable and executable in a variety of environments that support XLA, just like the tf.GraphDef can be executed on any TensorFlow runtime.

TFF provides a runtime stack based on the XLA compiler as a backend. You can activate it as follows:


Now, you can execute the computation we defined above:

add_numbers(2, 3)

Easy enough. Let's go with the blow and do something more complicated, such as MNIST.

Example of MNIST training with canned API

As usual, we start by defining a bunch of TFF types for batches of data, and for the model (remember, TFF is a strongly-typed framework).

import collections

BATCH_TYPE = collections.OrderedDict([
    ('pixels', tff.TensorType(np.float32, (50, 784))),
    ('labels', tff.TensorType(np.int32, (50,)))

MODEL_TYPE = collections.OrderedDict([
    ('weights', tff.TensorType(np.float32, (784, 10))),
    ('bias', tff.TensorType(np.float32, (10,)))

Now, let's define a loss function for the model in JAX, taking the model and a single batch of data as a parameter:

def loss(model, batch):
  y = jax.nn.softmax(
          jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
  targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
  return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

Now, one way to go is to use a canned API. Here's an example of how you can use our API to create a training process based on the loss function just defined.

STEP_SIZE = 0.001

trainer = tff.experimental.learning.build_jax_federated_averaging_process(

You can use the above just as you would use a trainer build from a tf.Keras model in TensorFlow. For example, here's how you can create the initial model for training:

initial_model = trainer.initialize()
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])

In order to perform actual training, we need some data. Let's make random data to keep it simple. Since the data is random, we are going to evaluate on training data, since otherwise, with random eval data, it would be hard to expect the model to perform. Also, for this small-scale demo, we will not worry about randomly sampling clients (we leave it as an exercise to the user to explore those types of changes by following the templates from other tutorials):

def random_batch():
  pixels = np.random.uniform(
      low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
  labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
  return collections.OrderedDict([('pixels', pixels), ('labels', labels)])


train_data = [
    [random_batch() for _ in range(NUM_BATCHES)]
    for _ in range(NUM_CLIENTS)]

With that, we can perform a single step of training, as follows:

trained_model =, train_data)
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05,  2.54597180e-05, ...,
         5.61640409e-05, -5.32875274e-05, -4.62881755e-04],
       [ 7.30908650e-05,  4.67643113e-05,  2.03352147e-06, ...,
         3.77510623e-05,  3.52839161e-05, -4.59865667e-04],
       [ 8.14835730e-05,  3.03147244e-05, -1.89143739e-05, ...,
         1.12527239e-04,  4.09212225e-06, -4.59960109e-04],
       [ 9.23552434e-05,  2.44302555e-06, -2.20817346e-05, ...,
         7.61375341e-05,  1.76906979e-05, -4.43495519e-04],
       [ 1.17451040e-04,  2.47748958e-05,  1.04728279e-05, ...,
         5.26388249e-07,  7.21131510e-05, -4.67137404e-04],
       [ 3.75041491e-05,  6.58061981e-05,  1.14522081e-05, ...,
         2.52584141e-05,  3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04,  2.6502126e-05, -1.9462314e-05,  8.1269856e-05,
        2.1832302e-04,  1.6636557e-04,  1.2815947e-04,  9.0642272e-05,
        7.7109929e-05, -9.1987278e-04], dtype=float32))])

Let's evalue the result of the training step. To keep it easy, we can evaluate it in in a centralized fashion:

import itertools
eval_data = list(itertools.chain.from_iterable(train_data))

def average_loss(model, data):
  return np.mean([loss(model, batch) for batch in data])

print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))

The loss is decreasing. Great! Now, let's run this over multiple rounds:

for _ in range(NUM_ROUNDS):
  trained_model =, train_data)
  print(average_loss(trained_model, eval_data))

As you see, using JAX with TFF is not that different, albeit the experimental APIs are not yet on par with the TensorFlow APIs functionality-wise.

Under the hood

If you prefer not to use our canned API, you can implement your own custom computations, much in the same way as how you have seen it done in the custom algorithms tutorials for TensorFlow, except that you will use JAX's mechanism for gradient descent. For example, below is how you can define a JAX computation that updates the model on a single minibatch:

@tff.experimental.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
  grads = jax.api.grad(loss)(model, batch)
  return collections.OrderedDict([
      (k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']

Here's how you can test that it works:

sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))

One caveat of working with JAX is that it does not offer the equivalent of Thus, in order to iterate over datasets, you will need to use TFF's declarative contructs for operations on sequences, such as the one shown below:

@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
  return tff.sequence_reduce(batches, model, train_on_one_batch)

Let's see that it works:

sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))

The computation that performs a single round of training looks just like the one you may have seen in the TensorFlow tutorials:

    tff.FederatedType(MODEL_TYPE, tff.SERVER),
    tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
  locally_trained_models = tff.federated_map(
          ('model', tff.federated_broadcast(model)),
          ('batches', federated_data)]))
  return tff.federated_mean(locally_trained_models)

Let's see that it works:

trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))

As you see, using JAX in TFF, whether via canned APIs, or directly using the low-level TFF constructs, is similar to using TFF with TensorFlow. Stay tuned for future updates, and if you'd like to see better support for interoperability across ML frameworks, feel free to send us a pull request!