xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/tools/utils.h"
17 
18 #include <algorithm>
19 #include <complex>
20 #include <random>
21 
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 #include "tensorflow/lite/tools/logging.h"
24 
25 namespace tflite {
26 namespace utils {
27 
28 namespace {
get_random_engine()29 std::mt19937* get_random_engine() {
30   static std::mt19937* engine = []() -> std::mt19937* {
31     return new std::mt19937();
32   }();
33   return engine;
34 }
35 
36 template <typename T, typename Distribution>
CreateInputTensorData(int num_elements,Distribution distribution)37 inline InputTensorData CreateInputTensorData(int num_elements,
38                                              Distribution distribution) {
39   InputTensorData tmp;
40   auto* random_engine = get_random_engine();
41   tmp.bytes = sizeof(T) * num_elements;
42   T* raw = new T[num_elements];
43   std::generate_n(raw, num_elements, [&]() {
44     if (std::is_same<T, std::complex<float>>::value) {
45       return static_cast<T>(distribution(*random_engine),
46                             distribution(*random_engine));
47     } else {
48       return static_cast<T>(distribution(*random_engine));
49     }
50   });
51   tmp.data = VoidUniquePtr(static_cast<void*>(raw),
52                            [](void* ptr) { delete[] static_cast<T*>(ptr); });
53   return tmp;
54 }
55 
56 }  // namespace
57 
CreateRandomTensorData(const TfLiteTensor & tensor,float low_range,float high_range)58 InputTensorData CreateRandomTensorData(const TfLiteTensor& tensor,
59                                        float low_range, float high_range) {
60   int num_elements = NumElements(tensor.dims);
61   switch (tensor.type) {
62     case kTfLiteComplex64: {
63       return CreateInputTensorData<std::complex<float>>(
64           num_elements,
65           std::uniform_real_distribution<float>(low_range, high_range));
66     }
67     case kTfLiteFloat32: {
68       return CreateInputTensorData<float>(
69           num_elements,
70           std::uniform_real_distribution<float>(low_range, high_range));
71     }
72     case kTfLiteFloat16: {
73       // TODO(b/138843274): Remove this preprocessor guard when bug is fixed.
74 #if TFLITE_ENABLE_FP16_CPU_BENCHMARKS
75 #if __GNUC__ && \
76     (__clang__ || __ARM_FP16_FORMAT_IEEE || __ARM_FP16_FORMAT_ALTERNATIVE)
77       // __fp16 is available on Clang or when __ARM_FP16_FORMAT_* is defined.
78       return CreateInputTensorData<__fp16>(
79           num_elements, std::uniform_real_distribution<float>(-0.5f, 0.5f));
80 #else
81       TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name
82                         << " of type FLOAT16 on this platform.";
83 #endif
84 #else
85       // You need to build with -DTFLITE_ENABLE_FP16_CPU_BENCHMARKS=1 using a
86       // compiler that supports __fp16 type. Note: when using Clang and *not*
87       // linking with compiler-rt, a definition of __gnu_h2f_ieee and
88       // __gnu_f2h_ieee must be supplied.
89       TFLITE_LOG(FATAL) << "Populating the tensor " << tensor.name
90                         << " of type FLOAT16 is disabled.";
91 #endif  // TFLITE_ENABLE_FP16_CPU_BENCHMARKS
92       break;
93     }
94     case kTfLiteFloat64: {
95       return CreateInputTensorData<double>(
96           num_elements,
97           std::uniform_real_distribution<double>(low_range, high_range));
98     }
99     case kTfLiteInt64: {
100       return CreateInputTensorData<int64_t>(
101           num_elements,
102           std::uniform_int_distribution<int64_t>(low_range, high_range));
103     }
104     case kTfLiteInt32: {
105       return CreateInputTensorData<int32_t>(
106           num_elements,
107           std::uniform_int_distribution<int32_t>(low_range, high_range));
108     }
109     case kTfLiteUInt32: {
110       return CreateInputTensorData<uint32_t>(
111           num_elements,
112           std::uniform_int_distribution<uint32_t>(low_range, high_range));
113     }
114     case kTfLiteInt16: {
115       return CreateInputTensorData<int16_t>(
116           num_elements,
117           std::uniform_int_distribution<int16_t>(low_range, high_range));
118     }
119     case kTfLiteUInt8: {
120       // std::uniform_int_distribution is specified not to support char types.
121       return CreateInputTensorData<uint8_t>(
122           num_elements,
123           std::uniform_int_distribution<uint32_t>(low_range, high_range));
124     }
125     case kTfLiteInt8: {
126       // std::uniform_int_distribution is specified not to support char types.
127       return CreateInputTensorData<int8_t>(
128           num_elements,
129           std::uniform_int_distribution<int32_t>(low_range, high_range));
130     }
131     case kTfLiteString: {
132       // Don't populate input for string. Instead, return a default-initialized
133       // `InputTensorData` object directly.
134       break;
135     }
136     case kTfLiteBool: {
137       // According to std::uniform_int_distribution specification, non-int type
138       // is not supported.
139       return CreateInputTensorData<bool>(
140           num_elements, std::uniform_int_distribution<uint32_t>(0, 1));
141     }
142     default: {
143       TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << tensor.name
144                         << " of type " << tensor.type;
145     }
146   }
147   return InputTensorData();
148 }
149 
GetDataRangesForType(TfLiteType type,float * low_range,float * high_range)150 void GetDataRangesForType(TfLiteType type, float* low_range,
151                           float* high_range) {
152   if (type == kTfLiteComplex64 || type == kTfLiteFloat32 ||
153       type == kTfLiteFloat64) {
154     *low_range = -0.5f;
155     *high_range = 0.5f;
156   } else if (type == kTfLiteInt64 || type == kTfLiteInt64 ||
157              type == kTfLiteInt64 || type == kTfLiteInt64) {
158     *low_range = 0;
159     *high_range = 99;
160   } else if (type == kTfLiteUInt8) {
161     *low_range = 0;
162     *high_range = 254;
163   } else if (type == kTfLiteInt8) {
164     *low_range = -127;
165     *high_range = 127;
166   }
167 }
168 
169 }  // namespace utils
170 }  // namespace tflite
171