View source on GitHub |
Use VIMCO to lower the variance of gradient[csiszar_function(log(Avg(u))].
tfp.substrates.jax.vi.csiszar_vimco(
f, p_log_prob, q, num_draws, num_batch_draws=1, seed=None, name=None
)
This function generalizes VIMCO [(Mnih and Rezende, 2016)][1] to Csiszar f-Divergences.
The VIMCO loss is:
vimco = f(log(Avg{u[i] : i=0,...,m-1}))
where,
logu[i] = log( p(x, h[i]) / q(h[i] | x) )
h[i] iid~ q(H | x)
Interestingly, the VIMCO gradient is not the naive gradient of vimco
.
Rather, it is characterized by:
grad[vimco] - variance_reducing_term
where,
variance_reducing_term = Sum{ grad[log q(h[i] | x)] *
(vimco - f(log Avg{h[j;i] : j=0,...,m-1}))
: i=0, ..., m-1 }
h[j;i] = { u[j] j!=i
{ GeometricAverage{ u[k] : k!=i} j==i
(We omitted stop_gradient
for brevity. See implementation for more details.)
The Avg{h[j;i] : j}
term is a kind of "swap-out average" where the i
-th
element has been replaced by the leave-i
-out Geometric-average.
This implementation prefers numerical precision over efficiency, i.e.,
O(num_draws * num_batch_draws * prod(batch_shape) * prod(event_shape))
.
(The constant may be fairly large, perhaps around 12.)
Args | |
---|---|
f
|
Python callable representing a Csiszar-function in log-space.
|
p_log_prob
|
Python callable representing the natural-log of the
probability under distribution p . (In variational inference p is the
joint distribution.)
|
q
|
tf.Distribution -like instance; must implement: sample(n, seed) , and
log_prob(x) . (In variational inference q is the approximate posterior
distribution.)
|
num_draws
|
Integer scalar number of draws used to approximate the f-Divergence expectation. |
num_batch_draws
|
Integer scalar number of draws used to approximate the f-Divergence expectation. |
seed
|
PRNG seed for q.sample ; see tfp.random.sanitize_seed for details.
|
name
|
Python str name prefixed to Ops created by this function.
|
Returns | |
---|---|
vimco
|
The Csiszar f-Divergence generalized VIMCO objective. |
Raises | |
---|---|
ValueError
|
if num_draws < 2 .
|
References
[1]: Andriy Mnih and Danilo Rezende. Variational Inference for Monte Carlo objectives. In International Conference on Machine Learning, 2016. https://arxiv.org/abs/1602.06725