TFDS fornece uma coleção de conjuntos de dados prontos para uso com TensorFlow, Jax e outras estruturas de aprendizado de máquina.
Ele lida com o download e a preparação dos dados deterministicamente e construindo umtf.data.Dataset
(ou np.array
).
![]() | ![]() | ![]() |
Instalação
TFDS existe em dois pacotes:
-
pip install tensorflow-datasets
: a versão estável, lançada a cada poucos meses. -
pip install tfds-nightly
: lançado todos os dias, contém as últimas versões dos conjuntos de dados.
Esta colab usa tfds-nightly
:
pip install -q tfds-nightly tensorflow matplotlib
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
Encontre conjuntos de dados disponíveis
Todos os construtores de conjuntos de dados são subclasses de tfds.core.DatasetBuilder
. Para obter a lista de construtores disponíveis, use tfds.list_builders()
ou consulte nosso catálogo .
tfds.list_builders()
['abstract_reasoning', 'accentdb', 'aeslc', 'aflw2k3d', 'ag_news_subset', 'ai2_arc', 'ai2_arc_with_ir', 'amazon_us_reviews', 'anli', 'arc', 'bair_robot_pushing_small', 'bccd', 'beans', 'big_patent', 'bigearthnet', 'billsum', 'binarized_mnist', 'binary_alpha_digits', 'blimp', 'bool_q', 'c4', 'caltech101', 'caltech_birds2010', 'caltech_birds2011', 'cars196', 'cassava', 'cats_vs_dogs', 'celeb_a', 'celeb_a_hq', 'cfq', 'cherry_blossoms', 'chexpert', 'cifar10', 'cifar100', 'cifar10_1', 'cifar10_corrupted', 'citrus_leaves', 'cityscapes', 'civil_comments', 'clevr', 'clic', 'clinc_oos', 'cmaterdb', 'cnn_dailymail', 'coco', 'coco_captions', 'coil100', 'colorectal_histology', 'colorectal_histology_large', 'common_voice', 'coqa', 'cos_e', 'cosmos_qa', 'covid19sum', 'crema_d', 'curated_breast_imaging_ddsm', 'cycle_gan', 'dart', 'davis', 'deep_weeds', 'definite_pronoun_resolution', 'dementiabank', 'diabetic_retinopathy_detection', 'div2k', 'dmlab', 'downsampled_imagenet', 'drop', 'dsprites', 'dtd', 'duke_ultrasound', 'e2e_cleaned', 'emnist', 'eraser_multi_rc', 'esnli', 'eurosat', 'fashion_mnist', 'flic', 'flores', 'food101', 'forest_fires', 'fuss', 'gap', 'geirhos_conflict_stimuli', 'genomics_ood', 'german_credit_numeric', 'gigaword', 'glue', 'goemotions', 'gpt3', 'gref', 'groove', 'gtzan', 'gtzan_music_speech', 'hellaswag', 'higgs', 'horses_or_humans', 'howell', 'i_naturalist2017', 'imagenet2012', 'imagenet2012_corrupted', 'imagenet2012_real', 'imagenet2012_subset', 'imagenet_a', 'imagenet_r', 'imagenet_resized', 'imagenet_v2', 'imagenette', 'imagewang', 'imdb_reviews', 'irc_disentanglement', 'iris', 'kitti', 'kmnist', 'lambada', 'lfw', 'librispeech', 'librispeech_lm', 'libritts', 'ljspeech', 'lm1b', 'lost_and_found', 'lsun', 'lvis', 'malaria', 'math_dataset', 'mctaco', 'mlqa', 'mnist', 'mnist_corrupted', 'movie_lens', 'movie_rationales', 'movielens', 'moving_mnist', 'multi_news', 'multi_nli', 'multi_nli_mismatch', 'natural_questions', 'natural_questions_open', 'newsroom', 'nsynth', 'nyu_depth_v2', 'omniglot', 'open_images_challenge2019_detection', 'open_images_v4', 'openbookqa', 'opinion_abstracts', 'opinosis', 'opus', 'oxford_flowers102', 'oxford_iiit_pet', 'para_crawl', 'patch_camelyon', 'paws_wiki', 'paws_x_wiki', 'pet_finder', 'pg19', 'piqa', 'places365_small', 'plant_leaves', 'plant_village', 'plantae_k', 'qa4mre', 'qasc', 'quac', 'quickdraw_bitmap', 'race', 'radon', 'reddit', 'reddit_disentanglement', 'reddit_tifu', 'resisc45', 'robonet', 'rock_paper_scissors', 'rock_you', 's3o4d', 'salient_span_wikipedia', 'samsum', 'savee', 'scan', 'scene_parse150', 'scicite', 'scientific_papers', 'sentiment140', 'shapes3d', 'siscore', 'smallnorb', 'snli', 'so2sat', 'speech_commands', 'spoken_digit', 'squad', 'stanford_dogs', 'stanford_online_products', 'starcraft_video', 'stl10', 'story_cloze', 'sun397', 'super_glue', 'svhn_cropped', 'ted_hrlr_translate', 'ted_multi_translate', 'tedlium', 'tf_flowers', 'the300w_lp', 'tiny_shakespeare', 'titanic', 'trec', 'trivia_qa', 'tydi_qa', 'uc_merced', 'ucf101', 'vctk', 'vgg_face2', 'visual_domain_decathlon', 'voc', 'voxceleb', 'voxforge', 'waymo_open_dataset', 'web_nlg', 'web_questions', 'wider_face', 'wiki40b', 'wiki_bio', 'wiki_table_questions', 'wiki_table_text', 'wikihow', 'wikipedia', 'wikipedia_toxicity_subtypes', 'wine_quality', 'winogrande', 'wmt14_translate', 'wmt15_translate', 'wmt16_translate', 'wmt17_translate', 'wmt18_translate', 'wmt19_translate', 'wmt_t2t_translate', 'wmt_translate', 'wordnet', 'wsc273', 'xnli', 'xquad', 'xsum', 'xtreme_pawsx', 'xtreme_xnli', 'yelp_polarity_reviews', 'yes_no']
Carregar um conjunto de dados
tfds.load
A maneira mais fácil de carregar um conjunto de dados é tfds.load
. Será:
- Baixe os dados e salve-os como arquivos
tfrecord
. - Carregue o
tfrecord
e crie otf.data.Dataset
.
ds = tfds.load('mnist', split='train', shuffle_files=True)
assert isinstance(ds, tf.data.Dataset)
print(ds)
Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1... Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data. <_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>
Alguns argumentos comuns:
-
split=
: Qual divisão ler (por exemplo,'train'
,['train', 'test']
,'train[80%:]'
, ...). Veja nosso guia de divisão de API . -
shuffle_files=
: controla se deve embaralhar os arquivos entre cada época (TFDS armazena grandes conjuntos de dados em vários arquivos menores). -
data_dir=
: Local onde o conjunto de dados é salvo (o padrão é~/tensorflow_datasets/
) -
with_info=True
: Retorna otfds.core.DatasetInfo
contendo metadados do conjunto de dados -
download=False
: Desativar download
tfds.builder
tfds.load
é um wrapper fino em torno de tfds.core.DatasetBuilder
. Você pode obter a mesma saída usando a API tfds.core.DatasetBuilder
:
builder = tfds.builder('mnist')
# 1. Create the tfrecord files (no-op if already exists)
builder.download_and_prepare()
# 2. Load the `tf.data.Dataset`
ds = builder.as_dataset(split='train', shuffle_files=True)
print(ds)
<_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>
tfds build
CLI
Se você deseja gerar um conjunto de dados específico, você pode usar a linha de comando tfds
. Por exemplo:
tfds build mnist
Veja o documento para sinalizadores disponíveis.
Iterar sobre um conjunto de dados
Como dict
Por padrão, otf.data.Dataset
objeto contém um dict
de tf.Tensor
s:
ds = tfds.load('mnist', split='train')
ds = ds.take(1) # Only take a single example
for example in ds: # example is `{'image': tf.Tensor, 'label': tf.Tensor}`
print(list(example.keys()))
image = example["image"]
label = example["label"]
print(image.shape, label)
['image', 'label'] (28, 28, 1) tf.Tensor(4, shape=(), dtype=int64)
Para descobrir os nomes e a estrutura das chaves dict
, consulte a documentação do conjunto de dados em nosso catálogo . Por exemplo: documentação mnist .
Como tupla ( as_supervised=True
)
Usando as_supervised=True
, você pode obter uma tupla (features, label)
vez de conjuntos de dados supervisionados.
ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)
for image, label in ds: # example is (image, label)
print(image.shape, label)
(28, 28, 1) tf.Tensor(4, shape=(), dtype=int64)
Tão numpy ( tfds.as_numpy
)
Usa tfds.as_numpy
para converter:
-
tf.Tensor
->np.array
tf.data.Dataset
->Iterator[Tree[np.array]]
(ATree
pode serDict
,Tuple
aninhada arbitrariamente)
ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)
for image, label in tfds.as_numpy(ds):
print(type(image), type(label), label)
<class 'numpy.ndarray'> <class 'numpy.int64'> 4
Como tf.Tensor em lote ( batch_size=-1
)
Usando batch_size=-1
, você pode carregar o conjunto de dados completo em um único lote.
Isso pode ser combinado com as_supervised=True
e tfds.as_numpy
para obter os dados como (np.array, np.array)
:
image, label = tfds.as_numpy(tfds.load(
'mnist',
split='test',
batch_size=-1,
as_supervised=True,
))
print(type(image), image.shape)
<class 'numpy.ndarray'> (10000, 28, 28, 1)
Tenha cuidado para que seu conjunto de dados caiba na memória e que todos os exemplos tenham o mesmo formato.
Construir pipeline de ponta a ponta
Para ir mais longe, você pode olhar:
- Nosso exemplo Keras ponta a ponta para ver um pipeline de treinamento completo (com batching, shuffling, ...).
- Nosso guia de desempenho para melhorar a velocidade de seus pipelines.
Visualização
tfds.as_dataframe
tf.data.Dataset
objetostf.data.Dataset
podem ser convertidos em pandas.DataFrame
com tfds.as_dataframe
para serem visualizados no Colab .
- Adicione o
tfds.core.DatasetInfo
como segundo argumento detfds.as_dataframe
para visualizar imagens, áudio, textos, vídeos, ... - Use
ds.take(x)
para exibir apenas os primeirosx
exemplos.pandas.DataFrame
carregará o conjunto de dados completo na memória e pode ser muito caro para exibir.
ds, info = tfds.load('mnist', split='train', with_info=True)
tfds.as_dataframe(ds.take(4), info)
tfds.show_examples
tfds.show_examples
retorna um matplotlib.figure.Figure
(apenas conjuntos de dados de imagem são suportados agora):
ds, info = tfds.load('mnist', split='train', with_info=True)
fig = tfds.show_examples(ds, info)
Acesse os metadados do conjunto de dados
Todos os construtores incluem um objeto tfds.core.DatasetInfo
contendo os metadados do conjunto de dados.
Ele pode ser acessado através de:
- A API
tfds.load
:
ds, info = tfds.load('mnist', with_info=True)
- A API
tfds.core.DatasetBuilder
:
builder = tfds.builder('mnist')
info = builder.info
As informações do conjunto de dados contém informações adicionais sobre o conjunto de dados (versão, citação, página inicial, descrição, ...).
print(info)
tfds.core.DatasetInfo( name='mnist', full_name='mnist/3.0.1', description=""" The MNIST database of handwritten digits. """, homepage='http://yann.lecun.com/exdb/mnist/', data_path='/home/kbuilder/tensorflow_datasets/mnist/3.0.1', download_size=11.06 MiB, dataset_size=21.00 MiB, features=FeaturesDict({ 'image': Image(shape=(28, 28, 1), dtype=tf.uint8), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10), }), supervised_keys=('image', 'label'), splits={ 'test': <SplitInfo num_examples=10000, num_shards=1>, 'train': <SplitInfo num_examples=60000, num_shards=1>, }, citation="""@article{lecun2010mnist, title={MNIST handwritten digit database}, author={LeCun, Yann and Cortes, Corinna and Burges, CJ}, journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist}, volume={2}, year={2010} }""", )
Metadados de recursos (nomes de rótulos, formato de imagem, ...)
Acesse tfds.features.FeatureDict
:
info.features
FeaturesDict({ 'image': Image(shape=(28, 28, 1), dtype=tf.uint8), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10), })
Número de classes, nomes de rótulos:
print(info.features["label"].num_classes)
print(info.features["label"].names)
print(info.features["label"].int2str(7)) # Human readable version (8 -> 'cat')
print(info.features["label"].str2int('7'))
10 ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 7 7
Formas, tipos d:
print(info.features.shape)
print(info.features.dtype)
print(info.features['image'].shape)
print(info.features['image'].dtype)
{'image': (28, 28, 1), 'label': ()} {'image': tf.uint8, 'label': tf.int64} (28, 28, 1) <dtype: 'uint8'>
Metadados de divisão (por exemplo, nomes de divisão, número de exemplos, ...)
Acesse o tfds.core.SplitDict
:
print(info.splits)
{'test': <SplitInfo num_examples=10000, num_shards=1>, 'train': <SplitInfo num_examples=60000, num_shards=1>}
Divisões disponíveis:
print(list(info.splits.keys()))
['test', 'train']
Obtenha informações sobre a divisão individual:
print(info.splits['train'].num_examples)
print(info.splits['train'].filenames)
print(info.splits['train'].num_shards)
60000 ['mnist-train.tfrecord-00000-of-00001'] 1
Também funciona com a API subsplit:
print(info.splits['train[15%:75%]'].num_examples)
print(info.splits['train[15%:75%]'].file_instructions)
36000 [FileInstruction(filename='mnist-train.tfrecord-00000-of-00001', skip=9000, take=36000, num_examples=36000)]
Solução de problemas
Download manual (se o download falhar)
Se o download falhar por algum motivo (por exemplo, offline, ...). Você sempre pode baixar manualmente os dados você mesmo e colocá-los no manual_dir
(o padrão é ~/tensorflow_datasets/download/manual/
.
Para descobrir quais URLs baixar, verifique:
Para novos conjuntos de dados (implementados como pasta):
tensorflow_datasets/
<type>/<dataset_name>/checksums.tsv
. Por exemplo:tensorflow_datasets/text/bool_q/checksums.tsv
.Você pode encontrar o local da fonte do conjunto de dados em nosso catálogo .
Para conjuntos de dados antigos:
tensorflow_datasets/url_checksums/<dataset_name>.txt
Corrigindo NonMatchingChecksumError
TFDS garante determinismo validando as somas de verificação de urls baixados. Se NonMatchingChecksumError
for gerado, pode indicar:
- O site pode estar fora do ar (por exemplo
503 status code
). Verifique o url. - Para URLs do Google Drive, tente novamente mais tarde, pois o Drive às vezes rejeita downloads quando muitas pessoas acessam o mesmo URL. Ver bug
- Os arquivos de conjuntos de dados originais podem ter sido atualizados. Nesse caso, o construtor do conjunto de dados TFDS deve ser atualizado. Abra um novo problema no Github ou RP:
- Registre as novas somas de verificação com
tfds build --register_checksums
- Eventualmente, atualize o código de geração do conjunto de dados.
- Atualize o conjunto de dados
VERSION
- Atualize o conjunto de dados
RELEASE_NOTES
: O que causou a mudança nas somas de verificação? Alguns exemplos mudaram? - Certifique-se de que o conjunto de dados ainda possa ser criado.
- Envie-nos um PR
- Registre as novas somas de verificação com
Citação
Se você estiver usando tensorflow-datasets
para um artigo, inclua a seguinte citação, além de qualquer citação específica para os datasets usados (que pode ser encontrada no catálogo de dataset ).
@misc{TFDS,
title = { {TensorFlow Datasets}, A collection of ready-to-use datasets},
howpublished = {\url{https://www.tensorflow.org/datasets} },
}