tf.split

Splits a tensor value into a list of sub tensors.

Used in the notebooks

Used in the guide Used in the tutorials

See also tf.unstack.

If num_or_size_splits is an integer, then value is split along the dimension axis into num_or_size_splits smaller tensors. This requires that value.shape[axis] is divisible by num_or_size_splits.

If num_or_size_splits is a 1-D Tensor (or list), then value is split into len(num_or_size_splits) elements. The shape of the i-th element has the same size as the value except along dimension axis where the size is num_or_size_splits[i].

For example:

x = tf.Variable(tf.random.uniform([5, 30], -1, 1))

# Split `x` into 3 tensors along dimension 1
s0, s1, s2 = tf.split(x, num_or_size_splits=3, axis=1)
tf.shape(s0).numpy()
array([ 5, 10], dtype=int32)

# Split `x` into 3 tensors with sizes [4, 15, 11] along dimension 1
split0, split1, split2 = tf.split(