1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker *
4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker *
8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker *
10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker */
16*993b0882SAndroid Build Coastguard Worker
17*993b0882SAndroid Build Coastguard Worker #include "utils/tflite/text_encoder3s.h"
18*993b0882SAndroid Build Coastguard Worker
19*993b0882SAndroid Build Coastguard Worker #include <memory>
20*993b0882SAndroid Build Coastguard Worker #include <vector>
21*993b0882SAndroid Build Coastguard Worker
22*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
23*993b0882SAndroid Build Coastguard Worker #include "utils/strings/stringpiece.h"
24*993b0882SAndroid Build Coastguard Worker #include "utils/tflite/encoder_common.h"
25*993b0882SAndroid Build Coastguard Worker #include "utils/tflite/text_encoder_config_generated.h"
26*993b0882SAndroid Build Coastguard Worker #include "utils/tokenfree/byte_encoder.h"
27*993b0882SAndroid Build Coastguard Worker #include "flatbuffers/flatbuffers.h"
28*993b0882SAndroid Build Coastguard Worker #include "flatbuffers/flexbuffers.h"
29*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/kernels/kernel_util.h"
30*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/model.h"
31*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/string_util.h"
32*993b0882SAndroid Build Coastguard Worker
33*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
34*993b0882SAndroid Build Coastguard Worker namespace {
35*993b0882SAndroid Build Coastguard Worker
36*993b0882SAndroid Build Coastguard Worker // Input parameters for the op.
37*993b0882SAndroid Build Coastguard Worker constexpr int kInputTextInd = 0;
38*993b0882SAndroid Build Coastguard Worker
39*993b0882SAndroid Build Coastguard Worker constexpr int kTextLengthInd = 1;
40*993b0882SAndroid Build Coastguard Worker constexpr int kMaxLengthInd = 2;
41*993b0882SAndroid Build Coastguard Worker constexpr int kInputAttrInd = 3;
42*993b0882SAndroid Build Coastguard Worker
43*993b0882SAndroid Build Coastguard Worker // Output parameters for the op.
44*993b0882SAndroid Build Coastguard Worker constexpr int kOutputEncodedInd = 0;
45*993b0882SAndroid Build Coastguard Worker constexpr int kOutputPositionInd = 1;
46*993b0882SAndroid Build Coastguard Worker constexpr int kOutputLengthsInd = 2;
47*993b0882SAndroid Build Coastguard Worker constexpr int kOutputAttrInd = 3;
48*993b0882SAndroid Build Coastguard Worker
49*993b0882SAndroid Build Coastguard Worker // Initializes text encoder object from serialized parameters.
Initialize(TfLiteContext * context,const char * buffer,size_t length)50*993b0882SAndroid Build Coastguard Worker void* Initialize(TfLiteContext* context, const char* buffer, size_t length) {
51*993b0882SAndroid Build Coastguard Worker std::unique_ptr<ByteEncoder> encoder(new ByteEncoder());
52*993b0882SAndroid Build Coastguard Worker return encoder.release();
53*993b0882SAndroid Build Coastguard Worker }
54*993b0882SAndroid Build Coastguard Worker
Free(TfLiteContext * context,void * buffer)55*993b0882SAndroid Build Coastguard Worker void Free(TfLiteContext* context, void* buffer) {
56*993b0882SAndroid Build Coastguard Worker delete reinterpret_cast<ByteEncoder*>(buffer);
57*993b0882SAndroid Build Coastguard Worker }
58*993b0882SAndroid Build Coastguard Worker
59*993b0882SAndroid Build Coastguard Worker namespace {
ResizeOutputTensors(TfLiteContext * context,TfLiteNode * node,int max_output_length)60*993b0882SAndroid Build Coastguard Worker TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
61*993b0882SAndroid Build Coastguard Worker int max_output_length) {
62*993b0882SAndroid Build Coastguard Worker TfLiteTensor& output_encoded =
63*993b0882SAndroid Build Coastguard Worker context->tensors[node->outputs->data[kOutputEncodedInd]];
64*993b0882SAndroid Build Coastguard Worker
65*993b0882SAndroid Build Coastguard Worker TF_LITE_ENSURE_OK(
66*993b0882SAndroid Build Coastguard Worker context, context->ResizeTensor(
67*993b0882SAndroid Build Coastguard Worker context, &output_encoded,
68*993b0882SAndroid Build Coastguard Worker CreateIntArray({kEncoderBatchSize, max_output_length})));
69*993b0882SAndroid Build Coastguard Worker TfLiteTensor& output_positions =
70*993b0882SAndroid Build Coastguard Worker context->tensors[node->outputs->data[kOutputPositionInd]];
71*993b0882SAndroid Build Coastguard Worker
72*993b0882SAndroid Build Coastguard Worker TF_LITE_ENSURE_OK(
73*993b0882SAndroid Build Coastguard Worker context, context->ResizeTensor(
74*993b0882SAndroid Build Coastguard Worker context, &output_positions,
75*993b0882SAndroid Build Coastguard Worker CreateIntArray({kEncoderBatchSize, max_output_length})));
76*993b0882SAndroid Build Coastguard Worker
77*993b0882SAndroid Build Coastguard Worker const int num_output_attrs = node->outputs->size - kOutputAttrInd;
78*993b0882SAndroid Build Coastguard Worker for (int i = 0; i < num_output_attrs; ++i) {
79*993b0882SAndroid Build Coastguard Worker TfLiteTensor& output =
80*993b0882SAndroid Build Coastguard Worker context->tensors[node->outputs->data[kOutputAttrInd + i]];
81*993b0882SAndroid Build Coastguard Worker TF_LITE_ENSURE_OK(
82*993b0882SAndroid Build Coastguard Worker context, context->ResizeTensor(
83*993b0882SAndroid Build Coastguard Worker context, &output,
84*993b0882SAndroid Build Coastguard Worker CreateIntArray({kEncoderBatchSize, max_output_length})));
85*993b0882SAndroid Build Coastguard Worker }
86*993b0882SAndroid Build Coastguard Worker return kTfLiteOk;
87*993b0882SAndroid Build Coastguard Worker }
88*993b0882SAndroid Build Coastguard Worker } // namespace
89*993b0882SAndroid Build Coastguard Worker
Prepare(TfLiteContext * context,TfLiteNode * node)90*993b0882SAndroid Build Coastguard Worker TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
91*993b0882SAndroid Build Coastguard Worker // Check that the batch dimension is kEncoderBatchSize.
92*993b0882SAndroid Build Coastguard Worker const TfLiteTensor& input_text =
93*993b0882SAndroid Build Coastguard Worker context->tensors[node->inputs->data[kInputTextInd]];
94*993b0882SAndroid Build Coastguard Worker TF_LITE_ENSURE_EQ(context, input_text.dims->size, kEncoderInputRank);
95*993b0882SAndroid Build Coastguard Worker TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], kEncoderBatchSize);
96*993b0882SAndroid Build Coastguard Worker
97*993b0882SAndroid Build Coastguard Worker TfLiteTensor& output_lengths =
98*993b0882SAndroid Build Coastguard Worker context->tensors[node->outputs->data[kOutputLengthsInd]];
99*993b0882SAndroid Build Coastguard Worker
100*993b0882SAndroid Build Coastguard Worker TfLiteTensor& output_encoded =
101*993b0882SAndroid Build Coastguard Worker context->tensors[node->outputs->data[kOutputEncodedInd]];
102*993b0882SAndroid Build Coastguard Worker TfLiteTensor& output_positions =
103*993b0882SAndroid Build Coastguard Worker context->tensors[node->outputs->data[kOutputPositionInd]];
104*993b0882SAndroid Build Coastguard Worker output_encoded.type = kTfLiteInt32;
105*993b0882SAndroid Build Coastguard Worker output_positions.type = kTfLiteInt32;
106*993b0882SAndroid Build Coastguard Worker output_lengths.type = kTfLiteInt32;
107*993b0882SAndroid Build Coastguard Worker
108*993b0882SAndroid Build Coastguard Worker TF_LITE_ENSURE_OK(context,
109*993b0882SAndroid Build Coastguard Worker context->ResizeTensor(context, &output_lengths,
110*993b0882SAndroid Build Coastguard Worker CreateIntArray({kEncoderBatchSize})));
111*993b0882SAndroid Build Coastguard Worker
112*993b0882SAndroid Build Coastguard Worker // Check that there are enough outputs for attributes.
113*993b0882SAndroid Build Coastguard Worker const int num_output_attrs = node->outputs->size - kOutputAttrInd;
114*993b0882SAndroid Build Coastguard Worker TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttrInd,
115*993b0882SAndroid Build Coastguard Worker num_output_attrs);
116*993b0882SAndroid Build Coastguard Worker
117*993b0882SAndroid Build Coastguard Worker // Copy attribute types from input to output tensors.
118*993b0882SAndroid Build Coastguard Worker for (int i = 0; i < num_output_attrs; ++i) {
119*993b0882SAndroid Build Coastguard Worker TfLiteTensor& input =
120*993b0882SAndroid Build Coastguard Worker context->tensors[node->inputs->data[kInputAttrInd + i]];
121*993b0882SAndroid Build Coastguard Worker TfLiteTensor& output =
122*993b0882SAndroid Build Coastguard Worker context->tensors[node->outputs->data[kOutputAttrInd + i]];
123*993b0882SAndroid Build Coastguard Worker output.type = input.type;
124*993b0882SAndroid Build Coastguard Worker }
125*993b0882SAndroid Build Coastguard Worker
126*993b0882SAndroid Build Coastguard Worker const TfLiteTensor& output_length =
127*993b0882SAndroid Build Coastguard Worker context->tensors[node->inputs->data[kMaxLengthInd]];
128*993b0882SAndroid Build Coastguard Worker
129*993b0882SAndroid Build Coastguard Worker if (tflite::IsConstantTensor(&output_length)) {
130*993b0882SAndroid Build Coastguard Worker return ResizeOutputTensors(context, node, output_length.data.i64[0]);
131*993b0882SAndroid Build Coastguard Worker } else {
132*993b0882SAndroid Build Coastguard Worker tflite::SetTensorToDynamic(&output_encoded);
133*993b0882SAndroid Build Coastguard Worker tflite::SetTensorToDynamic(&output_positions);
134*993b0882SAndroid Build Coastguard Worker for (int i = 0; i < num_output_attrs; ++i) {
135*993b0882SAndroid Build Coastguard Worker TfLiteTensor& output_attr =
136*993b0882SAndroid Build Coastguard Worker context->tensors[node->outputs->data[kOutputAttrInd + i]];
137*993b0882SAndroid Build Coastguard Worker tflite::SetTensorToDynamic(&output_attr);
138*993b0882SAndroid Build Coastguard Worker }
139*993b0882SAndroid Build Coastguard Worker }
140*993b0882SAndroid Build Coastguard Worker
141*993b0882SAndroid Build Coastguard Worker return kTfLiteOk;
142*993b0882SAndroid Build Coastguard Worker }
143*993b0882SAndroid Build Coastguard Worker
Eval(TfLiteContext * context,TfLiteNode * node)144*993b0882SAndroid Build Coastguard Worker TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
145*993b0882SAndroid Build Coastguard Worker if (node->user_data == nullptr) {
146*993b0882SAndroid Build Coastguard Worker return kTfLiteError;
147*993b0882SAndroid Build Coastguard Worker }
148*993b0882SAndroid Build Coastguard Worker auto text_encoder = reinterpret_cast<ByteEncoder*>(node->user_data);
149*993b0882SAndroid Build Coastguard Worker const TfLiteTensor& input_text =
150*993b0882SAndroid Build Coastguard Worker context->tensors[node->inputs->data[kInputTextInd]];
151*993b0882SAndroid Build Coastguard Worker const int num_strings_in_tensor = tflite::GetStringCount(&input_text);
152*993b0882SAndroid Build Coastguard Worker const int num_strings =
153*993b0882SAndroid Build Coastguard Worker context->tensors[node->inputs->data[kTextLengthInd]].data.i32[0];
154*993b0882SAndroid Build Coastguard Worker
155*993b0882SAndroid Build Coastguard Worker // Check that the number of strings is not bigger than the input tensor size.
156*993b0882SAndroid Build Coastguard Worker TF_LITE_ENSURE(context, num_strings_in_tensor >= num_strings);
157*993b0882SAndroid Build Coastguard Worker
158*993b0882SAndroid Build Coastguard Worker TfLiteTensor& output_encoded =
159*993b0882SAndroid Build Coastguard Worker context->tensors[node->outputs->data[kOutputEncodedInd]];
160*993b0882SAndroid Build Coastguard Worker if (tflite::IsDynamicTensor(&output_encoded)) {
161*993b0882SAndroid Build Coastguard Worker const TfLiteTensor& output_length =
162*993b0882SAndroid Build Coastguard Worker context->tensors[node->inputs->data[kMaxLengthInd]];
163*993b0882SAndroid Build Coastguard Worker TF_LITE_ENSURE_OK(
164*993b0882SAndroid Build Coastguard Worker context, ResizeOutputTensors(context, node, output_length.data.i64[0]));
165*993b0882SAndroid Build Coastguard Worker }
166*993b0882SAndroid Build Coastguard Worker TfLiteTensor& output_positions =
167*993b0882SAndroid Build Coastguard Worker context->tensors[node->outputs->data[kOutputPositionInd]];
168*993b0882SAndroid Build Coastguard Worker
169*993b0882SAndroid Build Coastguard Worker std::vector<int> encoded_total;
170*993b0882SAndroid Build Coastguard Worker std::vector<int> encoded_positions;
171*993b0882SAndroid Build Coastguard Worker std::vector<int> encoded_offsets;
172*993b0882SAndroid Build Coastguard Worker encoded_offsets.reserve(num_strings);
173*993b0882SAndroid Build Coastguard Worker const int max_output_length = output_encoded.dims->data[1];
174*993b0882SAndroid Build Coastguard Worker const int max_encoded_position = max_output_length;
175*993b0882SAndroid Build Coastguard Worker
176*993b0882SAndroid Build Coastguard Worker for (int i = 0; i < num_strings; ++i) {
177*993b0882SAndroid Build Coastguard Worker const auto& strref = tflite::GetString(&input_text, i);
178*993b0882SAndroid Build Coastguard Worker std::vector<int64_t> encoded;
179*993b0882SAndroid Build Coastguard Worker text_encoder->Encode(
180*993b0882SAndroid Build Coastguard Worker libtextclassifier3::StringPiece(strref.str, strref.len), &encoded);
181*993b0882SAndroid Build Coastguard Worker encoded_total.insert(encoded_total.end(), encoded.begin(), encoded.end());
182*993b0882SAndroid Build Coastguard Worker encoded_offsets.push_back(encoded_total.size());
183*993b0882SAndroid Build Coastguard Worker for (int i = 0; i < encoded.size(); ++i) {
184*993b0882SAndroid Build Coastguard Worker encoded_positions.push_back(std::min(i, max_encoded_position - 1));
185*993b0882SAndroid Build Coastguard Worker }
186*993b0882SAndroid Build Coastguard Worker }
187*993b0882SAndroid Build Coastguard Worker
188*993b0882SAndroid Build Coastguard Worker // Copy encoding to output tensor.
189*993b0882SAndroid Build Coastguard Worker const int start_offset =
190*993b0882SAndroid Build Coastguard Worker std::max(0, static_cast<int>(encoded_total.size()) - max_output_length);
191*993b0882SAndroid Build Coastguard Worker int output_offset = 0;
192*993b0882SAndroid Build Coastguard Worker int32_t* output_buffer = output_encoded.data.i32;
193*993b0882SAndroid Build Coastguard Worker int32_t* output_positions_buffer = output_positions.data.i32;
194*993b0882SAndroid Build Coastguard Worker for (int i = start_offset; i < encoded_total.size(); ++i, ++output_offset) {
195*993b0882SAndroid Build Coastguard Worker output_buffer[output_offset] = encoded_total[i];
196*993b0882SAndroid Build Coastguard Worker output_positions_buffer[output_offset] = encoded_positions[i];
197*993b0882SAndroid Build Coastguard Worker }
198*993b0882SAndroid Build Coastguard Worker
199*993b0882SAndroid Build Coastguard Worker // Save output encoded length.
200*993b0882SAndroid Build Coastguard Worker TfLiteTensor& output_lengths =
201*993b0882SAndroid Build Coastguard Worker context->tensors[node->outputs->data[kOutputLengthsInd]];
202*993b0882SAndroid Build Coastguard Worker output_lengths.data.i32[0] = output_offset;
203*993b0882SAndroid Build Coastguard Worker
204*993b0882SAndroid Build Coastguard Worker // Do padding.
205*993b0882SAndroid Build Coastguard Worker for (; output_offset < max_output_length; ++output_offset) {
206*993b0882SAndroid Build Coastguard Worker output_buffer[output_offset] = 0;
207*993b0882SAndroid Build Coastguard Worker output_positions_buffer[output_offset] = 0;
208*993b0882SAndroid Build Coastguard Worker }
209*993b0882SAndroid Build Coastguard Worker
210*993b0882SAndroid Build Coastguard Worker // Process attributes, all checks of sizes and types are done in Prepare.
211*993b0882SAndroid Build Coastguard Worker const int num_output_attrs = node->outputs->size - kOutputAttrInd;
212*993b0882SAndroid Build Coastguard Worker TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttrInd,
213*993b0882SAndroid Build Coastguard Worker num_output_attrs);
214*993b0882SAndroid Build Coastguard Worker for (int i = 0; i < num_output_attrs; ++i) {
215*993b0882SAndroid Build Coastguard Worker TfLiteStatus attr_status = CopyValuesToTensorAndPadOrTruncate(
216*993b0882SAndroid Build Coastguard Worker context->tensors[node->inputs->data[kInputAttrInd + i]],
217*993b0882SAndroid Build Coastguard Worker encoded_offsets, start_offset, context,
218*993b0882SAndroid Build Coastguard Worker &context->tensors[node->outputs->data[kOutputAttrInd + i]]);
219*993b0882SAndroid Build Coastguard Worker if (attr_status != kTfLiteOk) {
220*993b0882SAndroid Build Coastguard Worker return attr_status;
221*993b0882SAndroid Build Coastguard Worker }
222*993b0882SAndroid Build Coastguard Worker }
223*993b0882SAndroid Build Coastguard Worker
224*993b0882SAndroid Build Coastguard Worker return kTfLiteOk;
225*993b0882SAndroid Build Coastguard Worker }
226*993b0882SAndroid Build Coastguard Worker
227*993b0882SAndroid Build Coastguard Worker } // namespace
228*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3
229*993b0882SAndroid Build Coastguard Worker
230*993b0882SAndroid Build Coastguard Worker namespace tflite {
231*993b0882SAndroid Build Coastguard Worker namespace ops {
232*993b0882SAndroid Build Coastguard Worker namespace custom {
233*993b0882SAndroid Build Coastguard Worker
Register_TEXT_ENCODER3S()234*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_TEXT_ENCODER3S() {
235*993b0882SAndroid Build Coastguard Worker static TfLiteRegistration registration = {
236*993b0882SAndroid Build Coastguard Worker libtextclassifier3::Initialize, libtextclassifier3::Free,
237*993b0882SAndroid Build Coastguard Worker libtextclassifier3::Prepare, libtextclassifier3::Eval};
238*993b0882SAndroid Build Coastguard Worker return ®istration;
239*993b0882SAndroid Build Coastguard Worker }
240*993b0882SAndroid Build Coastguard Worker
241*993b0882SAndroid Build Coastguard Worker } // namespace custom
242*993b0882SAndroid Build Coastguard Worker } // namespace ops
243*993b0882SAndroid Build Coastguard Worker } // namespace tflite
244