|View on TensorFlow.org||Run in Google Colab||View source on GitHub||Download notebook|
This tutorial shows how to load and preprocess an image dataset in three ways:
- First, you will use high-level Keras preprocessing utilities (such as
tf.keras.utils.image_dataset_from_directory) and layers (such as
tf.keras.layers.Rescaling) to read a directory of images on disk.
- Next, you will write your own input pipeline from scratch using tf.data.
- Finally, you will download a dataset from the large catalog available in TensorFlow Datasets.
import numpy as np import os import PIL import PIL.Image import tensorflow as tf import tensorflow_datasets as tfds
2023-10-04 01:30:55.403011: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-10-04 01:30:55.403064: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-10-04 01:30:55.403101: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Download the flowers dataset
This tutorial uses a dataset of several thousand photos of flowers. The flowers dataset contains five sub-directories, one per class:
flowers_photos/ daisy/ dandelion/ roses/ sunflowers/ tulips/
import pathlib dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz" archive = tf.keras.utils.get_file(origin=dataset_url, extract=True) data_dir = pathlib.Path(archive).with_suffix('')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz 228813984/228813984 [==============================] - 1s 0us/step
After downloading (218MB), you should now have a copy of the flower photos available. There are 3,670 total images:
image_count = len(list(data_dir.glob('*/*.jpg'))) print(image_count)
Each directory contains images of that type of flower. Here are some roses:
roses = list(data_dir.glob('roses/*')) PIL.Image.open(str(roses))
roses = list(data_dir.glob('roses/*')) PIL.Image.open(str(roses))
Load data using a Keras utility
Let's load these images off disk using the helpful
Create a dataset
Define some parameters for the loader:
batch_size = 32 img_height = 180 img_width = 180
It's good practice to use a validation split when developing your model. You will use 80% of the images for training and 20% for validation.
train_ds = tf.keras.utils.image_dataset_from_directory( data_dir, validation_split=0.2, subset="training", seed=123, image_size=(img_height, img_width), batch_size=batch_size)
Found 3670 files belonging to 5 classes. Using 2936 files for training. 2023-10-04 01:31:01.556380: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices...
val_ds = tf.keras.utils.image_dataset_from_directory( data_dir, validation_split=0.2, subset="validation", seed=123, image_size=(img_height, img_width), batch_size=batch_size)
Found 3670 files belonging to 5 classes. Using 734 files for validation.
You can find the class names in the
class_names attribute on these datasets.
class_names = train_ds.class_names print(class_names)
['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
Visualize the data
Here are the first nine images from the training dataset.
import matplotlib.pyplot as plt plt.figure(figsize=(10, 10)) for images, labels in train_ds.take(1): for i in range(9): ax = plt.subplot(3, 3, i + 1) plt.imshow(images[i].numpy().astype("uint8")) plt.title(class_names[labels[i]]) plt.axis("off")
You can train a model using these datasets by passing them to
model.fit (shown later in this tutorial). If you like, you can also manually iterate over the dataset and retrieve batches of images:
for image_batch, labels_batch in train_ds: print(image_batch.shape) print(labels_batch.shape) break
(32, 180, 180, 3) (32,)
image_batch is a tensor of the shape
(32, 180, 180, 3). This is a batch of 32 images of shape
180x180x3 (the last dimension refers to color channels RGB). The
label_batch is a tensor of the shape
(32,), these are corresponding labels to the 32 images.
You can call
.numpy() on either of these tensors to convert them to a
Standardize the data
The RGB channel values are in the
[0, 255] range. This is not ideal for a neural network; in general you should seek to make your input values small.
Here, you will standardize values to be in the
[0, 1] range by using
normalization_layer = tf.keras.layers.Rescaling(1./255)
There are two ways to use this layer. You can apply it to the dataset by calling
normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y)) image_batch, labels_batch = next(iter(normalized_ds)) first_image = image_batch # Notice the pixel values are now in `[0,1]`. print(np.min(first_image), np.max(first_image))
Or, you can include the layer inside your model definition to simplify deployment. You will use the second approach here.
Configure the dataset for performance
Let's make sure to use buffered prefetching so you can yield data from disk without having I/O become blocking. These are two important methods you should use when loading data:
Dataset.cachekeeps the images in memory after they're loaded off disk during the first epoch. This will ensure the dataset does not become a bottleneck while training your model. If your dataset is too large to fit into memory, you can also use this method to create a performant on-disk cache.
Dataset.prefetchoverlaps data preprocessing and model execution while training.
Interested readers can learn more about both methods, as well as how to cache data to disk in the Prefetching section of the Better performance with the tf.data API guide.
AUTOTUNE = tf.data.AUTOTUNE train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE) val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
Train a model
For completeness, you will show how to train a simple model using the datasets you have just prepared.
The Sequential model consists of three convolution blocks (
tf.keras.layers.Conv2D) with a max pooling layer (
tf.keras.layers.MaxPooling2D) in each of them. There's a fully-connected layer (
tf.keras.layers.Dense) with 128 units on top of it that is activated by a ReLU activation function (
'relu'). This model has not been tuned in any way—the goal is to show you the mechanics using the datasets you just created. To learn more about image classification, visit the Image classification tutorial.
num_classes = 5 model = tf.keras.Sequential([ tf.keras.layers.Rescaling(1./255), tf.keras.layers.Conv2D(32, 3, activation='relu'), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Conv2D(32, 3, activation='relu'), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Conv2D(32, 3, activation='relu'), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(num_classes) ])
tf.keras.optimizers.Adam optimizer and
tf.keras.losses.SparseCategoricalCrossentropy loss function. To view training and validation accuracy for each training epoch, pass the
metrics argument to
model.compile( optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
model.fit( train_ds, validation_data=val_ds, epochs=3 )
Epoch 1/3 92/92 [==============================] - 18s 180ms/step - loss: 1.3220 - accuracy: 0.4322 - val_loss: 1.0815 - val_accuracy: 0.5777 Epoch 2/3 92/92 [==============================] - 16s 174ms/step - loss: 1.0369 - accuracy: 0.5882 - val_loss: 0.9789 - val_accuracy: 0.6281 Epoch 3/3 92/92 [==============================] - 16s 173ms/step - loss: 0.8626 - accuracy: 0.6580 - val_loss: 0.9311 - val_accuracy: 0.6403 <keras.src.callbacks.History at 0x7fe1c8646e20>
You may notice the validation accuracy is low compared to the training accuracy, indicating your model is overfitting. You can learn more about overfitting and how to reduce it in this tutorial.
Using tf.data for finer control
For finer grain control, you can write your own input pipeline using
tf.data. This section shows how to do just that, beginning with the file paths from the TGZ file you downloaded earlier.
list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'), shuffle=False) list_ds = list_ds.shuffle(image_count, reshuffle_each_iteration=False)
for f in list_ds.take(5): print(f.numpy())
b'/home/kbuilder/.keras/datasets/flower_photos/roses/11944957684_2cc806276e.jpg' b'/home/kbuilder/.keras/datasets/flower_photos/sunflowers/6908789145_814d448bb1_n.jpg' b'/home/kbuilder/.keras/datasets/flower_photos/sunflowers/200011914_93f57ed68b.jpg' b'/home/kbuilder/.keras/datasets/flower_photos/daisy/3386988684_bc5a66005e.jpg' b'/home/kbuilder/.keras/datasets/flower_photos/roses/12562723334_a2e0a9e3c8_n.jpg'
The tree structure of the files can be used to compile a
class_names = np.array(sorted([item.name for item in data_dir.glob('*') if item.name != "LICENSE.txt"])) print(class_names)
['daisy' 'dandelion' 'roses' 'sunflowers' 'tulips']
Split the dataset into training and validation sets:
val_size = int(image_count * 0.2) train_ds = list_ds.skip(val_size) val_ds = list_ds.take(val_size)
You can print the length of each dataset as follows:
Write a short function that converts a file path to an
(img, label) pair:
def get_label(file_path): # Convert the path to a list of path components parts = tf.strings.split(file_path, os.path.sep) # The second to last is the class-directory one_hot = parts[-2] == class_names # Integer encode the label return tf.argmax(one_hot)
def decode_img(img): # Convert the compressed string to a 3D uint8 tensor img = tf.io.decode_jpeg(img, channels=3) # Resize the image to the desired size return tf.image.resize(img, [img_height, img_width])
def process_path(file_path): label = get_label(file_path) # Load the raw data from the file as a string img = tf.io.read_file(file_path) img = decode_img(img) return img, label
Dataset.map to create a dataset of
image, label pairs:
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel. train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE) val_ds = val_ds.map(process_path, num_parallel_calls=AUTOTUNE)
for image, label in train_ds.take(1): print("Image shape: ", image.numpy().shape) print("Label: ", label.numpy())
Image shape: (180, 180, 3) Label: 4
Configure dataset for performance
To train a model with this dataset you will want the data:
- To be well shuffled.
- To be batched.
- Batches to be available as soon as possible.
def configure_for_performance(ds): ds = ds.cache() ds = ds.shuffle(buffer_size=1000) ds = ds.batch(batch_size) ds = ds.prefetch(buffer_size=AUTOTUNE) return ds train_ds = configure_for_performance(train_ds) val_ds = configure_for_performance(val_ds)
Visualize the data
You can visualize this dataset similarly to the one you created previously:
image_batch, label_batch = next(iter(train_ds)) plt.figure(figsize=(10, 10)) for i in range(9): ax = plt.subplot(3, 3, i + 1) plt.imshow(image_batch[i].numpy().astype("uint8")) label = label_batch[i] plt.title(class_names[label]) plt.axis("off")
2023-10-04 01:31:53.403686: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Continue training the model
You have now manually built a similar
tf.data.Dataset to the one created by
tf.keras.utils.image_dataset_from_directory above. You can continue training the model with it. As before, you will train for just a few epochs to keep the running time short.
model.fit( train_ds, validation_data=val_ds, epochs=3 )
Epoch 1/3 92/92 [==============================] - 17s 179ms/step - loss: 0.7499 - accuracy: 0.7248 - val_loss: 0.7527 - val_accuracy: 0.7139 Epoch 2/3 92/92 [==============================] - 16s 172ms/step - loss: 0.5483 - accuracy: 0.8045 - val_loss: 0.6978 - val_accuracy: 0.7357 Epoch 3/3 92/92 [==============================] - 16s 171ms/step - loss: 0.3818 - accuracy: 0.8604 - val_loss: 0.9073 - val_accuracy: 0.7003 <keras.src.callbacks.History at 0x7fe1a85256a0>
Using TensorFlow Datasets
As you have previously loaded the Flowers dataset off disk, let's now import it with TensorFlow Datasets.
Download the Flowers dataset using TensorFlow Datasets:
(train_ds, val_ds, test_ds), metadata = tfds.load( 'tf_flowers', split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'], with_info=True, as_supervised=True, )
The flowers dataset has five classes:
num_classes = metadata.features['label'].num_classes print(num_classes)
Retrieve an image from the dataset:
get_label_name = metadata.features['label'].int2str image, label = next(iter(train_ds)) _ = plt.imshow(image) _ = plt.title(get_label_name(label))
2023-10-04 01:32:44.674430: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
As before, remember to batch, shuffle, and configure the training, validation, and test sets for performance:
train_ds = configure_for_performance(train_ds) val_ds = configure_for_performance(val_ds) test_ds = configure_for_performance(test_ds)
You can find a complete example of working with the Flowers dataset and TensorFlow Datasets by visiting the Data augmentation tutorial.
This tutorial showed two ways of loading images off disk. First, you learned how to load and preprocess an image dataset using Keras preprocessing layers and utilities. Next, you learned how to write an input pipeline from scratch using
tf.data. Finally, you learned how to download a dataset from TensorFlow Datasets.
For your next steps: