xref: /aosp_15_r20/external/XNNPACK/src/subgraph/convolution-2d.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 
11 #include <xnnpack.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/operator.h>
14 #include <xnnpack/params.h>
15 #include <xnnpack/requantization.h>
16 #include <xnnpack/subgraph.h>
17 #include <xnnpack/subgraph-validation.h>
18 
19 
create_convolution_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)20 static enum xnn_status create_convolution_operator(
21   const struct xnn_node* node,
22   const struct xnn_value* values,
23   size_t num_values,
24   struct xnn_operator_data* opdata,
25   const struct xnn_caches* caches)
26 {
27   assert(node->num_inputs >= 2);
28   assert(node->num_inputs <= 3);
29   const uint32_t input_id = node->inputs[0];
30   assert(input_id != XNN_INVALID_VALUE_ID);
31   assert(input_id < num_values);
32   const uint32_t filter_id = node->inputs[1];
33   assert(filter_id != XNN_INVALID_VALUE_ID);
34   assert(filter_id < num_values);
35 
36   assert(node->num_outputs == 1);
37   const uint32_t output_id = node->outputs[0];
38   assert(output_id != XNN_INVALID_VALUE_ID);
39   assert(output_id < num_values);
40 
41   const void* filter_data = values[filter_id].data;
42   assert(filter_data != NULL);
43 
44   const void* bias_data = NULL;
45   if (node->num_inputs > 2) {
46     const uint32_t bias_id = node->inputs[2];
47     assert(bias_id != XNN_INVALID_VALUE_ID);
48     assert(bias_id < num_values);
49 
50     bias_data = values[bias_id].data;
51     assert(bias_data != NULL);
52   }
53 
54   enum xnn_status status;
55   if (values[output_id].layout == xnn_layout_type_nchw) {
56     assert(node->compute_type == xnn_compute_type_fp32);
57     status = xnn_create_convolution2d_nchw_f32(
58       node->params.convolution_2d.input_padding_top,
59       node->params.convolution_2d.input_padding_right,
60       node->params.convolution_2d.input_padding_bottom,
61       node->params.convolution_2d.input_padding_left,
62       node->params.convolution_2d.kernel_height,
63       node->params.convolution_2d.kernel_width,
64       node->params.convolution_2d.subsampling_height,
65       node->params.convolution_2d.subsampling_width,
66       node->params.convolution_2d.dilation_height,
67       node->params.convolution_2d.dilation_width,
68       node->params.convolution_2d.groups,
69       node->params.convolution_2d.group_input_channels,
70       node->params.convolution_2d.group_output_channels,
71       node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups /* input_pixel_stride */,
72       node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups /* output_pixel_stride */,
73       filter_data,
74       bias_data,
75       node->activation.output_min,
76       node->activation.output_max,
77       node->flags | (values[input_id].layout == xnn_layout_type_nhwc ? XNN_FLAG_INPUT_NHWC : 0),
78       caches,
79       &opdata->operator_objects[0]);
80   } else {
81     assert(values[input_id].layout == xnn_layout_type_nhwc);
82     assert(values[output_id].layout == xnn_layout_type_nhwc);
83     switch (node->compute_type) {
84       case xnn_compute_type_fp32:
85         status = xnn_create_convolution2d_nhwc_f32(
86           node->params.convolution_2d.input_padding_top,
87           node->params.convolution_2d.input_padding_right,
88           node->params.convolution_2d.input_padding_bottom,
89           node->params.convolution_2d.input_padding_left,
90           node->params.convolution_2d.kernel_height,
91           node->params.convolution_2d.kernel_width,
92           node->params.convolution_2d.subsampling_height,
93           node->params.convolution_2d.subsampling_width,
94           node->params.convolution_2d.dilation_height,
95           node->params.convolution_2d.dilation_width,
96           node->params.convolution_2d.groups,
97           node->params.convolution_2d.group_input_channels,
98           node->params.convolution_2d.group_output_channels,
99           node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups /* input_pixel_stride */,
100           node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups /* output_pixel_stride */,
101           filter_data,
102           bias_data,
103           node->activation.output_min,
104           node->activation.output_max,
105           node->flags,
106           caches,
107           &opdata->operator_objects[0]);
108         break;
109 #ifndef XNN_NO_F16_OPERATORS
110       case xnn_compute_type_fp16:
111         status = xnn_create_convolution2d_nhwc_f16(
112           node->params.convolution_2d.input_padding_top,
113           node->params.convolution_2d.input_padding_right,
114           node->params.convolution_2d.input_padding_bottom,
115           node->params.convolution_2d.input_padding_left,
116           node->params.convolution_2d.kernel_height,
117           node->params.convolution_2d.kernel_width,
118           node->params.convolution_2d.subsampling_height,
119           node->params.convolution_2d.subsampling_width,
120           node->params.convolution_2d.dilation_height,
121           node->params.convolution_2d.dilation_width,
122           node->params.convolution_2d.groups,
123           node->params.convolution_2d.group_input_channels,
124           node->params.convolution_2d.group_output_channels,
125           node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups /* input_pixel_stride */,
126           node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups /* output_pixel_stride */,
127           filter_data,
128           bias_data,
129           node->activation.output_min,
130           node->activation.output_max,
131           node->flags | XNN_FLAG_FP32_STATIC_WEIGHTS,
132           NULL,
133           &opdata->operator_objects[0]);
134         break;
135 #endif  // XNN_NO_F16_OPERATORS
136 #ifndef XNN_NO_QS8_OPERATORS
137       case xnn_compute_type_qs8:
138       {
139         const float output_scale = values[output_id].quantization.scale;
140         const int32_t output_zero_point = values[output_id].quantization.zero_point;
141         const int8_t output_min = xnn_qs8_quantize(node->activation.output_min, output_scale, output_zero_point);
142         const int8_t output_max = xnn_qs8_quantize(node->activation.output_max, output_scale, output_zero_point);
143         status = xnn_create_convolution2d_nhwc_qs8(
144           node->params.convolution_2d.input_padding_top,
145           node->params.convolution_2d.input_padding_right,
146           node->params.convolution_2d.input_padding_bottom,
147           node->params.convolution_2d.input_padding_left,
148           node->params.convolution_2d.kernel_height,
149           node->params.convolution_2d.kernel_width,
150           node->params.convolution_2d.subsampling_height,
151           node->params.convolution_2d.subsampling_width,
152           node->params.convolution_2d.dilation_height,
153           node->params.convolution_2d.dilation_width,
154           node->params.convolution_2d.groups,
155           node->params.convolution_2d.group_input_channels,
156           node->params.convolution_2d.group_output_channels,
157           node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups /* input_pixel_stride */,
158           node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups /* output_pixel_stride */,
159           (int8_t) values[input_id].quantization.zero_point,
160           values[input_id].quantization.scale,
161           values[filter_id].quantization.scale,
162           filter_data,
163           bias_data,
164           (int8_t) output_zero_point,
165           output_scale, output_min, output_max,
166           node->flags,
167           NULL,
168           &opdata->operator_objects[0]);
169         break;
170       }
171       case xnn_compute_type_qc8:
172       {
173         const float output_scale = values[output_id].quantization.scale;
174         const int32_t output_zero_point = values[output_id].quantization.zero_point;
175         const int8_t output_min = xnn_qs8_quantize(node->activation.output_min, output_scale, output_zero_point);
176         const int8_t output_max = xnn_qs8_quantize(node->activation.output_max, output_scale, output_zero_point);
177         status = xnn_create_convolution2d_nhwc_qc8(
178           node->params.convolution_2d.input_padding_top,
179           node->params.convolution_2d.input_padding_right,
180           node->params.convolution_2d.input_padding_bottom,
181           node->params.convolution_2d.input_padding_left,
182           node->params.convolution_2d.kernel_height,
183           node->params.convolution_2d.kernel_width,
184           node->params.convolution_2d.subsampling_height,
185           node->params.convolution_2d.subsampling_width,
186           node->params.convolution_2d.dilation_height,
187           node->params.convolution_2d.dilation_width,
188           node->params.convolution_2d.groups,
189           node->params.convolution_2d.group_input_channels,
190           node->params.convolution_2d.group_output_channels,
191           node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups /* input_pixel_stride */,
192           node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups /* output_pixel_stride */,
193           (int8_t) values[input_id].quantization.zero_point,
194           values[input_id].quantization.scale,
195           values[filter_id].quantization.channelwise_scale,
196           filter_data,
197           bias_data,
198           (int8_t) output_zero_point,
199           output_scale, output_min, output_max,
200           node->flags,
201           NULL,
202           &opdata->operator_objects[0]);
203         break;
204       }
205 #endif  // !defined(XNN_NO_QS8_OPERATORS)
206 #ifndef XNN_NO_QU8_OPERATORS
207       case xnn_compute_type_qu8:
208       {
209         const float output_scale = values[output_id].quantization.scale;
210         const int32_t output_zero_point = values[output_id].quantization.zero_point;
211         const uint8_t output_min = xnn_qu8_quantize(node->activation.output_min, output_scale, output_zero_point);
212         const uint8_t output_max = xnn_qu8_quantize(node->activation.output_max, output_scale, output_zero_point);
213         status = xnn_create_convolution2d_nhwc_qu8(
214           node->params.convolution_2d.input_padding_top,
215           node->params.convolution_2d.input_padding_right,
216           node->params.convolution_2d.input_padding_bottom,
217           node->params.convolution_2d.input_padding_left,
218           node->params.convolution_2d.kernel_height,
219           node->params.convolution_2d.kernel_width,
220           node->params.convolution_2d.subsampling_height,
221           node->params.convolution_2d.subsampling_width,
222           node->params.convolution_2d.dilation_height,
223           node->params.convolution_2d.dilation_width,
224           node->params.convolution_2d.groups,
225           node->params.convolution_2d.group_input_channels,
226           node->params.convolution_2d.group_output_channels,
227           node->params.convolution_2d.group_input_channels * node->params.convolution_2d.groups /* input_pixel_stride */,
228           node->params.convolution_2d.group_output_channels * node->params.convolution_2d.groups /* output_pixel_stride */,
229           (uint8_t) values[input_id].quantization.zero_point,
230           values[input_id].quantization.scale,
231           (uint8_t) values[filter_id].quantization.zero_point,
232           values[filter_id].quantization.scale,
233           filter_data,
234           bias_data,
235           (uint8_t) output_zero_point,
236           output_scale, output_min, output_max,
237           node->flags,
238           NULL,
239           &opdata->operator_objects[0]);
240         break;
241       }
242 #endif  // !defined(XNN_NO_QU8_OPERATORS)
243       default:
244         XNN_UNREACHABLE;
245     }
246   }
247   if (status == xnn_status_success) {
248     opdata->batch_size = values[input_id].shape.dim[0];
249     opdata->input_height = values[input_id].shape.dim[1];
250     opdata->input_width = values[input_id].shape.dim[2];
251     opdata->inputs[0] = input_id;
252     opdata->outputs[0] = output_id;
253   }
254   return status;
255 }
256 
setup_convolution_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)257 static enum xnn_status setup_convolution_operator(
258   const struct xnn_operator_data* opdata,
259   const struct xnn_blob* blobs,
260   size_t num_blobs,
261   pthreadpool_t threadpool)
262 {
263   const uint32_t input_id = opdata->inputs[0];
264   assert(input_id != XNN_INVALID_VALUE_ID);
265   assert(input_id < num_blobs);
266 
267   const uint32_t output_id = opdata->outputs[0];
268   assert(output_id != XNN_INVALID_VALUE_ID);
269   assert(output_id < num_blobs);
270 
271   const struct xnn_blob* input_blob = blobs + input_id;
272   const void* input_data = input_blob->data;
273   assert(input_data != NULL);
274 
275   const struct xnn_blob* output_blob = blobs + output_id;
276   void* output_data = output_blob->data;
277   assert(output_data != NULL);
278 
279   switch (opdata->operator_objects[0]->type) {
280     case xnn_operator_type_convolution_nchw_f32:
281       return xnn_setup_convolution2d_nchw_f32(
282         opdata->operator_objects[0],
283         opdata->batch_size,
284         opdata->input_height,
285         opdata->input_width,
286         input_data,
287         output_data,
288         threadpool);
289       break;
290     case xnn_operator_type_convolution_nhwc_f32:
291       return xnn_setup_convolution2d_nhwc_f32(
292         opdata->operator_objects[0],
293         opdata->batch_size,
294         opdata->input_height,
295         opdata->input_width,
296         input_data,
297         output_data,
298         threadpool);
299       break;
300 #ifndef XNN_NO_F16_OPERATORS
301     case xnn_operator_type_convolution_nhwc_f16:
302       return xnn_setup_convolution2d_nhwc_f16(
303         opdata->operator_objects[0],
304         opdata->batch_size,
305         opdata->input_height,
306         opdata->input_width,
307         input_data,
308         output_data,
309         threadpool);
310       break;
311 #endif  // !defined(XNN_NO_F16_OPERATORS)
312 #ifndef XNN_NO_QS8_OPERATORS
313     case xnn_operator_type_convolution_nhwc_qc8:
314       return xnn_setup_convolution2d_nhwc_qc8(
315         opdata->operator_objects[0],
316         opdata->batch_size,
317         opdata->input_height,
318         opdata->input_width,
319         input_data,
320         output_data,
321         threadpool);
322       break;
323     case xnn_operator_type_convolution_nhwc_qs8:
324       return xnn_setup_convolution2d_nhwc_qs8(
325         opdata->operator_objects[0],
326         opdata->batch_size,
327         opdata->input_height,
328         opdata->input_width,
329         input_data,
330         output_data,
331         threadpool);
332       break;
333 #endif  // !defined(XNN_NO_QS8_OPERATORS)
334 #ifndef XNN_NO_QU8_OPERATORS
335     case xnn_operator_type_convolution_nhwc_qu8:
336       return xnn_setup_convolution2d_nhwc_qu8(
337         opdata->operator_objects[0],
338         opdata->batch_size,
339         opdata->input_height,
340         opdata->input_width,
341         input_data,
342         output_data,
343         threadpool);
344       break;
345 #endif  // !defined(XNN_NO_QU8_OPERATORS)
346     default:
347       XNN_UNREACHABLE;
348   }
349 }
350 
validate_datatypes_with_bias(enum xnn_datatype input_datatype,enum xnn_datatype filter_datatype,enum xnn_datatype bias_datatype,enum xnn_datatype output_datatype)351 static inline enum xnn_compute_type validate_datatypes_with_bias(
352   enum xnn_datatype input_datatype,
353   enum xnn_datatype filter_datatype,
354   enum xnn_datatype bias_datatype,
355   enum xnn_datatype output_datatype)
356 {
357   switch (filter_datatype) {
358     case xnn_datatype_fp32:
359       if (input_datatype == xnn_datatype_fp32 &&
360           bias_datatype == xnn_datatype_fp32 &&
361           output_datatype == xnn_datatype_fp32)
362       {
363         return xnn_compute_type_fp32;
364       }
365       break;
366 #ifndef XNN_NO_QS8_OPERATORS
367     case xnn_datatype_qint8:
368       if (input_datatype == xnn_datatype_qint8 &&
369           bias_datatype == xnn_datatype_qint32 &&
370           output_datatype == xnn_datatype_qint8)
371       {
372         return xnn_compute_type_qs8;
373       }
374       break;
375     case xnn_datatype_qcint8:
376       if (input_datatype == xnn_datatype_qint8 &&
377           bias_datatype == xnn_datatype_qcint32 &&
378           output_datatype == xnn_datatype_qint8)
379       {
380         return xnn_compute_type_qc8;
381       }
382       break;
383 #endif  // !defined(XNN_NO_QS8_OPERATORS)
384 #ifndef XNN_NO_QU8_OPERATORS
385     case xnn_datatype_quint8:
386       if (input_datatype == xnn_datatype_quint8 &&
387           bias_datatype == xnn_datatype_qint32 &&
388           output_datatype == xnn_datatype_quint8)
389       {
390         return xnn_compute_type_qu8;
391       }
392       break;
393 #endif  // !defined(XNN_NO_QU8_OPERATORS)
394     default:
395       XNN_UNREACHABLE;
396   }
397   return xnn_compute_type_invalid;
398 }
399 
validate_datatypes_without_bias(enum xnn_datatype input_datatype,enum xnn_datatype filter_datatype,enum xnn_datatype output_datatype)400 static inline enum xnn_compute_type validate_datatypes_without_bias(
401   enum xnn_datatype input_datatype,
402   enum xnn_datatype filter_datatype,
403   enum xnn_datatype output_datatype)
404 {
405   switch (filter_datatype) {
406     case xnn_datatype_fp32:
407       if (input_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_fp32) {
408         return xnn_compute_type_fp32;
409       }
410       break;
411 #ifndef XNN_NO_QS8_OPERATORS
412     case xnn_datatype_qint8:
413       if (input_datatype == xnn_datatype_qint8 && output_datatype == xnn_datatype_qint8) {
414         return xnn_compute_type_qs8;
415       }
416       break;
417     case xnn_datatype_qcint8:
418       if (input_datatype == xnn_datatype_qint8 && output_datatype == xnn_datatype_qint8) {
419         return xnn_compute_type_qc8;
420       }
421       break;
422 #endif  // !defined(XNN_NO_QS8_OPERATORS)
423 #ifndef XNN_NO_QU8_OPERATORS
424     case xnn_datatype_quint8:
425       if (input_datatype == xnn_datatype_quint8 && output_datatype == xnn_datatype_quint8) {
426         return xnn_compute_type_qu8;
427       }
428       break;
429 #endif  // !defined(XNN_NO_QU8_OPERATORS)
430     default:
431       XNN_UNREACHABLE;
432   }
433   return xnn_compute_type_invalid;
434 }
435 
xnn_define_convolution_2d(xnn_subgraph_t subgraph,uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t kernel_height,uint32_t kernel_width,uint32_t subsampling_height,uint32_t subsampling_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t groups,size_t group_input_channels,size_t group_output_channels,float output_min,float output_max,uint32_t input_id,uint32_t filter_id,uint32_t bias_id,uint32_t output_id,uint32_t flags)436 enum xnn_status xnn_define_convolution_2d(
437   xnn_subgraph_t subgraph,
438   uint32_t input_padding_top,
439   uint32_t input_padding_right,
440   uint32_t input_padding_bottom,
441   uint32_t input_padding_left,
442   uint32_t kernel_height,
443   uint32_t kernel_width,
444   uint32_t subsampling_height,
445   uint32_t subsampling_width,
446   uint32_t dilation_height,
447   uint32_t dilation_width,
448   uint32_t groups,
449   size_t group_input_channels,
450   size_t group_output_channels,
451   float output_min,
452   float output_max,
453   uint32_t input_id,
454   uint32_t filter_id,
455   uint32_t bias_id,
456   uint32_t output_id,
457   uint32_t flags)
458 {
459   enum xnn_status status;
460   if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_convolution_2d)) != xnn_status_success) {
461     return status;
462   }
463 
464   if (kernel_width == 0 || kernel_height == 0) {
465     xnn_log_error(
466       "failed to define %s operator with %" PRIu32 "x%" PRIu32 " kernel: kernel dimensions must be non-zero",
467       xnn_node_type_to_string(xnn_node_type_convolution_2d), kernel_width, kernel_height);
468     return xnn_status_invalid_parameter;
469   }
470 
471   if (subsampling_width == 0 || subsampling_height == 0) {
472     xnn_log_error(
473       "failed to define %s operator with %" PRIu32 "x%" PRIu32 " subsampling: subsampling dimensions must be non-zero",
474       xnn_node_type_to_string(xnn_node_type_convolution_2d), subsampling_width, subsampling_height);
475     return xnn_status_invalid_parameter;
476   }
477 
478   if (dilation_width == 0 || dilation_height == 0) {
479     xnn_log_error(
480       "failed to define %s operator with %" PRIu32 "x%" PRIu32 " dilation: dilation dimensions must be non-zero",
481       xnn_node_type_to_string(xnn_node_type_convolution_2d), dilation_width, dilation_height);
482     return xnn_status_invalid_parameter;
483   }
484 
485   if (groups == 0) {
486     xnn_log_error(
487       "failed to define %s operator with %" PRIu32 " groups: number of groups must be non-zero",
488       xnn_node_type_to_string(xnn_node_type_convolution_2d), groups);
489     return xnn_status_invalid_parameter;
490   }
491 
492   if (group_input_channels == 0) {
493     xnn_log_error(
494       "failed to define %s operator with %zu input channels per group: number of channels must be non-zero",
495       xnn_node_type_to_string(xnn_node_type_convolution_2d), group_input_channels);
496     return xnn_status_invalid_parameter;
497   }
498 
499   if (group_output_channels == 0) {
500     xnn_log_error(
501       "failed to define %s operator with %zu output channels per group: number of channels must be non-zero",
502       xnn_node_type_to_string(xnn_node_type_convolution_2d), group_output_channels);
503     return xnn_status_invalid_parameter;
504   }
505 
506   status = xnn_subgraph_check_output_min_max(xnn_node_type_convolution_2d, output_min, output_max);
507   if (status != xnn_status_success) {
508     return status;
509   }
510 
511   const uint32_t supported_flags = XNN_FLAG_TENSORFLOW_SAME_PADDING;
512   const uint32_t invalid_flags = flags & ~supported_flags;
513   if (invalid_flags != 0) {
514     xnn_log_error(
515       "failed to define %s operator with 0x%08" PRIx32 " flags: invalid flags 0x%08" PRIx32,
516       xnn_node_type_to_string(xnn_node_type_convolution_2d), flags, invalid_flags);
517     return xnn_status_invalid_parameter;
518   }
519 
520   const bool any_padding = (input_padding_left | input_padding_top | input_padding_right | input_padding_bottom) != 0;
521   if ((flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) != 0 && any_padding) {
522     xnn_log_error(
523       "failed to define %s operator with %" PRIu32 "+%" PRIu32 "x%" PRIu32 "+%" PRIu32" padding: "
524       "TensorFlow SAME padding can't be combined with explicit padding specification",
525       xnn_node_type_to_string(xnn_node_type_convolution_2d),
526       input_padding_top, input_padding_left, input_padding_bottom, input_padding_right);
527     return xnn_status_invalid_parameter;
528   }
529 
530   // Convert TensorFlow SAME padding to explicit padding specification whenever possible
531   if ((flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) != 0 && (subsampling_height | subsampling_width) == 1) {
532     flags &= ~XNN_FLAG_TENSORFLOW_SAME_PADDING;
533     const uint32_t padding_height = (kernel_height - 1) * dilation_height;
534     const uint32_t padding_width = (kernel_width - 1) * dilation_width;
535     input_padding_left = padding_width / 2;
536     input_padding_top = padding_height / 2;
537     input_padding_right = padding_width - input_padding_left;
538     input_padding_bottom = padding_height - input_padding_top;
539   }
540 
541   if ((status = xnn_subgraph_check_input_node_id(xnn_node_type_convolution_2d, input_id, subgraph->num_values)) !=
542       xnn_status_success) {
543     return status;
544   }
545 
546   const struct xnn_value* input_value = &subgraph->values[input_id];
547   status = xnn_subgraph_check_input_type_dense(xnn_node_type_convolution_2d, input_id, input_value);
548   if (status != xnn_status_success) {
549     return status;
550   }
551 
552   switch (input_value->datatype) {
553     case xnn_datatype_fp32:
554 #ifndef XNN_NO_QS8_OPERATORS
555     case xnn_datatype_qint8:
556 #endif  // !defined(XNN_NO_QS8_OPERATORS)
557 #ifndef XNN_NO_QU8_OPERATORS
558     case xnn_datatype_quint8:
559 #endif  // !defined(XNN_NO_QU8_OPERATORS)
560       break;
561     default:
562       xnn_log_error(
563         "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
564         xnn_node_type_to_string(xnn_node_type_convolution_2d), input_id,
565         xnn_datatype_to_string(input_value->datatype), input_value->datatype);
566       return xnn_status_invalid_parameter;
567   }
568 
569   if (filter_id >= subgraph->num_values) {
570     xnn_log_error(
571       "failed to define %s operator with filter ID #%" PRIu32 ": invalid Value ID",
572       xnn_node_type_to_string(xnn_node_type_convolution_2d), filter_id);
573     return xnn_status_invalid_parameter;
574   }
575 
576   const struct xnn_value* filter_value = &subgraph->values[filter_id];
577   if (filter_value->type != xnn_value_type_dense_tensor) {
578     xnn_log_error(
579       "failed to define %s operator with filter ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
580       xnn_node_type_to_string(xnn_node_type_convolution_2d), filter_id, filter_value->type);
581     return xnn_status_invalid_parameter;
582   }
583 
584   if (filter_value->data == NULL) {
585     xnn_log_error(
586       "failed to define %s operator with filter ID #%" PRIu32 ": non-static Value",
587       xnn_node_type_to_string(xnn_node_type_convolution_2d), filter_id);
588     return xnn_status_invalid_parameter;
589   }
590 
591   switch (filter_value->datatype) {
592     case xnn_datatype_fp32:
593       break;
594 #ifndef XNN_NO_QS8_OPERATORS
595     case xnn_datatype_qint8:
596       if (filter_value->quantization.zero_point != 0) {
597         xnn_log_error(
598           "failed to define %s operator with filter ID #%" PRIu32 ": unsupported quantization zero point %" PRId32 " for datatype %s",
599           xnn_node_type_to_string(xnn_node_type_convolution_2d), filter_id,
600           filter_value->quantization.zero_point, xnn_datatype_to_string(filter_value->datatype));
601       }
602       break;
603     case xnn_datatype_qcint8:
604       break;
605 #endif  // !defined(XNN_NO_QS8_OPERATORS)
606 #ifndef XNN_NO_QU8_OPERATORS
607     case xnn_datatype_quint8:
608       break;
609 #endif  // !defined(XNN_NO_QU8_OPERATORS)
610     default:
611       xnn_log_error(
612         "failed to define %s operator with filter ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
613         xnn_node_type_to_string(xnn_node_type_convolution_2d), filter_id,
614         xnn_datatype_to_string(filter_value->datatype), filter_value->datatype);
615       return xnn_status_invalid_parameter;
616   }
617 
618   const struct xnn_value* bias_value = NULL;
619   if (bias_id != XNN_INVALID_VALUE_ID) {
620     if (bias_id >= subgraph->num_values) {
621       xnn_log_error(
622         "failed to define %s operator with bias ID #%" PRIu32 ": invalid Value ID",
623         xnn_node_type_to_string(xnn_node_type_convolution_2d), bias_id);
624       return xnn_status_invalid_parameter;
625     }
626 
627     bias_value = &subgraph->values[bias_id];
628     if (bias_value->type != xnn_value_type_dense_tensor) {
629       xnn_log_error(
630         "failed to define %s operator with bias ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
631         xnn_node_type_to_string(xnn_node_type_convolution_2d), bias_id, bias_value->type);
632       return xnn_status_invalid_parameter;
633     }
634 
635     if (bias_value->data == NULL) {
636       xnn_log_error(
637         "failed to define %s operator with bias ID #%" PRIu32 ": non-static Value",
638         xnn_node_type_to_string(xnn_node_type_convolution_2d), bias_id);
639       return xnn_status_invalid_parameter;
640     }
641 
642     switch (bias_value->datatype) {
643       case xnn_datatype_fp32:
644 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
645       case xnn_datatype_qint32:
646 #endif  // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
647 #ifndef XNN_NO_QS8_OPERATORS
648       case xnn_datatype_qcint32:
649 #endif  // !defined(XNN_NO_QS8_OPERATORS)
650         break;
651       default:
652         xnn_log_error(
653           "failed to define %s operator with bias ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
654           xnn_node_type_to_string(xnn_node_type_convolution_2d), bias_id,
655           xnn_datatype_to_string(bias_value->datatype), bias_value->datatype);
656         return xnn_status_invalid_parameter;
657     }
658   }
659 
660   status = xnn_subgraph_check_output_node_id(xnn_node_type_convolution_2d, output_id, subgraph->num_values);
661   if (status != xnn_status_success) {
662     return status;
663   }
664 
665   const struct xnn_value* output_value = &subgraph->values[output_id];
666   status = xnn_subgraph_check_output_type_dense(xnn_node_type_convolution_2d, output_id, output_value);
667   if (status != xnn_status_success) {
668     return status;
669   }
670 
671   switch (output_value->datatype) {
672     case xnn_datatype_fp32:
673 #ifndef XNN_NO_QS8_OPERATORS
674     case xnn_datatype_qint8:
675 #endif  // !defined(XNN_NO_QS8_OPERATORS)
676 #ifndef XNN_NO_QU8_OPERATORS
677     case xnn_datatype_quint8:
678 #endif  // !defined(XNN_NO_QU8_OPERATORS)
679       break;
680     default:
681       xnn_log_error(
682         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
683         xnn_node_type_to_string(xnn_node_type_convolution_2d), output_id,
684         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
685       return xnn_status_invalid_parameter;
686   }
687 
688   enum xnn_compute_type compute_type = xnn_compute_type_invalid;
689   if (bias_value != NULL) {
690     compute_type = validate_datatypes_with_bias(
691       input_value->datatype, filter_value->datatype, bias_value->datatype, output_value->datatype);
692     if (compute_type == xnn_compute_type_invalid) {
693       xnn_log_error(
694         "failed to define %s operator with input ID #%" PRIu32 ", filter ID #%" PRIu32 ", bias ID #%" PRIu32 ", and output ID #%" PRIu32
695         ": mismatching datatypes across input (%s), filter (%s), bias (%s), and output (%s)",
696         xnn_node_type_to_string(xnn_node_type_convolution_2d), input_id, filter_id, bias_id, output_id,
697         xnn_datatype_to_string(input_value->datatype),
698         xnn_datatype_to_string(filter_value->datatype),
699         xnn_datatype_to_string(bias_value->datatype),
700         xnn_datatype_to_string(output_value->datatype));
701       return xnn_status_invalid_parameter;
702     }
703   } else {
704     compute_type = validate_datatypes_without_bias(
705       input_value->datatype, filter_value->datatype, output_value->datatype);
706     if (compute_type == xnn_compute_type_invalid) {
707       xnn_log_error(
708         "failed to define %s operator with input ID #%" PRIu32 ", filter ID #%" PRIu32 ", and output ID #%" PRIu32
709         ": mismatching datatypes across input (%s), filter (%s), and output (%s)",
710         xnn_node_type_to_string(xnn_node_type_convolution_2d), input_id, filter_id, output_id,
711         xnn_datatype_to_string(input_value->datatype),
712         xnn_datatype_to_string(filter_value->datatype),
713         xnn_datatype_to_string(output_value->datatype));
714       return xnn_status_invalid_parameter;
715     }
716   }
717 
718 #ifndef XNN_NO_QS8_OPERATORS
719   if (filter_value->datatype == xnn_datatype_qcint8) {
720     if (filter_value->quantization.channel_dimension != 0) {
721       xnn_log_error(
722         "failed to define %s operator with filter ID #%" PRIu32 ": invalid channel dimension %zu",
723         xnn_node_type_to_string(xnn_node_type_convolution_2d), input_id, filter_value->quantization.channel_dimension);
724       return xnn_status_invalid_parameter;
725     }
726 
727     if (bias_value != NULL) {
728       assert(bias_value->datatype == xnn_datatype_qcint32);
729       if (bias_value->quantization.channel_dimension != 0) {
730         xnn_log_error(
731           "failed to define %s operator with bias ID #%" PRIu32 ": invalid channel dimension %zu",
732           xnn_node_type_to_string(xnn_node_type_convolution_2d), bias_id, bias_value->quantization.channel_dimension);
733         return xnn_status_invalid_parameter;
734       }
735     }
736   }
737 #endif  // !defined(XNN_NO_QS8_OPERATORS)
738 
739   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
740   if (node == NULL) {
741     return xnn_status_out_of_memory;
742   }
743 
744   node->type = xnn_node_type_convolution_2d;
745   node->compute_type = compute_type;
746   node->params.convolution_2d.input_padding_top = input_padding_top;
747   node->params.convolution_2d.input_padding_right = input_padding_right;
748   node->params.convolution_2d.input_padding_bottom = input_padding_bottom;
749   node->params.convolution_2d.input_padding_left = input_padding_left;
750   node->params.convolution_2d.kernel_height = kernel_height;
751   node->params.convolution_2d.kernel_width = kernel_width;
752   node->params.convolution_2d.subsampling_height = subsampling_height;
753   node->params.convolution_2d.subsampling_width = subsampling_width;
754   node->params.convolution_2d.dilation_height = dilation_height;
755   node->params.convolution_2d.dilation_width = dilation_width;
756   node->params.convolution_2d.groups = groups;
757   node->params.convolution_2d.group_input_channels = group_input_channels;
758   node->params.convolution_2d.group_output_channels = group_output_channels;
759   node->activation.output_min = output_min;
760   node->activation.output_max = output_max;
761   node->num_inputs = 2 + (size_t) (bias_id != XNN_INVALID_VALUE_ID);
762   node->inputs[0] = input_id;
763   node->inputs[1] = filter_id;
764   node->inputs[2] = bias_id;
765   node->num_outputs = 1;
766   node->outputs[0] = output_id;
767   node->flags = flags;
768 
769   node->create = create_convolution_operator;
770   node->setup = setup_convolution_operator;
771 
772   return xnn_status_success;
773 };
774