Dicas de desempenho

Este documento fornece dicas de desempenho específicas do TFDS. Note-se que TFDS fornece conjuntos de dados como tf.data.Dataset s, assim que o conselho do tf.data guia ainda se aplica.

Conjuntos de dados de referência

Use tfds.benchmark(ds) para aferir qualquer tf.data.Dataset objeto.

Certifique-se de indicar o batch_size= para normalizar os resultados (por exemplo, 100 iter / seg -> 3200 ex / seg). Isso funciona com qualquer iterable (por exemplo tfds.benchmark(tfds.as_numpy(ds)) ).

ds = tfds.load('mnist', split='train').batch(32).prefetch()
# Display some benchmark statistics
tfds.benchmark(ds, batch_size=32)
# Second iteration is much faster, due to auto-caching
tfds.benchmark(ds, batch_size=32)

Conjuntos de dados pequenos (<GB)

Todos os conjuntos de dados TFDS armazenar os dados no disco no TFRecord formato. Para os pequenos conjuntos de dados (por exemplo Mnist, Cifar, ...), a leitura de .tfrecord pode adicionar uma sobrecarga significativa.

Como esses conjuntos de dados cabem na memória, é possível melhorar significativamente o desempenho armazenando em cache ou pré-carregando o conjunto de dados. Observe que o TFDS armazena automaticamente em cache pequenos conjuntos de dados (consulte a próxima seção para obter detalhes).

Cache do conjunto de dados

Aqui está um exemplo de um pipeline de dados que armazena em cache explicitamente o conjunto de dados após normalizar as imagens.

def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label


ds, ds_info = tfds.load(
    'mnist',
    split='train',
    as_supervised=True,  # returns `(img, label)` instead of dict(image=, ...)
    with_info=True,
)
# Applying normalization before `ds.cache()` to re-use it.
# Note: Random transformations (e.g. images augmentations) should be applied
# after both `ds.cache()` (to avoid caching randomness) and `ds.batch()` (for
# vectorization [1]).
ds = ds.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.cache()
# For true randomness, we set the shuffle buffer to the full dataset size.
ds = ds.shuffle(ds_info.splits['train'].num_examples)
# Batch after shuffling to get unique batches at each epoch.
ds = ds.batch(128)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

Ao iterar sobre este conjunto de dados, a segunda iteração será muito mais rápida do que a primeira graças ao armazenamento em cache.

Cache automático

Por padrão, TFDS auto-caches (com ds.cache() ) conjuntos de dados que preencham as seguintes restrições:

  • O tamanho total do conjunto de dados (todas as divisões) é definido e <250 MiB
  • shuffle_files está desativado, ou apenas um único fragmento é lido

É possível optar por sair de auto-caching passando try_autocaching=False para tfds.ReadConfig em tfds.load . Dê uma olhada na documentação do catálogo do conjunto de dados para ver se um conjunto de dados específico usará o cache automático.

Carregando os dados completos como um único Tensor

Se seu conjunto de dados couber na memória, você também pode carregar o conjunto de dados completo como um único Tensor ou array NumPy. É possível fazê-lo por definição batch_size=-1 para lote em todos os exemplos um único tf.Tensor . Em seguida, use tfds.as_numpy para a conversão de tf.Tensor para np.array .

(img_train, label_train), (img_test, label_test) = tfds.as_numpy(tfds.load(
    'mnist',
    split=['train', 'test'],
    batch_size=-1,
    as_supervised=True,
))

Grandes conjuntos de dados

Grandes conjuntos de dados são fragmentados (divididos em vários arquivos) e normalmente não cabem na memória, portanto não devem ser armazenados em cache.

Embaralhar e treinar

Durante o treinamento, é importante misturar bem os dados; dados mal embaralhados podem resultar em menor precisão de treinamento.

Além de usar ds.shuffle para baralhar registros, você também deve definir shuffle_files=True para obter um bom comportamento baralhar para conjuntos de dados maiores que são fragmentados em vários arquivos. Caso contrário, as épocas lerão os fragmentos na mesma ordem e, portanto, os dados não serão verdadeiramente randomizados.

ds = tfds.load('imagenet2012', split='train', shuffle_files=True)

Além disso, quando shuffle_files=True , TFDS desactiva options.experimental_deterministic , o que pode dar um ligeiro aumento de desempenho. Para obter baralhar determinista, é possível opt-out desta função com tfds.ReadConfig : quer através da criação read_config.shuffle_seed ou substituir read_config.options.experimental_deterministic .

Auto-shard seus dados entre os trabalhadores (TF)

Quando o treinamento em vários trabalhadores, você pode usar o input_context argumento de tfds.ReadConfig , de modo que cada trabalhador irá ler um subconjunto dos dados.

input_context = tf.distribute.InputContext(
    input_pipeline_id=1,  # Worker id
    num_input_pipelines=4,  # Total number of workers
)
read_config = tfds.ReadConfig(
    input_context=input_context,
)
ds = tfds.load('dataset', split='train', read_config=read_config)

Isso é complementar à API subsplit. Primeiro a API subplit é aplicada ( train[:50%] é convertido em uma lista de arquivos para ler), então um ds.shard() op é aplicada sobre esses arquivos. Exemplo: quando se utiliza train[:50%] com num_input_pipelines=2 , cada um dos dois trabalhador lerá 1/4 dos dados.

Quando shuffle_files=True , os arquivos são embaralhadas dentro de um trabalhador, mas não entre os trabalhadores. Cada trabalhador lerá o mesmo subconjunto de arquivos entre as épocas.

Auto-shard seus dados entre os trabalhadores (Jax)

Com Jax, você pode usar o tfds.even_splits API para distribuir seus dados através de trabalhadores. Veja o guia de separação API .

splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
# The current `process_index` load only `1 / process_count` of the data.
ds = tfds.load('my_dataset', split=splits[jax.process_index()])

Descodificação de imagem mais rápida

Por padrão, o TFDS decodifica automaticamente as imagens. No entanto, há casos em que pode ser mais alto desempenho para saltar a imagem decodificação com tfds.decode.SkipDecoding e manualmente aplicar o tf.io.decode_image op:

  • Ao filtrar exemplos (com ds.filter ), para descodificar imagens após exemplos foram filtrados.
  • Quando cortar imagens, para usar o fundido tf.image.decode_and_crop_jpeg op.

O código para ambos os exemplos está disponível no guia de descodificação .

Pular recursos não utilizados

Se você estiver usando apenas um subconjunto dos recursos, é possível ignorar totalmente alguns recursos. Se o seu conjunto de dados tiver muitos recursos não utilizados, não decodificá-los pode melhorar significativamente o desempenho. veja https://www.tensorflow.org/datasets/decode#only_decode_a_sub-set_of_the_features