View source on GitHub |
Computes the Monte-Carlo approximation of E_p[f(X)]
.
tfp.substrates.jax.monte_carlo.expectation(
f,
samples,
log_prob=None,
use_reparameterization=True,
axis=0,
keepdims=False,
name=None
)
This function computes the Monte-Carlo approximation of an expectation, i.e.,
E_p[f(X)] approx= m**-1 sum_i^m f(x_j), x_j ~iid p(X)
where:
x_j = samples[j, ...]
,log(p(samples)) = log_prob(samples)
andm = prod(shape(samples)[axis])
.
Tricks: Reparameterization and Score-Gradient
When p is "reparameterized", i.e., a diffeomorphic transformation of a
parameterless distribution (e.g.,
Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)
), we can swap gradient and
expectation, i.e.,
grad[ Avg{ s_i : i=1...n } ] = Avg{ grad[s_i] : i=1...n }
where
S_n = Avg{s_i}
and s_i = f(x_i), x_i ~ p
.
However, if p is not reparameterized, TensorFlow's gradient will be incorrect
since the chain-rule stops at samples of non-reparameterized distributions.
(The non-differentiated result, approx_expectation
, is the same regardless
of use_reparameterization
.) In this circumstance using the Score-Gradient
trick results in an unbiased gradient, i.e.,
grad[ E_p[f(X)] ]
= grad[ int dx p(x) f(x) ]
= int dx grad[ p(x) f(x) ]
= int dx [ p'(x) f(x) + p(x) f'(x) ]
= int dx p(x) [p'(x) / p(x) f(x) + f'(x) ]
= int dx p(x) grad[ f(x) p(x) / stop_grad[p(x)] ]
= E_p[ grad[ f(x) p(x) / stop_grad[p(x)] ] ]
Unless p is not reparameterized, it is usually preferable to
use_reparameterization = True
.
Example Use:
# Monte-Carlo approximation of a reparameterized distribution, e.g., Normal.
num_draws = int(1e5)
p = tfp.distributions.Normal(loc=0., scale=1.)
q = tfp.distributions.Normal(loc=1., scale=2.)
exact_kl_normal_normal = tfp.distributions.kl_divergence(p, q)
# ==> 0.44314718
approx_kl_normal_normal = tfp.monte_carlo.expectation(
f=lambda x: p.log_prob(x) - q.log_prob(x),
samples=p.sample(num_draws, seed=42),
log_prob=p.log_prob,
use_reparameterization=(p.reparameterization_type
== tfp.distributions.FULLY_REPARAMETERIZED))
# ==> 0.44632751
# Relative Error: <1%
# Monte-Carlo approximation of non-reparameterized distribution,
# e.g., Bernoulli.
num_draws = int(1e5)
p = tfp.distributions.Bernoulli(probs=0.4)
q = tfp.distributions.Bernoulli(probs=0.8)
exact_kl_bernoulli_bernoulli = tfp.distributions.kl_divergence(p, q)
# ==> 0.38190854
approx_kl_bernoulli_bernoulli = tfp.monte_carlo.expectation(
f=lambda x: p.log_prob(x) - q.log_prob(x),
samples=p.sample(num_draws, seed=42),
log_prob=p.log_prob,
use_reparameterization=(p.reparameterization_type
== tfp.distributions.FULLY_REPARAMETERIZED))
# ==> 0.38336259
# Relative Error: <1%
# For comparing the gradients, see `expectation_test.py`.
approx_kl_p_q = bf.monte_carlo_variational_loss(
p_log_prob=q.log_prob,
q=p,
discrepancy_fn=bf.kl_reverse,
num_draws=num_draws)
Returns | |
---|---|
approx_expectation
|
Tensor corresponding to the Monte-Carlo approximation
of E_p[f(X)] .
|
Raises | |
---|---|
ValueError
|
if f is not a Python callable .
|
ValueError
|
if use_reparameterization=False and log_prob is not a Python
callable .
|