Importe um modelo do TensorFlow para o TensorFlow.js

Os modelos baseados em GraphDef do TensorFlow (normalmente criados por meio da API Python) podem ser salvos em um dos seguintes formatos:

  1. TensorFlow SavedModel
  2. Modelo Congelado
  3. Módulo Tensorflow Hub

Todos os formatos acima podem ser convertidos pelo conversor TensorFlow.js em um formato que pode ser carregado diretamente no TensorFlow.js para inferência.

(Observação: o TensorFlow suspendeu o uso do formato de pacote de sessão. Migre seus modelos para o formato SavedModel.)

Requisitos

O procedimento de conversão requer um ambiente Python; você pode querer manter um isolado usando pipenv ou virtualenv .

Para instalar o conversor, execute o seguinte comando:

 pip install tensorflowjs

A importação de um modelo do TensorFlow para o TensorFlow.js é um processo de duas etapas. Primeiro, converta um modelo existente no formato da Web do TensorFlow.js e carregue-o no TensorFlow.js.

Etapa 1. Converter um modelo existente do TensorFlow para o formato da Web do TensorFlow.js

Execute o script do conversor fornecido pelo pacote pip:

Exemplo de SavedModel:

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

Exemplo de modelo congelado:

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    /mobilenet/frozen_model.pb \
    /mobilenet/web_model

Exemplo de módulo Tensorflow Hub:

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
    /mobilenet/web_model
Argumentos Posicionais Descrição
input_path Caminho completo do diretório de modelo salvo, diretório de pacote de sessão, arquivo de modelo congelado ou identificador ou caminho do módulo TensorFlow Hub.
output_path Caminho para todos os artefatos de saída.
Opções Descrição
--input_format O formato do modelo de entrada. Use tf_saved_model para SavedModel, tf_frozen_model para modelo congelado, tf_session_bundle para pacote de sessão, tf_hub para módulo TensorFlow Hub e keras para Keras HDF5.
--output_node_names Os nomes dos nós de saída, separados por vírgulas.
--saved_model_tags Aplicável apenas à conversão de SavedModel. Tags do MetaGraphDef a carregar, separados por vírgula. Padrões para serve .
--signature_name Aplicável apenas à conversão do módulo TensorFlow Hub, assinatura para carregar. Padrões para default . Consulte https://www.tensorflow.org/hub/common_signatures/

Use o seguinte comando para obter uma mensagem de ajuda detalhada:

tensorflowjs_converter --help

Arquivos gerados pelo conversor

O script de conversão acima produz dois tipos de arquivos:

  • model.json : o gráfico de fluxo de dados e o manifesto de peso
  • group1-shard\*of\* : Uma coleção de arquivos binários de peso

Por exemplo, aqui está o resultado da conversão do MobileNet v2:

  output_directory/model.json
  output_directory/group1-shard1of5
  ...
  output_directory/group1-shard5of5

Etapa 2: Carregando e executando no navegador

  1. Instale o pacote tfjs-converter npm:

yarn add @tensorflow/tfjs ou npm install @tensorflow/tfjs

  1. Crie uma instância da classe FrozenModel e execute a inferência.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

Confira a demonstração do MobileNet .

A API loadGraphModel aceita um parâmetro LoadOptions adicional, que pode ser usado para enviar credenciais ou cabeçalhos personalizados junto com a solicitação. Para obter detalhes, consulte a documentação loadGraphModel() .

Operações suportadas

Atualmente, o TensorFlow.js oferece suporte a um conjunto limitado de operações do TensorFlow. Se seu modelo usar uma operação sem suporte, o script tensorflowjs_converter falhará e imprimirá uma lista das operações sem suporte em seu modelo. Registre um problema para cada operação para nos informar para quais operações você precisa de suporte.

Carregando apenas os pesos

Se preferir carregar apenas os pesos, você pode usar o seguinte trecho de código:

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
        this.weightManifest, "https://example.org/model");
// Use `weightMap` ...