Recommending Movies: Recommender Models in TFX

TFRS Tutorial Ported to TFX

This is a port of a basic TensorFlow Recommenders (TFRS) tutorial to TFX, which is designed to demonstrate how to use TFRS in a TFX pipeline. It mirrors the basic tutorial.

For context, real-world recommender systems are often composed of two stages:

  1. The retrieval stage is responsible for selecting an initial set of hundreds of candidates from all possible candidates. The main objective of this model is to efficiently weed out all candidates that the user is not interested in. Because the retrieval model may be dealing with millions of candidates, it has to be computationally efficient.
  2. The ranking stage takes the outputs of the retrieval model and fine-tunes them to select the best possible handful of recommendations. Its task is to narrow down the set of items the user may be interested in to a shortlist of likely candidates.

In this tutorial, we're going to focus on the first stage, retrieval. Retrieval models are often composed of two sub-models:

  1. A query model computing the query representation (normally a fixed-dimensionality embedding vector) using query features.
  2. A candidate model computing the candidate representation (an equally-sized vector) using the candidate features

The outputs of the two models are then multiplied together to give a query-candidate affinity score, with higher scores expressing a better match between the candidate and the query.

In this tutorial, we're going to build and train such a two-tower model using the Movielens dataset.

We're going to:

  1. Ingest and inspect the MovieLens dataset.
  2. Implement a retrieval model.
  3. Train and export the model.
  4. Make predictions

The dataset

The Movielens dataset is a classic dataset from the GroupLens research group at the University of Minnesota. It contains a set of ratings given to movies by a set of users, and is a workhorse of recommender system research.

The data can be treated in two ways:

  1. It can be interpreted as expressesing which movies the users watched (and rated), and which they did not. This is a form of implicit feedback, where users' watches tell us which things they prefer to see and which they'd rather not see.
  2. It can also be seen as expressesing how much the users liked the movies they did watch. This is a form of explicit feedback: given that a user watched a movie, we can tell roughly how much they liked by looking at the rating they have given.

In this tutorial, we are focusing on a retrieval system: a model that predicts a set of movies from the catalogue that the user is likely to watch. Often, implicit data is more useful here, and so we are going to treat Movielens as an implicit system. This means that every movie a user watched is a positive example, and every movie they have not seen is an implicit negative example.

Imports

Let's first get our imports out of the way.

pip install -Uq tfx
pip install -Uq tensorflow-recommenders
pip install -Uq tensorflow-datasets

Uninstall shapely

TODO(b/263441833) This is a temporal solution to avoid an ImportError. Ultimately, it should be handled by supporting a recent version of Bigquery, instead of uninstalling other extra dependencies.

pip uninstall shapely -y

Did you restart the runtime?

If you are using Google Colab, the first time that you run the cell above, you must restart the runtime (Runtime > Restart runtime ...). This is because of the way that Colab loads packages.

import os
import absl
import json
import pprint
import tempfile

from typing import Any, Dict, List, Text

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_recommenders as tfrs
import apache_beam as beam

from absl import logging

from tfx.components.example_gen.base_example_gen_executor import BaseExampleGenExecutor
from tfx.components.example_gen.component import FileBasedExampleGen
from tfx.components.example_gen import utils
from tfx.dsl.components.base import executor_spec

from tfx.types import artifact
from tfx.types import artifact_utils
from tfx.types import channel
from tfx.types import standard_artifacts
from tfx.types.standard_artifacts import Examples

from tfx.dsl.component.experimental.annotations import InputArtifact
from tfx.dsl.component.experimental.annotations import OutputArtifact
from tfx.dsl.component.experimental.annotations import Parameter
from tfx.dsl.component.experimental.decorators import component
from tfx.types.experimental.simple_artifacts import Dataset

from tfx import v1 as tfx
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext

# Set up logging.
tf.get_logger().propagate = False
absl.logging.set_verbosity(absl.logging.INFO)
pp = pprint.PrettyPrinter()

print(f"TensorFlow version: {tf.__version__}")
print(f"TFX version: {tfx.__version__}")
print(f"TensorFlow Recommenders version: {tfrs.__version__}")

%load_ext tfx.orchestration.experimental.interactive.notebook_extensions.skip
Using TensorFlow backend
TensorFlow version: 2.13.1
TFX version: 1.14.0
TensorFlow Recommenders version: v0.7.3

Create a TFDS ExampleGen

We create a custom ExampleGen component which we use to load a TensorFlow Datasets (TFDS) dataset. This uses a custom executor in a FileBasedExampleGen.

@beam.ptransform_fn
@beam.typehints.with_input_types(beam.Pipeline)
@beam.typehints.with_output_types(tf.train.Example)
def _TFDatasetToExample(  # pylint: disable=invalid-name
    pipeline: beam.Pipeline,
    exec_properties: Dict[str, Any],
    split_pattern: str
    ) -> beam.pvalue.PCollection:
    """Read a TensorFlow Dataset and create tf.Examples"""
    custom_config = json.loads(exec_properties['custom_config'])
    dataset_name = custom_config['dataset']
    split_name = custom_config['split']

    builder = tfds.builder(dataset_name)
    builder.download_and_prepare()

    return (pipeline
            | 'MakeExamples' >> tfds.beam.ReadFromTFDS(builder, split=split_name)
            | 'AsNumpy' >> beam.Map(tfds.as_numpy)
            | 'ToDict' >> beam.Map(dict)
            | 'ToTFExample' >> beam.Map(utils.dict_to_example)
            )

class TFDSExecutor(BaseExampleGenExecutor):
  def GetInputSourceToExamplePTransform(self) -> beam.PTransform:
    """Returns PTransform for TF Dataset to TF examples."""
    return _TFDatasetToExample

Init TFX Pipeline Context

context = InteractiveContext()
WARNING:absl:InteractiveContext pipeline_root argument not provided: using temporary directory /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev as root for pipeline outputs.
WARNING:absl:InteractiveContext metadata_connection_config not provided: using SQLite ML Metadata database at /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/metadata.sqlite.

Preparing the dataset

We will use our custom executor in a FileBasedExampleGen to load our datasets from TFDS. Since we have two datasets, we will create two ExampleGen components.

# Ratings data.
ratings_example_gen = FileBasedExampleGen(
    input_base='dummy',
    custom_config={'dataset':'movielens/100k-ratings', 'split':'train'},
    custom_executor_spec=executor_spec.ExecutorClassSpec(TFDSExecutor))
context.run(ratings_example_gen, enable_cache=True)
INFO:absl:Running driver for FileBasedExampleGen
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:select span and version = (0, None)
INFO:absl:latest span and version = (0, None)
INFO:absl:Running executor for FileBasedExampleGen
INFO:absl:Generating examples.
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.
INFO:absl:Load dataset info from gs://tensorflow-datasets/datasets/movielens/100k-ratings/0.1.1
INFO:absl:Reusing dataset movielens (gs://tensorflow-datasets/datasets/movielens/100k-ratings/0.1.1)
INFO:absl:Load dataset info from gs://tensorflow-datasets/datasets/movielens/100k-ratings/0.1.1
INFO:absl:Load dataset info from gs://tensorflow-datasets/datasets/movielens/100k-ratings/0.1.1
INFO:absl:Load dataset info from gs://tensorflow-datasets/datasets/movielens/100k-ratings/0.1.1
INFO:absl:Load dataset info from gs://tensorflow-datasets/datasets/movielens/100k-ratings/0.1.1
INFO:absl:Load dataset info from gs://tensorflow-datasets/datasets/movielens/100k-ratings/0.1.1
2023-10-03 09:15:41.495727: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1960] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
INFO:absl:Constructing tf.data.Dataset movielens for split train[0shard], from gs://tensorflow-datasets/datasets/movielens/100k-ratings/0.1.1
WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.
INFO:absl:Examples generated.
INFO:absl:Running publisher for FileBasedExampleGen
INFO:absl:MetadataStore with DB connection initialized
# Features of all the available movies.
movies_example_gen = FileBasedExampleGen(
    input_base='dummy',
    custom_config={'dataset':'movielens/100k-movies', 'split':'train'},
    custom_executor_spec=executor_spec.ExecutorClassSpec(TFDSExecutor))
context.run(movies_example_gen, enable_cache=True)
INFO:absl:Running driver for FileBasedExampleGen
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:select span and version = (0, None)
INFO:absl:latest span and version = (0, None)
INFO:absl:Running executor for FileBasedExampleGen
INFO:absl:Generating examples.
INFO:absl:Load dataset info from gs://tensorflow-datasets/datasets/movielens/100k-movies/0.1.1
INFO:absl:Reusing dataset movielens (gs://tensorflow-datasets/datasets/movielens/100k-movies/0.1.1)
INFO:absl:Load dataset info from gs://tensorflow-datasets/datasets/movielens/100k-movies/0.1.1
INFO:absl:Load dataset info from gs://tensorflow-datasets/datasets/movielens/100k-movies/0.1.1
INFO:absl:Load dataset info from gs://tensorflow-datasets/datasets/movielens/100k-movies/0.1.1
INFO:absl:Load dataset info from gs://tensorflow-datasets/datasets/movielens/100k-movies/0.1.1
INFO:absl:Load dataset info from gs://tensorflow-datasets/datasets/movielens/100k-movies/0.1.1
INFO:absl:Constructing tf.data.Dataset movielens for split train[0shard], from gs://tensorflow-datasets/datasets/movielens/100k-movies/0.1.1
INFO:absl:Examples generated.
INFO:absl:Running publisher for FileBasedExampleGen
INFO:absl:MetadataStore with DB connection initialized

Create inspect_examples utility

We create a convenience utility to inspect datasets of TF.Examples. The ratings dataset returns a dictionary of movie id, user id, the assigned rating, timestamp, movie information, and user information:

def inspect_examples(component,
                     channel_name='examples',
                     split_name='train',
                     num_examples=1):
  # Get the URI of the output artifact, which is a directory
  full_split_name = 'Split-{}'.format(split_name)
  print('channel_name: {}, split_name: {} (\"{}\"), num_examples: {}\n'.format(
      channel_name, split_name, full_split_name, num_examples))
  train_uri = os.path.join(
      component.outputs[channel_name].get()[0].uri, full_split_name)

  # Get the list of files in this directory (all compressed TFRecord files)
  tfrecord_filenames = [os.path.join(train_uri, name)
                        for name in os.listdir(train_uri)]

  # Create a `TFRecordDataset` to read these files
  dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")

  # Iterate over the records and print them
  for tfrecord in dataset.take(num_examples):
    serialized_example = tfrecord.numpy()
    example = tf.train.Example()
    example.ParseFromString(serialized_example)
    pp.pprint(example)

inspect_examples(ratings_example_gen)
channel_name: examples, split_name: train ("Split-train"), num_examples: 1

features {
  feature {
    key: "bucketized_user_age"
    value {
      float_list {
        value: 45.0
      }
    }
  }
  feature {
    key: "movie_genres"
    value {
      int64_list {
        value: 7
      }
    }
  }
  feature {
    key: "movie_id"
    value {
      bytes_list {
        value: "357"
      }
    }
  }
  feature {
    key: "movie_title"
    value {
      bytes_list {
        value: "One Flew Over the Cuckoo\'s Nest (1975)"
      }
    }
  }
  feature {
    key: "raw_user_age"
    value {
      float_list {
        value: 46.0
      }
    }
  }
  feature {
    key: "timestamp"
    value {
      int64_list {
        value: 879024327
      }
    }
  }
  feature {
    key: "user_gender"
    value {
      int64_list {
        value: 1
      }
    }
  }
  feature {
    key: "user_id"
    value {
      bytes_list {
        value: "138"
      }
    }
  }
  feature {
    key: "user_occupation_label"
    value {
      int64_list {
        value: 4
      }
    }
  }
  feature {
    key: "user_occupation_text"
    value {
      bytes_list {
        value: "doctor"
      }
    }
  }
  feature {
    key: "user_rating"
    value {
      float_list {
        value: 4.0
      }
    }
  }
  feature {
    key: "user_zip_code"
    value {
      bytes_list {
        value: "53211"
      }
    }
  }
}

The movies dataset contains the movie id, movie title, and data on what genres it belongs to. Note that the genres are encoded with integer labels.

inspect_examples(movies_example_gen)
channel_name: examples, split_name: train ("Split-train"), num_examples: 1

features {
  feature {
    key: "movie_genres"
    value {
      int64_list {
        value: 4
      }
    }
  }
  feature {
    key: "movie_id"
    value {
      bytes_list {
        value: "1681"
      }
    }
  }
  feature {
    key: "movie_title"
    value {
      bytes_list {
        value: "You So Crazy (1994)"
      }
    }
  }
}

ExampleGen did the split

When we ingested the movie lens dataset, our ExampleGen component split the data into train and eval splits. They are actually named Split-train and Split-eval. By default the split is 66% training, 34% evaluation.

Generate statistics for movies and ratings

For a TFX pipeline we need to generate statistics for the dataset. We do that by using a StatisticsGen component. These will be used by the SchemaGen component below when we generate a schema for our dataset. This is good practice anyway, because it's important to examine and analyze your data on an ongoing basis. Since we have two datasets we will create two StatisticsGen components.

movies_stats_gen = tfx.components.StatisticsGen(
    examples=movies_example_gen.outputs['examples'])
context.run(movies_stats_gen, enable_cache=True)
INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Running driver for StatisticsGen
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for StatisticsGen
INFO:absl:Generating statistics for split train.
INFO:absl:Statistics for split train written to /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/StatisticsGen/statistics/3/Split-train.
INFO:absl:Generating statistics for split eval.
INFO:absl:Statistics for split eval written to /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/StatisticsGen/statistics/3/Split-eval.
INFO:absl:Running publisher for StatisticsGen
INFO:absl:MetadataStore with DB connection initialized
context.show(movies_stats_gen.outputs['statistics'])
ratings_stats_gen = tfx.components.StatisticsGen(
    examples=ratings_example_gen.outputs['examples'])
context.run(ratings_stats_gen, enable_cache=True)
INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Running driver for StatisticsGen
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for StatisticsGen
INFO:absl:Generating statistics for split train.
INFO:absl:Statistics for split train written to /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/StatisticsGen/statistics/4/Split-train.
INFO:absl:Generating statistics for split eval.
INFO:absl:Statistics for split eval written to /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/StatisticsGen/statistics/4/Split-eval.
INFO:absl:Running publisher for StatisticsGen
INFO:absl:MetadataStore with DB connection initialized
context.show(ratings_stats_gen.outputs['statistics'])

Create schemas for movies and ratings

For a TFX pipeline we need to generate a data schema from our dataset. We do that by using a SchemaGen component. This will be used by the Transform component below to do our feature engineering in a way that is highly scalable to large datasets, and avoids training/serving skew. Since we have two datasets we will create two SchemaGen components.

movies_schema_gen = tfx.components.SchemaGen(
    statistics=movies_stats_gen.outputs['statistics'],
    infer_feature_shape=False)
context.run(movies_schema_gen, enable_cache=True)
INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Running driver for SchemaGen
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for SchemaGen
INFO:absl:Processing schema from statistics for split train.
INFO:absl:Processing schema from statistics for split eval.
INFO:absl:Schema written to /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/SchemaGen/schema/5/schema.pbtxt.
INFO:absl:Running publisher for SchemaGen
INFO:absl:MetadataStore with DB connection initialized
context.show(movies_schema_gen.outputs['schema'])
ratings_schema_gen = tfx.components.SchemaGen(
    statistics=ratings_stats_gen.outputs['statistics'],
    infer_feature_shape=False)
context.run(ratings_schema_gen, enable_cache=True)
INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Running driver for SchemaGen
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for SchemaGen
INFO:absl:Processing schema from statistics for split train.
INFO:absl:Processing schema from statistics for split eval.
INFO:absl:Schema written to /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/SchemaGen/schema/6/schema.pbtxt.
INFO:absl:Running publisher for SchemaGen
INFO:absl:MetadataStore with DB connection initialized
context.show(ratings_schema_gen.outputs['schema'])

Feature Engineering using Transform

For a structured and repeatable design of a TFX pipeline we will need a scalable approach to feature engineering. This allows us to handle the large datasets which are usually part of many recommender systems, and it also avoids training/serving skew. We will do that using the Transform component.

The Transform component uses a module file to supply user code for the feature engineering what we want to do, so our first step is to create that module file. Since we have two datasets, we will create two of these module files and two Transform components.

One of the things that our recommender needs is vocabularies for the user_id and movie_title fields. In the basic_retrieval tutorial those are created with inline Numpy, but here we will use Transform.

_movies_transform_module_file = 'movies_transform_module.py'
%%writefile {_movies_transform_module_file}

import tensorflow as tf
import tensorflow_transform as tft

def preprocessing_fn(inputs):
  # We only want the movie title
  return {'movie_title':inputs['movie_title']}
Writing movies_transform_module.py
movies_transform = tfx.components.Transform(
    examples=movies_example_gen.outputs['examples'],
    schema=movies_schema_gen.outputs['schema'],
    module_file=os.path.abspath(_movies_transform_module_file))
context.run(movies_transform, enable_cache=True)
INFO:absl:Generating ephemeral wheel package for '/tmpfs/src/temp/docs/tutorials/tfx/movies_transform_module.py' (including modules: ['movies_transform_module']).
INFO:absl:User module package has hash fingerprint version 5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '/tmpfs/tmp/tmp3o3_oocw/_tfx_generated_setup.py', 'bdist_wheel', '--bdist-dir', '/tmpfs/tmp/tmpq77qz33n', '--dist-dir', '/tmpfs/tmp/tmpr18bi7ov']
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
INFO:absl:Successfully built user code wheel distribution at '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204-py3-none-any.whl'; target user module is 'movies_transform_module'.
INFO:absl:Full user module path is 'movies_transform_module@/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204-py3-none-any.whl'
INFO:absl:Running driver for Transform
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for Transform
INFO:absl:Analyze the 'train' split and transform all splits when splits_config is not set.
INFO:absl:udf_utils.get_fn {'module_file': None, 'module_path': 'movies_transform_module@/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204-py3-none-any.whl', 'preprocessing_fn': None} 'preprocessing_fn'
INFO:absl:Installing '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204-py3-none-any.whl' to a temporary directory.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmpddkg9_x4', '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204-py3-none-any.whl']
running bdist_wheel
running build
running build_py
creating build
creating build/lib
copying movies_transform_module.py -> build/lib
installing to /tmpfs/tmp/tmpq77qz33n
running install
running install_lib
copying build/lib/movies_transform_module.py -> /tmpfs/tmp/tmpq77qz33n
running install_egg_info
running egg_info
creating tfx_user_code_Transform.egg-info
writing tfx_user_code_Transform.egg-info/PKG-INFO
writing dependency_links to tfx_user_code_Transform.egg-info/dependency_links.txt
writing top-level names to tfx_user_code_Transform.egg-info/top_level.txt
writing manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
reading manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
writing manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
Copying tfx_user_code_Transform.egg-info to /tmpfs/tmp/tmpq77qz33n/tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204-py3.9.egg-info
running install_scripts
creating /tmpfs/tmp/tmpq77qz33n/tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204.dist-info/WHEEL
creating '/tmpfs/tmp/tmpr18bi7ov/tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204-py3-none-any.whl' and adding '/tmpfs/tmp/tmpq77qz33n' to it
adding 'movies_transform_module.py'
adding 'tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204.dist-info/METADATA'
adding 'tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204.dist-info/WHEEL'
adding 'tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204.dist-info/top_level.txt'
adding 'tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204.dist-info/RECORD'
removing /tmpfs/tmp/tmpq77qz33n
Processing /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204-py3-none-any.whl
INFO:absl:Successfully installed '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204-py3-none-any.whl'.
INFO:absl:udf_utils.get_fn {'module_file': None, 'module_path': 'movies_transform_module@/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204-py3-none-any.whl', 'stats_options_updater_fn': None} 'stats_options_updater_fn'
INFO:absl:Installing '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204-py3-none-any.whl' to a temporary directory.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmp87aq_qj7', '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204-py3-none-any.whl']
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204
Processing /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204-py3-none-any.whl
INFO:absl:Successfully installed '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204-py3-none-any.whl'.
INFO:absl:Installing '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204-py3-none-any.whl' to a temporary directory.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmpf2fwliio', '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204-py3-none-any.whl']
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204
Processing /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204-py3-none-any.whl
INFO:absl:Successfully installed '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204-py3-none-any.whl'.
INFO:absl:Feature movie_genres has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_genres has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_genres has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_genres has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_genres has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_genres has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+5eb30f0529e01ad72232bd9acba34fc83d7fa66b99898a3d3ee424fbdf388204
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_genres has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_genres has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/Transform/transform_graph/7/.temp_path/tftransform_tmp/696c000bdd594a648af4deca581e8478/assets
INFO:absl:Writing fingerprint to /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/Transform/transform_graph/7/.temp_path/tftransform_tmp/696c000bdd594a648af4deca581e8478/fingerprint.pb
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:absl:Running publisher for Transform
INFO:absl:MetadataStore with DB connection initialized
context.show(movies_transform.outputs['post_transform_schema'])
inspect_examples(movies_transform, channel_name='transformed_examples')
channel_name: transformed_examples, split_name: train ("Split-train"), num_examples: 1

features {
  feature {
    key: "movie_title"
    value {
      bytes_list {
        value: "You So Crazy (1994)"
      }
    }
  }
}
_ratings_transform_module_file = 'ratings_transform_module.py'
%%writefile {_ratings_transform_module_file}

import tensorflow as tf
import tensorflow_transform as tft
import pdb

NUM_OOV_BUCKETS = 1

def preprocessing_fn(inputs):
  # We only want the user ID and the movie title, but we also need vocabularies
  # for both of them.  The vocabularies aren't features, they're only used by
  # the lookup.
  outputs = {}
  outputs['user_id'] = tft.sparse_tensor_to_dense_with_shape(inputs['user_id'], [None, 1], '-1')
  outputs['movie_title'] = tft.sparse_tensor_to_dense_with_shape(inputs['movie_title'], [None, 1], '-1')

  tft.compute_and_apply_vocabulary(
      inputs['user_id'],
      num_oov_buckets=NUM_OOV_BUCKETS,
      vocab_filename='user_id_vocab')

  tft.compute_and_apply_vocabulary(
      inputs['movie_title'],
      num_oov_buckets=NUM_OOV_BUCKETS,
      vocab_filename='movie_title_vocab')

  return outputs
Writing ratings_transform_module.py
ratings_transform = tfx.components.Transform(
    examples=ratings_example_gen.outputs['examples'],
    schema=ratings_schema_gen.outputs['schema'],
    module_file=os.path.abspath(_ratings_transform_module_file))
context.run(ratings_transform, enable_cache=True)
INFO:absl:Generating ephemeral wheel package for '/tmpfs/src/temp/docs/tutorials/tfx/ratings_transform_module.py' (including modules: ['movies_transform_module', 'ratings_transform_module']).
INFO:absl:User module package has hash fingerprint version 4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '/tmpfs/tmp/tmpamokpvuj/_tfx_generated_setup.py', 'bdist_wheel', '--bdist-dir', '/tmpfs/tmp/tmpl51qye5y', '--dist-dir', '/tmpfs/tmp/tmp0us0_9gf']
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
INFO:absl:Successfully built user code wheel distribution at '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51-py3-none-any.whl'; target user module is 'ratings_transform_module'.
INFO:absl:Full user module path is 'ratings_transform_module@/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51-py3-none-any.whl'
INFO:absl:Running driver for Transform
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for Transform
INFO:absl:Analyze the 'train' split and transform all splits when splits_config is not set.
INFO:absl:udf_utils.get_fn {'module_file': None, 'module_path': 'ratings_transform_module@/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51-py3-none-any.whl', 'preprocessing_fn': None} 'preprocessing_fn'
INFO:absl:Installing '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51-py3-none-any.whl' to a temporary directory.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmpgw7hahy7', '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51-py3-none-any.whl']
running bdist_wheel
running build
running build_py
creating build
creating build/lib
copying movies_transform_module.py -> build/lib
copying ratings_transform_module.py -> build/lib
installing to /tmpfs/tmp/tmpl51qye5y
running install
running install_lib
copying build/lib/movies_transform_module.py -> /tmpfs/tmp/tmpl51qye5y
copying build/lib/ratings_transform_module.py -> /tmpfs/tmp/tmpl51qye5y
running install_egg_info
running egg_info
creating tfx_user_code_Transform.egg-info
writing tfx_user_code_Transform.egg-info/PKG-INFO
writing dependency_links to tfx_user_code_Transform.egg-info/dependency_links.txt
writing top-level names to tfx_user_code_Transform.egg-info/top_level.txt
writing manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
reading manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
writing manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
Copying tfx_user_code_Transform.egg-info to /tmpfs/tmp/tmpl51qye5y/tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51-py3.9.egg-info
running install_scripts
creating /tmpfs/tmp/tmpl51qye5y/tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51.dist-info/WHEEL
creating '/tmpfs/tmp/tmp0us0_9gf/tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51-py3-none-any.whl' and adding '/tmpfs/tmp/tmpl51qye5y' to it
adding 'movies_transform_module.py'
adding 'ratings_transform_module.py'
adding 'tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51.dist-info/METADATA'
adding 'tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51.dist-info/WHEEL'
adding 'tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51.dist-info/top_level.txt'
adding 'tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51.dist-info/RECORD'
removing /tmpfs/tmp/tmpl51qye5y
Processing /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51-py3-none-any.whl
INFO:absl:Successfully installed '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51-py3-none-any.whl'.
INFO:absl:udf_utils.get_fn {'module_file': None, 'module_path': 'ratings_transform_module@/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51-py3-none-any.whl', 'stats_options_updater_fn': None} 'stats_options_updater_fn'
INFO:absl:Installing '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51-py3-none-any.whl' to a temporary directory.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmp6v984tvy', '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51-py3-none-any.whl']
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51
Processing /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51-py3-none-any.whl
INFO:absl:Successfully installed '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51-py3-none-any.whl'.
INFO:absl:Installing '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51-py3-none-any.whl' to a temporary directory.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmpghhslpe6', '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51-py3-none-any.whl']
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51
Processing /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51-py3-none-any.whl
INFO:absl:Successfully installed '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51-py3-none-any.whl'.
INFO:absl:Feature bucketized_user_age has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_genres has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature raw_user_age has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_gender has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_occupation_label has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_occupation_text has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_rating has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_zip_code has no shape. Setting to varlen_sparse_tensor.
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+4a5113f0b8c14180b5cd46cfa8cc0e3d065b2031e1567b99a9df81abd4940b51
INFO:absl:Feature bucketized_user_age has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_genres has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature raw_user_age has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_gender has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_occupation_label has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_occupation_text has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_rating has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_zip_code has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature bucketized_user_age has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_genres has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature raw_user_age has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_gender has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_occupation_label has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_occupation_text has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_rating has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_zip_code has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature bucketized_user_age has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_genres has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature raw_user_age has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_gender has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_occupation_label has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_occupation_text has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_rating has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_zip_code has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature bucketized_user_age has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_genres has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature raw_user_age has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_gender has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_occupation_label has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_occupation_text has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_rating has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_zip_code has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature bucketized_user_age has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_genres has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature raw_user_age has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_gender has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_occupation_label has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_occupation_text has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_rating has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_zip_code has no shape. Setting to varlen_sparse_tensor.
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: compute_and_apply_vocabulary/apply_vocab/text_file_init/InitializeTableFromTextFileV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: compute_and_apply_vocabulary_1/apply_vocab/text_file_init/InitializeTableFromTextFileV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: compute_and_apply_vocabulary/apply_vocab/text_file_init/InitializeTableFromTextFileV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: compute_and_apply_vocabulary_1/apply_vocab/text_file_init/InitializeTableFromTextFileV2
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature bucketized_user_age has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_genres has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature raw_user_age has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_gender has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_occupation_label has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_occupation_text has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_rating has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_zip_code has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature bucketized_user_age has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_genres has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature raw_user_age has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_gender has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_occupation_label has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_occupation_text has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_rating has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_zip_code has no shape. Setting to varlen_sparse_tensor.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/Transform/transform_graph/8/.temp_path/tftransform_tmp/2a033ce5c7904cac9a10b635c5b3090c/assets
INFO:absl:Writing fingerprint to /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/Transform/transform_graph/8/.temp_path/tftransform_tmp/2a033ce5c7904cac9a10b635c5b3090c/fingerprint.pb
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/Transform/transform_graph/8/.temp_path/tftransform_tmp/2c9d9b3909a141c18a8c913f050c8bd3/assets
INFO:absl:Writing fingerprint to /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/Transform/transform_graph/8/.temp_path/tftransform_tmp/2c9d9b3909a141c18a8c913f050c8bd3/fingerprint.pb
INFO:absl:Feature movie_title has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature user_id has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature movie_title has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature user_id has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:absl:Running publisher for Transform
INFO:absl:MetadataStore with DB connection initialized
context.show(ratings_transform.outputs['post_transform_schema'])
inspect_examples(ratings_transform, channel_name='transformed_examples')
channel_name: transformed_examples, split_name: train ("Split-train"), num_examples: 1

features {
  feature {
    key: "movie_title"
    value {
      bytes_list {
        value: "One Flew Over the Cuckoo\'s Nest (1975)"
      }
    }
  }
  feature {
    key: "user_id"
    value {
      bytes_list {
        value: "138"
      }
    }
  }
}

Implementing a model in TFX

In the basic_retrieval tutorial the model was created inline in the Python runtime. In a TFX pipeline, the model, metric, and loss are defined and trained in the module file for a pipeline component called Trainer. This makes the model, metric, and loss part of a repeatable process which can be automated and monitored.

TensorFlow Recommenders model architecture

We are going to build a two-tower retrieval model. The concept of two-tower means we will have a query tower computing the user representation using user features, and another item tower computing the movie representation using the movie features. We can build each tower separately (in the _build_user_model() and _build_movie_model() methods below) and then combine them in the final model (as in the MobieLensModel class). MovieLensModel is a subclass of tfrs.Model base class, which streamlines building models: all we need to do is to set up the components in the __init__ method, and implement the compute_loss method, taking in the raw features and returning a loss value.

# We're now going to create the module file for Trainer, which will include the
# code above with some modifications for TFX.

_trainer_module_file = 'trainer_module.py'
%%writefile {_trainer_module_file}

from typing import Dict, List, Text

import pdb

import os
import absl
import datetime
import glob
import tensorflow as tf
import tensorflow_transform as tft
import tensorflow_recommenders as tfrs

from absl import logging
from tfx.types import artifact_utils

from tfx import v1 as tfx
from tfx_bsl.coders import example_coder
from tfx_bsl.public import tfxio

absl.logging.set_verbosity(absl.logging.INFO)

EMBEDDING_DIMENSION = 32
INPUT_FN_BATCH_SIZE = 1


def extract_str_feature(dataset, feature_name):
  np_dataset = []
  for example in dataset:
    np_example = example_coder.ExampleToNumpyDict(example.numpy())
    np_dataset.append(np_example[feature_name][0].decode())
  return tf.data.Dataset.from_tensor_slices(np_dataset)


class MovielensModel(tfrs.Model):

  def __init__(self, user_model, movie_model, tf_transform_output, movies_uri):
    super().__init__()
    self.movie_model: tf.keras.Model = movie_model
    self.user_model: tf.keras.Model = user_model

    movies_artifact = movies_uri.get()[0]
    input_dir = artifact_utils.get_split_uri([movies_artifact], 'train')
    movie_files = glob.glob(os.path.join(input_dir, '*'))
    movies = tf.data.TFRecordDataset(movie_files, compression_type="GZIP")
    movies_dataset = extract_str_feature(movies, 'movie_title')

    loss_metrics = tfrs.metrics.FactorizedTopK(
        candidates=movies_dataset.batch(128).map(movie_model)
        )

    self.task: tf.keras.layers.Layer = tfrs.tasks.Retrieval(
        metrics=loss_metrics
        )


  def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:
    # We pick out the user features and pass them into the user model.
    try:
      user_embeddings = tf.squeeze(self.user_model(features['user_id']), axis=1)
      # And pick out the movie features and pass them into the movie model,
      # getting embeddings back.
      positive_movie_embeddings = self.movie_model(features['movie_title'])

      # The task computes the loss and the metrics.
      _task = self.task(user_embeddings, positive_movie_embeddings)
    except BaseException as err:
      logging.error('######## ERROR IN compute_loss:\n{}\n###############'.format(err))

    return _task


# This function will apply the same transform operation to training data
# and serving requests.
def _apply_preprocessing(raw_features, tft_layer):
  try:
    transformed_features = tft_layer(raw_features)
  except BaseException as err:
    logging.error('######## ERROR IN _apply_preprocessing:\n{}\n###############'.format(err))

  return transformed_features


def _input_fn(file_pattern: List[Text],
              data_accessor: tfx.components.DataAccessor,
              tf_transform_output: tft.TFTransformOutput,
              batch_size: int = 200) -> tf.data.Dataset:
  """Generates features and label for tuning/training.

  Args:
    file_pattern: List of paths or patterns of input tfrecord files.
    data_accessor: DataAccessor for converting input to RecordBatch.
    tf_transform_output: A TFTransformOutput.
    batch_size: representing the number of consecutive elements of returned
      dataset to combine in a single batch

  Returns:
    A dataset that contains (features, indices) tuple where features is a
      dictionary of Tensors, and indices is a single Tensor of label indices.
  """
  try:
    return data_accessor.tf_dataset_factory(
      file_pattern,
      tfxio.TensorFlowDatasetOptions(
          batch_size=batch_size),
      tf_transform_output.transformed_metadata.schema)
  except BaseException as err:
    logging.error('######## ERROR IN _input_fn:\n{}\n###############'.format(err))

  return None


def _get_serve_tf_examples_fn(model, tf_transform_output):
  """Returns a function that parses a serialized tf.Example and applies TFT."""
  try:
    model.tft_layer = tf_transform_output.transform_features_layer()

    @tf.function
    def serve_tf_examples_fn(serialized_tf_examples):
      """Returns the output to be used in the serving signature."""
      try:
        feature_spec = tf_transform_output.raw_feature_spec()
        parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)
        transformed_features = model.tft_layer(parsed_features)
        result = model(transformed_features)
      except BaseException as err:
        logging.error('######## ERROR IN serve_tf_examples_fn:\n{}\n###############'.format(err))
      return result
  except BaseException as err:
      logging.error('######## ERROR IN _get_serve_tf_examples_fn:\n{}\n###############'.format(err))

  return serve_tf_examples_fn


def _build_user_model(
    tf_transform_output: tft.TFTransformOutput, # Specific to ratings
    embedding_dimension: int = 32) -> tf.keras.Model:
  """Creates a Keras model for the query tower.

  Args:
    tf_transform_output: [tft.TFTransformOutput], the results of Transform
    embedding_dimension: [int], the dimensionality of the embedding space

  Returns:
    A keras Model.
  """
  try:
    unique_user_ids = tf_transform_output.vocabulary_by_name('user_id_vocab')
    users_vocab_str = [b.decode() for b in unique_user_ids]

    model = tf.keras.Sequential(
        [
         tf.keras.layers.StringLookup(
             vocabulary=users_vocab_str, mask_token=None),
         # We add an additional embedding to account for unknown tokens.
         tf.keras.layers.Embedding(len(users_vocab_str) + 1, embedding_dimension)
         ])
  except BaseException as err:
    logging.error('######## ERROR IN _build_user_model:\n{}\n###############'.format(err))

  return model


def _build_movie_model(
    tf_transform_output: tft.TFTransformOutput, # Specific to movies
    embedding_dimension: int = 32) -> tf.keras.Model:
  """Creates a Keras model for the candidate tower.

  Args:
    tf_transform_output: [tft.TFTransformOutput], the results of Transform
    embedding_dimension: [int], the dimensionality of the embedding space

  Returns:
    A keras Model.
  """
  try:
    unique_movie_titles = tf_transform_output.vocabulary_by_name('movie_title_vocab')
    titles_vocab_str = [b.decode() for b in unique_movie_titles]

    model = tf.keras.Sequential(
        [
         tf.keras.layers.StringLookup(
             vocabulary=titles_vocab_str, mask_token=None),
         # We add an additional embedding to account for unknown tokens.
         tf.keras.layers.Embedding(len(titles_vocab_str) + 1, embedding_dimension)
        ])
  except BaseException as err:
      logging.error('######## ERROR IN _build_movie_model:\n{}\n###############'.format(err))
  return model


# TFX Trainer will call this function.
def run_fn(fn_args: tfx.components.FnArgs):
  """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """
  try:
    tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

    train_dataset = _input_fn(fn_args.train_files, fn_args.data_accessor,
                              tf_transform_output, INPUT_FN_BATCH_SIZE)
    eval_dataset = _input_fn(fn_args.eval_files, fn_args.data_accessor,
                            tf_transform_output, INPUT_FN_BATCH_SIZE)

    model = MovielensModel(
        _build_user_model(tf_transform_output, EMBEDDING_DIMENSION),
        _build_movie_model(tf_transform_output, EMBEDDING_DIMENSION),
        tf_transform_output,
        fn_args.custom_config['movies']
        )

    tensorboard_callback = tf.keras.callbacks.TensorBoard(
        log_dir=fn_args.model_run_dir, update_freq='batch')

    model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.1))
  except BaseException as err:
    logging.error('######## ERROR IN run_fn before fit:\n{}\n###############'.format(err))

  try:
    model.fit(
        train_dataset,
        epochs=fn_args.custom_config['epochs'],
        steps_per_epoch=fn_args.train_steps,
        validation_data=eval_dataset,
        validation_steps=fn_args.eval_steps,
        callbacks=[tensorboard_callback])
  except BaseException as err:
      logging.error('######## ERROR IN run_fn during fit:\n{}\n###############'.format(err))

  try:
    index = tfrs.layers.factorized_top_k.BruteForce(model.user_model)

    movies_artifact = fn_args.custom_config['movies'].get()[0]
    input_dir = artifact_utils.get_split_uri([movies_artifact], 'eval')
    movie_files = glob.glob(os.path.join(input_dir, '*'))
    movies = tf.data.TFRecordDataset(movie_files, compression_type="GZIP")

    movies_dataset = extract_str_feature(movies, 'movie_title')

    index.index_from_dataset(
      tf.data.Dataset.zip((
          movies_dataset.batch(100),
          movies_dataset.batch(100).map(model.movie_model))
      )
    )

    # Run once so that we can get the right signatures into SavedModel
    _, titles = index(tf.constant(["42"]))
    print(f"Recommendations for user 42: {titles[0, :3]}")

    signatures = {
        'serving_default':
            _get_serve_tf_examples_fn(index,
                                      tf_transform_output).get_concrete_function(
                                          tf.TensorSpec(
                                              shape=[None],
                                              dtype=tf.string,
                                              name='examples')),
    }
    index.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)

  except BaseException as err:
      logging.error('######## ERROR IN run_fn during export:\n{}\n###############'.format(err))
Writing trainer_module.py

Training the model

After defining the model, we can run the Trainer component to do the model training.

trainer = tfx.components.Trainer(
    module_file=os.path.abspath(_trainer_module_file),
    examples=ratings_transform.outputs['transformed_examples'],
    transform_graph=ratings_transform.outputs['transform_graph'],
    schema=ratings_transform.outputs['post_transform_schema'],
    train_args=tfx.proto.TrainArgs(num_steps=500),
    eval_args=tfx.proto.EvalArgs(num_steps=10),
    custom_config={
        'epochs':5,
        'movies':movies_transform.outputs['transformed_examples'],
        'movie_schema':movies_transform.outputs['post_transform_schema'],
        'ratings':ratings_transform.outputs['transformed_examples'],
        'ratings_schema':ratings_transform.outputs['post_transform_schema']
        })

context.run(trainer, enable_cache=False)
INFO:absl:Generating ephemeral wheel package for '/tmpfs/src/temp/docs/tutorials/tfx/trainer_module.py' (including modules: ['trainer_module', 'movies_transform_module', 'ratings_transform_module']).
INFO:absl:User module package has hash fingerprint version 4c202258fc2c517eea8b489d39d665ef0cf758328d1dec40e9e9f405bfb5b918.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '/tmpfs/tmp/tmprw4uiyim/_tfx_generated_setup.py', 'bdist_wheel', '--bdist-dir', '/tmpfs/tmp/tmpvlz5rebl', '--dist-dir', '/tmpfs/tmp/tmpdi3u2ekq']
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
INFO:absl:Successfully built user code wheel distribution at '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Trainer-0.0+4c202258fc2c517eea8b489d39d665ef0cf758328d1dec40e9e9f405bfb5b918-py3-none-any.whl'; target user module is 'trainer_module'.
INFO:absl:Full user module path is 'trainer_module@/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Trainer-0.0+4c202258fc2c517eea8b489d39d665ef0cf758328d1dec40e9e9f405bfb5b918-py3-none-any.whl'
INFO:absl:Running driver for Trainer
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for Trainer
INFO:absl:Train on the 'train' split when train_args.splits is not set.
INFO:absl:Evaluate on the 'eval' split when eval_args.splits is not set.
WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE
WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE
running bdist_wheel
running build
running build_py
creating build
creating build/lib
copying trainer_module.py -> build/lib
copying movies_transform_module.py -> build/lib
copying ratings_transform_module.py -> build/lib
installing to /tmpfs/tmp/tmpvlz5rebl
running install
running install_lib
copying build/lib/trainer_module.py -> /tmpfs/tmp/tmpvlz5rebl
copying build/lib/movies_transform_module.py -> /tmpfs/tmp/tmpvlz5rebl
copying build/lib/ratings_transform_module.py -> /tmpfs/tmp/tmpvlz5rebl
running install_egg_info
running egg_info
creating tfx_user_code_Trainer.egg-info
writing tfx_user_code_Trainer.egg-info/PKG-INFO
writing dependency_links to tfx_user_code_Trainer.egg-info/dependency_links.txt
writing top-level names to tfx_user_code_Trainer.egg-info/top_level.txt
writing manifest file 'tfx_user_code_Trainer.egg-info/SOURCES.txt'
reading manifest file 'tfx_user_code_Trainer.egg-info/SOURCES.txt'
writing manifest file 'tfx_user_code_Trainer.egg-info/SOURCES.txt'
Copying tfx_user_code_Trainer.egg-info to /tmpfs/tmp/tmpvlz5rebl/tfx_user_code_Trainer-0.0+4c202258fc2c517eea8b489d39d665ef0cf758328d1dec40e9e9f405bfb5b918-py3.9.egg-info
running install_scripts
creating /tmpfs/tmp/tmpvlz5rebl/tfx_user_code_Trainer-0.0+4c202258fc2c517eea8b489d39d665ef0cf758328d1dec40e9e9f405bfb5b918.dist-info/WHEEL
creating '/tmpfs/tmp/tmpdi3u2ekq/tfx_user_code_Trainer-0.0+4c202258fc2c517eea8b489d39d665ef0cf758328d1dec40e9e9f405bfb5b918-py3-none-any.whl' and adding '/tmpfs/tmp/tmpvlz5rebl' to it
adding 'movies_transform_module.py'
adding 'ratings_transform_module.py'
adding 'trainer_module.py'
adding 'tfx_user_code_Trainer-0.0+4c202258fc2c517eea8b489d39d665ef0cf758328d1dec40e9e9f405bfb5b918.dist-info/METADATA'
adding 'tfx_user_code_Trainer-0.0+4c202258fc2c517eea8b489d39d665ef0cf758328d1dec40e9e9f405bfb5b918.dist-info/WHEEL'
adding 'tfx_user_code_Trainer-0.0+4c202258fc2c517eea8b489d39d665ef0cf758328d1dec40e9e9f405bfb5b918.dist-info/top_level.txt'
adding 'tfx_user_code_Trainer-0.0+4c202258fc2c517eea8b489d39d665ef0cf758328d1dec40e9e9f405bfb5b918.dist-info/RECORD'
removing /tmpfs/tmp/tmpvlz5rebl
WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE
INFO:absl:udf_utils.get_fn {'train_args': '{\n  "num_steps": 500\n}', 'eval_args': '{\n  "num_steps": 10\n}', 'module_file': None, 'run_fn': None, 'trainer_fn': None, 'custom_config': '{"epochs": 5, "movie_schema": {"__class__": "OutputChannel", "__module__": "tfx.types.channel", "__tfx_object_type__": "jsonable", "additional_custom_properties": {}, "additional_properties": {}, "artifacts": [{"__artifact_class_module__": "tfx.types.standard_artifacts", "__artifact_class_name__": "Schema", "artifact": {"custom_properties": {"name": {"string_value": "post_transform_schema:2023-10-03T09:17:10.175929"}, "producer_component": {"string_value": "Transform"}, "tfx_version": {"string_value": "1.14.0"} }, "id": "12", "name": "post_transform_schema:2023-10-03T09:17:10.175929", "state": "LIVE", "type_id": "18", "uri": "/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/Transform/post_transform_schema/7"}, "artifact_type": {"id": "18", "name": "Schema"} }], "output_key": "post_transform_schema", "producer_component_id": "Transform", "type": {"name": "Schema"} }, "movies": {"__class__": "OutputChannel", "__module__": "tfx.types.channel", "__tfx_object_type__": "jsonable", "additional_custom_properties": {}, "additional_properties": {}, "artifacts": [{"__artifact_class_module__": "tfx.types.standard_artifacts", "__artifact_class_name__": "Examples", "artifact": {"custom_properties": {"name": {"string_value": "transformed_examples:2023-10-03T09:17:10.175929"}, "producer_component": {"string_value": "Transform"}, "tfx_version": {"string_value": "1.14.0"} }, "id": "8", "name": "transformed_examples:2023-10-03T09:17:10.175929", "properties": {"split_names": {"string_value": "[\\"eval\\", \\"train\\"]"} }, "state": "LIVE", "type_id": "14", "uri": "/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/Transform/transformed_examples/7"}, "artifact_type": {"base_type": "DATASET", "id": "14", "name": "Examples", "properties": {"span": "INT", "split_names": "STRING", "version": "INT"} } }], "output_key": "transformed_examples", "producer_component_id": "Transform", "type": {"base_type": "DATASET", "name": "Examples", "properties": {"span": "INT", "split_names": "STRING", "version": "INT"} } }, "ratings": {"__class__": "OutputChannel", "__module__": "tfx.types.channel", "__tfx_object_type__": "jsonable", "additional_custom_properties": {}, "additional_properties": {}, "artifacts": [{"__artifact_class_module__": "tfx.types.standard_artifacts", "__artifact_class_name__": "Examples", "artifact": {"custom_properties": {"name": {"string_value": "transformed_examples:2023-10-03T09:17:23.150949"}, "producer_component": {"string_value": "Transform"}, "tfx_version": {"string_value": "1.14.0"} }, "id": "16", "name": "transformed_examples:2023-10-03T09:17:23.150949", "properties": {"split_names": {"string_value": "[\\"eval\\", \\"train\\"]"} }, "state": "LIVE", "type_id": "14", "uri": "/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/Transform/transformed_examples/8"}, "artifact_type": {"base_type": "DATASET", "id": "14", "name": "Examples", "properties": {"span": "INT", "split_names": "STRING", "version": "INT"} } }], "output_key": "transformed_examples", "producer_component_id": "Transform", "type": {"base_type": "DATASET", "name": "Examples", "properties": {"span": "INT", "split_names": "STRING", "version": "INT"} } }, "ratings_schema": {"__class__": "OutputChannel", "__module__": "tfx.types.channel", "__tfx_object_type__": "jsonable", "additional_custom_properties": {}, "additional_properties": {}, "artifacts": [{"__artifact_class_module__": "tfx.types.standard_artifacts", "__artifact_class_name__": "Schema", "artifact": {"custom_properties": {"name": {"string_value": "post_transform_schema:2023-10-03T09:17:23.150949"}, "producer_component": {"string_value": "Transform"}, "tfx_version": {"string_value": "1.14.0"} }, "id": "20", "name": "post_transform_schema:2023-10-03T09:17:23.150949", "state": "LIVE", "type_id": "18", "uri": "/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/Transform/post_transform_schema/8"}, "artifact_type": {"id": "18", "name": "Schema"} }], "output_key": "post_transform_schema", "producer_component_id": "Transform", "type": {"name": "Schema"} } }', 'module_path': 'trainer_module@/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Trainer-0.0+4c202258fc2c517eea8b489d39d665ef0cf758328d1dec40e9e9f405bfb5b918-py3-none-any.whl'} 'run_fn'
INFO:absl:Installing '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Trainer-0.0+4c202258fc2c517eea8b489d39d665ef0cf758328d1dec40e9e9f405bfb5b918-py3-none-any.whl' to a temporary directory.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmpbm1nxqo5', '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Trainer-0.0+4c202258fc2c517eea8b489d39d665ef0cf758328d1dec40e9e9f405bfb5b918-py3-none-any.whl']
Processing /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Trainer-0.0+4c202258fc2c517eea8b489d39d665ef0cf758328d1dec40e9e9f405bfb5b918-py3-none-any.whl
INFO:absl:Successfully installed '/tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/_wheels/tfx_user_code_Trainer-0.0+4c202258fc2c517eea8b489d39d665ef0cf758328d1dec40e9e9f405bfb5b918-py3-none-any.whl'.
INFO:absl:Training model.
INFO:absl:Feature movie_title has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature user_id has a shape dim {
  size: 1
}
. Setting to DenseTensor.
Installing collected packages: tfx-user-code-Trainer
Successfully installed tfx-user-code-Trainer-0.0+4c202258fc2c517eea8b489d39d665ef0cf758328d1dec40e9e9f405bfb5b918
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tfx_bsl/tfxio/tf_example_record.py:343: parse_example_dataset (from tensorflow.python.data.experimental.ops.parsing_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.map(tf.io.parse_example(...))` instead.
INFO:absl:Feature movie_title has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature user_id has a shape dim {
  size: 1
}
. Setting to DenseTensor.
Epoch 1/5
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py:1176: SyntaxWarning: In loss categorical_crossentropy, expected y_pred.shape to be (batch_size, num_classes) with num_classes > 1. Received: y_pred.shape=(1, 1). Consider using 'binary_crossentropy' if you only have 2 classes.
  return dispatch_target(*args, **kwargs)
500/500 [==============================] - 25s 46ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_5_categorical_accuracy: 0.0020 - factorized_top_k/top_10_categorical_accuracy: 0.0020 - factorized_top_k/top_50_categorical_accuracy: 0.0360 - factorized_top_k/top_100_categorical_accuracy: 0.0840 - loss: 0.0000e+00 - regularization_loss: 0.0000e+00 - total_loss: 0.0000e+00 - val_factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - val_factorized_top_k/top_5_categorical_accuracy: 0.0000e+00 - val_factorized_top_k/top_10_categorical_accuracy: 0.1000 - val_factorized_top_k/top_50_categorical_accuracy: 0.2000 - val_factorized_top_k/top_100_categorical_accuracy: 0.2000 - val_loss: 0.0000e+00 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.0000e+00
Epoch 2/5
500/500 [==============================] - 23s 46ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_5_categorical_accuracy: 0.0020 - factorized_top_k/top_10_categorical_accuracy: 0.0040 - factorized_top_k/top_50_categorical_accuracy: 0.0400 - factorized_top_k/top_100_categorical_accuracy: 0.0900 - loss: 0.0000e+00 - regularization_loss: 0.0000e+00 - total_loss: 0.0000e+00 - val_factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - val_factorized_top_k/top_5_categorical_accuracy: 0.0000e+00 - val_factorized_top_k/top_10_categorical_accuracy: 0.0000e+00 - val_factorized_top_k/top_50_categorical_accuracy: 0.2000 - val_factorized_top_k/top_100_categorical_accuracy: 0.2000 - val_loss: 0.0000e+00 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.0000e+00
Epoch 3/5
500/500 [==============================] - 23s 46ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_5_categorical_accuracy: 0.0060 - factorized_top_k/top_10_categorical_accuracy: 0.0140 - factorized_top_k/top_50_categorical_accuracy: 0.0500 - factorized_top_k/top_100_categorical_accuracy: 0.0800 - loss: 0.0000e+00 - regularization_loss: 0.0000e+00 - total_loss: 0.0000e+00 - val_factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - val_factorized_top_k/top_5_categorical_accuracy: 0.0000e+00 - val_factorized_top_k/top_10_categorical_accuracy: 0.0000e+00 - val_factorized_top_k/top_50_categorical_accuracy: 0.1000 - val_factorized_top_k/top_100_categorical_accuracy: 0.1000 - val_loss: 0.0000e+00 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.0000e+00
Epoch 4/5
500/500 [==============================] - 23s 46ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_5_categorical_accuracy: 0.0060 - factorized_top_k/top_10_categorical_accuracy: 0.0120 - factorized_top_k/top_50_categorical_accuracy: 0.0320 - factorized_top_k/top_100_categorical_accuracy: 0.0840 - loss: 0.0000e+00 - regularization_loss: 0.0000e+00 - total_loss: 0.0000e+00 - val_factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - val_factorized_top_k/top_5_categorical_accuracy: 0.0000e+00 - val_factorized_top_k/top_10_categorical_accuracy: 0.0000e+00 - val_factorized_top_k/top_50_categorical_accuracy: 0.0000e+00 - val_factorized_top_k/top_100_categorical_accuracy: 0.0000e+00 - val_loss: 0.0000e+00 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.0000e+00
Epoch 5/5
500/500 [==============================] - 23s 46ms/step - factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_5_categorical_accuracy: 0.0040 - factorized_top_k/top_10_categorical_accuracy: 0.0120 - factorized_top_k/top_50_categorical_accuracy: 0.0520 - factorized_top_k/top_100_categorical_accuracy: 0.0760 - loss: 0.0000e+00 - regularization_loss: 0.0000e+00 - total_loss: 0.0000e+00 - val_factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - val_factorized_top_k/top_5_categorical_accuracy: 0.0000e+00 - val_factorized_top_k/top_10_categorical_accuracy: 0.0000e+00 - val_factorized_top_k/top_50_categorical_accuracy: 0.1000 - val_factorized_top_k/top_100_categorical_accuracy: 0.2000 - val_loss: 0.0000e+00 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.0000e+00
Recommendations for user 42: [[b'Moll Flanders (1996)' b'Lone Star (1996)' b'Program, The (1993)'
  b"Joe's Apartment (1996)" b'Flirting With Disaster (1996)'
  b'Made in America (1993)' b"It's My Party (1995)"
  b'Maximum Risk (1996)' b'Killing Zoe (1994)'
  b'Land Before Time III: The Time of the Great Giving (1995) (V)']]
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:absl:Feature bucketized_user_age has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_genres has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature movie_title has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature raw_user_age has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature timestamp has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_gender has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_id has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_occupation_label has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_occupation_text has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_rating has no shape. Setting to varlen_sparse_tensor.
INFO:absl:Feature user_zip_code has no shape. Setting to varlen_sparse_tensor.
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/engine/functional.py:639: UserWarning: Input dict contained keys ['movie_title', 'user_id'] which did not match any model input. They will be ignored by the model.
  inputs = self._flatten_to_reference_inputs(inputs)
INFO:absl:Function `serve_tf_examples_fn` contains input name(s) 1485927, 1485937, table_handle, 1485949, resource with unsupported characters which will be renamed to transform_features_layer_1485927, transform_features_layer_1485937, brute_force_sequential_string_lookup_none_lookup_lookuptablefindv2_table_handle, brute_force_sequential_embedding_embedding_lookup_1485949, brute_force_gather_resource in the SavedModel.
WARNING:tensorflow:Model's `__init__()` arguments contain non-serializable objects. Please implement a `get_config()` method in the subclassed Model for proper saving and loading. Defaulting to empty config.
WARNING:tensorflow:Model's `__init__()` arguments contain non-serializable objects. Please implement a `get_config()` method in the subclassed Model for proper saving and loading. Defaulting to empty config.
INFO:absl:Found untraced functions such as query_with_exclusions while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/Trainer/model/9/Format-Serving/assets
INFO:absl:Writing fingerprint to /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/Trainer/model/9/Format-Serving/fingerprint.pb
WARNING:tensorflow:Model's `__init__()` arguments contain non-serializable objects. Please implement a `get_config()` method in the subclassed Model for proper saving and loading. Defaulting to empty config.
WARNING:tensorflow:Model's `__init__()` arguments contain non-serializable objects. Please implement a `get_config()` method in the subclassed Model for proper saving and loading. Defaulting to empty config.
INFO:absl:Training complete. Model written to /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/Trainer/model/9/Format-Serving. ModelRun written to /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/Trainer/model_run/9
INFO:absl:Running publisher for Trainer
INFO:absl:MetadataStore with DB connection initialized

Exporting the model

After training the model, we can use the Pusher component to export the model.

_serving_model_dir = os.path.join(tempfile.mkdtemp(), 'serving_model/tfrs_retrieval')

pusher = tfx.components.Pusher(
    model=trainer.outputs['model'],
    push_destination=tfx.proto.PushDestination(
        filesystem=tfx.proto.PushDestination.Filesystem(
            base_directory=_serving_model_dir)))
context.run(pusher, enable_cache=True)
INFO:absl:Running driver for Pusher
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Running executor for Pusher
WARNING:absl:Pusher is going to push the model without validation. Consider using Evaluator or InfraValidator in your pipeline.
INFO:absl:Model version: 1696324790
INFO:absl:Model written to serving path /tmpfs/tmp/tmpsi4_3ool/serving_model/tfrs_retrieval/1696324790.
INFO:absl:Model pushed to /tmpfs/tmp/tfx-interactive-2023-10-03T09_15_37.313254-ro4_raev/Pusher/pushed_model/10.
INFO:absl:Running publisher for Pusher
INFO:absl:MetadataStore with DB connection initialized

Make predictions

Now that we have a model, we load it back and make predictions.

loaded = tf.saved_model.load(pusher.outputs['pushed_model'].get()[0].uri)
scores, titles = loaded(["42"])

print(f"Recommendations: {titles[0][:3]}")
Recommendations: [[b'Moll Flanders (1996)' b'Lone Star (1996)' b'Program, The (1993)'
  b"Joe's Apartment (1996)" b'Flirting With Disaster (1996)'
  b'Made in America (1993)' b"It's My Party (1995)"
  b'Maximum Risk (1996)' b'Killing Zoe (1994)'
  b'Land Before Time III: The Time of the Great Giving (1995) (V)']]

Next step

In this tutorial, you have learned how to implement a retrieval model with TensorFlow Recommenders and TFX. To expand on what is presented here, have a look at the TFRS ranking with TFX tutorial.