Treine um modelo usando um web worker

Neste tutorial, você explorará um aplicativo da Web de exemplo que usa um trabalhador da Web para treinar uma Rede Neural Recorrente (RNN) para fazer a adição de números inteiros. O aplicativo de exemplo não define explicitamente o operador de adição. Em vez disso, ele treina o RNN usando somas de exemplo.

Claro, esta não é a maneira mais eficiente de somar dois números inteiros! Mas o tutorial demonstra uma técnica importante no ML da Web: como executar cálculos de longa duração sem bloquear o thread principal, que lida com a lógica da interface do usuário.

O aplicativo de exemplo para este tutorial está disponível online , então você não precisa baixar nenhum código ou configurar um ambiente de desenvolvimento. Se quiser executar o código localmente, conclua as etapas opcionais em Executar o exemplo localmente . Se você não deseja configurar um ambiente de desenvolvimento, pode pular para Explorar o exemplo .

O código de exemplo está disponível no GitHub .

(Opcional) Execute o exemplo localmente

Pré-requisitos

Para executar o aplicativo de exemplo localmente, você precisa do seguinte instalado em seu ambiente de desenvolvimento:

Instale e execute o aplicativo de exemplo

  1. Clone ou baixe o repositório tfjs-examples .
  2. Mude para o diretório addition-rnn-webworker :

    cd tfjs-examples/addition-rnn-webworker
    
  3. Instalar dependências:

    yarn
    
  4. Inicie o servidor de desenvolvimento:

    yarn run watch
    

Explorar o exemplo

Abra o aplicativo de exemplo . (Ou, se estiver executando o exemplo localmente, acesse http://localhost:1234 em seu navegador.)

Você deve ver uma página intitulada TensorFlow.js: Addition RNN . Siga as instruções para experimentar o aplicativo.

Usando o formulário da web, você pode atualizar alguns dos parâmetros usados ​​para treinar o modelo, incluindo o seguinte:

  • Dígitos : O número máximo de dígitos nos termos a serem adicionados.
  • Tamanho do treinamento : o número de exemplos de treinamento a serem gerados.
  • Tipo de RNN : Um entre SimpleRNN , GRU ou LSTM .
  • RNN Hidden Layer Size : Dimensionalidade do espaço de saída (deve ser um número inteiro positivo).
  • Tamanho do lote : número de amostras por atualização de gradiente.
  • Treinar iterações : número de vezes para treinar o modelo invocando model.fit()
  • Nº de exemplos de teste : Número de strings de exemplo (por exemplo, 27+41 ) a serem geradas.

Tente treinar o modelo com diferentes parâmetros e veja se você pode melhorar a precisão das previsões para vários conjuntos de dígitos. Observe também como o tempo de ajuste do modelo é afetado por diferentes parâmetros.

Explorar o código

O aplicativo de exemplo demonstra alguns dos parâmetros que você pode configurar para treinar um RNN. Ele também demonstra o uso de um web worker para treinar um modelo fora do thread principal. Web workers são importantes em web ML porque permitem que você execute tarefas de treinamento computacionalmente caras em um thread em segundo plano, evitando problemas de desempenho que podem afetar o usuário no thread principal. Os threads principal e de trabalho se comunicam entre si por meio de eventos de mensagem.

Para saber mais sobre Web workers, consulte Web workers API e Usando Web workers .

O módulo principal do aplicativo de exemplo é index.js . O script index.js cria um web worker que executa o módulo worker.js :

const worker =
    new Worker(new URL('./worker.js', import.meta.url), {type: 'module'});

index.js é amplamente composto por uma única função, runAdditionRNNDemo , que manipula o envio do formulário, processa os dados do formulário, passa os dados do formulário para o trabalhador, espera que o trabalhador treine o modelo e retorne os resultados e, em seguida, exiba os resultados na página .

Para enviar os dados do formulário ao trabalhador, o script invoca postMessage no trabalhador:

worker.postMessage({
  digits,
  trainingSize,
  rnnType,
  layers,
  hiddenSize,
  trainIterations,
  batchSize,
  numTestExamples
});

O trabalhador escuta esta mensagem e passa os dados do formulário para funções que preparam os dados e iniciam o treinamento:

self.addEventListener('message', async (e) => {
  const { digits, trainingSize, rnnType, layers, hiddenSize, trainIterations, batchSize, numTestExamples } = e.data;
  const demo = new AdditionRNNDemo(digits, trainingSize, rnnType, layers, hiddenSize);
  await demo.train(trainIterations, batchSize, numTestExamples);
})

Durante o treinamento, o trabalhador pode enviar dois tipos de mensagens diferentes, uma com isPredict definido como true

self.postMessage({
  isPredict: true,
  i, iterations, modelFitTime,
  lossValues, accuracyValues,
});

e o outro com isPredict definido como false .

self.postMessage({
  isPredict: false,
  isCorrect, examples
});

Quando o thread de IU ( index.js ) lida com eventos de mensagem, ele verifica o sinalizador isPredict para determinar a forma dos dados retornados do trabalhador. Se isPredict for true, os dados devem representar uma previsão e o script atualiza a página usando tfjs-vis . Se isPredict for false, o script executará um bloco de código que assume que os dados representam exemplos. Ele envolve os dados em HTML e insere o HTML na página.

Qual é o próximo

Este tutorial forneceu um exemplo de uso de um web worker para evitar o bloqueio do thread de interface do usuário com um processo de treinamento de longa duração. Para saber mais sobre os benefícios de fazer cálculos caros em um thread em segundo plano, consulte Use web workers para executar JavaScript fora do thread principal do navegador .

Para saber mais sobre como treinar um modelo do TensorFlow.js, consulte Modelos de treinamento .