웹 작업자를 사용하여 모델 교육

이 튜토리얼에서는 정수 덧셈을 수행하도록 RNN( Recurrent Neural Network )을 교육하기 위해 웹 워커를 사용하는 웹 애플리케이션의 예를 살펴보겠습니다. 예제 앱은 더하기 연산자를 명시적으로 정의하지 않습니다. 대신 예제 합계를 사용하여 RNN을 훈련합니다.

물론 이것은 두 개의 정수를 더하는 가장 효율적인 방법은 아닙니다! 하지만 튜토리얼에서는 웹 ML의 중요한 기술, 즉 UI 로직을 처리하는 메인 스레드를 차단하지 않고 장기 실행 계산을 수행하는 방법을 보여줍니다.

이 튜토리얼의 예제 애플리케이션은 온라인으로 제공 되므로 코드를 다운로드하거나 개발 환경을 설정할 필요가 없습니다. 코드를 로컬에서 실행하려면 로컬에서 예제 실행의 선택적 단계를 완료하세요. 개발 환경을 설정하지 않으려면 예제 탐색 으로 건너뛸 수 있습니다.

예제 코드는 GitHub 에서 사용할 수 있습니다.

(선택 사항) 로컬에서 예제 실행

전제 조건

예제 앱을 로컬에서 실행하려면 개발 환경에 다음이 설치되어 있어야 합니다.

예제 앱 설치 및 실행

  1. tfjs-examples 저장소를 복제하거나 다운로드하세요.
  2. addition-rnn-webworker 디렉터리로 변경합니다.

    cd tfjs-examples/addition-rnn-webworker
    
  3. 종속성을 설치합니다.

    yarn
    
  4. 개발 서버를 시작합니다.

    yarn run watch
    

예제 살펴보기

예제 앱을 엽니다 . (또는 예제를 로컬에서 실행하는 경우 브라우저에서 http://localhost:1234 로 이동합니다.)

TensorFlow.js: Addition RNN 이라는 제목의 페이지가 표시됩니다. 지침에 따라 앱을 사용해 보세요.

웹 양식을 사용하면 다음을 포함하여 모델 학습에 사용되는 일부 매개변수를 업데이트할 수 있습니다.

  • Digits : 추가할 용어의 최대 자릿수입니다.
  • Training Size : 생성할 훈련 예제의 수입니다.
  • RNN 유형 : SimpleRNN , GRU 또는 LSTM 중 하나입니다.
  • RNN Hidden Layer Size : 출력 공간의 차원입니다(양의 정수여야 함).
  • 배치 크기 : 그라데이션 업데이트당 샘플 수입니다.
  • Train Iterations : model.fit() 호출하여 모델을 교육하는 횟수입니다.
  • # of test example : 생성할 예제 문자열의 수(예: 27+41 )입니다.

다양한 매개변수를 사용하여 모델을 훈련하고 다양한 숫자 집합에 대한 예측 정확도를 향상시킬 수 있는지 확인하세요. 또한 모델 적합 시간이 다양한 매개변수에 의해 어떻게 영향을 받는지 확인하세요.

코드 살펴보기

예제 앱은 RNN 교육을 위해 구성할 수 있는 일부 매개변수를 보여줍니다. 또한 웹 작업자를 사용하여 메인 스레드에서 모델을 훈련시키는 방법도 보여줍니다. 웹 작업자는 백그라운드 스레드에서 계산 비용이 많이 드는 교육 작업을 실행하여 기본 스레드에서 잠재적으로 사용자에게 영향을 미치는 성능 문제를 방지할 수 있기 때문에 웹 ML에서 중요합니다. 기본 스레드와 작업자 스레드는 메시지 이벤트를 통해 서로 통신합니다.

웹 작업자에 대해 자세히 알아보려면 웹 작업자 API웹 작업자 사용을 참조하세요.

예제 앱의 기본 모듈은 index.js 입니다. index.js 스크립트는 worker.js 모듈을 실행하는 웹 작업자를 생성합니다 .

const worker =
    new Worker(new URL('./worker.js', import.meta.url), {type: 'module'});

index.js 양식 제출을 처리하고, 양식 데이터를 처리하고, 양식 데이터를 작업자에게 전달하고, 작업자가 모델을 훈련하고 결과를 반환할 때까지 기다린 후 페이지에 결과를 표시하는 단일 함수인 runAdditionRNNDemo 로 크게 구성됩니다. .

양식 데이터를 작업자에게 보내기 위해 스크립트는 작업자에서 postMessage 호출합니다 .

worker.postMessage({
  digits,
  trainingSize,
  rnnType,
  layers,
  hiddenSize,
  trainIterations,
  batchSize,
  numTestExamples
});

작업자는 이 메시지를 수신 하고 데이터를 준비하고 훈련을 시작하는 함수에 양식 데이터를 전달합니다.

self.addEventListener('message', async (e) => {
  const { digits, trainingSize, rnnType, layers, hiddenSize, trainIterations, batchSize, numTestExamples } = e.data;
  const demo = new AdditionRNNDemo(digits, trainingSize, rnnType, layers, hiddenSize);
  await demo.train(trainIterations, batchSize, numTestExamples);
})

훈련 중에 작업자는 isPredict true 로 설정된 두 가지 메시지 유형을 보낼 수 있습니다.

self.postMessage({
  isPredict: true,
  i, iterations, modelFitTime,
  lossValues, accuracyValues,
});

다른 하나는 isPredict false 로 설정되어 있습니다.

self.postMessage({
  isPredict: false,
  isCorrect, examples
});

UI 스레드( index.js )는 메시지 이벤트를 처리할 때 isPredict 플래그를 확인하여 작업자에서 반환된 데이터의 형태를 결정합니다. isPredict 가 true인 경우 데이터는 예측을 나타내야 하며 스크립트는 tfjs-vis 사용하여 페이지를 업데이트합니다 . isPredict 가 false인 경우 스크립트는 데이터가 예제를 나타낸다고 가정하는 코드 블록을 실행합니다. 데이터를 HTML로 래핑하고 HTML을 페이지에 삽입합니다.

무엇 향후 계획

이 튜토리얼에서는 장기 실행 학습 프로세스로 인해 UI 스레드가 차단되는 것을 방지하기 위해 웹 작업자를 사용하는 예를 제공했습니다. 백그라운드 스레드에서 비용이 많이 드는 계산을 수행할 때의 이점에 대해 자세히 알아보려면 웹 작업자를 사용하여 브라우저의 기본 스레드에서 JavaScript 실행을 참조 하세요.

TensorFlow.js 모델 학습에 대해 자세히 알아보려면 모델 학습을 참조하세요.