TensorFlow 2.0 Beta is available Learn more

Load NumPy Data with tf.data

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial provides an example of loading data from NumPy arrays into a tf.data.Dataset.

This example loads the MNIST dataset from a .npz file. However, the source of the NumPy arrays is not important.

Setup

try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass
from __future__ import absolute_import, division, print_function, unicode_literals
 
import numpy as np
import tensorflow as tf

Load from .npz file

DATA_URL = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz'

path = tf.keras.utils.get_file('mnist.npz', DATA_URL)
with np.load(path) as data:
  train_examples = data['x_train']
  train_labels = data['y_train']
  test_examples = data['x_test']
  test_labels = data['y_test']

Load NumPy arrays with tf.data.Dataset

Assuming you have an array of examples and a corresponding array of labels, pass the two arrays as a tuple into tf.data.Dataset.from_tensor_slices to create a tf.data.Dataset.

train_dataset = tf.data.Dataset.from_tensor_slices((train_examples, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_examples, test_labels))

Use the datasets

Shuffle and batch the datasets

BATCH_SIZE = 64
SHUFFLE_BUFFER_SIZE = 100

train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)

Build and train a model

model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer=tf.keras.optimizers.RMSprop(),
                loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
model.fit(train_dataset, epochs=10)
WARNING: Logging before flag parsing goes to stderr.
W0813 06:02:53.158667 139850591098624 deprecation.py:323] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

Epoch 1/10
938/938 [==============================] - 4s 5ms/step - loss: 9.2718 - sparse_categorical_accuracy: 0.4223
Epoch 2/10
938/938 [==============================] - 3s 4ms/step - loss: 8.5060 - sparse_categorical_accuracy: 0.4710
Epoch 3/10
938/938 [==============================] - 3s 4ms/step - loss: 8.4545 - sparse_categorical_accuracy: 0.4743
Epoch 4/10
938/938 [==============================] - 4s 4ms/step - loss: 8.4016 - sparse_categorical_accuracy: 0.4778
Epoch 5/10
938/938 [==============================] - 3s 4ms/step - loss: 7.7754 - sparse_categorical_accuracy: 0.5164
Epoch 6/10
938/938 [==============================] - 3s 4ms/step - loss: 7.0343 - sparse_categorical_accuracy: 0.5622
Epoch 7/10
938/938 [==============================] - 3s 4ms/step - loss: 6.9545 - sparse_categorical_accuracy: 0.5673
Epoch 8/10
938/938 [==============================] - 4s 4ms/step - loss: 6.9153 - sparse_categorical_accuracy: 0.5701
Epoch 9/10
938/938 [==============================] - 3s 4ms/step - loss: 6.8574 - sparse_categorical_accuracy: 0.5737
Epoch 10/10
938/938 [==============================] - 4s 4ms/step - loss: 6.8339 - sparse_categorical_accuracy: 0.5752

<tensorflow.python.keras.callbacks.History at 0x7f3175bcfc50>
model.evaluate(test_dataset)
157/157 [==============================] - 1s 3ms/step - loss: 6.8574 - sparse_categorical_accuracy: 0.5742

[6.857351382067249, 0.5742]