เข้าร่วม Women in ML Symposium ในวันที่ 7 ธันวาคม ลงทะเบียนตอนนี้

ด่านฝึก

จัดทุกอย่างให้เป็นระเบียบอยู่เสมอด้วยคอลเล็กชัน บันทึกและจัดหมวดหมู่เนื้อหาตามค่ากำหนดของคุณ

ดูบน TensorFlow.org ทำงานใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดโน๊ตบุ๊ค

วลี "การบันทึกโมเดล TensorFlow" โดยทั่วไปหมายถึงหนึ่งในสองสิ่ง:

  1. จุดตรวจ OR
  2. โมเดลที่บันทึกไว้

จุดตรวจจับค่าที่แน่นอนของพารามิเตอร์ทั้งหมด ( tf.Variable อบเจ็กต์) ที่ใช้โดยโมเดล จุดตรวจไม่มีคำอธิบายใดๆ ของการคำนวณที่กำหนดโดยโมเดล และโดยทั่วไปแล้วจะมีประโยชน์ก็ต่อเมื่อมีซอร์สโค้ดที่จะใช้ค่าพารามิเตอร์ที่บันทึกไว้เท่านั้น

ในทางกลับกัน รูปแบบ SavedModel รวมคำอธิบายต่อเนื่องของการคำนวณที่กำหนดโดยโมเดล นอกเหนือจากค่าพารามิเตอร์ (จุดตรวจสอบ) โมเดลในรูปแบบนี้ไม่ขึ้นกับซอร์สโค้ดที่สร้างโมเดล ดังนั้นจึงเหมาะสำหรับการปรับใช้ผ่าน TensorFlow Serving, TensorFlow Lite, TensorFlow.js หรือโปรแกรมในภาษาการเขียนโปรแกรมอื่นๆ (C, C++, Java, Go, Rust, C# เป็นต้น TensorFlow API)

คู่มือนี้ครอบคลุม API สำหรับการเขียนและการอ่านจุดตรวจ

ติดตั้ง

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

กำลังบันทึกจาก API การฝึกอบรม tf.keras

ดูคู่มือ tf.keras เกี่ยวกับการบันทึกและการกู้คืน

tf.keras.Model.save_weights บันทึกจุดตรวจ TensorFlow

net.save_weights('easy_checkpoint')

เขียนด่าน

สถานะคงอยู่ของโมเดล TensorFlow ถูกเก็บไว้ใน tf.Variable สิ่งเหล่านี้สามารถสร้างได้โดยตรง แต่มักจะสร้างผ่าน API ระดับสูง เช่น tf.keras.layers หรือ tf.keras.Model

วิธีที่ง่ายที่สุดในการจัดการตัวแปรคือการแนบไปกับอ็อบเจ็กต์ Python แล้วอ้างอิงอ็อบเจ็กต์เหล่านั้น

คลาสย่อยของ tf.train.Checkpoint , tf.keras.layers.Layer และ tf.keras.Model ติดตามตัวแปรที่กำหนดให้กับแอตทริบิวต์โดยอัตโนมัติ ตัวอย่างต่อไปนี้สร้างโมเดลเชิงเส้นอย่างง่าย จากนั้นจึงเขียนจุดตรวจสอบซึ่งมีค่าสำหรับตัวแปรทั้งหมดของโมเดล

คุณสามารถบันทึก model-checkpoint ได้อย่างง่ายดายด้วย Model.save_weights

จุดตรวจด้วยตนเอง

ติดตั้ง

เพื่อช่วยสาธิตคุณลักษณะทั้งหมดของ tf.train.Checkpoint ให้กำหนดชุดข้อมูลของเล่นและขั้นตอนการเพิ่มประสิทธิภาพ:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

สร้างวัตถุจุดตรวจ

ใช้อ็อบเจ็กต์ tf.train.Checkpoint เพื่อสร้างจุดตรวจด้วยตนเอง โดยที่ออบเจ็กต์ที่คุณต้องการตรวจสอบถูกตั้งค่าเป็นแอตทริบิวต์บนวัตถุ

tf.train.CheckpointManager ยังมีประโยชน์สำหรับการจัดการจุดตรวจหลายจุด

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

ฝึกและจุดตรวจโมเดล

ลูปการฝึกต่อไปนี้จะสร้างอินสแตนซ์ของโมเดลและเครื่องมือเพิ่มประสิทธิภาพ จากนั้นรวบรวมเป็นอ็อบเจ็กต์ tf.train.Checkpoint โดยเรียกขั้นตอนการฝึกวนเป็นชุดของข้อมูลแต่ละชุด และเขียนจุดตรวจสอบลงดิสก์เป็นระยะ

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 31.27
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 24.68
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 18.12
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 11.65
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 5.39

ฟื้นฟูและฝึกต่อ

หลังจากรอบการฝึกครั้งแรก คุณสามารถส่งต่อโมเดลและผู้จัดการคนใหม่ได้ แต่รับการฝึกอบรมตรงจุดที่คุณค้างไว้:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.50
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 1.27
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.56
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.70
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.35

ออบเจ็กต์ tf.train.CheckpointManager ลบจุดตรวจเก่า ด้านบนมีการกำหนดค่าให้เก็บเฉพาะจุดตรวจล่าสุดสามจุดเท่านั้น

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

เส้นทางเหล่านี้ เช่น './tf_ckpts/ckpt-10' ไม่ใช่ไฟล์บนดิสก์ แต่เป็นคำนำหน้าสำหรับไฟล์ index และไฟล์ข้อมูลอย่างน้อยหนึ่งไฟล์ที่มีค่าตัวแปร คำนำหน้าเหล่านี้จัดกลุ่มไว้ด้วยกันในไฟล์ checkpoint เดียว ( './tf_ckpts/checkpoint' ) โดยที่ CheckpointManager จะบันทึกสถานะ

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

กำลังโหลดกลศาสตร์

TensorFlow จะจับคู่ตัวแปรกับค่าจุดตรวจสอบโดยการสำรวจกราฟที่มีชื่อขอบ โดยเริ่มจากวัตถุที่กำลังโหลด ชื่อขอบมักจะมาจากชื่อแอตทริบิวต์ในวัตถุ เช่น "l1" ใน self.l1 = tf.keras.layers.Dense(5) tf.train.Checkpoint ใช้ชื่ออาร์กิวเมนต์ของคีย์เวิร์ด เช่นเดียวกับใน "step" ใน tf.train.Checkpoint(step=...)

กราฟการพึ่งพาจากตัวอย่างด้านบนมีลักษณะดังนี้:

การแสดงภาพกราฟการพึ่งพาสำหรับลูปการฝึกตัวอย่าง

เครื่องมือเพิ่มประสิทธิภาพจะเป็นสีแดง ตัวแปรปกติจะเป็นสีน้ำเงิน และตัวแปรช่องเครื่องมือเพิ่มประสิทธิภาพจะเป็นสีส้ม โหนดอื่นๆ เช่น แทน tf.train.Checkpoint เป็นสีดำ

ตัวแปรสล็อตเป็นส่วนหนึ่งของสถานะของตัวเพิ่มประสิทธิภาพ แต่ถูกสร้างขึ้นสำหรับตัวแปรเฉพาะ ตัวอย่างเช่น ขอบ 'm' ด้านบนสอดคล้องกับโมเมนตัม ซึ่งเครื่องมือเพิ่มประสิทธิภาพ Adam ติดตามสำหรับแต่ละตัวแปร ตัวแปรสล็อตจะถูกบันทึกในจุดตรวจสอบหากทั้งตัวแปรและตัวเพิ่มประสิทธิภาพจะถูกบันทึก ดังนั้นขอบที่เป็นเส้นประ

การเรียกการ restore บนอ็อบเจ็กต์ tf.train.Checkpoint จะเข้าคิวการคืนค่าที่ร้องขอ เรียกคืนค่าตัวแปรทันทีที่มีพาธที่ตรงกันจากออบเจ็กต์ Checkpoint ตัวอย่างเช่น คุณสามารถโหลดเฉพาะความเอนเอียงจากโมเดลที่คุณกำหนดไว้ด้านบนโดยสร้างเส้นทางเดียวไปยังโมเดลดังกล่าวผ่านเครือข่ายและเลเยอร์

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.7209885 3.7588918 4.421351  4.1466427 4.0712557]

กราฟการพึ่งพาสำหรับออบเจ็กต์ใหม่เหล่านี้เป็นกราฟย่อยที่เล็กกว่ามากของจุดตรวจที่ใหญ่กว่าที่คุณเขียนไว้ด้านบน รวมเฉพาะอคติและตัวนับบันทึกที่ tf.train.Checkpoint ใช้เพื่อนับจุดตรวจ

การแสดงภาพกราฟย่อยสำหรับตัวแปรอคติ

restore ส่งคืนวัตถุสถานะซึ่งมีการยืนยันที่เป็นทางเลือก ออบเจ็กต์ทั้งหมดที่สร้างขึ้นใน Checkpoint ใหม่ได้รับการฟื้นฟูแล้ว ดังนั้น status.assert_existing_objects_matched ผ่าน

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f93a075b9d0>

มีอ็อบเจ็กต์จำนวนมากในจุดตรวจที่ไม่ตรงกัน รวมถึงเคอร์เนลของเลเยอร์และตัวแปรของออปติไมเซอร์ status.assert_consumed ผ่านก็ต่อเมื่อจุดตรวจและโปรแกรมตรงกันทุกประการ และจะโยนข้อยกเว้นที่นี่

การบูรณะที่รอการตัดบัญชี

Layer อ็อบเจ็กต์ใน TensorFlow อาจเลื่อนการสร้างตัวแปรไปเป็นการโทรครั้งแรก เมื่อรูปร่างอินพุตพร้อมใช้งาน ตัวอย่างเช่น รูปร่างของเคอร์เนลของเลเยอร์ Dense ขึ้นอยู่กับทั้งรูปร่างอินพุตและเอาต์พุตของเลเยอร์ ดังนั้นรูปร่างเอาต์พุตที่ต้องการในฐานะอาร์กิวเมนต์ตัวสร้างจึงไม่มีข้อมูลเพียงพอที่จะสร้างตัวแปรได้ด้วยตัวเอง เนื่องจากการเรียกใช้ Layer ยังอ่านค่าของตัวแปรด้วย การคืนค่าจะต้องเกิดขึ้นระหว่างการสร้างตัวแปรและการใช้งานครั้งแรก

เพื่อรองรับสำนวนนี้ tf.train.Checkpoint defers restores ซึ่งยังไม่มีตัวแปรที่ตรงกัน

deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.5854754 4.607731  4.649179  4.8474874 5.121    ]]
ตัวยึดตำแหน่ง22

ตรวจสอบจุดตรวจด้วยตนเอง

tf.train.load_checkpoint ส่งคืน CheckpointReader ที่ให้การเข้าถึงเนื้อหาด่านที่ต่ำกว่า ประกอบด้วยการแมปจากคีย์ของตัวแปรแต่ละตัว ไปยังรูปร่างและ dtype ของตัวแปรแต่ละตัวในจุดตรวจสอบ คีย์ของตัวแปรคือเส้นทางของออบเจ็กต์ เช่นเดียวกับในกราฟที่แสดงด้านบน

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

ดังนั้น หากคุณสนใจในมูลค่าของ net.l1.kernel คุณสามารถรับค่าโดยใช้รหัสต่อไปนี้:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

นอกจากนี้ยังมีวิธี get_tensor ที่ให้คุณตรวจสอบค่าของตัวแปรได้:

reader.get_tensor(key)
array([[4.5854754, 4.607731 , 4.649179 , 4.8474874, 5.121    ]],
      dtype=float32)

การติดตามวัตถุ

จุดตรวจสอบจะบันทึกและกู้คืนค่าของ tf.Variable โดย "ติดตาม" ตัวแปรหรืออ็อบเจ็กต์ที่ติดตามได้ซึ่งตั้งค่าไว้ในแอตทริบิวต์ใดแอตทริบิวต์หนึ่ง เมื่อดำเนินการบันทึก ตัวแปรจะถูกรวบรวมแบบเรียกซ้ำจากออบเจ็กต์ที่ติดตามที่เข้าถึงได้ทั้งหมด

เช่นเดียวกับการกำหนดแอตทริบิวต์โดยตรง เช่น self.l1 = tf.keras.layers.Dense(5) การกำหนดรายการและพจนานุกรมให้กับแอตทริบิวต์จะติดตามเนื้อหา

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

คุณอาจสังเกตเห็นออบเจ็กต์ของแรปเปอร์สำหรับรายการและพจนานุกรม Wrapper เหล่านี้เป็นเวอร์ชันที่ตรวจสอบได้ของโครงสร้างข้อมูลพื้นฐาน เช่นเดียวกับการโหลดตามแอตทริบิวต์ Wrapper เหล่านี้จะคืนค่าของตัวแปรทันทีที่เพิ่มลงในคอนเทนเนอร์

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

ออบเจ็กต์ที่ติดตามได้รวมถึง tf.train.Checkpoint , tf.Module และคลาสย่อย (เช่น keras.layers.Layer และ keras.Model ) และคอนเทนเนอร์ Python ที่รู้จัก:

  • dict (และ collections.OrderedDict )
  • list
  • tuple (และ collections.namedtuple typing.NamedTuple )

ไม่รองรับ คอนเทนเนอร์ประเภทอื่นๆ ซึ่งรวมถึง:

  • collections.defaultdict
  • set

วัตถุ Python อื่น ๆ ทั้งหมดจะ ถูกละเว้น รวมถึง:

  • int
  • string
  • float

สรุป

ออบเจ็กต์ TensorFlow ให้กลไกอัตโนมัติที่ง่ายดายสำหรับการบันทึกและกู้คืนค่าของตัวแปรที่ใช้