Missed TensorFlow World? Check out the recap. Learn more

tfp.stats.covariance

View source on GitHub

Sample covariance between observations indexed by event_axis.

tfp.stats.covariance(
    x,
    y=None,
    sample_axis=0,
    event_axis=-1,
    keepdims=False,
    name=None
)

Given N samples of scalar random variables X and Y, covariance may be estimated as

Cov[X, Y] := N^{-1} sum_{n=1}^N (X_n - Xbar) Conj{(Y_n - Ybar)}
Xbar := N^{-1} sum_{n=1}^N X_n
Ybar := N^{-1} sum_{n=1}^N Y_n

For vector-variate random variables X = (X1, ..., Xd), Y = (Y1, ..., Yd), one is often interested in the covariance matrix, C_{ij} := Cov[Xi, Yj].

x = tf.random_normal(shape=(100, 2, 3))
y = tf.random_normal(shape=(100, 2, 3))

# cov[i, j] is the sample covariance between x[:, i, j] and y[:, i, j].
cov = tfp.stats.covariance(x, y, sample_axis=0, event_axis=None)

# cov_matrix[i, m, n] is the sample covariance of x[:, i, m] and y[:, i, n]
cov_matrix = tfp.stats.covariance(x, y, sample_axis=0, event_axis=-1)

Notice we divide by N (the numpy default), which does not create NaN when N = 1, but is slightly biased.

Args:

  • x: A numeric Tensor holding samples.
  • y: Optional Tensor with same dtype and shape as x. Default value: None (y is effectively set to x).
  • sample_axis: Scalar or vector Tensor designating axis holding samples, or None (meaning all axis hold samples). Default value: 0 (leftmost dimension).
  • event_axis: Scalar or vector Tensor, or None (scalar events). Axis indexing random events, whose covariance we are interested in. If a vector, entries must form a contiguous block of dims. sample_axis and event_axis should not intersect. Default value: -1 (rightmost axis holds events).
  • keepdims: Boolean. Whether to keep the sample axis as singletons.
  • name: Python str name prefixed to Ops created by this function. Default value: None (i.e., 'covariance').

Returns:

  • cov: A Tensor of same dtype as the x, and rank equal to rank(x) - len(sample_axis) + 2 * len(event_axis).

Raises:

  • AssertionError: If x and y are found to have different shape.
  • ValueError: If sample_axis and event_axis are found to overlap.
  • ValueError: If event_axis is found to not be contiguous.