דף זה תורגם על ידי Cloud Translation API.
Switch to English

ביצועים טובים יותר באמצעות ממשק ה- API של tf.data

צפה ב- TensorFlow.org הפעל בגוגל קולאב צפה במקור ב- GitHub הורד מחברת

סקירה כללית

GPUs ו- TPU יכולים להפחית באופן קיצוני את הזמן הדרוש לביצוע שלב אימון יחיד. השגת ביצועי שיא דורשת צינור קלט יעיל המספק נתונים לשלב הבא לפני סיום השלב הנוכחי. ממשק ה- API של tf.data מסייע בבניית צינורות קלט גמישים ויעילים. מסמך זה מדגים כיצד להשתמש בממשק ה- API של tf.data לבניית צינורות קלט TensorFlow בעלי ביצועים גבוהים.

לפני שתמשיך, קרא את המדריך " בניית צינורות קלט קלט TensorFlow ", כדי ללמוד כיצד להשתמש בממשק ה- API של tf.data .

אֶמְצָעִי

להכין

import tensorflow as tf

import time

במהלך המדריך הזה, תוכלו לחזור על פני מערך נתונים ולמדוד את הביצועים. ביצוע מדדי ביצוע לשחזור יכול להיות קשה, וגורמים שונים משפיעים עליו:

  • עומס המעבד הנוכחי,
  • תעבורת הרשת,
  • מנגנונים מורכבים כמו מטמון וכו '.

לפיכך, כדי לספק מדד לשחזור, בנה דוגמה מלאכותית.

מערך הנתונים

הגדר מחלקה tf.data.Dataset בירושה מ- tf.data.Dataset הנקרא ArtificialDataset . מערך נתונים זה:

  • יוצר דוגמאות num_samples (ברירת המחדל היא 3)
  • ישן זמן מה לפני שהפריט הראשון המדמה פתיחת קובץ
  • ישן זמן מה לפני שהוא מייצר כל פריט כדי לדמות קריאת נתונים מקובץ
class ArtificialDataset(tf.data.Dataset):
    def _generator(num_samples):
        # Opening the file
        time.sleep(0.03)
        
        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            time.sleep(0.015)
            
            yield (sample_idx,)
    
    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_types=tf.dtypes.int64,
            output_shapes=(1,),
            args=(num_samples,)
        )

מערך נתונים זה דומה tf.data.Dataset.range , ומוסיף עיכוב קבוע בהתחלה ובין כל דגימה.

לולאת האימונים

כתוב לולאת אימון דמה שמודדת כמה זמן לוקח לחזור על מערך נתונים. זמן האימון מדומה.

def benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        for sample in dataset:
            # Performing a training step
            time.sleep(0.01)
    tf.print("Execution time:", time.perf_counter() - start_time)

ייעל את הביצועים

כדי להציג כיצד ניתן לבצע אופטימיזציה של ביצועים, תשפר את ביצועי ה- ArtificialDataset .

הגישה הנאיבית

התחל בצינור נאיבי ללא טריקים, וחזור על מערך הנתונים כפי שהוא.

benchmark(ArtificialDataset())
Execution time: 0.2530532629998561

מתחת למכסה המנוע, כך בילית את זמן הביצוע שלך:

תמים

אתה יכול לראות שביצוע שלב הכשרה כולל:

  • פתיחת קובץ אם הוא עדיין לא נפתח,
  • אחזור הזנת נתונים מהקובץ,
  • שימוש בנתונים לאימון.

עם זאת, ביישום סינכרוני נאיבי כמו כאן, בזמן שהצינור שלך מביא את הנתונים, המודל שלך יושב בסרק. לעומת זאת, בזמן שהמודל שלך מתאמן, צינור הקלט יושב בסרק. זמן שלב האימון הוא אפוא סך כל זמן הפתיחה, הקריאה וההדרכה.

החלקים הבאים מתבססים על צינור קלט זה, וממחישים שיטות עבודה מומלצות לעיצוב צינורות קלט TensorFlow ביצועים.

איסוף מראש

איסוף מראש חופף את העיבוד המקדים וביצוע המודל של שלב אימונים. בזמן שהמודל מבצע את שלב האימון s , צינור הקלט קורא את הנתונים לשלב s+1 . פעולה זו מקטינה את זמן הצעד למקסימום (בניגוד לסכום) של האימון ואת הזמן שנדרש לחילוץ הנתונים.

ממשק ה- API של tf.data מספק את השינוי tf.data.Dataset.prefetch . ניתן להשתמש בו לקישור הזמן בו מיוצרים נתונים מרגע צריכת הנתונים. בפרט, הטרנספורמציה משתמשת בשרשור רקע ובמאגר פנימי כדי לשלוף רכיבים מראש ממערך הקלט לפני הזמן שהם מתבקשים. מספר האלמנטים שיש לשלוף מראש צריך להיות שווה למספר האצוות הנצרכות על ידי שלב אימון יחיד (או אולי גדול ממנו). באפשרותך לכוונן ערך זה באופן ידני, או להגדיר אותו ל- tf.data.experimental.AUTOTUNE אשר ינחה את זמן הריצה של tf.data לכוון את הערך באופן דינמי בזמן הריצה.

שים לב שהטרנספורמציה המובילה מראש מספקת יתרונות בכל עת שיש אפשרות לחפוף את עבודתו של "מפיק" לעבודה של "צרכן".

benchmark(
    ArtificialDataset()
    .prefetch(tf.data.experimental.AUTOTUNE)
)
Execution time: 0.20858672200006367

הועבר מראש

הפעם אתה יכול לראות שבזמן שלב האימון פועל לדוגמא 0, צינור הקלט קורא את הנתונים לדוגמא 1 וכן הלאה.

מקביל למיצוי נתונים

בסביבה של ממש, נתוני הקלט עשויים להיות מאוחסנים מרחוק (למשל, GCS או HDFS). צינור מערך נתונים שעובד היטב בקריאת נתונים באופן מקומי עלול להפוך לצוואר בקבוק ב- I / O בעת קריאת נתונים מרחוק בגלל ההבדלים הבאים בין אחסון מקומי לרחוק:

  • זמן לבית ראשון: קריאת בתים ראשונים של קובץ מאחסון מרוחק יכולה לקחת סדרי גודל יותר זמן מהאחסון המקומי.
  • תפוקת קריאה: בעוד שאחסון מרוחק מציע בדרך כלל רוחב פס מצטבר גדול, קריאת קובץ יחיד עשויה לנצל רק חלק קטן מרוחב הפס הזה.

בנוסף, לאחר טעינת הבתים הגולמיים בזיכרון, יתכן שיהיה צורך לבטל את הערעור מחדש ו / או פענוח הנתונים (למשל פרוטובוף ), הדורש חישוב נוסף. תקורה זו קיימת ללא קשר לשאלה אם הנתונים מאוחסנים באופן מקומי או מרחוק, אך הם יכולים להיות גרועים יותר במקרה הרחוק אם הנתונים אינם מושלים מראש ביעילות.

כדי למתן את ההשפעה של תקורות מיצוי הנתונים השונות, ניתן להשתמש בטרנספורמציה tf.data.Dataset.interleave כדי להקביל את שלב טעינת הנתונים, ולשלב את התוכן של מערכי נתונים אחרים (כגון קוראי קבצי נתונים). ניתן לציין את מספר מערכי הנתונים לחפיפה על ידי הארגומנט cycle_length , ואילו את רמת ההקבלה ניתן לציין על ידי הארגומנט num_parallel_calls . בדומה prefetch הטרנספורמציה, את interleave טרנספורמציה תומך tf.data.experimental.AUTOTUNE אשר יאציל את ההחלטה לגבי מה רמת ההקבלה לשימוש על tf.data ריצה.

משתלב רצף

הטיעונים המוגדרים כברירת מחדל של טרנספורמציית tf.data.Dataset.interleave גורמים לו לשלב דוגמאות בודדות משני מערכי נתונים ברצף.

benchmark(
    tf.data.Dataset.range(2)
    .interleave(ArtificialDataset)
)
Execution time: 0.2373930549999841

משתלב רצף

עלילה זו מאפשרת להציג את התנהגות הטרנספורמציה בין interleave , להביא דוגמאות לחלופין משני מערכי הנתונים הזמינים. עם זאת, לא מדובר כאן בשיפור הביצועים.

שידור מקביל

כעת השתמש בארגומנט num_parallel_calls של טרנספורמציית interleave . זה טוען מערכי נתונים מרובים במקביל, מה שמקטין את זמן ההמתנה לפתיחת הקבצים.

benchmark(
    tf.data.Dataset.range(2)
    .interleave(
        ArtificialDataset,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
)
Execution time: 0.1730301249999684

שידור מקביל

הפעם מקבילה הקריאה של שני מערכי הנתונים ומקטינה את זמן עיבוד הנתונים הגלובלי.

מקביל לשינוי נתונים

בעת הכנת נתונים, ייתכן שיהיה צורך לעבד את אלמנטים הקלט. לשם כך, ממשק ה- API של tf.data מציע את השינוי tf.data.Dataset.map , אשר מחיל פונקציה המוגדרת על ידי המשתמש על כל רכיב במערך הקלט. מכיוון שאלמנטים קלטיים אינם תלויים זה בזה, ניתן להקביל את העיבוד המקדים על פני מספר ליבות מעבד. כדי לאפשר זאת, בדומה prefetch ו interleave טרנספורמציות, את map טרנספורמציה מספקת את num_parallel_calls הטיעון כדי לציין את רמת המקביליות.

בחירת הערך הטוב ביותר עבור הטיעון num_parallel_calls תלויה בחומרה שלך, במאפייני נתוני האימון שלך (כגון הגודל והצורה), העלות של פונקציית המפה שלך ואיזה עיבוד אחר מתרחש במעבד בו זמנית. היוריסטיקה פשוטה היא שימוש במספר ליבות המעבד הזמינות. עם זאת, באשר prefetch ו interleave הטרנספורמציה, את map טרנספורמציה תומך tf.data.experimental.AUTOTUNE אשר יאציל את ההחלטה לגבי מה רמת ההקבלה לשימוש על tf.data ריצה.

def mapped_function(s):
    # Do some hard pre-processing
    tf.py_function(lambda: time.sleep(0.03), [], ())
    return s

מיפוי רציף

התחל על ידי שימוש בשינוי map ללא מקבילות כדוגמה בסיסית.

benchmark(
    ArtificialDataset()
    .map(mapped_function)
)
Execution time: 0.43913738300011573

מיפוי רציף

באשר לגישה הנאיבית , כאן הזמן המושקע לפתיחה, קריאה, עיבוד מקדים (מיפוי) והדרכה מתמצה יחד לאיטרציה אחת.

מיפוי מקביל

כעת, השתמש באותה פונקציית עיבוד מקדים אך החל אותה במקביל על מספר דגימות.

benchmark(
    ArtificialDataset()
    .map(
        mapped_function,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
)
Execution time: 0.2730358689998411

מיפוי מקביל

כעת תוכלו לראות בעלילה כי שלבי העיבוד המקדים חופפים, מה שמקטין את הזמן הכולל לאיטרציה יחידה.

מטמון

השינוי tf.data.Dataset.cache יכול tf.data.Dataset.cache מטמון מערך נתונים, בזיכרון או באחסון מקומי. זה יחסוך פעולות מסוימות (כמו פתיחת קבצים וקריאת נתונים) מביצוע במהלך כל עידן.

benchmark(
    ArtificialDataset()
    .map(  # Apply time consuming operations before cache
        mapped_function
    ).cache(
    ),
    5
)
Execution time: 0.36568501300007483

מערך נתונים במטמון

כאשר אתה שומר מטמון מערך נתונים, השינויים לפני cache (כמו פתיחת הקובץ וקריאת הנתונים) מבוצעים רק בתקופה הראשונה. העידנים הבאים ישתמשו מחדש בנתונים שנשמרו במטמון על ידי שינוי cache .

אם הפונקציה שהוגדרה על ידי המשתמש שהועברה לשינוי map היא יקרה, החל את טרנספורמציית cache לאחר טרנספורמציית map כל עוד מערך הנתונים שהתקבל עדיין יכול להתאים לזיכרון או לאחסון מקומי. אם הפונקציה המוגדרת על ידי המשתמש מגדילה את השטח הנדרש לאחסון מערך הנתונים מעבר לקיבולת המטמון, החל אותו לאחר שינוי cache או שקול לעבד מראש את הנתונים שלך לפני עבודת האימון שלך כדי להפחית את השימוש במשאבים.

מיפוי וקטור

הפעלת פונקציה המוגדרת על ידי המשתמש שהועברה לשינוי map כוללת תקורה הקשורה לתזמון ולביצוע הפונקציה המוגדרת על ידי המשתמש. אנו ממליצים לבצע וקטור של הפונקציה המוגדרת על ידי המשתמש (כלומר, לפעול על פני מערך תשומות בבת אחת) ולהחיל את טרנספורמציית batch לפני טרנספורמציית map .

כדי להמחיש מנהג טוב זה, מערך הנתונים המלאכותי שלך אינו מתאים. עיכוב התזמון הוא סביב 10 מיקרו-שניות (10e-6 שניות), הרבה פחות מעשרות אלפיות השניות המשמשות ב- ArtificialDataset , ולכן קשה לראות את השפעתו.

לדוגמא זו, השתמש tf.data.Dataset.range הבסיסית tf.data.Dataset.range ופשט את לולאת האימונים לצורה הפשוטה ביותר.

fast_dataset = tf.data.Dataset.range(10000)

def fast_benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    for _ in tf.data.Dataset.range(num_epochs):
        for _ in dataset:
            pass
    tf.print("Execution time:", time.perf_counter() - start_time)
    
def increment(x):
    return x+1

מיפוי סקלרים

fast_benchmark(
    fast_dataset
    # Apply function one item at a time
    .map(increment)
    # Batch
    .batch(256)
)
Execution time: 0.8861004689999845

מפת סקלר

העלילה לעיל ממחישה את המתרחש (עם פחות דוגמאות). אתה יכול לראות שהפונקציה הממופה מוחלת על כל דוגמה. אמנם פונקציה זו מהירה מאוד, אך יש לה כמה תקורות שמשפיעות על ביצועי הזמן.

מיפוי וקטורי

fast_benchmark(
    fast_dataset
    .batch(256)
    # Apply function on a batch of items
    # The tf.Tensor.__add__ method already handle batches
    .map(increment)
)
Execution time: 0.032729552000091644

מפה וקטורית

הפעם, הפונקציה הממופה נקראת פעם אחת והיא חלה על קבוצה של דוגמה. בעוד שהפונקציה עשויה להימשך זמן רב יותר לביצוע, התקורה מופיעה פעם אחת בלבד, ומשפרת את ביצועי הזמן הכוללים.

צמצום טביעת הרגל של הזיכרון

מספר טרנספורמציות, כולל interleave פנימי, prefetch shuffle , שומרות על מאגר פנימי של אלמנטים. אם הפונקציה המוגדרת על ידי המשתמש המועברת לשינוי map משנה את גודל האלמנטים, אז סדר השינוי במפה והטרנספורמציות שרכיבי המאגר משפיעים על השימוש בזיכרון. באופן כללי, אנו ממליצים לבחור בסדר שמביא לטביעת רגל זיכרון נמוכה יותר, אלא אם כן רצוי להזמין שונה לביצועים.

אחסון חלקי במטמון

מומלץ לשמור את מערך הנתונים במטמון לאחר שינוי map למעט אם שינוי זה הופך את הנתונים לגדולים מכדי שיתאימו לזיכרון. ניתן להשיג פשרה אם ניתן לפצל את הפונקציה הממופה שלך לשני חלקים: זמן גוזל וחלק גוזל זיכרון. במקרה זה, אתה יכול לשרשר את השינויים שלך כמו להלן:

dataset.map(time_consuming_mapping).cache().map(memory_consuming_mapping)

בדרך זו, החלק שגוזל זמן מבוצע רק במהלך העידן הראשון, ואתה נמנע משימוש בשטח מטמון רב מדי.

סיכום שיטות עבודה מומלצות

הנה סיכום של שיטות העבודה המומלצות לעיצוב צינורות קלט TensorFlow ביצועים:

שכפול הדמויות

כדי להעמיק tf.data.Dataset ה- API של tf.data.Dataset , אתה יכול לשחק עם צינורות משלך. להלן הקוד המשמש לרישום התמונות ממדריך זה. זו יכולה להיות נקודת התחלה טובה, ולהציג כמה דרכים לעקיפת הבעיה עבור קשיים נפוצים כגון:

  • שחזור זמן ביצוע;
  • פונקציות ממופתות להוטות בביצוע;
  • interleave לשנות את השינוי בין השניים.
import itertools
from collections import defaultdict

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

מערך הנתונים

בדומה ל- ArtificialDataset ניתן לבנות מערך נתונים המחזיר את הזמן שבילה בכל שלב.

class TimeMeasuredDataset(tf.data.Dataset):
    # OUTPUT: (steps, timings, counters)
    OUTPUT_TYPES = (tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32)
    OUTPUT_SHAPES = ((2, 1), (2, 2), (2, 3))
    
    _INSTANCES_COUNTER = itertools.count()  # Number of datasets generated
    _EPOCHS_COUNTER = defaultdict(itertools.count)  # Number of epochs done for each dataset
    
    def _generator(instance_idx, num_samples):
        epoch_idx = next(TimeMeasuredDataset._EPOCHS_COUNTER[instance_idx])
        
        # Opening the file
        open_enter = time.perf_counter()
        time.sleep(0.03)
        open_elapsed = time.perf_counter() - open_enter
        
        for sample_idx in range(num_samples):
            # Reading data (line, record) from the file
            read_enter = time.perf_counter()
            time.sleep(0.015)
            read_elapsed = time.perf_counter() - read_enter
            
            yield (
                [("Open",), ("Read",)],
                [(open_enter, open_elapsed), (read_enter, read_elapsed)],
                [(instance_idx, epoch_idx, -1), (instance_idx, epoch_idx, sample_idx)]
            )
            open_enter, open_elapsed = -1., -1.  # Negative values will be filtered
            
    
    def __new__(cls, num_samples=3):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_types=cls.OUTPUT_TYPES,
            output_shapes=cls.OUTPUT_SHAPES,
            args=(next(cls._INSTANCES_COUNTER), num_samples)
        )

מערך נתונים זה מספק דוגמאות של צורה [[2, 1], [2, 2], [2, 3]] וסוג [tf.dtypes.string, tf.dtypes.float32, tf.dtypes.int32] . כל מדגם הוא:

(
  [("Open"), ("Read")],
  [(t0, d), (t0, d)],
  [(i, e, -1), (i, e, s)]
)

איפה:

  • Open Read הם מזהים צעדים
  • t0 הוא חותמת הזמן בה התחיל השלב המתאים
  • d הוא הזמן המושקע בשלב המתאים
  • i הוא מדד המופעים
  • e הוא אינדקס העידן (מספר הפעמים שעוברים איטרציה של מערך הנתונים)
  • s הוא מדד המדגם

לולאת האיטרציה

הפוך את לולאת האיטרציה למורכבת מעט יותר כדי לצבור את כל התזמונים. זה יעבוד רק עם מערכי נתונים שיוצרים דוגמאות כמפורט לעיל.

def timelined_benchmark(dataset, num_epochs=2):
    # Initialize accumulators
    steps_acc = tf.zeros([0, 1], dtype=tf.dtypes.string)
    times_acc = tf.zeros([0, 2], dtype=tf.dtypes.float32)
    values_acc = tf.zeros([0, 3], dtype=tf.dtypes.int32)
    
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        epoch_enter = time.perf_counter()
        for (steps, times, values) in dataset:
            # Record dataset preparation informations
            steps_acc = tf.concat((steps_acc, steps), axis=0)
            times_acc = tf.concat((times_acc, times), axis=0)
            values_acc = tf.concat((values_acc, values), axis=0)
            
            # Simulate training time
            train_enter = time.perf_counter()
            time.sleep(0.01)
            train_elapsed = time.perf_counter() - train_enter
            
            # Record training informations
            steps_acc = tf.concat((steps_acc, [["Train"]]), axis=0)
            times_acc = tf.concat((times_acc, [(train_enter, train_elapsed)]), axis=0)
            values_acc = tf.concat((values_acc, [values[-1]]), axis=0)
        
        epoch_elapsed = time.perf_counter() - epoch_enter
        # Record epoch informations
        steps_acc = tf.concat((steps_acc, [["Epoch"]]), axis=0)
        times_acc = tf.concat((times_acc, [(epoch_enter, epoch_elapsed)]), axis=0)
        values_acc = tf.concat((values_acc, [[-1, epoch_num, -1]]), axis=0)
        time.sleep(0.001)
    
    tf.print("Execution time:", time.perf_counter() - start_time)
    return {"steps": steps_acc, "times": times_acc, "values": values_acc}

שיטת העלילה

לבסוף, הגדירו פונקציה המסוגלת לתכנן ציר זמן בהתחשב בערכים שהוחזרו על ידי הפונקציה timelined_benchmark .

def draw_timeline(timeline, title, width=0.5, annotate=False, save=False):
    # Remove invalid entries (negative times, or empty steps) from the timelines
    invalid_mask = np.logical_and(timeline['times'] > 0, timeline['steps'] != b'')[:,0]
    steps = timeline['steps'][invalid_mask].numpy()
    times = timeline['times'][invalid_mask].numpy()
    values = timeline['values'][invalid_mask].numpy()
    
    # Get a set of different steps, ordered by the first time they are encountered
    step_ids, indices = np.stack(np.unique(steps, return_index=True))
    step_ids = step_ids[np.argsort(indices)]

    # Shift the starting time to 0 and compute the maximal time value
    min_time = times[:,0].min()
    times[:,0] = (times[:,0] - min_time)
    end = max(width, (times[:,0]+times[:,1]).max() + 0.01)
    
    cmap = mpl.cm.get_cmap("plasma")
    plt.close()
    fig, axs = plt.subplots(len(step_ids), sharex=True, gridspec_kw={'hspace': 0})
    fig.suptitle(title)
    fig.set_size_inches(17.0, len(step_ids))
    plt.xlim(-0.01, end)
    
    for i, step in enumerate(step_ids):
        step_name = step.decode()
        ax = axs[i]
        ax.set_ylabel(step_name)
        ax.set_ylim(0, 1)
        ax.set_yticks([])
        ax.set_xlabel("time (s)")
        ax.set_xticklabels([])
        ax.grid(which="both", axis="x", color="k", linestyle=":")
        
        # Get timings and annotation for the given step
        entries_mask = np.squeeze(steps==step)
        serie = np.unique(times[entries_mask], axis=0)
        annotations = values[entries_mask]
        
        ax.broken_barh(serie, (0, 1), color=cmap(i / len(step_ids)), linewidth=1, alpha=0.66)
        if annotate:
            for j, (start, width) in enumerate(serie):
                annotation = "\n".join([f"{l}: {v}" for l,v in zip(("i", "e", "s"), annotations[j])])
                ax.text(start + 0.001 + (0.001 * (j % 2)), 0.55 - (0.1 * (j % 2)), annotation,
                        horizontalalignment='left', verticalalignment='center')
    if save:
        plt.savefig(title.lower().translate(str.maketrans(" ", "_")) + ".svg")

השתמש בעטיפות לפונקציה ממופה

כדי להפעיל פונקציה ממופה בהקשר נלהב, עליך לעטוף אותם בתוך שיחת tf.py_function .

def map_decorator(func):
    def wrapper(steps, times, values):
        # Use a tf.py_function to prevent auto-graph from compiling the method
        return tf.py_function(
            func,
            inp=(steps, times, values),
            Tout=(steps.dtype, times.dtype, values.dtype)
        )
    return wrapper

השוואת צינורות

_batch_map_num_items = 50

def dataset_generator_fun(*args):
    return TimeMeasuredDataset(num_samples=_batch_map_num_items)

תמים

@map_decorator
def naive_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001)  # Time consuming step
    time.sleep(0.0001)  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, [["Map"]]), axis=0),
        tf.concat((times, [[map_enter, map_elapsed]]), axis=0),
        tf.concat((values, [values[-1]]), axis=0)
    )

naive_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .flat_map(dataset_generator_fun)
    .map(naive_map)
    .batch(_batch_map_num_items, drop_remainder=True)
    .unbatch(),
    5
)
Execution time: 12.436093607999965

מותאם

@map_decorator
def time_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.001 * values.shape[0])  # Time consuming step
    map_elapsed = time.perf_counter() - map_enter

    return (
        tf.concat((steps, tf.tile([[["1st map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


@map_decorator
def memory_consuming_map(steps, times, values):
    map_enter = time.perf_counter()
    time.sleep(0.0001 * values.shape[0])  # Memory consuming step
    map_elapsed = time.perf_counter() - map_enter

    # Use tf.tile to handle batch dimension
    return (
        tf.concat((steps, tf.tile([[["2nd map"]]], [steps.shape[0], 1, 1])), axis=1),
        tf.concat((times, tf.tile([[[map_enter, map_elapsed]]], [times.shape[0], 1, 1])), axis=1),
        tf.concat((values, tf.tile([[values[:][-1][0]]], [values.shape[0], 1, 1])), axis=1)
    )


optimized_timeline = timelined_benchmark(
    tf.data.Dataset.range(2)
    .interleave(  # Parallelize data reading
        dataset_generator_fun,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
    .batch(  # Vectorize your mapped function
        _batch_map_num_items,
        drop_remainder=True)
    .map(  # Parallelize map transformation
        time_consuming_map,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
    .cache()  # Cache data
    .map(  # Reduce memory usage
        memory_consuming_map,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
    .prefetch(  # Overlap producer and consumer works
        tf.data.experimental.AUTOTUNE
    )
    .unbatch(),
    5
)
Execution time: 6.303204500999982

draw_timeline(naive_timeline, "Naive", 15)

png

draw_timeline(optimized_timeline, "Optimized", 15)

png