tfp.substrates.jax.sts.moments_of_masked_time_series

Compute mean and variance, accounting for a mask.

time_series_tensor float Tensor time series of shape concat([batch_shape, [num_timesteps]]).
broadcast_mask bool Tensor of the same shape as time_series.

mean float Tensor of shape batch_shape.
variance float Tensor of shape batch_shape.