Latih model menggunakan pekerja web

Dalam tutorial ini, Anda akan menjelajahi contoh aplikasi web yang menggunakan pekerja web untuk melatih Jaringan Syaraf Berulang (RNN) untuk melakukan penjumlahan bilangan bulat. Aplikasi contoh tidak secara eksplisit mendefinisikan operator penambahan. Sebaliknya, ia melatih RNN menggunakan jumlah contoh.

Tentu saja, ini bukan cara paling efisien untuk menjumlahkan dua bilangan bulat! Namun tutorial ini menunjukkan teknik penting dalam web ML: cara melakukan komputasi yang berjalan lama tanpa memblokir thread utama, yang menangani logika UI.

Contoh aplikasi untuk tutorial ini tersedia online , jadi Anda tidak perlu mengunduh kode apa pun atau menyiapkan lingkungan pengembangan. Jika Anda ingin menjalankan kode secara lokal, selesaikan langkah opsional di Jalankan contoh secara lokal . Jika Anda tidak ingin menyiapkan lingkungan pengembangan, Anda dapat melompat ke Jelajahi contohnya .

Kode contoh tersedia di GitHub .

(Opsional) Jalankan contoh secara lokal

Prasyarat

Untuk menjalankan aplikasi contoh secara lokal, Anda perlu menginstal yang berikut ini di lingkungan pengembangan Anda:

Instal dan jalankan aplikasi contoh

  1. Kloning atau unduh repositori tfjs-examples .
  2. Ubah ke direktori addition-rnn-webworker :

    cd tfjs-examples/addition-rnn-webworker
    
  3. Instal dependensi:

    yarn
    
  4. Mulai server pengembangan:

    yarn run watch
    

Jelajahi contohnya

Buka aplikasi contoh . (Atau, jika Anda menjalankan contoh secara lokal, kunjungi http://localhost:1234 di browser Anda.)

Anda akan melihat halaman berjudul TensorFlow.js: Addition RNN . Ikuti petunjuk untuk mencoba aplikasinya.

Dengan menggunakan formulir web, Anda bisa memperbarui beberapa parameter yang digunakan untuk melatih model, termasuk yang berikut:

  • Digit : Jumlah maksimum digit dalam istilah yang akan ditambahkan.
  • Ukuran Pelatihan : Jumlah contoh pelatihan yang akan dihasilkan.
  • Jenis RNN : Salah satu dari SimpleRNN , GRU , atau LSTM .
  • RNN Hidden Layer Size : Dimensi ruang keluaran (harus bilangan bulat positif).
  • Ukuran Batch : Jumlah sampel per pembaruan gradien.
  • Latih Iterasi : Berapa kali melatih model dengan memanggil model.fit()
  • # contoh pengujian : Jumlah contoh string (misalnya, 27+41 ) yang akan dihasilkan.

Coba latih model dengan parameter berbeda, dan lihat apakah Anda dapat meningkatkan akurasi prediksi untuk berbagai kumpulan digit. Perhatikan juga bagaimana waktu kesesuaian model dipengaruhi oleh parameter yang berbeda.

Jelajahi kodenya

Contoh aplikasi menunjukkan beberapa parameter yang dapat Anda konfigurasi untuk melatih RNN. Hal ini juga menunjukkan penggunaan pekerja web untuk melatih model dari thread utama. Pekerja web penting dalam ML web karena memungkinkan Anda menjalankan tugas pelatihan yang mahal secara komputasi di thread latar belakang, sehingga menghindari potensi masalah performa yang berdampak pada pengguna di thread utama. Thread utama dan thread pekerja berkomunikasi satu sama lain melalui peristiwa pesan.

Untuk mempelajari selengkapnya tentang pekerja web, lihat Web Workers API dan Menggunakan Web Workers .

Modul utama untuk aplikasi contoh adalah index.js . Skrip index.js membuat pekerja web yang menjalankan modul worker.js :

const worker =
    new Worker(new URL('./worker.js', import.meta.url), {type: 'module'});

index.js sebagian besar terdiri dari satu fungsi, runAdditionRNNDemo , yang menangani pengiriman formulir, memproses data formulir, meneruskan data formulir ke pekerja, menunggu pekerja melatih model dan mengembalikan hasilnya, lalu menampilkan hasilnya di halaman .

Untuk mengirim data formulir ke pekerja, skrip memanggil postMessage pada pekerja:

worker.postMessage({
  digits,
  trainingSize,
  rnnType,
  layers,
  hiddenSize,
  trainIterations,
  batchSize,
  numTestExamples
});

Pekerja mendengarkan pesan ini dan meneruskan data formulir ke fungsi yang menyiapkan data dan memulai pelatihan:

self.addEventListener('message', async (e) => {
  const { digits, trainingSize, rnnType, layers, hiddenSize, trainIterations, batchSize, numTestExamples } = e.data;
  const demo = new AdditionRNNDemo(digits, trainingSize, rnnType, layers, hiddenSize);
  await demo.train(trainIterations, batchSize, numTestExamples);
})

Selama pelatihan, pekerja dapat mengirim dua jenis pesan berbeda, satu dengan isPredict yang disetel ke true

self.postMessage({
  isPredict: true,
  i, iterations, modelFitTime,
  lossValues, accuracyValues,
});

dan yang lainnya dengan isPredict disetel ke false .

self.postMessage({
  isPredict: false,
  isCorrect, examples
});

Saat thread UI ( index.js ) menangani kejadian pesan, thread ini akan memeriksa flag isPredict untuk menentukan bentuk data yang dikembalikan dari pekerja. Jika isPredict benar, data harus mewakili prediksi, dan skrip memperbarui halaman menggunakan tfjs-vis . Jika isPredict salah, skrip menjalankan blok kode yang mengasumsikan bahwa data mewakili contoh. Itu membungkus data dalam HTML dan memasukkan HTML ke dalam halaman.

Apa berikutnya

Tutorial ini telah memberikan contoh penggunaan pekerja web untuk menghindari pemblokiran thread UI dengan proses pelatihan yang berjalan lama. Untuk mempelajari lebih lanjut tentang manfaat melakukan komputasi yang mahal pada thread latar belakang, lihat Menggunakan pekerja web untuk menjalankan JavaScript dari thread utama browser .

Untuk mempelajari lebih lanjut cara melatih model TensorFlow.js, lihat Model pelatihan .