xref: /aosp_15_r20/external/libtextclassifier/native/utils/tflite/encoder_common.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/tflite/encoder_common.h"
18 
19 #include "tensorflow/lite/kernels/kernel_util.h"
20 #include "tensorflow/lite/string_util.h"
21 
22 namespace libtextclassifier3 {
23 
CreateIntArray(const std::initializer_list<int> & values)24 TfLiteIntArray* CreateIntArray(const std::initializer_list<int>& values) {
25   TfLiteIntArray* array_size = TfLiteIntArrayCreate(values.size());
26   int index = 0;
27   for (const int size : values) {
28     array_size->data[index++] = size;
29   }
30   return array_size;
31 }
32 
CopyValuesToTensorAndPadOrTruncate(const TfLiteTensor & in,const std::vector<int> & encoding_end_offsets,int start_offset,TfLiteContext * context,TfLiteTensor * out)33 TfLiteStatus CopyValuesToTensorAndPadOrTruncate(
34     const TfLiteTensor& in, const std::vector<int>& encoding_end_offsets,
35     int start_offset, TfLiteContext* context, TfLiteTensor* out) {
36   TF_LITE_ENSURE_EQ(context, in.dims->size, kEncoderInputRank);
37   TF_LITE_ENSURE_EQ(context, in.dims->data[0], kEncoderBatchSize);
38   const int output_size = out->dims->data[1];
39   int output_offset = 0;
40   for (int value_index = 0;
41        value_index < encoding_end_offsets.size() && output_offset < output_size;
42        ++value_index) {
43     // Calculate how many elements need to be set with this value.
44     // The low bound depends on the offset from the beginning. If this is 0, it
45     // means that this value it truncated.
46     // The upper bound depends on how many elements are in the output tensor.
47     const int from_this_element =
48         std::min(std::max(0, encoding_end_offsets[value_index] - start_offset -
49                                  output_offset),
50                  output_size - output_offset);
51     if (from_this_element == 0) {
52       continue;
53     }
54 
55     switch (in.type) {
56       case kTfLiteInt32: {
57         std::fill(out->data.i32 + output_offset,
58                   out->data.i32 + output_offset + from_this_element,
59                   in.data.i32[value_index]);
60       } break;
61       case kTfLiteInt64: {
62         std::fill(out->data.i64 + output_offset,
63                   out->data.i64 + output_offset + from_this_element,
64                   in.data.i64[value_index]);
65       } break;
66       case kTfLiteFloat32: {
67         std::fill(out->data.f + output_offset,
68                   out->data.f + output_offset + from_this_element,
69                   in.data.f[value_index]);
70       } break;
71       default:
72         context->ReportError(
73             (context), __FILE__ " Not supported attribute type %d", in.type);
74         return kTfLiteError;
75     }
76     output_offset += from_this_element;
77   }
78   // Do final padding.
79   switch (in.type) {
80     case kTfLiteInt32: {
81       const int32_t value =
82           (output_offset > 0) ? out->data.i32[output_offset - 1] : 0;
83       std::fill(out->data.i32 + output_offset, out->data.i32 + output_size,
84                 value);
85     } break;
86     case kTfLiteInt64: {
87       const int64_t value =
88           (output_offset > 0) ? out->data.i64[output_offset - 1] : 0;
89       std::fill(out->data.i64 + output_offset, out->data.i64 + output_size,
90                 value);
91     } break;
92     case kTfLiteFloat32: {
93       const float value =
94           (output_offset > 0) ? out->data.f[output_offset - 1] : 0;
95       std::fill(out->data.f + output_offset, out->data.f + output_size, value);
96     } break;
97     default:
98       break;
99   }
100   return kTfLiteOk;
101 }
102 
ResizeOutputTensor(const int max_output_length,TfLiteTensor * tensor,TfLiteContext * context)103 TfLiteStatus ResizeOutputTensor(const int max_output_length,
104                                 TfLiteTensor* tensor, TfLiteContext* context) {
105   TF_LITE_ENSURE_OK(
106       context, context->ResizeTensor(
107                    context, tensor,
108                    CreateIntArray({kEncoderBatchSize, max_output_length})));
109   return kTfLiteOk;
110 }
111 
CopyDataToTensorAndPadOrTruncate(const int32_t max_output_length,const std::vector<int32_t> & data,const int32_t padding_value,TfLiteTensor * output_tensor)112 int CopyDataToTensorAndPadOrTruncate(const int32_t max_output_length,
113                                      const std::vector<int32_t>& data,
114                                      const int32_t padding_value,
115                                      TfLiteTensor* output_tensor) {
116   const int num_skip =
117       std::max(0, static_cast<int>(data.size()) - max_output_length);
118   int output_offset = 0;
119   int32_t* output_buffer = output_tensor->data.i32;
120   for (int i = num_skip; i < data.size(); ++i, ++output_offset) {
121     output_buffer[output_offset] = data[i];
122   }
123 
124   // Do padding.
125   for (; output_offset < max_output_length; ++output_offset) {
126     output_buffer[output_offset] = padding_value;
127   }
128 
129   // Return number of skipped entries from the beginning.
130   return num_skip;
131 }
132 
133 }  // namespace libtextclassifier3
134