Conditions a probabilistic program on random variables.
oryx.core.ppl.conditional(
f: oryx.core.ppl.LogProbFunction
,
names: Union[List[str], str]
) -> oryx.core.ppl.LogProbFunction
Used in the notebooks
conditional
is a probabilistic program transformation that converts latent
random variables into conditional inputs to the program. The random variables
that are moved to the input are specified via a list of names that correspond
to tagged random samples from the program. The final arguments to the output
program correspond to the list of names passed into conditional
.
Random variables that are conditioned are no longer random variables. This
means that if a variable x
is conditioned on, it will no longer appear in
the joint_sample
of a program and its log_prob
will no longer be computed
as part of a program's log_prob
.
Example:
def model(key):
k1, k2 = random.split(key)
z = random_variable(random.normal, name='z')(k1)
return z + random_variable(random.normal, name='x')(k2)
conditional(model, ['z'])(random.PRNGKey(0), 0.) # => -1.25153887
conditional(model, ['z'])(random.PRNGKey(0), 1.) # => -0.25153887
conditional(model, ['z'. 'x'])(random.PRNGKey(0), 1., 2.) # => 3.
Args |
f
|
A probabilistic program.
|
names
|
A string or list of strings correspond to random variable names in
f .
|
Returns |
A probabilistic program with additional conditional inputs.
|