1 #pragma once 2 3 #include <cstdlib> 4 #include <qnnpack/operator.h> 5 6 namespace qnnpack { 7 class PrePackConvWeights final { 8 public: 9 PrePackConvWeights( 10 const pytorch_qnnp_operator_t convolution, 11 const uint8_t* kernel_zero_points, 12 const uint8_t* kernel, 13 const int32_t* bias); 14 getPackedWeights()15 void* getPackedWeights() const 16 { 17 return packed_weights_; 18 } 19 getOutputChannels()20 int64_t getOutputChannels() const 21 { 22 return output_channels_; 23 } 24 ~PrePackConvWeights()25 ~PrePackConvWeights() 26 { 27 if (packed_weights_ != nullptr) { 28 free(packed_weights_); 29 } 30 } 31 32 PrePackConvWeights() = delete; 33 PrePackConvWeights(const PrePackConvWeights&) = delete; 34 PrePackConvWeights& operator=(const PrePackConvWeights&) = delete; 35 36 private: 37 void* packed_weights_ = nullptr; 38 int64_t output_channels_; 39 }; 40 41 class PackBMatrix final { 42 public: 43 PackBMatrix( 44 size_t input_channels, 45 size_t output_channels, 46 const uint8_t* kernel_zero_points, 47 const float* requantization_scale, 48 const uint8_t* kernel, 49 const int32_t* bias); 50 51 // This constructor is to be used for dynamic mode 52 // quantization. In dynamic mode, we dont yet support 53 // per channel quantization, and paying the cost of 54 // memory allocation for per channel zero point and 55 // requant scale will hurt performance. 56 PackBMatrix( 57 size_t input_channels, 58 size_t output_channels, 59 const uint8_t kernel_zero_point, 60 const float requantization_scale, 61 const uint8_t* kernel, 62 const int32_t* bias); 63 getPackedWeights()64 void* getPackedWeights() const 65 { 66 return packed_weights_; 67 } 68 69 void unpackWeights( 70 const uint8_t* kernel_zero_points, 71 int8_t* kernel 72 ) const; 73 getInputChannels()74 size_t getInputChannels() const 75 { 76 return input_channels_; 77 } 78 getOutputChannels()79 size_t getOutputChannels() const 80 { 81 return output_channels_; 82 } 83 ~PackBMatrix()84 ~PackBMatrix() 85 { 86 if (packed_weights_ != nullptr) { 87 free(packed_weights_); 88 } 89 } 90 91 PackBMatrix() = delete; 92 PackBMatrix(const PackBMatrix&) = delete; 93 PackBMatrix& operator=(const PackBMatrix&) = delete; 94 95 private: 96 void* packed_weights_ = nullptr; 97 size_t input_channels_; 98 size_t output_channels_; 99 }; 100 101 enum pytorch_qnnp_status qnnpackLinear( 102 const size_t batch_size, 103 const size_t input_channels, 104 const size_t output_channels, 105 const uint8_t input_zero_point, 106 const uint8_t* kernel_zero_points, 107 const float* requantization_scales, 108 const uint8_t output_zero_point, 109 const uint8_t output_min, 110 const uint8_t output_max, 111 const uint8_t* input, 112 const size_t input_stride, 113 void* packed_weights, 114 uint8_t* output, 115 const size_t output_stride, 116 pthreadpool_t threadpool); 117 118 enum pytorch_qnnp_status qnnpackConv( 119 const pytorch_qnnp_operator_t convolution, 120 void* packed_weights, 121 const size_t batch_size, 122 const size_t input_depth, 123 const size_t input_height, 124 const size_t input_width, 125 const uint8_t input_zero_point, 126 const uint8_t* input, 127 const uint8_t* kernel_zero_points, 128 const float* requantization_scales, 129 const uint8_t output_zero_point, 130 const uint8_t output_min, 131 const uint8_t output_max, 132 uint8_t* output, 133 pthreadpool_t threadpool); 134 135 enum pytorch_qnnp_status qnnpackDeConv( 136 const pytorch_qnnp_operator_t deconvolution, 137 void* packed_weights, 138 const size_t batch_size, 139 const size_t input_height, 140 const size_t input_width, 141 const uint8_t input_zero_point, 142 const uint8_t* input, 143 const uint8_t* kernel_zero_points, 144 const float* requantization_scales, 145 const uint8_t output_zero_point, 146 const uint8_t output_min, 147 const uint8_t output_max, 148 uint8_t* output, 149 pthreadpool_t threadpool); 150 151 enum pytorch_qnnp_status qnnpackLinearDynamic( 152 const size_t batch_size, 153 const size_t input_channels, 154 const size_t output_channels, 155 const uint8_t input_zero_point, 156 const uint8_t* kernel_zero_points, 157 const float* dequantization_scales, 158 const uint8_t* input, 159 const size_t input_stride, 160 void* packed_weights, 161 const float* bias, 162 float* output, 163 const size_t output_stride, 164 pthreadpool_t threadpool); 165 166 } // namespace qnnpack 167