ביצועים טובים יותר עם ה- tf.data API

הצג באתר TensorFlow.org הפעל בגוגל קולאב צפה במקור ב-GitHub הורד מחברת

סקירה כללית

GPUs ו-TPUs יכולים להפחית באופן קיצוני את הזמן הדרוש לביצוע שלב אימון בודד. השגת ביצועי שיא דורשת צינור קלט יעיל המספק נתונים לשלב הבא לפני שהשלב הנוכחי יסתיים. ה-API של tf.data עוזר לבנות צינורות קלט גמישים ויעילים. מסמך זה מדגים כיצד להשתמש ב-API של tf.data לבניית צינורות קלט TensorFlow בעלי ביצועים גבוהים.

לפני שתמשיך, עיין במדריך בניית צינורות קלט של TensorFlow כדי ללמוד כיצד להשתמש בממשק ה-API של tf.data .

אֶמְצָעִי

להכין

import tensorflow as tf

import time

במהלך המדריך הזה, תעבור על מערך נתונים ותמדוד את הביצועים. יצירת מדדי ביצועים שניתנים לשחזור יכולה להיות קשה. גורמים שונים המשפיעים על יכולת השחזור כוללים:

  • עומס המעבד הנוכחי
  • תעבורת הרשת
  • מנגנונים מורכבים, כגון מטמון

כדי לקבל אמת מידה שניתן לשחזר, תבנה דוגמה מלאכותית.

מערך הנתונים

התחל עם הגדרת מחלקה היורשת מ- tf.data.Dataset בשם ArtificialDataset . מערך הנתונים הזה:

  • מייצר דגימות num_samples (ברירת המחדל היא 3)
  • שינה למשך זמן מה לפני הפריט הראשון כדי לדמות פתיחת קובץ
  • ישן במשך זמן מה לפני ייצור כל פריט כדי לדמות קריאת נתונים מקובץ
class ArtificialDataset(tf.data.Dataset):
    def _generator(num_samples):
        # Opening the file
        time.sleep(0.03)

        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            time.sleep(0.015)

            yield (sample_idx,)

    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),
            args=(num_samples,)
        )

מערך נתונים זה דומה ל- tf.data.Dataset.range אחד, ומוסיף השהייה קבועה בתחילת ובין כל דגימה.

לולאת האימון

לאחר מכן, כתוב לולאת אימון דמה המודדת כמה זמן לוקח לחזור על מערך נתונים. זמן האימון מדומה.

def benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        for sample in dataset:
            # Performing a training step
            time.sleep(0.01)
    print("Execution time:", time.perf_counter() - start_time)

מטב את הביצועים

כדי להציג כיצד ניתן לייעל את הביצועים, תשפר את הביצועים של ArtificialDataset .

הגישה הנאיבית

התחל עם צינור נאיבי ללא טריקים, חוזר על מערך הנתונים כפי שהוא.

benchmark(ArtificialDataset())
Execution time: 0.26497629899995445

מתחת למכסה המנוע, כך בילה את זמן הביצוע שלך:

עלילת זמן ביצוע נתונים - שיטה נאיבית

העלילה מראה שביצוע שלב אימון כרוך ב:

  • פתיחת קובץ אם הוא עדיין לא נפתח
  • שליפת הזנת נתונים מהקובץ
  • שימוש בנתונים לאימון

עם זאת, ביישום סינכרוני נאיבי כמו כאן, בזמן שהצינור שלך מביא את הנתונים, המודל שלך לא פעיל. לעומת זאת, בזמן שהמודל שלך מתאמן, צינור הקלט עומד בטל. זמן צעדי האימון הוא אפוא סכום זמני הפתיחה, הקריאה והאימון.

הסעיפים הבאים מתבססים על צינור קלט זה, וממחישים שיטות עבודה מומלצות לתכנון צינורות קלט TensorFlow בעלי ביצועים.

אחזור מראש

השליפה מראש חופפת את העיבוד המקדים וביצוע המודל של שלב אימון. בזמן שהמודל מבצע שלבי אימון s , צינור הקלט קורא את הנתונים עבור שלב s+1 . פעולה זו מפחיתה את זמן הצעד למקסימום (בניגוד לסכום) של האימון ואת הזמן שלוקח לחלץ את הנתונים.

ה-API של tf.data מספק את הטרנספורמציה של tf.data.Dataset.prefetch . זה יכול לשמש כדי לנתק את הזמן שבו הנתונים מופקים מהזמן שבו הנתונים נצרכים. בפרט, הטרנספורמציה משתמשת בשרשור רקע ובמאגר פנימי כדי לשלוף מראש אלמנטים ממערך הקלט לפני הזמן שהם מתבקשים. מספר האלמנטים לשליפה מראש צריך להיות שווה (או אולי גדול יותר) ממספר המנות הנצרכות בשלב אימון בודד. אתה יכול לכוונן ערך זה באופן ידני, או להגדיר אותו ל- tf.data.AUTOTUNE , אשר ינחה את זמן הריצה של tf.data לכוון את הערך באופן דינמי בזמן הריצה.

שימו לב שהטרנספורמציה של אחזור מראש מספק הטבות בכל פעם שיש הזדמנות לחפוף את עבודתו של "מפיק" לעבודה של "צרכן".

benchmark(
    ArtificialDataset()
    .prefetch(tf.data.AUTOTUNE)
)
Execution time: 0.21731788600027357

עלילת זמן ביצוע נתונים - שיטת אחזור מראש

כעת, כפי שמראה עלילת זמן ביצוע הנתונים, בזמן ששלב האימון פועל עבור מדגם 0, צינור הקלט קורא את הנתונים עבור מדגם 1, וכן הלאה.

חילוץ נתונים מקביל

בסביבה אמיתית, נתוני הקלט עשויים להיות מאוחסנים מרחוק (לדוגמה, ב-Google Cloud Storage או HDFS). צינור של מערך נתונים שעובד היטב בעת קריאת נתונים מקומית עלול להיווצר צוואר בקבוק ב-I/O בעת קריאת נתונים מרחוק, בגלל ההבדלים הבאים בין אחסון מקומי למרוחק:

  • זמן עד-בייט ראשון : קריאת הביט הראשון של קובץ מאחסון מרוחק עלולה לקחת יותר סדרי גודל מאשר מאחסון מקומי.
  • תפוקת קריאה : בעוד שאחסון מרוחק מציע בדרך כלל רוחב פס מצטבר גדול, קריאת קובץ בודד עשויה לנצל רק חלק קטן מרוחב הפס הזה.

בנוסף, ברגע שהבתים הגולמיים נטענים לזיכרון, ייתכן שיהיה צורך גם לבצע דה-סריאליזציה ו/או פענוח הנתונים (למשל protobuf ), מה שדורש חישוב נוסף. תקורה זו קיימת ללא קשר לשאלה אם הנתונים מאוחסנים באופן מקומי או מרחוק, אך עלולה להיות גרועה יותר במקרה המרוחק אם הנתונים אינם מאוחסנים מראש ביעילות.

כדי למתן את ההשפעה של תקורה למיצוי הנתונים השונים, ניתן להשתמש בטרנספורמציה tf.data.Dataset.interleave כדי להקביל את שלב טעינת הנתונים, תוך שזירה של התוכן של מערכי נתונים אחרים (כגון קוראי קבצי נתונים). ניתן לציין את מספר מערכי הנתונים לחפיפה על ידי הארגומנט cycle_length , בעוד שרמת ההקבלה יכולה להיות מוגדרת על ידי הארגומנט num_parallel_calls . בדומה לטרנספורמציה של אחזור prefetch , הטרנספורמציה של tf.data.AUTOTUNE interleave שתאציל את ההחלטה באיזו רמת מקביליות להשתמש בזמן הריצה tf.data .

שזירה רציפה

ארגומנטי ברירת המחדל של הטרנספורמציה tf.data.Dataset.interleave הופכים אותו לשילוב של דוגמאות בודדות משני מערכי נתונים ברצף.

benchmark(
    tf.data.Dataset.range(2)
    .interleave(lambda _: ArtificialDataset())
)
Execution time: 0.4987426460002098

עלילת זמן ביצוע נתונים - שזירה רציפה

עלילת זמן ביצוע נתונים זו מאפשרת להציג את התנהגות הטרנספורמציה של interleave , להביא דוגמאות לחלופין משני מערכי הנתונים הזמינים. עם זאת, לא מעורב כאן שיפור ביצועים.

שזירה מקבילה

כעת, השתמש בארגומנט num_parallel_calls של טרנספורמציה של interleave . זה טוען מספר מערכי נתונים במקביל, ומצמצם את זמן ההמתנה לפתיחת הקבצים.

benchmark(
    tf.data.Dataset.range(2)
    .interleave(
        lambda _: ArtificialDataset(),
        num_parallel_calls=tf.data.AUTOTUNE
    )
)
Execution time: 0.283668874000341

עלילת זמן ביצוע נתונים - שיטת שזירה מקבילה

הפעם, כפי שמראה עלילת זמן ביצוע הנתונים, הקריאה של שני מערכי הנתונים מתבצעת במקביל, מה שמפחית את זמן עיבוד הנתונים הגלובלי.

טרנספורמציה מקבילה של נתונים

בעת הכנת נתונים, ייתכן שיהיה צורך לעבד מראש רכיבי קלט. לשם כך, ה-API של tf.data מציע את הטרנספורמציה tf.data.Dataset.map , אשר מחילה פונקציה המוגדרת על ידי משתמש על כל רכיב של מערך הקלט. מכיוון שרכיבי קלט אינם תלויים זה בזה, ניתן לבצע במקביל את העיבוד המקדים על פני מספר ליבות CPU. כדי לאפשר זאת, בדומה לטרנספורמציות ה- prefetch ו- interleave , טרנספורמציה map מספקת את הארגומנט num_parallel_calls כדי לציין את רמת ההקבלה.

בחירת הערך הטוב ביותר עבור הארגומנט num_parallel_calls תלויה בחומרה שלך, במאפיינים של נתוני האימון שלך (כגון הגודל והצורה שלו), העלות של פונקציית המפה שלך ואיזה עיבוד אחר מתרחש ב-CPU בו-זמנית. היוריסטיקה פשוטה היא להשתמש במספר ליבות המעבד הזמינות. עם זאת, באשר לטרנספורמציה של אחזור ו- interleave , הטרנספורמציה של map תומכת ב- prefetch tf.data.AUTOTUNE את ההחלטה באיזו רמת מקביליות להשתמש בזמן הריצה tf.data .

def mapped_function(s):
    # Do some hard pre-processing
    tf.py_function(lambda: time.sleep(0.03), [], ())
    return s

מיפוי רציף

התחל על ידי שימוש בטרנספורמציה של map ללא מקביליות כדוגמה בסיסית.

benchmark(
    ArtificialDataset()
    .map(mapped_function)
)
Execution time: 0.4505277170001136

עלילת זמן ביצוע נתונים - שיטת מיפוי רצף

באשר לגישה הנאיבית , כאן, כפי שמראה העלילה, זמני הפתיחה, הקריאה, העיבוד המקדים (מיפוי) וההכשרה מסתכמים יחד לאיטרציה אחת.

מיפוי מקביל

כעת, השתמש באותה פונקציית עיבוד מקדים אך החל אותה במקביל על מספר דוגמאות.

benchmark(
    ArtificialDataset()
    .map(
        mapped_function,
        num_parallel_calls=tf.data.AUTOTUNE
    )
)
Execution time: 0.2839677860001757

זמן ביצוע נתונים - מיפוי מקביל

כפי שמדגימה עלילת הנתונים, שלבי העיבוד המקדים חופפים, ומצמצמים את הזמן הכולל עבור איטרציה בודדת.

שמירה במטמון

הטרנספורמציה tf.data.Dataset.cache יכולה לשמור במטמון מערך נתונים, בזיכרון או באחסון מקומי. זה יחסוך כמה פעולות (כמו פתיחת קבצים וקריאת נתונים) מביצוע במהלך כל תקופה.

benchmark(
    ArtificialDataset()
    .map(  # Apply time consuming operations before cache
        mapped_function
    ).cache(
    ),
    5
)
Execution time: 0.3848854380003104

זמן ביצוע נתונים - שיטת הנתונים במטמון

כאן, עלילת זמן ביצוע הנתונים מראה שכאשר אתה מאחסן במטמון מערך נתונים, הטרנספורמציות לפני cache (כמו פתיחת הקובץ וקריאת הנתונים) מבוצעות רק בתקופה הראשונה. העידנים הבאים יעשו שימוש חוזר בנתונים שנשמרו במטמון על ידי טרנספורמציה של cache .

אם הפונקציה המוגדרת על ידי המשתמש המועברת לטרנספורמציה של map היא יקרה, החל את המרת cache לאחר המרת map כל עוד מערך הנתונים שהתקבל עדיין יכול להתאים לזיכרון או לאחסון מקומי. אם הפונקציה המוגדרת על ידי המשתמש מגדילה את המקום הדרוש לאחסון מערך הנתונים מעבר לקיבולת המטמון, החל אותו לאחר שינוי cache או שקול לעבד מראש את הנתונים שלך לפני עבודת ההדרכה שלך כדי להפחית את השימוש במשאבים.

מיפוי וקטוריזציה

הפעלת פונקציה מוגדרת על ידי משתמש המועברת לטרנספורמציה של map כוללת תקורה הקשורה לתזמון וביצוע הפונקציה המוגדרת על ידי המשתמש. הפעל את הפונקציה המוגדרת על ידי המשתמש (כלומר, תפעל על אצווה של קלט בבת אחת) והחל את השינוי batch לפני המרת map .

כדי להמחיש את הנוהג הטוב הזה, מערך הנתונים המלאכותי שלך אינו מתאים. העיכוב בתזמון הוא בסביבות 10 מיקרו-שניות (10e-6 שניות), הרבה פחות מעשרות האלפיות השניות בשימוש ב- ArtificialDataset , ולכן קשה לראות את השפעתו.

עבור דוגמה זו, השתמש בפונקציית base tf.data.Dataset.range ופשט את לולאת האימון לצורתה הפשוטה ביותר.

fast_dataset = tf.data.Dataset.range(10000)

def fast_benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for _ in tf.data.Dataset.range(num_epochs):
        for _ in dataset:
            pass
    tf.print("Execution time:", time.perf_counter() - start_time)

def increment(x):
    return x+1

מיפוי סקלרי

fast_benchmark(
    fast_dataset
    # Apply function one item at a time
    .map(increment)
    # Batch
    .batch(256)
)
Execution time: 0.2712608739998359

זמן ביצוע נתונים - שיטת מפה סקלרית

העלילה למעלה ממחישה מה קורה (עם פחות דוגמאות) בשיטת המיפוי הסקלרי. זה מראה שהפונקציה הממופת מיושמת עבור כל מדגם. למרות שפונקציה זו מהירה מאוד, יש לה תקורה מסוימת שמשפיעה על ביצועי הזמן.

מיפוי וקטורי

fast_benchmark(
    fast_dataset
    .batch(256)
    # Apply function on a batch of items
    # The tf.Tensor.__add__ method already handle batches
    .map(increment)
)
Execution time: 0.02737950600021577

זמן ביצוע נתונים - שיטת מפה וקטורית

הפעם, הפונקציה הממופת נקראת פעם אחת והיא חלה על אצווה של דגימה. כפי שמראה עלילת זמן ביצוע הנתונים, בעוד שהפונקציה עשויה לקחת יותר זמן לביצוע, התקורה מופיעה פעם אחת בלבד, מה שמשפר את ביצועי הזמן הכוללים.

צמצום טביעת הרגל של הזיכרון

מספר טרנספורמציות, כולל interleave , prefetch shuffle , שומרים על מאגר פנימי של אלמנטים. אם הפונקציה המוגדרת על ידי המשתמש שעברה לטרנספורמציה של map משנה את גודל האלמנטים, אזי סדר הטרנספורמציה של המפה והטרנספורמציות שרכיבי חוצץ משפיעים על השימוש בזיכרון. באופן כללי, בחר את הסדר שמביא לטביעת זיכרון נמוכה יותר, אלא אם כן רצוי סדר אחר לביצועים.

שמירה במטמון חישובים חלקיים

מומלץ לאחסן את מערך הנתונים לאחר טרנספורמציה של map למעט אם השינוי הזה הופך את הנתונים לגדולים מכדי להכנס לזיכרון. ניתן להשיג פשרה אם ניתן לפצל את הפונקציה הממופת שלך לשני חלקים: חלק שצורך זמן וחלק שצורך זיכרון. במקרה זה, אתה יכול לשרשר את הטרנספורמציות שלך כמו להלן:

dataset.map(time_consuming_mapping).cache().map(memory_consuming_mapping)

בדרך זו, החלק שגוזל זמן מבוצע רק במהלך התקופה הראשונה, ואתה נמנע משימוש רב מדי בשטח מטמון.

סיכום שיטות עבודה מומלצות

להלן סיכום של שיטות העבודה המומלצות לתכנון צינורות קלט TensorFlow בעלי ביצועים:

שחזור הדמויות

כדי להעמיק tf.data.Dataset API, אתה יכול לשחק עם הצינורות שלך. להלן הקוד המשמש לציור התמונות ממדריך זה. זו יכולה להיות נקודת התחלה טובה, להראות כמה דרכים לעקיפת קשיים נפוצים כגון:

  • שחזור זמן ביצוע
  • פונקציות ממופות ביצוע להוט
  • טרנספורמציה של interleave ניתנת להתקשרות
import itertools
from collections import defaultdict

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

מערך הנתונים

בדומה ל- ArtificialDataset , אתה יכול לבנות מערך נתונים המחזיר את הזמן שהושקע בכל שלב.

class TimeMeasuredDataset(tf.data.Dataset):
    # OUTPUT: (steps, timings, counters)
    OUTPUT_TYPES = (tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32)
    OUTPUT_SHAPES = ((2, 1), (2, 2), (2, 3))

    _INSTANCES_COUNTER = itertools.count()  # Number of datasets generated
    _EPOCHS_COUNTER = defaultdict(itertools.count)  # Number of epochs done for each dataset

    def _generator(instance_idx, num_samples):
        epoch_idx = next(TimeMeasuredDataset._EPOCHS_COUNTER[instance_idx])

        # Opening the file
        open_enter = time.perf_counter()
        time.sleep(0.03)
        open_elapsed = time.perf_counter() - open_enter

        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            read_enter = time.perf_counter()
            time.sleep(0.015)
            read_elapsed = time.perf_counter() - read_enter

            yield (
                [("Open",), ("Read",)],
                [(open_enter, open_elapsed), (read_enter, read_elapsed)],
                [(instance_idx, epoch_idx, -1), (instance_idx, epoch_idx, sample_idx)]
            )
            open_enter, open_elapsed = -1., -1.  # Negative values will be filtered


    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_types=cls.OUTPUT_TYPES,
            output_shapes=cls.OUTPUT_SHAPES,
            args=(next(cls._INSTANCES_COUNTER), num_samples)
        )

מערך נתונים זה מספק דוגמאות של צורה [[2, 1], [2, 2], [2, 3]] ושל סוג [tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32] . כל דוגמה היא:

(
  [("Open"), ("Read")],
  [(t0, d), (t0, d)],
  [(i, e, -1), (i, e, s)]
)

איפה:

  • Open Read הם מזהי צעדים
  • t0 הוא חותמת הזמן כאשר השלב המתאים התחיל
  • d הוא הזמן המושקע בשלב המתאים
  • i הוא אינדקס המופע
  • e הוא אינדקס העידן (מספר הפעמים שמערך הנתונים עבר איטרציה)
  • s הוא מדד המדגם

לולאת האיטרציה

הפוך את לולאת האיטרציה למעט יותר מסובכת כדי לצבור את כל התזמונים. זה יעבוד רק עם מערכי נתונים המייצרים דוגמאות כמפורט לעיל.

def timelined_benchmark(dataset, num_epochs=2):
    # Initialize accumulators
    steps_acc = tf.zeros([0, 1], dtype=tf.dtypes.string)
    times_acc = tf.zeros([0, 2], dtype=tf.dtypes.float32)
    values_acc = tf.zeros([0, 3], dtype=tf.dtypes.int32)

    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        epoch_enter = time.perf_counter()
        for (steps, times, values) in dataset:
            # Record dataset preparation informations
            steps_acc = tf.concat((steps_acc, steps), axis=0)
            times_acc = tf.concat((times_acc, times), axis=0)
            values_acc = tf.concat((values_acc, values), axis=0)

            # Simulate training time
            train_enter = time.perf_counter()
            time.sleep(0.01)
            train_elapsed = time.perf_counter() - train_enter

            # Record training informations
            steps_acc = tf.concat((steps_acc, [["Train"]]), axis=0)
            times_acc = tf.concat((times_acc, [(train_enter, train_elapsed)]), axis=0)
            values_acc = tf.concat((values_acc, [values[-1]]), axis=0)

        epoch_elapsed = time.perf_counter() - epoch_enter
        # Record epoch informations
        steps_acc = tf.concat((steps_acc, [["Epoch"]]), axis=0)
        times_acc = tf.concat((times_acc, [(epoch_enter, epoch_elapsed)]), axis=0)
        values_acc = tf.concat((values_acc, [[-1, epoch_num, -1]]), axis=0)
        time.sleep(0.001)

    tf.print("Execution time:", time.perf_counter() - start_time)
    return {"steps": steps_acc, "times": times_acc, "values": values_acc}

שיטת התכנון

לבסוף, הגדר פונקציה המסוגלת לשרטט ציר זמן בהינתן הערכים המוחזרים על ידי הפונקציה timelined_benchmark .

def draw_timeline(timeline, title, width=0.5, annotate=False, save=False):
    # Remove invalid entries (negative times, or empty steps) from the timelines
    invalid_mask = np.logical_and(timeline['times'] > 0, timeline['steps'] != b'')[:,0]
    steps = timeline['steps'][invalid_mask].numpy()
    times = timeline['times'][invalid_mask].numpy()
    values = timeline['values'][invalid_mask].numpy()

    # Get a set of different steps, ordered by the first time they are encountered
    step_ids, indices = np.stack(np.unique(steps, return_index=True))
    step_ids = step_ids[np.argsort(indices)]

    # Shift the starting time to 0 and compute the maximal time value
    min_time = times[:,0].min()
    times[:,0] = (times[:,0] - min_time)
    end = max(width, (times[:,0]+times[:,1]).max() + 0.01)

    cmap = mpl.cm.get_cmap("plasma")
    plt.close()
    fig, axs = plt.subplots(len(step_ids), sharex=True, gridspec_kw={'hspace': 0})
    fig.suptitle(title)
    fig.set_size_inches(17.0, len(step_ids))
    plt.xlim(-0.01, end)

    for i, step in enumerate(step_ids):
        step_name = step.decode()
        ax = axs[i]
        ax.set_ylabel(step_name)
        ax.set_ylim(0, 1)
        ax.set_yticks([])
        ax.set_xlabel("time (s)")
        ax.set_xticklabels([])
        ax.grid(which="both", axis="x", color="k", linestyle=":")

        # Get timings and annotation for the given step
        entries_mask = np.squeeze(steps==step)
        serie = np.unique(times[entries_mask], axis=0)
        annotations = values[entries_mask]

        ax.broken_barh(serie, (0, 1), color=cmap(i / len(step_ids)), linewidth=1, alpha=0.66)
        if annotate:
            for j, (start, width) in enumerate(serie):
                annotation = "\n".join([f"{l}: {v}" for l,v in zip(("i", "e", "s"), annotations[j])])
                ax.text(start + 0.001 + (0.001 * (j % 2)), 0.55 - (0.1 * (j % 2)), annotation,
                        horizontalalignment='left', verticalalignment='center')
    if save:
        plt.savefig(title.lower().translate(str.maketrans(" ", "_")) + ".svg")

השתמש בעטיפות עבור פונקציה ממופה

כדי להפעיל פונקציה ממופה בהקשר להוט, עליך לעטוף אותם בתוך קריאת tf.py_function .

def map_decorator(func):
    def wrapper(steps, times, values):
        # Use a tf.py_function to prevent auto-graph from compiling the method
        return tf.py_function(
            func,
            inp=(steps, times, values),
            Tout=(steps.dtype, times.dtype, values.dtype)
        )
    return wrapper

השוואת צנרת

_batch_map_num_items = 50

def dataset_generator_fun(*args):
    return TimeMeasuredDataset(num_samples=_batch_map_num_items)

תמים

@map_decorator
def naive_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001)  # Time consuming step
    time.sleep(0.0001)  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, [["Map"]]), axis=0),
        tf.concat((times, [[map_enter, map_elapsed]]), axis=0),
        tf.concat((values, [values[-1]]), axis=0)
    )

naive_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .flat_map(dataset_generator_fun)
    .map(naive_map)
    .batch(_batch_map_num_items, drop_remainder=True)
    .unbatch(),
    5
)
WARNING:tensorflow:From /tmp/ipykernel_23983/64197174.py:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_types is deprecated and will be removed in a future version.
Instructions for updating:
Use output_signature instead
WARNING:tensorflow:From /tmp/ipykernel_23983/64197174.py:36: calling DatasetV2.from_generator (from tensorflow.python.data.ops.dataset_ops) with output_shapes is deprecated and will be removed in a future version.
Instructions for updating:
Use output_signature instead
Execution time: 13.13538893499981

אופטימיזציה

@map_decorator
def time_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001 * values.shape[0])  # Time consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, tf.tile([[["1st map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


@map_decorator
def memory_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.0001 * values.shape[0])  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    # Use tf.tile to handle batch dimension
    return (
        tf.concat((steps, tf.tile([[["2nd map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


optimized_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .interleave(  # Parallelize data reading
        dataset_generator_fun,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .batch(  # Vectorize your mapped function
        _batch_map_num_items,
        drop_remainder=True)
    .map(  # Parallelize map transformation
        time_consuming_map,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .cache()  # Cache data
    .map(  # Reduce memory usage
        memory_consuming_map,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    .prefetch(  # Overlap producer and consumer works
        tf.data.AUTOTUNE
    )
    .unbatch(),
    5
)
Execution time: 6.723691489999965
draw_timeline(naive_timeline, "Naive", 15)

png

draw_timeline(optimized_timeline, "Optimized", 15)

png