tfl.lattice_lib.torsion_regularizer

View source on GitHub

Returns Torsion regularization loss for Lattice layer.

Lattice torsion regularizer penalizes how much the lattice function twists from side-to-side (see publication).

Consider a 3 x 2 lattice with weights w:

w[3]-----w[4]-----w[5]
  |        |        |
  |        |        |
w[0]-----w[1]-----w[2]

In this case, the torsion regularizer is defined as:

l1 * (|w[4] + w[0] - w[3] - w[1]| + |w[5] + w[1] - w[4] - w[2]|) +
l2 * ((w[4] + w[0] - w[3] - w[1])^2 + (w[5] + w[1] - w[4] - w[2])^2)

weights Lattice weights tensor of shape: (prod(lattice_sizes), units).
lattice_sizes List or tuple of integers which represents lattice sizes.
l1 l1 regularization amount. Either single float or list or tuple of floats to specify different regularization amount per dimension.
l2 l2 regularization amount. Either single float or list or tuple of floats to specify different regularization amount per dimension. The amount for the interaction term between i and j is the corresponding product of each per feature amount.

Laplacian regularization loss.