1 #include <pytorch_qnnpack.h>
2 #include <qnnpack_func.h>
3 #include <cstring>
4
5 namespace qnnpack {
6 struct q8gemm_dq_context {
7 size_t k;
8 size_t k_stride;
9 size_t n;
10 size_t n_stride;
11 const uint8_t* a;
12 size_t a_stride;
13 const uint8_t* packed_w;
14 const float* bias;
15 float* c;
16 size_t c_stride;
17 struct pytorch_qnnp_conv_dynamic_quantization_params quantization_params;
18 const pytorch_q8gemm_dq_ukernel_function ukernel;
19 };
20
compute_q8gemm_dq(const struct q8gemm_dq_context * context,size_t group_index,size_t pixel_index,size_t mr_block_start,size_t nr_block_start,size_t group_range,size_t pixel_range,size_t mr_block_size,size_t nr_block_size)21 static void compute_q8gemm_dq(
22 const struct q8gemm_dq_context* context,
23 size_t group_index,
24 size_t pixel_index,
25 size_t mr_block_start,
26 size_t nr_block_start,
27 size_t group_range /* always 1 */,
28 size_t pixel_range,
29 size_t mr_block_size,
30 size_t nr_block_size) {
31 const size_t k = context->k;
32 const size_t k_stride = context->k_stride;
33 const size_t n = context->n;
34 const size_t n_stride = context->n_stride;
35 const uint8_t* a = context->a;
36 const size_t a_stride = context->a_stride;
37 const void* packed_w = context->packed_w;
38 float* c = context->c;
39 const size_t c_stride = context->c_stride;
40 const float* bias = context->bias;
41
42 size_t output_channel_index = nr_block_start;
43 context->ukernel(
44 mr_block_size,
45 nr_block_size,
46 k,
47 a + (pixel_index + mr_block_start) * a_stride + group_index * k,
48 a_stride,
49 (const void*)((uintptr_t)packed_w + (nr_block_start + group_index * n_stride) * (k_stride * sizeof(uint8_t) + sizeof(int32_t))),
50 bias + nr_block_start,
51 c + (pixel_index + mr_block_start) * c_stride + nr_block_start +
52 group_index * n,
53 c_stride,
54 output_channel_index,
55 &context->quantization_params);
56 }
57
qnnpackLinearDynamic(const size_t batch_size,const size_t input_channels,const size_t output_channels,const uint8_t input_zero_point,const uint8_t * kernel_zero_points,const float * dequantization_scales,const uint8_t * input,const size_t input_stride,void * packed_weights,const float * bias,float * output,const size_t output_stride,pthreadpool_t threadpool)58 enum pytorch_qnnp_status qnnpackLinearDynamic(
59 const size_t batch_size,
60 const size_t input_channels,
61 const size_t output_channels,
62 const uint8_t input_zero_point,
63 const uint8_t* kernel_zero_points,
64 const float* dequantization_scales,
65 const uint8_t* input,
66 const size_t input_stride,
67 void* packed_weights,
68 const float* bias,
69 float* output,
70 const size_t output_stride,
71 pthreadpool_t threadpool) {
72 const size_t groups = 1;
73 const size_t group_input_channels = input_channels;
74 const size_t group_output_channels = output_channels;
75 const uint32_t mr = pytorch_qnnp_params.q8conv.mr;
76 const uint32_t nr = pytorch_qnnp_params.q8conv.nr;
77 const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
78 const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
79 const size_t n_stride = (group_output_channels + (nr - 1)) & -nr;
80
81 const size_t output_size = batch_size * 1;
82
83 const struct pytorch_qnnp_conv_dynamic_quantization_params
84 quantizationParams {
85 input_zero_point, kernel_zero_points, dequantization_scales,
86 };
87
88 struct q8gemm_dq_context q8gemm_dq_context = {
89 .k = group_input_channels,
90 .k_stride = k_stride,
91 .n = group_output_channels,
92 .n_stride = n_stride,
93 .a = input,
94 .a_stride = input_stride,
95 .packed_w = (uint8_t*)packed_weights,
96 .bias = bias,
97 .c = output,
98 .c_stride = output_stride,
99 .quantization_params = quantizationParams,
100 .ukernel = pytorch_qnnp_params.q8conv.gemm_dq,
101 };
102
103 if (output_size == 0) {
104 // pthreadpool can tolerate a range of 0, but not a tile of 0.
105 // We use output_size as a tile size, so bail here if it's 0.
106 return pytorch_qnnp_status_success;
107 }
108
109 pthreadpool_compute_4d_tiled(
110 threadpool,
111 (pthreadpool_function_4d_tiled_t)compute_q8gemm_dq,
112 &q8gemm_dq_context,
113 groups,
114 1 * output_size,
115 output_size,
116 group_output_channels,
117 1,
118 output_size,
119 mr,
120 nr);
121
122 return pytorch_qnnp_status_success;
123 }
124 } // namespace qnnpack
125