نقاط تفتيش التدريب

عرض على TensorFlow.org تشغيل في Google Colab عرض المصدر على جيثب تحميل دفتر

عادةً ما تعني عبارة "حفظ نموذج TensorFlow" أحد أمرين:

  1. نقاط التفتيش ، أو
  2. تم الحفظ

نقاط التفتيش التقاط القيمة الدقيقة لجميع المعلمات ( tf.Variable الكائنات) التي يستخدمها نموذج. لا تحتوي نقاط التحقق على أي وصف للحساب المحدد بواسطة النموذج ، وبالتالي فهي مفيدة فقط عندما يتوفر كود المصدر الذي سيستخدم قيم المعلمات المحفوظة.

من ناحية أخرى ، يتضمن تنسيق SavedModel وصفًا متسلسلًا للحساب المحدد بواسطة النموذج بالإضافة إلى قيم المعلمات (نقطة التحقق). النماذج في هذا التنسيق مستقلة عن الكود المصدري الذي أنشأ النموذج. وبالتالي فهي مناسبة للنشر عبر TensorFlow Serving أو TensorFlow Lite أو TensorFlow.js أو البرامج بلغات البرمجة الأخرى (C ، C ++ ، Java ، Go ، Rust ، C # إلخ. TensorFlow APIs).

يغطي هذا الدليل واجهات برمجة التطبيقات (APIs) لكتابة وقراءة نقاط التفتيش.

اقامة

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

إنقاذ من tf.keras واجهات برمجة التطبيقات التدريب

رؤية tf.keras دليل على إنقاذ وترميم.

tf.keras.Model.save_weights ينقذ نقطة تفتيش TensorFlow.

net.save_weights('easy_checkpoint')

كتابة نقاط التفتيش

يتم تخزين حالة المستمرة لنموذج TensorFlow في tf.Variable الكائنات. هذه يمكن بناؤها بشكل مباشر، ولكن غالبا ما تنشأ من خلال واجهات برمجة التطبيقات عالية المستوى مثل tf.keras.layers أو tf.keras.Model .

أسهل طريقة لإدارة المتغيرات هي إرفاقها بكائنات بايثون ، ثم الرجوع إلى تلك الكائنات.

فرعية من tf.train.Checkpoint ، tf.keras.layers.Layer ، و tf.keras.Model تلقائيا تتبع المتغيرات المخصصة لصفاتهم. يُنشئ المثال التالي نموذجًا خطيًا بسيطًا ، ثم يكتب نقاط التحقق التي تحتوي على قيم لجميع متغيرات النموذج.

يمكنك حفظ بسهولة نموذجا نقطة تفتيش مع Model.save_weights .

التفتيش اليدوي

اقامة

للمساعدة في توضيح كل ملامح tf.train.Checkpoint ، وتحديد مجموعة بيانات لعبة وخطوة الأمثل:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

إنشاء كائنات نقطة التفتيش

استخدام tf.train.Checkpoint الكائن يدويا إنشاء نقطة تفتيش، حيث يتم تعيين الكائنات التي تريد تفتيش كما السمات على الكائن.

A tf.train.CheckpointManager يمكن أيضا أن تكون مفيدة لإدارة العديد من نقاط التفتيش.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

تدريب وفحص النموذج

الحلقة التدريبية التالية يخلق مثيل من نموذج ومحسن، ثم تجمع لهم في tf.train.Checkpoint الكائن. يستدعي خطوة التدريب في حلقة على كل دفعة من البيانات ، ويكتب بشكل دوري نقاط التفتيش على القرص.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 29.67
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 23.09
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 16.53
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 10.10
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 4.39

استعادة ومواصلة التدريب

بعد الدورة التدريبية الأولى ، يمكنك اجتياز نموذج ومدير جديدين ، ولكن يمكنك متابعة التدريب من حيث توقفت تمامًا:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 0.64
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 1.17
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.69
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.34
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.19

و tf.train.CheckpointManager الكائن حذف نقاط التفتيش القديمة. تم تكوينه أعلاه للاحتفاظ بأحدث ثلاث نقاط تفتيش فقط.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

هذه المسارات، على سبيل المثال './tf_ckpts/ckpt-10' ، ليست ملفات على القرص. وبدلا من ذلك فهي البادئات ل index الملف وملفات البيانات واحد أو أكثر التي تحتوي على قيم متغير. يتم تجميع هذه البادئات معا في واحد checkpoint ملف ( './tf_ckpts/checkpoint' ) حيث CheckpointManager ينقذ حالته.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

ميكانيكا التحميل

يطابق TensorFlow المتغيرات مع القيم المحددة من خلال اجتياز الرسم البياني الموجه ذي الحواف المسماة ، بدءًا من الكائن الذي يتم تحميله. تأتي أسماء حافة عادة من أسماء السمة في الأشياء، على سبيل المثال "l1" في self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint يستخدم أسماء حجة رئيسية لها، كما هو الحال في "step" في tf.train.Checkpoint(step=...) .

يبدو الرسم البياني للتبعية من المثال أعلاه كما يلي:

تصور الرسم البياني للتبعية لمثال حلقة التدريب

المحسن باللون الأحمر ، والمتغيرات العادية باللون الأزرق ، ومتغيرات فتحة المحسن باللون البرتقالي. العقد لغيره سبيل المثال، تمثل tf.train.Checkpoint -هل باللون الأسود.

تعد متغيرات الفتحة جزءًا من حالة المحسن ، ولكن يتم إنشاؤها لمتغير معين. على سبيل المثال 'm' حواف فوق تتوافق مع الزخم الذي المسارات محسن آدم لكل متغير. يتم حفظ متغيرات الفتحة في نقطة فحص فقط إذا تم حفظ المتغير والمحسن ، وبالتالي الحواف المتقطعة.

داعيا restore على tf.train.Checkpoint طوابير الكائن الترميم المطلوب، واستعادة قيم متغير في أقرب وقت هناك مسار مطابق من Checkpoint الكائن. على سبيل المثال ، يمكنك تحميل التحيز فقط من النموذج الذي حددته أعلاه عن طريق إعادة بناء مسار واحد إليه عبر الشبكة والطبقة.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.3119967 2.088805  3.9098527 3.9504364 4.7226586]

الرسم البياني للتبعية لهذه الكائنات الجديدة هو رسم بياني فرعي أصغر بكثير لنقطة التحقق الأكبر التي كتبتها أعلاه. إنها لا تحتوي إلا على التحيز وحفظ المضادة التي tf.train.Checkpoint يستخدم لنقاط التفتيش العدد.

تصور الرسم البياني الفرعي لمتغير التحيز

restore عوائد كائن المركز، الذي لديه تأكيدات اختيارية. كافة الكائنات التي تم إنشاؤها في جديد Checkpoint تم تجديده، لذلك status.assert_existing_objects_matched يمر.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f967c028e10>

هناك العديد من الكائنات في نقطة التحقق غير المتطابقة ، بما في ذلك نواة الطبقة ومتغيرات المحسن. status.assert_consumed يمر إلا إذا كان تفتيش والمباراة البرنامج بالضبط، وسوف بطرح استثناء هنا.

الترميمات المتأخرة

Layer الكائنات في TensorFlow قد تؤخر إنشاء المتغيرات لدعوتهم الأولى، عندما تكون الأشكال المدخلات المتاحة. على سبيل المثال على شكل Dense نواة طبقة يعتمد على كل من الطبقة المدخلات والمخرجات الأشكال، وذلك على الشكل الناتج المطلوب كوسيطة منشئ ليس ما يكفي من المعلومات لإنشاء متغير من تلقاء نفسها. منذ استدعاء Layer أيضا يقرأ قيمة للمتغير، واستعادة يجب أن يحدث بين إنشاء متغير وأول استخدام لها.

لدعم هذا المصطلح، tf.train.Checkpoint قوائم الانتظار يعيد الذي لم يكن لديك متغير مطابقة.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.603108  4.814235  4.7161555 4.818163  4.8451676]]

التفتيش اليدوي على نقاط التفتيش

tf.train.load_checkpoint إرجاع CheckpointReader الذي يعطي أقل مستوى الوصول إلى محتويات نقطة تفتيش. يحتوي على تعيينات من مفتاح كل متغير ، إلى الشكل والنوع لكل متغير في نقطة التحقق. مفتاح المتغير هو مسار الكائن الخاص به ، كما هو الحال في الرسوم البيانية المعروضة أعلاه.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

لذلك إذا كنت ترغب في قيمة net.l1.kernel يمكنك الحصول على قيمة مع التعليمات البرمجية التالية:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

كما يوفر get_tensor طريقة مما يسمح لك لتفقد قيمة متغير:

reader.get_tensor(key)
array([[4.603108 , 4.814235 , 4.7161555, 4.818163 , 4.8451676]],
      dtype=float32)

قائمة وتتبع القاموس

كما هو الحال مع الاحالات المباشرة السمة مثل self.l1 = tf.keras.layers.Dense(5) ، وتكليف القوائم والقواميس إلى سمات سوف تتبع محتوياتها.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

قد تلاحظ كائنات مجمعة للقوائم والقواميس. هذه الأغلفة هي إصدارات يمكن التحقق منها من هياكل البيانات الأساسية. تمامًا مثل التحميل المستند إلى السمة ، تستعيد هذه الأغلفة قيمة المتغير بمجرد إضافته إلى الحاوية.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

يتم تطبيق تتبع نفس تلقائيا إلى أقسام فرعية من tf.keras.Model ، ويمكن استخدامها على سبيل المثال لتعقب قوائم طبقات.

ملخص

توفر كائنات TensorFlow آلية تلقائية سهلة لحفظ واستعادة قيم المتغيرات التي تستخدمها.