1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <assert.h>
10 #include <math.h>
11 #include <stdbool.h>
12 #include <stddef.h>
13 #include <stdint.h>
14 #include <string.h>
15
16 #include <pytorch_qnnpack.h>
17 #include <qnnpack/log.h>
18 #include <qnnpack/math.h>
19 #include <qnnpack/operator.h>
20 #include <qnnpack/pack.h>
21 #include <qnnpack/params.h>
22 #include <qnnpack/requantization.h>
23
pytorch_qnnp_create_fully_connected_sparse_dq_nc_q8(size_t input_channels,size_t output_channels,uint8_t input_zero_point,const uint8_t * kernel_zero_points,const void * kernel_col_indices,const void * kernel_row_values,const uint8_t * kernel_values,const uint32_t kernel_row_block_size,const uint32_t kernel_col_block_size,enum pytorch_qnnp_sparse_matrix_indices_dtype kernel_indices_dtype,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max,uint32_t flags,const float * requantization_scales,bool use_prepack_kernel,pytorch_qnnp_operator_t * fully_connected_out)24 enum pytorch_qnnp_status pytorch_qnnp_create_fully_connected_sparse_dq_nc_q8(
25 size_t input_channels,
26 size_t output_channels,
27 uint8_t input_zero_point,
28 const uint8_t* kernel_zero_points,
29 const void* kernel_col_indices,
30 const void* kernel_row_values,
31 const uint8_t* kernel_values,
32 const uint32_t kernel_row_block_size,
33 const uint32_t kernel_col_block_size,
34 enum pytorch_qnnp_sparse_matrix_indices_dtype kernel_indices_dtype,
35 uint8_t output_zero_point,
36 uint8_t output_min,
37 uint8_t output_max,
38 uint32_t flags,
39 const float* requantization_scales,
40 bool use_prepack_kernel,
41 pytorch_qnnp_operator_t* fully_connected_out) {
42 pytorch_qnnp_operator_t fully_connected = NULL;
43 enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized;
44
45 if (!pytorch_qnnp_params.initialized) {
46 pytorch_qnnp_log_error(
47 "pytorch_qnnp_create_fully_connected_sparse_dq_nc_q8 failed because QNNPACK is not properly initialized");
48 goto error;
49 }
50
51 status = pytorch_qnnp_status_unsupported_parameter;
52
53 for (int i = 0; i < output_channels; ++i) {
54 if (requantization_scales[i] <= 0.0f ||
55 !isnormal(requantization_scales[i])) {
56 pytorch_qnnp_log_error(
57 "failed to create fully connected operator with %.7g requantization scale: scale must be finite and positive",
58 requantization_scales[i]);
59 goto error;
60 }
61 }
62
63 status = pytorch_qnnp_status_out_of_memory;
64
65 fully_connected = calloc(1, sizeof(struct pytorch_qnnp_operator));
66 if (fully_connected == NULL) {
67 pytorch_qnnp_log_error(
68 "failed to allocate %zu bytes for pytorch_qnnp_operator structure",
69 sizeof(struct pytorch_qnnp_operator));
70 goto error;
71 }
72
73 if (kernel_row_block_size == 8 && kernel_col_block_size == 1) {
74 // This is to gate 8x1 on SSE2 since we have not implemented SSE2
75 // kernel that supports 8x1 sparsity pattern.
76 if (pytorch_qnnp_params.q8gemm_sparse_c8x1.packA == NULL) {
77 status = pytorch_qnnp_status_invalid_parameter;
78 goto error;
79 }
80 }
81
82 fully_connected->sparse_matrix.indices_dtype = kernel_indices_dtype;
83 switch (kernel_indices_dtype) {
84 case pytorch_qnnp_sparse_matrix_indices_dtype_uint32_t:
85 fully_connected->sparse_matrix.col_indices_w32 =
86 (const uint32_t*)kernel_col_indices;
87 fully_connected->sparse_matrix.row_values_w32 =
88 (const uint32_t*)kernel_row_values;
89 break;
90 case pytorch_qnnp_sparse_matrix_indices_dtype_uint16_t:
91 fully_connected->sparse_matrix.col_indices_w16 =
92 (const uint16_t*)kernel_col_indices;
93 fully_connected->sparse_matrix.row_values_w16 =
94 (const uint16_t*)kernel_row_values;
95 break;
96 case pytorch_qnnp_sparse_matrix_indices_dtype_uint8_t:
97 fully_connected->sparse_matrix.col_indices_w8 =
98 (const uint8_t*)kernel_col_indices;
99 fully_connected->sparse_matrix.row_values_w8 =
100 (const uint8_t*)kernel_row_values;
101 break;
102 case pytorch_qnnp_sparse_matrix_indices_dtype_invalid:
103 status = pytorch_qnnp_status_invalid_parameter;
104 pytorch_qnnp_log_error(
105 "Invalid indices dtype specified for qnnpack fully connected sparse");
106 goto error;
107 }
108
109 fully_connected->sparse_matrix.values = kernel_values;
110 fully_connected->sparse_matrix.row_block_size = kernel_row_block_size;
111 fully_connected->sparse_matrix.col_block_size = kernel_col_block_size;
112
113 fully_connected->groups = 1;
114 fully_connected->group_input_channels = input_channels;
115 fully_connected->group_output_channels = output_channels;
116
117 fully_connected->kernel_zero_point = kernel_zero_points[0];
118
119 fully_connected->dynamic_conv_quantization_params.input_zero_point =
120 input_zero_point;
121 fully_connected->dynamic_conv_quantization_params.kernel_zero_points =
122 kernel_zero_points;
123 fully_connected->dynamic_conv_quantization_params.multipliers =
124 requantization_scales;
125
126 // Always use prepacking based kernel
127 fully_connected->ukernel_type = pytorch_qnnp_ukernel_type_gemm_prepackA_sparse_dq;
128 fully_connected->format = pytorch_qnnp_format_quint8;
129
130 *fully_connected_out = fully_connected;
131 return pytorch_qnnp_status_success;
132
133 error:
134 pytorch_qnnp_delete_operator(fully_connected);
135 return status;
136 }
137
pytorch_qnnp_setup_fully_connected_sparse_dq_nc_q8(pytorch_qnnp_operator_t fully_connected,size_t batch_size,const uint8_t * input,size_t input_stride,const float * bias,float * output,size_t output_stride)138 enum pytorch_qnnp_status pytorch_qnnp_setup_fully_connected_sparse_dq_nc_q8(
139 pytorch_qnnp_operator_t fully_connected,
140 size_t batch_size,
141 const uint8_t* input,
142 size_t input_stride,
143 const float* bias,
144 float* output,
145 size_t output_stride) {
146 if (!pytorch_qnnp_params.initialized) {
147 pytorch_qnnp_log_error(
148 "pytorch_qnnp_setup_fully_connected_nc_q8 failed because QNNPACK is not properly initialized");
149 return pytorch_qnnp_status_uninitialized;
150 }
151
152 if (batch_size == 0) {
153 fully_connected->batch_size = 0;
154 return pytorch_qnnp_status_success;
155 }
156
157 fully_connected->batch_size = 1;
158 fully_connected->input_height = batch_size;
159 fully_connected->input_width = 1;
160 fully_connected->input = input;
161 fully_connected->input_pixel_stride = input_stride;
162
163 fully_connected->bias = bias;
164
165 fully_connected->output_height = batch_size;
166 fully_connected->output_width = 1;
167 fully_connected->output = output;
168 fully_connected->output_pixel_stride = output_stride;
169
170 return pytorch_qnnp_status_success;
171 }
172