AllToAll

public final class AllToAll

An Op to exchange data across TPU replicas.

On each replica, the input is split into `split_count` blocks along `split_dimension` and send to the other replicas given group_assignment. After receiving `split_count` - 1 blocks from other replicas, we concatenate the blocks along `concat_dimension` as the output.

For example, suppose there are 2 TPU replicas: replica 0 receives input: `[[A, B]]` replica 1 receives input: `[[C, D]]`

group_assignment=`[[0, 1]]` concat_dimension=0 split_dimension=1 split_count=2

replica 0's output: `[[A], [C]]` replica 1's output: `[[B], [D]]`

Constants

String OP_NAME The name of this op, as known by TensorFlow core engine

Public Methods

Output <T>
asOutput ()
Returns the symbolic handle of the tensor.
static <T extends TType > AllToAll <T>
create ( Scope scope, Operand <T> input, Operand < TInt32 > groupAssignment, Long concatDimension, Long splitDimension, Long splitCount)
Factory method to create a class wrapping a new AllToAll operation.
Output <T>
output ()
The exchanged result.

Inherited Methods

Constants

public static final String OP_NAME

The name of this op, as known by TensorFlow core engine

Constant Value: "AllToAll"

Public Methods

public Output <T> asOutput ()

Returns the symbolic handle of the tensor.

Inputs to TensorFlow operations are outputs of another TensorFlow operation. This method is used to obtain a symbolic handle that represents the computation of the input.

public static AllToAll <T> create ( Scope scope, Operand <T> input, Operand < TInt32 > groupAssignment, Long concatDimension, Long splitDimension, Long splitCount)

Factory method to create a class wrapping a new AllToAll operation.

Parameters
scope current scope
input The local input to the sum.
groupAssignment An int32 tensor with shape [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the replica ids in the ith subgroup.
concatDimension The dimension number to concatenate.
splitDimension The dimension number to split.
splitCount The number of splits, this number must equal to the sub-group size(group_assignment.get_shape()[1])
Returns
  • a new instance of AllToAll

public Output <T> output ()

The exchanged result.