xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/fully-connected-sparse.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 #include <math.h>
11 #include <stdbool.h>
12 #include <stddef.h>
13 #include <stdint.h>
14 #include <string.h>
15 
16 #include <pytorch_qnnpack.h>
17 #include <qnnpack/log.h>
18 #include <qnnpack/math.h>
19 #include <qnnpack/operator.h>
20 #include <qnnpack/pack.h>
21 #include <qnnpack/params.h>
22 #include <qnnpack/requantization.h>
23 
pytorch_qnnp_create_fully_connected_sparse_dq_nc_q8(size_t input_channels,size_t output_channels,uint8_t input_zero_point,const uint8_t * kernel_zero_points,const void * kernel_col_indices,const void * kernel_row_values,const uint8_t * kernel_values,const uint32_t kernel_row_block_size,const uint32_t kernel_col_block_size,enum pytorch_qnnp_sparse_matrix_indices_dtype kernel_indices_dtype,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max,uint32_t flags,const float * requantization_scales,bool use_prepack_kernel,pytorch_qnnp_operator_t * fully_connected_out)24 enum pytorch_qnnp_status pytorch_qnnp_create_fully_connected_sparse_dq_nc_q8(
25     size_t input_channels,
26     size_t output_channels,
27     uint8_t input_zero_point,
28     const uint8_t* kernel_zero_points,
29     const void* kernel_col_indices,
30     const void* kernel_row_values,
31     const uint8_t* kernel_values,
32     const uint32_t kernel_row_block_size,
33     const uint32_t kernel_col_block_size,
34     enum pytorch_qnnp_sparse_matrix_indices_dtype kernel_indices_dtype,
35     uint8_t output_zero_point,
36     uint8_t output_min,
37     uint8_t output_max,
38     uint32_t flags,
39     const float* requantization_scales,
40     bool use_prepack_kernel,
41     pytorch_qnnp_operator_t* fully_connected_out) {
42   pytorch_qnnp_operator_t fully_connected = NULL;
43   enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized;
44 
45   if (!pytorch_qnnp_params.initialized) {
46     pytorch_qnnp_log_error(
47         "pytorch_qnnp_create_fully_connected_sparse_dq_nc_q8 failed because QNNPACK is not properly initialized");
48     goto error;
49   }
50 
51   status = pytorch_qnnp_status_unsupported_parameter;
52 
53   for (int i = 0; i < output_channels; ++i) {
54     if (requantization_scales[i] <= 0.0f ||
55         !isnormal(requantization_scales[i])) {
56       pytorch_qnnp_log_error(
57           "failed to create fully connected operator with %.7g requantization scale: scale must be finite and positive",
58           requantization_scales[i]);
59       goto error;
60     }
61   }
62 
63   status = pytorch_qnnp_status_out_of_memory;
64 
65   fully_connected = calloc(1, sizeof(struct pytorch_qnnp_operator));
66   if (fully_connected == NULL) {
67     pytorch_qnnp_log_error(
68         "failed to allocate %zu bytes for pytorch_qnnp_operator structure",
69         sizeof(struct pytorch_qnnp_operator));
70     goto error;
71   }
72 
73   if (kernel_row_block_size == 8 && kernel_col_block_size == 1) {
74     // This is to gate 8x1 on SSE2 since we have not implemented SSE2
75     // kernel that supports 8x1 sparsity pattern.
76     if (pytorch_qnnp_params.q8gemm_sparse_c8x1.packA == NULL) {
77       status = pytorch_qnnp_status_invalid_parameter;
78       goto error;
79     }
80   }
81 
82   fully_connected->sparse_matrix.indices_dtype = kernel_indices_dtype;
83   switch (kernel_indices_dtype) {
84     case pytorch_qnnp_sparse_matrix_indices_dtype_uint32_t:
85       fully_connected->sparse_matrix.col_indices_w32 =
86           (const uint32_t*)kernel_col_indices;
87       fully_connected->sparse_matrix.row_values_w32 =
88           (const uint32_t*)kernel_row_values;
89       break;
90     case pytorch_qnnp_sparse_matrix_indices_dtype_uint16_t:
91       fully_connected->sparse_matrix.col_indices_w16 =
92           (const uint16_t*)kernel_col_indices;
93       fully_connected->sparse_matrix.row_values_w16 =
94           (const uint16_t*)kernel_row_values;
95       break;
96     case pytorch_qnnp_sparse_matrix_indices_dtype_uint8_t:
97       fully_connected->sparse_matrix.col_indices_w8 =
98           (const uint8_t*)kernel_col_indices;
99       fully_connected->sparse_matrix.row_values_w8 =
100           (const uint8_t*)kernel_row_values;
101       break;
102     case pytorch_qnnp_sparse_matrix_indices_dtype_invalid:
103       status = pytorch_qnnp_status_invalid_parameter;
104       pytorch_qnnp_log_error(
105           "Invalid indices dtype specified for qnnpack fully connected sparse");
106       goto error;
107   }
108 
109   fully_connected->sparse_matrix.values = kernel_values;
110   fully_connected->sparse_matrix.row_block_size = kernel_row_block_size;
111   fully_connected->sparse_matrix.col_block_size = kernel_col_block_size;
112 
113   fully_connected->groups = 1;
114   fully_connected->group_input_channels = input_channels;
115   fully_connected->group_output_channels = output_channels;
116 
117   fully_connected->kernel_zero_point = kernel_zero_points[0];
118 
119   fully_connected->dynamic_conv_quantization_params.input_zero_point =
120     input_zero_point;
121   fully_connected->dynamic_conv_quantization_params.kernel_zero_points =
122     kernel_zero_points;
123   fully_connected->dynamic_conv_quantization_params.multipliers =
124     requantization_scales;
125 
126   // Always use prepacking based kernel
127   fully_connected->ukernel_type = pytorch_qnnp_ukernel_type_gemm_prepackA_sparse_dq;
128   fully_connected->format = pytorch_qnnp_format_quint8;
129 
130   *fully_connected_out = fully_connected;
131   return pytorch_qnnp_status_success;
132 
133 error:
134   pytorch_qnnp_delete_operator(fully_connected);
135   return status;
136 }
137 
pytorch_qnnp_setup_fully_connected_sparse_dq_nc_q8(pytorch_qnnp_operator_t fully_connected,size_t batch_size,const uint8_t * input,size_t input_stride,const float * bias,float * output,size_t output_stride)138 enum pytorch_qnnp_status pytorch_qnnp_setup_fully_connected_sparse_dq_nc_q8(
139     pytorch_qnnp_operator_t fully_connected,
140     size_t batch_size,
141     const uint8_t* input,
142     size_t input_stride,
143     const float* bias,
144     float* output,
145     size_t output_stride) {
146   if (!pytorch_qnnp_params.initialized) {
147     pytorch_qnnp_log_error(
148         "pytorch_qnnp_setup_fully_connected_nc_q8 failed because QNNPACK is not properly initialized");
149     return pytorch_qnnp_status_uninitialized;
150   }
151 
152   if (batch_size == 0) {
153     fully_connected->batch_size = 0;
154     return pytorch_qnnp_status_success;
155   }
156 
157   fully_connected->batch_size = 1;
158   fully_connected->input_height = batch_size;
159   fully_connected->input_width = 1;
160   fully_connected->input = input;
161   fully_connected->input_pixel_stride = input_stride;
162 
163   fully_connected->bias = bias;
164 
165   fully_connected->output_height = batch_size;
166   fully_connected->output_width = 1;
167   fully_connected->output = output;
168   fully_connected->output_pixel_stride = output_stride;
169 
170   return pytorch_qnnp_status_success;
171 }
172