![]() |
A piecewise rational quadratic spline, as developed in [1].
Inherits From: Bijector
tfp.bijectors.RationalQuadraticSpline(
bin_widths, bin_heights, knot_slopes, range_min=-1, validate_args=False,
name=None
)
This transformation represents a monotonically increasing piecewise rational
quadratic function. Outside of the bounds of knot_x
/knot_y
, the transform
behaves as an identity function.
Typically this bijector will be used as part of a chain, with splines for
trailing x
dimensions conditioned on some of the earlier x
dimensions, and
with the inverse then solved first for unconditioned dimensions, then using
conditioning derived from those inverses, and so forth. For example, if we
split a 15-D xs
vector into 3 components, we may implement a forward and
inverse as follows:
nsplits = 3
class SplineParams(tf.Module):
def __init__(self, nbins=32):
self._nbins = nbins
self._built = False
self._bin_widths = None
self._bin_heights = None
self._knot_slopes = None
def __call__(self, x, nunits):
if not self._built:
def _bin_positions(x):
out_shape = tf.concat((tf.shape(x)[:-1], (nunits, self._nbins)), 0)
x = tf.reshape(x, out_shape)
return tf.math.softmax(x, axis=-1) * (2 - self._nbins * 1e-2) + 1e-2
def _slopes(x):
out_shape = tf.concat((
tf.shape(x)[:-1], (nunits, self._nbins - 1)), 0)
x = tf.reshape(x, out_shape)
return tf.math.softplus(x) + 1e-2
self._bin_widths = tf.keras.layers.Dense(
nunits * self._nbins, activation=_bin_positions, name='w')
self._bin_heights = tf.keras.layers.Dense(
nunits * self._nbins, activation=_bin_positions, name='h')
self._knot_slopes = tf.keras.layers.Dense(
nunits * (self._nbins - 1), activation=_slopes, name='s')
self._built = True
return tfb.RationalQuadraticSpline(
bin_widths=self._bin_widths(x),
bin_heights=self._bin_heights(x),
knot_slopes=self._knot_slopes(x))
xs = np.random.randn(3, 15).astype(np.float32) # Keras won't Dense(.)(vec).
splines = [SplineParams() for _ in range(nsplits)]
def spline_flow():
stack = tfb.Identity()
for i in range(nsplits):
stack = tfb.RealNVP(5 * i, bijector_fn=splines[i])(stack)
return stack
ys = spline_flow().forward(xs)
ys_inv = spline_flow().inverse(ys) # ys_inv ~= xs
For a one-at-a-time autoregressive flow as in [1], it would be profitable to
implement a mask over xs
to parallelize either the inverse or the forward
pass and implement the other using a tf.while_loop
. See
tfp.bijectors.MaskedAutoregressiveFlow
for support doing so (paired with
tfp.bijectors.Invert
depending which direction should be parallel).
References
[1]: Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios. Neural Spline Flows. arXiv preprint arXiv:1906.04032, 2019. https://arxiv.org/abs/1906.04032
Args | |
---|---|
bin_widths
|
The widths of the spans between subsequent knot x positions,
a floating point Tensor . Must be positive, and at least 1-D. Innermost
axis must sum to the same value as bin_heights . The knot x positions
will be a first at range_min , followed by knots at range_min +
cumsum(bin_widths, axis=-1) .
|
bin_heights
|
The heights of the spans between subsequent knot y
positions, a floating point Tensor . Must be positive, and at least
1-D. Innermost axis must sum to the same value as bin_widths . The knot
y positions will be a first at range_min , followed by knots at
range_min + cumsum(bin_heights, axis=-1) .
|
knot_slopes
|
The slope of the spline at each knot, a floating point
Tensor . Must be positive. 1 s are implicitly padded for the first and
last implicit knots corresponding to range_min and range_min +
sum(bin_widths, axis=-1) . Innermost axis size should be 1 less than
that of bin_widths /bin_heights , or 1 for broadcasting.
|
range_min
|
The x /y position of the first knot, which has implicit
slope 1 . range_max is implicit, and can be computed as range_min +
sum(bin_widths, axis=-1) . Scalar floating point Tensor .
|
validate_args
|
Toggles argument validation (can hurt performance). |
name
|
Optional name scope for associated ops. (Defaults to
'RationalQuadraticSpline' ).
|
Attributes | |
---|---|
bin_heights
|
|
bin_widths
|
|
dtype
|
|
forward_min_event_ndims
|
Returns the minimal number of dimensions bijector.forward operates on.
Multipart bijectors return structured |
graph_parents
|
Returns this Bijector 's graph_parents as a Python list.
|
has_static_min_event_ndims
|
Returns True if the bijector has statically-known min_event_ndims .
|
inverse_min_event_ndims
|
Returns the minimal number of dimensions bijector.inverse operates on.
Multipart bijectors return structured |
is_constant_jacobian
|
Returns true iff the Jacobian matrix is not a function of x. |
knot_slopes
|
|
name
|
Returns the string name of this Bijector .
|
name_scope
|
Returns a tf.name_scope instance for this class.
|
parameters
|
Dictionary of parameters used to instantiate this Bijector .
|
range_min
|
|
submodules
|
Sequence of all sub-modules.
Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).
|
trainable_variables
|
Sequence of trainable variables owned by this module and its submodules. |
validate_args
|
Returns True if Tensor arguments will be validated. |
variables
|
Sequence of variables owned by this module and its submodules. |
Methods
forward
forward(
x, name='forward', **kwargs
)
Returns the forward Bijector
evaluation, i.e., X = g(Y).
Args | |
---|---|
x
|
Tensor (structure). The input to the 'forward' evaluation.
|
name
|
The name to give this op. |
**kwargs
|
Named arguments forwarded to subclass implementation. |
Returns | |
---|---|
Tensor (structure).
|
Raises | |
---|---|
TypeError
|
if self.dtype is specified and x.dtype is not
self.dtype .
|
NotImplementedError
|
if _forward is not implemented.
|
forward_dtype
forward_dtype(
dtype=UNSPECIFIED, name='forward_dtype', **kwargs
)
Returns the dtype returned by forward
for the provided input.
forward_event_ndims
forward_event_ndims(
event_ndims, **kwargs
)
Returns the number of event dimensions produced by forward
.
forward_event_shape
forward_event_shape(
input_shape
)
Shape of a single sample from a single batch as a TensorShape
.
Same meaning as forward_event_shape_tensor
. May be only partially defined.
Args | |
---|---|
input_shape
|
TensorShape (structure) indicating event-portion shape
passed into forward function.
|
Returns | |
---|---|
forward_event_shape_tensor
|
TensorShape (structure) indicating
event-portion shape after applying forward . Possibly unknown.
|
forward_event_shape_tensor
forward_event_shape_tensor(
input_shape, name='forward_event_shape_tensor'
)
Shape of a single sample from a single batch as an int32
1D Tensor
.
Args | |
---|---|
input_shape
|
Tensor , int32 vector (structure) indicating event-portion
shape passed into forward function.
|
name
|
name to give to the op |
Returns | |
---|---|
forward_event_shape_tensor
|
Tensor , int32 vector (structure)
indicating event-portion shape after applying forward .
|
forward_log_det_jacobian
forward_log_det_jacobian(
x, event_ndims, name='forward_log_det_jacobian', **kwargs
)
Returns both the forward_log_det_jacobian.
Args | |
---|---|
x
|
Tensor (structure). The input to the 'forward' Jacobian determinant
evaluation.
|
event_ndims
|
Number of dimensions in the probabilistic events being
transformed. Must be greater than or equal to
self.forward_min_event_ndims . The result is summed over the final
dimensions to produce a scalar Jacobian determinant for each event, i.e.
it has shape rank(x) - event_ndims dimensions.
Multipart bijectors require structured event_ndims, such that
rank(y[i]) - rank(event_ndims[i]) is the same for all elements i of
the structured input. Furthermore, the first event_ndims[i] of each
x[i].shape must be the same for all i (broadcasting is not allowed).
|
name
|
The name to give this op. |
**kwargs
|
Named arguments forwarded to subclass implementation. |
Returns | |
---|---|
Tensor (structure), if this bijector is injective.
If not injective this is not implemented.
|
Raises | |
---|---|
TypeError
|
if y 's dtype is incompatible with the expected output dtype.
|
NotImplementedError
|
if neither _forward_log_det_jacobian
nor {_inverse , _inverse_log_det_jacobian } are implemented, or
this is a non-injective bijector.
|
inverse
inverse(
y, name='inverse', **kwargs
)
Returns the inverse Bijector
evaluation, i.e., X = g^{-1}(Y).
Args | |
---|---|
y
|
Tensor (structure). The input to the 'inverse' evaluation.
|
name
|
The name to give this op. |
**kwargs
|
Named arguments forwarded to subclass implementation. |
Returns | |
---|---|
Tensor (structure), if this bijector is injective.
If not injective, returns the k-tuple containing the unique
k points (x1, ..., xk) such that g(xi) = y .
|
Raises | |
---|---|
TypeError
|
if y 's structured dtype is incompatible with the expected
output dtype.
|
NotImplementedError
|
if _inverse is not implemented.
|
inverse_dtype
inverse_dtype(
dtype=UNSPECIFIED, name='inverse_dtype', **kwargs
)
Returns the dtype returned by inverse
for the provided input.
inverse_event_ndims
inverse_event_ndims(
event_ndims, **kwargs
)
Returns the number of event dimensions produced by inverse
.
inverse_event_shape
inverse_event_shape(
output_shape
)
Shape of a single sample from a single batch as a TensorShape
.
Same meaning as inverse_event_shape_tensor
. May be only partially defined.
Args | |
---|---|
output_shape
|
TensorShape (structure) indicating event-portion shape
passed into inverse function.
|
Returns | |
---|---|
inverse_event_shape_tensor
|
TensorShape (structure) indicating
event-portion shape after applying inverse . Possibly unknown.
|
inverse_event_shape_tensor
inverse_event_shape_tensor(
output_shape, name='inverse_event_shape_tensor'
)
Shape of a single sample from a single batch as an int32
1D Tensor
.
Args | |
---|---|
output_shape
|
Tensor , int32 vector (structure) indicating
event-portion shape passed into inverse function.
|
name
|
name to give to the op |
Returns | |
---|---|
inverse_event_shape_tensor
|
Tensor , int32 vector (structure)
indicating event-portion shape after applying inverse .
|
inverse_log_det_jacobian
inverse_log_det_jacobian(
y, event_ndims, name='inverse_log_det_jacobian', **kwargs
)
Returns the (log o det o Jacobian o inverse)(y).
Mathematically, returns: log(det(dX/dY))(Y)
. (Recall that: X=g^{-1}(Y)
.)
Note that forward_log_det_jacobian
is the negative of this function,
evaluated at g^{-1}(y)
.
Args | |
---|---|
y
|
Tensor (structure). The input to the 'inverse' Jacobian determinant
evaluation.
|
event_ndims
|
Number of dimensions in the probabilistic events being
transformed. Must be greater than or equal to
self.inverse_min_event_ndims . The result is summed over the final
dimensions to produce a scalar Jacobian determinant for each event, i.e.
it has shape rank(y) - event_ndims dimensions.
Multipart bijectors require structured event_ndims, such that
rank(y[i]) - rank(event_ndims[i]) is the same for all elements i of
the structured input. Furthermore, the first event_ndims[i] of each
x[i].shape must be the same for all i (broadcasting is not allowed).
|
name
|
The name to give this op. |
**kwargs
|
Named arguments forwarded to subclass implementation. |
Returns | |
---|---|
ildj
|
Tensor , if this bijector is injective.
If not injective, returns the tuple of local log det
Jacobians, log(det(Dg_i^{-1}(y))) , where g_i is the restriction
of g to the ith partition Di .
|
Raises | |
---|---|
TypeError
|
if x 's dtype is incompatible with the expected inverse-dtype.
|
NotImplementedError
|
if _inverse_log_det_jacobian is not implemented.
|
with_name_scope
@classmethod
with_name_scope( method )
Decorator to automatically enter the module name scope.
class MyModule(tf.Module):
@tf.Module.with_name_scope
def __call__(self, x):
if not hasattr(self, 'w'):
self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
return tf.matmul(x, self.w)
Using the above module would produce tf.Variable
s and tf.Tensor
s whose
names included the module name:
mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
mod.w
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>
Args | |
---|---|
method
|
The method to wrap. |
Returns | |
---|---|
The original method wrapped such that it enters the module's name scope. |
__call__
__call__(
value, name=None, **kwargs
)
Applies or composes the Bijector
, depending on input type.
This is a convenience function which applies the Bijector
instance in
three different ways, depending on the input:
- If the input is a
tfd.Distribution
instance, returntfd.TransformedDistribution(distribution=input, bijector=self)
. - If the input is a
tfb.Bijector
instance, returntfb.Chain([self, input])
. - Otherwise, return
self.forward(input)
Args | |
---|---|
value
|
A tfd.Distribution , tfb.Bijector , or a (structure of) Tensor .
|
name
|
Python str name given to ops created by this function.
|
**kwargs
|
Additional keyword arguments passed into the created
tfd.TransformedDistribution , tfb.Bijector , or self.forward .
|
Returns | |
---|---|
composition
|
A tfd.TransformedDistribution if the input was a
tfd.Distribution , a tfb.Chain if the input was a tfb.Bijector , or
a (structure of) Tensor computed by self.forward .
|
Examples
sigmoid = tfb.Reciprocal()(
tfb.Shift(shift=1.)(
tfb.Exp()(
tfb.Scale(scale=-1.))))
# ==> `tfb.Chain([
# tfb.Reciprocal(),
# tfb.Shift(shift=1.),
# tfb.Exp(),
# tfb.Scale(scale=-1.),
# ])` # ie, `tfb.Sigmoid()`
log_normal = tfb.Exp()(tfd.Normal(0, 1))
# ==> `tfd.TransformedDistribution(tfd.Normal(0, 1), tfb.Exp())`
tfb.Exp()([-1., 0., 1.])
# ==> tf.exp([-1., 0., 1.])
__eq__
__eq__(
other
)
Return self==value.