tfl.lattice_lib.compute_interpolation_weights

Computes weights for hypercube lattice interpolation.

Running time: O(batch_size * prod(lattice_sizes))

If clip_inputs == True, inputs outside of the range defined by lattice_sizes will be clipped into the lattice input range. If not, the corresponding weights will linearly approach 0.0 with input moving away from the valid input range.

inputs Tensor of shape: (batch_size, ..., len(lattice_sizes)) or list of len(lattice_sizes) tensors of same shape (batch_size, ..., 1) which represents points to apply lattice interpolation to. A typical shape is (batch_size, len(lattice_sizes)).
lattice_sizes List or tuple of integers which represents lattice sizes of layer for which interpolation is being computed.
clip_inputs Whether inputs should be clipped to the input range of the lattice.

ValueError If last dimension of inputs does not match lattice_sizes.

Interpolation weights tensor of shape: (batch_size, ..., prod(lattice_sizes)).