|View source on GitHub|
A factorized retrieval task.
tfrs.tasks.Retrieval( loss: Optional[tf.keras.losses.Loss] = None, metrics: Optional[
tfrs.metrics.FactorizedTopK] = None, temperature: Optional[float] = None, num_hard_negatives: Optional[int] = None, name: Optional[Text] = None ) -> None
Used in the notebooks
|Used in the tutorials|
Recommender systems are often composed of two components:
- a retrieval model, retrieving O(thousands) candidates from a corpus of O(millions) candidates.
- a ranker model, scoring the candidates retrieved by the retrieval model to return a ranked shortlist of a few dozen candidates.
This task defines models that facilitate efficient retrieval of candidates from large corpora by maintaining a two-tower, factorized structure: separate query and candidate representation towers, joined at the top via a lightweight scoring function.
Loss function. Defaults to
Object for evaluating top-K metrics over a
corpus of candidates. These metrics measure how good the model is at
picking the true candidate out of all possible candidates in the system.
Note, because the metrics range over the entire candidate set, they are
usually much slower to compute. Consider setting
||Temperature of the softmax.|
If positive, the
||Optional task name.|
||The metrics object used to compute retrieval metrics.|
call( query_embeddings: tf.Tensor, candidate_embeddings: tf.Tensor, sample_weight: Optional[tf.Tensor] = None, candidate_sampling_probability: Optional[tf.Tensor] = None, candidate_ids: Optional[tf.Tensor] = None, compute_metrics: bool = True ) -> tf.Tensor
Computes the task loss and metrics.
The main argument are pairs of query and candidate embeddings: the first row of query_embeddings denotes a query for which the candidate from the first row of candidate embeddings was selected by the user.
The task will try to maximize the affinity of these query, candidate pairs while minimizing the affinity between the query and candidates belonging to other queries in the batch.
||[num_queries, embedding_dim] tensor of query representations.|
||[num_queries, embedding_dim] tensor of candidate representations.|
||[num_queries] tensor of sample weights.|
||Optional tensor of candidate sampling probabilities. When given will be be used to correct the logits to reflect the sampling probability of negative candidates.|
||Optional tensor containing candidate ids. When given enables removing accidental hits of examples used as negatives. An accidental hit is defined as an candidate that is used as an in-batch negative but has the same id with the positive candidate.|
||Whether to compute metrics. Set this to False during training for faster training.|
||Tensor of loss values.|