View source on GitHub

TPU version of tf.compat.v1.feature_column.shared_embedding_columns.

    categorical_columns, dimension, combiner='mean', initializer=None,
    shared_embedding_collection_name=None, max_sequence_lengths=None,
    learning_rate_fn=None, embedding_lookup_device=None, tensor_core_shape=None

Note that the interface for tf.tpu.experimental.shared_embedding_columns is different from that of tf.compat.v1.feature_column.shared_embedding_columns: The following arguments are NOT supported: ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable.

Use this function in place of tf.compat.v1.feature_column.shared_embedding_columns` when you want to use the TPU to accelerate your embedding lookups via TPU embeddings.

column_a = tf.feature_column.categorical_column_with_identity(...)
column_b = tf.feature_column.categorical_column_with_identity(...)
tpu_columns = tf.tpu.experimental.shared_embedding_columns(
    [column_a, column_b], 10)
def model_fn(features):
  dense_feature = tf.keras.layers.DenseFeature(tpu_columns)
  embedded_feature = dense_feature(features)

estimator = tf.estimator.tpu.TPUEstimator(


  • categorical_columns: A list of categorical columns returned from categorical_column_with_identity, weighted_categorical_column, categorical_column_with_vocabulary_file, categorical_column_with_vocabulary_list, sequence_categorical_column_with_identity, sequence_categorical_column_with_vocabulary_file, sequence_categorical_column_with_vocabulary_list
  • dimension: An integer specifying dimension of the embedding, must be > 0.
  • combiner: A string specifying how to reduce if there are multiple entries in a single row for a non-sequence column. For more information, see tf.feature_column.embedding_column.
  • initializer: A variable initializer function to be used in embedding variable initialization. If not specified, defaults to tf.truncated_normal_initializer with mean 0.0 and standard deviation 1/sqrt(dimension).
  • shared_embedding_collection_name: Optional name of the collection where shared embedding weights are added. If not given, a reasonable name will be chosen based on the names of categorical_columns. This is also used in variable_scope when creating shared embedding weights.
  • max_sequence_lengths: An list of non-negative integers, either None or empty or the same length as the argument categorical_columns. Entries corresponding to non-sequence columns must be 0 and entries corresponding to sequence columns specify the max sequence length for the column. Any sequence shorter then this will be padded with 0 embeddings and any sequence longer will be truncated.
  • learning_rate_fn: A function that takes global step and returns learning rate for the embedding table.
  • embedding_lookup_device: The device on which to run the embedding lookup. Valid options are "cpu", "tpu_tensor_core", and "tpu_embedding_core". If specifying "tpu_tensor_core", a tensor_core_shape must be supplied. Defaults to "cpu". If not specified, the default behavior is embedding lookup on "tpu_embedding_core" for training and "cpu" for inference. Valid options for training : ["tpu_embedding_core", "tpu_tensor_core"] Valid options for serving : ["cpu", "tpu_tensor_core"] For training, tpu_embedding_core is good for large embedding vocab (>1M), otherwise, tpu_tensor_core is often sufficient. For serving, doing embedding lookup on tpu_tensor_core during serving is a way to reduce host cpu usage in cases where that is a bottleneck.
  • tensor_core_shape: If supplied, a list of integers which specifies the intended dense shape to run embedding lookup for this feature on TensorCore. The batch dimension can be left None or -1 to indicate a dynamic shape. Only rank 2 shapes currently supported.


A list of _TPUSharedEmbeddingColumnV2.


  • ValueError: if dimension not > 0.
  • ValueError: if initializer is specified but not callable.
  • ValueError: if max_sequence_lengths is specified and not the same length as categorical_columns.
  • ValueError: if max_sequence_lengths is positive for a non sequence column or 0 for a sequence column.