Consejos de rendimiento

Este documento proporciona consejos de rendimiento específicos de TensorFlow Datasets (TFDS). Tenga en cuenta que TFDS proporciona conjuntos de datos como objetos tf.data.Dataset , por lo que se siguen aplicando los consejos de la guía tf.data .

Conjuntos de datos de referencia

Utilice tfds.benchmark(ds) para comparar cualquier objeto tf.data.Dataset .

Asegúrese de indicar el tamaño del batch_size= para normalizar los resultados (por ejemplo, 100 iter/sec -> 3200 ex/sec). Esto funciona con cualquier iterable (por ejemplo tfds.benchmark(tfds.as_numpy(ds)) ).

ds = tfds.load('mnist', split='train').batch(32).prefetch()
# Display some benchmark statistics
tfds.benchmark(ds, batch_size=32)
# Second iteration is much faster, due to auto-caching
tfds.benchmark(ds, batch_size=32)

Conjuntos de datos pequeños (menos de 1 GB)

Todos los conjuntos de datos TFDS almacenan los datos en el disco en formato TFRecord . Para conjuntos de datos pequeños (p. ej., MNIST, CIFAR-10/-100), la lectura de .tfrecord puede agregar una sobrecarga significativa.

Como esos conjuntos de datos caben en la memoria, es posible mejorar significativamente el rendimiento almacenando en caché o precargando el conjunto de datos. Tenga en cuenta que TFDS almacena automáticamente en caché pequeños conjuntos de datos (la siguiente sección tiene los detalles).

Almacenamiento en caché del conjunto de datos

Este es un ejemplo de una canalización de datos que almacena explícitamente en caché el conjunto de datos después de normalizar las imágenes.

def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label


ds, ds_info = tfds.load(
    'mnist',
    split='train',
    as_supervised=True,  # returns `(img, label)` instead of dict(image=, ...)
    with_info=True,
)
# Applying normalization before `ds.cache()` to re-use it.
# Note: Random transformations (e.g. images augmentations) should be applied
# after both `ds.cache()` (to avoid caching randomness) and `ds.batch()` (for
# vectorization [1]).
ds = ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.cache()
# For true randomness, we set the shuffle buffer to the full dataset size.
ds = ds.shuffle(ds_info.splits['train'].num_examples)
# Batch after shuffling to get unique batches at each epoch.
ds = ds.batch(128)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

Al iterar sobre este conjunto de datos, la segunda iteración será mucho más rápida que la primera gracias al almacenamiento en caché.

Almacenamiento en caché automático

De forma predeterminada, TFDS almacena automáticamente en caché (con ds.cache() ) conjuntos de datos que cumplen las siguientes restricciones:

  • El tamaño total del conjunto de datos (todas las divisiones) está definido y < 250 MiB
  • shuffle_files está deshabilitado o solo se lee un único fragmento

Es posible desactivar el almacenamiento automático en caché pasando try_autocaching=False a tfds.ReadConfig en tfds.load . Eche un vistazo a la documentación del catálogo de conjuntos de datos para ver si un conjunto de datos específico utilizará el almacenamiento en caché automático.

Cargando los datos completos como un solo tensor

Si su conjunto de datos cabe en la memoria, también puede cargar el conjunto de datos completo como una sola matriz Tensor o NumPy. Es posible hacerlo configurando batch_size=-1 para agrupar todos los ejemplos en un solo tf.Tensor . Luego use tfds.as_numpy para la conversión de tf.Tensor a np.array .

(img_train, label_train), (img_test, label_test) = tfds.as_numpy(tfds.load(
    'mnist',
    split=['train', 'test'],
    batch_size=-1,
    as_supervised=True,
))

Grandes conjuntos de datos

Los conjuntos de datos grandes están fragmentados (divididos en varios archivos) y, por lo general, no caben en la memoria, por lo que no deben almacenarse en caché.

Barajar y entrenar

Durante el entrenamiento, es importante mezclar bien los datos; los datos mal mezclados pueden resultar en una menor precisión del entrenamiento.

Además de usar ds.shuffle para mezclar registros, también debe establecer shuffle_files=True para obtener un buen comportamiento de mezcla para conjuntos de datos más grandes que se fragmentan en varios archivos. De lo contrario, las épocas leerán los fragmentos en el mismo orden y, por lo tanto, los datos no serán realmente aleatorios.

ds = tfds.load('imagenet2012', split='train', shuffle_files=True)

Además, cuando shuffle_files=True , TFDS desactiva options.deterministic , lo que puede aumentar ligeramente el rendimiento. Para obtener una reproducción aleatoria determinista, es posible optar por no participar en esta función con tfds.ReadConfig : configurando read_config.shuffle_seed o sobrescribiendo read_config.options.deterministic .

Fragmentación automática de sus datos entre trabajadores (TF)

Al capacitar a varios trabajadores, puede usar el argumento input_context de tfds.ReadConfig , de modo que cada trabajador lea un subconjunto de los datos.

input_context = tf.distribute.InputContext(
    input_pipeline_id=1,  # Worker id
    num_input_pipelines=4,  # Total number of workers
)
read_config = tfds.ReadConfig(
    input_context=input_context,
)
ds = tfds.load('dataset', split='train', read_config=read_config)

Esto es complementario a la API de subdivisión. Primero, se aplica la API subdividida: train[:50%] se convierte en una lista de archivos para leer. Luego, se aplica una ds.shard() en esos archivos. Por ejemplo, al usar train[:50%] con num_input_pipelines=2 , cada uno de los 2 trabajadores leerá 1/4 de los datos.

Cuando shuffle_files=True , los archivos se mezclan dentro de un trabajador, pero no entre trabajadores. Cada trabajador leerá el mismo subconjunto de archivos entre épocas.

Fragmentación automática de sus datos entre trabajadores (Jax)

Con Jax, puede usar la API tfds.split_for_jax_process o tfds.even_splits para distribuir sus datos entre los trabajadores. Consulte la guía de API dividida .

split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)

tfds.split_for_jax_process es un alias simple para:

# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]

Decodificación de imágenes más rápida

De forma predeterminada, TFDS decodifica automáticamente las imágenes. Sin embargo, hay casos en los que puede ser más eficaz omitir la decodificación de la imagen con tfds.decode.SkipDecoding y aplicar manualmente la tf.io.decode_image :

El código de ambos ejemplos está disponible en la guía de decodificación .

Omitir funciones no utilizadas

Si solo está utilizando un subconjunto de las funciones, es posible omitir por completo algunas funciones. Si su conjunto de datos tiene muchas funciones sin usar, no decodificarlas puede mejorar significativamente el rendimiento. Consulte https://www.tensorflow.org/datasets/decode#only_decode_a_sub-set_of_the_features