Pętla treningowa

Podczas trenowania modelu uczenia maszynowego często stosuje się pętlę, w której dane szkoleniowe są pozyskiwane (lub generowane), partie przebiegają przez model, uzyskiwane są gradienty, a model jest aktualizowany za pomocą optymalizatora. Chociaż możesz napisać własną pętlę treningową dla każdej aplikacji szkoleniowej, Swift dla TensorFlow zapewnia eksperymentalną abstrakcję pętli szkoleniowej, która może uprościć ten proces.

Moduł TrainingLoop w repozytorium modeli zawiera aktualną wersję tej eksperymentalnej uogólnionej pętli szkoleniowej. Jest skonstruowany w taki sposób, aby integrować się z opakowaniami zestawów danych zgodnymi z interfejsem API Epochs w celu łatwego pozyskiwania danych oraz automatyzować interakcję modeli, zbiorów danych i optymalizatorów z backendami akceleratorów w celu osiągnięcia optymalnej wydajności. Duże dostosowanie procesu szkoleniowego można osiągnąć poprzez wykorzystanie wywołań zwrotnych.

Większość przykładów opartych na obrazach w repozytorium modeli została przekonwertowana w celu wykorzystania tej abstrakcji pętli szkoleniowej, a także przykładów szkolenia modelu tekstu nadzorowanego. Jednak pętla szkoleniowa może nie być odpowiednia w swoim obecnym projekcie dla wszystkich modeli uczenia maszynowego.

Na implementację uogólnionej pętli szkoleniowej Swift dla TensorFlow duży wpływ ma moduł Learner firmy Fastai . Więcej informacji na temat ich projektu można znaleźć w artykule „fastai: warstwowe API do głębokiego uczenia się” i prezentacji Sylvaina Guggera „Fast.ai – nieskończenie konfigurowalna pętla szkoleniowa” .

Stosowanie

Przykład ResNet-CIFAR10 stanowi dobrą demonstrację wykorzystania tej pętli szkoleniowej w praktyce. Najpierw zaimportuj moduł:

import TrainingLoop

następnie wybierz zaplecze akceleratora, konfigurując Device . W tym przypadku wybierzemy backend oparty na X10 XLA i skorzystamy z pierwszego dostępnego akceleratora:

let device = Device.defaultXLA

Następnym krokiem jest skonfigurowanie zbioru danych, modelu i optymalizatora do użycia w pętli szkoleniowej:

let dataset = CIFAR10(batchSize: 10, on: device)
var model = ResNet(classCount: 10, depth: .resNet56, downsamplingInFirstStage: false)
var optimizer = SGD(for: model, learningRate: 0.001)

a następnie skonfiguruj pętlę treningową:

var trainingLoop = TrainingLoop(
  training: dataset.training,
  validation: dataset.validation,
  optimizer: optimizer,
  lossFunction: softmaxCrossEntropy,
  metrics: [.accuracy])

Pętla szkoleniowa zakłada, że ​​używany zestaw danych jest zgodny z interfejsem API Epochs i umożliwia określenie, które podziały w zestawie danych mają być używane do szkolenia i sprawdzania poprawności. Można użyć dowolnej funkcji utraty po umieszczeniu jej w kompatybilnym opakowaniu, takim jak softmaxCrossEntropy .

Aktualne metryki, które można przechwycić, obejmują:

  • loss
  • accuracy
  • top5Accuracy
  • matthewsCorrelationCoefficient
  • perplexity

Na koniec, aby przeprowadzić szkolenie, wywołujesz:

try! trainingLoop.fit(&model, epochs: 10, on: device)

Spowoduje to wytrenowanie modelu przez 10 epok przy użyciu określonego przez nas zaplecza akceleratora. Statystyki będą wyświetlane podczas szkolenia na konsoli za pomocą animowanego komunikatu.

Oddzwonienia

Dostosowywanie tej uogólnionej pętli szkoleniowej odbywa się poprzez użycie wywołań zwrotnych. Te wywołania zwrotne można podłączyć do różnych punktów pętli.

Kilka wbudowanych wywołań zwrotnych zapewnia funkcjonalność, którą można dodać do dowolnej pętli szkoleniowej. Obejmują one:

  • Rejestrowanie statystyk w plikach CSV
  • Dostosowanie tempa nauki według niestandardowego harmonogramu
  • Monitorowanie i wykresowanie postępów szkolenia za pośrednictwem TensorBoard

Oprócz tego możesz tworzyć własne niestandardowe wywołania zwrotne, aby dodać szereg dodatkowych funkcji do standardowej pętli szkoleniowej.

Rejestrowanie CSV

Klasa CSVLogger hermetyzuje wywołanie zwrotne, które zapisze statystyki szkoleniowe w formacie wartości oddzielonych przecinkami do wybranego pliku. Ten plik zacznie się od kolumn oznaczonych epoch ”, batch ” i wszelkich metryk, które włączyłeś w pętli szkoleniowej. Następnie dla każdej partii zostanie zapisany jeden wiersz z bieżącymi wartościami tych kolumn.

Aby dodać rejestrowanie CSV do swojej pętli treningowej, dodaj coś podobnego do poniższej tablicy wywołań zwrotnych dostarczanych do callbacks: parametr dla Twojej TrainingLoop :

try! CSVLogger(path: "file.csv").log

Na przykład próbka LeNet-MNIST wykorzystuje to w swojej pętli szkoleniowej.

Harmonogramy nauczania

Często podczas uczenia modelu zmienia się szybkość uczenia przekazywana optymalizatorowi podczas procesu uczenia. Może to być tak proste, jak liniowy spadek w czasie, lub tak złożone, jak cykle nagrzewania i spadku opisane przez skomplikowane funkcje.

Wywołanie zwrotne learningRateScheduler umożliwia opisanie harmonogramów szybkości uczenia się składających się z różnych segmentów, z których każdy ma swój własny, odrębny kształt. Osiąga się to poprzez zdefiniowanie LearningRateSchedule składającego się z ScheduleSegment s, z których każdy ma Shape zdefiniowany przez funkcję, początkową szybkość uczenia się i końcową szybkość uczenia się.

Na przykład próbka BERT-CoLA wykorzystuje liniowy wzrost szybkości uczenia się w okresie rozgrzewki, a następnie liniowy spadek. W tym celu wywołanie zwrotne harmonogramu szybkości uczenia się definiuje się w następujący sposób:

learningRateScheduler(
  schedule: makeSchedule(
    [
      ScheduleSegment(shape: linear, startRate: 0, endRate: peakLearningRate, stepCount: 10),
      ScheduleSegment(shape: linear, endRate: 0)
    ]
  )
)

Dwa ScheduleSegment definiują szybkość uczenia się, która zaczyna się od 0 i wzrasta liniowo do wartości peakLearningRate w serii 10 odrębnych kroków, następnie rozpoczyna się od końcowej szybkości uczenia się z poprzedniego kroku i maleje liniowo do 0 pod koniec procesu uczenia.

Integracja z TensorBoardem

TensorBoard to potężne narzędzie do wizualizacji umożliwiające monitorowanie treningu modeli, analizowanie treningu po jego ukończeniu lub porównywanie przebiegów treningowych. Swift dla TensorFlow obsługuje wizualizację TensorBoard poprzez wykorzystanie modułu TensorBoard w repozytorium modeli, który zapewnia wywołania zwrotne rejestrujące metryki szkoleniowe.

Przykład GPT2-WikiText2 ilustruje, jak dodać rejestrowanie TensorBoard do szkolenia modelu. Najpierw zaimportuj moduł TensorBoard . Wtedy jest to tak proste, jak dodanie tensorBoardStatisticsLogger() do wywołań zwrotnych TrainingLoop callbacks: array.

Domyślnie będzie to rejestrować każdy przebieg treningu w katalogu run/tensorboard/stats . Aby wyświetlić to w Tensorboard, uruchom

tensorboard --logdir ./run/tensorboard/stats

a TensorBoard powinien uruchomić lokalny serwer, na którym można przeglądać wskaźniki treningu. Wyniki uczenia i walidacji powinny być wyświetlane osobno, a każdy przebieg ma unikalną sygnaturę czasową, aby umożliwić łatwe porównanie wielu przebiegów tego samego modelu.

Projekt integracji Swift dla TensorFlow TensorBoard został zainspirowany tensorboardX . Wywołania zwrotne TensorBoard bezpośrednio tworzą odpowiednie bufory protokołu zdarzeń i podsumowań oraz zapisują je w pliku dziennika podczas szkolenia.

Niestandardowe wywołania zwrotne

Oprócz wbudowanych wywołań zwrotnych opisanych powyżej, masz możliwość dostosowania funkcji pętli treningowych, tworząc własne wywołania zwrotne. Te wywołania zwrotne to funkcje, które mają podpis podobny do następującego:

func customCallback<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws
{
  if event == .updateStart {
    ...
  }
}

Pętla treningowa i powiązany stan są przekazywane jako pierwszy parametr. Bieżąca część pętli, na którą odpowiada wywołanie zwrotne, jest udostępniana za pośrednictwem event . Zdarzenie pętli szkoleniowej ma jeden z następujących stanów, każdy odpowiadający innemu punktowi w cyklu życia pętli:

  • fitStart
  • fitEnd
  • epochStart
  • epochEnd
  • trainingStart
  • trainingEnd
  • validationStart
  • validationEnd
  • batchStart
  • batchEnd
  • updateStart
  • inferencePredictionEnd

Twoja funkcja wywołania zwrotnego może aktywować swoją logikę w dowolnej kombinacji powyższych stanów, co pozwala na wydobywanie danych z pętli szkoleniowej lub kontrolowanie jej w inny sposób na wiele sposobów.