ความสามารถในการบันทึกและกู้คืนสถานะของแบบจำลองมีความสำคัญต่อแอปพลิเคชันจำนวนหนึ่ง เช่น ในการถ่ายโอนการเรียนรู้ หรือสำหรับการอนุมานโดยใช้แบบจำลองที่ได้รับการฝึกอบรมมาแล้ว การบันทึกพารามิเตอร์ของโมเดล (น้ำหนัก อคติ ฯลฯ) ในไฟล์จุดตรวจสอบหรือไดเร็กทอรีเป็นวิธีหนึ่งในการบรรลุเป้าหมายนี้
โมดูลนี้มีอินเทอร์เฟซระดับสูงสำหรับการโหลดและบันทึกจุดตรวจสอบ รูปแบบ TensorFlow v2 รวมถึงส่วนประกอบระดับล่างที่เขียนและอ่านจากรูปแบบไฟล์นี้
กำลังโหลดและบันทึกโมเดลอย่างง่าย
ด้วยการปฏิบัติตามโปรโตคอล Checkpointable
โมเดลง่ายๆ จำนวนมากจึงสามารถซีเรียลไลซ์ไปยังจุดตรวจสอบได้โดยไม่ต้องใช้รหัสเพิ่มเติม:
import Checkpoints
import ImageClassificationModels
extension LeNet: Checkpointable {}
var model = LeNet()
...
try model.writeCheckpoint(to: directory, name: "LeNet")
จากนั้นจุดตรวจสอบเดียวกันนั้นสามารถอ่านได้โดยใช้:
try model.readCheckpoint(from: directory, name: "LeNet")
การใช้งานเริ่มต้นสำหรับการโหลดและบันทึกโมเดลนี้จะใช้โครงร่างการตั้งชื่อตามเส้นทางสำหรับแต่ละเทนเซอร์ในโมเดลที่ขึ้นอยู่กับชื่อของคุณสมบัติภายในโครงสร้างของโมเดล ตัวอย่างเช่น น้ำหนักและอคติภายในการหมุนครั้งแรกใน โมเดล LeNet-5 จะถูกบันทึกด้วยชื่อ conv1/filter
และ conv1/bias
ตามลำดับ เมื่อโหลด เครื่องอ่านจุดตรวจจะค้นหาเทนเซอร์ด้วยชื่อเหล่านี้
การปรับแต่งการโหลดและการบันทึกโมเดล
หากคุณต้องการควบคุมเทนเซอร์ที่จะบันทึกและโหลดได้มากขึ้น หรือการตั้งชื่อเทนเซอร์เหล่านั้น โปรโตคอล Checkpointable
เสนอการปรับแต่งสองสามจุด
หากต้องการละเว้นคุณสมบัติของบางประเภท คุณสามารถจัดเตรียมการใช้งานของ ignoredTensorPaths
บนโมเดลของคุณที่ส่งคืนชุดสตริงในรูปแบบของ Type.property
ตัวอย่างเช่น หากต้องการละเว้นคุณสมบัติ scale
ในทุกเลเยอร์ Attention คุณสามารถส่งคืน ["Attention.scale"]
ได้
ตามค่าเริ่มต้น เครื่องหมายทับจะใช้เพื่อแยกแต่ละระดับที่ลึกกว่าในแบบจำลอง ซึ่งสามารถปรับแต่งได้โดยใช้ checkpointSeparator
กับโมเดลของคุณและระบุสตริงใหม่เพื่อใช้สำหรับตัวคั่นนี้
สุดท้ายนี้ เพื่อการปรับแต่งการตั้งชื่อเทนเซอร์ในระดับสูงสุด คุณสามารถใช้ tensorNameMap
และจัดเตรียมฟังก์ชันที่จับคู่จากชื่อสตริงเริ่มต้นที่สร้างขึ้นสำหรับเทนเซอร์ในโมเดลกับชื่อสตริงที่ต้องการในจุดตรวจสอบ โดยทั่วไป สิ่งนี้จะถูกใช้เพื่อทำงานร่วมกับจุดตรวจสอบที่สร้างด้วยเฟรมเวิร์กอื่นๆ ซึ่งแต่ละจุดมีแบบแผนการตั้งชื่อและโครงสร้างแบบจำลองของตัวเอง ฟังก์ชันการแมปแบบกำหนดเองช่วยให้ปรับแต่งวิธีการตั้งชื่อเทนเซอร์เหล่านี้ได้ในระดับสูงสุด
มีฟังก์ชันตัวช่วยมาตรฐานบางอย่างมาให้ เช่น CheckpointWriter.identityMap
เริ่มต้น (ซึ่งใช้ชื่อพาธเทนเซอร์ที่สร้างขึ้นโดยอัตโนมัติสำหรับจุดตรวจสอบ) หรือฟังก์ชัน CheckpointWriter.lookupMap(table:)
ซึ่งสามารถสร้างการแมปจากพจนานุกรมได้
สำหรับตัวอย่างวิธีการแมปแบบกำหนดเองให้สำเร็จ โปรดดู โมเดล GPT-2 ซึ่งใช้ฟังก์ชันการแมปเพื่อให้ตรงกับรูปแบบการตั้งชื่อที่ใช้สำหรับจุดตรวจสอบของ OpenAI
ส่วนประกอบ CheckpointReader และ CheckpointWriter
สำหรับการเขียนจุดตรวจ ส่วนขยายที่จัดทำโดยโปรโตคอล Checkpointable
จะใช้การสะท้อนและเส้นทางคีย์เพื่อวนซ้ำคุณสมบัติของแบบจำลอง และสร้างพจนานุกรมที่แมปเส้นทางสตริงเทนเซอร์กับค่าเทนเซอร์ พจนานุกรมนี้จัดทำขึ้นสำหรับ CheckpointWriter
พื้นฐาน พร้อมด้วยไดเร็กทอรีสำหรับเขียนจุดตรวจสอบ CheckpointWriter
นั้นจัดการงานสร้างจุดตรวจสอบบนดิสก์จากพจนานุกรมนั้น
สิ่งที่ตรงกันข้ามของกระบวนการนี้คือการอ่าน โดยที่ CheckpointReader
จะได้รับตำแหน่งของไดเร็กทอรีจุดตรวจสอบบนดิสก์ จากนั้นจะอ่านจากจุดตรวจนั้นและสร้างพจนานุกรมที่จับคู่ชื่อของเทนเซอร์ภายในจุดตรวจด้วยค่าที่บันทึกไว้ พจนานุกรมนี้ใช้เพื่อแทนที่เทนเซอร์ปัจจุบันในแบบจำลองด้วยพจนานุกรมในพจนานุกรมนี้
สำหรับทั้งการโหลดและการบันทึก โปรโตคอล Checkpointable
จะแมปพาธสตริงกับเทนเซอร์กับชื่อเทนเซอร์บนดิสก์ที่สอดคล้องกัน โดยใช้ฟังก์ชันการแมปที่อธิบายไว้ข้างต้น
หากโปรโตคอล Checkpointable
ขาดฟังก์ชันการทำงานที่จำเป็น หรือต้องการการควบคุมเพิ่มเติมเกี่ยวกับกระบวนการโหลดและบันทึกจุด CheckpointReader
และ CheckpointWriter
สามารถใช้งานได้ด้วยตัวเอง
รูปแบบจุดตรวจสอบ TensorFlow v2
รูปแบบจุดตรวจสอบ TensorFlow v2 ดังที่อธิบายไว้สั้นๆ ใน ส่วนหัวนี้ คือรูปแบบรุ่นที่สองสำหรับจุดตรวจสอบโมเดล TensorFlow รูปแบบรุ่นที่สองนี้มีการใช้งานมาตั้งแต่ปลายปี 2016 และมีการปรับปรุงหลายประการจากรูปแบบจุดตรวจสอบ v1 TensorFlow SavedModels ใช้จุดตรวจสอบ v2 ภายในเพื่อบันทึกพารามิเตอร์โมเดล
จุดตรวจสอบ TensorFlow v2 ประกอบด้วยไดเร็กทอรีที่มีโครงสร้างดังนี้:
checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002
โดยที่ไฟล์แรกเก็บข้อมูลเมตาสำหรับจุดตรวจสอบและไฟล์ที่เหลือเป็นไบนารีชาร์ดที่เก็บพารามิเตอร์ซีเรียลไลซ์สำหรับโมเดล
ไฟล์ข้อมูลเมตาดัชนีประกอบด้วยประเภท ขนาด ตำแหน่ง และชื่อสตริงของเทนเซอร์ที่ต่อเนื่องกันทั้งหมดที่มีอยู่ในชาร์ด ไฟล์ดัชนีนั้นเป็นส่วนที่มีโครงสร้างซับซ้อนที่สุดของจุดตรวจสอบ และอิงตาม tensorflow::table
ซึ่งตัวมันเองอิงตาม SSTable / LevelDB ไฟล์ดัชนีนี้ประกอบด้วยชุดของคู่คีย์-ค่า โดยที่คีย์เป็นสตริง และค่าเป็นบัฟเฟอร์โปรโตคอล สตริงจะถูกจัดเรียงและบีบอัดคำนำหน้า ตัวอย่างเช่น หากรายการแรกคือ conv1/weight
และ conv1/bias
ถัดไป รายการที่สองจะใช้เฉพาะส่วน bias
เท่านั้น
ไฟล์ดัชนีโดยรวมนี้บางครั้งถูกบีบอัดโดยใช้ Snappy Compression ไฟล์ SnappyDecompression.swift
นำเสนอการใช้งาน Swift แบบเนทีฟของการบีบอัด Snappy จากอินสแตนซ์ข้อมูลที่บีบอัด
ข้อมูลเมตาของส่วนหัวดัชนีและข้อมูลเมตาของเทนเซอร์ได้รับการเข้ารหัสเป็นบัฟเฟอร์โปรโตคอลและเข้ารหัส / ถอดรหัสโดยตรงผ่าน Swift Protobuf
คลาส CheckpointIndexReader
และ CheckpointIndexWriter
จัดการการโหลดและบันทึกไฟล์ดัชนีเหล่านี้โดยเป็นส่วนหนึ่งของคลาส CheckpointReader
และ CheckpointWriter
ที่ครอบคลุม อย่างหลังใช้ไฟล์ดัชนีเป็นพื้นฐานในการกำหนดว่าจะอ่านและเขียนอะไรลงในไบนารี่ชาร์ดที่มีโครงสร้างง่ายกว่าซึ่งมีข้อมูลเทนเซอร์