Computes weights for lattice interpolation.

    inputs, lattice_sizes, clip_inputs=True

Running time: O(batch_size * prod(lattice_sizes))

If clip_inputs == True, inputs outside of the range defined by lattice_sizes will be clipped into the lattice input range. If not, the corresponding weights will linearly approach 0.0 with input moving away from the valid input range.


  • inputs: Tensor of shape: (batch_size, ..., len(lattice_sizes)) or list of len(lattice_sizes) tensors of same shape (batch_size, ..., 1) which represents points to apply lattice interpolation to. A typical shape is (batch_size, len(lattice_sizes)).
  • lattice_sizes: List or tuple of integers which represents lattice sizes of layer for which interpolation is being computed.
  • clip_inputs: Whether inputs should be clipped to the input range of the lattice.


  • ValueError: If last dimension of inputs does not match lattice_sizes.


Interpolation weights tensor of shape: (batch_size, ..., prod(lattice_sizes)).