|View source on GitHub|
Lower bound on Kullback-Leibler (KL) divergence from Nguyen at al.
tfp.vi.mutual_information.lower_bound_nguyen_wainwright_jordan( logu, joint_sample_mask=None, validate_args=False, name=None )
The lower bound was introduced by Nguyen, Wainwright, Jordan (NWJ) in
[Nguyen et al. (2010)], and also known as
f-GAN KL [(Nowozin et al.,
MINE-f [(Belghazi et al., 2018)].
I_NWJ = E_p(x,y)[f(x, y)] - 1/e * E_p(y)[Z(y)],
f(x, y) is a critic function that scores pairs of samples
Z(y) is the corresponding partition function:
Z(y) = E_p(x)[ exp(f(x, y)) ].
Y are samples from a joint Gaussian distribution, with correlation
0.8 and both of dimension
batch_size, rho, dim = 10000, 0.8, 1 y, eps = tf.split( value=tf.random.normal(shape=(2 * batch_size, dim), seed=7), num_or_size_splits=2, axis=0) mean, conditional_stddev = rho * y, tf.sqrt(1. - tf.square(rho)) x = mean + conditional_stddev * eps # Scores/unnormalized likelihood of pairs of samples `x[i], y[j]` # (For NWJ lower bound, the optimal critic is of the form `f(x, y) = 1 + # log(p(x | y) / p(x))` [(Poole et al., 2018)]. ) conditional_dist = tfd.MultivariateNormalDiag( mean, scale_identity_multiplier=conditional_stddev) conditional_scores = conditional_dist.log_prob(y[:, tf.newaxis, :]) marginal_dist = tfd.MultivariateNormalDiag(tf.zeros(dim), tf.ones(dim)) marginal_scores = marginal_dist.log_prob(y)[:, tf.newaxis] scores = 1 + conditional_scores - marginal_scores # Mask for joint samples in score tensor # (The `scores` has its shape [x_batch_size, y_batch_size], i.e. # `scores[i, j] = f(x[i], y[j]) = log p(x[i] | y[j])`.) joint_sample_mask = tf.eye(batch_size, dtype=bool) # Lower bound on KL divergence between p(x,y) and p(x)p(y), # i.e. the mutual information between `X` and `Y`. lower_bound_nguyen_wainwright_jordan( logu=scores, joint_sample_mask=joint_sample_mask)
[batch_size_1, batch_size_2]representing critic scores (scores) for pairs of points (x, y) with
logu[i, j] = f(x[i], y[j]).
Tensorof the same size as
logumasking the positive samples by
True, i.e. samples from joint distribution
p(x, y). Default value:
None. By default, an identity matrix is constructed as the mask.
False. Whether to validate input with asserts. If
False, and the inputs are invalid, correct behavior is not guaranteed.
strname prefixed to Ops created by this function. Default value:
scalarfor lower bound on KL divergence between joint and marginal distrbutions.
: XuanLong Nguyen, Martin J. Wainwright, Michael I. Jordan. Estimating Divergence Functionals and the Likelihood Ratio by Convex Risk Minimization. IEEE Transactions on Information Theory, 56(11):5847-5861, 2010. https://arxiv.org/abs/0809.0853. : Sebastian Nowozin, Botond Cseke, Ryota Tomioka. f-GAN: Training Generative Neural Samplers using Variational Divergence Minimization. In Conference on Neural Information Processing Systems, 2016. https://arxiv.org/abs/1606.00709. : Mohamed Ishmael Belghazi, et al. MINE: Mutual Information Neural Estimation. In International Conference on Machine Learning, 2018. https://arxiv.org/abs/1801.04062. : Ben Poole, Sherjil Ozair, Aaron van den Oord, Alexander A. Alemi, George Tucker. On Variational Bounds of Mutual Information. In International Conference on Machine Learning, 2019. https://arxiv.org/abs/1905.06922.