![]() |
![]() |
![]() |
![]() |
The tf.data
API enables you to build complex input pipelines from simple,
reusable pieces. For example, the pipeline for an image model might aggregate
data from files in a distributed file system, apply random perturbations to each
image, and merge randomly selected images into a batch for training. The
pipeline for a text model might involve extracting symbols from raw text data,
converting them to embedding identifiers with a lookup table, and batching
together sequences of different lengths. The tf.data
API makes it possible to
handle large amounts of data, read from different data formats, and perform
complex transformations.
The tf.data
API introduces a tf.data.Dataset
abstraction that represents a
sequence of elements, in which each element consists of one or more components.
For example, in an image pipeline, an element might be a single training
example, with a pair of tensor components representing the image and its label.
There are two distinct ways to create a dataset:
A data source constructs a
Dataset
from data stored in memory or in one or more files.A data transformation constructs a dataset from one or more
tf.data.Dataset
objects.
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import pathlib
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
np.set_printoptions(precision=4)
Basic mechanics
To create an input pipeline, you must start with a data source. For example,
to construct a Dataset
from data in memory, you can use
tf.data.Dataset.from_tensors()
or tf.data.Dataset.from_tensor_slices()
.
Alternatively, if your input data is stored in a file in the recommended
TFRecord format, you can use tf.data.TFRecordDataset()
.
Once you have a Dataset
object, you can transform it into a new Dataset
by
chaining method calls on the tf.data.Dataset
object. For example, you can
apply per-element transformations such as Dataset.map()
, and multi-element
transformations such as Dataset.batch()
. See the documentation for
tf.data.Dataset
for a complete list of transformations.
The Dataset
object is a Python iterable. This makes it possible to consume its
elements using a for loop:
dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])
dataset
<TensorSliceDataset shapes: (), types: tf.int32>
for elem in dataset:
print(elem.numpy())
8 3 0 8 2 1
Or by explicitly creating a Python iterator using iter
and consuming its
elements using next
:
it = iter(dataset)
print(next(it).numpy())
8
Alternatively, dataset elements can be consumed using the reduce
transformation, which reduces all elements to produce a single result. The
following example illustrates how to use the reduce
transformation to compute
the sum of a dataset of integers.
print(dataset.reduce(0, lambda state, value: state + value).numpy())
22
Dataset structure
A dataset contains elements that each have the same (nested) structure and the
individual components of the structure can be of any type representable by
tf.TypeSpec
, including Tensor
, SparseTensor
, RaggedTensor
,
TensorArray
, or Dataset
.
The Dataset.element_spec
property allows you to inspect the type of each
element component. The property returns a nested structure of tf.TypeSpec
objects, matching the structure of the element, which may be a single component,
a tuple of components, or a nested tuple of components. For example:
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4, 10]))
dataset1.element_spec
TensorSpec(shape=(10,), dtype=tf.float32, name=None)
dataset2 = tf.data.Dataset.from_tensor_slices(
(tf.random.uniform([4]),
tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))
dataset2.element_spec
(TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(100,), dtype=tf.int32, name=None))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
dataset3.element_spec
(TensorSpec(shape=(10,), dtype=tf.float32, name=None), (TensorSpec(shape=(), dtype=tf.float32, name=None), TensorSpec(shape=(100,), dtype=tf.int32, name=None)))
# Dataset containing a sparse tensor.
dataset4 = tf.data.Dataset.from_tensors(tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))
dataset4.element_spec
SparseTensorSpec(TensorShape([3, 4]), tf.int32)
# Use value_type to see the type of value represented by the element spec
dataset4.element_spec.value_type
tensorflow.python.framework.sparse_tensor.SparseTensor
The Dataset
transformations support datasets of any structure. When using the
Dataset.map()
, and Dataset.filter()
transformations,
which apply a function to each element, the element structure determines the
arguments of the function:
dataset1 = tf.data.Dataset.from_tensor_slices(
tf.random.uniform([4, 10], minval=1, maxval=10, dtype=tf.int32))
dataset1
<TensorSliceDataset shapes: (10,), types: tf.int32>
for z in dataset1:
print(z.numpy())
[9 6 2 1 1 3 1 9 5 1] [9 7 4 7 6 6 3 9 7 4] [3 2 9 7 6 9 1 6 8 8] [8 1 7 2 1 4 2 7 5 5]
dataset2 = tf.data.Dataset.from_tensor_slices(
(tf.random.uniform([4]),
tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))
dataset2
<TensorSliceDataset shapes: ((), (100,)), types: (tf.float32, tf.int32)>
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
dataset3
<ZipDataset shapes: ((10,), ((), (100,))), types: (tf.int32, (tf.float32, tf.int32))>
for a, (b,c) in dataset3:
print('shapes: {a.shape}, {b.shape}, {c.shape}'.format(a=a, b=b, c=c))
shapes: (10,), (), (100,) shapes: (10,), (), (100,) shapes: (10,), (), (100,) shapes: (10,), (), (100,)
Reading input data
Consuming NumPy arrays
See Loading NumPy arrays for more examples.
If all of your input data fits in memory, the simplest way to create a Dataset
from them is to convert them to tf.Tensor
objects and use
Dataset.from_tensor_slices()
.
train, test = tf.keras.datasets.fashion_mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz 32768/29515 [=================================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz 26427392/26421880 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz 8192/5148 [===============================================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz 4423680/4422102 [==============================] - 0s 0us/step
images, labels = train
images = images/255
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset
<TensorSliceDataset shapes: ((28, 28), ()), types: (tf.float64, tf.uint8)>
Consuming Python generators
Another common data source that can easily be ingested as a tf.data.Dataset
is the python generator.
def count(stop):
i = 0
while i<stop:
yield i
i += 1
for n in count(5):
print(n)
0 1 2 3 4
The Dataset.from_generator
constructor converts the python generator to a fully functional tf.data.Dataset
.
The constructor takes a callable as input, not an iterator. This allows it to restart the generator when it reaches the end. It takes an optional args
argument, which is passed as the callable's arguments.
The output_types
argument is required because tf.data
builds a tf.Graph
internally, and graph edges require a tf.dtype
.
ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
for count_batch in ds_counter.repeat().batch(10).take(10):
print(count_batch.numpy())
[0 1 2 3 4 5 6 7 8 9] [10 11 12 13 14 15 16 17 18 19] [20 21 22 23 24 0 1 2 3 4] [ 5 6 7 8 9 10 11 12 13 14] [15 16 17 18 19 20 21 22 23 24] [0 1 2 3 4 5 6 7 8 9] [10 11 12 13 14 15 16 17 18 19] [20 21 22 23 24 0 1 2 3 4] [ 5 6 7 8 9 10 11 12 13 14] [15 16 17 18 19 20 21 22 23 24]
The output_shapes
argument is not required but is highly recomended as many tensorflow operations do not support tensors with unknown rank. If the length of a particular axis is unknown or variable, set it as None
in the output_shapes
.
It's also important to note that the output_shapes
and output_types
follow the same nesting rules as other dataset methods.
Here is an example generator that demonstrates both aspects, it returns tuples of arrays, where the second array is a vector with unknown length.
def gen_series():
i = 0
while True:
size = np.random.randint(0, 10)
yield i, np.random.normal(size=(size,))
i += 1
for i, series in gen_series():
print(i, ":", str(series))
if i > 5:
break
0 : [1.0292] 1 : [ 1.196 -0.0209 -0.7752 -0.4564 1.2394 -0.6198 1.1509 -0.5288] 2 : [-0.8607 -0.1837] 3 : [-2.4491 0.1419 -0.0917 1.2441 0.5747 0.6937] 4 : [3.0188 0.6504] 5 : [ 2.1593 -1.1982 0.203 1.4915 -0.3486 0.1988 -2.3991 0.7208 -0.6654] 6 : []
The first output is an int32
the second is a float32
.
The first item is a scalar, shape ()
, and the second is a vector of unknown length, shape (None,)
ds_series = tf.data.Dataset.from_generator(
gen_series,
output_types=(tf.int32, tf.float32),
output_shapes=((), (None,)))
ds_series
<DatasetV1Adapter shapes: ((), (None,)), types: (tf.int32, tf.float32)>
Now it can be used like a regular tf.data.Dataset
. Note that when batching a dataset with a variable shape, you need to use Dataset.padded_batch
.
ds_series_batch = ds_series.shuffle(20).padded_batch(10, padded_shapes=([], [None]))
ids, sequence_batch = next(iter(ds_series_batch))
print(ids.numpy())
print()
print(sequence_batch.numpy())
[18 3 4 5 6 24 2 10 15 16] [[ 0.1313 -1.1516 0.2933 1.7188 -0.0223 0. 0. 0. 0. ] [ 0.5759 0.6902 -1.0219 -1.3139 -1.0584 -0.5016 -0.2554 0. 0. ] [-0.5048 -1.3863 -0.2144 -0.3981 1.0549 0. 0. 0. 0. ] [ 0.0936 -0.913 0.1827 0.5312 0.6158 -2.1847 0.2045 -0.8826 -0.9115] [-0.5844 -1.1107 1.7099 0.0464 -0.5909 -1.5843 0.9954 -0.2046 -0.1894] [-0.0825 0. 0. 0. 0. 0. 0. 0. 0. ] [-0.3662 -0.74 -0.7329 0.6858 0.8259 0. 0. 0. 0. ] [ 0. 0. 0. 0. 0. 0. 0. 0. 0. ] [-0.5926 -0.3781 -0.1833 -0.0628 0.7526 0.8645 0.1582 -1.0713 0. ] [-1.0442 -0.9337 0.254 0.1426 0.5776 0. 0. 0. 0. ]]
For a more realistic example, try wrapping preprocessing.image.ImageDataGenerator
as a tf.data.Dataset
.
First download the data:
flowers = tf.keras.utils.get_file(
'flower_photos',
'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
untar=True)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz 228818944/228813984 [==============================] - 4s 0us/step
Create the image.ImageDataGenerator
img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)
images, labels = next(img_gen.flow_from_directory(flowers))
Found 3670 images belonging to 5 classes.
print(images.dtype, images.shape)
print(labels.dtype, labels.shape)
float32 (32, 256, 256, 3) float32 (32, 5)
ds = tf.data.Dataset.from_generator(
img_gen.flow_from_directory, args=[flowers],
output_types=(tf.float32, tf.float32),
output_shapes=([32,256,256,3], [32,5])
)
ds
<DatasetV1Adapter shapes: ((32, 256, 256, 3), (32, 5)), types: (tf.float32, tf.float32)>
Consuming TFRecord data
See Loading TFRecords for an end-to-end example.
The tf.data
API supports a variety of file formats so that you can process
large datasets that do not fit in memory. For example, the TFRecord file format
is a simple record-oriented binary format that many TensorFlow applications use
for training data. The tf.data.TFRecordDataset
class enables you to
stream over the contents of one or more TFRecord files as part of an input
pipeline.
Here is an example using the test file from the French Street Name Signs (FSNS).
# Creates a dataset that reads all of the examples from two files.
fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001 7905280/7904079 [==============================] - 0s 0us/step
The filenames
argument to the TFRecordDataset
initializer can either be a
string, a list of strings, or a tf.Tensor
of strings. Therefore if you have
two sets of files for training and validation purposes, you can create a factory
method that produces the dataset, taking filenames as an input argument:
dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])
dataset
<TFRecordDatasetV2 shapes: (), types: tf.string>
Many TensorFlow projects use serialized tf.train.Example
records in their TFRecord files. These need to be decoded before they can be inspected:
raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())
parsed.features.feature['image/text']
bytes_list { value: "Rue Perreyon" }
Consuming text data
See Loading Text for an end to end example.
Many datasets are distributed as one or more text files. The
tf.data.TextLineDataset
provides an easy way to extract lines from one or more
text files. Given one or more filenames, a TextLineDataset
will produce one
string-valued element per line of those files.
directory_url = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'
file_names = ['cowper.txt', 'derby.txt', 'butler.txt']
file_paths = [
tf.keras.utils.get_file(file_name, directory_url + file_name)
for file_name in file_names
]
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/cowper.txt 819200/815980 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/derby.txt 811008/809730 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/butler.txt 811008/807992 [==============================] - 0s 0us/step
dataset = tf.data.TextLineDataset(file_paths)
Here are the first few lines of the first file:
for line in dataset.take(5):
print(line.numpy())
b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus' son;" b'His wrath pernicious, who ten thousand woes' b"Caused to Achaia's host, sent many a soul" b'Illustrious into Ades premature,' b'And Heroes gave (so stood the will of Jove)'
To alternate lines between files use Dataset.interleave
. This makes it easier to shuffle files together. Here are the first, second and third lines from each translation:
files_ds = tf.data.Dataset.from_tensor_slices(file_paths)
lines_ds = files_ds.interleave(tf.data.TextLineDataset, cycle_length=3)
for i, line in enumerate(lines_ds.take(9)):
if i % 3 == 0:
print()
print(line.numpy())
b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus' son;" b"\xef\xbb\xbfOf Peleus' son, Achilles, sing, O Muse," b'\xef\xbb\xbfSing, O goddess, the anger of Achilles son of Peleus, that brought' b'His wrath pernicious, who ten thousand woes' b'The vengeance, deep and deadly; whence to Greece' b'countless ills upon the Achaeans. Many a brave soul did it send' b"Caused to Achaia's host, sent many a soul" b'Unnumbered ills arose; which many a soul' b'hurrying down to Hades, and many a hero did it yield a prey to dogs and'
By default, a TextLineDataset
yields every line of each file, which may
not be desirable, for example, if the file starts with a header line, or contains comments. These lines can be removed using the Dataset.skip()
or
Dataset.filter()
transformations. Here we skip the first line, then filter to
find only survivors.
titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)
Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv 32768/30874 [===============================] - 0s 0us/step
for line in titanic_lines.take(10):
print(line.numpy())
b'survived,sex,age,n_siblings_spouses,parch,fare,class,deck,embark_town,alone' b'0,male,22.0,1,0,7.25,Third,unknown,Southampton,n' b'1,female,38.0,1,0,71.2833,First,C,Cherbourg,n' b'1,female,26.0,0,0,7.925,Third,unknown,Southampton,y' b'1,female,35.0,1,0,53.1,First,C,Southampton,n' b'0,male,28.0,0,0,8.4583,Third,unknown,Queenstown,y' b'0,male,2.0,3,1,21.075,Third,unknown,Southampton,n' b'1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n' b'1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n' b'1,female,4.0,1,1,16.7,Third,G,Southampton,n'
def survived(line):
return tf.not_equal(tf.strings.substr(line, 0, 1), "0")
survivors = titanic_lines.skip(1).filter(survived)
for line in survivors.take(10):
print(line.numpy())
b'1,female,38.0,1,0,71.2833,First,C,Cherbourg,n' b'1,female,26.0,0,0,7.925,Third,unknown,Southampton,y' b'1,female,35.0,1,0,53.1,First,C,Southampton,n' b'1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n' b'1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n' b'1,female,4.0,1,1,16.7,Third,G,Southampton,n' b'1,male,28.0,0,0,13.0,Second,unknown,Southampton,y' b'1,female,28.0,0,0,7.225,Third,unknown,Cherbourg,y' b'1,male,28.0,0,0,35.5,First,A,Southampton,y' b'1,female,38.0,1,5,31.3875,Third,unknown,Southampton,n'
Consuming CSV data
See Loading CSV Files, and Loading Pandas DataFrames for more examples.
The CSV file format is a popular format for storing tabular data in plain text.
For example:
titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
df = pd.read_csv(titanic_file, index_col=None)
df.head()
If your data fits in memory the same Dataset.from_tensor_slices
method works on dictionaries, allowing this data to be easily imported:
titanic_slices = tf.data.Dataset.from_tensor_slices(dict(df))
for feature_batch in titanic_slices.take(1):
for key, value in feature_batch.items():
print(" {!r:20s}: {}".format(key, value))
'survived' : 0 'sex' : b'male' 'age' : 22.0 'n_siblings_spouses': 1 'parch' : 0 'fare' : 7.25 'class' : b'Third' 'deck' : b'unknown' 'embark_town' : b'Southampton' 'alone' : b'n'
A more scalable approach is to load from disk as necessary.
The tf.data
module provides methods to extract records from one or more CSV files that comply with RFC 4180.
The experimental.make_csv_dataset
function is the high level interface for reading sets of csv files. It supports column type inference and many other features, like batching and shuffling, to make usage simple.
titanic_batches = tf.data.experimental.make_csv_dataset(
titanic_file, batch_size=4,
label_name="survived")
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/data/experimental/ops/readers.py:521: parallel_interleave (from tensorflow.python.data.experimental.ops.interleave_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Dataset.interleave(map_func, cycle_length, block_length, num_parallel_calls=tf.data.experimental.AUTOTUNE)` instead. If sloppy execution is desired, use `tf.data.Options.experimental_determinstic`. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/data/experimental/ops/readers.py:215: shuffle_and_repeat (from tensorflow.python.data.experimental.ops.shuffle_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Dataset.shuffle(buffer_size, seed)` followed by `tf.data.Dataset.repeat(count)`. Static tf.data optimizations will take care of using the fused implementation.
for feature_batch, label_batch in titanic_batches.take(1):
print("'survived': {}".format(label_batch))
print("features:")
for key, value in feature_batch.items():
print(" {!r:20s}: {}".format(key, value))
'survived': [1 0 0 1] features: 'sex' : [b'female' b'female' b'male' b'male'] 'age' : [24. 28. 47. 25.] 'n_siblings_spouses': [1 0 0 1] 'parch' : [2 0 0 0] 'fare' : [65. 7.55 38.5 91.0792] 'class' : [b'Second' b'Third' b'First' b'First'] 'deck' : [b'unknown' b'unknown' b'E' b'B'] 'embark_town' : [b'Southampton' b'Southampton' b'Southampton' b'Cherbourg'] 'alone' : [b'n' b'y' b'y' b'n']
You can use the select_columns
argument if you only need a subset of columns.
titanic_batches = tf.data.experimental.make_csv_dataset(
titanic_file, batch_size=4,
label_name="survived", select_columns=['class', 'fare', 'survived'])
for feature_batch, label_batch in titanic_batches.take(1):
print("'survived': {}".format(label_batch))
for key, value in feature_batch.items():
print(" {!r:20s}: {}".format(key, value))
'survived': [1 1 1 0] 'fare' : [15.9 55. 22.3583 7.25 ] 'class' : [b'Third' b'First' b'Third' b'Third']
There is also a lower-level experimental.CsvDataset
class which provides finer grained control. It does not support column type inference. Instead you must specify the type of each column.
titanic_types = [tf.int32, tf.string, tf.float32, tf.int32, tf.int32, tf.float32, tf.string, tf.string, tf.string, tf.string]
dataset = tf.data.experimental.CsvDataset(titanic_file, titanic_types , header=True)
for line in dataset.take(10):
print([item.numpy() for item in line])
[0, b'male', 22.0, 1, 0, 7.25, b'Third', b'unknown', b'Southampton', b'n'] [1, b'female', 38.0, 1, 0, 71.2833, b'First', b'C', b'Cherbourg', b'n'] [1, b'female', 26.0, 0, 0, 7.925, b'Third', b'unknown', b'Southampton', b'y'] [1, b'female', 35.0, 1, 0, 53.1, b'First', b'C', b'Southampton', b'n'] [0, b'male', 28.0, 0, 0, 8.4583, b'Third', b'unknown', b'Queenstown', b'y'] [0, b'male', 2.0, 3, 1, 21.075, b'Third', b'unknown', b'Southampton', b'n'] [1, b'female', 27.0, 0, 2, 11.1333, b'Third', b'unknown', b'Southampton', b'n'] [1, b'female', 14.0, 1, 0, 30.0708, b'Second', b'unknown', b'Cherbourg', b'n'] [1, b'female', 4.0, 1, 1, 16.7, b'Third', b'G', b'Southampton', b'n'] [0, b'male', 20.0, 0, 0, 8.05, b'Third', b'unknown', b'Southampton', b'y']
If some columns are empty, this low-level interface allows you to provide default values instead of column types.
%%writefile missing.csv
1,2,3,4
,2,3,4
1,,3,4
1,2,,4
1,2,3,
,,,
Writing missing.csv
# Creates a dataset that reads all of the records from two CSV files, each with
# four float columns which may have missing values.
record_defaults = [999,999,999,999]
dataset = tf.data.experimental.CsvDataset("missing.csv", record_defaults)
dataset = dataset.map(lambda *items: tf.stack(items))
dataset
<MapDataset shapes: (4,), types: tf.int32>
for line in dataset:
print(line.numpy())
[1 2 3 4] [999 2 3 4] [ 1 999 3 4] [ 1 2 999 4] [ 1 2 3 999] [999 999 999 999]
By default, a CsvDataset
yields every column of every line of the file,
which may not be desirable, for example if the file starts with a header line
that should be ignored, or if some columns are not required in the input.
These lines and fields can be removed with the header
and select_cols
arguments respectively.
# Creates a dataset that reads all of the records from two CSV files with
# headers, extracting float data from columns 2 and 4.
record_defaults = [999, 999] # Only provide defaults for the selected columns
dataset = tf.data.experimental.CsvDataset("missing.csv", record_defaults, select_cols=[1, 3])
dataset = dataset.map(lambda *items: tf.stack(items))
dataset
<MapDataset shapes: (2,), types: tf.int32>
for line in dataset:
print(line.numpy())
[2 4] [2 4] [999 4] [2 4] [ 2 999] [999 999]
Consuming sets of files
There are many datasets distributed as a set of files, where each file is an example.
flowers_root = tf.keras.utils.get_file(
'flower_photos',
'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
untar=True)
flowers_root = pathlib.Path(flowers_root)
The root directory contains a directory for each class:
for item in flowers_root.glob("*"):
print(item.name)
daisy LICENSE.txt roses dandelion sunflowers tulips
The files in each class directory are examples:
list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))
for f in list_ds.take(5):
print(f.numpy())
b'/home/kbuilder/.keras/datasets/flower_photos/roses/6950609394_c53b8c6ac0_m.jpg' b'/home/kbuilder/.keras/datasets/flower_photos/dandelion/7132676187_7a4265b16f_n.jpg' b'/home/kbuilder/.keras/datasets/flower_photos/daisy/5869147563_66fb88119d.jpg' b'/home/kbuilder/.keras/datasets/flower_photos/sunflowers/14921668662_3ffc5b9db3_n.jpg' b'/home/kbuilder/.keras/datasets/flower_photos/daisy/4669117051_ce61e91b76.jpg'
We can read the data using the tf.io.read_file
function and extract the label from the path, returning (image, label)
pairs:
def process_path(file_path):
label = tf.strings.split(file_path, '/')[-2]
return tf.io.read_file(file_path), label
labeled_ds = list_ds.map(process_path)
for image_raw, label_text in labeled_ds.take(1):
print(repr(image_raw.numpy()[:100]))
print()
print(label_text.numpy())
b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xe2\x0cXICC_PROFILE\x00\x01\x01\x00\x00\x0cHLino\x02\x10\x00\x00mntrRGB XYZ \x07\xce\x00\x02\x00\t\x00\x06\x001\x00\x00acspMSFT\x00\x00\x00\x00IEC sRGB\x00\x00\x00\x00\x00\x00' b'roses'
Batching dataset elements
Simple batching
The simplest form of batching stacks n
consecutive elements of a dataset into
a single element. The Dataset.batch()
transformation does exactly this, with
the same constraints as the tf.stack()
operator, applied to each component
of the elements: i.e. for each component i, all elements must have a tensor
of the exact same shape.
inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)
for batch in batched_dataset.take(4):
print([arr.numpy() for arr in batch])
[array([0, 1, 2, 3]), array([ 0, -1, -2, -3])] [array([4, 5, 6, 7]), array([-4, -5, -6, -7])] [array([ 8, 9, 10, 11]), array([ -8, -9, -10, -11])] [array([12, 13, 14, 15]), array([-12, -13, -14, -15])]
While tf.data
tries to propagate shape information, the default settings of Dataset.batch
result in an unknown batch size because the last batch may not be full. Note the None
s in the shape:
batched_dataset
<BatchDataset shapes: ((None,), (None,)), types: (tf.int64, tf.int64)>
Use the drop_remainder
argument to ignore that last batch, and get full shape propagation:
batched_dataset = dataset.batch(7, drop_remainder=True)
batched_dataset
<BatchDataset shapes: ((7,), (7,)), types: (tf.int64, tf.int64)>
Batching tensors with padding
The above recipe works for tensors that all have the same size. However, many
models (e.g. sequence models) work with input data that can have varying size
(e.g. sequences of different lengths). To handle this case, the
Dataset.padded_batch()
transformation enables you to batch tensors of
different shape by specifying one or more dimensions in which they may be
padded.
dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=(None,))
for batch in dataset.take(2):
print(batch.numpy())
print()
[[0 0 0] [1 0 0] [2 2 0] [3 3 3]] [[4 4 4 4 0 0 0] [5 5 5 5 5 0 0] [6 6 6 6 6 6 0] [7 7 7 7 7 7 7]]
The Dataset.padded_batch()
transformation allows you to set different padding
for each dimension of each component, and it may be variable-length (signified
by None
in the example above) or constant-length. It is also possible to
override the padding value, which defaults to 0.
Training workflows
Processing multiple epochs
The tf.data
API offers two main ways to process multiple epochs of the same
data.
The simplest way to iterate over a dataset in multiple epochs is to use the
Dataset.repeat()
transformation. First we create a dataset of titanic data:
titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)
def plot_batch_sizes(ds):
batch_sizes = [batch.shape[0] for batch in ds]
plt.bar(range(len(batch_sizes)), batch_sizes)
plt.xlabel('Batch number')
plt.ylabel('Batch size')
Applying the Dataset.repeat()
transformation with no arguments will repeat
the input indefinitely.
The Dataset.repeat
transformation concatenates its
arguments without signaling the end of one epoch and the beginning of the next
epoch. Because of this a Dataset.batch
applied after Dataset.repeat
will yield batches that stradle epoch boundaries:
titanic_batches = titanic_lines.repeat(3).batch(128)
plot_batch_sizes(titanic_batches)
If you need clear epoch separation, put Dataset.batch
before the repeat:
titanic_batches = titanic_lines.batch(128).repeat(3)
plot_batch_sizes(titanic_batches)
If you would like to perform a custom computation (e.g. to collect statistics) at the end of each epoch then it's simplest to restart the dataset iteration on each epoch:
epochs = 3
dataset = titanic_lines.batch(128)
for epoch in range(epochs):
for batch in dataset:
print(batch.shape)
print("End of epoch: ", epoch)
(128,) (128,) (128,) (128,) (116,) End of epoch: 0 (128,) (128,) (128,) (128,) (116,) End of epoch: 1 (128,) (128,) (128,) (128,) (116,) End of epoch: 2
Randomly shuffling input data
The Dataset.shuffle()
transformation maintains a fixed-size
buffer and chooses the next element uniformly at random from that buffer.
Add an index to the dataset so you can see the effect:
lines = tf.data.TextLineDataset(titanic_file)
counter = tf.data.experimental.Counter()
dataset = tf.data.Dataset.zip((counter, lines))
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(20)
dataset
<BatchDataset shapes: ((None,), (None,)), types: (tf.int64, tf.string)>
Since the buffer_size
is 100, and the batch size is 20, the first batch contains no elements with an index over 120.
n,line_batch = next(iter(dataset))
print(n.numpy())
[ 68 60 46 94 98 102 0 84 1 57 79 90 100 51 24 2 8 82 48 77]
As with Dataset.batch
the order relative to Dataset.repeat
matters.
Dataset.shuffle
doesn't signal the end of an epoch until the shuffle buffer is empty. So a shuffle placed before a repeat will show every element of one epoch before moving to the next:
dataset = tf.data.Dataset.zip((counter, lines))
shuffled = dataset.shuffle(buffer_size=100).batch(10).repeat(2)
print("Here are the item ID's near the epoch boundary:\n")
for n, line_batch in shuffled.skip(60).take(5):
print(n.numpy())
Here are the item ID's near the epoch boundary: [613 557 583 464 497 554 289 589 623 438] [ 48 600 605 551 574 515 282 549 556 519] [500 546 421 612 594 430 614 76] [ 90 70 59 1 64 51 69 102 91 3] [ 54 105 49 95 32 50 110 61 67 47]
shuffle_repeat = [n.numpy().mean() for n, line_batch in shuffled]
plt.plot(shuffle_repeat, label="shuffle().repeat()")
plt.ylabel("Mean item ID")
plt.legend()
<matplotlib.legend.Legend at 0x7fd0902fb470>
But a repeat before a shuffle mixes the epoch boundaries together:
dataset = tf.data.Dataset.zip((counter, lines))
shuffled = dataset.repeat(2).shuffle(buffer_size=100).batch(10)
print("Here are the item ID's near the epoch boundary:\n")
for n, line_batch in shuffled.skip(55).take(15):
print(n.numpy())
Here are the item ID's near the epoch boundary: [ 4 507 409 558 557 247 537 562 162 24] [524 28 612 21 499 622 452 387 16 588] [403 589 306 31 7 571 12 574 582 578] [496 15 43 54 498 466 20 268 56 526] [594 624 47 41 51 32 30 34 35 520] [618 36 532 1 617 57 368 58 580 605] [ 2 25 66 568 22 37 405 461 538 74] [ 6 42 10 40 29 11 13 97 606 80] [512 59 68 82 69 98 77 62 76 300] [ 18 73 87 113 3 100 70 14 49 614] [625 94 495 95 110 33 595 602 101 105] [ 9 0 107 86 112 60 122 527 133 109] [ 46 26 81 27 490 627 83 53 413 148] [ 64 141 142 38 129 517 154 143 131 63] [ 17 149 609 45 139 71 145 443 88 457]
repeat_shuffle = [n.numpy().mean() for n, line_batch in shuffled]
plt.plot(shuffle_repeat, label="shuffle().repeat()")
plt.plot(repeat_shuffle, label="repeat().shuffle()")
plt.ylabel("Mean item ID")
plt.legend()
<matplotlib.legend.Legend at 0x7fd090342da0>
Preprocessing data
The Dataset.map(f)
transformation produces a new dataset by applying a given
function f
to each element of the input dataset. It is based on the
map()
function
that is commonly applied to lists (and other structures) in functional
programming languages. The function f
takes the tf.Tensor
objects that
represent a single element in the input, and returns the tf.Tensor
objects
that will represent a single element in the new dataset. Its implementation uses
standard TensorFlow operations to transform one element into another.
This section covers common examples of how to use Dataset.map()
.
Decoding image data and resizing it
When training a neural network on real-world image data, it is often necessary to convert images of different sizes to a common size, so that they may be batched into a fixed size.
Rebuild the flower filenames dataset:
list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))
Write a function that manipulates the dataset elements.
# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def parse_image(filename):
parts = tf.strings.split(file_path, '/')
label = parts[-2]
image = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [128, 128])
return image, label
Test that it works.
file_path = next(iter(list_ds))
image, label = parse_image(file_path)
def show(image, label):
plt.figure()
plt.imshow(image)
plt.title(label.numpy().decode('utf-8'))
plt.axis('off')
show(image, label)
Map it over the dataset.
images_ds = list_ds.map(parse_image)
for image, label in images_ds.take(2):
show(image, label)
Applying arbitrary Python logic
For performance reasons, we encourage you to use TensorFlow operations for
preprocessing your data whenever possible. However, it is sometimes useful to
call external Python libraries when parsing your input data. You can use the tf.py_function()
operation in a Dataset.map()
transformation.
For example, if you want to apply a random rotation, the tf.image
module only has tf.image.rot90
, which is not very useful for image augmentation.
To demonstrate tf.py_function
, try using the scipy.ndimage.rotate
function instead:
import scipy.ndimage as ndimage
def random_rotate_image(image):
image = ndimage.rotate(image, np.random.uniform(-30, 30), reshape=False)
return image
image, label = next(iter(images_ds))
image = random_rotate_image(image)
show(image, label)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
To use this function with Dataset.map
the same caveats apply as with Dataset.from_generator
, you need to describe the return shapes and types when you apply the function:
def tf_random_rotate_image(image, label):
im_shape = image.shape
[image,] = tf.py_function(random_rotate_image, [image], [tf.float32])
image.set_shape(im_shape)
return image, label
rot_ds = images_ds.map(tf_random_rotate_image)
for image, label in rot_ds.take(2):
show(image, label)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Parsing tf.Example
protocol buffer messages
Many input pipelines extract tf.train.Example
protocol buffer messages from a
TFRecord format. Each tf.train.Example
record contains one or more "features",
and the input pipeline typically converts these features into tensors.
fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])
dataset
<TFRecordDatasetV2 shapes: (), types: tf.string>
You can work with tf.train.Example
protos outside of a tf.data.Dataset
to understand the data:
raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())
feature = parsed.features.feature
raw_img = feature['image/encoded'].bytes_list.value[0]
img = tf.image.decode_png(raw_img)
plt.imshow(img)
plt.axis('off')
_ = plt.title(feature["image/text"].bytes_list.value[0])
raw_example = next(iter(dataset))
def tf_parse(eg):
example = tf.io.parse_example(
eg[tf.newaxis], {
'image/encoded': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
'image/text': tf.io.FixedLenFeature(shape=(), dtype=tf.string)
})
return example['image/encoded'][0], example['image/text'][0]
img, txt = tf_parse(raw_example)
print(txt.numpy())
print(repr(img.numpy()[:20]), "...")
b'Rue Perreyon' b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x02X' ...
decoded = dataset.map(tf_parse)
decoded
<MapDataset shapes: ((), ()), types: (tf.string, tf.string)>
image_batch, text_batch = next(iter(decoded.batch(10)))
image_batch.shape
TensorShape([10])
Time series windowing
For an end to end time series example see: Time series forecasting.
Time series data is often organized with the time axis intact.
Use a simple Dataset.range
to demonstrate:
range_ds = tf.data.Dataset.range(100000)
Typically, models based on this sort of data will want a contiguous time slice.
The simplest approach would be to batch the data:
Using batch
batches = range_ds.batch(10, drop_remainder=True)
for batch in batches.take(5):
print(batch.numpy())
[0 1 2 3 4 5 6 7 8 9] [10 11 12 13 14 15 16 17 18 19] [20 21 22 23 24 25 26 27 28 29] [30 31 32 33 34 35 36 37 38 39] [40 41 42 43 44 45 46 47 48 49]
Or to make dense predictions one step into the future, you might shift the features and labels by one step relative to each other:
def dense_1_step(batch):
# Shift features and labels one step relative to each other.
return batch[:-1], batch[1:]
predict_dense_1_step = batches.map(dense_1_step)
for features, label in predict_dense_1_step.take(3):
print(features.numpy(), " => ", label.numpy())
[0 1 2 3 4 5 6 7 8] => [1 2 3 4 5 6 7 8 9] [10 11 12 13 14 15 16 17 18] => [11 12 13 14 15 16 17 18 19] [20 21 22 23 24 25 26 27 28] => [21 22 23 24 25 26 27 28 29]
To predict a whole window instead of a fixed offset you can split the batches into two parts:
batches = range_ds.batch(15, drop_remainder=True)
def label_next_5_steps(batch):
return (batch[:-5], # Take the first 5 steps
batch[-5:]) # take the remainder
predict_5_steps = batches.map(label_next_5_steps)
for features, label in predict_5_steps.take(3):
print(features.numpy(), " => ", label.numpy())
[0 1 2 3 4 5 6 7 8 9] => [10 11 12 13 14] [15 16 17 18 19 20 21 22 23 24] => [25 26 27 28 29] [30 31 32 33 34 35 36 37 38 39] => [40 41 42 43 44]
To allow some overlap between the features of one batch and the labels of another, use Dataset.zip
:
feature_length = 10
label_length = 5
features = range_ds.batch(feature_length, drop_remainder=True)
labels = range_ds.batch(feature_length).skip(1).map(lambda labels: labels[:-5])
predict_5_steps = tf.data.Dataset.zip((features, labels))
for features, label in predict_5_steps.take(3):
print(features.numpy(), " => ", label.numpy())
[0 1 2 3 4 5 6 7 8 9] => [10 11 12 13 14] [10 11 12 13 14 15 16 17 18 19] => [20 21 22 23 24] [20 21 22 23 24 25 26 27 28 29] => [30 31 32 33 34]
Using window
While using Dataset.batch
works, there are situations where you may need finer control. The Dataset.window
method gives you complete control, but requires some care: it returns a Dataset
of Datasets
. See Dataset structure for details.
window_size = 5
windows = range_ds.window(window_size, shift=1)
for sub_ds in windows.take(5):
print(sub_ds)
<_VariantDataset shapes: (), types: tf.int64> <_VariantDataset shapes: (), types: tf.int64> <_VariantDataset shapes: (), types: tf.int64> <_VariantDataset shapes: (), types: tf.int64> <_VariantDataset shapes: (), types: tf.int64>
The Dataset.flat_map
method can take a dataset of datasets and flatten it into a single dataset:
for x in windows.flat_map(lambda x: x).take(30):
print(x.numpy(), end=' ')
WARNING:tensorflow:Entity <function <lambda> at 0x7fd0747c3ea0> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Failed to parse source code of <function <lambda> at 0x7fd0747c3ea0>, which Python reported as: for x in windows.flat_map(lambda x: x).take(30): If this is a lambda function, the error may be avoided by creating the lambda in a standalone statement. WARNING: Entity <function <lambda> at 0x7fd0747c3ea0> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: Failed to parse source code of <function <lambda> at 0x7fd0747c3ea0>, which Python reported as: for x in windows.flat_map(lambda x: x).take(30): If this is a lambda function, the error may be avoided by creating the lambda in a standalone statement. 0 1 2 3 4 1 2 3 4 5 2 3 4 5 6 3 4 5 6 7 4 5 6 7 8 5 6 7 8 9
In nearly all cases, you will want to .batch
the dataset first:
def sub_to_batch(sub):
return sub.batch(window_size, drop_remainder=True)
for example in windows.flat_map(sub_to_batch).take(5):
print(example.numpy())
[0 1 2 3 4] [1 2 3 4 5] [2 3 4 5 6] [3 4 5 6 7] [4 5 6 7 8]
Now, you can see that the shift
argument controls how much each window moves over.
Putting this together you might write this function:
def make_window_dataset(ds, window_size=5, shift=1, stride=1):
windows = ds.window(window_size, shift=shift, stride=stride)
def sub_to_batch(sub):
return sub.batch(window_size, drop_remainder=True)
windows = windows.flat_map(sub_to_batch)
return windows
ds = make_window_dataset(range_ds, window_size=10, shift = 5, stride=3)
for example in ds.take(10):
print(example.numpy())
[ 0 3 6 9 12 15 18 21 24 27] [ 5 8 11 14 17 20 23 26 29 32] [10 13 16 19 22 25 28 31 34 37] [15 18 21 24 27 30 33 36 39 42] [20 23 26 29 32 35 38 41 44 47] [25 28 31 34 37 40 43 46 49 52] [30 33 36 39 42 45 48 51 54 57] [35 38 41 44 47 50 53 56 59 62] [40 43 46 49 52 55 58 61 64 67] [45 48 51 54 57 60 63 66 69 72]
Then it's easy to extract labels, as before:
dense_labels_ds = ds.map(dense_1_step)
for inputs,labels in dense_labels_ds.take(3):
print(inputs.numpy(), "=>", labels.numpy())
[ 0 3 6 9 12 15 18 21 24] => [ 3 6 9 12 15 18 21 24 27] [ 5 8 11 14 17 20 23 26 29] => [ 8 11 14 17 20 23 26 29 32] [10 13 16 19 22 25 28 31 34] => [13 16 19 22 25 28 31 34 37]
Resampling
When working with a dataset that is very class-imbalanced, you may want to resample the dataset. tf.data
provides two methods to do this. The credit card fraud dataset is a good example of this sort of problem.
zip_path = tf.keras.utils.get_file(
origin='https://storage.googleapis.com/download.tensorflow.org/data/creditcard.zip',
fname='creditcard.zip',
extract=True)
csv_path = zip_path.replace('.zip', '.csv')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/creditcard.zip 69156864/69155632 [==============================] - 2s 0us/step
creditcard_ds = tf.data.experimental.make_csv_dataset(
csv_path, batch_size=1024, label_name="Class",
# Set the column types: 30 floats and an int.
column_defaults=[float()]*30+[int()])
Now, check the distribution of classes, it is highly skewed:
def count(counts, batch):
features, labels = batch
class_1 = labels == 1
class_1 = tf.cast(class_1, tf.int32)
class_0 = labels == 0
class_0 = tf.cast(class_0, tf.int32)
counts['class_0'] += tf.reduce_sum(class_0)
counts['class_1'] += tf.reduce_sum(class_1)
return counts
counts = creditcard_ds.take(10).reduce(
initial_state={'class_0': 0, 'class_1': 0},
reduce_func = count)
counts = np.array([counts['class_0'].numpy(),
counts['class_1'].numpy()]).astype(np.float32)
fractions = counts/counts.sum()
print(fractions)
[0.9955 0.0045]
A common approach to training with an imbalanced dataset is to balance it. tf.data
includes a few methods which enable this workflow:
Datasets sampling
One approach to resampling a dataset is to use sample_from_datasets
. This is more applicable when you have a separate data.Dataset
for each class.
Here, just use filter to generate them from the credit card fraud data:
negative_ds = (
creditcard_ds
.unbatch()
.filter(lambda features, label: label==0)
.repeat())
positive_ds = (
creditcard_ds
.unbatch()
.filter(lambda features, label: label==1)
.repeat())
WARNING:tensorflow:Entity <function <lambda> at 0x7fd0747437b8> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: expected exactly one node node, found [] WARNING: Entity <function <lambda> at 0x7fd0747437b8> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: expected exactly one node node, found [] WARNING:tensorflow:Entity <function <lambda> at 0x7fd074604598> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: expected exactly one node node, found [] WARNING: Entity <function <lambda> at 0x7fd074604598> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: expected exactly one node node, found []
for features, label in positive_ds.batch(10).take(1):
print(label.numpy())
[1 1 1 1 1 1 1 1 1 1]
To use tf.data.experimental.sample_from_datasets
pass the datasets, and the weight for each:
balanced_ds = tf.data.experimental.sample_from_datasets(
[negative_ds, positive_ds], [0.5, 0.5]).batch(10)
Now the dataset produces examples of each class with 50/50 probability:
for features, labels in balanced_ds.take(10):
print(labels.numpy())
[1 0 0 1 1 1 1 0 0 0] [0 1 0 0 1 0 0 0 0 0] [1 1 1 0 0 0 0 1 0 0] [0 1 1 1 1 0 0 1 0 0] [0 1 0 0 0 1 0 0 0 0] [0 0 0 0 0 0 0 0 0 0] [0 1 0 1 0 0 1 0 0 0] [0 0 0 0 1 0 0 1 0 0] [1 0 1 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 1 1 1]
Rejection resampling
One problem with the above experimental.sample_from_datasets
approach is that
is that it needs a separate tf.data.Dataset
per class. Using Dataset.filter
works, but results in all the data being loaded twice.
The data.experimental.rejection_resample
function can be applied to a dataset to rebalance it, while only loading it once. Elements will be dropped from the dataset to achieve balance.
data.experimental.rejection_resample
takes a class_func
argument. This class_func
is applied to each dataset element, and is used to determine which class an example belongs to for the purposes of balancing.
The elements of creditcard_ds
are already (features, label)
pairs. So the class_func
just needs to return those labels:
def class_func(features, label):
return label
The resampler also needs a target distribution, and optionally an initial distribution estimate:
resampler = tf.data.experimental.rejection_resample(
class_func, target_dist=[0.5, 0.5], initial_dist=fractions)
The resampler deals with individual examples, so you must unbatch
the dataset before applying the resampler:
resample_ds = creditcard_ds.unbatch().apply(resampler).batch(10)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/data/experimental/ops/resampling.py:151: Print (from tensorflow.python.ops.logging_ops) is deprecated and will be removed after 2018-08-20. Instructions for updating: Use tf.print instead of tf.Print. Note that tf.print returns a no-output operator that directly prints the output. Outside of defuns or eager mode, this operator will not be executed unless it is directly specified in session.run or used as a control dependency for other operators. This is only a concern in graph mode. Below is an example of how to ensure tf.print executes in graph mode:
The resampler returns creates (class, example)
pairs from the output of the class_func
. In this case, the example
was already a (feature, label)
pair, so use map
to drop the extra copy of the labels:
balanced_ds = resample_ds.map(lambda extra_label, features_and_label: features_and_label)
Now the dataset produces examples of each class with 50/50 probability:
for features, labels in balanced_ds.take(10):
print(labels.numpy())
[1 1 0 1 1 0 0 1 1 0] [1 1 0 0 0 1 1 1 0 0] [0 0 1 1 1 1 0 1 1 1] [0 1 0 0 0 0 1 1 0 1] [1 1 0 0 0 1 0 1 1 0] [0 1 1 0 1 1 0 1 1 1] [0 1 1 0 1 0 1 1 1 1] [1 1 0 0 0 0 0 1 1 1] [1 0 0 0 0 0 1 1 1 1] [0 1 0 0 1 0 1 0 0 1]
Using high-level APIs
tf.keras
The tf.keras
API simplifies many aspects of creating and executing machine
learning models. Its .fit()
and .evaluate()
and .predict()
APIs support datasets as inputs. Here is a quick dataset and model setup:
train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train
images = images/255.0
labels = labels.astype(np.int32)
fmnist_train_ds = tf.data.Dataset.from_tensor_slices((images, labels))
fmnist_train_ds = fmnist_train_ds.shuffle(5000).batch(32)
model = tf.keras.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])
Passing a dataset of (feature, label)
pairs is all that's needed for Model.fit
and Model.evaluate
:
model.fit(fmnist_train_ds, epochs=2)
WARNING:tensorflow:Layer sequential is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2. The layer has dtype float32 because it's dtype defaults to floatx. If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2. To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor. Epoch 1/2 1875/1875 [==============================] - 6s 3ms/step - loss: 0.6054 - accuracy: 0.7944 Epoch 2/2 1875/1875 [==============================] - 5s 2ms/step - loss: 0.4616 - accuracy: 0.8429 <tensorflow.python.keras.callbacks.History at 0x7fd0743fc400>
If you pass an infinite dataset, for example by calling Dataset.repeat()
, you just need to also pass the steps_per_epoch
argument:
model.fit(fmnist_train_ds.repeat(), epochs=2, steps_per_epoch=20)
Train for 20 steps Epoch 1/2 20/20 [==============================] - 0s 18ms/step - loss: 0.4281 - accuracy: 0.8609 Epoch 2/2 20/20 [==============================] - 0s 2ms/step - loss: 0.4131 - accuracy: 0.8703 <tensorflow.python.keras.callbacks.History at 0x7fd090092128>
For evaluation you can pass the number of evaluation steps:
loss, accuracy = model.evaluate(fmnist_train_ds)
print("Loss :", loss)
print("Accuracy :", accuracy)
1875/1875 [==============================] - 4s 2ms/step - loss: 0.4363 - accuracy: 0.8501 Loss : 0.4362662051320076 Accuracy : 0.85008335
For long datasets, set the number of steps to evaluate:
loss, accuracy = model.evaluate(fmnist_train_ds.repeat(), steps=10)
print("Loss :", loss)
print("Accuracy :", accuracy)
10/10 [==============================] - 0s 4ms/step - loss: 0.4099 - accuracy: 0.8750 Loss : 0.4099284529685974 Accuracy : 0.875
The labels are not required in when calling Model.predict
.
predict_ds = tf.data.Dataset.from_tensor_slices(images).batch(32)
result = model.predict(predict_ds, steps = 10)
print(result.shape)
(320, 10)
But the labels are ignored if you do pass a dataset containing them:
result = model.predict(fmnist_train_ds, steps = 10)
print(result.shape)
(320, 10)
tf.estimator
To use a Dataset
in the input_fn
of a tf.estimator.Estimator
, simply
return the Dataset
from the input_fn
and the framework will take care of consuming its elements
for you. For example:
import tensorflow_datasets as tfds
def train_input_fn():
titanic = tf.data.experimental.make_csv_dataset(
titanic_file, batch_size=32,
label_name="survived")
titanic_batches = (
titanic.cache().repeat().shuffle(500)
.prefetch(tf.data.experimental.AUTOTUNE))
return titanic_batches
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third'])
age = tf.feature_column.numeric_column('age')
import tempfile
model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
model_dir=model_dir,
feature_columns=[embark, cls, age],
n_classes=2
)
INFO:tensorflow:Using default config. INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpaaj9jgyb', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fd05016ce80>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
model = model.train(input_fn=train_input_fn, steps=100)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version. Instructions for updating: If using Keras pass *_constraint arguments to layers. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/feature_column/feature_column_v2.py:518: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version. Instructions for updating: Please use `layer.add_weight` method instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/canned/linear.py:308: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.cast` instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/keras/optimizer_v2/ftrl.py:143: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpaaj9jgyb/model.ckpt. INFO:tensorflow:loss = 0.6931472, step = 0 INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmpaaj9jgyb/model.ckpt. INFO:tensorflow:Loss for final step: 0.55556.
result = model.evaluate(train_input_fn, steps=10)
for key, value in result.items():
print(key, ":", value)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2019-11-14T02:22:53Z INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpaaj9jgyb/model.ckpt-100 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Finished evaluation at 2019-11-14-02:22:54 INFO:tensorflow:Saving dict for global step 100: accuracy = 0.703125, accuracy_baseline = 0.584375, auc = 0.7589563, auc_precision_recall = 0.6723181, average_loss = 0.5924618, global_step = 100, label/mean = 0.415625, loss = 0.5924618, precision = 0.7638889, prediction/mean = 0.35414353, recall = 0.41353384 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmpaaj9jgyb/model.ckpt-100 accuracy : 0.703125 accuracy_baseline : 0.584375 auc : 0.7589563 auc_precision_recall : 0.6723181 average_loss : 0.5924618 label/mean : 0.415625 loss : 0.5924618 precision : 0.7638889 prediction/mean : 0.35414353 recall : 0.41353384 global_step : 100
for pred in model.predict(train_input_fn):
for key, value in pred.items():
print(key, ":", value)
break
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpaaj9jgyb/model.ckpt-100 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. logits : [-1.3411] logistic : [0.2073] probabilities : [0.7927 0.2073] class_ids : [0] classes : [b'0'] all_class_ids : [0 1] all_classes : [b'0' b'1']