Join us virtually at the Women in ML Symposium on October 19 Register now

tf.gather

Gather slices from params axis axis according to indices. (deprecated arguments)

Used in the notebooks

Used in the guide Used in the tutorials

Gather slices from params axis axis according to indices. indices must be an integer tensor of any dimension (often 1-D).

Tensor.getitem works for scalars, tf.newaxis, and python slices

tf.gather extends indexing to handle tensors of indices.

In the simplest case it's identical to scalar indexing:

params = tf.constant(['p0', 'p1', 'p2', 'p3', 'p4', 'p5'])
params[3].numpy()
b'p3'
tf.gather(params, 3).numpy()
b'p3'

The most common case is to pass a single axis tensor of indices (this can't be expressed as a python slice because the indices are not sequential):

indices = [2, 0, 2, 5]
tf.gather(params, indices).numpy()
array([b'p2', b'p0', b'p2', b'p5'], dtype=object)

The indices can have any shape. When the params has 1 axis, the output shape is equal to the input shape:

tf.gather(params, [[2, 0], [2, 5]]).numpy()
array([[b'p2', b'p0'],
       [b'p2', b'p5']], dtype=object)

The params may also have any shape. gather can select slices across any axis depending on the axis argument (which defaults to 0). Below it is used to gather first rows, then columns from a matrix:

params = tf.constant([[0, 1.0, 2.0],
                      [10.0, 11.0, 12.0],
                      [20.0, 21.0, 22.0],
                      [30.0, 31.0, 32.0]])
tf.gather(params, indices=[3,1]).numpy()
array([[30., 31., 32.],
       [10., 11., 12.]], dtype=float32)
tf.gather(params, indices=[2,1], axis=1).numpy()
array([[ 2.,  1.],
       [12., 11.],
       [22., 21.],
       [32., 31.]], dtype=float32)

More generally: The output shape has the same shape as the input, with the indexed-axis replaced by the shape of the indices.

def result_shape(p_shape, i_shape, axis=0):
  return p_shape[:axis] + i_shape + p_shape[axis+1:]

result_shape([1, 2, 3], [], axis=1)
[1, 3]
result_shape([1, 2, 3], [7], axis=1)
[1, 7, 3]
result_shape([1, 2, 3], [7, 5], axis=1)
[1, 7, 5, 3]

Here are some examples:

params.shape.as_list()
[4, 3]
indices = tf.constant([[0, 2]])
tf.gather(params, indices=indices, axis=0).shape.as_list()
[1, 2, 3]
tf.gather(params, indices=indices, axis=1).shape.as_list()
[4, 1, 2]