tfr.utils.gather_per_row

Gathers the values from input tensor based on per-row indices.

Example Usage:

scores = [[1., 3., 2.], [1., 2., 3.]]
indices = [[1, 2], [2, 1]]
tfr.utils.gather_per_row(scores, indices)

Returns [[3., 2.], [3., 2.]]

inputs (tf.Tensor) A tensor of shape [batch_size, list_size] or [batch_size, list_size, feature_dims].
indices (tf.Tensor) A tensor of shape [batch_size, size] of positions to gather inputs from. Each index corresponds to a row entry in input_tensor.

A tensor of values gathered from inputs, of shape [batch_size, size] or [batch_size, size, feature_dims], depending on whether the input was 2D or 3D.