Join us at TensorFlow World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

nsl.lib.replicate_embeddings

View source on GitHub

Replicates the embeddings by replicate_times.

nsl.lib.replicate_embeddings(
    embeddings,
    replicate_times
)

This function is useful when comparing the same instance with multiple other instances. For example, given a seed and its neighbors, this function can be used to replicate the embeddings of the seed by the number of its neighbors, such that the distances between the seed and its neighbors can be computed efficiently.

The replicate_times is either a scalar, or a 1-D tensor. For example, given embeddings = [[0, 1, 2], [3, 4, 5], [6, 7, 8]], and replicate_times = 2, the returned tensor is [[0, 1, 2], [0, 1, 2], [3, 4, 5], [3, 4, 5], [6, 7, 8], [6, 7, 8]]. When replicate_times = [3, 0, 1], the returned tensor is [[0, 1, 2], [0, 1, 2], [0, 1, 2], [6, 7, 8]].

Args:

  • embeddings: A Tensor of shape {batch size, d1, ..., dN}.
  • replicate_times: A sclar or a 1-D Tensor of shape {batch size}. Each element indicates the number of times the corresponding row should be replicated.

Returns:

A Tensor of shape {sum of replicate_times, d1, ..., dN}.

Raises:

  • InvalidArgumentError: If replicate_times contain any negative value.
  • TypeError: If replicate_times cannot be cast to the int32.