1 /* 2 * Copyright (c) Qualcomm Innovation Center, Inc. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 #pragma once 9 10 #include <cstdint> 11 #include <memory> 12 #include <vector> 13 14 #include "QnnTypes.h" 15 namespace executorch { 16 namespace backends { 17 namespace qnn { 18 class QuantizeParamsWrapper { 19 public: 20 // To create the QuantizeParams_t using data from this class: 21 virtual Qnn_QuantizeParams_t CreateQuantizeParams() = 0; 22 // Other accessors: GetEncodingDefinition()23 Qnn_Definition_t GetEncodingDefinition() const { 24 return encoding_definition_; 25 }; GetQuantizationEncoding()26 Qnn_QuantizationEncoding_t GetQuantizationEncoding() const { 27 return quantization_encoding_; 28 }; 29 30 virtual std::unique_ptr<QuantizeParamsWrapper> Clone() = 0; 31 virtual ~QuantizeParamsWrapper() = default; 32 33 QuantizeParamsWrapper(QuantizeParamsWrapper&& rhs) = default; 34 QuantizeParamsWrapper(const QuantizeParamsWrapper& rhs) = default; 35 QuantizeParamsWrapper& operator=(const QuantizeParamsWrapper& rhs) = default; 36 QuantizeParamsWrapper& operator=(QuantizeParamsWrapper&& rhs) = default; 37 38 protected: QuantizeParamsWrapper(Qnn_Definition_t encoding_definition,Qnn_QuantizationEncoding_t quantization_encoding)39 explicit QuantizeParamsWrapper( 40 Qnn_Definition_t encoding_definition, 41 Qnn_QuantizationEncoding_t quantization_encoding) 42 : encoding_definition_(encoding_definition), 43 quantization_encoding_(quantization_encoding) {} 44 45 private: 46 Qnn_Definition_t encoding_definition_; 47 Qnn_QuantizationEncoding_t quantization_encoding_; 48 }; 49 50 class UndefinedQuantizeParamsWrapper final : public QuantizeParamsWrapper { 51 public: UndefinedQuantizeParamsWrapper()52 UndefinedQuantizeParamsWrapper() 53 : QuantizeParamsWrapper( 54 QNN_DEFINITION_UNDEFINED, 55 QNN_QUANTIZATION_ENCODING_UNDEFINED) {} UndefinedQuantizeParamsWrapper(const UndefinedQuantizeParamsWrapper & rhs)56 UndefinedQuantizeParamsWrapper(const UndefinedQuantizeParamsWrapper& rhs) 57 : QuantizeParamsWrapper( 58 rhs.GetEncodingDefinition(), 59 rhs.GetQuantizationEncoding()) {} 60 UndefinedQuantizeParamsWrapper(UndefinedQuantizeParamsWrapper&& rhs) = delete; 61 UndefinedQuantizeParamsWrapper& operator=( 62 const UndefinedQuantizeParamsWrapper& rhs) = delete; 63 UndefinedQuantizeParamsWrapper& operator=( 64 UndefinedQuantizeParamsWrapper&& rhs) = delete; 65 66 ~UndefinedQuantizeParamsWrapper() override = default; 67 Clone()68 std::unique_ptr<QuantizeParamsWrapper> Clone() override { 69 return std::make_unique<UndefinedQuantizeParamsWrapper>(*this); 70 } 71 CreateQuantizeParams()72 Qnn_QuantizeParams_t CreateQuantizeParams() override { 73 Qnn_QuantizeParams_t rval = { 74 .encodingDefinition = GetEncodingDefinition(), 75 .quantizationEncoding = GetQuantizationEncoding()}; 76 return rval; 77 } 78 }; 79 80 class BwAxisScaleOffsetQuantizeParamsWrapper final 81 : public QuantizeParamsWrapper { 82 public: BwAxisScaleOffsetQuantizeParamsWrapper(std::uint32_t bitwidth,std::int32_t axis,std::uint32_t num_elements,std::vector<float> scales,std::vector<int32_t> offsets)83 explicit BwAxisScaleOffsetQuantizeParamsWrapper( 84 std::uint32_t bitwidth, 85 std::int32_t axis, 86 std::uint32_t num_elements, 87 std::vector<float> scales, 88 std::vector<int32_t> offsets) 89 : QuantizeParamsWrapper( 90 QNN_DEFINITION_DEFINED, 91 QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET), 92 bitwidth_(bitwidth), 93 axis_(axis), 94 num_elements_(num_elements), 95 scales_(scales), 96 offsets_(offsets) {} 97 BwAxisScaleOffsetQuantizeParamsWrapper(const BwAxisScaleOffsetQuantizeParamsWrapper & rhs)98 BwAxisScaleOffsetQuantizeParamsWrapper( 99 const BwAxisScaleOffsetQuantizeParamsWrapper& rhs) 100 : QuantizeParamsWrapper( 101 rhs.GetEncodingDefinition(), 102 rhs.GetQuantizationEncoding()), 103 bitwidth_(rhs.bitwidth_), 104 axis_(rhs.axis_), 105 num_elements_(rhs.num_elements_), 106 scales_(rhs.scales_), 107 offsets_(rhs.offsets_) {} 108 BwAxisScaleOffsetQuantizeParamsWrapper( 109 BwAxisScaleOffsetQuantizeParamsWrapper&& rhs) = delete; 110 BwAxisScaleOffsetQuantizeParamsWrapper& operator=( 111 const BwAxisScaleOffsetQuantizeParamsWrapper& rhs) = delete; 112 BwAxisScaleOffsetQuantizeParamsWrapper& operator=( 113 BwAxisScaleOffsetQuantizeParamsWrapper&& rhs) = delete; 114 115 ~BwAxisScaleOffsetQuantizeParamsWrapper() override = default; 116 Clone()117 std::unique_ptr<QuantizeParamsWrapper> Clone() override { 118 return std::make_unique<BwAxisScaleOffsetQuantizeParamsWrapper>(*this); 119 } 120 CreateQuantizeParams()121 Qnn_QuantizeParams_t CreateQuantizeParams() override { 122 Qnn_QuantizeParams_t rval; 123 rval.encodingDefinition = GetEncodingDefinition(); 124 rval.quantizationEncoding = GetQuantizationEncoding(); 125 rval.bwAxisScaleOffsetEncoding.bitwidth = bitwidth_; 126 rval.bwAxisScaleOffsetEncoding.axis = axis_; 127 rval.bwAxisScaleOffsetEncoding.numElements = num_elements_; 128 rval.bwAxisScaleOffsetEncoding.scales = scales_.data(); 129 rval.bwAxisScaleOffsetEncoding.offsets = offsets_.data(); 130 return rval; 131 } 132 133 private: 134 std::uint32_t bitwidth_; 135 std::int32_t axis_; 136 std::uint32_t num_elements_; 137 std::vector<float> scales_; 138 std::vector<int32_t> offsets_; 139 }; 140 141 class BwScaleOffsetQuantizeParamsWrapper final : public QuantizeParamsWrapper { 142 public: BwScaleOffsetQuantizeParamsWrapper(std::uint32_t bitwidth,float scale,std::int32_t offset)143 explicit BwScaleOffsetQuantizeParamsWrapper( 144 std::uint32_t bitwidth, 145 float scale, 146 std::int32_t offset) 147 : QuantizeParamsWrapper( 148 QNN_DEFINITION_DEFINED, 149 QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET), 150 bitwidth_(bitwidth), 151 scale_(scale), 152 offset_(offset) {} 153 BwScaleOffsetQuantizeParamsWrapper(const BwScaleOffsetQuantizeParamsWrapper & rhs)154 BwScaleOffsetQuantizeParamsWrapper( 155 const BwScaleOffsetQuantizeParamsWrapper& rhs) 156 : QuantizeParamsWrapper( 157 rhs.GetEncodingDefinition(), 158 rhs.GetQuantizationEncoding()), 159 bitwidth_(rhs.bitwidth_), 160 scale_(rhs.scale_), 161 offset_(rhs.offset_) {} 162 BwScaleOffsetQuantizeParamsWrapper(BwScaleOffsetQuantizeParamsWrapper&& rhs) = 163 delete; 164 BwScaleOffsetQuantizeParamsWrapper& operator=( 165 const BwScaleOffsetQuantizeParamsWrapper& rhs) = delete; 166 BwScaleOffsetQuantizeParamsWrapper& operator=( 167 BwScaleOffsetQuantizeParamsWrapper&& rhs) = delete; 168 169 ~BwScaleOffsetQuantizeParamsWrapper() override = default; 170 Clone()171 std::unique_ptr<QuantizeParamsWrapper> Clone() override { 172 return std::make_unique<BwScaleOffsetQuantizeParamsWrapper>(*this); 173 } 174 CreateQuantizeParams()175 Qnn_QuantizeParams_t CreateQuantizeParams() override { 176 Qnn_QuantizeParams_t rval; 177 rval.encodingDefinition = GetEncodingDefinition(); 178 rval.quantizationEncoding = GetQuantizationEncoding(); 179 rval.bwScaleOffsetEncoding.bitwidth = bitwidth_; 180 rval.bwScaleOffsetEncoding.scale = scale_; 181 rval.bwScaleOffsetEncoding.offset = offset_; 182 return rval; 183 } 184 185 private: 186 std::uint32_t bitwidth_; 187 float scale_; 188 std::int32_t offset_; 189 }; 190 191 class ScaleOffsetQuantizeParamsWrapper final : public QuantizeParamsWrapper { 192 public: ScaleOffsetQuantizeParamsWrapper(float scale,std::int32_t offset)193 explicit ScaleOffsetQuantizeParamsWrapper(float scale, std::int32_t offset) 194 : QuantizeParamsWrapper( 195 QNN_DEFINITION_DEFINED, 196 QNN_QUANTIZATION_ENCODING_SCALE_OFFSET), 197 scale_(scale), 198 offset_(offset) {} 199 ScaleOffsetQuantizeParamsWrapper(const ScaleOffsetQuantizeParamsWrapper & rhs)200 ScaleOffsetQuantizeParamsWrapper(const ScaleOffsetQuantizeParamsWrapper& rhs) 201 : QuantizeParamsWrapper( 202 rhs.GetEncodingDefinition(), 203 rhs.GetQuantizationEncoding()), 204 scale_(rhs.scale_), 205 offset_(rhs.offset_) {} 206 ScaleOffsetQuantizeParamsWrapper(ScaleOffsetQuantizeParamsWrapper&& rhs) = 207 delete; 208 ScaleOffsetQuantizeParamsWrapper& operator=( 209 const ScaleOffsetQuantizeParamsWrapper& rhs) = delete; 210 ScaleOffsetQuantizeParamsWrapper& operator=( 211 ScaleOffsetQuantizeParamsWrapper&& rhs) = delete; 212 213 ~ScaleOffsetQuantizeParamsWrapper() override = default; 214 Clone()215 std::unique_ptr<QuantizeParamsWrapper> Clone() override { 216 return std::make_unique<ScaleOffsetQuantizeParamsWrapper>(*this); 217 } 218 CreateQuantizeParams()219 Qnn_QuantizeParams_t CreateQuantizeParams() override { 220 Qnn_QuantizeParams_t rval; 221 rval.encodingDefinition = GetEncodingDefinition(); 222 rval.quantizationEncoding = GetQuantizationEncoding(); 223 rval.scaleOffsetEncoding.scale = scale_; 224 rval.scaleOffsetEncoding.offset = offset_; 225 return rval; 226 } 227 228 private: 229 float scale_; 230 std::int32_t offset_; 231 }; 232 233 class AxisScaleOffsetQuantizeParamsWrapper final 234 : public QuantizeParamsWrapper { 235 public: AxisScaleOffsetQuantizeParamsWrapper(std::int32_t axis,const std::vector<Qnn_ScaleOffset_t> & scale_offsets)236 explicit AxisScaleOffsetQuantizeParamsWrapper( 237 std::int32_t axis, 238 const std::vector<Qnn_ScaleOffset_t>& scale_offsets) 239 : QuantizeParamsWrapper( 240 QNN_DEFINITION_DEFINED, 241 QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET), 242 axis_(axis), 243 scale_offsets_(scale_offsets) {} 244 AxisScaleOffsetQuantizeParamsWrapper(const AxisScaleOffsetQuantizeParamsWrapper & rhs)245 AxisScaleOffsetQuantizeParamsWrapper( 246 const AxisScaleOffsetQuantizeParamsWrapper& rhs) 247 : QuantizeParamsWrapper( 248 rhs.GetEncodingDefinition(), 249 rhs.GetQuantizationEncoding()), 250 axis_(rhs.axis_), 251 scale_offsets_(rhs.scale_offsets_) {} 252 AxisScaleOffsetQuantizeParamsWrapper( 253 AxisScaleOffsetQuantizeParamsWrapper&& rhs) = delete; 254 AxisScaleOffsetQuantizeParamsWrapper& operator=( 255 const AxisScaleOffsetQuantizeParamsWrapper& rhs) = delete; 256 AxisScaleOffsetQuantizeParamsWrapper& operator=( 257 AxisScaleOffsetQuantizeParamsWrapper&& rhs) = delete; 258 259 ~AxisScaleOffsetQuantizeParamsWrapper() override = default; 260 SetAxis(std::int32_t axis)261 void SetAxis(std::int32_t axis) { 262 axis_ = axis; 263 } 264 Clone()265 std::unique_ptr<QuantizeParamsWrapper> Clone() override { 266 return std::make_unique<AxisScaleOffsetQuantizeParamsWrapper>(*this); 267 } 268 CreateQuantizeParams()269 Qnn_QuantizeParams_t CreateQuantizeParams() override { 270 Qnn_QuantizeParams_t rval; 271 rval.encodingDefinition = GetEncodingDefinition(); 272 rval.quantizationEncoding = GetQuantizationEncoding(); 273 rval.axisScaleOffsetEncoding.axis = axis_; 274 rval.axisScaleOffsetEncoding.numScaleOffsets = scale_offsets_.size(); 275 rval.axisScaleOffsetEncoding.scaleOffset = scale_offsets_.data(); 276 return rval; 277 } 278 279 private: 280 std::int32_t axis_; 281 std::vector<Qnn_ScaleOffset_t> scale_offsets_; 282 }; 283 284 // Factory function to create quantization param wrapper from QnnQuantization 285 std::unique_ptr<QuantizeParamsWrapper> CreateQuantizationParamWrapper( 286 const Qnn_QuantizeParams_t& quantization); 287 } // namespace qnn 288 } // namespace backends 289 } // namespace executorch 290