Wytrenuj model przy użyciu pracownika sieci,Wytrenuj model przy użyciu pracownika sieci

W tym samouczku poznasz przykładową aplikację internetową, która wykorzystuje proces roboczy sieci Web do uczenia rekurencyjnej sieci neuronowej (RNN) w zakresie dodawania liczb całkowitych. Przykładowa aplikacja nie definiuje jawnie operatora dodawania. Zamiast tego trenuje RNN przy użyciu przykładowych sum.

Oczywiście nie jest to najskuteczniejszy sposób dodania dwóch liczb całkowitych! Jednak samouczek demonstruje ważną technikę w web ML: jak wykonywać długotrwałe obliczenia bez blokowania głównego wątku, który obsługuje logikę interfejsu użytkownika.

Przykładowa aplikacja do tego samouczka jest dostępna online , więc nie musisz pobierać żadnego kodu ani konfigurować środowiska programistycznego. Jeśli chcesz uruchomić kod lokalnie, wykonaj opcjonalne kroki w temacie Lokalne uruchamianie przykładu . Jeśli nie chcesz konfigurować środowiska programistycznego, możesz przejść do sekcji Poznaj przykład .

Przykładowy kod jest dostępny na GitHubie .

(Opcjonalnie) Uruchom przykład lokalnie

Warunki wstępne

Aby uruchomić przykładową aplikację lokalnie, w środowisku programistycznym muszą być zainstalowane następujące elementy:

Zainstaluj i uruchom przykładową aplikację

  1. Sklonuj lub pobierz repozytorium tfjs-examples .
  2. Przejdź do katalogu addition-rnn-webworker :

    cd tfjs-examples/addition-rnn-webworker
    
  3. Zainstaluj zależności:

    yarn
    
  4. Uruchom serwer deweloperski:

    yarn run watch
    

Przeanalizuj przykład

Otwórz przykładową aplikację . (Lub, jeśli uruchamiasz przykład lokalnie, przejdź do http://localhost:1234 w swojej przeglądarce.)

Powinieneś zobaczyć stronę zatytułowaną TensorFlow.js: Addition RNN . Postępuj zgodnie z instrukcjami, aby wypróbować aplikację.

Korzystając z formularza internetowego, możesz zaktualizować niektóre parametry używane do uczenia modelu, w tym następujące:

  • Cyfry : maksymalna liczba cyfr w terminach, które mają zostać dodane.
  • Rozmiar szkolenia : liczba przykładów szkoleniowych do wygenerowania.
  • Typ RNN : Jeden z SimpleRNN , GRU lub LSTM .
  • RNN Rozmiar warstwy ukrytej : Wymiar przestrzeni wyjściowej (musi być dodatnią liczbą całkowitą).
  • Rozmiar partii : Liczba próbek na aktualizację gradientu.
  • Iteracje pociągu : liczba powtórzeń uczenia modelu poprzez wywołanie model.fit()
  • Liczba przykładów testowych : liczba przykładowych ciągów znaków (na przykład 27+41 ) do wygenerowania.

Spróbuj wytrenować model z różnymi parametrami i sprawdź, czy możesz poprawić dokładność przewidywań dla różnych zestawów cyfr. Zwróć także uwagę, jak różne parametry wpływają na czas dopasowania modelu.

Poznaj kod

Przykładowa aplikacja demonstruje niektóre parametry, które można skonfigurować na potrzeby szkolenia RNN. Pokazuje także użycie procesu roboczego sieciowego do uczenia modelu poza głównym wątkiem. Procesy robocze sieci Web są ważne w procesie uczenia maszynowego w sieci Web, ponieważ umożliwiają uruchamianie kosztownych obliczeniowo zadań szkoleniowych w wątku w tle, unikając w ten sposób problemów z wydajnością, które mogą mieć wpływ na użytkownika w głównym wątku. Wątki główny i roboczy komunikują się ze sobą poprzez zdarzenia komunikatów.

Aby dowiedzieć się więcej na temat procesów roboczych sieci Web, zobacz Interfejs API procesów roboczych sieci Web i Korzystanie z procesów roboczych sieci Web .

Głównym modułem przykładowej aplikacji jest index.js . Skrypt index.js tworzy proces roboczy sieci WWW , który uruchamia moduł worker.js :

const worker =
    new Worker(new URL('./worker.js', import.meta.url), {type: 'module'});

index.js składa się głównie z jednej funkcji, runAdditionRNNDemo , która obsługuje przesyłanie formularza, przetwarza dane z formularza, przekazuje dane z formularza do pracownika, czeka, aż pracownik wytrenuje model i zwróci wyniki, a następnie wyświetla wyniki na stronie .

Aby wysłać dane formularza do pracownika, skrypt wywołuje na pracowniku postMessage :

worker.postMessage({
  digits,
  trainingSize,
  rnnType,
  layers,
  hiddenSize,
  trainIterations,
  batchSize,
  numTestExamples
});

Pracownik nasłuchuje tego komunikatu i przekazuje dane z formularza do funkcji, które przygotowują dane i rozpoczynają szkolenie:

self.addEventListener('message', async (e) => {
  const { digits, trainingSize, rnnType, layers, hiddenSize, trainIterations, batchSize, numTestExamples } = e.data;
  const demo = new AdditionRNNDemo(digits, trainingSize, rnnType, layers, hiddenSize);
  await demo.train(trainIterations, batchSize, numTestExamples);
})

Podczas szkolenia pracownik może wysłać dwa różne typy komunikatów, jeden z isPredict ustawioną na true

self.postMessage({
  isPredict: true,
  i, iterations, modelFitTime,
  lossValues, accuracyValues,
});

a drugi z isPredict ustawionym na false .

self.postMessage({
  isPredict: false,
  isCorrect, examples
});

Gdy wątek interfejsu użytkownika ( index.js ) obsługuje zdarzenia komunikatów, sprawdza flagę isPredict , aby określić kształt danych zwracanych przez proces roboczy. Jeśli isPredict ma wartość true, dane powinny reprezentować prognozę, a skrypt aktualizuje stronę za pomocą tfjs-vis . Jeśli isPredict ma wartość false, skrypt uruchamia blok kodu , który zakłada, że ​​dane reprezentują przykłady. Zawija dane w formacie HTML i wstawia kod HTML na stronę.

Co dalej

W tym samouczku przedstawiono przykład użycia procesu roboczego sieci Web, aby uniknąć blokowania wątku interfejsu użytkownika podczas długotrwałego procesu szkoleniowego. Aby dowiedzieć się więcej na temat korzyści wynikających z wykonywania kosztownych obliczeń w wątku w tle, zobacz Używanie procesów roboczych sieci Web do uruchamiania JavaScript z głównego wątku przeglądarki .

Aby dowiedzieć się więcej na temat uczenia modelu TensorFlow.js, zobacz Modele szkoleniowe .