ML Topluluk Günü 9 Kasım! TensorFlow, JAX güncellemeler için bize katılın ve daha fazla bilgi edinin

TensorFlow GraphDef tabanlı Modelleri TensorFlow.js'ye İçe Aktarma

TensorFlow GraphDef tabanlı modeller (genellikle Python API aracılığıyla oluşturulur) aşağıdaki biçimlerden birinde kaydedilebilir:

  1. TensorFlow SavedModel
  2. Dondurulmuş Model
  3. Tensorflow Hub modülü

Yukarıdaki formatları tüm tarafından dönüştürülebilir TensorFlow.js dönüştürücü çıkarım için TensorFlow.js doğrudan yüklenebilir bir biçime.

(Not: TensorFlow, oturum paketi biçimini kullanımdan kaldırmıştır, lütfen modellerinizi SavedModel biçimine taşıyın.)

Gereksinimler

Dönüştürme prosedürü bir Python ortamı gerektirir; Kullandığınız bir izole birini tutmak isteyebilirsiniz pipenv veya VIRTUALENV . Dönüştürücüyü kurmak için aşağıdaki komutu çalıştırın:

 pip install tensorflowjs

Bir TensorFlow modelini TensorFlow.js'ye içe aktarmak iki adımlı bir işlemdir. Önce mevcut bir modeli TensorFlow.js web formatına dönüştürün ve ardından onu TensorFlow.js'ye yükleyin.

Adım 1. Mevcut bir TensorFlow modelini TensorFlow.js web formatına dönüştürün

pip paketi tarafından sağlanan dönüştürücü komut dosyasını çalıştırın:

Kullanım: SavedModel örneği:

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

Dondurulmuş model örneği:

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    /mobilenet/frozen_model.pb \
    /mobilenet/web_model

Tensorflow Hub modülü örneği:

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
    /mobilenet/web_model
Konumsal Argümanlar Tanım
input_path Kaydedilen model dizininin tam yolu, oturum paketi dizini, donmuş model dosyası veya TensorFlow Hub modül tanıtıcısı veya yolu.
output_path Tüm çıktı yapılarının yolu.
Seçenekler Tanım
--input_format Girdi modeli formatı, SavedModel için tf_saved_model, frozen model için tf_frozen_model, oturum paketi için tf_session_bundle, TensorFlow Hub modülü için tf_hub ve Keras HDF5 için keras kullanın.
--output_node_names Çıkış düğümlerinin virgülle ayrılmış adları.
--saved_model_tags Yalnızca SavedModel dönüştürme için geçerlidir, Yüklenecek MetaGraphDef Etiketleri, virgülle ayrılmış biçimde. Varsayılan için serve .
--signature_name Yalnızca TensorFlow Hub modül dönüşümü için geçerlidir, imza yüklenir. Varsayılan için default . Bkz https://www.tensorflow.org/hub/common_signatures/

Ayrıntılı bir yardım mesajı almak için aşağıdaki komutu kullanın:

tensorflowjs_converter --help

Dönüştürücü tarafından oluşturulan dosyalar

Yukarıdaki dönüştürme komut dosyası iki tür dosya üretir:

  • model.json (veri akışı grafik ve ağırlık tezahür)
  • group1-shard\*of\* (ikili ağırlık dosyalarının toplanması)

Örneğin, MobileNet v2'nin dönüştürülmesinden elde edilen çıktı:

  output_directory/model.json
  output_directory/group1-shard1of5
  ...
  output_directory/group1-shard5of5

Adım 2: Tarayıcıda yükleme ve çalıştırma

  1. tfjs-converter npm paketini kurun

yarn add @tensorflow/tfjs veya npm install @tensorflow/tfjs

  1. Somutlaştırın FrozenModel sınıfını ve çalıştırmak çıkarım.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

Bizim göz atın MobileNet demo .

loadGraphModel API ek bir kabul LoadOptions isteğiyle birlikte kimlik bilgilerini veya özel başlıklar göndermek için kullanılabilir parametre. Bakınız loadGraphModel () belgelerine daha fazla ayrıntı için.

Desteklenen işlemler

Şu anda TensorFlow.js, sınırlı sayıda TensorFlow işlemini desteklemektedir. Modeliniz desteklenmeyen op kullanıyorsa, tensorflowjs_converter komut başarısız ve Modelinizdeki desteklenmeyen ops bir listesini yazdırır. Bir dosya Lütfen sorunu bize desteği gerek hangi ops bildirmek için her op için.

Sadece ağırlıkların yüklenmesi

Yalnızca ağırlıkları yüklemeyi tercih ederseniz, aşağıdaki kod parçasını kullanabilirsiniz.

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
        this.weightManifest, "https://example.org/model");