tfa.losses.triplet_hard_loss

View source on GitHub

Computes the triplet loss with hard negative and hard positive mining.

y_true 1-D integer Tensor with shape [batch_size] of multiclass integer labels.
y_pred 2-D float Tensor of embedding vectors. Embeddings should be l2 normalized.
margin Float, margin term in the loss definition.
soft Boolean, if set, use the soft margin version.
distance_metric str or function, determines distance metric: "L1" for l1-norm distance "L2" for l2-norm distance "angular" for cosine similarity A custom function returning a 2d adjacency matrix of a chosen distance metric can also be passed here. e.g.

def custom_distance(batch): batch = 1 - batch @ batch.T return batch

triplet_semihard_loss(batch, labels, distance_metric=custom_distance )

triplet_loss float scalar with dtype of y_pred.