Xây dựng API tác vụ của riêng bạn

Thư viện tác vụ TensorFlow Lite cung cấp các API gốc/Android/iOS dựng sẵn trên cùng cơ sở hạ tầng trừu tượng hóa TensorFlow. Bạn có thể mở rộng cơ sở hạ tầng API tác vụ để xây dựng các API tùy chỉnh nếu mô hình của bạn không được thư viện Tác vụ hiện có hỗ trợ.

Tổng quan

Cơ sở hạ tầng API tác vụ có cấu trúc hai lớp: lớp C++ dưới cùng đóng gói thời gian chạy TFLite gốc và lớp Java/ObjC trên cùng giao tiếp với lớp C++ thông qua JNI hoặc trình bao bọc gốc.

Việc triển khai tất cả logic TensorFlow chỉ trong C++ giúp giảm thiểu chi phí, tối đa hóa hiệu suất suy luận và đơn giản hóa quy trình làm việc tổng thể trên các nền tảng.

Để tạo một lớp Nhiệm vụ, hãy mở rộng BaseTaskApi để cung cấp logic chuyển đổi giữa giao diện mô hình TFLite và giao diện API nhiệm vụ, sau đó sử dụng các tiện ích Java/ObjC để tạo các API tương ứng. Với tất cả thông tin chi tiết về TensorFlow bị ẩn, bạn có thể triển khai mô hình TFLite trong ứng dụng của mình mà không cần bất kỳ kiến ​​thức nào về máy học.

TensorFlow Lite cung cấp một số API dựng sẵn cho hầu hết các tác vụ Vision và NLP phổ biến. Bạn có thể xây dựng API của riêng mình cho các tác vụ khác bằng cơ sở hạ tầng API tác vụ.

prebuild_task_apis
Hình 1. API tác vụ dựng sẵn

Xây dựng API của riêng bạn với cơ sở hạ tầng API tác vụ

API C++

Tất cả các chi tiết TFLite được triển khai trong API gốc. Tạo một đối tượng API bằng cách sử dụng một trong các hàm xuất xưởng và nhận kết quả mô hình bằng cách gọi các hàm được xác định trong giao diện.

Sử dụng mẫu

Đây là ví dụ sử dụng C++ BertQuestionAnswerer cho MobileBert .

  char kBertModelPath[] = "path/to/model.tflite";
  // Create the API from a model file
  std::unique_ptr<BertQuestionAnswerer> question_answerer =
      BertQuestionAnswerer::CreateFromFile(kBertModelPath);

  char kContext[] = ...; // context of a question to be answered
  char kQuestion[] = ...; // question to be answered
  // ask a question
  std::vector<QaAnswer> answers = question_answerer.Answer(kContext, kQuestion);
  // answers[0].text is the best answer

Xây dựng API

bản địa_task_api
Hình 2. API tác vụ gốc

Để xây dựng một đối tượng API, bạn phải cung cấp thông tin sau bằng cách mở rộng BaseTaskApi

  • Xác định I/O API - API của bạn sẽ hiển thị đầu vào/đầu ra tương tự trên các nền tảng khác nhau. ví dụ: BertQuestionAnswerer lấy hai chuỗi (std::string& context, std::string& question) làm đầu vào và xuất ra một vectơ câu trả lời và xác suất có thể có dưới dạng std::vector<QaAnswer> . Điều này được thực hiện bằng cách chỉ định các loại tương ứng trong tham số mẫu của BaseTaskApi . Với các tham số mẫu được chỉ định, hàm BaseTaskApi::Infer sẽ có loại đầu vào/đầu ra chính xác. Hàm này có thể được các ứng dụng khách API gọi trực tiếp, nhưng cách tốt nhất là bọc nó bên trong một hàm dành riêng cho mô hình, trong trường hợp này là BertQuestionAnswerer::Answer .

    class BertQuestionAnswerer : public BaseTaskApi<
                                  std::vector<QaAnswer>, // OutputType
                                  const std::string&, const std::string& // InputTypes
                                  > {
      // Model specific function delegating calls to BaseTaskApi::Infer
      std::vector<QaAnswer> Answer(const std::string& context, const std::string& question) {
        return Infer(context, question).value();
      }
    }
    
  • Cung cấp logic chuyển đổi giữa API I/O và tensor đầu vào/đầu ra của mô hình - Với các loại đầu vào và đầu ra được chỉ định, các lớp con cũng cần triển khai các hàm được gõ BaseTaskApi::PreprocessBaseTaskApi::Postprocess . Hai hàm này cung cấp đầu vàođầu ra từ TFLite FlatBuffer . Lớp con chịu trách nhiệm gán các giá trị từ API I/O cho các tensor I/O. Xem ví dụ triển khai đầy đủ trong BertQuestionAnswerer .

    class BertQuestionAnswerer : public BaseTaskApi<
                                  std::vector<QaAnswer>, // OutputType
                                  const std::string&, const std::string& // InputTypes
                                  > {
      // Convert API input into tensors
      absl::Status BertQuestionAnswerer::Preprocess(
        const std::vector<TfLiteTensor*>& input_tensors, // input tensors of the model
        const std::string& context, const std::string& query // InputType of the API
      ) {
        // Perform tokenization on input strings
        ...
        // Populate IDs, Masks and SegmentIDs to corresponding input tensors
        PopulateTensor(input_ids, input_tensors[0]);
        PopulateTensor(input_mask, input_tensors[1]);
        PopulateTensor(segment_ids, input_tensors[2]);
        return absl::OkStatus();
      }
    
      // Convert output tensors into API output
      StatusOr<std::vector<QaAnswer>> // OutputType
      BertQuestionAnswerer::Postprocess(
        const std::vector<const TfLiteTensor*>& output_tensors, // output tensors of the model
      ) {
        // Get start/end logits of prediction result from output tensors
        std::vector<float> end_logits;
        std::vector<float> start_logits;
        // output_tensors[0]: end_logits FLOAT[1, 384]
        PopulateVector(output_tensors[0], &end_logits);
        // output_tensors[1]: start_logits FLOAT[1, 384]
        PopulateVector(output_tensors[1], &start_logits);
        ...
        std::vector<QaAnswer::Pos> orig_results;
        // Look up the indices from vocabulary file and build results
        ...
        return orig_results;
      }
    }
    
  • Tạo các hàm xuất xưởng của API - Cần có tệp mô hình và OpResolver để khởi tạo tflite::Interpreter . TaskAPIFactory cung cấp các chức năng tiện ích để tạo phiên bản BaseTaskApi.

    Bạn cũng phải cung cấp mọi tệp được liên kết với mô hình. ví dụ: BertQuestionAnswerer cũng có thể có một tệp bổ sung cho từ vựng của mã thông báo của nó.

    class BertQuestionAnswerer : public BaseTaskApi<
                                  std::vector<QaAnswer>, // OutputType
                                  const std::string&, const std::string& // InputTypes
                                  > {
      // Factory function to create the API instance
      StatusOr<std::unique_ptr<QuestionAnswerer>>
      BertQuestionAnswerer::CreateBertQuestionAnswerer(
          const std::string& path_to_model, // model to passed to TaskApiFactory
          const std::string& path_to_vocab  // additional model specific files
      ) {
        // Creates an API object by calling one of the utils from TaskAPIFactory
        std::unique_ptr<BertQuestionAnswerer> api_to_init;
        ASSIGN_OR_RETURN(
            api_to_init,
            core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>(
                path_to_model,
                absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
                kNumLiteThreads));
    
        // Perform additional model specific initializations
        // In this case building a vocabulary vector from the vocab file.
        api_to_init->InitializeVocab(path_to_vocab);
        return api_to_init;
      }
    }
    

API Android

Tạo API Android bằng cách xác định giao diện Java/Kotlin và ủy quyền logic cho lớp C++ thông qua JNI. API Android yêu cầu API gốc phải được xây dựng trước.

Sử dụng mẫu

Đây là một ví dụ sử dụng Java BertQuestionAnswerer cho MobileBert .

  String BERT_MODEL_FILE = "path/to/model.tflite";
  String VOCAB_FILE = "path/to/vocab.txt";
  // Create the API from a model file and vocabulary file
    BertQuestionAnswerer bertQuestionAnswerer =
        BertQuestionAnswerer.createBertQuestionAnswerer(
            ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE, VOCAB_FILE);

  String CONTEXT = ...; // context of a question to be answered
  String QUESTION = ...; // question to be answered
  // ask a question
  List<QaAnswer> answers = bertQuestionAnswerer.answer(CONTEXT, QUESTION);
  // answers.get(0).text is the best answer

Xây dựng API

android_task_api
Hình 3. API tác vụ Android

Tương tự như API gốc, để xây dựng một đối tượng API, máy khách cần cung cấp thông tin sau bằng cách mở rộng BaseTaskApi , cung cấp khả năng xử lý JNI cho tất cả các API tác vụ Java.

  • Xác định API I/O - Điều này thường phản ánh các giao diện gốc. ví dụ: BertQuestionAnswerer lấy (String context, String question) làm đầu vào và đầu ra List<QaAnswer> . Việc triển khai gọi một hàm gốc riêng có chữ ký tương tự, ngoại trừ nó có một tham số bổ sung long nativeHandle , là con trỏ được trả về từ C++.

    class BertQuestionAnswerer extends BaseTaskApi {
      public List<QaAnswer> answer(String context, String question) {
        return answerNative(getNativeHandle(), context, question);
      }
    
      private static native List<QaAnswer> answerNative(
                                            long nativeHandle, // C++ pointer
                                            String context, String question // API I/O
                                           );
    
    }
    
  • Tạo các hàm xuất xưởng của API - Điều này cũng phản ánh các hàm gốc của nhà máy, ngoại trừ các hàm xuất xưởng của Android cũng cần lấy Context để truy cập tệp. Việc triển khai gọi một trong các tiện ích trong TaskJniUtils để xây dựng đối tượng API C++ tương ứng và chuyển con trỏ của nó tới hàm tạo BaseTaskApi .

      class BertQuestionAnswerer extends BaseTaskApi {
        private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME =
                                                  "bert_question_answerer_jni";
    
        // Extending super constructor by providing the
        // native handle(pointer of corresponding C++ API object)
        private BertQuestionAnswerer(long nativeHandle) {
          super(nativeHandle);
        }
    
        public static BertQuestionAnswerer createBertQuestionAnswerer(
                                            Context context, // Accessing Android files
                                            String pathToModel, String pathToVocab) {
          return new BertQuestionAnswerer(
              // The util first try loads the JNI module with name
              // BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, then opens two files,
              // converts them into ByteBuffer, finally ::initJniWithBertByteBuffers
              // is called with the buffer for a C++ API object pointer
              TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
                  context,
                  BertQuestionAnswerer::initJniWithBertByteBuffers,
                  BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
                  pathToModel,
                  pathToVocab));
        }
    
        // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer.
        // returns C++ API object pointer casted to long
        private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers);
    
      }
    
  • Triển khai mô-đun JNI cho các hàm gốc - Tất cả các phương thức gốc Java đều được triển khai bằng cách gọi hàm gốc tương ứng từ mô-đun JNI. Các hàm xuất xưởng sẽ tạo một đối tượng API gốc và trả về con trỏ của nó dưới dạng kiểu dài cho Java. Trong các lệnh gọi tới API Java sau này, con trỏ kiểu dài được chuyển trở lại JNI và chuyển trở lại đối tượng API gốc. Sau đó, các kết quả API gốc được chuyển đổi trở lại kết quả Java.

    Ví dụ: đây là cách triển khai bert_question_answerer_jni .

      // Implements BertQuestionAnswerer::initJniWithBertByteBuffers
      extern "C" JNIEXPORT jlong JNICALL
      Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBuffers(
          JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
        // Convert Java ByteBuffer object into a buffer that can be read by native factory functions
        absl::string_view model =
            GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
    
        // Creates the native API object
        absl::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
            BertQuestionAnswerer::CreateFromBuffer(
                model.data(), model.size());
        if (status.ok()) {
          // converts the object pointer to jlong and return to Java.
          return reinterpret_cast<jlong>(status->release());
        } else {
          return kInvalidPointer;
        }
      }
    
      // Implements BertQuestionAnswerer::answerNative
      extern "C" JNIEXPORT jobject JNICALL
      Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative(
      JNIEnv* env, jclass thiz, jlong native_handle, jstring context, jstring question) {
      // Convert long to native API object pointer
      QuestionAnswerer* question_answerer = reinterpret_cast<QuestionAnswerer*>(native_handle);
    
      // Calls the native API
      std::vector<QaAnswer> results = question_answerer->Answer(JStringToString(env, context),
                                             JStringToString(env, question));
    
      // Converts native result(std::vector<QaAnswer>) to Java result(List<QaAnswerer>)
      jclass qa_answer_class =
        env->FindClass("org/tensorflow/lite/task/text/qa/QaAnswer");
      jmethodID qa_answer_ctor =
        env->GetMethodID(qa_answer_class, "<init>", "(Ljava/lang/String;IIF)V");
      return ConvertVectorToArrayList<QaAnswer>(
        env, results,
        [env, qa_answer_class, qa_answer_ctor](const QaAnswer& ans) {
          jstring text = env->NewStringUTF(ans.text.data());
          jobject qa_answer =
              env->NewObject(qa_answer_class, qa_answer_ctor, text, ans.pos.start,
                             ans.pos.end, ans.pos.logit);
          env->DeleteLocalRef(text);
          return qa_answer;
        });
      }
    
      // Implements BaseTaskApi::deinitJni by delete the native object
      extern "C" JNIEXPORT void JNICALL Java_task_core_BaseTaskApi_deinitJni(
          JNIEnv* env, jobject thiz, jlong native_handle) {
        delete reinterpret_cast<QuestionAnswerer*>(native_handle);
      }
    

API iOS

Tạo API iOS bằng cách gói đối tượng API gốc vào đối tượng API ObjC. Đối tượng API đã tạo có thể được sử dụng trong ObjC hoặc Swift. API iOS yêu cầu API gốc phải được xây dựng trước.

Sử dụng mẫu

Dưới đây là ví dụ sử dụng ObjC TFLBertQuestionAnswerer cho MobileBert trong Swift.

  static let mobileBertModelPath = "path/to/model.tflite";
  // Create the API from a model file and vocabulary file
  let mobileBertAnswerer = TFLBertQuestionAnswerer.mobilebertQuestionAnswerer(
      modelPath: mobileBertModelPath)

  static let context = ...; // context of a question to be answered
  static let question = ...; // question to be answered
  // ask a question
  let answers = mobileBertAnswerer.answer(
      context: TFLBertQuestionAnswererTest.context, question: TFLBertQuestionAnswererTest.question)
  // answers.[0].text is the best answer

Xây dựng API

ios_task_api
Hình 4. API tác vụ iOS

API iOS là một trình bao bọc ObjC đơn giản bên trên API gốc. Xây dựng API bằng cách thực hiện theo các bước bên dưới:

  • Xác định trình bao bọc ObjC - Xác định lớp ObjC và ủy quyền triển khai cho đối tượng API gốc tương ứng. Lưu ý rằng các phần phụ thuộc gốc chỉ có thể xuất hiện trong tệp .mm do Swift không thể tương tác với C++.

    • tập tin .h
      @interface TFLBertQuestionAnswerer : NSObject
    
      // Delegate calls to the native BertQuestionAnswerer::CreateBertQuestionAnswerer
      + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString*)modelPath
                                                    vocabPath:(NSString*)vocabPath
          NS_SWIFT_NAME(mobilebertQuestionAnswerer(modelPath:vocabPath:));
    
      // Delegate calls to the native BertQuestionAnswerer::Answer
      - (NSArray<TFLQAAnswer*>*)answerWithContext:(NSString*)context
                                         question:(NSString*)question
          NS_SWIFT_NAME(answer(context:question:));
    }
    
    • tập tin .mm
      using BertQuestionAnswererCPP = ::tflite::task::text::BertQuestionAnswerer;
    
      @implementation TFLBertQuestionAnswerer {
        // define an iVar for the native API object
        std::unique_ptr<QuestionAnswererCPP> _bertQuestionAnswerwer;
      }
    
      // Initialize the native API object
      + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString *)modelPath
                                              vocabPath:(NSString *)vocabPath {
        absl::StatusOr<std::unique_ptr<QuestionAnswererCPP>> cQuestionAnswerer =
            BertQuestionAnswererCPP::CreateBertQuestionAnswerer(MakeString(modelPath),
                                                                MakeString(vocabPath));
        _GTMDevAssert(cQuestionAnswerer.ok(), @"Failed to create BertQuestionAnswerer");
        return [[TFLBertQuestionAnswerer alloc]
            initWithQuestionAnswerer:std::move(cQuestionAnswerer.value())];
      }
    
      // Calls the native API and converts C++ results into ObjC results
      - (NSArray<TFLQAAnswer *> *)answerWithContext:(NSString *)context question:(NSString *)question {
        std::vector<QaAnswerCPP> results =
          _bertQuestionAnswerwer->Answer(MakeString(context), MakeString(question));
        return [self arrayFromVector:results];
      }
    }