AllToAll

public final class AllToAll

An Op to exchange data across TPU replicas.

On each replica, the input is split into `split_count` blocks along `split_dimension` and send to the other replicas given group_assignment. After receiving `split_count` - 1 blocks from other replicas, we concatenate the blocks along `concat_dimension` as the output.

For example, suppose there are 2 TPU replicas: replica 0 receives input: `[[A, B]]` replica 1 receives input: `[[C, D]]`

group_assignment=`[[0, 1]]` concat_dimension=0 split_dimension=1 split_count=2

replica 0's output: `[[A], [C]]` replica 1's output: `[[B], [D]]`

Public Methods

Output<T>
asOutput()
Returns the symbolic handle of a tensor.
static <T> AllToAll<T>
create(Scope scope, Operand<T> input, Operand<Integer> groupAssignment, Long concatDimension, Long splitDimension, Long splitCount)
Factory method to create a class wrapping a new AllToAll operation.
Output<T>
output()
The exchanged result.

Inherited Methods

Public Methods

public Output<T> asOutput ()

Returns the symbolic handle of a tensor.

Inputs to TensorFlow operations are outputs of another TensorFlow operation. This method is used to obtain a symbolic handle that represents the computation of the input.

public static AllToAll<T> create (Scope scope, Operand<T> input, Operand<Integer> groupAssignment, Long concatDimension, Long splitDimension, Long splitCount)

Factory method to create a class wrapping a new AllToAll operation.

Parameters
scope current scope
input The local input to the sum.
groupAssignment An int32 tensor with shape [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the replica ids in the ith subgroup.
concatDimension The dimension number to concatenate.
splitDimension The dimension number to split.
splitCount The number of splits, this number must equal to the sub-group size(group_assignment.get_shape()[1])
Returns
  • a new instance of AllToAll

public Output<T> output ()

The exchanged result.