Save the date! Google I/O returns May 18-20 Register now


Computes the triplet loss with hard negative and hard positive mining.


y_true = tf.convert_to_tensor([0, 0])
y_pred = tf.convert_to_tensor([[0.0, 1.0], [1.0, 0.0]])
tfa.losses.triplet_hard_loss(y_true, y_pred, distance_metric="L2")
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
# Calling with callable `distance_metric`
distance_metric = lambda x: tf.linalg.matmul(x, x, transpose_b=True)
tfa.losses.triplet_hard_loss(y_true, y_pred, distance_metric=distance_metric)
<tf.Tensor: shape=(), dtype=float32, numpy=0.0>

y_true 1-D integer Tensor with shape [batch_size] of multiclass integer labels.
y_pred 2-D float Tensor of embedding vectors. Embeddings should be l2 normalized.
margin Float, margin term in the loss definition.
soft Boolean, if set, use the soft margin version.
distance_metric str or a Callable that determines distance metric. Valid strings are "L2" for l2-norm distance, "squared-L2" for squared l2-norm distance, and "angular" for cosine similarity.

A Callable should take a batch of embeddings as input and return the pairwise distance matrix.

triplet_loss float scalar with dtype of y_pred.