Decorates/wraps Python functions containing JAX code as TFF computations.
tff.jax_computation( *args, tff_internal_types=None )
Used in the notebooks
|Used in the tutorials|
This wrapper can be used in a similar manner to
exception of the following:
The code in the wrapped Python function must be JAX code that can be compiled to XLA (e.g., code that one would expect to be able to annotate with
The inputs and outputs must be tensors, or (possibly recursively) nested structures of tensors. Sequences are currently not supported.
@tff.jax_computation(tf.int32) def comp(x): return jax.numpy.add(x, np.int32(10))