xref: /aosp_15_r20/external/XNNPACK/src/subgraph/depthwise-convolution-2d.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 
11 #include <xnnpack.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/operator.h>
14 #include <xnnpack/params.h>
15 #include <xnnpack/requantization.h>
16 #include <xnnpack/subgraph.h>
17 #include <xnnpack/subgraph-validation.h>
18 
19 
create_convolution_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)20 static enum xnn_status create_convolution_operator(
21   const struct xnn_node* node,
22   const struct xnn_value* values,
23   size_t num_values,
24   struct xnn_operator_data* opdata,
25   const struct xnn_caches* caches)
26 {
27   assert(node->num_inputs >= 2);
28   assert(node->num_inputs <= 3);
29   const uint32_t input_id = node->inputs[0];
30   assert(input_id != XNN_INVALID_VALUE_ID);
31   assert(input_id < num_values);
32   const uint32_t filter_id = node->inputs[1];
33   assert(filter_id != XNN_INVALID_VALUE_ID);
34   assert(filter_id < num_values);
35 
36   assert(node->num_outputs == 1);
37   const uint32_t output_id = node->outputs[0];
38   assert(output_id != XNN_INVALID_VALUE_ID);
39   assert(output_id < num_values);
40 
41   const void* filter_data = values[filter_id].data;
42   assert(filter_data != NULL);
43 
44   const void* bias_data = NULL;
45   if (node->num_inputs > 2) {
46     const uint32_t bias_id = node->inputs[2];
47     assert(bias_id != XNN_INVALID_VALUE_ID);
48     assert(bias_id < num_values);
49 
50     bias_data = values[bias_id].data;
51     assert(bias_data != NULL);
52   }
53 
54   enum xnn_status status;
55   if (values[output_id].layout == xnn_layout_type_nchw) {
56     assert(values[input_id].layout == xnn_layout_type_nchw);
57     assert(node->compute_type == xnn_compute_type_fp32);
58     status = xnn_create_convolution2d_nchw_f32(
59       node->params.depthwise_convolution_2d.input_padding_top,
60       node->params.depthwise_convolution_2d.input_padding_right,
61       node->params.depthwise_convolution_2d.input_padding_bottom,
62       node->params.depthwise_convolution_2d.input_padding_left,
63       node->params.depthwise_convolution_2d.kernel_height,
64       node->params.depthwise_convolution_2d.kernel_width,
65       node->params.depthwise_convolution_2d.subsampling_height,
66       node->params.depthwise_convolution_2d.subsampling_width,
67       node->params.depthwise_convolution_2d.dilation_height,
68       node->params.depthwise_convolution_2d.dilation_width,
69       node->params.depthwise_convolution_2d.input_channels /* groups */,
70       1 /* group_input_channels */,
71       node->params.depthwise_convolution_2d.depth_multiplier /* group_output_channels */,
72       node->params.depthwise_convolution_2d.input_channels /* input_channel_stride */,
73       node->params.depthwise_convolution_2d.input_channels * node->params.depthwise_convolution_2d.depth_multiplier /* output_channel_stride */,
74       filter_data,
75       bias_data,
76       node->activation.output_min,
77       node->activation.output_max,
78       node->flags | XNN_FLAG_DEPTHWISE_CONVOLUTION,
79       caches,
80       &opdata->operator_objects[0]);
81   } else {
82     assert(values[input_id].layout == xnn_layout_type_nhwc);
83     assert(values[output_id].layout == xnn_layout_type_nhwc);
84     switch (node->compute_type) {
85       case xnn_compute_type_fp32:
86         status = xnn_create_convolution2d_nhwc_f32(
87           node->params.depthwise_convolution_2d.input_padding_top,
88           node->params.depthwise_convolution_2d.input_padding_right,
89           node->params.depthwise_convolution_2d.input_padding_bottom,
90           node->params.depthwise_convolution_2d.input_padding_left,
91           node->params.depthwise_convolution_2d.kernel_height,
92           node->params.depthwise_convolution_2d.kernel_width,
93           node->params.depthwise_convolution_2d.subsampling_height,
94           node->params.depthwise_convolution_2d.subsampling_width,
95           node->params.depthwise_convolution_2d.dilation_height,
96           node->params.depthwise_convolution_2d.dilation_width,
97           node->params.depthwise_convolution_2d.input_channels /* groups */,
98           1 /* group_input_channels */,
99           node->params.depthwise_convolution_2d.depth_multiplier /* group_output_channels */,
100           node->params.depthwise_convolution_2d.input_channels /* input_channel_stride */,
101           node->params.depthwise_convolution_2d.input_channels * node->params.depthwise_convolution_2d.depth_multiplier /* output_channel_stride */,
102           filter_data,
103           bias_data,
104           node->activation.output_min,
105           node->activation.output_max,
106           node->flags | XNN_FLAG_DEPTHWISE_CONVOLUTION,
107           NULL,
108           &opdata->operator_objects[0]);
109         break;
110 #ifndef XNN_NO_F16_OPERATORS
111       case xnn_compute_type_fp16:
112         status = xnn_create_convolution2d_nhwc_f16(
113           node->params.depthwise_convolution_2d.input_padding_top,
114           node->params.depthwise_convolution_2d.input_padding_right,
115           node->params.depthwise_convolution_2d.input_padding_bottom,
116           node->params.depthwise_convolution_2d.input_padding_left,
117           node->params.depthwise_convolution_2d.kernel_height,
118           node->params.depthwise_convolution_2d.kernel_width,
119           node->params.depthwise_convolution_2d.subsampling_height,
120           node->params.depthwise_convolution_2d.subsampling_width,
121           node->params.depthwise_convolution_2d.dilation_height,
122           node->params.depthwise_convolution_2d.dilation_width,
123           node->params.depthwise_convolution_2d.input_channels /* groups */,
124           1 /* group_input_channels */,
125           node->params.depthwise_convolution_2d.depth_multiplier /* group_output_channels */,
126           node->params.depthwise_convolution_2d.input_channels /* input_channel_stride */,
127           node->params.depthwise_convolution_2d.input_channels * node->params.depthwise_convolution_2d.depth_multiplier /* output_channel_stride */,
128           filter_data,
129           bias_data,
130           node->activation.output_min,
131           node->activation.output_max,
132           node->flags | XNN_FLAG_DEPTHWISE_CONVOLUTION | XNN_FLAG_FP32_STATIC_WEIGHTS,
133           NULL,
134           &opdata->operator_objects[0]);
135         break;
136 #endif  // XNN_NO_F16_OPERATORS
137 #ifndef XNN_NO_QS8_OPERATORS
138       case xnn_compute_type_qs8:
139       {
140         const float output_scale = values[output_id].quantization.scale;
141         const int32_t output_zero_point = values[output_id].quantization.zero_point;
142         const int8_t output_min = xnn_qs8_quantize(node->activation.output_min, output_scale, output_zero_point);
143         const int8_t output_max = xnn_qs8_quantize(node->activation.output_max, output_scale, output_zero_point);
144         status = xnn_create_convolution2d_nhwc_qs8(
145           node->params.depthwise_convolution_2d.input_padding_top,
146           node->params.depthwise_convolution_2d.input_padding_right,
147           node->params.depthwise_convolution_2d.input_padding_bottom,
148           node->params.depthwise_convolution_2d.input_padding_left,
149           node->params.depthwise_convolution_2d.kernel_height,
150           node->params.depthwise_convolution_2d.kernel_width,
151           node->params.depthwise_convolution_2d.subsampling_height,
152           node->params.depthwise_convolution_2d.subsampling_width,
153           node->params.depthwise_convolution_2d.dilation_height,
154           node->params.depthwise_convolution_2d.dilation_width,
155           node->params.depthwise_convolution_2d.input_channels /* groups */,
156           1 /* group_input_channels */,
157           node->params.depthwise_convolution_2d.depth_multiplier /* group_output_channels */,
158           node->params.depthwise_convolution_2d.input_channels /* input_channel_stride */,
159           node->params.depthwise_convolution_2d.input_channels * node->params.depthwise_convolution_2d.depth_multiplier /* output_channel_stride */,
160           (int8_t) values[input_id].quantization.zero_point,
161           values[input_id].quantization.scale,
162           values[filter_id].quantization.scale,
163           filter_data,
164           bias_data,
165           (int8_t) output_zero_point,
166           output_scale, output_min, output_max,
167           node->flags | XNN_FLAG_DEPTHWISE_CONVOLUTION,
168           NULL,
169           &opdata->operator_objects[0]);
170         break;
171       }
172       case xnn_compute_type_qc8:
173       {
174         const float output_scale = values[output_id].quantization.scale;
175         const int32_t output_zero_point = values[output_id].quantization.zero_point;
176         const int8_t output_min = xnn_qs8_quantize(node->activation.output_min, output_scale, output_zero_point);
177         const int8_t output_max = xnn_qs8_quantize(node->activation.output_max, output_scale, output_zero_point);
178         status = xnn_create_convolution2d_nhwc_qc8(
179           node->params.depthwise_convolution_2d.input_padding_top,
180           node->params.depthwise_convolution_2d.input_padding_right,
181           node->params.depthwise_convolution_2d.input_padding_bottom,
182           node->params.depthwise_convolution_2d.input_padding_left,
183           node->params.depthwise_convolution_2d.kernel_height,
184           node->params.depthwise_convolution_2d.kernel_width,
185           node->params.depthwise_convolution_2d.subsampling_height,
186           node->params.depthwise_convolution_2d.subsampling_width,
187           node->params.depthwise_convolution_2d.dilation_height,
188           node->params.depthwise_convolution_2d.dilation_width,
189           node->params.depthwise_convolution_2d.input_channels /* groups */,
190           1 /* group_input_channels */,
191           node->params.depthwise_convolution_2d.depth_multiplier /* group_output_channels */,
192           node->params.depthwise_convolution_2d.input_channels /* input_channel_stride */,
193           node->params.depthwise_convolution_2d.input_channels * node->params.depthwise_convolution_2d.depth_multiplier /* output_channel_stride */,
194           (int8_t) values[input_id].quantization.zero_point,
195           values[input_id].quantization.scale,
196           values[filter_id].quantization.channelwise_scale,
197           filter_data,
198           bias_data,
199           (int8_t) output_zero_point,
200           output_scale, output_min, output_max,
201           node->flags | XNN_FLAG_DEPTHWISE_CONVOLUTION,
202           NULL,
203           &opdata->operator_objects[0]);
204         break;
205       }
206 #endif  // !defined(XNN_NO_QS8_OPERATORS)
207 #ifndef XNN_NO_QU8_OPERATORS
208       case xnn_compute_type_qu8:
209       {
210         const float output_scale = values[output_id].quantization.scale;
211         const int32_t output_zero_point = values[output_id].quantization.zero_point;
212         const uint8_t output_min = xnn_qu8_quantize(node->activation.output_min, output_scale, output_zero_point);
213         const uint8_t output_max = xnn_qu8_quantize(node->activation.output_max, output_scale, output_zero_point);
214         status = xnn_create_convolution2d_nhwc_qu8(
215           node->params.depthwise_convolution_2d.input_padding_top,
216           node->params.depthwise_convolution_2d.input_padding_right,
217           node->params.depthwise_convolution_2d.input_padding_bottom,
218           node->params.depthwise_convolution_2d.input_padding_left,
219           node->params.depthwise_convolution_2d.kernel_height,
220           node->params.depthwise_convolution_2d.kernel_width,
221           node->params.depthwise_convolution_2d.subsampling_height,
222           node->params.depthwise_convolution_2d.subsampling_width,
223           node->params.depthwise_convolution_2d.dilation_height,
224           node->params.depthwise_convolution_2d.dilation_width,
225           node->params.depthwise_convolution_2d.input_channels /* groups */,
226           1 /* group_input_channels */,
227           node->params.depthwise_convolution_2d.depth_multiplier /* group_output_channels */,
228           node->params.depthwise_convolution_2d.input_channels /* input_channel_stride */,
229           node->params.depthwise_convolution_2d.input_channels * node->params.depthwise_convolution_2d.depth_multiplier /* output_channel_stride */,
230           (uint8_t) values[input_id].quantization.zero_point,
231           values[input_id].quantization.scale,
232           (uint8_t) values[filter_id].quantization.zero_point,
233           values[filter_id].quantization.scale,
234           filter_data,
235           bias_data,
236           (uint8_t) output_zero_point,
237           output_scale, output_min, output_max,
238           node->flags | XNN_FLAG_DEPTHWISE_CONVOLUTION,
239           NULL,
240           &opdata->operator_objects[0]);
241         break;
242       }
243 #endif  // !defined(XNN_NO_QU8_OPERATORS)
244       default:
245         XNN_UNREACHABLE;
246     }
247   }
248   if (status == xnn_status_success) {
249     opdata->batch_size = values[input_id].shape.dim[0];
250     opdata->input_height = values[input_id].shape.dim[1];
251     opdata->input_width = values[input_id].shape.dim[2];
252     opdata->inputs[0] = input_id;
253     opdata->outputs[0] = output_id;
254   }
255   return status;
256 }
257 
setup_convolution_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)258 static enum xnn_status setup_convolution_operator(
259   const struct xnn_operator_data* opdata,
260   const struct xnn_blob* blobs,
261   size_t num_blobs,
262   pthreadpool_t threadpool)
263 {
264   const uint32_t input_id = opdata->inputs[0];
265   assert(input_id != XNN_INVALID_VALUE_ID);
266   assert(input_id < num_blobs);
267 
268   const uint32_t output_id = opdata->outputs[0];
269   assert(output_id != XNN_INVALID_VALUE_ID);
270   assert(output_id < num_blobs);
271 
272   const struct xnn_blob* input_blob = blobs + input_id;
273   const void* input_data = input_blob->data;
274   assert(input_data != NULL);
275 
276   const struct xnn_blob* output_blob = blobs + output_id;
277   void* output_data = output_blob->data;
278   assert(output_data != NULL);
279 
280   switch (opdata->operator_objects[0]->type) {
281     case xnn_operator_type_convolution_nchw_f32:
282       return xnn_setup_convolution2d_nchw_f32(
283         opdata->operator_objects[0],
284         opdata->batch_size,
285         opdata->input_height,
286         opdata->input_width,
287         input_data,
288         output_data,
289         threadpool);
290       break;
291     case xnn_operator_type_convolution_nhwc_f32:
292       return xnn_setup_convolution2d_nhwc_f32(
293         opdata->operator_objects[0],
294         opdata->batch_size,
295         opdata->input_height,
296         opdata->input_width,
297         input_data,
298         output_data,
299         threadpool);
300       break;
301 #ifndef XNN_NO_F16_OPERATORS
302     case xnn_operator_type_convolution_nhwc_f16:
303       return xnn_setup_convolution2d_nhwc_f16(
304         opdata->operator_objects[0],
305         opdata->batch_size,
306         opdata->input_height,
307         opdata->input_width,
308         input_data,
309         output_data,
310         threadpool);
311       break;
312 #endif  // !defined(XNN_NO_F16_OPERATORS)
313 #ifndef XNN_NO_QS8_OPERATORS
314     case xnn_operator_type_convolution_nhwc_qc8:
315       return xnn_setup_convolution2d_nhwc_qc8(
316         opdata->operator_objects[0],
317         opdata->batch_size,
318         opdata->input_height,
319         opdata->input_width,
320         input_data,
321         output_data,
322         threadpool);
323       break;
324     case xnn_operator_type_convolution_nhwc_qs8:
325       return xnn_setup_convolution2d_nhwc_qs8(
326         opdata->operator_objects[0],
327         opdata->batch_size,
328         opdata->input_height,
329         opdata->input_width,
330         input_data,
331         output_data,
332         threadpool);
333       break;
334 #endif  // !defined(XNN_NO_QS8_OPERATORS)
335 #ifndef XNN_NO_QU8_OPERATORS
336     case xnn_operator_type_convolution_nhwc_qu8:
337       return xnn_setup_convolution2d_nhwc_qu8(
338         opdata->operator_objects[0],
339         opdata->batch_size,
340         opdata->input_height,
341         opdata->input_width,
342         input_data,
343         output_data,
344         threadpool);
345       break;
346 #endif  // !defined(XNN_NO_QU8_OPERATORS)
347     default:
348       XNN_UNREACHABLE;
349   }
350 }
351 
validate_datatypes_with_bias(enum xnn_datatype input_datatype,enum xnn_datatype filter_datatype,enum xnn_datatype bias_datatype,enum xnn_datatype output_datatype)352 static inline enum xnn_compute_type validate_datatypes_with_bias(
353   enum xnn_datatype input_datatype,
354   enum xnn_datatype filter_datatype,
355   enum xnn_datatype bias_datatype,
356   enum xnn_datatype output_datatype)
357 {
358   switch (filter_datatype) {
359     case xnn_datatype_fp32:
360       if (input_datatype == xnn_datatype_fp32 &&
361           bias_datatype == xnn_datatype_fp32 &&
362           output_datatype == xnn_datatype_fp32)
363       {
364         return xnn_compute_type_fp32;
365       }
366       break;
367 #ifndef XNN_NO_QS8_OPERATORS
368     case xnn_datatype_qint8:
369       if (input_datatype == xnn_datatype_qint8 &&
370           bias_datatype == xnn_datatype_qint32 &&
371           output_datatype == xnn_datatype_qint8)
372       {
373         return xnn_compute_type_qs8;
374       }
375       break;
376     case xnn_datatype_qcint8:
377       if (input_datatype == xnn_datatype_qint8 &&
378           bias_datatype == xnn_datatype_qcint32 &&
379           output_datatype == xnn_datatype_qint8)
380       {
381         return xnn_compute_type_qc8;
382       }
383       break;
384 #endif  // !defined(XNN_NO_QS8_OPERATORS)
385 #ifndef XNN_NO_QU8_OPERATORS
386     case xnn_datatype_quint8:
387       if (input_datatype == xnn_datatype_quint8 &&
388           bias_datatype == xnn_datatype_qint32 &&
389           output_datatype == xnn_datatype_quint8)
390       {
391         return xnn_compute_type_qu8;
392       }
393       break;
394 #endif  // !defined(XNN_NO_QU8_OPERATORS)
395     default:
396       XNN_UNREACHABLE;
397   }
398   return xnn_compute_type_invalid;
399 }
400 
validate_datatypes_without_bias(enum xnn_datatype input_datatype,enum xnn_datatype filter_datatype,enum xnn_datatype output_datatype)401 static inline enum xnn_compute_type validate_datatypes_without_bias(
402   enum xnn_datatype input_datatype,
403   enum xnn_datatype filter_datatype,
404   enum xnn_datatype output_datatype)
405 {
406   switch (filter_datatype) {
407     case xnn_datatype_fp32:
408       if (input_datatype == xnn_datatype_fp32 && output_datatype == xnn_datatype_fp32) {
409         return xnn_compute_type_fp32;
410       }
411       break;
412 #ifndef XNN_NO_QS8_OPERATORS
413     case xnn_datatype_qint8:
414       if (input_datatype == xnn_datatype_qint8 && output_datatype == xnn_datatype_qint8) {
415         return xnn_compute_type_qs8;
416       }
417       break;
418     case xnn_datatype_qcint8:
419       if (input_datatype == xnn_datatype_qint8 && output_datatype == xnn_datatype_qint8) {
420         return xnn_compute_type_qc8;
421       }
422       break;
423 #endif  // !defined(XNN_NO_QS8_OPERATORS)
424 #ifndef XNN_NO_QU8_OPERATORS
425     case xnn_datatype_quint8:
426       if (input_datatype == xnn_datatype_quint8 && output_datatype == xnn_datatype_quint8) {
427         return xnn_compute_type_qu8;
428       }
429       break;
430 #endif  // !defined(XNN_NO_QU8_OPERATORS)
431     default:
432       XNN_UNREACHABLE;
433   }
434   return xnn_compute_type_invalid;
435 }
436 
xnn_define_depthwise_convolution_2d(xnn_subgraph_t subgraph,uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t kernel_height,uint32_t kernel_width,uint32_t subsampling_height,uint32_t subsampling_width,uint32_t dilation_height,uint32_t dilation_width,uint32_t depth_multiplier,size_t input_channels,float output_min,float output_max,uint32_t input_id,uint32_t filter_id,uint32_t bias_id,uint32_t output_id,uint32_t flags)437 enum xnn_status xnn_define_depthwise_convolution_2d(
438   xnn_subgraph_t subgraph,
439   uint32_t input_padding_top,
440   uint32_t input_padding_right,
441   uint32_t input_padding_bottom,
442   uint32_t input_padding_left,
443   uint32_t kernel_height,
444   uint32_t kernel_width,
445   uint32_t subsampling_height,
446   uint32_t subsampling_width,
447   uint32_t dilation_height,
448   uint32_t dilation_width,
449   uint32_t depth_multiplier,
450   size_t input_channels,
451   float output_min,
452   float output_max,
453   uint32_t input_id,
454   uint32_t filter_id,
455   uint32_t bias_id,
456   uint32_t output_id,
457   uint32_t flags)
458 {
459   enum xnn_status status;
460   if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_depthwise_convolution_2d)) != xnn_status_success) {
461     return status;
462   }
463 
464   if (kernel_width == 0 || kernel_height == 0) {
465     xnn_log_error(
466       "failed to define %s operator with %" PRIu32 "x%" PRIu32 " kernel: kernel dimensions must be non-zero",
467       xnn_node_type_to_string(xnn_node_type_depthwise_convolution_2d), kernel_width, kernel_height);
468     return xnn_status_invalid_parameter;
469   }
470 
471   if (subsampling_width == 0 || subsampling_height == 0) {
472     xnn_log_error(
473       "failed to define %s operator with %" PRIu32 "x%" PRIu32 " subsampling: subsampling dimensions must be non-zero",
474       xnn_node_type_to_string(xnn_node_type_depthwise_convolution_2d), subsampling_width, subsampling_height);
475     return xnn_status_invalid_parameter;
476   }
477 
478   if (dilation_width == 0 || dilation_height == 0) {
479     xnn_log_error(
480       "failed to define %s operator with %" PRIu32 "x%" PRIu32 " dilation: dilation dimensions must be non-zero",
481       xnn_node_type_to_string(xnn_node_type_depthwise_convolution_2d), dilation_width, dilation_height);
482     return xnn_status_invalid_parameter;
483   }
484 
485   if (depth_multiplier == 0) {
486     xnn_log_error(
487       "failed to define %s operator with %" PRIu32 " depth multiplier: depth multiplier must be non-zero",
488       xnn_node_type_to_string(xnn_node_type_depthwise_convolution_2d), depth_multiplier);
489     return xnn_status_invalid_parameter;
490   }
491 
492   if (input_channels == 0) {
493     xnn_log_error(
494       "failed to define %s operator with %zu input channels: number of channels must be non-zero",
495       xnn_node_type_to_string(xnn_node_type_depthwise_convolution_2d), input_channels);
496     return xnn_status_invalid_parameter;
497   }
498 
499   status = xnn_subgraph_check_output_min_max(xnn_node_type_depthwise_convolution_2d, output_min, output_max);
500   if (status != xnn_status_success) {
501     return status;
502   }
503 
504   const uint32_t supported_flags = XNN_FLAG_TENSORFLOW_SAME_PADDING;
505   const uint32_t invalid_flags = flags & ~supported_flags;
506   if (invalid_flags != 0) {
507     xnn_log_error(
508       "failed to define %s operator with 0x%08" PRIx32 " flags: invalid flags 0x%08" PRIx32,
509       xnn_node_type_to_string(xnn_node_type_depthwise_convolution_2d), flags, invalid_flags);
510     return xnn_status_invalid_parameter;
511   }
512 
513   const bool any_padding = (input_padding_left | input_padding_top | input_padding_right | input_padding_bottom) != 0;
514   if ((flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) != 0 && any_padding) {
515     xnn_log_error(
516       "failed to define %s operator with %" PRIu32 "+%" PRIu32 "x%" PRIu32 "+%" PRIu32" padding: "
517       "TensorFlow SAME padding can't be combined with explicit padding specification",
518       xnn_node_type_to_string(xnn_node_type_convolution_2d),
519       input_padding_top, input_padding_left, input_padding_bottom, input_padding_right);
520     return xnn_status_invalid_parameter;
521   }
522 
523   // Convert TensorFlow SAME padding to explicit padding specification whenever possible
524   if ((flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) != 0 && (subsampling_height | subsampling_width) == 1) {
525     flags &= ~XNN_FLAG_TENSORFLOW_SAME_PADDING;
526     const uint32_t padding_height = (kernel_height - 1) * dilation_height;
527     const uint32_t padding_width = (kernel_width - 1) * dilation_width;
528     input_padding_left = padding_width / 2;
529     input_padding_top = padding_height / 2;
530     input_padding_right = padding_width - input_padding_left;
531     input_padding_bottom = padding_height - input_padding_top;
532   }
533 
534   if ((status = xnn_subgraph_check_input_node_id(xnn_node_type_depthwise_convolution_2d, input_id, subgraph->num_values)) !=
535       xnn_status_success) {
536     return status;
537   }
538 
539   const struct xnn_value* input_value = &subgraph->values[input_id];
540   status = xnn_subgraph_check_input_type_dense(xnn_node_type_depthwise_convolution_2d, input_id, input_value);
541   if (status != xnn_status_success) {
542     return status;
543   }
544 
545   switch (input_value->datatype) {
546     case xnn_datatype_fp32:
547 #ifndef XNN_NO_QS8_OPERATORS
548     case xnn_datatype_qint8:
549 #endif  // !defined(XNN_NO_QS8_OPERATORS)
550 #ifndef XNN_NO_QU8_OPERATORS
551     case xnn_datatype_quint8:
552 #endif  // !defined(XNN_NO_QU8_OPERATORS)
553       break;
554     default:
555       xnn_log_error(
556         "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
557         xnn_node_type_to_string(xnn_node_type_depthwise_convolution_2d), input_id,
558         xnn_datatype_to_string(input_value->datatype), input_value->datatype);
559       return xnn_status_invalid_parameter;
560   }
561 
562   if (filter_id >= subgraph->num_values) {
563     xnn_log_error(
564       "failed to define %s operator with filter ID #%" PRIu32 ": invalid Value ID",
565       xnn_node_type_to_string(xnn_node_type_depthwise_convolution_2d), filter_id);
566     return xnn_status_invalid_parameter;
567   }
568 
569   const struct xnn_value* filter_value = &subgraph->values[filter_id];
570   if (filter_value->type != xnn_value_type_dense_tensor) {
571     xnn_log_error(
572       "failed to define %s operator with filter ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
573       xnn_node_type_to_string(xnn_node_type_depthwise_convolution_2d), filter_id, filter_value->type);
574     return xnn_status_invalid_parameter;
575   }
576 
577   if (filter_value->data == NULL) {
578     xnn_log_error(
579       "failed to define %s operator with filter ID #%" PRIu32 ": non-static Value",
580       xnn_node_type_to_string(xnn_node_type_depthwise_convolution_2d), filter_id);
581     return xnn_status_invalid_parameter;
582   }
583 
584   switch (filter_value->datatype) {
585     case xnn_datatype_fp32:
586       break;
587 #ifndef XNN_NO_QS8_OPERATORS
588     case xnn_datatype_qint8:
589       if (filter_value->quantization.zero_point != 0) {
590         xnn_log_error(
591           "failed to define %s operator with filter ID #%" PRIu32 ": unsupported quantization zero point %" PRId32 " for datatype %s",
592           xnn_node_type_to_string(xnn_node_type_depthwise_convolution_2d), filter_id,
593           filter_value->quantization.zero_point, xnn_datatype_to_string(filter_value->datatype));
594       }
595       break;
596     case xnn_datatype_qcint8:
597       break;
598 #endif  // !defined(XNN_NO_QS8_OPERATORS)
599 #ifndef XNN_NO_QU8_OPERATORS
600     case xnn_datatype_quint8:
601       break;
602 #endif  // !defined(XNN_NO_QU8_OPERATORS)
603     default:
604       xnn_log_error(
605         "failed to define %s operator with filter ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
606         xnn_node_type_to_string(xnn_node_type_depthwise_convolution_2d), filter_id,
607         xnn_datatype_to_string(filter_value->datatype), filter_value->datatype);
608       return xnn_status_invalid_parameter;
609   }
610 
611   const struct xnn_value* bias_value = NULL;
612   if (bias_id != XNN_INVALID_VALUE_ID) {
613     if (bias_id >= subgraph->num_values) {
614       xnn_log_error(
615         "failed to define %s operator with bias ID #%" PRIu32 ": invalid Value ID",
616         xnn_node_type_to_string(xnn_node_type_depthwise_convolution_2d), bias_id);
617       return xnn_status_invalid_parameter;
618     }
619 
620     bias_value = &subgraph->values[bias_id];
621     if (bias_value->type != xnn_value_type_dense_tensor) {
622       xnn_log_error(
623         "failed to define %s operator with bias ID #%" PRIu32 ": unsupported Value type %d (expected dense tensor)",
624         xnn_node_type_to_string(xnn_node_type_depthwise_convolution_2d), bias_id, bias_value->type);
625       return xnn_status_invalid_parameter;
626     }
627 
628     if (bias_value->data == NULL) {
629       xnn_log_error(
630         "failed to define %s operator with bias ID #%" PRIu32 ": non-static Value",
631         xnn_node_type_to_string(xnn_node_type_depthwise_convolution_2d), bias_id);
632       return xnn_status_invalid_parameter;
633     }
634 
635     switch (bias_value->datatype) {
636       case xnn_datatype_fp32:
637 #if !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
638       case xnn_datatype_qint32:
639 #endif  // !defined(XNN_NO_QS8_OPERATORS) || !defined(XNN_NO_QU8_OPERATORS)
640 #ifndef XNN_NO_QS8_OPERATORS
641       case xnn_datatype_qcint32:
642 #endif  // !defined(XNN_NO_QS8_OPERATORS)
643         break;
644       default:
645         xnn_log_error(
646           "failed to define %s operator with bias ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
647           xnn_node_type_to_string(xnn_node_type_depthwise_convolution_2d), bias_id,
648           xnn_datatype_to_string(bias_value->datatype), bias_value->datatype);
649         return xnn_status_invalid_parameter;
650     }
651   }
652 
653   status = xnn_subgraph_check_output_node_id(xnn_node_type_depthwise_convolution_2d, output_id, subgraph->num_values);
654   if (status != xnn_status_success) {
655     return status;
656   }
657 
658   const struct xnn_value* output_value = &subgraph->values[output_id];
659   status = xnn_subgraph_check_output_type_dense(xnn_node_type_depthwise_convolution_2d, output_id, output_value);
660   if (status != xnn_status_success) {
661     return status;
662   }
663 
664   switch (output_value->datatype) {
665     case xnn_datatype_fp32:
666 #ifndef XNN_NO_QS8_OPERATORS
667     case xnn_datatype_qint8:
668 #endif  // !defined(XNN_NO_QS8_OPERATORS)
669 #ifndef XNN_NO_QU8_OPERATORS
670     case xnn_datatype_quint8:
671 #endif  // !defined(XNN_NO_QU8_OPERATORS)
672       break;
673     default:
674       xnn_log_error(
675         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
676         xnn_node_type_to_string(xnn_node_type_depthwise_convolution_2d), output_id,
677         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
678       return xnn_status_invalid_parameter;
679   }
680 
681   enum xnn_compute_type compute_type = xnn_compute_type_invalid;
682   if (bias_value != NULL) {
683     compute_type = validate_datatypes_with_bias(
684       input_value->datatype, filter_value->datatype, bias_value->datatype, output_value->datatype);
685     if (compute_type == xnn_compute_type_invalid) {
686       xnn_log_error(
687         "failed to define %s operator with input ID #%" PRIu32 ", filter ID #%" PRIu32 ", bias ID #%" PRIu32 ", and output ID #%" PRIu32
688         ": mismatching datatypes across input (%s), filter (%s), bias (%s), and output (%s)",
689         xnn_node_type_to_string(xnn_node_type_depthwise_convolution_2d), input_id, filter_id, bias_id, output_id,
690         xnn_datatype_to_string(input_value->datatype),
691         xnn_datatype_to_string(filter_value->datatype),
692         xnn_datatype_to_string(bias_value->datatype),
693         xnn_datatype_to_string(output_value->datatype));
694       return xnn_status_invalid_parameter;
695     }
696   } else {
697     compute_type = validate_datatypes_without_bias(input_value->datatype, filter_value->datatype, output_value->datatype);
698     if (compute_type == xnn_compute_type_invalid) {
699       xnn_log_error(
700         "failed to define %s operator with input ID #%" PRIu32 ", filter ID #%" PRIu32 ", and output ID #%" PRIu32
701         ": mismatching datatypes across input (%s), filter (%s), and output (%s)",
702         xnn_node_type_to_string(xnn_node_type_depthwise_convolution_2d), input_id, filter_id, output_id,
703         xnn_datatype_to_string(input_value->datatype),
704         xnn_datatype_to_string(filter_value->datatype),
705         xnn_datatype_to_string(output_value->datatype));
706       return xnn_status_invalid_parameter;
707     }
708   }
709 
710 #ifndef XNN_NO_QS8_OPERATORS
711   if (filter_value->datatype == xnn_datatype_qcint8) {
712     if (filter_value->quantization.channel_dimension != filter_value->shape.num_dims - 1) {
713       xnn_log_error(
714         "failed to define %s operator with filter ID #%" PRIu32 ": invalid channel dimension %zu",
715         xnn_node_type_to_string(xnn_node_type_convolution_2d), input_id, filter_value->quantization.channel_dimension);
716       return xnn_status_invalid_parameter;
717     }
718 
719     if (bias_value != NULL) {
720       assert(bias_value->datatype == xnn_datatype_qcint32);
721       if (bias_value->quantization.channel_dimension != 0) {
722         xnn_log_error(
723           "failed to define %s operator with bias ID #%" PRIu32 ": invalid channel dimension %zu",
724           xnn_node_type_to_string(xnn_node_type_convolution_2d), bias_id, bias_value->quantization.channel_dimension);
725         return xnn_status_invalid_parameter;
726       }
727     }
728   }
729 #endif  // !defined(XNN_NO_QS8_OPERATORS)
730 
731   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
732   if (node == NULL) {
733     return xnn_status_out_of_memory;
734   }
735 
736   node->type = xnn_node_type_depthwise_convolution_2d;
737   node->compute_type = compute_type;
738   node->params.depthwise_convolution_2d.input_padding_top = input_padding_top;
739   node->params.depthwise_convolution_2d.input_padding_right = input_padding_right;
740   node->params.depthwise_convolution_2d.input_padding_bottom = input_padding_bottom;
741   node->params.depthwise_convolution_2d.input_padding_left = input_padding_left;
742   node->params.depthwise_convolution_2d.kernel_height = kernel_height;
743   node->params.depthwise_convolution_2d.kernel_width = kernel_width;
744   node->params.depthwise_convolution_2d.subsampling_height = subsampling_height;
745   node->params.depthwise_convolution_2d.subsampling_width = subsampling_width;
746   node->params.depthwise_convolution_2d.dilation_height = dilation_height;
747   node->params.depthwise_convolution_2d.dilation_width = dilation_width;
748   node->params.depthwise_convolution_2d.depth_multiplier = depth_multiplier;
749   node->params.depthwise_convolution_2d.input_channels = input_channels;
750   node->activation.output_min = output_min;
751   node->activation.output_max = output_max;
752   node->num_inputs = 2 + (size_t) (bias_id != XNN_INVALID_VALUE_ID);
753   node->inputs[0] = input_id;
754   node->inputs[1] = filter_id;
755   node->inputs[2] = bias_id;
756   node->num_outputs = 1;
757   node->outputs[0] = output_id;
758   node->flags = flags;
759 
760   node->create = create_convolution_operator;
761   node->setup = setup_convolution_operator;
762 
763   return xnn_status_success;
764 };
765