¡El Día de la Comunidad de ML es el 9 de noviembre! Únase a nosotros para recibir actualizaciones de TensorFlow, JAX, y más Más información

Consejos de rendimiento

Este documento proporciona consejos de rendimiento específicos de TFDS. Tenga en cuenta que TFDS proporciona conjuntos de datos como tf.data.Dataset s, por lo que el asesoramiento de la tf.data guía sigue siendo válida.

Conjuntos de datos de referencia

Uso tfds.benchmark(ds) a cualquier punto de referencia tf.data.Dataset objeto.

Asegúrese de indicar el batch_size= para normalizar los resultados (por ejemplo, 100 iter / seg -> 3200 EX / seg). Esto funciona con cualquier iterable (por ejemplo tfds.benchmark(tfds.as_numpy(ds)) ).

ds = tfds.load('mnist', split='train').batch(32).prefetch()
# Display some benchmark statistics
tfds.benchmark(ds, batch_size=32)
# Second iteration is much faster, due to auto-caching
tfds.benchmark(ds, batch_size=32)

Pequeños conjuntos de datos (<GB)

Todos los conjuntos de datos TFDS almacenan los datos en el disco en el TFRecord formato. Para los pequeños conjuntos de datos (por ejemplo Mnist, Cifar, ...), la lectura de .tfrecord puede agregar una sobrecarga significativa.

A medida que esos conjuntos de datos encajan en la memoria, es posible mejorar significativamente el rendimiento almacenando en caché o precargando el conjunto de datos. Tenga en cuenta que TFDS almacena automáticamente en caché pequeños conjuntos de datos (consulte la siguiente sección para obtener más detalles).

Almacenamiento en caché del conjunto de datos

A continuación, se muestra un ejemplo de una canalización de datos que almacena en caché explícitamente el conjunto de datos después de normalizar las imágenes.

def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label


ds, ds_info = tfds.load(
    'mnist',
    split='train',
    as_supervised=True,  # returns `(img, label)` instead of dict(image=, ...)
    with_info=True,
)
# Applying normalization before `ds.cache()` to re-use it.
# Note: Random transformations (e.g. images augmentations) should be applied
# after both `ds.cache()` (to avoid caching randomness) and `ds.batch()` (for
# vectorization [1]).
ds = ds.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.cache()
# For true randomness, we set the shuffle buffer to the full dataset size.
ds = ds.shuffle(ds_info.splits['train'].num_examples)
# Batch after shuffling to get unique batches at each epoch.
ds = ds.batch(128)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

Al iterar sobre este conjunto de datos, la segunda iteración será mucho más rápida que la primera gracias al almacenamiento en caché.

Almacenamiento en caché automático

Por defecto, los TFD auto-cachés (con ds.cache() ) conjuntos de datos que satisfacen las siguientes limitaciones:

  • El tamaño total del conjunto de datos (todas las divisiones) está definido y es <250 MiB
  • shuffle_files está desactivado o se lee únicamente un solo fragmento

Es posible optar por auto-almacenamiento en caché al pasar try_autocaching=False a tfds.ReadConfig en tfds.load . Eche un vistazo a la documentación del catálogo de conjuntos de datos para ver si un conjunto de datos específico utilizará la caché automática.

Cargando los datos completos como un solo tensor

Si su conjunto de datos cabe en la memoria, también puede cargar el conjunto de datos completo como una única matriz Tensor o NumPy. Es posible hacerlo por el ajuste batch_size=-1 a lotes todos los ejemplos en un solo tf.Tensor . A continuación, utilice tfds.as_numpy para la conversión de tf.Tensor a np.array .

(img_train, label_train), (img_test, label_test) = tfds.as_numpy(tfds.load(
    'mnist',
    split=['train', 'test'],
    batch_size=-1,
    as_supervised=True,
))

Grandes conjuntos de datos

Los conjuntos de datos grandes están fragmentados (divididos en varios archivos) y, por lo general, no caben en la memoria, por lo que no deben almacenarse en caché.

Barajar y entrenar

Durante el entrenamiento, es importante mezclar bien los datos; Los datos mal mezclados pueden resultar en una menor precisión del entrenamiento.

Además de utilizar ds.shuffle a barajar los registros, también se debe establecer shuffle_files=True para obtener un buen comportamiento arrastrando los pies para grandes conjuntos de datos que están fragmentados en varios archivos. De lo contrario, las épocas leerán los fragmentos en el mismo orden, por lo que los datos no serán realmente aleatorios.

ds = tfds.load('imagenet2012', split='train', shuffle_files=True)

Además, cuando shuffle_files=True , TFDS desactiva options.experimental_deterministic , lo que puede dar un ligero aumento de rendimiento. Para obtener barajar determinista, es posible darse de baja de esta función con tfds.ReadConfig : ya sea mediante el establecimiento de read_config.shuffle_seed o sobrescribir read_config.options.experimental_deterministic .

Divida automáticamente sus datos entre trabajadores (TF)

Cuando el entrenamiento de varios trabajadores, se puede utilizar el input_context argumento de tfds.ReadConfig , por lo que cada trabajador leerá un subconjunto de los datos.

input_context = tf.distribute.InputContext(
    input_pipeline_id=1,  # Worker id
    num_input_pipelines=4,  # Total number of workers
)
read_config = tfds.ReadConfig(
    input_context=input_context,
)
ds = tfds.load('dataset', split='train', read_config=read_config)

Esto es complementario a la API subsplit. En primer lugar se aplica la API subplit ( train[:50%] se convierte en una lista de archivos para leer), entonces un ds.shard() op se aplica en esos archivos. Ejemplo: cuando se utiliza train[:50%] con num_input_pipelines=2 , cada uno de los 2 trabajador leer un cuarto de los datos.

Cuando shuffle_files=True , los archivos se barajan dentro de uno de los trabajadores, pero no a través de los trabajadores. Cada trabajador leerá el mismo subconjunto de archivos entre épocas.

Divida automáticamente sus datos entre trabajadores (Jax)

Con Jax, se puede utilizar el tfds.even_splits API para distribuir sus datos a través de los trabajadores. Consulte la guía API de división .

splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
# The current `process_index` load only `1 / process_count` of the data.
ds = tfds.load('my_dataset', split=splits[jax.process_index()])

Decodificación de imágenes más rápida

Por defecto, TFDS decodifica automáticamente las imágenes. Sin embargo, hay casos en los que puede ser más performante para saltar la imagen se decodifica con tfds.decode.SkipDecoding y manualmente aplicar el tf.io.decode_image op:

  • Al filtrar ejemplos (con ds.filter ), para decodificar las imágenes después de los ejemplos han sido filtrados.
  • Cuando recortar imágenes, usar el fundido tf.image.decode_and_crop_jpeg op.

El código para ambos ejemplos está disponible en la guía de decodificación .

Omitir funciones no utilizadas

Si solo está utilizando un subconjunto de las funciones, es posible omitir algunas funciones por completo. Si su conjunto de datos tiene muchas funciones sin usar, no decodificarlas puede mejorar significativamente el rendimiento. ver https://www.tensorflow.org/datasets/decode#only_decode_a_sub-set_of_the_features