tf.tpu.experimental.embedding.TPUEmbeddingV0

The TPUEmbedding mid level API running on TPU without Embedding accelerator.

This class has to be created under the TPUStrategy, Otherwise a RuntimeError will be raised.

strategy = tf.distribute.TPUStrategy(...)
with strategy.scope():
  embedding = tf.tpu.experimental.embedding.TPUEmbeddingV0(
      feature_config=feature_config,
      optimizer=tf.tpu.experimental.embedding.SGD(0.1))

When creating a distributed dataset that is to be passed to the lookup operation a special input option must be specified:

distributed_dataset = (
    strategy.distribute_datasets_from_function(
        dataset_fn=...,
        options=tf.distribute.InputOptions(
            experimental_fetch_to_device=False))
dataset_iterator = iter(distributed_dataset)

Below is an example of a training and evaluation step:

optimizer = tf.keras.optimizers.SGD(0.1)

@tf.function
def training_step(dataset_iterator, num_steps):
  def tpu_step(embedding_features):
    with tf.GradientTape() as tape:
      tape.watch(embedding.embedding_table.values())
      activation = embedding(embedding_features)
      model_output = model(activations)
      loss = ...  # some function of labels and model_output

    embedding_gradients = tape.gradient(loss,
                                        embedding.embedding_table.values())
    optimizer.apply_gradients(list(zip(gradients,
                              mid_level_api.embedding_tables.values())))
    # Insert your model gradient and optimizer application here

  for _ in tf.range(num_steps):
    strategy.run(tpu_step, args=(next(dataset_iterator), ))

@tf.function
def evalution_step(dataset_iterator, num_steps):
  def tpu_step(embedding_features):
    activations = embedding(embedding_features)
    model_output = model(activations)
    # Insert your evaluation code here.

  for _ in tf.range(num_steps):
    strategy.run(tpu_step, args=(next(dataset_iterator), ))
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.1)

def slot_variable_creation_fn(table, slot_names, slot_initializers):
    slots = {}
    for slot, initializer in zip(slot_names, slot_initializers):
      slots[slot] = optimizer.add_slot(table, slot, initializer)
    return slots

embedding_optimizer = tf.experimental.embedding.Adagrad(
    learning_rate=0.1,
    slot_variable_creation_fn=slot_variable_creation_fn)

# Use the embedding optimizer to create mid level api and keras optimizer to
# apply gradients.

embedding_tables Returns a dict of embedding tables, keyed by TableConfig.

Methods

build

View source

Create variables and slots variables for TPU embeddings.

embedding_lookup

View source

Apply embedding lookup on TPUs using Tensorcore.

Note that all the sparse and ragged tensors will be converted to dense tensors on CPU and then passed to the TPU to do embedding look up. Large embedding lookup is not supported by this API, use the TPUEmbedding mid level api instead.

Args
features a nested structure of Tensors, SparseTensors or RaggedTensors.
weights a nested structure of Tensors, SparseTensors or RaggedTensors or None for no weights. If not None, structure must match that of inputs, but entries are allowed to be None.

Returns
A nested structure of Tensors with the same structure as inputs.

__call__

View source

Call the mid level api to do embedding lookup.