TensorFlow Probability on JAX

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

TensorFlow Probability (TFP) is a library for probabilistic reasoning and statistical analysis that now works on JAX! For those not familiar, JAX is a library for accelerated numerical computing based on composable function transformations.

We have ported a lot of TFP's most useful functionality to JAX while preserving the abstractions and APIs that many TFP users are now comfortable with.


TFP on JAX does not depend on TensorFlow; let's uninstall TensorFlow from this Colab entirely.

pip uninstall tensorflow -y -q

We can install TFP on JAX with the latest nightly builds of TFP.

pip install -Uq tfp-nightly[jax] > /dev/null

Let's import some useful Python libraries.

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn import datasets
/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
  import pandas.util.testing as tm

Let's also import some basic JAX functionality.

import jax.numpy as jnp
from jax import grad
from jax import jit
from jax import random
from jax import value_and_grad
from jax import vmap

Importing TFP on JAX

To use TFP on JAX, simply import the jax "substrate" and use it as you usually would tfp:

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels

Demo: Bayesian logistic regression

To demonstrate what we can do with the JAX backend, we'll implement Bayesian logistic regression applied to the classic Iris dataset.

First, let's import the Iris dataset and extract some metadata.

iris = datasets.load_iris()
features, labels = iris['data'], iris['target']

num_features = features.shape[-1]
num_classes = len(iris.target_names)

We can define the model using tfd.JointDistributionCoroutine. We'll put standard normal priors on both the weights and the bias term then write a target_log_prob function that pins the sampled labels to the data.

Root = tfd.JointDistributionCoroutine.Root
def model():
  w = yield Root(tfd.Sample(tfd.Normal(0., 1.),
                            sample_shape=(num_features, num_classes)))
  b = yield Root(
      tfd.Sample(tfd.Normal(0., 1.), sample_shape=(num_classes,)))
  logits = jnp.dot(features, w) + b
  yield tfd.Independent(tfd.Categorical(logits=logits),

dist = tfd.JointDistributionCoroutine(model)
def target_log_prob(*params):
  return dist.log_prob(params + (labels,))

We sample from dist to produce an initial state for MCMC. We can then define a function that takes in a random key and an initial state, and produces 500 samples from a No-U-Turn-Sampler (NUTS). Note that we can use JAX transformations like jit to compile our NUTS sampler using XLA.

init_key, sample_key = random.split(random.PRNGKey(0))
init_params = tuple(dist.sample(seed=init_key)[:-1])

def run_chain(key, state):
  kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-3)
  return tfp.mcmc.sample_chain(500,
      trace_fn=lambda _, results: results.target_log_prob,

states, log_probs = run_chain(sample_key, init_params)
plt.ylabel('Target Log Prob')
plt.xlabel('Iterations of NUTS')


Let's use our samples to perform Bayesian model averaging (BMA) by averaging the predicted probabilies of each set of weights.

First let's write a function that for a given set of parameters will produce the probabilities over each class. We can use dist.sample_distributions to obtain the final distribution in the model.

def classifier_probs(params):
  dists, _ = dist.sample_distributions(seed=random.PRNGKey(0),
                                       value=params + (None,))
  return dists[-1].distribution.probs_parameter()

We can vmap(classifier_probs) over the set of samples to get the predicted class probabilities for each of our samples. We then compute the average accuracy across each sample, and the accuracy from Bayesian model averaging.

all_probs = jit(vmap(classifier_probs))(states)
print('Average accuracy:', jnp.mean(all_probs.argmax(axis=-1) == labels))
print('BMA accuracy:', jnp.mean(all_probs.mean(axis=0).argmax(axis=-1) == labels))
Average accuracy: 0.96952
BMA accuracy: 0.97999996

Looks like BMA reduces our error rate by almost a third!


TFP on JAX has an identical API to TF where instead of accepting TF objects like tf.Tensors it accepts the JAX analogue. For example, wherever a tf.Tensor was previously used as input, the API now expects a JAX DeviceArray. Instead of returning a tf.Tensor, TFP methods will return DeviceArrays. TFP on JAX also works with nested structures of JAX objects, like a list or dictionary of DeviceArrays.


Most of TFP's distributions are supported in JAX with very similar semantics to their TF counterparts. They are also registered as JAX Pytrees, so they can be inputs and outputs of JAX-transformed functions.

Basic distributions

The log_prob method for distributions works the same.

dist = tfd.Normal(0., 1.)

Sampling from a distribution requires explicitly passing in a PRNGKey (or list of integers) as the seed keyword argument. Failing to explicitly pass in a seed will throw an error.

tfd.Normal(0., 1.).sample(seed=random.PRNGKey(0))
DeviceArray(-0.20584226, dtype=float32)

The shape semantics for distributions remain the same in JAX, where distributions will each have an event_shape and a batch_shape and drawing many samples will add additional sample_shape dimensions.

For example, a tfd.MultivariateNormalDiag with vector parameters will have a vector event shape and empty batch shape.

dist = tfd.MultivariateNormalDiag(
print('Event shape:', dist.event_shape)
print('Batch shape:', dist.batch_shape)
Event shape: (5,)
Batch shape: ()

On the other hand, a tfd.Normal parameterized with vectors will have a scalar event shape and vector batch shape.

dist = tfd.Normal(
print('Event shape:', dist.event_shape)
print('Batch shape:', dist.batch_shape)
Event shape: ()
Batch shape: (5,)

The semantics of taking log_prob of samples works the same in JAX too.

dist =  tfd.Normal(jnp.zeros(5), jnp.ones(5))
s = dist.sample(sample_shape=(10, 2), seed=random.PRNGKey(0))

dist =  tfd.Independent(tfd.Normal(jnp.zeros(5), jnp.ones(5)), 1)
s = dist.sample(sample_shape=(10, 2), seed=random.PRNGKey(0))
(10, 2, 5)
(10, 2)

Because JAX DeviceArrays are compatible with libraries like NumPy and Matplotlib, we can feed samples directly into a plotting function.

sns.distplot(tfd.Normal(0., 1.).sample(1000, seed=random.PRNGKey(0)))


Distribution methods are compatible with JAX transformations.

sns.distplot(jit(vmap(lambda key: tfd.Normal(0., 1.).sample(seed=key)))(
    random.split(random.PRNGKey(0), 2000)))


x = jnp.linspace(-5., 5., 100)
plt.plot(x, jit(vmap(grad(tfd.Normal(0., 1.).prob)))(x))


Because TFP distributions are registered as JAX pytree nodes, we can write functions with distributions as inputs or outputs and transform them using jit, but they are not yet supported as arguments to vmap-ed functions.

def random_distribution(key):
  loc_key, scale_key = random.split(key)
  loc, log_scale = random.normal(loc_key), random.normal(scale_key)
  return tfd.Normal(loc, jnp.exp(log_scale))
random_dist = random_distribution(random.PRNGKey(0))
print(random_dist.mean(), random_dist.variance())
0.14389051 0.081832744

Transformed distributions

Transformed distributions i.e. distributions whose samples are passed through a Bijector also work out of the box (bijectors work too! see below).

dist = tfd.TransformedDistribution(
    tfd.Normal(0., 1.),
sns.distplot(dist.sample(1000, seed=random.PRNGKey(0)))


Joint distributions

TFP offers JointDistributions to enable combining component distributions into a single distribution over multiple random variables. Currently, TFP offers three core variants (JointDistributionSequential, JointDistributionNamed, and JointDistributionCoroutine) all of which are supported in JAX. The AutoBatched variants are also all supported.

dist = tfd.JointDistributionSequential([
  tfd.Normal(0., 1.),
  lambda x: tfd.Normal(x, 1e-1)
plt.scatter(*dist.sample(1000, seed=random.PRNGKey(0)), alpha=0.5)


joint = tfd.JointDistributionNamed(dict(
    e=             tfd.Exponential(rate=1.),
    n=             tfd.Normal(loc=0., scale=2.),
    m=lambda n, e: tfd.Normal(loc=n, scale=e),
    x=lambda    m: tfd.Sample(tfd.Bernoulli(logits=m), 12),
{'e': DeviceArray(3.376818, dtype=float32),
 'm': DeviceArray(2.5449684, dtype=float32),
 'n': DeviceArray(-0.6027825, dtype=float32),
 'x': DeviceArray([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)}
Root = tfd.JointDistributionCoroutine.Root
def model():
  e = yield Root(tfd.Exponential(rate=1.))
  n = yield Root(tfd.Normal(loc=0, scale=2.))
  m = yield tfd.Normal(loc=n, scale=e)
  x = yield tfd.Sample(tfd.Bernoulli(logits=m), 12)

joint = tfd.JointDistributionCoroutine(model)

StructTuple(var0=DeviceArray(0.17315261, dtype=float32), var1=DeviceArray(-3.290489, dtype=float32), var2=DeviceArray(-3.1949058, dtype=float32), var3=DeviceArray([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32))

Other distributions

Gaussian processes also work in JAX mode!

k1, k2, k3 = random.split(random.PRNGKey(0), 3)
observation_noise_variance = 0.01
f = lambda x: jnp.sin(10*x[..., 0]) * jnp.exp(-x[..., 0]**2)
observation_index_points = random.uniform(
    k1, [50], minval=-1.,maxval= 1.)[..., jnp.newaxis]
observations = f(observation_index_points) + tfd.Normal(
    loc=0., scale=jnp.sqrt(observation_noise_variance)).sample(seed=k2)

index_points = jnp.linspace(-1., 1., 100)[..., jnp.newaxis]

kernel = tfpk.ExponentiatedQuadratic(length_scale=0.1)

gprm = tfd.GaussianProcessRegressionModel(

samples = gprm.sample(10, seed=k3)
for i in range(10):
  plt.plot(index_points, samples[i], alpha=0.5)
plt.plot(observation_index_points, observations, marker='o', linestyle='')


Hidden Markov models are also supported.

initial_distribution = tfd.Categorical(probs=[0.8, 0.2])
transition_distribution = tfd.Categorical(probs=[[0.7, 0.3],
                                                 [0.2, 0.8]])

observation_distribution = tfd.Normal(loc=[0., 15.], scale=[5., 10.])

model = tfd.HiddenMarkovModel(

[3.       6.       7.5      8.249999 8.625001 8.812501 8.90625 ]

/usr/local/lib/python3.6/dist-packages/tensorflow_probability/substrates/jax/distributions/hidden_markov_model.py:483: UserWarning: HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug in which the transition model was applied prior to the initial step. This bug has been fixed. You may observe a slight change in behavior.
  'HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug '

[ 1.3641367  0.505798   1.3626463  3.6541772  2.272286  15.10309
 22.794212 ]

A few distributions like PixelCNN are not supported yet due to strict dependencies on TensorFlow or XLA incompatibilities.


Most of TFP's bijectors are supported in JAX today!

DeviceArray(0., dtype=float32)
bij = tfb.Shift(1.)(tfb.Scale(3.))
[4. 4. 4. 4. 4.]
[0. 0. 0. 0. 0.]

b = tfb.FillScaleTriL(diag_bijector=tfb.Exp(), diag_shift=None)
print(b.forward(x=[0., 0., 0.]))
print(b.inverse(y=[[1., 0], [.5, 2]]))
[[1. 0.]
 [0. 1.]]
[0.6931472 0.5       0.       ]

b = tfb.Chain([tfb.Exp(), tfb.Softplus()])
# or:
# b = tfb.Exp()(tfb.Softplus())
[1.3678794 1.3678794 1.3678794 1.3678794 1.3678794]

Bijectors are compatible with JAX transformations like jit, grad and vmap.

DeviceArray([     -inf, 0.       , 0.6931472, 1.0986123], dtype=float32)
x = jnp.linspace(0., 1., 100)
plt.plot(x, jit(grad(lambda x: vmap(tfb.Sigmoid().inverse)(x).sum()))(x))


Some bijectors, like RealNVP and FFJORD are not yet supported.


We've ported tfp.mcmc to JAX as well, so we can run algorithms like Hamiltonian Monte Carlo (HMC) and the No-U-Turn-Sampler (NUTS) in JAX.

target_log_prob = tfd.MultivariateNormalDiag(jnp.zeros(2), jnp.ones(2)).log_prob

Unlike TFP on TF, we are required to pass a PRNGKey into sample_chain using the seed keyword argument.

def run_chain(key, state):
  kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-1)
  return tfp.mcmc.sample_chain(1000,
      trace_fn=lambda _, results: results.target_log_prob,
states, log_probs = jit(run_chain)(random.PRNGKey(0), jnp.zeros(2))
plt.scatter(*states.T, alpha=0.5)



To run multiple chains, we can either pass a batch of states into sample_chain or use vmap (though we have not yet explored performance differences between the two approaches).

states, log_probs = jit(run_chain)(random.PRNGKey(0), jnp.zeros([10, 2]))
for i in range(10):
  plt.scatter(*states[:, i].T, alpha=0.5)
for i in range(10):
  plt.plot(log_probs[:, i], alpha=0.5)




TFP on JAX supports some important optimizers like BFGS and L-BFGS. Let's set up a simple scaled quadratic loss function.

minimum = jnp.array([1.0, 1.0])  # The center of the quadratic bowl.
scales = jnp.array([2.0, 3.0])  # The scales along the two axes.

# The objective function and the gradient.
def quadratic_loss(x):
  return jnp.sum(scales * jnp.square(x - minimum))

start = jnp.array([0.6, 0.8])  # Starting point for the search.

BFGS can find the minimum of this loss.

optim_results = tfp.optimizer.bfgs_minimize(
    value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

# Check that the search converged
# Check that the argmin is close to the actual value.
np.testing.assert_allclose(optim_results.position, minimum)
# Print out the total number of function evaluations it took. Should be 5.
print("Function evaluations: %d" % optim_results.num_objective_evaluations)
Function evaluations: 5

So can L-BFGS.

optim_results = tfp.optimizer.lbfgs_minimize(
    value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

# Check that the search converged
# Check that the argmin is close to the actual value.
np.testing.assert_allclose(optim_results.position, minimum)
# Print out the total number of function evaluations it took. Should be 5.
print("Function evaluations: %d" % optim_results.num_objective_evaluations)
Function evaluations: 5

To vmap L-BFGS, let's set up a function that optimizes the loss for a single starting point.

def optimize_single(start):
  return tfp.optimizer.lbfgs_minimize(
      value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

all_results = jit(vmap(optimize_single))(
    random.normal(random.PRNGKey(0), (10, 2)))
assert all(all_results.converged)
for i in range(10):
  np.testing.assert_allclose(optim_results.position[i], minimum)
print("Function evaluations: %s" % all_results.num_objective_evaluations)
Function evaluations: [6 6 9 6 6 8 6 8 5 9]


There are some fundamental differences between TF and JAX, some TFP behaviors will be different between the two substrates and not all functionality is supported. For example,

  • TFP on JAX does not support anything like tf.Variable since nothing like it exists in JAX. This also means utilities like tfp.util.TransformedVariable are not supported either.
  • tfp.layers is not supported in the backend yet, due to its dependence on Keras and tf.Variables.
  • tfp.math.minimize does not work in TFP on JAX because of its dependence on tf.Variable.
  • With TFP on JAX, tensor shapes are always concrete integer values and are never unknown/dynamic as in TFP on TF.
  • Pseudorandomness is handled differently in TF and JAX (see appendix).
  • Libraries in tfp.experimental are not guaranteed to exist in the JAX substrate.
  • Dtype promotion rules are different between TF and JAX. TFP on JAX tries to respect TF's dtype semantics internally, for consistency.
  • Bijectors have not yet been registered as JAX pytrees.

To see the complete list of what is supported in TFP on JAX, please refer to the API documentation.


We've ported a lot of TFP's features to JAX and are excited to see what everyone will build. Some functionality is not yet supported; if we've missed something important to you (or if you find a bug!) please reach out to us -- you can email tfprobability@tensorflow.org or file an issue on our Github repo.

Appendix: pseudorandomness in JAX

JAX's pseudorandom number generation (PRNG) model is stateless. Unlike a stateful model, there is no mutable global state that evolves after each random draw. In JAX's model, we start with a PRNG key, which acts like a pair of 32-bit integers. We can construct these keys by using jax.random.PRNGKey.

key = random.PRNGKey(0)  # Creates a key with value [0, 0]
[0 0]

Random functions in JAX consume a key to deterministically produce a random variate, meaning they should not be used again. For example, we can use key to sample a normally distributed value, but we should not use key again elsewhere. Furthermore, passing the same value into random.normal will produce the same value.


So how do we ever draw multiple samples from a single key? The answer is key splitting. The basic idea is that we can split a PRNGKey into multiple, and each of the new keys can be treated as an independent source of randomness.

key1, key2 = random.split(key, num=2)
print(key1, key2)
[4146024105  967050713] [2718843009 1272950319]

Key splitting is deterministic but is chaotic, so each new key can now be used to draw a distinct random sample.

print(random.normal(key1), random.normal(key2))
0.14389051 -1.2515389

For more details about JAX's deterministic key splitting model, see this guide.