דוגמניות מחסומים

היכולת לשמור ולשחזר את מצבו של מודל חיונית עבור מספר יישומים, כגון למידת העברה או לביצוע הסקה באמצעות מודלים שהוכשרו מראש. שמירת הפרמטרים של מודל (משקלים, הטיות וכו') בקובץ או בספריה של נקודת ביקורת היא אחת הדרכים להשיג זאת.

מודול זה מספק ממשק ברמה גבוהה לטעינה ושמירת נקודות ביקורת בפורמט TensorFlow v2 , כמו גם רכיבים ברמה נמוכה יותר שכותבים וקוראים מפורמט קובץ זה.

טעינה ושמירה של דגמים פשוטים

על ידי התאמה לפרוטוקול Checkpointable , דגמים פשוטים רבים יכולים להיות מסודרים לנקודות ביקורת ללא כל קוד נוסף:

import Checkpoints
import ImageClassificationModels

extension LeNet: Checkpointable {}

var model = LeNet()

...

try model.writeCheckpoint(to: directory, name: "LeNet")

ואז ניתן לקרוא את אותו מחסום באמצעות:

try model.readCheckpoint(from: directory, name: "LeNet")

יישום ברירת מחדל זה לטעינת מודל ולשמירה ישתמש בסכימת שמות מבוססת נתיב עבור כל טנזור במודל המבוססת על שמות המאפיינים בתוך מבני המודל. לדוגמה, המשקולות וההטיות בתוך הקונבולציה הראשונה במודל LeNet-5 יישמרו עם השמות conv1/filter ו- conv1/bias , בהתאמה. בעת הטעינה, קורא המחסום יחפש טנזורים עם שמות אלו.

התאמה אישית של טעינת דגם וחיסכון

אם ברצונך לקבל שליטה רבה יותר על אילו טנסורים נשמרים וטעונים, או על שמותיהם של אותם טנסורים, פרוטוקול Checkpointable מציע כמה נקודות של התאמה אישית.

כדי להתעלם ממאפיינים בסוגים מסוימים, אתה יכול לספק יישום של ignoredTensorPaths במודל שלך שמחזיר קבוצה של מחרוזות בצורה של Type.property . לדוגמה, כדי להתעלם מהמאפיין scale בכל שכבת Attention, תוכל להחזיר ["Attention.scale"] .

כברירת מחדל, נטוי קדימה משמש להפרדה בין כל רמה עמוקה יותר במודל. ניתן להתאים זאת על ידי הטמעת checkpointSeparator במודל שלך ומתן מחרוזת חדשה לשימוש עבור המפריד הזה.

לבסוף, למידת ההתאמה הגבוהה ביותר במתן שמות טנסור, אתה יכול ליישם tensorNameMap ולספק פונקציה שממפה משם המחרוזת המוגדרת כברירת מחדל שנוצרת עבור טנזור במודל לשם מחרוזת רצויה בנקודת הבידוק. לרוב, זה ישמש לפעילות הדדית עם מחסומים שנוצרו עם מסגרות אחרות, שלכל אחת מהן יש מוסכמות שמות ומבני מודל משלה. פונקציית מיפוי מותאמת אישית נותנת את הדרגה הגבוהה ביותר של התאמה אישית לאופן השם של הטנזורים הללו.

מסופקות חלק מפונקציות העזר הסטנדרטיות, כמו ברירת המחדל CheckpointWriter.identityMap (שפשוט משתמשת בשם נתיב הטנזור שנוצר אוטומטית עבור נקודות ביקורת), או הפונקציה CheckpointWriter.lookupMap(table:) , שיכולה לבנות מיפוי מתוך מילון.

לדוגמא כיצד ניתן לבצע מיפוי מותאם אישית, אנא עיין במודל GPT-2 , המשתמש בפונקציית מיפוי כדי להתאים את ערכת השמות המדויקת המשמשת לנקודות הבידוק של OpenAI.

הרכיבים CheckpointReader ו-CheckpointWriter

עבור כתיבת נקודות ביקורת, ההרחבה המסופקת על ידי פרוטוקול Checkpointable משתמשת בהשתקפות ובנתיבי מפתח כדי לחזור על מאפייני המודל וליצור מילון הממפה נתיבי טנסור של מחרוזת לערכי Tensor. מילון זה מסופק ל- CheckpointWriter הבסיסי, יחד עם ספרייה שבה ניתן לכתוב את המחסום. אותו CheckpointWriter מטפל במשימה של יצירת נקודת הבידוק בדיסק מאותו מילון.

ההפך של תהליך זה הוא קריאה, כאשר CheckpointReader מקבל את המיקום של ספריית נקודות ביקורת בדיסק. לאחר מכן הוא קורא מאותו מחסום ויוצר מילון שממפה את שמות הטנזורים בתוך המחסום עם הערכים השמורים שלהם. מילון זה משמש להחלפת הטנזורים הנוכחיים במודל באלו שבמילון זה.

הן לטעינה והן לשמירה, פרוטוקול Checkpointable ממפה את נתיבי המחרוזת לטנזורים לשמות טנסורים מתאימים בדיסק באמצעות פונקציית המיפוי שתוארה לעיל.

אם הפרוטוקול Checkpointable חסר פונקציונליות נחוצה, או אם יש צורך בשליטה רבה יותר על תהליך הטעינה והשמירה של נקודות המחסום, ניתן להשתמש במחלקות CheckpointReader ו- CheckpointWriter בעצמן.

פורמט נקודת הבידוק של TensorFlow v2

פורמט נקודות המחסום TensorFlow v2, כפי שמתואר בקצרה בכותרת זו , הוא פורמט הדור השני לנקודות ביקורת של מודל TensorFlow. פורמט הדור השני הזה נמצא בשימוש מאז סוף 2016, ויש לו מספר שיפורים ביחס לפורמט v1 checkpoint. TensorFlow SavedModels משתמש בנקודות ביקורת v2 בתוכם כדי לשמור פרמטרים של מודל.

נקודת ביקורת TensorFlow v2 מורכבת מספריה עם מבנה כמו הבא:

checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002

כאשר הקובץ הראשון מאחסן את המטא-נתונים עבור המחסום והקבצים הנותרים הם רסיסים בינאריים המכילים את הפרמטרים המסודרים עבור המודל.

קובץ המטא נתונים של האינדקס מכיל את הסוגים, הגדלים, המיקומים ושמות המחרוזות של כל הטנזורים המסודרים הכלולים ברסיסים. קובץ האינדקס הזה הוא החלק המורכב ביותר מבחינה מבנית של המחסום, והוא מבוסס על tensorflow::table , שהוא עצמו מבוסס על SSTable / LevelDB. קובץ אינדקס זה מורכב מסדרה של זוגות מפתח-ערך, כאשר המפתחות הם מחרוזות והערכים הם מאגרי פרוטוקול. המחרוזות ממוינות ודחוסות קידומת. לדוגמה: אם הערך הראשון הוא conv1/weight וה- conv1/bias הבא, הערך השני משתמש רק בחלק bias .

קובץ האינדקס הכולל הזה נדחס לפעמים באמצעות דחיסה Snappy . הקובץ SnappyDecompression.swift מספק יישום מקורי של Swift של פירוק Snappy ממופע נתונים דחוסים.

המטא-נתונים של כותרות האינדקס והמטא-נתונים של הטנסור מקודדים כמאגרי פרוטוקול ומקודדים/מפענחים ישירות באמצעות Swift Protobuf .

המחלקות CheckpointIndexReader ו- CheckpointIndexWriter מטפלות בטעינה ושמירה של קבצי אינדקס אלה כחלק ממחלקות CheckpointReader ו- CheckpointWriter הכוללות. האחרונים משתמשים בקובצי האינדקס כבסיס לקביעה ממה לקרוא ולכתוב לרסיסים הבינאריים הפשוטים יותר מבחינה מבנית המכילים את נתוני הטנזור.