Integrate BERT natural language classifier

The Task Library BertNLClassifier API is very similar to the NLClassifier that classifies input text into different categories, except that this API is specially tailored for Bert related models that require Wordpiece and Sentencepiece tokenizations outside the TFLite model.

Key features of the BertNLClassifier API

  • Takes a single string as input, performs classification with the string and outputs pairs as classification results.

  • Performs out-of-graph Wordpiece or Sentencepiece tokenizations on input text.

Supported BertNLClassifier models

The following models are compatible with the BertNLClassifier API.

Run inference in Java

Step 1: Import Gradle dependency and other settings

Copy the .tflite model file to the assets directory of the Android module where the model will be run. Specify that the file should not be compressed, and add the TensorFlow Lite library to the module’s build.gradle file:

android {
    // Other settings

    // Specify tflite file should not be compressed for the app apk
    aaptOptions {
        noCompress "tflite"
    }

}

dependencies {
    // Other dependencies

    // Import the Task Text Library dependency
    implementation 'org.tensorflow:tensorflow-lite-task-text:0.0.0-nightly'
}

Step 2: Run inference using the API

// Initialization
BertNLClassifier classifier = BertNLClassifier.createFromFile(context, modelFile);

// Run inference
List<Category> results = classifier.classify(input);

See the source code for more details.

Run inference in C++

// Initialization
std::unique_ptr<BertNLClassifier> classifier = BertNLClassifier::CreateFromFile(model_path).value();

// Run inference
std::vector<core::Category> categories = classifier->Classify(kInput);

See the source code for more details.

Example results

Here is an example of the classification results of movie reviews using the MobileBert model from Model Maker.

Input: "it's a charming and often affecting journey"

Output:

category[0]: 'negative' : '0.00006'
category[1]: 'positive' : '0.99994'

Try out the simple CLI demo tool for BertNLClassifier with your own model and test data.

Model compatibility requirements

The BetNLClassifier API expects a TFLite model with mandatory TFLite Model Metadata.

The Metadata should meet the following requiresments:

  • input_process_units for Wordpiece/Sentencepiece Tokenizer

  • 3 input tensors with names "ids", "mask" and "segment_ids" for the output of the tokenizer

  • 1 output tensor of type float32, with a optionally attached label file. If a label file is attached, the file should be a plain text file with one label per line and the number of labels should match the number of categories as the model outputs.