xref: /aosp_15_r20/external/XNNPACK/src/subgraph/fully-connected.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 
11 #include <xnnpack.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/operator.h>
14 #include <xnnpack/params.h>
15 #include <xnnpack/requantization.h>
16 #include <xnnpack/subgraph.h>
17 #include <xnnpack/subgraph-validation.h>
18 
19 
create_fully_connected_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)20 static enum xnn_status create_fully_connected_operator(
21   const struct xnn_node* node,
22   const struct xnn_value* values,
23   size_t num_values,
24   struct xnn_operator_data* opdata,
25   const struct xnn_caches* caches)
26 {
27   assert(node->num_inputs >= 2);
28   assert(node->num_inputs <= 3);
29   const uint32_t input_id = node->inputs[0];
30   assert(input_id != XNN_INVALID_VALUE_ID);
31   assert(input_id < num_values);
32   const uint32_t filter_id = node->inputs[1];
33   assert(filter_id != XNN_INVALID_VALUE_ID);
34   assert(filter_id < num_values);
35 
36   assert(node->num_outputs == 1);
37   const uint32_t output_id = node->outputs[0];
38   assert(output_id != XNN_INVALID_VALUE_ID);
39   assert(output_id < num_values);
40 
41   const size_t num_input_elements = xnn_shape_multiply_all_dims(&values[node->inputs[0]].shape);
42   size_t output_channels, input_channels;
43   if (node->flags & XNN_FLAG_TRANSPOSE_WEIGHTS) {
44     input_channels = values[node->inputs[1]].shape.dim[0];
45     output_channels = values[node->inputs[1]].shape.dim[1];
46   } else {
47     output_channels = values[node->inputs[1]].shape.dim[0];
48     input_channels = values[node->inputs[1]].shape.dim[1];
49   }
50 
51   const void* filter_data = values[filter_id].data;
52   assert(filter_data != NULL);
53 
54   const void* bias_data = NULL;
55   if (node->num_inputs > 2) {
56     const uint32_t bias_id = node->inputs[2];
57     assert(bias_id != XNN_INVALID_VALUE_ID);
58     assert(bias_id < num_values);
59 
60     bias_data = values[bias_id].data;
61     assert(bias_data != NULL);
62   }
63 
64   enum xnn_status status;
65   switch (node->compute_type) {
66 #ifndef XNN_NO_F16_OPERATORS
67     case xnn_compute_type_fp16:
68       status = xnn_create_fully_connected_nc_f16(
69         input_channels,
70         output_channels,
71         input_channels /* input stride */,
72         output_channels /* output stride */,
73         filter_data,
74         bias_data,
75         node->activation.output_min,
76         node->activation.output_max,
77         node->flags | XNN_FLAG_FP32_STATIC_WEIGHTS,
78         caches,
79         &opdata->operator_objects[0]);
80       break;
81 #endif  // XNN_NO_F16_OPERATORS
82     case xnn_compute_type_fp32:
83       status = xnn_create_fully_connected_nc_f32(
84         input_channels,
85         output_channels,
86         input_channels /* input stride */,
87         output_channels /* output stride */,
88         filter_data,
89         bias_data,
90         node->activation.output_min,
91         node->activation.output_max,
92         node->flags /* flags */,
93         caches,
94         &opdata->operator_objects[0]);
95       break;
96 #ifndef XNN_NO_QS8_OPERATORS
97     case xnn_compute_type_qs8:
98     {
99       const float output_scale = values[output_id].quantization.scale;
100       const int32_t output_zero_point = values[output_id].quantization.zero_point;
101       const int8_t output_min = xnn_qs8_quantize(node->activation.output_min, output_scale, output_zero_point);
102       const int8_t output_max = xnn_qs8_quantize(node->activation.output_max, output_scale, output_zero_point);
103       status = xnn_create_fully_connected_nc_qs8(
104         input_channels,
105         output_channels,
106         input_channels /* input stride */,
107         output_channels /* output stride */,
108         (int8_t) values[input_id].quantization.zero_point,
109         values[input_id].quantization.scale,
110         values[filter_id].quantization.scale,
111         filter_data,
112         bias_data,
113         (int8_t) output_zero_point,
114         output_scale, output_min, output_max,
115         node->flags /* flags */,
116         caches,
117         &opdata->operator_objects[0]);
118       break;
119     }
120 #endif  // !defined(XNN_NO_QS8_OPERATORS)
121 #ifndef XNN_NO_QU8_OPERATORS
122     case xnn_compute_type_qu8:
123     {
124       const float output_scale = values[output_id].quantization.scale;
125       const int32_t output_zero_point = values[output_id].quantization.zero_point;
126       const uint8_t output_min = xnn_qu8_quantize(node->activation.output_min, output_scale, output_zero_point);
127       const uint8_t output_max = xnn_qu8_quantize(node->activation.output_max, output_scale, output_zero_point);
128       status = xnn_create_fully_connected_nc_qu8(
129         input_channels,
130         output_channels,
131         input_channels /* input stride */,
132         output_channels /* output stride */,
133         (uint8_t) values[input_id].quantization.zero_point,
134         values[input_id].quantization.scale,
135         (uint8_t) values[filter_id].quantization.zero_point,
136         values[filter_id].quantization.scale,
137         filter_data,
138         bias_data,
139         (uint8_t) output_zero_point,
140         output_scale, output_min, output_max,
141         node->flags /* flags */,
142         caches,
143         &opdata->operator_objects[0]);
144       break;
145     }
146 #endif  // !defined(XNN_NO_QU8_OPERATORS)
147     default:
148       XNN_UNREACHABLE;
149   }
150   if (status == xnn_status_success) {
151     opdata->batch_size = num_input_elements / input_channels;
152     opdata->inputs[0] = input_id;
153     opdata->outputs[0] = output_id;
154   }
155   return status;
156 }
157 
setup_fully_connected_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)158 static enum xnn_status setup_fully_connected_operator(
159   const struct xnn_operator_data* opdata,
160   const struct xnn_blob* blobs,
161   size_t num_blobs,
162   pthreadpool_t threadpool)
163 {
164   const uint32_t input_id = opdata->inputs[0];
165   assert(input_id != XNN_INVALID_VALUE_ID);
166   assert(input_id < num_blobs);
167 
168   const uint32_t output_id = opdata->outputs[0];
169   assert(output_id != XNN_INVALID_VALUE_ID);
170   assert(output_id < num_blobs);
171 
172   const struct xnn_blob* input_blob = blobs + input_id;
173   const void* input_data = input_blob->data;
174   assert(input_data != NULL);
175 
176   const struct xnn_blob* output_blob = blobs + output_id;
177   void* output_data = output_blob->data;
178   assert(output_data != NULL);
179 
180   switch (opdata->operator_objects[0]->type) {
181 #ifndef XNN_NO_F16_OPERATORS
182     case xnn_operator_type_fully_connected_nc_f16:
183       return xnn_setup_fully_connected_nc_f16(
184         opdata->operator_objects[0],
185         opdata->batch_size,
186         input_data,
187         output_data,
188         threadpool);
189 #endif  // !defined(XNN_NO_F16_OPERATORS)
190     case xnn_operator_type_fully_connected_nc_f32:
191       return xnn_setup_fully_connected_nc_f32(
192         opdata->operator_objects[0],
193         opdata->batch_size,
194         input_data,
195         output_data,
196         threadpool);
197 #ifndef XNN_NO_QS8_OPERATORS
198     case xnn_operator_type_fully_connected_nc_qs8:
199       return xnn_setup_fully_connected_nc_qs8(
200         opdata->operator_objects[0],
201         opdata->batch_size,
202         input_data,
203         output_data,
204         threadpool);
205 #endif  // !defined(XNN_NO_QS8_OPERATORS)
206 #ifndef XNN_NO_QU8_OPERATORS
207     case xnn_operator_type_fully_connected_nc_qu8:
208       return xnn_setup_fully_connected_nc_qu8(
209         opdata->operator_objects[0],
210         opdata->batch_size,
211         input_data,
212         output_data,
213         threadpool);
214 #endif  // !defined(XNN_NO_QU8_OPERATORS)
215     default:
216       XNN_UNREACHABLE;
217   }
218 }
219 
validate_datatypes_with_bias(enum xnn_datatype input_datatype,enum xnn_datatype filter_datatype,enum xnn_datatype bias_datatype,enum xnn_datatype output_datatype)220 static inline enum xnn_compute_type validate_datatypes_with_bias(
221   enum xnn_datatype input_datatype,
222   enum xnn_datatype filter_datatype,
223   enum xnn_datatype bias_datatype,
224   enum xnn_datatype output_datatype)
225 {
226   switch (filter_datatype) {
227     case xnn_datatype_fp32:
228       if (input_datatype == xnn_datatype_fp32 &&
229           bias_datatype == xnn_datatype_fp32 &&
230           output_datatype == xnn_datatype_fp32)
231       {
232         return xnn_compute_type_fp32;
233       }
234       break;
235 #ifndef XNN_NO_QS8_OPERATORS
236     case xnn_datatype_qint8:
237       if (input_datatype == xnn_datatype_qint8 &&
238           bias_datatype == xnn_datatype_qint32 &&
239           output_datatype == xnn_datatype_qint8)
240       {
241         return xnn_compute_type_qs8;
242       }
243       break;
244 #endif  // !defined(XNN_NO_QS8_OPERATORS)
245 #ifndef XNN_NO_QU8_OPERATORS
246     case xnn_datatype_quint8:
247       if (input_datatype == xnn_datatype_quint8 &&
248           bias_datatype == xnn_datatype_qint32 &&
249           output_datatype == xnn_datatype_quint8)
250       {
251         return xnn_compute_type_qu8;
252       }
253       break;
254 #endif  // !defined(XNN_NO_QU8_OPERATORS)
255     default:
256       XNN_UNREACHABLE;
257   }
258   return xnn_compute_type_invalid;
259 }
260 
validate_datatypes_without_bias(enum xnn_datatype input_datatype,enum xnn_datatype filter_datatype,enum xnn_datatype output_datatype)261 static inline enum xnn_compute_type validate_datatypes_without_bias(
262   enum xnn_datatype input_datatype,
263   enum xnn_datatype filter_datatype,
264   enum xnn_datatype output_datatype)
265 {
266   switch (filter_datatype) {
267     case xnn_datatype_fp32:
268       if (input_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_fp32) {
269         return xnn_compute_type_fp32;
270       }
271       break;
272 #ifndef XNN_NO_QS8_OPERATORS
273     case xnn_datatype_qint8:
274       if (input_datatype == xnn_datatype_qint8 && output_datatype == xnn_datatype_qint8) {
275         return xnn_compute_type_qs8;
276       }
277       break;
278 #endif  // !defined(XNN_NO_QS8_OPERATORS)
279 #ifndef XNN_NO_QU8_OPERATORS
280     case xnn_datatype_quint8:
281       if (input_datatype == xnn_datatype_quint8 && output_datatype == xnn_datatype_quint8) {
282         return xnn_compute_type_qu8;
283       }
284       break;
285 #endif  // !defined(XNN_NO_QU8_OPERATORS)
286     default:
287       XNN_UNREACHABLE;
288   }
289   return xnn_compute_type_invalid;
290 }
291 
xnn_define_fully_connected(xnn_subgraph_t subgraph,float output_min,float output_max,uint32_t input_id,uint32_t filter_id,uint32_t bias_id,uint32_t output_id,uint32_t flags)292 enum xnn_status xnn_define_fully_connected(
293   xnn_subgraph_t subgraph,
294   float output_min,
295   float output_max,
296   uint32_t input_id,
297   uint32_t filter_id,
298   uint32_t bias_id,
299   uint32_t output_id,
300   uint32_t flags)
301 {
302   enum xnn_status status;
303   if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_fully_connected)) != xnn_status_success) {
304     return status;
305   }
306 
307   status = xnn_subgraph_check_output_min_max(xnn_node_type_fully_connected, output_min, output_max);
308   if (status != xnn_status_success) {
309     return status;
310   }
311 
312   if ((status = xnn_subgraph_check_input_node_id(xnn_node_type_fully_connected, input_id, subgraph->num_values)) !=
313       xnn_status_success) {
314     return status;
315   }
316 
317   const struct xnn_value* input_value = &subgraph->values[input_id];
318   status = xnn_subgraph_check_input_type_dense(xnn_node_type_fully_connected, input_id, input_value);
319   if (status != xnn_status_success) {
320     return status;
321   }
322 
323   switch (input_value->datatype) {
324     case xnn_datatype_fp32:
325 #ifndef XNN_NO_QS8_OPERATORS
326     case xnn_datatype_qint8:
327 #endif  // !defined(XNN_NO_QS8_OPERATORS)
328 #ifndef XNN_NO_QU8_OPERATORS
329     case xnn_datatype_quint8:
330 #endif  // !defined(XNN_NO_QS8_OPERATORS)
331       break;
332     default:
333       xnn_log_error(
334         "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
335         xnn_node_type_to_string(xnn_node_type_fully_connected), input_id,
336         xnn_datatype_to_string(input_value->datatype), input_value->datatype);
337       return xnn_status_invalid_parameter;
338   }
339 
340   if (filter_id >= subgraph->num_values) {
341     xnn_log_error(
342       "failed to define %s operator with filter ID #%" PRIu32 ": invalid Value ID",
343       xnn_node_type_to_string(xnn_node_type_fully_connected), filter_id);
344     return xnn_status_invalid_parameter;
345   }
346 
347   const struct xnn_value* filter_value = &subgraph->values[filter_id];
348   if (filter_value->type != xnn_value_type_dense_tensor) {
349     xnn_log_error(
350       "failed to define %s operator with filter ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
351       xnn_node_type_to_string(xnn_node_type_fully_connected), filter_id, filter_value->type);
352     return xnn_status_invalid_parameter;
353   }
354 
355   if (filter_value->data == NULL) {
356     xnn_log_error(
357       "failed to define %s operator with filter ID #%" PRIu32 ": non-static Value",
358       xnn_node_type_to_string(xnn_node_type_fully_connected), filter_id);
359     return xnn_status_invalid_parameter;
360   }
361 
362   switch (filter_value->datatype) {
363     case xnn_datatype_fp32:
364       break;
365 #ifndef XNN_NO_QS8_OPERATORS
366     case xnn_datatype_qint8:
367       if (filter_value->quantization.zero_point != 0) {
368         xnn_log_error(
369           "failed to define %s operator with filter ID #%" PRIu32 ": unsupported quantization zero point %" PRId32 " for datatype %s",
370           xnn_node_type_to_string(xnn_node_type_convolution_2d), filter_id,
371           filter_value->quantization.zero_point, xnn_datatype_to_string(filter_value->datatype));
372       }
373       break;
374 #endif  // !defined(XNN_NO_QS8_OPERATORS)
375 #ifndef XNN_NO_QU8_OPERATORS
376     case xnn_datatype_quint8:
377       break;
378 #endif  // !defined(XNN_NO_QU8_OPERATORS)
379     default:
380       xnn_log_error(
381         "failed to define %s operator with filter ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
382         xnn_node_type_to_string(xnn_node_type_fully_connected), filter_id,
383         xnn_datatype_to_string(filter_value->datatype), filter_value->datatype);
384       return xnn_status_invalid_parameter;
385   }
386 
387   const struct xnn_value* bias_value = NULL;
388   if (bias_id != XNN_INVALID_VALUE_ID) {
389     if (bias_id >= subgraph->num_values) {
390       xnn_log_error(
391         "failed to define %s operator with bias ID #%" PRIu32 ": invalid Value ID",
392         xnn_node_type_to_string(xnn_node_type_fully_connected), bias_id);
393       return xnn_status_invalid_parameter;
394     }
395 
396     bias_value = &subgraph->values[bias_id];
397     if (bias_value->type != xnn_value_type_dense_tensor) {
398       xnn_log_error(
399         "failed to define %s operator with bias ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
400         xnn_node_type_to_string(xnn_node_type_fully_connected), bias_id, bias_value->type);
401       return xnn_status_invalid_parameter;
402     }
403 
404     if (bias_value->data == NULL) {
405       xnn_log_error(
406         "failed to define %s operator with bias ID #%" PRIu32 ": non-static Value",
407         xnn_node_type_to_string(xnn_node_type_fully_connected), bias_id);
408       return xnn_status_invalid_parameter;
409     }
410 
411     switch (bias_value->datatype) {
412       case xnn_datatype_fp32:
413 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
414       case xnn_datatype_qint32:
415 #endif  // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
416         break;
417       default:
418         xnn_log_error(
419           "failed to define %s operator with bias ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
420           xnn_node_type_to_string(xnn_node_type_fully_connected), bias_id,
421           xnn_datatype_to_string(bias_value->datatype), bias_value->datatype);
422         return xnn_status_invalid_parameter;
423     }
424   }
425 
426   status = xnn_subgraph_check_output_node_id(xnn_node_type_fully_connected, output_id, subgraph->num_values);
427   if (status != xnn_status_success) {
428     return status;
429   }
430 
431   const struct xnn_value* output_value = &subgraph->values[output_id];
432   status = xnn_subgraph_check_output_type_dense(xnn_node_type_fully_connected, output_id, output_value);
433   if (status != xnn_status_success) {
434     return status;
435   }
436 
437   switch (output_value->datatype) {
438     case xnn_datatype_fp32:
439 #ifndef XNN_NO_QS8_OPERATORS
440     case xnn_datatype_qint8:
441 #endif  // !defined(XNN_NO_QS8_OPERATORS)
442 #ifndef XNN_NO_QU8_OPERATORS
443     case xnn_datatype_quint8:
444 #endif  // !defined(XNN_NO_QU8_OPERATORS)
445       break;
446     default:
447       xnn_log_error(
448         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
449         xnn_node_type_to_string(xnn_node_type_fully_connected), output_id,
450         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
451       return xnn_status_invalid_parameter;
452   }
453 
454   enum xnn_compute_type compute_type = xnn_compute_type_invalid;
455   if (bias_value != NULL) {
456     compute_type = validate_datatypes_with_bias(
457       input_value->datatype, filter_value->datatype, bias_value->datatype, output_value->datatype);
458     if (compute_type == xnn_compute_type_invalid) {
459       xnn_log_error(
460         "failed to define %s operator with input ID #%" PRIu32 ", filter ID #%" PRIu32 ", bias ID #%" PRIu32 ", and output ID #%" PRIu32
461         ": mismatching datatypes across input (%s), filter (%s), bias (%s), and output (%s)",
462         xnn_node_type_to_string(xnn_node_type_fully_connected), input_id, filter_id, bias_id, output_id,
463         xnn_datatype_to_string(input_value->datatype),
464         xnn_datatype_to_string(filter_value->datatype),
465         xnn_datatype_to_string(bias_value->datatype),
466         xnn_datatype_to_string(output_value->datatype));
467       return xnn_status_invalid_parameter;
468     }
469   } else {
470     compute_type = validate_datatypes_without_bias(
471       input_value->datatype, filter_value->datatype, output_value->datatype);
472     if (compute_type == xnn_compute_type_invalid) {
473       xnn_log_error(
474         "failed to define %s operator with input ID #%" PRIu32 ", filter ID #%" PRIu32 ", and output ID #%" PRIu32
475         ": mismatching datatypes across input (%s), filter (%s), and output (%s)",
476         xnn_node_type_to_string(xnn_node_type_fully_connected), input_id, filter_id, output_id,
477         xnn_datatype_to_string(input_value->datatype),
478         xnn_datatype_to_string(filter_value->datatype),
479         xnn_datatype_to_string(output_value->datatype));
480       return xnn_status_invalid_parameter;
481     }
482   }
483 
484   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
485   if (node == NULL) {
486     return xnn_status_out_of_memory;
487   }
488 
489   node->type = xnn_node_type_fully_connected;
490   node->compute_type = compute_type;
491   node->activation.output_min = output_min;
492   node->activation.output_max = output_max;
493   node->num_inputs = 2 + (size_t) (bias_id != XNN_INVALID_VALUE_ID);
494   node->inputs[0] = input_id;
495   node->inputs[1] = filter_id;
496   node->inputs[2] = bias_id;
497   node->num_outputs = 1;
498   node->outputs[0] = output_id;
499   node->flags = flags;
500 
501   node->create = create_fully_connected_operator;
502   node->setup = setup_fully_connected_operator;
503 
504   return xnn_status_success;
505 }
506