tfg.geometry.convolution.utils.flatten_batch_to_2d

View source on GitHub

Reshapes a batch of 2d Tensors by flattening across the batch dimensions.

Note:

In the following, A1 to An are optional batch dimensions.

A tensor with shape [A1, ..., An, D1, D2] will be reshaped to one with shape [A1*...*An*D1, D2]. This function also returns an inverse function that returns any tensor with shape [A1*...*An*D1, D3] to one with shape [A1, ..., An, D1, D3].

Padded inputs in dimension D1 are allowed. sizes determines the first elements from D1 to select from each batch dimension.

Examples:

data = [[[1., 2.], [3., 4.]],
        [[5., 6.], [7., 8.]],
        [[9., 10.], [11., 12.]]]
sizes = None
output = flatten_batch_to_2d(data, size)
print(output)
>>> [[1., 2.], [3., 4.], [5., 6.], [7., 8.], [9., 10.], [11., 12.]]

data = [[[1., 2.], [0., 0.]],
        [[5., 6.], [7., 8.]],
        [[9., 10.], [0., 0.]]]
sizes = [1, 2, 1]
output = flatten_batch_to_2d(data, size)
print(output)
>>> [[1., 2.], [5., 6.], [7., 8.], [9., 10.]]

data A tensor with shape [A1, ..., An, D1, D2].
sizes An int tensor with shape [A1, ..., An]. Can be None. sizes[i] <= D1.
name A name for this op. Defaults to 'utils_flatten_batch_to_2d'.

A tensor with shape [A1*...*An*D1, D2] if sizes == None, otherwise a tensor with shape [sum(sizes), D2]. A function that reshapes a tensor with shape [A1*...*An*D1, D3] to a tensor with shape [A1, ..., An, D1, D3] if sizes == None, otherwise it reshapes a tensor with shape [sum(sizes), D3] to one with shape [A1, ..., An, ..., D1, D3].

ValueError if the input tensor dimensions are invalid.