Gunakan model terlatih

Dalam tutorial ini Anda akan menjelajahi contoh aplikasi web yang mendemonstrasikan pembelajaran transfer menggunakan TensorFlow.js Layers API. Contoh ini memuat model yang telah dilatih sebelumnya dan kemudian melatih ulang model tersebut di browser.

Model ini telah dilatih sebelumnya dengan Python pada digit 0-4 dari kumpulan data klasifikasi digit MNIST . Pelatihan ulang (atau transfer pembelajaran) di browser menggunakan angka 5-9. Contoh tersebut menunjukkan bahwa beberapa lapisan pertama model yang telah dilatih sebelumnya dapat digunakan untuk mengekstrak fitur dari data baru selama pembelajaran transfer, sehingga memungkinkan pelatihan yang lebih cepat pada data baru.

Contoh aplikasi untuk tutorial ini tersedia online , jadi Anda tidak perlu mengunduh kode apa pun atau menyiapkan lingkungan pengembangan. Jika Anda ingin menjalankan kode secara lokal, selesaikan langkah opsional di Jalankan contoh secara lokal . Jika Anda tidak ingin menyiapkan lingkungan pengembangan, Anda dapat melompat ke Jelajahi contohnya .

Kode contoh tersedia di GitHub .

(Opsional) Jalankan contoh secara lokal

Prasyarat

Untuk menjalankan aplikasi contoh secara lokal, Anda perlu menginstal yang berikut ini di lingkungan pengembangan Anda:

Instal dan jalankan aplikasi contoh

  1. Kloning atau unduh repositori tfjs-examples .
  2. Ubah ke direktori mnist-transfer-cnn :

    cd tfjs-examples/mnist-transfer-cnn
    
  3. Instal dependensi:

    yarn
    
  4. Mulai server pengembangan:

    yarn run watch
    

Jelajahi contohnya

Buka aplikasi contoh . (Atau, jika Anda menjalankan contoh secara lokal, kunjungi http://localhost:1234 di browser Anda.)

Anda akan melihat halaman berjudul MNIST CNN Transfer Learning . Ikuti petunjuk untuk mencoba aplikasinya.

Berikut beberapa hal yang bisa dicoba:

  • Bereksperimenlah dengan berbagai mode pelatihan dan bandingkan kerugian dan akurasi.
  • Pilih contoh bitmap yang berbeda dan periksa probabilitas klasifikasi. Perhatikan bahwa angka-angka dalam setiap contoh bitmap adalah nilai integer skala abu-abu yang mewakili piksel dari suatu gambar.
  • Edit nilai integer bitmap dan lihat bagaimana perubahan mempengaruhi probabilitas klasifikasi.

Jelajahi kodenya

Contoh aplikasi web memuat model yang telah dilatih sebelumnya pada subkumpulan kumpulan data MNIST. Pra-pelatihan didefinisikan dalam program Python: mnist_transfer_cnn.py . Program Python berada di luar cakupan tutorial ini, tetapi ada baiknya melihat jika Anda ingin melihat contoh konversi model .

File index.js berisi sebagian besar kode pelatihan untuk demo. Saat index.js berjalan di browser, fungsi pengaturan, setupMnistTransferCNN , membuat instance dan menginisialisasi MnistTransferCNNPredictor , yang merangkum rutinitas pelatihan ulang dan prediksi.

Metode inisialisasi, MnistTransferCNNPredictor.init , memuat model, memuat data pelatihan ulang, dan membuat data pengujian. Inilah baris yang memuat model:

this.model = await loader.loadHostedPretrainedModel(urls.model);

Jika Anda melihat definisi loader.loadHostedPretrainedModel , Anda akan melihat bahwa ia mengembalikan hasil panggilan ke tf.loadLayersModel . Ini adalah TensorFlow.js API untuk memuat model yang terdiri dari objek Lapisan.

Logika pelatihan ulang didefinisikan dalam MnistTransferCNNPredictor.retrainModel . Jika pengguna telah memilih Bekukan lapisan fitur sebagai mode pelatihan, 7 lapisan pertama model dasar akan dibekukan, dan hanya 5 lapisan terakhir yang dilatih pada data baru. Jika pengguna telah memilih Inisialisasi ulang bobot , semua bobot akan disetel ulang, dan aplikasi secara efektif melatih model dari awal.

if (trainingMode === 'freeze-feature-layers') {
  console.log('Freezing feature layers of the model.');
  for (let i = 0; i < 7; ++i) {
    this.model.layers[i].trainable = false;
  }
} else if (trainingMode === 'reinitialize-weights') {
  // Make a model with the same topology as before, but with re-initialized
  // weight values.
  const returnString = false;
  this.model = await tf.models.modelFromJSON({
    modelTopology: this.model.toJSON(null, returnString)
  });
}

Model tersebut kemudian dikompilasi , dan kemudian dilatih pada data pengujian menggunakan model.fit() :

await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
  batchSize: batchSize,
  epochs: epochs,
  validationData: [this.gte5TestData.x, this.gte5TestData.y],
  callbacks: [
    ui.getProgressBarCallbackConfig(epochs),
    tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
      zoomToFit: true,
      zoomToFitAccuracy: true,
      height: 200,
      callbacks: ['onEpochEnd'],
    }),
  ]
});

Untuk mempelajari lebih lanjut tentang parameter model.fit() , lihat dokumentasi API .

Setelah dilatih pada dataset baru (digit 5-9), model dapat digunakan untuk membuat prediksi. Metode MnistTransferCNNPredictor.predict melakukan ini menggunakan model.predict() :

// Perform prediction on the input image using the loaded model.
predict(imageText) {
  tf.tidy(() => {
    try {
      const image = util.textToImageArray(imageText, this.imageSize);
      const predictOut = this.model.predict(image);
      const winner = predictOut.argMax(1);

      ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
    } catch (e) {
      ui.setPredictError(e.message);
    }
  });
}

Perhatikan penggunaan tf.tidy , yang membantu mencegah kebocoran memori.

Belajarlah lagi

Tutorial ini telah mengeksplorasi contoh aplikasi yang melakukan pembelajaran transfer di browser menggunakan TensorFlow.js. Lihat sumber daya di bawah untuk mempelajari lebih lanjut tentang model terlatih dan pembelajaran transfer.

TensorFlow.js

Inti TensorFlow