Keras-Modelle (normalerweise über die Python-API erstellt) können in einem von mehreren Formaten gespeichert werden. Das Format "Gesamtes Modell" kann in das Ebenenformat TensorFlow.js konvertiert werden, das zur Inferenz oder zur weiteren Schulung direkt in TensorFlow.js geladen werden kann.
Das Ziel-TensorFlow.js-Layer-Format ist ein Verzeichnis, das eine model.json
Datei und eine Reihe von Sharded-Weight-Dateien im Binärformat enthält. Die Datei model.json
enthält sowohl die model.json
(auch bekannt als "Architektur" oder "Grafik": eine Beschreibung der Ebenen und wie sie verbunden sind) als auch ein Manifest der Gewichtungsdateien.
Bedarf
Das Konvertierungsverfahren erfordert eine Python-Umgebung. Möglicherweise möchten Sie eine isolierte Datei mit pipenv oder virtualenv beibehalten . Verwenden Sie zum Installieren des Konverters pip install tensorflowjs
.
Das Importieren eines Keras-Modells in TensorFlow.js erfolgt in zwei Schritten. Konvertieren Sie zunächst ein vorhandenes Keras-Modell in das Layer-Format TF.js und laden Sie es dann in TensorFlow.js.
Schritt 1. Konvertieren Sie ein vorhandenes Keras-Modell in das TF.js-Ebenenformat
Keras-Modelle werden normalerweise über model.save(filepath)
, wodurch eine einzelne HDF5-Datei (.h5) erstellt wird, die sowohl die Modelltopologie als auch die Gewichte enthält. Führen Sie den folgenden Befehl aus, um eine solche Datei in das TF.js-Layer-Format path/to/my_model.h5
ist path/to/my_model.h5
die Keras .h5-Quelldatei und path/to/tfjs_target_dir
das Zielausgabeverzeichnis für die TF.js-Dateien:
# bash
tensorflowjs_converter --input_format keras \
path/to/my_model.h5 \
path/to/tfjs_target_dir
Alternative: Verwenden Sie die Python-API, um direkt in das Layer-Format von TF.j zu exportieren
Wenn Sie ein Keras-Modell in Python haben, können Sie es wie folgt direkt in das Layer-Format TensorFlow.js exportieren:
# Python
import tensorflowjs as tfjs
def train(...):
model = keras.models.Sequential() # for example
...
model.compile(...)
model.fit(...)
tfjs.converters.save_keras_model(model, tfjs_target_dir)
Schritt 2: Laden Sie das Modell in TensorFlow.js
Verwenden Sie einen Webserver, um die in Schritt 1 generierten konvertierten Modelldateien bereitzustellen. Beachten Sie, dass Sie Ihren Server möglicherweise so konfigurieren müssen, dass CORS (Cross-Origin Resource Sharing) zulässig ist , damit die Dateien in JavaScript abgerufen werden können.
Laden Sie dann das Modell in TensorFlow.js, indem Sie die URL zur Datei model.json angeben:
// JavaScript
import * as tf from '@tensorflow/tfjs';
const model = await tf.loadLayersModel('https://foo.bar/tfjs_artifacts/model.json');
Jetzt kann das Modell geschlossen, bewertet oder neu trainiert werden. Zum Beispiel kann das geladene Modell sofort verwendet werden, um eine Vorhersage zu treffen:
// JavaScript
const example = tf.fromPixels(webcamElement); // for example
const prediction = model.predict(example);
Viele der TensorFlow.js-Beispiele verwenden diesen Ansatz und verwenden vorab trainierte Modelle, die konvertiert und in Google Cloud Storage gehostet wurden.
Beachten Sie, dass Sie mit dem Dateinamen model.json
auf das gesamte Modell model.json
. loadModel(...)
model.json
und stellt dann zusätzliche HTTP (S) -Anforderungen, um die Shard-Weight-Dateien model.json
weight-Manifest model.json
. Mit diesem Ansatz können alle diese Dateien vom Browser (und möglicherweise von zusätzlichen Caching-Servern im Internet) zwischengespeichert werden, da model.json
und die Gewichtungsshards jeweils kleiner als die typische Größenbeschränkung für Cache-Dateien sind. Daher wird ein Modell bei späteren Gelegenheiten wahrscheinlich schneller geladen.
Unterstützte Funktionen
TensorFlow.js Layer unterstützen derzeit nur Keras-Modelle, die Standard-Keras-Konstrukte verwenden. Modelle, die nicht unterstützte Operationen oder Ebenen verwenden, z. B. benutzerdefinierte Ebenen, Lambda-Ebenen, benutzerdefinierte Verluste oder benutzerdefinierte Metriken, können nicht automatisch importiert werden, da sie von Python-Code abhängen, der nicht zuverlässig in JavaScript übersetzt werden kann.