1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 #include <stddef.h>
12 #include <stdint.h>
13
14 #include <qnnpack/requantization.h>
15
16 enum pytorch_qnnp_format {
17 pytorch_qnnp_format_quint8 = 0x02000000,
18 pytorch_qnnp_format_float32 = 0x02020202,
19 pytorch_qnnp_format_float16 = 0x01010101,
20 };
21
22 enum pytorch_qnnp_ukernel_type {
23 pytorch_qnnp_ukernel_type_none = 0,
24 pytorch_qnnp_ukernel_type_add,
25 pytorch_qnnp_ukernel_type_average_pooling,
26 pytorch_qnnp_ukernel_type_channel_shuffle,
27 pytorch_qnnp_ukernel_type_clamp,
28 pytorch_qnnp_ukernel_type_conv,
29 pytorch_qnnp_ukernel_type_dwconv,
30 pytorch_qnnp_ukernel_type_gemm,
31 pytorch_qnnp_ukernel_type_gemm_sparse_dq,
32 pytorch_qnnp_ukernel_type_gemm_prepackA_sparse_dq,
33 pytorch_qnnp_ukernel_type_global_average_pooling,
34 pytorch_qnnp_ukernel_type_lut,
35 pytorch_qnnp_ukernel_type_max_pooling,
36 pytorch_qnnp_ukernel_type_softargmax,
37 pytorch_qnnp_ukernel_type_xzp_gemm,
38 };
39
40 typedef struct {
41 union {
42 const uint32_t* col_indices_w32;
43 const uint16_t* col_indices_w16;
44 const uint8_t* col_indices_w8;
45 };
46 union {
47 const uint32_t* row_values_w32;
48 const uint16_t* row_values_w16;
49 const uint8_t* row_values_w8;
50 };
51 const uint8_t* values;
52 uint32_t row_block_size;
53 uint32_t col_block_size;
54 enum pytorch_qnnp_sparse_matrix_indices_dtype indices_dtype;
55 } sparse_matrix_t;
56
57 struct pytorch_qnnp_operator {
58 size_t batch_size;
59 uint32_t input_padding_depth;
60 uint32_t input_padding_height;
61 uint32_t input_padding_width;
62 uint32_t adjustment_height;
63 uint32_t adjustment_width;
64 uint32_t kernel_depth;
65 uint32_t kernel_height;
66 uint32_t kernel_width;
67 uint32_t stride_depth;
68 uint32_t stride_height;
69 uint32_t stride_width;
70 uint32_t dilation_depth;
71 uint32_t dilation_height;
72 uint32_t dilation_width;
73 uint32_t groups;
74 size_t group_stride;
75 size_t group_channels;
76 size_t group_input_channels;
77 size_t group_output_channels;
78 size_t channels;
79
80 size_t input_depth;
81 size_t input_height;
82 size_t input_width;
83 size_t input_pixel_stride;
84 const void* input;
85 const void** indirection_buffer;
86 void* a_sum;
87
88 size_t step_depth;
89 size_t step_height;
90 size_t step_width;
91
92 size_t input2_pixel_stride;
93 const void* input2;
94
95 size_t output_depth;
96 size_t output_height;
97 size_t output_width;
98 size_t output_pixel_stride;
99 void* output;
100
101 void* packed_weights;
102 float input_scale;
103 float output_scale;
104 uint8_t input_zero_point;
105 uint8_t kernel_zero_point;
106 uint8_t output_zero_point;
107 uint8_t output_min;
108 uint8_t output_max;
109
110 size_t valid_batch_size;
111 size_t last_input_height;
112 size_t last_input_width;
113 const void* last_input;
114
115 void* zero_buffer;
116 void* zero_pointer;
117 void* lookup_table;
118
119 union {
120 union pytorch_qnnp_q31_requantization_params requantization_params;
121 union pytorch_qnnp_conv_quantization_params conv_quantization_params;
122 union pytorch_qnnp_add_quantization_params add_quantization_params;
123 union pytorch_qnnp_avgpool_quantization_params avgpool_quantization_params;
124 union pytorch_qnnp_u8_clamping_params u8_clamping_params;
125 };
126 enum pytorch_qnnp_ukernel_type ukernel_type;
127 enum pytorch_qnnp_format format;
128
129 bool per_channel;
130 bool transpose;
131
132 // Sparsity support
133 sparse_matrix_t sparse_matrix;
134 const void* bias;
135 struct pytorch_qnnp_conv_dynamic_quantization_params dynamic_conv_quantization_params;
136 uint8_t* prepacked_a;
137 };
138
pytorch_qnnp_operator_get_log2_output_element_size(const struct pytorch_qnnp_operator * convolution)139 static inline uint32_t pytorch_qnnp_operator_get_log2_output_element_size(
140 const struct pytorch_qnnp_operator* convolution) {
141 return (uint32_t)(convolution->format & UINT32_C(0xFF));
142 }
143
pytorch_qnnp_operator_get_log2_input_element_size(const struct pytorch_qnnp_operator * convolution)144 static inline uint32_t pytorch_qnnp_operator_get_log2_input_element_size(
145 const struct pytorch_qnnp_operator* convolution) {
146 return (uint32_t)((convolution->format >> 8) & UINT32_C(0xFF));
147 }
148
pytorch_qnnp_operator_get_log2_kernel_element_size(const struct pytorch_qnnp_operator * convolution)149 static inline uint32_t pytorch_qnnp_operator_get_log2_kernel_element_size(
150 const struct pytorch_qnnp_operator* convolution) {
151 return (uint32_t)((convolution->format >> 16) & UINT32_C(0xFF));
152 }
153
pytorch_qnnp_operator_get_log2_bias_element_size(const struct pytorch_qnnp_operator * convolution)154 static inline uint32_t pytorch_qnnp_operator_get_log2_bias_element_size(
155 const struct pytorch_qnnp_operator* convolution) {
156 return (uint32_t)((convolution->format >> 24) & UINT32_C(0xFF));
157 }
158