Catat tanggalnya! Google I / O mengembalikan 18-20 Mei Daftar sekarang
Halaman ini diterjemahkan oleh Cloud Translation API.
Switch to English

Performa lebih baik dengan API tf.data

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Gambaran

GPU dan TPU dapat secara drastis mengurangi waktu yang dibutuhkan untuk menjalankan satu langkah pelatihan. Mencapai kinerja puncak membutuhkan pipa masukan yang efisien yang mengirimkan data untuk langkah berikutnya sebelum langkah saat ini selesai. API tf.data membantu membangun pipeline input yang fleksibel dan efisien. Dokumen ini menunjukkan cara menggunakan tf.data API untuk membuat pipeline input TensorFlow yang berperforma tinggi.

Sebelum melanjutkan, periksa panduan pipeline input Build TensorFlow untuk mempelajari cara menggunakan API tf.data .

Sumber daya

Mempersiapkan

import tensorflow as tf

import time

Dalam panduan ini, Anda akan mengulang di seluruh set data dan mengukur performanya. Membuat tolok ukur kinerja yang dapat direproduksi bisa jadi sulit. Faktor berbeda yang mempengaruhi reproduktifitas meliputi:

  • Beban CPU saat ini
  • Lalu lintas jaringan
  • Mekanisme yang kompleks, seperti cache

Untuk mendapatkan tolok ukur yang dapat direproduksi, Anda akan membuat contoh buatan.

Dataset

Mulailah dengan mendefinisikan kelas yang mewarisi daritf.data.Dataset disebut ArtificialDataset . Kumpulan data ini:

  • Menghasilkan sampel num_samples (default adalah 3)
  • Tidur selama beberapa waktu sebelum item pertama untuk mensimulasikan membuka file
  • Tidur selama beberapa waktu sebelum memproduksi setiap item untuk mensimulasikan pembacaan data dari file
class ArtificialDataset(tf.data.Dataset):
    def _generator(num_samples):
        # Opening the file
        time.sleep(0.03)

        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            time.sleep(0.015)

            yield (sample_idx,)

    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),
            args=(num_samples,)
        )

tf.data.Dataset.range data ini mirip dengan yang tf.data.Dataset.range , menambahkan penundaan tetap di awal dan di antara setiap sampel.

Lingkaran pelatihan

Selanjutnya, tulis loop pelatihan dummy yang mengukur berapa lama waktu yang dibutuhkan untuk melakukan iterasi atas set data. Waktu pelatihan disimulasikan.

def benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        for sample in dataset:
            # Performing a training step
            time.sleep(0.01)
    print("Execution time:", time.perf_counter() - start_time)

Optimalkan kinerja

Untuk menunjukkan bagaimana performa dapat dioptimalkan, Anda akan meningkatkan performa dari ArtificialDataset .

Pendekatan naif

Mulailah dengan pipeline naif tanpa trik, lakukan iterasi atas kumpulan data apa adanya.

benchmark(ArtificialDataset())
Execution time: 0.2541472299999441

Di bawah tenda, beginilah waktu eksekusi Anda:

Plot waktu eksekusi data - metode yang naif

Plot menunjukkan bahwa melakukan langkah pelatihan melibatkan:

  • Membuka file jika belum dibuka
  • Mengambil entri data dari file
  • Menggunakan data untuk pelatihan

Namun, dalam implementasi sinkron yang naif seperti di sini, saat pipeline Anda mengambil data, model Anda menganggur. Sebaliknya, saat model Anda sedang dilatih, pipeline input tidak digunakan. Dengan demikian, waktu langkah pelatihan adalah jumlah waktu pembukaan, membaca, dan pelatihan.

Bagian selanjutnya mengembangkan pipeline input ini, yang menggambarkan praktik terbaik untuk mendesain pipeline input TensorFlow yang berperforma baik.

Mengambil lebih dulu

Prapengambilan tumpang tindih dengan pra-pemrosesan dan eksekusi model dari langkah pelatihan. Sementara model mengeksekusi langkah pelatihan s , pipa input membaca data untuk langkah s+1 . Melakukannya akan mengurangi waktu langkah ke maksimum (sebagai lawan dari jumlah) pelatihan dan waktu yang dibutuhkan untuk mengekstrak data.

API tf.data menyediakan transformasi tf.data.Dataset.prefetch . Ini dapat digunakan untuk memisahkan waktu saat data dihasilkan dari saat data digunakan. Secara khusus, transformasi menggunakan thread latar belakang dan buffer internal untuk mengambil elemen dari dataset input sebelum mereka diminta. Jumlah elemen yang akan diambil sebelumnya harus sama dengan (atau mungkin lebih besar dari) jumlah batch yang dikonsumsi oleh satu langkah pelatihan. Anda dapat menyetel nilai ini secara manual, atau menyetelnya ke tf.data.AUTOTUNE , yang akan meminta waktu proses tf.data menyetel nilai secara dinamis pada waktu proses.

Perhatikan bahwa transformasi prefetch memberikan manfaat setiap kali ada peluang untuk tumpang tindih antara pekerjaan "produsen" dengan pekerjaan "konsumen".

benchmark(
    ArtificialDataset()
    .prefetch(tf.data.AUTOTUNE)
)
Execution time: 0.20805208699994182

Plot waktu eksekusi data - metode prefetching

Sekarang, seperti yang ditunjukkan plot waktu eksekusi data, saat langkah pelatihan dijalankan untuk sampel 0, pipeline input sedang membaca data untuk sampel 1, dan seterusnya.

Paralelisasi ekstraksi data

Dalam setelan dunia nyata, data masukan dapat disimpan dari jarak jauh (misalnya, di Google Cloud Storage atau HDFS). Pipeline set data yang berfungsi dengan baik saat membaca data secara lokal mungkin mengalami hambatan di I / O saat membaca data dari jarak jauh karena perbedaan berikut antara penyimpanan lokal dan jarak jauh:

  • Waktu-ke-byte pertama : Membaca byte pertama dari sebuah file dari penyimpanan jarak jauh dapat memakan waktu lebih lama daripada dari penyimpanan lokal.
  • Throughput baca : Meskipun penyimpanan jarak jauh biasanya menawarkan bandwidth agregat yang besar, membaca satu file mungkin hanya dapat menggunakan sebagian kecil dari bandwidth ini.

Selain itu, setelah byte mentah dimuat ke dalam memori, mungkin juga diperlukan deserialisasi dan / atau dekripsi data (misalnya protobuf ), yang memerlukan komputasi tambahan. Overhead ini muncul terlepas dari apakah data disimpan secara lokal atau jarak jauh, tetapi bisa lebih buruk dalam kasus jarak jauh jika data tidak diambil secara efektif.

Untuk mengurangi dampak dari berbagai overhead ekstraksi data, transformasi tf.data.Dataset.interleave dapat digunakan untuk memparalelkan langkah pemuatan data, tf.data.Dataset.interleave data lain (seperti pembaca file data). Jumlah cycle_length data yang tumpang tindih dapat ditentukan dengan argumen cycle_length , sedangkan tingkat paralelisme dapat ditentukan dengan argumen num_parallel_calls . Mirip dengan transformasi prefetch , transformasi interleave mendukung tf.data.AUTOTUNE , yang akan mendelegasikan keputusan tentang tingkat paralelisme yang akan digunakan untuk runtime tf.data .

Interleave berurutan

Argumen default dari transformasi tf.data.Dataset.interleave membuatnya interleave sampel tunggal dari dua tf.data.Dataset.interleave data secara berurutan.

benchmark(
    tf.data.Dataset.range(2)
    .interleave(lambda _: ArtificialDataset())
)
Execution time: 0.4883518669998921

Plot waktu eksekusi data - interleave berurutan

Plot waktu eksekusi data ini memungkinkan untuk menunjukkan perilaku transformasi interleave , mengambil sampel sebagai alternatif dari dua set data yang tersedia. Namun, tidak ada peningkatan kinerja yang terlibat di sini.

Interleave paralel

Sekarang, gunakan argumen num_parallel_calls dari transformasi interleave . Ini memuat banyak set data secara paralel, mengurangi waktu menunggu file dibuka.

benchmark(
    tf.data.Dataset.range(2)
    .interleave(
        lambda _: ArtificialDataset(),
        num_parallel_calls=tf.data.AUTOTUNE
    )
)
Execution time: 0.26920967700016263

Plot waktu eksekusi data - metode interleave paralel

Kali ini, seperti yang ditunjukkan plot waktu eksekusi data, pembacaan kedua dataset diparalelkan, mengurangi waktu pemrosesan data global.

Memparalelkan transformasi data

Saat menyiapkan data, elemen masukan mungkin perlu diproses sebelumnya. Untuk tujuan ini, API tf.data menawarkan transformasi tf.data.Dataset.map , yang menerapkan fungsi yang ditentukan pengguna ke setiap elemen tf.data.Dataset.map data masukan. Karena elemen input tidak bergantung satu sama lain, pra-pemrosesan dapat diparalelkan di beberapa inti CPU. Untuk memungkinkan hal ini, mirip dengan transformasi prefetch dan interleave , transformasi map menyediakan argumen num_parallel_calls untuk menentukan tingkat paralelisme.

Memilih nilai terbaik untuk argumen num_parallel_calls bergantung pada perangkat keras Anda, karakteristik data pelatihan Anda (seperti ukuran dan bentuknya), biaya fungsi peta Anda, dan pemrosesan lain apa yang terjadi pada CPU pada waktu yang sama. Heuristik sederhana adalah menggunakan jumlah inti CPU yang tersedia. Namun, untuk transformasi prefetch dan interleave , transformasi map mendukung tf.data.AUTOTUNE yang akan mendelegasikan keputusan tentang tingkat paralelisme yang akan digunakan untuk runtime tf.data .

def mapped_function(s):
    # Do some hard pre-processing
    tf.py_function(lambda: time.sleep(0.03), [], ())
    return s

Pemetaan berurutan

Mulailah dengan menggunakan transformasi map tanpa paralelisme sebagai contoh dasar.

benchmark(
    ArtificialDataset()
    .map(mapped_function)
)
Execution time: 0.4379127629999857

Plot waktu eksekusi data - metode pemetaan sekuensial

Adapun pendekatan naif , di sini, seperti yang ditunjukkan plot, waktu yang dihabiskan untuk membuka, membaca, pra-pemrosesan (pemetaan) dan langkah-langkah pelatihan dijumlahkan untuk satu iterasi.

Pemetaan paralel

Sekarang, gunakan fungsi pra-pemrosesan yang sama tetapi terapkan secara paralel pada beberapa sampel.

benchmark(
    ArtificialDataset()
    .map(
        mapped_function,
        num_parallel_calls=tf.data.AUTOTUNE
    )
)
Execution time: 0.2747970279999663

Waktu eksekusi data - pemetaan paralel

Seperti yang ditunjukkan plot data, langkah-langkah pra-pemrosesan tumpang tindih, mengurangi waktu keseluruhan untuk satu iterasi.

Caching

Transformasi tf.data.Dataset.cache dapat tf.data.Dataset.cache -cache tf.data.Dataset.cache data, baik di memori atau di penyimpanan lokal. Ini akan menghemat beberapa operasi (seperti pembukaan file dan pembacaan data) agar tidak dijalankan selama setiap periode.

benchmark(
    ArtificialDataset()
    .map(  # Apply time consuming operations before cache
        mapped_function
    ).cache(
    ),
    5
)
Execution time: 0.3715158390000397

Waktu eksekusi data - metode set data yang di-cache

Di sini, plot waktu eksekusi data menunjukkan bahwa ketika Anda meng-cache set data, transformasi sebelum cache (seperti pembukaan file dan pembacaan data) dijalankan hanya selama epoch pertama. Epoch berikutnya akan menggunakan kembali data yang di-cache oleh transformasi cache .

Jika fungsi yang ditentukan pengguna yang diteruskan ke transformasi map mahal, terapkan transformasi cache setelah transformasi map selama dataset yang dihasilkan masih bisa masuk ke dalam memori atau penyimpanan lokal. Jika fungsi yang ditentukan pengguna meningkatkan ruang yang diperlukan untuk menyimpan set data di luar kapasitas cache, terapkan setelah transformasi cache atau pertimbangkan pemrosesan awal data Anda sebelum tugas pelatihan Anda untuk mengurangi penggunaan sumber daya.

Pemetaan vektorisasi

Memanggil fungsi yang ditentukan pengguna yang diteruskan ke transformasi map memiliki overhead terkait dengan penjadwalan dan pelaksanaan fungsi yang ditentukan pengguna. Lakukan vektorisasi fungsi yang ditentukan pengguna (yaitu, operasikan pada batch input sekaligus) dan terapkan transformasi batch sebelum transformasi map .

Untuk menggambarkan praktik yang baik ini, kumpulan data buatan Anda tidak sesuai. Penundaan penjadwalan adalah sekitar 10 mikrodetik (10e-6 detik), jauh lebih kecil dari puluhan milidetik yang digunakan dalam ArtificialDataset , dan dengan demikian dampaknya sulit dilihat.

Untuk contoh ini, gunakan fungsi dasar tf.data.Dataset.range dan sederhanakan loop pelatihan ke bentuk yang paling sederhana.

fast_dataset = tf.data.Dataset.range(10000)

def fast_benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for _ in tf.data.Dataset.range(num_epochs):
        for _ in dataset:
            pass
    tf.print("Execution time:", time.perf_counter() - start_time)

def increment(x):
    return x+1

Pemetaan skalar

fast_benchmark(
    fast_dataset
    # Apply function one item at a time
    .map(increment)
    # Batch
    .batch(256)
)
Execution time: 0.9082538790000854

Waktu eksekusi data - metode peta skalar

Plot di atas menggambarkan apa yang sedang terjadi (dengan lebih sedikit sampel) menggunakan metode pemetaan skalar. Ini menunjukkan bahwa fungsi yang dipetakan diterapkan untuk setiap sampel. Meskipun fungsi ini sangat cepat, ada beberapa overhead yang mempengaruhi kinerja waktu.

Pemetaan vektor

fast_benchmark(
    fast_dataset
    .batch(256)
    # Apply function on a batch of items
    # The tf.Tensor.__add__ method already handle batches
    .map(increment)
)
Execution time: 0.03624614399996062

Waktu eksekusi data - metode peta vektor

Kali ini, fungsi yang dipetakan dipanggil sekali dan berlaku untuk sekumpulan sampel. Seperti yang ditunjukkan plot waktu eksekusi data, sementara fungsi dapat membutuhkan lebih banyak waktu untuk dieksekusi, overhead hanya muncul sekali, meningkatkan kinerja waktu secara keseluruhan.

Mengurangi jejak memori

Sejumlah transformasi, termasuk interleave , prefetch , dan shuffle , mempertahankan buffer internal elemen. Jika fungsi yang ditentukan pengguna diteruskan ke transformasi map mengubah ukuran elemen, maka urutan transformasi peta dan transformasi yang elemen penyangga mempengaruhi penggunaan memori. Secara umum, pilih urutan yang menghasilkan footprint memori yang lebih rendah, kecuali jika urutan yang berbeda diinginkan untuk performa.

Caching komputasi parsial

Direkomendasikan untuk menyimpan set data ke dalam cache setelah transformasi map kecuali jika transformasi ini membuat data terlalu besar untuk dimasukkan ke dalam memori. Pertukaran dapat dicapai jika fungsi yang dipetakan dapat dibagi menjadi dua bagian: satu bagian yang memakan waktu dan bagian yang memakan memori. Dalam hal ini, Anda dapat merangkai transformasi Anda seperti di bawah ini:

dataset.map(time_consuming_mapping).cache().map(memory_consuming_mapping)

Dengan cara ini, bagian yang memakan waktu hanya dijalankan selama epoch pertama, dan Anda menghindari penggunaan terlalu banyak ruang cache.

Ringkasan praktik terbaik

Berikut adalah ringkasan praktik terbaik untuk mendesain pipeline input TensorFlow yang berperforma baik:

Mereproduksi angka-angka itu

Untuk lebih mendalami pemahaman APItf.data.Dataset , Anda bisa bermain-main dengan pipeline Anda sendiri. Di bawah ini adalah kode yang digunakan untuk memplot gambar dari panduan ini. Ini bisa menjadi titik awal yang baik, menunjukkan beberapa solusi untuk kesulitan umum seperti:

  • Reproduksibilitas waktu eksekusi
  • Fungsi yang dipetakan ingin dieksekusi
  • transformasi interleave callable
import itertools
from collections import defaultdict

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

Dataset

Mirip dengan ArtificialDataset Anda dapat membuat set data yang mengembalikan waktu yang dihabiskan di setiap langkah.

class TimeMeasuredDataset(tf.data.Dataset):
    # OUTPUT: (steps, timings, counters)
    OUTPUT_TYPES = (tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32)
    OUTPUT_SHAPES = ((2, 1), (2, 2), (2, 3))

    _INSTANCES_COUNTER = itertools.count()  # Number of datasets generated
    _EPOCHS_COUNTER = defaultdict(itertools.count)  # Number of epochs done for each dataset

    def _generator(instance_idx, num_samples):
        epoch_idx = next(TimeMeasuredDataset._EPOCHS_COUNTER[instance_idx])

        # Opening the file
        open_enter = time.perf_counter()
        time.sleep(0.03)
        open_elapsed = time.perf_counter() - open_enter

        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            read_enter = time.perf_counter()
            time.sleep(0.015)
            read_elapsed = time.perf_counter() - read_enter

            yield (
                [("Open",), ("Read",)],
                [(open_enter, open_elapsed), (read_enter, read_elapsed)],
                [(instance_idx, epoch_idx, -1), (instance_idx, epoch_idx, sample_idx)]
            )
            open_enter, open_elapsed = -1., -1.  # Negative values will be filtered


    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_types=cls.OUTPUT_TYPES,
            output_shapes=cls.OUTPUT_SHAPES,
            args=(next(cls._INSTANCES_COUNTER), num_samples)
        )

Kumpulan data ini memberikan contoh bentuk [[2, 1], [2, 2], [2, 3]] dan jenis [tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32] . Setiap sampel adalah:

(
  [("Open"), ("Read")],
  [(t0, d), (t0, d)],
  [(i, e, -1), (i, e, s)]
)

Dimana:

  • Open dan Read adalah pengenal langkah
  • t0 adalah stempel waktu saat langkah terkait dimulai
  • d adalah waktu yang dihabiskan untuk langkah yang sesuai
  • i adalah indeks instance
  • e adalah indeks epoch (berapa kali dataset diiterasi)
  • s adalah indeks sampel

Loop iterasi

Buat loop iterasi sedikit lebih rumit untuk menggabungkan semua pengaturan waktu. Ini hanya akan berfungsi dengan kumpulan data yang menghasilkan sampel seperti yang dijelaskan di atas.

def timelined_benchmark(dataset, num_epochs=2):
    # Initialize accumulators
    steps_acc = tf.zeros([0, 1], dtype=tf.dtypes.string)
    times_acc = tf.zeros([0, 2], dtype=tf.dtypes.float32)
    values_acc = tf.zeros([0, 3], dtype=tf.dtypes.int32)

    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        epoch_enter = time.perf_counter()
        for (steps, times, values) in dataset:
            # Record dataset preparation informations
            steps_acc = tf.concat((steps_acc, steps), axis=0)
            times_acc = tf.concat((times_acc, times), axis=0)
            values_acc = tf.concat((values_acc, values), axis=0)

            # Simulate training time
            train_enter = time.perf_counter()
            time.sleep(0.01)
            train_elapsed = time.perf_counter() - train_enter

            # Record training informations
            steps_acc = tf.concat((steps_acc, [["Train"]]), axis=0)
            times_acc = tf.concat((times_acc, [(train_enter, train_elapsed)]), axis=0)
            values_acc = tf.concat((values_acc, [values[-1]]), axis=0)

        epoch_elapsed = time.perf_counter() - epoch_enter
        # Record epoch informations
        steps_acc = tf.concat((steps_acc, [["Epoch"]]), axis=0)
        times_acc = tf.concat((times_acc, [(epoch_enter, epoch_elapsed)]), axis=0)
        values_acc = tf.concat((values_acc, [[-1, epoch_num, -1]]), axis=0)
        time.sleep(0.001)

    tf.print("Execution time:", time.perf_counter() - start_time)
    return {"steps": steps_acc, "times": times_acc, "values": values_acc}

Metode plot

Terakhir, tentukan fungsi yang dapat memplot garis waktu dengan nilai yang dikembalikan oleh fungsi timelined_benchmark .

def draw_timeline(timeline, title, width=0.5, annotate=False, save=False):
    # Remove invalid entries (negative times, or empty steps) from the timelines
    invalid_mask = np.logical_and(timeline['times'] > 0, timeline['steps'] != b'')[:,0]
    steps = timeline['steps'][invalid_mask].numpy()
    times = timeline['times'][invalid_mask].numpy()
    values = timeline['values'][invalid_mask].numpy()

    # Get a set of different steps, ordered by the first time they are encountered
    step_ids, indices = np.stack(np.unique(steps, return_index=True))
    step_ids = step_ids[np.argsort(indices)]

    # Shift the starting time to 0 and compute the maximal time value
    min_time = times[:,0].min()
    times[:,0] = (times[:,0] - min_time)
    end = max(width, (times[:,0]+times[:,1]).max() + 0.01)

    cmap = mpl.cm.get_cmap("plasma")
    plt.close()
    fig, axs = plt.subplots(len(step_ids), sharex=True, gridspec_kw={'hspace': 0})
    fig.suptitle(title)
    fig.set_size_inches(17.0, len(step_ids))
    plt.xlim(-0.01, end)

    for i, step in enumerate(step_ids):
        step_name = step.decode()
        ax = axs[i]
        ax.set_ylabel(step_name)
        ax.set_ylim(0, 1)
        ax.set_yticks([])
        ax.set_xlabel("time (s)")
        ax.set_xticklabels([])
        ax.grid(which="both", axis="x", color="k", linestyle=":")

        # Get timings and annotation for the given step
        entries_mask = np.squeeze(steps==step)
        serie = np.unique(times[entries_mask], axis=0)
        annotations = values[entries_mask]

        ax.broken_barh(serie, (0, 1), color=cmap(i / len(step_ids)), linewidth=1, alpha=0.66)
        if annotate:
            for j, (start, width) in enumerate(serie):
                annotation = "\n".join([f"{l}: {v}" for l,v in zip(("i", "e", "s"), annotations[j])])
                ax.text(start + 0.001 + (0.001 * (j % 2)), 0.55 - (0.1 * (j % 2)), annotation,
                        horizontalalignment='left', verticalalignment='center')
    if save:
        plt.savefig(title.lower().translate(str.maketrans(" ", "_")) + ".svg")

Gunakan pembungkus untuk fungsi yang dipetakan

Untuk menjalankan fungsi yang dipetakan dalam konteks yang menarik, Anda harus membungkusnya di dalam panggilan tf.py_function .

def map_decorator(func):
    def wrapper(steps, times, values):
        # Use a tf.py_function to prevent auto-graph from compiling the method
        return tf.py_function(
            func,
            inp=(steps, times, values),
            Tout=(steps.dtype, times.dtype, values.dtype)
        )
    return wrapper

Perbandingan saluran pipa

_batch_map_num_items = 50

def dataset_generator_fun(*args):
    return TimeMeasuredDataset(num_samples=_batch_map_num_items)

Naif

@map_decorator
def naive_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001)  # Time consuming step
    time.sleep(0.0001)  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, [["Map"]]), axis=0),
        tf.concat((times, [[map_enter, map_elapsed]]), axis=0),
        tf.concat((values, [values[-1]]), axis=0)
    )

naive_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .flat_map(dataset_generator_fun)
    .map(naive_map)
    .batch(_batch_map_num_items, drop_remainder=True)
    .unbatch(),
    5
)
WARNING:tensorflow:From <ipython-input-1-c85330a00c6e>:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_types is deprecated and will be removed in a future version.
Instructions for updating:
Use output_signature instead
WARNING:tensorflow:From <ipython-input-1-c85330a00c6e>:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_shapes is deprecated and will be removed in a future version.
Instructions for updating:
Use output_signature instead
Execution time: 12.445692234000035

Dioptimalkan

@map_decorator
def time_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001 * values.shape[0])  # Time consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, tf.tile([[["1st map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


@map_decorator
def memory_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.0001 * values.shape[0])  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    # Use tf.tile to handle batch dimension
    return (
        tf.concat((steps, tf.tile([[["2nd map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


optimized_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .interleave(  # Parallelize data reading
        dataset_generator_fun,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .batch(  # Vectorize your mapped function
        _batch_map_num_items,
        drop_remainder=True)
    .map(  # Parallelize map transformation
        time_consuming_map,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .cache()  # Cache data
    .map(  # Reduce memory usage
        memory_consuming_map,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .prefetch(  # Overlap producer and consumer works
        tf.data.AUTOTUNE
    )
    .unbatch(),
    5
)
Execution time: 6.326935971000012
draw_timeline(naive_timeline, "Naive", 15)

png

draw_timeline(optimized_timeline, "Optimized", 15)

png