TFDS y determinismo

Ver en TensorFlow.org Ejecutar en Google Colab Ver en GitHub Descargar cuaderno

Este documento explica:

  • Las garantías TFDS sobre el determinismo
  • ¿En qué orden TFDS lee los ejemplos?
  • Varias advertencias y trampas

Configuración

Conjuntos de datos

Se necesita algo de contexto para comprender cómo TFDS lee los datos.

Durante la generación, TFDS escriben los datos originales en estandarizados .tfrecord archivos. Para grandes conjuntos de datos, múltiples .tfrecord archivos se crean, cada uno con múltiples ejemplos. Llamamos a cada .tfrecord presentar una esquirla.

Esta guía utiliza imagenet que tiene 1024 fragmentos:

import re
import tensorflow_datasets as tfds

imagenet = tfds.builder('imagenet2012')

num_shards = imagenet.info.splits['train'].num_shards
num_examples = imagenet.info.splits['train'].num_examples
print(f'imagenet has {num_shards} shards ({num_examples} examples)')
imagenet has 1024 shards (1281167 examples)

Encontrar los ID de ejemplos de conjuntos de datos

Puede pasar a la siguiente sección si solo desea conocer el determinismo.

Cada ejemplo conjunto de datos se identifica de forma única por un id (por ejemplo, 'imagenet2012-train.tfrecord-01023-of-01024__32' ). Puede recuperar este id pasando read_config.add_tfds_id = True que agregará un 'tfds_id' llave en el dict de la tf.data.Dataset .

En este tutorial, definimos una pequeña utilidad que imprimirá los identificadores de ejemplo del conjunto de datos (convertido en entero para que sea más legible por humanos):

Determinismo al leer

En esta sección se explica garantía deterministim de tfds.load .

Con shuffle_files=False (predeterminado)

Por TFDS predeterminados producen ejemplos determinista ( shuffle_files=False )

# Same as: imagenet.as_dataset(split='train').take(20)
print_ex_ids(imagenet, split='train', take=20)
print_ex_ids(imagenet, split='train', take=20)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]

Para obtener un rendimiento, TFDS leer varios fragmentos al mismo tiempo utilizando tf.data.Dataset.interleave . Vemos en este ejemplo que TFDS cambiar al fragmento 2 después de leer 16 ejemplos ( ..., 14, 15, 1251, 1252, ... ). Más sobre entrelazado a continuación.

Del mismo modo, la API subplit también es determinista:

print_ex_ids(imagenet, split='train[67%:84%]', take=20)
print_ex_ids(imagenet, split='train[67%:84%]', take=20)
[858382, 858383, 858384, 858385, 858386, 858387, 858388, 858389, 858390, 858391, 858392, 858393, 858394, 858395, 858396, 858397, 859533, 859534, 859535, 859536]
[858382, 858383, 858384, 858385, 858386, 858387, 858388, 858389, 858390, 858391, 858392, 858393, 858394, 858395, 858396, 858397, 859533, 859534, 859535, 859536]

Si usted está entrenando para más de una época, la configuración anterior no se recomienda como todas las épocas leerán los fragmentos en el mismo orden (por lo que el azar se limita a los ds = ds.shuffle(buffer) el tamaño de búfer).

Con shuffle_files=True

Con shuffle_files=True , los fragmentos se barajan para cada época, por lo que la lectura no es determinista más.

print_ex_ids(imagenet, split='train', shuffle_files=True, take=20)
print_ex_ids(imagenet, split='train', shuffle_files=True, take=20)
[568017, 329050, 329051, 329052, 329053, 329054, 329056, 329055, 568019, 568020, 568021, 568022, 568023, 568018, 568025, 568024, 568026, 568028, 568030, 568031]
[43790, 43791, 43792, 43793, 43796, 43794, 43797, 43798, 43795, 43799, 43800, 43801, 43802, 43803, 43804, 43805, 43806, 43807, 43809, 43810]

Consulte la receta a continuación para obtener una mezcla de archivos determinista.

Advertencia sobre el determinismo: intercalar argumentos

Cambiando read_config.interleave_cycle_length , read_config.interleave_block_length cambiará el orden ejemplos.

TFDS se basa en tf.data.Dataset.interleave sólo para cargar un par de fragmentos a la vez, mejorar el rendimiento y reducir el uso de memoria.

Solo se garantiza que el orden del ejemplo sea el mismo para un valor fijo de argumentos entrelazados. Ver doc intercalación de entender lo cycle_length y block_length corresponden también.

  • cycle_length=16 , block_length=16 (valor predeterminado, igual al anterior):
print_ex_ids(imagenet, split='train', take=20)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]
  • cycle_length=3 , block_length=2 :
read_config = tfds.ReadConfig(
    interleave_cycle_length=3,
    interleave_block_length=2,
)
print_ex_ids(imagenet, split='train', read_config=read_config, take=20)
[0, 1, 1251, 1252, 2502, 2503, 2, 3, 1253, 1254, 2504, 2505, 4, 5, 1255, 1256, 2506, 2507, 6, 7]

En el segundo ejemplo, se ve que el conjunto de datos leer 2 ( block_length=2 ) ejemplos en un fragmento, a continuación, cambiar al siguiente fragmento. Cada 2 * 3 ( cycle_length=3 ) ejemplos, que se remonta a la primera fragmento ( shard0-ex0, shard0-ex1, shard1-ex0, shard1-ex1, shard2-ex0, shard2-ex1, shard0-ex2, shard0-ex3, shard1-ex2, shard1-ex3, shard2-ex2,... ).

Subdivisión y orden de ejemplo

Cada ejemplo tiene un id 0, 1, ..., num_examples-1 . El API subsplit seleccionar una rebanada de ejemplos (por ejemplo, train[:x] seleccione 0, 1, ..., x-1 ).

Sin embargo, dentro de la subdivisión, los ejemplos no se leen en orden creciente de identificación (debido a fragmentos e intercalación).

Más específicamente, ds.take(x) y split='train[:x]' no son equivalentes!

Esto se puede ver fácilmente en el ejemplo de intercalación anterior donde los ejemplos provienen de diferentes fragmentos.

print_ex_ids(imagenet, split='train', take=25)  # tfds.load(..., split='train').take(25)
print_ex_ids(imagenet, split='train[:25]', take=-1)  # tfds.load(..., split='train[:25]')
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]

Después de los 16 ejemplos (block_length), .take(25) conmuta al siguiente fragmento mientras train[:25] seguir leyendo ejemplos en los de la primera fragmento.

Recetas

Obtenga una reproducción aleatoria de archivos determinista

Hay 2 formas de tener una mezcla determinista:

  1. Ajuste de la shuffle_seed . Nota: Esto requiere cambiar la semilla en cada época; de lo contrario, los fragmentos se leerán en el mismo orden entre las épocas.
read_config = tfds.ReadConfig(
    shuffle_seed=32,
)

# Deterministic order, different from the default shuffle_files=False above
print_ex_ids(imagenet, split='train', shuffle_files=True, read_config=read_config, take=22)
print_ex_ids(imagenet, split='train', shuffle_files=True, read_config=read_config, take=22)
[176411, 176412, 176413, 176414, 176415, 176416, 176417, 176418, 176419, 176420, 176421, 176422, 176423, 176424, 176425, 176426, 710647, 710648, 710649, 710650, 710651, 710652]
[176411, 176412, 176413, 176414, 176415, 176416, 176417, 176418, 176419, 176420, 176421, 176422, 176423, 176424, 176425, 176426, 710647, 710648, 710649, 710650, 710651, 710652]
  1. Usando experimental_interleave_sort_fn : Esto le da un control total sobre la que se leen fragmentos y en qué orden, en lugar de depender de ds.shuffle orden.
def _reverse_order(file_instructions):
  return list(reversed(file_instructions))

read_config = tfds.ReadConfig(
    experimental_interleave_sort_fn=_reverse_order,
)

# Last shard (01023-of-01024) is read first
print_ex_ids(imagenet, split='train', read_config=read_config, take=5)
[1279916, 1279917, 1279918, 1279919, 1279920]

Obtenga una canalización prioritaria determinista

Éste es más complicado. No existe una solución fácil y satisfactoria.

  1. Sin ds.shuffle y con barajado determinista, en teoría, debería ser posible para contar los ejemplos que se han leído y deducir que los ejemplos han sido leídos dentro en cada fragmento (como una función de cycle_length , block_length y el orden fragmento). A continuación, el skip , take para cada fragmento se podría inyectar a través de experimental_interleave_sort_fn .

  2. Con ds.shuffle lo más probable es imposible sin reproducir la tubería formación completa. Sería necesario salvar el ds.shuffle estado tapón deducir qué ejemplos han sido leídos. Ejemplos podrían ser no continua (por ejemplo shard5_ex2 , shard5_ex4 leer pero no shard5_ex3 ).

  3. Con ds.shuffle , una forma sería guardar todos shards_ids / example_ids leen (deducen de tfds_id ), y luego deducir las instrucciones de archivo de eso.

El caso más simple de 1. es tener .skip(x).take(y) partido train[x:x+y] partido. Requiere:

  • Set cycle_length=1 (de modo fragmentos se leen secuencialmente)
  • Set shuffle_files=False
  • No utilice ds.shuffle

Solo debe usarse en un conjunto de datos enorme donde el entrenamiento es solo de 1 época. Los ejemplos se leerían en el orden aleatorio predeterminado.

read_config = tfds.ReadConfig(
    interleave_cycle_length=1,  # Read shards sequentially
)

print_ex_ids(imagenet, split='train', read_config=read_config, skip=40, take=22)
# If the job get pre-empted, using the subsplit API will skip at most `len(shard0)`
print_ex_ids(imagenet, split='train[40:]', read_config=read_config, take=22)
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]

Encuentre qué fragmentos / ejemplos se leen para una subplit determinada

Con la tfds.core.DatasetInfo , tiene acceso directo a las instrucciones de lectura.

imagenet.info.splits['train[44%:45%]'].file_instructions
[FileInstruction(filename='imagenet2012-train.tfrecord-00450-of-01024', skip=700, take=-1, num_examples=551),
 FileInstruction(filename='imagenet2012-train.tfrecord-00451-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00452-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00453-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00454-of-01024', skip=0, take=-1, num_examples=1252),
 FileInstruction(filename='imagenet2012-train.tfrecord-00455-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00456-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00457-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00458-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00459-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00460-of-01024', skip=0, take=1001, num_examples=1001)]