xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/include/qnnpack_func.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstdlib>
4 #include <qnnpack/operator.h>
5 
6 namespace qnnpack {
7 class PrePackConvWeights final {
8  public:
9   PrePackConvWeights(
10       const pytorch_qnnp_operator_t convolution,
11       const uint8_t* kernel_zero_points,
12       const uint8_t* kernel,
13       const int32_t* bias);
14 
getPackedWeights()15   void* getPackedWeights() const
16   {
17     return packed_weights_;
18   }
19 
getOutputChannels()20   int64_t getOutputChannels() const
21   {
22     return output_channels_;
23   }
24 
~PrePackConvWeights()25   ~PrePackConvWeights()
26   {
27     if (packed_weights_ != nullptr) {
28       free(packed_weights_);
29     }
30   }
31 
32   PrePackConvWeights() = delete;
33   PrePackConvWeights(const PrePackConvWeights&) = delete;
34   PrePackConvWeights& operator=(const PrePackConvWeights&) = delete;
35 
36  private:
37   void* packed_weights_ = nullptr;
38   int64_t output_channels_;
39 };
40 
41 class PackBMatrix final {
42  public:
43   PackBMatrix(
44       size_t input_channels,
45       size_t output_channels,
46       const uint8_t* kernel_zero_points,
47       const float* requantization_scale,
48       const uint8_t* kernel,
49       const int32_t* bias);
50 
51   // This constructor is to be used for dynamic mode
52   // quantization. In dynamic mode, we dont yet support
53   // per channel quantization, and paying the cost of
54   // memory allocation for per channel zero point and
55   // requant scale will hurt performance.
56   PackBMatrix(
57       size_t input_channels,
58       size_t output_channels,
59       const uint8_t kernel_zero_point,
60       const float requantization_scale,
61       const uint8_t* kernel,
62       const int32_t* bias);
63 
getPackedWeights()64   void* getPackedWeights() const
65   {
66     return packed_weights_;
67   }
68 
69   void unpackWeights(
70       const uint8_t* kernel_zero_points,
71       int8_t* kernel
72     ) const;
73 
getInputChannels()74   size_t getInputChannels() const
75   {
76     return input_channels_;
77   }
78 
getOutputChannels()79   size_t getOutputChannels() const
80   {
81     return output_channels_;
82   }
83 
~PackBMatrix()84   ~PackBMatrix()
85   {
86     if (packed_weights_ != nullptr) {
87       free(packed_weights_);
88     }
89   }
90 
91   PackBMatrix() = delete;
92   PackBMatrix(const PackBMatrix&) = delete;
93   PackBMatrix& operator=(const PackBMatrix&) = delete;
94 
95  private:
96   void* packed_weights_ = nullptr;
97   size_t input_channels_;
98   size_t output_channels_;
99 };
100 
101 enum pytorch_qnnp_status qnnpackLinear(
102     const size_t batch_size,
103     const size_t input_channels,
104     const size_t output_channels,
105     const uint8_t input_zero_point,
106     const uint8_t* kernel_zero_points,
107     const float* requantization_scales,
108     const uint8_t output_zero_point,
109     const uint8_t output_min,
110     const uint8_t output_max,
111     const uint8_t* input,
112     const size_t input_stride,
113     void* packed_weights,
114     uint8_t* output,
115     const size_t output_stride,
116     pthreadpool_t threadpool);
117 
118 enum pytorch_qnnp_status qnnpackConv(
119     const pytorch_qnnp_operator_t convolution,
120     void* packed_weights,
121     const size_t batch_size,
122     const size_t input_depth,
123     const size_t input_height,
124     const size_t input_width,
125     const uint8_t input_zero_point,
126     const uint8_t* input,
127     const uint8_t* kernel_zero_points,
128     const float* requantization_scales,
129     const uint8_t output_zero_point,
130     const uint8_t output_min,
131     const uint8_t output_max,
132     uint8_t* output,
133     pthreadpool_t threadpool);
134 
135 enum pytorch_qnnp_status qnnpackDeConv(
136     const pytorch_qnnp_operator_t deconvolution,
137     void* packed_weights,
138     const size_t batch_size,
139     const size_t input_height,
140     const size_t input_width,
141     const uint8_t input_zero_point,
142     const uint8_t* input,
143     const uint8_t* kernel_zero_points,
144     const float* requantization_scales,
145     const uint8_t output_zero_point,
146     const uint8_t output_min,
147     const uint8_t output_max,
148     uint8_t* output,
149     pthreadpool_t threadpool);
150 
151 enum pytorch_qnnp_status qnnpackLinearDynamic(
152     const size_t batch_size,
153     const size_t input_channels,
154     const size_t output_channels,
155     const uint8_t input_zero_point,
156     const uint8_t* kernel_zero_points,
157     const float* dequantization_scales,
158     const uint8_t* input,
159     const size_t input_stride,
160     void* packed_weights,
161     const float* bias,
162     float* output,
163     const size_t output_stride,
164     pthreadpool_t threadpool);
165 
166 } // namespace qnnpack
167