Save the date! Google I/O returns May 18-20 Register now

tfa.losses.SparsemaxLoss

Sparsemax loss function.

Computes the generalized multi-label classification loss for the sparsemax function.

Because the sparsemax loss function needs both the properbility output and the logits to compute the loss value, from_logits must be True.

Because it computes the generalized multi-label loss, the shape of both y_pred and y_true must be [batch_size, num_classes].

from_logits Whether y_pred is expected to be a logits tensor. Default is True, meaning y_pred is the logits.
reduction (Optional) Type of tf.keras.losses.Reduction to apply to loss. Default value is SUM_OVER_BATCH_SIZE.
name Optional name for the op.

Methods

from_config

Instantiates a Loss from its config (output of get_config()).

Args
config Output of get_config().

Returns
A Loss instance.

get_config

View source

Returns the config dictionary for a Loss instance.

__call__

Invokes the Loss instance.

Args
y_true Ground truth values. shape = [batch_size, d0, .. dN], except sparse loss functions such as sparse categorical crossentropy where shape = [batch_size, d0, .. dN-1]
y_pred The predicted values. shape = [batch_size, d0, .. dN]
sample_weight Optional sample_weight acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If sample_weight is a tensor of size [batch_size], then the total loss for each sample of the batch is rescaled by the corresponding element in the sample_weight vector. If the shape of sample_weight is [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of y_pred is scaled by the corresponding value of sample_weight. (Note ondN-1: all loss functions reduce by 1 dimension, usually axis=-1.)

Returns
Weighted loss float Tensor. If reduction is NONE, this has shape [batch_size, d0, .. dN-1]; otherwise, it is scalar. (Note dN-1 because all loss functions reduce by 1 dimension, usually axis=-1.)

Raises
ValueError If the shape of sample_weight is invalid.