Computes weights for hypercube lattice interpolation.
tfl.lattice_lib.compute_interpolation_weights(
inputs, lattice_sizes, clip_inputs=True
)
Running time: O(batch_size * prod(lattice_sizes))
If clip_inputs == True
, inputs outside of the range defined by
lattice_sizes
will be clipped into the lattice input range. If not, the
corresponding weights will linearly approach 0.0 with input moving away from
the valid input range.
Args |
inputs
|
Tensor of shape: (batch_size, ..., len(lattice_sizes)) or list of
len(lattice_sizes) tensors of same shape (batch_size, ..., 1) which
represents points to apply lattice interpolation to. A typical shape is
(batch_size, len(lattice_sizes)) .
|
lattice_sizes
|
List or tuple of integers which represents lattice sizes of
layer for which interpolation is being computed.
|
clip_inputs
|
Whether inputs should be clipped to the input range of the
lattice.
|
Raises |
ValueError
|
If last dimension of inputs does not match lattice_sizes .
|
Returns |
Interpolation weights tensor of shape:
(batch_size, ..., prod(lattice_sizes)) .
|