xref: /aosp_15_r20/external/XNNPACK/src/operators/max-pooling-nhwc.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates.
2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved.
3*4bdc9457SAndroid Build Coastguard Worker //
4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
5*4bdc9457SAndroid Build Coastguard Worker //
6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
8*4bdc9457SAndroid Build Coastguard Worker 
9*4bdc9457SAndroid Build Coastguard Worker #include <assert.h>
10*4bdc9457SAndroid Build Coastguard Worker #include <math.h>
11*4bdc9457SAndroid Build Coastguard Worker #include <stdbool.h>
12*4bdc9457SAndroid Build Coastguard Worker #include <stddef.h>
13*4bdc9457SAndroid Build Coastguard Worker #include <stdint.h>
14*4bdc9457SAndroid Build Coastguard Worker #include <stdlib.h>
15*4bdc9457SAndroid Build Coastguard Worker #include <string.h>
16*4bdc9457SAndroid Build Coastguard Worker 
17*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
18*4bdc9457SAndroid Build Coastguard Worker 
19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/allocator.h>
21*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/common.h>
22*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/indirection.h>
23*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/log.h>
24*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/math.h>
25*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/operator.h>
26*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microparams-init.h>
27*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/params.h>
28*4bdc9457SAndroid Build Coastguard Worker 
29*4bdc9457SAndroid Build Coastguard Worker 
compute_output_dimension_with_tf_same_padding(size_t input_dimension,size_t stride_dimension)30*4bdc9457SAndroid Build Coastguard Worker static inline size_t compute_output_dimension_with_tf_same_padding(
31*4bdc9457SAndroid Build Coastguard Worker     size_t input_dimension,
32*4bdc9457SAndroid Build Coastguard Worker     size_t stride_dimension)
33*4bdc9457SAndroid Build Coastguard Worker {
34*4bdc9457SAndroid Build Coastguard Worker   return divide_round_up(input_dimension, stride_dimension);
35*4bdc9457SAndroid Build Coastguard Worker }
36*4bdc9457SAndroid Build Coastguard Worker 
create_max_pooling2d_nhwc(uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t pooling_height,uint32_t pooling_width,uint32_t stride_height,uint32_t stride_width,uint32_t dilation_height,uint32_t dilation_width,size_t channels,size_t input_pixel_stride,size_t output_pixel_stride,uint32_t flags,const void * params,size_t params_size,uint32_t datatype_init_flags,enum xnn_operator_type operator_type,xnn_operator_t * max_pooling_op_out)37*4bdc9457SAndroid Build Coastguard Worker static enum xnn_status create_max_pooling2d_nhwc(
38*4bdc9457SAndroid Build Coastguard Worker     uint32_t input_padding_top,
39*4bdc9457SAndroid Build Coastguard Worker     uint32_t input_padding_right,
40*4bdc9457SAndroid Build Coastguard Worker     uint32_t input_padding_bottom,
41*4bdc9457SAndroid Build Coastguard Worker     uint32_t input_padding_left,
42*4bdc9457SAndroid Build Coastguard Worker     uint32_t pooling_height,
43*4bdc9457SAndroid Build Coastguard Worker     uint32_t pooling_width,
44*4bdc9457SAndroid Build Coastguard Worker     uint32_t stride_height,
45*4bdc9457SAndroid Build Coastguard Worker     uint32_t stride_width,
46*4bdc9457SAndroid Build Coastguard Worker     uint32_t dilation_height,
47*4bdc9457SAndroid Build Coastguard Worker     uint32_t dilation_width,
48*4bdc9457SAndroid Build Coastguard Worker     size_t channels,
49*4bdc9457SAndroid Build Coastguard Worker     size_t input_pixel_stride,
50*4bdc9457SAndroid Build Coastguard Worker     size_t output_pixel_stride,
51*4bdc9457SAndroid Build Coastguard Worker     uint32_t flags,
52*4bdc9457SAndroid Build Coastguard Worker     const void* params,
53*4bdc9457SAndroid Build Coastguard Worker     size_t params_size,
54*4bdc9457SAndroid Build Coastguard Worker     uint32_t datatype_init_flags,
55*4bdc9457SAndroid Build Coastguard Worker     enum xnn_operator_type operator_type,
56*4bdc9457SAndroid Build Coastguard Worker     xnn_operator_t* max_pooling_op_out)
57*4bdc9457SAndroid Build Coastguard Worker {
58*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t max_pooling_op = NULL;
59*4bdc9457SAndroid Build Coastguard Worker   enum xnn_status status = xnn_status_uninitialized;
60*4bdc9457SAndroid Build Coastguard Worker 
61*4bdc9457SAndroid Build Coastguard Worker   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
62*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error("failed to setup %s operator: XNNPACK is not initialized",
63*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(operator_type));
64*4bdc9457SAndroid Build Coastguard Worker     return xnn_status_uninitialized;
65*4bdc9457SAndroid Build Coastguard Worker   }
66*4bdc9457SAndroid Build Coastguard Worker 
67*4bdc9457SAndroid Build Coastguard Worker   status = xnn_status_unsupported_hardware;
68*4bdc9457SAndroid Build Coastguard Worker 
69*4bdc9457SAndroid Build Coastguard Worker   if ((xnn_params.init_flags & datatype_init_flags) != datatype_init_flags) {
70*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error(
71*4bdc9457SAndroid Build Coastguard Worker       "failed to create %s operator: operations on data type are not supported",
72*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(operator_type));
73*4bdc9457SAndroid Build Coastguard Worker     goto error;
74*4bdc9457SAndroid Build Coastguard Worker   }
75*4bdc9457SAndroid Build Coastguard Worker 
76*4bdc9457SAndroid Build Coastguard Worker   status = xnn_status_invalid_parameter;
77*4bdc9457SAndroid Build Coastguard Worker 
78*4bdc9457SAndroid Build Coastguard Worker   const uint32_t pooling_size = pooling_height * pooling_width;
79*4bdc9457SAndroid Build Coastguard Worker   if (pooling_size == 0) {
80*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error(
81*4bdc9457SAndroid Build Coastguard Worker       "failed to create %s operator with %" PRIu32 "x%" PRIu32 " pooling size: "
82*4bdc9457SAndroid Build Coastguard Worker       "pooling size dimensions must be non-zero",
83*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(operator_type),
84*4bdc9457SAndroid Build Coastguard Worker       pooling_width, pooling_height);
85*4bdc9457SAndroid Build Coastguard Worker     goto error;
86*4bdc9457SAndroid Build Coastguard Worker   }
87*4bdc9457SAndroid Build Coastguard Worker 
88*4bdc9457SAndroid Build Coastguard Worker   if (pooling_size == 1) {
89*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error(
90*4bdc9457SAndroid Build Coastguard Worker       "failed to create %s operator with 1 pooling element: 1x1 pooling is meaningless",
91*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(operator_type));
92*4bdc9457SAndroid Build Coastguard Worker     goto error;
93*4bdc9457SAndroid Build Coastguard Worker   }
94*4bdc9457SAndroid Build Coastguard Worker 
95*4bdc9457SAndroid Build Coastguard Worker   if (stride_height == 0 || stride_width == 0) {
96*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error(
97*4bdc9457SAndroid Build Coastguard Worker       "failed to create %s operator with %" PRIu32 "x%" PRIu32 " stride: stride dimensions must be non-zero",
98*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(operator_type), stride_width, stride_height);
99*4bdc9457SAndroid Build Coastguard Worker     goto error;
100*4bdc9457SAndroid Build Coastguard Worker   }
101*4bdc9457SAndroid Build Coastguard Worker 
102*4bdc9457SAndroid Build Coastguard Worker   if (dilation_height == 0 || dilation_width == 0) {
103*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error(
104*4bdc9457SAndroid Build Coastguard Worker       "failed to create %s operator with %" PRIu32 "x%" PRIu32 " dilation: dilation dimensions must be non-zero",
105*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(operator_type), dilation_width, dilation_height);
106*4bdc9457SAndroid Build Coastguard Worker     goto error;
107*4bdc9457SAndroid Build Coastguard Worker   }
108*4bdc9457SAndroid Build Coastguard Worker 
109*4bdc9457SAndroid Build Coastguard Worker   if (stride_height > pooling_height) {
110*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error(
111*4bdc9457SAndroid Build Coastguard Worker       "failed to define %s operator with %" PRIu32 " stride height: must be less than pooling height %" PRIu32,
112*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(operator_type), stride_height, pooling_height);
113*4bdc9457SAndroid Build Coastguard Worker     return xnn_status_invalid_parameter;
114*4bdc9457SAndroid Build Coastguard Worker   }
115*4bdc9457SAndroid Build Coastguard Worker 
116*4bdc9457SAndroid Build Coastguard Worker   if (stride_width > pooling_width) {
117*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error(
118*4bdc9457SAndroid Build Coastguard Worker       "failed to define %s operator with %" PRIu32 " stride width: must be less than pooling width %" PRIu32,
119*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(operator_type), stride_width, pooling_width);
120*4bdc9457SAndroid Build Coastguard Worker     return xnn_status_invalid_parameter;
121*4bdc9457SAndroid Build Coastguard Worker   }
122*4bdc9457SAndroid Build Coastguard Worker 
123*4bdc9457SAndroid Build Coastguard Worker   if (channels == 0) {
124*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error(
125*4bdc9457SAndroid Build Coastguard Worker       "failed to create %s operator with %zu channels: number of channels must be non-zero",
126*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(operator_type), channels);
127*4bdc9457SAndroid Build Coastguard Worker     goto error;
128*4bdc9457SAndroid Build Coastguard Worker   }
129*4bdc9457SAndroid Build Coastguard Worker 
130*4bdc9457SAndroid Build Coastguard Worker   if (input_pixel_stride < channels) {
131*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error(
132*4bdc9457SAndroid Build Coastguard Worker       "failed to create %s operator with input pixel stride of %zu: "
133*4bdc9457SAndroid Build Coastguard Worker       "stride must be at least as large as the number of channels (%zu)",
134*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(operator_type), input_pixel_stride, channels);
135*4bdc9457SAndroid Build Coastguard Worker     goto error;
136*4bdc9457SAndroid Build Coastguard Worker   }
137*4bdc9457SAndroid Build Coastguard Worker 
138*4bdc9457SAndroid Build Coastguard Worker   if (output_pixel_stride < channels) {
139*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error(
140*4bdc9457SAndroid Build Coastguard Worker       "failed to create %s operator with output pixel stride of %zu: "
141*4bdc9457SAndroid Build Coastguard Worker       "stride must be at least as large as the number of channels (%zu)",
142*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(operator_type), output_pixel_stride, channels);
143*4bdc9457SAndroid Build Coastguard Worker     goto error;
144*4bdc9457SAndroid Build Coastguard Worker   }
145*4bdc9457SAndroid Build Coastguard Worker 
146*4bdc9457SAndroid Build Coastguard Worker   const bool any_padding = (input_padding_left | input_padding_top | input_padding_right | input_padding_bottom) != 0;
147*4bdc9457SAndroid Build Coastguard Worker   if ((flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) != 0) {
148*4bdc9457SAndroid Build Coastguard Worker     if (any_padding) {
149*4bdc9457SAndroid Build Coastguard Worker       xnn_log_error(
150*4bdc9457SAndroid Build Coastguard Worker         "failed to create %s operator with %" PRIu32 "+%" PRIu32 "x%" PRIu32 "+%" PRIu32" padding: "
151*4bdc9457SAndroid Build Coastguard Worker         "TensorFlow SAME padding can't be combined with explicit padding specification",
152*4bdc9457SAndroid Build Coastguard Worker         xnn_operator_type_to_string(operator_type),
153*4bdc9457SAndroid Build Coastguard Worker         input_padding_top, input_padding_left, input_padding_bottom, input_padding_right);
154*4bdc9457SAndroid Build Coastguard Worker       goto error;
155*4bdc9457SAndroid Build Coastguard Worker     }
156*4bdc9457SAndroid Build Coastguard Worker   }
157*4bdc9457SAndroid Build Coastguard Worker 
158*4bdc9457SAndroid Build Coastguard Worker   status = xnn_status_out_of_memory;
159*4bdc9457SAndroid Build Coastguard Worker 
160*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op = xnn_allocate_zero_simd_memory(sizeof(struct xnn_operator));
161*4bdc9457SAndroid Build Coastguard Worker   if (max_pooling_op == NULL) {
162*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error(
163*4bdc9457SAndroid Build Coastguard Worker       "failed to allocate %zu bytes for %s operator descriptor",
164*4bdc9457SAndroid Build Coastguard Worker       sizeof(struct xnn_operator), xnn_operator_type_to_string(operator_type));
165*4bdc9457SAndroid Build Coastguard Worker     goto error;
166*4bdc9457SAndroid Build Coastguard Worker   }
167*4bdc9457SAndroid Build Coastguard Worker 
168*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->padding_top = input_padding_top;
169*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->padding_right = input_padding_right;
170*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->padding_bottom = input_padding_bottom;
171*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->padding_left = input_padding_left;
172*4bdc9457SAndroid Build Coastguard Worker 
173*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->kernel_height = pooling_height;
174*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->kernel_width = pooling_width;
175*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->stride_height = stride_height;
176*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->stride_width = stride_width;
177*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->dilation_height = dilation_height;
178*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->dilation_width = dilation_width;
179*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->channels = channels;
180*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->input_pixel_stride = input_pixel_stride;
181*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->output_pixel_stride = output_pixel_stride;
182*4bdc9457SAndroid Build Coastguard Worker 
183*4bdc9457SAndroid Build Coastguard Worker   memcpy(&max_pooling_op->params, params, params_size);
184*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->type = operator_type;
185*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->flags = flags;
186*4bdc9457SAndroid Build Coastguard Worker 
187*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->state = xnn_run_state_invalid;
188*4bdc9457SAndroid Build Coastguard Worker 
189*4bdc9457SAndroid Build Coastguard Worker   *max_pooling_op_out = max_pooling_op;
190*4bdc9457SAndroid Build Coastguard Worker   return xnn_status_success;
191*4bdc9457SAndroid Build Coastguard Worker 
192*4bdc9457SAndroid Build Coastguard Worker error:
193*4bdc9457SAndroid Build Coastguard Worker   xnn_delete_operator(max_pooling_op);
194*4bdc9457SAndroid Build Coastguard Worker   return status;
195*4bdc9457SAndroid Build Coastguard Worker }
196*4bdc9457SAndroid Build Coastguard Worker 
setup_max_pooling2d_nhwc(xnn_operator_t max_pooling_op,enum xnn_operator_type expected_operator_type,size_t batch_size,size_t input_height,size_t input_width,const void * input,void * output,uint32_t log2_input_element_size,uint32_t log2_output_element_size,struct maxpool_parameters maxpool[restrict XNN_MIN_ELEMENTS (1)],const void * params,size_t params_size,size_t num_threads)197*4bdc9457SAndroid Build Coastguard Worker static enum xnn_status setup_max_pooling2d_nhwc(
198*4bdc9457SAndroid Build Coastguard Worker   xnn_operator_t max_pooling_op,
199*4bdc9457SAndroid Build Coastguard Worker   enum xnn_operator_type expected_operator_type,
200*4bdc9457SAndroid Build Coastguard Worker   size_t batch_size,
201*4bdc9457SAndroid Build Coastguard Worker   size_t input_height,
202*4bdc9457SAndroid Build Coastguard Worker   size_t input_width,
203*4bdc9457SAndroid Build Coastguard Worker   const void* input,
204*4bdc9457SAndroid Build Coastguard Worker   void* output,
205*4bdc9457SAndroid Build Coastguard Worker   uint32_t log2_input_element_size,
206*4bdc9457SAndroid Build Coastguard Worker   uint32_t log2_output_element_size,
207*4bdc9457SAndroid Build Coastguard Worker   struct maxpool_parameters maxpool[restrict XNN_MIN_ELEMENTS(1)],
208*4bdc9457SAndroid Build Coastguard Worker   const void* params,
209*4bdc9457SAndroid Build Coastguard Worker   size_t params_size,
210*4bdc9457SAndroid Build Coastguard Worker   size_t num_threads)
211*4bdc9457SAndroid Build Coastguard Worker {
212*4bdc9457SAndroid Build Coastguard Worker   if (max_pooling_op->type != expected_operator_type) {
213*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error("failed to setup operator: operator type mismatch (expected %s, got %s)",
214*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(expected_operator_type),
215*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(max_pooling_op->type));
216*4bdc9457SAndroid Build Coastguard Worker     return xnn_status_invalid_parameter;
217*4bdc9457SAndroid Build Coastguard Worker   }
218*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->state = xnn_run_state_invalid;
219*4bdc9457SAndroid Build Coastguard Worker 
220*4bdc9457SAndroid Build Coastguard Worker   if ((xnn_params.init_flags & XNN_INIT_FLAG_XNNPACK) == 0) {
221*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error(
222*4bdc9457SAndroid Build Coastguard Worker       "failed to setup %s operator: XNNPACK is not initialized",
223*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(max_pooling_op->type));
224*4bdc9457SAndroid Build Coastguard Worker     return xnn_status_uninitialized;
225*4bdc9457SAndroid Build Coastguard Worker   }
226*4bdc9457SAndroid Build Coastguard Worker 
227*4bdc9457SAndroid Build Coastguard Worker   if (input_width == 0 || input_height == 0) {
228*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error(
229*4bdc9457SAndroid Build Coastguard Worker       "failed to setup %s operator with %zux%zu input: input dimensions must be non-zero",
230*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(max_pooling_op->type), input_width, input_height);
231*4bdc9457SAndroid Build Coastguard Worker     return xnn_status_invalid_parameter;
232*4bdc9457SAndroid Build Coastguard Worker   }
233*4bdc9457SAndroid Build Coastguard Worker 
234*4bdc9457SAndroid Build Coastguard Worker   if (batch_size == 0) {
235*4bdc9457SAndroid Build Coastguard Worker     max_pooling_op->state = xnn_run_state_skip;
236*4bdc9457SAndroid Build Coastguard Worker     return xnn_status_success;
237*4bdc9457SAndroid Build Coastguard Worker   }
238*4bdc9457SAndroid Build Coastguard Worker 
239*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->input_height = input_height;
240*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->input_width = input_width;
241*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->input = input;
242*4bdc9457SAndroid Build Coastguard Worker 
243*4bdc9457SAndroid Build Coastguard Worker   if (max_pooling_op->flags & XNN_FLAG_TENSORFLOW_SAME_PADDING) {
244*4bdc9457SAndroid Build Coastguard Worker     max_pooling_op->output_height = compute_output_dimension_with_tf_same_padding(
245*4bdc9457SAndroid Build Coastguard Worker         input_height, max_pooling_op->stride_height);
246*4bdc9457SAndroid Build Coastguard Worker     max_pooling_op->output_width = compute_output_dimension_with_tf_same_padding(
247*4bdc9457SAndroid Build Coastguard Worker         input_width, max_pooling_op->stride_width);
248*4bdc9457SAndroid Build Coastguard Worker 
249*4bdc9457SAndroid Build Coastguard Worker     const uint32_t effective_kernel_height = (max_pooling_op->kernel_height - 1) * max_pooling_op->dilation_height + 1;
250*4bdc9457SAndroid Build Coastguard Worker     const uint32_t effective_kernel_width = (max_pooling_op->kernel_width - 1) * max_pooling_op->dilation_width + 1;
251*4bdc9457SAndroid Build Coastguard Worker     const uint32_t total_padding_height =
252*4bdc9457SAndroid Build Coastguard Worker       doz((max_pooling_op->output_height - 1) * max_pooling_op->stride_height + effective_kernel_height, input_height);
253*4bdc9457SAndroid Build Coastguard Worker     const uint32_t total_padding_width =
254*4bdc9457SAndroid Build Coastguard Worker       doz((max_pooling_op->output_width - 1) * max_pooling_op->stride_width + effective_kernel_width, input_width);
255*4bdc9457SAndroid Build Coastguard Worker     max_pooling_op->padding_top = total_padding_height / 2;
256*4bdc9457SAndroid Build Coastguard Worker     max_pooling_op->padding_left = total_padding_width / 2;
257*4bdc9457SAndroid Build Coastguard Worker     max_pooling_op->padding_bottom = total_padding_height - max_pooling_op->padding_top;
258*4bdc9457SAndroid Build Coastguard Worker     max_pooling_op->padding_right = total_padding_width - max_pooling_op->padding_left;
259*4bdc9457SAndroid Build Coastguard Worker   } else {
260*4bdc9457SAndroid Build Coastguard Worker     max_pooling_op->output_height = xnn_compute_convolution_output_dimension(
261*4bdc9457SAndroid Build Coastguard Worker         max_pooling_op->padding_top + input_height + max_pooling_op->padding_bottom,
262*4bdc9457SAndroid Build Coastguard Worker         max_pooling_op->kernel_height,
263*4bdc9457SAndroid Build Coastguard Worker         max_pooling_op->dilation_height,
264*4bdc9457SAndroid Build Coastguard Worker         max_pooling_op->stride_height);
265*4bdc9457SAndroid Build Coastguard Worker     max_pooling_op->output_width = xnn_compute_convolution_output_dimension(
266*4bdc9457SAndroid Build Coastguard Worker         max_pooling_op->padding_left + input_width + max_pooling_op->padding_right,
267*4bdc9457SAndroid Build Coastguard Worker         max_pooling_op->kernel_width,
268*4bdc9457SAndroid Build Coastguard Worker         max_pooling_op->dilation_width,
269*4bdc9457SAndroid Build Coastguard Worker         max_pooling_op->stride_width);
270*4bdc9457SAndroid Build Coastguard Worker   }
271*4bdc9457SAndroid Build Coastguard Worker 
272*4bdc9457SAndroid Build Coastguard Worker   const size_t pooling_height = max_pooling_op->kernel_height;
273*4bdc9457SAndroid Build Coastguard Worker   const size_t pooling_width = max_pooling_op->kernel_width;
274*4bdc9457SAndroid Build Coastguard Worker   const size_t pooling_size = pooling_height * pooling_width;
275*4bdc9457SAndroid Build Coastguard Worker   const size_t output_height = max_pooling_op->output_height;
276*4bdc9457SAndroid Build Coastguard Worker   const size_t output_width = max_pooling_op->output_width;
277*4bdc9457SAndroid Build Coastguard Worker   const uint32_t mr = maxpool->mr;
278*4bdc9457SAndroid Build Coastguard Worker 
279*4bdc9457SAndroid Build Coastguard Worker   const size_t step_width =
280*4bdc9457SAndroid Build Coastguard Worker     max_pooling_op->dilation_width > 1 ? pooling_width : min(max_pooling_op->stride_width, pooling_width);
281*4bdc9457SAndroid Build Coastguard Worker   const size_t step_height = pooling_size + (output_width - 1) * step_width * pooling_height;
282*4bdc9457SAndroid Build Coastguard Worker 
283*4bdc9457SAndroid Build Coastguard Worker   if (input_height != max_pooling_op->last_input_height ||
284*4bdc9457SAndroid Build Coastguard Worker       input_width != max_pooling_op->last_input_width)
285*4bdc9457SAndroid Build Coastguard Worker   {
286*4bdc9457SAndroid Build Coastguard Worker     // Micro-kernel may read up to (mr - 1) elements after the end of indirection buffer.
287*4bdc9457SAndroid Build Coastguard Worker     const size_t indirection_buffer_size = sizeof(void*) * ((mr - 1) + output_height * step_height);
288*4bdc9457SAndroid Build Coastguard Worker     const void** indirection_buffer =
289*4bdc9457SAndroid Build Coastguard Worker       (const void**) xnn_reallocate_memory(max_pooling_op->indirection_buffer, indirection_buffer_size);
290*4bdc9457SAndroid Build Coastguard Worker     if (indirection_buffer == NULL) {
291*4bdc9457SAndroid Build Coastguard Worker       xnn_log_error("failed to allocate %zu bytes for indirection buffer", indirection_buffer_size);
292*4bdc9457SAndroid Build Coastguard Worker       return xnn_status_out_of_memory;
293*4bdc9457SAndroid Build Coastguard Worker     }
294*4bdc9457SAndroid Build Coastguard Worker     max_pooling_op->indirection_buffer = indirection_buffer;
295*4bdc9457SAndroid Build Coastguard Worker 
296*4bdc9457SAndroid Build Coastguard Worker     xnn_indirection_init_maxpool2d(max_pooling_op, step_height, step_width, log2_input_element_size);
297*4bdc9457SAndroid Build Coastguard Worker 
298*4bdc9457SAndroid Build Coastguard Worker     max_pooling_op->last_input = input;
299*4bdc9457SAndroid Build Coastguard Worker     max_pooling_op->last_input_height = input_height;
300*4bdc9457SAndroid Build Coastguard Worker     max_pooling_op->last_input_width = input_width;
301*4bdc9457SAndroid Build Coastguard Worker   }
302*4bdc9457SAndroid Build Coastguard Worker 
303*4bdc9457SAndroid Build Coastguard Worker   const uint32_t qr = maxpool->qr;
304*4bdc9457SAndroid Build Coastguard Worker   const size_t channels = max_pooling_op->channels;
305*4bdc9457SAndroid Build Coastguard Worker 
306*4bdc9457SAndroid Build Coastguard Worker   const size_t indirect_input_height_stride = step_height * sizeof(void*);
307*4bdc9457SAndroid Build Coastguard Worker   const size_t output_width_stride = max_pooling_op->output_pixel_stride << log2_output_element_size;
308*4bdc9457SAndroid Build Coastguard Worker   const size_t output_height_stride = output_width * output_width_stride;
309*4bdc9457SAndroid Build Coastguard Worker   const size_t multipass_adjustment = round_up(doz(pooling_size, mr), qr) + mr;
310*4bdc9457SAndroid Build Coastguard Worker 
311*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->context.max_pooling = (struct max_pooling_context) {
312*4bdc9457SAndroid Build Coastguard Worker     .indirect_input = max_pooling_op->indirection_buffer,
313*4bdc9457SAndroid Build Coastguard Worker     .indirect_input_height_stride = indirect_input_height_stride,
314*4bdc9457SAndroid Build Coastguard Worker     .input_offset = (size_t) ((uintptr_t) input - (uintptr_t) max_pooling_op->last_input),
315*4bdc9457SAndroid Build Coastguard Worker     .input_batch_stride = (input_height * input_width * max_pooling_op->input_pixel_stride) << log2_input_element_size,
316*4bdc9457SAndroid Build Coastguard Worker     .output = output,
317*4bdc9457SAndroid Build Coastguard Worker     .output_batch_stride = output_height * output_height_stride,
318*4bdc9457SAndroid Build Coastguard Worker     .output_height_stride = output_height_stride,
319*4bdc9457SAndroid Build Coastguard Worker     .output_width = output_width,
320*4bdc9457SAndroid Build Coastguard Worker     .pooling_size = pooling_size,
321*4bdc9457SAndroid Build Coastguard Worker     .channels = channels,
322*4bdc9457SAndroid Build Coastguard Worker     .input_increment = (pooling_height * step_width - multipass_adjustment) * sizeof(void*),
323*4bdc9457SAndroid Build Coastguard Worker     .output_increment = output_width_stride - (channels << log2_output_element_size),
324*4bdc9457SAndroid Build Coastguard Worker     .ukernel = maxpool->ukernel,
325*4bdc9457SAndroid Build Coastguard Worker   };
326*4bdc9457SAndroid Build Coastguard Worker   memcpy(&max_pooling_op->context.max_pooling.params, params, params_size);
327*4bdc9457SAndroid Build Coastguard Worker 
328*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->compute.type = xnn_parallelization_type_2d;
329*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->compute.task_2d = (pthreadpool_task_2d_t) xnn_compute_max_pooling;
330*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->compute.range[0] = batch_size;
331*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->compute.range[1] = output_height;
332*4bdc9457SAndroid Build Coastguard Worker   max_pooling_op->state = xnn_run_state_ready;
333*4bdc9457SAndroid Build Coastguard Worker 
334*4bdc9457SAndroid Build Coastguard Worker   return xnn_status_success;
335*4bdc9457SAndroid Build Coastguard Worker }
336*4bdc9457SAndroid Build Coastguard Worker 
xnn_create_max_pooling2d_nhwc_s8(uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t pooling_height,uint32_t pooling_width,uint32_t stride_height,uint32_t stride_width,uint32_t dilation_height,uint32_t dilation_width,size_t channels,size_t input_pixel_stride,size_t output_pixel_stride,int8_t output_min,int8_t output_max,uint32_t flags,xnn_operator_t * max_pooling_op_out)337*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_max_pooling2d_nhwc_s8(
338*4bdc9457SAndroid Build Coastguard Worker     uint32_t input_padding_top,
339*4bdc9457SAndroid Build Coastguard Worker     uint32_t input_padding_right,
340*4bdc9457SAndroid Build Coastguard Worker     uint32_t input_padding_bottom,
341*4bdc9457SAndroid Build Coastguard Worker     uint32_t input_padding_left,
342*4bdc9457SAndroid Build Coastguard Worker     uint32_t pooling_height,
343*4bdc9457SAndroid Build Coastguard Worker     uint32_t pooling_width,
344*4bdc9457SAndroid Build Coastguard Worker     uint32_t stride_height,
345*4bdc9457SAndroid Build Coastguard Worker     uint32_t stride_width,
346*4bdc9457SAndroid Build Coastguard Worker     uint32_t dilation_height,
347*4bdc9457SAndroid Build Coastguard Worker     uint32_t dilation_width,
348*4bdc9457SAndroid Build Coastguard Worker     size_t channels,
349*4bdc9457SAndroid Build Coastguard Worker     size_t input_pixel_stride,
350*4bdc9457SAndroid Build Coastguard Worker     size_t output_pixel_stride,
351*4bdc9457SAndroid Build Coastguard Worker     int8_t output_min,
352*4bdc9457SAndroid Build Coastguard Worker     int8_t output_max,
353*4bdc9457SAndroid Build Coastguard Worker     uint32_t flags,
354*4bdc9457SAndroid Build Coastguard Worker     xnn_operator_t* max_pooling_op_out)
355*4bdc9457SAndroid Build Coastguard Worker {
356*4bdc9457SAndroid Build Coastguard Worker   if (output_min >= output_max) {
357*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error(
358*4bdc9457SAndroid Build Coastguard Worker       "failed to create %s operator with [%" PRId8 ", %" PRId8 "] output range: range min must be below range max",
359*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(xnn_operator_type_max_pooling_nhwc_s8), output_min, output_max);
360*4bdc9457SAndroid Build Coastguard Worker     return xnn_status_invalid_parameter;
361*4bdc9457SAndroid Build Coastguard Worker   }
362*4bdc9457SAndroid Build Coastguard Worker 
363*4bdc9457SAndroid Build Coastguard Worker   union xnn_s8_minmax_params params;
364*4bdc9457SAndroid Build Coastguard Worker   xnn_params.s8.maxpool.init.s8(&params, output_min, output_max);
365*4bdc9457SAndroid Build Coastguard Worker   return create_max_pooling2d_nhwc(
366*4bdc9457SAndroid Build Coastguard Worker     input_padding_top, input_padding_right, input_padding_bottom, input_padding_left,
367*4bdc9457SAndroid Build Coastguard Worker     pooling_height, pooling_width,
368*4bdc9457SAndroid Build Coastguard Worker     stride_height, stride_width,
369*4bdc9457SAndroid Build Coastguard Worker     dilation_height, dilation_width,
370*4bdc9457SAndroid Build Coastguard Worker     channels, input_pixel_stride, output_pixel_stride,
371*4bdc9457SAndroid Build Coastguard Worker     flags,
372*4bdc9457SAndroid Build Coastguard Worker     &params, sizeof(params), XNN_INIT_FLAG_S8,
373*4bdc9457SAndroid Build Coastguard Worker     xnn_operator_type_max_pooling_nhwc_s8,
374*4bdc9457SAndroid Build Coastguard Worker     max_pooling_op_out);
375*4bdc9457SAndroid Build Coastguard Worker }
376*4bdc9457SAndroid Build Coastguard Worker 
xnn_create_max_pooling2d_nhwc_u8(uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t pooling_height,uint32_t pooling_width,uint32_t stride_height,uint32_t stride_width,uint32_t dilation_height,uint32_t dilation_width,size_t channels,size_t input_pixel_stride,size_t output_pixel_stride,uint8_t output_min,uint8_t output_max,uint32_t flags,xnn_operator_t * max_pooling_op_out)377*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_max_pooling2d_nhwc_u8(
378*4bdc9457SAndroid Build Coastguard Worker     uint32_t input_padding_top,
379*4bdc9457SAndroid Build Coastguard Worker     uint32_t input_padding_right,
380*4bdc9457SAndroid Build Coastguard Worker     uint32_t input_padding_bottom,
381*4bdc9457SAndroid Build Coastguard Worker     uint32_t input_padding_left,
382*4bdc9457SAndroid Build Coastguard Worker     uint32_t pooling_height,
383*4bdc9457SAndroid Build Coastguard Worker     uint32_t pooling_width,
384*4bdc9457SAndroid Build Coastguard Worker     uint32_t stride_height,
385*4bdc9457SAndroid Build Coastguard Worker     uint32_t stride_width,
386*4bdc9457SAndroid Build Coastguard Worker     uint32_t dilation_height,
387*4bdc9457SAndroid Build Coastguard Worker     uint32_t dilation_width,
388*4bdc9457SAndroid Build Coastguard Worker     size_t channels,
389*4bdc9457SAndroid Build Coastguard Worker     size_t input_pixel_stride,
390*4bdc9457SAndroid Build Coastguard Worker     size_t output_pixel_stride,
391*4bdc9457SAndroid Build Coastguard Worker     uint8_t output_min,
392*4bdc9457SAndroid Build Coastguard Worker     uint8_t output_max,
393*4bdc9457SAndroid Build Coastguard Worker     uint32_t flags,
394*4bdc9457SAndroid Build Coastguard Worker     xnn_operator_t* max_pooling_op_out)
395*4bdc9457SAndroid Build Coastguard Worker {
396*4bdc9457SAndroid Build Coastguard Worker   if (output_min >= output_max) {
397*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error(
398*4bdc9457SAndroid Build Coastguard Worker       "failed to create %s operator with [%" PRIu8 ", %" PRIu8 "] output range: range min must be below range max",
399*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(xnn_operator_type_max_pooling_nhwc_u8), output_min, output_max);
400*4bdc9457SAndroid Build Coastguard Worker     return xnn_status_invalid_parameter;
401*4bdc9457SAndroid Build Coastguard Worker   }
402*4bdc9457SAndroid Build Coastguard Worker 
403*4bdc9457SAndroid Build Coastguard Worker   union xnn_u8_minmax_params params;
404*4bdc9457SAndroid Build Coastguard Worker   xnn_params.u8.maxpool.init.u8(&params, output_min, output_max);
405*4bdc9457SAndroid Build Coastguard Worker   return create_max_pooling2d_nhwc(
406*4bdc9457SAndroid Build Coastguard Worker     input_padding_top, input_padding_right, input_padding_bottom, input_padding_left,
407*4bdc9457SAndroid Build Coastguard Worker     pooling_height, pooling_width,
408*4bdc9457SAndroid Build Coastguard Worker     stride_height, stride_width,
409*4bdc9457SAndroid Build Coastguard Worker     dilation_height, dilation_width,
410*4bdc9457SAndroid Build Coastguard Worker     channels, input_pixel_stride, output_pixel_stride,
411*4bdc9457SAndroid Build Coastguard Worker     flags,
412*4bdc9457SAndroid Build Coastguard Worker     &params, sizeof(params), XNN_INIT_FLAG_U8,
413*4bdc9457SAndroid Build Coastguard Worker     xnn_operator_type_max_pooling_nhwc_u8,
414*4bdc9457SAndroid Build Coastguard Worker     max_pooling_op_out);
415*4bdc9457SAndroid Build Coastguard Worker }
416*4bdc9457SAndroid Build Coastguard Worker 
xnn_create_max_pooling2d_nhwc_f32(uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t pooling_height,uint32_t pooling_width,uint32_t stride_height,uint32_t stride_width,uint32_t dilation_height,uint32_t dilation_width,size_t channels,size_t input_pixel_stride,size_t output_pixel_stride,float output_min,float output_max,uint32_t flags,xnn_operator_t * max_pooling_op_out)417*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_max_pooling2d_nhwc_f32(
418*4bdc9457SAndroid Build Coastguard Worker     uint32_t input_padding_top,
419*4bdc9457SAndroid Build Coastguard Worker     uint32_t input_padding_right,
420*4bdc9457SAndroid Build Coastguard Worker     uint32_t input_padding_bottom,
421*4bdc9457SAndroid Build Coastguard Worker     uint32_t input_padding_left,
422*4bdc9457SAndroid Build Coastguard Worker     uint32_t pooling_height,
423*4bdc9457SAndroid Build Coastguard Worker     uint32_t pooling_width,
424*4bdc9457SAndroid Build Coastguard Worker     uint32_t stride_height,
425*4bdc9457SAndroid Build Coastguard Worker     uint32_t stride_width,
426*4bdc9457SAndroid Build Coastguard Worker     uint32_t dilation_height,
427*4bdc9457SAndroid Build Coastguard Worker     uint32_t dilation_width,
428*4bdc9457SAndroid Build Coastguard Worker     size_t channels,
429*4bdc9457SAndroid Build Coastguard Worker     size_t input_pixel_stride,
430*4bdc9457SAndroid Build Coastguard Worker     size_t output_pixel_stride,
431*4bdc9457SAndroid Build Coastguard Worker     float output_min,
432*4bdc9457SAndroid Build Coastguard Worker     float output_max,
433*4bdc9457SAndroid Build Coastguard Worker     uint32_t flags,
434*4bdc9457SAndroid Build Coastguard Worker     xnn_operator_t* max_pooling_op_out)
435*4bdc9457SAndroid Build Coastguard Worker {
436*4bdc9457SAndroid Build Coastguard Worker   if (isnan(output_min)) {
437*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error(
438*4bdc9457SAndroid Build Coastguard Worker       "failed to create %s with NaN output lower bound: lower bound must be non-NaN",
439*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(xnn_operator_type_max_pooling_nhwc_f32));
440*4bdc9457SAndroid Build Coastguard Worker     return xnn_status_invalid_parameter;
441*4bdc9457SAndroid Build Coastguard Worker   }
442*4bdc9457SAndroid Build Coastguard Worker 
443*4bdc9457SAndroid Build Coastguard Worker   if (isnan(output_max)) {
444*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error(
445*4bdc9457SAndroid Build Coastguard Worker       "failed to create %s with NaN output upper bound: upper bound must be non-NaN",
446*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(xnn_operator_type_max_pooling_nhwc_f32));
447*4bdc9457SAndroid Build Coastguard Worker     return xnn_status_invalid_parameter;
448*4bdc9457SAndroid Build Coastguard Worker   }
449*4bdc9457SAndroid Build Coastguard Worker 
450*4bdc9457SAndroid Build Coastguard Worker   if (output_min >= output_max) {
451*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error(
452*4bdc9457SAndroid Build Coastguard Worker       "failed to create %s with [%.7g, %.7g] output range: lower bound must be below upper bound",
453*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(xnn_operator_type_max_pooling_nhwc_f32), output_min, output_max);
454*4bdc9457SAndroid Build Coastguard Worker     return xnn_status_invalid_parameter;
455*4bdc9457SAndroid Build Coastguard Worker   }
456*4bdc9457SAndroid Build Coastguard Worker 
457*4bdc9457SAndroid Build Coastguard Worker   union xnn_f32_minmax_params params;
458*4bdc9457SAndroid Build Coastguard Worker   xnn_params.f32.maxpool.init.f32(&params, output_min, output_max);
459*4bdc9457SAndroid Build Coastguard Worker   return create_max_pooling2d_nhwc(
460*4bdc9457SAndroid Build Coastguard Worker     input_padding_top, input_padding_right, input_padding_bottom, input_padding_left,
461*4bdc9457SAndroid Build Coastguard Worker     pooling_height, pooling_width,
462*4bdc9457SAndroid Build Coastguard Worker     stride_height, stride_width,
463*4bdc9457SAndroid Build Coastguard Worker     dilation_height, dilation_width,
464*4bdc9457SAndroid Build Coastguard Worker     channels, input_pixel_stride, output_pixel_stride,
465*4bdc9457SAndroid Build Coastguard Worker     flags,
466*4bdc9457SAndroid Build Coastguard Worker     &params, sizeof(params), XNN_INIT_FLAG_F32,
467*4bdc9457SAndroid Build Coastguard Worker     xnn_operator_type_max_pooling_nhwc_f32,
468*4bdc9457SAndroid Build Coastguard Worker     max_pooling_op_out);
469*4bdc9457SAndroid Build Coastguard Worker }
470*4bdc9457SAndroid Build Coastguard Worker 
xnn_create_max_pooling2d_nhwc_f16(uint32_t input_padding_top,uint32_t input_padding_right,uint32_t input_padding_bottom,uint32_t input_padding_left,uint32_t pooling_height,uint32_t pooling_width,uint32_t stride_height,uint32_t stride_width,uint32_t dilation_height,uint32_t dilation_width,size_t channels,size_t input_pixel_stride,size_t output_pixel_stride,float output_min,float output_max,uint32_t flags,xnn_operator_t * max_pooling_op_out)471*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_create_max_pooling2d_nhwc_f16(
472*4bdc9457SAndroid Build Coastguard Worker     uint32_t input_padding_top,
473*4bdc9457SAndroid Build Coastguard Worker     uint32_t input_padding_right,
474*4bdc9457SAndroid Build Coastguard Worker     uint32_t input_padding_bottom,
475*4bdc9457SAndroid Build Coastguard Worker     uint32_t input_padding_left,
476*4bdc9457SAndroid Build Coastguard Worker     uint32_t pooling_height,
477*4bdc9457SAndroid Build Coastguard Worker     uint32_t pooling_width,
478*4bdc9457SAndroid Build Coastguard Worker     uint32_t stride_height,
479*4bdc9457SAndroid Build Coastguard Worker     uint32_t stride_width,
480*4bdc9457SAndroid Build Coastguard Worker     uint32_t dilation_height,
481*4bdc9457SAndroid Build Coastguard Worker     uint32_t dilation_width,
482*4bdc9457SAndroid Build Coastguard Worker     size_t channels,
483*4bdc9457SAndroid Build Coastguard Worker     size_t input_pixel_stride,
484*4bdc9457SAndroid Build Coastguard Worker     size_t output_pixel_stride,
485*4bdc9457SAndroid Build Coastguard Worker     float output_min,
486*4bdc9457SAndroid Build Coastguard Worker     float output_max,
487*4bdc9457SAndroid Build Coastguard Worker     uint32_t flags,
488*4bdc9457SAndroid Build Coastguard Worker     xnn_operator_t* max_pooling_op_out)
489*4bdc9457SAndroid Build Coastguard Worker {
490*4bdc9457SAndroid Build Coastguard Worker   if (isnan(output_min)) {
491*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error(
492*4bdc9457SAndroid Build Coastguard Worker       "failed to create %s with NaN output lower bound: lower bound must be non-NaN",
493*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(xnn_operator_type_max_pooling_nhwc_f16));
494*4bdc9457SAndroid Build Coastguard Worker     return xnn_status_invalid_parameter;
495*4bdc9457SAndroid Build Coastguard Worker   }
496*4bdc9457SAndroid Build Coastguard Worker 
497*4bdc9457SAndroid Build Coastguard Worker   if (isnan(output_max)) {
498*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error(
499*4bdc9457SAndroid Build Coastguard Worker       "failed to create %s with NaN output upper bound: upper bound must be non-NaN",
500*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(xnn_operator_type_max_pooling_nhwc_f16));
501*4bdc9457SAndroid Build Coastguard Worker     return xnn_status_invalid_parameter;
502*4bdc9457SAndroid Build Coastguard Worker   }
503*4bdc9457SAndroid Build Coastguard Worker 
504*4bdc9457SAndroid Build Coastguard Worker   const uint16_t output_min_as_half = fp16_ieee_from_fp32_value(output_min);
505*4bdc9457SAndroid Build Coastguard Worker   const uint16_t output_max_as_half = fp16_ieee_from_fp32_value(output_max);
506*4bdc9457SAndroid Build Coastguard Worker   output_min = fp16_ieee_to_fp32_value(output_min_as_half);
507*4bdc9457SAndroid Build Coastguard Worker   output_max = fp16_ieee_to_fp32_value(output_max_as_half);
508*4bdc9457SAndroid Build Coastguard Worker   if (output_min >= output_max) {
509*4bdc9457SAndroid Build Coastguard Worker     xnn_log_error(
510*4bdc9457SAndroid Build Coastguard Worker       "failed to create %s operator with [%.7g, %.7g] output range: lower bound must be below upper bound",
511*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_type_to_string(xnn_operator_type_max_pooling_nhwc_f16), output_min, output_max);
512*4bdc9457SAndroid Build Coastguard Worker     return xnn_status_invalid_parameter;
513*4bdc9457SAndroid Build Coastguard Worker   }
514*4bdc9457SAndroid Build Coastguard Worker 
515*4bdc9457SAndroid Build Coastguard Worker   union xnn_f16_minmax_params params;
516*4bdc9457SAndroid Build Coastguard Worker   if (xnn_params.f16.maxpool.init.f16 != NULL) {
517*4bdc9457SAndroid Build Coastguard Worker     xnn_params.f16.maxpool.init.f16(&params, output_min_as_half, output_max_as_half);
518*4bdc9457SAndroid Build Coastguard Worker   }
519*4bdc9457SAndroid Build Coastguard Worker   return create_max_pooling2d_nhwc(
520*4bdc9457SAndroid Build Coastguard Worker     input_padding_top, input_padding_right, input_padding_bottom, input_padding_left,
521*4bdc9457SAndroid Build Coastguard Worker     pooling_height, pooling_width,
522*4bdc9457SAndroid Build Coastguard Worker     stride_height, stride_width,
523*4bdc9457SAndroid Build Coastguard Worker     dilation_height, dilation_width,
524*4bdc9457SAndroid Build Coastguard Worker     channels, input_pixel_stride, output_pixel_stride,
525*4bdc9457SAndroid Build Coastguard Worker     flags,
526*4bdc9457SAndroid Build Coastguard Worker     &params, sizeof(params), XNN_INIT_FLAG_F16,
527*4bdc9457SAndroid Build Coastguard Worker     xnn_operator_type_max_pooling_nhwc_f16,
528*4bdc9457SAndroid Build Coastguard Worker     max_pooling_op_out);
529*4bdc9457SAndroid Build Coastguard Worker }
530*4bdc9457SAndroid Build Coastguard Worker 
xnn_setup_max_pooling2d_nhwc_s8(xnn_operator_t max_pooling_op,size_t batch_size,size_t input_height,size_t input_width,const int8_t * input,int8_t * output,pthreadpool_t threadpool)531*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_max_pooling2d_nhwc_s8(
532*4bdc9457SAndroid Build Coastguard Worker     xnn_operator_t max_pooling_op,
533*4bdc9457SAndroid Build Coastguard Worker     size_t batch_size,
534*4bdc9457SAndroid Build Coastguard Worker     size_t input_height,
535*4bdc9457SAndroid Build Coastguard Worker     size_t input_width,
536*4bdc9457SAndroid Build Coastguard Worker     const int8_t* input,
537*4bdc9457SAndroid Build Coastguard Worker     int8_t* output,
538*4bdc9457SAndroid Build Coastguard Worker     pthreadpool_t threadpool)
539*4bdc9457SAndroid Build Coastguard Worker {
540*4bdc9457SAndroid Build Coastguard Worker   return setup_max_pooling2d_nhwc(
541*4bdc9457SAndroid Build Coastguard Worker     max_pooling_op, xnn_operator_type_max_pooling_nhwc_s8,
542*4bdc9457SAndroid Build Coastguard Worker     batch_size, input_height, input_width,
543*4bdc9457SAndroid Build Coastguard Worker     input, output,
544*4bdc9457SAndroid Build Coastguard Worker     0 /* log2(sizeof(input element)) = log2(sizeof(int8_t)) */,
545*4bdc9457SAndroid Build Coastguard Worker     0 /* log2(sizeof(output element)) = log2(sizeof(int8_t)) */,
546*4bdc9457SAndroid Build Coastguard Worker     &xnn_params.s8.maxpool,
547*4bdc9457SAndroid Build Coastguard Worker     &max_pooling_op->params.s8_minmax, sizeof(max_pooling_op->params.s8_minmax),
548*4bdc9457SAndroid Build Coastguard Worker     pthreadpool_get_threads_count(threadpool));
549*4bdc9457SAndroid Build Coastguard Worker }
550*4bdc9457SAndroid Build Coastguard Worker 
xnn_setup_max_pooling2d_nhwc_u8(xnn_operator_t max_pooling_op,size_t batch_size,size_t input_height,size_t input_width,const uint8_t * input,uint8_t * output,pthreadpool_t threadpool)551*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_max_pooling2d_nhwc_u8(
552*4bdc9457SAndroid Build Coastguard Worker     xnn_operator_t max_pooling_op,
553*4bdc9457SAndroid Build Coastguard Worker     size_t batch_size,
554*4bdc9457SAndroid Build Coastguard Worker     size_t input_height,
555*4bdc9457SAndroid Build Coastguard Worker     size_t input_width,
556*4bdc9457SAndroid Build Coastguard Worker     const uint8_t* input,
557*4bdc9457SAndroid Build Coastguard Worker     uint8_t* output,
558*4bdc9457SAndroid Build Coastguard Worker     pthreadpool_t threadpool)
559*4bdc9457SAndroid Build Coastguard Worker {
560*4bdc9457SAndroid Build Coastguard Worker   return setup_max_pooling2d_nhwc(
561*4bdc9457SAndroid Build Coastguard Worker     max_pooling_op, xnn_operator_type_max_pooling_nhwc_u8,
562*4bdc9457SAndroid Build Coastguard Worker     batch_size, input_height, input_width,
563*4bdc9457SAndroid Build Coastguard Worker     input, output,
564*4bdc9457SAndroid Build Coastguard Worker     0 /* log2(sizeof(input element)) = log2(sizeof(uint8_t)) */,
565*4bdc9457SAndroid Build Coastguard Worker     0 /* log2(sizeof(output element)) = log2(sizeof(uint8_t)) */,
566*4bdc9457SAndroid Build Coastguard Worker     &xnn_params.u8.maxpool,
567*4bdc9457SAndroid Build Coastguard Worker     &max_pooling_op->params.u8_minmax, sizeof(max_pooling_op->params.u8_minmax),
568*4bdc9457SAndroid Build Coastguard Worker     pthreadpool_get_threads_count(threadpool));
569*4bdc9457SAndroid Build Coastguard Worker }
570*4bdc9457SAndroid Build Coastguard Worker 
xnn_setup_max_pooling2d_nhwc_f16(xnn_operator_t max_pooling_op,size_t batch_size,size_t input_height,size_t input_width,const void * input,void * output,pthreadpool_t threadpool)571*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_max_pooling2d_nhwc_f16(
572*4bdc9457SAndroid Build Coastguard Worker     xnn_operator_t max_pooling_op,
573*4bdc9457SAndroid Build Coastguard Worker     size_t batch_size,
574*4bdc9457SAndroid Build Coastguard Worker     size_t input_height,
575*4bdc9457SAndroid Build Coastguard Worker     size_t input_width,
576*4bdc9457SAndroid Build Coastguard Worker     const void* input,
577*4bdc9457SAndroid Build Coastguard Worker     void* output,
578*4bdc9457SAndroid Build Coastguard Worker     pthreadpool_t threadpool)
579*4bdc9457SAndroid Build Coastguard Worker {
580*4bdc9457SAndroid Build Coastguard Worker   return setup_max_pooling2d_nhwc(
581*4bdc9457SAndroid Build Coastguard Worker     max_pooling_op, xnn_operator_type_max_pooling_nhwc_f16,
582*4bdc9457SAndroid Build Coastguard Worker     batch_size, input_height, input_width,
583*4bdc9457SAndroid Build Coastguard Worker     input, output,
584*4bdc9457SAndroid Build Coastguard Worker     1 /* log2(sizeof(input element)) = log2(sizeof(uint16_t)) */,
585*4bdc9457SAndroid Build Coastguard Worker     1 /* log2(sizeof(output element)) = log2(sizeof(uint16_t)) */,
586*4bdc9457SAndroid Build Coastguard Worker     &xnn_params.f16.maxpool,
587*4bdc9457SAndroid Build Coastguard Worker     &max_pooling_op->params.f16_minmax, sizeof(max_pooling_op->params.f16_minmax),
588*4bdc9457SAndroid Build Coastguard Worker     pthreadpool_get_threads_count(threadpool));
589*4bdc9457SAndroid Build Coastguard Worker }
590*4bdc9457SAndroid Build Coastguard Worker 
xnn_setup_max_pooling2d_nhwc_f32(xnn_operator_t max_pooling_op,size_t batch_size,size_t input_height,size_t input_width,const float * input,float * output,pthreadpool_t threadpool)591*4bdc9457SAndroid Build Coastguard Worker enum xnn_status xnn_setup_max_pooling2d_nhwc_f32(
592*4bdc9457SAndroid Build Coastguard Worker     xnn_operator_t max_pooling_op,
593*4bdc9457SAndroid Build Coastguard Worker     size_t batch_size,
594*4bdc9457SAndroid Build Coastguard Worker     size_t input_height,
595*4bdc9457SAndroid Build Coastguard Worker     size_t input_width,
596*4bdc9457SAndroid Build Coastguard Worker     const float* input,
597*4bdc9457SAndroid Build Coastguard Worker     float* output,
598*4bdc9457SAndroid Build Coastguard Worker     pthreadpool_t threadpool)
599*4bdc9457SAndroid Build Coastguard Worker {
600*4bdc9457SAndroid Build Coastguard Worker   return setup_max_pooling2d_nhwc(
601*4bdc9457SAndroid Build Coastguard Worker     max_pooling_op, xnn_operator_type_max_pooling_nhwc_f32,
602*4bdc9457SAndroid Build Coastguard Worker     batch_size, input_height, input_width,
603*4bdc9457SAndroid Build Coastguard Worker     input, output,
604*4bdc9457SAndroid Build Coastguard Worker     2 /* log2(sizeof(input element)) = log2(sizeof(float)) */,
605*4bdc9457SAndroid Build Coastguard Worker     2 /* log2(sizeof(output element)) = log2(sizeof(float)) */,
606*4bdc9457SAndroid Build Coastguard Worker     &xnn_params.f32.maxpool,
607*4bdc9457SAndroid Build Coastguard Worker     &max_pooling_op->params.f32_minmax, sizeof(max_pooling_op->params.f32_minmax),
608*4bdc9457SAndroid Build Coastguard Worker     pthreadpool_get_threads_count(threadpool));
609*4bdc9457SAndroid Build Coastguard Worker }
610*4bdc9457SAndroid Build Coastguard Worker 
611