1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10
11 #include <xnnpack.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/operator.h>
14 #include <xnnpack/params.h>
15 #include <xnnpack/requantization.h>
16 #include <xnnpack/subgraph.h>
17 #include <xnnpack/subgraph-validation.h>
18
19
create_fully_connected_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)20 static enum xnn_status create_fully_connected_operator(
21 const struct xnn_node* node,
22 const struct xnn_value* values,
23 size_t num_values,
24 struct xnn_operator_data* opdata,
25 const struct xnn_caches* caches)
26 {
27 assert(node->num_inputs >= 2);
28 assert(node->num_inputs <= 3);
29 const uint32_t input_id = node->inputs[0];
30 assert(input_id != XNN_INVALID_VALUE_ID);
31 assert(input_id < num_values);
32 const uint32_t filter_id = node->inputs[1];
33 assert(filter_id != XNN_INVALID_VALUE_ID);
34 assert(filter_id < num_values);
35
36 assert(node->num_outputs == 1);
37 const uint32_t output_id = node->outputs[0];
38 assert(output_id != XNN_INVALID_VALUE_ID);
39 assert(output_id < num_values);
40
41 const size_t num_input_elements = xnn_shape_multiply_all_dims(&values[node->inputs[0]].shape);
42 size_t output_channels, input_channels;
43 if (node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
44 input_channels = values[node->inputs[1]].shape.dim[0];
45 output_channels = values[node->inputs[1]].shape.dim[1];
46 } else {
47 output_channels = values[node->inputs[1]].shape.dim[0];
48 input_channels = values[node->inputs[1]].shape.dim[1];
49 }
50
51 const void* filter_data = values[filter_id].data;
52 assert(filter_data != NULL);
53
54 const void* bias_data = NULL;
55 if (node->num_inputs > 2) {
56 const uint32_t bias_id = node->inputs[2];
57 assert(bias_id != XNN_INVALID_VALUE_ID);
58 assert(bias_id < num_values);
59
60 bias_data = values[bias_id].data;
61 assert(bias_data != NULL);
62 }
63
64 enum xnn_status status;
65 switch (node->compute_type) {
66 #ifndef XNN_NO_F16_OPERATORS
67 case xnn_compute_type_fp16:
68 status = xnn_create_fully_connected_nc_f16(
69 input_channels,
70 output_channels,
71 input_channels /* input stride */,
72 output_channels /* output stride */,
73 filter_data,
74 bias_data,
75 node->activation.output_min,
76 node->activation.output_max,
77 node->flags | XNN_FLAG_FP32_STATIC_WEIGHTS,
78 caches,
79 &opdata->operator_objects[0]);
80 break;
81 #endif // XNN_NO_F16_OPERATORS
82 case xnn_compute_type_fp32:
83 status = xnn_create_fully_connected_nc_f32(
84 input_channels,
85 output_channels,
86 input_channels /* input stride */,
87 output_channels /* output stride */,
88 filter_data,
89 bias_data,
90 node->activation.output_min,
91 node->activation.output_max,
92 node->flags /* flags */,
93 caches,
94 &opdata->operator_objects[0]);
95 break;
96 #ifndef XNN_NO_QS8_OPERATORS
97 case xnn_compute_type_qs8:
98 {
99 const float output_scale = values[output_id].quantization.scale;
100 const int32_t output_zero_point = values[output_id].quantization.zero_point;
101 const int8_t output_min = xnn_qs8_quantize(node->activation.output_min, output_scale, output_zero_point);
102 const int8_t output_max = xnn_qs8_quantize(node->activation.output_max, output_scale, output_zero_point);
103 status = xnn_create_fully_connected_nc_qs8(
104 input_channels,
105 output_channels,
106 input_channels /* input stride */,
107 output_channels /* output stride */,
108 (int8_t) values[input_id].quantization.zero_point,
109 values[input_id].quantization.scale,
110 values[filter_id].quantization.scale,
111 filter_data,
112 bias_data,
113 (int8_t) output_zero_point,
114 output_scale, output_min, output_max,
115 node->flags /* flags */,
116 caches,
117 &opdata->operator_objects[0]);
118 break;
119 }
120 #endif // !defined(XNN_NO_QS8_OPERATORS)
121 #ifndef XNN_NO_QU8_OPERATORS
122 case xnn_compute_type_qu8:
123 {
124 const float output_scale = values[output_id].quantization.scale;
125 const int32_t output_zero_point = values[output_id].quantization.zero_point;
126 const uint8_t output_min = xnn_qu8_quantize(node->activation.output_min, output_scale, output_zero_point);
127 const uint8_t output_max = xnn_qu8_quantize(node->activation.output_max, output_scale, output_zero_point);
128 status = xnn_create_fully_connected_nc_qu8(
129 input_channels,
130 output_channels,
131 input_channels /* input stride */,
132 output_channels /* output stride */,
133 (uint8_t) values[input_id].quantization.zero_point,
134 values[input_id].quantization.scale,
135 (uint8_t) values[filter_id].quantization.zero_point,
136 values[filter_id].quantization.scale,
137 filter_data,
138 bias_data,
139 (uint8_t) output_zero_point,
140 output_scale, output_min, output_max,
141 node->flags /* flags */,
142 caches,
143 &opdata->operator_objects[0]);
144 break;
145 }
146 #endif // !defined(XNN_NO_QU8_OPERATORS)
147 default:
148 XNN_UNREACHABLE;
149 }
150 if (status == xnn_status_success) {
151 opdata->batch_size = num_input_elements / input_channels;
152 opdata->inputs[0] = input_id;
153 opdata->outputs[0] = output_id;
154 }
155 return status;
156 }
157
setup_fully_connected_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)158 static enum xnn_status setup_fully_connected_operator(
159 const struct xnn_operator_data* opdata,
160 const struct xnn_blob* blobs,
161 size_t num_blobs,
162 pthreadpool_t threadpool)
163 {
164 const uint32_t input_id = opdata->inputs[0];
165 assert(input_id != XNN_INVALID_VALUE_ID);
166 assert(input_id < num_blobs);
167
168 const uint32_t output_id = opdata->outputs[0];
169 assert(output_id != XNN_INVALID_VALUE_ID);
170 assert(output_id < num_blobs);
171
172 const struct xnn_blob* input_blob = blobs + input_id;
173 const void* input_data = input_blob->data;
174 assert(input_data != NULL);
175
176 const struct xnn_blob* output_blob = blobs + output_id;
177 void* output_data = output_blob->data;
178 assert(output_data != NULL);
179
180 switch (opdata->operator_objects[0]->type) {
181 #ifndef XNN_NO_F16_OPERATORS
182 case xnn_operator_type_fully_connected_nc_f16:
183 return xnn_setup_fully_connected_nc_f16(
184 opdata->operator_objects[0],
185 opdata->batch_size,
186 input_data,
187 output_data,
188 threadpool);
189 #endif // !defined(XNN_NO_F16_OPERATORS)
190 case xnn_operator_type_fully_connected_nc_f32:
191 return xnn_setup_fully_connected_nc_f32(
192 opdata->operator_objects[0],
193 opdata->batch_size,
194 input_data,
195 output_data,
196 threadpool);
197 #ifndef XNN_NO_QS8_OPERATORS
198 case xnn_operator_type_fully_connected_nc_qs8:
199 return xnn_setup_fully_connected_nc_qs8(
200 opdata->operator_objects[0],
201 opdata->batch_size,
202 input_data,
203 output_data,
204 threadpool);
205 #endif // !defined(XNN_NO_QS8_OPERATORS)
206 #ifndef XNN_NO_QU8_OPERATORS
207 case xnn_operator_type_fully_connected_nc_qu8:
208 return xnn_setup_fully_connected_nc_qu8(
209 opdata->operator_objects[0],
210 opdata->batch_size,
211 input_data,
212 output_data,
213 threadpool);
214 #endif // !defined(XNN_NO_QU8_OPERATORS)
215 default:
216 XNN_UNREACHABLE;
217 }
218 }
219
validate_datatypes_with_bias(enum xnn_datatype input_datatype,enum xnn_datatype filter_datatype,enum xnn_datatype bias_datatype,enum xnn_datatype output_datatype)220 static inline enum xnn_compute_type validate_datatypes_with_bias(
221 enum xnn_datatype input_datatype,
222 enum xnn_datatype filter_datatype,
223 enum xnn_datatype bias_datatype,
224 enum xnn_datatype output_datatype)
225 {
226 switch (filter_datatype) {
227 case xnn_datatype_fp32:
228 if (input_datatype == xnn_datatype_fp32 &&
229 bias_datatype == xnn_datatype_fp32 &&
230 output_datatype == xnn_datatype_fp32)
231 {
232 return xnn_compute_type_fp32;
233 }
234 break;
235 #ifndef XNN_NO_QS8_OPERATORS
236 case xnn_datatype_qint8:
237 if (input_datatype == xnn_datatype_qint8 &&
238 bias_datatype == xnn_datatype_qint32 &&
239 output_datatype == xnn_datatype_qint8)
240 {
241 return xnn_compute_type_qs8;
242 }
243 break;
244 #endif // !defined(XNN_NO_QS8_OPERATORS)
245 #ifndef XNN_NO_QU8_OPERATORS
246 case xnn_datatype_quint8:
247 if (input_datatype == xnn_datatype_quint8 &&
248 bias_datatype == xnn_datatype_qint32 &&
249 output_datatype == xnn_datatype_quint8)
250 {
251 return xnn_compute_type_qu8;
252 }
253 break;
254 #endif // !defined(XNN_NO_QU8_OPERATORS)
255 default:
256 XNN_UNREACHABLE;
257 }
258 return xnn_compute_type_invalid;
259 }
260
validate_datatypes_without_bias(enum xnn_datatype input_datatype,enum xnn_datatype filter_datatype,enum xnn_datatype output_datatype)261 static inline enum xnn_compute_type validate_datatypes_without_bias(
262 enum xnn_datatype input_datatype,
263 enum xnn_datatype filter_datatype,
264 enum xnn_datatype output_datatype)
265 {
266 switch (filter_datatype) {
267 case xnn_datatype_fp32:
268 if (input_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_fp32) {
269 return xnn_compute_type_fp32;
270 }
271 break;
272 #ifndef XNN_NO_QS8_OPERATORS
273 case xnn_datatype_qint8:
274 if (input_datatype == xnn_datatype_qint8 && output_datatype == xnn_datatype_qint8) {
275 return xnn_compute_type_qs8;
276 }
277 break;
278 #endif // !defined(XNN_NO_QS8_OPERATORS)
279 #ifndef XNN_NO_QU8_OPERATORS
280 case xnn_datatype_quint8:
281 if (input_datatype == xnn_datatype_quint8 && output_datatype == xnn_datatype_quint8) {
282 return xnn_compute_type_qu8;
283 }
284 break;
285 #endif // !defined(XNN_NO_QU8_OPERATORS)
286 default:
287 XNN_UNREACHABLE;
288 }
289 return xnn_compute_type_invalid;
290 }
291
xnn_define_fully_connected(xnn_subgraph_t subgraph,float output_min,float output_max,uint32_t input_id,uint32_t filter_id,uint32_t bias_id,uint32_t output_id,uint32_t flags)292 enum xnn_status xnn_define_fully_connected(
293 xnn_subgraph_t subgraph,
294 float output_min,
295 float output_max,
296 uint32_t input_id,
297 uint32_t filter_id,
298 uint32_t bias_id,
299 uint32_t output_id,
300 uint32_t flags)
301 {
302 enum xnn_status status;
303 if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_fully_connected)) != xnn_status_success) {
304 return status;
305 }
306
307 status = xnn_subgraph_check_output_min_max(xnn_node_type_fully_connected, output_min, output_max);
308 if (status != xnn_status_success) {
309 return status;
310 }
311
312 if ((status = xnn_subgraph_check_input_node_id(xnn_node_type_fully_connected, input_id, subgraph->num_values)) !=
313 xnn_status_success) {
314 return status;
315 }
316
317 const struct xnn_value* input_value = &subgraph->values[input_id];
318 status = xnn_subgraph_check_input_type_dense(xnn_node_type_fully_connected, input_id, input_value);
319 if (status != xnn_status_success) {
320 return status;
321 }
322
323 switch (input_value->datatype) {
324 case xnn_datatype_fp32:
325 #ifndef XNN_NO_QS8_OPERATORS
326 case xnn_datatype_qint8:
327 #endif // !defined(XNN_NO_QS8_OPERATORS)
328 #ifndef XNN_NO_QU8_OPERATORS
329 case xnn_datatype_quint8:
330 #endif // !defined(XNN_NO_QS8_OPERATORS)
331 break;
332 default:
333 xnn_log_error(
334 "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
335 xnn_node_type_to_string(xnn_node_type_fully_connected), input_id,
336 xnn_datatype_to_string(input_value->datatype), input_value->datatype);
337 return xnn_status_invalid_parameter;
338 }
339
340 if (filter_id >= subgraph->num_values) {
341 xnn_log_error(
342 "failed to define %s operator with filter ID #%" PRIu32 ": invalid Value ID",
343 xnn_node_type_to_string(xnn_node_type_fully_connected), filter_id);
344 return xnn_status_invalid_parameter;
345 }
346
347 const struct xnn_value* filter_value = &subgraph->values[filter_id];
348 if (filter_value->type != xnn_value_type_dense_tensor) {
349 xnn_log_error(
350 "failed to define %s operator with filter ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
351 xnn_node_type_to_string(xnn_node_type_fully_connected), filter_id, filter_value->type);
352 return xnn_status_invalid_parameter;
353 }
354
355 if (filter_value->data == NULL) {
356 xnn_log_error(
357 "failed to define %s operator with filter ID #%" PRIu32 ": non-static Value",
358 xnn_node_type_to_string(xnn_node_type_fully_connected), filter_id);
359 return xnn_status_invalid_parameter;
360 }
361
362 switch (filter_value->datatype) {
363 case xnn_datatype_fp32:
364 break;
365 #ifndef XNN_NO_QS8_OPERATORS
366 case xnn_datatype_qint8:
367 if (filter_value->quantization.zero_point != 0) {
368 xnn_log_error(
369 "failed to define %s operator with filter ID #%" PRIu32 ": unsupported quantization zero point %" PRId32 " for datatype %s",
370 xnn_node_type_to_string(xnn_node_type_convolution_2d), filter_id,
371 filter_value->quantization.zero_point, xnn_datatype_to_string(filter_value->datatype));
372 }
373 break;
374 #endif // !defined(XNN_NO_QS8_OPERATORS)
375 #ifndef XNN_NO_QU8_OPERATORS
376 case xnn_datatype_quint8:
377 break;
378 #endif // !defined(XNN_NO_QU8_OPERATORS)
379 default:
380 xnn_log_error(
381 "failed to define %s operator with filter ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
382 xnn_node_type_to_string(xnn_node_type_fully_connected), filter_id,
383 xnn_datatype_to_string(filter_value->datatype), filter_value->datatype);
384 return xnn_status_invalid_parameter;
385 }
386
387 const struct xnn_value* bias_value = NULL;
388 if (bias_id != XNN_INVALID_VALUE_ID) {
389 if (bias_id >= subgraph->num_values) {
390 xnn_log_error(
391 "failed to define %s operator with bias ID #%" PRIu32 ": invalid Value ID",
392 xnn_node_type_to_string(xnn_node_type_fully_connected), bias_id);
393 return xnn_status_invalid_parameter;
394 }
395
396 bias_value = &subgraph->values[bias_id];
397 if (bias_value->type != xnn_value_type_dense_tensor) {
398 xnn_log_error(
399 "failed to define %s operator with bias ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
400 xnn_node_type_to_string(xnn_node_type_fully_connected), bias_id, bias_value->type);
401 return xnn_status_invalid_parameter;
402 }
403
404 if (bias_value->data == NULL) {
405 xnn_log_error(
406 "failed to define %s operator with bias ID #%" PRIu32 ": non-static Value",
407 xnn_node_type_to_string(xnn_node_type_fully_connected), bias_id);
408 return xnn_status_invalid_parameter;
409 }
410
411 switch (bias_value->datatype) {
412 case xnn_datatype_fp32:
413 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
414 case xnn_datatype_qint32:
415 #endif // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
416 break;
417 default:
418 xnn_log_error(
419 "failed to define %s operator with bias ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
420 xnn_node_type_to_string(xnn_node_type_fully_connected), bias_id,
421 xnn_datatype_to_string(bias_value->datatype), bias_value->datatype);
422 return xnn_status_invalid_parameter;
423 }
424 }
425
426 status = xnn_subgraph_check_output_node_id(xnn_node_type_fully_connected, output_id, subgraph->num_values);
427 if (status != xnn_status_success) {
428 return status;
429 }
430
431 const struct xnn_value* output_value = &subgraph->values[output_id];
432 status = xnn_subgraph_check_output_type_dense(xnn_node_type_fully_connected, output_id, output_value);
433 if (status != xnn_status_success) {
434 return status;
435 }
436
437 switch (output_value->datatype) {
438 case xnn_datatype_fp32:
439 #ifndef XNN_NO_QS8_OPERATORS
440 case xnn_datatype_qint8:
441 #endif // !defined(XNN_NO_QS8_OPERATORS)
442 #ifndef XNN_NO_QU8_OPERATORS
443 case xnn_datatype_quint8:
444 #endif // !defined(XNN_NO_QU8_OPERATORS)
445 break;
446 default:
447 xnn_log_error(
448 "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
449 xnn_node_type_to_string(xnn_node_type_fully_connected), output_id,
450 xnn_datatype_to_string(output_value->datatype), output_value->datatype);
451 return xnn_status_invalid_parameter;
452 }
453
454 enum xnn_compute_type compute_type = xnn_compute_type_invalid;
455 if (bias_value != NULL) {
456 compute_type = validate_datatypes_with_bias(
457 input_value->datatype, filter_value->datatype, bias_value->datatype, output_value->datatype);
458 if (compute_type == xnn_compute_type_invalid) {
459 xnn_log_error(
460 "failed to define %s operator with input ID #%" PRIu32 ", filter ID #%" PRIu32 ", bias ID #%" PRIu32 ", and output ID #%" PRIu32
461 ": mismatching datatypes across input (%s), filter (%s), bias (%s), and output (%s)",
462 xnn_node_type_to_string(xnn_node_type_fully_connected), input_id, filter_id, bias_id, output_id,
463 xnn_datatype_to_string(input_value->datatype),
464 xnn_datatype_to_string(filter_value->datatype),
465 xnn_datatype_to_string(bias_value->datatype),
466 xnn_datatype_to_string(output_value->datatype));
467 return xnn_status_invalid_parameter;
468 }
469 } else {
470 compute_type = validate_datatypes_without_bias(
471 input_value->datatype, filter_value->datatype, output_value->datatype);
472 if (compute_type == xnn_compute_type_invalid) {
473 xnn_log_error(
474 "failed to define %s operator with input ID #%" PRIu32 ", filter ID #%" PRIu32 ", and output ID #%" PRIu32
475 ": mismatching datatypes across input (%s), filter (%s), and output (%s)",
476 xnn_node_type_to_string(xnn_node_type_fully_connected), input_id, filter_id, output_id,
477 xnn_datatype_to_string(input_value->datatype),
478 xnn_datatype_to_string(filter_value->datatype),
479 xnn_datatype_to_string(output_value->datatype));
480 return xnn_status_invalid_parameter;
481 }
482 }
483
484 struct xnn_node* node = xnn_subgraph_new_node(subgraph);
485 if (node == NULL) {
486 return xnn_status_out_of_memory;
487 }
488
489 node->type = xnn_node_type_fully_connected;
490 node->compute_type = compute_type;
491 node->activation.output_min = output_min;
492 node->activation.output_max = output_max;
493 node->num_inputs = 2 + (size_t) (bias_id != XNN_INVALID_VALUE_ID);
494 node->inputs[0] = input_id;
495 node->inputs[1] = filter_id;
496 node->inputs[2] = bias_id;
497 node->num_outputs = 1;
498 node->outputs[0] = output_id;
499 node->flags = flags;
500
501 node->create = create_fully_connected_operator;
502 node->setup = setup_fully_connected_operator;
503
504 return xnn_status_success;
505 }
506