View source on GitHub

Computes the triplet loss with semi-hard negative mining.

y_true 1-D integer Tensor with shape [batch_size] of multiclass integer labels.
y_pred 2-D float Tensor of embedding vectors. Embeddings should be l2 normalized.
margin Float, margin term in the loss definition.
distance_metric str or function, determines distance metric: "L2" for l2-norm distance "squared-L2" for squared l2-norm distance "angular" for cosine similarity A custom function returning a 2d adjacency matrix of a chosen distance metric can also be passed here. e.g.

def custom_distance(batch): batch = 1 - batch @ batch.T return batch

triplet_semihard_loss(batch, labels, distance_metric=custom_distance )

triplet_loss float scalar with dtype of y_pred.