tfp.experimental.vi.register_asvi_substitution_rule

Registers a rule for substituting distributions in ASVI surrogates.

condition Python callable that takes a Distribution instance and returns a Python bool indicating whether or not to substitute it. May also be a class type such as tfd.Normal, in which case the condition is interpreted as lambda distribution: isinstance(distribution, class).
substitution_fn Python callable that takes a Distribution instance and returns a new Distribution instance used to define the ASVI surrogate posterior. Note that this substitution does not modify the original model.

Example

To use a Normal surrogate for all location-scale family distributions, we could register the substitution:

tfp.experimental.vi.register_asvi_surrogate_substitution(
  condition=lambda distribution: (
    hasattr(distribution, 'loc') and hasattr(distribution, 'scale'))
  substitution_fn=lambda distribution: (
    # Invoking the event space bijector applies any relevant constraints,
    # e.g., that HalfCauchy samples must be `>= loc`.
    distribution.experimental_default_event_space_bijector()(
      tfd.Normal(loc=distribution.loc, scale=distribution.scale)))

This rule will fire when ASVI encounters a location-scale distribution, and instructs ASVI to build a surrogate 'as if' the model had just used a (possibly constrained) Normal in its place. Note that we could have used a more precise condition, e.g., to limit the substitution to distributions with a specific name, if we had reason to think that a Normal distribution would be a good surrogate for some model variables but not others.