Importando um modelo Keras para TensorFlow.js

Os modelos Keras (normalmente criados por meio da API Python) podem ser salvos em um dos vários formatos . O formato "modelo inteiro" pode ser convertido para o formato TensorFlow.js Layers, que pode ser carregado diretamente no TensorFlow.js para inferência ou treinamento adicional.

O formato de camadas TensorFlow.js de destino é um diretório que contém um arquivo model.json e um conjunto de arquivos de peso fragmentados em formato binário. O arquivo model.json contém a topologia do modelo (também conhecida como "arquitetura" ou "gráfico": uma descrição das camadas e como elas estão conectadas) e um manifesto dos arquivos de peso.

Requisitos

O procedimento de conversão requer um ambiente Python; você pode querer manter um isolado usando pipenv ou virtualenv . Para instalar o conversor, use pip install tensorflowjs .

A importação de um modelo Keras para o TensorFlow.js é um processo de duas etapas. Primeiro, converta um modelo Keras existente para o formato TF.js Layers e, em seguida, carregue-o no TensorFlow.js.

Etapa 1. Converter um modelo Keras existente para o formato TF.js Layers

Os modelos Keras geralmente são salvos via model.save(filepath) , que produz um único arquivo HDF5 (.h5) contendo a topologia do modelo e os pesos. Para converter esse arquivo para o formato TF.js Layers, execute o seguinte comando, onde path/to/my_model.h5 é o arquivo Keras .h5 de origem e path/to/tfjs_target_dir é o diretório de saída de destino para os arquivos TF.js:

# bash

tensorflowjs_converter --input_format keras \
                       path/to/my_model.h5 \
                       path/to/tfjs_target_dir

Alternativa: use a API Python para exportar diretamente para o formato TF.js Layers

Se você tiver um modelo Keras em Python, poderá exportá-lo diretamente para o formato TensorFlow.js Layers da seguinte maneira:

# Python

import tensorflowjs as tfjs

def train(...):
    model = keras.models.Sequential()   # for example
    ...
    model.compile(...)
    model.fit(...)
    tfjs.converters.save_keras_model(model, tfjs_target_dir)

Etapa 2: carregar o modelo no TensorFlow.js

Use um servidor web para servir os arquivos de modelo convertidos gerados na Etapa 1. Observe que pode ser necessário configurar seu servidor para permitir o Cross-Origin Resource Sharing (CORS) , para permitir a busca dos arquivos em JavaScript.

Em seguida, carregue o modelo no TensorFlow.js fornecendo o URL para o arquivo model.json:

// JavaScript

import * as tf from '@tensorflow/tfjs';

const model = await tf.loadLayersModel('https://foo.bar/tfjs_artifacts/model.json');

Agora o modelo está pronto para inferência, avaliação ou retreinamento. Por exemplo, o modelo carregado pode ser usado imediatamente para fazer uma previsão:

// JavaScript

const example = tf.fromPixels(webcamElement);  // for example
const prediction = model.predict(example);

Muitos dos exemplos do TensorFlow.js adotam essa abordagem, usando modelos pré-treinados que foram convertidos e hospedados no Google Cloud Storage.

Observe que você se refere ao modelo inteiro usando o nome de arquivo model.json . loadModel(...) busca model.json e, em seguida, faz solicitações HTTP(S) adicionais para obter os arquivos de peso fragmentados referenciados no manifesto de peso model.json . Essa abordagem permite que todos esses arquivos sejam armazenados em cache pelo navegador (e talvez por servidores de cache adicionais na Internet), porque o model.json e os fragmentos de peso são menores que o limite típico de tamanho do arquivo de cache. Portanto, é provável que um modelo carregue mais rapidamente em ocasiões subsequentes.

Recursos suportados

Atualmente, as camadas TensorFlow.js oferecem suporte apenas a modelos Keras que usam construções Keras padrão. Modelos que usam operações ou camadas não suportadas (por exemplo, camadas personalizadas, camadas Lambda, perdas personalizadas ou métricas personalizadas) não podem ser importados automaticamente porque dependem de código Python que não pode ser traduzido de forma confiável em JavaScript.