Halaman ini diterjemahkan oleh Cloud Translation API.
Switch to English

Performa lebih baik dengan API tf.data

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Gambaran

GPU dan TPU dapat secara radikal mengurangi waktu yang dibutuhkan untuk menjalankan satu langkah pelatihan. Mencapai kinerja puncak membutuhkan pipa masukan efisien yang mengirimkan data untuk langkah berikutnya sebelum langkah saat ini selesai. API tf.data membantu membangun pipeline input yang fleksibel dan efisien. Dokumen ini menunjukkan cara menggunakan tf.data API untuk membuat pipeline masukan TensorFlow yang berperforma tinggi.

Sebelum melanjutkan, baca panduan " Membuat pipeline input TensorFlow ", untuk mempelajari cara menggunakan API tf.data .

Sumber daya

Mempersiapkan

import tensorflow as tf

import time

Dalam panduan ini, Anda akan mengulang di seluruh set data dan mengukur performanya. Sulit untuk membuat tolok ukur kinerja yang dapat direproduksi, karena ada berbagai faktor yang memengaruhinya:

  • beban CPU saat ini,
  • lalu lintas jaringan,
  • mekanisme kompleks seperti cache, dll.

Karenanya, untuk memberikan tolok ukur yang dapat direproduksi, buat contoh buatan.

Dataset

Tentukan kelas yang mewarisi dari tf.data.Dataset disebut ArtificialDataset . Kumpulan data ini:

  • menghasilkan sampel num_samples (standarnya 3)
  • tidur untuk beberapa waktu sebelum item pertama untuk mensimulasikan membuka file
  • tertidur selama beberapa waktu sebelum memproduksi setiap item untuk mensimulasikan pembacaan data dari sebuah file
class ArtificialDataset(tf.data.Dataset):
    def _generator(num_samples):
        # Opening the file
        time.sleep(0.03)
        
        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            time.sleep(0.015)
            
            yield (sample_idx,)
    
    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_types=tf.dtypes.int64,
            output_shapes=(1,),
            args=(num_samples,)
        )

tf.data.Dataset.range data ini mirip dengan kumpulan data tf.data.Dataset.range , menambahkan penundaan tetap di awal dan di antara setiap sampel.

Lingkaran pelatihan

Tulis loop pelatihan dummy yang mengukur berapa lama waktu yang dibutuhkan untuk melakukan iterasi atas set data. Waktu pelatihan disimulasikan.

def benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        for sample in dataset:
            # Performing a training step
            time.sleep(0.01)
    tf.print("Execution time:", time.perf_counter() - start_time)

Optimalkan kinerja

Untuk menunjukkan bagaimana performa bisa dioptimalkan, Anda akan meningkatkan performa dari ArtificialDataset .

Pendekatan naif

Mulailah dengan pipeline naif tanpa trik, lakukan iterasi atas kumpulan data apa adanya.

benchmark(ArtificialDataset())
Execution time: 0.2530532629998561

Di balik terpal, beginilah waktu eksekusi Anda:

Naif

Anda dapat melihat bahwa melakukan langkah pelatihan melibatkan:

  • membuka file jika belum dibuka,
  • mengambil entri data dari file,
  • menggunakan data untuk pelatihan.

Namun, dalam implementasi sinkron yang naif seperti di sini, saat pipeline Anda mengambil data, model Anda menganggur. Sebaliknya, saat model Anda sedang dilatih, pipeline input tidak digunakan. Dengan demikian, waktu langkah pelatihan adalah jumlah dari semua, waktu pembukaan, membaca dan pelatihan.

Bagian selanjutnya mengembangkan pipeline input ini, yang menggambarkan praktik terbaik untuk mendesain pipeline input TensorFlow yang berperforma baik.

Mengambil lebih dulu

Prapengambilan tumpang tindih dengan pra-pemrosesan dan eksekusi model dari langkah pelatihan. Sementara model mengeksekusi langkah pelatihan s , pipa input membaca data untuk langkah s+1 . Melakukannya mengurangi waktu langkah ke maksimum (sebagai lawan dari jumlah) pelatihan dan waktu yang diperlukan untuk mengekstrak data.

API tf.data menyediakan transformasi tf.data.Dataset.prefetch . Ini dapat digunakan untuk memisahkan waktu saat data diproduksi dari saat data digunakan. Secara khusus, transformasi menggunakan thread latar belakang dan buffer internal untuk mengambil lebih dulu elemen dari dataset input sebelum mereka diminta. Jumlah elemen yang akan diambil sebelumnya harus sama dengan (atau mungkin lebih besar dari) jumlah batch yang dikonsumsi oleh satu langkah pelatihan. Anda dapat menyetel nilai ini secara manual, atau menyetelnya ke tf.data.experimental.AUTOTUNE yang akan meminta waktu proses tf.data menyetel nilai secara dinamis pada waktu proses.

Perhatikan bahwa transformasi prefetch memberikan manfaat setiap kali ada peluang untuk tumpang tindih antara pekerjaan "produsen" dengan pekerjaan "konsumen".

benchmark(
    ArtificialDataset()
    .prefetch(tf.data.experimental.AUTOTUNE)
)
Execution time: 0.20858672200006367

Sudah diambil sebelumnya

Kali ini Anda dapat melihat bahwa saat langkah pelatihan dijalankan untuk sampel 0, pipeline input sedang membaca data untuk sampel 1, dan seterusnya.

Paralelisasi ekstraksi data

Dalam pengaturan dunia nyata, data masukan dapat disimpan dari jarak jauh (misalnya, GCS atau HDFS). Pipeline set data yang berfungsi dengan baik saat membaca data secara lokal mungkin mengalami hambatan di I / O saat membaca data dari jarak jauh karena perbedaan berikut antara penyimpanan lokal dan jarak jauh:

  • Waktu-ke-byte pertama: Membaca byte pertama dari sebuah file dari penyimpanan jauh dapat memakan waktu lipat lebih lama daripada dari penyimpanan lokal.
  • Throughput baca: Meskipun penyimpanan jarak jauh biasanya menawarkan bandwidth agregat yang besar, membaca satu file mungkin hanya dapat memanfaatkan sebagian kecil bandwidth ini.

Selain itu, setelah byte mentah dimuat ke dalam memori, mungkin juga diperlukan deserialisasi dan / atau dekripsi data (misalnya protobuf ), yang memerlukan komputasi tambahan. Overhead ini muncul terlepas dari apakah data disimpan secara lokal atau jarak jauh, tetapi bisa lebih buruk dalam kasus jarak jauh jika data tidak diambil secara efektif.

Untuk mengurangi dampak dari berbagai overhead ekstraksi data, transformasi tf.data.Dataset.interleave dapat digunakan untuk memparalelkan langkah pemuatan data, tf.data.Dataset.interleave data lain (seperti pembaca file data). Jumlah cycle_length data yang tumpang tindih dapat ditentukan dengan argumen cycle_length , sedangkan tingkat paralelisme dapat ditentukan dengan argumen num_parallel_calls . Mirip dengan transformasi prefetch , transformasi interleave mendukung tf.data.experimental.AUTOTUNE yang akan mendelegasikan keputusan tentang tingkat paralelisme apa yang akan digunakan ke runtime tf.data .

Interleave berurutan

Argumen default transformasi tf.data.Dataset.interleave membuatnya interleave sampel tunggal dari dua tf.data.Dataset.interleave data secara berurutan.

benchmark(
    tf.data.Dataset.range(2)
    .interleave(ArtificialDataset)
)
Execution time: 0.2373930549999841

Interleave berurutan

Plot ini memungkinkan untuk menunjukkan perilaku transformasi interleave , mengambil sampel sebagai alternatif dari dua dataset yang tersedia. Namun, tidak ada peningkatan kinerja yang terlibat di sini.

Interleave paralel

Sekarang gunakan argumen num_parallel_calls dari transformasi interleave . Ini memuat banyak set data secara paralel, mengurangi waktu menunggu file dibuka.

benchmark(
    tf.data.Dataset.range(2)
    .interleave(
        ArtificialDataset,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
)
Execution time: 0.1730301249999684

Interleave paralel

Kali ini, pembacaan kedua dataset diparalelkan, sehingga mengurangi waktu pemrosesan data global.

Paralelisasi transformasi data

Saat menyiapkan data, elemen masukan mungkin perlu diproses sebelumnya. Untuk tujuan ini, API tf.data menawarkan transformasi tf.data.Dataset.map , yang menerapkan fungsi yang ditentukan pengguna ke setiap elemen tf.data.Dataset.map data masukan. Karena elemen input tidak bergantung satu sama lain, pra-pemrosesan dapat diparalelkan di beberapa inti CPU. Untuk memungkinkan hal ini, sama seperti transformasi prefetch dan interleave , transformasi map menyediakan argumen num_parallel_calls untuk menentukan tingkat paralelisme.

Memilih nilai terbaik untuk argumen num_parallel_calls bergantung pada perangkat keras Anda, karakteristik data pelatihan Anda (seperti ukuran dan bentuknya), biaya fungsi peta Anda, dan pemrosesan lain apa yang terjadi pada CPU pada waktu yang sama. Heuristik sederhana adalah dengan menggunakan jumlah inti CPU yang tersedia. Namun, untuk transformasi prefetch dan interleave , transformasi map mendukung tf.data.experimental.AUTOTUNE yang akan mendelegasikan keputusan tentang tingkat paralelisme apa yang akan digunakan untuk runtime tf.data .

def mapped_function(s):
    # Do some hard pre-processing
    tf.py_function(lambda: time.sleep(0.03), [], ())
    return s

Pemetaan berurutan

Mulailah dengan menggunakan transformasi map tanpa paralelisme sebagai contoh dasar.

benchmark(
    ArtificialDataset()
    .map(mapped_function)
)
Execution time: 0.43913738300011573

Pemetaan berurutan

Sedangkan untuk pendekatan naif , di sini waktu yang dihabiskan untuk membuka, membaca, pra-pemrosesan (pemetaan), dan langkah-langkah pelatihan dijumlahkan untuk satu iterasi.

Pemetaan paralel

Sekarang, gunakan fungsi pra-pemrosesan yang sama tetapi terapkan secara paralel pada beberapa sampel.

benchmark(
    ArtificialDataset()
    .map(
        mapped_function,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
)
Execution time: 0.2730358689998411

Pemetaan paralel

Sekarang, Anda dapat melihat di plot bahwa langkah-langkah pra-pemrosesan tumpang tindih, mengurangi waktu keseluruhan untuk satu iterasi.

Caching

Transformasi tf.data.Dataset.cache bisa tf.data.Dataset.cache -cache tf.data.Dataset.cache data, baik di memori atau di penyimpanan lokal. Ini akan menghemat beberapa operasi (seperti pembukaan file dan pembacaan data) agar tidak dijalankan selama setiap periode.

benchmark(
    ArtificialDataset()
    .map(  # Apply time consuming operations before cache
        mapped_function
    ).cache(
    ),
    5
)
Execution time: 0.36568501300007483

Set data dalam cache

Saat Anda menyimpan set data ke cache, transformasi sebelum cache (seperti pembukaan file dan pembacaan data) dijalankan hanya selama epoch pertama. Epoch berikutnya akan menggunakan kembali data yang di-cache oleh transformasi cache .

Jika fungsi yang ditentukan pengguna yang diteruskan ke transformasi map mahal, terapkan transformasi cache setelah transformasi map selama dataset yang dihasilkan masih bisa masuk ke dalam memori atau penyimpanan lokal. Jika fungsi yang ditentukan pengguna meningkatkan ruang yang diperlukan untuk menyimpan set data di luar kapasitas cache, terapkan setelah transformasi cache atau pertimbangkan pemrosesan awal data Anda sebelum tugas pelatihan Anda untuk mengurangi penggunaan sumber daya.

Pemetaan vektorisasi

Memanggil fungsi yang ditentukan pengguna yang diteruskan ke transformasi map memiliki overhead terkait dengan penjadwalan dan pelaksanaan fungsi yang ditentukan pengguna. Kami merekomendasikan memvektorisasi fungsi yang ditentukan pengguna (yaitu, membuatnya beroperasi pada sejumlah input sekaligus) dan menerapkan transformasi batch sebelum transformasi map .

Untuk mengilustrasikan praktik yang baik ini, kumpulan data buatan Anda tidak sesuai. Penundaan penjadwalan adalah sekitar 10 mikrodetik (10e-6 detik), jauh lebih kecil dari puluhan milidetik yang digunakan dalam Set Data ArtificialDataset , dan dengan demikian dampaknya sulit dilihat.

Untuk contoh ini, gunakan fungsi dasar tf.data.Dataset.range dan sederhanakan loop pelatihan ke bentuk yang paling sederhana.

fast_dataset = tf.data.Dataset.range(10000)

def fast_benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for _ in tf.data.Dataset.range(num_epochs):
        for _ in dataset:
            pass
    tf.print("Execution time:", time.perf_counter() - start_time)
    
def increment(x):
    return x+1

Pemetaan skalar

fast_benchmark(
    fast_dataset
    # Apply function one item at a time
    .map(increment)
    # Batch
    .batch(256)
)
Execution time: 0.8861004689999845

Peta skalar

Plot di atas menggambarkan apa yang sedang terjadi (dengan lebih sedikit sampel). Anda dapat melihat bahwa fungsi yang dipetakan diterapkan untuk setiap sampel. Meskipun fungsi ini sangat cepat, ia memiliki beberapa overhead yang mempengaruhi kinerja waktu.

Pemetaan vektor

fast_benchmark(
    fast_dataset
    .batch(256)
    # Apply function on a batch of items
    # The tf.Tensor.__add__ method already handle batches
    .map(increment)
)
Execution time: 0.032729552000091644

Peta vektor

Kali ini, fungsi yang dipetakan dipanggil sekali dan berlaku untuk sekumpulan sampel. Meskipun fungsi tersebut membutuhkan lebih banyak waktu untuk dijalankan, overhead hanya muncul sekali, meningkatkan kinerja waktu secara keseluruhan.

Mengurangi jejak memori

Sejumlah transformasi, termasuk interleave , prefetch , dan shuffle , mempertahankan buffer internal elemen. Jika fungsi yang ditentukan pengguna diteruskan ke transformasi map mengubah ukuran elemen, maka urutan transformasi peta dan transformasi yang elemen penyangga mempengaruhi penggunaan memori. Secara umum, kami menganjurkan memilih urutan yang menghasilkan footprint memori yang lebih rendah, kecuali jika urutan yang berbeda diinginkan untuk performa.

Caching komputasi parsial

Direkomendasikan untuk menyimpan dataset setelah transformasi map kecuali jika transformasi ini membuat data terlalu besar untuk muat dalam memori. Pertukaran dapat dicapai jika fungsi yang dipetakan dapat dibagi menjadi dua bagian: satu bagian yang memakan waktu dan bagian yang memakan memori. Dalam hal ini, Anda dapat merangkai transformasi Anda seperti di bawah ini:

dataset.map(time_consuming_mapping).cache().map(memory_consuming_mapping)

Dengan cara ini, bagian yang memakan waktu hanya dijalankan selama epoch pertama, dan Anda menghindari penggunaan terlalu banyak ruang cache.

Ringkasan praktik terbaik

Berikut adalah ringkasan praktik terbaik untuk mendesain pipeline input TensorFlow yang berperforma baik:

Mereproduksi angka

Untuk lebih mendalami pemahaman API tf.data.Dataset , Anda bisa bermain-main dengan pipeline Anda sendiri. Di bawah ini adalah kode yang digunakan untuk memplot gambar dari panduan ini. Ini bisa menjadi titik awal yang baik, menunjukkan beberapa solusi untuk kesulitan umum seperti:

  • Reproduksibilitas waktu eksekusi;
  • Fungsi yang dipetakan ingin dieksekusi;
  • transformasi interleave callable.
import itertools
from collections import defaultdict

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

Dataset

Mirip dengan ArtificialDataset Anda dapat membuat set data yang mengembalikan waktu yang dihabiskan di setiap langkah.

class TimeMeasuredDataset(tf.data.Dataset):
    # OUTPUT: (steps, timings, counters)
    OUTPUT_TYPES = (tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32)
    OUTPUT_SHAPES = ((2, 1), (2, 2), (2, 3))
    
    _INSTANCES_COUNTER = itertools.count()  # Number of datasets generated
    _EPOCHS_COUNTER = defaultdict(itertools.count)  # Number of epochs done for each dataset
    
    def _generator(instance_idx, num_samples):
        epoch_idx = next(TimeMeasuredDataset._EPOCHS_COUNTER[instance_idx])
        
        # Opening the file
        open_enter = time.perf_counter()
        time.sleep(0.03)
        open_elapsed = time.perf_counter() - open_enter
        
        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            read_enter = time.perf_counter()
            time.sleep(0.015)
            read_elapsed = time.perf_counter() - read_enter
            
            yield (
                [("Open",), ("Read",)],
                [(open_enter, open_elapsed), (read_enter, read_elapsed)],
                [(instance_idx, epoch_idx, -1), (instance_idx, epoch_idx, sample_idx)]
            )
            open_enter, open_elapsed = -1., -1.  # Negative values will be filtered
            
    
    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_types=cls.OUTPUT_TYPES,
            output_shapes=cls.OUTPUT_SHAPES,
            args=(next(cls._INSTANCES_COUNTER), num_samples)
        )

Kumpulan data ini memberikan contoh bentuk [[2, 1], [2, 2], [2, 3]] dan jenis [tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32] . Setiap sampel adalah:

(
  [("Open"), ("Read")],
  [(t0, d), (t0, d)],
  [(i, e, -1), (i, e, s)]
)

Dimana:

  • Open dan Read adalah pengenal langkah
  • t0 adalah stempel waktu saat langkah terkait dimulai
  • d adalah waktu yang dihabiskan untuk langkah yang sesuai
  • i adalah indeks instance
  • e adalah indeks epoch (berapa kali dataset diiterasi)
  • s adalah indeks sampel

Loop iterasi

Jadikan loop iterasi sedikit lebih rumit untuk menggabungkan semua pengaturan waktu. Ini hanya akan berfungsi dengan kumpulan data yang menghasilkan sampel seperti yang dijelaskan di atas.

def timelined_benchmark(dataset, num_epochs=2):
    # Initialize accumulators
    steps_acc = tf.zeros([0, 1], dtype=tf.dtypes.string)
    times_acc = tf.zeros([0, 2], dtype=tf.dtypes.float32)
    values_acc = tf.zeros([0, 3], dtype=tf.dtypes.int32)
    
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        epoch_enter = time.perf_counter()
        for (steps, times, values) in dataset:
            # Record dataset preparation informations
            steps_acc = tf.concat((steps_acc, steps), axis=0)
            times_acc = tf.concat((times_acc, times), axis=0)
            values_acc = tf.concat((values_acc, values), axis=0)
            
            # Simulate training time
            train_enter = time.perf_counter()
            time.sleep(0.01)
            train_elapsed = time.perf_counter() - train_enter
            
            # Record training informations
            steps_acc = tf.concat((steps_acc, [["Train"]]), axis=0)
            times_acc = tf.concat((times_acc, [(train_enter, train_elapsed)]), axis=0)
            values_acc = tf.concat((values_acc, [values[-1]]), axis=0)
        
        epoch_elapsed = time.perf_counter() - epoch_enter
        # Record epoch informations
        steps_acc = tf.concat((steps_acc, [["Epoch"]]), axis=0)
        times_acc = tf.concat((times_acc, [(epoch_enter, epoch_elapsed)]), axis=0)
        values_acc = tf.concat((values_acc, [[-1, epoch_num, -1]]), axis=0)
        time.sleep(0.001)
    
    tf.print("Execution time:", time.perf_counter() - start_time)
    return {"steps": steps_acc, "times": times_acc, "values": values_acc}

Metode plot

Terakhir, tentukan fungsi yang bisa memplot garis waktu dengan nilai yang dikembalikan oleh fungsi timelined_benchmark .

def draw_timeline(timeline, title, width=0.5, annotate=False, save=False):
    # Remove invalid entries (negative times, or empty steps) from the timelines
    invalid_mask = np.logical_and(timeline['times'] > 0, timeline['steps'] != b'')[:,0]
    steps = timeline['steps'][invalid_mask].numpy()
    times = timeline['times'][invalid_mask].numpy()
    values = timeline['values'][invalid_mask].numpy()
    
    # Get a set of different steps, ordered by the first time they are encountered
    step_ids, indices = np.stack(np.unique(steps, return_index=True))
    step_ids = step_ids[np.argsort(indices)]

    # Shift the starting time to 0 and compute the maximal time value
    min_time = times[:,0].min()
    times[:,0] = (times[:,0] - min_time)
    end = max(width, (times[:,0]+times[:,1]).max() + 0.01)
    
    cmap = mpl.cm.get_cmap("plasma")
    plt.close()
    fig, axs = plt.subplots(len(step_ids), sharex=True, gridspec_kw={'hspace': 0})
    fig.suptitle(title)
    fig.set_size_inches(17.0, len(step_ids))
    plt.xlim(-0.01, end)
    
    for i, step in enumerate(step_ids):
        step_name = step.decode()
        ax = axs[i]
        ax.set_ylabel(step_name)
        ax.set_ylim(0, 1)
        ax.set_yticks([])
        ax.set_xlabel("time (s)")
        ax.set_xticklabels([])
        ax.grid(which="both", axis="x", color="k", linestyle=":")
        
        # Get timings and annotation for the given step
        entries_mask = np.squeeze(steps==step)
        serie = np.unique(times[entries_mask], axis=0)
        annotations = values[entries_mask]
        
        ax.broken_barh(serie, (0, 1), color=cmap(i / len(step_ids)), linewidth=1, alpha=0.66)
        if annotate:
            for j, (start, width) in enumerate(serie):
                annotation = "\n".join([f"{l}: {v}" for l,v in zip(("i", "e", "s"), annotations[j])])
                ax.text(start + 0.001 + (0.001 * (j % 2)), 0.55 - (0.1 * (j % 2)), annotation,
                        horizontalalignment='left', verticalalignment='center')
    if save:
        plt.savefig(title.lower().translate(str.maketrans(" ", "_")) + ".svg")

Gunakan pembungkus untuk fungsi yang dipetakan

Untuk menjalankan fungsi yang dipetakan dalam konteks yang menarik, Anda harus membungkusnya di dalam panggilan tf.py_function .

def map_decorator(func):
    def wrapper(steps, times, values):
        # Use a tf.py_function to prevent auto-graph from compiling the method
        return tf.py_function(
            func,
            inp=(steps, times, values),
            Tout=(steps.dtype, times.dtype, values.dtype)
        )
    return wrapper

Perbandingan saluran pipa

_batch_map_num_items = 50

def dataset_generator_fun(*args):
    return TimeMeasuredDataset(num_samples=_batch_map_num_items)

Naif

@map_decorator
def naive_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001)  # Time consuming step
    time.sleep(0.0001)  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, [["Map"]]), axis=0),
        tf.concat((times, [[map_enter, map_elapsed]]), axis=0),
        tf.concat((values, [values[-1]]), axis=0)
    )

naive_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .flat_map(dataset_generator_fun)
    .map(naive_map)
    .batch(_batch_map_num_items, drop_remainder=True)
    .unbatch(),
    5
)
Execution time: 12.436093607999965

Dioptimalkan

@map_decorator
def time_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001 * values.shape[0])  # Time consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, tf.tile([[["1st map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


@map_decorator
def memory_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.0001 * values.shape[0])  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    # Use tf.tile to handle batch dimension
    return (
        tf.concat((steps, tf.tile([[["2nd map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


optimized_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .interleave(  # Parallelize data reading
        dataset_generator_fun,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
    .batch(  # Vectorize your mapped function
        _batch_map_num_items,
        drop_remainder=True)
    .map(  # Parallelize map transformation
        time_consuming_map,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
    .cache()  # Cache data
    .map(  # Reduce memory usage
        memory_consuming_map,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
    .prefetch(  # Overlap producer and consumer works
        tf.data.experimental.AUTOTUNE
    )
    .unbatch(),
    5
)
Execution time: 6.303204500999982

draw_timeline(naive_timeline, "Naive", 15)

png

draw_timeline(optimized_timeline, "Optimized", 15)

png