De meilleures performances avec l'API tf.data

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier

Aperçu

Les GPU et les TPU peuvent réduire considérablement le temps nécessaire pour exécuter une seule étape de formation. L'obtention de performances optimales nécessite un pipeline d'entrée efficace qui fournit des données pour l'étape suivante avant la fin de l'étape en cours. L'API tf.data permet de créer des pipelines d'entrée flexibles et efficaces. Ce document montre comment utiliser l'API tf.data pour créer des pipelines d'entrée TensorFlow hautement performants.

Avant de continuer, consultez le guide Build TensorFlow input pipelines pour savoir comment utiliser l'API tf.data .

Ressources

Installer

import tensorflow as tf

import time

Tout au long de ce guide, vous allez parcourir un ensemble de données et mesurer les performances. Faire des benchmarks de performance reproductibles peut être difficile. Les différents facteurs affectant la reproductibilité comprennent :

  • La charge CPU actuelle
  • Le trafic réseau
  • Mécanismes complexes, tels que le cache

Pour obtenir un benchmark reproductible, vous allez construire un exemple artificiel.

Le jeu de données

Commencez par définir une classe héritant de tf.data.Dataset appelée ArtificialDataset . Cet ensemble de données :

  • Génère num_samples échantillons (la valeur par défaut est 3)
  • Dort pendant un certain temps avant le premier élément pour simuler l'ouverture d'un fichier
  • Dort pendant un certain temps avant de produire chaque élément pour simuler la lecture de données à partir d'un fichier
class ArtificialDataset(tf.data.Dataset):
    def _generator(num_samples):
        # Opening the file
        time.sleep(0.03)

        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            time.sleep(0.015)

            yield (sample_idx,)

    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),
            args=(num_samples,)
        )

Ce jeu de données est similaire à celui de tf.data.Dataset.range , ajoutant un délai fixe au début et entre chaque échantillon.

La boucle d'entraînement

Ensuite, écrivez une boucle d'entraînement factice qui mesure le temps nécessaire pour itérer sur un ensemble de données. Le temps de formation est simulé.

def benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        for sample in dataset:
            # Performing a training step
            time.sleep(0.01)
    print("Execution time:", time.perf_counter() - start_time)

Optimiser les performances

Pour montrer comment les performances peuvent être optimisées, vous allez améliorer les performances de ArtificialDataset .

L'approche naïve

Commencez avec un pipeline naïf sans aucune astuce, itérant sur l'ensemble de données tel quel.

benchmark(ArtificialDataset())
Execution time: 0.26497629899995445

Sous le capot, voici comment votre temps d'exécution a été dépensé :

Tracé du temps d'exécution des données - une méthode naïve

Le graphique montre que l'exécution d'une étape d'entraînement implique :

  • Ouvrir un fichier s'il n'a pas encore été ouvert
  • Récupération d'une entrée de données à partir du fichier
  • Utilisation des données pour la formation

Cependant, dans une implémentation synchrone naïve comme ici, pendant que votre pipeline récupère les données, votre modèle est inactif. Inversement, pendant que votre modèle s'entraîne, le pipeline d'entrée est inactif. Le temps de pas d'apprentissage est donc la somme des temps d'ouverture, de lecture et d'apprentissage.

Les sections suivantes s'appuient sur ce pipeline d'entrée, illustrant les bonnes pratiques pour la conception de pipelines d'entrée TensorFlow performants.

Prélecture

La prélecture chevauche le prétraitement et l'exécution du modèle d'une étape d'apprentissage. Pendant que le modèle exécute l'étape d'apprentissage s , le pipeline d'entrée lit les données pour l'étape s+1 . Cela réduit le temps de pas au maximum (par opposition à la somme) de la formation et le temps nécessaire pour extraire les données.

L'API tf.data fournit la transformation tf.data.Dataset.prefetch . Il peut être utilisé pour découpler le moment où les données sont produites du moment où les données sont consommées. En particulier, la transformation utilise un thread d'arrière-plan et un tampon interne pour préextraire les éléments de l'ensemble de données d'entrée avant le moment où ils sont demandés. Le nombre d'éléments à prérécupérer doit être égal (ou éventuellement supérieur) au nombre de lots consommés par une seule étape d'apprentissage. Vous pouvez soit ajuster manuellement cette valeur, soit la définir sur tf.data.AUTOTUNE , ce qui demandera à l'environnement d'exécution tf.data d'ajuster dynamiquement la valeur au moment de l'exécution.

Notez que la transformation de prélecture offre des avantages chaque fois qu'il y a une opportunité de chevaucher le travail d'un "producteur" avec le travail d'un "consommateur".

benchmark(
    ArtificialDataset()
    .prefetch(tf.data.AUTOTUNE)
)
Execution time: 0.21731788600027357

Tracé du temps d'exécution des données - méthode de prélecture

Maintenant, comme le montre le tracé du temps d'exécution des données, pendant que l'étape d'apprentissage s'exécute pour l'échantillon 0, le pipeline d'entrée lit les données pour l'échantillon 1, et ainsi de suite.

Parallélisation de l'extraction de données

Dans un environnement réel, les données d'entrée peuvent être stockées à distance (par exemple, sur Google Cloud Storage ou HDFS). Un pipeline d'ensemble de données qui fonctionne bien lors de la lecture de données localement peut présenter un goulot d'étranglement sur les E/S lors de la lecture de données à distance en raison des différences suivantes entre le stockage local et le stockage distant :

  • Time-to-first-byte : la lecture du premier octet d'un fichier à partir d'un stockage distant peut prendre des ordres de grandeur plus longs qu'à partir d'un stockage local.
  • Débit de lecture : alors que le stockage distant offre généralement une large bande passante agrégée, la lecture d'un seul fichier peut n'utiliser qu'une petite fraction de cette bande passante.

De plus, une fois les octets bruts chargés en mémoire, il peut également être nécessaire de désérialiser et/ou de déchiffrer les données (par exemple protobuf ), ce qui nécessite des calculs supplémentaires. Cette surcharge est présente indépendamment du fait que les données soient stockées localement ou à distance, mais peut être pire dans le cas distant si les données ne sont pas prélues efficacement.

Pour atténuer l'impact des divers frais généraux d'extraction de données, la transformation tf.data.Dataset.interleave peut être utilisée pour paralléliser l'étape de chargement des données, en entrelaçant le contenu d'autres ensembles de données (tels que les lecteurs de fichiers de données). Le nombre d'ensembles de données à chevaucher peut être spécifié par l'argument cycle_length , tandis que le niveau de parallélisme peut être spécifié par l'argument num_parallel_calls . Semblable à la transformation de prefetch , la transformation d' interleave prend en charge tf.data.AUTOTUNE , qui délègue la décision sur le niveau de parallélisme à utiliser au runtime tf.data .

Entrelacement séquentiel

Les arguments par défaut de la transformation tf.data.Dataset.interleave lui permettent d'entrelacer séquentiellement des échantillons uniques de deux ensembles de données.

benchmark(
    tf.data.Dataset.range(2)
    .interleave(lambda _: ArtificialDataset())
)
Execution time: 0.4987426460002098

Tracé du temps d'exécution des données - entrelacement séquentiel

Ce tracé du temps d'exécution des données permet d'exposer le comportement de la transformation interleave , en récupérant alternativement des échantillons à partir des deux ensembles de données disponibles. Cependant, aucune amélioration des performances n'est impliquée ici.

Entrelacement parallèle

Maintenant, utilisez l'argument num_parallel_calls de la transformation interleave . Cela charge plusieurs ensembles de données en parallèle, ce qui réduit le temps d'attente pour l'ouverture des fichiers.

benchmark(
    tf.data.Dataset.range(2)
    .interleave(
        lambda _: ArtificialDataset(),
        num_parallel_calls=tf.data.AUTOTUNE
    )
)
Execution time: 0.283668874000341

Tracé du temps d'exécution des données - méthode d'entrelacement parallèle

Cette fois, comme le montre le tracé du temps d'exécution des données, la lecture des deux jeux de données est parallélisée, ce qui réduit le temps de traitement global des données.

Paralléliser la transformation des données

Lors de la préparation des données, les éléments d'entrée peuvent devoir être prétraités. À cette fin, l'API tf.data propose la transformation tf.data.Dataset.map , qui applique une fonction définie par l'utilisateur à chaque élément de l'ensemble de données d'entrée. Étant donné que les éléments d'entrée sont indépendants les uns des autres, le prétraitement peut être parallélisé sur plusieurs cœurs de processeur. Pour rendre cela possible, de la même manière que les transformations prefetch et interleave , la transformation map fournit l'argument num_parallel_calls pour spécifier le niveau de parallélisme.

Le choix de la meilleure valeur pour l'argument num_parallel_calls dépend de votre matériel, des caractéristiques de vos données d'entraînement (telles que leur taille et leur forme), du coût de votre fonction de carte et des autres traitements qui se produisent sur le CPU en même temps. Une heuristique simple consiste à utiliser le nombre de cœurs de processeur disponibles. Cependant, comme pour la transformation de prefetch et d' interleave , la transformation de map prend en charge tf.data.AUTOTUNE qui déléguera la décision sur le niveau de parallélisme à utiliser au runtime tf.data .

def mapped_function(s):
    # Do some hard pre-processing
    tf.py_function(lambda: time.sleep(0.03), [], ())
    return s

Cartographie séquentielle

Commencez par utiliser la transformation de map sans parallélisme comme exemple de référence.

benchmark(
    ArtificialDataset()
    .map(mapped_function)
)
Execution time: 0.4505277170001136

Tracé du temps d'exécution des données - méthode de mappage séquentiel

Quant à l' approche naïve , ici, comme le montre l'intrigue, les temps passés pour les étapes d'ouverture, de lecture, de pré-traitement (cartographie) et d'apprentissage s'additionnent pour une seule itération.

Cartographie parallèle

Maintenant, utilisez la même fonction de prétraitement mais appliquez-la en parallèle sur plusieurs échantillons.

benchmark(
    ArtificialDataset()
    .map(
        mapped_function,
        num_parallel_calls=tf.data.AUTOTUNE
    )
)
Execution time: 0.2839677860001757

Temps d'exécution des données - mappage parallèle

Comme le montre le diagramme de données, les étapes de prétraitement se chevauchent, ce qui réduit le temps global pour une seule itération.

Mise en cache

La transformation tf.data.Dataset.cache peut mettre en cache un jeu de données, soit en mémoire, soit sur le stockage local. Cela évitera que certaines opérations (telles que l'ouverture de fichiers et la lecture de données) ne soient exécutées à chaque époque.

benchmark(
    ArtificialDataset()
    .map(  # Apply time consuming operations before cache
        mapped_function
    ).cache(
    ),
    5
)
Execution time: 0.3848854380003104

Temps d'exécution des données - méthode du jeu de données mis en cache

Ici, le tracé du temps d'exécution des données montre que lorsque vous mettez en cache un ensemble de données, les transformations avant celle du cache (comme l'ouverture du fichier et la lecture des données) ne sont exécutées que pendant la première époque. Les époques suivantes réutiliseront les données mises en cache par la transformation du cache .

Si la fonction définie par l'utilisateur transmise à la transformation de map est coûteuse, appliquez la transformation de cache après la transformation de map tant que l'ensemble de données résultant peut toujours tenir dans la mémoire ou le stockage local. Si la fonction définie par l'utilisateur augmente l'espace requis pour stocker l'ensemble de données au-delà de la capacité du cache, appliquez-la après la transformation du cache ou envisagez de prétraiter vos données avant votre tâche d'entraînement afin de réduire l'utilisation des ressources.

Cartographie de vectorisation

L'appel d'une fonction définie par l'utilisateur transmise à la transformation de map entraîne une surcharge liée à la planification et à l'exécution de la fonction définie par l'utilisateur. Vectorisez la fonction définie par l'utilisateur (c'est-à-dire faites-la fonctionner sur un lot d'entrées à la fois) et appliquez la transformation batch avant la transformation de la map .

Pour illustrer cette bonne pratique, votre jeu de données artificiel n'est pas adapté. Le délai de planification est d'environ 10 microsecondes (10e-6 secondes), bien inférieur aux dizaines de millisecondes utilisées dans le ArtificialDataset , et son impact est donc difficile à voir.

Pour cet exemple, utilisez la fonction de base tf.data.Dataset.range et simplifiez la boucle de formation dans sa forme la plus simple.

fast_dataset = tf.data.Dataset.range(10000)

def fast_benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for _ in tf.data.Dataset.range(num_epochs):
        for _ in dataset:
            pass
    tf.print("Execution time:", time.perf_counter() - start_time)

def increment(x):
    return x+1

Cartographie scalaire

fast_benchmark(
    fast_dataset
    # Apply function one item at a time
    .map(increment)
    # Batch
    .batch(256)
)
Execution time: 0.2712608739998359

Temps d'exécution des données - méthode de carte scalaire

Le graphique ci-dessus illustre ce qui se passe (avec moins d'échantillons) en utilisant la méthode de cartographie scalaire. Il montre que la fonction mappée est appliquée pour chaque échantillon. Bien que cette fonction soit très rapide, elle a une surcharge qui a un impact sur les performances temporelles.

Cartographie vectorisée

fast_benchmark(
    fast_dataset
    .batch(256)
    # Apply function on a batch of items
    # The tf.Tensor.__add__ method already handle batches
    .map(increment)
)
Execution time: 0.02737950600021577

Temps d'exécution des données - méthode de la carte vectorisée

Cette fois, la fonction mappée est appelée une fois et s'applique à un lot d'échantillon. Comme le montre le tracé du temps d'exécution des données, bien que la fonction puisse prendre plus de temps à s'exécuter, la surcharge n'apparaît qu'une seule fois, ce qui améliore les performances globales en matière de temps.

Réduction de l'empreinte mémoire

Un certain nombre de transformations, notamment interleave , prefetch et shuffle , maintiennent un tampon interne d'éléments. Si la fonction définie par l'utilisateur transmise à la transformation de map modifie la taille des éléments, l'ordre de la transformation de carte et des transformations qui mettent les éléments en mémoire tampon affecte l'utilisation de la mémoire. En général, choisissez l'ordre qui réduit l'empreinte mémoire, à moins qu'un ordre différent ne soit souhaitable pour les performances.

Mise en cache des calculs partiels

Il est recommandé de mettre en cache le jeu de données après la transformation de la map , sauf si cette transformation rend les données trop volumineuses pour tenir en mémoire. Un compromis peut être atteint si votre fonction mappée peut être divisée en deux parties : une partie qui prend du temps et une partie qui consomme de la mémoire. Dans ce cas, vous pouvez enchaîner vos transformations comme ci-dessous :

dataset.map(time_consuming_mapping).cache().map(memory_consuming_mapping)

De cette façon, la partie chronophage n'est exécutée que pendant la première époque, et vous évitez d'utiliser trop d'espace de cache.

Résumé des meilleures pratiques

Voici un résumé des bonnes pratiques pour concevoir des pipelines d'entrée TensorFlow performants :

Reproduisant les chiffres

Pour approfondir la compréhension de l'API tf.data.Dataset , vous pouvez jouer avec vos propres pipelines. Vous trouverez ci-dessous le code utilisé pour tracer les images de ce guide. Cela peut être un bon point de départ, montrant quelques solutions de contournement pour les difficultés courantes telles que :

  • Reproductibilité du temps d'exécution
  • Fonctions mappées exécutées avec impatience
  • transformation interleave appelable
import itertools
from collections import defaultdict

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

Le jeu de données

Semblable au ArtificialDataset , vous pouvez créer un ensemble de données renvoyant le temps passé à chaque étape.

class TimeMeasuredDataset(tf.data.Dataset):
    # OUTPUT: (steps, timings, counters)
    OUTPUT_TYPES = (tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32)
    OUTPUT_SHAPES = ((2, 1), (2, 2), (2, 3))

    _INSTANCES_COUNTER = itertools.count()  # Number of datasets generated
    _EPOCHS_COUNTER = defaultdict(itertools.count)  # Number of epochs done for each dataset

    def _generator(instance_idx, num_samples):
        epoch_idx = next(TimeMeasuredDataset._EPOCHS_COUNTER[instance_idx])

        # Opening the file
        open_enter = time.perf_counter()
        time.sleep(0.03)
        open_elapsed = time.perf_counter() - open_enter

        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            read_enter = time.perf_counter()
            time.sleep(0.015)
            read_elapsed = time.perf_counter() - read_enter

            yield (
                [("Open",), ("Read",)],
                [(open_enter, open_elapsed), (read_enter, read_elapsed)],
                [(instance_idx, epoch_idx, -1), (instance_idx, epoch_idx, sample_idx)]
            )
            open_enter, open_elapsed = -1., -1.  # Negative values will be filtered


    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_types=cls.OUTPUT_TYPES,
            output_shapes=cls.OUTPUT_SHAPES,
            args=(next(cls._INSTANCES_COUNTER), num_samples)
        )

Cet ensemble de données fournit des échantillons de forme [[2, 1], [2, 2], [2, 3]] et de type [tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32] . Chaque échantillon est :

(
  [("Open"), ("Read")],
  [(t0, d), (t0, d)],
  [(i, e, -1), (i, e, s)]
)

Où:

  • Open et Read sont des identifiants d'étapes
  • t0 est l'horodatage du démarrage de l'étape correspondante
  • d est le temps passé dans l'étape correspondante
  • i est l'indice d'instance
  • e est l'indice d'époque (nombre de fois où l'ensemble de données a été itéré)
  • s est l'indice d'échantillon

La boucle d'itération

Rendez la boucle d'itération un peu plus compliquée pour agréger tous les timings. Cela ne fonctionnera qu'avec des ensembles de données générant des échantillons comme indiqué ci-dessus.

def timelined_benchmark(dataset, num_epochs=2):
    # Initialize accumulators
    steps_acc = tf.zeros([0, 1], dtype=tf.dtypes.string)
    times_acc = tf.zeros([0, 2], dtype=tf.dtypes.float32)
    values_acc = tf.zeros([0, 3], dtype=tf.dtypes.int32)

    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        epoch_enter = time.perf_counter()
        for (steps, times, values) in dataset:
            # Record dataset preparation informations
            steps_acc = tf.concat((steps_acc, steps), axis=0)
            times_acc = tf.concat((times_acc, times), axis=0)
            values_acc = tf.concat((values_acc, values), axis=0)

            # Simulate training time
            train_enter = time.perf_counter()
            time.sleep(0.01)
            train_elapsed = time.perf_counter() - train_enter

            # Record training informations
            steps_acc = tf.concat((steps_acc, [["Train"]]), axis=0)
            times_acc = tf.concat((times_acc, [(train_enter, train_elapsed)]), axis=0)
            values_acc = tf.concat((values_acc, [values[-1]]), axis=0)

        epoch_elapsed = time.perf_counter() - epoch_enter
        # Record epoch informations
        steps_acc = tf.concat((steps_acc, [["Epoch"]]), axis=0)
        times_acc = tf.concat((times_acc, [(epoch_enter, epoch_elapsed)]), axis=0)
        values_acc = tf.concat((values_acc, [[-1, epoch_num, -1]]), axis=0)
        time.sleep(0.001)

    tf.print("Execution time:", time.perf_counter() - start_time)
    return {"steps": steps_acc, "times": times_acc, "values": values_acc}

La méthode de traçage

Enfin, définissez une fonction capable de tracer une chronologie en fonction des valeurs renvoyées par la fonction timelined_benchmark .

def draw_timeline(timeline, title, width=0.5, annotate=False, save=False):
    # Remove invalid entries (negative times, or empty steps) from the timelines
    invalid_mask = np.logical_and(timeline['times'] > 0, timeline['steps'] != b'')[:,0]
    steps = timeline['steps'][invalid_mask].numpy()
    times = timeline['times'][invalid_mask].numpy()
    values = timeline['values'][invalid_mask].numpy()

    # Get a set of different steps, ordered by the first time they are encountered
    step_ids, indices = np.stack(np.unique(steps, return_index=True))
    step_ids = step_ids[np.argsort(indices)]

    # Shift the starting time to 0 and compute the maximal time value
    min_time = times[:,0].min()
    times[:,0] = (times[:,0] - min_time)
    end = max(width, (times[:,0]+times[:,1]).max() + 0.01)

    cmap = mpl.cm.get_cmap("plasma")
    plt.close()
    fig, axs = plt.subplots(len(step_ids), sharex=True, gridspec_kw={'hspace': 0})
    fig.suptitle(title)
    fig.set_size_inches(17.0, len(step_ids))
    plt.xlim(-0.01, end)

    for i, step in enumerate(step_ids):
        step_name = step.decode()
        ax = axs[i]
        ax.set_ylabel(step_name)
        ax.set_ylim(0, 1)
        ax.set_yticks([])
        ax.set_xlabel("time (s)")
        ax.set_xticklabels([])
        ax.grid(which="both", axis="x", color="k", linestyle=":")

        # Get timings and annotation for the given step
        entries_mask = np.squeeze(steps==step)
        serie = np.unique(times[entries_mask], axis=0)
        annotations = values[entries_mask]

        ax.broken_barh(serie, (0, 1), color=cmap(i / len(step_ids)), linewidth=1, alpha=0.66)
        if annotate:
            for j, (start, width) in enumerate(serie):
                annotation = "\n".join([f"{l}: {v}" for l,v in zip(("i", "e", "s"), annotations[j])])
                ax.text(start + 0.001 + (0.001 * (j % 2)), 0.55 - (0.1 * (j % 2)), annotation,
                        horizontalalignment='left', verticalalignment='center')
    if save:
        plt.savefig(title.lower().translate(str.maketrans(" ", "_")) + ".svg")

Utiliser des wrappers pour la fonction mappée

Pour exécuter une fonction mappée dans un contexte impatient, vous devez les envelopper dans un appel tf.py_function .

def map_decorator(func):
    def wrapper(steps, times, values):
        # Use a tf.py_function to prevent auto-graph from compiling the method
        return tf.py_function(
            func,
            inp=(steps, times, values),
            Tout=(steps.dtype, times.dtype, values.dtype)
        )
    return wrapper

Comparaison des pipelines

_batch_map_num_items = 50

def dataset_generator_fun(*args):
    return TimeMeasuredDataset(num_samples=_batch_map_num_items)

Naïve

@map_decorator
def naive_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001)  # Time consuming step
    time.sleep(0.0001)  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, [["Map"]]), axis=0),
        tf.concat((times, [[map_enter, map_elapsed]]), axis=0),
        tf.concat((values, [values[-1]]), axis=0)
    )

naive_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .flat_map(dataset_generator_fun)
    .map(naive_map)
    .batch(_batch_map_num_items, drop_remainder=True)
    .unbatch(),
    5
)
WARNING:tensorflow:From /tmp/ipykernel_23983/64197174.py:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_types is deprecated and will be removed in a future version.
Instructions for updating:
Use output_signature instead
WARNING:tensorflow:From /tmp/ipykernel_23983/64197174.py:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_shapes is deprecated and will be removed in a future version.
Instructions for updating:
Use output_signature instead
Execution time: 13.13538893499981

Optimisé

@map_decorator
def time_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001 * values.shape[0])  # Time consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, tf.tile([[["1st map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


@map_decorator
def memory_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.0001 * values.shape[0])  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    # Use tf.tile to handle batch dimension
    return (
        tf.concat((steps, tf.tile([[["2nd map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


optimized_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .interleave(  # Parallelize data reading
        dataset_generator_fun,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .batch(  # Vectorize your mapped function
        _batch_map_num_items,
        drop_remainder=True)
    .map(  # Parallelize map transformation
        time_consuming_map,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .cache()  # Cache data
    .map(  # Reduce memory usage
        memory_consuming_map,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .prefetch(  # Overlap producer and consumer works
        tf.data.AUTOTUNE
    )
    .unbatch(),
    5
)
Execution time: 6.723691489999965
draw_timeline(naive_timeline, "Naive", 15)

png

draw_timeline(optimized_timeline, "Optimized", 15)

png