tfl.lattice_lib.evaluate_with_simplex_interpolation

Evaluates a lattice using simplex interpolation.

Within each cell of the lattice, we partition the hypercube into d! simplices, where each simplex has d+1 vertices. Each simplex (relative to the lower corner of the hypercube) includes the all-zeros vertex, a vertex with a single one, a vertex with two ones, ... and the all-ones vertex. For example, for a three-dimensional unit hypercube the 3! = 6 simplices are:

[0,0,0], [0,0,1], [0,1,1], [1,1,1] [0,0,0], [0,0,1], [1,0,1], [1,1,1] [0,0,0], [0,1,0], [0,1,1], [1,1,1] [0,0,0], [0,1,0], [1,1,0], [1,1,1] [0,0,0], [1,0,0], [1,1,0], [1,1,1] [0,0,0], [1,0,0], [1,0,1], [1,1,1]

A point x in the hypercube is contained in the simplex corresponding to the order of x's components. For example, x = [0.4,0.2,0.8] is contained in the simplex specified by 2,0,1. The weight associated with each vertex in the simplex is the difference between the decreasingly sorted cooredinates of the input. For details, see e.g. "Dissection of the hypercube into simplices", D.G. Mead, Proceedings of the AMS, 76:2, Sep. 1979.

inputs Tensor of shape: (batch_size, ..., len(lattice_sizes)) or list of len(lattice_sizes) tensors of same shape (batch_size, ..., 1) which represents points to apply lattice interpolation to. A typical shape is (batch_size, len(lattice_sizes)).
kernel Lattice kernel of shape (num_params_per_lattice, units).
units Output dimension of the lattice.
lattice_sizes List or tuple of integers which represents lattice sizes of layer for which interpolation is being computed.
clip_inputs Whether inputs should be clipped to the input range of the lattice.

Tensor of shape: (batch_size, ..., units).