TensorFlow 2.0 Beta is available Learn more

Load images with tf.data

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial provides a simple example of how to load an image dataset using tf.data.

The dataset used in this example is distributed as directories of images, with one class of image per directory.

Setup

from __future__ import absolute_import, division, print_function, unicode_literals

!pip install -q tensorflow==2.0.0-beta1
import tensorflow as tf
AUTOTUNE = tf.data.experimental.AUTOTUNE

Download and inspect the dataset

Retrieve the images

Before you start any training, you will need a set of images to teach the network about the new classes you want to recognize. You have already created an archive of creative-commons licensed flower photos to use initially:

import pathlib
data_root_orig = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
                                         fname='flower_photos', untar=True)
data_root = pathlib.Path(data_root_orig)
print(data_root)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 5s 0us/step
/home/kbuilder/.keras/datasets/flower_photos

After downloading 218MB, you should now have a copy of the flower photos available:

for item in data_root.iterdir():
  print(item)
/home/kbuilder/.keras/datasets/flower_photos/LICENSE.txt
/home/kbuilder/.keras/datasets/flower_photos/daisy
/home/kbuilder/.keras/datasets/flower_photos/roses
/home/kbuilder/.keras/datasets/flower_photos/sunflowers
/home/kbuilder/.keras/datasets/flower_photos/dandelion
/home/kbuilder/.keras/datasets/flower_photos/tulips
import random
all_image_paths = list(data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]
random.shuffle(all_image_paths)

image_count = len(all_image_paths)
image_count
3670
all_image_paths[:10]
['/home/kbuilder/.keras/datasets/flower_photos/sunflowers/9448615838_04078d09bf_n.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/tulips/3002863623_cd83d6e634.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/sunflowers/3062794421_295f8c2c4e.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/dandelion/5600240736_4a90c10579_n.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/roses/3422228549_f147d6e642.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/dandelion/16970837587_4a9d8500d7.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/dandelion/14306875733_61d71c64c0_n.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/dandelion/155646858_9a8b5e8fc8.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/tulips/9030467406_05e93ff171_n.jpg',
 '/home/kbuilder/.keras/datasets/flower_photos/roses/537207677_f96a0507bb.jpg']

Inspect the images

Now let's have a quick look at a couple of the images, so you know what you are dealing with:

import os
attributions = (data_root/"LICENSE.txt").open(encoding='utf-8').readlines()[4:]
attributions = [line.split(' CC-BY') for line in attributions]
attributions = dict(attributions)
import IPython.display as display

def caption_image(image_path):
    image_rel = pathlib.Path(image_path).relative_to(data_root)
    return "Image (CC BY 2.0) " + ' - '.join(attributions[str(image_rel)].split(' - ')[:-1])
for n in range(3):
  image_path = random.choice(all_image_paths)
  display.display(display.Image(image_path))
  print(caption_image(image_path))
  print()

jpeg

jpeg

Image (CC BY 2.0)  by John Liu

Image (CC BY 2.0)  by OliBac

jpeg

Image (CC BY 2.0)  by Ron Cogswell

Determine the label for each image

List the available labels:

label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
label_names
['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']

Assign an index to each label:

label_to_index = dict((name, index) for index, name in enumerate(label_names))
label_to_index
{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}

Create a list of every file, and its label index:

all_image_labels = [label_to_index[pathlib.Path(path).parent.name]
                    for path in all_image_paths]

print("First 10 labels indices: ", all_image_labels[:10])
First 10 labels indices:  [3, 4, 3, 1, 2, 1, 1, 1, 4, 2]

Load and format the images

TensorFlow includes all the tools you need to load and process images:

img_path = all_image_paths[0]
img_path
'/home/kbuilder/.keras/datasets/flower_photos/sunflowers/9448615838_04078d09bf_n.jpg'

Here is the raw data:

img_raw = tf.io.read_file(img_path)
print(repr(img_raw)[:100]+"...")
<tf.Tensor: id=1, shape=(), dtype=string, numpy=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x...

Decode it into an image tensor:

img_tensor = tf.image.decode_image(img_raw)

print(img_tensor.shape)
print(img_tensor.dtype)
(274, 320, 3)
<dtype: 'uint8'>

Resize it for your model:

img_final = tf.image.resize(img_tensor, [192, 192])
img_final = img_final/255.0
print(img_final.shape)
print(img_final.numpy().min())
print(img_final.numpy().max())
(192, 192, 3)
0.0
1.0

Wrap up these up in simple functions for later.

def preprocess_image(image):
  image = tf.image.decode_jpeg(image, channels=3)
  image = tf.image.resize(image, [192, 192])
  image /= 255.0  # normalize to [0,1] range

  return image
def load_and_preprocess_image(path):
  image = tf.io.read_file(path)
  return preprocess_image(image)
import matplotlib.pyplot as plt

image_path = all_image_paths[0]
label = all_image_labels[0]

plt.imshow(load_and_preprocess_image(img_path))
plt.grid(False)
plt.xlabel(caption_image(img_path))
plt.title(label_names[label].title())
print()

Build a tf.data.Dataset

A dataset of images

The easiest way to build a tf.data.Dataset is using the from_tensor_slices method.

Slicing the array of strings, results in a dataset of strings:

path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)

The shapes and types describe the content of each item in the dataset. In this case it is a set of scalar binary-strings

print(path_ds)
<TensorSliceDataset shapes: (), types: tf.string>

Now create a new dataset that loads and formats images on the fly by mapping preprocess_image over the dataset of paths.

image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
import matplotlib.pyplot as plt

plt.figure(figsize=(8,8))
for n, image in enumerate(image_ds.take(4)):
  plt.subplot(2,2,n+1)
  plt.imshow(image)
  plt.grid(False)
  plt.xticks([])
  plt.yticks([])
  plt.xlabel(caption_image(all_image_paths[n]))
  plt.show()

png

png

png

png

A dataset of (image, label) pairs

Using the same from_tensor_slices method you can build a dataset of labels:

label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels, tf.int64))
for label in label_ds.take(10):
  print(label_names[label.numpy()])
sunflowers
tulips
sunflowers
dandelion
roses
dandelion
dandelion
dandelion
tulips
roses

Since the datasets are in the same order you can just zip them together to get a dataset of (image, label) pairs:

image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))

The new dataset's shapes and types are tuples of shapes and types as well, describing each field:

print(image_label_ds)
<ZipDataset shapes: ((192, 192, 3), ()), types: (tf.float32, tf.int64)>
ds = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))

# The tuples are unpacked into the positional arguments of the mapped function
def load_and_preprocess_from_path_label(path, label):
  return load_and_preprocess_image(path), label

image_label_ds = ds.map(load_and_preprocess_from_path_label)
image_label_ds
<MapDataset shapes: ((192, 192, 3), ()), types: (tf.float32, tf.int32)>

Basic methods for training

To train a model with this dataset you will want the data:

  • To be well shuffled.
  • To be batched.
  • To repeat forever.
  • Batches to be available as soon as possible.

These features can be easily added using the tf.data api.

BATCH_SIZE = 32

# Setting a shuffle buffer size as large as the dataset ensures that the data is
# completely shuffled.
ds = image_label_ds.shuffle(buffer_size=image_count)
ds = ds.repeat()
ds = ds.batch(BATCH_SIZE)
# `prefetch` lets the dataset fetch batches in the background while the model is training.
ds = ds.prefetch(buffer_size=AUTOTUNE)
ds
<PrefetchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int32)>

There are a few things to note here:

  1. The order is important.

    • A .shuffle after a .repeat would shuffle items across epoch boundaries (some items will be seen twice before others are seen at all).
    • A .shuffle after a .batch would shuffle the order of the batches, but not shuffle the items across batches.
  2. You use a buffer_size the same size as the dataset for a full shuffle. Up to the dataset size, large values provide better randomization, but use more memory.

  3. The shuffle buffer is filled before any elements are pulled from it. So a large buffer_size may cause a delay when your Dataset is starting.

  4. The shuffeled dataset doesn't report the end of a dataset until the shuffle-buffer is completely empty. The Dataset is restarted by .repeat, causing another wait for the shuffle-buffer to be filled.

This last point can be addressed by using the tf.data.Dataset.apply method with the fused tf.data.experimental.shuffle_and_repeat function:

ds = image_label_ds.apply(
  tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
ds = ds.batch(BATCH_SIZE)
ds = ds.prefetch(buffer_size=AUTOTUNE)
ds
WARNING: Logging before flag parsing goes to stderr.
W0713 00:53:03.145668 139726670800640 deprecation.py:323] From <ipython-input-31-4dc713bd4d84>:2: shuffle_and_repeat (from tensorflow.python.data.experimental.ops.shuffle_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.shuffle(buffer_size, seed)` followed by `tf.data.Dataset.repeat(count)`. Static tf.data optimizations will take care of using the fused implementation.

<PrefetchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int32)>

Pipe the dataset to a model

Fetch a copy of MobileNet v2 from tf.keras.applications.

This will be used for a simple transfer learning example.

Set the MobileNet weights to be non-trainable:

mobile_net = tf.keras.applications.MobileNetV2(input_shape=(192, 192, 3), include_top=False)
mobile_net.trainable=False
Downloading data from https://github.com/JonathanCMitchell/mobilenet_v2_keras/releases/download/v1.1/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_192_no_top.h5
9412608/9406464 [==============================] - 1s 0us/step

This model expects its input to be normalized to the [-1,1] range:

help(keras_applications.mobilenet_v2.preprocess_input)
...
This function applies the "Inception" preprocessing which converts
the RGB values from [0, 255] to [-1, 1]
...

Before you pass the input to the MobilNet model, you need to convert it from a range of [0,1] to [-1,1]:

def change_range(image,label):
  return 2*image-1, label

keras_ds = ds.map(change_range)

The MobileNet returns a 6x6 spatial grid of features for each image.

Pass it a batch of images to see:

# The dataset may take a few seconds to start, as it fills its shuffle buffer.
image_batch, label_batch = next(iter(keras_ds))
feature_map_batch = mobile_net(image_batch)
print(feature_map_batch.shape)
(32, 6, 6, 1280)

Build a model wrapped around MobileNet and use tf.keras.layers.GlobalAveragePooling2D to average over those space dimensions before the output tf.keras.layers.Dense layer:

model = tf.keras.Sequential([
  mobile_net,
  tf.keras.layers.GlobalAveragePooling2D(),
  tf.keras.layers.Dense(len(label_names), activation = 'softmax')])

Now it produces outputs of the expected shape:

logit_batch = model(image_batch).numpy()

print("min logit:", logit_batch.min())
print("max logit:", logit_batch.max())
print()

print("Shape:", logit_batch.shape)
min logit: 0.012025281
max logit: 0.6102859

Shape: (32, 5)

Compile the model to describe the training procedure:

model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss='sparse_categorical_crossentropy',
              metrics=["accuracy"])

There are 2 trainable variables - the Dense weights and bias:

len(model.trainable_variables)
2
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
mobilenetv2_1.00_192 (Model) (None, 6, 6, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 5)                 6405      
=================================================================
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________

You are ready to train the model.

Note that for demonstration purposes you will only run 3 steps per epoch, but normally you would specify the real number of steps, as defined below, before passing it to model.fit():

steps_per_epoch=tf.math.ceil(len(all_image_paths)/BATCH_SIZE).numpy()
steps_per_epoch
115.0
model.fit(ds, epochs=1, steps_per_epoch=3)
W0713 00:53:26.495867 139726670800640 deprecation.py:323] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

3/3 [==============================] - 10s 3s/step - loss: 1.9158 - accuracy: 0.1771

<tensorflow.python.keras.callbacks.History at 0x7f142c0413c8>

Performance

The simple pipeline used above reads each file individually, on each epoch. This is fine for local training on CPU, but may not be sufficient for GPU training and is totally inappropriate for any sort of distributed training.

To investigate, first build a simple function to check the performance of our datasets:

import time
default_timeit_steps = 2*steps_per_epoch+1

def timeit(ds, steps=default_timeit_steps):
  overall_start = time.time()
  # Fetch a single batch to prime the pipeline (fill the shuffle buffer),
  # before starting the timer
  it = iter(ds.take(steps+1))
  next(it)

  start = time.time()
  for i,(images,labels) in enumerate(it):
    if i%10 == 0:
      print('.',end='')
  print()
  end = time.time()

  duration = end-start
  print("{} batches: {} s".format(steps, duration))
  print("{:0.5f} Images/s".format(BATCH_SIZE*steps/duration))
  print("Total time: {}s".format(end-overall_start))

The performance of the current dataset is:

ds = image_label_ds.apply(
  tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
ds = ds.batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
ds
<PrefetchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int32)>
timeit(ds)
........................
231.0 batches: 15.309242010116577 s
482.84559 Images/s
Total time: 22.376640558242798s

Cache

Use tf.data.Dataset.cache to easily cache calculations across epochs. This is very efficient, especially when the data fits in memory.

Here the images are cached, after being pre-precessed (decoded and resized):

ds = image_label_ds.cache()
ds = ds.apply(
  tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
ds = ds.batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
ds
<PrefetchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int32)>
timeit(ds)
........................
231.0 batches: 0.7133467197418213 s
10362.42236 Images/s
Total time: 8.366860389709473s

One disadvantage to using an in memory cache is that the cache must be rebuilt on each run, giving the same startup delay each time the dataset is started:

timeit(ds)
........................
231.0 batches: 0.8372540473937988 s
8828.86147 Images/s
Total time: 8.99988865852356s

If the data doesn't fit in memory, use a cache file:

ds = image_label_ds.cache(filename='./cache.tf-data')
ds = ds.apply(
  tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
ds = ds.batch(BATCH_SIZE).prefetch(1)
ds
<PrefetchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int32)>
timeit(ds)
........................
231.0 batches: 3.6042563915252686 s
2050.90848 Images/s
Total time: 13.559855222702026s

The cache file also has the advantage that it can be used to quickly restart the dataset without rebuilding the cache. Note how much faster it is the second time:

timeit(ds)
........................
231.0 batches: 3.0372560024261475 s
2433.77575 Images/s
Total time: 4.762259244918823s

TFRecord File

Raw image data

TFRecord files are a simple format to store a sequence of binary blobs. By packing multiple examples into the same file, TensorFlow is able to read multiple examples at once, which is especially important for performance when using a remote storage service such as GCS.

First, build a TFRecord file from the raw image data:

image_ds = tf.data.Dataset.from_tensor_slices(all_image_paths).map(tf.io.read_file)
tfrec = tf.data.experimental.TFRecordWriter('images.tfrec')
tfrec.write(image_ds)

Next, build a dataset that reads from the TFRecord file and decodes/reformats the images using the preprocess_image function you defined earlier:

image_ds = tf.data.TFRecordDataset('images.tfrec').map(preprocess_image)

Zip that dataset with the labels dataset you defined earlier to get the expected (image,label) pairs:

ds = tf.data.Dataset.zip((image_ds, label_ds))
ds = ds.apply(
  tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
ds=ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)
ds
<PrefetchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int64)>
timeit(ds)
........................
231.0 batches: 15.116378784179688 s
489.00601 Images/s
Total time: 22.38015913963318s

This is slower than the cache version because you have not cached the preprocessing.

Serialized Tensors

To save some preprocessing to the TFRecord file, first make a dataset of the processed images, as before:

paths_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
image_ds = paths_ds.map(load_and_preprocess_image)
image_ds
<MapDataset shapes: (192, 192, 3), types: tf.float32>

Now instead of a dataset of .jpeg strings, you have a dataset of tensors.

To serialize this to a TFRecord file you first convert the dataset of tensors to a dataset of strings:

ds = image_ds.map(tf.io.serialize_tensor)
ds
<MapDataset shapes: (), types: tf.string>
tfrec = tf.data.experimental.TFRecordWriter('images.tfrec')
tfrec.write(ds)

With the preprocessing cached, data can be loaded from the TFrecord file quite efficiently - just remember to de-serialize tensor before using it:

ds = tf.data.TFRecordDataset('images.tfrec')

def parse(x):
  result = tf.io.parse_tensor(x, out_type=tf.float32)
  result = tf.reshape(result, [192, 192, 3])
  return result

ds = ds.map(parse, num_parallel_calls=AUTOTUNE)
ds
<ParallelMapDataset shapes: (192, 192, 3), types: tf.float32>

Now, add the labels and apply the same standard operations, as before:

ds = tf.data.Dataset.zip((ds, label_ds))
ds = ds.apply(
  tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
ds=ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)
ds
<PrefetchDataset shapes: ((None, 192, 192, 3), (None,)), types: (tf.float32, tf.int64)>
timeit(ds)
........................
231.0 batches: 2.657438039779663 s
2781.62647 Images/s
Total time: 3.732140064239502s