หน้านี้ได้รับการแปลโดย Cloud Translation API
Switch to English

จุดตรวจการฝึกอบรม

ดูใน TensorFlow.org เรียกใช้ใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดสมุดบันทึก

วลี "การบันทึกแบบจำลอง TensorFlow" โดยทั่วไปหมายถึงหนึ่งในสองสิ่ง:

  1. จุดตรวจหรือ
  2. บันทึกไว้

จุดตรวจจับค่าที่แน่นอนของพารามิเตอร์ทั้งหมด ( tf.Variable objects) ที่โมเดลใช้ จุดตรวจไม่มีคำอธิบายใด ๆ ของการคำนวณที่กำหนดโดยโมเดลดังนั้นโดยทั่วไปจะมีประโยชน์ก็ต่อเมื่อมีซอร์สโค้ดที่จะใช้ค่าพารามิเตอร์ที่บันทึกไว้เท่านั้น

ในทางกลับกันรูปแบบที่บันทึกไว้มีคำอธิบายแบบอนุกรมของการคำนวณที่กำหนดโดยโมเดลนอกเหนือจากค่าพารามิเตอร์ (จุดตรวจสอบ) โมเดลในรูปแบบนี้ไม่ขึ้นอยู่กับซอร์สโค้ดที่สร้างโมเดล จึงเหมาะสำหรับการปรับใช้ผ่าน TensorFlow Serving, TensorFlow Lite, TensorFlow.js หรือโปรแกรมในภาษาโปรแกรมอื่น ๆ (C, C ++, Java, Go, Rust, C # เป็นต้น TensorFlow APIs)

คู่มือนี้ครอบคลุม API สำหรับการเขียนและการอ่านจุดตรวจ

ติดตั้ง

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

บันทึกจาก API การฝึกอบรม tf.keras

ดูคู่มือ tf.keras เกี่ยวกับการบันทึกและการกู้คืน

tf.keras.Model.save_weights บันทึกจุดตรวจ TensorFlow

net.save_weights('easy_checkpoint')

การเขียนจุดตรวจ

สถานะถาวรของโมเดล TensorFlow ถูกเก็บไว้ใน tf.Variable สิ่งเหล่านี้สามารถสร้างได้โดยตรง แต่มักสร้างผ่าน API ระดับสูงเช่น tf.keras.layers หรือ tf.keras.Model

วิธีที่ง่ายที่สุดในการจัดการตัวแปรคือการแนบเข้ากับวัตถุ Python จากนั้นอ้างอิงวัตถุเหล่านั้น

คลาสย่อยของ tf.train.Checkpoint , tf.keras.layers.Layer และ tf.keras.Model จะติดตามตัวแปรที่กำหนดให้กับแอ็ตทริบิวต์โดยอัตโนมัติ ตัวอย่างต่อไปนี้สร้างโมเดลเชิงเส้นอย่างง่ายจากนั้นเขียนจุดตรวจที่มีค่าสำหรับตัวแปรทั้งหมดของโมเดล

คุณสามารถบันทึกจุดตรวจโมเดลได้อย่างง่ายดายด้วย Model.save_weights

การตรวจด้วยตนเอง

ติดตั้ง

เพื่อช่วยสาธิตคุณสมบัติทั้งหมดของ tf.train.Checkpoint กำหนดชุดข้อมูลของเล่นและขั้นตอนการปรับให้เหมาะสม:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

สร้างวัตถุจุดตรวจ

ในการสร้างจุดตรวจด้วยตนเองคุณจะต้องมีวัตถุ tf.train.Checkpoint โดยที่วัตถุที่คุณต้องการตรวจสอบจะถูกตั้งค่าเป็นแอตทริบิวต์บนวัตถุ

tf.train.CheckpointManager ยังมีประโยชน์สำหรับการจัดการจุดตรวจหลายจุด

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

ฝึกและตรวจสอบโมเดล

ลูปการฝึกอบรมต่อไปนี้จะสร้างอินสแตนซ์ของแบบจำลองและของเครื่องมือเพิ่มประสิทธิภาพจากนั้นรวบรวมไว้ในวัตถุ tf.train.Checkpoint เรียกขั้นตอนการฝึกเป็นวงในแต่ละชุดข้อมูลและเขียนจุดตรวจลงในดิสก์เป็นระยะ

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 32.40
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 25.82
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 19.26
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 12.77
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 6.48

ฟื้นฟูและดำเนินการฝึกอบรมต่อไป

หลังจากครั้งแรกคุณสามารถผ่านโมเดลและผู้จัดการใหม่ได้ แต่การฝึกอบรมการรับสินค้าตรงจุดที่คุณทำค้างไว้:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.85
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.88
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.44
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.41
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.25

วัตถุ tf.train.CheckpointManager ลบจุดตรวจเก่า ด้านบนมีการกำหนดค่าให้เก็บเฉพาะสามด่านล่าสุด

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

เส้นทางเหล่านี้เช่น './tf_ckpts/ckpt-10' ไม่ใช่ไฟล์บนดิสก์ แต่เป็นคำนำหน้าสำหรับไฟล์ index และไฟล์ข้อมูลอย่างน้อยหนึ่งไฟล์ที่มีค่าตัวแปร คำนำหน้าเหล่านี้ถูกจัดกลุ่มเข้าด้วยกันในไฟล์ checkpoint เดียว ( './tf_ckpts/checkpoint' ) โดยที่ CheckpointManager บันทึกสถานะ

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

กำลังโหลดกลศาสตร์

TensorFlow จับคู่ตัวแปรกับค่าที่ตรวจสอบโดยการข้ามกราฟกำกับด้วยขอบที่มีชื่อโดยเริ่มจากวัตถุที่กำลังโหลด ชื่อ Edge มักมาจากชื่อแอตทริบิวต์ในออบเจ็กต์ตัวอย่างเช่น "l1" ใน self.l1 = tf.keras.layers.Dense(5) tf.train.Checkpoint ใช้ชื่ออาร์กิวเมนต์คำหลักเช่นเดียวกับใน "step" ใน tf.train.Checkpoint(step=...)

กราฟการอ้างอิงจากตัวอย่างด้านบนมีลักษณะดังนี้:

การแสดงกราฟการอ้างอิงสำหรับลูปการฝึกอบรมตัวอย่าง

ด้วยเครื่องมือเพิ่มประสิทธิภาพเป็นสีแดงตัวแปรปกติเป็นสีน้ำเงินและตัวแปรสล็อตเครื่องมือเพิ่มประสิทธิภาพเป็นสีส้ม โหนดอื่น ๆ เช่นแทน tf.train.Checkpoint เป็นสีดำ

ตัวแปรสล็อตเป็นส่วนหนึ่งของสถานะของเครื่องมือเพิ่มประสิทธิภาพ แต่ถูกสร้างขึ้นสำหรับตัวแปรเฉพาะ ตัวอย่างเช่นขอบ 'm' ด้านบนตรงกับโมเมนตัมซึ่ง Adam Optimizer ติดตามตัวแปรแต่ละตัว ตัวแปรสล็อตจะถูกบันทึกเฉพาะในจุดตรวจสอบหากตัวแปรและเครื่องมือเพิ่มประสิทธิภาพจะถูกบันทึกทั้งคู่ดังนั้นขอบประ

การเรียก restore() บน tf.train.Checkpoint วัตถุ tf.train.Checkpoint จะจัดคิวการคืนค่าที่ร้องขอโดยเรียกคืนค่าตัวแปรทันทีที่มีเส้นทางที่ตรงกันจากวัตถุ Checkpoint ตัวอย่างเช่นคุณสามารถโหลดเฉพาะอคติจากแบบจำลองที่เรากำหนดไว้ข้างต้นโดยการสร้างเส้นทางขึ้นมาใหม่ผ่านเครือข่ายและเลเยอร์

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[3.4461102 3.030825  4.4315968 3.5077076 4.7258596]

กราฟการอ้างอิงสำหรับวัตถุใหม่เหล่านี้เป็นย่อหน้าย่อยที่เล็กกว่าของจุดตรวจขนาดใหญ่ที่เราเขียนไว้ข้างต้น มีเฉพาะอคติและตัวนับบันทึกที่ tf.train.Checkpoint ใช้เพื่อตรวจนับจำนวน

การแสดงภาพของกราฟย่อยสำหรับตัวแปร bias

restore() ส่งคืนออบเจ็กต์สถานะซึ่งมีการยืนยันเพิ่มเติม อ็อบเจ็กต์ทั้งหมดที่สร้างใน Checkpoint ใหม่ได้รับการกู้คืนแล้วดังนั้น status.assert_existing_objects_matched() ผ่านไป

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fec144bd080>

มีออบเจ็กต์มากมายในจุดตรวจที่ยังไม่ตรงกันรวมถึงเคอร์เนลของเลเยอร์และตัวแปรของเครื่องมือเพิ่มประสิทธิภาพ status.assert_consumed() จะผ่านก็ต่อเมื่อจุดตรวจและโปรแกรมตรงกันทั้งหมดและจะมีข้อยกเว้นที่นี่

การบูรณะล่าช้า

Layer อบเจ็กต์ใน TensorFlow อาจทำให้การสร้างตัวแปรในการเรียกครั้งแรกล่าช้าเมื่อมีรูปทรงอินพุต ตัวอย่างเช่นรูปร่างของเคอร์เนลของเลเยอร์ Dense ขึ้นอยู่กับทั้งอินพุตและรูปร่างเอาต์พุตของเลเยอร์ดังนั้นรูปร่างเอาต์พุตที่ต้องการเป็นอาร์กิวเมนต์ตัวสร้างจึงมีข้อมูลไม่เพียงพอที่จะสร้างตัวแปรด้วยตัวเอง เนื่องจากการเรียก Layer จะอ่านค่าของตัวแปรการคืนค่าจะต้องเกิดขึ้นระหว่างการสร้างตัวแปรและการใช้งานครั้งแรก

เพื่อรองรับสำนวนนี้ tf.train.Checkpoint จะคืนค่าซึ่งยังไม่มีตัวแปรที่ตรงกัน

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.4598393 4.677273  4.655946  4.926899  4.79748  ]]

ตรวจสอบจุดตรวจด้วยตนเอง

tf.train.list_variables แสดงรายการคีย์จุดตรวจและรูปร่างของตัวแปรในจุดตรวจ ปุ่มจุดตรวจคือเส้นทางในกราฟที่แสดงด้านบน

tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts/'))
[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
 ('iterator/.ATTRIBUTES/ITERATOR_STATE', [1]),
 ('net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('save_counter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('step/.ATTRIBUTES/VARIABLE_VALUE', [])]

การติดตามรายการและพจนานุกรม

เช่นเดียวกับการกำหนดคุณสมบัติโดยตรงเช่น self.l1 = tf.keras.layers.Dense(5) กำหนดรายการและพจนานุกรมให้กับแอตทริบิวต์จะติดตามเนื้อหา

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

คุณอาจสังเกตเห็นวัตถุ Wrapper สำหรับรายการและพจนานุกรม Wrapper เหล่านี้เป็นเวอร์ชันที่สามารถตรวจสอบได้ของโครงสร้างข้อมูลพื้นฐาน เช่นเดียวกับการโหลดตามแอตทริบิวต์ Wrapper เหล่านี้จะกู้คืนค่าของตัวแปรทันทีที่เพิ่มลงในคอนเทนเนอร์

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

การติดตามเดียวกันจะนำไปใช้กับคลาสย่อยของ tf.keras.Model โดยอัตโนมัติและอาจใช้เป็นตัวอย่างเพื่อติดตามรายการของเลเยอร์

สรุป

วัตถุ TensorFlow เป็นกลไกอัตโนมัติที่ใช้งานง่ายสำหรับการบันทึกและเรียกคืนค่าของตัวแปรที่ใช้