Missed TensorFlow World? Check out the recap. Learn more

tf.batch_gather

tf.batch_gather(
    params,
    indices,
    name=None
)

Defined in tensorflow/python/ops/array_ops.py.

Gather slices from params according to indices with leading batch dims.

This operation assumes that the leading dimensions of indices are dense, and the gathers on the axis corresponding to the last dimension of indices. More concretely it computes:

result[i1, ..., in] = params[i1, ..., in-1, indices[i1, ..., in]]

Therefore params should be a Tensor of shape [A1, ..., AN, B1, ..., BM], indices should be a Tensor of shape [A1, ..., AN-1, C] and result will be a Tensor of size [A1, ..., AN-1, C, B1, ..., BM].

In the case in which indices is a 1D tensor, this operation is equivalent to tf.gather.

See also tf.gather and tf.gather_nd.

Args:

  • params: A Tensor. The tensor from which to gather values.
  • indices: A Tensor. Must be one of the following types: int32, int64. Index tensor. Must be in range [0, params.shape[axis], where axis is the last dimension of indices itself.
  • name: A name for the operation (optional).

Returns:

A Tensor. Has the same type as params.

Raises:

  • ValueError: if indices has an unknown shape.