Dicas de desempenho

Este documento fornece dicas de desempenho específicas do TensorFlow Datasets (TFDS). Observe que o TFDS fornece conjuntos de dados como objetos tf.data.Dataset , portanto, o conselho do guia tf.data ainda se aplica.

Conjuntos de dados comparativos

Use tfds.benchmark(ds) para comparar qualquer objeto tf.data.Dataset .

Certifique-se de indicar batch_size= para normalizar os resultados (por exemplo, 100 iter/sec -> 3200 ex/sec). Isso funciona com qualquer iterável (por exemplo tfds.benchmark(tfds.as_numpy(ds)) ).

ds = tfds.load('mnist', split='train').batch(32).prefetch()
# Display some benchmark statistics
tfds.benchmark(ds, batch_size=32)
# Second iteration is much faster, due to auto-caching
tfds.benchmark(ds, batch_size=32)

Conjuntos de dados pequenos (menos de 1 GB)

Todos os conjuntos de dados TFDS armazenam os dados em disco no formato TFRecord . Para conjuntos de dados pequenos (por exemplo, MNIST, CIFAR-10/-100), a leitura de .tfrecord pode adicionar uma sobrecarga significativa.

Como esses conjuntos de dados cabem na memória, é possível melhorar significativamente o desempenho armazenando em cache ou pré-carregando o conjunto de dados. Observe que o TFDS armazena em cache pequenos conjuntos de dados automaticamente (a seção a seguir contém os detalhes).

Como armazenar em cache o conjunto de dados

Aqui está um exemplo de um pipeline de dados que armazena em cache explicitamente o conjunto de dados após normalizar as imagens.

def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label


ds, ds_info = tfds.load(
    'mnist',
    split='train',
    as_supervised=True,  # returns `(img, label)` instead of dict(image=, ...)
    with_info=True,
)
# Applying normalization before `ds.cache()` to re-use it.
# Note: Random transformations (e.g. images augmentations) should be applied
# after both `ds.cache()` (to avoid caching randomness) and `ds.batch()` (for
# vectorization [1]).
ds = ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.cache()
# For true randomness, we set the shuffle buffer to the full dataset size.
ds = ds.shuffle(ds_info.splits['train'].num_examples)
# Batch after shuffling to get unique batches at each epoch.
ds = ds.batch(128)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

Ao iterar sobre esse conjunto de dados, a segunda iteração será muito mais rápida que a primeira, graças ao armazenamento em cache.

Cache automático

Por padrão, o TFDS armazena em cache automático (com ds.cache() ) conjuntos de dados que atendem às seguintes restrições:

  • O tamanho total do conjunto de dados (todas as divisões) é definido e < 250 MiB
  • shuffle_files está desabilitado ou apenas um único fragmento é lido

É possível desativar o cache automático passando try_autocaching=False para tfds.ReadConfig em tfds.load . Dê uma olhada na documentação do catálogo de conjuntos de dados para ver se um conjunto de dados específico usará o cache automático.

Carregando os dados completos como um único tensor

Se o seu conjunto de dados couber na memória, você também poderá carregar o conjunto de dados completo como um único array Tensor ou NumPy. É possível fazer isso definindo batch_size=-1 para agrupar todos os exemplos em um único tf.Tensor . Em seguida, use tfds.as_numpy para a conversão de tf.Tensor para np.array .

(img_train, label_train), (img_test, label_test) = tfds.as_numpy(tfds.load(
    'mnist',
    split=['train', 'test'],
    batch_size=-1,
    as_supervised=True,
))

Grandes conjuntos de dados

Grandes conjuntos de dados são fragmentados (divididos em vários arquivos) e normalmente não cabem na memória, portanto, não devem ser armazenados em cache.

Embaralhar e treinar

Durante o treinamento, é importante embaralhar bem os dados - dados mal embaralhados podem resultar em menor precisão de treinamento.

Além de usar ds.shuffle para embaralhar registros, você também deve definir shuffle_files=True para obter um bom comportamento de embaralhamento para conjuntos de dados maiores que são fragmentados em vários arquivos. Caso contrário, as épocas lerão os fragmentos na mesma ordem e, portanto, os dados não serão realmente aleatórios.

ds = tfds.load('imagenet2012', split='train', shuffle_files=True)

Além disso, quando shuffle_files=True , o TFDS desativa options.deterministic , o que pode aumentar um pouco o desempenho. Para obter embaralhamento determinístico, é possível desativar esse recurso com tfds.ReadConfig : configurando read_config.shuffle_seed ou sobrescrevendo read_config.options.deterministic .

Partilhe automaticamente seus dados entre trabalhadores (TF)

Ao treinar em vários trabalhadores, você pode usar o argumento input_context de tfds.ReadConfig , para que cada trabalhador leia um subconjunto dos dados.

input_context = tf.distribute.InputContext(
    input_pipeline_id=1,  # Worker id
    num_input_pipelines=4,  # Total number of workers
)
read_config = tfds.ReadConfig(
    input_context=input_context,
)
ds = tfds.load('dataset', split='train', read_config=read_config)

Isso é complementar à API de subdivisão. Primeiro, a API de subdivisão é aplicada: train[:50%] é convertido em uma lista de arquivos a serem lidos. Em seguida, um op ds.shard() é aplicado a esses arquivos. Por exemplo, ao usar train[:50%] com num_input_pipelines=2 , cada um dos 2 trabalhadores lerá 1/4 dos dados.

Quando shuffle_files=True , os arquivos são embaralhados em um trabalhador, mas não entre trabalhadores. Cada trabalhador lerá o mesmo subconjunto de arquivos entre épocas.

Partilhe automaticamente seus dados entre trabalhadores (Jax)

Com o Jax, você pode usar a API tfds.split_for_jax_process ou tfds.even_splits para distribuir seus dados entre os trabalhadores. Consulte o guia da API dividida .

split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)

tfds.split_for_jax_process é um alias simples para:

# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]

Decodificação de imagem mais rápida

Por padrão, o TFDS decodifica imagens automaticamente. No entanto, há casos em que pode ser mais eficiente pular a decodificação da imagem com tfds.decode.SkipDecoding e aplicar manualmente a tf.io.decode_image :

O código para ambos os exemplos está disponível no guia de decodificação .

Ignorar recursos não utilizados

Se você estiver usando apenas um subconjunto dos recursos, é possível pular totalmente alguns recursos. Se seu conjunto de dados tiver muitos recursos não utilizados, não decodificá-los pode melhorar significativamente o desempenho. Consulte https://www.tensorflow.org/datasets/decode#only_decode_a_sub-set_of_the_features