Boucle d'entraînement

Lors de la formation d'un modèle d'apprentissage automatique, il est courant d'avoir une boucle dans laquelle les données de formation sont ingérées (ou générées), des lots exécutés via un modèle, des gradients obtenus et le modèle mis à jour via un optimiseur. Bien que vous puissiez écrire votre propre boucle de formation pour chaque application de formation, Swift pour TensorFlow fournit une abstraction expérimentale de la boucle de formation qui peut simplifier ce processus.

Le module TrainingLoop au sein du référentiel de modèles contient la version actuelle de cette boucle de formation expérimentale généralisée. Il est structuré de manière à s'intégrer à des wrappers d'ensembles de données conformes à l'API Epochs pour une ingestion facile des données et à automatiser l'interaction des modèles, des ensembles de données et des optimiseurs avec les backends d'accélérateurs pour obtenir des performances optimales. Une personnalisation poussée du processus de formation peut être obtenue grâce à l'utilisation de rappels.

La plupart des exemples basés sur des images dans le référentiel de modèles ont été convertis pour utiliser cette abstraction de boucle de formation, ainsi que les exemples de formation de modèles de texte supervisés. Cependant, la boucle de formation n’est peut-être pas appropriée dans sa conception actuelle pour tous les modèles d’apprentissage automatique.

La mise en œuvre de Swift pour la boucle de formation généralisée de TensorFlow est fortement influencée par Learner de fastai . Pour en savoir plus sur leur conception, veuillez vous référer à « fastai : A Layered API for Deep Learning » et à la présentation de Sylvain Gugger « Fast.ai - An infinitely personnalisable training loop » .

Usage

L'exemple ResNet-CIFAR10 fournit une bonne démonstration de la façon d'utiliser cette boucle de formation dans la pratique. Tout d'abord, importez le module :

import TrainingLoop

puis choisissez un backend accélérateur en configurant un Device . Dans ce cas, nous sélectionnerons le backend basé sur X10 XLA et utiliserons le premier accélérateur disponible :

let device = Device.defaultXLA

L'étape suivante consiste à configurer l'ensemble de données, le modèle et l'optimiseur à utiliser avec votre boucle d'entraînement :

let dataset = CIFAR10(batchSize: 10, on: device)
var model = ResNet(classCount: 10, depth: .resNet56, downsamplingInFirstStage: false)
var optimizer = SGD(for: model, learningRate: 0.001)

puis configurez la boucle de formation :

var trainingLoop = TrainingLoop(
  training: dataset.training,
  validation: dataset.validation,
  optimizer: optimizer,
  lossFunction: softmaxCrossEntropy,
  metrics: [.accuracy])

La boucle de formation suppose que l'ensemble de données que vous utilisez est conforme à l'API Epochs et vous permet de spécifier les divisions de l'ensemble de données à utiliser pour la formation et la validation. Toute fonction de perte peut être utilisée une fois placée dans un wrapper compatible, tel que softmaxCrossEntropy est ici .

Les mesures actuelles qui peuvent être capturées incluent :

  • loss
  • accuracy
  • top5Accuracy
  • matthewsCorrelationCoefficient
  • perplexity

Enfin, pour effectuer une formation, vous appelez les éléments suivants :

try! trainingLoop.fit(&model, epochs: 10, on: device)

Cela entraînera le modèle pendant 10 époques à l'aide du backend accélérateur que nous avons spécifié. Les statistiques seront affichées pendant l'entraînement sur la console à l'aide d'une invite animée.

Rappels

La personnalisation de cette boucle de formation généralisée se fait via l'utilisation de rappels. Ces rappels peuvent être accrochés à différents points de la boucle.

Plusieurs rappels intégrés fournissent des fonctionnalités qui peuvent être ajoutées à n'importe quelle boucle de formation. Ceux-ci inclus:

  • Journalisation des statistiques dans des fichiers CSV (valeurs séparées par des virgules)
  • Ajustement du taux d'apprentissage selon un calendrier personnalisé
  • Surveillance et représentation graphique des progrès de la formation via TensorBoard

En plus de cela, vous pouvez créer vos propres rappels personnalisés pour ajouter une gamme de fonctionnalités supplémentaires à une boucle de formation standard.

Journalisation CSV

La classe CSVLogger encapsule un rappel qui écrira les statistiques d'entraînement dans un format de valeurs séparées par des virgules dans un fichier de votre choix. Ce fichier commencera par des colonnes intitulées epoch , batch et toutes les métriques que vous avez activées dans votre boucle d'entraînement. Une ligne sera alors écrite pour chaque lot, avec les valeurs actuelles de ces colonnes.

Pour ajouter la journalisation CSV à votre boucle d'entraînement, ajoutez quelque chose comme ce qui suit à un tableau de rappels fournis au paramètre callbacks: pour votre TrainingLoop :

try! CSVLogger(path: "file.csv").log

À titre d'exemple, l' exemple LeNet-MNIST l'utilise dans sa boucle de formation.

Barèmes de taux d'apprentissage

Il est courant, lors de la formation d'un modèle, de modifier le taux d'apprentissage fourni à un optimiseur pendant le processus de formation. Cela peut être aussi simple qu'une diminution linéaire dans le temps, ou aussi complexe que des cycles de réchauffement et de déclin décrits par des fonctions compliquées.

Le rappel learningRateScheduler fournit le moyen de décrire des programmes de taux d'apprentissage composés de différents segments, chacun avec sa propre forme distincte. Ceci est accompli en définissant un LearningRateSchedule composé de ScheduleSegment qui ont chacun une Shape définie par une fonction, un taux d'apprentissage initial et un taux d'apprentissage final.

Par exemple, l' échantillon BERT-CoLA utilise une augmentation linéaire du taux d'apprentissage pendant une période d'échauffement et une diminution linéaire par la suite. Pour ce faire, le rappel du planning de taux d'apprentissage est défini comme suit :

learningRateScheduler(
  schedule: makeSchedule(
    [
      ScheduleSegment(shape: linear, startRate: 0, endRate: peakLearningRate, stepCount: 10),
      ScheduleSegment(shape: linear, endRate: 0)
    ]
  )
)

Les deux ScheduleSegment définissent un taux d'apprentissage qui commence à 0 et augmente linéairement jusqu'à peakLearningRate sur une série de 10 étapes discrètes, puis commence au taux d'apprentissage final de l'étape précédente et diminue linéairement jusqu'à 0 à la fin du processus de formation.

Intégration TensorBoard

TensorBoard est un outil de visualisation puissant permettant de surveiller l'entraînement du modèle, d'analyser l'entraînement une fois terminé ou de comparer les exécutions d'entraînement. Swift pour TensorFlow prend en charge la visualisation TensorBoard grâce à l'utilisation du module TensorBoard dans le référentiel de modèles, qui fournit des rappels qui enregistrent les métriques de formation.

L'exemple GPT2-WikiText2 illustre comment ajouter la journalisation TensorBoard à la formation de votre modèle. Tout d’abord, importez le module TensorBoard . Ensuite, c'est aussi simple que d'ajouter tensorBoardStatisticsLogger() aux rappels de votre TrainingLoop callbacks: array.

Par défaut, chaque exécution d'entraînement sera enregistrée dans un répertoire run/tensorboard/stats . Pour afficher cela dans Tensorboard, exécutez

tensorboard --logdir ./run/tensorboard/stats

et TensorBoard doit démarrer un serveur local sur lequel vous pouvez afficher vos métriques de formation. Les résultats de formation et de validation doivent être affichés séparément, et chaque exécution possède un horodatage unique pour permettre une comparaison facile entre plusieurs exécutions du même modèle.

La conception de l'intégration Swift pour TensorFlow TensorBoard a été inspirée par tensorboardX . Les rappels TensorBoard créent directement les tampons de protocole d'événement et de résumé appropriés et les écrivent dans un fichier journal pendant la formation.

Rappels personnalisés

En plus des rappels intégrés décrits ci-dessus, vous avez la possibilité de personnaliser la fonction des boucles d'entraînement en créant vos propres rappels. Ces rappels sont des fonctions qui ont une signature similaire à celle-ci :

func customCallback<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws
{
  if event == .updateStart {
    ...
  }
}

La boucle d'entraînement et l'état associé sont transmis comme premier paramètre. La partie actuelle de la boucle à laquelle le rappel répond est fournie via event . L'événement de boucle d'entraînement a l'un des états suivants, chacun correspondant à un point différent du cycle de vie de la boucle :

  • fitStart
  • fitEnd
  • epochStart
  • epochEnd
  • trainingStart
  • trainingEnd
  • validationStart
  • validationEnd
  • batchStart
  • batchEnd
  • updateStart
  • inferencePredictionEnd

Votre fonction de rappel peut choisir d'activer sa logique sur n'importe quelle combinaison des états ci-dessus, ce qui permet d'extraire des données ou de contrôler la boucle d'entraînement de plusieurs manières.