Usa un modelo pre-entrenado

En este tutorial, explorará una aplicación web de ejemplo que demuestra el aprendizaje por transferencia utilizando la API de capas TensorFlow.js. El ejemplo carga un modelo previamente entrenado y luego vuelve a entrenar el modelo en el navegador.

El modelo ha sido entrenado previamente en Python en los dígitos 0-4 del conjunto de datos de clasificación de dígitos MNIST . El reentrenamiento (o transferencia de aprendizaje) en el navegador utiliza los dígitos del 5 al 9. El ejemplo muestra que las primeras capas de un modelo previamente entrenado se pueden utilizar para extraer características de nuevos datos durante el aprendizaje por transferencia, lo que permite un entrenamiento más rápido con los nuevos datos.

La aplicación de ejemplo para este tutorial está disponible en línea , por lo que no necesita descargar ningún código ni configurar un entorno de desarrollo. Si desea ejecutar el código localmente, complete los pasos opcionales en Ejecutar el ejemplo localmente . Si no desea configurar un entorno de desarrollo, puede pasar a Explorar el ejemplo .

El código de ejemplo está disponible en GitHub .

(Opcional) Ejecute el ejemplo localmente

Requisitos previos

Para ejecutar la aplicación de ejemplo localmente, necesita tener instalado lo siguiente en su entorno de desarrollo:

Instalar y ejecutar la aplicación de ejemplo

  1. Clona o descarga el repositorio tfjs-examples .
  2. Cambie al directorio mnist-transfer-cnn :

    cd tfjs-examples/mnist-transfer-cnn
    
  3. Instalar dependencias:

    yarn
    
  4. Inicie el servidor de desarrollo:

    yarn run watch
    

Explora el ejemplo

Abra la aplicación de ejemplo . (O, si está ejecutando el ejemplo localmente, vaya a http://localhost:1234 en su navegador).

Debería ver una página titulada MNIST CNN Transfer Learning . Siga las instrucciones para probar la aplicación.

Aquí hay algunas cosas que puede probar:

  • Experimente con los diferentes modos de entrenamiento y compare la pérdida y la precisión.
  • Seleccione diferentes ejemplos de mapas de bits e inspeccione las probabilidades de clasificación. Tenga en cuenta que los números en cada ejemplo de mapa de bits son valores enteros en escala de grises que representan píxeles de una imagen.
  • Edite los valores enteros del mapa de bits y vea cómo los cambios afectan las probabilidades de clasificación.

Explora el código

La aplicación web de ejemplo carga un modelo que ha sido entrenado previamente en un subconjunto del conjunto de datos MNIST. El entrenamiento previo se define en un programa Python: mnist_transfer_cnn.py . El programa Python está fuera del alcance de este tutorial, pero vale la pena verlo si desea ver un ejemplo de conversión de modelo .

El archivo index.js contiene la mayor parte del código de entrenamiento para la demostración. Cuando index.js se ejecuta en el navegador, una función de configuración, setupMnistTransferCNN , crea una instancia e inicializa MnistTransferCNNPredictor , que encapsula las rutinas de reentrenamiento y predicción.

El método de inicialización, MnistTransferCNNPredictor.init , carga un modelo, carga datos de reentrenamiento y crea datos de prueba. Aquí está la línea que carga el modelo:

this.model = await loader.loadHostedPretrainedModel(urls.model);

Si observa la definición de loader.loadHostedPretrainedModel , verá que devuelve el resultado de una llamada a tf.loadLayersModel . Esta es la API de TensorFlow.js para cargar un modelo compuesto por objetos Layer.

La lógica de reentrenamiento se define en MnistTransferCNNPredictor.retrainModel . Si el usuario ha seleccionado Congelar capas de entidades como modo de entrenamiento, las primeras 7 capas del modelo base se congelan y solo las últimas 5 capas se entrenan con datos nuevos. Si el usuario ha seleccionado Reinicializar pesos , todos los pesos se restablecen y la aplicación entrena eficazmente el modelo desde cero.

if (trainingMode === 'freeze-feature-layers') {
  console.log('Freezing feature layers of the model.');
  for (let i = 0; i < 7; ++i) {
    this.model.layers[i].trainable = false;
  }
} else if (trainingMode === 'reinitialize-weights') {
  // Make a model with the same topology as before, but with re-initialized
  // weight values.
  const returnString = false;
  this.model = await tf.models.modelFromJSON({
    modelTopology: this.model.toJSON(null, returnString)
  });
}

Luego se compila el modelo y luego se entrena con los datos de prueba usando model.fit() :

await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
  batchSize: batchSize,
  epochs: epochs,
  validationData: [this.gte5TestData.x, this.gte5TestData.y],
  callbacks: [
    ui.getProgressBarCallbackConfig(epochs),
    tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
      zoomToFit: true,
      zoomToFitAccuracy: true,
      height: 200,
      callbacks: ['onEpochEnd'],
    }),
  ]
});

Para obtener más información sobre los parámetros model.fit() , consulte la documentación de la API .

Después de entrenarse con el nuevo conjunto de datos (dígitos 5 a 9), el modelo se puede utilizar para hacer predicciones. El método MnistTransferCNNPredictor.predict hace esto usando model.predict() :

// Perform prediction on the input image using the loaded model.
predict(imageText) {
  tf.tidy(() => {
    try {
      const image = util.textToImageArray(imageText, this.imageSize);
      const predictOut = this.model.predict(image);
      const winner = predictOut.argMax(1);

      ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
    } catch (e) {
      ui.setPredictError(e.message);
    }
  });
}

Tenga en cuenta el uso de tf.tidy , que ayuda a prevenir pérdidas de memoria.

Aprende más

En este tutorial se ha explorado una aplicación de ejemplo que realiza el aprendizaje por transferencia en el navegador utilizando TensorFlow.js. Consulte los recursos a continuación para obtener más información sobre modelos previamente entrenados y transferencia de aprendizaje.

TensorFlow.js

Núcleo TensorFlow