ScaNN approximate retrieval index for a factorized retrieval model.
Inherits From: TopK
tfrs.layers.factorized_top_k.ScaNN(
query_model: Optional[tf.keras.Model] = None,
k: int = 10,
distance_measure: Text = 'dot_product',
num_leaves: int = 100,
num_leaves_to_search: int = 10,
dimensions_per_block: int = 2,
num_reordering_candidates: Optional[int] = None,
parallelize_batch_searches: bool = True,
name: Optional[Text] = None
)
Used in the notebooks
This layer uses the state-of-the-art
ScaNN
library to retrieve the best candidates for a given query.
Args |
query_model
|
Optional Keras model for representing queries. If provided,
will be used to transform raw features into query embeddings when
querying the layer. If not provided, the layer will expect to be given
query embeddings as inputs.
|
k
|
Default number of results to retrieve. Can be overridden in call .
|
distance_measure
|
Distance metric to use.
|
num_leaves
|
Number of leaves.
|
num_leaves_to_search
|
Number of leaves to search.
|
dimensions_per_block
|
Controls the dataset compression ratio. A higher
number results in greater compression, leading to faster scoring but
less accuracy and more memory usage.
|
num_reordering_candidates
|
If set, the index will perform a final
refinement pass on num_reordering_candidates candidates after
retrieving an initial set of neighbours. This helps improve accuracy,
but requires the original representations to be kept, and so will
increase the final model size."
|
parallelize_batch_searches
|
Whether batch querying should be done in
parallel.
|
name
|
Name of the layer.
|
Raises |
ImportError
|
if the scann library is not installed.
|
Methods
call
View source
call(
queries: Union[tf.Tensor, Dict[Text, tf.Tensor]],
k: Optional[int] = None
) -> Tuple[tf.Tensor, tf.Tensor]
Query the index.
Args |
queries
|
Query features. If query_model was provided in the constructor,
these can be raw query features that will be processed by the query
model before performing retrieval. If query_model was not provided,
these should be pre-computed query embeddings.
|
k
|
The number of candidates to retrieve. Defaults to constructor k
parameter if not supplied.
|
Returns |
Tuple of (top candidate scores, top candidate identifiers).
|
Raises |
ValueError if index has not been called.
ValueError if queries is not a tensor (after being passed through
the query model) or is not rank 2.
|
index
View source
index(
candidates: Union[tf.Tensor, tf.data.Dataset],
identifiers: Optional[Union[tf.Tensor, tf.data.Dataset]] = None
) -> "ScaNN"
Builds the retrieval index.
When called multiple times the existing index will be dropped and a new one
created.
Args |
candidates
|
Matrix (or dataset) of candidate embeddings.
|
identifiers
|
Optional tensor (or dataset) of candidate identifiers. If
given, these will be used to as identifiers of top candidates returned
when performing searches. If not given, indices into the candidates
tensor will be given instead.
|
query_with_exclusions
View source
@tf.function
query_with_exclusions(
queries: Union[tf.Tensor, Dict[Text, tf.Tensor]],
exclusions: tf.Tensor,
k: Optional[int] = None
) -> Tuple[tf.Tensor, tf.Tensor]
Query the index.
Args |
queries
|
Query features. If query_model was provided in the constructor,
these can be raw query features that will be processed by the query
model before performing retrieval. If query_model was not provided,
these should be pre-computed query embeddings.
|
exclusions
|
[query_batch_size, num_to_exclude] tensor of identifiers to
be excluded from the top-k calculation. This is most commonly used to
exclude previously seen candidates from retrieval. For example, if a
user has already seen items with ids "42" and "43", you could set
exclude to [["42", "43"]] .
|
k
|
The number of candidates to retrieve. Defaults to constructor k
parameter if not supplied.
|
Returns |
Tuple of (top candidate scores, top candidate identifiers).
|
Raises |
ValueError if index has not been called.
ValueError if queries is not a tensor (after being passed through
the query model).
|