构建您自己的任务 API

TensorFlow Lite Task Library 在抽象 TensorFlow 的相同基础架构上提供了预构建的原生/Android/iOS API。如果现有的 Task 库不支持您的模型,您可以扩展 Task API 基础架构来构建自定义 API。

概述

Task API 基础架构有两层结构:底部的 C++ 层封装了原生的 TFLite 运行时,顶部的 Java/ObjC 层通过 JNI 或原生封装容器与 C++ 层通信。

只用 C++ 实现所有 TensorFlow 逻辑,能够最大限度地降低成本,最大限度地提高推断性能,并简化跨平台的整体工作流。

要创建 Task 类,请扩展 BaseTaskApi 来提供 TFLite 模型接口和 Task API 接口之间的转换逻辑,然后使用 Java/ObjC 实用工具来创建相应的 API。由于隐藏了所有的 TensorFlow 细节,您可以在没有任何机器学习知识的情况下在应用中部署 TFLite 模型。

TensorFlow Lite 为最热门的视觉和 NLP 任务提供了一些预构建的 API。您可以使用 Task API 基础架构为其他任务构建自己的 API。

![prebuilt_task_apis](images/prebuilt_task_apis.svg)
图 1. 预构建的 Task API
使用 Task API 基础架构构建您自己的 API C++ API

所有 TFLite 详细信息都在原生 API 中实现。使用一个工厂函数创建 API 对象,并通过调用在接口中定义的函数获取模型结果。

示例用法

下面是一个将 C++ BertQuestionAnswerer 用于 MobileBert 的示例。

  char kBertModelPath[] = "path/to/model.tflite";
  // Create the API from a model file
  std::unique_ptr<BertQuestionAnswerer> question_answerer =
      BertQuestionAnswerer::CreateFromFile(kBertModelPath);

  char kContext[] = ...; // context of a question to be answered
  char kQuestion[] = ...; // question to be answered
  // ask a question
  std::vector<QaAnswer> answers = question_answerer.Answer(kContext, kQuestion);
  // answers[0].text is the best answer
构建 API
![native_task_api](images/native_task_api.svg)
图 2. 原生 Task API

要构建 API 对象,您必须通过扩展 BaseTaskApi 提供以下信息

  • 确定 API I/O - 您的 API 应在不同的平台之间公开相似的输入/输出。例如,BertQuestionAnswerer 将两个字符串 (std::string& context, std::string& question) 作为输入,并以 std::vector<QaAnswer> 形式输出可能的回答和概率的矢量。为此,您需要在 BaseTaskApi 的[模板参数](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/core/base_task_api.h?q="template <class OutputType, class... InputTypes>")中指定对应类型。指定模板参数后,[BaseTaskApi::Infer](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/core/base_task_api.h?q="Infer(InputTypes... args)") 函数将具有正确的输入/输出类型。此函数可以由 API 客户端直接调用,但是最好将其封装在模型特定的函数中,在本例中为 BertQuestionAnswerer::Answer

    class BertQuestionAnswerer : public BaseTaskApi<
                                  std::vector<QaAnswer>, // OutputType
                                  const std::string&, const std::string& // InputTypes
                                  > {
      // Model specific function delegating calls to BaseTaskApi::Infer
      std::vector<QaAnswer> Answer(const std::string& context, const std::string& question) {
        return Infer(context, question).value();
      }
    }
    
  • 提供 API I/O 与模型输入/输出张量之间的转换逻辑 - 指定输入和输出类型后,子类还需要实现类型函数 BaseTaskApi::PreprocessBaseTaskApi::Postprocess。这两个函数从 TFLite FlatBuffer 提供输入输出。子类负责将值从 API I/O 分配给 I/O 张量。请参阅 BertQuestionAnswerer 中的完整实现示例。

    class BertQuestionAnswerer : public BaseTaskApi<
                                  std::vector<QaAnswer>, // OutputType
                                  const std::string&, const std::string& // InputTypes
                                  > {
      // Convert API input into tensors
      absl::Status BertQuestionAnswerer::Preprocess(
        const std::vector<TfLiteTensor*>& input_tensors, // input tensors of the model
        const std::string& context, const std::string& query // InputType of the API
      ) {
        // Perform tokenization on input strings
        ...
        // Populate IDs, Masks and SegmentIDs to corresponding input tensors
        PopulateTensor(input_ids, input_tensors[0]);
        PopulateTensor(input_mask, input_tensors[1]);
        PopulateTensor(segment_ids, input_tensors[2]);
        return absl::OkStatus();
      }
    
      // Convert output tensors into API output
      StatusOr<std::vector<QaAnswer>> // OutputType
      BertQuestionAnswerer::Postprocess(
        const std::vector<const TfLiteTensor*>& output_tensors, // output tensors of the model
      ) {
        // Get start/end logits of prediction result from output tensors
        std::vector<float> end_logits;
        std::vector<float> start_logits;
        // output_tensors[0]: end_logits FLOAT[1, 384]
        PopulateVector(output_tensors[0], &end_logits);
        // output_tensors[1]: start_logits FLOAT[1, 384]
        PopulateVector(output_tensors[1], &start_logits);
        ...
        std::vector<QaAnswer::Pos> orig_results;
        // Look up the indices from vocabulary file and build results
        ...
        return orig_results;
      }
    }
    
  • 创建 API 的工厂函数 - 初始化 tflite::Interpreter 需要模型文件和 OpResolverTaskAPIFactory 提供了创建 BaseTaskApi 实例的效用函数。

    注:默认情况下,TaskAPIFactory 提供了 BuiltInOpResolver。如果您的模型需要自定义运算或一部分内置内置运算,您可以通过创建一个 MutableOpResolver 来注册它们。

    您还必须提供与模型关联的任何文件。例如,BertQuestionAnswerer 也可以为其分词器的词汇额外使用一个文件。

    class BertQuestionAnswerer : public BaseTaskApi<
                                  std::vector<QaAnswer>, // OutputType
                                  const std::string&, const std::string& // InputTypes
                                  > {
      // Factory function to create the API instance
      StatusOr<std::unique_ptr<QuestionAnswerer>>
      BertQuestionAnswerer::CreateBertQuestionAnswerer(
          const std::string& path_to_model, // model to passed to TaskApiFactory
          const std::string& path_to_vocab  // additional model specific files
      ) {
        // Creates an API object by calling one of the utils from TaskAPIFactory
        std::unique_ptr<BertQuestionAnswerer> api_to_init;
        ASSIGN_OR_RETURN(
            api_to_init,
            core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>(
                path_to_model,
                absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
                kNumLiteThreads));
    
        // Perform additional model specific initializations
        // In this case building a vocabulary vector from the vocab file.
        api_to_init->InitializeVocab(path_to_vocab);
        return api_to_init;
      }
    }
    
Android API

定义 Java/Kotlin 接口并通过 JNI 将逻辑委托给 C++ 层以创建 Android API。Android API 要求先构建原生 API。

示例用法

下面是一个将 Java BertQuestionAnswerer 用于 MobileBert 的示例。

  String BERT_MODEL_FILE = "path/to/model.tflite";
  String VOCAB_FILE = "path/to/vocab.txt";
  // Create the API from a model file and vocabulary file
    BertQuestionAnswerer bertQuestionAnswerer =
        BertQuestionAnswerer.createBertQuestionAnswerer(
            ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE, VOCAB_FILE);

  String CONTEXT = ...; // context of a question to be answered
  String QUESTION = ...; // question to be answered
  // ask a question
  List<QaAnswer> answers = bertQuestionAnswerer.answer(CONTEXT, QUESTION);
  // answers.get(0).text is the best answer
构建 API
![android_task_api](images/android_task_api.svg)
图 3. Android Task API

与原生 API 类似,要构建 API 对象,客户端也需要通过扩展 BaseTaskApi 提供以下信息,后者为所有 Java Task API 提供了 JNI 处理。

  • 确定 API I/O - 这通常可以镜像原生接口,例如 BertQuestionAnswerer(String context, String question) 作为输入,并输出 List<QaAnswer>。该实现调用了一个具有类似签名的私有原生函数,但该函数具有一个额外的参数 long nativeHandle,这是从 C++ 返回的指针。

    class BertQuestionAnswerer extends BaseTaskApi {
      public List<QaAnswer> answer(String context, String question) {
        return answerNative(getNativeHandle(), context, question);
      }
    
      private static native List<QaAnswer> answerNative(
                                            long nativeHandle, // C++ pointer
                                            String context, String question // API I/O
                                           );
    
    }
    
  • 创建 API 的工厂函数 - 这也会镜像原生工厂函数,但获取文件访问的 Context 也需要 Android 工厂函数。该实现调用了 TaskJniUtils 中的一个实用工具来构建对应的 C++ API 对象并将其指针传递给 BaseTaskApi 构造函数。

      class BertQuestionAnswerer extends BaseTaskApi {
        private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME =
                                                  "bert_question_answerer_jni";
    
        // Extending super constructor by providing the
        // native handle(pointer of corresponding C++ API object)
        private BertQuestionAnswerer(long nativeHandle) {
          super(nativeHandle);
        }
    
        public static BertQuestionAnswerer createBertQuestionAnswerer(
                                            Context context, // Accessing Android files
                                            String pathToModel, String pathToVocab) {
          return new BertQuestionAnswerer(
              // The util first try loads the JNI module with name
              // BERT_QUESTION_ANSWERER_NATIVE_LIBNAME, then opens two files,
              // converts them into ByteBuffer, finally ::initJniWithBertByteBuffers
              // is called with the buffer for a C++ API object pointer
              TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
                  context,
                  BertQuestionAnswerer::initJniWithBertByteBuffers,
                  BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
                  pathToModel,
                  pathToVocab));
        }
    
        // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer.
        // returns C++ API object pointer casted to long
        private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers);
    
      }
    
  • 为原生函数实现 JNI 模块 - 所有 Java 原生方法均通过从 JNI 模块调用对应的原生函数来实现。工厂函数将创建一个原生 API 对象并将其指针作为长类型返回给 Java。在对 Java API 的后续调用中,长类型指针将被传回 JNI 并转换回原生 API 对象。随后,原生 API 结果会被转换回 Java 结果。

    例如,以下是 bert_question_answerer_jni 的实现方式。

      // Implements BertQuestionAnswerer::initJniWithBertByteBuffers
      extern "C" JNIEXPORT jlong JNICALL
      Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBuffers(
          JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
        // Convert Java ByteBuffer object into a buffer that can be read by native factory functions
        absl::string_view model =
            GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
    
        // Creates the native API object
        absl::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
            BertQuestionAnswerer::CreateFromBuffer(
                model.data(), model.size());
        if (status.ok()) {
          // converts the object pointer to jlong and return to Java.
          return reinterpret_cast<jlong>(status->release());
        } else {
          return kInvalidPointer;
        }
      }
    
      // Implements BertQuestionAnswerer::answerNative
      extern "C" JNIEXPORT jobject JNICALL
      Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative(
      JNIEnv* env, jclass thiz, jlong native_handle, jstring context, jstring question) {
      // Convert long to native API object pointer
      QuestionAnswerer* question_answerer = reinterpret_cast<QuestionAnswerer*>(native_handle);
    
      // Calls the native API
      std::vector<QaAnswer> results = question_answerer->Answer(JStringToString(env, context),
                                             JStringToString(env, question));
    
      // Converts native result(std::vector<QaAnswer>) to Java result(List<QaAnswerer>)
      jclass qa_answer_class =
        env->FindClass("org/tensorflow/lite/task/text/qa/QaAnswer");
      jmethodID qa_answer_ctor =
        env->GetMethodID(qa_answer_class, "<init>", "(Ljava/lang/String;IIF)V");
      return ConvertVectorToArrayList<QaAnswer>(
        env, results,
        [env, qa_answer_class, qa_answer_ctor](const QaAnswer& ans) {
          jstring text = env->NewStringUTF(ans.text.data());
          jobject qa_answer =
              env->NewObject(qa_answer_class, qa_answer_ctor, text, ans.pos.start,
                             ans.pos.end, ans.pos.logit);
          env->DeleteLocalRef(text);
          return qa_answer;
        });
      }
    
      // Implements BaseTaskApi::deinitJni by delete the native object
      extern "C" JNIEXPORT void JNICALL Java_task_core_BaseTaskApi_deinitJni(
          JNIEnv* env, jobject thiz, jlong native_handle) {
        delete reinterpret_cast<QuestionAnswerer*>(native_handle);
      }
    
iOS API

通过将原生 API 对象封装到 ObjC API 对象中来创建 iOS API。创建的 API 对象可以用于 ObjC 或 Swift。iOS API 要求先构建原生 API。

Sample usage

下面是一个在 Swift 中将 ObjC TFLBertQuestionAnswerer 用于 MobileBert 的示例。

  static let mobileBertModelPath = "path/to/model.tflite";
  // Create the API from a model file and vocabulary file
  let mobileBertAnswerer = TFLBertQuestionAnswerer.mobilebertQuestionAnswerer(
      modelPath: mobileBertModelPath)

  static let context = ...; // context of a question to be answered
  static let question = ...; // question to be answered
  // ask a question
  let answers = mobileBertAnswerer.answer(
      context: TFLBertQuestionAnswererTest.context, question: TFLBertQuestionAnswererTest.question)
  // answers.[0].text is the best answer
构建 API
![ios_task_api](images/ios_task_api.svg)
图 4. iOS Task API

iOS API 是一个基于原生 API 的简单 ObjC 封装容器。按照以下步骤操作来构建 API:

  • 定义 ObjC 封装容器 - 定义一个 ObjC 类并将实现委托给对应的原生 API 对象。请注意,由于 Swift 无法与 C++ 互操作,原生依赖项仅可以出现在 .mm 文件中。

    • .h 文件
      @interface TFLBertQuestionAnswerer : NSObject
    
      // Delegate calls to the native BertQuestionAnswerer::CreateBertQuestionAnswerer
    
      + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString*)modelPath
                                                    vocabPath:(NSString*)vocabPath
          NS_SWIFT_NAME(mobilebertQuestionAnswerer(modelPath:vocabPath:));
    
      // Delegate calls to the native BertQuestionAnswerer::Answer
    
      - (NSArray<TFLQAAnswer*>*)answerWithContext:(NSString*)context
                                         question:(NSString*)question
          NS_SWIFT_NAME(answer(context:question:));
    }
    
    • .mm 文件
      using BertQuestionAnswererCPP = ::tflite::task::text::qa::BertQuestionAnswerer;
    
      @implementation TFLBertQuestionAnswerer {
        // define an iVar for the native API object
        std::unique_ptr<QuestionAnswererCPP> _bertQuestionAnswerwer;
      }
    
      // Initialize the native API object
    
      + (instancetype)mobilebertQuestionAnswererWithModelPath:(NSString *)modelPath
                                              vocabPath:(NSString *)vocabPath {
        absl::StatusOr<std::unique_ptr<QuestionAnswererCPP>> cQuestionAnswerer =
            BertQuestionAnswererCPP::CreateBertQuestionAnswerer(MakeString(modelPath),
                                                                MakeString(vocabPath));
        _GTMDevAssert(cQuestionAnswerer.ok(), @"Failed to create BertQuestionAnswerer");
        return [[TFLBertQuestionAnswerer alloc]
            initWithQuestionAnswerer:std::move(cQuestionAnswerer.value())];
      }
    
      // Calls the native API and converts C++ results into ObjC results
    
      - (NSArray<TFLQAAnswer *> *)answerWithContext:(NSString *)context question:(NSString *)question {
        std::vector<QaAnswerCPP> results =
          _bertQuestionAnswerwer->Answer(MakeString(context), MakeString(question));
        return [self arrayFromVector:results];
      }
    }