Importar un modelo de TensorFlow a TensorFlow.js

Los modelos basados ​​en TensorFlow GraphDef (generalmente creados a través de la API de Python) se pueden guardar en uno de los siguientes formatos:

  1. Modelo guardado de TensorFlow
  2. modelo congelado
  3. Módulo concentrador de Tensorflow

El convertidor TensorFlow.js puede convertir todos los formatos anteriores en un formato que se puede cargar directamente en TensorFlow.js para la inferencia.

(Nota: TensorFlow ha dejado de usar el formato de paquete de sesión. Migre sus modelos al formato de modelo guardado).

Requisitos

El procedimiento de conversión requiere un entorno Python; es posible que desee mantener uno aislado usando pipenv o virtualenv .

Para instalar el convertidor, ejecute el siguiente comando:

 pip install tensorflowjs

Importar un modelo de TensorFlow a TensorFlow.js es un proceso de dos pasos. Primero, convierta un modelo existente al formato web TensorFlow.js y luego cárguelo en TensorFlow.js.

Paso 1. Convierta un modelo TensorFlow existente al formato web TensorFlow.js

Ejecute la secuencia de comandos del convertidor proporcionada por el paquete pip:

Ejemplo de modelo guardado:

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

Ejemplo de modelo congelado:

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    /mobilenet/frozen_model.pb \
    /mobilenet/web_model

Ejemplo del módulo Tensorflow Hub:

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
    /mobilenet/web_model
Argumentos posicionales Descripción
input_path Ruta completa del directorio del modelo guardado, el directorio del paquete de la sesión, el archivo del modelo congelado o el identificador o la ruta del módulo TensorFlow Hub.
output_path Ruta para todos los artefactos de salida.
Opciones Descripción
--input_format El formato del modelo de entrada. Usa tf_saved_model para el modelo guardado, tf_frozen_model para el modelo congelado, tf_session_bundle para el paquete de sesión, tf_hub para el módulo TensorFlow Hub y keras para Keras HDF5.
--output_node_names Los nombres de los nodos de salida, separados por comas.
--saved_model_tags Solo aplicable a la conversión de modelo guardado. Etiquetas del MetaGraphDef a cargar, en formato separado por comas. Valores predeterminados para serve .
--signature_name Solo se aplica a la conversión del módulo TensorFlow Hub, firma para cargar. Predeterminado a default . Consulte https://www.tensorflow.org/hub/common_signatures/

Use el siguiente comando para obtener un mensaje de ayuda detallado:

tensorflowjs_converter --help

Archivos generados por el convertidor

El script de conversión anterior produce dos tipos de archivos:

  • model.json : el gráfico de flujo de datos y el manifiesto de peso
  • group1-shard\*of\* : Una colección de archivos de peso binarios

Por ejemplo, aquí está el resultado de convertir MobileNet v2:

  output_directory/model.json
  output_directory/group1-shard1of5
  ...
  output_directory/group1-shard5of5

Paso 2: Cargar y ejecutar en el navegador

  1. Instale el paquete tfjs-converter npm:

yarn add @tensorflow/tfjs o npm install @tensorflow/tfjs

  1. Cree una instancia de la clase FrozenModel y ejecute la inferencia.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

Consulte la demostración de MobileNet .

La API loadGraphModel acepta un parámetro LoadOptions adicional, que se puede usar para enviar credenciales o encabezados personalizados junto con la solicitud. Para obtener más información, consulte la documentación de loadGraphModel() .

Operaciones admitidas

Actualmente, TensorFlow.js admite un conjunto limitado de operaciones de TensorFlow. Si su modelo usa una operación no admitida, el script tensorflowjs_converter fallará e imprimirá una lista de las operaciones no admitidas en su modelo. Presente un problema para cada operación para informarnos para qué operaciones necesita soporte.

Cargando los pesos solamente

Si prefiere cargar solo los pesos, puede usar el siguiente fragmento de código:

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
        this.weightManifest, "https://example.org/model");
// Use `weightMap` ...