xref: /aosp_15_r20/external/libtextclassifier/native/utils/tflite/text_encoder3s.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "utils/tflite/text_encoder3s.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include <memory>
20*993b0882SAndroid Build Coastguard Worker #include <vector>
21*993b0882SAndroid Build Coastguard Worker 
22*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
23*993b0882SAndroid Build Coastguard Worker #include "utils/strings/stringpiece.h"
24*993b0882SAndroid Build Coastguard Worker #include "utils/tflite/encoder_common.h"
25*993b0882SAndroid Build Coastguard Worker #include "utils/tflite/text_encoder_config_generated.h"
26*993b0882SAndroid Build Coastguard Worker #include "utils/tokenfree/byte_encoder.h"
27*993b0882SAndroid Build Coastguard Worker #include "flatbuffers/flatbuffers.h"
28*993b0882SAndroid Build Coastguard Worker #include "flatbuffers/flexbuffers.h"
29*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/kernels/kernel_util.h"
30*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/model.h"
31*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/string_util.h"
32*993b0882SAndroid Build Coastguard Worker 
33*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
34*993b0882SAndroid Build Coastguard Worker namespace {
35*993b0882SAndroid Build Coastguard Worker 
36*993b0882SAndroid Build Coastguard Worker // Input parameters for the op.
37*993b0882SAndroid Build Coastguard Worker constexpr int kInputTextInd = 0;
38*993b0882SAndroid Build Coastguard Worker 
39*993b0882SAndroid Build Coastguard Worker constexpr int kTextLengthInd = 1;
40*993b0882SAndroid Build Coastguard Worker constexpr int kMaxLengthInd = 2;
41*993b0882SAndroid Build Coastguard Worker constexpr int kInputAttrInd = 3;
42*993b0882SAndroid Build Coastguard Worker 
43*993b0882SAndroid Build Coastguard Worker // Output parameters for the op.
44*993b0882SAndroid Build Coastguard Worker constexpr int kOutputEncodedInd = 0;
45*993b0882SAndroid Build Coastguard Worker constexpr int kOutputPositionInd = 1;
46*993b0882SAndroid Build Coastguard Worker constexpr int kOutputLengthsInd = 2;
47*993b0882SAndroid Build Coastguard Worker constexpr int kOutputAttrInd = 3;
48*993b0882SAndroid Build Coastguard Worker 
49*993b0882SAndroid Build Coastguard Worker // Initializes text encoder object from serialized parameters.
Initialize(TfLiteContext * context,const char * buffer,size_t length)50*993b0882SAndroid Build Coastguard Worker void* Initialize(TfLiteContext* context, const char* buffer, size_t length) {
51*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ByteEncoder> encoder(new ByteEncoder());
52*993b0882SAndroid Build Coastguard Worker   return encoder.release();
53*993b0882SAndroid Build Coastguard Worker }
54*993b0882SAndroid Build Coastguard Worker 
Free(TfLiteContext * context,void * buffer)55*993b0882SAndroid Build Coastguard Worker void Free(TfLiteContext* context, void* buffer) {
56*993b0882SAndroid Build Coastguard Worker   delete reinterpret_cast<ByteEncoder*>(buffer);
57*993b0882SAndroid Build Coastguard Worker }
58*993b0882SAndroid Build Coastguard Worker 
59*993b0882SAndroid Build Coastguard Worker namespace {
ResizeOutputTensors(TfLiteContext * context,TfLiteNode * node,int max_output_length)60*993b0882SAndroid Build Coastguard Worker TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
61*993b0882SAndroid Build Coastguard Worker                                  int max_output_length) {
62*993b0882SAndroid Build Coastguard Worker   TfLiteTensor& output_encoded =
63*993b0882SAndroid Build Coastguard Worker       context->tensors[node->outputs->data[kOutputEncodedInd]];
64*993b0882SAndroid Build Coastguard Worker 
65*993b0882SAndroid Build Coastguard Worker   TF_LITE_ENSURE_OK(
66*993b0882SAndroid Build Coastguard Worker       context, context->ResizeTensor(
67*993b0882SAndroid Build Coastguard Worker                    context, &output_encoded,
68*993b0882SAndroid Build Coastguard Worker                    CreateIntArray({kEncoderBatchSize, max_output_length})));
69*993b0882SAndroid Build Coastguard Worker   TfLiteTensor& output_positions =
70*993b0882SAndroid Build Coastguard Worker       context->tensors[node->outputs->data[kOutputPositionInd]];
71*993b0882SAndroid Build Coastguard Worker 
72*993b0882SAndroid Build Coastguard Worker   TF_LITE_ENSURE_OK(
73*993b0882SAndroid Build Coastguard Worker       context, context->ResizeTensor(
74*993b0882SAndroid Build Coastguard Worker                    context, &output_positions,
75*993b0882SAndroid Build Coastguard Worker                    CreateIntArray({kEncoderBatchSize, max_output_length})));
76*993b0882SAndroid Build Coastguard Worker 
77*993b0882SAndroid Build Coastguard Worker   const int num_output_attrs = node->outputs->size - kOutputAttrInd;
78*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < num_output_attrs; ++i) {
79*993b0882SAndroid Build Coastguard Worker     TfLiteTensor& output =
80*993b0882SAndroid Build Coastguard Worker         context->tensors[node->outputs->data[kOutputAttrInd + i]];
81*993b0882SAndroid Build Coastguard Worker     TF_LITE_ENSURE_OK(
82*993b0882SAndroid Build Coastguard Worker         context, context->ResizeTensor(
83*993b0882SAndroid Build Coastguard Worker                      context, &output,
84*993b0882SAndroid Build Coastguard Worker                      CreateIntArray({kEncoderBatchSize, max_output_length})));
85*993b0882SAndroid Build Coastguard Worker   }
86*993b0882SAndroid Build Coastguard Worker   return kTfLiteOk;
87*993b0882SAndroid Build Coastguard Worker }
88*993b0882SAndroid Build Coastguard Worker }  // namespace
89*993b0882SAndroid Build Coastguard Worker 
Prepare(TfLiteContext * context,TfLiteNode * node)90*993b0882SAndroid Build Coastguard Worker TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
91*993b0882SAndroid Build Coastguard Worker   // Check that the batch dimension is kEncoderBatchSize.
92*993b0882SAndroid Build Coastguard Worker   const TfLiteTensor& input_text =
93*993b0882SAndroid Build Coastguard Worker       context->tensors[node->inputs->data[kInputTextInd]];
94*993b0882SAndroid Build Coastguard Worker   TF_LITE_ENSURE_EQ(context, input_text.dims->size, kEncoderInputRank);
95*993b0882SAndroid Build Coastguard Worker   TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], kEncoderBatchSize);
96*993b0882SAndroid Build Coastguard Worker 
97*993b0882SAndroid Build Coastguard Worker   TfLiteTensor& output_lengths =
98*993b0882SAndroid Build Coastguard Worker       context->tensors[node->outputs->data[kOutputLengthsInd]];
99*993b0882SAndroid Build Coastguard Worker 
100*993b0882SAndroid Build Coastguard Worker   TfLiteTensor& output_encoded =
101*993b0882SAndroid Build Coastguard Worker       context->tensors[node->outputs->data[kOutputEncodedInd]];
102*993b0882SAndroid Build Coastguard Worker   TfLiteTensor& output_positions =
103*993b0882SAndroid Build Coastguard Worker       context->tensors[node->outputs->data[kOutputPositionInd]];
104*993b0882SAndroid Build Coastguard Worker   output_encoded.type = kTfLiteInt32;
105*993b0882SAndroid Build Coastguard Worker   output_positions.type = kTfLiteInt32;
106*993b0882SAndroid Build Coastguard Worker   output_lengths.type = kTfLiteInt32;
107*993b0882SAndroid Build Coastguard Worker 
108*993b0882SAndroid Build Coastguard Worker   TF_LITE_ENSURE_OK(context,
109*993b0882SAndroid Build Coastguard Worker                     context->ResizeTensor(context, &output_lengths,
110*993b0882SAndroid Build Coastguard Worker                                           CreateIntArray({kEncoderBatchSize})));
111*993b0882SAndroid Build Coastguard Worker 
112*993b0882SAndroid Build Coastguard Worker   // Check that there are enough outputs for attributes.
113*993b0882SAndroid Build Coastguard Worker   const int num_output_attrs = node->outputs->size - kOutputAttrInd;
114*993b0882SAndroid Build Coastguard Worker   TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttrInd,
115*993b0882SAndroid Build Coastguard Worker                     num_output_attrs);
116*993b0882SAndroid Build Coastguard Worker 
117*993b0882SAndroid Build Coastguard Worker   // Copy attribute types from input to output tensors.
118*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < num_output_attrs; ++i) {
119*993b0882SAndroid Build Coastguard Worker     TfLiteTensor& input =
120*993b0882SAndroid Build Coastguard Worker         context->tensors[node->inputs->data[kInputAttrInd + i]];
121*993b0882SAndroid Build Coastguard Worker     TfLiteTensor& output =
122*993b0882SAndroid Build Coastguard Worker         context->tensors[node->outputs->data[kOutputAttrInd + i]];
123*993b0882SAndroid Build Coastguard Worker     output.type = input.type;
124*993b0882SAndroid Build Coastguard Worker   }
125*993b0882SAndroid Build Coastguard Worker 
126*993b0882SAndroid Build Coastguard Worker   const TfLiteTensor& output_length =
127*993b0882SAndroid Build Coastguard Worker       context->tensors[node->inputs->data[kMaxLengthInd]];
128*993b0882SAndroid Build Coastguard Worker 
129*993b0882SAndroid Build Coastguard Worker   if (tflite::IsConstantTensor(&output_length)) {
130*993b0882SAndroid Build Coastguard Worker     return ResizeOutputTensors(context, node, output_length.data.i64[0]);
131*993b0882SAndroid Build Coastguard Worker   } else {
132*993b0882SAndroid Build Coastguard Worker     tflite::SetTensorToDynamic(&output_encoded);
133*993b0882SAndroid Build Coastguard Worker     tflite::SetTensorToDynamic(&output_positions);
134*993b0882SAndroid Build Coastguard Worker     for (int i = 0; i < num_output_attrs; ++i) {
135*993b0882SAndroid Build Coastguard Worker       TfLiteTensor& output_attr =
136*993b0882SAndroid Build Coastguard Worker           context->tensors[node->outputs->data[kOutputAttrInd + i]];
137*993b0882SAndroid Build Coastguard Worker       tflite::SetTensorToDynamic(&output_attr);
138*993b0882SAndroid Build Coastguard Worker     }
139*993b0882SAndroid Build Coastguard Worker   }
140*993b0882SAndroid Build Coastguard Worker 
141*993b0882SAndroid Build Coastguard Worker   return kTfLiteOk;
142*993b0882SAndroid Build Coastguard Worker }
143*993b0882SAndroid Build Coastguard Worker 
Eval(TfLiteContext * context,TfLiteNode * node)144*993b0882SAndroid Build Coastguard Worker TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
145*993b0882SAndroid Build Coastguard Worker   if (node->user_data == nullptr) {
146*993b0882SAndroid Build Coastguard Worker     return kTfLiteError;
147*993b0882SAndroid Build Coastguard Worker   }
148*993b0882SAndroid Build Coastguard Worker   auto text_encoder = reinterpret_cast<ByteEncoder*>(node->user_data);
149*993b0882SAndroid Build Coastguard Worker   const TfLiteTensor& input_text =
150*993b0882SAndroid Build Coastguard Worker       context->tensors[node->inputs->data[kInputTextInd]];
151*993b0882SAndroid Build Coastguard Worker   const int num_strings_in_tensor = tflite::GetStringCount(&input_text);
152*993b0882SAndroid Build Coastguard Worker   const int num_strings =
153*993b0882SAndroid Build Coastguard Worker       context->tensors[node->inputs->data[kTextLengthInd]].data.i32[0];
154*993b0882SAndroid Build Coastguard Worker 
155*993b0882SAndroid Build Coastguard Worker   // Check that the number of strings is not bigger than the input tensor size.
156*993b0882SAndroid Build Coastguard Worker   TF_LITE_ENSURE(context, num_strings_in_tensor >= num_strings);
157*993b0882SAndroid Build Coastguard Worker 
158*993b0882SAndroid Build Coastguard Worker   TfLiteTensor& output_encoded =
159*993b0882SAndroid Build Coastguard Worker       context->tensors[node->outputs->data[kOutputEncodedInd]];
160*993b0882SAndroid Build Coastguard Worker   if (tflite::IsDynamicTensor(&output_encoded)) {
161*993b0882SAndroid Build Coastguard Worker     const TfLiteTensor& output_length =
162*993b0882SAndroid Build Coastguard Worker         context->tensors[node->inputs->data[kMaxLengthInd]];
163*993b0882SAndroid Build Coastguard Worker     TF_LITE_ENSURE_OK(
164*993b0882SAndroid Build Coastguard Worker         context, ResizeOutputTensors(context, node, output_length.data.i64[0]));
165*993b0882SAndroid Build Coastguard Worker   }
166*993b0882SAndroid Build Coastguard Worker   TfLiteTensor& output_positions =
167*993b0882SAndroid Build Coastguard Worker       context->tensors[node->outputs->data[kOutputPositionInd]];
168*993b0882SAndroid Build Coastguard Worker 
169*993b0882SAndroid Build Coastguard Worker   std::vector<int> encoded_total;
170*993b0882SAndroid Build Coastguard Worker   std::vector<int> encoded_positions;
171*993b0882SAndroid Build Coastguard Worker   std::vector<int> encoded_offsets;
172*993b0882SAndroid Build Coastguard Worker   encoded_offsets.reserve(num_strings);
173*993b0882SAndroid Build Coastguard Worker   const int max_output_length = output_encoded.dims->data[1];
174*993b0882SAndroid Build Coastguard Worker   const int max_encoded_position = max_output_length;
175*993b0882SAndroid Build Coastguard Worker 
176*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < num_strings; ++i) {
177*993b0882SAndroid Build Coastguard Worker     const auto& strref = tflite::GetString(&input_text, i);
178*993b0882SAndroid Build Coastguard Worker     std::vector<int64_t> encoded;
179*993b0882SAndroid Build Coastguard Worker     text_encoder->Encode(
180*993b0882SAndroid Build Coastguard Worker         libtextclassifier3::StringPiece(strref.str, strref.len), &encoded);
181*993b0882SAndroid Build Coastguard Worker     encoded_total.insert(encoded_total.end(), encoded.begin(), encoded.end());
182*993b0882SAndroid Build Coastguard Worker     encoded_offsets.push_back(encoded_total.size());
183*993b0882SAndroid Build Coastguard Worker     for (int i = 0; i < encoded.size(); ++i) {
184*993b0882SAndroid Build Coastguard Worker       encoded_positions.push_back(std::min(i, max_encoded_position - 1));
185*993b0882SAndroid Build Coastguard Worker     }
186*993b0882SAndroid Build Coastguard Worker   }
187*993b0882SAndroid Build Coastguard Worker 
188*993b0882SAndroid Build Coastguard Worker   // Copy encoding to output tensor.
189*993b0882SAndroid Build Coastguard Worker   const int start_offset =
190*993b0882SAndroid Build Coastguard Worker       std::max(0, static_cast<int>(encoded_total.size()) - max_output_length);
191*993b0882SAndroid Build Coastguard Worker   int output_offset = 0;
192*993b0882SAndroid Build Coastguard Worker   int32_t* output_buffer = output_encoded.data.i32;
193*993b0882SAndroid Build Coastguard Worker   int32_t* output_positions_buffer = output_positions.data.i32;
194*993b0882SAndroid Build Coastguard Worker   for (int i = start_offset; i < encoded_total.size(); ++i, ++output_offset) {
195*993b0882SAndroid Build Coastguard Worker     output_buffer[output_offset] = encoded_total[i];
196*993b0882SAndroid Build Coastguard Worker     output_positions_buffer[output_offset] = encoded_positions[i];
197*993b0882SAndroid Build Coastguard Worker   }
198*993b0882SAndroid Build Coastguard Worker 
199*993b0882SAndroid Build Coastguard Worker   // Save output encoded length.
200*993b0882SAndroid Build Coastguard Worker   TfLiteTensor& output_lengths =
201*993b0882SAndroid Build Coastguard Worker       context->tensors[node->outputs->data[kOutputLengthsInd]];
202*993b0882SAndroid Build Coastguard Worker   output_lengths.data.i32[0] = output_offset;
203*993b0882SAndroid Build Coastguard Worker 
204*993b0882SAndroid Build Coastguard Worker   // Do padding.
205*993b0882SAndroid Build Coastguard Worker   for (; output_offset < max_output_length; ++output_offset) {
206*993b0882SAndroid Build Coastguard Worker     output_buffer[output_offset] = 0;
207*993b0882SAndroid Build Coastguard Worker     output_positions_buffer[output_offset] = 0;
208*993b0882SAndroid Build Coastguard Worker   }
209*993b0882SAndroid Build Coastguard Worker 
210*993b0882SAndroid Build Coastguard Worker   // Process attributes, all checks of sizes and types are done in Prepare.
211*993b0882SAndroid Build Coastguard Worker   const int num_output_attrs = node->outputs->size - kOutputAttrInd;
212*993b0882SAndroid Build Coastguard Worker   TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttrInd,
213*993b0882SAndroid Build Coastguard Worker                     num_output_attrs);
214*993b0882SAndroid Build Coastguard Worker   for (int i = 0; i < num_output_attrs; ++i) {
215*993b0882SAndroid Build Coastguard Worker     TfLiteStatus attr_status = CopyValuesToTensorAndPadOrTruncate(
216*993b0882SAndroid Build Coastguard Worker         context->tensors[node->inputs->data[kInputAttrInd + i]],
217*993b0882SAndroid Build Coastguard Worker         encoded_offsets, start_offset, context,
218*993b0882SAndroid Build Coastguard Worker         &context->tensors[node->outputs->data[kOutputAttrInd + i]]);
219*993b0882SAndroid Build Coastguard Worker     if (attr_status != kTfLiteOk) {
220*993b0882SAndroid Build Coastguard Worker       return attr_status;
221*993b0882SAndroid Build Coastguard Worker     }
222*993b0882SAndroid Build Coastguard Worker   }
223*993b0882SAndroid Build Coastguard Worker 
224*993b0882SAndroid Build Coastguard Worker   return kTfLiteOk;
225*993b0882SAndroid Build Coastguard Worker }
226*993b0882SAndroid Build Coastguard Worker 
227*993b0882SAndroid Build Coastguard Worker }  // namespace
228*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
229*993b0882SAndroid Build Coastguard Worker 
230*993b0882SAndroid Build Coastguard Worker namespace tflite {
231*993b0882SAndroid Build Coastguard Worker namespace ops {
232*993b0882SAndroid Build Coastguard Worker namespace custom {
233*993b0882SAndroid Build Coastguard Worker 
Register_TEXT_ENCODER3S()234*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_TEXT_ENCODER3S() {
235*993b0882SAndroid Build Coastguard Worker   static TfLiteRegistration registration = {
236*993b0882SAndroid Build Coastguard Worker       libtextclassifier3::Initialize, libtextclassifier3::Free,
237*993b0882SAndroid Build Coastguard Worker       libtextclassifier3::Prepare, libtextclassifier3::Eval};
238*993b0882SAndroid Build Coastguard Worker   return &registration;
239*993b0882SAndroid Build Coastguard Worker }
240*993b0882SAndroid Build Coastguard Worker 
241*993b0882SAndroid Build Coastguard Worker }  // namespace custom
242*993b0882SAndroid Build Coastguard Worker }  // namespace ops
243*993b0882SAndroid Build Coastguard Worker }  // namespace tflite
244