Importazione di un modello Keras in TensorFlow.js

I modelli Keras (in genere creati tramite l'API Python) possono essere salvati in uno dei diversi formati . Il formato "modello intero" può essere convertito nel formato TensorFlow.js Layers, che può essere caricato direttamente in TensorFlow.js per l'inferenza o per ulteriore formazione.

Il formato TensorFlow.js Layers di destinazione è una directory contenente un file model.json e una serie di file di peso partizionati in formato binario. Il file model.json contiene sia la topologia del modello (nota anche come "architettura" o "grafico": una descrizione dei livelli e il modo in cui sono collegati) sia un manifest dei file di peso.

Requisiti

La procedura di conversione richiede un ambiente Python; potresti voler mantenerne uno isolato usando pipenv o virtualenv . Per installare il convertitore, utilizzare pip install tensorflowjs .

L'importazione di un modello Keras in TensorFlow.js è un processo in due passaggi. Innanzitutto, converti un modello Keras esistente nel formato TF.js Layers, quindi caricalo in TensorFlow.js.

Passaggio 1. Converti un modello Keras esistente nel formato TF.js Layers

I modelli Keras vengono solitamente salvati tramite model.save(filepath) , che produce un singolo file HDF5 (.h5) contenente sia la topologia del modello che i pesi. Per convertire un file di questo tipo nel formato TF.js Layers, esegui il comando seguente, dove path/to/my_model.h5 è il file Keras .h5 di origine e path/to/tfjs_target_dir è la directory di output di destinazione per i file TF.js:

# bash

tensorflowjs_converter --input_format keras \
                       path/to/my_model.h5 \
                       path/to/tfjs_target_dir

Alternativa: utilizzare l'API Python per esportare direttamente nel formato TF.js Layers

Se disponi di un modello Keras in Python, puoi esportarlo direttamente nel formato TensorFlow.js Layers come segue:

# Python

import tensorflowjs as tfjs

def train(...):
    model = keras.models.Sequential()   # for example
    ...
    model.compile(...)
    model.fit(...)
    tfjs.converters.save_keras_model(model, tfjs_target_dir)

Passaggio 2: carica il modello in TensorFlow.js

Utilizza un server Web per servire i file del modello convertito generati nel passaggio 1. Tieni presente che potrebbe essere necessario configurare il server per consentire la condivisione delle risorse tra origini (CORS) , per consentire il recupero dei file in JavaScript.

Quindi carica il modello in TensorFlow.js fornendo l'URL al file model.json:

// JavaScript

import * as tf from '@tensorflow/tfjs';

const model = await tf.loadLayersModel('https://foo.bar/tfjs_artifacts/model.json');

Ora il modello è pronto per l'inferenza, la valutazione o il riaddestramento. Ad esempio, il modello caricato può essere immediatamente utilizzato per fare una previsione:

// JavaScript

const example = tf.fromPixels(webcamElement);  // for example
const prediction = model.predict(example);

Molti degli esempi di TensorFlow.js adottano questo approccio, utilizzando modelli preaddestrati che sono stati convertiti e ospitati su Google Cloud Storage.

Tieni presente che fai riferimento all'intero modello utilizzando il nome file model.json . loadModel(...) recupera model.json e quindi effettua ulteriori richieste HTTP(S) per ottenere i file di peso partizionati a cui si fa riferimento nel manifest del peso model.json . Questo approccio consente a tutti questi file di essere memorizzati nella cache dal browser (e forse da ulteriori server di memorizzazione nella cache su Internet), poiché model.json e i frammenti di peso sono ciascuno più piccoli del limite di dimensione tipico del file di cache. Pertanto è probabile che un modello venga caricato più rapidamente nelle occasioni successive.

Funzionalità supportate

TensorFlow.js Layers attualmente supporta solo i modelli Keras che utilizzano costrutti Keras standard. I modelli che utilizzano operazioni o livelli non supportati, ad esempio livelli personalizzati, livelli Lambda, perdite personalizzate o parametri personalizzati, non possono essere importati automaticamente perché dipendono da codice Python che non può essere tradotto in modo affidabile in JavaScript.