ترجمت واجهة Cloud Translation API‏ هذه الصفحة.
Switch to English

نقاط تفتيش التدريب

عرض على TensorFlow.org تشغيل في Google Colab عرض المصدر على جيثب تحميل دفتر

عادةً ما تعني عبارة "حفظ نموذج TensorFlow" أحد أمرين:

  1. نقاط التفتيش ، أو
  2. تم الحفظ

تلتقط نقاط tf.Variable القيمة الدقيقة لجميع المعلمات (الكائنات tf.Variable ) المستخدمة بواسطة النموذج. لا تحتوي نقاط التحقق على أي وصف للحساب الذي يحدده النموذج ، وبالتالي فهي مفيدة فقط عندما يتوفر كود المصدر الذي سيستخدم قيم المعلمات المحفوظة.

من ناحية أخرى ، يتضمن تنسيق SavedModel وصفًا متسلسلًا للحساب المحدد بواسطة النموذج بالإضافة إلى قيم المعلمات (نقطة التحقق). النماذج في هذا التنسيق مستقلة عن الكود المصدري الذي أنشأ النموذج. وبالتالي فهي مناسبة للنشر عبر TensorFlow Serving أو TensorFlow Lite أو TensorFlow.js أو البرامج بلغات البرمجة الأخرى (C ، C ++ ، Java ، Go ، Rust ، C # إلخ. TensorFlow APIs).

يغطي هذا الدليل واجهات برمجة التطبيقات (APIs) لكتابة وقراءة نقاط التفتيش

اقامة

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

الحفظ من واجهات برمجة تطبيقات التدريب tf.keras

راجع دليل tf.keras حول الحفظ والاستعادة.

tf.keras.Model.save_weights ينقذ نقطة تفتيش TensorFlow.

net.save_weights('easy_checkpoint')

كتابة نقاط التفتيش

يتم تخزين الحالة المستمرة لنموذج tf.Variable في كائنات tf.Variable . يمكن إنشاء هذه بشكل مباشر ، ولكن غالبًا ما يتم إنشاؤها من خلال واجهات برمجة تطبيقات عالية المستوى مثل tf.keras.layers أو tf.keras.Model .

أسهل طريقة لإدارة المتغيرات هي إرفاقها بكائنات بايثون ، ثم الرجوع إلى تلك الكائنات.

tf.train.Checkpoint الفرعية لـ tf.train.Checkpoint و tf.keras.layers.Layer و tf.keras.Model المتغيرات المعينة tf.keras.Model تلقائيًا. يوضح المثال التالي نموذجًا خطيًا بسيطًا ، ثم يكتب نقاط التحقق التي تحتوي على قيم لجميع متغيرات النموذج.

يمكنك بسهولة حفظ نقطة Model.save_weights النموذج باستخدام Model.save_weights

التفتيش اليدوي

اقامة

للمساعدة في توضيح جميع ميزات tf.train.Checkpoint حدد مجموعة بيانات اللعبة وخطوة التحسين:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

إنشاء كائنات نقطة التفتيش

لإنشاء نقطة تفتيش يدويًا ، ستحتاج إلى كائن tf.train.Checkpoint . حيث يتم تعيين الكائنات التي تريد تحديدها كسمات على الكائن.

يمكن أن يكون tf.train.CheckpointManager أيضًا لإدارة نقاط التفتيش المتعددة.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

تدريب وفحص النموذج

تُنشئ حلقة التدريب التالية tf.train.Checkpoint ، ثم تجمعهم في كائن tf.train.Checkpoint . يستدعي خطوة التدريب في حلقة على كل دفعة من البيانات ، ويكتب بشكل دوري نقاط التفتيش على القرص.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 28.15
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 21.56
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 15.00
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 8.52
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 3.25

استعادة ومواصلة التدريب

بعد الأول يمكنك اجتياز نموذج ومدير جديدين ، لكن التدريب على الالتحاق بالضبط من حيث توقفت:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.76
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.65
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.51
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.34
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.25

يحذف الكائن tf.train.CheckpointManager نقاط التحقق القديمة. أعلاه تم تكوينه للاحتفاظ بنقاط التفتيش الثلاثة الأخيرة فقط.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

هذه المسارات ، مثل './tf_ckpts/ckpt-10' ، ليست ملفات على القرص. بدلاً من ذلك ، فهي بادئات لملف index وملف بيانات واحد أو أكثر يحتوي على القيم المتغيرة. يتم تجميع هذه البادئات معًا في ملف checkpoint واحد ( './tf_ckpts/checkpoint' ) حيث يحفظ CheckpointManager حالته.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

ميكانيكا التحميل

يطابق TensorFlow المتغيرات مع القيم المحددة من خلال اجتياز الرسم البياني الموجه ذي الحواف المسماة ، بدءًا من الكائن الذي يتم تحميله. تأتي أسماء self.l1 = tf.keras.layers.Dense(5) عادةً من أسماء السمات في الكائنات ، على سبيل المثال "l1" في self.l1 = tf.keras.layers.Dense(5) . يستخدم tf.train.Checkpoint أسماء وسيطات الكلمات الرئيسية الخاصة به ، كما في "step" في tf.train.Checkpoint(step=...) .

يبدو الرسم البياني للتبعية من المثال أعلاه كما يلي:

تصور الرسم البياني للتبعية لمثال حلقة التدريب

مع وجود المحسن باللون الأحمر ، والمتغيرات العادية باللون الأزرق ، ومتغيرات فتحة المحسن باللون البرتقالي. العقد الأخرى ، على سبيل المثال التي تمثل tf.train.Checkpoint ، تكون سوداء.

تعد متغيرات الفتحة جزءًا من حالة المحسن ، ولكن يتم إنشاؤها لمتغير معين. على سبيل المثال 'm' تتوافق حواف 'm' أعلاه مع الزخم ، الذي يتتبعه مُحسِّن آدم لكل متغير. يتم حفظ متغيرات الفتحة فقط في نقطة فحص إذا تم حفظ المتغير والمحسن ، وبالتالي الحواف المتقطعة.

يؤدي استدعاء restore() على tf.train.Checkpoint قوائم انتظار عمليات الاستعادة المطلوبة ، واستعادة القيم المتغيرة بمجرد وجود مسار مطابق من كائن Checkpoint . على سبيل المثال ، يمكننا تحميل التحيز فقط من النموذج الذي حددناه أعلاه من خلال إعادة بناء مسار واحد إليه من خلال الشبكة والطبقة.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # We get the restored value now
[0. 0. 0. 0. 0.]
[3.5548685 2.8931093 2.3509905 3.5525272 4.017799 ]

الرسم البياني للتبعية لهذه الكائنات الجديدة هو رسم بياني فرعي أصغر بكثير لنقطة التفتيش الأكبر التي كتبناها أعلاه. يتضمن فقط التحيز وعداد الحفظ الذي يستخدمه tf.train.Checkpoint نقاط التفتيش.

تصور الرسم البياني الفرعي لمتغير التحيز

تُرجع restore() كائن الحالة ، الذي يحتوي على تأكيدات اختيارية. تمت استعادة جميع الكائنات التي أنشأناها في Checkpoint الجديد ، لذا فإن status.assert_existing_objects_matched() يمر.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fea0c3c3860>

هناك العديد من الكائنات في نقطة التحقق غير المتطابقة ، بما في ذلك نواة الطبقة ومتغيرات المحسن. status.assert_consumed() يمر فقط إذا كانت نقطة التفتيش والبرنامج status.assert_consumed() تمامًا ، وسوف status.assert_consumed() استثناء هنا.

الترميمات المتأخرة

قد تؤخر كائنات Layer في TensorFlow إنشاء متغيرات لاستدعائها الأول ، عندما تكون أشكال الإدخال متاحة. على سبيل المثال ، يعتمد شكل نواة الطبقة Dense على كل من أشكال الإدخال والإخراج للطبقة ، وبالتالي فإن شكل الإخراج المطلوب كوسيطة مُنشئ ليس معلومات كافية لإنشاء المتغير بمفرده. نظرًا لأن استدعاء Layer يقرأ أيضًا قيمة المتغير ، يجب أن تحدث استعادة بين إنشاء المتغير واستخدامه لأول مرة.

لدعم هذا المصطلح ، يستعيد tf.train.Checkpoint قوائم الانتظار التي لا تحتوي على متغير مطابق بعد.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.453001  4.6668463 4.9372597 4.90143   4.9549575]]

التفتيش اليدوي على نقاط التفتيش

يسرد tf.train.list_variables مفاتيح نقاط التحقق وأشكال المتغيرات في نقطة التحقق. مفاتيح نقاط التفتيش هي مسارات في الرسم البياني المعروض أعلاه.

tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts/'))
[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
 ('iterator/.ATTRIBUTES/ITERATOR_STATE', [1]),
 ('net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('save_counter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('step/.ATTRIBUTES/VARIABLE_VALUE', [])]

قائمة وتتبع القاموس

كما هو الحال مع تعيينات السمات المباشرة مثل self.l1 = tf.keras.layers.Dense(5) ، فإن تخصيص القوائم والقواميس للسمات سيؤدي إلى تتبع محتوياتها.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

قد تلاحظ كائنات مجمعة للقوائم والقواميس. هذه الأغلفة هي إصدارات يمكن التحقق منها من هياكل البيانات الأساسية. تمامًا مثل التحميل المستند إلى السمة ، تستعيد هذه الأغلفة قيمة المتغير بمجرد إضافته إلى الحاوية.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

يتم تطبيق نفس التتبع تلقائيًا على الفئات الفرعية من tf.keras.Model ، ويمكن استخدامه على سبيل المثال لتتبع قوائم الطبقات.

حفظ نقاط التفتيش القائمة على الكائن باستخدام المقدر

انظر دليل المقدر .

يقوم المقدّرون افتراضيًا بحفظ نقاط التحقق بأسماء متغيرة بدلاً من الرسم البياني للكائن الموضح في الأقسام السابقة. tf.train.Checkpoint سوف يقبل نقاط التحقق المستندة إلى الاسم ، لكن أسماء المتغيرات قد تتغير عند نقل أجزاء من نموذج خارج model_fn الخاص model_fn . يسهل حفظ نقاط التحقق المستندة إلى الكائن من تدريب نموذج داخل مقدر ثم استخدامه خارج أحد النماذج.

import tensorflow.compat.v1 as tf_compat
def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 4.388644, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 34.98601.

<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7fea648fbf60>

tf.train.Checkpoint يمكن بعد ذلك تحميل نقاط التحقق الخاصة model_dir من model_dir الخاص model_dir .

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # From est.train(..., steps=10)
10

ملخص

توفر كائنات TensorFlow آلية تلقائية سهلة لحفظ واستعادة قيم المتغيرات التي تستخدمها.