Treine um modelo usando um web worker

Neste tutorial, você explorará um exemplo de aplicativo Web que usa um web trabalhador para treinar uma Rede Neural Recorrente (RNN) para fazer adição de números inteiros. O aplicativo de exemplo não define explicitamente o operador de adição. Em vez disso, ele treina o RNN usando somas de exemplo.

Claro, esta não é a maneira mais eficiente de somar dois números inteiros! Mas o tutorial demonstra uma técnica importante em ML da web: como realizar cálculos de longa duração sem bloquear o thread principal, que lida com a lógica da UI.

O aplicativo de exemplo deste tutorial está disponível online , portanto você não precisa baixar nenhum código ou configurar um ambiente de desenvolvimento. Se quiser executar o código localmente, conclua as etapas opcionais em Executar o exemplo localmente . Se não quiser configurar um ambiente de desenvolvimento, você pode pular para Explorar o exemplo .

O código de exemplo está disponível no GitHub .

(Opcional) Execute o exemplo localmente

Pré-requisitos

Para executar o aplicativo de exemplo localmente, você precisa do seguinte instalado em seu ambiente de desenvolvimento:

Instale e execute o aplicativo de exemplo

  1. Clone ou baixe o repositório tfjs-examples .
  2. Mude para o diretório addition-rnn-webworker :

    cd tfjs-examples/addition-rnn-webworker
    
  3. Instale dependências:

    yarn
    
  4. Inicie o servidor de desenvolvimento:

    yarn run watch
    

Explore o exemplo

Abra o aplicativo de exemplo . (Ou, se você estiver executando o exemplo localmente, acesse http://localhost:1234 no seu navegador.)

Você deverá ver uma página intitulada TensorFlow.js: Addition RNN . Siga as instruções para experimentar o aplicativo.

Usando o formulário web, você pode atualizar alguns dos parâmetros usados ​​para treinar o modelo, incluindo o seguinte:

  • Dígitos : O número máximo de dígitos nos termos a serem adicionados.
  • Tamanho do treinamento : o número de exemplos de treinamento a serem gerados.
  • Tipo RNN : SimpleRNN , GRU ou LSTM .
  • Tamanho da camada oculta RNN : Dimensionalidade do espaço de saída (deve ser um número inteiro positivo).
  • Tamanho do lote : número de amostras por atualização de gradiente.
  • Treinar iterações : número de vezes para treinar o modelo invocando model.fit()
  • Nº de exemplos de teste : número de strings de exemplo (por exemplo, 27+41 ) a serem geradas.

Tente treinar o modelo com parâmetros diferentes e veja se você consegue melhorar a precisão das previsões para vários conjuntos de dígitos. Observe também como o tempo de ajuste do modelo é afetado por diferentes parâmetros.

Explorar o código

O aplicativo de exemplo demonstra alguns dos parâmetros que você pode configurar para treinar um RNN. Também demonstra o uso de um web trabalhador para treinar um modelo fora do thread principal. Os web workers são importantes no web ML porque permitem executar tarefas de treinamento computacionalmente caras em um thread em segundo plano, evitando assim problemas de desempenho que podem afetar o usuário no thread principal. Os threads principal e de trabalho se comunicam entre si por meio de eventos de mensagem.

Para saber mais sobre web workers, consulte API de Web Workers e Uso de Web Workers .

O módulo principal do aplicativo de exemplo é index.js . O script index.js cria um web trabalhador que executa o módulo worker.js :

const worker =
    new Worker(new URL('./worker.js', import.meta.url), {type: 'module'});

index.js é amplamente composto por uma única função, runAdditionRNNDemo , que lida com o envio do formulário, processa os dados do formulário, passa os dados do formulário para o trabalhador, espera que o trabalhador treine o modelo e retorne os resultados e, em seguida, exiba os resultados na página .

Para enviar os dados do formulário ao trabalhador, o script invoca postMessage no trabalhador:

worker.postMessage({
  digits,
  trainingSize,
  rnnType,
  layers,
  hiddenSize,
  trainIterations,
  batchSize,
  numTestExamples
});

O trabalhador escuta esta mensagem e passa os dados do formulário para funções que preparam os dados e iniciam o treinamento:

self.addEventListener('message', async (e) => {
  const { digits, trainingSize, rnnType, layers, hiddenSize, trainIterations, batchSize, numTestExamples } = e.data;
  const demo = new AdditionRNNDemo(digits, trainingSize, rnnType, layers, hiddenSize);
  await demo.train(trainIterations, batchSize, numTestExamples);
})

Durante o treinamento, o trabalhador pode enviar dois tipos de mensagens diferentes, um com isPredict definido como true

self.postMessage({
  isPredict: true,
  i, iterations, modelFitTime,
  lossValues, accuracyValues,
});

e o outro com isPredict definido como false .

self.postMessage({
  isPredict: false,
  isCorrect, examples
});

Quando o thread de UI ( index.js ) manipula eventos de mensagem, ele verifica o sinalizador isPredict para determinar a forma dos dados retornados do trabalhador. Se isPredict for verdadeiro, os dados deverão representar uma previsão e o script atualizará a página usando tfjs-vis . Se isPredict for falso, o script executa um bloco de código que assume que os dados representam exemplos. Ele agrupa os dados em HTML e insere o HTML na página.

Qual é o próximo

Este tutorial forneceu um exemplo de uso de um web trabalhador para evitar o bloqueio do thread de UI com um processo de treinamento de longa duração. Para saber mais sobre os benefícios de fazer cálculos caros em um thread em segundo plano, consulte Usar web workers para executar JavaScript no thread principal do navegador .

Para saber mais sobre como treinar um modelo do TensorFlow.js, consulte Treinamento de modelos .