1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <jni.h>
17
18 #include "tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h"
19 #include "tensorflow_lite_support/cc/utils/jni_utils.h"
20
21 namespace {
22
23 using ::tflite::support::utils::ConvertVectorToArrayList;
24 using ::tflite::support::utils::GetMappedFileBuffer;
25 using ::tflite::support::utils::JStringToString;
26 using ::tflite::task::text::qa::BertQuestionAnswerer;
27 using ::tflite::task::text::qa::QaAnswer;
28 using ::tflite::task::text::qa::QuestionAnswerer;
29
30 constexpr int kInvalidPointer = 0;
31
32 extern "C" JNIEXPORT void JNICALL
Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_deinitJni(JNIEnv * env,jobject thiz,jlong native_handle)33 Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_deinitJni(
34 JNIEnv* env, jobject thiz, jlong native_handle) {
35 delete reinterpret_cast<QuestionAnswerer*>(native_handle);
36 }
37
38 extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithModelWithMetadataByteBuffers(JNIEnv * env,jclass thiz,jobjectArray model_buffers)39 Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithModelWithMetadataByteBuffers(
40 JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
41 absl::string_view model_with_metadata =
42 GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
43
44 tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
45 BertQuestionAnswerer::CreateFromBuffer(
46 model_with_metadata.data(), model_with_metadata.size());
47 if (status.ok()) {
48 return reinterpret_cast<jlong>(status->release());
49 } else {
50 return kInvalidPointer;
51 }
52 }
53
54 extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithFileDescriptor(JNIEnv * env,jclass thiz,jint fd)55 Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithFileDescriptor(
56 JNIEnv* env, jclass thiz, jint fd) {
57 tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
58 BertQuestionAnswerer::CreateFromFd(fd);
59 if (status.ok()) {
60 return reinterpret_cast<jlong>(status->release());
61 } else {
62 return kInvalidPointer;
63 }
64 }
65
66 extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBuffers(JNIEnv * env,jclass thiz,jobjectArray model_buffers)67 Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBuffers(
68 JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
69 absl::string_view model =
70 GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
71 absl::string_view vocab =
72 GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 1));
73
74 tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
75 BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer(
76 model.data(), model.size(), vocab.data(), vocab.size());
77 if (status.ok()) {
78 return reinterpret_cast<jlong>(status->release());
79 } else {
80 return kInvalidPointer;
81 }
82 }
83
84 extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithAlbertByteBuffers(JNIEnv * env,jclass thiz,jobjectArray model_buffers)85 Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithAlbertByteBuffers(
86 JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
87 absl::string_view model =
88 GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
89 absl::string_view sp_model =
90 GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 1));
91
92 tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
93 BertQuestionAnswerer::CreateAlbertQuestionAnswererFromBuffer(
94 model.data(), model.size(), sp_model.data(), sp_model.size());
95 if (status.ok()) {
96 return reinterpret_cast<jlong>(status->release());
97 } else {
98 return kInvalidPointer;
99 }
100 }
101
102 extern "C" JNIEXPORT jobject JNICALL
Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative(JNIEnv * env,jclass thiz,jlong native_handle,jstring context,jstring question)103 Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative(
104 JNIEnv* env, jclass thiz, jlong native_handle, jstring context,
105 jstring question) {
106 auto* question_answerer = reinterpret_cast<QuestionAnswerer*>(native_handle);
107
108 std::vector<QaAnswer> results = question_answerer->Answer(
109 JStringToString(env, context), JStringToString(env, question));
110 jclass qa_answer_class =
111 env->FindClass("org/tensorflow/lite/task/text/qa/QaAnswer");
112 jmethodID qa_answer_ctor =
113 env->GetMethodID(qa_answer_class, "<init>", "(Ljava/lang/String;IIF)V");
114
115 return ConvertVectorToArrayList<QaAnswer>(
116 env, results,
117 [env, qa_answer_class, qa_answer_ctor](const QaAnswer& ans) {
118 jstring text = env->NewStringUTF(ans.text.data());
119 jobject qa_answer =
120 env->NewObject(qa_answer_class, qa_answer_ctor, text, ans.pos.start,
121 ans.pos.end, ans.pos.logit);
122 env->DeleteLocalRef(text);
123 return qa_answer;
124 });
125 }
126
127 } // namespace
128