컨텍스트 기능 활용

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 소스 보기 노트북 다운로드

기능화 튜토리얼에서 우리는 사용자 및 영화 식별자를 넘어 여러 기능을 모델에 통합했지만 이러한 기능이 모델 정확도를 향상시키는지 여부는 조사하지 않았습니다.

id 이상의 기능이 추천 모델에서 유용한지 여부에 영향을 미치는 많은 요인이 있습니다.

  1. 컨텍스트의 중요성 : 사용자 선호도가 컨텍스트와 시간에 걸쳐 상대적으로 안정적이라면 컨텍스트 기능은 많은 이점을 제공하지 않을 수 있습니다. 그러나 사용자 기본 설정이 매우 상황에 맞는 경우 컨텍스트를 추가하면 모델이 크게 향상됩니다. 예를 들어, 요일은 짧은 클립 또는 영화를 추천할지 여부를 결정할 때 중요한 기능일 수 있습니다. 사용자는 주중에 짧은 콘텐츠를 볼 시간이 있지만 주말에는 긴장을 풀고 장편 영화를 즐길 수 있습니다. . 유사하게, 쿼리 타임스탬프는 인기 역학을 모델링하는 데 중요한 역할을 할 수 있습니다. 한 영화가 개봉 당시에는 매우 인기가 있었지만 이후에는 빠르게 소멸됩니다. 반대로, 다른 영화는 몇 번이고 즐겁게 볼 수 있는 늘푸른 영화일 수 있습니다.
  2. 데이터 희소성 : 데이터가 희박한 경우 ID가 아닌 기능을 사용하는 것이 중요할 수 있습니다. 주어진 사용자 또는 항목에 대해 사용할 수 있는 관찰이 거의 없는 경우 모델은 사용자별 또는 항목별 표현을 추정하는 데 어려움을 겪을 수 있습니다. 정확한 모델을 구축하려면 항목 범주, 설명 및 이미지와 같은 다른 기능을 사용하여 모델이 교육 데이터 이상으로 일반화되도록 해야 합니다. 이는 일부 항목 또는 사용자에 대해 상대적으로 적은 데이터를 사용할 수 있는 콜드 스타트 상황에서 특히 관련이 있습니다.

이 튜토리얼에서는 MovieLens 모델에 영화 제목과 사용자 ID 이외의 기능을 사용하여 실험할 것입니다.

예선

먼저 필요한 패키지를 가져옵니다.

pip install -q tensorflow-recommenders
pip install -q --upgrade tensorflow-datasets
import os
import tempfile

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

import tensorflow_recommenders as tfrs

우리 는 기능화 튜토리얼 을 따르고 사용자 ID, 타임스탬프 및 영화 제목 기능을 유지합니다.

ratings = tfds.load("movielens/100k-ratings", split="train")
movies = tfds.load("movielens/100k-movies", split="train")

ratings = ratings.map(lambda x: {
    "movie_title": x["movie_title"],
    "user_id": x["user_id"],
    "timestamp": x["timestamp"],
})
movies = movies.map(lambda x: x["movie_title"])

우리는 또한 기능 어휘를 준비하기 위해 약간의 하우스키핑을 합니다.

timestamps = np.concatenate(list(ratings.map(lambda x: x["timestamp"]).batch(100)))

max_timestamp = timestamps.max()
min_timestamp = timestamps.min()

timestamp_buckets = np.linspace(
    min_timestamp, max_timestamp, num=1000,
)

unique_movie_titles = np.unique(np.concatenate(list(movies.batch(1000))))
unique_user_ids = np.unique(np.concatenate(list(ratings.batch(1_000).map(
    lambda x: x["user_id"]))))

모델 정의

쿼리 모델

원시 입력 예제를 기능 임베딩으로 변환하는 작업을 수행하는 모델의 첫 번째 계층으로 기능화 자습서 에 정의된 사용자 모델부터 시작합니다. 그러나 타임스탬프 기능을 켜거나 끌 수 있도록 약간 변경합니다. 이를 통해 타임스탬프 기능이 모델에 미치는 영향을 보다 쉽게 ​​시연할 수 있습니다. 아래 코드에서 use_timestamps 매개변수를 사용하면 타임스탬프 기능을 사용할지 여부를 제어할 수 있습니다.

class UserModel(tf.keras.Model):

  def __init__(self, use_timestamps):
    super().__init__()

    self._use_timestamps = use_timestamps

    self.user_embedding = tf.keras.Sequential([
        tf.keras.layers.StringLookup(
            vocabulary=unique_user_ids, mask_token=None),
        tf.keras.layers.Embedding(len(unique_user_ids) + 1, 32),
    ])

    if use_timestamps:
      self.timestamp_embedding = tf.keras.Sequential([
          tf.keras.layers.Discretization(timestamp_buckets.tolist()),
          tf.keras.layers.Embedding(len(timestamp_buckets) + 1, 32),
      ])
      self.normalized_timestamp = tf.keras.layers.Normalization(
          axis=None
      )

      self.normalized_timestamp.adapt(timestamps)

  def call(self, inputs):
    if not self._use_timestamps:
      return self.user_embedding(inputs["user_id"])

    return tf.concat([
        self.user_embedding(inputs["user_id"]),
        self.timestamp_embedding(inputs["timestamp"]),
        tf.reshape(self.normalized_timestamp(inputs["timestamp"]), (-1, 1)),
    ], axis=1)

이 자습서에서 타임스탬프 기능을 사용하면 훈련-테스트 분할 선택과 바람직하지 않은 방식으로 상호 작용합니다. 테스트 데이터 세트에 속하는 이벤트가 훈련 세트의 이벤트보다 나중에 발생하도록 하기 위해 시간순이 아닌 무작위로 데이터를 분할했기 때문에 우리 모델은 미래로부터 효과적으로 학습할 수 있습니다. 이것은 비현실적입니다. 결국 내일의 데이터로 오늘 모델을 훈련할 수 없습니다.

이는 모델에 시간 기능을 추가하면 미래의 상호 작용 패턴을 학습할 수 있음을 의미합니다. 설명 목적으로만 이 작업을 수행합니다. MovieLens 데이터 세트 자체는 매우 밀도가 높으며 많은 실제 데이터 세트와 달리 사용자 ID 및 영화 제목 이외의 기능에서 큰 이점을 얻지 못합니다.

이 경고는 제쳐두고, 실제 모델은 특히 데이터에 강한 계절적 패턴이 있는 경우 시간 또는 요일과 같은 다른 시간 기반 기능의 이점을 얻을 수 있습니다.

후보 모델

단순화를 위해 후보 모델을 고정된 상태로 유지합니다. 다시 한 번 기능화 자습서에서 복사합니다.

class MovieModel(tf.keras.Model):

  def __init__(self):
    super().__init__()

    max_tokens = 10_000

    self.title_embedding = tf.keras.Sequential([
      tf.keras.layers.StringLookup(
          vocabulary=unique_movie_titles, mask_token=None),
      tf.keras.layers.Embedding(len(unique_movie_titles) + 1, 32)
    ])

    self.title_vectorizer = tf.keras.layers.TextVectorization(
        max_tokens=max_tokens)

    self.title_text_embedding = tf.keras.Sequential([
      self.title_vectorizer,
      tf.keras.layers.Embedding(max_tokens, 32, mask_zero=True),
      tf.keras.layers.GlobalAveragePooling1D(),
    ])

    self.title_vectorizer.adapt(movies)

  def call(self, titles):
    return tf.concat([
        self.title_embedding(titles),
        self.title_text_embedding(titles),
    ], axis=1)

결합 모델

UserModelMovieModel 을 모두 정의하면 결합된 모델을 조합하고 손실 및 메트릭 논리를 구현할 수 있습니다.

여기에서 검색 모델을 구축하고 있습니다. 작동 방식에 대한 복습은 기본 검색 자습서를 참조하십시오.

또한 쿼리 모델과 후보 모델이 호환되는 크기의 임베딩을 출력하는지 확인해야 합니다. 더 많은 기능을 추가하여 크기를 다양화할 것이기 때문에 이를 수행하는 가장 쉬운 방법은 각 모델 다음에 조밀한 투영 레이어를 사용하는 것입니다.

class MovielensModel(tfrs.models.Model):

  def __init__(self, use_timestamps):
    super().__init__()
    self.query_model = tf.keras.Sequential([
      UserModel(use_timestamps),
      tf.keras.layers.Dense(32)
    ])
    self.candidate_model = tf.keras.Sequential([
      MovieModel(),
      tf.keras.layers.Dense(32)
    ])
    self.task = tfrs.tasks.Retrieval(
        metrics=tfrs.metrics.FactorizedTopK(
            candidates=movies.batch(128).map(self.candidate_model),
        ),
    )

  def compute_loss(self, features, training=False):
    # We only pass the user id and timestamp features into the query model. This
    # is to ensure that the training inputs would have the same keys as the
    # query inputs. Otherwise the discrepancy in input structure would cause an
    # error when loading the query model after saving it.
    query_embeddings = self.query_model({
        "user_id": features["user_id"],
        "timestamp": features["timestamp"],
    })
    movie_embeddings = self.candidate_model(features["movie_title"])

    return self.task(query_embeddings, movie_embeddings)

실험

데이터 준비

먼저 데이터를 훈련 세트와 테스트 세트로 나눕니다.

tf.random.set_seed(42)
shuffled = ratings.shuffle(100_000, seed=42, reshuffle_each_iteration=False)

train = shuffled.take(80_000)
test = shuffled.skip(80_000).take(20_000)

cached_train = train.shuffle(100_000).batch(2048)
cached_test = test.batch(4096).cache()

기준: 타임스탬프 기능 없음

첫 번째 모델을 시험해 볼 준비가 되었습니다. 기준선을 설정하기 위해 타임스탬프 기능을 사용하지 않는 것부터 시작하겠습니다.

model = MovielensModel(use_timestamps=False)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))

model.fit(cached_train, epochs=3)

train_accuracy = model.evaluate(
    cached_train, return_dict=True)["factorized_top_k/top_100_categorical_accuracy"]
test_accuracy = model.evaluate(
    cached_test, return_dict=True)["factorized_top_k/top_100_categorical_accuracy"]

print(f"Top-100 accuracy (train): {train_accuracy:.2f}.")
print(f"Top-100 accuracy (test): {test_accuracy:.2f}.")
Epoch 1/3
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
40/40 [==============================] - 10s 169ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0092 - factorized_top_k/top_5_categorical_accuracy: 0.0172 - factorized_top_k/top_10_categorical_accuracy: 0.0256 - factorized_top_k/top_50_categorical_accuracy: 0.0824 - factorized_top_k/top_100_categorical_accuracy: 0.1473 - loss: 14579.4628 - regularization_loss: 0.0000e+00 - total_loss: 14579.4628
Epoch 2/3
40/40 [==============================] - 9s 173ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0020 - factorized_top_k/top_5_categorical_accuracy: 0.0126 - factorized_top_k/top_10_categorical_accuracy: 0.0251 - factorized_top_k/top_50_categorical_accuracy: 0.1129 - factorized_top_k/top_100_categorical_accuracy: 0.2133 - loss: 14136.2137 - regularization_loss: 0.0000e+00 - total_loss: 14136.2137
Epoch 3/3
40/40 [==============================] - 9s 174ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0021 - factorized_top_k/top_5_categorical_accuracy: 0.0155 - factorized_top_k/top_10_categorical_accuracy: 0.0307 - factorized_top_k/top_50_categorical_accuracy: 0.1389 - factorized_top_k/top_100_categorical_accuracy: 0.2535 - loss: 13939.9265 - regularization_loss: 0.0000e+00 - total_loss: 13939.9265
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
40/40 [==============================] - 10s 189ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0036 - factorized_top_k/top_5_categorical_accuracy: 0.0226 - factorized_top_k/top_10_categorical_accuracy: 0.0427 - factorized_top_k/top_50_categorical_accuracy: 0.1729 - factorized_top_k/top_100_categorical_accuracy: 0.2944 - loss: 13711.3802 - regularization_loss: 0.0000e+00 - total_loss: 13711.3802
5/5 [==============================] - 3s 267ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0010 - factorized_top_k/top_5_categorical_accuracy: 0.0078 - factorized_top_k/top_10_categorical_accuracy: 0.0184 - factorized_top_k/top_50_categorical_accuracy: 0.1051 - factorized_top_k/top_100_categorical_accuracy: 0.2126 - loss: 30995.8988 - regularization_loss: 0.0000e+00 - total_loss: 30995.8988
Top-100 accuracy (train): 0.29.
Top-100 accuracy (test): 0.21.

이것은 약 0.2의 기준선 상위 100개 정확도를 제공합니다.

시간 기능으로 시간 역학 캡처

시간 기능을 추가하면 결과가 변경됩니까?

model = MovielensModel(use_timestamps=True)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))

model.fit(cached_train, epochs=3)

train_accuracy = model.evaluate(
    cached_train, return_dict=True)["factorized_top_k/top_100_categorical_accuracy"]
test_accuracy = model.evaluate(
    cached_test, return_dict=True)["factorized_top_k/top_100_categorical_accuracy"]

print(f"Top-100 accuracy (train): {train_accuracy:.2f}.")
print(f"Top-100 accuracy (test): {test_accuracy:.2f}.")
Epoch 1/3
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
40/40 [==============================] - 10s 175ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0057 - factorized_top_k/top_5_categorical_accuracy: 0.0148 - factorized_top_k/top_10_categorical_accuracy: 0.0238 - factorized_top_k/top_50_categorical_accuracy: 0.0812 - factorized_top_k/top_100_categorical_accuracy: 0.1487 - loss: 14606.0927 - regularization_loss: 0.0000e+00 - total_loss: 14606.0927
Epoch 2/3
40/40 [==============================] - 9s 176ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0026 - factorized_top_k/top_5_categorical_accuracy: 0.0153 - factorized_top_k/top_10_categorical_accuracy: 0.0304 - factorized_top_k/top_50_categorical_accuracy: 0.1375 - factorized_top_k/top_100_categorical_accuracy: 0.2512 - loss: 13958.5635 - regularization_loss: 0.0000e+00 - total_loss: 13958.5635
Epoch 3/3
40/40 [==============================] - 9s 177ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0026 - factorized_top_k/top_5_categorical_accuracy: 0.0189 - factorized_top_k/top_10_categorical_accuracy: 0.0393 - factorized_top_k/top_50_categorical_accuracy: 0.1713 - factorized_top_k/top_100_categorical_accuracy: 0.3015 - loss: 13696.8511 - regularization_loss: 0.0000e+00 - total_loss: 13696.8511
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor. Received: inputs={'user_id': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=string>, 'timestamp': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=int64>}. Consider rewriting this model with the Functional API.
40/40 [==============================] - 9s 172ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0050 - factorized_top_k/top_5_categorical_accuracy: 0.0323 - factorized_top_k/top_10_categorical_accuracy: 0.0606 - factorized_top_k/top_50_categorical_accuracy: 0.2254 - factorized_top_k/top_100_categorical_accuracy: 0.3637 - loss: 13382.7869 - regularization_loss: 0.0000e+00 - total_loss: 13382.7869
5/5 [==============================] - 1s 237ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0012 - factorized_top_k/top_5_categorical_accuracy: 0.0097 - factorized_top_k/top_10_categorical_accuracy: 0.0214 - factorized_top_k/top_50_categorical_accuracy: 0.1259 - factorized_top_k/top_100_categorical_accuracy: 0.2468 - loss: 30699.8529 - regularization_loss: 0.0000e+00 - total_loss: 30699.8529
Top-100 accuracy (train): 0.36.
Top-100 accuracy (test): 0.25.

이것은 훨씬 더 좋습니다. 훈련 정확도가 훨씬 높을 뿐만 아니라 테스트 정확도도 상당히 향상됩니다.

다음 단계

이 튜토리얼은 단순한 모델이라도 더 많은 기능을 통합할 때 더 정확해질 수 있음을 보여줍니다. 그러나 기능을 최대한 활용하려면 더 크고 심층적인 모델을 구축해야 하는 경우가 많습니다. 더 자세히 알아보려면 심층 검색 튜토리얼 을 살펴보세요.