Thanks for tuning in to Google I/O. View all sessions on demandWatch on demand

Load NumPy data

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial provides an example of loading data from NumPy arrays into a tf.data.Dataset.

This example loads the MNIST dataset from a .npz file. However, the source of the NumPy arrays is not important.

Setup

import numpy as np
import tensorflow as tf
2022-12-14 03:34:37.948644: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 03:34:37.948743: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 03:34:37.948753: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

Load from .npz file

DATA_URL = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz'

path = tf.keras.utils.get_file('mnist.npz', DATA_URL)
with np.load(path) as data:
  train_examples = data['x_train']
  train_labels = data['y_train']
  test_examples = data['x_test']
  test_labels = data['y_test']
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11490434/11490434 [==============================] - 0s 0us/step

Load NumPy arrays with tf.data.Dataset

Assuming you have an array of examples and a corresponding array of labels, pass the two arrays as a tuple into tf.data.Dataset.from_tensor_slices to create a tf.data.Dataset.

train_dataset = tf.data.Dataset.from_tensor_slices((train_examples, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_examples, test_labels))

Use the datasets

Shuffle and batch the datasets

BATCH_SIZE = 64
SHUFFLE_BUFFER_SIZE = 100

train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)

Build and train a model

model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10)
])

model.compile(optimizer=tf.keras.optimizers.RMSprop(),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['sparse_categorical_accuracy'])
model.fit(train_dataset, epochs=10)
Epoch 1/10
938/938 [==============================] - 3s 2ms/step - loss: 2.9153 - sparse_categorical_accuracy: 0.8698
Epoch 2/10
938/938 [==============================] - 2s 2ms/step - loss: 0.4571 - sparse_categorical_accuracy: 0.9285
Epoch 3/10
938/938 [==============================] - 2s 2ms/step - loss: 0.3365 - sparse_categorical_accuracy: 0.9463
Epoch 4/10
938/938 [==============================] - 2s 2ms/step - loss: 0.2928 - sparse_categorical_accuracy: 0.9539
Epoch 5/10
938/938 [==============================] - 2s 2ms/step - loss: 0.2468 - sparse_categorical_accuracy: 0.9608
Epoch 6/10
938/938 [==============================] - 2s 2ms/step - loss: 0.2194 - sparse_categorical_accuracy: 0.9655
Epoch 7/10
938/938 [==============================] - 2s 2ms/step - loss: 0.2102 - sparse_categorical_accuracy: 0.9681
Epoch 8/10
938/938 [==============================] - 2s 2ms/step - loss: 0.1931 - sparse_categorical_accuracy: 0.9712
Epoch 9/10
938/938 [==============================] - 2s 2ms/step - loss: 0.1778 - sparse_categorical_accuracy: 0.9730
Epoch 10/10
938/938 [==============================] - 2s 2ms/step - loss: 0.1702 - sparse_categorical_accuracy: 0.9756
<keras.callbacks.History at 0x7f0e980debb0>
model.evaluate(test_dataset)
157/157 [==============================] - 0s 2ms/step - loss: 0.5646 - sparse_categorical_accuracy: 0.9550
[0.5645962357521057, 0.9549999833106995]