Esta página foi traduzida pela API Cloud Translation.
Switch to English

Conjuntos de dados TensorFlow

TFDS fornece uma coleção de conjuntos de dados prontos para uso com TensorFlow, Jax e outras estruturas de aprendizado de máquina.

Ele lida com o download e a preparação dos dados deterministicamente e construindo umtf.data.Dataset (ou np.array ).

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub

Instalação

TFDS existe em dois pacotes:

  • pip install tensorflow-datasets : a versão estável, lançada a cada poucos meses.
  • pip install tfds-nightly : lançado todos os dias, contém as últimas versões dos conjuntos de dados.

Esta colab usa tfds-nightly :

pip install -q tfds-nightly tensorflow matplotlib
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

import tensorflow_datasets as tfds

Encontre conjuntos de dados disponíveis

Todos os construtores de conjuntos de dados são subclasses de tfds.core.DatasetBuilder . Para obter a lista de construtores disponíveis, use tfds.list_builders() ou consulte nosso catálogo .

tfds.list_builders()
['abstract_reasoning',
 'accentdb',
 'aeslc',
 'aflw2k3d',
 'ag_news_subset',
 'ai2_arc',
 'ai2_arc_with_ir',
 'amazon_us_reviews',
 'anli',
 'arc',
 'bair_robot_pushing_small',
 'bccd',
 'beans',
 'big_patent',
 'bigearthnet',
 'billsum',
 'binarized_mnist',
 'binary_alpha_digits',
 'blimp',
 'bool_q',
 'c4',
 'caltech101',
 'caltech_birds2010',
 'caltech_birds2011',
 'cars196',
 'cassava',
 'cats_vs_dogs',
 'celeb_a',
 'celeb_a_hq',
 'cfq',
 'cherry_blossoms',
 'chexpert',
 'cifar10',
 'cifar100',
 'cifar10_1',
 'cifar10_corrupted',
 'citrus_leaves',
 'cityscapes',
 'civil_comments',
 'clevr',
 'clic',
 'clinc_oos',
 'cmaterdb',
 'cnn_dailymail',
 'coco',
 'coco_captions',
 'coil100',
 'colorectal_histology',
 'colorectal_histology_large',
 'common_voice',
 'coqa',
 'cos_e',
 'cosmos_qa',
 'covid19sum',
 'crema_d',
 'curated_breast_imaging_ddsm',
 'cycle_gan',
 'dart',
 'davis',
 'deep_weeds',
 'definite_pronoun_resolution',
 'dementiabank',
 'diabetic_retinopathy_detection',
 'div2k',
 'dmlab',
 'downsampled_imagenet',
 'drop',
 'dsprites',
 'dtd',
 'duke_ultrasound',
 'e2e_cleaned',
 'emnist',
 'eraser_multi_rc',
 'esnli',
 'eurosat',
 'fashion_mnist',
 'flic',
 'flores',
 'food101',
 'forest_fires',
 'fuss',
 'gap',
 'geirhos_conflict_stimuli',
 'genomics_ood',
 'german_credit_numeric',
 'gigaword',
 'glue',
 'goemotions',
 'gpt3',
 'gref',
 'groove',
 'gtzan',
 'gtzan_music_speech',
 'hellaswag',
 'higgs',
 'horses_or_humans',
 'howell',
 'i_naturalist2017',
 'imagenet2012',
 'imagenet2012_corrupted',
 'imagenet2012_real',
 'imagenet2012_subset',
 'imagenet_a',
 'imagenet_r',
 'imagenet_resized',
 'imagenet_v2',
 'imagenette',
 'imagewang',
 'imdb_reviews',
 'irc_disentanglement',
 'iris',
 'kitti',
 'kmnist',
 'lambada',
 'lfw',
 'librispeech',
 'librispeech_lm',
 'libritts',
 'ljspeech',
 'lm1b',
 'lost_and_found',
 'lsun',
 'lvis',
 'malaria',
 'math_dataset',
 'mctaco',
 'mlqa',
 'mnist',
 'mnist_corrupted',
 'movie_lens',
 'movie_rationales',
 'movielens',
 'moving_mnist',
 'multi_news',
 'multi_nli',
 'multi_nli_mismatch',
 'natural_questions',
 'natural_questions_open',
 'newsroom',
 'nsynth',
 'nyu_depth_v2',
 'omniglot',
 'open_images_challenge2019_detection',
 'open_images_v4',
 'openbookqa',
 'opinion_abstracts',
 'opinosis',
 'opus',
 'oxford_flowers102',
 'oxford_iiit_pet',
 'para_crawl',
 'patch_camelyon',
 'paws_wiki',
 'paws_x_wiki',
 'pet_finder',
 'pg19',
 'piqa',
 'places365_small',
 'plant_leaves',
 'plant_village',
 'plantae_k',
 'qa4mre',
 'qasc',
 'quac',
 'quickdraw_bitmap',
 'race',
 'radon',
 'reddit',
 'reddit_disentanglement',
 'reddit_tifu',
 'resisc45',
 'robonet',
 'rock_paper_scissors',
 'rock_you',
 's3o4d',
 'salient_span_wikipedia',
 'samsum',
 'savee',
 'scan',
 'scene_parse150',
 'scicite',
 'scientific_papers',
 'sentiment140',
 'shapes3d',
 'siscore',
 'smallnorb',
 'snli',
 'so2sat',
 'speech_commands',
 'spoken_digit',
 'squad',
 'stanford_dogs',
 'stanford_online_products',
 'starcraft_video',
 'stl10',
 'story_cloze',
 'sun397',
 'super_glue',
 'svhn_cropped',
 'ted_hrlr_translate',
 'ted_multi_translate',
 'tedlium',
 'tf_flowers',
 'the300w_lp',
 'tiny_shakespeare',
 'titanic',
 'trec',
 'trivia_qa',
 'tydi_qa',
 'uc_merced',
 'ucf101',
 'vctk',
 'vgg_face2',
 'visual_domain_decathlon',
 'voc',
 'voxceleb',
 'voxforge',
 'waymo_open_dataset',
 'web_nlg',
 'web_questions',
 'wider_face',
 'wiki40b',
 'wiki_bio',
 'wiki_table_questions',
 'wiki_table_text',
 'wikihow',
 'wikipedia',
 'wikipedia_toxicity_subtypes',
 'wine_quality',
 'winogrande',
 'wmt14_translate',
 'wmt15_translate',
 'wmt16_translate',
 'wmt17_translate',
 'wmt18_translate',
 'wmt19_translate',
 'wmt_t2t_translate',
 'wmt_translate',
 'wordnet',
 'wsc273',
 'xnli',
 'xquad',
 'xsum',
 'xtreme_pawsx',
 'xtreme_xnli',
 'yelp_polarity_reviews',
 'yes_no']

Carregar um conjunto de dados

tfds.load

A maneira mais fácil de carregar um conjunto de dados é tfds.load . Será:

  1. Baixe os dados e salve-os como arquivos tfrecord .
  2. Carregue o tfrecord e crie otf.data.Dataset .
ds = tfds.load('mnist', split='train', shuffle_files=True)
assert isinstance(ds, tf.data.Dataset)
print(ds)
Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1...
Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
<_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>

Alguns argumentos comuns:

  • split= : Qual divisão ler (por exemplo, 'train' , ['train', 'test'] , 'train[80%:]' , ...). Veja nosso guia de divisão de API .
  • shuffle_files= : controla se deve embaralhar os arquivos entre cada época (TFDS armazena grandes conjuntos de dados em vários arquivos menores).
  • data_dir= : Local onde o conjunto de dados é salvo (o padrão é ~/tensorflow_datasets/ )
  • with_info=True : Retorna o tfds.core.DatasetInfo contendo metadados do conjunto de dados
  • download=False : Desativar download

tfds.builder

tfds.load é um wrapper fino em torno de tfds.core.DatasetBuilder . Você pode obter a mesma saída usando a API tfds.core.DatasetBuilder :

builder = tfds.builder('mnist')
# 1. Create the tfrecord files (no-op if already exists)
builder.download_and_prepare()
# 2. Load the `tf.data.Dataset`
ds = builder.as_dataset(split='train', shuffle_files=True)
print(ds)
<_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>

tfds build CLI

Se você deseja gerar um conjunto de dados específico, você pode usar a linha de comando tfds . Por exemplo:

tfds build mnist

Veja o documento para sinalizadores disponíveis.

Iterar sobre um conjunto de dados

Como dict

Por padrão, otf.data.Dataset objeto contém um dict de tf.Tensor s:

ds = tfds.load('mnist', split='train')
ds = ds.take(1)  # Only take a single example

for example in ds:  # example is `{'image': tf.Tensor, 'label': tf.Tensor}`
  print(list(example.keys()))
  image = example["image"]
  label = example["label"]
  print(image.shape, label)
['image', 'label']
(28, 28, 1) tf.Tensor(4, shape=(), dtype=int64)

Para descobrir os nomes e a estrutura das chaves dict , consulte a documentação do conjunto de dados em nosso catálogo . Por exemplo: documentação mnist .

Como tupla ( as_supervised=True )

Usando as_supervised=True , você pode obter uma tupla (features, label) vez de conjuntos de dados supervisionados.

ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)

for image, label in ds:  # example is (image, label)
  print(image.shape, label)
(28, 28, 1) tf.Tensor(4, shape=(), dtype=int64)

Tão numpy ( tfds.as_numpy )

Usa tfds.as_numpy para converter:

  • tf.Tensor -> np.array
  • tf.data.Dataset -> Iterator[Tree[np.array]] (A Tree pode ser Dict , Tuple aninhada arbitrariamente)
ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)

for image, label in tfds.as_numpy(ds):
  print(type(image), type(label), label)
<class 'numpy.ndarray'> <class 'numpy.int64'> 4

Como tf.Tensor em lote ( batch_size=-1 )

Usando batch_size=-1 , você pode carregar o conjunto de dados completo em um único lote.

Isso pode ser combinado com as_supervised=True e tfds.as_numpy para obter os dados como (np.array, np.array) :

image, label = tfds.as_numpy(tfds.load(
    'mnist',
    split='test',
    batch_size=-1,
    as_supervised=True,
))

print(type(image), image.shape)
<class 'numpy.ndarray'> (10000, 28, 28, 1)

Tenha cuidado para que seu conjunto de dados caiba na memória e que todos os exemplos tenham o mesmo formato.

Construir pipeline de ponta a ponta

Para ir mais longe, você pode olhar:

Visualização

tfds.as_dataframe

tf.data.Dataset objetostf.data.Dataset podem ser convertidos em pandas.DataFrame com tfds.as_dataframe para serem visualizados no Colab .

  • Adicione o tfds.core.DatasetInfo como segundo argumento de tfds.as_dataframe para visualizar imagens, áudio, textos, vídeos, ...
  • Use ds.take(x) para exibir apenas os primeiros x exemplos. pandas.DataFrame carregará o conjunto de dados completo na memória e pode ser muito caro para exibir.
ds, info = tfds.load('mnist', split='train', with_info=True)

tfds.as_dataframe(ds.take(4), info)

tfds.show_examples

tfds.show_examples retorna um matplotlib.figure.Figure (apenas conjuntos de dados de imagem são suportados agora):

ds, info = tfds.load('mnist', split='train', with_info=True)

fig = tfds.show_examples(ds, info)

png

Acesse os metadados do conjunto de dados

Todos os construtores incluem um objeto tfds.core.DatasetInfo contendo os metadados do conjunto de dados.

Ele pode ser acessado através de:

ds, info = tfds.load('mnist', with_info=True)
builder = tfds.builder('mnist')
info = builder.info

As informações do conjunto de dados contém informações adicionais sobre o conjunto de dados (versão, citação, página inicial, descrição, ...).

print(info)
tfds.core.DatasetInfo(
    name='mnist',
    full_name='mnist/3.0.1',
    description="""
    The MNIST database of handwritten digits.
    """,
    homepage='http://yann.lecun.com/exdb/mnist/',
    data_path='/home/kbuilder/tensorflow_datasets/mnist/3.0.1',
    download_size=11.06 MiB,
    dataset_size=21.00 MiB,
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    supervised_keys=('image', 'label'),
    splits={
        'test': <SplitInfo num_examples=10000, num_shards=1>,
        'train': <SplitInfo num_examples=60000, num_shards=1>,
    },
    citation="""@article{lecun2010mnist,
      title={MNIST handwritten digit database},
      author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
      journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},
      volume={2},
      year={2010}
    }""",
)

Metadados de recursos (nomes de rótulos, formato de imagem, ...)

Acesse tfds.features.FeatureDict :

info.features
FeaturesDict({
    'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
    'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
})

Número de classes, nomes de rótulos:

print(info.features["label"].num_classes)
print(info.features["label"].names)
print(info.features["label"].int2str(7))  # Human readable version (8 -> 'cat')
print(info.features["label"].str2int('7'))
10
['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
7
7

Formas, tipos d:

print(info.features.shape)
print(info.features.dtype)
print(info.features['image'].shape)
print(info.features['image'].dtype)
{'image': (28, 28, 1), 'label': ()}
{'image': tf.uint8, 'label': tf.int64}
(28, 28, 1)
<dtype: 'uint8'>

Metadados de divisão (por exemplo, nomes de divisão, número de exemplos, ...)

Acesse o tfds.core.SplitDict :

print(info.splits)
{'test': <SplitInfo num_examples=10000, num_shards=1>, 'train': <SplitInfo num_examples=60000, num_shards=1>}

Divisões disponíveis:

print(list(info.splits.keys()))
['test', 'train']

Obtenha informações sobre a divisão individual:

print(info.splits['train'].num_examples)
print(info.splits['train'].filenames)
print(info.splits['train'].num_shards)
60000
['mnist-train.tfrecord-00000-of-00001']
1

Também funciona com a API subsplit:

print(info.splits['train[15%:75%]'].num_examples)
print(info.splits['train[15%:75%]'].file_instructions)
36000
[FileInstruction(filename='mnist-train.tfrecord-00000-of-00001', skip=9000, take=36000, num_examples=36000)]

Solução de problemas

Download manual (se o download falhar)

Se o download falhar por algum motivo (por exemplo, offline, ...). Você sempre pode baixar manualmente os dados você mesmo e colocá-los no manual_dir (o padrão é ~/tensorflow_datasets/download/manual/ .

Para descobrir quais URLs baixar, verifique:

Corrigindo NonMatchingChecksumError

TFDS garante determinismo validando as somas de verificação de urls baixados. Se NonMatchingChecksumError for gerado, pode indicar:

  • O site pode estar fora do ar (por exemplo 503 status code ). Verifique o url.
  • Para URLs do Google Drive, tente novamente mais tarde, pois o Drive às vezes rejeita downloads quando muitas pessoas acessam o mesmo URL. Ver bug
  • Os arquivos de conjuntos de dados originais podem ter sido atualizados. Nesse caso, o construtor do conjunto de dados TFDS deve ser atualizado. Abra um novo problema no Github ou RP:
    • Registre as novas somas de verificação com tfds build --register_checksums
    • Eventualmente, atualize o código de geração do conjunto de dados.
    • Atualize o conjunto de dados VERSION
    • Atualize o conjunto de dados RELEASE_NOTES : O que causou a mudança nas somas de verificação? Alguns exemplos mudaram?
    • Certifique-se de que o conjunto de dados ainda possa ser criado.
    • Envie-nos um PR

Citação

Se você estiver usando tensorflow-datasets para um artigo, inclua a seguinte citação, além de qualquer citação específica para os datasets usados ​​(que pode ser encontrada no catálogo de dataset ).

@misc{TFDS,
  title = { {TensorFlow Datasets}, A collection of ready-to-use datasets},
  howpublished = {\url{https://www.tensorflow.org/datasets} },
}