1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <xnnpack/operator.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <string.h>
11
12 #include <fp16.h>
13
14 #include <xnnpack.h>
15 #include <xnnpack/log.h>
16 #include <xnnpack/operator.h>
17 #include <xnnpack/params.h>
18 #include <xnnpack/requantization.h>
19 #include <xnnpack/subgraph.h>
20 #include <xnnpack/subgraph-validation.h>
21
22
create_constant_pad_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)23 static enum xnn_status create_constant_pad_operator(
24 const struct xnn_node* node,
25 const struct xnn_value* values,
26 size_t num_values,
27 struct xnn_operator_data* opdata,
28 const struct xnn_caches* caches)
29 {
30 assert(node->num_inputs == 1);
31 const uint32_t input_id = node->inputs[0];
32 assert(input_id != XNN_INVALID_VALUE_ID);
33 assert(input_id < num_values);
34
35 assert(node->num_outputs == 1);
36 const uint32_t output_id = node->outputs[0];
37 assert(output_id != XNN_INVALID_VALUE_ID);
38 assert(output_id < num_values);
39
40 enum xnn_status status;
41 switch (node->compute_type) {
42 #ifndef XNN_NO_F16_OPERATORS
43 case xnn_compute_type_fp16:
44 status = xnn_create_constant_pad_nd_x16(
45 &node->params.static_pad.padding_value,
46 node->flags,
47 &opdata->operator_objects[0]);
48 break;
49 #endif // !defined(XNN_NO_F16_OPERATORS)
50 case xnn_compute_type_fp32:
51 status = xnn_create_constant_pad_nd_x32(
52 &node->params.static_pad.padding_value,
53 node->flags,
54 &opdata->operator_objects[0]);
55 break;
56 #ifndef XNN_NO_QS8_OPERATORS
57 case xnn_compute_type_qs8:
58 #endif // !defined(XNN_NO_QS8_OPERATORS)
59 #ifndef XNN_NO_QU8_OPERATORS
60 case xnn_compute_type_qu8:
61 #endif // !defined(XNN_NO_QU8_OPERATORS)
62 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
63 status = xnn_create_constant_pad_nd_x8(
64 &node->params.static_pad.padding_value,
65 node->flags,
66 &opdata->operator_objects[0]);
67 break;
68 #endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
69 default:
70 XNN_UNREACHABLE;
71 }
72 if (status == xnn_status_success) {
73 opdata->shape1 = values[input_id].shape;
74 memcpy(opdata->pre_paddings, node->params.static_pad.pre_paddings, sizeof(size_t) * XNN_MAX_TENSOR_DIMS);
75 memcpy(opdata->post_paddings, node->params.static_pad.post_paddings, sizeof(size_t) * XNN_MAX_TENSOR_DIMS);
76 opdata->inputs[0] = input_id;
77 opdata->outputs[0] = output_id;
78 }
79 return status;
80 }
81
setup_constant_pad_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)82 static enum xnn_status setup_constant_pad_operator(
83 const struct xnn_operator_data* opdata,
84 const struct xnn_blob* blobs,
85 size_t num_blobs,
86 pthreadpool_t threadpool)
87 {
88 const uint32_t input_id = opdata->inputs[0];
89 assert(input_id != XNN_INVALID_VALUE_ID);
90 assert(input_id < num_blobs);
91
92 const uint32_t output_id = opdata->outputs[0];
93 assert(output_id != XNN_INVALID_VALUE_ID);
94 assert(output_id < num_blobs);
95
96 const struct xnn_blob* input_blob = blobs + input_id;
97 const void* input_data = input_blob->data;
98 assert(input_data != NULL);
99
100 const struct xnn_blob* output_blob = blobs + output_id;
101 void* output_data = output_blob->data;
102 assert(output_data != NULL);
103
104 switch (opdata->operator_objects[0]->type) {
105 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
106 case xnn_operator_type_constant_pad_nd_x8:
107 return xnn_setup_constant_pad_nd_x8(
108 opdata->operator_objects[0],
109 opdata->shape1.num_dims,
110 opdata->shape1.dim,
111 opdata->pre_paddings,
112 opdata->post_paddings,
113 input_data,
114 output_data,
115 threadpool);
116 break;
117 #endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
118 #ifndef XNN_NO_F16_OPERATORS
119 case xnn_operator_type_constant_pad_nd_x16:
120 return xnn_setup_constant_pad_nd_x16(
121 opdata->operator_objects[0],
122 opdata->shape1.num_dims,
123 opdata->shape1.dim,
124 opdata->pre_paddings,
125 opdata->post_paddings,
126 input_data,
127 output_data,
128 threadpool);
129 break;
130 #endif // !defined(XNN_NO_F16_OPERATORS)
131 case xnn_operator_type_constant_pad_nd_x32:
132 return xnn_setup_constant_pad_nd_x32(
133 opdata->operator_objects[0],
134 opdata->shape1.num_dims,
135 opdata->shape1.dim,
136 opdata->pre_paddings,
137 opdata->post_paddings,
138 input_data,
139 output_data,
140 threadpool);
141 break;
142 default:
143 XNN_UNREACHABLE;
144 }
145 }
146
xnn_define_static_constant_pad(xnn_subgraph_t subgraph,const size_t * pre_paddings,const size_t * post_paddings,float padding_value,uint32_t input_id,uint32_t output_id,uint32_t flags)147 enum xnn_status xnn_define_static_constant_pad(
148 xnn_subgraph_t subgraph,
149 const size_t* pre_paddings,
150 const size_t* post_paddings,
151 float padding_value,
152 uint32_t input_id,
153 uint32_t output_id,
154 uint32_t flags)
155 {
156 enum xnn_status status;
157 if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_static_constant_pad)) != xnn_status_success) {
158 return status;
159 }
160
161 if (input_id >= subgraph->num_values) {
162 xnn_log_error(
163 "failed to define %s operator with input ID #%" PRIu32 ": invalid Value ID",
164 xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id);
165 return xnn_status_invalid_parameter;
166 }
167
168 const struct xnn_value* input_value = &subgraph->values[input_id];
169 status = xnn_subgraph_check_input_type_dense(xnn_node_type_static_constant_pad, input_id, input_value);
170 if (status != xnn_status_success) {
171 return status;
172 }
173
174 switch (input_value->datatype) {
175 case xnn_datatype_fp32:
176 #ifndef XNN_NO_QS8_OPERATORS
177 case xnn_datatype_qint8:
178 #endif // !defined(XNN_NO_QS8_OPERATORS)
179 #ifndef XNN_NO_QU8_OPERATORS
180 case xnn_datatype_quint8:
181 #endif // !defined(XNN_NO_QU8_OPERATORS)
182 break;
183 default:
184 xnn_log_error(
185 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
186 xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id,
187 xnn_datatype_to_string(input_value->datatype), input_value->datatype);
188 return xnn_status_invalid_parameter;
189 }
190
191 status = xnn_subgraph_check_output_node_id(xnn_node_type_static_constant_pad, output_id, subgraph->num_values);
192 if (status != xnn_status_success) {
193 return status;
194 }
195
196 const struct xnn_value* output_value = &subgraph->values[output_id];
197 status = xnn_subgraph_check_output_type_dense(xnn_node_type_static_constant_pad, output_id, output_value);
198 if (status != xnn_status_success) {
199 return status;
200 }
201
202 enum xnn_compute_type compute_type = xnn_compute_type_invalid;
203 switch (output_value->datatype) {
204 case xnn_datatype_fp32:
205 compute_type = xnn_compute_type_fp32;
206 break;
207 #ifndef XNN_NO_QS8_OPERATORS
208 case xnn_datatype_qint8:
209 compute_type = xnn_compute_type_qs8;
210 break;
211 #endif // !defined(XNN_NO_QS8_OPERATORS)
212 #ifndef XNN_NO_QU8_OPERATORS
213 case xnn_datatype_quint8:
214 compute_type = xnn_compute_type_qu8;
215 break;
216 #endif // !defined(XNN_NO_QU8_OPERATORS)
217 default:
218 xnn_log_error(
219 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
220 xnn_node_type_to_string(xnn_node_type_static_constant_pad), output_id,
221 xnn_datatype_to_string(output_value->datatype), output_value->datatype);
222 return xnn_status_invalid_parameter;
223 }
224
225 status = xnn_subgraph_check_datatype_matches(
226 xnn_node_type_static_constant_pad, input_id, input_value, output_id, output_value);
227 if (status != xnn_status_success) {
228 return status;
229 }
230
231 #if !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
232 if (output_value->datatype == xnn_datatype_qint8 || output_value->datatype == xnn_datatype_quint8) {
233 if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
234 xnn_log_error(
235 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
236 ": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
237 xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id, output_id,
238 input_value->quantization.zero_point, output_value->quantization.zero_point);
239 return xnn_status_invalid_parameter;
240 }
241 if (input_value->quantization.scale != output_value->quantization.scale) {
242 xnn_log_error(
243 "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
244 ": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)",
245 xnn_node_type_to_string(xnn_node_type_static_constant_pad), input_id, output_id,
246 input_value->quantization.scale, output_value->quantization.scale);
247 return xnn_status_invalid_parameter;
248 }
249 }
250 #endif // !defined(XNN_NO_QU8_OPERATORS) || !defined(XNN_NO_QS8_OPERATORS)
251
252 struct xnn_node* node = xnn_subgraph_new_node(subgraph);
253 if (node == NULL) {
254 return xnn_status_out_of_memory;
255 }
256
257 const size_t num_dims = subgraph->values[input_id].shape.num_dims;
258 memcpy(&node->params.static_pad.pre_paddings, pre_paddings, num_dims * sizeof(size_t));
259 memcpy(&node->params.static_pad.post_paddings, post_paddings, num_dims * sizeof(size_t));
260 switch (output_value->datatype) {
261 case xnn_datatype_fp32:
262 node->params.static_pad.padding_value = float_as_uint32(padding_value);
263 break;
264 #ifndef XNN_NO_QS8_OPERATORS
265 case xnn_datatype_qint8:
266 {
267 const float output_scale = output_value->quantization.scale;
268 const int32_t output_zero_point = output_value->quantization.zero_point;
269 node->params.static_pad.padding_value = xnn_qs8_quantize(padding_value, output_scale, output_zero_point);
270 break;
271 }
272 #endif // !defined(XNN_NO_QS8_OPERATORS)
273 #ifndef XNN_NO_QU8_OPERATORS
274 case xnn_datatype_quint8:
275 {
276 const float output_scale = output_value->quantization.scale;
277 const int32_t output_zero_point = output_value->quantization.zero_point;
278 node->params.static_pad.padding_value = xnn_qu8_quantize(padding_value, output_scale, output_zero_point);
279 break;
280 }
281 #endif // !defined(XNN_NO_QU8_OPERATORS)
282 default:
283 XNN_UNREACHABLE;
284 }
285
286 node->type = xnn_node_type_static_constant_pad;
287 node->compute_type = compute_type;
288 node->num_inputs = 1;
289 node->inputs[0] = input_id;
290 node->num_outputs = 1;
291 node->outputs[0] = output_id;
292 node->flags = flags;
293
294 node->create = create_constant_pad_operator;
295 node->setup = setup_constant_pad_operator;
296
297 return xnn_status_success;
298 }
299