1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <jni.h>
17 
18 #include "tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h"
19 #include "tensorflow_lite_support/cc/utils/jni_utils.h"
20 
21 namespace {
22 
23 using ::tflite::support::utils::ConvertVectorToArrayList;
24 using ::tflite::support::utils::GetMappedFileBuffer;
25 using ::tflite::support::utils::JStringToString;
26 using ::tflite::task::text::qa::BertQuestionAnswerer;
27 using ::tflite::task::text::qa::QaAnswer;
28 using ::tflite::task::text::qa::QuestionAnswerer;
29 
30 constexpr int kInvalidPointer = 0;
31 
32 extern "C" JNIEXPORT void JNICALL
Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_deinitJni(JNIEnv * env,jobject thiz,jlong native_handle)33 Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_deinitJni(
34     JNIEnv* env, jobject thiz, jlong native_handle) {
35   delete reinterpret_cast<QuestionAnswerer*>(native_handle);
36 }
37 
38 extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithModelWithMetadataByteBuffers(JNIEnv * env,jclass thiz,jobjectArray model_buffers)39 Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithModelWithMetadataByteBuffers(
40     JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
41   absl::string_view model_with_metadata =
42       GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
43 
44   tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
45       BertQuestionAnswerer::CreateFromBuffer(
46           model_with_metadata.data(), model_with_metadata.size());
47   if (status.ok()) {
48     return reinterpret_cast<jlong>(status->release());
49   } else {
50     return kInvalidPointer;
51   }
52 }
53 
54 extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithFileDescriptor(JNIEnv * env,jclass thiz,jint fd)55 Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithFileDescriptor(
56     JNIEnv* env, jclass thiz, jint fd) {
57   tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
58       BertQuestionAnswerer::CreateFromFd(fd);
59   if (status.ok()) {
60     return reinterpret_cast<jlong>(status->release());
61   } else {
62     return kInvalidPointer;
63   }
64 }
65 
66 extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBuffers(JNIEnv * env,jclass thiz,jobjectArray model_buffers)67 Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBuffers(
68     JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
69   absl::string_view model =
70       GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
71   absl::string_view vocab =
72       GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 1));
73 
74   tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
75       BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer(
76           model.data(), model.size(), vocab.data(), vocab.size());
77   if (status.ok()) {
78     return reinterpret_cast<jlong>(status->release());
79   } else {
80     return kInvalidPointer;
81   }
82 }
83 
84 extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithAlbertByteBuffers(JNIEnv * env,jclass thiz,jobjectArray model_buffers)85 Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithAlbertByteBuffers(
86     JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
87   absl::string_view model =
88       GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
89   absl::string_view sp_model =
90       GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 1));
91 
92   tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
93       BertQuestionAnswerer::CreateAlbertQuestionAnswererFromBuffer(
94           model.data(), model.size(), sp_model.data(), sp_model.size());
95   if (status.ok()) {
96     return reinterpret_cast<jlong>(status->release());
97   } else {
98     return kInvalidPointer;
99   }
100 }
101 
102 extern "C" JNIEXPORT jobject JNICALL
Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative(JNIEnv * env,jclass thiz,jlong native_handle,jstring context,jstring question)103 Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative(
104     JNIEnv* env, jclass thiz, jlong native_handle, jstring context,
105     jstring question) {
106   auto* question_answerer = reinterpret_cast<QuestionAnswerer*>(native_handle);
107 
108   std::vector<QaAnswer> results = question_answerer->Answer(
109       JStringToString(env, context), JStringToString(env, question));
110   jclass qa_answer_class =
111       env->FindClass("org/tensorflow/lite/task/text/qa/QaAnswer");
112   jmethodID qa_answer_ctor =
113       env->GetMethodID(qa_answer_class, "<init>", "(Ljava/lang/String;IIF)V");
114 
115   return ConvertVectorToArrayList<QaAnswer>(
116       env, results,
117       [env, qa_answer_class, qa_answer_ctor](const QaAnswer& ans) {
118         jstring text = env->NewStringUTF(ans.text.data());
119         jobject qa_answer =
120             env->NewObject(qa_answer_class, qa_answer_ctor, text, ans.pos.start,
121                            ans.pos.end, ans.pos.logit);
122         env->DeleteLocalRef(text);
123         return qa_answer;
124       });
125 }
126 
127 }  // namespace
128