View source on GitHub

A transformation that batches ragged elements into tf.RaggedTensors.
    batch_size, drop_remainder=False, row_splits_dtype=tf.dtypes.int64

This transformation combines multiple consecutive elements of the input dataset into a single element.

Like, the components of the resulting element will have an additional outer dimension, which will be batch_size (or N % batch_size for the last element if batch_size does not divide the number of input elements N evenly and drop_remainder is False). If your program depends on the batches having the same outer dimension, you should set the drop_remainder argument to True to prevent the smaller batch from being produced.

Unlike, the input elements to be batched may have different shapes, and each batch will be encoded as a tf.RaggedTensor. Example:

dataset = 
dataset = x: tf.range(x)) 
dataset = dataset.apply( 
for batch in dataset: 
<tf.RaggedTensor [[], [0]]> 
<tf.RaggedTensor [[0, 1], [0, 1, 2]]> 
<tf.RaggedTensor [[0, 1, 2, 3], [0, 1, 2, 3, 4]]> 


  • batch_size: A tf.int64 scalar tf.Tensor, representing the number of consecutive elements of this dataset to combine in a single batch.
  • drop_remainder: (Optional.) A tf.bool scalar tf.Tensor, representing whether the last batch should be dropped in the case it has fewer than batch_size elements; the default behavior is not to drop the smaller batch.
  • row_splits_dtype: The dtype that should be used for the row_splits of any new ragged tensors. Existing tf.RaggedTensor elements do not have their row_splits dtype changed.


  • Dataset: A Dataset.