¡Google I/O es una envoltura! Póngase al día con las sesiones de TensorFlow Ver sesiones

Importación de modelos basados ​​en GraphDef de TensorFlow en TensorFlow.js

Los modelos basados ​​en TensorFlow GraphDef (normalmente creados a través de la API de Python) se pueden guardar en uno de los siguientes formatos:

  1. TensorFlow SavedModel
  2. Modelo congelado
  3. Módulo Tensorflow Hub

Todos los formatos anteriores pueden ser convertidos por el convertidor de TensorFlow.js en un formato que se puede cargar directamente en TensorFlow.js para la inferencia.

(Nota: TensorFlow ha desaprobado el formato de paquete de sesión, migre sus modelos al formato de modelo guardado).

Requisitos

El procedimiento de conversión requiere un entorno Python; es posible que desee mantener un aislado usando uno pipenv o virtualenv . Para instalar el convertidor, ejecute el siguiente comando:

 pip install tensorflowjs

Importar un modelo de TensorFlow a TensorFlow.js es un proceso de dos pasos. Primero, convierta un modelo existente al formato web TensorFlow.js y luego cárguelo en TensorFlow.js.

Paso 1. Convierta un modelo de TensorFlow existente al formato web TensorFlow.js

Ejecute el script de conversión proporcionado por el paquete pip:

Uso: ejemplo de modelo guardado:

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

Ejemplo de modelo congelado:

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    /mobilenet/frozen_model.pb \
    /mobilenet/web_model

Ejemplo de módulo de Tensorflow Hub:

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
    /mobilenet/web_model
Argumentos posicionales Descripción
input_path Ruta completa del directorio del modelo guardado, del directorio del paquete de sesión, del archivo del modelo congelado o del identificador o ruta del módulo de TensorFlow Hub.
output_path Ruta de todos los artefactos de salida.
Opciones Descripción
--input_format El formato del modelo de entrada, use tf_saved_model para SavedModel, tf_f
--output_node_names Los nombres de los nodos de salida, separados por comas.
--saved_model_tags Solo aplicable a la conversión SavedModel, Etiquetas de MetaGraphDef para cargar, en formato separado por comas. Por defecto para serve .
--signature_name Solo se aplica a la conversión del módulo de TensorFlow Hub, firma para cargar. Por defecto es default . ver https://www.tensorflow.org/hub/common_signatures/

Utilice el siguiente comando para obtener un mensaje de ayuda detallado:

tensorflowjs_converter --help

Archivos generados por convertidor

El script de conversión anterior produce dos tipos de archivos:

  • model.json (el gráfico de flujo de datos y se manifiestan peso)
  • group1-shard\*of\* (colección de archivos binarios de peso)

Por ejemplo, aquí está el resultado de la conversión de MobileNet v2:

  output_directory/model.json
  output_directory/group1-shard1of5
  ...
  output_directory/group1-shard5of5

Paso 2: carga y ejecución en el navegador

  1. Instale el paquete npm tfjs-converter

yarn add @tensorflow/tfjs o npm install @tensorflow/tfjs

  1. Una instancia de la clase FrozenModel y la inferencia de gestión.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

Visita nuestra demostración MobileNet .

El loadGraphModel API acepta un adicionales LoadOptions parámetro, que puede ser utilizado para enviar las credenciales o cabeceras personalizadas junto con la solicitud. Por favor, vea la documentación loadGraphModel () para más detalles.

Operaciones admitidas

Actualmente, TensorFlow.js admite un conjunto limitado de operaciones de TensorFlow. Si su modelo utiliza un artículo no compatible, el tensorflowjs_converter escritura fallará e imprimir una lista de las operaciones no compatible en el modelo. Por favor, presentar un tema para cada op para hacernos saber lo que ops necesita soporte para.

Cargando solo los pesos

Si prefiere cargar solo los pesos, puede utilizar el siguiente fragmento de código.

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
        this.weightManifest, "https://example.org/model");