|View source on GitHub|
Registers a rule for substituting distributions in ASVI surrogates.
tfp.experimental.vi.register_asvi_substitution_rule( condition, substitution_fn )
To use a Normal surrogate for all location-scale family distributions, we could register the substitution:
tfp.experimental.vi.register_asvi_surrogate_substitution( condition=lambda distribution: ( hasattr(distribution, 'loc') and hasattr(distribution, 'scale')) substitution_fn=lambda distribution: ( # Invoking the event space bijector applies any relevant constraints, # e.g., that HalfCauchy samples must be `>= loc`. distribution.experimental_default_event_space_bijector()( tfd.Normal(loc=distribution.loc, scale=distribution.scale)))
This rule will fire when ASVI encounters a location-scale distribution,
and instructs ASVI to build a surrogate 'as if' the model had just used a
(possibly constrained) Normal in its place. Note that we could have used a
more precise condition, e.g., to limit the substitution to distributions with
name, if we had reason to think that a Normal distribution would
be a good surrogate for some model variables but not others.