oryx.distributions.mvn_conjugate_linear_update

Computes a conjugate normal posterior for a Bayesian linear regression.

We assume the following model:

latent ~ MVN(loc=prior_mean, scale=prior_scale)
observation ~ MVN(loc=linear_transformation.matvec(latent),
                  scale=likelihood_scale)

For Bayesian linear regression, the latent represents the weights, and the provided linear_transformation is the design matrix.

This method computes the multivariate normal posterior p(latent | observation), using LinearOperators to perform perform computations efficiently when the matrices involved have special structure.

prior_scale Instance of tf.linalg.LinearOperator of shape [..., num_features, num_features], specifying a scale matrix (any matrix L such that LL' = Q where Q is the covariance) for the prior on regression weights. May optionally be a float Tensor.
linear_transformation Instance of tf.linalg.LinearOperator of shape [..., num_outputs, num_features]), specifying a transformation of the latent values. May optionally be a float Tensor.
likelihood_scale Instance of tf.linalg.LinearOperator of shape [..., num_outputs, num_outputs] specifying a scale matrix (any matrix L such that LL' = Q where Q is the covariance) for the likelihood of observed targets. May optionally be a float Tensor.
observation Float Tensor of shape [..., num_outputs]]), specifying the observed values or regression targets. </td> </tr><tr> <td>prior_mean<a id="prior_mean"></a> </td> <td> Optional floatTensorof shape[..., num_features], specifying the prior mean. IfNone, the prior mean is assumed to be zero and some computation is avoided. Default value:None. </td> </tr><tr> <td>name<a id="name"></a> </td> <td> Option Pythonstr` name given to ops created by this function. Default value: 'mvn_conjugate_linear_update'.

posterior_mean Float Tensor of shape [..., num_features], giving the mean of the multivariate normal posterior on the latent value.
posterior_prec Instance of tf.linalg.LinearOperator of shape shape [..., num_features, num_features], giving the posterior precision (inverse covariance) matrix.

Mathematical details

Let the prior precision be denoted by prior_prec = prior_scale.matmul(prior_scale, adjoint_arg=True).inverse() and the likelihood precision by likelihood_prec = likelihood_scale.matmul( likelihood_scale, adjoint_arg=True).inverse(). Then the posterior p(latent | observation) is multivariate normal with precision

posterior_prec = (
  linear_transformation.matmul(
    likelihood_prec.matmul(linear_transformation), adjoint=True) +
   prior_prec)

and mean

posterior_mean = posterior_prec.solvevec(
  linear_transformation.matvec(
    likelihood_prec.matvec(observation) +
    prior_prec.matvec(prior_mean)))