Utiliser un modèle pré-formé

Dans ce didacticiel, vous explorerez un exemple d'application Web qui illustre l'apprentissage par transfert à l'aide de l'API TensorFlow.js Layers. L'exemple charge un modèle pré-entraîné, puis recycle le modèle dans le navigateur.

Le modèle a été pré-entraîné en Python sur les chiffres 0 à 4 de l' ensemble de données de classification des chiffres MNIST . Le recyclage (ou transfert d'apprentissage) dans le navigateur utilise les chiffres 5 à 9. L'exemple montre que les premières couches d'un modèle pré-entraîné peuvent être utilisées pour extraire des fonctionnalités de nouvelles données lors de l'apprentissage par transfert, permettant ainsi une formation plus rapide sur les nouvelles données.

L'exemple d'application de ce didacticiel est disponible en ligne . Vous n'avez donc pas besoin de télécharger de code ni de configurer un environnement de développement. Si vous souhaitez exécuter le code localement, suivez les étapes facultatives dans Exécuter l'exemple localement . Si vous ne souhaitez pas configurer d'environnement de développement, vous pouvez passer à Explorer l'exemple .

L'exemple de code est disponible sur GitHub .

(Facultatif) Exécutez l'exemple localement

Conditions préalables

Pour exécuter l'exemple d'application localement, vous devez installer les éléments suivants dans votre environnement de développement :

Installez et exécutez l'exemple d'application

  1. Clonez ou téléchargez le référentiel tfjs-examples .
  2. Accédez au répertoire mnist-transfer-cnn :

    cd tfjs-examples/mnist-transfer-cnn
    
  3. Installer les dépendances :

    yarn
    
  4. Démarrez le serveur de développement :

    yarn run watch
    

Explorez l'exemple

Ouvrez l'exemple d'application . (Ou, si vous exécutez l'exemple localement, accédez à http://localhost:1234 dans votre navigateur.)

Vous devriez voir une page intitulée MNIST CNN Transfer Learning . Suivez les instructions pour essayer l'application.

Voici quelques choses à essayer :

  • Expérimentez avec les différents modes d'entraînement et comparez la perte et la précision.
  • Sélectionnez différents exemples de bitmap et inspectez les probabilités de classification. Notez que les nombres dans chaque exemple de bitmap sont des valeurs entières en niveaux de gris représentant les pixels d'une image.
  • Modifiez les valeurs entières bitmap et voyez comment les modifications affectent les probabilités de classification.

Explorez le code

L'exemple d'application Web charge un modèle qui a été pré-entraîné sur un sous-ensemble de l'ensemble de données MNIST. La pré-formation est définie dans un programme Python : mnist_transfer_cnn.py . Le programme Python est hors du cadre de ce didacticiel, mais il vaut la peine d'y jeter un coup d'œil si vous souhaitez voir un exemple de conversion de modèle .

Le fichier index.js contient la plupart du code de formation pour la démo. Lorsque index.js s'exécute dans le navigateur, une fonction de configuration, setupMnistTransferCNN , instancie et initialise MnistTransferCNNPredictor , qui encapsule les routines de recyclage et de prédiction.

La méthode d'initialisation, MnistTransferCNNPredictor.init , charge un modèle, charge les données de recyclage et crée des données de test. Voici la ligne qui charge le modèle :

this.model = await loader.loadHostedPretrainedModel(urls.model);

Si vous regardez la définition de loader.loadHostedPretrainedModel , vous verrez qu'elle renvoie le résultat d'un appel à tf.loadLayersModel . Il s'agit de l'API TensorFlow.js permettant de charger un modèle composé d'objets Layer.

La logique de recyclage est définie dans MnistTransferCNNPredictor.retrainModel . Si l'utilisateur a sélectionné Geler les couches d'entités comme mode d'entraînement, les 7 premières couches du modèle de base sont gelées et seules les 5 dernières couches sont entraînées sur de nouvelles données. Si l'utilisateur a sélectionné Réinitialiser les poids , tous les poids sont réinitialisés et l'application entraîne efficacement le modèle à partir de zéro.

if (trainingMode === 'freeze-feature-layers') {
  console.log('Freezing feature layers of the model.');
  for (let i = 0; i < 7; ++i) {
    this.model.layers[i].trainable = false;
  }
} else if (trainingMode === 'reinitialize-weights') {
  // Make a model with the same topology as before, but with re-initialized
  // weight values.
  const returnString = false;
  this.model = await tf.models.modelFromJSON({
    modelTopology: this.model.toJSON(null, returnString)
  });
}

Le modèle est ensuite compilé , puis entraîné sur les données de test à l'aide de model.fit() :

await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
  batchSize: batchSize,
  epochs: epochs,
  validationData: [this.gte5TestData.x, this.gte5TestData.y],
  callbacks: [
    ui.getProgressBarCallbackConfig(epochs),
    tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
      zoomToFit: true,
      zoomToFitAccuracy: true,
      height: 200,
      callbacks: ['onEpochEnd'],
    }),
  ]
});

Pour en savoir plus sur les paramètres model.fit() , consultez la documentation de l'API .

Après avoir été formé sur le nouvel ensemble de données (chiffres 5 à 9), le modèle peut être utilisé pour faire des prédictions. La méthode MnistTransferCNNPredictor.predict fait cela en utilisant model.predict() :

// Perform prediction on the input image using the loaded model.
predict(imageText) {
  tf.tidy(() => {
    try {
      const image = util.textToImageArray(imageText, this.imageSize);
      const predictOut = this.model.predict(image);
      const winner = predictOut.argMax(1);

      ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
    } catch (e) {
      ui.setPredictError(e.message);
    }
  });
}

Notez l'utilisation de tf.tidy , qui permet d'éviter les fuites de mémoire.

Apprendre encore plus

Ce didacticiel a exploré un exemple d'application qui effectue un apprentissage par transfert dans le navigateur à l'aide de TensorFlow.js. Consultez les ressources ci-dessous pour en savoir plus sur les modèles pré-entraînés et l'apprentissage par transfert.

TensorFlow.js

Noyau TensorFlow