Google I/O est terminé ! Suivez les sessions TensorFlow Afficher les sessions

Intégrer le classificateur de langage naturel BERT

L'API de la bibliothèque de tâches BertNLClassifier est très similaire au NLClassifier qui classe le texte d'entrée dans différentes catégories, sauf que cette API est spécialement conçue pour les modèles liés à Bert qui nécessitent des tokenisations Wordpiece et Sentencepiece en dehors du modèle TFLite.

Fonctionnalités clés de l'API BertNLClassifier

  • Prend une seule chaîne en entrée, effectue la classification avec la chaîne et les sorties paires comme résultats de classification.

  • Effectue des tokenisations Wordpiece ou Sentencepiece hors graphique sur le texte d'entrée.

Modèles BertNLClassifier pris en charge

Les modèles suivants sont compatibles avec l'API BertNLClassifier .

Exécuter l'inférence en Java

Étape 1 : Importer la dépendance Gradle et d'autres paramètres

Copiez le fichier de modèle .tflite dans le répertoire assets du module Android où le modèle sera exécuté. Spécifiez que le fichier ne doit pas être compressé et ajoutez la bibliothèque TensorFlow Lite au fichier build.gradle du module :

android {
    // Other settings

    // Specify tflite file should not be compressed for the app apk
    aaptOptions {
        noCompress "tflite"
    }

}

dependencies {
    // Other dependencies

    // Import the Task Text Library dependency (NNAPI is included)
    implementation 'org.tensorflow:tensorflow-lite-task-text:0.3.0'
}

Étape 2 : Exécuter l'inférence à l'aide de l'API

// Initialization
BertNLClassifierOptions options =
    BertNLClassifierOptions.builder()
        .setBaseOptions(BaseOptions.builder().setNumThreads(4).build())
        .build();
BertNLClassifier classifier =
    BertNLClassifier.createFromFileAndOptions(context, modelFile, options);

// Run inference
List<Category> results = classifier.classify(input);

Voir le code source pour plus de détails.

Exécuter l'inférence dans Swift

Étape 1 : Importer des CocoaPods

Ajouter le pod TensorFlowLiteTaskText dans Podfile

target 'MySwiftAppWithTaskAPI' do
  use_frameworks!
  pod 'TensorFlowLiteTaskText', '~> 0.2.0'
end

Étape 2 : Exécuter l'inférence à l'aide de l'API

// Initialization
let bertNLClassifier = TFLBertNLClassifier.bertNLClassifier(
      modelPath: bertModelPath)

// Run inference
let categories = bertNLClassifier.classify(text: input)

Voir le code source pour plus de détails.

Exécuter l'inférence en C++

// Initialization
BertNLClassifierOptions options;
options.mutable_base_options()->mutable_model_file()->set_file_name(model_file);
std::unique_ptr<BertNLClassifier> classifier = BertNLClassifier::CreateFromOptions(options).value();

// Run inference
std::vector<core::Category> categories = classifier->Classify(kInput);

Voir le code source pour plus de détails.

Exemples de résultats

Voici un exemple des résultats de classification des critiques de films utilisant le modèle MobileBert de Model Maker.

Entrée : "c'est un voyage charmant et souvent émouvant"

Production:

category[0]: 'negative' : '0.00006'
category[1]: 'positive' : '0.99994'

Essayez l' outil de démonstration CLI simple pour BertNLClassifier avec votre propre modèle et vos données de test.

Exigences de compatibilité des modèles

L'API BetNLClassifier attend un modèle TFLite avec des métadonnées de modèle TFLite obligatoires.

Les métadonnées doivent répondre aux exigences suivantes :

  • input_process_units pour Wordpiece/Sentencepiece Tokenizer

  • 3 tenseurs d'entrée avec les noms "ids", "mask" et "segment_ids" pour la sortie du tokenizer

  • 1 tenseur de sortie de type float32, avec éventuellement un fichier label attaché. Si un fichier d'étiquettes est joint, le fichier doit être un fichier texte brut avec une étiquette par ligne et le nombre d'étiquettes doit correspondre au nombre de catégories en sortie du modèle.