|View source on GitHub|
Computes weights for lattice interpolation.
tfl.lattice_lib.compute_interpolation_weights( inputs, lattice_sizes, clip_inputs=True )
O(batch_size * prod(lattice_sizes))
clip_inputs == True, inputs outside of the range defined by
lattice_sizes will be clipped into the lattice input range. If not, the
corresponding weights will linearly approach 0.0 with input moving away from
the valid input range.
inputs: Tensor of shape:
(batch_size, ..., len(lattice_sizes))or list of
len(lattice_sizes)tensors of same shape
(batch_size, ..., 1)which represents points to apply lattice interpolation to. A typical shape is
lattice_sizes: List or tuple of integers which represents lattice sizes of layer for which interpolation is being computed.
clip_inputs: Whether inputs should be clipped to the input range of the lattice.
ValueError: If last dimension of
inputsdoes not match
Interpolation weights tensor of shape:
(batch_size, ..., prod(lattice_sizes)).