oryx.bijectors.tree_flatten

Returns a Bijector variant of tf.nest.flatten.

To make it a Bijector, it has to know how to "unflatten" as well---unlike the real tf.nest.flatten, this can only flatten or unflatten a specific structure. The example argument defines the structure.

See also the Restructure bijector for general rearrangements.

example A Tensor or (potentially nested) collection of Tensors.
name An optional Python string, inserted into names of TF ops created by this bijector.

flatten A Bijector whose forward method flattens structures parallel to example into a list of Tensors, and whose inverse method packs a list of Tensors of the right length into a structure parallel to example.

Example

x = tf.constant(1)
example = collections.OrderedDict([
    ('a', [x, x, x]),
    ('b', x)])
bij = tfb.tree_flatten(example)
ys = collections.OrderedDict([
    ('a', [1, 2, 3]),
    ('b', 4.)])
bij.forward(ys)
# Returns [1, 2, 3, 4.]