tfg.geometry.convolution.utils.unflatten_2d_to_batch

View source on GitHub

Reshapes a 2d Tensor into a batch of 2d Tensors.

The data tensor with shape [D1, D2] will be mapped to a tensor with shape [A1, ..., An, max_rows, D2] where max_rows defaults to max(sizes). sizes determines the segment of rows in the input that get mapped to a particular batch dimension (sum(sizes) == D1).

Examples:

data = [[1., 2.],
        [3., 4.],
        [5., 6.],
        [7., 8.],
        [9., 10.],
        [11., 12.]]
sizes = [2, 3, 1]

output = unflatten_2d_to_batch(data, sizes, max_rows=None)
print(output.shape)
>>> [3, 3, 2]
print(output)
>>> [[[1., 2.],
      [3., 4.],
      [0., 0.]],
     [[5., 6.],
      [7., 8.],
      [9., 10.]],
     [[11., 12.],
      [0., 0.],
      [0., 0.]]]

output = unflatten_2d_to_batch(data, sizes, max_rows=4)
print(output.shape)
>>> [3, 4, 2]
print(output)
>>> [[[1., 2.],
      [3., 4.],
      [0., 0.],
      [0., 0.]],
     [[5., 6.],
      [7., 8.],
      [9., 10.],
      [0., 0.]],
     [[11., 12.],
      [0., 0.],
      [0., 0.],
      [0., 0.]]]

data A tensor with shape [D1, D2].
sizes An int tensor with shape [A1, ..., An].
max_rows An int specifying the maximum number of rows in the unflattened output. max_rows >= max(sizes).
name A name for this op. Defaults to 'utils_unflatten_2d_to_batch'.

A tensor with shape [A1, A2, ..., max_rows, D2].