xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/fc-dynamic-run.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <pytorch_qnnpack.h>
2 #include <qnnpack_func.h>
3 #include <cstring>
4 
5 namespace qnnpack {
6 struct q8gemm_dq_context {
7   size_t k;
8   size_t k_stride;
9   size_t n;
10   size_t n_stride;
11   const uint8_t* a;
12   size_t a_stride;
13   const uint8_t* packed_w;
14   const float* bias;
15   float* c;
16   size_t c_stride;
17   struct pytorch_qnnp_conv_dynamic_quantization_params quantization_params;
18   const pytorch_q8gemm_dq_ukernel_function ukernel;
19 };
20 
compute_q8gemm_dq(const struct q8gemm_dq_context * context,size_t group_index,size_t pixel_index,size_t mr_block_start,size_t nr_block_start,size_t group_range,size_t pixel_range,size_t mr_block_size,size_t nr_block_size)21 static void compute_q8gemm_dq(
22     const struct q8gemm_dq_context* context,
23     size_t group_index,
24     size_t pixel_index,
25     size_t mr_block_start,
26     size_t nr_block_start,
27     size_t group_range /* always 1 */,
28     size_t pixel_range,
29     size_t mr_block_size,
30     size_t nr_block_size) {
31   const size_t k = context->k;
32   const size_t k_stride = context->k_stride;
33   const size_t n = context->n;
34   const size_t n_stride = context->n_stride;
35   const uint8_t* a = context->a;
36   const size_t a_stride = context->a_stride;
37   const void* packed_w = context->packed_w;
38   float* c = context->c;
39   const size_t c_stride = context->c_stride;
40   const float* bias = context->bias;
41 
42   size_t output_channel_index = nr_block_start;
43   context->ukernel(
44       mr_block_size,
45       nr_block_size,
46       k,
47       a + (pixel_index + mr_block_start) * a_stride + group_index * k,
48       a_stride,
49       (const void*)((uintptr_t)packed_w + (nr_block_start + group_index * n_stride) * (k_stride * sizeof(uint8_t) + sizeof(int32_t))),
50       bias + nr_block_start,
51       c + (pixel_index + mr_block_start) * c_stride + nr_block_start +
52           group_index * n,
53       c_stride,
54       output_channel_index,
55       &context->quantization_params);
56 }
57 
qnnpackLinearDynamic(const size_t batch_size,const size_t input_channels,const size_t output_channels,const uint8_t input_zero_point,const uint8_t * kernel_zero_points,const float * dequantization_scales,const uint8_t * input,const size_t input_stride,void * packed_weights,const float * bias,float * output,const size_t output_stride,pthreadpool_t threadpool)58 enum pytorch_qnnp_status qnnpackLinearDynamic(
59     const size_t batch_size,
60     const size_t input_channels,
61     const size_t output_channels,
62     const uint8_t input_zero_point,
63     const uint8_t* kernel_zero_points,
64     const float* dequantization_scales,
65     const uint8_t* input,
66     const size_t input_stride,
67     void* packed_weights,
68     const float* bias,
69     float* output,
70     const size_t output_stride,
71     pthreadpool_t threadpool) {
72   const size_t groups = 1;
73   const size_t group_input_channels = input_channels;
74   const size_t group_output_channels = output_channels;
75   const uint32_t mr = pytorch_qnnp_params.q8conv.mr;
76   const uint32_t nr = pytorch_qnnp_params.q8conv.nr;
77   const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
78   const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
79   const size_t n_stride = (group_output_channels + (nr - 1)) & -nr;
80 
81   const size_t output_size = batch_size * 1;
82 
83   const struct pytorch_qnnp_conv_dynamic_quantization_params
84       quantizationParams {
85     input_zero_point, kernel_zero_points, dequantization_scales,
86   };
87 
88   struct q8gemm_dq_context q8gemm_dq_context = {
89       .k = group_input_channels,
90       .k_stride = k_stride,
91       .n = group_output_channels,
92       .n_stride = n_stride,
93       .a = input,
94       .a_stride = input_stride,
95       .packed_w = (uint8_t*)packed_weights,
96       .bias = bias,
97       .c = output,
98       .c_stride = output_stride,
99       .quantization_params = quantizationParams,
100       .ukernel = pytorch_qnnp_params.q8conv.gemm_dq,
101   };
102 
103   if (output_size == 0) {
104       // pthreadpool can tolerate a range of 0, but not a tile of 0.
105       // We use output_size as a tile size, so bail here if it's 0.
106       return pytorch_qnnp_status_success;
107   }
108 
109   pthreadpool_compute_4d_tiled(
110       threadpool,
111       (pthreadpool_function_4d_tiled_t)compute_q8gemm_dq,
112       &q8gemm_dq_context,
113       groups,
114       1 * output_size,
115       output_size,
116       group_output_channels,
117       1,
118       output_size,
119       mr,
120       nr);
121 
122   return pytorch_qnnp_status_success;
123 }
124 } // namespace qnnpack
125