Migliori prestazioni con l'API tf.data

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza la fonte su GitHub Scarica taccuino

Panoramica

GPU e TPU possono ridurre radicalmente il tempo necessario per eseguire una singola fase di addestramento. Il raggiungimento delle massime prestazioni richiede una pipeline di input efficiente che fornisca i dati per il passaggio successivo prima che il passaggio corrente sia terminato. Il tf.data API aiuta a costruire oleodotti di ingresso flessibili ed efficienti. Questo documento viene illustrato come utilizzare il tf.data API per costruire altamente performanti tubazioni di ingresso tensorflow.

Prima di continuare, controllare le tubazioni di ingresso Corporatura tensorflow guida per imparare come utilizzare il tf.data API.

risorse

Impostare

import tensorflow as tf

import time

In questa guida si itera su un set di dati e si misurano le prestazioni. Fare benchmark delle prestazioni riproducibili può essere difficile. Diversi fattori che influenzano la riproducibilità includono:

  • L'attuale carico della CPU
  • Il traffico di rete
  • Meccanismi complessi, come la cache

Per ottenere un benchmark riproducibile, costruirai un esempio artificiale.

Il set di dati

Inizia con la definizione di una classe che eredita da tf.data.Dataset chiamato ArtificialDataset . Questo set di dati:

  • Genera num_samples campioni (di default è 3)
  • Dorme per un po' di tempo prima del primo elemento per simulare l'apertura di un file
  • Dorme per un po' di tempo prima di produrre ogni elemento per simulare la lettura dei dati da un file
class ArtificialDataset(tf.data.Dataset):
    def _generator(num_samples):
        # Opening the file
        time.sleep(0.03)

        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            time.sleep(0.015)

            yield (sample_idx,)

    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),
            args=(num_samples,)
        )

Questo insieme di dati è simile a tf.data.Dataset.range uno, aggiungendo un ritardo fisso ad inizio e tra ogni campione.

Il ciclo di allenamento

Quindi, scrivi un ciclo di addestramento fittizio che misuri il tempo necessario per eseguire l'iterazione su un set di dati. Il tempo di allenamento è simulato.

def benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        for sample in dataset:
            # Performing a training step
            time.sleep(0.01)
    print("Execution time:", time.perf_counter() - start_time)

Ottimizza le prestazioni

Per esporre come le prestazioni possono essere ottimizzate, potrete migliorare le prestazioni del ArtificialDataset .

L'approccio ingenuo

Inizia con una pipeline ingenua senza trucchi, iterando sul set di dati così com'è.

benchmark(ArtificialDataset())
Execution time: 0.2588925869999912

Sotto il cofano, ecco come è stato speso il tempo di esecuzione:

Grafico del tempo di esecuzione dei dati: un metodo ingenuo

Il grafico mostra che l'esecuzione di una fase di addestramento comporta:

  • Aprire un file se non è stato ancora aperto
  • Recupero di una voce di dati dal file
  • Utilizzo dei dati per l'allenamento

Tuttavia, in un'implementazione sincrona ingenua come qui, mentre la tua pipeline sta recuperando i dati, il tuo modello è inattivo. Al contrario, durante l'addestramento del modello, la pipeline di input è inattiva. Il tempo della fase di allenamento è quindi la somma dei tempi di apertura, lettura e allenamento.

Le sezioni successive si basano su questa pipeline di input, illustrando le procedure consigliate per la progettazione di pipeline di input TensorFlow ad alte prestazioni.

Precaricamento

Il precaricamento si sovrappone alla preelaborazione e all'esecuzione del modello di una fase di addestramento. Mentre il modello è in esecuzione di formazione passo s , la pipeline di ingresso sta leggendo i dati per passo s+1 . In questo modo si riduce il tempo di passaggio al massimo (in contrapposizione alla somma) dell'allenamento e al tempo necessario per estrarre i dati.

Il tf.data API fornisce il tf.data.Dataset.prefetch trasformazione. Può essere utilizzato per disaccoppiare il momento in cui i dati vengono prodotti dal momento in cui i dati vengono consumati. In particolare, la trasformazione utilizza un thread in background e un buffer interno per preletturare gli elementi dal set di dati di input prima che vengano richiesti. Il numero di elementi da precaricare dovrebbe essere uguale (o possibilmente maggiore) del numero di batch consumati da una singola fase di addestramento. Si potrebbe o sintonizzare manualmente questo valore, o impostarlo su tf.data.AUTOTUNE , che chiederà il tf.data runtime per regolare il valore in modo dinamico in fase di esecuzione.

Nota che la trasformazione prefetch offre vantaggi ogni volta che c'è l'opportunità di sovrapporre il lavoro di un "produttore" con il lavoro di un "consumatore".

benchmark(
    ArtificialDataset()
    .prefetch(tf.data.AUTOTUNE)
)
Execution time: 0.21371877499996117

Grafico del tempo di esecuzione dei dati - metodo di prelettura

Ora, come mostra il grafico del tempo di esecuzione dei dati, mentre il passaggio di training è in esecuzione per il campione 0, la pipeline di input legge i dati per il campione 1 e così via.

Parallelizzazione dell'estrazione dei dati

In un ambiente reale, i dati di input possono essere archiviati in remoto (ad esempio, su Google Cloud Storage o HDFS). Una pipeline di set di dati che funziona bene durante la lettura dei dati in locale potrebbe subire un collo di bottiglia sull'I/O durante la lettura dei dati in remoto a causa delle seguenti differenze tra l'archiviazione locale e remota:

  • Time-to-primo byte: Leggendo il primo byte di un file dall'archivio remoto può prendere ordini di grandezza più di dall'archivio locale.
  • Leggi il throughput: Mentre storage remoto offre tipicamente grande larghezza di banda aggregata, la lettura di un singolo file potrebbe solo essere in grado di utilizzare una piccola frazione di questa larghezza di banda.

Inoltre, una volta che i byte non elaborati vengono caricati in memoria, può anche essere necessario deserializzare e / o decodificare i dati (es protobuf ), che richiede ulteriori calcoli. Questo sovraccarico è presente indipendentemente dal fatto che i dati siano archiviati localmente o in remoto, ma può essere peggiore nel caso remoto se i dati non vengono preletturati in modo efficace.

Per attenuare l'impatto delle diverse spese di estrazione dei dati, il tf.data.Dataset.interleave trasformazione può essere utilizzato per parallelizzare la fase di caricamento di dati, interlacciatura il contenuto di altri insiemi di dati (ad esempio lettori di file di dati). Il numero di gruppi di dati a sovrapposizione può essere specificato dal cycle_length argomento, mentre il livello di parallelismo può essere specificato dal num_parallel_calls argomento. Simile al prefetch trasformazione, la interleave trasformazione supporta tf.data.AUTOTUNE , che delegare la decisione circa quale livello di parallelismo da utilizzare per la tf.data runtime.

Interfoglio sequenziale

I parametri predefiniti dei tf.data.Dataset.interleave trasformazione rendono interleave campioni singoli da due insiemi di dati in sequenza.

benchmark(
    tf.data.Dataset.range(2)
    .interleave(lambda _: ArtificialDataset())
)
Execution time: 0.4863386449999325

Grafico del tempo di esecuzione dei dati - interfoliazione sequenziale

Questa trama tempo di esecuzione dei dati permette di esporre il comportamento del interleave trasformazione, recupero campioni alternativamente dai due set di dati disponibili. Tuttavia, qui non è coinvolto alcun miglioramento delle prestazioni.

Interfoglio parallelo

Ora, utilizzare la num_parallel_calls argomento del interleave trasformazione. Questo carica più set di dati in parallelo, riducendo il tempo di attesa per l'apertura dei file.

benchmark(
    tf.data.Dataset.range(2)
    .interleave(
        lambda _: ArtificialDataset(),
        num_parallel_calls=tf.data.AUTOTUNE
    )
)
Execution time: 0.2785680570000295

Grafico del tempo di esecuzione dei dati - metodo di interleave parallelo

Questa volta, come mostra il grafico del tempo di esecuzione dei dati, la lettura dei due set di dati viene parallelizzata, riducendo il tempo di elaborazione dei dati globali.

Parallelizzazione della trasformazione dei dati

Durante la preparazione dei dati, potrebbe essere necessario pre-elaborare gli elementi di input. A tal fine, le tf.data offerte API il tf.data.Dataset.map trasformazione, che applica una funzione definita dall'utente per ciascun elemento del set di dati di ingresso. Poiché gli elementi di input sono indipendenti l'uno dall'altro, la pre-elaborazione può essere parallelizzata su più core della CPU. Per rendere questo possibile, in modo simile alle prefetch e interleave trasformazioni, la map di trasformazione prevede la num_parallel_calls argomento per specificare il livello di parallelismo.

La scelta del miglior valore per il num_parallel_calls argomento dipende dal vostro hardware, le caratteristiche dei dati e la formazione (come ad esempio la sua dimensione e la forma), il costo della vostra funzione di carta, e ciò che per altre lavorazioni che sta accadendo sulla CPU, allo stesso tempo. Una semplice euristica consiste nell'utilizzare il numero di core CPU disponibili. Tuttavia, come per il prefetch e interleave trasformazione, la map trasformazione supporta tf.data.AUTOTUNE che delegare la decisione circa quale livello di parallelismo da utilizzare per la tf.data runtime.

def mapped_function(s):
    # Do some hard pre-processing
    tf.py_function(lambda: time.sleep(0.03), [], ())
    return s

Mappatura sequenziale

Iniziare con la map di trasformazione, senza il parallelismo come un esempio della linea di base.

benchmark(
    ArtificialDataset()
    .map(mapped_function)
)
Execution time: 0.44197996000002604

Grafico del tempo di esecuzione dei dati - metodo di mappatura sequenziale

Per quanto riguarda l' approccio ingenuo , qui, come mostra il diagramma, i tempi spesi per l'apertura, la lettura, la pre-elaborazione (mappatura) e fasi di formazione sommare insieme per una singola iterazione.

Mappatura parallela

Ora, usa la stessa funzione di pre-elaborazione ma applicala in parallelo su più campioni.

benchmark(
    ArtificialDataset()
    .map(
        mapped_function,
        num_parallel_calls=tf.data.AUTOTUNE
    )
)
Execution time: 0.2774802769999951

Tempo di esecuzione dei dati - mappatura parallela

Come dimostra il grafico dei dati, le fasi di pre-elaborazione si sovrappongono, riducendo il tempo complessivo per una singola iterazione.

memorizzazione nella cache

Il tf.data.Dataset.cache trasformazione può memorizzare nella cache un set di dati, in memoria o su storage locale. Ciò salverà alcune operazioni (come l'apertura di file e la lettura dei dati) dall'esecuzione durante ogni epoca.

benchmark(
    ArtificialDataset()
    .map(  # Apply time consuming operations before cache
        mapped_function
    ).cache(
    ),
    5
)
Execution time: 0.37056952400007503

Tempo di esecuzione dei dati - metodo del set di dati memorizzato nella cache

Qui, il tempo di esecuzione dei dati mostra trama che quando si memorizzare nella cache un set di dati, le trasformazioni prima della cache uno (come l'apertura di file e lettura dei dati) vengono eseguite solo durante la prima epoca. Le prossime epoche potranno riutilizzare i dati memorizzati nella cache dalla cache trasformazione.

Se la funzione definita dall'utente passato nella map trasformazione è costoso, applicare la cache trasformazione dopo la map trasformazione finché il set di dati risultante può ancora entrare nella memoria o memorizzazione locale. Se la funzione definita dall'utente aumenta lo spazio necessario per memorizzare il set di dati al di là della capacità di cache, o applicarlo dopo la cache trasformazione o considerare pre-trattamento dei tuoi dati prima che il lavoro di formazione per ridurre l'utilizzo delle risorse.

Mappatura vettorizzazione

Chiamando una funzione definita dall'utente passato nella map trasformazione è generali legati alla programmazione e l'esecuzione della funzione definita dall'utente. Vettorizza la funzione definita dall'utente (cioè, farlo funzionare su un lotto di ingressi alla volta) e applicare il batch trasformazione prima della map trasformazione.

Per illustrare questa buona pratica, il tuo set di dati artificiale non è adatto. Il ritardo di programmazione è di circa 10 microsecondi (10e-6 secondi), molto meno delle decine di millisecondi utilizzati nella ArtificialDataset , e quindi il suo impatto è difficile da vedere.

Per questo esempio, utilizzare la base tf.data.Dataset.range funzione e semplificare il ciclo di formazione alla sua forma più semplice.

fast_dataset = tf.data.Dataset.range(10000)

def fast_benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for _ in tf.data.Dataset.range(num_epochs):
        for _ in dataset:
            pass
    tf.print("Execution time:", time.perf_counter() - start_time)

def increment(x):
    return x+1

Mappatura scalare

fast_benchmark(
    fast_dataset
    # Apply function one item at a time
    .map(increment)
    # Batch
    .batch(256)
)
Execution time: 0.23767012700000123

Tempo di esecuzione dei dati - metodo della mappa scalare

Il grafico sopra illustra cosa sta succedendo (con meno campioni) utilizzando il metodo di mappatura scalare. Mostra che la funzione mappata viene applicata per ogni campione. Sebbene questa funzione sia molto veloce, ha un sovraccarico che influisce sulle prestazioni temporali.

Mappatura vettorizzata

fast_benchmark(
    fast_dataset
    .batch(256)
    # Apply function on a batch of items
    # The tf.Tensor.__add__ method already handle batches
    .map(increment)
)
Execution time: 0.023475749999988693

Tempo di esecuzione dei dati - metodo della mappa vettorizzata

Questa volta, la funzione mappata viene chiamata una volta e si applica a un batch di campioni. Come mostra il grafico del tempo di esecuzione dei dati, mentre la funzione potrebbe richiedere più tempo per l'esecuzione, l'overhead viene visualizzato solo una volta, migliorando le prestazioni complessive del tempo.

Riduzione dell'impronta di memoria

Un certo numero di trasformazioni, compresi interleave , prefetch e shuffle , mantenere un buffer interno di elementi. Se la funzione definita dall'utente passato nella map trasformazione modifica le dimensioni degli elementi, allora l'ordinamento della mappa di trasformazione e le trasformazioni che elementi respingenti colpisce l'utilizzo della memoria. In generale, scegliere l'ordine che comporta un footprint di memoria inferiore, a meno che non sia desiderabile un ordinamento diverso per le prestazioni.

Memorizzazione nella cache di calcoli parziali

Si consiglia di memorizzare nella cache il set di dati dopo la map di trasformazione a meno che questa trasformazione rende i dati troppo grande per entrare nella memoria. È possibile ottenere un compromesso se la funzione mappata può essere divisa in due parti: una che richiede tempo e una parte che richiede memoria. In questo caso, puoi concatenare le tue trasformazioni come di seguito:

dataset.map(time_consuming_mapping).cache().map(memory_consuming_mapping)

In questo modo, la parte che richiede tempo viene eseguita solo durante la prima epoca ed eviti di utilizzare troppo spazio nella cache.

Riepilogo delle migliori pratiche

Di seguito è riportato un riepilogo delle best practice per la progettazione di pipeline di input TensorFlow ad alte prestazioni:

Riproducendo le figure

Per andare più in profondità nella tf.data.Dataset comprensione API, si può giocare con i propri oleodotti. Di seguito è riportato il codice utilizzato per tracciare le immagini di questa guida. Può essere un buon punto di partenza, mostrando alcune soluzioni alternative per difficoltà comuni come:

  • Riproducibilità del tempo di esecuzione
  • Esecuzione impaziente di funzioni mappate
  • interleave trasformazione richiamabile
import itertools
from collections import defaultdict

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

Il set di dati

Simile al ArtificialDataset si può costruire un set di dati di restituire il tempo trascorso in ciascuna fase.

class TimeMeasuredDataset(tf.data.Dataset):
    # OUTPUT: (steps, timings, counters)
    OUTPUT_TYPES = (tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32)
    OUTPUT_SHAPES = ((2, 1), (2, 2), (2, 3))

    _INSTANCES_COUNTER = itertools.count()  # Number of datasets generated
    _EPOCHS_COUNTER = defaultdict(itertools.count)  # Number of epochs done for each dataset

    def _generator(instance_idx, num_samples):
        epoch_idx = next(TimeMeasuredDataset._EPOCHS_COUNTER[instance_idx])

        # Opening the file
        open_enter = time.perf_counter()
        time.sleep(0.03)
        open_elapsed = time.perf_counter() - open_enter

        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            read_enter = time.perf_counter()
            time.sleep(0.015)
            read_elapsed = time.perf_counter() - read_enter

            yield (
                [("Open",), ("Read",)],
                [(open_enter, open_elapsed), (read_enter, read_elapsed)],
                [(instance_idx, epoch_idx, -1), (instance_idx, epoch_idx, sample_idx)]
            )
            open_enter, open_elapsed = -1., -1.  # Negative values will be filtered


    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_types=cls.OUTPUT_TYPES,
            output_shapes=cls.OUTPUT_SHAPES,
            args=(next(cls._INSTANCES_COUNTER), num_samples)
        )

Questo insieme di dati fornisce esempi di forma [[2, 1], [2, 2], [2, 3]] e del tipo [tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32] . Ogni campione è:

(
  [("Open"), ("Read")],
  [(t0, d), (t0, d)],
  [(i, e, -1), (i, e, s)]
)

In cui si:

  • Open e Read sono passi identificatori
  • t0 è il timestamp quando il passo corrispondente iniziata
  • d è il tempo trascorso nella fase corrispondente
  • i è l'indice dell'istanza
  • e è l'indice dell'epoca (numero di volte che il set di dati è stata iterata)
  • s è l'indice del campione

Il ciclo di iterazione

Rendi il ciclo di iterazione un po' più complicato per aggregare tutti i tempi. Funziona solo con set di dati che generano campioni come descritto sopra.

def timelined_benchmark(dataset, num_epochs=2):
    # Initialize accumulators
    steps_acc = tf.zeros([0, 1], dtype=tf.dtypes.string)
    times_acc = tf.zeros([0, 2], dtype=tf.dtypes.float32)
    values_acc = tf.zeros([0, 3], dtype=tf.dtypes.int32)

    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        epoch_enter = time.perf_counter()
        for (steps, times, values) in dataset:
            # Record dataset preparation informations
            steps_acc = tf.concat((steps_acc, steps), axis=0)
            times_acc = tf.concat((times_acc, times), axis=0)
            values_acc = tf.concat((values_acc, values), axis=0)

            # Simulate training time
            train_enter = time.perf_counter()
            time.sleep(0.01)
            train_elapsed = time.perf_counter() - train_enter

            # Record training informations
            steps_acc = tf.concat((steps_acc, [["Train"]]), axis=0)
            times_acc = tf.concat((times_acc, [(train_enter, train_elapsed)]), axis=0)
            values_acc = tf.concat((values_acc, [values[-1]]), axis=0)

        epoch_elapsed = time.perf_counter() - epoch_enter
        # Record epoch informations
        steps_acc = tf.concat((steps_acc, [["Epoch"]]), axis=0)
        times_acc = tf.concat((times_acc, [(epoch_enter, epoch_elapsed)]), axis=0)
        values_acc = tf.concat((values_acc, [[-1, epoch_num, -1]]), axis=0)
        time.sleep(0.001)

    tf.print("Execution time:", time.perf_counter() - start_time)
    return {"steps": steps_acc, "times": times_acc, "values": values_acc}

Il metodo di plottaggio

Infine, definire una funzione in grado di tracciare una linea temporale determinato i valori restituiti dalla timelined_benchmark funzione.

def draw_timeline(timeline, title, width=0.5, annotate=False, save=False):
    # Remove invalid entries (negative times, or empty steps) from the timelines
    invalid_mask = np.logical_and(timeline['times'] > 0, timeline['steps'] != b'')[:,0]
    steps = timeline['steps'][invalid_mask].numpy()
    times = timeline['times'][invalid_mask].numpy()
    values = timeline['values'][invalid_mask].numpy()

    # Get a set of different steps, ordered by the first time they are encountered
    step_ids, indices = np.stack(np.unique(steps, return_index=True))
    step_ids = step_ids[np.argsort(indices)]

    # Shift the starting time to 0 and compute the maximal time value
    min_time = times[:,0].min()
    times[:,0] = (times[:,0] - min_time)
    end = max(width, (times[:,0]+times[:,1]).max() + 0.01)

    cmap = mpl.cm.get_cmap("plasma")
    plt.close()
    fig, axs = plt.subplots(len(step_ids), sharex=True, gridspec_kw={'hspace': 0})
    fig.suptitle(title)
    fig.set_size_inches(17.0, len(step_ids))
    plt.xlim(-0.01, end)

    for i, step in enumerate(step_ids):
        step_name = step.decode()
        ax = axs[i]
        ax.set_ylabel(step_name)
        ax.set_ylim(0, 1)
        ax.set_yticks([])
        ax.set_xlabel("time (s)")
        ax.set_xticklabels([])
        ax.grid(which="both", axis="x", color="k", linestyle=":")

        # Get timings and annotation for the given step
        entries_mask = np.squeeze(steps==step)
        serie = np.unique(times[entries_mask], axis=0)
        annotations = values[entries_mask]

        ax.broken_barh(serie, (0, 1), color=cmap(i / len(step_ids)), linewidth=1, alpha=0.66)
        if annotate:
            for j, (start, width) in enumerate(serie):
                annotation = "\n".join([f"{l}: {v}" for l,v in zip(("i", "e", "s"), annotations[j])])
                ax.text(start + 0.001 + (0.001 * (j % 2)), 0.55 - (0.1 * (j % 2)), annotation,
                        horizontalalignment='left', verticalalignment='center')
    if save:
        plt.savefig(title.lower().translate(str.maketrans(" ", "_")) + ".svg")

Usa i wrapper per la funzione mappata

Per eseguire la funzione mappato in un contesto ansioso, si deve avvolgerli all'interno di un tf.py_function chiamata.

def map_decorator(func):
    def wrapper(steps, times, values):
        # Use a tf.py_function to prevent auto-graph from compiling the method
        return tf.py_function(
            func,
            inp=(steps, times, values),
            Tout=(steps.dtype, times.dtype, values.dtype)
        )
    return wrapper

Confronto delle condutture

_batch_map_num_items = 50

def dataset_generator_fun(*args):
    return TimeMeasuredDataset(num_samples=_batch_map_num_items)

Ingenuo

@map_decorator
def naive_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001)  # Time consuming step
    time.sleep(0.0001)  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, [["Map"]]), axis=0),
        tf.concat((times, [[map_enter, map_elapsed]]), axis=0),
        tf.concat((values, [values[-1]]), axis=0)
    )

naive_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .flat_map(dataset_generator_fun)
    .map(naive_map)
    .batch(_batch_map_num_items, drop_remainder=True)
    .unbatch(),
    5
)
WARNING:tensorflow:From /tmp/ipykernel_31283/64197174.py:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_types is deprecated and will be removed in a future version.
Instructions for updating:
Use output_signature instead
WARNING:tensorflow:From /tmp/ipykernel_31283/64197174.py:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_shapes is deprecated and will be removed in a future version.
Instructions for updating:
Use output_signature instead
Execution time: 12.536449214999948

Ottimizzato

@map_decorator
def time_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001 * values.shape[0])  # Time consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, tf.tile([[["1st map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


@map_decorator
def memory_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.0001 * values.shape[0])  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    # Use tf.tile to handle batch dimension
    return (
        tf.concat((steps, tf.tile([[["2nd map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


optimized_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .interleave(  # Parallelize data reading
        dataset_generator_fun,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .batch(  # Vectorize your mapped function
        _batch_map_num_items,
        drop_remainder=True)
    .map(  # Parallelize map transformation
        time_consuming_map,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .cache()  # Cache data
    .map(  # Reduce memory usage
        memory_consuming_map,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .prefetch(  # Overlap producer and consumer works
        tf.data.AUTOTUNE
    )
    .unbatch(),
    5
)
Execution time: 6.391495143999919
draw_timeline(naive_timeline, "Naive", 15)

png

draw_timeline(optimized_timeline, "Optimized", 15)

png