Taking advantage of context features

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

In the featurization tutorial we incorporated multiple features beyond just user and movie identifiers into our models, but we haven't explored whether those features improve model accuracy.

Many factors affect whether features beyond ids are useful in a recommender model:

  1. Importance of context: if user preferences are relatively stable across contexts and time, context features may not provide much benefit. If, however, users preferences are highly contextual, adding context will improve the model significantly. For example, day of the week may be an important feature when deciding whether to recommend a short clip or a movie: users may only have time to watch short content during the week, but can relax and enjoy a full-length movie during the weekend. Similarly, query timestamps may play an important role in modelling popularity dynamics: one movie may be highly popular around the time of its release, but decay quickly afterwards. Conversely, other movies may be evergreens that are happily watched time and time again.
  2. Data sparsity: using non-id features may be critical if data is sparse. With few observations available for a given user or item, the model may struggle with estimating a good per-user or per-item representation. To build an accurate model, other features such as item categories, descriptions, and images have to be used to help the model generalize beyond the training data. This is especially relevant in cold-start situations, where relatively little data is available on some items or users.

In this tutorial, we'll experiment with using features beyond movie titles and user ids to our MovieLens model.

Preliminaries

We first import the necessary packages.

pip install -q tensorflow-recommenders
pip install -q --upgrade tensorflow-datasets
import os
import tempfile

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

import tensorflow_recommenders as tfrs
2022-12-14 12:11:12.860829: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 12:11:12.860920: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 12:11:12.860930: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

We follow the featurization tutorial and keep the user id, timestamp, and movie title features.

ratings = tfds.load("movielens/100k-ratings", split="train")
movies = tfds.load("movielens/100k-movies", split="train")

ratings = ratings.map(lambda x: {
    "movie_title": x["movie_title"],
    "user_id": x["user_id"],
    "timestamp": x["timestamp"],
})
movies = movies.map(lambda x: x["movie_title"])
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089

We also do some housekeeping to prepare feature vocabularies.

timestamps = np.concatenate(list(ratings.map(lambda x: x["timestamp"]).batch(100)))

max_timestamp = timestamps.max()
min_timestamp = timestamps.min()

timestamp_buckets = np.linspace(
    min_timestamp, max_timestamp, num=1000,
)

unique_movie_titles = np.unique(np.concatenate(list(movies.batch(1000))))
unique_user_ids = np.unique(np.concatenate(list(ratings.batch(1_000).map(
    lambda x: x["user_id"]))))

Model definition

Query model

We start with the user model defined in the featurization tutorial as the first layer of our model, tasked with converting raw input examples into feature embeddings. However, we change it slightly to allow us to turn timestamp features on or off. This will allow us to more easily demonstrate the effect that timestamp features have on the model. In the code below, the use_timestamps parameter gives us control over whether we use timestamp features.

class UserModel(tf.keras.Model):

  def __init__(self, use_timestamps):
    super().__init__()

    self._use_timestamps = use_timestamps

    self.user_embedding = tf.keras.Sequential([
        tf.keras.layers.StringLookup(
            vocabulary=unique_user_ids, mask_token=None),
        tf.keras.layers.Embedding(len(unique_user_ids) + 1, 32),
    ])

    if use_timestamps:
      self.timestamp_embedding = tf.keras.Sequential([
          tf.keras.layers.Discretization(timestamp_buckets.tolist()),
          tf.keras.layers.Embedding(len(timestamp_buckets) + 1, 32),
      ])
      self.normalized_timestamp = tf.keras.layers.Normalization(
          axis=None
      )

      self.normalized_timestamp.adapt(timestamps)

  def call(self, inputs):
    if not self._use_timestamps:
      return self.user_embedding(inputs["user_id"])

    return tf.concat([
        self.user_embedding(inputs["user_id"]),
        self.timestamp_embedding(inputs["timestamp"]),
        tf.reshape(self.normalized_timestamp(inputs["timestamp"]), (-1, 1)),
    ], axis=1)

Note that our use of timestamp features in this tutorial interacts with our choice of training-test split in an undesirable way. Because we have split our data randomly rather than chronologically (to ensure that events that belong to the test dataset happen later than those in the training set), our model can effectively learn from the future. This is unrealistic: after all, we cannot train a model today on data from tomorrow.

This means that adding time features to the model lets it learn future interaction patterns. We do this for illustration purposes only: the MovieLens dataset itself is very dense, and unlike many real-world datasets does not benefit greatly from features beyond user ids and movie titles.

This caveat aside, real-world models may well benefit from other time-based features such as time of day or day of the week, especially if the data has strong seasonal patterns.

Candidate model

For simplicity, we'll keep the candidate model fixed. Again, we copy it from the featurization tutorial:

class MovieModel(tf.keras.Model):

  def __init__(self):
    super().__init__()

    max_tokens = 10_000

    self.title_embedding = tf.keras.Sequential([
      tf.keras.layers.StringLookup(
          vocabulary=unique_movie_titles, mask_token=None),
      tf.keras.layers.Embedding(len(unique_movie_titles) + 1, 32)
    ])

    self.title_vectorizer = tf.keras.layers.TextVectorization(
        max_tokens=max_tokens)

    self.title_text_embedding = tf.keras.Sequential([
      self.title_vectorizer,
      tf.keras.layers.Embedding(max_tokens, 32, mask_zero=True),
      tf.keras.layers.GlobalAveragePooling1D(),
    ])

    self.title_vectorizer.adapt(movies)

  def call(self, titles):
    return tf.concat([
        self.title_embedding(titles),
        self.title_text_embedding(titles),
    ], axis=1)

Combined model

With both UserModel and MovieModel defined, we can put together a combined model and implement our loss and metrics logic.

Here we're building a retrieval model. For a refresher on how this works, see the Basic retrieval tutorial.

Note that we also need to make sure that the query model and candidate model output embeddings of compatible size. Because we'll be varying their sizes by adding more features, the easiest way to accomplish this is to use a dense projection layer after each model:

class MovielensModel(tfrs.models.Model):

  def __init__(self, use_timestamps):
    super().__init__()
    self.query_model = tf.keras.Sequential([
      UserModel(use_timestamps),
      tf.keras.layers.Dense(32)
    ])
    self.candidate_model = tf.keras.Sequential([
      MovieModel(),
      tf.keras.layers.Dense(32)
    ])
    self.task = tfrs.tasks.Retrieval(
        metrics=tfrs.metrics.FactorizedTopK(
            candidates=movies.batch(128).map(self.candidate_model),
        ),
    )

  def compute_loss(self, features, training=False):
    # We only pass the user id and timestamp features into the query model. This
    # is to ensure that the training inputs would have the same keys as the
    # query inputs. Otherwise the discrepancy in input structure would cause an
    # error when loading the query model after saving it.
    query_embeddings = self.query_model({
        "user_id": features["user_id"],
        "timestamp": features["timestamp"],
    })
    movie_embeddings = self.candidate_model(features["movie_title"])

    return self.task(query_embeddings, movie_embeddings)

Experiments

Prepare the data

We first split the data into a training set and a testing set.

tf.random.set_seed(42)
shuffled = ratings.shuffle(100_000, seed=42, reshuffle_each_iteration=False)

train = shuffled.take(80_000)
test = shuffled.skip(80_000).take(20_000)

cached_train = train.shuffle(100_000).batch(2048)
cached_test = test.batch(4096).cache()

Baseline: no timestamp features

We're ready to try out our first model: let's start with not using timestamp features to establish our baseline.

model = MovielensModel(use_timestamps=False)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))

model.fit(cached_train, epochs=3)

train_accuracy = model.evaluate(
    cached_train, return_dict=True)["factorized_top_k/top_100_categorical_accuracy"]
test_accuracy = model.evaluate(
    cached_test, return_dict=True)["factorized_top_k/top_100_categorical_accuracy"]

print(f"Top-100 accuracy (train): {train_accuracy:.2f}.")
print(f"Top-100 accuracy (test): {test_accuracy:.2f}.")
Epoch 1/3
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
40/40 [==============================] - 13s 215ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0097 - factorized_top_k/top_5_categorical_accuracy: 0.0208 - factorized_top_k/top_10_categorical_accuracy: 0.0310 - factorized_top_k/top_50_categorical_accuracy: 0.0937 - factorized_top_k/top_100_categorical_accuracy: 0.1620 - loss: 14575.4544 - regularization_loss: 0.0000e+00 - total_loss: 14575.4544
Epoch 2/3
40/40 [==============================] - 9s 194ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0027 - factorized_top_k/top_5_categorical_accuracy: 0.0145 - factorized_top_k/top_10_categorical_accuracy: 0.0274 - factorized_top_k/top_50_categorical_accuracy: 0.1229 - factorized_top_k/top_100_categorical_accuracy: 0.2267 - loss: 14112.4198 - regularization_loss: 0.0000e+00 - total_loss: 14112.4198
Epoch 3/3
40/40 [==============================] - 9s 180ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0024 - factorized_top_k/top_5_categorical_accuracy: 0.0159 - factorized_top_k/top_10_categorical_accuracy: 0.0319 - factorized_top_k/top_50_categorical_accuracy: 0.1421 - factorized_top_k/top_100_categorical_accuracy: 0.2579 - loss: 13928.9752 - regularization_loss: 0.0000e+00 - total_loss: 13928.9752
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
40/40 [==============================] - 8s 155ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0035 - factorized_top_k/top_5_categorical_accuracy: 0.0226 - factorized_top_k/top_10_categorical_accuracy: 0.0427 - factorized_top_k/top_50_categorical_accuracy: 0.1742 - factorized_top_k/top_100_categorical_accuracy: 0.2950 - loss: 13707.4737 - regularization_loss: 0.0000e+00 - total_loss: 13707.4737
5/5 [==============================] - 3s 212ms/step - factorized_top_k/top_1_categorical_accuracy: 5.5000e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0071 - factorized_top_k/top_10_categorical_accuracy: 0.0166 - factorized_top_k/top_50_categorical_accuracy: 0.1054 - factorized_top_k/top_100_categorical_accuracy: 0.2069 - loss: 31034.1432 - regularization_loss: 0.0000e+00 - total_loss: 31034.1432
Top-100 accuracy (train): 0.30.
Top-100 accuracy (test): 0.21.

This gives us a baseline top-100 accuracy of around 0.2.

Capturing time dynamics with time features

Do the result change if we add time features?

model = MovielensModel(use_timestamps=True)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))

model.fit(cached_train, epochs=3)

train_accuracy = model.evaluate(
    cached_train, return_dict=True)["factorized_top_k/top_100_categorical_accuracy"]
test_accuracy = model.evaluate(
    cached_test, return_dict=True)["factorized_top_k/top_100_categorical_accuracy"]

print(f"Top-100 accuracy (train): {train_accuracy:.2f}.")
print(f"Top-100 accuracy (test): {test_accuracy:.2f}.")
Epoch 1/3
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
40/40 [==============================] - 12s 227ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0142 - factorized_top_k/top_5_categorical_accuracy: 0.0285 - factorized_top_k/top_10_categorical_accuracy: 0.0393 - factorized_top_k/top_50_categorical_accuracy: 0.1045 - factorized_top_k/top_100_categorical_accuracy: 0.1739 - loss: 14540.3001 - regularization_loss: 0.0000e+00 - total_loss: 14540.3001
Epoch 2/3
40/40 [==============================] - 8s 172ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0037 - factorized_top_k/top_5_categorical_accuracy: 0.0178 - factorized_top_k/top_10_categorical_accuracy: 0.0340 - factorized_top_k/top_50_categorical_accuracy: 0.1443 - factorized_top_k/top_100_categorical_accuracy: 0.2600 - loss: 13946.6094 - regularization_loss: 0.0000e+00 - total_loss: 13946.6094
Epoch 3/3
40/40 [==============================] - 9s 172ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0026 - factorized_top_k/top_5_categorical_accuracy: 0.0200 - factorized_top_k/top_10_categorical_accuracy: 0.0412 - factorized_top_k/top_50_categorical_accuracy: 0.1784 - factorized_top_k/top_100_categorical_accuracy: 0.3091 - loss: 13681.3415 - regularization_loss: 0.0000e+00 - total_loss: 13681.3415
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
40/40 [==============================] - 8s 157ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0057 - factorized_top_k/top_5_categorical_accuracy: 0.0344 - factorized_top_k/top_10_categorical_accuracy: 0.0638 - factorized_top_k/top_50_categorical_accuracy: 0.2331 - factorized_top_k/top_100_categorical_accuracy: 0.3749 - loss: 13357.4951 - regularization_loss: 0.0000e+00 - total_loss: 13357.4951
5/5 [==============================] - 1s 225ms/step - factorized_top_k/top_1_categorical_accuracy: 9.0000e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0093 - factorized_top_k/top_10_categorical_accuracy: 0.0228 - factorized_top_k/top_50_categorical_accuracy: 0.1311 - factorized_top_k/top_100_categorical_accuracy: 0.2531 - loss: 30674.3815 - regularization_loss: 0.0000e+00 - total_loss: 30674.3815
Top-100 accuracy (train): 0.37.
Top-100 accuracy (test): 0.25.

This is quite a bit better: not only is the training accuracy much higher, but the test accuracy is also substantially improved.

Next Steps

This tutorial shows that even simple models can become more accurate when incorporating more features. However, to get the most of your features it's often necessary to build larger, deeper models. Have a look at the deep retrieval tutorial to explore this in more detail.