日付を保存! Google I / Oが5月18日から20日に戻ってきます今すぐ登録
このページは Cloud Translation API によって翻訳されました。
Switch to English

tf.dataAPIによるパフォーマンスの向上

TensorFlow.orgで表示 GoogleColabで実行 GitHubでソースを表示ノートブックをダウンロード

概要概要

GPUとTPUは、単一のトレーニングステップの実行に必要な時間を大幅に短縮できます。ピークパフォーマンスを達成するには、現在のステップが完了する前に次のステップのデータを配信する効率的な入力パイプラインが必要です。 tf.data APIは、柔軟で効率的な入力パイプラインの構築に役立ちます。このドキュメントでは、 tf.data APIを使用して、パフォーマンスの高いTensorFlow入力パイプラインを構築する方法を示します。

続行する前に、 Build TensorFlow入力パイプラインガイドを確認して、 tf.data使用方法を確認してください。

リソース

セットアップ

import tensorflow as tf

import time

このガイド全体を通して、データセット全体を反復処理し、パフォーマンスを測定します。再現性のあるパフォーマンスベンチマークを作成するのは難しい場合があります。再現性に影響を与えるさまざまな要因は次のとおりです。

  • 現在のCPU負荷
  • ネットワークトラフィック
  • キャッシュなどの複雑なメカニズム

再現可能なベンチマークを取得するには、人工的な例を作成します。

データセット

ArtificialDatasetと呼ばれるtf.data.Datasetから継承するクラスを定義することからtf.data.Datasetます。このデータセット:

  • num_samplesサンプルを生成します(デフォルトは3)
  • ファイルを開くことをシミュレートするために、最初のアイテムの前にしばらくスリープします
  • ファイルからのデータの読み取りをシミュレートするために、各アイテムを生成する前にしばらくスリープします
class ArtificialDataset(tf.data.Dataset):
    def _generator(num_samples):
        # Opening the file
        time.sleep(0.03)

        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            time.sleep(0.015)

            yield (sample_idx,)

    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),
            args=(num_samples,)
        )

このデータセットはtf.data.Dataset.rangeデータセットに似ていtf.data.Dataset.range 、各サンプルの開始時とその間に固定の遅延が追加されています。

トレーニングループ

次に、データセットを反復処理するのにかかる時間を測定するダミーのトレーニングループを記述します。トレーニング時間はシミュレートされます。

def benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        for sample in dataset:
            # Performing a training step
            time.sleep(0.01)
    print("Execution time:", time.perf_counter() - start_time)

パフォーマンスを最適化する

パフォーマンスを最適化する方法を示すために、 ArtificialDatasetパフォーマンスを向上させます。

素朴なアプローチ

トリックを使用せずに単純なパイプラインから開始し、データセットをそのまま繰り返します。

benchmark(ArtificialDataset())
Execution time: 0.2541472299999441

内部的には、次のように実行時間が費やされました。

データ実行時間プロット-素朴な方法

プロットは、トレーニングステップの実行には以下が含まれることを示しています。

  • まだ開いていない場合はファイルを開く
  • ファイルからのデータエントリの取得
  • トレーニングのためのデータの使用

ただし、ここのような単純な同期実装では、パイプラインがデータをフェッチしている間、モデルはアイドル状態になっています。逆に、モデルがトレーニングしている間、入力パイプラインはアイドル状態になっています。したがって、トレーニングステップ時間は、開封時間、読書時間、トレーニング時間の合計です。

次のセクションでは、この入力パイプラインに基づいて、パフォーマンスの高いTensorFlow入力パイプラインを設計するためのベストプラクティスを示します。

プリフェッチ

プリフェッチは、トレーニングステップの前処理とモデル実行と重複します。モデルがトレーニングステップs実行している間、入力パイプラインはステップs+1データを読み取っています。そうすることで、ステップ時間がトレーニングの最大値(合計ではなく)に短縮され、データの抽出にかかる時間が短縮されます。

tf.data APIは、 tf.data.Dataset.prefetch変換を提供します。これは、データが生成される時間とデータが消費される時間を分離するために使用できます。特に、変換では、バックグラウンドスレッドと内部バッファーを使用して、要求される前に入力データセットから要素をプリフェッチします。プリフェッチする要素の数は、単一のトレーニングステップで消費されるバッチの数と同じ(またはそれ以上)である必要があります。この値を手動で調整するか、 tf.data.AUTOTUNEに設定すると、実行時に値を動的に調整するようにtf.dataランタイムにtf.dataれます。

プリフェッチ変換は、「プロデューサー」の作業と「コンシューマー」の作業をオーバーラップする機会があるときはいつでも利点を提供することに注意してください。

benchmark(
    ArtificialDataset()
    .prefetch(tf.data.AUTOTUNE)
)
Execution time: 0.20805208699994182

データ実行時間プロット-プリフェッチ方法

ここで、データ実行時間のプロットが示すように、トレーニングステップがサンプル0に対して実行されている間、入力パイプラインはサンプル1のデータを読み取っています。

データ抽出の並列化

実際の設定では、入力データはリモートで保存される場合があります(たとえば、Google Cloud StorageまたはHDFSに)。ローカルストレージとリモートストレージの違いにより、ローカルでデータを読み取るときに適切に機能するデータセットパイプラインは、リモートでデータを読み取るときにI / Oでボトルネックになる可能性があります。

  • 最初のバイトまでの時間:リモートストレージからファイルの最初のバイトを読み取るには、ローカルストレージからの読み取りよりも桁違いに時間がかかる場合があります。
  • 読み取りスループット:リモートストレージは通常、大きな総帯域幅を提供しますが、単一のファイルを読み取ると、この帯域幅のごく一部しか利用できない場合があります。

さらに、生のバイトがメモリにロードされると、データを逆シリアル化および/または復号化する必要がある場合があり( protobufなど)、追加の計算が必要になります。このオーバーヘッドは、データがローカルに保存されているかリモートに保存されているかに関係なく存在しますが、データが効果的にプリフェッチされない場合、リモートの場合はさらに悪化する可能性があります。

さまざまなデータ抽出オーバーヘッドの影響を軽減するために、 tf.data.Dataset.interleaveトランスフォーメーションを使用して、データロードステップを並列化し、他のデータセット(データファイルリーダーなど)のコンテンツをインターリーブできます。オーバーラップするデータセットの数はcycle_length引数で指定でき、並列処理のレベルはnum_parallel_calls引数で指定できます。 prefetch変換と同様に、 interleave変換はtf.data.AUTOTUNEサポートします。これにより、使用する並列処理のレベルに関する決定がtf.dataランタイムにtf.dataます。

シーケンシャルインターリーブ

tf.data.Dataset.interleave変換のデフォルトの引数により、2つのデータセットからの単一のサンプルが順番にインターリーブされます。

benchmark(
    tf.data.Dataset.range(2)
    .interleave(lambda _: ArtificialDataset())
)
Execution time: 0.4883518669998921

データ実行時間プロット-シーケンシャルインターリーブ

このデータ実行時間プロットにより、 interleave変換の動作を示し、使用可能な2つのデータセットからサンプルを交互にフェッチできます。ただし、ここではパフォーマンスの向上は含まれていません。

並列インターリーブ

ここで、 interleave変換のnum_parallel_calls引数を使用します。これにより、複数のデータセットが並行して読み込まれ、ファイルが開かれるのを待つ時間が短縮されます。

benchmark(
    tf.data.Dataset.range(2)
    .interleave(
        lambda _: ArtificialDataset(),
        num_parallel_calls=tf.data.AUTOTUNE
    )
)
Execution time: 0.26920967700016263

データ実行時間プロット-並列インターリーブ方式

今回は、データ実行時間のプロットが示すように、2つのデータセットの読み取りが並列化され、グローバルデータ処理時間が短縮されます。

データ変換の並列化

データを準備するとき、入力要素を前処理する必要がある場合があります。この目的のために、 tf.data APIはtf.data.Dataset.map変換を提供します。これは、入力データセットの各要素にユーザー定義関数を適用します。入力要素は互いに独立しているため、前処理を複数のCPUコア間で並列化できます。これを可能にするために、 prefetchおよびinterleave変換と同様に、 map変換は、並列処理のレベルを指定するnum_parallel_calls引数を提供します。

num_parallel_calls引数に最適な値を選択するかどうかは、ハードウェア、トレーニングデータの特性(サイズや形状など)、マップ関数のコスト、およびCPUで同時に発生している他の処理によって異なります。単純なヒューリスティックは、使用可能なCPUコアの数を使用することです。ただし、 prefetchおよびinterleave変換に関しては、 map変換はtf.data.AUTOTUNEをサポートします。これにより、使用する並列処理のレベルに関する決定がtf.dataランタイムにtf.dataます。

def mapped_function(s):
    # Do some hard pre-processing
    tf.py_function(lambda: time.sleep(0.03), [], ())
    return s

シーケンシャルマッピング

ベースラインの例として、並列処理なしのmap変換を使用することから始めます。

benchmark(
    ArtificialDataset()
    .map(mapped_function)
)
Execution time: 0.4379127629999857

データ実行時間プロット-シーケンシャルマッピング方法

素朴なアプローチに関しては、ここでは、プロットが示すように、開く、読み取る、前処理(マッピング)、およびトレーニングの各ステップに費やされる時間が1回の反復で合計されます。

並列マッピング

ここで、同じ前処理関数を使用しますが、複数のサンプルに並行して適用します。

benchmark(
    ArtificialDataset()
    .map(
        mapped_function,
        num_parallel_calls=tf.data.AUTOTUNE
    )
)
Execution time: 0.2747970279999663

データ実行時間-並列マッピング

データプロットが示すように、前処理ステップが重複し、1回の反復の全体的な時間が短縮されます。

キャッシング

tf.data.Dataset.cacheトランスフォーメーションは、メモリまたはローカルストレージのいずれかにデータセットをキャッシュできます。これにより、一部の操作(ファイルのオープンやデータの読み取りなど)が各エポック中に実行されるのを防ぐことができます。

benchmark(
    ArtificialDataset()
    .map(  # Apply time consuming operations before cache
        mapped_function
    ).cache(
    ),
    5
)
Execution time: 0.3715158390000397

データ実行時間-キャッシュされたデータセットメソッド

ここで、データ実行時間のプロットは、データセットをキャッシュすると、 cache前の変換(ファイルのオープンやデータの読み取りなど)が最初のエポックでのみ実行されることを示しています。次のエポックは、 cache変換によってcacheされたデータを再利用します。

map変換に渡されるユーザー定義関数のコストが高い場合は、結果のデータセットがメモリまたはローカルストレージに収まる限り、 map変換の後にcache変換を適用します。ユーザー定義関数によって、データセットを保存するために必要なスペースがキャッシュ容量を超えて増加する場合は、 cache変換後にデータセットを適用するか、トレーニングジョブの前にデータを前処理してリソース使用量を削減することを検討してください。

マッピングのベクトル化

map変換に渡されたユーザー定義関数を呼び出すと、ユーザー定義関数のスケジューリングと実行に関連するオーバーヘッドが発生します。ユーザー定義関数をベクトル化し(つまり、入力のバッチに対して一度に操作するようにします)、 map変換の前にbatch変換を適用します。

このグッドプラクティスを説明するために、人工データセットは適切ではありません。スケジューリングの遅延は約10マイクロ秒(10e-6秒)であり、 ArtificialDatasetで使用される数十ミリ秒よりはるかに短いため、その影響はわかりにくいです。

この例では、基本のtf.data.Dataset.range関数を使用して、トレーニングループを最も単純な形式に単純化します。

fast_dataset = tf.data.Dataset.range(10000)

def fast_benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for _ in tf.data.Dataset.range(num_epochs):
        for _ in dataset:
            pass
    tf.print("Execution time:", time.perf_counter() - start_time)

def increment(x):
    return x+1

スカラーマッピング

fast_benchmark(
    fast_dataset
    # Apply function one item at a time
    .map(increment)
    # Batch
    .batch(256)
)
Execution time: 0.9082538790000854

データ実行時間-スカラーマップ方式

上のプロットは、スカラーマッピング方法を使用して(より少ないサンプルで)何が起こっているかを示しています。これは、マップされた関数が各サンプルに適用されていることを示しています。この関数は非常に高速ですが、時間のパフォーマンスに影響を与えるオーバーヘッドがあります。

ベクトル化されたマッピング

fast_benchmark(
    fast_dataset
    .batch(256)
    # Apply function on a batch of items
    # The tf.Tensor.__add__ method already handle batches
    .map(increment)
)
Execution time: 0.03624614399996062

データ実行時間-ベクトル化されたマップメソッド

今回は、マップされた関数が1回呼び出され、サンプルのバッチに適用されます。データ実行時間のプロットが示すように、関数の実行にはさらに時間がかかる可能性がありますが、オーバーヘッドは1回だけ表示されるため、全体的な時間パフォーマンスが向上します。

メモリフットプリントの削減

interleaveprefetchshuffleなどの多くの変換は、要素の内部バッファーを維持します。 map変換に渡されたユーザー定義関数が要素のサイズを変更する場合、マップ変換の順序と要素をバッファリングする変換がメモリ使用量に影響します。一般に、パフォーマンスのために異なる順序が望ましい場合を除いて、メモリフットプリントが低くなる順序を選択します。

部分計算のキャッシュ

map変換後にデータセットをキャッシュすることをお勧めします。ただし、この変換によってデータが大きくなりすぎてメモリに収まらない場合を除きます。マップされた関数を、時間のかかる部分とメモリを消費する部分の2つの部分に分割できる場合、トレードオフを実現できます。この場合、以下のように変換を連鎖させることができます。

dataset.map(time_consuming_mapping).cache().map(memory_consuming_mapping)

このように、時間のかかる部分は最初のエポック中にのみ実行され、キャッシュスペースを使いすぎないようにします。

ベストプラクティスの概要

パフォーマンスの高いTensorFlow入力パイプラインを設計するためのベストプラクティスの概要は次のとおりです。

フィギュアの再現

tf.data.Dataset APIの理解を深めるために、独自のパイプラインをtf.data.Datasetことができます。以下は、このガイドの画像をプロットするために使用されるコードです。これは、次のような一般的な問題のいくつかの回避策を示す、良い出発点になる可能性があります。

  • 実行時間の再現性
  • マップされた関数は実行に熱心です
  • 呼び出し可能なinterleave変換
import itertools
from collections import defaultdict

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

データセット

ArtificialDatasetと同様に、各ステップで費やされた時間を返すデータセットを構築できます。

class TimeMeasuredDataset(tf.data.Dataset):
    # OUTPUT: (steps, timings, counters)
    OUTPUT_TYPES = (tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32)
    OUTPUT_SHAPES = ((2, 1), (2, 2), (2, 3))

    _INSTANCES_COUNTER = itertools.count()  # Number of datasets generated
    _EPOCHS_COUNTER = defaultdict(itertools.count)  # Number of epochs done for each dataset

    def _generator(instance_idx, num_samples):
        epoch_idx = next(TimeMeasuredDataset._EPOCHS_COUNTER[instance_idx])

        # Opening the file
        open_enter = time.perf_counter()
        time.sleep(0.03)
        open_elapsed = time.perf_counter() - open_enter

        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            read_enter = time.perf_counter()
            time.sleep(0.015)
            read_elapsed = time.perf_counter() - read_enter

            yield (
                [("Open",), ("Read",)],
                [(open_enter, open_elapsed), (read_enter, read_elapsed)],
                [(instance_idx, epoch_idx, -1), (instance_idx, epoch_idx, sample_idx)]
            )
            open_enter, open_elapsed = -1., -1.  # Negative values will be filtered


    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_types=cls.OUTPUT_TYPES,
            output_shapes=cls.OUTPUT_SHAPES,
            args=(next(cls._INSTANCES_COUNTER), num_samples)
        )

このデータセットは、形状[[2, 1], [2, 2], [2, 3]] [tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32] [[2, 1], [2, 2], [2, 3]]およびタイプ[tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32] 。各サンプルは次のとおりです。

(
  [("Open"), ("Read")],
  [(t0, d), (t0, d)],
  [(i, e, -1), (i, e, s)]
)

どこ:

  • OpenReadはステップ識別子です
  • t0は、対応するステップが開始されたときのタイムスタンプです。
  • dは、対応するステップで費やされた時間です
  • iはインスタンスインデックスです
  • eはエポックインデックス(データセットが反復された回数)
  • sはサンプルインデックスです

反復ループ

反復ループをもう少し複雑にして、すべてのタイミングを集約します。これは、上記のようにサンプルを生成するデータセットでのみ機能します。

def timelined_benchmark(dataset, num_epochs=2):
    # Initialize accumulators
    steps_acc = tf.zeros([0, 1], dtype=tf.dtypes.string)
    times_acc = tf.zeros([0, 2], dtype=tf.dtypes.float32)
    values_acc = tf.zeros([0, 3], dtype=tf.dtypes.int32)

    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        epoch_enter = time.perf_counter()
        for (steps, times, values) in dataset:
            # Record dataset preparation informations
            steps_acc = tf.concat((steps_acc, steps), axis=0)
            times_acc = tf.concat((times_acc, times), axis=0)
            values_acc = tf.concat((values_acc, values), axis=0)

            # Simulate training time
            train_enter = time.perf_counter()
            time.sleep(0.01)
            train_elapsed = time.perf_counter() - train_enter

            # Record training informations
            steps_acc = tf.concat((steps_acc, [["Train"]]), axis=0)
            times_acc = tf.concat((times_acc, [(train_enter, train_elapsed)]), axis=0)
            values_acc = tf.concat((values_acc, [values[-1]]), axis=0)

        epoch_elapsed = time.perf_counter() - epoch_enter
        # Record epoch informations
        steps_acc = tf.concat((steps_acc, [["Epoch"]]), axis=0)
        times_acc = tf.concat((times_acc, [(epoch_enter, epoch_elapsed)]), axis=0)
        values_acc = tf.concat((values_acc, [[-1, epoch_num, -1]]), axis=0)
        time.sleep(0.001)

    tf.print("Execution time:", time.perf_counter() - start_time)
    return {"steps": steps_acc, "times": times_acc, "values": values_acc}

プロット方法

最後に、 timelined_benchmark関数によって返される値を指定して、タイムラインをプロットできる関数を定義します。

def draw_timeline(timeline, title, width=0.5, annotate=False, save=False):
    # Remove invalid entries (negative times, or empty steps) from the timelines
    invalid_mask = np.logical_and(timeline['times'] > 0, timeline['steps'] != b'')[:,0]
    steps = timeline['steps'][invalid_mask].numpy()
    times = timeline['times'][invalid_mask].numpy()
    values = timeline['values'][invalid_mask].numpy()

    # Get a set of different steps, ordered by the first time they are encountered
    step_ids, indices = np.stack(np.unique(steps, return_index=True))
    step_ids = step_ids[np.argsort(indices)]

    # Shift the starting time to 0 and compute the maximal time value
    min_time = times[:,0].min()
    times[:,0] = (times[:,0] - min_time)
    end = max(width, (times[:,0]+times[:,1]).max() + 0.01)

    cmap = mpl.cm.get_cmap("plasma")
    plt.close()
    fig, axs = plt.subplots(len(step_ids), sharex=True, gridspec_kw={'hspace': 0})
    fig.suptitle(title)
    fig.set_size_inches(17.0, len(step_ids))
    plt.xlim(-0.01, end)

    for i, step in enumerate(step_ids):
        step_name = step.decode()
        ax = axs[i]
        ax.set_ylabel(step_name)
        ax.set_ylim(0, 1)
        ax.set_yticks([])
        ax.set_xlabel("time (s)")
        ax.set_xticklabels([])
        ax.grid(which="both", axis="x", color="k", linestyle=":")

        # Get timings and annotation for the given step
        entries_mask = np.squeeze(steps==step)
        serie = np.unique(times[entries_mask], axis=0)
        annotations = values[entries_mask]

        ax.broken_barh(serie, (0, 1), color=cmap(i / len(step_ids)), linewidth=1, alpha=0.66)
        if annotate:
            for j, (start, width) in enumerate(serie):
                annotation = "\n".join([f"{l}: {v}" for l,v in zip(("i", "e", "s"), annotations[j])])
                ax.text(start + 0.001 + (0.001 * (j % 2)), 0.55 - (0.1 * (j % 2)), annotation,
                        horizontalalignment='left', verticalalignment='center')
    if save:
        plt.savefig(title.lower().translate(str.maketrans(" ", "_")) + ".svg")

マップされた関数にラッパーを使用する

マップされた関数を熱心なコンテキストで実行するには、それらをtf.py_function呼び出し内にラップする必要があります。

def map_decorator(func):
    def wrapper(steps, times, values):
        # Use a tf.py_function to prevent auto-graph from compiling the method
        return tf.py_function(
            func,
            inp=(steps, times, values),
            Tout=(steps.dtype, times.dtype, values.dtype)
        )
    return wrapper

パイプラインの比較

_batch_map_num_items = 50

def dataset_generator_fun(*args):
    return TimeMeasuredDataset(num_samples=_batch_map_num_items)

ナイーブ

@map_decorator
def naive_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001)  # Time consuming step
    time.sleep(0.0001)  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, [["Map"]]), axis=0),
        tf.concat((times, [[map_enter, map_elapsed]]), axis=0),
        tf.concat((values, [values[-1]]), axis=0)
    )

naive_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .flat_map(dataset_generator_fun)
    .map(naive_map)
    .batch(_batch_map_num_items, drop_remainder=True)
    .unbatch(),
    5
)
WARNING:tensorflow:From <ipython-input-1-c85330a00c6e>:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_types is deprecated and will be removed in a future version.
Instructions for updating:
Use output_signature instead
WARNING:tensorflow:From <ipython-input-1-c85330a00c6e>:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_shapes is deprecated and will be removed in a future version.
Instructions for updating:
Use output_signature instead
Execution time: 12.445692234000035

最適化

@map_decorator
def time_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001 * values.shape[0])  # Time consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, tf.tile([[["1st map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


@map_decorator
def memory_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.0001 * values.shape[0])  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    # Use tf.tile to handle batch dimension
    return (
        tf.concat((steps, tf.tile([[["2nd map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


optimized_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .interleave(  # Parallelize data reading
        dataset_generator_fun,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .batch(  # Vectorize your mapped function
        _batch_map_num_items,
        drop_remainder=True)
    .map(  # Parallelize map transformation
        time_consuming_map,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .cache()  # Cache data
    .map(  # Reduce memory usage
        memory_consuming_map,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .prefetch(  # Overlap producer and consumer works
        tf.data.AUTOTUNE
    )
    .unbatch(),
    5
)
Execution time: 6.326935971000012
draw_timeline(naive_timeline, "Naive", 15)

png

draw_timeline(optimized_timeline, "Optimized", 15)

png