ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

A transformation that batches ragged elements into tf.RaggedTensors.

Used in the notebooks

Used in the guide

This transformation combines multiple consecutive elements of the input dataset into a single element.

Like, the components of the resulting element will have an additional outer dimension, which will be batch_size (or N % batch_size for the last element if batch_size does not divide the number of input elements N evenly and drop_remainder is False). If your program depends on the batches having the same outer dimension, you should set the drop_remainder argument to True to prevent the smaller batch from being produced.

Unlike, the input elements to be batched may have different shapes:

  • If an input element is a tf.Tensor whose static tf.TensorShape is fully defined, then it is batched as normal.
  • If an input element is a tf.Tensor whose static tf.TensorShape contains one or more axes with unknown size (i.e., shape[i]=None), then the output will contain a tf.RaggedTensor that is ragged up to any of such dimensions.
  • If an input element is a tf.RaggedTensor or any other type, then it is batched as normal.


dataset =
dataset = x: tf.range(x))
dataset = dataset.apply(
for batch in dataset:
<tf.RaggedTensor [[], [0]]>
<tf.RaggedTensor [[0, 1], [0, 1, 2]]>
<tf.RaggedTensor [[0, 1, 2, 3], [0, 1, 2, 3, 4]]>

batch_size A tf.int64 scalar tf.Tensor, representing the number of consecutive elements of this dataset to combine in a single batch.
drop_remainder (Optional.) A