tfp.stats.log_average_probs

View source on GitHub

Computes log(average(to_probs(logits))) in a numerically stable manner.

tfp.stats.log_average_probs(
    logits,
    sample_axis=0,
    event_axis=None,
    keepdims=False,
    validate_args=False,
    name=None
)

The meaning of to_probs is controlled by the event_axis argument. When event_axis is None, to_probs = tf.math.sigmoid and otherwise to_probs = lambda x: tf.math.log_softmax(x, axis=event_axis).

sample_axis and event_axis should have a null intersection. This requirement is always verified when validate_args is True.

Args:

  • logits: A float Tensor representing logits.
  • sample_axis: Scalar or vector Tensor designating axis holding samples, or None (meaning all axis hold samples). Default value: 0 (leftmost dimension).
  • event_axis: Scalar or vector Tensor designating the axis representing categorical logits. Default value: None (i.e., Bernoulli logits).
  • keepdims: Boolean. Whether to keep the sample axis as singletons. Default value: False (i.e., squeeze the reduced dimensions).
  • validate_args: Python bool, default False. When True distribution parameters are checked for validity despite possibly degrading runtime performance. When False invalid inputs may silently render incorrect outputs. Default value: False (i.e., do not validate args).
  • name: Python str name prefixed to Ops created by this function. Default value: None (i.e., 'log_average_probs').

Returns:

  • log_avg_probs: The natural log of the average of probs computed from logits.