TF 2.0 is out! Get hands-on practice at TF World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

Module: tfa.losses

View source on GitHub

Additional losses that conform to Keras API.

Modules

contrastive module: Implements contrastive loss.

focal_loss module: Implements Focal loss.

lifted module: Implements lifted_struct_loss.

metric_learning module: Functions of metric learning.

npairs module: Implements npairs loss.

triplet module: Implements triplet loss.

Classes

class ContrastiveLoss: Computes the contrastive loss between y_true and y_pred.

class LiftedStructLoss: Computes the lifted structured loss.

class NpairsLoss: Computes the npairs loss between y_true and y_pred.

class NpairsMultilabelLoss: Computes the npairs loss between multilabel data y_true and y_pred.

class SigmoidFocalCrossEntropy: Implements the focal loss function.

class SparsemaxLoss: Sparsemax loss function.

class TripletSemiHardLoss: Computes the triplet loss with semi-hard negative mining.

Functions

contrastive_loss(...): Computes the contrastive loss between y_true and y_pred.

lifted_struct_loss(...): Computes the lifted structured loss.

npairs_loss(...): Computes the npairs loss between y_true and y_pred.

npairs_multilabel_loss(...): Computes the npairs loss between multilabel data y_true and y_pred.

sigmoid_focal_crossentropy(...): Args

sparsemax_loss(...): Sparsemax loss function [1].

triplet_semihard_loss(...): Computes the triplet loss with semi-hard negative mining.