Represents a potentially large set of elements.

The tf.data.Dataset API supports writing descriptive and efficient input pipelines. Dataset usage follows a common pattern:

  1. Create a source dataset from your input data.
  2. Apply dataset transformations to preprocess the data.
  3. Iterate over the dataset and process the elements.

Iteration happens in a streaming fashion, so the full dataset does not need to fit into memory.

Source Datasets:

The simplest way to create a dataset is to create it from a python list:

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset:
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)

To process lines from files, use tf.data.TextLineDataset:

dataset = tf.data.TextLineDataset(["file1.txt", "file2.txt"])

To process records written in the TFRecord format, use TFRecordDataset:

dataset = tf.data.TFRecordDataset(["file1.tfrecords", "file2.tfrecords"])

To create a dataset of all files matching a pattern, use tf.data.Dataset.list_files:

dataset = tf.data.Dataset.list_files("/path/*.txt")

See tf.data.FixedLengthRecordDataset and tf.data.Dataset.from_generator for more ways to create datasets.


Once you have a dataset, you can apply transformations to prepare the data for your model:

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.map(lambda x: x*2)
[2, 4, 6]

Common Terms:

Element: A single output from calling next() on a dataset iterator. Elements may be nested structures containing multiple components. For example, the element (1, (3, "apple")) has one tuple nested in another tuple. The components are 1, 3, and "apple".

Component: The leaf in the nested structure of an element.

Supported types:

Elements can be nested structures of tuples, named tuples, and dictionaries. Note that Python lists are not treated as nested structures of components. Instead, lists are converted to tensors and treated as components. For example, the element (1, [1, 2, 3]) has only two components; the tensor 1 and the tensor [1, 2, 3]. Element components can be of any type representable by tf.TypeSpec, including tf.Tensor, tf.data.Dataset, tf.sparse.SparseTensor, tf.RaggedTensor, and tf.TensorArray.

a = 1 # Integer element
b = 2.0 # Float element
c = (1, 2) # Tuple element with 2 components
d = {"a": (2, 2), "b": 3} # Dict element with 3 components
Point = collections.namedtuple("Point", ["x", "y"])
e = Point(1, 2) # Named tuple
f = tf.data.Dataset.range(10) # Dataset element

variant_tensor A DT_VARIANT tensor that represents the dataset.

element_spec The type specification of an element of this dataset.

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
TensorSpec(shape=(), dtype=tf.int32, name=None)

Applies a transformation function to this dataset.

apply enables chaining of custom Dataset transformations, which are represented as functions that take one Dataset argument and return a transformed Dataset.

dataset = tf.data.Dataset.range(100)
def dataset_fn(ds):
  return ds.filter(lambda x: x < 5)
dataset = dataset.apply(dataset_fn)
[0, 1, 2, 3, 4]

transformation_func A function that takes one Dataset argument and returns a Dataset.

Dataset The Dataset returned by applying transformation_func to this dataset.


Returns an iterator which converts all elements of the dataset to numpy.

Use as_numpy_iterator to inspect the content of your dataset. To see element shapes and types, print dataset elements directly instead of using as_numpy_iterator.

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset:
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)

This method requires that you are running in eager mode and the dataset's element_spec contains only TensorSpec components.

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset.as_numpy_iterator():
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
[1, 2, 3]

as_numpy_iterator() will preserve the nested structure of dataset elements.

dataset = tf.data.Dataset.from_tensor_slices({'a': ([1, 2], [3, 4]),
                                              'b': [5, 6]})
list(dataset.as_numpy_iterator()) == [{'a': (1, 3), 'b': 5},
                                      {'a': (2, 4), 'b': 6}]

An iterable over the elements of the dataset, with their tensors converted to numpy arrays.

TypeError if an element contains a non-Tensor value.
RuntimeError if eager execution is not enabled.


Combines consecutive elements of this dataset into batches.

dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3)
[array([0, 1, 2]), array([3, 4, 5]), array([6, 7])]
dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3, drop_remainder=True)
[array([0, 1, 2]), array([3, 4, 5])]

The components of the resulting element will have an additional outer dimension, which will be batch_size (or N % batch_size for the last element if batch_size does not divide the number of input elements N evenly and drop_remainder is False). If your program depends on the batches having the same outer dimensi