Efficient serving

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Retrieval models are often built to surface a handful of top candidates out of millions or even hundreds of millions of candidates. To be able to react to the user's context and behaviour, they need to be able to do this on the fly, in a matter of milliseconds.

Approximate nearest neighbour search (ANN) is the technology that makes this possible. In this tutorial, we'll show how to use ScaNN - a state of the art nearest neighbour retrieval package - to seamlessly scale TFRS retrieval to millions of items.

What is ScaNN?

ScaNN is a library from Google Research that performs dense vector similarity search at large scale. Given a database of candidate embeddings, ScaNN indexes these embeddings in a manner that allows them to be rapidly searched at inference time. ScaNN uses state of the art vector compression techniques and carefully implemented algorithms to achieve the best speed-accuracy tradeoff. It can greatly outperform brute force search while sacrificing little in terms of accuracy.

Building a ScaNN-powered model

To try out ScaNN in TFRS, we'll build a simple MovieLens retrieval model, just as we did in the basic retrieval tutorial. If you have followed that tutorial, this section will be familiar and can safely be skipped.

To start, install TFRS and TensorFlow Datasets:

pip install -q tensorflow-recommenders
pip install -q --upgrade tensorflow-datasets

We also need to install scann: it's an optional dependency of TFRS, and so needs to be installed separately.

pip install -q scann

Set up all the necessary imports.

from typing import Dict, Text

import os
import pprint
import tempfile

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_recommenders as tfrs

And load the data:

# Load the MovieLens 100K data.
ratings = tfds.load(
    "movielens/100k-ratings",
    split="train"
)

# Get the ratings data.
ratings = (ratings
           # Retain only the fields we need.
           .map(lambda x: {"user_id": x["user_id"], "movie_title": x["movie_title"]})
           # Cache for efficiency.
           .cache(tempfile.NamedTemporaryFile().name)
)

# Get the movies data.
movies = tfds.load("movielens/100k-movies", split="train")
movies = (movies
          # Retain only the fields we need.
          .map(lambda x: x["movie_title"])
          # Cache for efficiency.
          .cache(tempfile.NamedTemporaryFile().name))
WARNING:tensorflow:AutoGraph could not transform <function <lambda> at 0x7f6e0c4e1620> and will run it as-is.
Cause: could not parse the source code:

           .map(lambda x: {"user_id": x["user_id"], "movie_title": x["movie_title"]})

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning:tensorflow:AutoGraph could not transform <function <lambda> at 0x7f6e0c4e1620> and will run it as-is.
Cause: could not parse the source code:

           .map(lambda x: {"user_id": x["user_id"], "movie_title": x["movie_title"]})

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning: AutoGraph could not transform <function <lambda> at 0x7f6e0c4e1620> and will run it as-is.
Cause: could not parse the source code:

           .map(lambda x: {"user_id": x["user_id"], "movie_title": x["movie_title"]})

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:tensorflow:AutoGraph could not transform <function <lambda> at 0x7f6e08333d90> and will run it as-is.
Cause: could not parse the source code:

          .map(lambda x: x["movie_title"])

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning:tensorflow:AutoGraph could not transform <function <lambda> at 0x7f6e08333d90> and will run it as-is.
Cause: could not parse the source code:

          .map(lambda x: x["movie_title"])

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning: AutoGraph could not transform <function <lambda> at 0x7f6e08333d90> and will run it as-is.
Cause: could not parse the source code:

          .map(lambda x: x["movie_title"])

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Before we can build a model, we need to set up the user and movie vocabularies:

user_ids = ratings.map(lambda x: x["user_id"])

unique_movie_titles = np.unique(np.concatenate(list(movies.batch(1000))))
unique_user_ids = np.unique(np.concatenate(list(user_ids.batch(1000))))

We'll also set up the training and test sets:

tf.random.set_seed(42)
shuffled = ratings.shuffle(100_000, seed=42, reshuffle_each_iteration=False)

train = shuffled.take(80_000)
test = shuffled.skip(80_000).take(20_000)

Model definition

Just as in the basic retrieval tutorial, we build a simple two-tower model.

class MovielensModel(tfrs.Model):

  def __init__(self):
    super().__init__()

    embedding_dimension = 32

    # Set up a model for representing movies.
    self.movie_model = tf.keras.Sequential([
      tf.keras.layers.experimental.preprocessing.StringLookup(
        vocabulary=unique_movie_titles, mask_token=None),
      # We add an additional embedding to account for unknown tokens.
      tf.keras.layers.Embedding(len(unique_movie_titles) + 1, embedding_dimension)
    ])

    # Set up a model for representing users.
    self.user_model = tf.keras.Sequential([
      tf.keras.layers.experimental.preprocessing.StringLookup(
        vocabulary=unique_user_ids, mask_token=None),
        # We add an additional embedding to account for unknown tokens.
      tf.keras.layers.Embedding(len(unique_user_ids) + 1, embedding_dimension)
    ])

    # Set up a task to optimize the model and compute metrics.
    self.task = tfrs.tasks.Retrieval(
      metrics=tfrs.metrics.FactorizedTopK(
        candidates=movies.batch(128).cache().map(self.movie_model)
      )
    )

  def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:
    # We pick out the user features and pass them into the user model.
    user_embeddings = self.user_model(features["user_id"])
    # And pick out the movie features and pass them into the movie model,
    # getting embeddings back.
    positive_movie_embeddings = self.movie_model(features["movie_title"])

    # The task computes the loss and the metrics.

    return self.task(user_embeddings, positive_movie_embeddings, compute_metrics=not training)

Fitting and evaluation

A TFRS model is just a Keras model. We can compile it:

model = MovielensModel()
model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.1))

Estimate it:

model.fit(train.batch(8192), epochs=3)
Epoch 1/3
WARNING:tensorflow:The dtype of the source tensor must be floating (e.g. tf.float32) when calling GradientTape.gradient, got tf.int32

Warning:tensorflow:The dtype of the source tensor must be floating (e.g. tf.float32) when calling GradientTape.gradient, got tf.int32

Warning:tensorflow:Gradients do not exist for variables ['counter:0'] when minimizing the loss.

Warning:tensorflow:Gradients do not exist for variables ['counter:0'] when minimizing the loss.

Warning:tensorflow:The dtype of the source tensor must be floating (e.g. tf.float32) when calling GradientTape.gradient, got tf.int32

Warning:tensorflow:The dtype of the source tensor must be floating (e.g. tf.float32) when calling GradientTape.gradient, got tf.int32

Warning:tensorflow:Gradients do not exist for variables ['counter:0'] when minimizing the loss.

Warning:tensorflow:Gradients do not exist for variables ['counter:0'] when minimizing the loss.

10/10 [==============================] - 1s 149ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_5_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_10_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_50_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_100_categorical_accuracy: 0.0000e+00 - loss: 69808.9688 - regularization_loss: 0.0000e+00 - total_loss: 69808.9688
Epoch 2/3
10/10 [==============================] - 1s 149ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_5_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_10_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_50_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_100_categorical_accuracy: 0.0000e+00 - loss: 67485.8828 - regularization_loss: 0.0000e+00 - total_loss: 67485.8828
Epoch 3/3
10/10 [==============================] - 1s 149ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_5_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_10_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_50_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_100_categorical_accuracy: 0.0000e+00 - loss: 66311.9581 - regularization_loss: 0.0000e+00 - total_loss: 66311.9581

<tensorflow.python.keras.callbacks.History at 0x7f6e0806aeb8>

And evaluate it.

model.evaluate(test.batch(8192), return_dict=True)
3/3 [==============================] - 0s 104ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0012 - factorized_top_k/top_5_categorical_accuracy: 0.0095 - factorized_top_k/top_10_categorical_accuracy: 0.0222 - factorized_top_k/top_50_categorical_accuracy: 0.1260 - factorized_top_k/top_100_categorical_accuracy: 0.2363 - loss: 49466.8779 - regularization_loss: 0.0000e+00 - total_loss: 49466.8779

{'factorized_top_k/top_1_categorical_accuracy': 0.0012000000569969416,
 'factorized_top_k/top_5_categorical_accuracy': 0.009549999609589577,
 'factorized_top_k/top_10_categorical_accuracy': 0.022199999541044235,
 'factorized_top_k/top_50_categorical_accuracy': 0.12604999542236328,
 'factorized_top_k/top_100_categorical_accuracy': 0.23634999990463257,
 'loss': 28242.833984375,
 'regularization_loss': 0,
 'total_loss': 28242.833984375}

Approximate prediction

The most straightforward way of retrieving top candidates in response to a query is to do it via brute force: compute user-movie scores for all possible movies, sort them, and pick a couple of top recommendations.

In TFRS, this is accomplished via the BruteForce layer:

brute_force = tfrs.layers.factorized_top_k.BruteForce(model.user_model)
brute_force.index(movies.batch(128).map(model.movie_model), movies)
<tensorflow_recommenders.layers.factorized_top_k.BruteForce at 0x7f6e0806add8>

Once created and populated with candidates (via the index method), we can call it to get predictions out:

# Get predictions for user 42.
_, titles = brute_force(np.array(["42"]), k=3)

print(f"Top recommendations: {titles[0]}")
Top recommendations: [b'Homeward Bound: The Incredible Journey (1993)'
 b"Kid in King Arthur's Court, A (1995)" b'Rudy (1993)']

On a small dataset of under 1000 movies, this is very fast:

%timeit _, titles = brute_force(np.array(["42"]), k=3)
787 µs ± 18.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

But what happens if we have more candidates - millions instead of thousands?

We can simulate this by indexing all of our movies multiple times:

# Construct a dataset of movies that's 1,000 times larger. We 
# do this by adding several million dummy movie titles to the dataset.
lots_of_movies = tf.data.Dataset.concatenate(
    movies.batch(4096),
    movies.batch(4096).repeat(1_000).map(lambda x: tf.zeros_like(x))
)

# We also add lots of dummy embeddings by randomly perturbing
# the estimated embeddings for real movies.
lots_of_movies_embeddings = tf.data.Dataset.concatenate(
    movies.batch(4096).map(model.movie_model),
    movies.batch(4096).repeat(1_000)
      .map(lambda x: model.movie_model(x))
      .map(lambda x: x * tf.random.uniform(tf.shape(x)))
)
WARNING:tensorflow:AutoGraph could not transform <function <lambda> at 0x7f6e0821a0d0> and will run it as-is.
Cause: could not parse the source code:

      .map(lambda x: model.movie_model(x))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning:tensorflow:AutoGraph could not transform <function <lambda> at 0x7f6e0821a0d0> and will run it as-is.
Cause: could not parse the source code:

      .map(lambda x: model.movie_model(x))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning: AutoGraph could not transform <function <lambda> at 0x7f6e0821a0d0> and will run it as-is.
Cause: could not parse the source code:

      .map(lambda x: model.movie_model(x))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:tensorflow:AutoGraph could not transform <function <lambda> at 0x7f6e0821a048> and will run it as-is.
Cause: could not parse the source code:

      .map(lambda x: x * tf.random.uniform(tf.shape(x)))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning:tensorflow:AutoGraph could not transform <function <lambda> at 0x7f6e0821a048> and will run it as-is.
Cause: could not parse the source code:

      .map(lambda x: x * tf.random.uniform(tf.shape(x)))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning: AutoGraph could not transform <function <lambda> at 0x7f6e0821a048> and will run it as-is.
Cause: could not parse the source code:

      .map(lambda x: x * tf.random.uniform(tf.shape(x)))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

We can build a BruteForce index on this larger dataset:

brute_force_lots = tfrs.layers.factorized_top_k.BruteForce()
brute_force_lots.index(lots_of_movies_embeddings, lots_of_movies)
<tensorflow_recommenders.layers.factorized_top_k.BruteForce at 0x7f6d20543a90>

The recommendations are still the same

_, titles = brute_force_lots(model.user_model(np.array(["42"])), k=3)

print(f"Top recommendations: {titles[0]}")
Top recommendations: [b'Homeward Bound: The Incredible Journey (1993)'
 b"Kid in King Arthur's Court, A (1995)" b'Rudy (1993)']

But they take much longer. With a candidate set of 1 million movies, brute force prediction becomes quite slow:

%timeit _, titles = brute_force_lots(model.user_model(np.array(["42"])), k=3)
29.5 ms ± 164 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

As the number of candidate grows, the amount of time needed grows linearly: with 10 million candidates, serving top candidates would take 250 milliseconds. This is clearly too slow for a live service.

This is where approximate mechanisms come in.

Using ScaNN in TFRS is accomplished via the tfrs.layers.factorized_top_k.ScaNN layer. It follow the same interface as the other top k layers:

scann = tfrs.layers.factorized_top_k.ScaNN(num_reordering_candidates=1000)
scann.index(lots_of_movies_embeddings, lots_of_movies)
<tensorflow_recommenders.layers.factorized_top_k.ScaNN at 0x7f6d2061e0f0>

The recommendations are (approximately!) the same

_, titles = scann(model.user_model(np.array(["42"])), k=3)

print(f"Top recommendations: {titles[0]}")
Top recommendations: [b'Homeward Bound: The Incredible Journey (1993)'
 b"Kid in King Arthur's Court, A (1995)" b'Rudy (1993)']

But they are much, much faster to compute:

%timeit _, titles = scann(model.user_model(np.array(["42"])), k=3)
3.17 ms ± 9.62 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In this case, we can retrieve the top 3 movies out of a set of ~1 million in around 2 milliseconds: 15 times faster than by computing the best candidates via brute force. The advantage of approximate methods grows even larger for larger datasets.

Evaluating the approximation

When using approximate top K retrieval mechanisms (such as ScaNN), speed of retrieval often comes at the expense of accuracy. To understand this trade-off, it's important to measure the model's evaluation metrics when using ScaNN, and to compare them with the baseline.

Fortunately, TFRS makes this easy. We simply override the metrics on the retrieval task with metrics using ScaNN, re-compile the model, and run evaluation.

To make the comparison, let's first run baseline results. We still need to override our metrics to make sure they are using the enlarged candidate set rather than the original set of movies:

# Override the existing streaming candidate source.
model.task.factorized_metrics = tfrs.metrics.FactorizedTopK(
    candidates=lots_of_movies_embeddings
)
# Need to recompile the model for the changes to take effect.
model.compile()

%time baseline_result = model.evaluate(test.batch(8192), return_dict=True, verbose=False)
CPU times: user 21min 12s, sys: 2min 29s, total: 23min 42s
Wall time: 47.4 s

We can do the same using ScaNN:

model.task.factorized_metrics = tfrs.metrics.FactorizedTopK(
    candidates=scann
)
model.compile()

# We can use a much bigger batch size here because ScaNN evaluation
# is more memory efficient.
%time scann_result = model.evaluate(test.batch(8192), return_dict=True, verbose=False)
CPU times: user 13 s, sys: 4.37 s, total: 17.3 s
Wall time: 1.94 s

ScaNN based evaluation is much, much quicker: it's over ten times faster! This advantage is going to grow even larger for bigger datasets, and so for large datasets it may be prudent to always run ScaNN-based evaluation to improve model development velocity.

But how about the results? Fortunately, in this case the results are almost the same:

print(f"Brute force top-100 accuracy: {baseline_result['factorized_top_k/top_100_categorical_accuracy']:.2f}")
print(f"ScaNN top-100 accuracy:       {scann_result['factorized_top_k/top_100_categorical_accuracy']:.2f}")
Brute force top-100 accuracy: 0.15
ScaNN top-100 accuracy:       0.16

This suggests that on this artificial datase, there is little loss from the approximation. In general, all approximate methods exhibit speed-accuracy tradeoffs. To understand this in more depth you can check out Erik Bernhardsson's ANN benchmarks.

Deploying the approximate model

The ScaNN-based model is fully integrated into TensorFlow models, and serving it is as easy as serving any other TensorFlow model.

We can save it as a SavedModel object

# We re-index the ScaNN layer to include the user embeddings in the same model.
# This way we can give the saved model raw features and get valid predictions
# back.
scann = tfrs.layers.factorized_top_k.ScaNN(model.user_model, num_reordering_candidates=1000)
scann.index(lots_of_movies_embeddings, lots_of_movies)

# Need to call it to set the shapes.
_ = scann(np.array(["42"]))

with tempfile.TemporaryDirectory() as tmp:
  path = os.path.join(tmp, "model")
  scann.save(
    path,
    options=tf.saved_model.SaveOptions(namespace_whitelist=["Scann"])
  )

  loaded = tf.keras.models.load_model(path)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

INFO:tensorflow:Assets written to: /tmp/tmpo1ur88if/model/assets

INFO:tensorflow:Assets written to: /tmp/tmpo1ur88if/model/assets

Warning:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.

Warning:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.

and then load it and serve, getting exactly the same results back:

_, titles = loaded(tf.constant(["42"]))

print(f"Top recommendations: {titles[0][:3]}")
Top recommendations: [b'Homeward Bound: The Incredible Journey (1993)'
 b"Kid in King Arthur's Court, A (1995)" b'Rudy (1993)']

Tuning ScaNN

Now let's look into tuning our ScaNN layer to get a better performance/accuracy tradeoff. In order to do this effectively, we first need to measure our baseline performance and accuracy.

From above, we already have a measurement of our model's latency for processing a single (non-batched) query (although note that a fair amount of this latency is from non-ScaNN components of the model).

Now we need to investigate ScaNN's accuracy, which we measure through recall. A recall@k of x% means that if we use brute force to retrieve the true top k neighbors, and compare those results to using ScaNN to also retrieve the top k neighbors, x% of ScaNN's results are in the true brute force results. Let's compute the recall for the current ScaNN searcher.

First, we need to generate the brute force, ground truth top-k:

# Process queries in groups of 1000; processing them all at once with brute force
# may lead to out-of-memory errors, because processing a batch of q queries against
# a size-n dataset takes O(nq) space with brute force.
titles_ground_truth = tf.concat([
  brute_force_lots(queries, k=10)[1] for queries in
  test.batch(1000).map(lambda x: model.user_model(x["user_id"]))
], axis=0)

Our variable titles_ground_truth now contains the top-10 movie recommendations returned by brute-force retrieval. Now we can compute the same recommendations when using ScaNN:

# Get all user_id's as a 1d tensor of strings
test_flat = np.concatenate(list(test.map(lambda x: x["user_id"]).batch(1000).as_numpy_iterator()), axis=0)

# ScaNN is much more memory efficient and has no problem processing the whole
# batch of 20000 queries at once.
_, titles = scann(test_flat, k=10)

Next, we define our function that computes recall. For each query, it counts how many results are in the intersection of the brute force and the ScaNN results and divides this by the number of brute force results. The average of this quantity over all queries is our recall.

def compute_recall(ground_truth, approx_results):
  return np.mean([
      len(np.intersect1d(truth, approx)) / len(truth)
      for truth, approx in zip(ground_truth, approx_results)
  ])

This gives us baseline recall@10 with the current ScaNN config:

print(f"Recall: {compute_recall(titles_ground_truth, titles):.3f}")
Recall: 0.920

We can also measure the baseline latency:

%timeit -n 1000 scann(np.array(["42"]), k=10)
3.21 ms ± 8.91 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Let's see if we can do better!

To do this, we need a model of how ScaNN's tuning knobs affect performance. Our current model uses ScaNN's tree-AH algorithm. This algorithm partitions the database of embeddings (the "tree") and then scores the most promising of these partitions using AH, which is a highly optimized approximate distance computation routine.

The default parameters for TensorFlow Recommenders' ScaNN Keras layer sets num_leaves=100 and num_leaves_to_search=10. This means our database is partitioned into 100 disjoint subsets, and the 10 most promising of these partitions is scored with AH. This means 10/100=10% of the dataset is being searched with AH.

If we have, say, num_leaves=1000 and num_leaves_to_search=100, we would also be searching 10% of the database with AH. However, in comparison to the previous setting, the 10% we would search will contain higher-quality candidates, because a higher num_leaves allows us to make finer-grained decisions about what parts of the dataset are worth searching.

It's no surprise then that with num_leaves=1000 and num_leaves_to_search=100 we get significantly higher recall:

scann2 = tfrs.layers.factorized_top_k.ScaNN(
    model.user_model, 
    num_leaves=1000,
    num_leaves_to_search=100,
    num_reordering_candidates=1000)
scann2.index(lots_of_movies_embeddings, lots_of_movies)

_, titles2 = scann2(test_flat, k=10)

print(f"Recall: {compute_recall(titles_ground_truth, titles2):.3f}")
Recall: 0.965

However, as a tradeoff, our latency has also increased. This is because the partitioning step has gotten more expensive; scann picks the top 10 of 100 partitions while scann2 picks the top 100 of 1000 partitions. The latter can be more expensive because it involves looking at 10 times as many partitions.

%timeit -n 1000 scann2(np.array(["42"]), k=10)
3.37 ms ± 44 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In general, tuning ScaNN search is about picking the right tradeoffs. Each individual parameter change generally won't make search both faster and more accurate; our goal is to tune the parameters to optimally trade off between these two conflicting goals.

In our case, scann2 significantly improved recall over scann at some cost in latency. Can we dial back some other knobs to cut down on latency, while preserving most of our recall advantage?

Let's try searching 70/1000=7% of the dataset with AH, and only rescoring the final 400 candidates:

scann3 = tfrs.layers.factorized_top_k.ScaNN(
    model.user_model,
    num_leaves=1000,
    num_leaves_to_search=70,
    num_reordering_candidates=400)
scann3.index(lots_of_movies_embeddings, lots_of_movies)

_, titles3 = scann3(test_flat, k=10)
print(f"Recall: {compute_recall(titles_ground_truth, titles3):.3f}")
Recall: 0.960

scann3 delivers about a 3% absolute recall gain over scann while also delivering lower latency:

%timeit -n 1000 scann3(np.array(["42"]), k=10)
3.19 ms ± 8.02 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

These knobs can be further adjusted to optimize for different points along the accuracy-performance pareto frontier. ScaNN's algorithms can achieve state-of-the-art performance over a wide range of recall targets.

Further reading

ScaNN uses advanced vector quantization techniques and highly optimized implementation to achieve its results. The field of vector quantization has a rich history with a variety of approaches. ScaNN's current quantization technique is detailed in this paper, published at ICML 2020. The paper was also released along with this blog article which gives a high level overview of our technique.

Many related quantization techniques are mentioned in the references of our ICML 2020 paper, and other ScaNN-related research is listed at http://sanjivk.com/