TF 2.0 is out! Get hands-on practice at TF World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

Load images

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial provides a simple example of how to load an image dataset using tf.data.

The dataset used in this example is distributed as directories of images, with one class of image per directory.

Setup

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
AUTOTUNE = tf.data.experimental.AUTOTUNE
import IPython.display as display
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
tf.__version__
'2.1.0-dev20191007'

Retrieve the images

Before you start any training, you will need a set of images to teach the network about the new classes you want to recognize. You can use an archive of creative-commons licensed flower photos from Google.

import pathlib
data_dir = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
                                         fname='flower_photos', untar=True)
data_dir = pathlib.Path(data_dir)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 5s 0us/step

After downloading (218MB), you should now have a copy of the flower photos available.

The directory contains 5 sub-directories, one per class:

image_count = len(list(data_dir.glob('*/*.jpg')))
image_count
3670
CLASS_NAMES = np.array([item.name for item in data_dir.glob('*') if item.name != "LICENSE.txt"])
CLASS_NAMES
array(['daisy', 'roses', 'dandelion', 'sunflowers', 'tulips'],
      dtype='<U10')

Each directory contains images of that type of flower. Here are some roses:

roses = list(data_dir.glob('roses/*'))

for image_path in roses[:3]:
    display.display(Image.open(str(image_path)))

png

png

png

Load using keras.preprocessing

A simple way to load images is to use tf.keras.preprocessing.

# The 1./255 is to convert from uint8 to float32 in range [0,1].
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)

Define some parameters for the loader:

BATCH_SIZE = 32
IMG_HEIGHT = 224
IMG_WIDTH = 224
STEPS_PER_EPOCH = np.ceil(image_count/BATCH_SIZE)
train_data_gen = image_generator.flow_from_directory(directory=str(data_dir),
                                                     batch_size=BATCH_SIZE,
                                                     shuffle=True,
                                                     target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                     classes = list(CLASS_NAMES))
Found 3670 images belonging to 5 classes.

Inspect a batch:

def show_batch(image_batch, label_batch):
  plt.figure(figsize=(10,10))
  for n in range(25):
      ax = plt.subplot(5,5,n+1)
      plt.imshow(image_batch[n])
      plt.title(CLASS_NAMES[label_batch[n]==1][0].title())
      plt.axis('off')
image_batch, label_batch = next(train_data_gen)
show_batch(image_batch, label_batch)

png

Load using tf.data

The above keras.preprocessing method is convienient, but has two downsides:

  1. It's slow. See the performance section below.
  2. It lacks fine-grained control.
  3. It is not well integrated with the rest of TensorFlow.

To load the files as a tf.data.Dataset first create a dataset of the file paths:

list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'))
for f in list_ds.take(5):
  print(f.numpy())
b'/home/kbuilder/.keras/datasets/flower_photos/roses/3494252600_29f26e3ff0_n.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/dandelion/3512879565_88dd8fc269_n.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/roses/16155980245_6ab8d7b888.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/dandelion/7465850028_cdfaae235a_n.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/roses/14943194730_f48b4d4547_n.jpg'

Write a short pure-tensorflow function that converts a file paths to an (image_data, label) pair:

def get_label(file_path):
  # convert the path to a list of path components
  parts = tf.strings.split(file_path, '/')
  # The second to last is the class-directory
  return parts[-2] == CLASS_NAMES
def decode_img(img):
  # convert the compressed string to a 3D uint8 tensor
  img = tf.image.decode_jpeg(img, channels=3)
  # Use `convert_image_dtype` to convert to floats in the [0,1] range.
  img = tf.image.convert_image_dtype(img, tf.float32)
  # resize the image to the desired size.
  return tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT])
def process_path(file_path):
  label = get_label(file_path)
  # load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  return img, label

Use Dataset.map to create a dataset of image, label pairs:

# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
labeled_ds = list_ds.map(process_path, num_parallel_calls=AUTOTUNE)
for image, label in labeled_ds.take(1):
  print("Image shape: ", image.numpy().shape)
  print("Label: ", label.numpy())
Image shape:  (224, 224, 3)
Label:  [False False False  True False]

Basic methods for training

To train a model with this dataset you will want the data:

  • To be well shuffled.
  • To be batched.
  • Batches to be available as soon as possible.

These features can be easily added using the tf.data api.



def prepare_for_training(ds, cache=True, shuffle_buffer_size=1000):
  # This is a small dataset, only load it once, and keep it in memory.
  # use `.cache(filename)` to cache preprocessing work for datasets that don't
  # fit in memory.
  if cache:
    if isinstance(cache, str):
      ds = ds.cache(cache)
    else:
      ds = ds.cache()

  ds = ds.shuffle(buffer_size=shuffle_buffer_size)

  # Repeat forever
  ds = ds.repeat()

  ds = ds.batch(BATCH_SIZE)

  # `prefetch` lets the dataset fetch batches in the background while the model
  # is training.
  ds = ds.prefetch(buffer_size=AUTOTUNE)

  return ds
train_ds = prepare_for_training(labeled_ds)

image_batch, label_batch = next(iter(train_ds))
show_batch(image_batch.numpy(), label_batch.numpy())

png

Performance

To investigate, first here's a function to check the performance of our datasets:

import time
default_timeit_steps = 1000

def timeit(ds, steps=default_timeit_steps):
  start = time.time()
  it = iter(ds)
  for i in range(steps):
    batch = next(it)
    if i%10 == 0:
      print('.',end='')
  print()
  end = time.time()

  duration = end-start
  print("{} batches: {} s".format(steps, duration))
  print("{:0.5f} Images/s".format(BATCH_SIZE*steps/duration))

Let's compare the speed of the two data generators:

# `keras.preprocessing`
timeit(train_data_gen)
....................................................................................................
1000 batches: 93.29608130455017 s
342.99404 Images/s
# `tf.data`
timeit(train_ds)
....................................................................................................
1000 batches: 7.579981565475464 s
4221.64615 Images/s

A large part of the performance gain comes from the use of .cache.

uncached_ds = prepare_for_training(labeled_ds, cache=False)
timeit(uncached_ds)
....................................................................................................
1000 batches: 25.684147119522095 s
1245.90472 Images/s

If the dataset doesn't fit in memory use a cache file to maintain some of the advantages:

filecache_ds = prepare_for_training(labeled_ds, cache="./flowers.tfcache")
timeit(filecache_ds)
....................................................................................................
1000 batches: 18.914026498794556 s
1691.86609 Images/s