![]() |
![]() |
![]() |
![]() |
In this tutorial, we build a simple two tower ranking model using the MovieLens 100K dataset with TF-Ranking. We can use this model to rank and recommend movies for a given user according to their predicted user ratings.
Setup
Install and import the TF-Ranking library:
pip install -q tensorflow-ranking
pip install -q --upgrade tensorflow-datasets
from typing import Dict, Tuple
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_ranking as tfr
2022-12-14 12:12:58.910578: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 12:12:58.910674: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 12:12:58.910683: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
Read the data
Prepare to train a model by creating a ratings dataset and movies dataset. Use user_id
as the query input feature, movie_title
as the document input feature, and user_rating
as the label to train the ranking model.
%%capture --no-display
# Ratings data.
ratings = tfds.load('movielens/100k-ratings', split="train")
# Features of all the available movies.
movies = tfds.load('movielens/100k-movies', split="train")
# Select the basic features.
ratings = ratings.map(lambda x: {
"movie_title": x["movie_title"],
"user_id": x["user_id"],
"user_rating": x["user_rating"]
})
2022-12-14 12:13:01.300208: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
Build vocabularies to convert all user ids and all movie titles into integer indices for embedding layers:
movies = movies.map(lambda x: x["movie_title"])
users = ratings.map(lambda x: x["user_id"])
user_ids_vocabulary = tf.keras.layers.experimental.preprocessing.StringLookup(
mask_token=None)
user_ids_vocabulary.adapt(users.batch(1000))
movie_titles_vocabulary = tf.keras.layers.experimental.preprocessing.StringLookup(
mask_token=None)
movie_titles_vocabulary.adapt(movies.batch(1000))
Group by user_id
to form lists for ranking models:
key_func = lambda x: user_ids_vocabulary(x["user_id"])
reduce_func = lambda key, dataset: dataset.batch(100)
ds_train = ratings.group_by_window(
key_func=key_func, reduce_func=reduce_func, window_size=100)
for x in ds_train.take(1):
for key, value in x.items():
print(f"Shape of {key}: {value.shape}")
print(f"Example values of {key}: {value[:5].numpy()}")
print()
Shape of movie_title: (100,) Example values of movie_title: [b'Man Who Would Be King, The (1975)' b'Silence of the Lambs, The (1991)' b'Next Karate Kid, The (1994)' b'2001: A Space Odyssey (1968)' b'Usual Suspects, The (1995)'] Shape of user_id: (100,) Example values of user_id: [b'405' b'405' b'405' b'405' b'405'] Shape of user_rating: (100,) Example values of user_rating: [1. 4. 1. 5. 5.]
Generate batched features and labels:
def _features_and_labels(
x: Dict[str, tf.Tensor]) -> Tuple[Dict[str, tf.Tensor], tf.Tensor]:
labels = x.pop("user_rating")
return x, labels
ds_train = ds_train.map(_features_and_labels)
ds_train = ds_train.apply(
tf.data.experimental.dense_to_ragged_batch(batch_size=32))
The user_id
and movie_title
tensors generated in ds_train
are of shape [32, None]
, where the second dimension is 100 in most cases except for the batches when less than 100 items grouped in lists. A model working on ragged tensors is thus used.
for x, label in ds_train.take(1):
for key, value in x.items():
print(f"Shape of {key}: {value.shape}")
print(f"Example values of {key}: {value[:3, :3].numpy()}")
print()
print(f"Shape of label: {label.shape}")
print(f"Example values of label: {label[:3, :3].numpy()}")
Shape of movie_title: (32, None) Example values of movie_title: [[b'Man Who Would Be King, The (1975)' b'Silence of the Lambs, The (1991)' b'Next Karate Kid, The (1994)'] [b'Flower of My Secret, The (Flor de mi secreto, La) (1995)' b'Little Princess, The (1939)' b'Time to Kill, A (1996)'] [b'Kundun (1997)' b'Scream (1996)' b'Power 98 (1995)']] Shape of user_id: (32, None) Example values of user_id: [[b'405' b'405' b'405'] [b'655' b'655' b'655'] [b'13' b'13' b'13']] Shape of label: (32, None) Example values of label: [[1. 4. 1.] [3. 3. 3.] [5. 1. 1.]]
Define a model
Define a ranking model by inheriting from tf.keras.Model
and implementing the call
method:
class MovieLensRankingModel(tf.keras.Model):
def __init__(self, user_vocab, movie_vocab):
super().__init__()
# Set up user and movie vocabulary and embedding.
self.user_vocab = user_vocab
self.movie_vocab = movie_vocab
self.user_embed = tf.keras.layers.Embedding(user_vocab.vocabulary_size(),
64)
self.movie_embed = tf.keras.layers.Embedding(movie_vocab.vocabulary_size(),
64)
def call(self, features: Dict[str, tf.Tensor]) -> tf.Tensor:
# Define how the ranking scores are computed:
# Take the dot-product of the user embeddings with the movie embeddings.
user_embeddings = self.user_embed(self.user_vocab(features["user_id"]))
movie_embeddings = self.movie_embed(
self.movie_vocab(features["movie_title"]))
return tf.reduce_sum(user_embeddings * movie_embeddings, axis=2)
Create the model, and then compile it with ranking tfr.keras.losses
and tfr.keras.metrics
, which are the core of the TF-Ranking package.
This example uses a ranking-specific softmax loss, which is a listwise loss introduced to promote all relevant items in the ranking list with better chances on top of the irrelevant ones. In contrast to the softmax loss in the multi-class classification problem, where only one class is positive and the rest are negative, the TF-Ranking library supports multiple relevant documents in a query list and non-binary relevance labels.
For ranking metrics, this example uses in specific Normalized Discounted Cumulative Gain (NDCG) and Mean Reciprocal Rank (MRR), which calculate the user utility of a ranked query list with position discounts. For more details about ranking metrics, review evaluation measures offline metrics.
# Create the ranking model, trained with a ranking loss and evaluated with
# ranking metrics.
model = MovieLensRankingModel(user_ids_vocabulary, movie_titles_vocabulary)
optimizer = tf.keras.optimizers.Adagrad(0.5)
loss = tfr.keras.losses.get(
loss=tfr.keras.losses.RankingLossKey.SOFTMAX_LOSS, ragged=True)
eval_metrics = [
tfr.keras.metrics.get(key="ndcg", name="metric/ndcg", ragged=True),
tfr.keras.metrics.get(key="mrr", name="metric/mrr", ragged=True)
]
model.compile(optimizer=optimizer, loss=loss, metrics=eval_metrics)
Train and evaluate the model
Train the model with model.fit
.
model.fit(ds_train, epochs=3)
Epoch 1/3 48/48 [==============================] - 7s 61ms/step - loss: 998.7476 - metric/ndcg: 0.8254 - metric/mrr: 1.0000 Epoch 2/3 48/48 [==============================] - 3s 55ms/step - loss: 997.0419 - metric/ndcg: 0.9166 - metric/mrr: 1.0000 Epoch 3/3 48/48 [==============================] - 3s 57ms/step - loss: 994.8353 - metric/ndcg: 0.9388 - metric/mrr: 1.0000 <keras.callbacks.History at 0x7f53fa36aac0>
Generate predictions and evaluate.
# Get movie title candidate list.
for movie_titles in movies.batch(2000):
break
# Generate the input for user 42.
inputs = {
"user_id":
tf.expand_dims(tf.repeat("42", repeats=movie_titles.shape[0]), axis=0),
"movie_title":
tf.expand_dims(movie_titles, axis=0)
}
# Get movie recommendations for user 42.
scores = model(inputs)
titles = tfr.utils.sort_by_scores(scores,
[tf.expand_dims(movie_titles, axis=0)])[0]
print(f"Top 5 recommendations for user 42: {titles[0, :5]}")
Top 5 recommendations for user 42: [b'Air Force One (1997)' b'Star Wars (1977)' b'Titanic (1997)' b'Raiders of the Lost Ark (1981)' b'Rock, The (1996)']