Conditions a probabilistic program on random variables.

Used in the notebooks

Used in the tutorials

conditional is a probabilistic program transformation that converts latent random variables into conditional inputs to the program. The random variables that are moved to the input are specified via a list of names that correspond to tagged random samples from the program. The final arguments to the output program correspond to the list of names passed into conditional.

Random variables that are conditioned are no longer random variables. This means that if a variable x is conditioned on, it will no longer appear in the joint_sample of a program and its log_prob will no longer be computed as part of a program's log_prob.


def model(key):
  k1, k2 = random.split(key)
  z = random_variable(random.normal, name='z')(k1)
  return z + random_variable(random.normal, name='x')(k2)
conditional(model, ['z'])(random.PRNGKey(0), 0.)  # => -1.25153887
conditional(model, ['z'])(random.PRNGKey(0), 1.)  # => -0.25153887
conditional(model, ['z'. 'x'])(random.PRNGKey(0), 1., 2.)  # => 3.

f A probabilistic program.
names A string or list of strings correspond to random variable names in f.

A probabilistic program with additional conditional inputs.