tfr.utils.reshape_first_ndims

Reshapes the first n dims of the input tensor to new shape.

tensor The input Tensor.
first_ndims A int denoting the first n dims.
new_shape A list of int representing the new shape.

A reshaped Tensor.