ML Community Day คือวันที่ 9 พฤศจิกายน! ร่วมกับเราสำหรับการปรับปรุงจาก TensorFlow, JAX และอื่น ๆ เรียนรู้เพิ่มเติม

ด่านฝึก

ดูบน TensorFlow.org ทำงานใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดโน๊ตบุ๊ค

วลี "การบันทึกโมเดล TensorFlow" โดยทั่วไปหมายถึงหนึ่งในสองสิ่ง:

  1. จุดตรวจ OR
  2. โมเดลที่บันทึกไว้

จุดตรวจจับค่าที่แน่นอนของพารามิเตอร์ทั้งหมด ( tf.Variable วัตถุ) โดยใช้รูปแบบ จุดตรวจไม่มีคำอธิบายใดๆ ของการคำนวณที่กำหนดโดยโมเดล และโดยทั่วไปแล้วจะมีประโยชน์ก็ต่อเมื่อมีซอร์สโค้ดที่จะใช้ค่าพารามิเตอร์ที่บันทึกไว้เท่านั้น

ในทางกลับกัน รูปแบบ SavedModel มีคำอธิบายต่อเนื่องของการคำนวณที่กำหนดโดยโมเดล นอกเหนือจากค่าพารามิเตอร์ (จุดตรวจสอบ) โมเดลในรูปแบบนี้ไม่ขึ้นกับซอร์สโค้ดที่สร้างโมเดล ดังนั้นจึงเหมาะสำหรับการปรับใช้ผ่าน TensorFlow Serving, TensorFlow Lite, TensorFlow.js หรือโปรแกรมในภาษาการเขียนโปรแกรมอื่นๆ (C, C++, Java, Go, Rust, C# เป็นต้น TensorFlow API)

คู่มือนี้ครอบคลุมถึง API สำหรับการเขียนและการอ่านจุดตรวจ

ติดตั้ง

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

ออมทรัพย์จาก tf.keras API ของการฝึกอบรม

ดู tf.keras ให้คำแนะนำเกี่ยวกับการบันทึกและเรียกคืน

tf.keras.Model.save_weights บันทึกด่าน TensorFlow

net.save_weights('easy_checkpoint')

การเขียนจุดตรวจ

รัฐถาวรของรุ่น TensorFlow จะถูกจัดเก็บใน tf.Variable วัตถุ เหล่านี้สามารถสร้างโดยตรง แต่มักจะถูกสร้างขึ้นผ่าน API ที่ระดับสูงเช่น tf.keras.layers หรือ tf.keras.Model

วิธีที่ง่ายที่สุดในการจัดการตัวแปรคือการแนบไปกับอ็อบเจ็กต์ Python แล้วอ้างอิงอ็อบเจ็กต์เหล่านั้น

subclasses ของ tf.train.Checkpoint , tf.keras.layers.Layer และ tf.keras.Model โดยอัตโนมัติติดตามตัวแปรที่ได้รับมอบหมายให้คุณลักษณะของพวกเขา ตัวอย่างต่อไปนี้สร้างโมเดลเชิงเส้นอย่างง่าย จากนั้นจึงเขียนจุดตรวจสอบซึ่งมีค่าสำหรับตัวแปรทั้งหมดของโมเดล

คุณสามารถบันทึกแบบจำลองด่านกับ Model.save_weights

จุดตรวจด้วยตนเอง

ติดตั้ง

เพื่อช่วยในการแสดงให้เห็นถึงคุณลักษณะทั้งหมดของ tf.train.Checkpoint กำหนดชุดข้อมูลของเล่นและขั้นตอนการเพิ่มประสิทธิภาพ:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

สร้างวัตถุจุดตรวจ

ใช้ tf.train.Checkpoint วัตถุด้วยตนเองสร้างด่านที่วัตถุที่คุณต้องการที่จะมีการตั้งด่านเป็นแอตทริบิวต์บนวัตถุ

tf.train.CheckpointManager ยังสามารถเป็นประโยชน์สำหรับการจัดการหลายด่าน

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

ฝึกและจุดตรวจโมเดล

ห่วงการฝึกอบรมต่อไปนี้สร้างอินสแตนซ์ของรูปแบบและเพิ่มประสิทธิภาพแล้วรวบรวมไว้ใน tf.train.Checkpoint วัตถุ โดยจะเรียกขั้นตอนการฝึกวนเป็นชุดของข้อมูลแต่ละชุด และเขียนจุดตรวจสอบลงดิสก์เป็นระยะ

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 29.77
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 23.18
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 16.62
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 10.16
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 4.09

ฟื้นฟูและฝึกต่อ

หลังจากรอบการฝึกครั้งแรก คุณสามารถส่งต่อโมเดลและผู้จัดการคนใหม่ได้ แต่รับการฝึกอบรมตรงจุดที่คุณค้างไว้:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.33
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.90
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.62
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.27
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.22

tf.train.CheckpointManager วัตถุลบด่านเก่า ด้านบนมีการกำหนดค่าให้เก็บเฉพาะจุดตรวจล่าสุดสามจุดเท่านั้น

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

เส้นทางเหล่านี้เช่น './tf_ckpts/ckpt-10' จะไม่ได้ไฟล์บนดิสก์ แต่พวกเขามีคำนำหน้าสำหรับ index ไฟล์และหนึ่งหรือไฟล์ข้อมูลมากขึ้นซึ่งประกอบด้วยค่าตัวแปร คำนำหน้าเหล่านี้จะรวมกลุ่มกันในครั้งเดียว checkpoint ไฟล์ ( './tf_ckpts/checkpoint' ) ที่ CheckpointManager บันทึกของรัฐ

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

กำลังโหลดกลศาสตร์

TensorFlow จะจับคู่ตัวแปรกับค่าจุดตรวจสอบโดยการสำรวจกราฟที่มีชื่อขอบ โดยเริ่มจากวัตถุที่กำลังโหลด ชื่อขอบมักจะมาจากชื่อแอตทริบิวต์ในวัตถุตัวอย่างเช่น "l1" ใน self.l1 = tf.keras.layers.Dense(5) tf.train.Checkpoint ใช้ชื่อโต้แย้งคำหลักเช่นเดียวกับใน "step" ใน tf.train.Checkpoint(step=...) )

กราฟการพึ่งพาจากตัวอย่างด้านบนมีลักษณะดังนี้:

การแสดงกราฟการพึ่งพาสำหรับตัวอย่างลูปการฝึก

เครื่องมือเพิ่มประสิทธิภาพจะเป็นสีแดง ตัวแปรปกติจะเป็นสีน้ำเงิน และตัวแปรช่องเครื่องมือเพิ่มประสิทธิภาพจะเป็นสีส้ม โหนดสำหรับอื่น ๆ ตัวอย่างเช่นตัวแทน tf.train.Checkpoint สรรพในสีดำ

ตัวแปรสล็อตเป็นส่วนหนึ่งของสถานะของตัวเพิ่มประสิทธิภาพ แต่ถูกสร้างขึ้นสำหรับตัวแปรเฉพาะ ตัวอย่างเช่น 'm' ขอบข้างต้นสอดคล้องกับโมเมนตัมที่แทร็คอดัมเพิ่มประสิทธิภาพสำหรับแต่ละตัวแปร ตัวแปรสล็อตจะถูกบันทึกในจุดตรวจสอบหากทั้งตัวแปรและตัวเพิ่มประสิทธิภาพจะถูกบันทึก ดังนั้นขอบที่เป็นเส้นประ

โทร restore ใน tf.train.Checkpoint คิววัตถุบูรณะขอคืนค่าตัวแปรเร็วที่สุดเท่าที่มีการจับคู่เส้นทางจาก Checkpoint วัตถุ ตัวอย่างเช่น คุณสามารถโหลดเฉพาะความเอนเอียงจากโมเดลที่คุณกำหนดไว้ด้านบนโดยสร้างเส้นทางเดียวไปยังโมเดลดังกล่าวผ่านเครือข่ายและเลเยอร์

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[1.9851578 3.6375327 2.9331083 3.8130412 4.778274 ]

กราฟการพึ่งพาสำหรับออบเจ็กต์ใหม่เหล่านี้เป็นกราฟย่อยที่เล็กกว่ามากของจุดตรวจที่ใหญ่กว่าที่คุณเขียนไว้ด้านบน มันมีเพียงอคติและประหยัดนับที่ tf.train.Checkpoint ใช้ในการจุดตรวจจำนวน

การสร้างภาพกราฟย่อยสำหรับตัวแปรอคติ

restore ผลตอบแทนวัตถุสถานะซึ่งมีการยืนยันตัวเลือก ทั้งหมดของวัตถุที่สร้างขึ้นใหม่ใน Checkpoint ได้รับการบูรณะเพื่อ status.assert_existing_objects_matched ผ่าน

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f027807e050>

มีอ็อบเจ็กต์จำนวนมากในจุดตรวจที่ไม่ตรงกัน รวมถึงเคอร์เนลของเลเยอร์และตัวแปรของออปติไมเซอร์ status.assert_consumed เพียงผ่านหากด่านและโปรแกรมที่ตรงและจะโยนข้อยกเว้นที่นี่

การบูรณะล่าช้า

Layer วัตถุใน TensorFlow อาจล่าช้าการสร้างตัวแปรที่จะสายแรกของพวกเขาเมื่อรูปทรงการป้อนข้อมูลที่มีอยู่ ตัวอย่างเช่นรูปทรงของที่ Dense เคอร์เนลชั้นขึ้นอยู่กับทั้งสองชั้นของอินพุตและเอาต์พุตรูปร่างและรูปทรงเพื่อการส่งออกต้องเป็นอาร์กิวเมนต์คอนสตรัคไม่ได้เป็นข้อมูลเพียงพอที่จะสร้างตัวแปรในตัวเอง ตั้งแต่การเรียก Layer ยังอ่านค่าของตัวแปรที่เรียกคืนจะต้องเกิดขึ้นระหว่างการสร้างตัวแปรและใช้งานครั้งแรก

เพื่อสนับสนุนสำนวนนี้ tf.train.Checkpoint คิวคืนซึ่งยังไม่ได้ตัวแปรที่ตรงกัน

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.6800494 4.607369  4.8321466 4.816245  4.8435326]]

ตรวจสอบจุดตรวจด้วยตนเอง

tf.train.load_checkpoint ส่งกลับ CheckpointReader ที่ช่วยให้เข้าถึงระดับที่ต่ำกว่าเนื้อหาด่าน ประกอบด้วยการแมปจากคีย์ของตัวแปรแต่ละตัว ไปยังรูปร่างและ dtype ของตัวแปรแต่ละตัวในจุดตรวจสอบ คีย์ของตัวแปรคือเส้นทางของออบเจ็กต์ เช่นเดียวกับในกราฟที่แสดงด้านบน

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

ดังนั้นหากคุณกำลังสนใจในค่าของ net.l1.kernel คุณจะได้รับความคุ้มค่าที่มีรหัสต่อไปนี้:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

นอกจากนี้ยังมี get_tensor วิธีช่วยให้คุณสามารถตรวจสอบค่าของตัวแปรนี้:

reader.get_tensor(key)
array([[4.6800494, 4.607369 , 4.8321466, 4.816245 , 4.8435326]],
      dtype=float32)

การติดตามรายการและพจนานุกรม

เช่นเดียวกับที่ได้รับมอบหมายโดยตรงแอตทริบิวต์เช่น self.l1 = tf.keras.layers.Dense(5) กำหนดรายชื่อและพจนานุกรมแอตทริบิวต์จะติดตามเนื้อหาของพวกเขา

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

คุณอาจสังเกตเห็นออบเจ็กต์ของแรปเปอร์สำหรับรายการและพจนานุกรม Wrapper เหล่านี้เป็นเวอร์ชันที่ตรวจสอบได้ของโครงสร้างข้อมูลพื้นฐาน เช่นเดียวกับการโหลดตามแอตทริบิวต์ Wrapper เหล่านี้จะคืนค่าของตัวแปรทันทีที่เพิ่มลงในคอนเทนเนอร์

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

ติดตามเดียวกันถูกนำไปใช้โดยอัตโนมัติเพื่อ subclasses ของ tf.keras.Model และอาจจะใช้สำหรับตัวอย่างในการติดตามรายการของชั้น

สรุป

ออบเจ็กต์ TensorFlow ให้กลไกอัตโนมัติที่ง่ายสำหรับการบันทึกและกู้คืนค่าของตัวแปรที่ใช้