RSVP for your your local TensorFlow Everywhere event today!


A Keras layer for accelerating embedding lookups for large tables with TPU.

Feature and table configuration

When creating an instance of this layer, you must specify:

  1. The complete set of embedding tables,
  2. The features you expect to lookup in those tables and
  3. The optimizer(s) you wish to use on the tables.

See the documentation of tf.tpu.experimental.embedding.TableConfig and tf.tpu.experimental.embedding.FeatureConfig for more details on the complete set of options. We will cover the basic usage here.

table_config_one = tf.tpu.experimental.embedding.TableConfig(
table_config_two = tf.tpu.experimental.embedding.TableConfig(
feature_config = {
    'feature_one': tf.tpu.experimental.embedding.FeatureConfig(
    'feature_two': tf.tpu.experimental.embedding.FeatureConfig(
    'feature_three': tf.tpu.experimental.embedding.FeatureConfig(


An optimizer can be globally specified by passing one of the following types of input to the optimizer argument:

  1. A string, one of 'sgd', 'adagrad' or 'adam', which uses the given optimizer with the default parameters.
  2. An instance of a Keras optimizer.
  3. An instance of an optimizer class from the tf.tpu.experimental.embedding module.

You may also specify an optimizer as the table level via the optimizer argument of tf.tpu.experimental.embedding.TableConfig. This will completely override the global optimizer for this table. For performance reasons it is recommended that you minimize the total number of distinct optimizers.

Dynamic Learning Rate

Using a dynamic learning rate is supported for all optimizers, all other hyper parameters are static. There are two ways of specifying a dynamic learning rate in your optimizer:

  1. One of the objects in the tf.keras.optimizers.schedules name space.
  2. A python callable takeing no parameters which returns a scalar tensor of type tf.float32.


This method of specifying a learning schedule is only possible when using a Keras optimizer. In this case, set the learning rate of the optimizer to your desired tf.keras.optimizers.schedules object.


This method can be used if you use a Keras optimizer or one of the optimizer classes in the tf.tpu.experimental.embedding namespace.

In either case you should create a callable function that returns a tensor. This function will be called once, but the ops it generates will be reevaluated each step. Thus it is recommended that you either create a tf.Variable representing your current step counter or use the iterations property of an optimizer you call apply_gradients on each trianing step.

with strategy.scope():
  step = tf.Variable(
      initial_value=0, trainable=False, dtype=tf.int64,

Model creation

For a functional style Keras model:

strategy = tf.distribute.TPUStrategy(...)
with strategy.scope():
  embedding_inputs = {
      'feature_one': tf.keras.Input(batch_size=1024, shape=(),
      'feature_two': tf.keras.Input(batch_size=1024, shape=(),
                                    dtype=tf.int32, ragged=True),
      'feature_three': tf.keras.Input(batch_size=1024, shape=(),
  # embedding, feature_config and embedding_inputs all have the same nested
  # structure.
  embedding = tpu_embedding_layer.TPUEmbedding(
  logits = tf.keras.layers.Dense(1)(tf.concat(tf.nest.flatten(embedding)))
  model = tf.keras.Model(embedding_inputs, logits)

For a subclass style model:

class ModelWithEmbeddings(tf.keras.Model):
  def __init__(self):
    self.embedding_layer = tpu_embedding_layer.TPUEmbedding(

  def call(self, inputs):
    embedding = self.embedding_layer(inputs)
    logits = tf.keras.layers.Dense(1)(tf.concat(tf.nest.flatten(embedding)))

with strategy.scope():
  model = ModelWithEmbeddings()

Input data

When creating a distributed dataset that is to be passed to be used with a model that contains a TPUEmbedding layer, a special option must be specified when calling any of the dataset distribution methods of TPUStrategy:

distributed_dataset = (
dataset_iterator = iter(distributed_dataset)

Training and evaluation

To use this API on TPU you should use a custom training loop. Below is an example of a training and evaluation step:

def training_step(dataset_iterator, num_steps):
  def tpu_step(inputs):
    labels, features = inputs
    with tf.GradientTape() as tape:
      model_output = model(features)
      loss = ...  # some function of labels and model_output

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  for _ in tf.range(num_steps):, args=(next(dataset_iterator), ))

def evaluation_step(dataset_iterator, num_steps):
  def tpu_step(inputs):
    labels, features = inputs
    model_output = model(features)
    # Insert your evaluation code here.

  for _ in tf.range(num_steps):, args=(next(dataset_iterator), ))

In the above examples, we assume that the user has a dataset which returns a tuple where the second element of the tuple matches the structure of what was passed as the feature_config argument to the object initializer. Also we utilize tf.range to get a tf.while_loop in order to increase performance.

The embedding layer does not affect checkpointing; simply checkpoint your model as normal, remembering that if you passed either a Keras optimizer or an optimizer converted from a Keras optimizer via translate_keras_optimizer you must checkpoint the optimizer to ensure that your slot variables are saved.

checkpoint = tf.train.Checkpoint(model=model)


Serving is accomplished through the tf.saved_model API. The model may be exported directly from training.

First we write a tf.function that represents the serving graph. Typically this may take as input a string tensor containing protos that are parsed into tensors and then passed to the model. I.e.

            shape=[None], dtype=tf.string, name='examples')}])
def serve_examples(examples):
  input_data = ...  # parse the examples tensor to produce input tensors.
  return model(input_data),
                    signatures={'serving': serve_examples})

The exported model can now be loaded (in python or c) and used for serving:

imported = tf.saved_model.load(...)
predict_fn = imported.signatures['serving']

Using this layer on CPU

This layer can also be instantiated under a CPU strategy and used for local testing/training. The model created in such a way are checkpoint compatible with models created under TPUStrategy. In order to achieve checkpoint compatibility, you must use a Keras optimizers (or ones converted by translate_keras_optimizer) as your optimizers.

In the simplest case, where you use the same optimizer for your embedding and dense layers, the training_step above will function exactly the same in both situations.

If you use a separate Keras optimizer for your embedding layers (e.g. you want a different hyper parameter setting or an entirely different algorithm), special care must be observed to keep things the same. To understand why, there are a few technical details you need to know:

When created under TPUStrategy the underlying table variables are not considered trainable and are not available under model.trainable_variables. The main reason for this is that the table variables are just a stand-in for the real data which lives in the HBM of the TPU. These variables are stale and are only updated when saving and restoring checkpoints.

Because of this a standard optimizer.apply_gradient will not work on these variables. Instead a separate virtual trainable variable is added to the list of trainable variables and simply computing the gradient of this variable will cause the gradient for the embeddings to be computed and the optimizer applied.

When created under a CPU strategy, the table variables are created normally are part of the model's trainiable variables. In this case, if you are using a different optimizer to embedding tables, you must manually partition the variables and gradients so that you can use the Keras optmizer you created for embedding tables on the tables.


class ModelWithSeparateOptimizer(tf.keras.Model):
  def __init__(self, optimizer):
    self.embedding_layer = tpu_embedding_layer.TPUEmbedding(

  def call(self, inputs):
    embedding = self.embedding_layer(inputs)
    logits = tf.keras.layers.Dense(1)(tf.concat(tf.nest.flatten(embedding)))

with strategy.scope():
  embedding_optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.1)
  dense_optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)
  model = ModelWithSeparateOptimizer(embedding_optimizer)

def training_step(dataset_iterator, num_steps):
  def tpu_step(inputs):
    labels, features = inputs
    with tf.GradientTape() as tape:
      model_output = model(features)
      loss = ...  # some function of labels and model_output

    gradients = tape.gradient(loss, model.trainable_variables)
    grads_and_vars = zip(gradients, model.trainable_variables)

    # Note the use of 'id' here: 'x in y' uses x's equality method and if x is
    # a tensor this tf.math.equal rather than python object equality.
    embedding_var_ids = [
        id(v) for v in model.embedding_layer.trainable_variables]
    dense_grads_and_vars = [
        (g, v) for g, v in grads_and_vars
        if id(v) not in embedding_var_ids

    embedding_grads_and_vars = [
        (g, v) for g, v in grads_and_vars
        if id(v) in embedding_var_ids]

  for _ in tf.range(num_steps):, args=(next(dataset_iterator), ))

The above training step works both on TPU and on CPU.

feature_config A nested structure of tf.tpu.experimental.embedding.FeatureConfig configs.
optimizer An instance of one of tf.tpu.experimental.embedding.SGD, tf.tpu.experimental.embedding.Adagrad or tf.tpu.experimental.embedding.Adam, a Keras optimizer or a string name of an optimizer (see tf.keras.optimizers.get). Or, if not created under a TPU strategy, None, which will avoid creation of the optimizer slot variable do reduce memory consumption during export.
pipeline_execution_with_tensor_core If True, the TPU embedding computations will overlap with the TensorCore computations (and hence will be one step old with potential correctness drawbacks). Set to True for improved performance.
batch_size If set, this will be used as the global batch size and overrides the autodetection of the batch size from the layer's input. This is necesarry if all inputs to the layer's call are SparseTensors.

embedding_tables A mapping from table configs to tables.

When instantiated under a TPU strategy, this returns a sharded variable. This variable is strictly a placeholder used for saving and restoring. Attempting to assign values to this variable will not update the actual embedding tables and reading may result in reading a stale copy of the table. Should not be used for actual computation, only for exporting the model for serving.



View source

Look up features in the embedding tables and combine using weights.

features a nested structure of Tensors, SparseTensors or RaggedTensors with the same structure as feature_config. These tensors are used as ids to lookup rows in the embedding tables using the config as specified in the corresponding entry of feature_config. You can mix Tensors and SparseTensors, or Tensors and RaggedTensors, but not SparseTensors and RaggedTensors.
weights None, or a nested structure of Tensors,SparseTensors orRaggedTensors or None matching features. These are the weights used when combining the looked up rows for a given feature and examples. If None, weights of 1 will be used. </td> </tr><tr> <td>serving_config` A nested structure of tf.tpu.experimental.embedding.FeatureConfig objects. If not None, this layer uses CPU based lookup using serving_config and the current set of embedding tables.

The combined embedding activations for the input ids passed in via features.

RuntimeError If layer is not created under a TPU strategy and is called under a TPU strategy.