tfrs.layers.loss.RemoveAccidentalHits

Zeroes the logits of accidental negatives.

Methods

call

View source

Zeros selected logits.

For each row in the batch, zeros the logits of negative candidates that have the same id as the positive candidate in that row.

Args
labels [batch_size, num_candidates] one-hot labels tensor.
logits [batch_size, num_candidates] logits tensor.
candidate_ids [num_candidates] candidate identifiers tensor

Returns
logits Modified logits.