ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tff.experimental.jax_computation

Decorates/wraps Python functions containing JAX code as TFF computations.

Used in the notebooks

Used in the tutorials

This wrapper can be used in a similar manner to tff.tf_computation, with exception of the following:

  • The code in the wrapped Python function must be JAX code that can be compiled to XLA (e.g., code that one would expect to be able to annotate with @jax.jit).

  • The inputs and outputs must be tensors, or (possibly recursively) nested structures of tensors. Sequences are currently not supported.

Example:

@tff.experimental.jax_computation(tf.int32)
def comp(x):
  return jax.numpy.add(x, np.int32(10))