RSVP für Ihr lokales TensorFlow Everywhere-Event noch heute!
Diese Seite wurde von der Cloud Translation API übersetzt.
Switch to English

Importieren eines Keras-Modells in TensorFlow.js

Keras-Modelle (normalerweise über die Python-API erstellt) können in einem von mehreren Formaten gespeichert werden. Das Format "Gesamtes Modell" kann in das Ebenenformat TensorFlow.js konvertiert werden, das zur Inferenz oder zur weiteren Schulung direkt in TensorFlow.js geladen werden kann.

Das Ziel-TensorFlow.js-Layer-Format ist ein Verzeichnis, das eine model.json Datei und eine Reihe von Sharded-Weight-Dateien im Binärformat enthält. Die Datei model.json enthält sowohl die model.json (auch bekannt als "Architektur" oder "Grafik": eine Beschreibung der Ebenen und wie sie verbunden sind) als auch ein Manifest der Gewichtungsdateien.

Bedarf

Das Konvertierungsverfahren erfordert eine Python-Umgebung. Möglicherweise möchten Sie eine isolierte Datei mit pipenv oder virtualenv beibehalten . Verwenden Sie zum Installieren des Konverters pip install tensorflowjs .

Das Importieren eines Keras-Modells in TensorFlow.js erfolgt in zwei Schritten. Konvertieren Sie zunächst ein vorhandenes Keras-Modell in das Layer-Format TF.js und laden Sie es dann in TensorFlow.js.

Schritt 1. Konvertieren Sie ein vorhandenes Keras-Modell in das TF.js-Ebenenformat

Keras-Modelle werden normalerweise über model.save(filepath) , wodurch eine einzelne HDF5-Datei (.h5) erstellt wird, die sowohl die Modelltopologie als auch die Gewichte enthält. Führen Sie den folgenden Befehl aus, um eine solche Datei in das TF.js-Layer-Format path/to/my_model.h5 ist path/to/my_model.h5 die Keras .h5-Quelldatei und path/to/tfjs_target_dir das Zielausgabeverzeichnis für die TF.js-Dateien:

# bash

tensorflowjs_converter --input_format keras \
                       path/to/my_model.h5 \
                       path/to/tfjs_target_dir

Alternative: Verwenden Sie die Python-API, um direkt in das Layer-Format von TF.j zu exportieren

Wenn Sie ein Keras-Modell in Python haben, können Sie es wie folgt direkt in das Layer-Format TensorFlow.js exportieren:

# Python

import tensorflowjs as tfjs

def train(...):
    model = keras.models.Sequential()   # for example
    ...
    model.compile(...)
    model.fit(...)
    tfjs.converters.save_keras_model(model, tfjs_target_dir)

Schritt 2: Laden Sie das Modell in TensorFlow.js

Verwenden Sie einen Webserver, um die in Schritt 1 generierten konvertierten Modelldateien bereitzustellen. Beachten Sie, dass Sie Ihren Server möglicherweise so konfigurieren müssen, dass CORS (Cross-Origin Resource Sharing) zulässig ist , damit die Dateien in JavaScript abgerufen werden können.

Laden Sie dann das Modell in TensorFlow.js, indem Sie die URL zur Datei model.json angeben:

// JavaScript

import * as tf from '@tensorflow/tfjs';

const model = await tf.loadLayersModel('https://foo.bar/tfjs_artifacts/model.json');

Jetzt kann das Modell geschlossen, bewertet oder neu trainiert werden. Zum Beispiel kann das geladene Modell sofort verwendet werden, um eine Vorhersage zu treffen:

// JavaScript

const example = tf.fromPixels(webcamElement);  // for example
const prediction = model.predict(example);

Viele der TensorFlow.js-Beispiele verwenden diesen Ansatz und verwenden vorab trainierte Modelle, die konvertiert und in Google Cloud Storage gehostet wurden.

Beachten Sie, dass Sie mit dem Dateinamen model.json auf das gesamte Modell model.json . loadModel(...) model.json und stellt dann zusätzliche HTTP (S) -Anforderungen, um die Shard-Weight-Dateien model.json weight-Manifest model.json . Mit diesem Ansatz können alle diese Dateien vom Browser (und möglicherweise von zusätzlichen Caching-Servern im Internet) zwischengespeichert werden, da model.json und die Gewichtungsshards jeweils kleiner als die typische Größenbeschränkung für Cache-Dateien sind. Daher wird ein Modell bei späteren Gelegenheiten wahrscheinlich schneller geladen.

Unterstützte Funktionen

TensorFlow.js Layer unterstützen derzeit nur Keras-Modelle, die Standard-Keras-Konstrukte verwenden. Modelle, die nicht unterstützte Operationen oder Ebenen verwenden, z. B. benutzerdefinierte Ebenen, Lambda-Ebenen, benutzerdefinierte Verluste oder benutzerdefinierte Metriken, können nicht automatisch importiert werden, da sie von Python-Code abhängen, der nicht zuverlässig in JavaScript übersetzt werden kann.