Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tfp.experimental.substrates.jax.distributions.mvn_conjugate_linear_update

View source on GitHub

Computes a conjugate normal posterior for a Bayesian linear regression.

tfp.experimental.substrates.jax.distributions.mvn_conjugate_linear_update(
    prior_scale, linear_transformation, likelihood_scale, observation,
    prior_mean=None, name=None
)

We assume the following model:

latent ~ MVN(loc=prior_mean, scale=prior_scale)
observation ~ MVN(loc=linear_transformation.matvec(latent),
                  scale=likelihood_scale)

For Bayesian linear regression, the latent represents the weights, and the provided linear_transformation is the design matrix.

This method computes the multivariate normal posterior p(latent | observation), using LinearOperators to perform perform computations efficiently when the matrices involved have special structure.

Args:

  • prior_scale: Instance of tf.linalg.LinearOperator of shape [..., num_features, num_features], specifying a scale matrix (any matrix L such that LL' = Q where Q is the covariance) for the prior on regression weights. May optionally be a float Tensor.
  • linear_transformation: Instance of tf.linalg.LinearOperator of shape [..., num_outputs, num_features]), specifying a transformation of the latent values. May optionally be a float Tensor.
  • likelihood_scale: Instance of tf.linalg.LinearOperator of shape [..., num_outputs, num_outputs] specifying a scale matrix (any matrix L such that LL' = Q where Q is the covariance) for the likelihood of observed targets. May optionally be a float Tensor.
  • observation: Float Tensor of shape `[..., num_outputs]]), specifying the observed values or regression targets.
  • prior_mean: Optional float Tensor of shape [..., num_features], specifying the prior mean. If None, the prior mean is assumed to be zero and some computation is avoided. Default value: None.
  • name: Option Python str name given to ops created by this function. Default value: 'mvn_conjugate_linear_update'.

Returns:

  • posterior_mean: Float Tensor of shape [..., num_features], giving the mean of the multivariate normal posterior on the latent value.
  • posterior_prec: Instance of tf.linalg.LinearOperator of shape shape [..., num_features, num_features], giving the posterior precision (inverse covariance) matrix.

Mathematical details

Let the prior precision be denoted by prior_prec = prior_scale.matmul(prior_scale, adjoint_arg=True).inverse() and the likelihood precision by likelihood_prec = likelihood_scale.matmul( likelihood_scale, adjoint_arg=True).inverse(). Then the posterior p(latent | observation) is multivariate normal with precision

posterior_prec = (
  linear_transformation.matmul(
    likelihood_prec.matmul(linear_transformation), adjoint=True) +
   prior_prec)

and mean

posterior_mean = posterior_prec.solvevec(
  linear_transformation.matvec(
    likelihood_prec.matvec(observation) +
    prior_prec.matvec(prior_mean)))