Importation de modèles basés sur TensorFlow GraphDef dans TensorFlow.js

Les modèles basés sur TensorFlow GraphDef (généralement créés via l'API Python) peuvent être enregistrés dans l'un des formats suivants :

  1. tensorflow SavedModel
  2. Modèle congelé
  3. Module Tensorflow Hub

Tous les formats ci - dessus peuvent être convertis par le convertisseur de TensorFlow.js dans un format qui peut être chargé directement dans TensorFlow.js pour l' inférence.

(Remarque : TensorFlow a abandonné le format de groupe de sessions. Veuillez migrer vos modèles vers le format SavedModel.)

Exigences

La procédure de conversion nécessite un environnement Python ; vous pouvez garder un cas isolé en utilisant pipenv ou virtualenv . Pour installer le convertisseur, exécutez la commande suivante :

 pip install tensorflowjs

L'importation d'un modèle TensorFlow dans TensorFlow.js est un processus en deux étapes. Tout d'abord, convertissez un modèle existant au format Web TensorFlow.js, puis chargez-le dans TensorFlow.js.

Étape 1. Convertissez un modèle TensorFlow existant au format Web TensorFlow.js

Exécutez le script de conversion fourni par le package pip :

Utilisation : Exemple de modèle enregistré :

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

Exemple de modèle congelé :

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    /mobilenet/frozen_model.pb \
    /mobilenet/web_model

Exemple de module Tensorflow Hub :

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
    /mobilenet/web_model
Arguments positionnels La description
input_path Chemin complet du répertoire du modèle enregistré, du répertoire de l'ensemble de sessions, du fichier de modèle gelé ou du descripteur ou du chemin du module TensorFlow Hub.
output_path Chemin de tous les artefacts de sortie.
Options La description
--input_format Le format du modèle d'entrée, utilisez tf_saved_model pour SavedModel, tf_frozen_model pour le modèle figé, tf_session_bundle pour le bundle de session, tf_hub pour le module TensorFlow Hub et keras pour Keras HDF5.
--output_node_names Les noms des nœuds de sortie, séparés par des virgules.
--saved_model_tags Uniquement applicable à la conversion SavedModel, balises du MetaGraphDef à charger, au format séparé par des virgules. La valeur par défaut pour serve .
--signature_name Applicable uniquement à la conversion du module TensorFlow Hub, signature à charger. Par défaut default . voir https://www.tensorflow.org/hub/common_signatures/

Utilisez la commande suivante pour obtenir un message d'aide détaillé :

tensorflowjs_converter --help

Fichiers générés par le convertisseur

Le script de conversion ci-dessus produit deux types de fichiers :

  • model.json (le graphe de flux de données et manifeste en poids)
  • group1-shard\*of\* (collection de fichiers binaires poids)

Par exemple, voici le résultat de la conversion de MobileNet v2 :

  output_directory/model.json
  output_directory/group1-shard1of5
  ...
  output_directory/group1-shard5of5

Étape 2 : Chargement et exécution dans le navigateur

  1. Installez le package npm tfjs-converter

yarn add @tensorflow/tfjs ou npm install @tensorflow/tfjs

  1. Instancier la classe FrozenModel et l' inférence exécuter.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

Consultez notre MobileNet démo .

L' loadGraphModel API accepte un montant supplémentaire LoadOptions paramètre qui peut être utilisé pour envoyer des informations d' identification ou les en- têtes personnalisés ainsi que la demande. S'il vous plaît voir la documentation loadGraphModel () pour plus de détails.

Opérations prises en charge

Actuellement, TensorFlow.js prend en charge un ensemble limité d'opérations TensorFlow. Si votre modèle utilise un op non pris en charge, le tensorflowjs_converter script échouera et imprimer une liste des opérations non pris en charge dans votre modèle. S'il vous plaît déposer une question pour chaque op pour nous faire savoir que vous ops besoin de soutien pour.

Chargement des poids uniquement

Si vous préférez charger uniquement les poids, vous pouvez utiliser l'extrait de code suivant.

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
        this.weightManifest, "https://example.org/model");