Integra il classificatore del linguaggio naturale BERT

L'API BertNLClassifier della Task Library è molto simile a NLClassifier che classifica il testo di input in diverse categorie, tranne per il fatto che questa API è appositamente studiata per i modelli correlati a Bert che richiedono tokenizzazioni Wordpiece e Sentencepiece al di fuori del modello TFLite.

Caratteristiche principali dell'API BertNLClassifier

  • Prende una singola stringa come input, esegue la classificazione con la stringa e genera output coppie come risultati della classificazione.

  • Esegue tokenizzazioni di Wordpiece o Sentencepiece fuori dal grafico sul testo di input.

Modelli BertNLClassifier supportati

I seguenti modelli sono compatibili con l'API BertNLClassifier .

Esegui l'inferenza in Java

Passaggio 1: importa la dipendenza Gradle e altre impostazioni

Copia il file del modello .tflite nella directory asset del modulo Android in cui verrà eseguito il modello. Specifica che il file non deve essere compresso e aggiungi la libreria TensorFlow Lite al file build.gradle del modulo:

android {
    // Other settings

    // Specify tflite file should not be compressed for the app apk
    aaptOptions {
        noCompress "tflite"
    }

}

dependencies {
    // Other dependencies

    // Import the Task Text Library dependency (NNAPI is included)
    implementation 'org.tensorflow:tensorflow-lite-task-text:0.4.4'
}

Passaggio 2: esegui l'inferenza utilizzando l'API

// Initialization
BertNLClassifierOptions options =
    BertNLClassifierOptions.builder()
        .setBaseOptions(BaseOptions.builder().setNumThreads(4).build())
        .build();
BertNLClassifier classifier =
    BertNLClassifier.createFromFileAndOptions(context, modelFile, options);

// Run inference
List<Category> results = classifier.classify(input);

Vedi il codice sorgente per maggiori dettagli.

Esegui l'inferenza in Swift

Passaggio 1: importa CocoaPods

Aggiungi il pod TensorFlowLiteTaskText in Podfile

target 'MySwiftAppWithTaskAPI' do
  use_frameworks!
  pod 'TensorFlowLiteTaskText', '~> 0.4.4'
end

Passaggio 2: esegui l'inferenza utilizzando l'API

// Initialization
let bertNLClassifier = TFLBertNLClassifier.bertNLClassifier(
      modelPath: bertModelPath)

// Run inference
let categories = bertNLClassifier.classify(text: input)

Vedi il codice sorgente per maggiori dettagli.

Esegui l'inferenza in C++

// Initialization
BertNLClassifierOptions options;
options.mutable_base_options()->mutable_model_file()->set_file_name(model_path);
std::unique_ptr<BertNLClassifier> classifier = BertNLClassifier::CreateFromOptions(options).value();

// Run inference with your input, `input_text`.
std::vector<core::Category> categories = classifier->Classify(input_text);

Vedi il codice sorgente per maggiori dettagli.

Esegui l'inferenza in Python

Passaggio 1: installa il pacchetto pip

pip install tflite-support

Passaggio 2: utilizzo del modello

# Imports
from tflite_support.task import text

# Initialization
classifier = text.BertNLClassifier.create_from_file(model_path)

# Run inference
text_classification_result = classifier.classify(text)

Vedere il codice sorgente per ulteriori opzioni per configurare BertNLClassifier .

Risultati di esempio

Ecco un esempio dei risultati della classificazione delle recensioni di film utilizzando il modello MobileBert di Model Maker.

Input: "è un viaggio affascinante e spesso toccante"

Produzione:

category[0]: 'negative' : '0.00006'
category[1]: 'positive' : '0.99994'

Prova il semplice strumento demo CLI per BertNLClassifier con il tuo modello e i dati di test.

Requisiti di compatibilità del modello

L'API BetNLClassifier prevede un modello TFLite con metadati del modello TFLite obbligatori.

I metadati devono soddisfare i seguenti requisiti:

  • input_process_units per Tokenizer Wordpiece/Sentencepiece

  • 3 tensori di input con nomi "ids", "mask" e "segment_ids" per l'output del tokenizzatore

  • 1 tensore di output di tipo float32, con un file di etichetta allegato facoltativamente. Se è allegato un file di etichette, il file dovrebbe essere un file di testo semplice con un'etichetta per riga e il numero di etichette dovrebbe corrispondere al numero di categorie come output del modello.