![]() |
Evaluates a lattice using simplex interpolation.
tfl.lattice_lib.evaluate_with_simplex_interpolation(
inputs, kernel, units, lattice_sizes, clip_inputs
)
Within each cell of the lattice, we partition the hypercube into d! simplices, where each simplex has d+1 vertices. Each simplex (relative to the lower corner of the hypercube) includes the all-zeros vertex, a vertex with a single one, a vertex with two ones, ... and the all-ones vertex. For example, for a three-dimensional unit hypercube the 3! = 6 simplices are:
[0,0,0], [0,0,1], [0,1,1], [1,1,1] [0,0,0], [0,0,1], [1,0,1], [1,1,1] [0,0,0], [0,1,0], [0,1,1], [1,1,1] [0,0,0], [0,1,0], [1,1,0], [1,1,1] [0,0,0], [1,0,0], [1,1,0], [1,1,1] [0,0,0], [1,0,0], [1,0,1], [1,1,1]
A point x in the hypercube is contained in the simplex corresponding to the order of x's components. For example, x = [0.4,0.2,0.8] is contained in the simplex specified by 2,0,1. The weight associated with each vertex in the simplex is the difference between the decreasingly sorted cooredinates of the input. For details, see e.g. "Dissection of the hypercube into simplices", D.G. Mead, Proceedings of the AMS, 76:2, Sep. 1979.
Args | |
---|---|
inputs
|
Tensor of shape: (batch_size, ..., len(lattice_sizes)) or list of
len(lattice_sizes) tensors of same shape (batch_size, ..., 1) which
represents points to apply lattice interpolation to. A typical shape is
(batch_size, len(lattice_sizes)) .
|
kernel
|
Lattice kernel of shape (num_params_per_lattice, units). |
units
|
Output dimension of the lattice. |
lattice_sizes
|
List or tuple of integers which represents lattice sizes of layer for which interpolation is being computed. |
clip_inputs
|
Whether inputs should be clipped to the input range of the lattice. |
Returns | |
---|---|
Tensor of shape: (batch_size, ..., units) .
|