![]() |
![]() |
Represents a potentially large set of elements.
tf.data.Dataset(
variant_tensor
)
Used in the notebooks
Used in the guide | Used in the tutorials |
---|---|
The tf.data.Dataset
API supports writing descriptive and efficient input
pipelines. Dataset
usage follows a common pattern:
- Create a source dataset from your input data.
- Apply dataset transformations to preprocess the data.
- Iterate over the dataset and process the elements.
Iteration happens in a streaming fashion, so the full dataset does not need to fit into memory.
Source Datasets:
The simplest way to create a dataset is to create it from a python list
:
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset:
print(element)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
To process lines from files, use tf.data.TextLineDataset
:
dataset = tf.data.TextLineDataset(["file1.txt", "file2.txt"])
To process records written in the TFRecord
format, use TFRecordDataset
:
dataset = tf.data.TFRecordDataset(["file1.tfrecords", "file2.tfrecords"])
To create a dataset of all files matching a pattern, use
tf.data.Dataset.list_files
:
dataset = tf.data.Dataset.list_files("/path/*.txt") # doctest: +SKIP
See tf.data.FixedLengthRecordDataset
and tf.data.Dataset.from_generator
for more ways to create datasets.
Transformations:
Once you have a dataset, you can apply transformations to prepare the data for your model:
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.map(lambda x: x*2)
list(dataset.as_numpy_iterator())
[2, 4, 6]
Common Terms:
Element: A single output from calling next()
on a dataset iterator.
Elements may be nested structures containing multiple components. For
example, the element (1, (3, "apple"))
has one tuple nested in another
tuple. The components are 1
, 3
, and "apple"
.
Component: The leaf in the nested structure of an element.
Supported types:
Elements can be nested structures of tuples, named tuples, and dictionaries.
Note that Python lists are not treated as nested structures of components.
Instead, lists are converted to tensors and treated as components. For
example, the element (1, [1, 2, 3])
has only two components; the tensor 1
and the tensor [1, 2, 3]
. Element components can be of any type
representable by tf.TypeSpec
, including tf.Tensor
, tf.data.Dataset
,
tf.sparse.SparseTensor
, tf.RaggedTensor
, and tf.TensorArray
.
a = 1 # Integer element
b = 2.0 # Float element
c = (1, 2) # Tuple element with 2 components
d = {"a": (2, 2), "b": 3} # Dict element with 3 components
Point = collections.namedtuple("Point", ["x", "y"]) # doctest: +SKIP
e = Point(1, 2) # Named tuple # doctest: +SKIP
f = tf.data.Dataset.range(10) # Dataset element
Args | |
---|---|
variant_tensor
|
A DT_VARIANT tensor that represents the dataset. |
Attributes | |
---|---|
element_spec
|
The type specification of an element of this dataset.
|
Methods
apply
apply(
transformation_func
)
Applies a transformation function to this dataset.
apply
enables chaining of custom Dataset
transformations, which are
represented as functions that take one Dataset
argument and return a
transformed Dataset
.
dataset = tf.data.Dataset.range(100)
def dataset_fn(ds):
return ds.filter(lambda x: x < 5)
dataset = dataset.apply(dataset_fn)
list(dataset.as_numpy_iterator())
[0, 1, 2, 3, 4]
Args | |
---|---|
transformation_func
|
A function that takes one Dataset argument and
returns a Dataset .
|
Returns | |
---|---|
Dataset
|
The Dataset returned by applying transformation_func to this
dataset.
|
as_numpy_iterator
as_numpy_iterator()
Returns an iterator which converts all elements of the dataset to numpy.
Use as_numpy_iterator
to inspect the content of your dataset. To see
element shapes and types, print dataset elements directly instead of using
as_numpy_iterator
.
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset:
print(element)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
This method requires that you are running in eager mode and the dataset's
element_spec contains only TensorSpec
components.
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset.as_numpy_iterator():
print(element)
1
2
3
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
print(list(dataset.as_numpy_iterator()))
[1, 2, 3]
as_numpy_iterator()
will preserve the nested structure of dataset
elements.
dataset = tf.data.Dataset.from_tensor_slices({'a': ([1, 2], [3, 4]),
'b': [5, 6]})
list(dataset.as_numpy_iterator()) == [{'a': (1, 3), 'b': 5},
{'a': (2, 4), 'b': 6}]
True
Returns | |
---|---|
An iterable over the elements of the dataset, with their tensors converted to numpy arrays. |
Raises | |
---|---|
TypeError
|
if an element contains a non-Tensor value.
|
RuntimeError
|
if eager execution is not enabled. |
batch
batch(
batch_size, drop_remainder=False
)
Combines consecutive elements of this dataset into batches.
dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3)
list(dataset.as_numpy_iterator())
[array([0, 1, 2]), array([3, 4, 5]), array([6, 7])]
dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3, drop_remainder=True)
list(dataset.as_numpy_iterator())
[array([0, 1, 2]), array([3, 4, 5])]
The components of the resulting element will have an additional outer
dimension, which will be batch_size
(or N % batch_size
for the last
element if batch_size
does not divide the number of input elements N
evenly and drop_remainder
is False
). If your program depends on the
batches having the same outer dimension, you should set the drop_remainder
argument to True
to prevent the smaller batch from being produced.
Args | |
---|---|
batch_size
|
A tf.int64 scalar tf.Tensor , representing the number of
consecutive elements of this dataset to combine in a single batch.
|
drop_remainder
|
(Optional.) A tf.bool scalar tf.Tensor , representing
whether the last batch should be dropped in the case it has fewer than
batch_size elements; the default behavior is not to drop the smaller
batch.
|
Returns | |
---|---|
Dataset
|
A Dataset .
|
cache
cache(
filename=''
)
Caches the elements in this dataset.
The first time the dataset is iterated over, its elements will be cached either in the specified file or in memory. Subsequent iterations will use the cached data.
dataset = tf.data.Dataset.range(5)
dataset = dataset.map(lambda x: x**2)
dataset = dataset.cache()
# The first time reading through the data will generate the data using
# `range` and `map`.
list(dataset.as_numpy_iterator())
[0, 1, 4, 9, 16]
# Subsequent iterations read from the cache.
list(dataset.as_numpy_iterator())
[0, 1, 4, 9, 16]
When caching to a file, the cached data will persist across runs. Even the
first iteration through the data will read from the cache file. Changing
the input pipeline before the call to .cache()
will have no effect until
the cache file is removed or the filename is changed.
dataset = tf.data.Dataset.range(5)
dataset = dataset.cache("/path/to/file") # doctest: +SKIP
list(dataset.as_numpy_iterator()) # doctest: +SKIP
[0, 1, 2, 3, 4]
dataset = tf.data.Dataset.range(10)
dataset = dataset.cache("/path/to/file") # Same file! # doctest: +SKIP
list(dataset.as_numpy_iterator()) # doctest: +SKIP
[0, 1, 2, 3, 4]
Args | |
---|---|
filename
|
A tf.string scalar tf.Tensor , representing the name of a
directory on the filesystem to use for caching elements in this Dataset.
If a filename is not provided, the dataset will be cached in memory.
|
Returns | |
---|---|
Dataset
|
A Dataset .
|
cardinality
cardinality()
Returns the cardinality of the dataset, if known.
cardinality
may return tf.data.INFINITE_CARDINALITY
if the dataset
contains an infinite number of elements or tf.data.UNKNOWN_CARDINALITY
if
the analysis fails to determine the number of elements in the dataset
(e.g. when the dataset source is a file).
dataset = tf.data.Dataset.range(42)
print(dataset.cardinality().numpy())
42
dataset = dataset.repeat()
cardinality = dataset.cardinality()
print((cardinality == tf.data.INFINITE_CARDINALITY).numpy())
True
dataset = dataset.filter(lambda x: True)
cardinality = dataset.cardinality()
print((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy())
True
Returns | |
---|---|
A scalar tf.int64 Tensor representing the cardinality of the dataset.
If the cardinality is infinite or unknown, cardinality returns the
named constants tf.data.INFINITE_CARDINALITY and
tf.data.UNKNOWN_CARDINALITY respectively.
|
concatenate
concatenate(
dataset
)
Creates a Dataset
by concatenating the given dataset with this dataset.
a = tf.data.Dataset.range(1, 4) # ==> [ 1, 2, 3 ]
b = tf.data.Dataset.range(4, 8) # ==> [ 4, 5, 6, 7 ]
ds = a.concatenate(b)
list(ds.as_numpy_iterator())
[1, 2, 3, 4, 5, 6, 7]
# The input dataset and dataset to be concatenated should have the same
# nested structures and output types.
c = tf.data.Dataset.zip((a, b))
a.concatenate(c)
Traceback (most recent call last):
TypeError: Two datasets to concatenate have different types
<dtype: 'int64'> and (tf.int64, tf.int64)
d = tf.data.Dataset.from_tensor_slices(["a", "b", "c"])
a.concatenate(d)
Traceback (most recent call last):
TypeError: Two datasets to concatenate have different types
<dtype: 'int64'> and <dtype: 'string'>
Args | |
---|---|
|