Importer un modèle TensorFlow dans TensorFlow.js

Les modèles basés sur TensorFlow GraphDef (généralement créés via l'API Python) peuvent être enregistrés dans l'un des formats suivants :

  1. Modèle enregistré TensorFlow
  2. Modèle congelé
  3. Module Tensorflow Hub

Tous les formats ci-dessus peuvent être convertis par le convertisseur TensorFlow.js en un format qui peut être chargé directement dans TensorFlow.js pour inférence.

(Remarque : TensorFlow a rendu obsolète le format de groupe de sessions. Veuillez migrer vos modèles vers le format SavedModel.)

Exigences

La procédure de conversion nécessite un environnement Python ; vous souhaiterez peut-être en conserver un isolé en utilisant pipenv ou virtualenv .

Pour installer le convertisseur, exécutez la commande suivante :

 pip install tensorflowjs

L'importation d'un modèle TensorFlow dans TensorFlow.js est un processus en deux étapes. Commencez par convertir un modèle existant au format Web TensorFlow.js, puis chargez-le dans TensorFlow.js.

Étape 1. Convertir un modèle TensorFlow existant au format Web TensorFlow.js

Exécutez le script de conversion fourni par le package pip :

Exemple de modèle enregistré :

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

Exemple de modèle gelé :

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    /mobilenet/frozen_model.pb \
    /mobilenet/web_model

Exemple de module Tensorflow Hub :

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
    /mobilenet/web_model
Arguments de position Description
input_path Chemin complet du répertoire du modèle enregistré, du répertoire du bundle de session, du fichier de modèle gelé ou du handle ou du chemin du module TensorFlow Hub.
output_path Chemin de tous les artefacts de sortie.
Possibilités Description
--input_format Le format du modèle d’entrée. Utilisez tf_saved_model pour SavedModel, tf_frozen_model pour le modèle gelé, tf_session_bundle pour le bundle de session, tf_hub pour le module TensorFlow Hub et keras pour Keras HDF5.
--output_node_names Les noms des nœuds de sortie, séparés par des virgules.
--saved_model_tags Applicable uniquement à la conversion SavedModel. Balises du MetaGraphDef à charger, au format séparé par des virgules. Par défaut, serve .
--signature_name Applicable uniquement à la conversion du module TensorFlow Hub, signature à charger. La valeur par default . Voir https://www.tensorflow.org/hub/common_signatures/

Utilisez la commande suivante pour obtenir un message d'aide détaillé :

tensorflowjs_converter --help

Fichiers générés par le convertisseur

Le script de conversion ci-dessus produit deux types de fichiers :

  • model.json : Le graphique de flux de données et le manifeste de poids
  • group1-shard\*of\* : Une collection de fichiers de poids binaires

Par exemple, voici le résultat de la conversion de MobileNet v2 :

  output_directory/model.json
  output_directory/group1-shard1of5
  ...
  output_directory/group1-shard5of5

Étape 2 : chargement et exécution dans le navigateur

  1. Installez le package npm tfjs-converter :

yarn add @tensorflow/tfjs ou npm install @tensorflow/tfjs

  1. Instanciez la classe FrozenModel et exécutez l'inférence.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

Découvrez la démo MobileNet .

L'API loadGraphModel accepte un paramètre LoadOptions supplémentaire, qui peut être utilisé pour envoyer des informations d'identification ou des en-têtes personnalisés avec la demande. Pour plus de détails, consultez la documentation loadGraphModel() .

Opérations prises en charge

Actuellement, TensorFlow.js prend en charge un ensemble limité d'opérations TensorFlow. Si votre modèle utilise une opération non prise en charge, le script tensorflowjs_converter échouera et imprimera une liste des opérations non prises en charge dans votre modèle. Veuillez signaler un problème pour chaque opération afin de nous indiquer les opérations pour lesquelles vous avez besoin d'assistance.

Chargement des poids uniquement

Si vous préférez charger uniquement les poids, vous pouvez utiliser l'extrait de code suivant :

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
        this.weightManifest, "https://example.org/model");
// Use `weightMap` ...