tf.data.Dataset

Represents a potentially large set of elements.

The tf.data.Dataset API supports writing descriptive and efficient input pipelines. Dataset usage follows a common pattern:

  1. Create a source dataset from your input data.
  2. Apply dataset transformations to preprocess the data.
  3. Iterate over the dataset and process the elements.

Iteration happens in a streaming fashion, so the full dataset does not need to fit into memory.

Source Datasets:

The simplest way to create a dataset is to create it from a python list:

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset:
  print(element)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)

To process lines from files, use tf.data.TextLineDataset:

dataset = tf.data.TextLineDataset(["file1.txt", "file2.txt"])

To process records written in the TFRecord format, use TFRecordDataset:

dataset = tf.data.TFRecordDataset(["file1.tfrecords", "file2.tfrecords"])

To create a dataset of all files matching a pattern, use tf.data.Dataset.list_files:

dataset = tf.data.Dataset.list_files("/path/*.txt")

See tf.data.FixedLengthRecordDataset and tf.data.Dataset.from_generator for more ways to create datasets.

Transformations:

Once you have a dataset, you can apply transformations to prepare the data for your model:

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.map(lambda x: x*2)
list(dataset.as_numpy_iterator())
[2, 4, 6]

Common Terms:

Element: A single output from calling next() on a dataset iterator. Elements may be nested structures containing multiple components. For example, the element (1, (3, "apple")) has one tuple nested in another tuple. The components are 1, 3, and "apple".

Component: The leaf in the nested structure of an element.

Supported types:

Elements can be nested structures of tuples, named tuples, and dictionaries. Note that Python lists are not treated as nested structures of components. Instead, lists are converted to tensors and treated as components. For example, the element (1, [1, 2, 3]) has only two components; the tensor 1 and the tensor [1, 2, 3]. Element components can be of any type representable by tf.TypeSpec, including tf.Tensor, tf.data.Dataset, tf.sparse.SparseTensor, tf.RaggedTensor, and tf.TensorArray.

a = 1 # Integer element
b = 2.0 # Float element
c = (1, 2) # Tuple element with 2 components
d = {"a": (2, 2), "b": 3} # Dict element with 3 components
Point = collections.namedtuple("Point", ["x", "y"])
e = Point(1, 2) # Named tuple
f = tf.data.Dataset.range(10) # Dataset element

For more information, read this guide.

variant_tensor A DT_VARIANT tensor that represents the dataset.

element_spec The type specification of an element of this dataset.

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset.element_spec
TensorSpec(shape=(), dtype=tf.int32, name=None)

For more information, read this guide.

Methods

apply

View source

Applies a transformation function to this dataset.

apply enables chaining of custom Dataset transformations, which are represented as functions that take one Dataset argument and return a transformed Dataset.

dataset = tf.data.Dataset.range(100)
def dataset_fn(ds):
  return ds.filter(lambda x: x < 5)
dataset = dataset.apply(dataset_fn)
list(dataset.as_numpy_iterator())
[0, 1, 2, 3, 4]

Args
transformation_func A function that takes one Dataset argument and returns a Dataset.

Returns
Dataset The Dataset returned by applying transformation_func to this dataset.

as_numpy_iterator

View source

Returns an iterator which converts all elements of the dataset to numpy.

Use as_numpy_iterator to inspect the content of your dataset. To see element shapes and types, print dataset elements directly instead of using as_numpy_iterator.

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset:
  print(element)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)

This method requires that you are running in eager mode and the dataset's element_spec contains only TensorSpec components.

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset.as_numpy_iterator():
  print(element)
1
2
3
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
print(list(dataset.as_numpy_iterator()))
[1, 2, 3]

as_numpy_iterator() will preserve the nested structure of dataset elements.

dataset = tf.data.Dataset.from_tensor_slices({'a': ([1, 2], [3, 4]),
                                              'b': [5, 6]})
list(dataset.as_numpy_iterator()) == [{'a': (1, 3), 'b': 5},
                                      {'a': (2, 4), 'b': 6}]
True

Returns
An iterable over the elements of the dataset, with their tensors converted to numpy arrays.

Raises
TypeError if an element contains a non-Tensor value.
RuntimeError if eager execution is not enabled.

batch

View source

Combines consecutive elements of this dataset into batches.

dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3)
list(dataset.as_numpy_iterator())
[array([0, 1, 2]), array([3, 4, 5]), array([6, 7])]
dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3, drop_remainder=True)
list(dataset.as_numpy_iterator())
[array([0, 1, 2]), array([3, 4, 5])]

The components of the resulting element will have an additional outer dimension, which will be batch_size (or N % batch_size for the last element if batch_size does not divide the number of input elements N evenly and drop_remainder is False). If your program depends on the batches having the same outer dimension, you should set the drop_remainder argument to True to prevent the smaller batch from being produced.

Args
batch_size A tf.int64 scalar tf.Tensor, representing the number of consecutive elements of this dataset to combine in a single batch.
drop_remainder (Optional.) A tf.bool scalar tf.Tensor, representing whether the last batch should be dropped in the case it has fewer than batch_size elements; the default behavior is not to drop the smaller batch.
num_parallel_calls (Optional.) A tf.int64 scalar tf.Tensor, representing the number of batches to compute asynchronously in parallel. If not specified, batches will be computed sequentially. If the value tf.data.AUTOTUNE is used, then the number of parallel calls is set dynamically based on available resources.
deterministic (Optional.) When num_parallel_calls is specified, if this boolean is specified (True or False), it controls the order in which the transformation produces elements. If set to False, the transformation is allowed to yield elements out of order to trade determinism for performance. If not specified, the tf.data.Options.experimental_deterministic option (True by default) controls the behavior.

Returns
Dataset A Dataset.

bucket_by_sequence_length

View source

A transformation that buckets elements in a Dataset by length.

Elements of the Dataset are grouped together by length and then are padded and batched.

This is useful for sequence tasks in which the elements have variable length. Grouping together elements that have similar lengths reduces the total fraction of padding in a batch which increases training step efficiency.

Below is an example to bucketize the input data to the 3 buckets "[0, 3), [3, 5), [5, inf)" based on sequence length, with batch size 2.

elements = [
  [0], [1, 2, 3, 4], [5, 6, 7],
  [7, 8, 9, 10, 11], [13, 14, 15, 16, 19, 20], [21, 22]]
dataset = tf.data.Dataset.from_generator(
    lambda: elements, tf.int64, output_shapes=[None])
dataset = dataset.bucket_by_sequence_length(
        element_length_func=lambda elem: tf.shape(elem)[0],
        bucket_boundaries=[3, 5],
        bucket_batch_sizes=[2, 2, 2])
for elem in dataset.as_numpy_iterator():
  print(elem)
[[1 2 3 4]
[5 6 7 0]]
[[ 7  8  9 10 11  0]
[13 14 15 16 19 20]]
[[ 0  0]
[21 22]]

Args
element_length_func function from element in Dataset to tf.int32, determines the length of the element, which will determine the bucket it goes into.
bucket_boundaries list<int>, upper length boundaries of the buckets.
bucket_batch_sizes list<int>, batch size per bucket. Length should be len(bucket_boundaries) + 1.
padded_shapes Nested structure of tf.TensorShape to pass to tf.data.Dataset.padded_batch. If not provided, will use dataset.output_shapes, which will result in variable length dimensions being padded out to the maximum length in each batch.
padding_values Values to pad with, passed to tf.data.Dataset.padded_batch. Defaults to padding with 0.
pad_to_bucket_boundary bool, if False, will pad dimensions with unknown size to maximum length in batch. If True, will pad dimensions with unknown size to bucket boundary minus 1 (i.e., the maximum length in each bucket), and caller must ensure that the source Dataset does not contain any elements with length longer than max(bucket_boundaries).
no_padding bool, indicates whether to pad the batch features (features need to be either of type tf.sparse.SparseTensor or of same shape).
drop_remainder (Optional.) A tf.bool scalar tf.Tensor, representing whether the last batch should be dropped in the case it has fewer than batch_size elements; the default behavior is not to drop the smaller batch.

Returns
A Dataset.

Raises
ValueError if len(bucket_batch_sizes) != len(bucket_boundaries) + 1.

cache

View source

Caches the elements in this dataset.

The first time the dataset is iterated over, its elements will be cached either in the specified file or in memory. Subsequent iterations will use the cached data.

dataset = tf.data.Dataset.range(5)
dataset = dataset.map(lambda x: x**2)
dataset = dataset.cache()
# The first time reading through the data will generate the data using
# `range` and `map`.
list(dataset.as_numpy_iterator())
[0, 1, 4, 9, 16]
# Subsequent iterations read from the cache.
list(dataset.as_numpy_iterator())
[0, 1, 4, 9, 16]

When caching to a file, the cached data will persist across runs. Even the first iteration through the data will read from the cache file. Changing the input pipeline before the call to .cache() will have no effect until the cache file is removed or the filename is changed.

dataset = tf.data.Dataset.range(5)
dataset = dataset.cache("/path/to/file")
list(dataset.as_numpy_iterator())
# [0, 1, 2, 3, 4]
dataset = tf.data.Dataset.range(10)
dataset = dataset.cache("/path/to/file")  # Same file!
list(dataset.as_numpy_iterator())
# [0, 1, 2, 3, 4]

Args
filename A tf.string scalar tf.Tensor, representing the name of a directory on the filesystem to use for caching elements in this Dataset. If a filename is not provided, the dataset will be cached in memory.

Returns
Dataset A Dataset.

cardinality

View source

Returns the cardinality of the dataset, if known.

cardinality may return tf.data.INFINITE_CARDINALITY if the dataset contains an infinite number of elements or tf.data.UNKNOWN_CARDINALITY if the analysis fails to determine the number of elements in the dataset (e.g. when the dataset source is a file).

dataset = tf.data.Dataset.range(42)
print(dataset.cardinality().numpy())
42
dataset = dataset.repeat()
cardinality = dataset.cardinality()
print((cardinality == tf.data.INFINITE_CARDINALITY).numpy())
True
dataset = dataset.filter(lambda x: True)
cardinality = dataset.cardinality()
print((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy())
True

Returns
A scalar tf.int64 Tensor representing the cardinality of the dataset. If the cardinality is infinite or unknown, cardinality returns the named constants tf.data.INFINITE_CARDINALITY and tf.data.UNKNOWN_CARDINALITY respectively.

concatenate

View source

Creates a Dataset by concatenating the given dataset with this dataset.

a = tf.data.Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
b = tf.data.Dataset.range(4, 8)  # ==> [ 4, 5, 6, 7 ]
ds = a.concatenate(b)
list(ds.as_numpy_iterator())
[1, 2, 3, 4, 5, 6, 7]
# The input dataset and dataset to be concatenated should have
# compatible element specs.
c = tf.data.Dataset.zip((a, b))
a.concatenate(c)
Traceback (most recent call last):
TypeError: Two datasets to concatenate have different types
<dtype: 'int64'> and (tf.int64, tf.int64)
d = tf.data.Dataset.from_tensor_slices(["a", "b", "c"])
a.concatenate(d)
Traceback (most recent call last):
TypeError: Two datasets to concatenate have different types
<dtype: 'int64'> and <dtype: 'string'>

Args
dataset Dataset to be concatenated.

Returns
Dataset A Dataset.

enumerate

View source

Enumerates the elements of this dataset.

It is similar to python's enumerate.

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.enumerate(start=5)
for element in dataset.as_numpy_iterator():
  print(element)
(5, 1)
(6, 2)
(7, 3)
# The (nested) structure of the input dataset determines the
# structure of elements in the resulting dataset.
dataset = tf.data.Dataset.from_tensor_slices([(7, 8), (9, 10)])
dataset = dataset.enumerate()
for element in dataset.as_numpy_iterator():
  print(element)
(0, array([7, 8], dtype=int32))
(1, array([ 9, 10], dtype=int32))

Args
start A tf.int64 scalar tf.Tensor, representing the start value for enumeration.

Returns
Dataset A Dataset.

filter

View source

Filters this dataset according to predicate.

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.filter(lambda x: x < 3)
list(dataset.as_numpy_iterator())
[1, 2]
# `tf.math.equal(x, y)` is required for equality comparison
def filter_fn(x):
  return tf.math.equal(x, 1)
dataset = dataset.filter(filter_fn)
list(dataset.as_numpy_iterator())
[1]

Args
predicate A function mapping a dataset element to a boolean.

Returns
Dataset The Dataset containing the elements of this dataset for which predicate is True.

flat_map

View source

Maps map_func across this dataset and flattens the result.

The type signature is:

def flat_map(
  self: Dataset[T],
  map_func: Callable[[T], Dataset[S]]
) -> Dataset[S]

Use flat_map if you want to make sure that the order of your dataset stays the same. For example, to flatten a dataset of batches into a dataset of their elements:

dataset = tf.data.Dataset.from_tensor_slices(
    [[1, 2, 3], [4, 5, 6], [7, 8, 9]])
dataset = dataset.flat_map(
    lambda x: tf.data.Dataset.from_tensor_slices(x))
list(dataset.as_numpy_iterator())
[1, 2, 3, 4, 5, 6, 7, 8, 9]

tf.data.Dataset.interleave() is a generalization of flat_map, since flat_map produces the same output as tf.data.Dataset.interleave(cycle_length=1)

Args
map_func A function mapping a dataset element to a dataset.

Returns
Dataset A Dataset.

from_generator

View source

Creates a Dataset whose elements are generated by generator. (deprecated arguments)

The generator argument must be a callable object that returns an object that supports the iter() protocol (e.g. a generator function).

The elements generated by generator must be compatible with either the given output_signature argument or with the given output_types and (optionally) output_shapes arguments, whichever was specified.

The recommended way to call from_generator is to use the output_signature argument. In this case the output will be assumed to consist of objects with the classes, shapes and types defined by tf.TypeSpec objects from output_signature argument:

def gen():
  ragged_tensor = tf.ragged.constant([[1, 2], [3]])
  yield 42, ragged_tensor

dataset = tf.data.Dataset.from_generator(
     gen,
     output_signature=(
         tf.TensorSpec(shape=(), dtype=tf.int32),
         tf.RaggedTensorSpec(shape=(2, None), dtype=tf.int32)))

list(dataset.take(1))
[(<tf.Tensor: shape=(), dtype=int32, numpy=42>,
<tf.RaggedTensor [[1, 2], [3]]>)]

There is also a deprecated way to call from_generator by either with output_types argument alone or together with output_shapes argument. In this case the output of the function will be assumed to consist of tf.Tensor objects with the types defined by output_types and with the shapes which are either unknown or defined by output_shapes.

Args
generator A callable object that returns an object that supports the iter() protocol. If args is not specified, generator must take no arguments; otherwise it must take as many arguments as there are values in args.
output_types (Optional.) A (nested) structure of tf.DType objects corresponding to each component of an element yielded by generator.
output_shapes (Optional.) A (nested) structure of tf.TensorShape objects corresponding to each component of an element yielded by generator.
args (Optional.) A tuple of tf.Tensor objects that will be evaluated and passed to generator as NumPy-array arguments.
output_signature (Optional.) A (nested) structure of tf.TypeSpec objects corresponding to each component of an element yielded by generator.

Returns
Dataset A Dataset.

from_tensor_slices

View source

Creates a Dataset whose elements are slices of the given tensors.

The given tensors are sliced along their first dimension. This operation preserves the structure of the input tensors, removing the first dimension of each tensor and using it as the dataset dimension. All input tensors must have the same size in their first dimensions.

# Slicing a 1D tensor produces scalar tensor elements.
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
list(dataset.as_numpy_iterator())
[1, 2, 3]
# Slicing a 2D tensor produces 1D tensor elements.
dataset = tf.data.Dataset.from_tensor_slices([[1, 2], [3, 4]])
list(dataset.as_numpy_iterator())
[array([1, 2], dtype=int32), array([3, 4], dtype=int32)]
# Slicing a tuple of 1D tensors produces tuple elements containing
# scalar tensors.
dataset = tf.data.Dataset.from_tensor_slices(([1, 2], [3, 4], [5, 6]))
list(dataset.as_numpy_iterator())
[(1, 3, 5), (2, 4, 6)]
# Dictionary structure is also preserved.
dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]})
list(dataset.as_numpy_iterator()) == [{'a': 1, 'b': 3},
                                      {'a': 2, 'b': 4}]
True
# Two tensors can be combined into one Dataset object.
features = tf.constant([[1, 3], [2, 1], [3, 3]]) # ==> 3x2 tensor
labels = tf.constant(['A', 'B', 'A']) # ==> 3x1 tensor
dataset = Dataset.from_tensor_slices((features, labels))
# Both the features and the labels tensors can be converted
# to a Dataset object separately and combined after.
features_dataset = Dataset.from_tensor_slices(features)
labels_dataset = Dataset.from_tensor_slices(labels)
dataset = Dataset.zip((features_dataset, labels_dataset))
# A batched feature and label set can be converted to a Dataset
# in similar fashion.
batched_features = tf.constant([[[1, 3], [2, 3]],
                                [[2, 1], [1, 2]],
                                [[3, 3], [3, 2]]], shape=(3, 2, 2))
batched_labels = tf.constant([['A', 'A'],
                              ['B', 'B'],
                              ['A', 'B']], shape=(3, 2, 1))
dataset = Dataset.from_tensor_slices((batched_features, batched_labels))
for element in dataset.as_numpy_iterator():
  print(element)
(array([[1, 3],
       [2, 3]], dtype=int32), array([[b'A'],
       [b'A']], dtype=object))
(array([[2, 1],
       [1, 2]], dtype=int32), array([[b'B'],
       [b'B']], dtype=object))
(array([[3, 3],
       [3, 2]], dtype=int32), array([[b'A'],
       [b'B']], dtype=object))

Note that if tensors contains a NumPy array, and eager execution is not enabled, the values will be embedded in the graph as one or more tf.constant operations. For large datasets (> 1 GB), this can waste memory and run into byte limits of graph serialization. If tensors contains one or more large NumPy arrays, consider the alternative described in this guide.

Args
tensors A dataset element, whose components have the same first dimension. Supported values are documented here.

Returns
Dataset A Dataset.

from_tensors

View source

Creates a Dataset with a single element, comprising the given tensors.

from_tensors produces a dataset containing only a single element. To slice the input tensor into multiple elements, use from_tensor_slices instead.

dataset = tf.data.Dataset.from_tensors([1, 2, 3])
list(dataset.as_numpy_iterator())
[array([1, 2, 3], dtype=int32)]
dataset = tf.data.Dataset.from_tensors(([1, 2, 3], 'A'))
list(dataset.as_numpy_iterator())
[(array([1, 2, 3], dtype=int32), b'A')]
# You can use `from_tensors` to produce a dataset which repeats
# the same example many times.
example = tf.constant([1,2,3])
dataset = tf.data.Dataset.from_tensors(example).repeat(2)
list(dataset.as_numpy_iterator())
[array([1, 2, 3], dtype=int32), array([1, 2, 3], dtype=int32)]

Note that if tensors contains a NumPy array, and eager execution is not enabled, the values will be embedded in the graph as one or more tf.constant operations. For large datasets (> 1 GB), this can waste memory and run into byte limits of graph serialization. If tensors contains one or more large NumPy arrays, consider the alternative described in this guide.

Args
tensors A dataset "element". Supported values are documented here.

Returns
Dataset A Dataset.

get_single_element

View source

Returns the single element of the dataset as a nested structure of tensors.

The function enables you to use a tf.data.Dataset in a stateless "tensor-in tensor-out" expression, without creating an iterator. This facilitates the ease of data transformation on tensors using the optimized tf.data.Dataset abstraction on top of them.

For example, lets consider a preprocessing_fn which would take as an input the raw features and returns the processed feature along with it's label.

def preprocessing_fn(raw_feature):
  # ... the raw_feature is preprocessed as per the use-case
  return feature

raw_features = ...  # input batch of BATCH_SIZE elements.
dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
          .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
          .batch(BATCH_SIZE))

processed_features = dataset.get_single_element()

In the above example, the raw_features tensor of length=BATCH_SIZE was converted to a tf.data.Dataset. Next, each of the raw_feature was mapped using the preprocessing_fn and the processed features were grouped into a single batch. The final dataset contains only one element which is a batch of all the processed features.

Now, instead of creating an iterator for the dataset and retrieving the batch of features, the tf.data.get_single_element() function is used to skip the iterator creation process and directly output the batch of features.

This can be particularly useful when your tensor transformations are expressed as tf.data.Dataset operations, and you want to use those transformations while serving your model.

Keras


model = ... # A pre-built or custom model

class PreprocessingModel(tf.keras.Model):
  def __init__(self, model):
    super().__init__(self)
    self.model = model

  @tf.function(input_signature=[...])
  def serving_fn(self, data):
    ds = tf.data.Dataset.from_tensor_slices(data)
    ds = ds.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
    ds = ds.batch(batch_size=BATCH_SIZE)
    return tf.argmax(self.model(ds.get_single_element()), axis=-1)

preprocessing_model = PreprocessingModel(model)
your_exported_model_dir = ... # save the model to this path.
tf.saved_model.save(preprocessing_model, your_exported_model_dir,
              signatures={'serving_default': preprocessing_model.serving_fn}
              )

Estimator

In the case of estimators, you need to generally define a serving_input_fn which would require the features to be processed by the model while inferencing.

def serving_input_fn():

  raw_feature_spec = ... # Spec for the raw_features
  input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
      raw_feature_spec, default_batch_size=None)
  )
  serving_input_receiver = input_fn()
  raw_features = serving_input_receiver.features

  def preprocessing_fn(raw_feature):
    # ... the raw_feature is preprocessed as per the use-case
    return feature

  dataset = (tf.data.Dataset.from_tensor_slices(raw_features)
            .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
            .batch(BATCH_SIZE))

  processed_features = dataset.get_single_element()

  # Please note that the value of `BATCH_SIZE` should be equal to
  # the size of the leading dimension of `raw_features`. This ensures
  # that `dataset` has only element, which is a pre-requisite for
  # using `dataset.get_single_element()`.

  return tf.estimator.export.ServingInputReceiver(
      processed_features, serving_input_receiver.receiver_tensors)

estimator = ... # A pre-built or custom estimator
estimator.export_saved_model(your_exported_model_dir, serving_input_fn)

Returns
A nested structure of tf.Tensor objects, corresponding to the single element of dataset.

Raises
InvalidArgumentError (at runtime) if dataset does not contain exactly one element.

group_by_window

View source

Groups windows of elements by key and reduces them.

This transformation maps each consecutive element in a dataset to a key using key_func and groups the elements by key. It then applies reduce_func to at most window_size_func(key) elements matching the same key. All except the final window for each key will contain window_size_func(key) elements; the final window may be smaller.

You may provide either a constant window_size or a window size determined by the key through window_size_func.

dataset = tf.data.Dataset.range(10)
window_size = 5
key_func = lambda x: x%2
reduce_func = lambda key, dataset: dataset.batch(window_size)
dataset = dataset.group_by_window(
          key_func=key_func,
          reduce_func=reduce_func,
          window_size=window_size)
for elem in dataset.as_numpy_iterator():
  print(elem)
[0 2 4 6 8]
[1 3 5 7 9]

Args
key_func A function mapping a nested structure of tensors (having shapes and types defined by self.output_shapes and self.output_types) to a scalar tf.int64 tensor.
reduce_func A function mapping a key and a dataset of up to window_size consecutive elements matching that key to another dataset.
window_size A tf.int64 scalar tf.Tensor, representing the number of consecutive elements matching the same key to combine in a single batch, which will be passed to reduce_func. Mutually exclusive with window_size_func.
window_size_func A function mapping a key to a tf.int64 scalar tf.Tensor, representing the number of consecutive elements matching the same key to combine in a single batch, which will be passed to reduce_func. Mutually exclusive with window_size.

Returns
A Dataset.

Raises
ValueError if neither or both of {window_size, window_size_func} are passed.

interleave

View source

Maps map_func across this dataset, and interleaves the results.

The type signature is:

def interleave(
  self: Dataset[T],
  map_func: Callable[[T], Dataset[S]]
) -> Dataset[S]

For example, you can use Dataset.interleave() to process many input files concurrently:

# Preprocess 4 files concurrently, and interleave blocks of 16 records
# from each file.
filenames = ["/var/data/file1.txt", "/var/data/file2.txt",
             "/var/data/file3.txt", "/var/data/file4.txt"]
dataset = tf.data.Dataset.from_tensor_slices(filenames)
def parse_fn(filename):
  return tf.data.Dataset.range(10)
dataset = dataset.interleave(lambda x:
    tf.data.TextLineDataset(x).map(parse_fn, num_parallel_calls=1),
    cycle_length=4, block_length=16)

The cycle_length and block_length arguments control the order in which elements are produced. cycle_length controls the number of input elements that are processed concurrently. If you set cycle_length to 1, this transformation will handle one input element at a time, and will produce identical results to tf.data.Dataset.flat_map. In general, this transformation will apply map_func to cycle_length input elements, open iterators on the returned Dataset objects, and cycle through them producing block_length consecutive elements from each iterator, and consuming the next input element each time it reaches the end of an iterator.

For example:

dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
# NOTE: New lines indicate "block" boundaries.
dataset = dataset.interleave(
    lambda x: Dataset.from_tensors(x).repeat(6),
    cycle_length=2, block_length=4)
list(dataset.as_numpy_iterator())
[1, 1, 1, 1,
 2, 2, 2, 2,
 1, 1,
 2, 2,
 3, 3, 3, 3,
 4, 4, 4, 4,
 3, 3,
 4, 4,
 5, 5, 5, 5,
 5, 5]

Performance can often be improved by setting num_parallel_calls so that interleave will use multiple threads to fetch elements. If determinism isn't required, it can also improve performance to set deterministic=False.

filenames = ["/var/data/file1.txt", "/var/data/file2.txt",
             "/var/data/file3.txt", "/var/data/file4.txt"]
dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.interleave(lambda x: tf.data.TFRecordDataset(x),
    cycle_length=4, num_parallel_calls=tf.data.AUTOTUNE,
    deterministic=False)

Args
map_func A function mapping a dataset element to a dataset.
cycle_length (Optional.) The number of input elements that will be processed concurrently. If not set, the tf.data runtime decides what it should be based on available CPU. If num_parallel_calls is set to tf.data.AUTOTUNE, the cycle_length argument identifies the maximum degree of parallelism.
block_length (Optional.) The number of consecutive elements to produce from each input element before cycling to another input element. If not set, defaults to 1.
num_parallel_calls (Optional.) If specified, the implementation creates a threadpool, which is used to fetch inputs from cycle elements asynchronously and in parallel. The default behavior is to fetch inputs from cycle elements synchronously with no parallelism. If the value tf.data.AUTOTUNE is used, then the number of parallel calls is set dynamically based on available CPU.
deterministic (Optional.) When num_parallel_calls is specified, if this boolean is specified (True or False), it controls the order in which the transformation produces elements. If set to False, the transformation is allowed to yield elements out of order to trade determinism for performance. If not specified, the tf.data.Options.experimental_deterministic option (True by default) controls the behavior.

Returns
Dataset A Dataset.

list_files

View source

A dataset of all files matching one or more glob patterns.

The file_pattern argument should be a small number of glob patterns. If your filenames have already been globbed, use Dataset.from_tensor_slices(filenames) instead, as re-globbing every filename with list_files may result in poor performance with remote storage systems.

Example:

If we had the following files on our filesystem:

  • /path/to/dir/a.txt
  • /path/to/dir/b.py
  • /path/to/dir/c.py

If we pass "/path/to/dir/*.py" as the directory, the dataset would produce:

  • /path/to/dir/b.py
  • /path/to/dir/c.py

Args
file_pattern A string, a list of strings, or a tf.Tensor of string type (scalar or vector), representing the filename glob (i.e. shell wildcard) pattern(s) that will be matched.
shuffle (Optional.) If True, the file names will be shuffled randomly. Defaults to True.
seed (Optional.) A tf.int64 scalar tf.Tensor, representing the random seed that will be used to create the distribution. See tf.random.set_seed for behavior.

Returns
Dataset A Dataset of strings corresponding to file names.

map

View source

Maps map_func across the elements of this dataset.

This transformation applies map_func to each element of this dataset, and returns a new dataset containing the transformed elements, in the same order as they appeared in the input. map_func can be used to change both the values and the structure of a dataset's elements. Supported structure constructs are documented here.

For example, map can be used for adding 1 to each element, or projecting a subset of element components.

dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map(lambda x: x + 1)
list(dataset.as_numpy_iterator())
[2, 3, 4, 5, 6]

The input signature of map_func is determined by the structure of each element in this dataset.

dataset = Dataset.range(5)
# `map_func` takes a single argument of type `tf.Tensor` with the same
# shape and dtype.
result = dataset.map(lambda x: x + 1)
# Each element is a tuple containing two `tf.Tensor` objects.
elements = [(1, "foo"), (2, "bar"), (3, "baz")]
dataset = tf.data.Dataset.from_generator(
    lambda: elements, (tf.int32, tf.string))
# `map_func` takes two arguments of type `tf.Tensor`. This function
# projects out just the first component.
result = dataset.map(lambda x_int, y_str: x_int)
list(result.as_numpy_iterator())
[1, 2, 3]
# Each element is a dictionary mapping strings to `tf.Tensor` objects.
elements =  ([{"a": 1, "b": "foo"},
              {"a": 2, "b": "bar"},
              {"a": 3, "b": "baz"}])
dataset = tf.data.Dataset.from_generator(
    lambda: elements, {"a": tf.int32, "b": tf.string})
# `map_func` takes a single argument of type `dict` with the same keys
# as the elements.
result = dataset.map(lambda d: str(d["a"]) + d["b"])

The value or values returned by map_func determine the structure of each element in the returned dataset.

dataset = tf.data.Dataset.range(3)
# `map_func` returns two `tf.Tensor` objects.
def g(x):
  return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"])
result = dataset.map(g)
result.element_spec
(TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(3,), dtype=tf.string, name=None))
# Python primitives, lists, and NumPy arrays are implicitly converted to
# `tf.Tensor`.
def h(x):
  return 37.0, ["Foo", "Bar"], np.array([1.0, 2.0], dtype=np.float64)
result = dataset.map(h)
result.element_spec
(TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(2,), dtype=tf.string, name=None), TensorSpec(shape=(2,), dtype=tf.float64, name=None))
# `map_func` can return nested structures.
def i(x):
  return (37.0, [42, 16]), "foo"
result = dataset.map(i)
result.element_spec
((TensorSpec(shape=(), dtype=tf.float32, name=None),
  TensorSpec(shape=(2,), dtype=tf.int32, name=None)),
 TensorSpec(shape=(), dtype=tf.string, name=None))

map_func can accept as arguments and return any type of dataset element.

Note that irrespective of the context in which map_func is defined (eager vs. graph), tf.data traces the function and executes it as a graph. To use Python code inside of the function you have a few options:

1) Rely on AutoGraph to convert Python code into an equivalent graph computation. The downside of this approach is that AutoGraph can convert some but not all Python code.

2) Use tf.py_function, which allows you to write arbitrary Python code but will generally result in worse performance than 1). For example:

d = tf.data.Dataset.from_tensor_slices(['hello', 'world'])
# transform a string tensor to upper case string using a Python function
def upper_case_fn(t: tf.Tensor):
  return t.numpy().decode('utf-8').upper()
d = d.map(lambda x: tf.py_function(func=upper_case_fn,
          inp=[x], Tout=tf.string))
list(d.as_numpy_iterator())
[b'HELLO', b'WORLD']

3) Use tf.numpy_function, which also allows you to write arbitrary Python code. Note that tf.py_function accepts tf.Tensor whereas tf.numpy_function accepts numpy arrays and returns only numpy arrays. For example:

d = tf.data.Dataset.from_tensor_slices(['hello', 'world'])
def upper_case_fn(t: np.ndarray):
  return t.decode('utf-8').upper()
d = d.map(lambda x: tf.numpy_function(func=upper_case_fn,
          inp=[x], Tout=tf.string))
list(d.as_numpy_iterator())
[b'HELLO', b'WORLD']

Note that the use of tf.numpy_function and tf.py_function in general precludes the possibility of executing user-defined transformations in parallel (because of Python GIL).

Performance can often be improved by setting num_parallel_calls so that map will use multiple threads to process elements. If deterministic order isn't required, it can also improve performance to set deterministic=False.

dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map(lambda x: x + 1,
    num_parallel_calls=tf.data.AUTOTUNE,
    deterministic=False)

The order of elements yielded by this transformation is deterministic if deterministic=True. If map_func contains stateful operations and num_parallel_calls > 1, the order in which that state is accessed is undefined, so the values of output elements may not be deterministic regardless of the deterministic flag value.

Args
map_func A function mapping a dataset element to another dataset element.
num_parallel_calls (Optional.) A tf.int64 scalar tf.Tensor, representing the number elements to process asynchronously in parallel. If not specified, elements will be processed sequentially. If the value tf.data.AUTOTUNE is used, then the number of parallel calls is set dynamically based on available CPU.
deterministic (Optional.) When num_parallel_calls is specified, if this boolean is specified (True or False), it controls the order in which the transformation produces elements. If set to False, the transformation is allowed to yield elements out of order to trade determinism for performance. If not specified, the tf.data.Options.experimental_deterministic option (True by default) controls the behavior.

Returns
Dataset A Dataset.

options

View source

Returns the options for this dataset and its inputs.

Returns
A tf.data.Options object representing the dataset options.

padded_batch

View source

Combines consecutive elements of this dataset into padded batches.

This transformation combines multiple consecutive elements of the input dataset into a single element.

Like tf.data.Dataset.batch, the components of the resulting element will have an additional outer dimension, which will be batch_size (or N % batch_size for the last element if batch_size does not divide the number of input elements N evenly and drop_remainder is False). If your program depends on the batches having the same outer dimension, you should set the drop_remainder argument to True to prevent the smaller batch from being produced.

Unlike tf.data.Dataset.batch, the input elements to be batched may have different shapes, and this transformation will pad each component to the respective shape in padded_shapes. The padded_shapes argument determines the resulting shape for each dimension of each component in an output element:

  • If the dimension is a constant, the component will be padded out to that length in that dimension.
  • If the dimension is unknown, the component will be padded out to the maximum length of all elements in that dimension.
A = (tf.data.Dataset
     .range(1, 5, output_type=tf.int32)
     .map(lambda x: tf.fill([x], x)))
# Pad to the smallest per-batch size that fits all elements.
B = A.padded_batch(2)
for element in B.as_numpy_iterator():
  print(element)
[[1 0]
 [2 2]]
[[3 3 3 0]
 [4 4 4 4]]
# Pad to a fixed size.
C = A.padded_batch(2, padded_shapes=5)
for element in C.as_numpy_iterator():
  print(element)
[[1 0 0 0 0]
 [2 2 0 0 0]]
[[3 3 3 0 0]
 [4 4 4 4 0]]
# Pad with a custom value.
D = A.padded_batch(2, padded_shapes=5, padding_values=-1)
for element in D.as_numpy_iterator():
  print(element)
[[ 1 -1 -1 -1 -1]
 [ 2  2 -1 -1 -1]]
[[ 3  3  3 -1 -1]
 [ 4  4  4  4 -1]]
# Components of nested elements can be padded independently.
elements = [([1, 2, 3], [10]),
            ([4, 5], [11, 12])]
dataset = tf.data.Dataset.from_generator(
    lambda: iter(elements), (tf.int32, tf.int32))
# Pad the first component of the tuple to length 4, and the second
# component to the smallest size that fits.
dataset = dataset.padded_batch(2,
    padded_shapes=([4], [None]),
    padding_values=(-1, 100))
list(dataset.as_numpy_iterator())
[(array([[ 1,  2,  3, -1], [ 4,  5, -1, -1]], dtype=int32),
  array([[ 10, 100], [ 11,  12]], dtype=int32))]
# Pad with a single value and multiple components.
E = tf.data.Dataset.zip((A, A)).padded_batch(2, padding_values=-1)
for element in E.as_numpy_iterator():
  print(element)
(array([[ 1, -1],
       [ 2,  2]], dtype=int32), array([[ 1, -1],
       [ 2,  2]], dtype=int32))
(array([[ 3,  3,  3, -1],
       [ 4,  4,  4,  4]], dtype=int32), array([[ 3,  3,  3, -1],
       [ 4,  4,  4,  4]], dtype=int32))

See also tf.data.experimental.dense_to_sparse_batch, which combines elements that may have different shapes into a tf.sparse.SparseTensor.

Args
batch_size A tf.int64 scalar tf.Tensor, representing the number of consecutive elements of this dataset to combine in a single batch.
padded_shapes (Optional.) A (nested) structure of tf.TensorShape or tf.int64 vector tensor-like objects representing the shape to which the respective component of each input element should be padded prior to batching. Any unknown dimensions will be padded to the maximum size of that dimension in each batch. If unset, all dimensions of all components are padded to the maximum size in the batch. padded_shapes must be set if any component has an unknown rank.
padding_values (Optional.) A (nested) structure of scalar-shaped tf.Tensor, representing the padding values to use for the respective components. None represents that the (nested) structure should be padded with default values. Defaults are 0 for numeric types and the empty string for string types. The padding_values should have the same (nested) structure as the input dataset. If padding_values is a single element and the input dataset has multiple components, then the same padding_values will be used to pad every component of the dataset. If padding_values is a scalar, then its value will be broadcasted to match the shape of each component.
drop_remainder (Optional.) A tf.bool scalar tf.Tensor, representing whether the last batch should be dropped in the case it has fewer than batch_size elements; the default behavior is not to drop the smaller batch.

Returns
Dataset A Dataset.

Raises
ValueError If a component has an unknown rank, and the padded_shapes argument is not set.

prefetch

View source

Creates a Dataset that prefetches elements from this dataset.

Most dataset input pipelines should end with a call to prefetch. This allows later elements to be prepared while the current element is being processed. This often improves latency and throughput, at the cost of using additional memory to store prefetched elements.

dataset = tf.data.Dataset.range(3)
dataset = dataset.prefetch(2)
list(dataset.as_numpy_iterator())
[0, 1, 2]

Args
buffer_size A tf.int64 scalar tf.Tensor, representing the maximum number of elements that will be buffered when prefetching. If the value tf.data.AUTOTUNE is used, then the buffer size is dynamically tuned.

Returns
Dataset A Dataset.

random

View source

Creates a Dataset of pseudorandom values.

The dataset generates a sequence of uniformly distributed integer values.

ds1 = tf.data.Dataset.random(seed=4).take(10)
ds2 = tf.data.Dataset.random(seed=4).take(10)
print(list(ds2.as_numpy_iterator())==list(ds2.as_numpy_iterator()))
True

Args
seed (Optional) If specified, the dataset produces a deterministic sequence of values.

Returns
Dataset A Dataset.

range

View source

Creates a Dataset of a step-separated range of values.

list(Dataset.range(5).as_numpy_iterator())
[0, 1, 2, 3, 4]
list(Dataset.range(2, 5).as_numpy_iterator())
[2, 3, 4]
list(Dataset.range(1, 5, 2).as_numpy_iterator())
[1, 3]
list(Dataset.range(1, 5, -2).as_numpy_iterator())
[]
list(Dataset.range(5, 1).as_numpy_iterator())
[]
list(Dataset.range(5, 1, -2).as_numpy_iterator())
[5, 3]
list(Dataset.range(2, 5, output_type=tf.int32).as_numpy_iterator())
[2, 3, 4]
list(Dataset.range(1, 5, 2, output_type=tf.float32).as_numpy_iterator())
[1.0, 3.0]

Args
*args follows the same semantics as python's xrange. len(args) == 1 -> start = 0, stop = args[0], step = 1. len(args) == 2 -> start = args[0], stop = args[1], step = 1. len(args) == 3 -> start = args[0], stop = args[1], step = args[2].
**kwargs

  • output_type: Its expected dtype. (Optional, default: tf.int64).

Returns
Dataset A RangeDataset.

Raises
ValueError if len(args) == 0.

reduce

View source

Reduces the input dataset to a single element.

The transformation calls reduce_func successively on every element of the input dataset until the dataset is exhausted, aggregating information in its internal state. The initial_state argument is used for the initial state and the final state is returned as the result.

tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1).numpy()
5
tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y).numpy()
10

Args
initial_state An element representing the initial state of the transformation.
reduce_func A function that maps (old_state, input_element) to new_state. It must take two arguments and return a new element The structure of new_state must match the structure of initial_state.

Returns
A dataset element corresponding to the final state of the transformation.

repeat

View source

Repeats this dataset so each original value is seen count times.

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.repeat(3)
list(dataset.as_numpy_iterator())
[1, 2, 3, 1, 2, 3, 1, 2, 3]

Args
count (Optional.) A tf.int64 scalar tf.Tensor, representing the number of times the dataset should be repeated. The default behavior (if count is None or -1) is for the dataset be repeated indefinitely.

Returns
Dataset A Dataset.

scan

View source

A transformation that scans a function across an input dataset.

This transformation is a stateful relative of tf.data.Dataset.map. In addition to mapping scan_func across the elements of the input dataset, scan() accumulates one or more state tensors, whose initial values are initial_state.

dataset = tf.data.Dataset.range(10)
initial_state = tf.constant(0, dtype=tf.int64)
scan_func = lambda state, i: (state + i, state + i)
dataset = dataset.scan(initial_state=initial_state, scan_func=scan_func)
list(dataset.as_numpy_iterator())
[0, 1, 3, 6, 10, 15, 21, 28, 36, 45]

Args
initial_state A nested structure of tensors, representing the initial state of the accumulator.
scan_func A function that maps (old_state, input_element) to (new_state, output_element). It must take two arguments and return a pair of nested structures of tensors. The new_state must match the structure of initial_state.

Returns
A Dataset.

shard

View source

Creates a Dataset that includes only 1/num_shards of this dataset.

shard is deterministic. The Dataset produced by A.shard(n, i) will contain all elements of A whose index mod n = i.

A = tf.data.Dataset.range(10)
B = A.shard(num_shards=3, index=0)
list(B.as_numpy_iterator())
[0, 3, 6, 9]
C = A.shard(num_shards=3, index=1)
list(C.as_numpy_iterator())
[1, 4, 7]
D = A.shard(num_shards=3, index=2)
list(D.as_numpy_iterator())
[2, 5, 8]

This dataset operator is very useful when running distributed training, as it allows each worker to read a unique subset.

When reading a single input file, you can shard elements as follows:

d = tf.data.TFRecordDataset(input_file)
d = d.shard(num_workers, worker_index)
d = d.repeat(num_epochs)
d = d.shuffle(shuffle_buffer_size)
d = d.map(parser_fn, num_parallel_calls=num_map_threads)

Important caveats:

  • Be sure to shard before you use any randomizing operator (such as shuffle).
  • Generally it is best if the shard operator is used early in the dataset pipeline. For example, when reading from a set of TFRecord files, shard before converting the dataset to input samples. This avoids reading every file on every worker. The following is an example of an efficient sharding strategy within a complete pipeline:
d = Dataset.list_files(pattern)
d = d.shard(num_workers, worker_index)
d = d.repeat(num_epochs)
d = d.shuffle(shuffle_buffer_size)
d = d.interleave(tf.data.TFRecordDataset,
                 cycle_length=num_readers, block_length=1)
d = d.map(parser_fn, num_parallel_calls=num_map_threads)

Args
num_shards A tf.int64 scalar tf.Tensor, representing the number of shards operating in parallel.
index A tf.int64 scalar tf.Tensor, representing the worker index.

Returns
Dataset A Dataset.

Raises
InvalidArgumentError if num_shards or index are illegal values.

shuffle

View source

Randomly shuffles the elements of this dataset.

This dataset fills a buffer with buffer_size elements, then randomly samples elements from this buffer, replacing the selected elements with new elements. For perfect shuffling, a buffer size greater than or equal to the full size of the dataset is required.

For instance, if your dataset contains 10,000 elements but buffer_size is set to 1,000, then shuffle will initially select a random element from only the first 1,000 elements in the buffer. Once an element is selected, its space in the buffer is replaced by the next (i.e. 1,001-st) element, maintaining the 1,000 element buffer.

reshuffle_each_iteration controls whether the shuffle order should be different for each epoch. In TF 1.X, the idiomatic way to create epochs was through the repeat transformation:

dataset = tf.data.Dataset.range(3)
dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
dataset = dataset.repeat(2)
# [1, 0, 2, 1, 2, 0]

dataset = tf.data.Dataset.range(3)
dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
dataset = dataset.repeat(2)
# [1, 0, 2, 1, 0, 2]

In TF 2.0, tf.data.Dataset objects are Python iterables which makes it possible to also create epochs through Python iteration:

dataset = tf.data.Dataset.range(3)
dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
list(dataset.as_numpy_iterator())
# [1, 0, 2]
list(dataset.as_numpy_iterator())
# [1, 2, 0]
dataset = tf.data.Dataset.range(3)
dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
list(dataset.as_numpy_iterator())
# [1, 0, 2]
list(dataset.as_numpy_iterator())
# [1, 0, 2]

Args
buffer_size A tf.int64 scalar tf.Tensor, representing the number of elements from this dataset from which the new dataset will sample.
seed (Optional.) A tf.int64 scalar tf.Tensor, representing the random seed that will be used to create the distribution. See tf.random.set_seed for behavior.
reshuffle_each_iteration (Optional.) A boolean, which if true indicates that the dataset should be pseudorandomly reshuffled each time it is iterated over. (Defaults to True.)

Returns
Dataset A Dataset.

skip

View source

Creates a Dataset that skips count elements from this dataset.

dataset = tf.data.Dataset.range(10)
dataset = dataset.skip(7)
list(dataset.as_numpy_iterator())
[7, 8, 9]

Args
count A tf.int64 scalar tf.Tensor, representing the number of elements of this dataset that should be skipped to form the new dataset. If count is greater than the size of this dataset, the new dataset will contain no elements. If count is -1, skips the entire dataset.

Returns
Dataset A Dataset.

snapshot

View source

API to persist the output of the input dataset.

The snapshot API allows users to transparently persist the output of their preprocessing pipeline to disk, and materialize the pre-processed data on a different training run.

This API enables repeated preprocessing steps to be consolidated, and allows re-use of already processed data, trading off disk storage and network bandwidth for freeing up more valuable CPU resources and accelerator compute time.

https://github.com/tensorflow/community/blob/master/rfcs/20200107-tf-data-snapshot.md has detailed design documentation of this feature.

Users can specify various options to control the behavior of snapshot, including how snapshots are read from and written to by passing in user-defined functions to the reader_func and shard_func parameters.

shard_func is a user specified function that maps input elements to snapshot shards.

Users may want to specify this function to control how snapshot files should be written to disk. Below is an example of how a potential shard_func could be written.

dataset = ...
dataset = dataset.enumerate()
dataset = dataset.snapshot("/path/to/snapshot/dir",
    shard_func=lambda x, y: x % NUM_SHARDS, ...)
dataset = dataset.map(lambda x, y: y)

reader_func is a user specified function that accepts a single argument: (1) a Dataset of Datasets, each representing a "split" of elements of the original dataset. The cardinality of the input dataset matches the number of the shards specified in the shard_func (see above). The function should return a Dataset of elements of the original dataset.

Users may want specify this function to control how snapshot files should be read from disk, including the amount of shuffling and parallelism.

Here is an example of a standard reader function a user can define. This function enables both dataset shuffling and parallel reading of datasets:

def user_reader_func(datasets):
  # shuffle the datasets splits
  datasets = datasets.shuffle(NUM_CORES)
  # read datasets in parallel and interleave their elements
  return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE)

dataset = dataset.snapshot("/path/to/snapshot/dir",
    reader_func=user_reader_func)

By default, snapshot parallelizes reads by the number of cores available on the system, but will not attempt to shuffle the data.

Args
path Required. A directory to use for storing / loading the snapshot to / from.
compression Optional. The type of compression to apply to the snapshot written to disk. Supported options are GZIP, SNAPPY, AUTO or None. Defaults to AUTO, which attempts to pick an appropriate compression algorithm for the dataset.
reader_func Optional. A function to control how to read data from snapshot shards.
shard_func Optional. A function to control how to shard data when writing a snapshot.

Returns
A Dataset.

take

View source

Creates a Dataset with at most count elements from this dataset.

dataset = tf.data.Dataset.range(10)
dataset = dataset.take(3)
list(dataset.as_numpy_iterator())
[0, 1, 2]

Args
count A tf.int64 scalar tf.Tensor, representing the number of elements of this dataset that should be taken to form the new dataset. If count is -1, or if count is greater than the size of this dataset, the new dataset will contain all elements of this dataset.

Returns
Dataset A Dataset.

take_while

View source

A transformation that stops dataset iteration based on a predicate.

dataset = tf.data.Dataset.range(10)
dataset = dataset.take_while(lambda x: x < 5)
list(dataset.as_numpy_iterator())
[0, 1, 2, 3, 4]

Args
predicate A function that maps a nested structure of tensors (having shapes and types defined by self.output_shapes and self.output_types) to a scalar tf.bool tensor.

Returns
A Dataset.

unbatch

View source

Splits elements of a dataset into multiple elements.

For example, if elements of the dataset are shaped [B, a0, a1, ...], where B may vary for each input element, then for each element in the dataset, the unbatched dataset will contain B consecutive elements of shape [a0, a1, ...].

elements = [ [1, 2, 3], [1, 2], [1, 2, 3, 4] ]
dataset = tf.data.Dataset.from_generator(lambda: elements, tf.int64)
dataset = dataset.unbatch()
list(dataset.as_numpy_iterator())
[1, 2, 3, 1, 2, 1, 2, 3, 4]

Returns
A Dataset.

unique

View source

A transformation that discards duplicate elements of a Dataset.

Use this transformation to produce a dataset that contains one instance of each unique element in the input. For example:

dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1])
dataset = dataset.unique()
sorted(list(dataset.as_numpy_iterator()))
[1, 2, 37]

Returns
A Dataset.

window

View source

Returns a dataset of "windows".

Each "window" is a dataset that contains a subset of elements of the input dataset. These are finite datasets of size size (or possibly fewer if there are not enough input elements to fill the window and drop_remainder evaluates to False).

For example:

dataset = tf.data.Dataset.range(7).window(3)
for window in dataset:
  print(window)
<...Dataset shapes: (), types: tf.int64>
<...Dataset shapes: (), types: tf.int64>
<...Dataset shapes: (), types: tf.int64>

Since windows are datasets, they can be iterated over:

for window in dataset:
  print([item.numpy() for item in window])
[0, 1, 2]
[3, 4, 5]
[6]

Shift

The shift argument determines the number of input elements to shift between the start of each window. If windows and elements are both numbered starting at 0, the first element in window k will be element k * shift of the input dataset. In particular, the first element of the first window will always be the first element of the input dataset.

dataset = tf.data.Dataset.range(7).window(3, shift=1,
                                          drop_remainder=True)
for window in dataset:
  print(list(window.as_numpy_iterator()))
[0, 1, 2]
[1, 2, 3]
[2, 3, 4]
[3, 4, 5]
[4, 5, 6]

Stride

The stride argument determines the stride between input elements within a window.

dataset = tf.data.Dataset.range(7).window(3, shift=1, stride=2,
                                          drop_remainder=True)
for window in dataset:
  print(list(window.as_numpy_iterator()))
[0, 2, 4]
[1, 3, 5]
[2, 4, 6]

Nested elements

When the window transformation is applied to a dataset whos elements are nested structures, it produces a dataset where the elements have the same nested structure but each leaf is replaced by a window. In other words, the nesting is applied outside of the windows as opposed inside of them.

The type signature is:

def window(
    self: Dataset[Nest[T]], ...
) -> Dataset[Nest[Dataset[T]]]

Applying window to a Dataset of tuples gives a tuple of windows:

dataset = tf.data.Dataset.from_tensor_slices(([1, 2, 3, 4, 5],
                                              [6, 7, 8, 9, 10]))
dataset = dataset.window(2)
windows = next(iter(dataset))
windows
(<...Dataset shapes: (), types: tf.int32>,
 <...Dataset shapes: (), types: tf.int32>)
def to_numpy(ds):
  return list(ds.as_numpy_iterator())

for windows in dataset:
  print(to_numpy(windows[0]), to_numpy(windows[1]))
[1, 2] [6, 7]
[3, 4] [8, 9]
[5] [10]

Applying window to a Dataset of dictionaries gives a dictionary of Datasets:

dataset = tf.data.Dataset.from_tensor_slices({'a': [1, 2, 3],
                                              'b': [4, 5, 6],
                                              'c': [7, 8, 9]})
dataset = dataset.window(2)
def to_numpy(ds):
  return list(ds.as_numpy_iterator())

for windows in dataset:
  print(tf.nest.map_structure(to_numpy, windows))
{'a': [1, 2], 'b': [4, 5], 'c': [7, 8]}
{'a': [3], 'b': [6], 'c': [9]}

Flatten a dataset of windows

The Dataset.flat_map and Dataset.interleave methods can be used to flatten a dataset of windows into a single dataset.

The argument to flat_map is a function that takes an element from the dataset and returns a Dataset. flat_map chains together the resulting datasets sequentially.

For example, to turn each window into a dense tensor:

size = 3
dataset = tf.data.Dataset.range(7).window(size, shift=1,
                                          drop_remainder=True)
batched = dataset.flat_map(lambda x:x.batch(3))
for batch in batched:
  print(batch.numpy())
[0 1 2]
[1 2 3]
[2 3 4]
[3 4 5]
[4 5 6]

Args
size A tf.int64 scalar tf.Tensor, representing the number of elements of the input dataset to combine into a window. Must be positive.
shift (Optional.) A tf.int64 scalar tf.Tensor, representing the number of input elements by which the window moves in each iteration. Defaults to size. Must be positive.
stride (Optional.) A tf.int64 scalar tf.Tensor, representing the stride of the input elements in the sliding window. Must be positive. The default value of 1 means "retain every input element".
drop_remainder (Optional.) A tf.bool scalar tf.Tensor, representing whether the last windows should be dropped if their size is smaller than size.

Returns
Dataset A Dataset of (nests of) windows. Each window is a finite datasets of flat elements.

with_options

View source

Returns a new tf.data.Dataset with the given options set.

The options are "global" in the sense they apply to the entire dataset. If options are set multiple times, they are merged as long as different options do not use different non-default values.

ds = tf.data.Dataset.range(5)
ds = ds.interleave(lambda x: tf.data.Dataset.range(5),
                   cycle_length=3,
                   num_parallel_calls=3)
options = tf.data.Options()
# This will make the interleave order non-deterministic.
options.experimental_deterministic = False
ds = ds.with_options(options)

Args
options A tf.data.Options that identifies the options the use.

Returns
Dataset A Dataset with the given options.

Raises
ValueError when an option is set more than once to a non-default value

zip

View source

Creates a Dataset by zipping together the given datasets.

This method has similar semantics to the built-in zip() function in Python, with the main difference being that the datasets argument can be a (nested) structure of Dataset objects. The supported nesting mechanisms are documented here.

# The nested structure of the `datasets` argument determines the
# structure of elements in the resulting dataset.
a = tf.data.Dataset.range(1, 4)  # ==> [ 1, 2, 3 ]
b = tf.data.Dataset.range(4, 7)  # ==> [ 4, 5, 6 ]
ds = tf.data.Dataset.zip((a, b))
list(ds.as_numpy_iterator())
[(1, 4), (2, 5), (3, 6)]
ds = tf.data.Dataset.zip((b, a))
list(ds.as_numpy_iterator())
[(4, 1), (5, 2), (6, 3)]

# The `datasets` argument may contain an arbitrary number of datasets.
c = tf.data.Dataset.range(7, 13).batch(2)  # ==> [ [7, 8],
                                           #       [9, 10],
                                           #       [11, 12] ]
ds = tf.data.Dataset.zip((a, b, c))
for element in ds.as_numpy_iterator():
  print(element)
(1, 4, array([7, 8]))
(2, 5, array([ 9, 10]))
(3, 6, array([11, 12]))

# The number of elements in the resulting dataset is the same as
# the size of the smallest dataset in `datasets`.
d = tf.data.Dataset.range(13, 15)  # ==> [ 13, 14 ]
ds = tf.data.Dataset.zip((a, d))
list(ds.as_numpy_iterator())
[(1, 13), (2, 14)]

Args
datasets A (nested) structure of datasets.

Returns
Dataset A Dataset.

__bool__

View source

__iter__

View source

Creates an iterator for elements of this dataset.

The returned iterator implements the Python Iterator protocol.

Returns
An tf.data.Iterator for the elements of this dataset.

Raises
RuntimeError If not inside of tf.function and not executing eagerly.

__len__

View source

Returns the length of the dataset if it is known and finite.

This method requires that you are running in eager mode, and that the length of the dataset is known and non-infinite. When the length may be unknown or infinite, or if you are running in graph mode, use tf.data.Dataset.cardinality instead.

Returns
An integer representing the length of the dataset.

Raises
RuntimeError If the dataset length is unknown or infinite, or if eager execution is not enabled.

__nonzero__

View source