コンテキスト機能を利用する

TensorFlow.orgで表示GoogleColabで実行GitHubでソースを表示 ノートブックをダウンロード

機能化チュートリアルでは、ユーザーと映画の識別子だけでなく複数の機能をモデルに組み込みましたが、これらの機能がモデルの精度を向上させるかどうかについては調査していません。

ID以外の機能がレコメンダーモデルで役立つかどうかには、多くの要因が影響します。

  1. コンテキストの重要性:ユーザーの好みがコンテキストと時間にわたって比較的安定している場合、コンテキスト機能はあまりメリットがない可能性があります。ただし、ユーザー設定が高度にコンテキストに依存している場合は、コンテキストを追加するとモデルが大幅に改善されます。たとえば、曜日は、短いクリップと映画のどちらを推奨するかを決定する際の重要な機能です。ユーザーは、その週は短いコンテンツを見る時間しかありませんが、週末はリラックスして長編映画を楽しむことができます。 。同様に、クエリのタイムスタンプは、人気のダイナミクスをモデル化する上で重要な役割を果たす可能性があります。1つの映画は、リリースの前後で非常に人気がありますが、その後すぐに衰退します。逆に、他の映画は何度も何度も楽しく見られている常緑樹かもしれません。
  2. データのスパース性:データがスパースである場合、id以外の機能を使用することが重要になる場合があります。特定のユーザーまたはアイテムで利用できる観測値がほとんどないため、モデルは、ユーザーごとまたはアイテムごとの適切な表現を推定するのに苦労する可能性があります。正確なモデルを構築するには、アイテムカテゴリ、説明、画像などの他の機能を使用して、モデルがトレーニングデータを超えて一般化できるようにする必要があります。これは、一部のアイテムまたはユーザーで利用できるデータが比較的少ないコールドスタートの状況で特に関係があります。

このチュートリアルでは、映画のタイトルやユーザーID以外の機能をMovieLensモデルに使用して実験します。

予選

まず、必要なパッケージをインポートします。

pip install -q tensorflow-recommenders
pip install -q --upgrade tensorflow-datasets
import os
import tempfile

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

import tensorflow_recommenders as tfrs

機能化のチュートリアルに従い、ユーザーID、タイムスタンプ、および映画のタイトル機能を保持します。

ratings = tfds.load("movielens/100k-ratings", split="train")
movies = tfds.load("movielens/100k-movies", split="train")

ratings = ratings.map(lambda x: {
    "movie_title": x["movie_title"],
    "user_id": x["user_id"],
    "timestamp": x["timestamp"],
})
movies = movies.map(lambda x: x["movie_title"])

また、機能の語彙を準備するためにいくつかのハウスキーピングを行います。

timestamps = np.concatenate(list(ratings.map(lambda x: x["timestamp"]).batch(100)))

max_timestamp = timestamps.max()
min_timestamp = timestamps.min()

timestamp_buckets = np.linspace(
    min_timestamp, max_timestamp, num=1000,
)

unique_movie_titles = np.unique(np.concatenate(list(movies.batch(1000))))
unique_user_ids = np.unique(np.concatenate(list(ratings.batch(1_000).map(
    lambda x: x["user_id"]))))

モデル定義

クエリモデル

まず、機能化チュートリアルでモデルの最初のレイヤーとして定義されたユーザーモデルから始め、生の入力例を機能の埋め込みに変換します。ただし、タイムスタンプ機能をオンまたはオフにできるように少し変更しています。これにより、タイムスタンプ機能がモデルに与える影響をより簡単に示すことができます。以下のコードでは、 use_timestampsパラメーターを使用して、タイムスタンプ機能を使用するかどうかを制御できます。

class UserModel(tf.keras.Model):

  def __init__(self, use_timestamps):
    super().__init__()

    self._use_timestamps = use_timestamps

    self.user_embedding = tf.keras.Sequential([
        tf.keras.layers.StringLookup(
            vocabulary=unique_user_ids, mask_token=None),
        tf.keras.layers.Embedding(len(unique_user_ids) + 1, 32),
    ])

    if use_timestamps:
      self.timestamp_embedding = tf.keras.Sequential([
          tf.keras.layers.Discretization(timestamp_buckets.tolist()),
          tf.keras.layers.Embedding(len(timestamp_buckets) + 1, 32),
      ])
      self.normalized_timestamp = tf.keras.layers.Normalization(
          axis=None
      )

      self.normalized_timestamp.adapt(timestamps)

  def call(self, inputs):
    if not self._use_timestamps:
      return self.user_embedding(inputs["user_id"])

    return tf.concat([
        self.user_embedding(inputs["user_id"]),
        self.timestamp_embedding(inputs["timestamp"]),
        tf.reshape(self.normalized_timestamp(inputs["timestamp"]), (-1, 1)),
    ], axis=1)

このチュートリアルでのタイムスタンプ機能の使用は、望ましくない方法でトレーニングテスト分割の選択と相互作用することに注意してください。データを時系列ではなくランダムに分割しているため(テストデータセットに属するイベントがトレーニングセットのイベントよりも遅く発生するようにするため)、モデルは将来から効果的に学習できます。これは非現実的です。結局のところ、明日からのデータで今日モデルをトレーニングすることはできません。

これは、モデルに時間機能を追加すると、将来の相互作用パターンを学習できることを意味します。これは説明の目的でのみ行います。MovieLensデータセット自体は非常に高密度であり、多くの実際のデータセットとは異なり、ユーザーIDや映画のタイトル以外の機能の恩恵をあまり受けません。

この警告はさておき、実際のモデルは、特にデータに強い季節パターンがある場合、時刻や曜日などの他の時間ベースの機能の恩恵を受ける可能性があります。

候補モデル

簡単にするために、候補モデルは固定しておきます。繰り返しになりますが、機能化チュートリアルからコピーします。

class MovieModel(tf.keras.Model):

  def __init__(self):
    super().__init__()

    max_tokens = 10_000

    self.title_embedding = tf.keras.Sequential([
      tf.keras.layers.StringLookup(
          vocabulary=unique_movie_titles, mask_token=None),
      tf.keras.layers.Embedding(len(unique_movie_titles) + 1, 32)
    ])

    self.title_vectorizer = tf.keras.layers.TextVectorization(
        max_tokens=max_tokens)

    self.title_text_embedding = tf.keras.Sequential([
      self.title_vectorizer,
      tf.keras.layers.Embedding(max_tokens, 32, mask_zero=True),
      tf.keras.layers.GlobalAveragePooling1D(),
    ])

    self.title_vectorizer.adapt(movies)

  def call(self, titles):
    return tf.concat([
        self.title_embedding(titles),
        self.title_text_embedding(titles),
    ], axis=1)

複合モデル

UserModelMovieModelの両方を定義すると、結合されたモデルをまとめて、損失とメトリックのロジックを実装できます。

ここでは、検索モデルを構築しています。これがどのように機能するかについての復習については、基本的な検索チュートリアルを参照してください。

また、クエリモデルと候補モデルが互換性のあるサイズの埋め込みを出力することを確認する必要があることに注意してください。機能を追加してサイズを変更するため、これを実現する最も簡単な方法は、各モデルの後に高密度の投影レイヤーを使用することです。

class MovielensModel(tfrs.models.Model):

  def __init__(self, use_timestamps):
    super().__init__()
    self.query_model = tf.keras.Sequential([
      UserModel(use_timestamps),
      tf.keras.layers.Dense(32)
    ])
    self.candidate_model = tf.keras.Sequential([
      MovieModel(),
      tf.keras.layers.Dense(32)
    ])
    self.task = tfrs.tasks.Retrieval(
        metrics=tfrs.metrics.FactorizedTopK(
            candidates=movies.batch(128).map(self.candidate_model),
        ),
    )

  def compute_loss(self, features, training=False):
    # We only pass the user id and timestamp features into the query model. This
    # is to ensure that the training inputs would have the same keys as the
    # query inputs. Otherwise the discrepancy in input structure would cause an
    # error when loading the query model after saving it.
    query_embeddings = self.query_model({
        "user_id": features["user_id"],
        "timestamp": features["timestamp"],
    })
    movie_embeddings = self.candidate_model(features["movie_title"])

    return self.task(query_embeddings, movie_embeddings)

実験

データを準備する

まず、データをトレーニングセットとテストセットに分割します。

tf.random.set_seed(42)
shuffled = ratings.shuffle(100_000, seed=42, reshuffle_each_iteration=False)

train = shuffled.take(80_000)
test = shuffled.skip(80_000).take(20_000)

cached_train = train.shuffle(100_000).batch(2048)
cached_test = test.batch(4096).cache()

ベースライン:タイムスタンプ機能なし

最初のモデルを試す準備ができました。ベースラインを確立するためにタイムスタンプ機能を使用しないことから始めましょう。

model = MovielensModel(use_timestamps=False)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))

model.fit(cached_train, epochs=3)

train_accuracy = model.evaluate(
    cached_train, return_dict=True)["factorized_top_k/top_100_categorical_accuracy"]
test_accuracy = model.evaluate(
    cached_test, return_dict=True)["factorized_top_k/top_100_categorical_accuracy"]

print(f"Top-100 accuracy (train): {train_accuracy:.2f}.")
print(f"Top-100 accuracy (test): {test_accuracy:.2f}.")
Epoch 1/3
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
40/40 [==============================] - 10s 169ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0092 - factorized_top_k/top_5_categorical_accuracy: 0.0172 - factorized_top_k/top_10_categorical_accuracy: 0.0256 - factorized_top_k/top_50_categorical_accuracy: 0.0824 - factorized_top_k/top_100_categorical_accuracy: 0.1473 - loss: 14579.4628 - regularization_loss: 0.0000e+00 - total_loss: 14579.4628
Epoch 2/3
40/40 [==============================] - 9s 173ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0020 - factorized_top_k/top_5_categorical_accuracy: 0.0126 - factorized_top_k/top_10_categorical_accuracy: 0.0251 - factorized_top_k/top_50_categorical_accuracy: 0.1129 - factorized_top_k/top_100_categorical_accuracy: 0.2133 - loss: 14136.2137 - regularization_loss: 0.0000e+00 - total_loss: 14136.2137
Epoch 3/3
40/40 [==============================] - 9s 174ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0021 - factorized_top_k/top_5_categorical_accuracy: 0.0155 - factorized_top_k/top_10_categorical_accuracy: 0.0307 - factorized_top_k/top_50_categorical_accuracy: 0.1389 - factorized_top_k/top_100_categorical_accuracy: 0.2535 - loss: 13939.9265 - regularization_loss: 0.0000e+00 - total_loss: 13939.9265
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
40/40 [==============================] - 10s 189ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0036 - factorized_top_k/top_5_categorical_accuracy: 0.0226 - factorized_top_k/top_10_categorical_accuracy: 0.0427 - factorized_top_k/top_50_categorical_accuracy: 0.1729 - factorized_top_k/top_100_categorical_accuracy: 0.2944 - loss: 13711.3802 - regularization_loss: 0.0000e+00 - total_loss: 13711.3802
5/5 [==============================] - 3s 267ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0010 - factorized_top_k/top_5_categorical_accuracy: 0.0078 - factorized_top_k/top_10_categorical_accuracy: 0.0184 - factorized_top_k/top_50_categorical_accuracy: 0.1051 - factorized_top_k/top_100_categorical_accuracy: 0.2126 - loss: 30995.8988 - regularization_loss: 0.0000e+00 - total_loss: 30995.8988
Top-100 accuracy (train): 0.29.
Top-100 accuracy (test): 0.21.

これにより、ベースラインのトップ100の精度は約0.2になります。

時間機能を使用して時間ダイナミクスをキャプチャする

時間機能を追加すると、結果は変わりますか?

model = MovielensModel(use_timestamps=True)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))

model.fit(cached_train, epochs=3)

train_accuracy = model.evaluate(
    cached_train, return_dict=True)["factorized_top_k/top_100_categorical_accuracy"]
test_accuracy = model.evaluate(
    cached_test, return_dict=True)["factorized_top_k/top_100_categorical_accuracy"]

print(f"Top-100 accuracy (train): {train_accuracy:.2f}.")
print(f"Top-100 accuracy (test): {test_accuracy:.2f}.")
Epoch 1/3
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
40/40 [==============================] - 10s 175ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0057 - factorized_top_k/top_5_categorical_accuracy: 0.0148 - factorized_top_k/top_10_categorical_accuracy: 0.0238 - factorized_top_k/top_50_categorical_accuracy: 0.0812 - factorized_top_k/top_100_categorical_accuracy: 0.1487 - loss: 14606.0927 - regularization_loss: 0.0000e+00 - total_loss: 14606.0927
Epoch 2/3
40/40 [==============================] - 9s 176ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0026 - factorized_top_k/top_5_categorical_accuracy: 0.0153 - factorized_top_k/top_10_categorical_accuracy: 0.0304 - factorized_top_k/top_50_categorical_accuracy: 0.1375 - factorized_top_k/top_100_categorical_accuracy: 0.2512 - loss: 13958.5635 - regularization_loss: 0.0000e+00 - total_loss: 13958.5635
Epoch 3/3
40/40 [==============================] - 9s 177ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0026 - factorized_top_k/top_5_categorical_accuracy: 0.0189 - factorized_top_k/top_10_categorical_accuracy: 0.0393 - factorized_top_k/top_50_categorical_accuracy: 0.1713 - factorized_top_k/top_100_categorical_accuracy: 0.3015 - loss: 13696.8511 - regularization_loss: 0.0000e+00 - total_loss: 13696.8511
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
40/40 [==============================] - 9s 172ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0050 - factorized_top_k/top_5_categorical_accuracy: 0.0323 - factorized_top_k/top_10_categorical_accuracy: 0.0606 - factorized_top_k/top_50_categorical_accuracy: 0.2254 - factorized_top_k/top_100_categorical_accuracy: 0.3637 - loss: 13382.7869 - regularization_loss: 0.0000e+00 - total_loss: 13382.7869
5/5 [==============================] - 1s 237ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0012 - factorized_top_k/top_5_categorical_accuracy: 0.0097 - factorized_top_k/top_10_categorical_accuracy: 0.0214 - factorized_top_k/top_50_categorical_accuracy: 0.1259 - factorized_top_k/top_100_categorical_accuracy: 0.2468 - loss: 30699.8529 - regularization_loss: 0.0000e+00 - total_loss: 30699.8529
Top-100 accuracy (train): 0.36.
Top-100 accuracy (test): 0.25.

これはかなり優れています。トレーニングの精度がはるかに高いだけでなく、テストの精度も大幅に向上しています。

次のステップ

このチュートリアルは、より多くの機能を組み込むと、単純なモデルでもより正確になる可能性があることを示しています。ただし、機能を最大限に活用するには、多くの場合、より大きく、より深いモデルを構築する必要があります。これをより詳細に調べるには、詳細検索チュートリアルを参照してください。