View source on GitHub |
Returns a Bijector variant of tf.nest.flatten.
tfp.substrates.jax.bijectors.tree_flatten(
example, name='restructure'
)
To make it a Bijector, it has to know how to "unflatten" as
well---unlike the real tf.nest.flatten
, this can only flatten or
unflatten a specific structure. The example
argument defines the
structure.
See also the Restructure
bijector for general rearrangements.
Args | |
---|---|
example
|
A Tensor or (potentially nested) collection of Tensors. |
name
|
An optional Python string, inserted into names of TF ops created by this bijector. |
Example
x = tf.constant(1)
example = collections.OrderedDict([
('a', [x, x, x]),
('b', x)])
bij = tfb.tree_flatten(example)
ys = collections.OrderedDict([
('a', [1, 2, 3]),
('b', 4.)])
bij.forward(ys)
# Returns [1, 2, 3, 4.]