xref: /aosp_15_r20/external/armnn/src/backends/tosaCommon/operatorMappings/TosaOperatorUtils.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <Layer.hpp>
9 #include <armnn/Tensor.hpp>
10 #include <armnn/Types.hpp>
11 
12 #include "common/include/ProfilingGuid.hpp"
13 
14 #include <tosa_serialization_handler.h>
15 
16 using namespace armnn;
17 using namespace tosa;
18 
19 // Function to return Tosa datatype from input ArmNN datatype.
ArmNNToDType(const DataType & type)20 inline DType ArmNNToDType(const DataType& type)
21 {
22     switch (type)
23     {
24         case DataType::Float16:
25         case DataType::BFloat16:
26             return DType_FP16;
27         case DataType::Float32:
28             return DType_FP32;
29         case DataType::QAsymmU8:
30             return DType_UINT8;
31         case DataType::QSymmS8:
32         case DataType::QAsymmS8:
33             return DType_INT8;
34         case DataType::QSymmS16:
35             return DType_INT16;
36         case DataType::Signed32:
37             return DType_INT32;
38         case DataType::Signed64:
39             // No signed 64, only DType_INT48.
40             return DType_UNKNOWN;
41         case DataType::Boolean:
42             return DType_BOOL;
43         default:
44             return DType_UNKNOWN;
45     }
46 }
47 
48 // Function to return Tosa tensor shape from input ArmNN tensor shape.
GetTosaTensorShape(const TensorShape & shape)49 inline std::vector<int32_t> GetTosaTensorShape(const TensorShape& shape)
50 {
51     std::vector<int32_t> returnShape;
52     for (u_int32_t i = 0; i < shape.GetNumDimensions(); i++)
53     {
54         returnShape.push_back(static_cast<int32_t>(shape[i]));
55     }
56     return returnShape;
57 }
58 
59 // Function that generates unique name using the layer type, input slot and layer guid.
GenerateUniqueName(const Layer & layer,uint32_t layerSlot)60 inline std::string GenerateUniqueName(const Layer& layer, uint32_t layerSlot)
61 {
62     std::string guid        = std::to_string(layer.GetGuid());
63     std::string slotAndGuid = std::to_string(layerSlot) + "_" + guid;
64 
65     switch (layer.GetType())
66     {
67         case LayerType::Input:
68             return "input" + slotAndGuid;
69         case LayerType::Output:
70             return "output" + slotAndGuid;
71         case LayerType::Constant:
72             return "constant_" + guid;
73         default:
74             return "intermediate" + slotAndGuid;
75     }
76 }
77 
78 // Function that generates unique output name using the layer type, input slot and layer guid.
GenerateUniqueOutputName(const Layer & layer,uint32_t layerSlot)79 inline std::string GenerateUniqueOutputName(const Layer& layer, uint32_t layerSlot)
80 {
81     Layer& connectedLayer = layer.GetOutputSlot().GetConnection(0)->GetOwningLayer();
82 
83     // Get the layer connected to the output slot, if output use that layer and id,
84     // otherwise use current layer and id.
85     if(connectedLayer.GetType() == LayerType::Output)
86     {
87         return GenerateUniqueName(connectedLayer, layerSlot);
88     }
89     else
90     {
91         return GenerateUniqueName(layer, layerSlot);
92     }
93 }
94 
95 // Function to return unique int as a string to ensure uniqueness between all input, output and block names.
96 static int uniqueTosaMappingID = 0;
GetUniqueTosaMappingID()97 inline std::string GetUniqueTosaMappingID()
98 {
99     return std::to_string(++uniqueTosaMappingID);
100 }
101 
102 // Function to return Tosa DType as string.
TosaDTypeToString(DType tosaDType)103 inline std::string TosaDTypeToString(DType tosaDType)
104 {
105     switch (tosaDType)
106     {
107         case DType_UNKNOWN:
108             return "DType_UNKNOWN";
109         case DType_BOOL:
110             return "DType_BOOL";
111         case DType_UINT8:
112             return "DType_UINT8";
113         case DType_INT4:
114             return "DType_INT4";
115         case DType_INT8:
116             return "DType_INT8";
117         case DType_INT16:
118             return "DType_INT16";
119         case DType_INT32:
120             return "DType_INT32";
121         case DType_INT48:
122             return "DType_INT48";
123         case DType_FP32:
124             return "DType_FP32";
125         case DType_UINT16:
126             return "DType_UINT16";
127         case DType_FP16:
128             return "DType_FP16";
129     }
130     return "";
131 }
132 
133 // Function to return Tosa Op as string.
TosaOpToString(Op tosaOp)134 inline std::string TosaOpToString(Op tosaOp)
135 {
136     switch (tosaOp)
137     {
138         case Op_ADD:
139             return "Op_ADD";
140         case Op_AVG_POOL2D:
141             return "Op_AVG_POOL2D";
142         case Op_MAX_POOL2D:
143             return "Op_MAX_POOL2D";
144         case Op_PAD:
145             return "Op_PAD";
146         case Op_UNKNOWN:
147             return "Op_UNKNOWN";
148         case Op_ARGMAX:
149             return "Op_ARGMAX";
150         case Op_CONV2D:
151             return "Op_CONV2D";
152         case Op_CONV3D:
153             return "Op_CONV3D";
154         case Op_DEPTHWISE_CONV2D:
155             return "Op_DEPTHWISE_CONV2D";
156         case Op_FULLY_CONNECTED:
157             return "Op_FULLY_CONNECTED";
158         case Op_MATMUL:
159             return "Op_MATMUL";
160         case Op_TRANSPOSE_CONV2D:
161             return "Op_TRANSPOSE_CONV2D";
162         case Op_CLAMP:
163             return "Op_CLAMP";
164         case Op_RESERVED:
165             return "Op_RESERVED";
166         case Op_SIGMOID:
167             return "Op_SIGMOID";
168         case Op_TANH:
169             return "Op_TANH";
170         case Op_ARITHMETIC_RIGHT_SHIFT:
171             return "Op_ARITHMETIC_RIGHT_SHIFT";
172         case Op_BITWISE_AND:
173             return "Op_BITWISE_AND";
174         case Op_BITWISE_OR:
175             return "Op_BITWISE_OR";
176         case Op_BITWISE_XOR:
177             return "Op_BITWISE_XOR";
178         case Op_INTDIV:
179             return "Op_INTDIV";
180         case Op_LOGICAL_AND:
181             return "Op_LOGICAL_AND";
182         case Op_LOGICAL_LEFT_SHIFT:
183             return "Op_LOGICAL_LEFT_SHIFT";
184         case Op_LOGICAL_RIGHT_SHIFT:
185             return "Op_LOGICAL_RIGHT_SHIFT";
186         case Op_LOGICAL_OR:
187             return "Op_LOGICAL_OR";
188         case Op_LOGICAL_XOR:
189             return "Op_LOGICAL_XOR";
190         case Op_MAXIMUM:
191             return "Op_MAXIMUM";
192         case Op_MINIMUM:
193             return "Op_MINIMUM";
194         case Op_MUL:
195             return "Op_MUL";
196         case Op_POW:
197             return "Op_POW";
198         case Op_SUB:
199             return "Op_SUB";
200         case Op_TABLE:
201             return "Op_TABLE";
202         case Op_ABS:
203             return "Op_ABS";
204         case Op_BITWISE_NOT:
205             return "Op_BITWISE_NOT";
206         case Op_CEIL:
207             return "Op_CEIL";
208         case Op_CLZ:
209             return "Op_CLZ";
210         case Op_EXP:
211             return "Op_EXP";
212         case Op_FLOOR:
213             return "Op_FLOOR";
214         case Op_LOG:
215             return "Op_LOG";
216         case Op_LOGICAL_NOT:
217             return "Op_LOGICAL_NOT";
218         case Op_NEGATE:
219             return "Op_NEGATE";
220         case Op_RECIPROCAL:
221             return "Op_RECIPROCAL";
222         case Op_RSQRT:
223             return "Op_RSQRT";
224         case Op_SELECT:
225             return "Op_SELECT";
226         case Op_EQUAL:
227             return "Op_EQUAL";
228         case Op_GREATER:
229             return "Op_GREATER";
230         case Op_GREATER_EQUAL:
231             return "Op_GREATER_EQUAL";
232         case Op_REDUCE_ANY:
233             return "Op_REDUCE_ANY";
234         case Op_REDUCE_ALL:
235             return "Op_REDUCE_ALL";
236         case Op_REDUCE_MAX:
237             return "Op_REDUCE_MAX";
238         case Op_REDUCE_MIN:
239             return "Op_REDUCE_MIN";
240         case Op_REDUCE_PRODUCT:
241             return "Op_REDUCE_PRODUCT";
242         case Op_REDUCE_SUM:
243             return "Op_REDUCE_SUM";
244         case Op_CONCAT:
245             return "Op_CONCAT";
246         case Op_RESHAPE:
247             return "Op_RESHAPE";
248         case Op_REVERSE:
249             return "Op_REVERSE";
250         case Op_SLICE:
251             return "Op_SLICE";
252         case Op_TILE:
253             return "Op_TILE";
254         case Op_TRANSPOSE:
255             return "Op_TRANSPOSE";
256         case Op_GATHER:
257             return "Op_GATHER";
258         case Op_SCATTER:
259             return "Op_SCATTER";
260         case Op_RESIZE:
261             return "Op_RESIZE";
262         case Op_CAST:
263             return "Op_CAST";
264         case Op_RESCALE:
265             return "Op_RESCALE";
266         case Op_CONST:
267             return "Op_CONST";
268         case Op_IDENTITY:
269             return "Op_IDENTITY";
270         case Op_CUSTOM:
271             return "Op_CUSTOM";
272         case Op_COND_IF:
273             return "Op_COND_IF";
274         case Op_WHILE_LOOP:
275             return "Op_WHILE_LOOP";
276     }
277     return "";
278 }
279 
ConvertConstantTensorDataToBuffer(const std::shared_ptr<ConstTensorHandle> & tensorHandle)280 inline std::vector<uint8_t> ConvertConstantTensorDataToBuffer(const std::shared_ptr<ConstTensorHandle>& tensorHandle)
281 {
282     tosa_err_t error = tosa_err_t::TOSA_OK;
283     std::vector<uint8_t> uint8Data;
284     auto tensorInfo = tensorHandle->GetTensorInfo();
285 
286     switch (tensorInfo.GetDataType())
287     {
288         case DataType::Float32:
289         {
290             std::vector<float> data(tensorInfo.GetNumElements());
291             memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
292 
293             error = TosaSerializationHandler::ConvertF32toU8(data, uint8Data);
294             break;
295         }
296         case DataType::Float16:
297         {
298             std::vector<float> data(tensorInfo.GetNumElements());
299             memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
300 
301             error = TosaSerializationHandler::ConvertF16toU8(data, uint8Data);
302             break;
303         }
304         case DataType::QSymmS8:
305         case DataType::QAsymmS8:
306         {
307             std::vector<int8_t> data(tensorInfo.GetNumElements());
308             memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
309 
310             error = TosaSerializationHandler::ConvertI8toU8(data, uint8Data);
311             break;
312         }
313         case DataType::QAsymmU8:
314         {
315             memcpy(uint8Data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
316             break;
317         }
318         case DataType::QSymmS16:
319         {
320             std::vector<int16_t> data(tensorInfo.GetNumElements());
321             memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
322 
323             error = TosaSerializationHandler::ConvertI16toU8(data, uint8Data);
324             break;
325         }
326         case DataType::Signed32:
327         {
328             std::vector<int32_t> data(tensorInfo.GetNumElements());
329             memcpy(data.data(), tensorHandle->Map(true), tensorInfo.GetNumBytes());
330 
331             error = TosaSerializationHandler::ConvertI32toU8(data, uint8Data);
332             break;
333         }
334         default:
335         {
336             throw armnn::Exception("SetConstantTensorData: An unsupported data type was encountered.");
337         }
338     }
339 
340     if(error != tosa_err_t::TOSA_OK)
341     {
342         throw armnn::Exception("SetConstantTensorData: An error occurred when converting constant data");
343     }
344 
345     tensorHandle->Unmap();
346     return uint8Data;
347 }
348