![]() |
![]() |
Class Dataset
Represents a potentially large set of elements.
Inherits From: Dataset
Aliases:
- Class
tf.compat.v1.data.Dataset
A Dataset
can be used to represent an input pipeline as a
collection of elements (nested structures of tensors) and a "logical
plan" of transformations that act on those elements.
__init__
__init__()
Properties
output_classes
Returns the class of each component of an element of this dataset. (deprecated)
The expected values are tf.Tensor
and tf.SparseTensor
.
Returns:
A nested structure of Python type
objects corresponding to each
component of an element of this dataset.
output_shapes
Returns the shape of each component of an element of this dataset. (deprecated)
Returns:
A nested structure of tf.TensorShape
objects corresponding to each
component of an element of this dataset.
output_types
Returns the type of each component of an element of this dataset. (deprecated)
Returns:
A nested structure of tf.DType
objects corresponding to each component
of an element of this dataset.
Methods
__iter__
__iter__()
Creates an Iterator
for enumerating the elements of this dataset.
The returned iterator implements the Python iterator protocol and therefore can only be used in eager mode.
Returns:
An Iterator
over the elements of this dataset.
Raises:
RuntimeError
: If eager execution is not enabled.
apply
apply(transformation_func)
Applies a transformation function to this dataset.
apply
enables chaining of custom Dataset
transformations, which are
represented as functions that take one Dataset
argument and return a
transformed Dataset
.
For example:
dataset = (dataset.map(lambda x: x ** 2)
.apply(group_by_window(key_func, reduce_func, window_size))
.map(lambda x: x ** 3))
Args:
transformation_func
: A function that takes oneDataset
argument and returns aDataset
.
Returns:
Dataset
: TheDataset
returned by applyingtransformation_func
to this dataset.
batch
batch(
batch_size,
drop_remainder=False
)
Combines consecutive elements of this dataset into batches.
The tensors in the resulting element will have an additional outer
dimension, which will be batch_size
(or N % batch_size
for the last
element if batch_size
does not divide the number of input elements N
evenly and drop_remainder
is False
). If your program depends on the
batches having the same outer dimension, you should set the drop_remainder
argument to True
to prevent the smaller batch from being produced.
Args:
batch_size
: Atf.int64
scalartf.Tensor
, representing the number of consecutive elements of this dataset to combine in a single batch.drop_remainder
: (Optional.) Atf.bool
scalartf.Tensor
, representing whether the last batch should be dropped in the case it has fewer thanbatch_size
elements; the default behavior is not to drop the smaller batch.
Returns:
Dataset
: ADataset
.
cache
cache(filename='')
Caches the elements in this dataset.
Args:
filename
: Atf.string
scalartf.Tensor
, representing the name of a directory on the filesystem to use for caching tensors in this Dataset. If a filename is not provided, the dataset will be cached in memory.
Returns:
Dataset
: ADataset
.
concatenate
concatenate(dataset)
Creates a Dataset
by concatenating given dataset with this dataset.
a = Dataset.range(1, 4) # ==> [ 1, 2, 3 ]
b = Dataset.range(4, 8) # ==> [ 4, 5, 6, 7 ]
# Input dataset and dataset to be concatenated should have same
# nested structures and output types.
# c = Dataset.range(8, 14).batch(2) # ==> [ [8, 9], [10, 11], [12, 13] ]
# d = Dataset.from_tensor_slices([14.0, 15.0, 16.0])
# a.concatenate(c) and a.concatenate(d) would result in error.
a.concatenate(b) # ==> [ 1, 2, 3, 4, 5, 6, 7 ]
Args:
dataset
:Dataset
to be concatenated.
Returns:
Dataset
: ADataset
.
enumerate
enumerate(start=0)
Enumerates the elements of this dataset.
It is similar to python's enumerate
.
For example:
# NOTE: The following examples use `{ ... }` to represent the
# contents of a dataset.
a = { 1, 2, 3 }
b = { (7, 8), (9, 10) }
# The nested structure of the `datasets` argument determines the
# structure of elements in the resulting dataset.
a.enumerate(start=5)) == { (5, 1), (6, 2), (7, 3) }
b.enumerate() == { (0, (7, 8)), (1, (9, 10)) }
Args:
Returns:
Dataset
: ADataset
.
filter
filter(predicate)
Filters this dataset according to predicate
.
d = tf.data.Dataset.from_tensor_slices([1, 2, 3])
d = d.filter(lambda x: x < 3) # ==> [1, 2]
# `tf.math.equal(x, y)` is required for equality comparison
def filter_fn(x):
return tf.math.equal(x, 1)
d = d.filter(filter_fn) # ==> [1]
Args:
predicate
: A function mapping a nested structure of tensors (having shapes and types defined byself.output_shapes
andself.output_types
) to a scalartf.bool
tensor.
Returns:
Dataset
: TheDataset
containing the elements of this dataset for whichpredicate
isTrue
.
filter_with_legacy_function
filter_with_legacy_function(predicate)
Filters this dataset according to predicate
. (deprecated)
NOTE: This is an escape hatch for existing uses of filter
that do not work
with V2 functions. New uses are strongly discouraged and existing uses
should migrate to filter
as this method will be removed in V2.
Args:
predicate
: A function mapping a nested structure of tensors (having shapes and types defined byself.output_shapes
andself.output_types
) to a scalartf.bool
tensor.
Returns:
Dataset
: TheDataset
containing the elements of this dataset for whichpredicate
isTrue
.
flat_map
flat_map(map_func)
Maps map_func
across this dataset and flattens the result.
Use flat_map
if you want to make sure that the order of your dataset
stays the same. For example, to flatten a dataset of batches into a
dataset of their elements:
a = Dataset.from_tensor_slices([ [1, 2, 3], [4, 5, 6], [7, 8, 9] ])
a.flat_map(lambda x: Dataset.from_tensor_slices(x + 1)) # ==>
# [ 2, 3, 4, 5, 6, 7, 8, 9, 10 ]
tf.data.Dataset.interleave()
is a generalization of flat_map
, since
flat_map
produces the same output as
tf.data.Dataset.interleave(cycle_length=1)
Args:
map_func
: A function mapping a nested structure of tensors (having shapes and types defined byself.output_shapes
andself.output_types
) to aDataset
.
Returns:
Dataset
: ADataset
.
from_generator
@staticmethod
from_generator(
generator,
output_types,
output_shapes=None,
args=None
)
Creates a Dataset
whose elements are generated by generator
.
The generator
argument must be a callable object that returns
an object that support the iter()
protocol (e.g. a generator function).
The elements generated by generator
must be compatible with the given
output_types
and (optional) output_shapes
arguments.
For example:
import itertools
tf.compat.v1.enable_eager_execution()
def gen():
for i in itertools.count(1):
yield (i, [1] * i)
ds = tf.data.Dataset.from_generator(
gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None])))
for value in ds.take(2):
print value
# (1, array([1]))
# (2, array([1, 1]))
NOTE: The current implementation of Dataset.from_generator()
uses
tf.compat.v1.py_func
and inherits the same constraints. In particular, it
requires the Dataset
- and Iterator
-related operations to be placed
on a device in the same process as the Python program that called
Dataset.from_generator()
. The body of generator
will not be
serialized in a GraphDef
, and you should not use this method if you
need to serialize your model and restore it in a different environment.
NOTE: If generator
depends on mutable global variables or other external
state, be aware that the runtime may invoke generator
multiple times
(in order to support repeating the Dataset
) and at any time
between the call to Dataset.from_generator()
and the production of the
first element from the generator. Mutating global variables or external
state can cause undefined behavior, and we recommend that you explicitly
cache any external state in generator
before calling
Dataset.from_generator()
.
Args:
generator
: A callable object that returns an object that supports theiter()
protocol. Ifargs
is not specified,generator
must take no arguments; otherwise it must take as many arguments as there are values inargs
.output_types
: A nested structure oftf.DType
objects corresponding to each component of an element yielded bygenerator
.output_shapes
: (Optional.) A nested structure oftf.TensorShape
objects corresponding to each component of an element yielded bygenerator
.args
: (Optional.) A tuple oftf.Tensor
objects that will be evaluated and passed togenerator
as NumPy-array arguments.
Returns:
Dataset
: ADataset
.
from_sparse_tensor_slices
@staticmethod
from_sparse_tensor_slices(sparse_tensor)
Splits each rank-N tf.SparseTensor
in this dataset row-wise. (deprecated)
Args:
sparse_tensor
: Atf.SparseTensor
.
Returns:
Dataset
: ADataset
of rank-(N-1) sparse tensors.
from_tensor_slices
@staticmethod
from_tensor_slices(tensors)
Creates a Dataset
whose elements are slices of the given tensors.
Note that if tensors
contains a NumPy array, and eager execution is not
enabled, the values will be embedded in the graph as one or more
tf.constant
operations. For large datasets (> 1 GB), this can waste
memory and run into byte limits of graph serialization. If tensors
contains one or more large NumPy arrays, consider the alternative described
in this guide.
Args:
tensors
: A nested structure of tensors, each having the same size in the 0th dimension.
Returns:
Dataset
: ADataset
.
from_tensors
@staticmethod
from_tensors(tensors)
Creates a Dataset
with a single element, comprising the given tensors.
Note that if tensors
contains a NumPy array, and eager execution is not
enabled, the values will be embedded in the graph as one or more
tf.constant
operations. For large datasets (> 1 GB), this can waste
memory and run into byte limits of graph serialization. If tensors
contains one or more large NumPy arrays, consider the alternative described
in this
guide.
Args:
tensors
: A nested structure of tensors.
Returns:
Dataset
: ADataset
.
interleave
interleave(
map_func,
cycle_length,
block_length=1,
num_parallel_calls=None
)
Maps map_func
across this dataset, and interleaves the results.
For example, you can use Dataset.interleave()
to process many input files
concurrently:
# Preprocess 4 files concurrently, and interleave blocks of 16 records from
# each file.
filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ...]
dataset = (Dataset.from_tensor_slices(filenames)
.interleave(lambda x:
TextLineDataset(x).map(parse_fn, num_parallel_calls=1),
cycle_length=4, block_length=16))
The cycle_length
and block_length
arguments control the order in which
elements are produced. cycle_length
controls the number of input elements
that are processed concurrently. If you set cycle_length
to 1, this
transformation will handle one input element at a time, and will produce
identical results to tf.data.Dataset.flat_map
. In general,
this transformation will apply map_func
to cycle_length
input elements,
open iterators on the returned Dataset
objects, and cycle through them
producing block_length
consecutive elements from each iterator, and
consuming the next input element each time it reaches the end of an
iterator.
For example:
a = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ]
# NOTE: New lines indicate "block" boundaries.
a.interleave(lambda x: Dataset.from_tensors(x).repeat(6),
cycle_length=2, block_length=4) # ==> [1, 1, 1, 1,
# 2, 2, 2, 2,
# 1, 1,
# 2, 2,
# 3, 3, 3, 3,
# 4, 4, 4, 4,
# 3, 3,
# 4, 4,
# 5, 5, 5, 5,
# 5, 5]
NOTE: The order of elements yielded by this transformation is
deterministic, as long as map_func
is a pure function. If
map_func
contains any stateful operations, the order in which
that state is accessed is undefined.
Args:
map_func
: A function mapping a nested structure of tensors (having shapes and types defined byself.output_shapes
andself.output_types
) to aDataset
.cycle_length
: The number of elements from this dataset that will be processed concurrently.block_length
: The number of consecutive elements to produce from each input element before cycling to another input element.num_parallel_calls
: (Optional.) If specified, the implementation creates a threadpool, which is used to fetch inputs from cycle elements asynchronously and in parallel. The default behavior is to fetch inputs from cycle elements synchronously with no parallelism. If the valuetf.data.experimental.AUTOTUNE
is used, then the number of parallel calls is set dynamically based on available CPU.
Returns:
Dataset
: ADataset
.
list_files
@staticmethod
list_files(
file_pattern,
shuffle=None,
seed=None
)
A dataset of all files matching one or more glob patterns.
NOTE: The default behavior of this method is to return filenames in
a non-deterministic random shuffled order. Pass a seed
or shuffle=False
to get results in a deterministic order.
Example:
If we had the following files on our filesystem: - /path/to/dir/a.txt - /path/to/dir/b.py - /path/to/dir/c.py If we pass "/path/to/dir/*.py" as the directory, the dataset would produce: - /path/to/dir/b.py - /path/to/dir/c.py
Args:
file_pattern
: A string, a list of strings, or atf.Tensor
of string type (scalar or vector), representing the filename glob (i.e. shell wildcard) pattern(s) that will be matched.shuffle
: (Optional.) IfTrue
, the file names will be shuffled randomly. Defaults toTrue
.seed
: (Optional.) Atf.int64
scalartf.Tensor
, representing the random seed that will be used to create the distribution. Seetf.compat.v1.set_random_seed
for behavior.
Returns:
Dataset
: ADataset
of strings corresponding to file names.
make_initializable_iterator
make_initializable_iterator(shared_name=None)
Creates an Iterator
for enumerating the elements of this dataset. (deprecated)
dataset = ...
iterator = dataset.make_initializable_iterator()
# ...
sess.run(iterator.initializer)
Args:
shared_name
: (Optional.) If non-empty, the returned iterator will be shared under the given name across multiple sessions that share the same devices (e.g. when using a remote server).
Returns:
An Iterator
over the elements of this dataset.
Raises:
RuntimeError<