|View source on GitHub|
Lower bound on Jensen-Shannon (JS) divergence.
tfp.vi.mutual_information.lower_bound_jensen_shannon( logu, joint_sample_mask=None, validate_args=False, name=None )
This lower bound on JS divergence is proposed in [Goodfellow et al. (2014)] and [Nowozin et al. (2016)]. When estimating lower bounds on mutual information, one can also use different approaches for training the critic w.r.t. estimating mutual information [(Poole et al., 2018)]. The JS lower bound is used to train the critic with the standard lower bound on the Jensen-Shannon divergence as used in GANs, and then evaluates the critic using the NWJ lower bound on KL divergence, i.e. mutual information. As Eq.7 and Eq.8 of [Nowozin et al. (2016)], the bound is given by
I_JS = E_p(x,y)[log( D(x,y) )] + E_p(x)p(y)[log( 1 - D(x,y) )]
where the first term is the expectation over the samples from joint distribution (positive samples), and the second is for the samples from marginal distributions (negative samples), with
D(x, y) = sigmoid(f(x, y)), log(D(x, y)) = softplus(-f(x, y)).
f(x, y) is a critic function that scores all pairs of samples.
Y are samples from a joint Gaussian distribution, with
0.8 and both of dimension
batch_size, rho, dim = 10000, 0.8, 1 y, eps = tf.split( value=tf.random.normal(shape=(2 * batch_size, dim), seed=7), num_or_size_splits=2, axis=0) mean, conditional_stddev = rho * y, tf.sqrt(1. - tf.square(rho)) x = mean + conditional_stddev * eps # Scores/unnormalized likelihood of pairs of samples `x[i], y[j]` # (For JS lower bound, the optimal critic is of the form `f(x, y) = 1 + # log(p(x | y) / p(x))` [(Poole et al., 2018)].) conditional_dist = tfd.MultivariateNormalDiag( mean, scale_identity_multiplier=conditional_stddev) conditional_scores = conditional_dist.log_prob(y[:, tf.newaxis, :]) marginal_dist = tfd.MultivariateNormalDiag(tf.zeros(dim), tf.ones(dim)) marginal_scores = marginal_dist.log_prob(y)[:, tf.newaxis] scores = 1 + conditional_scores - marginal_scores # Mask for joint samples in the score tensor # (The `scores` has its shape [x_batch_size, y_batch_size], i.e. # `scores[i, j] = f(x[i], y[j]) = log p(x[i] | y[j])`.) joint_sample_mask = tf.eye(batch_size, dtype=bool) # Lower bound on Jensen Shannon divergence lower_bound_jensen_shannon(logu=scores, joint_sample_mask=joint_sample_mask)
[batch_size_1, batch_size_2]representing critic scores (scores) for pairs of points (x, y) with
logu[i, j] = f(x[i], y[j]).
Tensorof the same size as
logumasking the positive samples by
True, i.e. samples from joint distribution
p(x, y). Default value:
None. By default, an identity matrix is constructed as the mask.
False. Whether to validate input with asserts. If
False, and the inputs are invalid, correct behavior is not guaranteed.
strname prefixed to Ops created by this function. Default value:
scalarfor lower bound on JS divergence.
: Ian J. Goodfellow, et al. Generative Adversarial Nets. In Conference on Neural Information Processing Systems, 2014. https://arxiv.org/abs/1406.2661. : Sebastian Nowozin, Botond Cseke, Ryota Tomioka. f-GAN: Training Generative Neural Samplers using Variational Divergence Minimization. In Conference on Neural Information Processing Systems, 2016. https://arxiv.org/abs/1606.00709. : Ben Poole, Sherjil Ozair, Aaron van den Oord, Alexander A. Alemi, George Tucker. On Variational Bounds of Mutual Information. In International Conference on Machine Learning, 2019. https://arxiv.org/abs/1905.06922.