Önceden eğitilmiş bir model kullanın

Bu eğitimde TensorFlow.js Katmanlar API'sini kullanarak aktarım öğrenimini gösteren örnek bir web uygulamasını keşfedeceksiniz. Örnek, önceden eğitilmiş bir modeli yükler ve ardından modeli tarayıcıda yeniden eğitir.

Model , MNIST basamak sınıflandırma veri kümesinin 0-4 basamakları üzerinde Python'da önceden eğitilmiştir. Tarayıcıdaki yeniden eğitim (veya transfer öğrenimi) 5-9 rakamlarını kullanır. Örnek, önceden eğitilmiş bir modelin ilk birkaç katmanının, transfer öğrenimi sırasında yeni verilerden özellikler çıkarmak için kullanılabileceğini ve böylece yeni veriler üzerinde daha hızlı eğitim sağlanabileceğini göstermektedir.

Bu eğitimin örnek uygulaması çevrimiçi olarak mevcuttur , dolayısıyla herhangi bir kod indirmenize veya bir geliştirme ortamı kurmanıza gerek yoktur. Kodu yerel olarak çalıştırmak istiyorsanız Örneği yerel olarak çalıştırma bölümündeki isteğe bağlı adımları tamamlayın. Bir geliştirme ortamı ayarlamak istemiyorsanız Örneği keşfetme bölümüne geçebilirsiniz.

Örnek kod GitHub'da mevcuttur.

(İsteğe bağlı) Örneği yerel olarak çalıştırın

Önkoşullar

Örnek uygulamayı yerel olarak çalıştırmak için geliştirme ortamınızda aşağıdakilerin yüklü olması gerekir:

Örnek uygulamayı yükleyin ve çalıştırın

  1. tfjs-examples deposunu klonlayın veya indirin.
  2. mnist-transfer-cnn dizinine geçin:

    cd tfjs-examples/mnist-transfer-cnn
    
  3. Bağımlılıkları yükleyin:

    yarn
    
  4. Geliştirme sunucusunu başlatın:

    yarn run watch
    

Örneği keşfedin

Örnek uygulamayı açın . (Ya da örneği yerel olarak çalıştırıyorsanız tarayıcınızda http://localhost:1234 adresine gidin.)

MNIST CNN Transfer Learning başlıklı bir sayfa görmelisiniz. Uygulamayı denemek için talimatları izleyin.

İşte deneyebileceğiniz birkaç şey:

  • Farklı eğitim modlarıyla denemeler yapın ve kayıp ile doğruluğu karşılaştırın.
  • Farklı bitmap örnekleri seçin ve sınıflandırma olasılıklarını inceleyin. Her bitmap örneğindeki sayıların, bir görüntüdeki pikselleri temsil eden gri tonlamalı tam sayı değerleri olduğunu unutmayın.
  • Bitmap tamsayı değerlerini düzenleyin ve değişikliklerin sınıflandırma olasılıklarını nasıl etkilediğini görün.

Kodu keşfedin

Örnek web uygulaması, MNIST veri kümesinin bir alt kümesinde önceden eğitilmiş bir modeli yükler. Ön eğitim bir Python programında tanımlanır: mnist_transfer_cnn.py . Python programı bu eğitimin kapsamı dışındadır, ancak bir model dönüştürme örneği görmek istiyorsanız, buna bakmaya değer.

index.js dosyası demo için eğitim kodunun çoğunu içerir. index.js tarayıcıda çalıştığında, setupMnistTransferCNN bir kurulum işlevi, yeniden eğitim ve tahmin rutinlerini kapsayan MnistTransferCNNPredictor örneğini oluşturur ve başlatır.

Başlatma yöntemi MnistTransferCNNPredictor.init bir model yükler, yeniden eğitim verilerini yükler ve test verileri oluşturur. İşte modeli yükleyen satır :

this.model = await loader.loadHostedPretrainedModel(urls.model);

loader.loadHostedPretrainedModel tanımına bakarsanız, tf.loadLayersModel çağrısının sonucunu döndürdüğünü görürsünüz. Bu, Katman nesnelerinden oluşan bir modeli yüklemek için kullanılan TensorFlow.js API'sidir.

Yeniden eğitim mantığı MnistTransferCNNPredictor.retrainModel tanımlanmıştır. Kullanıcı eğitim modu olarak Özellik katmanlarını dondur'u seçtiyse temel modelin ilk 7 katmanı dondurulur ve yalnızca son 5 katman yeni veriler üzerinde eğitilir. Kullanıcı Ağırlıkları yeniden başlat seçeneğini seçtiyse tüm ağırlıklar sıfırlanır ve uygulama, modeli etkili bir şekilde sıfırdan eğitir.

if (trainingMode === 'freeze-feature-layers') {
  console.log('Freezing feature layers of the model.');
  for (let i = 0; i < 7; ++i) {
    this.model.layers[i].trainable = false;
  }
} else if (trainingMode === 'reinitialize-weights') {
  // Make a model with the same topology as before, but with re-initialized
  // weight values.
  const returnString = false;
  this.model = await tf.models.modelFromJSON({
    modelTopology: this.model.toJSON(null, returnString)
  });
}

Model daha sonra derlenir ve model.fit() kullanılarak test verileri üzerinde eğitilir :

await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
  batchSize: batchSize,
  epochs: epochs,
  validationData: [this.gte5TestData.x, this.gte5TestData.y],
  callbacks: [
    ui.getProgressBarCallbackConfig(epochs),
    tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
      zoomToFit: true,
      zoomToFitAccuracy: true,
      height: 200,
      callbacks: ['onEpochEnd'],
    }),
  ]
});

model.fit() parametreleri hakkında daha fazla bilgi edinmek için API belgelerine bakın.

Yeni veri seti (5-9 arasındaki rakamlar) üzerinde eğitildikten sonra model, tahminlerde bulunmak için kullanılabilir. MnistTransferCNNPredictor.predict yöntemi bunu model.predict() kullanarak yapar:

// Perform prediction on the input image using the loaded model.
predict(imageText) {
  tf.tidy(() => {
    try {
      const image = util.textToImageArray(imageText, this.imageSize);
      const predictOut = this.model.predict(image);
      const winner = predictOut.argMax(1);

      ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
    } catch (e) {
      ui.setPredictError(e.message);
    }
  });
}

Bellek sızıntılarını önlemeye yardımcı olan tf.tidy kullanımına dikkat edin.

Daha fazla bilgi edin

Bu eğitimde TensorFlow.js kullanarak tarayıcıda aktarım öğrenimi gerçekleştiren örnek bir uygulamayı inceledik. Önceden eğitilmiş modeller ve transfer öğrenimi hakkında daha fazla bilgi edinmek için aşağıdaki kaynaklara göz atın.

TensorFlow.js

TensorFlow Çekirdeği