tfa.activations.sparsemax

View source on GitHub

Sparsemax activation function 1.

Aliases:

tfa.activations.sparsemax(
    logits,
    axis=-1
)

For each batch i and class j we have

\(sparsemax[i, j] = max(logits[i, j] - tau(logits[i, :]), 0)\)

Args:

  • logits: Input tensor.
  • axis: Integer, axis along which the sparsemax operation is applied.

Returns:

Tensor, output of sparsemax transformation. Has the same type and shape as logits.

Raises:

  • ValueError: In case dim(logits) == 1.