tfrs.tasks.Retrieval

A factorized retrieval task.

Inherits From: Task

Used in the notebooks

Used in the tutorials

Recommender systems are often composed of two components:

  • a retrieval model, retrieving O(thousands) candidates from a corpus of O(millions) candidates.
  • a ranker model, scoring the candidates retrieved by the retrieval model to return a ranked shortlist of a few dozen candidates.

This task defines models that facilitate efficient retrieval of candidates from large corpora by maintaining a two-tower, factorized structure: separate query and candidate representation towers, joined at the top via a lightweight scoring function.

loss Loss function. Defaults to tf.keras.losses.CategoricalCrossentropy.
metrics Object for evaluating top-K metrics over a corpus of candidates. These metrics measure how good the model is at picking the true candidate out of all possible candidates in the system. Note, because the metrics range over the entire candidate set, they are usually much slower to compute. Consider setting compute_metrics=False during training to save the time in computing the metrics.
temperature Temperature of the softmax.
num_hard_negatives If positive, the num_hard_negatives negative examples with largest logits are kept when computing cross-entropy loss. If larger than batch size or non-positive, all the negative examples are kept.
name Optional task name.

factorized_metrics The metrics object used to compute retrieval metrics.

Methods

call

View source

Computes the task loss and metrics.

The main argument are pairs of query and candidate embeddings: the first row of query_embeddings denotes a query for which the candidate from the first row of candidate embeddings was selected by the user.

The task will try to maximize the affinity of these query, candidate pairs while minimizing the affinity between the query and candidates belonging to other queries in the batch.

Args
query_embeddings [num_queries, embedding_dim] tensor of query representations.
candidate_embeddings [num_queries, embedding_dim] tensor of candidate representations.
sample_weight [num_queries] tensor of sample weights.
candidate_sampling_probability Optional tensor of candidate sampling probabilities. When given will be be used to correct the logits to reflect the sampling probability of negative candidates.
candidate_ids Optional tensor containing candidate ids. When given enables removing accidental hits of examples used as negatives. An accidental hit is defined as an candidate that is used as an in-batch negative but has the same id with the positive candidate.
compute_metrics Whether to compute metrics. Set this to False during training for faster training.

Returns
loss Tensor of loss values.