Cette page a été traduite par l'API Cloud Translation.
Switch to English

Meilleures performances avec l'API tf.data

Voir sur TensorFlow.org Exécuter dans Google Colab Afficher la source sur GitHub Télécharger le cahier

Aperçu

Les GPU et les TPU peuvent réduire considérablement le temps nécessaire pour exécuter une seule étape d'entraînement. Pour atteindre des performances optimales, il faut un pipeline d'entrée efficace qui fournit des données pour l'étape suivante avant la fin de l'étape en cours. L'API tf.data permet de créer des pipelines d'entrée flexibles et efficaces. Ce document montre comment utiliser l'API tf.data pour créer des pipelines d'entrée TensorFlow hautement performants.

Avant de continuer, lisez le guide « Créer des pipelines d'entrée TensorFlow » pour savoir comment utiliser l'API tf.data .

Ressources

Installer

 import tensorflow as tf

import time
 

Tout au long de ce guide, vous parcourrez un ensemble de données et mesurerez les performances. Faire des benchmarks de performance reproductibles peut être difficile, différents facteurs ayant un impact:

  • la charge actuelle du processeur,
  • le trafic réseau,
  • mécanismes complexes comme le cache, etc.

Par conséquent, pour fournir un repère reproductible, construisez un exemple artificiel.

Le jeu de données

Définissez une classe héritant de tf.data.Dataset appelée ArtificialDataset . Cet ensemble de données:

  • génère num_samples échantillons (la valeur par défaut est 3)
  • dort pendant un certain temps avant le premier élément pour simuler l'ouverture d'un fichier
  • dort pendant un certain temps avant de produire chaque élément pour simuler la lecture des données d'un fichier
 class ArtificialDataset(tf.data.Dataset):
    def _generator(num_samples):
        # Opening the file
        time.sleep(0.03)
        
        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            time.sleep(0.015)
            
            yield (sample_idx,)
    
    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_types=tf.dtypes.int64,
            output_shapes=(1,),
            args=(num_samples,)
        )
 

Cet ensemble de données est similaire à celui de tf.data.Dataset.range , ajoutant un délai fixe au début et entre chaque échantillon.

La boucle d'entraînement

Écrivez une boucle d'apprentissage factice qui mesure le temps nécessaire pour itérer sur un ensemble de données. Le temps de formation est simulé.

 def benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        for sample in dataset:
            # Performing a training step
            time.sleep(0.01)
    tf.print("Execution time:", time.perf_counter() - start_time)
 

Optimiser les performances

Pour montrer comment les performances peuvent être optimisées, vous améliorerez les performances de ArtificialDataset .

L'approche naïve

Commencez avec un pipeline naïf n'utilisant aucune astuce, itérant sur l'ensemble de données tel quel.

 benchmark(ArtificialDataset())
 
Execution time: 0.2346214259999897

Sous le capot, voici comment s'est passé votre temps d'exécution:

Naïve

Vous pouvez voir que l'exécution d'une étape d'entraînement implique:

  • ouvrir un fichier s'il n'a pas encore été ouvert,
  • récupérer une entrée de données dans le fichier,
  • utiliser les données pour la formation.

Cependant, dans une implémentation synchrone naïve comme ici, pendant que votre pipeline récupère les données, votre modèle est inactif. À l'inverse, pendant que votre modèle s'entraîne, le pipeline d'entrée est inactif. Le temps de pas de formation est donc la somme de tous les temps d'ouverture, de lecture et de formation.

Les sections suivantes s'appuient sur ce pipeline d'entrée, illustrant les meilleures pratiques pour la conception de pipelines d'entrée TensorFlow performants.

Prérécupération

La prélecture chevauche le prétraitement et l'exécution du modèle d'une étape d'apprentissage. Pendant que le modèle exécute l'étape d'apprentissage s , le pipeline d'entrée lit les données pour l'étape s+1 . Cela réduit le temps de pas au maximum (par opposition à la somme) de la formation et le temps nécessaire pour extraire les données.

L'API tf.data fournit la transformation tf.data.Dataset.prefetch . Il peut être utilisé pour découpler le moment où les données sont produites du moment où les données sont consommées. En particulier, la transformation utilise un thread d'arrière-plan et un tampon interne pour pré-extraire les éléments de l'ensemble de données d'entrée avant le moment où ils sont demandés. Le nombre d'éléments à pré-lire doit être égal (ou éventuellement supérieur) au nombre de lots consommés par une seule étape d'apprentissage. Vous pouvez soit régler manuellement cette valeur, soit la définir sur tf.data.experimental.AUTOTUNE ce qui tf.data runtime tf.data à régler la valeur de manière dynamique lors de l'exécution.

Notez que la transformation de prélecture offre des avantages chaque fois qu'il y a une opportunité de chevaucher le travail d'un «producteur» avec le travail d'un «consommateur».

 benchmark(
    ArtificialDataset()
    .prefetch(tf.data.experimental.AUTOTUNE)
)
 
Execution time: 0.19066840700008925

Prérécupéré

Cette fois, vous pouvez voir que pendant que l'étape d'entraînement est en cours d'exécution pour l'échantillon 0, le pipeline d'entrée lit les données de l'échantillon 1, et ainsi de suite.

Paralléliser l'extraction de données

Dans un environnement réel, les données d'entrée peuvent être stockées à distance (par exemple, GCS ou HDFS). Un pipeline d'ensemble de données qui fonctionne bien lors de la lecture de données localement peut devenir un goulot d'étranglement sur les E / S lors de la lecture de données à distance en raison des différences suivantes entre le stockage local et distant:

  • Time-to-first-byte: La lecture du premier octet d'un fichier à partir d'un stockage distant peut prendre des ordres de grandeur plus longs qu'à partir du stockage local.
  • Débit de lecture: alors que le stockage distant offre généralement une large bande passante globale, la lecture d'un seul fichier peut n'utiliser qu'une petite fraction de cette bande passante.

De plus, une fois les octets bruts chargés en mémoire, il peut également être nécessaire de désérialiser et / ou décrypter les données (par exemple protobuf ), ce qui nécessite des calculs supplémentaires. Cette surcharge est présente indépendamment du fait que les données soient stockées localement ou à distance, mais peut être pire dans le cas distant si les données ne sont pas extraites efficacement.

Pour atténuer l'impact des divers frais généraux d'extraction de données, la transformation tf.data.Dataset.interleave peut être utilisée pour paralléliser l'étape de chargement des données, en entrelaçant le contenu d'autres ensembles de données (tels que les lecteurs de fichiers de données). Le nombre d'ensembles de données à chevaucher peut être spécifié par l'argument cycle_length , tandis que le niveau de parallélisme peut être spécifié par l'argument num_parallel_calls . Semblable à la transformation de prefetch , la transformation d' interleave prend en charge tf.data.experimental.AUTOTUNE qui déléguera la décision sur le niveau de parallélisme à utiliser au runtime tf.data .

Entrelacement séquentiel

Les arguments par défaut de la transformation tf.data.Dataset.interleave permettent d'entrelacer des échantillons uniques de deux ensembles de données de manière séquentielle.

 benchmark(
    tf.data.Dataset.range(2)
    .interleave(ArtificialDataset)
)
 
Execution time: 0.19681244399998832

Entrelacement séquentiel

Ce tracé permet de montrer le comportement de la transformation d' interleave , en récupérant des échantillons alternativement à partir des deux jeux de données disponibles. Cependant, aucune amélioration des performances n'est impliquée ici.

Entrelacement parallèle

Utilisez maintenant l'argument num_parallel_calls de la transformation d' interleave . Cela charge plusieurs ensembles de données en parallèle, ce qui réduit le temps d'attente pour l'ouverture des fichiers.

 benchmark(
    tf.data.Dataset.range(2)
    .interleave(
        ArtificialDataset,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
)
 
Execution time: 0.1402252690000978

Entrelacement parallèle

Cette fois, la lecture des deux ensembles de données est parallélisée, réduisant le temps de traitement global des données.

Paralléliser la transformation des données

Lors de la préparation des données, les éléments d'entrée peuvent devoir être prétraités. À cette fin, l'API tf.data propose la transformation tf.data.Dataset.map , qui applique une fonction définie par l'utilisateur à chaque élément de l'ensemble de données d'entrée. Les éléments d'entrée étant indépendants les uns des autres, le prétraitement peut être parallélisé sur plusieurs cœurs de processeur. Pour rendre cela possible, de la même manière que pour les transformations de prefetch et d' interleave , la transformation de map fournit l'argument num_parallel_calls pour spécifier le niveau de parallélisme.

Le choix de la meilleure valeur pour l'argument num_parallel_calls dépend de votre matériel, des caractéristiques de vos données d'entraînement (telles que sa taille et sa forme), du coût de votre fonction de carte et des autres traitements en cours sur le processeur en même temps. Une heuristique simple consiste à utiliser le nombre de cœurs CPU disponibles. Cependant, comme pour la transformation de prefetch et d' interleave , la transformation de map prend en charge tf.data.experimental.AUTOTUNE qui déléguera la décision sur le niveau de parallélisme à utiliser au runtime tf.data .

 def mapped_function(s):
    # Do some hard pre-processing
    tf.py_function(lambda: time.sleep(0.03), [], ())
    return s
 

Cartographie séquentielle

Commencez par utiliser la transformation de map sans parallélisme comme exemple de base.

 benchmark(
    ArtificialDataset()
    .map(mapped_function)
)
 
Execution time: 0.41994816599992646

Cartographie séquentielle

Quant à l' approche naïve , ici les temps consacrés aux étapes d'ouverture, de lecture, de prétraitement (cartographie) et d'apprentissage se cumulent pour une seule itération.

Cartographie parallèle

Maintenant, utilisez la même fonction de prétraitement mais appliquez-la en parallèle sur plusieurs échantillons.

 benchmark(
    ArtificialDataset()
    .map(
        mapped_function,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
)
 
Execution time: 0.25377997199996116

Cartographie parallèle

Maintenant, vous pouvez voir sur le tracé que les étapes de prétraitement se chevauchent, ce qui réduit le temps global pour une seule itération.

Mise en cache

La transformation tf.data.Dataset.cache peut mettre en cache un ensemble de données, soit en mémoire, soit sur le stockage local. Cela évitera l'exécution de certaines opérations (comme l'ouverture de fichier et la lecture de données) à chaque époque.

 benchmark(
    ArtificialDataset()
    .map(  # Apply time consuming operations before cache
        mapped_function
    ).cache(
    ),
    5
)
 
Execution time: 0.36316757900010543

Ensemble de données mis en cache

Lorsque vous mettez en cache un ensemble de données, les transformations avant celle du cache (comme l'ouverture du fichier et la lecture des données) ne sont exécutées qu'à la première époque. Les époques suivantes réutiliseront les données mises en cache par la transformation de cache .

Si la fonction définie par l'utilisateur transmise à la transformation de map est coûteuse, appliquez la transformation de cache après la transformation de map tant que l'ensemble de données résultant peut toujours tenir dans la mémoire ou le stockage local. Si la fonction définie par l'utilisateur augmente l'espace requis pour stocker l'ensemble de données au-delà de la capacité du cache, appliquez-la après la transformation du cache ou envisagez de prétraiter vos données avant votre tâche d'entraînement pour réduire l'utilisation des ressources.

Cartographie de vectorisation

L'appel d'une fonction définie par l'utilisateur transmise à la transformation de map entraîne une surcharge liée à la planification et à l'exécution de la fonction définie par l'utilisateur. Nous vous recommandons de vectoriser la fonction définie par l'utilisateur (c'est-à-dire de la faire fonctionner sur un lot d'entrées à la fois) et d'appliquer la transformation par batch avant la transformation de map .

Pour illustrer cette bonne pratique, votre ensemble de données artificielles ne convient pas. Le délai de planification est d'environ 10 microsecondes (10e-6 secondes), bien moins que les dizaines de millisecondes utilisées dans ArtificialDataset , et son impact est donc difficile à voir.

Pour cet exemple, utilisez la fonction de base tf.data.Dataset.range et simplifiez la boucle d'apprentissage dans sa forme la plus simple.

 fast_dataset = tf.data.Dataset.range(10000)

def fast_benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for _ in tf.data.Dataset.range(num_epochs):
        for _ in dataset:
            pass
    tf.print("Execution time:", time.perf_counter() - start_time)
    
def increment(x):
    return x+1
 

Cartographie scalaire

 fast_benchmark(
    fast_dataset
    # Apply function one item at a time
    .map(increment)
    # Batch
    .batch(256)
)
 
Execution time: 0.9026268119999941

Carte scalaire

Le graphique ci-dessus illustre ce qui se passe (avec moins d'échantillons). Vous pouvez voir que la fonction mappée est appliquée à chaque échantillon. Bien que cette fonction soit très rapide, elle a des frais généraux qui ont un impact sur les performances temporelles.

Cartographie vectorisée

 fast_benchmark(
    fast_dataset
    .batch(256)
    # Apply function on a batch of items
    # The tf.Tensor.__add__ method already handle batches
    .map(increment)
)
 
Execution time: 0.03311353300000519

Carte vectorisée

Cette fois, la fonction mappée est appelée une fois et s'applique à un lot d'échantillons. Bien que la fonction puisse prendre plus de temps à s'exécuter, la surcharge n'apparaît qu'une seule fois, améliorant ainsi les performances globales de temps.

Réduire l'empreinte mémoire

Un certain nombre de transformations, y compris l' interleave , la prefetch et la prefetch shuffle , maintiennent un tampon interne d'éléments. Si la fonction définie par l'utilisateur transmise à la transformation de map modifie la taille des éléments, l'ordre de la transformation de carte et les transformations qui tamponnent les éléments affectent l'utilisation de la mémoire. En général, nous vous recommandons de choisir l'ordre qui entraîne une moindre empreinte mémoire, sauf si un ordre différent est souhaitable pour les performances.

Mise en cache des calculs partiels

Il est recommandé de mettre en cache l'ensemble de données après la transformation de la map sauf si cette transformation rend les données trop volumineuses pour tenir en mémoire. Un compromis peut être obtenu si votre fonction mappée peut être divisée en deux parties: une qui prend du temps et une autre qui consomme de la mémoire. Dans ce cas, vous pouvez chaîner vos transformations comme ci-dessous:

 dataset.map(time_consuming_mapping).cache().map(memory_consuming_mapping)
 

De cette façon, la partie qui prend du temps n'est exécutée que pendant la première époque et vous évitez d'utiliser trop d'espace de cache.

Résumé des meilleures pratiques

Voici un résumé des meilleures pratiques pour la conception de pipelines d'entrée TensorFlow performants:

Reproduire les figures

Pour approfondir la tf.data.Dataset API tf.data.Dataset , vous pouvez jouer avec vos propres pipelines. Vous trouverez ci-dessous le code utilisé pour tracer les images de ce guide. Cela peut être un bon point de départ, montrant quelques solutions de contournement pour les difficultés courantes telles que:

  • Reproductibilité du temps d'exécution;
  • Fonctions mappées exécution impatiente;
  • transformation d' interleave appelable.
 import itertools
from collections import defaultdict

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
 

Le jeu de données

Similaire à ArtificialDataset vous pouvez créer un ensemble de données renvoyant le temps passé à chaque étape.

 class TimeMeasuredDataset(tf.data.Dataset):
    # OUTPUT: (steps, timings, counters)
    OUTPUT_TYPES = (tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32)
    OUTPUT_SHAPES = ((2, 1), (2, 2), (2, 3))
    
    _INSTANCES_COUNTER = itertools.count()  # Number of datasets generated
    _EPOCHS_COUNTER = defaultdict(itertools.count)  # Number of epochs done for each dataset
    
    def _generator(instance_idx, num_samples):
        epoch_idx = next(TimeMeasuredDataset._EPOCHS_COUNTER[instance_idx])
        
        # Opening the file
        open_enter = time.perf_counter()
        time.sleep(0.03)
        open_elapsed = time.perf_counter() - open_enter
        
        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            read_enter = time.perf_counter()
            time.sleep(0.015)
            read_elapsed = time.perf_counter() - read_enter
            
            yield (
                [("Open",), ("Read",)],
                [(open_enter, open_elapsed), (read_enter, read_elapsed)],
                [(instance_idx, epoch_idx, -1), (instance_idx, epoch_idx, sample_idx)]
            )
            open_enter, open_elapsed = -1., -1.  # Negative values will be filtered
            
    
    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_types=cls.OUTPUT_TYPES,
            output_shapes=cls.OUTPUT_SHAPES,
            args=(next(cls._INSTANCES_COUNTER), num_samples)
        )
 

Cet ensemble de données fournit des échantillons de forme [[2, 1], [2, 2], [2, 3]] et de type [tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32] . Chaque échantillon est:

 (
  [("Open"), ("Read")],
  [(t0, d), (t0, d)],
  [(i, e, -1), (i, e, s)]
)
 

Où:

  • Open et Read sont des identificateurs d'étapes
  • t0 est l'horodatage du début de l'étape correspondante
  • d est le temps passé à l'étape correspondante
  • i est l'index de l'instance
  • e est l'indice d'époque (nombre de fois que l'ensemble de données a été itéré)
  • s est l'index de l'échantillon

La boucle d'itération

Rendez la boucle d'itération un peu plus compliquée pour agréger tous les horaires. Cela ne fonctionnera qu'avec des ensembles de données générant des échantillons comme détaillé ci-dessus.

 def timelined_benchmark(dataset, num_epochs=2):
    # Initialize accumulators
    steps_acc = tf.zeros([0, 1], dtype=tf.dtypes.string)
    times_acc = tf.zeros([0, 2], dtype=tf.dtypes.float32)
    values_acc = tf.zeros([0, 3], dtype=tf.dtypes.int32)
    
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        epoch_enter = time.perf_counter()
        for (steps, times, values) in dataset:
            # Record dataset preparation informations
            steps_acc = tf.concat((steps_acc, steps), axis=0)
            times_acc = tf.concat((times_acc, times), axis=0)
            values_acc = tf.concat((values_acc, values), axis=0)
            
            # Simulate training time
            train_enter = time.perf_counter()
            time.sleep(0.01)
            train_elapsed = time.perf_counter() - train_enter
            
            # Record training informations
            steps_acc = tf.concat((steps_acc, [["Train"]]), axis=0)
            times_acc = tf.concat((times_acc, [(train_enter, train_elapsed)]), axis=0)
            values_acc = tf.concat((values_acc, [values[-1]]), axis=0)
        
        epoch_elapsed = time.perf_counter() - epoch_enter
        # Record epoch informations
        steps_acc = tf.concat((steps_acc, [["Epoch"]]), axis=0)
        times_acc = tf.concat((times_acc, [(epoch_enter, epoch_elapsed)]), axis=0)
        values_acc = tf.concat((values_acc, [[-1, epoch_num, -1]]), axis=0)
        time.sleep(0.001)
    
    tf.print("Execution time:", time.perf_counter() - start_time)
    return {"steps": steps_acc, "times": times_acc, "values": values_acc}
 

La méthode de traçage

Enfin, définissez une fonction capable de tracer une chronologie compte tenu des valeurs renvoyées par la fonction timelined_benchmark .

 def draw_timeline(timeline, title, width=0.5, annotate=False, save=False):
    # Remove invalid entries (negative times, or empty steps) from the timelines
    invalid_mask = np.logical_and(timeline['times'] > 0, timeline['steps'] != b'')[:,0]
    steps = timeline['steps'][invalid_mask].numpy()
    times = timeline['times'][invalid_mask].numpy()
    values = timeline['values'][invalid_mask].numpy()
    
    # Get a set of different steps, ordered by the first time they are encountered
    step_ids, indices = np.stack(np.unique(steps, return_index=True))
    step_ids = step_ids[np.argsort(indices)]

    # Shift the starting time to 0 and compute the maximal time value
    min_time = times[:,0].min()
    times[:,0] = (times[:,0] - min_time)
    end = max(width, (times[:,0]+times[:,1]).max() + 0.01)
    
    cmap = mpl.cm.get_cmap("plasma")
    plt.close()
    fig, axs = plt.subplots(len(step_ids), sharex=True, gridspec_kw={'hspace': 0})
    fig.suptitle(title)
    fig.set_size_inches(17.0, len(step_ids))
    plt.xlim(-0.01, end)
    
    for i, step in enumerate(step_ids):
        step_name = step.decode()
        ax = axs[i]
        ax.set_ylabel(step_name)
        ax.set_ylim(0, 1)
        ax.set_yticks([])
        ax.set_xlabel("time (s)")
        ax.set_xticklabels([])
        ax.grid(which="both", axis="x", color="k", linestyle=":")
        
        # Get timings and annotation for the given step
        entries_mask = np.squeeze(steps==step)
        serie = np.unique(times[entries_mask], axis=0)
        annotations = values[entries_mask]
        
        ax.broken_barh(serie, (0, 1), color=cmap(i / len(step_ids)), linewidth=1, alpha=0.66)
        if annotate:
            for j, (start, width) in enumerate(serie):
                annotation = "\n".join([f"{l}: {v}" for l,v in zip(("i", "e", "s"), annotations[j])])
                ax.text(start + 0.001 + (0.001 * (j % 2)), 0.55 - (0.1 * (j % 2)), annotation,
                        horizontalalignment='left', verticalalignment='center')
    if save:
        plt.savefig(title.lower().translate(str.maketrans(" ", "_")) + ".svg")
 

Utiliser des wrappers pour la fonction mappée

Pour exécuter une fonction mappée dans un contexte impatient, vous devez les envelopper dans un appel tf.py_function .

 def map_decorator(func):
    def wrapper(steps, times, values):
        # Use a tf.py_function to prevent auto-graph from compiling the method
        return tf.py_function(
            func,
            inp=(steps, times, values),
            Tout=(steps.dtype, times.dtype, values.dtype)
        )
    return wrapper
 

Comparaison des pipelines

 _batch_map_num_items = 50

def dataset_generator_fun(*args):
    return TimeMeasuredDataset(num_samples=_batch_map_num_items)
 

Naïve

 @map_decorator
def naive_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001)  # Time consuming step
    time.sleep(0.0001)  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, [["Map"]]), axis=0),
        tf.concat((times, [[map_enter, map_elapsed]]), axis=0),
        tf.concat((values, [values[-1]]), axis=0)
    )

naive_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .flat_map(dataset_generator_fun)
    .map(naive_map)
    .batch(_batch_map_num_items, drop_remainder=True)
    .unbatch(),
    5
)
 
Execution time: 12.461227312999995

Optimisé

 @map_decorator
def time_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001 * values.shape[0])  # Time consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, tf.tile([[["1st map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


@map_decorator
def memory_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.0001 * values.shape[0])  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    # Use tf.tile to handle batch dimension
    return (
        tf.concat((steps, tf.tile([[["2nd map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


optimized_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .interleave(  # Parallelize data reading
        dataset_generator_fun,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
    .batch(  # Vectorize your mapped function
        _batch_map_num_items,
        drop_remainder=True)
    .map(  # Parallelize map transformation
        time_consuming_map,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
    .cache()  # Cache data
    .map(  # Reduce memory usage
        memory_consuming_map,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
    .prefetch(  # Overlap producer and consumer works
        tf.data.experimental.AUTOTUNE
    )
    .unbatch(),
    5
)
 
Execution time: 6.333680681000033

 draw_timeline(naive_timeline, "Naive", 15)
 

png

 draw_timeline(optimized_timeline, "Optimized", 15)
 

png