הצטרף ל-TensorFlow ב-Google I/O, 11-12 במאי הירשם עכשיו

מחסומי אימון

הצג באתר TensorFlow.org הפעל בגוגל קולאב צפה במקור ב-GitHub הורד מחברת

הביטוי "שמירת מודל TensorFlow" אומר בדרך כלל אחד משני דברים:

  1. מחסומים, OR
  2. SavedModel.

נקודות ביקורת לוכדות את הערך המדויק של כל הפרמטרים ( tf.Variable objects) המשמשים את המודל. נקודות ביקורת אינן מכילות כל תיאור של החישוב שהוגדר על ידי המודל ולכן הן בדרך כלל שימושיות רק כאשר קוד מקור שישתמש בערכי הפרמטרים השמורים זמין.

פורמט SavedModel לעומת זאת כולל תיאור סדרתי של החישוב שהוגדר על ידי המודל בנוסף לערכי הפרמטרים (נקודת ביקורת). מודלים בפורמט זה אינם תלויים בקוד המקור שיצר את המודל. לפיכך הם מתאימים לפריסה באמצעות TensorFlow Serving, TensorFlow Lite, TensorFlow.js, או תוכניות בשפות תכנות אחרות (ה-C, C++, Java, Go, Rust, C# וכו'. API של TensorFlow).

מדריך זה מכסה ממשקי API לכתיבה וקריאה של מחסומים.

להכין

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

שמירת ממשקי API להדרכה של tf.keras

עיין במדריך tf.keras בנושא שמירה ושחזור.

tf.keras.Model.save_weights שומר נקודת ביקורת של TensorFlow.

net.save_weights('easy_checkpoint')

כתיבת מחסומים

המצב המתמשך של מודל TensorFlow מאוחסן באובייקטים tf.Variable . ניתן לבנות אותם ישירות, אך לרוב נוצרים באמצעות ממשקי API ברמה גבוהה כמו tf.keras.layers או tf.keras.Model .

הדרך הקלה ביותר לנהל משתנים היא על ידי הצמדתם לאובייקטים של Python, ואז הפניה לאובייקטים הללו.

תת-מחלקות של tf.train.Checkpoint , tf.keras.layers.Layer ו- tf.keras.Model באופן אוטומטי אחר משתנים המוקצים לתכונות שלהם. הדוגמה הבאה בונה מודל ליניארי פשוט, ולאחר מכן כותבת נקודות ביקורת המכילות ערכים עבור כל משתני המודל.

אתה יכול בקלות לשמור נקודת ביקורת מודל עם Model.save_weights .

בדיקה ידנית

להכין

כדי לעזור להדגים את כל התכונות של tf.train.Checkpoint , הגדר מערך צעצוע ושלב אופטימיזציה:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

צור את אובייקטי המחסום

השתמש באובייקט tf.train.Checkpoint כדי ליצור נקודת ביקורת באופן ידני, כאשר האובייקטים שברצונך למחסום מוגדרים כמאפיינים על האובייקט.

tf.train.CheckpointManager יכול להיות מועיל גם לניהול מחסומים מרובים.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

הרכבת ונקודת ביקורת של הדגם

לולאת האימון הבאה יוצרת מופע של המודל ושל אופטימיזציה, ואז אוספת אותם לאובייקט tf.train.Checkpoint . זה קורא לשלב האימון בלולאה על כל אצווה של נתונים, ומדי פעם כותב נקודות ביקורת לדיסק.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 31.27
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 24.68
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 18.12
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 11.65
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 5.39

שחזר והמשך אימון

לאחר מחזור ההכשרה הראשון אתה יכול לעבור מודל ומנהל חדש, אבל להמשיך את האימונים בדיוק מהנקודה בה הפסקתם:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.50
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 1.27
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.56
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.70
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.35

האובייקט tf.train.CheckpointManager מוחק נקודות ביקורת ישנות. מעל זה מוגדר לשמור רק את שלושת המחסומים העדכניים ביותר.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

נתיבים אלה, למשל './tf_ckpts/ckpt-10' , אינם קבצים בדיסק. במקום זאת הם קידומות לקובץ index וקובץ נתונים אחד או יותר המכילים את ערכי המשתנים. קידומות אלה מקובצות יחד בקובץ נקודת checkpoint בודד ( './tf_ckpts/checkpoint' ) שבו ה- CheckpointManager שומר את מצבו.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

מכניקת טעינה

TensorFlow מתאים משתנים לערכי נקודת ביקורת על ידי חציית גרף מכוון עם קצוות בעלי שם, החל מהאובייקט הנטען. שמות קצה מגיעים בדרך כלל משמות תכונות באובייקטים, למשל ה- "l1" ב- self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint משתמש בשמות הארגומנטים של מילות המפתח שלו, כמו ב"שלב "step" ב- tf.train.Checkpoint(step=...) .

גרף התלות מהדוגמה למעלה נראה כך:

הדמיה של גרף התלות עבור לולאת האימון לדוגמה

כלי האופטימיזציה באדום, משתנים רגילים הם בכחול, ומשתני משבצת המיטוב הם בכתום. הצמתים האחרים - למשל, המייצגים את tf.train.Checkpoint - הם בצבע שחור.

משתני משבצות הם חלק ממצב האופטימיזציה, אך נוצרים עבור משתנה ספציפי. לדוגמה, קצוות ה- 'm' שלמעלה תואמים למומנטום, שאותו מייעל Adam עוקב אחר כל משתנה. משתני משבצת נשמרים בנקודת ביקורת רק אם המשתנה והמייעל יישמרו שניהם, ובכך הקצוות המקווקוים.

קריאה restore באובייקט tf.train.Checkpoint בתור את השחזורים המבוקשים, משחזר ערכי משתנים ברגע שיש נתיב תואם מאובייקט Checkpoint . לדוגמה, אתה יכול לטעון רק את ההטיה מהמודל שהגדרת למעלה על ידי שחזור נתיב אחד אליו דרך הרשת והשכבה.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.7209885 3.7588918 4.421351  4.1466427 4.0712557]

גרף התלות של האובייקטים החדשים הללו הוא תת-גרף קטן בהרבה של המחסום הגדול יותר שכתבת למעלה. הוא כולל רק את ההטיה ומונה שמירה ש- tf.train.Checkpoint משתמש בהם כדי למספר נקודות ביקורת.

הדמיה של תת-גרף עבור משתנה ההטיה

restore מחזיר אובייקט סטטוס, שיש לו קביעות אופציונליות. כל האובייקטים שנוצרו ב- Checkpoint החדש שוחזרו, אז status.assert_existing_objects_matched עובר.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f93a075b9d0>

ישנם אובייקטים רבים בנקודת הבידוק שלא התאימו, כולל ליבת השכבה והמשתנים של האופטימיזציה. status.assert_consumed עובר רק אם המחסום והתוכנית תואמים במדויק, ותזרוק כאן חריגה.

שחזורים דחויים

אובייקטי Layer ב-TensorFlow עשויים לדחות את יצירת המשתנים לקריאה הראשונה שלהם, כאשר צורות קלט זמינות. לדוגמה, הצורה של ליבת שכבה Dense תלויה הן בצורות הקלט והן בצורות הפלט של השכבה, ולכן צורת הפלט הנדרשת כארגומנט הבנאי אינה מספיק מידע כדי ליצור את המשתנה בפני עצמו. מכיוון שקריאה Layer קוראת גם את ערך המשתנה, שחזור חייב לקרות בין יצירת המשתנה לשימוש הראשון בו.

כדי לתמוך בביטוי זה, tf.train.Checkpoint דוחה שחזורים שעדיין אין להם משתנה תואם.

deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.5854754 4.607731  4.649179  4.8474874 5.121    ]]

בדיקה ידנית של מחסומים

tf.train.load_checkpoint מחזיר CheckpointReader שנותן גישה ברמה נמוכה יותר לתוכן המחסום. הוא מכיל מיפויים מהמפתח של כל משתנה, לצורה ול-dtype עבור כל משתנה בנקודת הבידוק. המפתח של משתנה הוא נתיב האובייקט שלו, כמו בגרפים המוצגים למעלה.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

אז אם אתה מעוניין בערך של net.l1.kernel אתה יכול לקבל את הערך עם הקוד הבא:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

זה גם מספק שיטת get_tensor המאפשרת לך לבדוק את הערך של משתנה:

reader.get_tensor(key)
array([[4.5854754, 4.607731 , 4.649179 , 4.8474874, 5.121    ]],
      dtype=float32)

מעקב אחר אובייקטים

מחסומים שומרים ומשחזרים את הערכים של אובייקטי tf.Variable על ידי "מעקב" אחר כל משתנה או אובייקט בר מעקב באחת מהתכונות שלו. בעת ביצוע שמירה, משתנים נאספים באופן רקורסיבי מכל האובייקטים הנגישים למעקב.

כמו בהקצאות תכונות ישירות כמו self.l1 = tf.keras.layers.Dense(5) , הקצאת רשימות ומילון לתכונות תעקוב אחר תוכנן.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

ייתכן שתבחין באובייקטי מעטפת עבור רשימות ומילונים. עטיפות אלו הן גרסאות הניתנות לבדיקה של מבני הנתונים הבסיסיים. בדיוק כמו הטעינה המבוססת על תכונה, עטיפות אלה משחזרות ערך של משתנה ברגע שהוא מתווסף למיכל.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

אובייקטים שניתנים למעקב כוללים tf.train.Checkpoint , tf.Module המשנה שלו (למשל keras.layers.Layer ו- keras.Model ), ומכולות Python מזוהות:

  • dict (and collections.OrderedDict )
  • list
  • tuple (ו- collections.namedtuple , typing.NamedTuple )

סוגי מיכל אחרים אינם נתמכים , כולל:

  • collections.defaultdict
  • set

מתעלמים מכל שאר האובייקטים של Python, כולל:

  • int
  • string
  • float

סיכום

אובייקטי TensorFlow מספקים מנגנון אוטומטי קל לשמירה ושחזור הערכים של המשתנים שבהם הם משתמשים.