1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/tflite/encoder_common.h"
18
19 #include "tensorflow/lite/kernels/kernel_util.h"
20 #include "tensorflow/lite/string_util.h"
21
22 namespace libtextclassifier3 {
23
CreateIntArray(const std::initializer_list<int> & values)24 TfLiteIntArray* CreateIntArray(const std::initializer_list<int>& values) {
25 TfLiteIntArray* array_size = TfLiteIntArrayCreate(values.size());
26 int index = 0;
27 for (const int size : values) {
28 array_size->data[index++] = size;
29 }
30 return array_size;
31 }
32
CopyValuesToTensorAndPadOrTruncate(const TfLiteTensor & in,const std::vector<int> & encoding_end_offsets,int start_offset,TfLiteContext * context,TfLiteTensor * out)33 TfLiteStatus CopyValuesToTensorAndPadOrTruncate(
34 const TfLiteTensor& in, const std::vector<int>& encoding_end_offsets,
35 int start_offset, TfLiteContext* context, TfLiteTensor* out) {
36 TF_LITE_ENSURE_EQ(context, in.dims->size, kEncoderInputRank);
37 TF_LITE_ENSURE_EQ(context, in.dims->data[0], kEncoderBatchSize);
38 const int output_size = out->dims->data[1];
39 int output_offset = 0;
40 for (int value_index = 0;
41 value_index < encoding_end_offsets.size() && output_offset < output_size;
42 ++value_index) {
43 // Calculate how many elements need to be set with this value.
44 // The low bound depends on the offset from the beginning. If this is 0, it
45 // means that this value it truncated.
46 // The upper bound depends on how many elements are in the output tensor.
47 const int from_this_element =
48 std::min(std::max(0, encoding_end_offsets[value_index] - start_offset -
49 output_offset),
50 output_size - output_offset);
51 if (from_this_element == 0) {
52 continue;
53 }
54
55 switch (in.type) {
56 case kTfLiteInt32: {
57 std::fill(out->data.i32 + output_offset,
58 out->data.i32 + output_offset + from_this_element,
59 in.data.i32[value_index]);
60 } break;
61 case kTfLiteInt64: {
62 std::fill(out->data.i64 + output_offset,
63 out->data.i64 + output_offset + from_this_element,
64 in.data.i64[value_index]);
65 } break;
66 case kTfLiteFloat32: {
67 std::fill(out->data.f + output_offset,
68 out->data.f + output_offset + from_this_element,
69 in.data.f[value_index]);
70 } break;
71 default:
72 context->ReportError(
73 (context), __FILE__ " Not supported attribute type %d", in.type);
74 return kTfLiteError;
75 }
76 output_offset += from_this_element;
77 }
78 // Do final padding.
79 switch (in.type) {
80 case kTfLiteInt32: {
81 const int32_t value =
82 (output_offset > 0) ? out->data.i32[output_offset - 1] : 0;
83 std::fill(out->data.i32 + output_offset, out->data.i32 + output_size,
84 value);
85 } break;
86 case kTfLiteInt64: {
87 const int64_t value =
88 (output_offset > 0) ? out->data.i64[output_offset - 1] : 0;
89 std::fill(out->data.i64 + output_offset, out->data.i64 + output_size,
90 value);
91 } break;
92 case kTfLiteFloat32: {
93 const float value =
94 (output_offset > 0) ? out->data.f[output_offset - 1] : 0;
95 std::fill(out->data.f + output_offset, out->data.f + output_size, value);
96 } break;
97 default:
98 break;
99 }
100 return kTfLiteOk;
101 }
102
ResizeOutputTensor(const int max_output_length,TfLiteTensor * tensor,TfLiteContext * context)103 TfLiteStatus ResizeOutputTensor(const int max_output_length,
104 TfLiteTensor* tensor, TfLiteContext* context) {
105 TF_LITE_ENSURE_OK(
106 context, context->ResizeTensor(
107 context, tensor,
108 CreateIntArray({kEncoderBatchSize, max_output_length})));
109 return kTfLiteOk;
110 }
111
CopyDataToTensorAndPadOrTruncate(const int32_t max_output_length,const std::vector<int32_t> & data,const int32_t padding_value,TfLiteTensor * output_tensor)112 int CopyDataToTensorAndPadOrTruncate(const int32_t max_output_length,
113 const std::vector<int32_t>& data,
114 const int32_t padding_value,
115 TfLiteTensor* output_tensor) {
116 const int num_skip =
117 std::max(0, static_cast<int>(data.size()) - max_output_length);
118 int output_offset = 0;
119 int32_t* output_buffer = output_tensor->data.i32;
120 for (int i = num_skip; i < data.size(); ++i, ++output_offset) {
121 output_buffer[output_offset] = data[i];
122 }
123
124 // Do padding.
125 for (; output_offset < max_output_length; ++output_offset) {
126 output_buffer[output_offset] = padding_value;
127 }
128
129 // Return number of skipped entries from the beginning.
130 return num_skip;
131 }
132
133 } // namespace libtextclassifier3
134