Melhor desempenho com a API tf.data

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno

Visão geral

GPUs e TPUs podem reduzir radicalmente o tempo necessário para executar uma única etapa de treinamento. Alcançar o desempenho máximo requer um pipeline de entrada eficiente que fornece dados para a próxima etapa antes que a etapa atual seja concluída. A API tf.data ajuda a construir pipelines de entrada flexíveis e eficientes. Este documento demonstra como usar a API tf.data para criar pipelines de entrada do TensorFlow de alto desempenho.

Antes de continuar, verifique o guia Build TensorFlow input pipelines para aprender a usar a API tf.data .

Recursos

Configurar

import tensorflow as tf

import time

Ao longo deste guia, você irá iterar em um conjunto de dados e medir o desempenho. Fazer benchmarks de desempenho reproduzíveis pode ser difícil. Diferentes fatores que afetam a reprodutibilidade incluem:

  • A carga atual da CPU
  • O tráfego da rede
  • Mecanismos complexos, como cache

Para obter um benchmark reproduzível, você construirá um exemplo artificial.

O conjunto de dados

Comece definindo uma classe herdada detf.data.Dataset chamada ArtificialDataset . Este conjunto de dados:

  • Gera amostras num_samples (o padrão é 3)
  • Dorme por algum tempo antes do primeiro item para simular a abertura de um arquivo
  • Dorme por algum tempo antes de produzir cada item para simular a leitura de dados de um arquivo
class ArtificialDataset(tf.data.Dataset):
    def _generator(num_samples):
        # Opening the file
        time.sleep(0.03)

        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            time.sleep(0.015)

            yield (sample_idx,)

    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),
            args=(num_samples,)
        )

Este conjunto de dados é semelhante ao tf.data.Dataset.range , adicionando um atraso fixo no início e entre cada amostra.

O ciclo de treinamento

Em seguida, escreva um loop de treinamento fictício que mede quanto tempo leva para iterar em um conjunto de dados. O tempo de treinamento é simulado.

def benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        for sample in dataset:
            # Performing a training step
            time.sleep(0.01)
    print("Execution time:", time.perf_counter() - start_time)

Otimize o desempenho

Para mostrar como o desempenho pode ser otimizado, você melhorará o desempenho do ArtificialDataset .

A abordagem ingênua

Comece com um pipeline ingênuo, sem truques, iterando o conjunto de dados no estado em que se encontra.

benchmark(ArtificialDataset())
Execution time: 0.2541472299999441

Nos bastidores, é assim que seu tempo de execução foi gasto:

Gráfico de tempo de execução de dados - um método ingênuo

O gráfico mostra que a realização de uma etapa de treinamento envolve:

  • Abrindo um arquivo se ele ainda não foi aberto
  • Buscando uma entrada de dados do arquivo
  • Usando os dados para treinamento

No entanto, em uma implementação síncrona ingênua como esta, enquanto o pipeline busca os dados, o modelo fica ocioso. Por outro lado, enquanto seu modelo está treinando, o pipeline de entrada fica ocioso. O tempo da etapa de treinamento é, portanto, a soma dos tempos de abertura, leitura e treinamento.

As próximas seções se baseiam neste pipeline de entrada, ilustrando as práticas recomendadas para projetar pipelines de entrada do TensorFlow de alto desempenho.

Pré-busca

A pré-busca se sobrepõe ao pré-processamento e à execução do modelo de uma etapa de treinamento. Enquanto o modelo está executando as etapas de treinamento s , o pipeline de entrada está lendo os dados para a etapa s+1 . Isso reduz o tempo da etapa ao máximo (em oposição à soma) do treinamento e o tempo que leva para extrair os dados.

A API tf.data fornece a transformação tf.data.Dataset.prefetch . Ele pode ser usado para separar o momento em que os dados são produzidos do momento em que os dados são consumidos. Em particular, a transformação usa um thread de segundo plano e um buffer interno para pré-buscar elementos do conjunto de dados de entrada antes do momento em que são solicitados. O número de elementos para pré-busca deve ser igual (ou possivelmente maior do que) o número de lotes consumidos por uma única etapa de treinamento. Você pode ajustar manualmente esse valor ou configurá-lo para tf.data.AUTOTUNE , o que solicitará que o tempo de execução tf.data ajuste o valor dinamicamente no tempo de execução.

Observe que a transformação de pré-busca fornece benefícios sempre que houver uma oportunidade de sobrepor o trabalho de um "produtor" ao trabalho de um "consumidor".

benchmark(
    ArtificialDataset()
    .prefetch(tf.data.AUTOTUNE)
)
Execution time: 0.20805208699994182

Gráfico de tempo de execução de dados - método de pré-busca

Agora, como mostra o gráfico de tempo de execução de dados, enquanto a etapa de treinamento está em execução para a amostra 0, o pipeline de entrada está lendo os dados para a amostra 1 e assim por diante.

Paralelizando a extração de dados

Em uma configuração do mundo real, os dados de entrada podem ser armazenados remotamente (por exemplo, no Google Cloud Storage ou HDFS). Um pipeline de conjunto de dados que funciona bem ao ler dados localmente pode se tornar um gargalo de E / S ao ler dados remotamente devido às seguintes diferenças entre armazenamento local e remoto:

  • Tempo para o primeiro byte : Ler o primeiro byte de um arquivo do armazenamento remoto pode levar muito mais tempo do que no armazenamento local.
  • Taxa de transferência de leitura : embora o armazenamento remoto geralmente ofereça grande largura de banda agregada, a leitura de um único arquivo pode utilizar apenas uma pequena fração dessa largura de banda.

Além disso, uma vez que os bytes brutos são carregados na memória, também pode ser necessário desserializar e / ou descriptografar os dados (por exemplo, protobuf ), o que requer computação adicional. Essa sobrecarga está presente independentemente de os dados serem armazenados local ou remotamente, mas pode ser pior no caso remoto se os dados não forem pré-buscados de forma eficaz.

Para mitigar o impacto das várias sobrecargas de extração de dados, a transformação tf.data.Dataset.interleave pode ser usada para paralelizar a etapa de carregamento de dados, intercalando o conteúdo de outros conjuntos de dados (como leitores de arquivos de dados). O número de conjuntos de dados a sobrepor pode ser especificado pelo argumento cycle_length , enquanto o nível de paralelismo pode ser especificado pelo argumento num_parallel_calls . Semelhante à transformação de prefetch - prefetch , a transformação de interleave oferece suporte a tf.data.AUTOTUNE , que delegará a decisão sobre o nível de paralelismo a ser usado para o tempo de execução tf.data .

Intercalação sequencial

Os argumentos padrão da transformação tf.data.Dataset.interleave fazem com que ela intercale amostras únicas de dois conjuntos de dados sequencialmente.

benchmark(
    tf.data.Dataset.range(2)
    .interleave(lambda _: ArtificialDataset())
)
Execution time: 0.4883518669998921

Gráfico de tempo de execução de dados - intercalação sequencial

Este gráfico de tempo de execução de dados permite exibir o comportamento da transformação interleave , buscando amostras alternativamente dos dois conjuntos de dados disponíveis. No entanto, nenhuma melhoria de desempenho está envolvida aqui.

Intercalação paralela

Agora, use o argumento num_parallel_calls da transformação de interleave . Isso carrega vários conjuntos de dados em paralelo, reduzindo o tempo de espera para os arquivos serem abertos.

benchmark(
    tf.data.Dataset.range(2)
    .interleave(
        lambda _: ArtificialDataset(),
        num_parallel_calls=tf.data.AUTOTUNE
    )
)
Execution time: 0.26920967700016263

Gráfico de tempo de execução de dados - método de intercalação paralela

Desta vez, como mostra o gráfico do tempo de execução dos dados, a leitura dos dois conjuntos de dados é paralelizada, reduzindo o tempo global de processamento dos dados.

Paralelizando a transformação de dados

Ao preparar dados, os elementos de entrada podem precisar ser pré-processados. Para esse fim, a API tf.data oferece a transformação tf.data.Dataset.map , que aplica uma função definida pelo usuário a cada elemento do conjunto de dados de entrada. Como os elementos de entrada são independentes uns dos outros, o pré-processamento pode ser paralelizado em vários núcleos de CPU. Para tornar isso possível, de maneira semelhante às transformações de prefetch - prefetch e interleave , a transformação do map fornece o argumento num_parallel_calls para especificar o nível de paralelismo.

A escolha do melhor valor para o argumento num_parallel_calls depende do seu hardware, das características dos dados de treinamento (como tamanho e forma), do custo da função de mapa e de que outro processamento está acontecendo na CPU ao mesmo tempo. Uma heurística simples é usar o número de núcleos de CPU disponíveis. No entanto, quanto à transformação de prefetch - prefetch e interleave , a transformação do map suporta tf.data.AUTOTUNE que delegará a decisão sobre qual nível de paralelismo usar para o tempo de execução tf.data .

def mapped_function(s):
    # Do some hard pre-processing
    tf.py_function(lambda: time.sleep(0.03), [], ())
    return s

Mapeamento sequencial

Comece usando a transformação do map sem paralelismo como um exemplo de linha de base.

benchmark(
    ArtificialDataset()
    .map(mapped_function)
)
Execution time: 0.4379127629999857

Gráfico de tempo de execução de dados - método de mapeamento sequencial

Quanto à abordagem ingênua , aqui, como mostra o gráfico, os tempos gastos para as etapas de abertura, leitura, pré-processamento (mapeamento) e treinamento somam-se para uma única iteração.

Mapeamento paralelo

Agora, use a mesma função de pré-processamento, mas aplique-a em paralelo em várias amostras.

benchmark(
    ArtificialDataset()
    .map(
        mapped_function,
        num_parallel_calls=tf.data.AUTOTUNE
    )
)
Execution time: 0.2747970279999663

Tempo de execução de dados - mapeamento paralelo

Como o gráfico de dados demonstra, as etapas de pré-processamento se sobrepõem, reduzindo o tempo geral para uma única iteração.

Cache

A transformação tf.data.Dataset.cache pode armazenar em cache um conjunto de dados, na memória ou no armazenamento local. Isso evitará que algumas operações (como abertura de arquivo e leitura de dados) sejam executadas durante cada época.

benchmark(
    ArtificialDataset()
    .map(  # Apply time consuming operations before cache
        mapped_function
    ).cache(
    ),
    5
)
Execution time: 0.3715158390000397

Tempo de execução de dados - método de conjunto de dados em cache

Aqui, o gráfico de tempo de execução de dados mostra que quando você armazena em cache um conjunto de dados, as transformações antes do conjunto de cache (como a abertura de arquivo e leitura de dados) são executadas apenas durante a primeira época. As próximas épocas reutilizarão os dados armazenados em cache pela transformação do cache .

Se a função definida pelo usuário passada para a transformação do map for cara, aplique a transformação do cache após a transformação do map , desde que o conjunto de dados resultante ainda possa caber na memória ou armazenamento local. Se a função definida pelo usuário aumentar o espaço necessário para armazenar o conjunto de dados além da capacidade do cache, aplique-o após a transformação do cache ou considere o pré-processamento de seus dados antes do trabalho de treinamento para reduzir o uso de recursos.

Mapeamento de vetorização

Invocar uma função definida pelo usuário passada para a transformação do map tem sobrecarga relacionada ao agendamento e execução da função definida pelo usuário. Vectorize a função definida pelo usuário (ou seja, faça com que ela opere em um lote de entradas de uma vez) e aplique a transformação de batch antes da transformação do map .

Para ilustrar essa boa prática, seu conjunto de dados artificial não é adequado. O atraso de agendamento é de cerca de 10 microssegundos (10e-6 segundos), muito menos do que as dezenas de milissegundos usados ​​no ArtificialDataset e, portanto, seu impacto é difícil de ver.

Para este exemplo, use a função base tf.data.Dataset.range e simplifique o loop de treinamento em sua forma mais simples.

fast_dataset = tf.data.Dataset.range(10000)

def fast_benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for _ in tf.data.Dataset.range(num_epochs):
        for _ in dataset:
            pass
    tf.print("Execution time:", time.perf_counter() - start_time)

def increment(x):
    return x+1

Mapeamento escalar

fast_benchmark(
    fast_dataset
    # Apply function one item at a time
    .map(increment)
    # Batch
    .batch(256)
)
Execution time: 0.9082538790000854

Tempo de execução de dados - método de mapa escalar

O gráfico acima ilustra o que está acontecendo (com menos amostras) usando o método de mapeamento escalar. Mostra que a função mapeada é aplicada a cada amostra. Embora essa função seja muito rápida, ela tem alguma sobrecarga que afeta o desempenho do tempo.

Mapeamento vetorizado

fast_benchmark(
    fast_dataset
    .batch(256)
    # Apply function on a batch of items
    # The tf.Tensor.__add__ method already handle batches
    .map(increment)
)
Execution time: 0.03624614399996062

Tempo de execução de dados - método de mapa vetorizado

Desta vez, a função mapeada é chamada uma vez e se aplica a um lote de amostra. Como mostra o gráfico de tempo de execução de dados, embora a função possa levar mais tempo para ser executada, a sobrecarga aparece apenas uma vez, melhorando o desempenho geral do tempo.

Reduzindo o consumo de memória

Diversas transformações, incluindo interleave , prefetch e shuffle , mantêm um buffer interno de elementos. Se a função definida pelo usuário passada para a transformação do map alterar o tamanho dos elementos, a ordem da transformação do mapa e as transformações que armazenam os elementos do buffer afetam o uso da memória. Em geral, escolha a ordem que resulta em menor consumo de memória, a menos que uma ordem diferente seja desejável para o desempenho.

Cache de cálculos parciais

Recomenda-se armazenar em cache o conjunto de dados após a transformação do map exceto se essa transformação tornar os dados muito grandes para caber na memória. Uma compensação pode ser alcançada se sua função mapeada puder ser dividida em duas partes: uma que consome muito tempo e outra que consome memória. Nesse caso, você pode encadear suas transformações como a seguir:

dataset.map(time_consuming_mapping).cache().map(memory_consuming_mapping)

Dessa forma, a parte demorada só é executada durante a primeira época e você evita usar muito espaço de cache.

Resumo das melhores práticas

Aqui está um resumo das práticas recomendadas para projetar pipelines de entrada do TensorFlow de alto desempenho:

Reproduzindo as figuras

Para se aprofundar no entendimento da APItf.data.Dataset , você pode brincar com seus próprios pipelines. Abaixo está o código usado para plotar as imagens deste guia. Pode ser um bom ponto de partida, mostrando algumas soluções alternativas para dificuldades comuns, como:

  • Reprodutibilidade do tempo de execução
  • Execução rápida de funções mapeadas
  • transformação de interleave chamável
import itertools
from collections import defaultdict

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

O conjunto de dados

Semelhante ao ArtificialDataset você pode construir um conjunto de dados retornando o tempo gasto em cada etapa.

class TimeMeasuredDataset(tf.data.Dataset):
    # OUTPUT: (steps, timings, counters)
    OUTPUT_TYPES = (tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32)
    OUTPUT_SHAPES = ((2, 1), (2, 2), (2, 3))

    _INSTANCES_COUNTER = itertools.count()  # Number of datasets generated
    _EPOCHS_COUNTER = defaultdict(itertools.count)  # Number of epochs done for each dataset

    def _generator(instance_idx, num_samples):
        epoch_idx = next(TimeMeasuredDataset._EPOCHS_COUNTER[instance_idx])

        # Opening the file
        open_enter = time.perf_counter()
        time.sleep(0.03)
        open_elapsed = time.perf_counter() - open_enter

        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            read_enter = time.perf_counter()
            time.sleep(0.015)
            read_elapsed = time.perf_counter() - read_enter

            yield (
                [("Open",), ("Read",)],
                [(open_enter, open_elapsed), (read_enter, read_elapsed)],
                [(instance_idx, epoch_idx, -1), (instance_idx, epoch_idx, sample_idx)]
            )
            open_enter, open_elapsed = -1., -1.  # Negative values will be filtered


    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_types=cls.OUTPUT_TYPES,
            output_shapes=cls.OUTPUT_SHAPES,
            args=(next(cls._INSTANCES_COUNTER), num_samples)
        )

Este conjunto de dados fornece amostras de forma [[2, 1], [2, 2], [2, 3]] e do tipo [tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32] . Cada amostra é:

(
  [("Open"), ("Read")],
  [(t0, d), (t0, d)],
  [(i, e, -1), (i, e, s)]
)

Onde:

  • Open e Read são identificadores de etapas
  • t0 é o carimbo de data / hora quando a etapa correspondente começou
  • d é o tempo gasto na etapa correspondente
  • i é o índice de instância
  • e é o índice de época (número de vezes que o conjunto de dados foi iterado)
  • s é o índice da amostra

O loop de iteração

Torne o loop de iteração um pouco mais complicado para agregar todos os tempos. Isso só funcionará com conjuntos de dados gerando amostras conforme detalhado acima.

def timelined_benchmark(dataset, num_epochs=2):
    # Initialize accumulators
    steps_acc = tf.zeros([0, 1], dtype=tf.dtypes.string)
    times_acc = tf.zeros([0, 2], dtype=tf.dtypes.float32)
    values_acc = tf.zeros([0, 3], dtype=tf.dtypes.int32)

    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        epoch_enter = time.perf_counter()
        for (steps, times, values) in dataset:
            # Record dataset preparation informations
            steps_acc = tf.concat((steps_acc, steps), axis=0)
            times_acc = tf.concat((times_acc, times), axis=0)
            values_acc = tf.concat((values_acc, values), axis=0)

            # Simulate training time
            train_enter = time.perf_counter()
            time.sleep(0.01)
            train_elapsed = time.perf_counter() - train_enter

            # Record training informations
            steps_acc = tf.concat((steps_acc, [["Train"]]), axis=0)
            times_acc = tf.concat((times_acc, [(train_enter, train_elapsed)]), axis=0)
            values_acc = tf.concat((values_acc, [values[-1]]), axis=0)

        epoch_elapsed = time.perf_counter() - epoch_enter
        # Record epoch informations
        steps_acc = tf.concat((steps_acc, [["Epoch"]]), axis=0)
        times_acc = tf.concat((times_acc, [(epoch_enter, epoch_elapsed)]), axis=0)
        values_acc = tf.concat((values_acc, [[-1, epoch_num, -1]]), axis=0)
        time.sleep(0.001)

    tf.print("Execution time:", time.perf_counter() - start_time)
    return {"steps": steps_acc, "times": times_acc, "values": values_acc}

O método de plotagem

Finalmente, defina uma função capaz de traçar uma linha do tempo dados os valores retornados pela função timelined_benchmark .

def draw_timeline(timeline, title, width=0.5, annotate=False, save=False):
    # Remove invalid entries (negative times, or empty steps) from the timelines
    invalid_mask = np.logical_and(timeline['times'] > 0, timeline['steps'] != b'')[:,0]
    steps = timeline['steps'][invalid_mask].numpy()
    times = timeline['times'][invalid_mask].numpy()
    values = timeline['values'][invalid_mask].numpy()

    # Get a set of different steps, ordered by the first time they are encountered
    step_ids, indices = np.stack(np.unique(steps, return_index=True))
    step_ids = step_ids[np.argsort(indices)]

    # Shift the starting time to 0 and compute the maximal time value
    min_time = times[:,0].min()
    times[:,0] = (times[:,0] - min_time)
    end = max(width, (times[:,0]+times[:,1]).max() + 0.01)

    cmap = mpl.cm.get_cmap("plasma")
    plt.close()
    fig, axs = plt.subplots(len(step_ids), sharex=True, gridspec_kw={'hspace': 0})
    fig.suptitle(title)
    fig.set_size_inches(17.0, len(step_ids))
    plt.xlim(-0.01, end)

    for i, step in enumerate(step_ids):
        step_name = step.decode()
        ax = axs[i]
        ax.set_ylabel(step_name)
        ax.set_ylim(0, 1)
        ax.set_yticks([])
        ax.set_xlabel("time (s)")
        ax.set_xticklabels([])
        ax.grid(which="both", axis="x", color="k", linestyle=":")

        # Get timings and annotation for the given step
        entries_mask = np.squeeze(steps==step)
        serie = np.unique(times[entries_mask], axis=0)
        annotations = values[entries_mask]

        ax.broken_barh(serie, (0, 1), color=cmap(i / len(step_ids)), linewidth=1, alpha=0.66)
        if annotate:
            for j, (start, width) in enumerate(serie):
                annotation = "\n".join([f"{l}: {v}" for l,v in zip(("i", "e", "s"), annotations[j])])
                ax.text(start + 0.001 + (0.001 * (j % 2)), 0.55 - (0.1 * (j % 2)), annotation,
                        horizontalalignment='left', verticalalignment='center')
    if save:
        plt.savefig(title.lower().translate(str.maketrans(" ", "_")) + ".svg")

Use wrappers para função mapeada

Para executar a função mapeada em um contexto inicial, você deve envolvê-los dentro de uma chamada tf.py_function .

def map_decorator(func):
    def wrapper(steps, times, values):
        # Use a tf.py_function to prevent auto-graph from compiling the method
        return tf.py_function(
            func,
            inp=(steps, times, values),
            Tout=(steps.dtype, times.dtype, values.dtype)
        )
    return wrapper

Comparação de pipelines

_batch_map_num_items = 50

def dataset_generator_fun(*args):
    return TimeMeasuredDataset(num_samples=_batch_map_num_items)

Ingênuo

@map_decorator
def naive_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001)  # Time consuming step
    time.sleep(0.0001)  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, [["Map"]]), axis=0),
        tf.concat((times, [[map_enter, map_elapsed]]), axis=0),
        tf.concat((values, [values[-1]]), axis=0)
    )

naive_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .flat_map(dataset_generator_fun)
    .map(naive_map)
    .batch(_batch_map_num_items, drop_remainder=True)
    .unbatch(),
    5
)
WARNING:tensorflow:From <ipython-input-1-c85330a00c6e>:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_types is deprecated and will be removed in a future version.
Instructions for updating:
Use output_signature instead
WARNING:tensorflow:From <ipython-input-1-c85330a00c6e>:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_shapes is deprecated and will be removed in a future version.
Instructions for updating:
Use output_signature instead
Execution time: 12.445692234000035

Otimizado

@map_decorator
def time_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001 * values.shape[0])  # Time consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, tf.tile([[["1st map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


@map_decorator
def memory_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.0001 * values.shape[0])  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    # Use tf.tile to handle batch dimension
    return (
        tf.concat((steps, tf.tile([[["2nd map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


optimized_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .interleave(  # Parallelize data reading
        dataset_generator_fun,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .batch(  # Vectorize your mapped function
        _batch_map_num_items,
        drop_remainder=True)
    .map(  # Parallelize map transformation
        time_consuming_map,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .cache()  # Cache data
    .map(  # Reduce memory usage
        memory_consuming_map,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .prefetch(  # Overlap producer and consumer works
        tf.data.AUTOTUNE
    )
    .unbatch(),
    5
)
Execution time: 6.326935971000012
draw_timeline(naive_timeline, "Naive", 15)

png

draw_timeline(optimized_timeline, "Optimized", 15)

png