xref: /aosp_15_r20/external/executorch/backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Qualcomm Innovation Center, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 #pragma once
9 
10 #include <cstdint>
11 #include <memory>
12 #include <vector>
13 
14 #include "QnnTypes.h"
15 namespace executorch {
16 namespace backends {
17 namespace qnn {
18 class QuantizeParamsWrapper {
19  public:
20   // To create the QuantizeParams_t using data from this class:
21   virtual Qnn_QuantizeParams_t CreateQuantizeParams() = 0;
22   // Other accessors:
GetEncodingDefinition()23   Qnn_Definition_t GetEncodingDefinition() const {
24     return encoding_definition_;
25   };
GetQuantizationEncoding()26   Qnn_QuantizationEncoding_t GetQuantizationEncoding() const {
27     return quantization_encoding_;
28   };
29 
30   virtual std::unique_ptr<QuantizeParamsWrapper> Clone() = 0;
31   virtual ~QuantizeParamsWrapper() = default;
32 
33   QuantizeParamsWrapper(QuantizeParamsWrapper&& rhs) = default;
34   QuantizeParamsWrapper(const QuantizeParamsWrapper& rhs) = default;
35   QuantizeParamsWrapper& operator=(const QuantizeParamsWrapper& rhs) = default;
36   QuantizeParamsWrapper& operator=(QuantizeParamsWrapper&& rhs) = default;
37 
38  protected:
QuantizeParamsWrapper(Qnn_Definition_t encoding_definition,Qnn_QuantizationEncoding_t quantization_encoding)39   explicit QuantizeParamsWrapper(
40       Qnn_Definition_t encoding_definition,
41       Qnn_QuantizationEncoding_t quantization_encoding)
42       : encoding_definition_(encoding_definition),
43         quantization_encoding_(quantization_encoding) {}
44 
45  private:
46   Qnn_Definition_t encoding_definition_;
47   Qnn_QuantizationEncoding_t quantization_encoding_;
48 };
49 
50 class UndefinedQuantizeParamsWrapper final : public QuantizeParamsWrapper {
51  public:
UndefinedQuantizeParamsWrapper()52   UndefinedQuantizeParamsWrapper()
53       : QuantizeParamsWrapper(
54             QNN_DEFINITION_UNDEFINED,
55             QNN_QUANTIZATION_ENCODING_UNDEFINED) {}
UndefinedQuantizeParamsWrapper(const UndefinedQuantizeParamsWrapper & rhs)56   UndefinedQuantizeParamsWrapper(const UndefinedQuantizeParamsWrapper& rhs)
57       : QuantizeParamsWrapper(
58             rhs.GetEncodingDefinition(),
59             rhs.GetQuantizationEncoding()) {}
60   UndefinedQuantizeParamsWrapper(UndefinedQuantizeParamsWrapper&& rhs) = delete;
61   UndefinedQuantizeParamsWrapper& operator=(
62       const UndefinedQuantizeParamsWrapper& rhs) = delete;
63   UndefinedQuantizeParamsWrapper& operator=(
64       UndefinedQuantizeParamsWrapper&& rhs) = delete;
65 
66   ~UndefinedQuantizeParamsWrapper() override = default;
67 
Clone()68   std::unique_ptr<QuantizeParamsWrapper> Clone() override {
69     return std::make_unique<UndefinedQuantizeParamsWrapper>(*this);
70   }
71 
CreateQuantizeParams()72   Qnn_QuantizeParams_t CreateQuantizeParams() override {
73     Qnn_QuantizeParams_t rval = {
74         .encodingDefinition = GetEncodingDefinition(),
75         .quantizationEncoding = GetQuantizationEncoding()};
76     return rval;
77   }
78 };
79 
80 class BwAxisScaleOffsetQuantizeParamsWrapper final
81     : public QuantizeParamsWrapper {
82  public:
BwAxisScaleOffsetQuantizeParamsWrapper(std::uint32_t bitwidth,std::int32_t axis,std::uint32_t num_elements,std::vector<float> scales,std::vector<int32_t> offsets)83   explicit BwAxisScaleOffsetQuantizeParamsWrapper(
84       std::uint32_t bitwidth,
85       std::int32_t axis,
86       std::uint32_t num_elements,
87       std::vector<float> scales,
88       std::vector<int32_t> offsets)
89       : QuantizeParamsWrapper(
90             QNN_DEFINITION_DEFINED,
91             QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET),
92         bitwidth_(bitwidth),
93         axis_(axis),
94         num_elements_(num_elements),
95         scales_(scales),
96         offsets_(offsets) {}
97 
BwAxisScaleOffsetQuantizeParamsWrapper(const BwAxisScaleOffsetQuantizeParamsWrapper & rhs)98   BwAxisScaleOffsetQuantizeParamsWrapper(
99       const BwAxisScaleOffsetQuantizeParamsWrapper& rhs)
100       : QuantizeParamsWrapper(
101             rhs.GetEncodingDefinition(),
102             rhs.GetQuantizationEncoding()),
103         bitwidth_(rhs.bitwidth_),
104         axis_(rhs.axis_),
105         num_elements_(rhs.num_elements_),
106         scales_(rhs.scales_),
107         offsets_(rhs.offsets_) {}
108   BwAxisScaleOffsetQuantizeParamsWrapper(
109       BwAxisScaleOffsetQuantizeParamsWrapper&& rhs) = delete;
110   BwAxisScaleOffsetQuantizeParamsWrapper& operator=(
111       const BwAxisScaleOffsetQuantizeParamsWrapper& rhs) = delete;
112   BwAxisScaleOffsetQuantizeParamsWrapper& operator=(
113       BwAxisScaleOffsetQuantizeParamsWrapper&& rhs) = delete;
114 
115   ~BwAxisScaleOffsetQuantizeParamsWrapper() override = default;
116 
Clone()117   std::unique_ptr<QuantizeParamsWrapper> Clone() override {
118     return std::make_unique<BwAxisScaleOffsetQuantizeParamsWrapper>(*this);
119   }
120 
CreateQuantizeParams()121   Qnn_QuantizeParams_t CreateQuantizeParams() override {
122     Qnn_QuantizeParams_t rval;
123     rval.encodingDefinition = GetEncodingDefinition();
124     rval.quantizationEncoding = GetQuantizationEncoding();
125     rval.bwAxisScaleOffsetEncoding.bitwidth = bitwidth_;
126     rval.bwAxisScaleOffsetEncoding.axis = axis_;
127     rval.bwAxisScaleOffsetEncoding.numElements = num_elements_;
128     rval.bwAxisScaleOffsetEncoding.scales = scales_.data();
129     rval.bwAxisScaleOffsetEncoding.offsets = offsets_.data();
130     return rval;
131   }
132 
133  private:
134   std::uint32_t bitwidth_;
135   std::int32_t axis_;
136   std::uint32_t num_elements_;
137   std::vector<float> scales_;
138   std::vector<int32_t> offsets_;
139 };
140 
141 class BwScaleOffsetQuantizeParamsWrapper final : public QuantizeParamsWrapper {
142  public:
BwScaleOffsetQuantizeParamsWrapper(std::uint32_t bitwidth,float scale,std::int32_t offset)143   explicit BwScaleOffsetQuantizeParamsWrapper(
144       std::uint32_t bitwidth,
145       float scale,
146       std::int32_t offset)
147       : QuantizeParamsWrapper(
148             QNN_DEFINITION_DEFINED,
149             QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET),
150         bitwidth_(bitwidth),
151         scale_(scale),
152         offset_(offset) {}
153 
BwScaleOffsetQuantizeParamsWrapper(const BwScaleOffsetQuantizeParamsWrapper & rhs)154   BwScaleOffsetQuantizeParamsWrapper(
155       const BwScaleOffsetQuantizeParamsWrapper& rhs)
156       : QuantizeParamsWrapper(
157             rhs.GetEncodingDefinition(),
158             rhs.GetQuantizationEncoding()),
159         bitwidth_(rhs.bitwidth_),
160         scale_(rhs.scale_),
161         offset_(rhs.offset_) {}
162   BwScaleOffsetQuantizeParamsWrapper(BwScaleOffsetQuantizeParamsWrapper&& rhs) =
163       delete;
164   BwScaleOffsetQuantizeParamsWrapper& operator=(
165       const BwScaleOffsetQuantizeParamsWrapper& rhs) = delete;
166   BwScaleOffsetQuantizeParamsWrapper& operator=(
167       BwScaleOffsetQuantizeParamsWrapper&& rhs) = delete;
168 
169   ~BwScaleOffsetQuantizeParamsWrapper() override = default;
170 
Clone()171   std::unique_ptr<QuantizeParamsWrapper> Clone() override {
172     return std::make_unique<BwScaleOffsetQuantizeParamsWrapper>(*this);
173   }
174 
CreateQuantizeParams()175   Qnn_QuantizeParams_t CreateQuantizeParams() override {
176     Qnn_QuantizeParams_t rval;
177     rval.encodingDefinition = GetEncodingDefinition();
178     rval.quantizationEncoding = GetQuantizationEncoding();
179     rval.bwScaleOffsetEncoding.bitwidth = bitwidth_;
180     rval.bwScaleOffsetEncoding.scale = scale_;
181     rval.bwScaleOffsetEncoding.offset = offset_;
182     return rval;
183   }
184 
185  private:
186   std::uint32_t bitwidth_;
187   float scale_;
188   std::int32_t offset_;
189 };
190 
191 class ScaleOffsetQuantizeParamsWrapper final : public QuantizeParamsWrapper {
192  public:
ScaleOffsetQuantizeParamsWrapper(float scale,std::int32_t offset)193   explicit ScaleOffsetQuantizeParamsWrapper(float scale, std::int32_t offset)
194       : QuantizeParamsWrapper(
195             QNN_DEFINITION_DEFINED,
196             QNN_QUANTIZATION_ENCODING_SCALE_OFFSET),
197         scale_(scale),
198         offset_(offset) {}
199 
ScaleOffsetQuantizeParamsWrapper(const ScaleOffsetQuantizeParamsWrapper & rhs)200   ScaleOffsetQuantizeParamsWrapper(const ScaleOffsetQuantizeParamsWrapper& rhs)
201       : QuantizeParamsWrapper(
202             rhs.GetEncodingDefinition(),
203             rhs.GetQuantizationEncoding()),
204         scale_(rhs.scale_),
205         offset_(rhs.offset_) {}
206   ScaleOffsetQuantizeParamsWrapper(ScaleOffsetQuantizeParamsWrapper&& rhs) =
207       delete;
208   ScaleOffsetQuantizeParamsWrapper& operator=(
209       const ScaleOffsetQuantizeParamsWrapper& rhs) = delete;
210   ScaleOffsetQuantizeParamsWrapper& operator=(
211       ScaleOffsetQuantizeParamsWrapper&& rhs) = delete;
212 
213   ~ScaleOffsetQuantizeParamsWrapper() override = default;
214 
Clone()215   std::unique_ptr<QuantizeParamsWrapper> Clone() override {
216     return std::make_unique<ScaleOffsetQuantizeParamsWrapper>(*this);
217   }
218 
CreateQuantizeParams()219   Qnn_QuantizeParams_t CreateQuantizeParams() override {
220     Qnn_QuantizeParams_t rval;
221     rval.encodingDefinition = GetEncodingDefinition();
222     rval.quantizationEncoding = GetQuantizationEncoding();
223     rval.scaleOffsetEncoding.scale = scale_;
224     rval.scaleOffsetEncoding.offset = offset_;
225     return rval;
226   }
227 
228  private:
229   float scale_;
230   std::int32_t offset_;
231 };
232 
233 class AxisScaleOffsetQuantizeParamsWrapper final
234     : public QuantizeParamsWrapper {
235  public:
AxisScaleOffsetQuantizeParamsWrapper(std::int32_t axis,const std::vector<Qnn_ScaleOffset_t> & scale_offsets)236   explicit AxisScaleOffsetQuantizeParamsWrapper(
237       std::int32_t axis,
238       const std::vector<Qnn_ScaleOffset_t>& scale_offsets)
239       : QuantizeParamsWrapper(
240             QNN_DEFINITION_DEFINED,
241             QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET),
242         axis_(axis),
243         scale_offsets_(scale_offsets) {}
244 
AxisScaleOffsetQuantizeParamsWrapper(const AxisScaleOffsetQuantizeParamsWrapper & rhs)245   AxisScaleOffsetQuantizeParamsWrapper(
246       const AxisScaleOffsetQuantizeParamsWrapper& rhs)
247       : QuantizeParamsWrapper(
248             rhs.GetEncodingDefinition(),
249             rhs.GetQuantizationEncoding()),
250         axis_(rhs.axis_),
251         scale_offsets_(rhs.scale_offsets_) {}
252   AxisScaleOffsetQuantizeParamsWrapper(
253       AxisScaleOffsetQuantizeParamsWrapper&& rhs) = delete;
254   AxisScaleOffsetQuantizeParamsWrapper& operator=(
255       const AxisScaleOffsetQuantizeParamsWrapper& rhs) = delete;
256   AxisScaleOffsetQuantizeParamsWrapper& operator=(
257       AxisScaleOffsetQuantizeParamsWrapper&& rhs) = delete;
258 
259   ~AxisScaleOffsetQuantizeParamsWrapper() override = default;
260 
SetAxis(std::int32_t axis)261   void SetAxis(std::int32_t axis) {
262     axis_ = axis;
263   }
264 
Clone()265   std::unique_ptr<QuantizeParamsWrapper> Clone() override {
266     return std::make_unique<AxisScaleOffsetQuantizeParamsWrapper>(*this);
267   }
268 
CreateQuantizeParams()269   Qnn_QuantizeParams_t CreateQuantizeParams() override {
270     Qnn_QuantizeParams_t rval;
271     rval.encodingDefinition = GetEncodingDefinition();
272     rval.quantizationEncoding = GetQuantizationEncoding();
273     rval.axisScaleOffsetEncoding.axis = axis_;
274     rval.axisScaleOffsetEncoding.numScaleOffsets = scale_offsets_.size();
275     rval.axisScaleOffsetEncoding.scaleOffset = scale_offsets_.data();
276     return rval;
277   }
278 
279  private:
280   std::int32_t axis_;
281   std::vector<Qnn_ScaleOffset_t> scale_offsets_;
282 };
283 
284 // Factory function to create quantization param wrapper from QnnQuantization
285 std::unique_ptr<QuantizeParamsWrapper> CreateQuantizationParamWrapper(
286     const Qnn_QuantizeParams_t& quantization);
287 } // namespace qnn
288 } // namespace backends
289 } // namespace executorch
290