xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/util.h"
16 
17 #include <stddef.h>
18 #include <stdint.h>
19 
20 #include <algorithm>
21 #include <complex>
22 #include <cstring>
23 #include <initializer_list>
24 #include <memory>
25 #include <string>
26 #include <vector>
27 
28 #include "tensorflow/lite/builtin_ops.h"
29 #include "tensorflow/lite/c/common.h"
30 #include "tensorflow/lite/core/macros.h"
31 #include "tensorflow/lite/schema/schema_generated.h"
32 
33 namespace tflite {
34 namespace {
35 
UnresolvedOpInvoke(TfLiteContext * context,TfLiteNode * node)36 TfLiteStatus UnresolvedOpInvoke(TfLiteContext* context, TfLiteNode* node) {
37   TF_LITE_KERNEL_LOG(context,
38                      "Encountered an unresolved custom op. Did you miss "
39                      "a custom op or delegate?");
40   return kTfLiteError;
41 }
42 
43 }  // namespace
44 
IsFlexOp(const char * custom_name)45 bool IsFlexOp(const char* custom_name) {
46   return custom_name && strncmp(custom_name, kFlexCustomCodePrefix,
47                                 strlen(kFlexCustomCodePrefix)) == 0;
48 }
49 
BuildTfLiteIntArray(const std::vector<int> & data)50 std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> BuildTfLiteIntArray(
51     const std::vector<int>& data) {
52   std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> result(
53       TfLiteIntArrayCreate(data.size()));
54   std::copy(data.begin(), data.end(), result->data);
55   return result;
56 }
57 
ConvertVectorToTfLiteIntArray(const std::vector<int> & input)58 TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input) {
59   return ConvertArrayToTfLiteIntArray(static_cast<int>(input.size()),
60                                       input.data());
61 }
62 
ConvertArrayToTfLiteIntArray(const int ndims,const int * dims)63 TfLiteIntArray* ConvertArrayToTfLiteIntArray(const int ndims, const int* dims) {
64   TfLiteIntArray* output = TfLiteIntArrayCreate(ndims);
65   for (size_t i = 0; i < ndims; i++) {
66     output->data[i] = dims[i];
67   }
68   return output;
69 }
70 
EqualArrayAndTfLiteIntArray(const TfLiteIntArray * a,const int b_size,const int * b)71 bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size,
72                                  const int* b) {
73   if (!a) return false;
74   if (a->size != b_size) return false;
75   for (int i = 0; i < a->size; ++i) {
76     if (a->data[i] != b[i]) return false;
77   }
78   return true;
79 }
80 
CombineHashes(std::initializer_list<size_t> hashes)81 size_t CombineHashes(std::initializer_list<size_t> hashes) {
82   size_t result = 0;
83   // Hash combiner used by TensorFlow core.
84   for (size_t hash : hashes) {
85     result = result ^
86              (hash + 0x9e3779b97f4a7800ULL + (result << 10) + (result >> 4));
87   }
88   return result;
89 }
90 
GetSizeOfType(TfLiteContext * context,const TfLiteType type,size_t * bytes)91 TfLiteStatus GetSizeOfType(TfLiteContext* context, const TfLiteType type,
92                            size_t* bytes) {
93   // TODO(levp): remove the default case so that new types produce compilation
94   // error.
95   switch (type) {
96     case kTfLiteFloat32:
97       *bytes = sizeof(float);
98       break;
99     case kTfLiteInt32:
100       *bytes = sizeof(int32_t);
101       break;
102     case kTfLiteUInt32:
103       *bytes = sizeof(uint32_t);
104       break;
105     case kTfLiteUInt8:
106       *bytes = sizeof(uint8_t);
107       break;
108     case kTfLiteInt64:
109       *bytes = sizeof(int64_t);
110       break;
111     case kTfLiteUInt64:
112       *bytes = sizeof(uint64_t);
113       break;
114     case kTfLiteBool:
115       *bytes = sizeof(bool);
116       break;
117     case kTfLiteComplex64:
118       *bytes = sizeof(std::complex<float>);
119       break;
120     case kTfLiteComplex128:
121       *bytes = sizeof(std::complex<double>);
122       break;
123     case kTfLiteUInt16:
124       *bytes = sizeof(uint16_t);
125       break;
126     case kTfLiteInt16:
127       *bytes = sizeof(int16_t);
128       break;
129     case kTfLiteInt8:
130       *bytes = sizeof(int8_t);
131       break;
132     case kTfLiteFloat16:
133       *bytes = sizeof(TfLiteFloat16);
134       break;
135     case kTfLiteFloat64:
136       *bytes = sizeof(double);
137       break;
138     default:
139       if (context) {
140         TF_LITE_KERNEL_LOG(
141             context,
142             "Type %d is unsupported. Only float16, float32, float64, int8, "
143             "int16, int32, int64, uint8, uint64, bool, complex64 and "
144             "complex128 supported currently.",
145             type);
146       }
147       return kTfLiteError;
148   }
149   return kTfLiteOk;
150 }
151 
CreateUnresolvedCustomOp(const char * custom_op_name)152 TfLiteRegistration CreateUnresolvedCustomOp(const char* custom_op_name) {
153   return TfLiteRegistration{nullptr,
154                             nullptr,
155                             nullptr,
156                             /*invoke*/ &UnresolvedOpInvoke,
157                             nullptr,
158                             BuiltinOperator_CUSTOM,
159                             custom_op_name,
160                             1};
161 }
162 
IsUnresolvedCustomOp(const TfLiteRegistration & registration)163 bool IsUnresolvedCustomOp(const TfLiteRegistration& registration) {
164   return registration.builtin_code == tflite::BuiltinOperator_CUSTOM &&
165          registration.invoke == &UnresolvedOpInvoke;
166 }
167 
GetOpNameByRegistration(const TfLiteRegistration & registration)168 std::string GetOpNameByRegistration(const TfLiteRegistration& registration) {
169   auto op = registration.builtin_code;
170   std::string result =
171       EnumNameBuiltinOperator(static_cast<BuiltinOperator>(op));
172   if ((op == kTfLiteBuiltinCustom || op == kTfLiteBuiltinDelegate) &&
173       registration.custom_name) {
174     result += " " + std::string(registration.custom_name);
175   }
176   return result;
177 }
178 
IsValidationSubgraph(const char * name)179 bool IsValidationSubgraph(const char* name) {
180   // NOLINTNEXTLINE: can't use absl::StartsWith as absl is not allowed.
181   return name && std::string(name).find(kValidationSubgraphNamePrefix) == 0;
182 }
183 
MultiplyAndCheckOverflow(size_t a,size_t b,size_t * product)184 TfLiteStatus MultiplyAndCheckOverflow(size_t a, size_t b, size_t* product) {
185   // Multiplying a * b where a and b are size_t cannot result in overflow in a
186   // size_t accumulator if both numbers have no non-zero bits in their upper
187   // half.
188   constexpr size_t size_t_bits = 8 * sizeof(size_t);
189   constexpr size_t overflow_upper_half_bit_position = size_t_bits / 2;
190   *product = a * b;
191   // If neither integers have non-zero bits past 32 bits can't overflow.
192   // Otherwise check using slow devision.
193   if (TFLITE_EXPECT_FALSE((a | b) >> overflow_upper_half_bit_position != 0)) {
194     if (a != 0 && *product / a != b) return kTfLiteError;
195   }
196   return kTfLiteOk;
197 }
198 }  // namespace tflite
199