xref: /aosp_15_r20/external/armnn/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/BackendRegistry.hpp>      // for BackendRegistryInstance
9 #include <armnn/Logging.hpp>              // for ScopedRecord, ARMNN_LOG
10 #include <armnn/utility/NumericCast.hpp>  // for numeric_cast
11 #include <armnn/utility/StringUtils.hpp>  // for StringTokenizer
12 #include <armnn/BackendId.hpp>            // for BackendId, BackendIdSet
13 #include <armnn/Optional.hpp>             // for Optional, EmptyOptional
14 #include <armnn/Tensor.hpp>               // for Tensor, TensorInfo
15 #include <armnn/TypesUtils.hpp>           // for Dequantize
16 #include <chrono>                         // for duration
17 #include <functional>                     // for function
18 #include <fstream>
19 #include <iomanip>
20 #include <iostream>                       // for ofstream, basic_istream
21 #include <ratio>                          // for milli
22 #include <string>                         // for string, getline, basic_string
23 #include <type_traits>                    // for enable_if_t, is_floating_point
24 #include <unordered_set>                  // for operator!=, operator==, _No...
25 #include <vector>                         // for vector
26 #include <math.h>                         // for pow, sqrt
27 #include <stdint.h>                       // for int32_t
28 #include <stdio.h>                        // for printf, size_t
29 #include <stdlib.h>                       // for abs
30 #include <algorithm>                      // for find, for_each
31 
32 /**
33  * Given a measured duration and a threshold time tell the user whether we succeeded or not.
34  *
35  * @param duration the measured inference duration.
36  * @param thresholdTime the threshold time in milliseconds.
37  * @return false if the measured time exceeded the threshold.
38  */
39 bool CheckInferenceTimeThreshold(const std::chrono::duration<double, std::milli>& duration,
40                                  const double& thresholdTime);
41 
CheckRequestedBackendsAreValid(const std::vector<armnn::BackendId> & backendIds,armnn::Optional<std::string &> invalidBackendIds=armnn::EmptyOptional ())42 inline bool CheckRequestedBackendsAreValid(const std::vector<armnn::BackendId>& backendIds,
43                                            armnn::Optional<std::string&> invalidBackendIds = armnn::EmptyOptional())
44 {
45     if (backendIds.empty())
46     {
47         return false;
48     }
49 
50     armnn::BackendIdSet validBackendIds = armnn::BackendRegistryInstance().GetBackendIds();
51 
52     bool allValid = true;
53     for (const auto& backendId : backendIds)
54     {
55         if (std::find(validBackendIds.begin(), validBackendIds.end(), backendId) == validBackendIds.end())
56         {
57             allValid = false;
58             if (invalidBackendIds)
59             {
60                 if (!invalidBackendIds.value().empty())
61                 {
62                     invalidBackendIds.value() += ", ";
63                 }
64                 invalidBackendIds.value() += backendId;
65             }
66         }
67     }
68     return allValid;
69 }
70 
71 std::vector<unsigned int> ParseArray(std::istream& stream);
72 
73 /// Splits a given string at every accurance of delimiter into a vector of string
74 std::vector<std::string> ParseStringList(const std::string& inputString, const char* delimiter);
75 
76 double ComputeByteLevelRMSE(const void* expected, const void* actual, const size_t size);
77 
78 /// Dequantize an array of a given type
79 /// @param array Type erased array to dequantize
80 /// @param numElements Elements in the array
81 /// @param array Type erased array to dequantize
82 template <typename T>
DequantizeArray(const void * array,unsigned int numElements,float scale,int32_t offset)83 std::vector<float> DequantizeArray(const void* array, unsigned int numElements, float scale, int32_t offset)
84 {
85     const T* quantizedArray = reinterpret_cast<const T*>(array);
86     std::vector<float> dequantizedVector;
87     dequantizedVector.reserve(numElements);
88     for (unsigned int i = 0; i < numElements; ++i)
89     {
90         float f = armnn::Dequantize(*(quantizedArray + i), scale, offset);
91         dequantizedVector.push_back(f);
92     }
93     return dequantizedVector;
94 }
95 
96 void LogAndThrow(std::string eMsg);
97 
98 /**
99  * Verifies if the given string is a valid path. Reports invalid paths to std::err.
100  * @param file string - A string containing the path to check
101  * @param expectFile bool - If true, checks for a regular file.
102  * @return bool - True if given string is a valid path., false otherwise.
103  * */
104 bool ValidatePath(const std::string& file, const bool expectFile);
105 
106 /**
107  * Verifies if a given vector of strings are valid paths. Reports invalid paths to std::err.
108  * @param fileVec vector of string - A vector of string containing the paths to check
109  * @param expectFile bool - If true, checks for a regular file.
110  * @return bool - True if all given strings are valid paths., false otherwise.
111  * */
112 bool ValidatePaths(const std::vector<std::string>& fileVec, const bool expectFile);
113 
114 /// Returns a function of read the given type as a string
115 template <typename Integer, typename std::enable_if_t<std::is_integral<Integer>::value>* = nullptr>
GetParseElementFunc()116 std::function<Integer(const std::string&)> GetParseElementFunc()
117 {
118     return [](const std::string& s) { return armnn::numeric_cast<Integer>(std::stoi(s)); };
119 }
120 
121 template <typename Float, std::enable_if_t<std::is_floating_point<Float>::value>* = nullptr>
GetParseElementFunc()122 std::function<Float(const std::string&)> GetParseElementFunc()
123 {
124     return [](const std::string& s) { return std::stof(s); };
125 }
126 
127 template <typename T>
PopulateTensorWithData(T * tensor,const unsigned int numElements,const armnn::Optional<std::string> & dataFile,const std::string & inputName)128 void PopulateTensorWithData(T* tensor,
129                             const unsigned int numElements,
130                             const armnn::Optional<std::string>& dataFile,
131                             const std::string& inputName)
132 {
133     const bool readFromFile = dataFile.has_value() && !dataFile.value().empty();
134 
135     std::ifstream inputTensorFile;
136     if (!readFromFile)
137     {
138         std::fill(tensor, tensor + numElements, 0);
139         return;
140     }
141     else
142     {
143         inputTensorFile = std::ifstream(dataFile.value());
144     }
145 
146     auto parseElementFunc = GetParseElementFunc<T>();
147     std::string line;
148     unsigned int index = 0;
149     while (std::getline(inputTensorFile, line))
150     {
151         std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, "\t ,:");
152         for (const std::string& token : tokens)
153         {
154             if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
155             {
156                 try
157                 {
158                     if (index == numElements)
159                     {
160                         ARMNN_LOG(error) << "Number of elements: " << (index +1) << " in file \"" << dataFile.value()
161                                          << "\" does not match number of elements: " << numElements
162                                          << " for input \"" << inputName << "\".";
163                     }
164                     *(tensor + index) = parseElementFunc(token);
165                     index++;
166                 }
167                 catch (const std::exception&)
168                 {
169                     ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
170                 }
171             }
172         }
173     }
174 
175     if (index != numElements)
176     {
177         ARMNN_LOG(error) << "Number of elements: " << (index +1) << " in file \"" << inputName
178                          << "\" does not match number of elements: " << numElements
179                          << " for input \"" << inputName << "\".";
180     }
181 }
182 
183 template<typename T>
WriteToFile(const std::string & outputTensorFileName,const std::string & outputName,const T * const array,const unsigned int numElements)184 void WriteToFile(const std::string& outputTensorFileName,
185                  const std::string& outputName,
186                  const T* const array,
187                  const unsigned int numElements)
188 {
189     std::ofstream outputTensorFile;
190     outputTensorFile.open(outputTensorFileName, std::ofstream::out | std::ofstream::trunc);
191     if (outputTensorFile.is_open())
192     {
193         outputTensorFile << outputName << ": ";
194         for (std::size_t i = 0; i < numElements; ++i)
195         {
196             outputTensorFile << +array[i] << " ";
197         }
198     }
199     else
200     {
201         ARMNN_LOG(info) << "Output Tensor File: " << outputTensorFileName << " could not be opened!";
202     }
203     outputTensorFile.close();
204 }
205 
206 struct OutputWriteInfo
207 {
208     const armnn::Optional<std::string>& m_OutputTensorFile;
209     const std::string& m_OutputName;
210     const armnn::Tensor& m_Tensor;
211     const bool m_PrintTensor;
212 };
213 
214 template <typename T>
PrintTensor(OutputWriteInfo & info,const char * formatString)215 void PrintTensor(OutputWriteInfo& info, const char* formatString)
216 {
217     const T* array = reinterpret_cast<const T*>(info.m_Tensor.GetMemoryArea());
218 
219     if (info.m_OutputTensorFile.has_value())
220     {
221         WriteToFile(info.m_OutputTensorFile.value(),
222                     info.m_OutputName,
223                     array,
224                     info.m_Tensor.GetNumElements());
225     }
226 
227     if (info.m_PrintTensor)
228     {
229         for (unsigned int i = 0; i < info.m_Tensor.GetNumElements(); i++)
230         {
231             printf(formatString, array[i]);
232         }
233     }
234 }
235 
236 template <typename T>
PrintQuantizedTensor(OutputWriteInfo & info)237 void PrintQuantizedTensor(OutputWriteInfo& info)
238 {
239     std::vector<float> dequantizedValues;
240     auto tensor = info.m_Tensor;
241     dequantizedValues = DequantizeArray<T>(tensor.GetMemoryArea(),
242                                            tensor.GetNumElements(),
243                                            tensor.GetInfo().GetQuantizationScale(),
244                                            tensor.GetInfo().GetQuantizationOffset());
245 
246     if (info.m_OutputTensorFile.has_value())
247     {
248         WriteToFile(info.m_OutputTensorFile.value(),
249                     info.m_OutputName,
250                     dequantizedValues.data(),
251                     tensor.GetNumElements());
252     }
253 
254     if (info.m_PrintTensor)
255     {
256         std::for_each(dequantizedValues.begin(), dequantizedValues.end(), [&](float value)
257         {
258             printf("%f ", value);
259         });
260     }
261 }
262 
263 template<typename T, typename TParseElementFunc>
ParseArrayImpl(std::istream & stream,TParseElementFunc parseElementFunc,const char * chars="\\t ,:")264 std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char* chars = "\t ,:")
265 {
266     std::vector<T> result;
267     // Processes line-by-line.
268     std::string line;
269     while (std::getline(stream, line))
270     {
271         std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, chars);
272         for (const std::string& token : tokens)
273         {
274             if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
275             {
276                 try
277                 {
278                     result.push_back(parseElementFunc(token));
279                 }
280                 catch (const std::exception&)
281                 {
282                     ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
283                 }
284             }
285         }
286     }
287 
288     return result;
289 }
290