xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/Encoders.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "BaseIterator.hpp"
9 
10 #include <armnnUtils/TensorUtils.hpp>
11 
12 #include <armnn/utility/Assert.hpp>
13 
14 namespace armnn
15 {
16 
17 template<typename T>
18 inline std::unique_ptr<Encoder<T>> MakeEncoder(const TensorInfo& info, void* data = nullptr);
19 
20 template<>
MakeEncoder(const TensorInfo & info,void * data)21 inline std::unique_ptr<Encoder<float>> MakeEncoder(const TensorInfo& info, void* data)
22 {
23     switch(info.GetDataType())
24     {
25         case armnn::DataType::QAsymmS8:
26         {
27             return std::make_unique<QASymmS8Encoder>(
28                 static_cast<int8_t*>(data),
29                 info.GetQuantizationScale(),
30                 info.GetQuantizationOffset());
31         }
32         case armnn::DataType::QAsymmU8:
33         {
34             return std::make_unique<QASymm8Encoder>(
35                 static_cast<uint8_t*>(data),
36                 info.GetQuantizationScale(),
37                 info.GetQuantizationOffset());
38         }
39         case DataType::QSymmS8:
40         {
41             if (info.HasPerAxisQuantization())
42             {
43                 std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info);
44                 return std::make_unique<QSymm8PerAxisEncoder>(
45                         static_cast<int8_t*>(data),
46                         params.second,
47                         params.first);
48             }
49             else
50             {
51                 return std::make_unique<QSymmS8Encoder>(
52                         static_cast<int8_t*>(data),
53                         info.GetQuantizationScale(),
54                         info.GetQuantizationOffset());
55             }
56         }
57         case armnn::DataType::QSymmS16:
58         {
59             return std::make_unique<QSymm16Encoder>(
60                 static_cast<int16_t*>(data),
61                 info.GetQuantizationScale(),
62                 info.GetQuantizationOffset());
63         }
64         case armnn::DataType::Signed32:
65         {
66             return std::make_unique<Int32Encoder>(static_cast<int32_t*>(data));
67         }
68         case armnn::DataType::Float16:
69         {
70             return std::make_unique<Float16Encoder>(static_cast<Half*>(data));
71         }
72         case armnn::DataType::Float32:
73         {
74             return std::make_unique<Float32Encoder>(static_cast<float*>(data));
75         }
76         default:
77         {
78             ARMNN_ASSERT_MSG(false, "Unsupported target Data Type!");
79             break;
80         }
81     }
82     return nullptr;
83 }
84 
85 template<>
MakeEncoder(const TensorInfo & info,void * data)86 inline std::unique_ptr<Encoder<bool>> MakeEncoder(const TensorInfo& info, void* data)
87 {
88     switch(info.GetDataType())
89     {
90         case armnn::DataType::Boolean:
91         {
92             return std::make_unique<BooleanEncoder>(static_cast<uint8_t*>(data));
93         }
94         default:
95         {
96             ARMNN_ASSERT_MSG(false, "Cannot encode from boolean. Not supported target Data Type!");
97             break;
98         }
99     }
100     return nullptr;
101 }
102 
103 template<>
MakeEncoder(const TensorInfo & info,void * data)104 inline std::unique_ptr<Encoder<int32_t>> MakeEncoder(const TensorInfo& info, void* data)
105 {
106     switch(info.GetDataType())
107     {
108         case DataType::Signed32:
109         {
110             return std::make_unique<Int32ToInt32tEncoder>(static_cast<int32_t*>(data));
111         }
112         default:
113         {
114             ARMNN_ASSERT_MSG(false, "Unsupported Data Type!");
115             break;
116         }
117     }
118     return nullptr;
119 }
120 
121 } //namespace armnn
122