# tfp.vi.mutual_information.lower_bound_info_nce

InfoNCE lower bound on mutual information.

InfoNCE lower bound is proposed in [van den Oord et al. (2018)] based on noise contrastive estimation (NCE).

``````I(X; Y) >= 1/K sum(i=1:K, log( p_joint[i] / p_marginal[i])),
``````

where the numerator and the denominator are, respectively,

``````p_joint[i] = p(x[i] | y[i]) = exp( f(x[i], y[i]) ),
p_marginal[i] = 1/K sum(j=1:K, p(x[i] | y[j]) )
= 1/K sum(j=1:K, exp( f(x[i], y[j]) ) ),
``````

and `(x[i], y[i]), i=1:K` are samples from joint distribution `p(x, y)`. Pairs of points (x, y) are scored using a critic function `f`.

#### Example:

`X`, `Y` are samples from a joint Gaussian distribution, with correlation `0.8` and both of dimension `1`.

``````batch_size, rho, dim = 10000, 0.8, 1
y, eps = tf.split(
value=tf.random.normal(shape=(2 * batch_size, dim), seed=7),
num_or_size_splits=2, axis=0)
mean, conditional_stddev = rho * y, tf.sqrt(1. - tf.square(rho))
x = mean + conditional_stddev * eps

# Conditional distribution of p(x|y)
conditional_dist = tfd.MultivariateNormalDiag(
mean, scale_identity_multiplier=conditional_stddev)

# Scores/unnormalized likelihood of pairs of samples `x[i], y[j]`
# (The scores has its shape [x_batch_size, distibution_batch_size]
# as the `lower_bound_info_nce` requires `scores[i, j] = f(x[i], y[j])
# = log p(x[i] | y[j])`.)
scores = conditional_dist.log_prob(x[:, tf.newaxis, :])

# InfoNCE lower bound on mutual information
`logu` `float`-like `Tensor` of size `[batch_size_1, batch_size_2]` representing critic scores (scores) for pairs of points (x, y) with `logu[i, j] = f(x[i], y[j])`.
`joint_sample_mask` `bool`-like `Tensor` of the same size as `logu` masking the positive samples by `True`, i.e. samples from joint distribution `p(x, y)`. Default value: `None`. By default, an identity matrix is constructed as the mask.
`validate_args` Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed.
`name` Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'lower_bound_info_nce').
`lower_bound` `float`-like `scalar` for lower bound on mutual information.