tfrs.layers.loss.HardNegativeMining

Transforms logits and labels to return hard negatives.

num_hard_negatives How many hard negatives to return.

Methods

call

View source

Filters logits and labels with per-query hard negative mining.

The result will include logits and labels for num_hard_negatives negatives as well as the positive candidate.

Args
logits [batch_size, number_of_candidates] tensor of logits.
labels [batch_size, number_of_candidates] one-hot tensor of labels.

Returns
logits [batch_size, num_hard_negatives + 1] tensor of logits.
labels [batch_size, num_hard_negatives + 1] one-hot tensor of labels.