Computes log(average(to_probs(logits)))
in a numerically stable manner.
tfp.substrates.numpy.stats.log_average_probs(
logits,
sample_axis=0,
event_axis=None,
keepdims=False,
validate_args=False,
name=None
)
The meaning of to_probs
is controlled by the event_axis
argument. When
event_axis
is None
, to_probs = tf.math.sigmoid
and otherwise
to_probs = lambda x: tf.math.log_softmax(x, axis=event_axis)
.
sample_axis
and event_axis
should have a null intersection. This
requirement is always verified when validate_args
is True
.
Args |
logits
|
A float Tensor representing logits.
|
sample_axis
|
Scalar or vector Tensor designating axis holding samples, or
None (meaning all axis hold samples).
Default value: 0 (leftmost dimension).
|
event_axis
|
Scalar or vector Tensor designating the axis representing
categorical logits.
Default value: None (i.e., Bernoulli logits).
|
keepdims
|
Boolean. Whether to keep the sample axis as singletons.
Default value: False (i.e., squeeze the reduced dimensions).
|
validate_args
|
Python bool , default False . When True distribution
parameters are checked for validity despite possibly degrading runtime
performance. When False invalid inputs may silently render incorrect
outputs.
Default value: False (i.e., do not validate args).
|
name
|
Python str name prefixed to Ops created by this function.
Default value: None (i.e., 'log_average_probs' ).
|
Returns |
log_avg_probs
|
The natural log of the average of probs computed from logits.
|