![]() | ![]() | ![]() | ![]() |
הביטוי "שמירת מודל TensorFlow" פירוש בדרך כלל אחד משני דברים:
- מחסומים, או
- SavedModel.
מחסומים לוכדים את הערך המדויק של כל הפרמטרים ( tf.Variable
אובייקטים tf.Variable
) המשמשים מודל. מחסומים אינם מכילים תיאור כלשהו של החישוב שהוגדר על ידי המודל ולכן הם בדרך כלל שימושיים רק כאשר קוד המקור שישתמש בערכי הפרמטר שנשמר זמין.
הפורמט של SavedModel לעומת זאת כולל תיאור סדרתי של החישוב שהוגדר על ידי המודל בנוסף לערכי הפרמטרים (נקודת ביקורת). מודלים בפורמט זה אינם תלויים בקוד המקור שיצר את המודל. לכן הם מתאימים לפריסה באמצעות TensorFlow Serving, TensorFlow Lite, TensorFlow.js, או תוכניות בשפות תכנות אחרות (C, C ++, Java, Go, Rust, C # וכו 'TensorFlow API).
מדריך זה מכסה ממשקי API לכתיבה וקריאה של נקודות ביקורת.
להכין
import tensorflow as tf
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
net = Net()
חוסך tf.keras
API להכשרה של tf.keras
עיין במדריך tf.keras
לשמירה ושחזור.
tf.keras.Model.save_weights
חוסך מחסום TensorFlow.
net.save_weights('easy_checkpoint')
כתיבת מחסומים
המצב המתמשך של מודל TensorFlow נשמר באובייקטים tf.Variable
. tf.Variable
. אלה יכולים להיבנות ישירות, אך לרוב נוצרים באמצעות ממשקי API ברמה גבוהה כמוtf.keras.layers
או tf.keras.Model
.
הדרך הקלה ביותר לנהל משתנים היא על ידי הצמדתם לאובייקטים של פייתון, ואז הפניה לאובייקטים אלה.
tf.train.Checkpoint
משנה של tf.train.Checkpoint
, tf.keras.layers.Layer
ו- tf.keras.Model
אוטומטית אחר משתנים שהוקצו לתכונות שלהם. הדוגמה הבאה בונה מודל לינארי פשוט, ואז כותבת נקודות ביקורת המכילות ערכים לכל המשתנים של המודל.
אתה יכול בקלות לשמור נקודת ביקורת באמצעות Model.save_weights
.
מחסום ידני
להכין
כדי לעזור להפגין את כל התכונות של tf.train.Checkpoint
, הגדר מערך צעצועים וצעד אופטימיזציה:
def toy_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
return tf.data.Dataset.from_tensor_slices(
dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
"""Trains `net` on `example` using `optimizer`."""
with tf.GradientTape() as tape:
output = net(example['x'])
loss = tf.reduce_mean(tf.abs(output - example['y']))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss
צור את אובייקטי המחסום
השתמש באובייקט tf.train.Checkpoint
כדי ליצור באופן ידני נקודת ביקורת, כאשר האובייקטים שברצונך למחסום מוגדרים כתכונות באובייקט.
tf.train.CheckpointManager
יכול גם להיות מועיל לניהול נקודות מחסום מרובות.
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
התאמן ובדוק את הדגם
לולאת האימון הבאה יוצרת מופע של המודל ושל מייעל, ואז אוספת אותם לאובייקט tf.train.Checkpoint
. זה מכנה את שלב האימון בלולאה על כל אצווה של נתונים, וכותב מעת לעת מחסומים לדיסק.
def train_and_checkpoint(net, manager):
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
for _ in range(50):
example = next(iterator)
loss = train_step(net, example, opt)
ckpt.step.assign_add(1)
if int(ckpt.step) % 10 == 0:
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch. Saved checkpoint for step 10: ./tf_ckpts/ckpt-1 loss 30.42 Saved checkpoint for step 20: ./tf_ckpts/ckpt-2 loss 23.83 Saved checkpoint for step 30: ./tf_ckpts/ckpt-3 loss 17.27 Saved checkpoint for step 40: ./tf_ckpts/ckpt-4 loss 10.81 Saved checkpoint for step 50: ./tf_ckpts/ckpt-5 loss 4.74
לשחזר ולהמשיך באימונים
לאחר מחזור האימונים הראשון תוכלו להעביר מודל ומנהל חדשים, אך להעלות אימונים בדיוק במקום בו הפסקתם:
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5 Saved checkpoint for step 60: ./tf_ckpts/ckpt-6 loss 0.85 Saved checkpoint for step 70: ./tf_ckpts/ckpt-7 loss 0.87 Saved checkpoint for step 80: ./tf_ckpts/ckpt-8 loss 0.71 Saved checkpoint for step 90: ./tf_ckpts/ckpt-9 loss 0.46 Saved checkpoint for step 100: ./tf_ckpts/ckpt-10 loss 0.21
האובייקט tf.train.CheckpointManager
מוחק מחסומים ישנים. מעל זה מוגדר לשמור רק את שלוש המחסומים האחרונים.
print(manager.checkpoints) # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']
נתיבים אלה, למשל './tf_ckpts/ckpt-10'
, אינם קבצים בדיסק. במקום זאת הם קידומות לקובץ index
ולקובץ נתונים אחד או יותר המכילים את הערכים המשתנים. קידומות אלה מקובצות יחד בקובץ checkpoint
בודד ( './tf_ckpts/checkpoint'
) כאשר CheckpointManager
שומר את מצבו.
ls ./tf_ckpts
checkpoint ckpt-8.data-00000-of-00001 ckpt-9.index ckpt-10.data-00000-of-00001 ckpt-8.index ckpt-10.index ckpt-9.data-00000-of-00001
מכניקת העמסה
TensorFlow מתאים משתנים לערכים נקודתיים על ידי חציית גרף מכוון עם קצוות בשם, החל מהאובייקט שנטען. שמות קצה מגיעים בדרך כלל משמות תכונות באובייקטים, למשל "l1"
ב- self.l1 = tf.keras.layers.Dense(5)
. tf.train.Checkpoint
משתמש בשמות ארגומנט מילות המפתח שלה, כמו tf.train.Checkpoint(step=...)
"step"
ב- tf.train.Checkpoint(step=...)
.
גרף התלות מהדוגמה שלמעלה נראה כך:
האופטימיזציה היא באדום, משתנים רגילים הם בכחול, ומשתני חריץ האופטימיזציה הם בצבע כתום. הצמתים האחרים - למשל, המייצגים את tf.train.Checkpoint
- נמצאים בשחור.
משתני חריץ הם חלק ממצב האופטימיזציה, אך נוצרים עבור משתנה ספציפי. לדוגמה 'm'
הקצוות 'm'
שלמעלה תואמים את המומנטום, שעוקב אחר האופטימיזציה של אדם עבור כל משתנה. משתני חריץ נשמרים במחסום רק אם המשתנה והמייעל יישמרו, וכך הקצוות המקווקו.
קריאה restore
באובייקט tf.train.Checkpoint
בתורים לשחזורים המבוקשים, ומשחזרת ערכים משתנים ברגע שיש נתיב תואם מאובייקט Checkpoint
. לדוגמה, תוכל לטעון רק את ההטיה מהמודל שהגדרת לעיל על ידי שחזור נתיב אחד אליו דרך הרשת והשכבה.
to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy()) # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy()) # This gets the restored value.
[0. 0. 0. 0. 0.] [2.831489 3.7156947 2.5892444 3.8669944 4.749503 ]
גרף התלות של האובייקטים החדשים הללו הוא תצלום קטן בהרבה של נקודת הביקורת הגדולה יותר שכתבת לעיל. הוא כולל רק את ההטיה tf.train.Checkpoint
שמירה ש- tf.train.Checkpoint
משתמש בו tf.train.Checkpoint
נקודות ביקורת.
restore
מחזיר אובייקט סטטוס, שיש לו קביעות אופציונליות. כל האובייקטים שנוצרו Checkpoint
החדש שוחזרו, כך status.assert_existing_objects_matched
עובר.
status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f1644447b70>
במחסום ישנם אובייקטים רבים שאינם תואמים, כולל ליבת השכבה ומשתני האופטימיזציה. status.assert_consumed
עובר רק אם נקודת הבידוק והתוכנית תואמות במדויק, והיה יוצא חריג כאן.
שחזור עיכוב
אובייקטים של Layer
ב- TensorFlow עשויים לעכב את יצירת המשתנים לשיחה הראשונה שלהם, כאשר צורות קלט זמינות. לדוגמא צורת ליבת שכבה Dense
תלויה הן בצורת הקלט והן בצורת הפלט, ולכן צורת הפלט הנדרשת כטיעון קונסטרוקטור אינה מספיק מידע כדי ליצור את המשתנה בפני עצמו. מכיוון שקריאה Layer
קוראת גם את ערך המשתנה, שחזור חייב להתרחש בין יצירת המשתנה לשימושו הראשון.
כדי לתמוך tf.train.Checkpoint
זה, משחזר tf.train.Checkpoint
תורים שעדיין אין להם משתנה תואם.
delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy()) # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy()) # Restored
[[0. 0. 0. 0. 0.]] [[4.5719748 4.6099544 4.931875 4.836442 4.8496275]]
בדיקה ידנית של מחסומים
tf.train.load_checkpoint
מחזיר CheckpointReader
שנותן גישה ברמה נמוכה יותר לתוכן המחסום. הוא מכיל מיפויים ממפתח כל vartiable, לצורה ולסוג ה- dt עבור כל משתנה במחסום. המפתח של משתנה הוא נתיב האובייקט שלו, כמו בתרשימים המוצגים לעיל.
reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()
sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH', 'iterator/.ATTRIBUTES/ITERATOR_STATE', 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', 'save_counter/.ATTRIBUTES/VARIABLE_VALUE', 'step/.ATTRIBUTES/VARIABLE_VALUE']
אז אם אתה מעוניין בערך net.l1.kernel
אתה יכול לקבל את הערך באמצעות הקוד הבא:
key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'
print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5] Dtype: float32
זה גם מספק שיטת get_tensor
המאפשרת לך לבדוק את הערך של המשתנה:
reader.get_tensor(key)
array([[4.5719748, 4.6099544, 4.931875 , 4.836442 , 4.8496275]], dtype=float32)
מעקב אחר רשימות ומילונים
כמו במשימות תכונות ישירות כמו self.l1 = tf.keras.layers.Dense(5)
, הקצאת רשימות ומילונים לתכונות תעקוב אחר תוכנם.
save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')
restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy() # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()
יתכן שתבחין באובייטי עטיפה לרשימות ומילונים. עטיפות אלה הן גרסאות הניתנות לבדיקה של מבני הנתונים הבסיסיים. בדיוק כמו הטעינה המבוססת על מאפיין, עטיפות אלה משחזרות את ערך המשתנה ברגע שהוא מתווסף למיכל.
restore.listed = []
print(restore.listed) # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1) # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])
אותו מעקב מוחל באופן אוטומטי על מחלקות משנה של tf.keras.Model
, ויכול לשמש למשל למעקב אחר רשימות שכבות.
סיכום
אובייקטים של TensorFlow מספקים מנגנון אוטומטי קל לשמירה ושחזור ערכי המשתנים בהם הם משתמשים.