Importa un modello TensorFlow in TensorFlow.js

I modelli basati su TensorFlow GraphDef (in genere creati tramite l'API Python) possono essere salvati in uno dei seguenti formati:

  1. Modello salvato di TensorFlow
  2. Modello congelato
  3. Modulo Tensorflow Hub

Tutti i formati di cui sopra possono essere convertiti dal convertitore TensorFlow.js in un formato che può essere caricato direttamente in TensorFlow.js per l'inferenza.

(Nota: TensorFlow ha deprecato il formato del bundle di sessione. Migra i tuoi modelli al formato SavedModel.)

Requisiti

La procedura di conversione richiede un ambiente Python; potresti voler mantenerne uno isolato usando pipenv o virtualenv .

Per installare il convertitore, eseguire il comando seguente:

 pip install tensorflowjs

L'importazione di un modello TensorFlow in TensorFlow.js è un processo in due passaggi. Innanzitutto, converti un modello esistente nel formato web TensorFlow.js, quindi caricalo in TensorFlow.js.

Passaggio 1. Converti un modello TensorFlow esistente nel formato web TensorFlow.js

Esegui lo script del convertitore fornito dal pacchetto pip:

Esempio di modello salvato:

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

Esempio di modello congelato:

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    /mobilenet/frozen_model.pb \
    /mobilenet/web_model

Esempio di modulo Tensorflow Hub:

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
    /mobilenet/web_model
Argomenti posizionali Descrizione
input_path Percorso completo della directory del modello salvato, della directory del bundle di sessione, del file del modello congelato o dell'handle o del percorso del modulo TensorFlow Hub.
output_path Percorso per tutti gli artefatti di output.
Opzioni Descrizione
--input_format Il formato del modello di input. Utilizza tf_saved_model per SavedModel, tf_frozen_model per il modello congelato, tf_session_bundle per il bundle di sessione, tf_hub per il modulo TensorFlow Hub e keras per Keras HDF5.
--output_node_names I nomi dei nodi di output, separati da virgole.
--saved_model_tags Applicabile solo alla conversione SavedModel. Tag del MetaGraphDef da caricare, in formato separato da virgole. Per impostazione predefinita viene serve .
--signature_name Applicabile solo alla conversione del modulo TensorFlow Hub, firma da caricare. Impostazioni predefinite per default . Vedere https://www.tensorflow.org/hub/common_signatures/

Utilizzare il seguente comando per ottenere un messaggio di aiuto dettagliato:

tensorflowjs_converter --help

File generati dal convertitore

Lo script di conversione sopra produce due tipi di file:

  • model.json : il grafico del flusso di dati e il manifest del peso
  • group1-shard\*of\* : una raccolta di file di peso binari

Ad esempio, ecco l'output della conversione di MobileNet v2:

  output_directory/model.json
  output_directory/group1-shard1of5
  ...
  output_directory/group1-shard5of5

Passaggio 2: caricamento ed esecuzione nel browser

  1. Installa il pacchetto tfjs-converter npm:

yarn add @tensorflow/tfjs o npm install @tensorflow/tfjs

  1. Crea un'istanza della classe FrozenModel ed esegui l'inferenza.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

Guarda la demo di MobileNet .

L'API loadGraphModel accetta un parametro LoadOptions aggiuntivo, che può essere utilizzato per inviare credenziali o intestazioni personalizzate insieme alla richiesta. Per i dettagli, consultare la documentazione loadGraphModel() .

Operazioni supportate

Attualmente TensorFlow.js supporta un set limitato di operazioni TensorFlow. Se il tuo modello utilizza un'operazione non supportata, lo script tensorflowjs_converter fallirà e stamperà un elenco delle operazioni non supportate nel tuo modello. Invia un problema per ciascuna operazione per farci sapere per quali operazioni hai bisogno di supporto.

Caricamento solo dei pesi

Se preferisci caricare solo i pesi, puoi utilizzare il seguente snippet di codice:

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
        this.weightManifest, "https://example.org/model");
// Use `weightMap` ...