Se usó la API de Cloud Translation para traducir esta página.
Switch to English

Mejor rendimiento con la API tf.data

Ver en TensorFlow.org Ejecutar en Google Colab Ver código fuente en GitHub Descargar cuaderno

Visión general

Las GPU y TPU pueden reducir radicalmente el tiempo requerido para ejecutar un solo paso de entrenamiento. Lograr el máximo rendimiento requiere una canalización de entrada eficiente que entregue datos para el siguiente paso antes de que el paso actual haya finalizado. La API tf.data ayuda a construir canalizaciones de entrada flexibles y eficientes. Este documento muestra cómo usar la API tf.data para construir canalizaciones de entrada de TensorFlow de alto rendimiento.

Antes de continuar, lea la guía " Construir canales de entrada de TensorFlow " para aprender a usar la API tf.data .

Recursos

Preparar

 import tensorflow as tf

import time
 

A lo largo de esta guía, iterará en un conjunto de datos y medirá el rendimiento. Hacer puntos de referencia de rendimiento reproducibles puede ser difícil, diferentes factores lo impactan:

  • la carga actual de la CPU,
  • el tráfico de red,
  • mecanismos complejos como caché, etc.

Por lo tanto, para proporcionar un punto de referencia reproducible, cree un ejemplo artificial.

El conjunto de datos

Defina una clase que herede de tf.data.Dataset llamada ArtificialDataset . Este conjunto de datos:

  • genera muestras de num_samples (el valor predeterminado es 3)
  • duerme un tiempo antes del primer elemento para simular la apertura de un archivo
  • duerme durante algún tiempo antes de producir cada elemento para simular la lectura de datos de un archivo
 class ArtificialDataset(tf.data.Dataset):
    def _generator(num_samples):
        # Opening the file
        time.sleep(0.03)
        
        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            time.sleep(0.015)
            
            yield (sample_idx,)
    
    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_types=tf.dtypes.int64,
            output_shapes=(1,),
            args=(num_samples,)
        )
 

Este conjunto de datos es similar al tf.data.Dataset.range one, agregando un retraso fijo al comienzo y entre cada muestra.

El circuito de entrenamiento

Escriba un bucle de entrenamiento ficticio que mida cuánto tiempo lleva iterar sobre un conjunto de datos. El tiempo de entrenamiento es simulado.

 def benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        for sample in dataset:
            # Performing a training step
            time.sleep(0.01)
    tf.print("Execution time:", time.perf_counter() - start_time)
 

Optimizar el rendimiento

Para mostrar cómo se puede optimizar el rendimiento, mejorará el rendimiento del ArtificialDataset .

El enfoque ingenuo

Comience con una canalización ingenua sin trucos, iterando sobre el conjunto de datos tal como está.

 benchmark(ArtificialDataset())
 
Execution time: 0.2346214259999897

Debajo del capó, así es como se pasó el tiempo de ejecución:

Ingenuo

Puede ver que realizar un paso de entrenamiento implica:

  • abrir un archivo si aún no se ha abierto,
  • recuperar una entrada de datos del archivo,
  • utilizando los datos para el entrenamiento.

Sin embargo, en una implementación síncrona ingenua como aquí, mientras su canalización está recuperando los datos, su modelo está inactivo. Por el contrario, mientras su modelo está entrenando, la tubería de entrada está inactiva. El tiempo del paso de entrenamiento es, por lo tanto, la suma de todos, el tiempo de apertura, lectura y entrenamiento.

Las siguientes secciones se basan en esta tubería de entrada, ilustrando las mejores prácticas para diseñar tuberías de entrada de TensorFlow de alto rendimiento.

Captación previa

La captación previa solapa el preprocesamiento y la ejecución del modelo de un paso de capacitación. Mientras el modelo está ejecutando los pasos de entrenamiento s , la canalización de entrada está leyendo los datos para el paso s+1 . Hacerlo reduce el tiempo de paso al máximo (en oposición a la suma) del entrenamiento y el tiempo que lleva extraer los datos.

La API tf.data proporciona la transformación tf.data.Dataset.prefetch . Se puede usar para desacoplar el momento en que se producen los datos del momento en que se consumen los datos. En particular, la transformación utiliza un subproceso en segundo plano y un búfer interno para captar previamente elementos del conjunto de datos de entrada antes de que se soliciten. El número de elementos para pretratar debe ser igual (o posiblemente mayor que) al número de lotes consumidos por un solo paso de entrenamiento. Puede ajustar manualmente este valor o establecerlo en tf.data.experimental.AUTOTUNE que hará que el tiempo de ejecución tf.data ajuste el valor dinámicamente en tiempo de ejecución.

Tenga en cuenta que la transformación de captación previa proporciona beneficios cada vez que existe la oportunidad de superponer el trabajo de un "productor" con el trabajo de un "consumidor".

 benchmark(
    ArtificialDataset()
    .prefetch(tf.data.experimental.AUTOTUNE)
)
 
Execution time: 0.19066840700008925

Captado previamente

Esta vez puede ver que mientras se ejecuta el paso de entrenamiento para la muestra 0, la canalización de entrada está leyendo los datos para la muestra 1, y así sucesivamente.

Extracción de datos en paralelo

En una configuración del mundo real, los datos de entrada pueden almacenarse de forma remota (por ejemplo, GCS o HDFS). Una tubería de conjunto de datos que funciona bien al leer datos localmente puede tener un cuello de botella en E / S al leer datos de forma remota debido a las siguientes diferencias entre el almacenamiento local y remoto:

  • Tiempo hasta el primer byte: leer el primer byte de un archivo desde el almacenamiento remoto puede tomar órdenes de magnitud más largas que desde el almacenamiento local.
  • Rendimiento de lectura: si bien el almacenamiento remoto generalmente ofrece un gran ancho de banda agregado, la lectura de un solo archivo solo puede utilizar una pequeña fracción de este ancho de banda.

Además, una vez que los bytes sin procesar se cargan en la memoria, también puede ser necesario deserializar y / o descifrar los datos (por ejemplo, protobuf ), lo que requiere un cálculo adicional. Esta sobrecarga está presente independientemente de si los datos se almacenan localmente o remotamente, pero puede ser peor en el caso remoto si los datos no se obtienen previamente de manera efectiva.

Para mitigar el impacto de los diversos gastos generales de extracción de datos, la transformación tf.data.Dataset.interleave se puede utilizar para paralelizar el paso de carga de datos, intercalando el contenido de otros conjuntos de datos (como lectores de archivos de datos). El número de conjuntos de datos a superponer puede especificarse mediante el argumento cycle_length , mientras que el nivel de paralelismo puede especificarse mediante el argumento num_parallel_calls . Al igual que en la prefetch transformación, la interleave de transformación apoya tf.data.experimental.AUTOTUNE que delegar la decisión acerca de qué nivel de paralelismo con el uso de la tf.data tiempo de ejecución.

Intercalación secuencial

Los argumentos predeterminados de la transformación tf.data.Dataset.interleave hacen que intercalar muestras individuales de dos conjuntos de datos secuencialmente.

 benchmark(
    tf.data.Dataset.range(2)
    .interleave(ArtificialDataset)
)
 
Execution time: 0.19681244399998832

Intercalación secuencial

Este gráfico permite exhibir el comportamiento de la transformación interleave , obteniendo muestras alternativamente de los dos conjuntos de datos disponibles. Sin embargo, aquí no hay ninguna mejora en el rendimiento.

Intercalado paralelo

Ahora use el argumento num_parallel_calls de la transformación interleave . Esto carga múltiples conjuntos de datos en paralelo, lo que reduce el tiempo de espera para que se abran los archivos.

 benchmark(
    tf.data.Dataset.range(2)
    .interleave(
        ArtificialDataset,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
)
 
Execution time: 0.1402252690000978

Intercalado paralelo

Esta vez, la lectura de los dos conjuntos de datos es paralela, lo que reduce el tiempo de procesamiento de datos global.

Paralelización de la transformación de datos

Al preparar los datos, es posible que los elementos de entrada deban procesarse previamente. Para este fin, la API tf.data ofrece la transformación tf.data.Dataset.map , que aplica una función definida por el usuario a cada elemento del conjunto de datos de entrada. Debido a que los elementos de entrada son independientes entre sí, el preprocesamiento puede ser paralelo a través de múltiples núcleos de CPU. Para que esto sea posible, de manera similar a las prefetch e interleave transformaciones, el map transformación proporciona la num_parallel_calls argumento para especificar el nivel de paralelismo.

Elegir el mejor valor para el argumento num_parallel_calls depende de su hardware, las características de sus datos de entrenamiento (como su tamaño y forma), el costo de su función de mapa y qué otro procesamiento está sucediendo en la CPU al mismo tiempo. Una heurística simple es usar el número de núcleos de CPU disponibles. Sin embargo, en cuanto a la prefetch y la interleave la transformación, el map transformación apoya tf.data.experimental.AUTOTUNE que delegar la decisión acerca de qué nivel de paralelismo con el uso de la tf.data tiempo de ejecución.

 def mapped_function(s):
    # Do some hard pre-processing
    tf.py_function(lambda: time.sleep(0.03), [], ())
    return s
 

Mapeo secuencial

Comience usando la transformación del map sin paralelismo como ejemplo de línea de base.

 benchmark(
    ArtificialDataset()
    .map(mapped_function)
)
 
Execution time: 0.41994816599992646

Mapeo secuencial

En cuanto al enfoque ingenuo , aquí los tiempos dedicados a la apertura, la lectura, el preprocesamiento (mapeo) y los pasos de entrenamiento se suman para una sola iteración.

Mapeo paralelo

Ahora, use la misma función de preprocesamiento pero aplíquela en paralelo en varias muestras.

 benchmark(
    ArtificialDataset()
    .map(
        mapped_function,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
)
 
Execution time: 0.25377997199996116

Mapeo paralelo

Ahora, puede ver en la gráfica que los pasos de preprocesamiento se superponen, reduciendo el tiempo total para una sola iteración.

Almacenamiento en caché

La transformación tf.data.Dataset.cache puede almacenar en caché un conjunto de datos, ya sea en la memoria o en el almacenamiento local. Esto evitará que algunas operaciones (como la apertura de archivos y la lectura de datos) se ejecuten durante cada época.

 benchmark(
    ArtificialDataset()
    .map(  # Apply time consuming operations before cache
        mapped_function
    ).cache(
    ),
    5
)
 
Execution time: 0.36316757900010543

Conjunto de datos en caché

Cuando almacena en caché un conjunto de datos, las transformaciones anteriores a la de cache (como la apertura de archivos y la lectura de datos) se ejecutan solo durante la primera época. Las próximas épocas reutilizarán los datos almacenados en caché por la transformación de cache .

Si la función definida por el usuario que se pasa a la transformación del map es costosa, aplique la transformación de cache después de la transformación del map , siempre y cuando el conjunto de datos resultante pueda caber en la memoria o en el almacenamiento local. Si la función definida por el usuario aumenta el espacio requerido para almacenar el conjunto de datos más allá de la capacidad de caché, aplíquelo después de la transformación de cache o considere preprocesar sus datos antes de su trabajo de capacitación para reducir el uso de recursos.

Mapeo de vectorización

Invocar una función definida por el usuario que se pasa a la transformación del map tiene una sobrecarga relacionada con la programación y ejecución de la función definida por el usuario. Recomendamos vectorizar la función definida por el usuario (es decir, hacerla funcionar sobre un lote de entradas a la vez) y aplicar la transformación por batch antes de la transformación del map .

Para ilustrar esta buena práctica, su conjunto de datos artificial no es adecuado. El retraso de la programación es de alrededor de 10 microsegundos (10e-6 segundos), mucho menos que las decenas de milisegundos utilizados en el ArtificialDataset y, por lo tanto, su impacto es difícil de ver.

Para este ejemplo, use la función base tf.data.Dataset.range y simplifique el ciclo de entrenamiento a su forma más simple.

 fast_dataset = tf.data.Dataset.range(10000)

def fast_benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for _ in tf.data.Dataset.range(num_epochs):
        for _ in dataset:
            pass
    tf.print("Execution time:", time.perf_counter() - start_time)
    
def increment(x):
    return x+1
 

Mapeo escalar

 fast_benchmark(
    fast_dataset
    # Apply function one item at a time
    .map(increment)
    # Batch
    .batch(256)
)
 
Execution time: 0.9026268119999941

Mapa escalar

La gráfica anterior ilustra lo que está sucediendo (con menos muestras). Puede ver que la función asignada se aplica para cada muestra. Si bien esta función es muy rápida, tiene algunos gastos generales que afectan el rendimiento del tiempo.

Mapeo vectorizado

 fast_benchmark(
    fast_dataset
    .batch(256)
    # Apply function on a batch of items
    # The tf.Tensor.__add__ method already handle batches
    .map(increment)
)
 
Execution time: 0.03311353300000519

Mapa vectorizado

Esta vez, la función asignada se llama una vez y se aplica a un lote de muestra. Si bien la función puede tardar más tiempo en ejecutarse, la sobrecarga aparece solo una vez, mejorando el rendimiento general del tiempo.

Reducción de la huella de memoria

Una serie de transformaciones, incluidas interleave , prefetch y shuffle , mantienen un búfer interno de elementos. Si la función definida por el usuario que se pasa a la transformación del map cambia el tamaño de los elementos, entonces el orden de la transformación del mapa y las transformaciones que los elementos del búfer afectan el uso de la memoria. En general, recomendamos elegir el orden que resulte en una menor huella de memoria, a menos que sea deseable un orden diferente para el rendimiento.

Almacenamiento en caché de cálculos parciales

Se recomienda almacenar en caché el conjunto de datos después de la transformación del map excepto si esta transformación hace que los datos sean demasiado grandes para caber en la memoria. Se puede lograr una compensación si la función asignada se puede dividir en dos partes: una que consume mucho tiempo y una parte que consume memoria. En este caso, puede encadenar sus transformaciones como a continuación:

 dataset.map(time_consuming_mapping).cache().map(memory_consuming_mapping)
 

De esta manera, la parte que consume mucho tiempo solo se ejecuta durante la primera época, y evita usar demasiado espacio de caché.

Resumen de mejores prácticas

Aquí hay un resumen de las mejores prácticas para diseñar canalizaciones de entrada de TensorFlow de alto rendimiento:

Reproduciendo las figuras

Para profundizar en la tf.data.Dataset API tf.data.Dataset , puede jugar con sus propios canales. A continuación se muestra el código utilizado para trazar las imágenes de esta guía. Puede ser un buen punto de partida, mostrando algunas soluciones para dificultades comunes como:

  • Tiempo de ejecución reproducibilidad;
  • Funciones mapeadas ejecución ansiosa;
  • transformación interleave invocable.
 import itertools
from collections import defaultdict

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
 

El conjunto de datos

De manera similar al ArtificialDataset , puede crear un conjunto de datos que devuelva el tiempo empleado en cada paso.

 class TimeMeasuredDataset(tf.data.Dataset):
    # OUTPUT: (steps, timings, counters)
    OUTPUT_TYPES = (tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32)
    OUTPUT_SHAPES = ((2, 1), (2, 2), (2, 3))
    
    _INSTANCES_COUNTER = itertools.count()  # Number of datasets generated
    _EPOCHS_COUNTER = defaultdict(itertools.count)  # Number of epochs done for each dataset
    
    def _generator(instance_idx, num_samples):
        epoch_idx = next(TimeMeasuredDataset._EPOCHS_COUNTER[instance_idx])
        
        # Opening the file
        open_enter = time.perf_counter()
        time.sleep(0.03)
        open_elapsed = time.perf_counter() - open_enter
        
        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            read_enter = time.perf_counter()
            time.sleep(0.015)
            read_elapsed = time.perf_counter() - read_enter
            
            yield (
                [("Open",), ("Read",)],
                [(open_enter, open_elapsed), (read_enter, read_elapsed)],
                [(instance_idx, epoch_idx, -1), (instance_idx, epoch_idx, sample_idx)]
            )
            open_enter, open_elapsed = -1., -1.  # Negative values will be filtered
            
    
    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_types=cls.OUTPUT_TYPES,
            output_shapes=cls.OUTPUT_SHAPES,
            args=(next(cls._INSTANCES_COUNTER), num_samples)
        )
 

Este conjunto de datos proporciona muestras de forma [[2, 1], [2, 2], [2, 3]] y del tipo [tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32] . Cada muestra es:

 (
  [("Open"), ("Read")],
  [(t0, d), (t0, d)],
  [(i, e, -1), (i, e, s)]
)
 

Dónde:

  • Open y Read son identificadores de pasos
  • t0 es la marca de tiempo cuando comenzó el paso correspondiente
  • d es el tiempo empleado en el paso correspondiente
  • i es el índice de instancia
  • e es el índice de época (número de veces que se ha iterado el conjunto de datos)
  • s es el índice de muestra

El ciclo de iteración

Haga que el ciclo de iteración sea un poco más complicado para agregar todos los tiempos. Esto solo funcionará con conjuntos de datos que generen muestras como se detalla anteriormente.

 def timelined_benchmark(dataset, num_epochs=2):
    # Initialize accumulators
    steps_acc = tf.zeros([0, 1], dtype=tf.dtypes.string)
    times_acc = tf.zeros([0, 2], dtype=tf.dtypes.float32)
    values_acc = tf.zeros([0, 3], dtype=tf.dtypes.int32)
    
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        epoch_enter = time.perf_counter()
        for (steps, times, values) in dataset:
            # Record dataset preparation informations
            steps_acc = tf.concat((steps_acc, steps), axis=0)
            times_acc = tf.concat((times_acc, times), axis=0)
            values_acc = tf.concat((values_acc, values), axis=0)
            
            # Simulate training time
            train_enter = time.perf_counter()
            time.sleep(0.01)
            train_elapsed = time.perf_counter() - train_enter
            
            # Record training informations
            steps_acc = tf.concat((steps_acc, [["Train"]]), axis=0)
            times_acc = tf.concat((times_acc, [(train_enter, train_elapsed)]), axis=0)
            values_acc = tf.concat((values_acc, [values[-1]]), axis=0)
        
        epoch_elapsed = time.perf_counter() - epoch_enter
        # Record epoch informations
        steps_acc = tf.concat((steps_acc, [["Epoch"]]), axis=0)
        times_acc = tf.concat((times_acc, [(epoch_enter, epoch_elapsed)]), axis=0)
        values_acc = tf.concat((values_acc, [[-1, epoch_num, -1]]), axis=0)
        time.sleep(0.001)
    
    tf.print("Execution time:", time.perf_counter() - start_time)
    return {"steps": steps_acc, "times": times_acc, "values": values_acc}
 

El método de trazado

Finalmente, defina una función capaz de trazar una línea de tiempo dados los valores devueltos por la función timelined_benchmark .

 def draw_timeline(timeline, title, width=0.5, annotate=False, save=False):
    # Remove invalid entries (negative times, or empty steps) from the timelines
    invalid_mask = np.logical_and(timeline['times'] > 0, timeline['steps'] != b'')[:,0]
    steps = timeline['steps'][invalid_mask].numpy()
    times = timeline['times'][invalid_mask].numpy()
    values = timeline['values'][invalid_mask].numpy()
    
    # Get a set of different steps, ordered by the first time they are encountered
    step_ids, indices = np.stack(np.unique(steps, return_index=True))
    step_ids = step_ids[np.argsort(indices)]

    # Shift the starting time to 0 and compute the maximal time value
    min_time = times[:,0].min()
    times[:,0] = (times[:,0] - min_time)
    end = max(width, (times[:,0]+times[:,1]).max() + 0.01)
    
    cmap = mpl.cm.get_cmap("plasma")
    plt.close()
    fig, axs = plt.subplots(len(step_ids), sharex=True, gridspec_kw={'hspace': 0})
    fig.suptitle(title)
    fig.set_size_inches(17.0, len(step_ids))
    plt.xlim(-0.01, end)
    
    for i, step in enumerate(step_ids):
        step_name = step.decode()
        ax = axs[i]
        ax.set_ylabel(step_name)
        ax.set_ylim(0, 1)
        ax.set_yticks([])
        ax.set_xlabel("time (s)")
        ax.set_xticklabels([])
        ax.grid(which="both", axis="x", color="k", linestyle=":")
        
        # Get timings and annotation for the given step
        entries_mask = np.squeeze(steps==step)
        serie = np.unique(times[entries_mask], axis=0)
        annotations = values[entries_mask]
        
        ax.broken_barh(serie, (0, 1), color=cmap(i / len(step_ids)), linewidth=1, alpha=0.66)
        if annotate:
            for j, (start, width) in enumerate(serie):
                annotation = "\n".join([f"{l}: {v}" for l,v in zip(("i", "e", "s"), annotations[j])])
                ax.text(start + 0.001 + (0.001 * (j % 2)), 0.55 - (0.1 * (j % 2)), annotation,
                        horizontalalignment='left', verticalalignment='center')
    if save:
        plt.savefig(title.lower().translate(str.maketrans(" ", "_")) + ".svg")
 

Usar envoltorios para la función asignada

Para ejecutar la función asignada en un contexto entusiasta, debe envolverlos dentro de una llamada tf.py_function .

 def map_decorator(func):
    def wrapper(steps, times, values):
        # Use a tf.py_function to prevent auto-graph from compiling the method
        return tf.py_function(
            func,
            inp=(steps, times, values),
            Tout=(steps.dtype, times.dtype, values.dtype)
        )
    return wrapper
 

Comparación de tuberías

 _batch_map_num_items = 50

def dataset_generator_fun(*args):
    return TimeMeasuredDataset(num_samples=_batch_map_num_items)
 

Ingenuo

 @map_decorator
def naive_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001)  # Time consuming step
    time.sleep(0.0001)  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, [["Map"]]), axis=0),
        tf.concat((times, [[map_enter, map_elapsed]]), axis=0),
        tf.concat((values, [values[-1]]), axis=0)
    )

naive_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .flat_map(dataset_generator_fun)
    .map(naive_map)
    .batch(_batch_map_num_items, drop_remainder=True)
    .unbatch(),
    5
)
 
Execution time: 12.461227312999995

Optimizado

 @map_decorator
def time_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001 * values.shape[0])  # Time consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, tf.tile([[["1st map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


@map_decorator
def memory_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.0001 * values.shape[0])  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    # Use tf.tile to handle batch dimension
    return (
        tf.concat((steps, tf.tile([[["2nd map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


optimized_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .interleave(  # Parallelize data reading
        dataset_generator_fun,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
    .batch(  # Vectorize your mapped function
        _batch_map_num_items,
        drop_remainder=True)
    .map(  # Parallelize map transformation
        time_consuming_map,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
    .cache()  # Cache data
    .map(  # Reduce memory usage
        memory_consuming_map,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
    .prefetch(  # Overlap producer and consumer works
        tf.data.experimental.AUTOTUNE
    )
    .unbatch(),
    5
)
 
Execution time: 6.333680681000033

 draw_timeline(naive_timeline, "Naive", 15)
 

png

 draw_timeline(optimized_timeline, "Optimized", 15)
 

png