TensorFlow 2 TPUEmbeddingLayer: Quick Start

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook


This Colab gives a brief introduction into the TPUEmbeddingLayer of TensorFlow 2.

The TPUEmbeddingLayer can use the embedding accelerator on the Cloud TPU to speed up embedding lookups when you have many large embedding tables. This is particularly useful when creating recommendation models as these models typically have very large embedding tables.

Please follow the Google Cloud TPU quickstart for how to create GCP account and GCS bucket. You have $300 free credit to get started with any GCP product. You can learn more about Cloud TPU at https://cloud.google.com/tpu/docs


Install Tensorflow 2.0 and Tensorflow-Recommenders

pip install -U tensorflow-recommenders
import numpy as np
import tensorflow as tf
import tensorflow_recommenders as tfrs

Connect to the TPU node or local TPU and initialize the TPU system.

resolver = tf.distribute.cluster_resolver.TPUClusterResolver('').connect('')

Create the TPU strategy. Model that needs to run on TPU should be created under TPUStrategy.

strategy = tf.distribute.TPUStrategy(resolver)

You can also check the tpu hardware feature in the TPUStrategy object.

For example, you can check which version of embedding feature is supported on this TPU. Check the tf.tpu.experimental.HardwareFeature for detailed documentation.

embedding_feature = strategy.extended.tpu_hardware_feature.embedding_feature
assert embedding_feature == tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V1, 'Make sure that you have the right TPU Hardware'

TPUEmbedding API break down

Feature and table configuration

When creating an instance of this layer, you must specify:

  1. The complete set of embedding tables,
  2. The features you expect to lookup in those tables and
  3. The optimizer(s) you wish to use on the tables.

See the documentation of tf.tpu.experimental.embedding.TableConfig and tf.tpu.experimental.embedding.FeatureConfig for more details on the complete set of options. We will cover the basic usage here.

Multiple FeatureConfig objects can use the same TableConfig object, allowing different features to share the same table:

table_config_one = tf.tpu.experimental.embedding.TableConfig(
    vocabulary_size=8, dim=8)
table_config_two = tf.tpu.experimental.embedding.TableConfig(
    vocabulary_size=16, dim=4)
feature_config = {


An optimizer can be globally specified by passing one of the following types of input to the optimizer argument:

  1. A string, one of 'sgd', 'adagrad' or 'adam', which uses the given optimizer with the default parameters.
  2. An instance of a Keras optimizer.
  3. An instance of an optimizer class from the tf.tpu.experimental.embedding module.

You may also specify an optimizer at the table level via the optimizer argument of tf.tpu.experimental.embedding.TableConfig. This will completely override the global optimizer for this table. For performance reasons it is recommended that you minimize the total number of distinct optimizers.


Model Creation

Here are two examples of creating a keras model with tpu embedding layer in it.

For a functional style Keras model:

with strategy.scope():
  embedding_inputs = {
          tf.keras.Input(batch_size=1024, shape=(1,), dtype=tf.int32),
              batch_size=1024, shape=(1,), dtype=tf.int32, ragged=True),
          tf.keras.Input(batch_size=1024, shape=(1,), dtype=tf.int32)
  # embedding, feature_config and embedding_inputs all have the same nested
  # structure.
  embedding = tfrs.layers.embedding.TPUEmbedding(
      feature_config=feature_config, optimizer=optimizer)(
  logits = tf.keras.layers.Dense(1)(
      tf.concat(tf.nest.flatten(embedding), axis=1))
  model = tf.keras.Model(embedding_inputs, logits)

For a subclass style model:

class ModelWithEmbeddings(tf.keras.Model):

  def __init__(self):
    super(ModelWithEmbeddings, self).__init__()
    self.embedding_layer = tfrs.layers.embedding.TPUEmbedding(
        feature_config=feature_config, optimizer=optimizer)
    self.dense = tf.keras.layers.Dense(1)

  def call(self, inputs):
    embedding = self.embedding_layer(inputs)
    logits = self.dense(tf.concat(tf.nest.flatten(embedding), axis=1))
    return logits

# Make sure that the tpu is reinitialized when you try to create another mdoel
with strategy.scope():
  model = ModelWithEmbeddings()
WARNING:tensorflow:TPU system grpc:// has already been initialized. Reinitializing the TPU can cause previously created variables on TPU to be lost.
WARNING:tensorflow:TPU system grpc:// has already been initialized. Reinitializing the TPU can cause previously created variables on TPU to be lost.
<tensorflow.python.tpu.topology.Topology at 0x7f2085f74400>

Simple TPUEmbeddingLayer example

In this tutorial, we build a simple ranking model using the MovieLens 100K dataset with TPUEmbeddingLayer. We can use this model to predict ratings based on user_id and movie_id.

Install and import tensorflow datasets

pip install -q --upgrade tensorflow-datasets
import tensorflow_datasets as tfds

Read the data

In order to make the dataset accessible to the Cloud TPU worker. You need to create a gcs bucket and download the dataset to the bucket. Follow this instructions to create your gcs bucket.

gcs_bucket = 'gs://YOUR-BUCKET-NAME'
from google.colab import auth

First we fetch the data using tensorflow_dataset. The data that we need is movie_id, user_id and user_rating.

Then preprocess the data and convert them to integers.

# Ratings data.
ratings = tfds.load(
    "movielens/100k-ratings", data_dir=gcs_bucket, split="train")

# Select the basic features.
ratings = ratings.map(
    lambda x: {
        "movie_id": tf.strings.to_number(x["movie_id"]),
        "user_id": tf.strings.to_number(x["user_id"]),
        "user_rating": x["user_rating"],

Prepare the dataset and model

Here we define some hyperparameters for the model.

per_replica_batch_size = 16
movie_vocabulary_size = 2048
movie_embedding_size = 64
user_vocabulary_size = 2048
user_embedding_size = 64

We'll split the data by putting 80% of the ratings in the train set, and 20% in the test set.

shuffled = ratings.shuffle(100_000, seed=42, reshuffle_each_iteration=False)

train = shuffled.take(80_000)
test = shuffled.skip(80_000).take(20_000)

Batch the dataset and convert it to a distributed dataset.

train_dataset = train.batch(
    per_replica_batch_size * strategy.num_replicas_in_sync,
test_dataset = test.batch(
    per_replica_batch_size * strategy.num_replicas_in_sync,
distribute_train_dataset = strategy.experimental_distribute_dataset(
distribute_test_dataset = strategy.experimental_distribute_dataset(

Here we create the optimizer, specify the feature and table config. Then we create the model with embedding layer.

optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.1)

user_table = tf.tpu.experimental.embedding.TableConfig(
    vocabulary_size=user_vocabulary_size, dim=user_embedding_size)
movie_table = tf.tpu.experimental.embedding.TableConfig(
    vocabulary_size=movie_vocabulary_size, dim=movie_embedding_size)
feature_config = {
    "movie_id": tf.tpu.experimental.embedding.FeatureConfig(table=movie_table),
    "user_id": tf.tpu.experimental.embedding.FeatureConfig(table=user_table)

# Define a ranking model with embedding layer.
class EmbeddingModel(tfrs.models.Model):

  def __init__(self):

    self.embedding_layer = tfrs.layers.embedding.TPUEmbedding(
        feature_config=feature_config, optimizer=optimizer)
    self.ratings = tf.keras.Sequential([
        # Learn multiple dense layers.
        tf.keras.layers.Dense(256, activation="relu"),
        tf.keras.layers.Dense(64, activation="relu"),
        # Make rating predictions in the final layer.
    self.task: tf.keras.layers.Layer = tfrs.tasks.Ranking(

  def compute_loss(self, features, training=False):
    embedding = self.embedding_layer({
        "user_id": features["user_id"],
        "movie_id": features["movie_id"]
    rating_predictions = self.ratings(
        tf.concat([embedding["user_id"], embedding["movie_id"]], axis=1))

    return tf.reduce_sum(
            labels=features["user_rating"], predictions=rating_predictions)) * (
                1 / (per_replica_batch_size * strategy.num_replicas_in_sync))

  def call(self, features, serving_config=None):
    embedding = self.embedding_layer(
            "user_id": features["user_id"],
            "movie_id": features["movie_id"]
    return self.ratings(
        tf.concat([embedding["user_id"], embedding["movie_id"]], axis=1))

Make sure that you initialize the model under TPUStrategy.

with strategy.scope():
  model = EmbeddingModel()

Train and eval the model

import os

Train the model

model.fit(distribute_train_dataset, steps_per_epoch=10, epochs=10)
Epoch 1/10
10/10 [==============================] - 7s 32ms/step - root_mean_squared_error: 2.7897 - loss: 0.0564 - regularization_loss: 0.0000e+00 - total_loss: 0.0564
Epoch 2/10
10/10 [==============================] - 0s 26ms/step - root_mean_squared_error: 1.1963 - loss: 0.0088 - regularization_loss: 0.0000e+00 - total_loss: 0.0088
Epoch 3/10
10/10 [==============================] - 0s 25ms/step - root_mean_squared_error: 1.1261 - loss: 0.0089 - regularization_loss: 0.0000e+00 - total_loss: 0.0089
Epoch 4/10
10/10 [==============================] - 0s 35ms/step - root_mean_squared_error: 1.1403 - loss: 0.0094 - regularization_loss: 0.0000e+00 - total_loss: 0.0094
Epoch 5/10
10/10 [==============================] - 0s 40ms/step - root_mean_squared_error: 1.1269 - loss: 0.0103 - regularization_loss: 0.0000e+00 - total_loss: 0.0103
Epoch 6/10
10/10 [==============================] - 0s 36ms/step - root_mean_squared_error: 1.1162 - loss: 0.0100 - regularization_loss: 0.0000e+00 - total_loss: 0.0100
Epoch 7/10
10/10 [==============================] - 0s 36ms/step - root_mean_squared_error: 1.1365 - loss: 0.0097 - regularization_loss: 0.0000e+00 - total_loss: 0.0097
Epoch 8/10
10/10 [==============================] - 0s 47ms/step - root_mean_squared_error: 1.1171 - loss: 0.0110 - regularization_loss: 0.0000e+00 - total_loss: 0.0110
Epoch 9/10
10/10 [==============================] - 0s 48ms/step - root_mean_squared_error: 1.1037 - loss: 0.0100 - regularization_loss: 0.0000e+00 - total_loss: 0.0100
Epoch 10/10
10/10 [==============================] - 0s 51ms/step - root_mean_squared_error: 1.0953 - loss: 0.0092 - regularization_loss: 0.0000e+00 - total_loss: 0.0092
<keras.callbacks.History at 0x7f2084d7ddf0>

Evaluate the model on test dataset

model.evaluate(distribute_test_dataset, steps=10)
10/10 [==============================] - 4s 27ms/step - root_mean_squared_error: 1.1339 - loss: 0.0090 - regularization_loss: 0.0000e+00 - total_loss: 0.0090
[1.1338995695114136, 0.009662957862019539, 0, 0.009662957862019539]

Save and restore the checkpoint

You can use a gcs bucket to store your checkpoint.

Make sure that you give the tpu worker access to the bucket by following the instructions.

model_dir = os.path.join(gcs_bucket, "saved_model")

Create the checkpoint for the TPU model and save the model to the bucket.

checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
saved_tpu_model_path = checkpoint.save(os.path.join(model_dir, "ckpt"))

You can list the variables that get stored in that path.

  [2048, 64]),
  [2048, 64]),
  [2048, 64]),
  [2048, 64]),
 ('model/ratings/layer_with_weights-0/bias/.ATTRIBUTES/VARIABLE_VALUE', [256]),
  [128, 256]),
  [128, 256]),
 ('model/ratings/layer_with_weights-1/bias/.ATTRIBUTES/VARIABLE_VALUE', [64]),
  [256, 64]),
  [256, 64]),
 ('model/ratings/layer_with_weights-2/bias/.ATTRIBUTES/VARIABLE_VALUE', [1]),
  [64, 1]),
  [64, 1]),
 ('model/task/_ranking_metrics/0/count/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('model/task/_ranking_metrics/0/total/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('save_counter/.ATTRIBUTES/VARIABLE_VALUE', [])]

You can restore the checkpoint later. This is a common practice to checkpoint your model for every epoch and restore that afterwards.

with strategy.scope():

Addtionally you can create a cpu model and restore the weights that gets trained on TPU.

cpu_model = EmbeddingModel()

# Create the cpu checkpoint and restore the tpu checkpoint.
cpu_checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=cpu_model)
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f20830fe5b0>

You can also restore embedding weights partially.

embedding_checkpoint = tf.train.Checkpoint(embedding=model.embedding_layer)
saved_embedding_path = embedding_checkpoint.save(
    os.path.join(model_dir, 'tpu-embedding'))
# Restore the embedding parameters on cpu model.
cpu_embedding_checkpoint = tf.train.Checkpoint(
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f20831bbeb0>
# Save the embedding parameters on cpu model and restore it to the tpu model.
saved_cpu_embedding_path = embedding_checkpoint.save(
    os.path.join(model_dir, 'cpu-embedding'))
with strategy.scope():


Finally, You can use the exported cpu model to do serving. Serving is accomplished through the tf.saved_model API

def serve_tensors(features):
  return cpu_model(features)

signatures = {
                    tf.TensorSpec(shape=(1,), dtype=tf.int32, name='movie_id'),
                    tf.TensorSpec(shape=(1,), dtype=tf.int32, name='user_id'),
    export_dir=os.path.join(model_dir, 'exported_model'),
WARNING:tensorflow:Skipping full serialization of Keras layer <tensorflow_recommenders.tasks.ranking.Ranking object at 0x7f20831ead00>, because it is not built.

The exported model can now be loaded (in Python or C) and used for serving

imported = tf.saved_model.load(os.path.join(model_dir, 'exported_model'))
predict_fn = imported.signatures['serving']

# Dummy serving data.
input_batch = {
    'movie_id': tf.constant(np.array([100]), dtype=tf.int32),
    'user_id': tf.constant(np.array([30]), dtype=tf.int32)
# The prediction it generates.
prediction = predict_fn(**input_batch)['output_0']
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:An attribute in the restored object could not be found in the checkpoint. Object: (root).embedding_layer._tpu_embedding, attribute: ['TPUEmbedding_saveable']

Additionally, you can pass the serving config to do serving.

Note that: You can use a subset of the trained embedding tables to do serving by using a serving config.

serving_config = {
    'movie_id': tf.tpu.experimental.embedding.FeatureConfig(table=movie_table),
    'user_id': tf.tpu.experimental.embedding.FeatureConfig(table=user_table)
prediction = cpu_model(input_batch, serving_config=serving_config)