किसी मॉडल की स्थिति को सहेजने और पुनर्स्थापित करने की क्षमता कई अनुप्रयोगों के लिए महत्वपूर्ण है, जैसे ट्रांसफर लर्निंग में या पूर्व-प्रशिक्षित मॉडल का उपयोग करके अनुमान लगाने के लिए। किसी मॉडल के मापदंडों (वजन, पूर्वाग्रह, आदि) को चेकपॉइंट फ़ाइल या निर्देशिका में सहेजना इसे पूरा करने का एक तरीका है।
यह मॉड्यूल TensorFlow v2 प्रारूप चौकियों को लोड करने और सहेजने के लिए एक उच्च-स्तरीय इंटरफ़ेस प्रदान करता है, साथ ही निचले स्तर के घटक भी प्रदान करता है जो इस फ़ाइल प्रारूप में लिखते और पढ़ते हैं।
सरल मॉडल लोड करना और सहेजना
Checkpointable
प्रोटोकॉल के अनुरूप, कई सरल मॉडलों को बिना किसी अतिरिक्त कोड के चेकपॉइंट पर क्रमबद्ध किया जा सकता है:
import Checkpoints
import ImageClassificationModels
extension LeNet: Checkpointable {}
var model = LeNet()
...
try model.writeCheckpoint(to: directory, name: "LeNet")
और फिर उसी चेकपॉइंट का उपयोग करके पढ़ा जा सकता है:
try model.readCheckpoint(from: directory, name: "LeNet")
मॉडल लोडिंग और सेविंग के लिए यह डिफ़ॉल्ट कार्यान्वयन मॉडल में प्रत्येक टेंसर के लिए पथ-आधारित नामकरण योजना का उपयोग करेगा जो मॉडल संरचनाओं के भीतर गुणों के नाम पर आधारित है। उदाहरण के लिए, LeNet-5 मॉडल में पहले कनवल्शन के भीतर वजन और पूर्वाग्रह क्रमशः conv1/filter
और conv1/bias
नाम से सहेजे जाएंगे। लोड करते समय, चेकपॉइंट रीडर इन नामों वाले टेंसर की खोज करेगा।
मॉडल लोडिंग और सेविंग को अनुकूलित करना
यदि आप इस पर अधिक नियंत्रण रखना चाहते हैं कि कौन से टेंसर सहेजे और लोड किए गए हैं, या उन टेंसरों का नामकरण किया गया है, तो Checkpointable
प्रोटोकॉल अनुकूलन के कुछ बिंदु प्रदान करता है।
कुछ प्रकारों पर गुणों को अनदेखा करने के लिए, आप अपने मॉडल पर ignoredTensorPaths
का कार्यान्वयन प्रदान कर सकते हैं जो Type.property
के रूप में स्ट्रिंग्स का एक सेट लौटाता है। उदाहरण के लिए, प्रत्येक अटेंशन लेयर पर scale
प्रॉपर्टी को अनदेखा करने के लिए, आप ["Attention.scale"]
वापस कर सकते हैं।
डिफ़ॉल्ट रूप से, किसी मॉडल में प्रत्येक गहरे स्तर को अलग करने के लिए फ़ॉरवर्ड स्लैश का उपयोग किया जाता है। इसे आपके मॉडल पर checkpointSeparator
लागू करके और इस सेपरेटर के लिए उपयोग करने के लिए एक नई स्ट्रिंग प्रदान करके अनुकूलित किया जा सकता है।
अंत में, टेंसर नामकरण में अनुकूलन की सबसे बड़ी डिग्री के लिए, आप tensorNameMap
कार्यान्वित कर सकते हैं और एक फ़ंक्शन प्रदान कर सकते हैं जो मॉडल में टेंसर के लिए उत्पन्न डिफ़ॉल्ट स्ट्रिंग नाम से चेकपॉइंट में वांछित स्ट्रिंग नाम पर मैप करता है। आमतौर पर, इसका उपयोग अन्य रूपरेखाओं के साथ उत्पन्न चौकियों के साथ इंटरऑपरेट करने के लिए किया जाएगा, जिनमें से प्रत्येक की अपनी नामकरण परंपराएं और मॉडल संरचनाएं हैं। एक कस्टम मैपिंग फ़ंक्शन इन टेंसरों को कैसे नामित किया जाता है, इसके लिए अनुकूलन की सबसे बड़ी डिग्री देता है।
कुछ मानक सहायक फ़ंक्शन प्रदान किए जाते हैं, जैसे डिफ़ॉल्ट CheckpointWriter.identityMap
(जो चेकपॉइंट के लिए स्वचालित रूप से जेनरेट किए गए टेंसर पथ नाम का उपयोग करता है), या CheckpointWriter.lookupMap(table:)
फ़ंक्शन, जो एक शब्दकोश से मैपिंग बना सकता है।
कस्टम मैपिंग कैसे पूरी की जा सकती है, इसके उदाहरण के लिए, कृपया GPT-2 मॉडल देखें, जो OpenAI की चौकियों के लिए उपयोग की जाने वाली सटीक नामकरण योजना से मेल खाने के लिए मैपिंग फ़ंक्शन का उपयोग करता है।
चेकपॉइंटरीडर और चेकपॉइंटराइटर घटक
चेकपॉइंट लेखन के लिए, Checkpointable
प्रोटोकॉल द्वारा प्रदान किया गया एक्सटेंशन एक मॉडल के गुणों पर पुनरावृत्ति करने के लिए प्रतिबिंब और कीपाथ का उपयोग करता है और एक शब्दकोश उत्पन्न करता है जो स्ट्रिंग टेंसर पथ को टेंसर मानों पर मैप करता है। यह शब्दकोश एक अंतर्निहित CheckpointWriter
को एक निर्देशिका के साथ प्रदान किया जाता है जिसमें चेकपॉइंट लिखना होता है। वह CheckpointWriter
उस शब्दकोश से ऑन-डिस्क चेकपॉइंट उत्पन्न करने का कार्य संभालता है।
इस प्रक्रिया का उलटा भाग रीडिंग है, जहां CheckpointReader
ऑन-डिस्क चेकपॉइंट निर्देशिका का स्थान दिया जाता है। फिर यह उस चेकपॉइंट से पढ़ता है और एक शब्दकोश बनाता है जो चेकपॉइंट के भीतर टेंसरों के नामों को उनके सहेजे गए मानों के साथ मैप करता है। इस शब्दकोश का उपयोग किसी मॉडल में मौजूदा टेंसर को इस शब्दकोश के टेंसर से बदलने के लिए किया जाता है।
लोडिंग और सेविंग दोनों के लिए, Checkpointable
प्रोटोकॉल ऊपर वर्णित मैपिंग फ़ंक्शन का उपयोग करके टेंसर के स्ट्रिंग पथ को संबंधित ऑन-डिस्क टेंसर नामों पर मैप करता है।
यदि Checkpointable
प्रोटोकॉल में आवश्यक कार्यक्षमता का अभाव है, या चेकपॉइंट लोडिंग और सेविंग प्रक्रिया पर अधिक नियंत्रण वांछित है, तो CheckpointReader
और CheckpointWriter
कक्षाओं का उपयोग स्वयं किया जा सकता है।
TensorFlow v2 चेकपॉइंट प्रारूप
TensorFlow v2 चेकपॉइंट प्रारूप, जैसा कि इस हेडर में संक्षेप में वर्णित है, TensorFlow मॉडल चेकपॉइंट्स के लिए दूसरी पीढ़ी का प्रारूप है। यह दूसरी पीढ़ी का प्रारूप 2016 के अंत से उपयोग में है, और इसमें v1 चेकपॉइंट प्रारूप में कई सुधार हैं। मॉडल मापदंडों को सहेजने के लिए TensorFlow SavedModels अपने भीतर v2 चेकपॉइंट का उपयोग करते हैं।
TensorFlow v2 चेकपॉइंट में निम्न जैसी संरचना वाली एक निर्देशिका होती है:
checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002
जहां पहली फ़ाइल चेकपॉइंट के लिए मेटाडेटा संग्रहीत करती है और शेष फ़ाइलें मॉडल के लिए क्रमबद्ध पैरामीटर रखने वाले बाइनरी शार्ड हैं।
इंडेक्स मेटाडेटा फ़ाइल में शार्ड्स में निहित सभी क्रमबद्ध टेंसरों के प्रकार, आकार, स्थान और स्ट्रिंग नाम शामिल हैं। वह इंडेक्स फ़ाइल चेकपॉइंट का सबसे संरचनात्मक रूप से जटिल हिस्सा है, और tensorflow::table
पर आधारित है, जो स्वयं एसएसटेबल/लेवलडीबी पर आधारित है। यह इंडेक्स फ़ाइल कुंजी-मूल्य जोड़े की एक श्रृंखला से बनी है, जहां कुंजी स्ट्रिंग हैं और मान प्रोटोकॉल बफ़र्स हैं। तारों को क्रमबद्ध और उपसर्ग-संपीड़ित किया जाता है। उदाहरण के लिए: यदि पहली प्रविष्टि conv1/weight
है और अगली conv1/bias
, तो दूसरी प्रविष्टि केवल bias
भाग का उपयोग करती है।
यह समग्र सूचकांक फ़ाइल कभी-कभी स्नैपी संपीड़न का उपयोग करके संपीड़ित होती है। SnappyDecompression.swift
फ़ाइल एक संपीड़ित डेटा इंस्टेंस से Snappy डीकंप्रेसन का मूल स्विफ्ट कार्यान्वयन प्रदान करती है।
इंडेक्स हेडर मेटाडेटा और टेंसर मेटाडेटा को प्रोटोकॉल बफ़र्स के रूप में एन्कोड किया गया है और सीधे स्विफ्ट प्रोटोबफ़ के माध्यम से एन्कोड / डिकोड किया गया है।
CheckpointIndexReader
और CheckpointIndexWriter
क्लास इन इंडेक्स फ़ाइलों को ओवररचिंग CheckpointReader
और CheckpointWriter
क्लास के हिस्से के रूप में लोड करने और सहेजने का काम संभालते हैं। उत्तरार्द्ध सूचकांक फ़ाइलों का उपयोग यह निर्धारित करने के लिए आधार के रूप में करते हैं कि संरचनात्मक रूप से सरल बाइनरी शार्ड में क्या पढ़ना और लिखना है जिसमें टेंसर डेटा होता है।
किसी मॉडल की स्थिति को सहेजने और पुनर्स्थापित करने की क्षमता कई अनुप्रयोगों के लिए महत्वपूर्ण है, जैसे ट्रांसफर लर्निंग में या पूर्व-प्रशिक्षित मॉडल का उपयोग करके अनुमान लगाने के लिए। किसी मॉडल के मापदंडों (वजन, पूर्वाग्रह, आदि) को चेकपॉइंट फ़ाइल या निर्देशिका में सहेजना इसे पूरा करने का एक तरीका है।
यह मॉड्यूल TensorFlow v2 प्रारूप चौकियों को लोड करने और सहेजने के लिए एक उच्च-स्तरीय इंटरफ़ेस प्रदान करता है, साथ ही निचले स्तर के घटक भी प्रदान करता है जो इस फ़ाइल प्रारूप में लिखते और पढ़ते हैं।
सरल मॉडल लोड करना और सहेजना
Checkpointable
प्रोटोकॉल के अनुरूप, कई सरल मॉडलों को बिना किसी अतिरिक्त कोड के चेकपॉइंट पर क्रमबद्ध किया जा सकता है:
import Checkpoints
import ImageClassificationModels
extension LeNet: Checkpointable {}
var model = LeNet()
...
try model.writeCheckpoint(to: directory, name: "LeNet")
और फिर उसी चेकपॉइंट का उपयोग करके पढ़ा जा सकता है:
try model.readCheckpoint(from: directory, name: "LeNet")
मॉडल लोडिंग और सेविंग के लिए यह डिफ़ॉल्ट कार्यान्वयन मॉडल में प्रत्येक टेंसर के लिए पथ-आधारित नामकरण योजना का उपयोग करेगा जो मॉडल संरचनाओं के भीतर गुणों के नाम पर आधारित है। उदाहरण के लिए, LeNet-5 मॉडल में पहले कनवल्शन के भीतर वजन और पूर्वाग्रह क्रमशः conv1/filter
और conv1/bias
नाम से सहेजे जाएंगे। लोड करते समय, चेकपॉइंट रीडर इन नामों वाले टेंसर की खोज करेगा।
मॉडल लोडिंग और सेविंग को अनुकूलित करना
यदि आप इस पर अधिक नियंत्रण रखना चाहते हैं कि कौन से टेंसर सहेजे और लोड किए गए हैं, या उन टेंसरों का नामकरण किया गया है, तो Checkpointable
प्रोटोकॉल अनुकूलन के कुछ बिंदु प्रदान करता है।
कुछ प्रकारों पर गुणों को अनदेखा करने के लिए, आप अपने मॉडल पर ignoredTensorPaths
का कार्यान्वयन प्रदान कर सकते हैं जो Type.property
के रूप में स्ट्रिंग्स का एक सेट लौटाता है। उदाहरण के लिए, प्रत्येक अटेंशन लेयर पर scale
प्रॉपर्टी को अनदेखा करने के लिए, आप ["Attention.scale"]
वापस कर सकते हैं।
डिफ़ॉल्ट रूप से, किसी मॉडल में प्रत्येक गहरे स्तर को अलग करने के लिए फ़ॉरवर्ड स्लैश का उपयोग किया जाता है। इसे आपके मॉडल पर checkpointSeparator
लागू करके और इस सेपरेटर के लिए उपयोग करने के लिए एक नई स्ट्रिंग प्रदान करके अनुकूलित किया जा सकता है।
अंत में, टेंसर नामकरण में अनुकूलन की सबसे बड़ी डिग्री के लिए, आप tensorNameMap
कार्यान्वित कर सकते हैं और एक फ़ंक्शन प्रदान कर सकते हैं जो मॉडल में टेंसर के लिए उत्पन्न डिफ़ॉल्ट स्ट्रिंग नाम से चेकपॉइंट में वांछित स्ट्रिंग नाम पर मैप करता है। आमतौर पर, इसका उपयोग अन्य रूपरेखाओं के साथ उत्पन्न चौकियों के साथ इंटरऑपरेट करने के लिए किया जाएगा, जिनमें से प्रत्येक की अपनी नामकरण परंपराएं और मॉडल संरचनाएं हैं। एक कस्टम मैपिंग फ़ंक्शन इन टेंसरों को कैसे नामित किया जाता है, इसके लिए अनुकूलन की सबसे बड़ी डिग्री देता है।
कुछ मानक सहायक फ़ंक्शन प्रदान किए जाते हैं, जैसे डिफ़ॉल्ट CheckpointWriter.identityMap
(जो चेकपॉइंट के लिए स्वचालित रूप से जेनरेट किए गए टेंसर पथ नाम का उपयोग करता है), या CheckpointWriter.lookupMap(table:)
फ़ंक्शन, जो एक शब्दकोश से मैपिंग बना सकता है।
कस्टम मैपिंग कैसे पूरी की जा सकती है, इसके उदाहरण के लिए, कृपया GPT-2 मॉडल देखें, जो OpenAI की चौकियों के लिए उपयोग की जाने वाली सटीक नामकरण योजना से मेल खाने के लिए मैपिंग फ़ंक्शन का उपयोग करता है।
चेकपॉइंटरीडर और चेकपॉइंटराइटर घटक
चेकपॉइंट लेखन के लिए, Checkpointable
प्रोटोकॉल द्वारा प्रदान किया गया एक्सटेंशन एक मॉडल के गुणों पर पुनरावृत्ति करने के लिए प्रतिबिंब और कीपाथ का उपयोग करता है और एक शब्दकोश उत्पन्न करता है जो स्ट्रिंग टेंसर पथ को टेंसर मानों पर मैप करता है। यह शब्दकोश एक अंतर्निहित CheckpointWriter
को एक निर्देशिका के साथ प्रदान किया जाता है जिसमें चेकपॉइंट लिखना होता है। वह CheckpointWriter
उस शब्दकोश से ऑन-डिस्क चेकपॉइंट उत्पन्न करने का कार्य संभालता है।
इस प्रक्रिया का उलटा भाग रीडिंग है, जहां CheckpointReader
ऑन-डिस्क चेकपॉइंट निर्देशिका का स्थान दिया जाता है। फिर यह उस चेकपॉइंट से पढ़ता है और एक शब्दकोश बनाता है जो चेकपॉइंट के भीतर टेंसरों के नामों को उनके सहेजे गए मानों के साथ मैप करता है। इस शब्दकोश का उपयोग किसी मॉडल में मौजूदा टेंसर को इस शब्दकोश के टेंसर से बदलने के लिए किया जाता है।
लोडिंग और सेविंग दोनों के लिए, Checkpointable
प्रोटोकॉल ऊपर वर्णित मैपिंग फ़ंक्शन का उपयोग करके टेंसर के स्ट्रिंग पथ को संबंधित ऑन-डिस्क टेंसर नामों पर मैप करता है।
यदि Checkpointable
प्रोटोकॉल में आवश्यक कार्यक्षमता का अभाव है, या चेकपॉइंट लोडिंग और सेविंग प्रक्रिया पर अधिक नियंत्रण वांछित है, तो CheckpointReader
और CheckpointWriter
कक्षाओं का उपयोग स्वयं किया जा सकता है।
TensorFlow v2 चेकपॉइंट प्रारूप
TensorFlow v2 चेकपॉइंट प्रारूप, जैसा कि इस हेडर में संक्षेप में वर्णित है, TensorFlow मॉडल चेकपॉइंट्स के लिए दूसरी पीढ़ी का प्रारूप है। यह दूसरी पीढ़ी का प्रारूप 2016 के अंत से उपयोग में है, और इसमें v1 चेकपॉइंट प्रारूप में कई सुधार हैं। मॉडल मापदंडों को सहेजने के लिए TensorFlow SavedModels अपने भीतर v2 चेकपॉइंट का उपयोग करते हैं।
TensorFlow v2 चेकपॉइंट में निम्न जैसी संरचना वाली एक निर्देशिका होती है:
checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002
जहां पहली फ़ाइल चेकपॉइंट के लिए मेटाडेटा संग्रहीत करती है और शेष फ़ाइलें मॉडल के लिए क्रमबद्ध पैरामीटर रखने वाले बाइनरी शार्ड हैं।
इंडेक्स मेटाडेटा फ़ाइल में शार्ड्स में निहित सभी क्रमबद्ध टेंसरों के प्रकार, आकार, स्थान और स्ट्रिंग नाम शामिल हैं। वह इंडेक्स फ़ाइल चेकपॉइंट का सबसे संरचनात्मक रूप से जटिल हिस्सा है, और tensorflow::table
पर आधारित है, जो स्वयं एसएसटेबल/लेवलडीबी पर आधारित है। यह इंडेक्स फ़ाइल कुंजी-मूल्य जोड़े की एक श्रृंखला से बनी है, जहां कुंजी स्ट्रिंग हैं और मान प्रोटोकॉल बफ़र्स हैं। तारों को क्रमबद्ध और उपसर्ग-संपीड़ित किया जाता है। उदाहरण के लिए: यदि पहली प्रविष्टि conv1/weight
है और अगली conv1/bias
, तो दूसरी प्रविष्टि केवल bias
भाग का उपयोग करती है।
यह समग्र सूचकांक फ़ाइल कभी-कभी स्नैपी संपीड़न का उपयोग करके संपीड़ित होती है। SnappyDecompression.swift
फ़ाइल एक संपीड़ित डेटा इंस्टेंस से Snappy डीकंप्रेसन का मूल स्विफ्ट कार्यान्वयन प्रदान करती है।
इंडेक्स हेडर मेटाडेटा और टेंसर मेटाडेटा को प्रोटोकॉल बफ़र्स के रूप में एन्कोड किया गया है और सीधे स्विफ्ट प्रोटोबफ़ के माध्यम से एन्कोड / डिकोड किया गया है।
CheckpointIndexReader
और CheckpointIndexWriter
क्लास इन इंडेक्स फ़ाइलों को ओवररचिंग CheckpointReader
और CheckpointWriter
क्लास के हिस्से के रूप में लोड करने और सहेजने का काम संभालते हैं। उत्तरार्द्ध सूचकांक फ़ाइलों का उपयोग यह निर्धारित करने के लिए आधार के रूप में करते हैं कि संरचनात्मक रूप से सरल बाइनरी शार्ड में क्या पढ़ना और लिखना है जिसमें टेंसर डेटा होता है।