Importe um modelo do TensorFlow para o TensorFlow.js

Os modelos baseados em TensorFlow GraphDef (normalmente criados por meio da API Python) podem ser salvos em um dos seguintes formatos:

  1. Modelo salvo do TensorFlow
  2. Modelo Congelado
  3. Módulo Tensorflow Hub

Todos os formatos acima podem ser convertidos pelo conversor TensorFlow.js em um formato que pode ser carregado diretamente no TensorFlow.js para inferência.

(Observação: o TensorFlow descontinuou o formato do pacote de sessão. Migre seus modelos para o formato SavedModel.)

Requisitos

O procedimento de conversão requer um ambiente Python; você pode querer manter um isolado usando pipenv ou virtualenv .

Para instalar o conversor, execute o seguinte comando:

 pip install tensorflowjs

A importação de um modelo do TensorFlow para o TensorFlow.js é um processo de duas etapas. Primeiro, converta um modelo existente para o formato Web TensorFlow.js e, em seguida, carregue-o no TensorFlow.js.

Etapa 1. Converter um modelo TensorFlow existente para o formato Web TensorFlow.js

Execute o script do conversor fornecido pelo pacote pip:

Exemplo de SaveModel:

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

Exemplo de modelo congelado:

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    /mobilenet/frozen_model.pb \
    /mobilenet/web_model

Exemplo de módulo Tensorflow Hub:

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
    /mobilenet/web_model
Argumentos posicionais Descrição
input_path Caminho completo do diretório do modelo salvo, diretório do pacote de sessão, arquivo de modelo congelado ou identificador ou caminho do módulo TensorFlow Hub.
output_path Caminho para todos os artefatos de saída.
Opções Descrição
--input_format O formato do modelo de entrada. Use tf_saved_model para SavedModel, tf_frozen_model para modelo congelado, tf_session_bundle para pacote de sessão, tf_hub para módulo TensorFlow Hub e keras para Keras HDF5.
--output_node_names Os nomes dos nós de saída, separados por vírgulas.
--saved_model_tags Aplicável apenas à conversão SavedModel. Tags do MetaGraphDef a serem carregadas, em formato separado por vírgula. O padrão é serve .
--signature_name Aplicável apenas à conversão do módulo TensorFlow Hub, assinatura para carregamento. O padrão é default . Consulte https://www.tensorflow.org/hub/common_signatures/

Use o seguinte comando para obter uma mensagem de ajuda detalhada:

tensorflowjs_converter --help

Arquivos gerados pelo conversor

O script de conversão acima produz dois tipos de arquivos:

  • model.json : o gráfico de fluxo de dados e o manifesto de peso
  • group1-shard\*of\* : Uma coleção de arquivos de peso binário

Por exemplo, aqui está o resultado da conversão do MobileNet v2:

  output_directory/model.json
  output_directory/group1-shard1of5
  ...
  output_directory/group1-shard5of5

Etapa 2: Carregar e executar no navegador

  1. Instale o pacote npm do tfjs-converter:

yarn add @tensorflow/tfjs ou npm install @tensorflow/tfjs

  1. Instancie a classe FrozenModel e execute a inferência.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

Confira a demonstração do MobileNet .

A API loadGraphModel aceita um parâmetro LoadOptions adicional, que pode ser usado para enviar credenciais ou cabeçalhos personalizados junto com a solicitação. Para obter detalhes, consulte a documentação loadGraphModel() .

Operações suportadas

Atualmente, o TensorFlow.js oferece suporte a um conjunto limitado de operações do TensorFlow. Se o seu modelo usar uma operação não suportada, o script tensorflowjs_converter falhará e imprimirá uma lista das operações não suportadas no seu modelo. Registre um problema para cada operação para nos informar para quais operações você precisa de suporte.

Carregando apenas os pesos

Se preferir carregar apenas os pesos, você pode usar o seguinte trecho de código:

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
        this.weightManifest, "https://example.org/model");
// Use `weightMap` ...