xref: /aosp_15_r20/external/XNNPACK/src/subgraph/max-pooling-2d.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 
11 #include <xnnpack.h>
12 #include <xnnpack/log.h>
13 #include <xnnpack/operator.h>
14 #include <xnnpack/params.h>
15 #include <xnnpack/requantization.h>
16 #include <xnnpack/subgraph.h>
17 #include <xnnpack/subgraph-validation.h>
18 
19 
create_max_pooling_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)20 static enum xnn_status create_max_pooling_operator(
21   const struct xnn_node* node,
22   const struct xnn_value* values,
23   size_t num_values,
24   struct xnn_operator_data* opdata,
25   const struct xnn_caches* caches)
26 {
27   assert(node->num_inputs == 1);
28   const uint32_t input_id = node->inputs[0];
29   assert(input_id != XNN_INVALID_VALUE_ID);
30   assert(input_id < num_values);
31 
32   assert(node->num_outputs == 1);
33   const uint32_t output_id = node->outputs[0];
34   assert(output_id != XNN_INVALID_VALUE_ID);
35   assert(output_id < num_values);
36 
37   const size_t channel_dim = values[input_id].shape.dim[3];
38   assert(channel_dim == values[output_id].shape.dim[3]);
39 
40   enum xnn_status status;
41   switch (node->compute_type) {
42 #ifndef XNN_NO_F16_OPERATORS
43     case xnn_compute_type_fp16:
44       status = xnn_create_max_pooling2d_nhwc_f16(
45         node->params.pooling_2d.padding_top,
46         node->params.pooling_2d.padding_right,
47         node->params.pooling_2d.padding_bottom,
48         node->params.pooling_2d.padding_left,
49         node->params.pooling_2d.pooling_height,
50         node->params.pooling_2d.pooling_width,
51         node->params.pooling_2d.stride_height,
52         node->params.pooling_2d.stride_width,
53         node->params.pooling_2d.dilation_height,
54         node->params.pooling_2d.dilation_width,
55         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
56         node->activation.output_min,
57         node->activation.output_max,
58         node->flags,
59         &opdata->operator_objects[0]);
60       break;
61 #endif  // !defined(XNN_NO_F16_OPERATORS)
62     case xnn_compute_type_fp32:
63       status = xnn_create_max_pooling2d_nhwc_f32(
64         node->params.pooling_2d.padding_top,
65         node->params.pooling_2d.padding_right,
66         node->params.pooling_2d.padding_bottom,
67         node->params.pooling_2d.padding_left,
68         node->params.pooling_2d.pooling_height,
69         node->params.pooling_2d.pooling_width,
70         node->params.pooling_2d.stride_height,
71         node->params.pooling_2d.stride_width,
72         node->params.pooling_2d.dilation_height,
73         node->params.pooling_2d.dilation_width,
74         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
75         node->activation.output_min,
76         node->activation.output_max,
77         node->flags,
78         &opdata->operator_objects[0]);
79       break;
80 #ifndef XNN_NO_S8_OPERATORS
81     case xnn_compute_type_qs8:
82     {
83       const float output_scale = values[output_id].quantization.scale;
84       const int32_t output_zero_point = values[output_id].quantization.zero_point;
85       const int8_t output_min = xnn_qs8_quantize(node->activation.output_min, output_scale, output_zero_point);
86       const int8_t output_max = xnn_qs8_quantize(node->activation.output_max, output_scale, output_zero_point);
87       status = xnn_create_max_pooling2d_nhwc_s8(
88         node->params.pooling_2d.padding_top,
89         node->params.pooling_2d.padding_right,
90         node->params.pooling_2d.padding_bottom,
91         node->params.pooling_2d.padding_left,
92         node->params.pooling_2d.pooling_height,
93         node->params.pooling_2d.pooling_width,
94         node->params.pooling_2d.stride_height,
95         node->params.pooling_2d.stride_width,
96         node->params.pooling_2d.dilation_height,
97         node->params.pooling_2d.dilation_width,
98         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
99         output_min,
100         output_max,
101         node->flags,
102         &opdata->operator_objects[0]);
103       break;
104     }
105 #endif  // !defined(XNN_NO_S8_OPERATORS)
106 #ifndef XNN_NO_U8_OPERATORS
107     case xnn_compute_type_qu8:
108     {
109       const float output_scale = values[output_id].quantization.scale;
110       const int32_t output_zero_point = values[output_id].quantization.zero_point;
111       const uint8_t output_min = xnn_qu8_quantize(node->activation.output_min, output_scale, output_zero_point);
112       const uint8_t output_max = xnn_qu8_quantize(node->activation.output_max, output_scale, output_zero_point);
113       status = xnn_create_max_pooling2d_nhwc_u8(
114         node->params.pooling_2d.padding_top,
115         node->params.pooling_2d.padding_right,
116         node->params.pooling_2d.padding_bottom,
117         node->params.pooling_2d.padding_left,
118         node->params.pooling_2d.pooling_height,
119         node->params.pooling_2d.pooling_width,
120         node->params.pooling_2d.stride_height,
121         node->params.pooling_2d.stride_width,
122         node->params.pooling_2d.dilation_height,
123         node->params.pooling_2d.dilation_width,
124         channel_dim /* channels */, channel_dim /* input stride */, channel_dim /* output stride */,
125         output_min,
126         output_max,
127         node->flags,
128         &opdata->operator_objects[0]);
129       break;
130     }
131 #endif  // !defined(XNN_NO_U8_OPERATORS)
132     default:
133       XNN_UNREACHABLE;
134   }
135   if (status == xnn_status_success) {
136     opdata->batch_size = values[input_id].shape.dim[0];
137     opdata->input_height = values[input_id].shape.dim[1];
138     opdata->input_width = values[input_id].shape.dim[2];
139     opdata->inputs[0] = input_id;
140     opdata->outputs[0] = output_id;
141   }
142   return status;
143 }
144 
setup_max_pooling_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)145 static enum xnn_status setup_max_pooling_operator(
146   const struct xnn_operator_data* opdata,
147   const struct xnn_blob* blobs,
148   size_t num_blobs,
149   pthreadpool_t threadpool)
150 {
151   const uint32_t input_id = opdata->inputs[0];
152   assert(input_id != XNN_INVALID_VALUE_ID);
153   assert(input_id < num_blobs);
154 
155   const uint32_t output_id = opdata->outputs[0];
156   assert(output_id != XNN_INVALID_VALUE_ID);
157   assert(output_id < num_blobs);
158 
159   const struct xnn_blob* input_blob = blobs + input_id;
160   const void* input_data = input_blob->data;
161   assert(input_data != NULL);
162 
163   const struct xnn_blob* output_blob = blobs + output_id;
164   void* output_data = output_blob->data;
165   assert(output_data != NULL);
166 
167   switch (opdata->operator_objects[0]->type) {
168 #ifndef XNN_NO_F16_OPERATORS
169     case xnn_operator_type_max_pooling_nhwc_f16:
170       return xnn_setup_max_pooling2d_nhwc_f16(
171         opdata->operator_objects[0],
172         opdata->batch_size,
173         opdata->input_height,
174         opdata->input_width,
175         input_data,
176         output_data,
177         threadpool);
178 #endif  // !defined(XNN_NO_F16_OPERATORS)
179     case xnn_operator_type_max_pooling_nhwc_f32:
180       return xnn_setup_max_pooling2d_nhwc_f32(
181         opdata->operator_objects[0],
182         opdata->batch_size,
183         opdata->input_height,
184         opdata->input_width,
185         input_data,
186         output_data,
187         threadpool);
188 #ifndef XNN_NO_S8_OPERATORS
189     case xnn_operator_type_max_pooling_nhwc_s8:
190       return xnn_setup_max_pooling2d_nhwc_s8(
191         opdata->operator_objects[0],
192         opdata->batch_size,
193         opdata->input_height,
194         opdata->input_width,
195         input_data,
196         output_data,
197         threadpool);
198 #endif  // !defined(XNN_NO_S8_OPERATORS)
199 #ifndef XNN_NO_U8_OPERATORS
200     case xnn_operator_type_max_pooling_nhwc_u8:
201       return xnn_setup_max_pooling2d_nhwc_u8(
202         opdata->operator_objects[0],
203         opdata->batch_size,
204         opdata->input_height,
205         opdata->input_width,
206         input_data,
207         output_data,
208         threadpool);
209 #endif  // !defined(XNN_NO_U8_OPERATORS)
210     default:
211       XNN_UNREACHABLE;
212   }
213 }
214 
xnn_define_max_pooling_2d(xnn_subgraph_t subgraph,uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t pooling_height,uint32_t pooling_width,uint32_t stride_height,uint32_t stride_width,uint32_t dilation_height,uint32_t dilation_width,float output_min,float output_max,uint32_t input_id,uint32_t output_id,uint32_t flags)215 enum xnn_status xnn_define_max_pooling_2d(
216   xnn_subgraph_t subgraph,
217   uint32_t input_padding_top,
218   uint32_t input_padding_right,
219   uint32_t input_padding_bottom,
220   uint32_t input_padding_left,
221   uint32_t pooling_height,
222   uint32_t pooling_width,
223   uint32_t stride_height,
224   uint32_t stride_width,
225   uint32_t dilation_height,
226   uint32_t dilation_width,
227   float output_min,
228   float output_max,
229   uint32_t input_id,
230   uint32_t output_id,
231   uint32_t flags)
232 {
233   enum xnn_status status;
234   if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_max_pooling_2d)) != xnn_status_success) {
235     return status;
236   }
237 
238   const uint32_t pooling_size = pooling_height * pooling_width;
239   if (pooling_size == 0) {
240     xnn_log_error(
241       "failed to define %s operator with %" PRIu32 "x%" PRIu32 " pooling size: "
242       "pooling size dimensions must be non-zero",
243       xnn_node_type_to_string(xnn_node_type_max_pooling_2d), pooling_width, pooling_height);
244     return xnn_status_invalid_parameter;
245   }
246 
247   if (pooling_size == 1) {
248     xnn_log_error(
249       "failed to define %s operator with 1 pooling element: 1x1 pooling is meaningless",
250       xnn_node_type_to_string(xnn_node_type_max_pooling_2d));
251     return xnn_status_invalid_parameter;
252   }
253 
254   if (stride_height == 0 || stride_width == 0) {
255     xnn_log_error(
256       "failed to define %s operator with %" PRIu32 "x%" PRIu32 " stride: stride dimensions must be non-zero",
257       xnn_node_type_to_string(xnn_node_type_max_pooling_2d), stride_width, stride_height);
258     return xnn_status_invalid_parameter;
259   }
260 
261   if (dilation_height == 0 || dilation_width == 0) {
262     xnn_log_error(
263       "failed to define %s operator with %" PRIu32 "x%" PRIu32 " dilation: dilation dimensions must be non-zero",
264       xnn_node_type_to_string(xnn_node_type_max_pooling_2d), dilation_width, dilation_height);
265     return xnn_status_invalid_parameter;
266   }
267 
268   if (stride_height > pooling_height) {
269     xnn_log_error(
270       "failed to define %s operator with %" PRIu32 " stride height: must be less than pooling height %" PRIu32,
271       xnn_node_type_to_string(xnn_node_type_max_pooling_2d), stride_height, pooling_height);
272     return xnn_status_invalid_parameter;
273   }
274 
275   if (stride_width > pooling_width) {
276     xnn_log_error(
277       "failed to define %s operator with %" PRIu32 " stride width: must be less than pooling width %" PRIu32,
278       xnn_node_type_to_string(xnn_node_type_max_pooling_2d), stride_width, pooling_width);
279     return xnn_status_invalid_parameter;
280   }
281 
282   status = xnn_subgraph_check_output_min_max(xnn_node_type_max_pooling_2d, output_min, output_max);
283   if (status != xnn_status_success) {
284     return status;
285   }
286 
287   const bool any_padding = (input_padding_left | input_padding_top | input_padding_right | input_padding_bottom) != 0;
288   if ((flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) != 0) {
289     if (any_padding) {
290       xnn_log_error(
291         "failed to define %s operator with %" PRIu32 "+%" PRIu32 "x%" PRIu32 "+%" PRIu32" padding: "
292         "TensorFlow SAME padding can't be combined with explicit padding specification",
293         xnn_node_type_to_string(xnn_node_type_max_pooling_2d),
294         input_padding_top, input_padding_left, input_padding_bottom, input_padding_right);
295       return xnn_status_invalid_parameter;
296     }
297   }
298 
299   if ((status = xnn_subgraph_check_input_node_id(xnn_node_type_max_pooling_2d, input_id, subgraph->num_values)) !=
300       xnn_status_success) {
301     return status;
302   }
303 
304   const struct xnn_value* input_value = &subgraph->values[input_id];
305   status = xnn_subgraph_check_input_type_dense(xnn_node_type_max_pooling_2d, input_id, input_value);
306   if (status != xnn_status_success) {
307     return status;
308   }
309 
310   switch (input_value->datatype) {
311     case xnn_datatype_fp32:
312 #ifndef XNN_NO_S8_OPERATORS
313     case xnn_datatype_qint8:
314 #endif  // !defined(XNN_NO_S8_OPERATORS)
315 #ifndef XNN_NO_U8_OPERATORS
316     case xnn_datatype_quint8:
317 #endif  // !defined(XNN_NO_U8_OPERATORS)
318       break;
319     default:
320       xnn_log_error(
321         "failed to define %s operator with input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
322         xnn_node_type_to_string(xnn_node_type_max_pooling_2d), input_id,
323         xnn_datatype_to_string(input_value->datatype), input_value->datatype);
324       return xnn_status_invalid_parameter;
325   }
326 
327   status = xnn_subgraph_check_output_node_id(xnn_node_type_max_pooling_2d, output_id, subgraph->num_values);
328   if (status != xnn_status_success) {
329     return status;
330   }
331 
332   const struct xnn_value* output_value = &subgraph->values[output_id];
333   status = xnn_subgraph_check_output_type_dense(xnn_node_type_max_pooling_2d, output_id, output_value);
334   if (status != xnn_status_success) {
335     return status;
336   }
337 
338   enum xnn_compute_type compute_type = xnn_compute_type_invalid;
339   switch (output_value->datatype) {
340     case xnn_datatype_fp32:
341       compute_type = xnn_compute_type_fp32;
342       break;
343 #ifndef XNN_NO_S8_OPERATORS
344     case xnn_datatype_qint8:
345       compute_type = xnn_compute_type_qs8;
346       break;
347 #endif  // !defined(XNN_NO_S8_OPERATORS)
348 #ifndef XNN_NO_U8_OPERATORS
349     case xnn_datatype_quint8:
350       compute_type = xnn_compute_type_qu8;
351       break;
352 #endif  // !defined(XNN_NO_U8_OPERATORS)
353     default:
354       xnn_log_error(
355         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
356         xnn_node_type_to_string(xnn_node_type_max_pooling_2d), output_id,
357         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
358       return xnn_status_invalid_parameter;
359   }
360 
361   status = xnn_subgraph_check_datatype_matches(
362     xnn_node_type_max_pooling_2d, input_id, input_value, output_id, output_value);
363   if (status != xnn_status_success) {
364     return status;
365   }
366 
367 #if !defined(XNN_NO_S8_OPERATORS) || !defined(XNN_NO_U8_OPERATORS)
368   if (output_value->datatype == xnn_datatype_qint8 || output_value->datatype == xnn_datatype_quint8) {
369     if (input_value->quantization.zero_point != output_value->quantization.zero_point) {
370       xnn_log_error(
371         "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
372         ": mismatching zero point quantization parameter across input (%"PRId32") and output (%"PRId32")",
373         xnn_node_type_to_string(xnn_node_type_max_pooling_2d), input_id, output_id,
374         input_value->quantization.zero_point, output_value->quantization.zero_point);
375       return xnn_status_invalid_parameter;
376     }
377     if (input_value->quantization.scale != output_value->quantization.scale) {
378       xnn_log_error(
379         "failed to define %s operator with input ID #%" PRIu32 " and output ID #%" PRIu32
380         ": mismatching zero point quantization parameter across input (%.7g) and output (%.7g)",
381         xnn_node_type_to_string(xnn_node_type_max_pooling_2d), input_id, output_id,
382         input_value->quantization.scale, output_value->quantization.scale);
383       return xnn_status_invalid_parameter;
384     }
385   }
386 #endif  // !defined(XNN_NO_S8_OPERATORS) || !defined(XNN_NO_U8_OPERATORS)
387 
388   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
389   if (node == NULL) {
390     return xnn_status_out_of_memory;
391   }
392 
393   node->type = xnn_node_type_max_pooling_2d;
394   node->compute_type = compute_type;
395   node->params.pooling_2d.padding_top = input_padding_top;
396   node->params.pooling_2d.padding_right = input_padding_right;
397   node->params.pooling_2d.padding_bottom = input_padding_bottom;
398   node->params.pooling_2d.padding_left = input_padding_left;
399   node->params.pooling_2d.pooling_height = pooling_height;
400   node->params.pooling_2d.pooling_width = pooling_width;
401   node->params.pooling_2d.stride_height = stride_height;
402   node->params.pooling_2d.stride_width = stride_width;
403   node->params.pooling_2d.dilation_height = dilation_height;
404   node->params.pooling_2d.dilation_width = dilation_width;
405   node->activation.output_min = output_min;
406   node->activation.output_max = output_max;
407   node->num_inputs = 1;
408   node->inputs[0] = input_id;
409   node->num_outputs = 1;
410   node->outputs[0] = output_id;
411   node->flags = flags;
412 
413   node->create = create_max_pooling_operator;
414   node->setup = setup_max_pooling_operator;
415 
416   return xnn_status_success;
417 }
418