Esta página foi traduzida pela API Cloud Translation.
Switch to English

Conjuntos de dados TensorFlow

O TFDS fornece uma coleção de conjuntos de dados prontos para uso para uso com o TensorFlow, Jax e outras estruturas de Machine Learning.

Ele lida com o download e a preparação dos dados de maneira determinística e com a construção de um tf.data.Dataset (ou np.array ).

Ver em TensorFlow.org Executar no Google Colab Ver fonte no GitHub

Instalação

O TFDS existe em dois pacotes:

  • tensorflow-datasets : a versão estável, lançada a cada poucos meses.
  • tfds-nightly : lançado todos os dias, contém as últimas versões dos conjuntos de dados.

Para instalar:

 pip install tensorflow-datasets
 

Essa tfds-nightly usa tfds-nightly e TF 2.

pip install -q tensorflow>=2 tfds-nightly matplotlib
 import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

import tensorflow_datasets as tfds
 

Encontre conjuntos de dados disponíveis

Todos os construtores de conjuntos de dados são subclasses de tfds.core.DatasetBuilder . Para obter a lista de construtores disponíveis, use tfds.list_builders() ou consulte nosso catálogo .

 tfds.list_builders()
 
['abstract_reasoning',
 'aeslc',
 'aflw2k3d',
 'amazon_us_reviews',
 'anli',
 'arc',
 'bair_robot_pushing_small',
 'beans',
 'big_patent',
 'bigearthnet',
 'billsum',
 'binarized_mnist',
 'binary_alpha_digits',
 'blimp',
 'c4',
 'caltech101',
 'caltech_birds2010',
 'caltech_birds2011',
 'cars196',
 'cassava',
 'cats_vs_dogs',
 'celeb_a',
 'celeb_a_hq',
 'cfq',
 'chexpert',
 'cifar10',
 'cifar100',
 'cifar10_1',
 'cifar10_corrupted',
 'citrus_leaves',
 'cityscapes',
 'civil_comments',
 'clevr',
 'cmaterdb',
 'cnn_dailymail',
 'coco',
 'coil100',
 'colorectal_histology',
 'colorectal_histology_large',
 'common_voice',
 'cos_e',
 'cosmos_qa',
 'covid19sum',
 'crema_d',
 'curated_breast_imaging_ddsm',
 'cycle_gan',
 'deep_weeds',
 'definite_pronoun_resolution',
 'dementiabank',
 'diabetic_retinopathy_detection',
 'div2k',
 'dmlab',
 'downsampled_imagenet',
 'dsprites',
 'dtd',
 'duke_ultrasound',
 'emnist',
 'eraser_multi_rc',
 'esnli',
 'eurosat',
 'fashion_mnist',
 'flic',
 'flores',
 'food101',
 'forest_fires',
 'gap',
 'geirhos_conflict_stimuli',
 'german_credit_numeric',
 'gigaword',
 'glue',
 'groove',
 'higgs',
 'horses_or_humans',
 'i_naturalist2017',
 'image_label_folder',
 'imagenet2012',
 'imagenet2012_corrupted',
 'imagenet2012_subset',
 'imagenet_resized',
 'imagenette',
 'imagewang',
 'imdb_reviews',
 'irc_disentanglement',
 'iris',
 'kitti',
 'kmnist',
 'lfw',
 'librispeech',
 'librispeech_lm',
 'libritts',
 'ljspeech',
 'lm1b',
 'lost_and_found',
 'lsun',
 'malaria',
 'math_dataset',
 'mctaco',
 'mnist',
 'mnist_corrupted',
 'movie_rationales',
 'moving_mnist',
 'multi_news',
 'multi_nli',
 'multi_nli_mismatch',
 'natural_questions',
 'newsroom',
 'nsynth',
 'nyu_depth_v2',
 'omniglot',
 'open_images_challenge2019_detection',
 'open_images_v4',
 'opinion_abstracts',
 'opinosis',
 'oxford_flowers102',
 'oxford_iiit_pet',
 'para_crawl',
 'patch_camelyon',
 'pet_finder',
 'pg19',
 'places365_small',
 'plant_leaves',
 'plant_village',
 'plantae_k',
 'qa4mre',
 'quickdraw_bitmap',
 'reddit',
 'reddit_tifu',
 'resisc45',
 'robonet',
 'rock_paper_scissors',
 'rock_you',
 'samsum',
 'savee',
 'scan',
 'scene_parse150',
 'scicite',
 'scientific_papers',
 'shapes3d',
 'smallnorb',
 'snli',
 'so2sat',
 'speech_commands',
 'squad',
 'stanford_dogs',
 'stanford_online_products',
 'starcraft_video',
 'stl10',
 'sun397',
 'super_glue',
 'svhn_cropped',
 'ted_hrlr_translate',
 'ted_multi_translate',
 'tedlium',
 'tf_flowers',
 'the300w_lp',
 'tiny_shakespeare',
 'titanic',
 'trivia_qa',
 'uc_merced',
 'ucf101',
 'vgg_face2',
 'visual_domain_decathlon',
 'voc',
 'voxceleb',
 'voxforge',
 'waymo_open_dataset',
 'web_questions',
 'wider_face',
 'wiki40b',
 'wikihow',
 'wikipedia',
 'winogrande',
 'wmt14_translate',
 'wmt15_translate',
 'wmt16_translate',
 'wmt17_translate',
 'wmt18_translate',
 'wmt19_translate',
 'wmt_t2t_translate',
 'wmt_translate',
 'xnli',
 'xsum',
 'yelp_polarity_reviews']

Carregar um conjunto de dados

A maneira mais fácil de carregar um conjunto de dados é tfds.load . Será:

  1. Faça o download dos dados e salve-os como arquivos tfrecord .
  2. Carregue o tfrecord e crie o tf.data.Dataset .
 ds = tfds.load('mnist', split='train', shuffle_files=True)
assert isinstance(ds, tf.data.Dataset)
print(ds)
 
WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.


Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1...
Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
<_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>

Alguns argumentos comuns:

  • split= : Qual divisão para ler (por exemplo, 'train' , ['train', 'test'] , 'train[80%:]' , ...). Veja nosso guia de API dividido .
  • shuffle_files= : controla se os arquivos devem ser embaralhados entre cada época (o TFDS armazena grandes conjuntos de dados em vários arquivos menores).
  • data_dir= : local onde o conjunto de dados é salvo (o padrão é ~/tensorflow_datasets/ )
  • with_info=True : retorna o tfds.core.DatasetInfo contém metadados do conjunto de dados
  • download=False : desativar o download

tfds.load é um invólucro fino em torno do tfds.core.DatasetBuilder . Você pode obter a mesma saída usando a API tfds.core.DatasetBuilder :

 builder = tfds.builder('mnist')
# 1. Create the tfrecord files (no-op if already exists)
builder.download_and_prepare()
# 2. Load the `tf.data.Dataset`
ds = builder.as_dataset(split='train', shuffle_files=True)
print(ds)
 
<_OptionsDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>

Iterar sobre um conjunto de dados

Como ditado

Por padrão, o tf.data.Dataset objeto contém um dict de tf.Tensor s:

 ds = tfds.load('mnist', split='train')
ds = ds.take(1)  # Only take a single example

for example in ds:  # example is `{'image': tf.Tensor, 'label': tf.Tensor}`
  print(list(example.keys()))
  image = example["image"]
  label = example["label"]
  print(image.shape, label)
 
['image', 'label']
(28, 28, 1) tf.Tensor(4, shape=(), dtype=int64)

Como tupla

Usando as_supervised=True , você pode obter uma tupla (features, label) para conjuntos de dados supervisionados.

 ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)

for image, label in ds:  # example is (image, label)
  print(image.shape, label)
 
(28, 28, 1) tf.Tensor(4, shape=(), dtype=int64)

Tão entorpecido

Usa tfds.as_numpy para converter:

 ds = tfds.load('mnist', split='train', as_supervised=True)
ds = ds.take(1)

for image, label in tfds.as_numpy(ds):
  print(type(image), type(label), label)
 
<class 'numpy.ndarray'> <class 'numpy.int64'> 4

Como lote tf.Tensor

Usando batch_size=-1 , você pode carregar o conjunto de dados completo em um único lote.

tfds.load retornará um dict ( tuple com as_supervised=True ) de tf.Tensor ( np.array com tfds.as_numpy ).

Cuidado para que seu conjunto de dados caiba na memória e que todos os exemplos tenham a mesma forma.

 image, label = tfds.as_numpy(tfds.load(
    'mnist',
    split='test', 
    batch_size=-1, 
    as_supervised=True,
))

print(type(image), image.shape)
 
<class 'numpy.ndarray'> (10000, 28, 28, 1)

Crie um pipeline de ponta a ponta

Para ir além, você pode procurar:

  • Nosso exemplo completo do Keras para ver um pipeline de treinamento completo (com lotes, embaralhamento, ...).
  • Nosso guia de desempenho para melhorar a velocidade dos seus dutos.

Visualize um conjunto de dados

Visualize conjuntos de dados com tfds.show_examples (apenas conjuntos de dados de imagem são suportados agora):

 ds, info = tfds.load('mnist', split='train', with_info=True)

fig = tfds.show_examples(ds, info)
 

png

Acesse os metadados do conjunto de dados

Todos os construtores incluem um objeto tfds.core.DatasetInfo contém os metadados do conjunto de dados.

Pode ser acessado através de:

 ds, info = tfds.load('mnist', with_info=True)
 
 builder = tfds.builder('mnist')
info = builder.info
 

As informações do conjunto de dados contêm informações adicionais sobre o conjunto de dados (versão, citação, página inicial, descrição, ...).

 print(info)
 
tfds.core.DatasetInfo(
    name='mnist',
    version=3.0.1,
    description='The MNIST database of handwritten digits.',
    homepage='http://yann.lecun.com/exdb/mnist/',
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    total_num_examples=70000,
    splits={
        'test': 10000,
        'train': 60000,
    },
    supervised_keys=('image', 'label'),
    citation="""@article{lecun2010mnist,
      title={MNIST handwritten digit database},
      author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
      journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},
      volume={2},
      year={2010}
    }""",
    redistribution_info=,
)


Metadados de recursos (nomes de rótulos, formato da imagem, ...)

Acesse tfds.features.FeatureDict :

 info.features
 
FeaturesDict({
    'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
    'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
})

Número de classes, nomes de rótulos:

 print(info.features["label"].num_classes)
print(info.features["label"].names)
print(info.features["label"].int2str(7))  # Human readable version (8 -> 'cat')
print(info.features["label"].str2int('7'))
 
10
['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
7
7

Formas, tipos:

 print(info.features.shape)
print(info.features.dtype)
print(info.features['image'].shape)
print(info.features['image'].dtype)
 
{'image': (28, 28, 1), 'label': ()}
{'image': tf.uint8, 'label': tf.int64}
(28, 28, 1)
<dtype: 'uint8'>

Metadados de divisão (por exemplo, nomes de divisão, número de exemplos, ...)

Acesse o tfds.core.SplitDict :

 print(info.splits)
 
{'test': <tfds.core.SplitInfo num_examples=10000>, 'train': <tfds.core.SplitInfo num_examples=60000>}

Divisões disponíveis:

 print(list(info.splits.keys()))
 
['test', 'train']

Obtenha informações sobre a divisão individual:

 print(info.splits['train'].num_examples)
print(info.splits['train'].filenames)
print(info.splits['train'].num_shards)
 
60000
['mnist-train.tfrecord-00000-of-00001']
1

Também funciona com a API do subsplit:

 print(info.splits['train[15%:75%]'].num_examples)
print(info.splits['train[15%:75%]'].file_instructions)
 
36000
[FileInstruction(filename='mnist-train.tfrecord-00000-of-00001', skip=9000, take=36000, num_examples=36000)]

Citação

Se você estiver usando tensorflow-datasets para um artigo, inclua a seguinte citação, além de qualquer citação específica para os conjuntos de dados usados ​​(que pode ser encontrada no catálogo de conjuntos de dados ).

 @misc{TFDS,
  title = { {TensorFlow Datasets}, A collection of ready-to-use datasets},
  howpublished = {\url{https://www.tensorflow.org/datasets} },
}