ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tf.gather

Gather slices from params axis axis according to indices. (deprecated arguments)

Used in the notebooks

Used in the guide Used in the tutorials

Gather slices from params axis axis according to indices. indices must be an integer tensor of any dimension (often 1-D).

Tensor.getitem works for scalars, tf.newaxis, and python slices

tf.gather extends indexing to handle tensors of indices.

In the simplest case it's identical to scalar indexing:

params = tf.constant(['p0', 'p1', 'p2', 'p3', 'p4', 'p5'])
params[3].numpy()
b'p3'
tf.gather(params, 3).numpy()
b'p3'

The most common case is to pass a single axis tensor of indices (this can't be expressed as a python slice because the indices are not sequential):

indices = [2, 0, 2, 5]
tf.gather(params, indices).numpy()
array([b'p2', b'p0', b'p2', b'p5'], dtype=object)

The indices can have any shape. When the params has 1 axis, the output shape is equal to the input shape:

tf.gather(params, [[2, 0], [2, 5]]).numpy()
array([[b'p2', b'p0'],
       [b'p2', b'p5']], dtype=object)

The params may also have any shape. gather can select slices across any axis depending on the axis argument (which defaults to 0). Below it is used to gather first rows, then columns from a matrix:

params = tf.constant([[0, 1.0, 2.0],
                      [10.0, 11.0, 12.0],
                      [20.0, 21.0, 22.0],
                      [30.0, 31.0, 32.0]])
tf.gather(params, indices=[3,1]).numpy()
array([[30., 31., 32.],
       [10., 11., 12.]], dtype=float32)
tf.gather(params, indices=[2,1], axis=1).numpy()
array([[ 2.,  1.],
       [12., 11.],