¡El Día de la Comunidad de ML es el 9 de noviembre! Únase a nosotros para recibir actualizaciones de TensorFlow, JAX, y más Más información

Mejor rendimiento con la API tf.data

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

Visión general

Las GPU y TPU pueden reducir radicalmente el tiempo necesario para ejecutar un solo paso de entrenamiento. Lograr el máximo rendimiento requiere una canalización de entrada eficiente que proporcione datos para el siguiente paso antes de que finalice el paso actual. El tf.data API ayuda a construir las tuberías de entrada flexibles y eficientes. En este documento se muestra cómo utilizar el tf.data API para construir tuberías de entrada TensorFlow gran rendimiento.

Antes de continuar, compruebe las tuberías de entrada de generación TensorFlow guía para aprender cómo utilizar el tf.data API.

Recursos

Configuración

import tensorflow as tf

import time

A lo largo de esta guía, iterará en un conjunto de datos y medirá el rendimiento. Hacer puntos de referencia de rendimiento reproducibles puede resultar complicado. Los diferentes factores que afectan la reproducibilidad incluyen:

  • La carga actual de la CPU
  • El tráfico de la red
  • Mecanismos complejos, como caché

Para obtener un punto de referencia reproducible, creará un ejemplo artificial.

El conjunto de datos

Comenzar con la definición de una clase que hereda de tf.data.Dataset llamada ArtificialDataset . Este conjunto de datos:

  • Genera num_samples muestras (por defecto es 3)
  • Duerme durante un tiempo antes del primer elemento para simular la apertura de un archivo.
  • Duerme durante un tiempo antes de producir cada elemento para simular la lectura de datos de un archivo.
class ArtificialDataset(tf.data.Dataset):
    def _generator(num_samples):
        # Opening the file
        time.sleep(0.03)

        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            time.sleep(0.015)

            yield (sample_idx,)

    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),
            args=(num_samples,)
        )

Este conjunto de datos es similar a la tf.data.Dataset.range uno, añadiendo un retardo fijo en el comienzo de y en el medio de cada muestra.

El ciclo de entrenamiento

A continuación, escriba un bucle de entrenamiento ficticio que mida cuánto tiempo lleva iterar sobre un conjunto de datos. Se simula el tiempo de entrenamiento.

def benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        for sample in dataset:
            # Performing a training step
            time.sleep(0.01)
    print("Execution time:", time.perf_counter() - start_time)

Optimiza el rendimiento

Para mostrar cómo se puede optimizar el rendimiento, mejorará el rendimiento de la ArtificialDataset .

El enfoque ingenuo

Comience con una canalización ingenua sin trucos, iterando sobre el conjunto de datos tal cual.

benchmark(ArtificialDataset())
Execution time: 0.26497629899995445

Debajo del capó, así es como se gastó el tiempo de ejecución:

Gráfico de tiempo de ejecución de datos: un método ingenuo

La trama muestra que realizar un paso de entrenamiento implica:

  • Abrir un archivo si aún no se ha abierto
  • Obteniendo una entrada de datos del archivo
  • Usar los datos para entrenar

Sin embargo, en una implementación sincrónica ingenua como aquí, mientras su canalización está obteniendo los datos, su modelo está inactivo. Por el contrario, mientras su modelo se está entrenando, la canalización de entrada permanece inactiva. El tiempo del paso de formación es, por tanto, la suma de los tiempos de apertura, lectura y formación.

Las siguientes secciones se basan en esta canalización de entrada e ilustran las prácticas recomendadas para diseñar canalizaciones de entrada de TensorFlow con rendimiento.

Precarga

La captación previa se superpone al procesamiento previo y la ejecución del modelo de un paso de formación. Mientras que el modelo está ejecutando el entrenamiento paso s , la tubería de entrada está leyendo los datos para la etapa s+1 . Al hacerlo, se reduce el tiempo de paso al máximo (en lugar de la suma) del entrenamiento y el tiempo que se tarda en extraer los datos.

El tf.data API proporciona la tf.data.Dataset.prefetch transformación. Se puede utilizar para desacoplar el momento en que se producen los datos del momento en que se consumen los datos. En particular, la transformación utiliza un subproceso en segundo plano y un búfer interno para obtener elementos del conjunto de datos de entrada antes de que se soliciten. El número de elementos para captar previamente debe ser igual (o posiblemente mayor) al número de lotes consumidos por un solo paso de entrenamiento. Usted podría sintonizar manualmente este valor, o se establece en tf.data.AUTOTUNE , que pedirá al tf.data tiempo de ejecución para ajustar el valor dinámicamente en tiempo de ejecución.

Tenga en cuenta que la transformación de captación previa proporciona beneficios cada vez que existe la oportunidad de superponer el trabajo de un "productor" con el trabajo de un "consumidor".

benchmark(
    ArtificialDataset()
    .prefetch(tf.data.AUTOTUNE)
)
Execution time: 0.21731788600027357

Gráfico de tiempo de ejecución de datos: método de captura previa

Ahora, como muestra el gráfico de tiempo de ejecución de datos, mientras se ejecuta el paso de entrenamiento para la muestra 0, la canalización de entrada lee los datos para la muestra 1, y así sucesivamente.

Paralelizar la extracción de datos

En un entorno del mundo real, los datos de entrada se pueden almacenar de forma remota (por ejemplo, en Google Cloud Storage o HDFS). Una canalización de conjuntos de datos que funciona bien al leer datos de forma local puede convertirse en un cuello de botella en la E / S al leer datos de forma remota debido a las siguientes diferencias entre el almacenamiento local y remoto:

  • Tiempo de salida al primer byte: Lectura del primer byte de un archivo de almacenamiento remoto puede recibir órdenes de magnitud mayor que desde el almacenamiento local.
  • Leer rendimiento: Aunque el almacenamiento remoto normalmente ofrece gran ancho de banda agregado, la lectura de un único archivo sólo podría ser capaz de utilizar una pequeña fracción de este ancho de banda.

Además, una vez los bytes primas se cargan en memoria, puede también ser necesario deserializar y / o descifrar los datos (por ejemplo protobuf ), que requiere cálculo adicional. Esta sobrecarga está presente independientemente de si los datos se almacenan localmente o de forma remota, pero puede ser peor en el caso remoto si los datos no se obtienen previamente de manera efectiva.

Para mitigar el impacto de los diversos gastos generales de extracción de datos, la tf.data.Dataset.interleave transformación se puede utilizar para paralelizar la etapa de carga de datos, la intercalación el contenido de otros conjuntos de datos (tales como lectores de archivos de datos). El número de conjuntos de datos a la superposición puede ser especificado por el cycle_length argumento, mientras que el nivel de paralelismo puede ser especificado por el num_parallel_calls argumento. Al igual que en la prefetch transformación, la interleave de transformación apoya tf.data.AUTOTUNE , que delegar la decisión acerca de qué nivel de paralelismo con el uso de la tf.data tiempo de ejecución.

Entrelazado secuencial

Los argumentos predeterminados de la tf.data.Dataset.interleave transformación hacen intercalar muestras individuales a partir de dos conjuntos de datos de forma secuencial.

benchmark(
    tf.data.Dataset.range(2)
    .interleave(lambda _: ArtificialDataset())
)
Execution time: 0.4987426460002098

Gráfico de tiempo de ejecución de datos: intercalación secuencial

Este gráfico de tiempo de ejecución de datos permite a exhibir el comportamiento de la interleave la transformación, ir a buscar muestras alternativamente de los dos conjuntos de datos disponibles. Sin embargo, aquí no se trata de una mejora del rendimiento.

Entrelazado paralelo

Ahora, utilice el num_parallel_calls argumento de la interleave transformación. Esto carga múltiples conjuntos de datos en paralelo, lo que reduce el tiempo de espera para que se abran los archivos.

benchmark(
    tf.data.Dataset.range(2)
    .interleave(
        lambda _: ArtificialDataset(),
        num_parallel_calls=tf.data.AUTOTUNE
    )
)
Execution time: 0.283668874000341

Gráfico de tiempo de ejecución de datos: método de intercalación paralela

Esta vez, como muestra el gráfico de tiempo de ejecución de datos, la lectura de los dos conjuntos de datos se paraleliza, lo que reduce el tiempo de procesamiento de datos global.

Paralelizar la transformación de datos

Al preparar los datos, es posible que los elementos de entrada deban procesarse previamente. Para este fin, las tf.data ofertas API la tf.data.Dataset.map de transformación, que se aplica una función definida por el usuario a cada elemento del conjunto de datos de entrada. Debido a que los elementos de entrada son independientes entre sí, el preprocesamiento se puede paralelizar en varios núcleos de CPU. Para que esto sea posible, de manera similar a las prefetch e interleave transformaciones, el map transformación proporciona la num_parallel_calls argumento para especificar el nivel de paralelismo.

Elegir el mejor valor para el num_parallel_calls argumento depende del hardware, las características de los datos de capacitación (como su tamaño y forma), el costo de su función de mapa, y lo otro proceso está ocurriendo en la CPU al mismo tiempo. Una heurística simple es utilizar la cantidad de núcleos de CPU disponibles. Sin embargo, en cuanto a la prefetch y la interleave la transformación, el map transformación apoya tf.data.AUTOTUNE que delegar la decisión acerca de qué nivel de paralelismo con el uso de la tf.data tiempo de ejecución.

def mapped_function(s):
    # Do some hard pre-processing
    tf.py_function(lambda: time.sleep(0.03), [], ())
    return s

Mapeo secuencial

Comience usando el map transformación sin paralelismo como un ejemplo de referencia.

benchmark(
    ArtificialDataset()
    .map(mapped_function)
)
Execution time: 0.4505277170001136

Gráfico de tiempo de ejecución de datos: método de mapeo secuencial

En cuanto al enfoque ingenuo , aquí, como muestra de la trama, los tiempos pasaron para la apertura, lectura, pre-procesamiento (mapeo) y los pasos de formación resumir juntos por una sola iteración.

Mapeo paralelo

Ahora, utilice la misma función de preprocesamiento pero aplíquela en paralelo en varias muestras.

benchmark(
    ArtificialDataset()
    .map(
        mapped_function,
        num_parallel_calls=tf.data.AUTOTUNE
    )
)
Execution time: 0.2839677860001757

Tiempo de ejecución de datos: mapeo paralelo

Como demuestra el gráfico de datos, los pasos de preprocesamiento se superponen, lo que reduce el tiempo total para una sola iteración.

Almacenamiento en caché

El tf.data.Dataset.cache transformación puede almacenar en caché un conjunto de datos, ya sea en la memoria o en el almacenamiento local. Esto evitará que algunas operaciones (como la apertura de archivos y la lectura de datos) se ejecuten durante cada época.

benchmark(
    ArtificialDataset()
    .map(  # Apply time consuming operations before cache
        mapped_function
    ).cache(
    ),
    5
)
Execution time: 0.3848854380003104

Tiempo de ejecución de datos: método de conjunto de datos en caché

Aquí, los datos de tiempo de ejecución gráfico muestra que cuando se hace una caché de un conjunto de datos, las transformaciones antes de la cache uno (como la apertura de archivos y datos de lectura) se ejecutan sólo durante la primera época. Los próximos épocas reutilizarán los datos en caché por el cache transformación.

Si la función definida por el usuario pasado en el map transformación es caro, aplicar el cache transformación después de que el map de transformación, siempre que el conjunto de datos resultante todavía puede caber en la memoria o almacenamiento local. Si la función definida por el usuario aumenta el espacio necesario para almacenar el conjunto de datos más allá de la capacidad de la caché, o bien aplicarlo después de la cache transformación o considerar pre-procesamiento de los datos antes de su puesto de trabajo para reducir el uso de recursos.

Vectorización de mapeo

Invocar una función definida por el usuario pasado en el map de transformación overhead ha relacionado con la programación y la ejecución de la función definida por el usuario. Vectorizar la función definida por el usuario (es decir, tiene que operar sobre un lote de entradas a la vez) y aplicar el batch de transformación antes de que el map de transformación.

Para ilustrar esta buena práctica, su conjunto de datos artificiales no es adecuado. El retardo de programación es de alrededor de 10 microsegundos (10e-6 segundos), mucho menos que las decenas de milisegundos utilizados en la ArtificialDataset , y por tanto su impacto es difícil de ver.

Para este ejemplo, utilice la base tf.data.Dataset.range función y simplificar el circuito de entrenamiento a su forma más simple.

fast_dataset = tf.data.Dataset.range(10000)

def fast_benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for _ in tf.data.Dataset.range(num_epochs):
        for _ in dataset:
            pass
    tf.print("Execution time:", time.perf_counter() - start_time)

def increment(x):
    return x+1

Mapeo escalar

fast_benchmark(
    fast_dataset
    # Apply function one item at a time
    .map(increment)
    # Batch
    .batch(256)
)
Execution time: 0.2712608739998359

Tiempo de ejecución de datos: método de mapa escalar

La gráfica de arriba ilustra lo que está sucediendo (con menos muestras) usando el método de mapeo escalar. Muestra que la función mapeada se aplica para cada muestra. Si bien esta función es muy rápida, tiene algunos gastos generales que afectan el rendimiento del tiempo.

Mapeo vectorizado

fast_benchmark(
    fast_dataset
    .batch(256)
    # Apply function on a batch of items
    # The tf.Tensor.__add__ method already handle batches
    .map(increment)
)
Execution time: 0.02737950600021577

Tiempo de ejecución de datos: método de mapa vectorizado

Esta vez, la función asignada se llama una vez y se aplica a un lote de muestra. Como muestra el gráfico de tiempo de ejecución de datos, si bien la función puede tardar más en ejecutarse, la sobrecarga aparece solo una vez, lo que mejora el rendimiento general del tiempo.

Reducir la huella de memoria

Un número de transformaciones, incluyendo la interleave , prefetch , y shuffle , mantener un buffer interno de elementos. Si la función definida por el usuario pasado en el map de transformación cambia el tamaño de los elementos, a continuación, el orden de la transformación mapa y las transformaciones que los elementos de amortiguación afecta el uso de memoria. En general, elija el orden que resulte en una menor huella de memoria, a menos que sea deseable un orden diferente para el rendimiento.

Almacenamiento en caché de cálculos parciales

Se recomienda almacenar en caché el conjunto de datos después de que el map de transformación, excepto si esta transformación hace que los datos demasiado grande para caber en la memoria. Se puede lograr una compensación si su función mapeada se puede dividir en dos partes: una que consume mucho tiempo y otra que consume memoria. En este caso, puede encadenar sus transformaciones como se muestra a continuación:

dataset.map(time_consuming_mapping).cache().map(memory_consuming_mapping)

De esta manera, la parte que consume mucho tiempo solo se ejecuta durante la primera época y evita usar demasiado espacio de caché.

Resumen de las mejores prácticas

A continuación, se muestra un resumen de las mejores prácticas para diseñar canalizaciones de entrada de TensorFlow con rendimiento:

Reproduciendo las figuras

Para profundizar en el tf.data.Dataset la comprensión de la API, se puede jugar con sus propias tuberías. A continuación se muestra el código utilizado para trazar las imágenes de esta guía. Puede ser un buen punto de partida, mostrando algunas soluciones para dificultades comunes como:

  • Reproducibilidad del tiempo de ejecución
  • Ejecución ávida de funciones mapeadas
  • interleave exigible transformación
import itertools
from collections import defaultdict

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

El conjunto de datos

Al igual que en el ArtificialDataset se puede construir un conjunto de datos devolver el tiempo empleado en cada paso.

class TimeMeasuredDataset(tf.data.Dataset):
    # OUTPUT: (steps, timings, counters)
    OUTPUT_TYPES = (tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32)
    OUTPUT_SHAPES = ((2, 1), (2, 2), (2, 3))

    _INSTANCES_COUNTER = itertools.count()  # Number of datasets generated
    _EPOCHS_COUNTER = defaultdict(itertools.count)  # Number of epochs done for each dataset

    def _generator(instance_idx, num_samples):
        epoch_idx = next(TimeMeasuredDataset._EPOCHS_COUNTER[instance_idx])

        # Opening the file
        open_enter = time.perf_counter()
        time.sleep(0.03)
        open_elapsed = time.perf_counter() - open_enter

        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            read_enter = time.perf_counter()
            time.sleep(0.015)
            read_elapsed = time.perf_counter() - read_enter

            yield (
                [("Open",), ("Read",)],
                [(open_enter, open_elapsed), (read_enter, read_elapsed)],
                [(instance_idx, epoch_idx, -1), (instance_idx, epoch_idx, sample_idx)]
            )
            open_enter, open_elapsed = -1., -1.  # Negative values will be filtered


    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_types=cls.OUTPUT_TYPES,
            output_shapes=cls.OUTPUT_SHAPES,
            args=(next(cls._INSTANCES_COUNTER), num_samples)
        )

Este conjunto de datos proporciona muestras de la forma [[2, 1], [2, 2], [2, 3]] y de tipo [tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32] . Cada muestra es:

(
  [("Open"), ("Read")],
  [(t0, d), (t0, d)],
  [(i, e, -1), (i, e, s)]
)

Dónde:

  • Open y Read son identificadores pasos
  • t0 es la marca de tiempo cuando el paso correspondiente comenzó
  • d es el tiempo pasado en el paso correspondiente
  • i es el índice de instancia
  • e es el índice de época (número de veces el conjunto de datos se ha iterado)
  • s es el índice de la muestra

El ciclo de iteración

Haga que el ciclo de iteración sea un poco más complicado para agregar todos los tiempos. Esto solo funcionará con conjuntos de datos que generen muestras como se detalla anteriormente.

def timelined_benchmark(dataset, num_epochs=2):
    # Initialize accumulators
    steps_acc = tf.zeros([0, 1], dtype=tf.dtypes.string)
    times_acc = tf.zeros([0, 2], dtype=tf.dtypes.float32)
    values_acc = tf.zeros([0, 3], dtype=tf.dtypes.int32)

    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        epoch_enter = time.perf_counter()
        for (steps, times, values) in dataset:
            # Record dataset preparation informations
            steps_acc = tf.concat((steps_acc, steps), axis=0)
            times_acc = tf.concat((times_acc, times), axis=0)
            values_acc = tf.concat((values_acc, values), axis=0)

            # Simulate training time
            train_enter = time.perf_counter()
            time.sleep(0.01)
            train_elapsed = time.perf_counter() - train_enter

            # Record training informations
            steps_acc = tf.concat((steps_acc, [["Train"]]), axis=0)
            times_acc = tf.concat((times_acc, [(train_enter, train_elapsed)]), axis=0)
            values_acc = tf.concat((values_acc, [values[-1]]), axis=0)

        epoch_elapsed = time.perf_counter() - epoch_enter
        # Record epoch informations
        steps_acc = tf.concat((steps_acc, [["Epoch"]]), axis=0)
        times_acc = tf.concat((times_acc, [(epoch_enter, epoch_elapsed)]), axis=0)
        values_acc = tf.concat((values_acc, [[-1, epoch_num, -1]]), axis=0)
        time.sleep(0.001)

    tf.print("Execution time:", time.perf_counter() - start_time)
    return {"steps": steps_acc, "times": times_acc, "values": values_acc}

El método de trazado

Por último, definir una función capaz de trazar una línea de tiempo dado los valores devueltos por la timelined_benchmark función.

def draw_timeline(timeline, title, width=0.5, annotate=False, save=False):
    # Remove invalid entries (negative times, or empty steps) from the timelines
    invalid_mask = np.logical_and(timeline['times'] > 0, timeline['steps'] != b'')[:,0]
    steps = timeline['steps'][invalid_mask].numpy()
    times = timeline['times'][invalid_mask].numpy()
    values = timeline['values'][invalid_mask].numpy()

    # Get a set of different steps, ordered by the first time they are encountered
    step_ids, indices = np.stack(np.unique(steps, return_index=True))
    step_ids = step_ids[np.argsort(indices)]

    # Shift the starting time to 0 and compute the maximal time value
    min_time = times[:,0].min()
    times[:,0] = (times[:,0] - min_time)
    end = max(width, (times[:,0]+times[:,1]).max() + 0.01)

    cmap = mpl.cm.get_cmap("plasma")
    plt.close()
    fig, axs = plt.subplots(len(step_ids), sharex=True, gridspec_kw={'hspace': 0})
    fig.suptitle(title)
    fig.set_size_inches(17.0, len(step_ids))
    plt.xlim(-0.01, end)

    for i, step in enumerate(step_ids):
        step_name = step.decode()
        ax = axs[i]
        ax.set_ylabel(step_name)
        ax.set_ylim(0, 1)
        ax.set_yticks([])
        ax.set_xlabel("time (s)")
        ax.set_xticklabels([])
        ax.grid(which="both", axis="x", color="k", linestyle=":")

        # Get timings and annotation for the given step
        entries_mask = np.squeeze(steps==step)
        serie = np.unique(times[entries_mask], axis=0)
        annotations = values[entries_mask]

        ax.broken_barh(serie, (0, 1), color=cmap(i / len(step_ids)), linewidth=1, alpha=0.66)
        if annotate:
            for j, (start, width) in enumerate(serie):
                annotation = "\n".join([f"{l}: {v}" for l,v in zip(("i", "e", "s"), annotations[j])])
                ax.text(start + 0.001 + (0.001 * (j % 2)), 0.55 - (0.1 * (j % 2)), annotation,
                        horizontalalignment='left', verticalalignment='center')
    if save:
        plt.savefig(title.lower().translate(str.maketrans(" ", "_")) + ".svg")

Utilice envoltorios para la función mapeada

Para ejecutar la función asignada en un contexto ansiosos, hay que envolverlos en un tf.py_function llamada.

def map_decorator(func):
    def wrapper(steps, times, values):
        # Use a tf.py_function to prevent auto-graph from compiling the method
        return tf.py_function(
            func,
            inp=(steps, times, values),
            Tout=(steps.dtype, times.dtype, values.dtype)
        )
    return wrapper

Comparación de tuberías

_batch_map_num_items = 50

def dataset_generator_fun(*args):
    return TimeMeasuredDataset(num_samples=_batch_map_num_items)

Ingenuo

@map_decorator
def naive_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001)  # Time consuming step
    time.sleep(0.0001)  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, [["Map"]]), axis=0),
        tf.concat((times, [[map_enter, map_elapsed]]), axis=0),
        tf.concat((values, [values[-1]]), axis=0)
    )

naive_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .flat_map(dataset_generator_fun)
    .map(naive_map)
    .batch(_batch_map_num_items, drop_remainder=True)
    .unbatch(),
    5
)
WARNING:tensorflow:From /tmp/ipykernel_23983/64197174.py:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_types is deprecated and will be removed in a future version.
Instructions for updating:
Use output_signature instead
WARNING:tensorflow:From /tmp/ipykernel_23983/64197174.py:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_shapes is deprecated and will be removed in a future version.
Instructions for updating:
Use output_signature instead
Execution time: 13.13538893499981

Optimizado

@map_decorator
def time_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001 * values.shape[0])  # Time consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, tf.tile([[["1st map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


@map_decorator
def memory_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.0001 * values.shape[0])  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    # Use tf.tile to handle batch dimension
    return (
        tf.concat((steps, tf.tile([[["2nd map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


optimized_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .interleave(  # Parallelize data reading
        dataset_generator_fun,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .batch(  # Vectorize your mapped function
        _batch_map_num_items,
        drop_remainder=True)
    .map(  # Parallelize map transformation
        time_consuming_map,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .cache()  # Cache data
    .map(  # Reduce memory usage
        memory_consuming_map,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .prefetch(  # Overlap producer and consumer works
        tf.data.AUTOTUNE
    )
    .unbatch(),
    5
)
Execution time: 6.723691489999965
draw_timeline(naive_timeline, "Naive", 15)

png

draw_timeline(optimized_timeline, "Optimized", 15)

png