tf.tpu.experimental.embedding.TPUEmbedding

The TPUEmbedding mid level API.

This class can be used to support training large embeddings on TPU. When creating an instance of this class, you must specify the complete set of tables and features you expect to lookup in those tables. See the documentation of tf.tpu.experimental.embedding.TableConfig and tf.tpu.experimental.embedding.FeatureConfig for more details on the complete set of options. We will cover the basic usage here.

table_config_one = tf.tpu.experimental.embedding.TableConfig(
    vocabulary_size=...,
    dim=...)
table_config_two = tf.tpu.experimental.embedding.TableConfig(
    vocabulary_size=...,
    dim=...)
feature_config = {
    'feature_one': tf.tpu.experimental.embedding.FeatureConfig(
        table=table_config_one),
    'feature_two': tf.tpu.experimental.embedding.FeatureConfig(
        table=table_config_one),
    'feature_three': tf.tpu.experimental.embedding.FeatureConfig(
        table=table_config_two)}

There are two modes under which the TPUEmbedding class can used. This depends on if the class was created under a TPUStrategy scope or not.

Under TPUStrategy, we allow access to the method enqueue, dequeue and apply_gradients. We will show examples below of how to use these to train and evaluate your model. Under CPU, we only access to the embedding_tables property which allow access to the embedding tables so that you can use them to run model evaluation/prediction on CPU.

First lets look at the TPUStrategy mode. Initial setup looks like:

strategy = tf.distribute.TPUStrategy(...)
with strategy.scope():
  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
      feature_config=feature_config,
      batch_size=1024,
      optimizer=tf.tpu.experimental.embedding.SGD(0.1))

When creating a distributed dataset that is to be passed to the enqueue operation a special input option must be specified:

distributed_dataset = (
    strategy.experimental_distribute_datasets_from_function(
        dataset_fn=...,
        options=tf.distribute.InputOptions(
            experimental_prefetch_to_device=False))
dataset_iterator = iter(distributed_dataset)

To use this API on TPU you should use a custom training loop. Below is an example of a training and evaluation step:

@tf.function
def training_step(dataset_iterator, num_steps):
  def tpu_step(tpu_features):
    with tf.GradientTape() as tape:
      activations = embedding.dequeue()
      tape.watch(activations)
      model_output = model(activations)
      loss = ...  # some function of labels and model_output

    embedding_gradients = tape.gradient(loss, activations)
    embedding.apply_gradients(embedding_gradients)
    # Insert your model gradient and optimizer application here

  for _ in tf.range(num_steps):
    embedding_features, tpu_features = next(dataset_iterator)
    embedding.enqueue(embedding_features, training=True)
    strategy.run(tpu_step, args=(embedding_features, ))

@tf.function
def evalution_step(dataset_iterator, num_steps):
  def tpu_step(tpu_features):
    activations = embedding.dequeue()
    model_output = model(activations)
    # Insert your evaluation code here.

  for _ in tf.range(num_steps):
    embedding_features, tpu_features = next(dataset_iterator)
    embedding.enqueue(embedding_features, training=False)
    strategy.run(tpu_step, args=(embedding_features, ))

In the above examples, we assume that the user has a dataset which returns a tuple where the first element of the tuple matches the structure of what was passed as the feature_config argument to the object initializer. Also we utilize tf.range to get a tf.while_loop in order to increase performance.

When checkpointing your model, you should include your tf.tpu.experimental.embedding.TPUEmbedding object in the checkpoint. It is a trackable object and saving it will save the embedding tables and their optimizer slot variables:

checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
checkpoint.save(...)

On CPU, only the embedding_table property is usable. This will allow you to restore a checkpoint to the object and have access to the table variables:

model = model_fn(...)
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
    feature_config=feature_config,
    batch_size=1024,
    optimizer=tf.tpu.experimental.embedding.SGD(0.1))
checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
checkpoint.restore(...)

tables = embedding.embedding_tables

You can now use table in functions like tf.nn.embedding_lookup to perform your embedding lookup and pass to your model.

feature_config A nested structure of tf.tpu.experimental.embedding.FeatureConfig configs.
batch_size The global batch size that you indend to use. Note that is fixed and the same batch size must be used for both training and evaluation.
optimizer An instance of one of tf.tpu.experimental.embedding.SGD, tf.tpu.experimental.embedding.Adagrad or tf.tpu.experimental.embedding.Adam.
pipeline_execution_with_tensor_core If True, the TPU embedding computations will overlap with the TensorCore computations (and hence will be one step old). Set to True for improved performance.
initialize_tpu_embedding If False, will not initialize the TPU embedding engine. If this is set to False and another instance of this class has not initialized the tpu embedding engine, the creation of this object will fail.

ValueError If optimizer is not one of tf.tpu.experimental.embedding.(SGD, Adam or Adagrad).

embedding_tables Returns a dict of embedding tables, keyed by TableConfig.

This property only works when the TPUEmbedding object is created under a non-TPU strategy. This is intended to be used to for CPU based lookup when creating a serving checkpoint.

Methods

apply_gradients

View source

Applies the gradient update to the embedding tables.

If a gradient of None is passed in any position of the nested structure, then an gradient update with a zero gradient is applied for that feature. For optimizers like SGD or Adagrad, this is the same as applying no update at all. For lazy Adam and other sparsely applied optimizers with decay, ensure you understand the effect of applying a zero gradient.

strategy = tf.distribute.TPUStrategy(...)
with strategy.scope():
  embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)

distributed_dataset = (
    strategy.experimental_distribute_datasets_from_function(
        dataset_fn=...,
        options=tf.distribute.InputOptions(
            experimental_prefetch_to_device=False))
dataset_iterator = iter(distributed_dataset)

@tf.function
def training_step():
  def tpu_step(tpu_features):
    with tf.GradientTape() as tape:
      activations = embedding.dequeue()
      tape.watch(activations)

      loss = ... #  some computation involving activations

    embedding_gradients = tape.gradient(loss, activations)
    embedding.apply_gradients(embedding_gradients)

  embedding_features, tpu_features = next(dataset_iterator)
  embedding.enqueue(embedding_features, training=True)
  strategy.run(tpu_step, args=(embedding_features, ))

training_step()

Args
gradients A nested structure of gradients, with structure matching the feature_config passed to this object.
name A name for the underlying op.

Raises
RuntimeError If called when object wasn't created under a TPUStrategy.
ValueError If a non-tf.Tensor non-None gradient is passed in, or a tf.Tensor of the incorrect shape is passed in. Also if the size of any sequence in gradients does not match corresponding sequence in feature_config.
TypeError If the type of any sequence in gradients does not match corresponding sequence in feature_config.

dequeue

View source

Get the embedding results.

Returns a nested structure of tf.Tensor objects, matching the structure of the feature_config argument to the TPUEmbedding class. The output shape of the tensors is (batch_size, dim), where batch_size is the per core batch size, dim is the dimension of the corresponding TableConfig. If the feature's corresponding FeatureConfig has max_sequence_length greater than 0, the output will be a sequence of shape (batch_size, max_sequence_length, dim) instead.

strategy = tf.distribute.TPUStrategy(...)
with strategy.scope():
  embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)

distributed_dataset = (
    strategy.experimental_distribute_datasets_from_function(
        dataset_fn=...,
        options=tf.distribute.InputOptions(
            experimental_prefetch_to_device=False))
dataset_iterator = iter(distributed_dataset)

@tf.function
def training_step():
  def tpu_step(tpu_features):
    with tf.GradientTape() as tape:
      activations = embedding.dequeue()
      tape.watch(activations)

      loss = ... #  some computation involving activations

    embedding_gradients = tape.gradient(loss, activations)
    embedding.apply_gradients(embedding_gradients)

  embedding_features, tpu_features = next(dataset_iterator)
  embedding.enqueue(embedding_features, training=True)
  strategy.run(tpu_step, args=(embedding_features, ))

training_step()

Args
name A name for the