एक पूर्व प्रशिक्षित मॉडल का प्रयोग करें

इस ट्यूटोरियल में आप एक उदाहरण वेब एप्लिकेशन का पता लगाएंगे जो TensorFlow.js लेयर्स एपीआई का उपयोग करके ट्रांसफर लर्निंग को प्रदर्शित करता है। उदाहरण एक पूर्व-प्रशिक्षित मॉडल को लोड करता है और फिर ब्राउज़र में मॉडल को पुनः प्रशिक्षित करता है।

मॉडल को एमएनआईएसटी अंक वर्गीकरण डेटासेट के अंक 0-4 पर पायथन में पूर्व-प्रशिक्षित किया गया है। ब्राउज़र में पुनर्प्रशिक्षण (या स्थानांतरण शिक्षण) अंक 5-9 का उपयोग करता है। उदाहरण से पता चलता है कि पूर्व-प्रशिक्षित मॉडल की पहली कई परतों का उपयोग स्थानांतरण सीखने के दौरान नए डेटा से सुविधाओं को निकालने के लिए किया जा सकता है, इस प्रकार नए डेटा पर तेजी से प्रशिक्षण सक्षम हो सकता है।

इस ट्यूटोरियल के लिए उदाहरण एप्लिकेशन ऑनलाइन उपलब्ध है, इसलिए आपको कोई कोड डाउनलोड करने या विकास वातावरण स्थापित करने की आवश्यकता नहीं है। यदि आप कोड को स्थानीय रूप से चलाना चाहते हैं, तो स्थानीय रूप से उदाहरण चलाएँ में वैकल्पिक चरणों को पूरा करें। यदि आप एक विकास वातावरण स्थापित नहीं करना चाहते हैं, तो आप उदाहरण का अन्वेषण करना छोड़ सकते हैं।

उदाहरण कोड GitHub पर उपलब्ध है।

(वैकल्पिक) उदाहरण को स्थानीय रूप से चलाएँ

आवश्यक शर्तें

उदाहरण ऐप को स्थानीय रूप से चलाने के लिए, आपको अपने विकास परिवेश में निम्नलिखित इंस्टॉल करना होगा:

उदाहरण ऐप इंस्टॉल करें और चलाएं

  1. tfjs-examples भंडार को क्लोन करें या डाउनलोड करें।
  2. mnist-transfer-cnn निर्देशिका में बदलें:

    cd tfjs-examples/mnist-transfer-cnn
    
  3. निर्भरताएँ स्थापित करें:

    yarn
    
  4. विकास सर्वर प्रारंभ करें:

    yarn run watch
    

उदाहरण का अन्वेषण करें

उदाहरण ऐप खोलें . (या, यदि आप स्थानीय रूप से उदाहरण चला रहे हैं, तो अपने ब्राउज़र में http://localhost:1234 पर जाएं।)

आपको एमएनआईएसटी सीएनएन ट्रांसफर लर्निंग शीर्षक वाला एक पेज देखना चाहिए। ऐप को आज़माने के लिए निर्देशों का पालन करें।

यहां आज़माने लायक कुछ चीज़ें दी गई हैं:

  • विभिन्न प्रशिक्षण मोड के साथ प्रयोग करें और हानि और सटीकता की तुलना करें।
  • विभिन्न बिटमैप उदाहरण चुनें और वर्गीकरण संभावनाओं का निरीक्षण करें। ध्यान दें कि प्रत्येक बिटमैप उदाहरण में संख्याएँ एक छवि से पिक्सेल का प्रतिनिधित्व करने वाले ग्रेस्केल पूर्णांक मान हैं।
  • बिटमैप पूर्णांक मान संपादित करें और देखें कि परिवर्तन वर्गीकरण संभावनाओं को कैसे प्रभावित करते हैं।

कोड का अन्वेषण करें

उदाहरण वेब ऐप एक मॉडल को लोड करता है जिसे एमएनआईएसटी डेटासेट के सबसेट पर पूर्व-प्रशिक्षित किया गया है। पूर्व-प्रशिक्षण को पायथन प्रोग्राम में परिभाषित किया गया है: mnist_transfer_cnn.py । पायथन प्रोग्राम इस ट्यूटोरियल के दायरे से बाहर है, लेकिन यदि आप मॉडल रूपांतरण का एक उदाहरण देखना चाहते हैं तो यह देखने लायक है।

index.js फ़ाइल में डेमो के लिए अधिकांश प्रशिक्षण कोड शामिल हैं। जब index.js ब्राउज़र में चलता है, तो एक सेटअप फ़ंक्शन, setupMnistTransferCNN , MnistTransferCNNPredictor इंस्टेंट और इनिशियलाइज़ करता है, जो रीट्रेनिंग और भविष्यवाणी रूटीन को इनकैप्सुलेट करता है।

आरंभीकरण विधि, MnistTransferCNNPredictor.init , एक मॉडल लोड करती है, पुनः प्रशिक्षण डेटा लोड करती है, और परीक्षण डेटा बनाती है। यहां वह पंक्ति है जो मॉडल को लोड करती है:

this.model = await loader.loadHostedPretrainedModel(urls.model);

यदि आप loader.loadHostedPretrainedModel की परिभाषा को देखें, तो आप देखेंगे कि यह कॉल का परिणाम tf.loadLayersModel पर लौटाता है। लेयर ऑब्जेक्ट से बने मॉडल को लोड करने के लिए यह TensorFlow.js API है।

पुनर्प्रशिक्षण तर्क को MnistTransferCNNPredictor.retrainModel में परिभाषित किया गया है। यदि उपयोगकर्ता ने फ़्रीज़ फ़ीचर लेयर्स को प्रशिक्षण मोड के रूप में चुना है, तो बेस मॉडल की पहली 7 परतें फ़्रीज़ हो जाती हैं, और केवल अंतिम 5 परतें नए डेटा पर प्रशिक्षित होती हैं। यदि उपयोगकर्ता ने रीइनिशियलाइज़ वेट का चयन किया है, तो सभी वज़न रीसेट हो जाते हैं, और ऐप प्रभावी ढंग से मॉडल को स्क्रैच से प्रशिक्षित करता है।

if (trainingMode === 'freeze-feature-layers') {
  console.log('Freezing feature layers of the model.');
  for (let i = 0; i < 7; ++i) {
    this.model.layers[i].trainable = false;
  }
} else if (trainingMode === 'reinitialize-weights') {
  // Make a model with the same topology as before, but with re-initialized
  // weight values.
  const returnString = false;
  this.model = await tf.models.modelFromJSON({
    modelTopology: this.model.toJSON(null, returnString)
  });
}

फिर मॉडल संकलित किया जाता है, और फिर इसे model.fit() उपयोग करके परीक्षण डेटा पर प्रशिक्षित किया जाता है :

await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
  batchSize: batchSize,
  epochs: epochs,
  validationData: [this.gte5TestData.x, this.gte5TestData.y],
  callbacks: [
    ui.getProgressBarCallbackConfig(epochs),
    tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
      zoomToFit: true,
      zoomToFitAccuracy: true,
      height: 200,
      callbacks: ['onEpochEnd'],
    }),
  ]
});

model.fit() पैरामीटर के बारे में अधिक जानने के लिए, API दस्तावेज़ देखें।

नए डेटासेट (अंक 5-9) पर प्रशिक्षित होने के बाद, मॉडल का उपयोग भविष्यवाणियां करने के लिए किया जा सकता है। MnistTransferCNNPredictor.predict विधि model.predict() उपयोग करके ऐसा करती है:

// Perform prediction on the input image using the loaded model.
predict(imageText) {
  tf.tidy(() => {
    try {
      const image = util.textToImageArray(imageText, this.imageSize);
      const predictOut = this.model.predict(image);
      const winner = predictOut.argMax(1);

      ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
    } catch (e) {
      ui.setPredictError(e.message);
    }
  });
}

tf.tidy के उपयोग पर ध्यान दें, जो मेमोरी लीक को रोकने में मदद करता है।

और अधिक जानें

इस ट्यूटोरियल में एक उदाहरण ऐप का पता लगाया गया है जो TensorFlow.js का उपयोग करके ब्राउज़र में ट्रांसफर लर्निंग करता है। पूर्व-प्रशिक्षित मॉडल और स्थानांतरण शिक्षण के बारे में अधिक जानने के लिए नीचे दिए गए संसाधनों को देखें।

TensorFlow.js

टेन्सरफ्लो कोर