tfp.experimental.vi.build_trainable_location_scale_distribution

View source on GitHub

Builds a variational distribution from a location-scale family.

tfp.experimental.vi.build_trainable_location_scale_distribution(
    initial_loc,
    initial_scale,
    event_ndims,
    distribution_fn=tfp.distributions.Normal,
    validate_args=False,
    name=None
)

Args:

  • initial_loc: Float Tensor initial location.
  • initial_scale: Float Tensor initial scale.
  • event_ndims: Integer Tensor number of event dimensions in initial_loc.
  • distribution_fn: Optional constructor for a tfd.Distribution instance in a location-scale family. This should have signature dist = distribution_fn(loc, scale, validate_args). Default value: tfd.Normal.
  • validate_args: Python bool. Whether to validate input with asserts. This imposes a runtime cost. If validate_args is False, and the inputs are invalid, correct behavior is not guaranteed. Default value: False.
  • name: Python str name prefixed to ops created by this function. Default value: None (i.e., 'build_trainable_location_scale_distribution').

Returns:

  • posterior_dist: A tfd.Distribution instance.