TFDS proporciona una colección de conjuntos de datos listos para usar para usar con TensorFlow, Jax y otros marcos de aprendizaje automático.
Maneja la descarga y preparación de los datos de forma determinista y construye untf.data.Dataset
(o np.array
).
![]() | ![]() | ![]() | ![]() |
Instalación
TFDS existe en dos paquetes:
-
pip install tensorflow-datasets
: la versión estable, lanzada cada pocos meses. -
pip install tfds-nightly
: Lanzado todos los días, contiene las últimas versiones de los conjuntos de datos.
Este colab usa tfds-nightly
:
pip install -q tfds-nightly tensorflow matplotlib
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
Encuentra conjuntos de datos disponibles
Todos los creadores de conjuntos de datos son subclase de tfds.core.DatasetBuilder
. Para obtener la lista de constructores disponibles, use tfds.list_builders()
o consulte nuestro catálogo .
tfds.list_builders()
['abstract_reasoning', 'accentdb', 'aeslc', 'aflw2k3d', 'ag_news_subset', 'ai2_arc', 'ai2_arc_with_ir', 'amazon_us_reviews', 'anli', 'arc', 'bair_robot_pushing_small', 'bccd', 'beans', 'big_patent', 'bigearthnet', 'billsum', 'binarized_mnist', 'binary_alpha_digits', 'blimp', 'bool_q', 'c4', 'caltech101', 'caltech_birds2010', 'caltech_birds2011', 'cars196', 'cassava', 'cats_vs_dogs', 'celeb_a', 'celeb_a_hq', 'cfq', 'cherry_blossoms', 'chexpert', 'cifar10', 'cifar100', 'cifar10_1', 'cifar10_corrupted', 'citrus_leaves', 'cityscapes', 'civil_comments', 'clevr', 'clic', 'clinc_oos', 'cmaterdb', 'cnn_dailymail', 'coco', 'coco_captions', 'coil100', 'colorectal_histology', 'colorectal_histology_large', 'common_voice', 'coqa', 'cos_e', 'cosmos_qa', 'covid19sum', 'crema_d', 'curated_breast_imaging_ddsm', 'cycle_gan', 'dart', 'davis', 'deep_weeds', 'definite_pronoun_resolution', 'dementiabank', 'diabetic_retinopathy_detection', 'div2k', 'dmlab', 'downsampled_imagenet', 'drop', 'dsprites', 'dtd', 'duke_ultrasound', 'e2e_cleaned', 'efron_morris75', 'emnist', 'eraser_multi_rc', 'esnli', 'eurosat', 'fashion_mnist', 'flic', 'flores', 'food101', 'forest_fires', 'fuss', 'gap', 'geirhos_conflict_stimuli', 'genomics_ood', 'german_credit_numeric', 'gigaword', 'glue', 'goemotions', 'gpt3', 'gref', 'groove', 'gtzan', 'gtzan_music_speech', 'hellaswag', 'higgs', 'horses_or_humans', 'howell', 'i_naturalist2017', 'imagenet2012', 'imagenet2012_corrupted', 'imagenet2012_real', 'imagenet2012_subset', 'imagenet_a', 'imagenet_r', 'imagenet_resized', 'imagenet_v2', 'imagenette', 'imagewang', 'imdb_reviews', 'irc_disentanglement', 'iris', 'kitti', 'kmnist', 'lambada', 'lfw', 'librispeech', 'librispeech_lm', 'libritts', 'ljspeech', 'lm1b', 'lost_and_found', 'lsun', 'lvis', 'malaria', 'math_dataset', 'mctaco', 'mlqa', 'mnist', 'mnist_corrupted', 'movie_lens', 'movie_rationales', 'movielens', 'moving_mnist', 'multi_news', 'multi_nli', 'multi_nli_mismatch', 'natural_questions', 'natural_questions_open', 'newsroom', 'nsynth', 'nyu_depth_v2', 'omniglot', 'open_images_challenge2019_detection', 'open_images_v4', 'openbookqa', 'opinion_abstracts', 'opinosis', 'opus', 'oxford_flowers102', 'oxford_iiit_pet', 'para_crawl', 'patch_camelyon', 'paws_wiki', 'paws_x_wiki', 'pet_finder', 'pg19', 'piqa', 'places365_small', 'plant_leaves', 'plant_village', 'plantae_k', 'qa4mre', 'qasc', 'quac', 'quickdraw_bitmap', 'race', 'radon', 'reddit', 'reddit_disentanglement', 'reddit_tifu', 'resisc45', 'robonet', 'rock_paper_scissors', 'rock_you', 's3o4d', 'salient_span_wikipedia', 'samsum', 'savee', 'scan', 'scene_parse150', 'scicite', 'scientific_papers', 'sentiment140', 'shapes3d', 'siscore', 'smallnorb', 'snli', 'so2sat', 'speech_commands', 'spoken_digit', 'squad', 'stanford_dogs', 'stanford_online_products', 'star_cfq', 'starcraft_video', 'stl10', 'story_cloze', 'sun397', 'super_glue', 'svhn_cropped', 'tao', 'ted_hrlr_translate', 'ted_multi_translate', 'tedlium', 'tf_flowers', 'the300w_lp', 'tiny_shakespeare', 'titanic', 'trec', 'trivia_qa', 'tydi_qa', 'uc_merced', 'ucf101', 'vctk', 'vgg_face2', 'visual_domain_decathlon', 'voc', 'voxceleb', 'voxforge', 'waymo_open_dataset', 'web_nlg', 'web_questions', 'wider_face', 'wiki40b', 'wiki_bio', 'wiki_table_questions', 'wiki_table_text', 'wikihow', 'wikipedia', 'wikipedia_toxicity_subtypes', 'wine_quality', 'winogrande', 'wmt13_translate', 'wmt14_translate', 'wmt15_translate', 'wmt16_translate', 'wmt17_translate', 'wmt18_translate', 'wmt19_translate', 'wmt_t2t_translate', 'wmt_translate', 'wordnet', 'wsc273', 'xnli', 'xquad', 'xsum', 'xtreme_pawsx', 'xtreme_xnli', 'yelp_polarity_reviews', 'yes_no', 'youtube_vis']
Cargar un conjunto de datos
tfds.load
La forma más sencilla de cargar un conjunto de datos es tfds.load
. Va a:
- Descargue los datos y guárdelos como archivos
tfrecord
. - Cargue el
tfrecord
y cree eltf.data.Dataset
.
ds = tfds.load('mnist', split='train', shuffle_files=True)
assert isinstance(ds, tf.data.Dataset)
print(ds)
<_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>
Algunos argumentos comunes:
-
split=
: Qué división leer (por ejemplo,'train'
,['train', 'test']
,'train[80%:]'
, ...) Consulte nuestra guía de API dividida . -
shuffle_files=
: Controla si se mezclan los archivos entre cada época (TFDS almacena grandes conjuntos de datos en varios archivos más pequeños). -
data_dir=
: Ubicación donde se guarda el conjunto de datos (por defecto es~/tensorflow_datasets/
) -
with_info=True
: Devuelve eltfds.core.DatasetInfo
contiene los metadatos del conjunto de datos -
download=False
: deshabilitar la descarga
tfds.builder
tfds.load
es un contenedor delgado alrededor de tfds.core.DatasetBuilder
. Puede obtener el mismo resultado utilizando la API tfds.core.DatasetBuilder
:
builder = tfds.builder('mnist')
# 1. Create the tfrecord files (no-op if already exists)
builder.download_and_prepare()
# 2. Load the `tf.data.Dataset`
ds = builder.as_dataset(split='train', shuffle_files=True)
print(ds)
<_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>
tfds build
CLI
Si desea generar un conjunto de datos específico, puede usar la línea de comando tfds
. Por ejemplo:
tfds build mnist
Consulte el documento para conocer las banderas disponibles.
Iterar sobre un conjunto de datos
Como dict
De forma predeterminada, el objetotf.data.Dataset
contiene un dict
de tf.Tensor
s:
ds = tfds.load('mnist', split='train')
ds = ds.take(1) # Only take a single example
for example in ds: # example is `{'image': tf.Tensor, 'label': tf.Tensor}`
print(list(example.keys()))
image = example["image"]
label = example["label"]
print(image.shape, label)
['image', 'label'] (28, 28, 1) tf.Tensor(4, shape=(), dtype=int64)
Para conocer la estructura y los nombres de las claves de dict
, consulte la documentación del conjunto de datos en nuestro catálogo . Por ejemplo: documentación mnist .
Como tupla ( as_supervised=True
)
Al usar as_supervised=True
, puede obtener una tupla (features, label)
lugar de conjuntos de datos supervisados.
ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)
for image, label in ds: # example is (image, label)
print(image.shape, label)
(28, 28, 1) tf.Tensor(4, shape=(), dtype=int64)
Como numpy ( tfds.as_numpy
)
Utiliza tfds.as_numpy
para convertir:
-
tf.Tensor
->np.array
tf.data.Dataset
->tf.data.Dataset
Iterator[Tree[np.array]]
(ElTree
puede ser unDict
,Tuple
anidado arbitrario)
ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)
for image, label in tfds.as_numpy(ds):
print(type(image), type(label), label)
<class 'numpy.ndarray'> <class 'numpy.int64'> 4
Como tf.Tensor por lotes ( batch_size=-1
)
Al usar batch_size=-1
, puede cargar el conjunto de datos completo en un solo lote.
Esto se puede combinar con as_supervised=True
y tfds.as_numpy
para obtener los datos como (np.array, np.array)
:
image, label = tfds.as_numpy(tfds.load(
'mnist',
split='test',
batch_size=-1,
as_supervised=True,
))
print(type(image), image.shape)
<class 'numpy.ndarray'> (10000, 28, 28, 1)
Tenga cuidado de que su conjunto de datos pueda caber en la memoria y de que todos los ejemplos tengan la misma forma.
Compare sus conjuntos de datos
Evaluación comparativa de un conjunto de datos es un simple tfds.benchmark
llamada en cualquier iterable (por ejemplotf.data.Dataset
, tfds.as_numpy
, ...).
ds = tfds.load('mnist', split='train')
ds = ds.batch(32).prefetch(1)
tfds.benchmark(ds, batch_size=32)
tfds.benchmark(ds, batch_size=32) # Second epoch much faster due to auto-caching
************ Summary ************ Examples/sec (First included) 47889.92 ex/sec (total: 60000 ex, 1.25 sec) Examples/sec (First only) 110.24 ex/sec (total: 32 ex, 0.29 sec) Examples/sec (First excluded) 62298.08 ex/sec (total: 59968 ex, 0.96 sec) ************ Summary ************ Examples/sec (First included) 290380.50 ex/sec (total: 60000 ex, 0.21 sec) Examples/sec (First only) 2506.57 ex/sec (total: 32 ex, 0.01 sec) Examples/sec (First excluded) 309338.21 ex/sec (total: 59968 ex, 0.19 sec)
- No olvide normalizar los resultados por tamaño de lote con
batch_size=
kwarg. - En resumen, el primer lote de preparación se separa de los demás para capturar el tiempo de configuración adicional de
tf.data.Dataset
(por ejemplo, inicialización de búferes, ...). - Observe cómo la segunda iteración es mucho más rápida debido al almacenamiento en caché automático de TFDS .
-
tfds.benchmark
devuelve untfds.core.BenchmarkResult
que se puede inspeccionar para un análisis más detallado.
Construya una canalización de un extremo a otro
Para ir más lejos, puedes mirar:
- Nuestro ejemplo de Keras de extremo a extremo para ver una canalización de capacitación completa (con lotes, barajado, ...).
- Nuestra guía de rendimiento para mejorar la velocidad de sus canalizaciones (consejo: use
tfds.benchmark(ds)
para comparar sus conjuntos de datos).
Visualización
tfds.as_dataframe
tf.data.Dataset
objetostf.data.Dataset
se pueden convertir a pandas.DataFrame
con tfds.as_dataframe
para visualizarlos en Colab .
- Agregue
tfds.core.DatasetInfo
como segundo argumento detfds.as_dataframe
para visualizar imágenes, audio, textos, videos, ... - Utilice
ds.take(x)
para mostrar solo los primerosx
ejemplos.pandas.DataFrame
cargará el conjunto de datos completo en la memoria y puede ser muy costoso de mostrar.
ds, info = tfds.load('mnist', split='train', with_info=True)
tfds.as_dataframe(ds.take(4), info)
tfds.show_examples
tfds.show_examples
devuelve un matplotlib.figure.Figure
(ahora solo se admiten conjuntos de datos de imágenes):
ds, info = tfds.load('mnist', split='train', with_info=True)
fig = tfds.show_examples(ds, info)
Acceder a los metadatos del conjunto de datos
Todos los constructores incluyen un objeto tfds.core.DatasetInfo
que contiene los metadatos del conjunto de datos.
Se puede acceder a través de:
- La API
tfds.load
:
ds, info = tfds.load('mnist', with_info=True)
- La API
tfds.core.DatasetBuilder
:
builder = tfds.builder('mnist')
info = builder.info
La información del conjunto de datos contiene información adicional sobre el conjunto de datos (versión, cita, página de inicio, descripción, ...).
print(info)
tfds.core.DatasetInfo( name='mnist', full_name='mnist/3.0.1', description=""" The MNIST database of handwritten digits. """, homepage='http://yann.lecun.com/exdb/mnist/', data_path='gs://tensorflow-datasets/datasets/mnist/3.0.1', download_size=11.06 MiB, dataset_size=21.00 MiB, features=FeaturesDict({ 'image': Image(shape=(28, 28, 1), dtype=tf.uint8), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10), }), supervised_keys=('image', 'label'), splits={ 'test': <SplitInfo num_examples=10000, num_shards=1>, 'train': <SplitInfo num_examples=60000, num_shards=1>, }, citation="""@article{lecun2010mnist, title={MNIST handwritten digit database}, author={LeCun, Yann and Cortes, Corinna and Burges, CJ}, journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist}, volume={2}, year={2010} }""", )
Incluye metadatos (nombres de etiquetas, forma de la imagen, ...)
Acceda a tfds.features.FeatureDict
:
info.features
FeaturesDict({ 'image': Image(shape=(28, 28, 1), dtype=tf.uint8), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10), })
Número de clases, nombres de etiquetas:
print(info.features["label"].num_classes)
print(info.features["label"].names)
print(info.features["label"].int2str(7)) # Human readable version (8 -> 'cat')
print(info.features["label"].str2int('7'))
10 ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 7 7
Formas, dtipos:
print(info.features.shape)
print(info.features.dtype)
print(info.features['image'].shape)
print(info.features['image'].dtype)
{'image': (28, 28, 1), 'label': ()} {'image': tf.uint8, 'label': tf.int64} (28, 28, 1) <dtype: 'uint8'>
Metadatos divididos (por ejemplo, nombres divididos, número de ejemplos, ...)
Acceda al tfds.core.SplitDict
:
print(info.splits)
{'test': <SplitInfo num_examples=10000, num_shards=1>, 'train': <SplitInfo num_examples=60000, num_shards=1>}
Divisiones disponibles:
print(list(info.splits.keys()))
['test', 'train']
Obtenga información sobre la división individual:
print(info.splits['train'].num_examples)
print(info.splits['train'].filenames)
print(info.splits['train'].num_shards)
60000 ['mnist-train.tfrecord-00000-of-00001'] 1
También funciona con la API subsplit:
print(info.splits['train[15%:75%]'].num_examples)
print(info.splits['train[15%:75%]'].file_instructions)
36000 [FileInstruction(filename='mnist-train.tfrecord-00000-of-00001', skip=9000, take=36000, num_examples=36000)]
Solución de problemas
Descarga manual (si falla la descarga)
Si la descarga falla por algún motivo (por ejemplo, sin conexión, ...). Siempre puede descargar manualmente los datos usted mismo y colocarlos en manual_dir
(el valor predeterminado es ~/tensorflow_datasets/download/manual/
.
Para saber qué URL descargar, busque en:
Para nuevos conjuntos de datos (implementados como carpeta):
tensorflow_datasets/
<type>/<dataset_name>/checksums.tsv
. Por ejemplo:tensorflow_datasets/text/bool_q/checksums.tsv
.Puede encontrar la ubicación de la fuente del conjunto de datos en nuestro catálogo .
Para conjuntos de datos antiguos:
tensorflow_datasets/url_checksums/<dataset_name>.txt
Arreglando NonMatchingChecksumError
TFDS garantiza el determinismo validando las sumas de comprobación de las URL descargadas. Si se NonMatchingChecksumError
, podría indicar:
- El sitio web puede estar inactivo (por ejemplo
503 status code
). Comprueba la URL. - Para las URL de Google Drive, inténtelo de nuevo más tarde, ya que Drive a veces rechaza las descargas cuando demasiadas personas acceden a la misma URL. Ver error
- Es posible que se hayan actualizado los archivos de conjuntos de datos originales. En este caso, se debe actualizar el generador de conjuntos de datos TFDS. Abra un nuevo problema de Github o PR:
- Registre las nuevas sumas de comprobación con
tfds build --register_checksums
- Actualice eventualmente el código de generación del conjunto de datos.
- Actualizar la
VERSION
conjunto de datos - Actualice el conjunto de datos
RELEASE_NOTES
: ¿Qué provocó que cambiaran las sumas de comprobación? ¿Han cambiado algunos ejemplos? - Asegúrese de que el conjunto de datos aún se pueda construir.
- Envíanos un PR
- Registre las nuevas sumas de comprobación con
Citación
Si está utilizando tensorflow-datasets
para un artículo, incluya la siguiente cita, además de cualquier cita específica de los conjuntos de datos utilizados (que se puede encontrar en el catálogo de conjuntos de datos ).
@misc{TFDS,
title = { {TensorFlow Datasets}, A collection of ready-to-use datasets},
howpublished = {\url{https://www.tensorflow.org/datasets} },
}