Importando un modelo de Keras a TensorFlow.js

Los modelos de Keras (generalmente creados a través de la API de Python) se pueden guardar en uno de varios formatos . El formato de "modelo completo" se puede convertir al formato de capas de TensorFlow.js, que se puede cargar directamente en TensorFlow.js para inferencia o para capacitación adicional.

El formato de capas objetivo de TensorFlow.js es un directorio que contiene un archivo model.json y un conjunto de archivos de peso fragmentados en formato binario. El archivo model.json contiene la topología del modelo (también conocida como "arquitectura" o "gráfico": una descripción de las capas y cómo están conectadas) y un manifiesto de los archivos de peso.

Requisitos

El procedimiento de conversión requiere un entorno Python; es posible que desee mantener uno aislado usando pipenv o virtualenv . Para instalar el convertidor, use pip install tensorflowjs .

Importar un modelo de Keras a TensorFlow.js es un proceso de dos pasos. Primero, convierta un modelo Keras existente al formato de capas TF.js y luego cárguelo en TensorFlow.js.

Paso 1. Convierta un modelo Keras existente al formato de capas TF.js

Los modelos de Keras generalmente se guardan a través model.save(filepath) , que produce un solo archivo HDF5 (.h5) que contiene tanto la topología del modelo como los pesos. Para convertir un archivo de este tipo al formato de capas TF.js, ejecute el siguiente comando, donde path/to/my_model.h5 es el archivo Keras .h5 de origen y path/to/tfjs_target_dir es el directorio de salida de destino para los archivos TF.js:

# bash

tensorflowjs_converter --input_format keras \
                       path/to/my_model.h5 \
                       path/to/tfjs_target_dir

Alternativa: use la API de Python para exportar directamente al formato de capas TF.js

Si tiene un modelo de Keras en Python, puede exportarlo directamente al formato de capas de TensorFlow.js de la siguiente manera:

# Python

import tensorflowjs as tfjs

def train(...):
    model = keras.models.Sequential()   # for example
    ...
    model.compile(...)
    model.fit(...)
    tfjs.converters.save_keras_model(model, tfjs_target_dir)

Paso 2: Cargue el modelo en TensorFlow.js

Utilice un servidor web para servir los archivos de modelo convertidos que generó en el paso 1. Tenga en cuenta que es posible que deba configurar su servidor para permitir el uso compartido de recursos de origen cruzado (CORS) , a fin de poder obtener los archivos en JavaScript.

Luego carga el modelo en TensorFlow.js proporcionando la URL del archivo model.json:

// JavaScript

import * as tf from '@tensorflow/tfjs';

const model = await tf.loadLayersModel('https://foo.bar/tfjs_artifacts/model.json');

Ahora el modelo está listo para la inferencia, la evaluación o el reentrenamiento. Por ejemplo, el modelo cargado se puede usar inmediatamente para hacer una predicción:

// JavaScript

const example = tf.fromPixels(webcamElement);  // for example
const prediction = model.predict(example);

Muchos de los ejemplos de TensorFlow.js adoptan este enfoque, utilizando modelos previamente entrenados que se han convertido y alojado en Google Cloud Storage.

Tenga en cuenta que se refiere a todo el modelo con el nombre de archivo model.json . loadModel(...) obtiene model.json y luego realiza solicitudes HTTP(S) adicionales para obtener los archivos de peso fragmentados a los que se hace referencia en el manifiesto de peso de model.json . Este enfoque permite que el navegador almacene en caché todos estos archivos (y quizás servidores de almacenamiento en caché adicionales en Internet), porque el model.json y los fragmentos de peso son cada uno más pequeños que el límite de tamaño de archivo de caché típico. Por lo tanto, es probable que un modelo se cargue más rápido en ocasiones posteriores.

Funciones admitidas

TensorFlow.js Layers actualmente solo admite modelos de Keras que usan construcciones estándar de Keras. Los modelos que utilizan operaciones o capas no admitidas, por ejemplo, capas personalizadas, capas de Lambda, pérdidas personalizadas o métricas personalizadas, no se pueden importar automáticamente porque dependen del código de Python que no se puede traducir de manera confiable a JavaScript.