tfp.experimental.distributions.marginal_fns.ps.smart_where

As tf.where, but only calls x_fn/y_fn when condition not statically known.

condition A bool Tensor.
x_fn A callable returning a Tensor, for locations where condition is True.
y_fn A callable returning a Tensor, for locations where condition is False.

A Tensor equivalent to tf.where(condition, x_fn(), y_fn()).