Consejos de rendimiento

Este documento proporciona consejos de rendimiento específicos de TensorFlow Datasets (TFDS). Tenga en cuenta que TFDS proporciona conjuntos de datos como objetos tf.data.Dataset , por lo que los consejos de la guía tf.data aún se aplican.

Conjuntos de datos de referencia

Utilice tfds.benchmark(ds) para comparar cualquier objeto tf.data.Dataset .

Asegúrese de indicar el batch_size= para normalizar los resultados (por ejemplo, 100 iter/seg -> 3200 ex/seg). Esto funciona con cualquier iterable (por ejemplo, tfds.benchmark(tfds.as_numpy(ds)) ).

ds = tfds.load('mnist', split='train').batch(32).prefetch()
# Display some benchmark statistics
tfds.benchmark(ds, batch_size=32)
# Second iteration is much faster, due to auto-caching
tfds.benchmark(ds, batch_size=32)

Conjuntos de datos pequeños (menos de 1 GB)

Todos los conjuntos de datos TFDS almacenan los datos en el disco en formato TFRecord . Para conjuntos de datos pequeños (por ejemplo, MNIST, CIFAR-10/-100), la lectura de .tfrecord puede suponer una sobrecarga significativa.

A medida que esos conjuntos de datos caben en la memoria, es posible mejorar significativamente el rendimiento almacenando en caché o precargando el conjunto de datos. Tenga en cuenta que TFDS almacena en caché automáticamente pequeños conjuntos de datos (la siguiente sección tiene los detalles).

Almacenamiento en caché del conjunto de datos

A continuación se muestra un ejemplo de una canalización de datos que almacena en caché explícitamente el conjunto de datos después de normalizar las imágenes.

def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label


ds, ds_info = tfds.load(
    'mnist',
    split='train',
    as_supervised=True,  # returns `(img, label)` instead of dict(image=, ...)
    with_info=True,
)
# Applying normalization before `ds.cache()` to re-use it.
# Note: Random transformations (e.g. images augmentations) should be applied
# after both `ds.cache()` (to avoid caching randomness) and `ds.batch()` (for
# vectorization [1]).
ds = ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.cache()
# For true randomness, we set the shuffle buffer to the full dataset size.
ds = ds.shuffle(ds_info.splits['train'].num_examples)
# Batch after shuffling to get unique batches at each epoch.
ds = ds.batch(128)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

Al iterar sobre este conjunto de datos, la segunda iteración será mucho más rápida que la primera gracias al almacenamiento en caché.

Almacenamiento en caché automático

De forma predeterminada, TFDS almacena en caché automáticamente (con ds.cache() ) conjuntos de datos que satisfacen las siguientes restricciones:

  • El tamaño total del conjunto de datos (todas las divisiones) está definido y es < 250 MiB
  • shuffle_files está deshabilitado o solo se lee un fragmento

Es posible desactivar el almacenamiento en caché automático pasando try_autocaching=False a tfds.ReadConfig en tfds.load . Eche un vistazo a la documentación del catálogo de conjuntos de datos para ver si un conjunto de datos específico utilizará el caché automático.

Cargando los datos completos como un solo tensor

Si su conjunto de datos cabe en la memoria, también puede cargar el conjunto de datos completo como una única matriz Tensor o NumPy. Es posible hacerlo configurando batch_size=-1 para agrupar todos los ejemplos en un solo tf.Tensor . Luego use tfds.as_numpy para la conversión de tf.Tensor a np.array .

(img_train, label_train), (img_test, label_test) = tfds.as_numpy(tfds.load(
    'mnist',
    split=['train', 'test'],
    batch_size=-1,
    as_supervised=True,
))

Grandes conjuntos de datos

Los conjuntos de datos grandes están fragmentados (divididos en varios archivos) y normalmente no caben en la memoria, por lo que no deben almacenarse en caché.

Mezcla y entrenamiento

Durante el entrenamiento, es importante mezclar bien los datos; los datos mal mezclados pueden dar lugar a una menor precisión del entrenamiento.

Además de usar ds.shuffle para mezclar registros, también debe configurar shuffle_files=True para obtener un buen comportamiento de mezcla para conjuntos de datos más grandes que están divididos en varios archivos. De lo contrario, las épocas leerán los fragmentos en el mismo orden y, por lo tanto, los datos no serán realmente aleatorios.

ds = tfds.load('imagenet2012', split='train', shuffle_files=True)

Además, cuando shuffle_files=True , TFDS deshabilita options.deterministic , lo que puede aumentar ligeramente el rendimiento. Para obtener una mezcla aleatoria determinista, es posible desactivar esta función con tfds.ReadConfig : ya sea configurando read_config.shuffle_seed o sobrescribiendo read_config.options.deterministic .

Divide automáticamente tus datos entre trabajadores (TF)

Al capacitar a varios trabajadores, puede usar el argumento input_context de tfds.ReadConfig , de modo que cada trabajador leerá un subconjunto de datos.

input_context = tf.distribute.InputContext(
    input_pipeline_id=1,  # Worker id
    num_input_pipelines=4,  # Total number of workers
)
read_config = tfds.ReadConfig(
    input_context=input_context,
)
ds = tfds.load('dataset', split='train', read_config=read_config)

Esto es complementario a la API subsplit. Primero, se aplica la API de subdivisión: train[:50%] se convierte en una lista de archivos para leer. Luego, se aplica una operación ds.shard() en esos archivos. Por ejemplo, cuando se usa train[:50%] con num_input_pipelines=2 , cada uno de los 2 trabajadores leerá 1/4 de los datos.

Cuando shuffle_files=True , los archivos se mezclan dentro de un trabajador, pero no entre los trabajadores. Cada trabajador leerá el mismo subconjunto de archivos entre épocas.

Divide automáticamente tus datos entre trabajadores (Jax)

Con Jax, puede utilizar la API tfds.split_for_jax_process o tfds.even_splits para distribuir sus datos entre los trabajadores. Consulte la guía de API dividida .

split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)

tfds.split_for_jax_process es un alias simple para:

# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]

Decodificación de imágenes más rápida

De forma predeterminada, TFDS decodifica imágenes automáticamente. Sin embargo, hay casos en los que puede resultar más eficaz omitir la decodificación de la imagen con tfds.decode.SkipDecoding y aplicar manualmente la operación tf.io.decode_image :

El código de ambos ejemplos está disponible en la guía de decodificación .

Saltar funciones no utilizadas

Si solo está utilizando un subconjunto de funciones, es posible omitir por completo algunas funciones. Si su conjunto de datos tiene muchas funciones no utilizadas, no decodificarlas puede mejorar significativamente el rendimiento. Consulte https://www.tensorflow.org/datasets/decode#only_decode_a_sub-set_of_the_features

¡tf.data usa toda mi RAM!

Si tiene una RAM limitada o si está cargando muchos conjuntos de datos en paralelo mientras usa tf.data , aquí hay algunas opciones que pueden ayudar:

Anular el tamaño del búfer

builder.as_dataset(
  read_config=tfds.ReadConfig(
    ...
    override_buffer_size=1024,  # Save quite a bit of RAM.
  ),
  ...
)

Esto anula el buffer_size pasado a TFRecordDataset (o equivalente): https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset#args

Utilice tf.data.Dataset.with_options para detener comportamientos mágicos

https://www.tensorflow.org/api_docs/python/tf/data/Dataset#with_options

options = tf.data.Options()

# Stop magic stuff that eats up RAM:
options.autotune.enabled = False
options.experimental_distribute.auto_shard_policy = (
  tf.data.experimental.AutoShardPolicy.OFF)
options.experimental_optimization.inject_prefetch = False

data = data.with_options(options)