Los modelos basados en TensorFlow GraphDef (normalmente creados a través de la API de Python) se pueden guardar en uno de los siguientes formatos:
- TensorFlow SavedModel
- Modelo congelado
- Módulo Tensorflow Hub
Todos los formatos anteriores pueden ser convertidos por el convertidor de TensorFlow.js en un formato que se puede cargar directamente en TensorFlow.js para la inferencia.
(Nota: TensorFlow ha desaprobado el formato de paquete de sesión, migre sus modelos al formato de modelo guardado).
Requisitos
El procedimiento de conversión requiere un entorno Python; es posible que desee mantener un aislado usando uno pipenv o virtualenv . Para instalar el convertidor, ejecute el siguiente comando:
pip install tensorflowjs
Importar un modelo de TensorFlow a TensorFlow.js es un proceso de dos pasos. Primero, convierta un modelo existente al formato web TensorFlow.js y luego cárguelo en TensorFlow.js.
Paso 1. Convierta un modelo de TensorFlow existente al formato web TensorFlow.js
Ejecute el script de conversión proporcionado por el paquete pip:
Uso: ejemplo de modelo guardado:
tensorflowjs_converter \
--input_format=tf_saved_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
--saved_model_tags=serve \
/mobilenet/saved_model \
/mobilenet/web_model
Ejemplo de modelo congelado:
tensorflowjs_converter \
--input_format=tf_frozen_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
/mobilenet/frozen_model.pb \
/mobilenet/web_model
Ejemplo de módulo de Tensorflow Hub:
tensorflowjs_converter \
--input_format=tf_hub \
'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
/mobilenet/web_model
Argumentos posicionales | Descripción |
---|---|
input_path | Ruta completa del directorio del modelo guardado, del directorio del paquete de sesión, del archivo del modelo congelado o del identificador o ruta del módulo de TensorFlow Hub. |
output_path | Ruta de todos los artefactos de salida. |
Opciones | Descripción |
---|---|
--input_format | El formato del modelo de entrada, use tf_saved_model para SavedModel, tf_f |
--output_node_names | Los nombres de los nodos de salida, separados por comas. |
--saved_model_tags | Solo aplicable a la conversión SavedModel, Etiquetas de MetaGraphDef para cargar, en formato separado por comas. Por defecto para serve . |
--signature_name | Solo se aplica a la conversión del módulo de TensorFlow Hub, firma para cargar. Por defecto es default . ver https://www.tensorflow.org/hub/common_signatures/ |
Utilice el siguiente comando para obtener un mensaje de ayuda detallado:
tensorflowjs_converter --help
Archivos generados por convertidor
El script de conversión anterior produce dos tipos de archivos:
-
model.json
(el gráfico de flujo de datos y se manifiestan peso) -
group1-shard\*of\*
(colección de archivos binarios de peso)
Por ejemplo, aquí está el resultado de la conversión de MobileNet v2:
output_directory/model.json
output_directory/group1-shard1of5
...
output_directory/group1-shard5of5
Paso 2: carga y ejecución en el navegador
- Instale el paquete npm tfjs-converter
yarn add @tensorflow/tfjs
o npm install @tensorflow/tfjs
- Una instancia de la clase FrozenModel y la inferencia de gestión.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';
const MODEL_URL = 'model_directory/model.json';
const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));
Visita nuestra demostración MobileNet .
El loadGraphModel
API acepta un adicionales LoadOptions
parámetro, que puede ser utilizado para enviar las credenciales o cabeceras personalizadas junto con la solicitud. Por favor, vea la documentación loadGraphModel () para más detalles.
Operaciones admitidas
Actualmente, TensorFlow.js admite un conjunto limitado de operaciones de TensorFlow. Si su modelo utiliza un artículo no compatible, el tensorflowjs_converter
escritura fallará e imprimir una lista de las operaciones no compatible en el modelo. Por favor, presentar un tema para cada op para hacernos saber lo que ops necesita soporte para.
Cargando solo los pesos
Si prefiere cargar solo los pesos, puede utilizar el siguiente fragmento de código.
import * as tf from '@tensorflow/tfjs';
const weightManifestUrl = "https://example.org/model/weights_manifest.json";
const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
this.weightManifest, "https://example.org/model");