Melhor desempenho com a API tf.data

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno

Visão geral

GPUs e TPUs podem reduzir radicalmente o tempo necessário para executar uma única etapa de treinamento. Alcançar o desempenho máximo requer um pipeline de entrada eficiente que fornece dados para a próxima etapa antes que a etapa atual seja concluída. O tf.data API ajuda a construir condutas de entrada flexíveis e eficientes. Este documento demonstra como usar o tf.data API para construir dutos de entrada TensorFlow com alto desempenho.

Antes de continuar, verifique os dutos de entrada Desenvolver TensorFlow guia para aprender como usar o tf.data API.

Recursos

Configurar

import tensorflow as tf

import time

Ao longo deste guia, você irá iterar em um conjunto de dados e medir o desempenho. Fazer benchmarks de desempenho reproduzíveis pode ser difícil. Diferentes fatores que afetam a reprodutibilidade incluem:

  • A carga atual da CPU
  • O tráfego da rede
  • Mecanismos complexos, como cache

Para obter um benchmark reproduzível, você construirá um exemplo artificial.

O conjunto de dados

Comece com a definição de uma classe herdada de tf.data.Dataset chamado ArtificialDataset . Este conjunto de dados:

  • Gera num_samples amostras (padrão é 3)
  • Dorme por algum tempo antes do primeiro item para simular a abertura de um arquivo
  • Dorme por algum tempo antes de produzir cada item para simular a leitura de dados de um arquivo
class ArtificialDataset(tf.data.Dataset):
    def _generator(num_samples):
        # Opening the file
        time.sleep(0.03)

        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            time.sleep(0.015)

            yield (sample_idx,)

    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),
            args=(num_samples,)
        )

Este conjunto de dados é semelhante ao tf.data.Dataset.range um, adicionando um atraso fixo no início e no meio de cada amostra.

O ciclo de treinamento

Em seguida, escreva um loop de treinamento fictício que mede quanto tempo leva para iterar em um conjunto de dados. O tempo de treinamento é simulado.

def benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        for sample in dataset:
            # Performing a training step
            time.sleep(0.01)
    print("Execution time:", time.perf_counter() - start_time)

Otimize o desempenho

Para exibir como o desempenho pode ser otimizado, você irá melhorar o desempenho do ArtificialDataset .

A abordagem ingênua

Comece com um pipeline ingênuo, sem truques, iterando o conjunto de dados no estado em que se encontra.

benchmark(ArtificialDataset())
Execution time: 0.26497629899995445

Nos bastidores, é assim que seu tempo de execução foi gasto:

Gráfico de tempo de execução de dados - um método ingênuo

O gráfico mostra que a realização de uma etapa de treinamento envolve:

  • Abrir um arquivo se ainda não tiver sido aberto
  • Buscando uma entrada de dados do arquivo
  • Usando os dados para treinamento

No entanto, em uma implementação síncrona ingênua como esta, enquanto o pipeline busca os dados, o modelo fica ocioso. Por outro lado, enquanto seu modelo está treinando, o pipeline de entrada fica ocioso. O tempo da etapa de treinamento é, portanto, a soma dos tempos de abertura, leitura e treinamento.

As próximas seções se baseiam neste pipeline de entrada, ilustrando as práticas recomendadas para projetar pipelines de entrada do TensorFlow de alto desempenho.

Pré-busca

A pré-busca se sobrepõe ao pré-processamento e à execução do modelo de uma etapa de treinamento. Enquanto o modelo está executando treinamento passo s , o pipeline de entrada é a leitura dos dados para a etapa s+1 . Isso reduz o tempo da etapa ao máximo (em oposição à soma) do treinamento e o tempo que leva para extrair os dados.

O tf.data API fornece a tf.data.Dataset.prefetch transformação. Ele pode ser usado para separar o momento em que os dados são produzidos do momento em que os dados são consumidos. Em particular, a transformação usa um thread de segundo plano e um buffer interno para pré-buscar elementos do conjunto de dados de entrada antes do momento em que são solicitados. O número de elementos para pré-busca deve ser igual (ou possivelmente maior que) ao número de lotes consumidos por uma única etapa de treinamento. Você poderia sintonizar manualmente esse valor, ou configurá-lo para tf.data.AUTOTUNE , que pedirá ao tf.data runtime para ajustar o valor dinamicamente durante a execução.

Observe que a transformação de pré-busca fornece benefícios sempre que houver uma oportunidade de sobrepor o trabalho de um "produtor" ao trabalho de um "consumidor".

benchmark(
    ArtificialDataset()
    .prefetch(tf.data.AUTOTUNE)
)
Execution time: 0.21731788600027357

Gráfico de tempo de execução de dados - método de pré-busca

Agora, como mostra o gráfico de tempo de execução de dados, enquanto a etapa de treinamento está em execução para a amostra 0, o pipeline de entrada está lendo os dados para a amostra 1 e assim por diante.

Paralelizando a extração de dados

Em uma configuração do mundo real, os dados de entrada podem ser armazenados remotamente (por exemplo, no Google Cloud Storage ou HDFS). Um pipeline de conjunto de dados que funciona bem ao ler dados localmente pode se tornar um gargalo de E / S ao ler dados remotamente devido às seguintes diferenças entre armazenamento local e remoto:

  • Time-to-primeiro-byte: Lendo o primeiro byte de um arquivo do armazenamento remoto pode tomar ordens de magnitude maior do que a partir do armazenamento local.
  • Leia rendimento: Enquanto armazenamento remoto normalmente oferece grande largura de banda agregada, lendo um único arquivo só pode ser capaz de utilizar uma pequena fração desta largura de banda.

Além disso, uma vez que os bytes brutos são carregados na memória, pode também ser necessária para anular a serialização e / ou descriptografar os dados (por exemplo Protobuf ), que exige a computação adicional. Essa sobrecarga está presente independentemente de os dados serem armazenados local ou remotamente, mas pode ser pior no caso remoto se os dados não forem pré-buscados de forma eficaz.

Para atenuar o impacto das várias despesas gerais de extracção de dados, o tf.data.Dataset.interleave transformação pode ser usado para paralelizar a etapa de carregamento dos dados, a intercalação o conteúdo de outros conjuntos de dados (tais como leitores de ficheiros de dados). O número de conjuntos de dados de sobreposição pode ser especificado pelo cycle_length argumento, enquanto que o nível de paralelismo pode ser especificado pelo num_parallel_calls argumento. Semelhante ao prefetch transformação, a interleave de transformação suporta tf.data.AUTOTUNE , que irá delegar a decisão sobre o nível de paralelismo para uso ao tf.data tempo de execução.

Intercalação sequencial

Os argumentos padrão do tf.data.Dataset.interleave transformação torná-lo intercalar amostras únicas de dois conjuntos de dados seqüencialmente.

benchmark(
    tf.data.Dataset.range(2)
    .interleave(lambda _: ArtificialDataset())
)
Execution time: 0.4987426460002098

Gráfico de tempo de execução de dados - intercalação sequencial

Esta trama de tempo de execução de dados permite a exibir o comportamento do interleave transformação, buscar amostras, alternativamente, a partir dos dois conjuntos de dados disponíveis. No entanto, nenhuma melhoria de desempenho está envolvida aqui.

Intercalação paralela

Agora, use o num_parallel_calls argumento do interleave transformação. Isso carrega vários conjuntos de dados em paralelo, reduzindo o tempo de espera para os arquivos serem abertos.

benchmark(
    tf.data.Dataset.range(2)
    .interleave(
        lambda _: ArtificialDataset(),
        num_parallel_calls=tf.data.AUTOTUNE
    )
)
Execution time: 0.283668874000341

Gráfico de tempo de execução de dados - método de intercalação paralela

Desta vez, como mostra o gráfico do tempo de execução dos dados, a leitura dos dois conjuntos de dados é paralelizada, reduzindo o tempo global de processamento dos dados.

Paralelizando a transformação de dados

Ao preparar dados, os elementos de entrada podem precisar ser pré-processados. Para este fim, os tf.data API ofertas a tf.data.Dataset.map de transformação, o qual se aplica uma função definida pelo utilizador para cada elemento do conjunto de dados de entrada. Como os elementos de entrada são independentes uns dos outros, o pré-processamento pode ser paralelizado em vários núcleos de CPU. Para tornar isso possível, semelhante aos prefetch e interleave transformações, o map de transformação fornece a num_parallel_calls argumento para especificar o nível de paralelismo.

Escolher o melhor valor para o num_parallel_calls argumento depende do seu hardware, características de seus dados de treinamento (tais como seu tamanho e forma), o custo de sua função de mapa, eo que outro processamento está acontecendo na CPU ao mesmo tempo. Uma heurística simples é usar o número de núcleos de CPU disponíveis. No entanto, como para a prefetch e interleave transformação, o map de transformação suporta tf.data.AUTOTUNE que irá delegar a decisão sobre o nível de paralelismo para uso ao tf.data tempo de execução.

def mapped_function(s):
    # Do some hard pre-processing
    tf.py_function(lambda: time.sleep(0.03), [], ())
    return s

Mapeamento sequencial

Comece usando o map transformação sem paralelismo como um exemplo de referência.

benchmark(
    ArtificialDataset()
    .map(mapped_function)
)
Execution time: 0.4505277170001136

Gráfico de tempo de execução de dados - método de mapeamento sequencial

Como para a abordagem ingénua , aqui, como mostra a trama, o tempo gasto para a abertura, a leitura, o pré-processamento (mapeamento) e as etapas de formação resumir juntos por uma única iteração.

Mapeamento paralelo

Agora, use a mesma função de pré-processamento, mas aplique-a em paralelo em várias amostras.

benchmark(
    ArtificialDataset()
    .map(
        mapped_function,
        num_parallel_calls=tf.data.AUTOTUNE
    )
)
Execution time: 0.2839677860001757

Tempo de execução de dados - mapeamento paralelo

Como o gráfico de dados demonstra, as etapas de pré-processamento se sobrepõem, reduzindo o tempo geral para uma única iteração.

Cache

O tf.data.Dataset.cache transformação pode armazenar em cache um conjunto de dados, seja na memória ou no armazenamento local. Isso evitará que algumas operações (como abertura de arquivo e leitura de dados) sejam executadas durante cada época.

benchmark(
    ArtificialDataset()
    .map(  # Apply time consuming operations before cache
        mapped_function
    ).cache(
    ),
    5
)
Execution time: 0.3848854380003104

Tempo de execução de dados - método de conjunto de dados em cache

Aqui, a execução de dados em tempo gráfico mostra que quando você armazenar em cache um conjunto de dados, as transformações antes do cache um (como a abertura de arquivos e dados de leitura) são executados apenas durante a primeira época. As próximas épocas irá reutilizar os dados armazenados em cache pelo cache transformação.

Se a função definida pelo usuário passado para o map de transformação é caro, aplicar o cache transformação após o map de transformação, desde que o conjunto de dados resultante ainda pode caber na memória ou armazenamento local. Se a função definida pelo utilizador aumenta o espaço necessário para armazenar o conjunto de dados além da capacidade de cache, ou aplicá-lo após o cache transformação ou considerar pré-processamento de seus dados antes de seu trabalho de treinamento para reduzir o uso de recursos.

Mapeamento de vetorização

Invocando uma função definida pelo utilizador passou para o map transformação tem sobrecarga relacionada com a programação e executar a função definida pelo utilizador. Vetorizar a função definida pelo utilizador (isto é, tem que operam ao longo de um lote de entradas de uma só vez) e aplicar o batch transformação antes do map transformação.

Para ilustrar essa boa prática, seu conjunto de dados artificial não é adequado. O atraso de agendamento é de cerca de 10 microssegundos (10e-6 segundos), muito menos do que as dezenas de milissegundos utilizados na ArtificialDataset , e, portanto, o seu impacto é difícil de ver.

Para este exemplo, utilizar a base tf.data.Dataset.range função e simplificar o circuito de formação para a sua forma mais simples.

fast_dataset = tf.data.Dataset.range(10000)

def fast_benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for _ in tf.data.Dataset.range(num_epochs):
        for _ in dataset:
            pass
    tf.print("Execution time:", time.perf_counter() - start_time)

def increment(x):
    return x+1

Mapeamento escalar

fast_benchmark(
    fast_dataset
    # Apply function one item at a time
    .map(increment)
    # Batch
    .batch(256)
)
Execution time: 0.2712608739998359

Tempo de execução de dados - método de mapa escalar

O gráfico acima ilustra o que está acontecendo (com menos amostras) usando o método de mapeamento escalar. Mostra que a função mapeada é aplicada a cada amostra. Embora essa função seja muito rápida, ela tem alguma sobrecarga que afeta o desempenho do tempo.

Mapeamento vetorizado

fast_benchmark(
    fast_dataset
    .batch(256)
    # Apply function on a batch of items
    # The tf.Tensor.__add__ method already handle batches
    .map(increment)
)
Execution time: 0.02737950600021577

Tempo de execução de dados - método de mapa vetorizado

Desta vez, a função mapeada é chamada uma vez e se aplica a um lote de amostra. Como mostra o gráfico de tempo de execução de dados, embora a função possa levar mais tempo para ser executada, a sobrecarga aparece apenas uma vez, melhorando o desempenho geral do tempo.

Reduzindo o consumo de memória

Um certo número de transformações, incluindo interleave , prefetch , e shuffle , manter uma reserva interna de elementos. Se a função definida pelo utilizador passou para o map transformação altera o tamanho dos elementos, em seguida, a ordenação do mapa de transformação e as transformações que elementos de amortecimento afecta o uso de memória. Em geral, escolha a ordem que resulta em menor consumo de memória, a menos que uma ordem diferente seja desejável para o desempenho.

Cache de cálculos parciais

É recomendado para armazenar em cache o conjunto de dados após o map transformação salvo se essa transformação torna os dados muito grande para caber na memória. Uma compensação pode ser alcançada se sua função mapeada puder ser dividida em duas partes: uma que consome muito tempo e outra que consome memória. Nesse caso, você pode encadear suas transformações como a seguir:

dataset.map(time_consuming_mapping).cache().map(memory_consuming_mapping)

Dessa forma, a parte demorada só é executada durante a primeira época e você evita usar muito espaço de cache.

Resumo das melhores práticas

Aqui está um resumo das práticas recomendadas para projetar pipelines de entrada do TensorFlow de alto desempenho:

Reproduzindo as figuras

Para ir mais fundo na tf.data.Dataset compreensão API, você pode jogar com seus próprios oleodutos. Abaixo está o código usado para plotar as imagens deste guia. Pode ser um bom ponto de partida, mostrando algumas soluções alternativas para dificuldades comuns, como:

  • Reprodutibilidade do tempo de execução
  • Execução rápida de funções mapeadas
  • interleave exigível transformação
import itertools
from collections import defaultdict

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

O conjunto de dados

Semelhante ao ArtificialDataset você pode construir um conjunto de dados retornando o tempo gasto em cada etapa.

class TimeMeasuredDataset(tf.data.Dataset):
    # OUTPUT: (steps, timings, counters)
    OUTPUT_TYPES = (tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32)
    OUTPUT_SHAPES = ((2, 1), (2, 2), (2, 3))

    _INSTANCES_COUNTER = itertools.count()  # Number of datasets generated
    _EPOCHS_COUNTER = defaultdict(itertools.count)  # Number of epochs done for each dataset

    def _generator(instance_idx, num_samples):
        epoch_idx = next(TimeMeasuredDataset._EPOCHS_COUNTER[instance_idx])

        # Opening the file
        open_enter = time.perf_counter()
        time.sleep(0.03)
        open_elapsed = time.perf_counter() - open_enter

        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            read_enter = time.perf_counter()
            time.sleep(0.015)
            read_elapsed = time.perf_counter() - read_enter

            yield (
                [("Open",), ("Read",)],
                [(open_enter, open_elapsed), (read_enter, read_elapsed)],
                [(instance_idx, epoch_idx, -1), (instance_idx, epoch_idx, sample_idx)]
            )
            open_enter, open_elapsed = -1., -1.  # Negative values will be filtered


    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_types=cls.OUTPUT_TYPES,
            output_shapes=cls.OUTPUT_SHAPES,
            args=(next(cls._INSTANCES_COUNTER), num_samples)
        )

Este conjunto de dados fornece amostras de forma [[2, 1], [2, 2], [2, 3]] e do tipo [tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32] . Cada amostra é:

(
  [("Open"), ("Read")],
  [(t0, d), (t0, d)],
  [(i, e, -1), (i, e, s)]
)

Onde:

  • Open e Read são passos identificadores
  • t0 é o carimbo, quando a etapa correspondente começou
  • d é o tempo gasto na etapa correspondente
  • i é o índice instância
  • e é o índice época (número de vezes que o conjunto de dados foi iterada)
  • s é o índice da amostra de

O loop de iteração

Torne o loop de iteração um pouco mais complicado para agregar todos os tempos. Isso só funcionará com conjuntos de dados gerando amostras conforme detalhado acima.

def timelined_benchmark(dataset, num_epochs=2):
    # Initialize accumulators
    steps_acc = tf.zeros([0, 1], dtype=tf.dtypes.string)
    times_acc = tf.zeros([0, 2], dtype=tf.dtypes.float32)
    values_acc = tf.zeros([0, 3], dtype=tf.dtypes.int32)

    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        epoch_enter = time.perf_counter()
        for (steps, times, values) in dataset:
            # Record dataset preparation informations
            steps_acc = tf.concat((steps_acc, steps), axis=0)
            times_acc = tf.concat((times_acc, times), axis=0)
            values_acc = tf.concat((values_acc, values), axis=0)

            # Simulate training time
            train_enter = time.perf_counter()
            time.sleep(0.01)
            train_elapsed = time.perf_counter() - train_enter

            # Record training informations
            steps_acc = tf.concat((steps_acc, [["Train"]]), axis=0)
            times_acc = tf.concat((times_acc, [(train_enter, train_elapsed)]), axis=0)
            values_acc = tf.concat((values_acc, [values[-1]]), axis=0)

        epoch_elapsed = time.perf_counter() - epoch_enter
        # Record epoch informations
        steps_acc = tf.concat((steps_acc, [["Epoch"]]), axis=0)
        times_acc = tf.concat((times_acc, [(epoch_enter, epoch_elapsed)]), axis=0)
        values_acc = tf.concat((values_acc, [[-1, epoch_num, -1]]), axis=0)
        time.sleep(0.001)

    tf.print("Execution time:", time.perf_counter() - start_time)
    return {"steps": steps_acc, "times": times_acc, "values": values_acc}

O método de plotagem

Finalmente, definir uma função capaz de traçar uma linha de tempo dado os valores devolvidos pelo timelined_benchmark função.

def draw_timeline(timeline, title, width=0.5, annotate=False, save=False):
    # Remove invalid entries (negative times, or empty steps) from the timelines
    invalid_mask = np.logical_and(timeline['times'] > 0, timeline['steps'] != b'')[:,0]
    steps = timeline['steps'][invalid_mask].numpy()
    times = timeline['times'][invalid_mask].numpy()
    values = timeline['values'][invalid_mask].numpy()

    # Get a set of different steps, ordered by the first time they are encountered
    step_ids, indices = np.stack(np.unique(steps, return_index=True))
    step_ids = step_ids[np.argsort(indices)]

    # Shift the starting time to 0 and compute the maximal time value
    min_time = times[:,0].min()
    times[:,0] = (times[:,0] - min_time)
    end = max(width, (times[:,0]+times[:,1]).max() + 0.01)

    cmap = mpl.cm.get_cmap("plasma")
    plt.close()
    fig, axs = plt.subplots(len(step_ids), sharex=True, gridspec_kw={'hspace': 0})
    fig.suptitle(title)
    fig.set_size_inches(17.0, len(step_ids))
    plt.xlim(-0.01, end)

    for i, step in enumerate(step_ids):
        step_name = step.decode()
        ax = axs[i]
        ax.set_ylabel(step_name)
        ax.set_ylim(0, 1)
        ax.set_yticks([])
        ax.set_xlabel("time (s)")
        ax.set_xticklabels([])
        ax.grid(which="both", axis="x", color="k", linestyle=":")

        # Get timings and annotation for the given step
        entries_mask = np.squeeze(steps==step)
        serie = np.unique(times[entries_mask], axis=0)
        annotations = values[entries_mask]

        ax.broken_barh(serie, (0, 1), color=cmap(i / len(step_ids)), linewidth=1, alpha=0.66)
        if annotate:
            for j, (start, width) in enumerate(serie):
                annotation = "\n".join([f"{l}: {v}" for l,v in zip(("i", "e", "s"), annotations[j])])
                ax.text(start + 0.001 + (0.001 * (j % 2)), 0.55 - (0.1 * (j % 2)), annotation,
                        horizontalalignment='left', verticalalignment='center')
    if save:
        plt.savefig(title.lower().translate(str.maketrans(" ", "_")) + ".svg")

Use wrappers para função mapeada

Para executar a função mapeado em um contexto ansioso, você tem que envolvê-los dentro de um tf.py_function chamada.

def map_decorator(func):
    def wrapper(steps, times, values):
        # Use a tf.py_function to prevent auto-graph from compiling the method
        return tf.py_function(
            func,
            inp=(steps, times, values),
            Tout=(steps.dtype, times.dtype, values.dtype)
        )
    return wrapper

Comparação de pipelines

_batch_map_num_items = 50

def dataset_generator_fun(*args):
    return TimeMeasuredDataset(num_samples=_batch_map_num_items)

Ingênuo

@map_decorator
def naive_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001)  # Time consuming step
    time.sleep(0.0001)  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, [["Map"]]), axis=0),
        tf.concat((times, [[map_enter, map_elapsed]]), axis=0),
        tf.concat((values, [values[-1]]), axis=0)
    )

naive_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .flat_map(dataset_generator_fun)
    .map(naive_map)
    .batch(_batch_map_num_items, drop_remainder=True)
    .unbatch(),
    5
)
WARNING:tensorflow:From /tmp/ipykernel_23983/64197174.py:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_types is deprecated and will be removed in a future version.
Instructions for updating:
Use output_signature instead
WARNING:tensorflow:From /tmp/ipykernel_23983/64197174.py:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_shapes is deprecated and will be removed in a future version.
Instructions for updating:
Use output_signature instead
Execution time: 13.13538893499981

Otimizado

@map_decorator
def time_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001 * values.shape[0])  # Time consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, tf.tile([[["1st map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


@map_decorator
def memory_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.0001 * values.shape[0])  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    # Use tf.tile to handle batch dimension
    return (
        tf.concat((steps, tf.tile([[["2nd map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


optimized_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .interleave(  # Parallelize data reading
        dataset_generator_fun,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .batch(  # Vectorize your mapped function
        _batch_map_num_items,
        drop_remainder=True)
    .map(  # Parallelize map transformation
        time_consuming_map,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .cache()  # Cache data
    .map(  # Reduce memory usage
        memory_consuming_map,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .prefetch(  # Overlap producer and consumer works
        tf.data.AUTOTUNE
    )
    .unbatch(),
    5
)
Execution time: 6.723691489999965
draw_timeline(naive_timeline, "Naive", 15)

png

draw_timeline(optimized_timeline, "Optimized", 15)

png