![]() |
![]() |
Represents a potentially large set of elements.
tf.data.Dataset(
variant_tensor
)
The tf.data.Dataset
API supports writing descriptive and efficient input
pipelines. Dataset
usage follows a common pattern:
- Create a source dataset from your input data.
- Apply dataset transformations to preprocess the data.
- Iterate over the dataset and process the elements.
Iteration happens in a streaming fashion, so the full dataset does not need to fit into memory.
Source Datasets:
The simplest way to create a dataset is to create it from a python list
:
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset:
print(element)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
To process lines from files, use tf.data.TextLineDataset
:
dataset = tf.data.TextLineDataset(["file1.txt", "file2.txt"])
To process records written in the TFRecord
format, use TFRecordDataset
:
dataset = tf.data.TFRecordDataset(["file1.tfrecords", "file2.tfrecords"])
To create a dataset of all files matching a pattern, use
tf.data.Dataset.list_files
:
dataset = tf.data.Dataset.list_files("/path/*.txt") # doctest: +SKIP
See tf.data.FixedLengthRecordDataset
and tf.data.Dataset.from_generator
for more ways to create datasets.
Transformations:
Once you have a dataset, you can apply transformations to prepare the data for your model:
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.map(lambda x: x*2)
list(dataset.as_numpy_iterator())
[2, 4, 6]
Common Terms:
Element: A single output from calling next()
on a dataset iterator.
Elements may be nested structures containing multiple components. For
example, the element (1, (3, "apple"))
has one tuple nested in another
tuple. The components are 1
, 3
, and "apple"
.
Component: The leaf in the nested structure of an element.
Supported types:
Elements can be nested structures of tuples, named tuples, and dictionaries.
Element components can be of any type representable by tf.TypeSpec
,
including tf.Tensor
, tf.data.Dataset
, tf.SparseTensor
,
tf.RaggedTensor
, and tf.TensorArray
.
a = 1 # Integer element
b = 2.0 # Float element
c = (1, 2) # Tuple element with 2 components
d = {"a": (2, 2), "b": 3} # Dict element with 3 components
Point = collections.namedtuple("Point", ["x", "y"]) # doctest: +SKIP
e = Point(1, 2) # Named tuple # doctest: +SKIP
f = tf.data.Dataset.range(10) # Dataset element
Args | |
---|---|
variant_tensor
|
A DT_VARIANT tensor that represents the dataset. |
Attributes | |
---|---|
element_spec
|
The type specification of an element of this dataset.
|
Methods
apply
apply(
transformation_func
)
Applies a transformation function to this dataset.
apply
enables chaining of custom Dataset
transformations, which are
represented as functions that take one Dataset
argument and return a
transformed Dataset
.
dataset = tf.data.Dataset.range(100)
def dataset_fn(ds):
return ds.filter(lambda x: x < 5)
dataset = dataset.apply(dataset_fn)
list(dataset.as_numpy_iterator())
[0, 1, 2, 3, 4]
Args | |
---|---|
transformation_func
|
A function that takes one Dataset argument and
returns a Dataset .
|
Returns | |
---|---|
Dataset
|
The Dataset returned by applying transformation_func to this
dataset.
|
as_numpy_iterator
as_numpy_iterator()
Returns an iterator which converts all elements of the dataset to numpy.
Use as_numpy_iterator
to inspect the content of your dataset. To see
element shapes and types, print dataset elements directly instead of using
as_numpy_iterator
.
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset:
print(element)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
This method requires that you are running in eager mode and the dataset's
element_spec contains only TensorSpec
components.
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset.as_numpy_iterator():
print(element)
1
2
3
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
print(list(dataset.as_numpy_iterator()))
[1, 2, 3]
as_numpy_iterator()
will preserve the nested structure of dataset
elements.
dataset = tf.data.Dataset.from_tensor_slices({'a': ([1, 2], [3, 4]),
'b': [5, 6]})
list(dataset.as_numpy_iterator()) == [{'a': (1, 3), 'b': 5},
{'a': (2, 4), 'b': 6}]
True
Returns | |
---|---|
An iterable over the elements of the dataset, with their tensors converted to numpy arrays. |
Raises | |
---|---|
TypeError
|
if an element contains a non-Tensor value.
|
RuntimeError
|
if eager execution is not enabled. |
batch
batch(
batch_size, drop_remainder=False
)
Combines consecutive elements of this dataset into batches.
dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3)
list(dataset.as_numpy_iterator())
[array([0, 1, 2]), array([3, 4, 5]), array([6, 7])]
dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3, drop_remainder=True)
list(dataset.as_numpy_iterator())
[array([0, 1, 2]), array([3, 4, 5])]
The components of the resulting element will have an additional outer
dimension, which will be batch_size
(or N % batch_size
for the last
element if batch_size
does not divide the number of input elements N
evenly and drop_remainder
is False
). If your program depends on the
batches having the same outer dimension, you should set the drop_remainder
argument to True
to prevent the smaller batch from being produced.
Args | |
---|---|
batch_size
|
A tf.int64 scalar tf.Tensor , representing the number of
consecutive elements of this dataset to combine in a single batch.
|
drop_remainder
|
(Optional.) A tf.bool scalar tf.Tensor , representing
whether the last batch should be dropped in the case it has fewer than
batch_size elements; the default behavior is not to drop the smaller
batch.
|
Returns | |
---|---|
Dataset
|
A Dataset .
|
cache
cache(
filename=''
)
Caches the elements in this dataset.
The first time the dataset is iterated over, its elements will be cached either in the specified file or in memory. Subsequent iterations will use the cached data.
dataset = tf.data.Dataset.range(5)
dataset = dataset.map(lambda x: x**2)
dataset = dataset.cache()
# The first time reading through the data will generate the data using
# `range` and `map`.
list(dataset.as_numpy_iterator())
[0, 1, 4, 9, 16]
# Subsequent iterations read from the cache.
list(dataset.as_numpy_iterator())
[0, 1, 4, 9, 16]
When caching to a file, the cached data will persist across runs. Even the
first iteration through the data will read from the cache file. Changing
the input pipeline before the call to .cache()
will have no effect until
the cache file is removed or the filename is changed.
dataset = tf.data.Dataset.range(5)
dataset = dataset.cache("/path/to/file") # doctest: +SKIP
list(dataset.as_numpy_iterator()) # doctest: +SKIP
[0, 1, 2, 3, 4]
dataset = tf.data.Dataset.range(10)
dataset = dataset.cache("/path/to/file") # Same file! # doctest: +SKIP
list(dataset.as_numpy_iterator()) # doctest: +SKIP
[0, 1, 2, 3, 4]
Args | |
---|---|
filename
|
A tf.string scalar tf.Tensor , representing the name of a
directory on the filesystem to use for caching elements in this Dataset.
If a filename is not provided, the dataset will be cached in memory.
|
Returns | |
---|---|
Dataset
|
A Dataset .
|
concatenate
concatenate(
dataset
)
Creates a Dataset
by concatenating the given dataset with this dataset.
a = tf.data.Dataset.range(1, 4) # ==> [ 1, 2, 3 ]
b = tf.data.Dataset.range(4, 8) # ==> [ 4, 5, 6, 7 ]
ds = a.concatenate(b)
list(ds.as_numpy_iterator())
[1, 2, 3, 4, 5, 6, 7]
# The input dataset and dataset to be concatenated should have the same
# nested structures and output types.
c = tf.data.Dataset.zip((a, b))
a.concatenate(c)
Traceback (most recent call last):
TypeError: Two datasets to concatenate have different types
<dtype: 'int64'> and (tf.int64, tf.int64)
d = tf.data.Dataset.from_tensor_slices(["a", "b", "c"])
a.concatenate(d)
Traceback (most recent call last):
TypeError: Two datasets to concatenate have different types
<dtype: 'int64'> and <dtype: 'string'>
Args | |
---|---|
dataset
|
Dataset to be concatenated.
|
Returns | |
---|---|
Dataset
|
A Dataset .
|
enumerate
enumerate(
start=0
)
Enumerates the elements of this dataset.
It is similar to python's enumerate
.
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.enumerate(start=5)
for element in dataset.as_numpy_iterator():
print(element)
(5, 1)
(6, 2)
(7, 3)
# The nested structure of the input dataset determines the structure of
# elements in the resulting dataset.
dataset = tf.data.Dataset.from_tensor_slices([(7, 8), (9, 10)])
dataset = dataset.enumerate()
for element in dataset.as_numpy_iterator():
print(element)
(0, array([7, 8], dtype=int32))
(1, array([ 9, 10], dtype=int32))
Args | |
---|---|
start
|
A tf.int64 scalar tf.Tensor , representing the start value for
enumeration.
|
Returns | |
---|---|
Dataset
|
A Dataset .
|
filter
filter(
predicate
)
Filters this dataset according to predicate
.
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.filter(lambda x: x < 3)
list(dataset.as_numpy_iterator())
[1, 2]
# `tf.math.equal(x, y)` is required for equality comparison
def filter_fn(x):
return tf.math.equal(x, 1)
dataset = dataset.filter(filter_fn)
list(dataset.as_numpy_iterator())
[1]
Args | |
---|---|
predicate
|
A function mapping a dataset element to a boolean. |
Returns | |
---|---|
Dataset
|
The Dataset containing the elements of this dataset for which
predicate is True .
|
flat_map
flat_map(
map_func
)
Maps map_func
across this dataset and flattens the result.
Use flat_map
if you want to make sure that the order of your dataset
stays the same. For example, to flatten a dataset of batches into a
dataset of their elements:
dataset = Dataset.from_tensor_slices([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
dataset = dataset.flat_map(lambda x: Dataset.from_tensor_slices(x))
list(dataset.as_numpy_iterator())
[1, 2, 3, 4, 5, 6, 7, 8, 9]
tf.data.Dataset.interleave()
is a generalization of flat_map
, since
flat_map
produces the same output as
tf.data.Dataset.interleave(cycle_length=1)
Args | |
---|---|
map_func
|
A function mapping a dataset element to a dataset. |
Returns | |
---|---|
Dataset
|
A Dataset .
|
from_generator
@staticmethod
from_generator( generator, output_types, output_shapes=None, args=None )
Creates a Dataset
whose elements are generated by generator
.
The generator
argument must be a callable object that returns
an object that supports the iter()
protocol (e.g. a generator function).
The elements generated by generator
must be compatible with the given
output_types
and (optional) output_shapes
arguments.
import itertools
def gen():
for i in itertools.count(1):
yield (i, [1] * i)
dataset = tf.data.Dataset.from_generator(
gen,
(tf.int64, tf.int64),
(tf.TensorShape([]), tf.TensorShape([None])))
list(dataset.take(3).as_numpy_iterator())
[(1, array([1])), (2, array([1, 1])), (3, array([1, 1, 1]))]
Args | |
---|---|
generator
|
A callable object that returns an object that supports the
iter() protocol. If args is not specified, generator must take no
arguments; otherwise it must take as many arguments as there are values
in args .
|
output_types
|
A nested structure of tf.DType objects corresponding to
each component of an element yielded by generator .
|
output_shapes
|
(Optional.) A nested structure of tf.TensorShape objects
corresponding to each component of an element yielded by generator .
|
args
|
(Optional.) A tuple of tf.Tensor objects that will be evaluated
and passed to generator as NumPy-array arguments.
|
Returns | |
---|---|
Dataset
|
A Dataset .
|
from_tensor_slices
@staticmethod
from_tensor_slices( tensors )
Creates a Dataset
whose elements are slices of the given tensors.
The given tensors are sliced along their first dimension. This operation preserves the structure of the input tensors, removing the first dimension of each tensor and using it as the dataset dimension. All input tensors must have the same size in their first dimensions.
# Slicing a 1D tensor produces scalar tensor elements.
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
list(dataset.as_numpy_iterator())
[1, 2, 3]
# Slicing a 2D tensor produces 1D tensor elements.
dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4]])
list(dataset.as_numpy_iterator())
[array([1, 2], dtype=int32), array([3, 4], dtype=int32)]
# Slicing a tuple of 1D tensors produces tuple elements containing
# scalar tensors.
dataset = tf.data.Dataset.from_tensor_slices(([1, 2], [3, 4], [5, 6]))
list(dataset.as_numpy_iterator())
[(1, 3, 5), (2, 4, 6)]
# Dictionary structure is also preserved.
dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]})
list(dataset.as_numpy_iterator()) == [{'a': 1, 'b': 3},
{'a': 2, 'b': 4}]
True
# Two tensors can be combined into one Dataset object.
features = tf.constant([[1, 3], [2, 1], [3, 3]]) # ==> 3x2 tensor
labels = tf.constant(['A', 'B', 'A']) # ==> 3x1 tensor
dataset = Dataset.from_tensor_slices((features, labels))
# Both the features and the labels tensors can be converted
# to a Dataset object separately and combined after.
features_dataset = Dataset.from_tensor_slices(features)
labels_dataset = Dataset.from_tensor_slices(labels)
dataset = Dataset.zip((features_dataset, labels_dataset))
# A batched feature and label set can be converted to a Dataset
# in similar fashion.
batched_features = tf.constant([[[1, 3], [2, 3]],
[[2, 1], [1, 2]],
[[3, 3], [3, 2]]], shape=(3, 2, 2))
batched_labels = tf.constant([['A', 'A'],
['B', 'B'],
['A', 'B']], shape=(3, 2, 1))
dataset = Dataset.from_tensor_slices((batched_features, batched_labels))
for element in dataset.as_numpy_iterator():
print(element)
(array([[1, 3],
[2, 3]], dtype=int32), array([[b'A'],
[b'A']], dtype=object))
(array([[2, 1],
[1, 2]], dtype=int32), array([[b'B'],
[b'B']], dtype=object))
(array([[3, 3],
[3, 2]], dtype=int32), array([[b'A'],
[b'B']], dtype=object))
Note that if tensors
contains a NumPy array, and eager execution is not
enabled, the values will be embedded in the graph as one or more
tf.constant
operations. For large datasets (> 1 GB), this can waste
memory and run into byte limits of graph serialization. If tensors
contains one or more large NumPy arrays, consider the alternative described
in this guide.
Args | |
---|---|
tensors
|
A dataset element, with each component having the same size in the first dimension. |
Returns | |
---|---|
Dataset
|
A Dataset .
|
from_tensors
@staticmethod
from_tensors( tensors )
Creates a Dataset
with a single element, comprising the given tensors.
from_tensors
produces a dataset containing only a single element. To slice
the input tensor into multiple elements, use from_tensor_slices
instead.
dataset = tf.data.Dataset.from_tensors([1, 2, 3])
list(dataset.as_numpy_iterator())
[array([1, 2, 3], dtype=int32)]
dataset = tf.data.Dataset.from_tensors(([1, 2, 3], 'A'))
list(dataset.as_numpy_iterator())
[(array([1, 2, 3], dtype=int32), b'A')]
# You can use `from_tensors` to produce a dataset which repeats
# the same example many times.
example = tf.constant([1,2,3])
dataset = tf.data.Dataset.from_tensors(example).repeat(2)
list(dataset.as_numpy_iterator())
[array([1, 2, 3], dtype=int32), array([1, 2, 3], dtype=int32)]
Note that if tensors
contains a NumPy array, and eager execution is not
enabled, the values will be embedded in the graph as one or more
tf.constant
operations. For large datasets (> 1 GB), this can waste
memory and run into byte limits of graph serialization. If tensors
contains one or more large NumPy arrays, consider the alternative described
in this
guide.
Args | |
---|---|
tensors
|
A dataset element. |
Returns | |
---|---|
Dataset
|
A Dataset .
|
interleave
interleave(
map_func, cycle_length=AUTOTUNE, block_length=1, num_parallel_calls=None,
deterministic=None
)
Maps map_func
across this dataset, and interleaves the results.
For example, you can use Dataset.interleave()
to process many input files
concurrently:
# Preprocess 4 files concurrently, and interleave blocks of 16 records
# from each file.