tfl.lattice_lib.compute_interpolation_weights

View source on GitHub

Computes weights for lattice interpolation.

tfl.lattice_lib.compute_interpolation_weights(
    inputs,
    lattice_sizes,
    clip_inputs=True
)

Running time: O(batch_size * prod(lattice_sizes))

If clip_inputs == True, inputs outside of the range defined by lattice_sizes will be clipped into the lattice input range. If not, the corresponding weights will linearly approach 0.0 with input moving away from the valid input range.

Args:

  • inputs: Tensor of shape: (batch_size, ..., len(lattice_sizes)) or list of len(lattice_sizes) tensors of same shape (batch_size, ..., 1) which represents points to apply lattice interpolation to. A typical shape is (batch_size, len(lattice_sizes)).
  • lattice_sizes: List or tuple of integers which represents lattice sizes of layer for which interpolation is being computed.
  • clip_inputs: Whether inputs should be clipped to the input range of the lattice.

Raises:

  • ValueError: If last dimension of inputs does not match lattice_sizes.

Returns:

Interpolation weights tensor of shape: (batch_size, ..., prod(lattice_sizes)).