احفظ التاريخ! يعود مؤتمر Google I / O من 18 إلى 20 مايو. سجل الآن
ترجمت واجهة Cloud Translation API‏ هذه الصفحة.
Switch to English

نقاط تفتيش التدريب

عرض على TensorFlow.org تشغيل في Google Colab عرض المصدر على جيثب تحميل دفتر

عادةً ما تعني عبارة "حفظ نموذج TensorFlow" أحد أمرين:

  1. نقاط التفتيش ، أو
  2. تم الحفظ

تلتقط نقاط tf.Variable القيمة الدقيقة لجميع المعلمات (الكائنات tf.Variable ) المستخدمة بواسطة النموذج. لا تحتوي نقاط التحقق على أي وصف للحساب الذي يحدده النموذج ، وبالتالي فهي مفيدة فقط عندما يتوفر كود المصدر الذي سيستخدم قيم المعلمات المحفوظة.

من ناحية أخرى ، يتضمن تنسيق SavedModel وصفًا متسلسلًا للحساب المحدد بواسطة النموذج بالإضافة إلى قيم المعلمات (نقطة التحقق). النماذج في هذا التنسيق مستقلة عن الكود المصدري الذي أنشأ النموذج. وبالتالي فهي مناسبة للنشر عبر خدمة TensorFlow أو TensorFlow Lite أو TensorFlow.js أو البرامج بلغات البرمجة الأخرى (C ، C ++ ، Java ، Go ، Rust ، C # إلخ. TensorFlow APIs).

يغطي هذا الدليل واجهات برمجة التطبيقات (APIs) لكتابة وقراءة نقاط التفتيش.

يثبت

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

الحفظ من واجهات برمجة تطبيقات التدريب tf.keras

راجع دليل tf.keras حول الحفظ والاستعادة.

tf.keras.Model.save_weights ينقذ نقطة تفتيش TensorFlow.

net.save_weights('easy_checkpoint')

كتابة نقاط التفتيش

يتم تخزين الحالة المستمرة لنموذج tf.Variable في كائنات tf.Variable . يمكن إنشاء هذه بشكل مباشر ، ولكن غالبًا ما يتم إنشاؤها من خلال واجهات برمجة تطبيقات عالية المستوى مثلtf.keras.layers أو tf.keras.Model .

أسهل طريقة لإدارة المتغيرات هي إرفاقها بكائنات بايثون ، ثم الرجوع إلى تلك الكائنات.

tf.train.Checkpoint الفرعية لـ tf.train.Checkpoint و tf.keras.layers.Layer و tf.keras.Model المتغيرات المخصصة tf.keras.Model تلقائيًا. يُنشئ المثال التالي نموذجًا خطيًا بسيطًا ، ثم يكتب نقاط التحقق التي تحتوي على قيم لجميع متغيرات النموذج.

يمكنك بسهولة حفظ نقطة Model.save_weights النموذج باستخدام Model.save_weights .

التفتيش اليدوي

يثبت

للمساعدة في توضيح جميع ميزات tf.train.Checkpoint ، حدد مجموعة بيانات اللعبة وخطوة التحسين:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

إنشاء كائنات نقطة التفتيش

استخدم كائن tf.train.Checkpoint لإنشاء نقطة تحقق يدويًا ، حيث يتم تعيين الكائنات التي تريد tf.train.Checkpoint على الكائن.

يمكن أن يكون tf.train.CheckpointManager أيضًا في إدارة نقاط التفتيش المتعددة.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

تدريب وفحص النموذج

تُنشئ حلقة التدريب التالية tf.train.Checkpoint ، ثم tf.train.Checkpoint كائن tf.train.Checkpoint . يستدعي خطوة التدريب في حلقة على كل دفعة من البيانات ، ويكتب بشكل دوري نقاط التفتيش على القرص.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 30.42
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 23.83
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 17.27
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 10.81
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 4.74

استعادة ومواصلة التدريب

بعد الدورة التدريبية الأولى ، يمكنك اجتياز نموذج ومدير جديدين ، ولكن يمكنك متابعة التدريب من حيث توقفت تمامًا:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 0.85
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.87
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.71
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.46
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.21

يحذف كائن tf.train.CheckpointManager نقاط التحقق القديمة. أعلاه تم تكوينه للاحتفاظ بنقاط التفتيش الثلاثة الأخيرة فقط.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

هذه المسارات ، مثل './tf_ckpts/ckpt-10' ، ليست ملفات على القرص. بدلاً من ذلك ، فهي بادئات لملف index وواحد أو أكثر من ملفات البيانات التي تحتوي على القيم المتغيرة. يتم تجميع هذه البادئات معًا في ملف checkpoint './tf_ckpts/checkpoint' واحد ( './tf_ckpts/checkpoint' ) حيث يحفظ CheckpointManager حالته.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

ميكانيكا التحميل

يطابق TensorFlow المتغيرات مع القيم المحددة من خلال اجتياز الرسم البياني الموجه ذي الحواف المسماة ، بدءًا من الكائن الذي يتم تحميله. تأتي أسماء self.l1 = tf.keras.layers.Dense(5) عادةً من أسماء السمات في الكائنات ، على سبيل المثال "l1" في self.l1 = tf.keras.layers.Dense(5) . يستخدم tf.train.Checkpoint أسماء وسيطات الكلمات الرئيسية الخاصة به ، كما في "step" في tf.train.Checkpoint(step=...) .

يبدو الرسم البياني للتبعية من المثال أعلاه كما يلي:

تصور الرسم البياني للتبعية لمثال حلقة التدريب

المحسن باللون الأحمر ، والمتغيرات العادية باللون الأزرق ، ومتغيرات فتحة المحسن باللون البرتقالي. العقد الأخرى - على سبيل المثال ، التي تمثل tf.train.Checkpoint - هي باللون الأسود.

تعد متغيرات الفتحة جزءًا من حالة المحسن ، ولكن يتم إنشاؤها لمتغير معين. على سبيل المثال 'm' تتوافق حواف 'm' أعلاه مع الزخم ، الذي يتتبعه مُحسِّن آدم لكل متغير. يتم حفظ متغيرات الفتحة في نقطة فحص فقط إذا تم حفظ المتغير والمحسن ، وبالتالي الحواف المتقطعة.

يؤدي استدعاء restore على tf.train.Checkpoint قوائم انتظار عمليات الاستعادة المطلوبة ، واستعادة القيم المتغيرة بمجرد وجود مسار مطابق من كائن Checkpoint . على سبيل المثال ، يمكنك تحميل التحيز فقط من النموذج الذي حددته أعلاه عن طريق إعادة بناء مسار واحد إليه عبر الشبكة والطبقة.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.831489  3.7156947 2.5892444 3.8669944 4.749503 ]

الرسم البياني للتبعية لهذه الكائنات الجديدة هو رسم بياني فرعي أصغر بكثير لنقطة التحقق الأكبر التي كتبتها أعلاه. يتضمن فقط التحيز وعداد الحفظ الذي يستخدمه tf.train.Checkpoint نقاط التفتيش.

تصور الرسم البياني الفرعي لمتغير التحيز

restore كائن الحالة ، الذي يحتوي على تأكيدات اختيارية. تمت استعادة جميع الكائنات التي تم إنشاؤها في Checkpoint status.assert_existing_objects_matched الجديدة ، لذا فإن status.assert_existing_objects_matched يمر.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f1644447b70>

هناك العديد من الكائنات في نقطة التحقق غير المتطابقة ، بما في ذلك نواة الطبقة ومتغيرات المحسن. status.assert_consumed يمر فقط إذا كانت نقطة التفتيش والبرنامج status.assert_consumed تمامًا ، وسوف status.assert_consumed استثناء هنا.

الترميمات المتأخرة

قد تؤخر كائنات Layer في TensorFlow إنشاء متغيرات لاستدعائها الأول ، عندما تكون أشكال الإدخال متاحة. على سبيل المثال ، يعتمد شكل نواة الطبقة Dense على كلٍ من أشكال الإدخال والإخراج للطبقة ، وبالتالي فإن شكل المخرجات المطلوب كوسيطة مُنشئ ليس معلومات كافية لإنشاء المتغير بمفرده. نظرًا لأن استدعاء Layer يقرأ أيضًا قيمة المتغير ، يجب أن تحدث استعادة بين إنشاء المتغير واستخدامه لأول مرة.

لدعم هذا المصطلح ، يستعيد tf.train.Checkpoint قوائم الانتظار التي لا تحتوي على متغير مطابق بعد.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.5719748 4.6099544 4.931875  4.836442  4.8496275]]

التفتيش اليدوي على نقاط التفتيش

tf.train.load_checkpoint بإرجاع CheckpointReader الذي يوفر وصولاً منخفض المستوى لمحتويات نقطة التفتيش. يحتوي على تعيينات من كل مفتاح قابل للتغير ، إلى الشكل والنوع لكل متغير في نقطة التحقق. مفتاح المتغير هو مسار الكائن الخاص به ، كما هو الحال في الرسوم البيانية المعروضة أعلاه.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

لذلك إذا كنت مهتمًا بقيمة net.l1.kernel يمكنك الحصول على القيمة باستخدام الكود التالي:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

يوفر أيضًا طريقة get_tensor تسمح لك بفحص قيمة المتغير:

reader.get_tensor(key)
array([[4.5719748, 4.6099544, 4.931875 , 4.836442 , 4.8496275]],
      dtype=float32)

قائمة وتتبع القاموس

كما هو الحال مع تعيينات السمات المباشرة مثل self.l1 = tf.keras.layers.Dense(5) ، فإن تخصيص القوائم والقواميس للسمات سيؤدي إلى تتبع محتوياتها.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

قد تلاحظ كائنات مجمعة للقوائم والقواميس. هذه الأغلفة هي إصدارات يمكن التحقق منها من هياكل البيانات الأساسية. تمامًا مثل التحميل المستند إلى السمة ، تستعيد هذه الأغلفة قيمة المتغير بمجرد إضافته إلى الحاوية.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

يتم تطبيق نفس التتبع تلقائيًا على الفئات الفرعية من tf.keras.Model ، ويمكن استخدامه على سبيل المثال لتتبع قوائم الطبقات.

ملخص

توفر كائنات TensorFlow آلية تلقائية سهلة لحفظ واستعادة قيم المتغيرات التي تستخدمها.