xref: /aosp_15_r20/external/XNNPACK/test/average-pooling-operator-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates.
2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved.
3*4bdc9457SAndroid Build Coastguard Worker //
4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
5*4bdc9457SAndroid Build Coastguard Worker //
6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
8*4bdc9457SAndroid Build Coastguard Worker 
9*4bdc9457SAndroid Build Coastguard Worker #pragma once
10*4bdc9457SAndroid Build Coastguard Worker 
11*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
12*4bdc9457SAndroid Build Coastguard Worker 
13*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
14*4bdc9457SAndroid Build Coastguard Worker 
15*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
16*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
17*4bdc9457SAndroid Build Coastguard Worker #include <cassert>
18*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
19*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
20*4bdc9457SAndroid Build Coastguard Worker #include <limits>
21*4bdc9457SAndroid Build Coastguard Worker #include <random>
22*4bdc9457SAndroid Build Coastguard Worker #include <vector>
23*4bdc9457SAndroid Build Coastguard Worker 
24*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
25*4bdc9457SAndroid Build Coastguard Worker 
26*4bdc9457SAndroid Build Coastguard Worker 
27*4bdc9457SAndroid Build Coastguard Worker class AveragePoolingOperatorTester {
28*4bdc9457SAndroid Build Coastguard Worker  public:
padding_tf_same(bool padding_same)29*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& padding_tf_same(bool padding_same) {
30*4bdc9457SAndroid Build Coastguard Worker     if (padding_same) {
31*4bdc9457SAndroid Build Coastguard Worker       assert(padding_top() == 0);
32*4bdc9457SAndroid Build Coastguard Worker       assert(padding_left() == 0);
33*4bdc9457SAndroid Build Coastguard Worker       assert(padding_bottom() == 0);
34*4bdc9457SAndroid Build Coastguard Worker       assert(padding_right() == 0);
35*4bdc9457SAndroid Build Coastguard Worker     }
36*4bdc9457SAndroid Build Coastguard Worker     this->padding_tf_same_ = padding_same;
37*4bdc9457SAndroid Build Coastguard Worker     return *this;
38*4bdc9457SAndroid Build Coastguard Worker   }
39*4bdc9457SAndroid Build Coastguard Worker 
padding_tf_same()40*4bdc9457SAndroid Build Coastguard Worker   inline bool padding_tf_same() const {
41*4bdc9457SAndroid Build Coastguard Worker     return this->padding_tf_same_;
42*4bdc9457SAndroid Build Coastguard Worker   }
43*4bdc9457SAndroid Build Coastguard Worker 
padding(uint32_t padding)44*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& padding(uint32_t padding) {
45*4bdc9457SAndroid Build Coastguard Worker     assert(!padding_tf_same());
46*4bdc9457SAndroid Build Coastguard Worker     this->padding_top_ = padding;
47*4bdc9457SAndroid Build Coastguard Worker     this->padding_right_ = padding;
48*4bdc9457SAndroid Build Coastguard Worker     this->padding_bottom_ = padding;
49*4bdc9457SAndroid Build Coastguard Worker     this->padding_left_ = padding;
50*4bdc9457SAndroid Build Coastguard Worker     return *this;
51*4bdc9457SAndroid Build Coastguard Worker   }
52*4bdc9457SAndroid Build Coastguard Worker 
padding(uint32_t padding_height,uint32_t padding_width)53*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& padding(uint32_t padding_height, uint32_t padding_width) {
54*4bdc9457SAndroid Build Coastguard Worker     assert(!padding_tf_same());
55*4bdc9457SAndroid Build Coastguard Worker     this->padding_top_ = padding_height;
56*4bdc9457SAndroid Build Coastguard Worker     this->padding_right_ = padding_width;
57*4bdc9457SAndroid Build Coastguard Worker     this->padding_bottom_ = padding_height;
58*4bdc9457SAndroid Build Coastguard Worker     this->padding_left_ = padding_width;
59*4bdc9457SAndroid Build Coastguard Worker     return *this;
60*4bdc9457SAndroid Build Coastguard Worker   }
61*4bdc9457SAndroid Build Coastguard Worker 
padding_height(uint32_t padding_height)62*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& padding_height(uint32_t padding_height) {
63*4bdc9457SAndroid Build Coastguard Worker     assert(!padding_tf_same());
64*4bdc9457SAndroid Build Coastguard Worker     this->padding_top_ = padding_height;
65*4bdc9457SAndroid Build Coastguard Worker     this->padding_bottom_ = padding_height;
66*4bdc9457SAndroid Build Coastguard Worker     return *this;
67*4bdc9457SAndroid Build Coastguard Worker   }
68*4bdc9457SAndroid Build Coastguard Worker 
padding_width(uint32_t padding_width)69*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& padding_width(uint32_t padding_width) {
70*4bdc9457SAndroid Build Coastguard Worker     assert(!padding_tf_same());
71*4bdc9457SAndroid Build Coastguard Worker     this->padding_right_ = padding_width;
72*4bdc9457SAndroid Build Coastguard Worker     this->padding_left_ = padding_width;
73*4bdc9457SAndroid Build Coastguard Worker     return *this;
74*4bdc9457SAndroid Build Coastguard Worker   }
75*4bdc9457SAndroid Build Coastguard Worker 
padding_top(uint32_t padding_top)76*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& padding_top(uint32_t padding_top) {
77*4bdc9457SAndroid Build Coastguard Worker     assert(!padding_tf_same());
78*4bdc9457SAndroid Build Coastguard Worker     this->padding_top_ = padding_top;
79*4bdc9457SAndroid Build Coastguard Worker     return *this;
80*4bdc9457SAndroid Build Coastguard Worker   }
81*4bdc9457SAndroid Build Coastguard Worker 
padding_top()82*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t padding_top() const {
83*4bdc9457SAndroid Build Coastguard Worker     if (padding_tf_same()) {
84*4bdc9457SAndroid Build Coastguard Worker       const uint32_t total_padding_height =
85*4bdc9457SAndroid Build Coastguard Worker         (output_height() - 1) * stride_height() + pooling_height() - input_height();
86*4bdc9457SAndroid Build Coastguard Worker       return total_padding_height / 2;
87*4bdc9457SAndroid Build Coastguard Worker     } else {
88*4bdc9457SAndroid Build Coastguard Worker       return this->padding_top_;
89*4bdc9457SAndroid Build Coastguard Worker     }
90*4bdc9457SAndroid Build Coastguard Worker   }
91*4bdc9457SAndroid Build Coastguard Worker 
padding_left(uint32_t padding_left)92*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& padding_left(uint32_t padding_left) {
93*4bdc9457SAndroid Build Coastguard Worker     assert(!padding_tf_same());
94*4bdc9457SAndroid Build Coastguard Worker     this->padding_left_ = padding_left;
95*4bdc9457SAndroid Build Coastguard Worker     return *this;
96*4bdc9457SAndroid Build Coastguard Worker   }
97*4bdc9457SAndroid Build Coastguard Worker 
padding_left()98*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t padding_left() const {
99*4bdc9457SAndroid Build Coastguard Worker     if (padding_tf_same()) {
100*4bdc9457SAndroid Build Coastguard Worker       const uint32_t total_padding_width =
101*4bdc9457SAndroid Build Coastguard Worker         (output_width() - 1) * stride_width() + pooling_width() - input_width();
102*4bdc9457SAndroid Build Coastguard Worker       return total_padding_width / 2;
103*4bdc9457SAndroid Build Coastguard Worker     } else {
104*4bdc9457SAndroid Build Coastguard Worker       return this->padding_left_;
105*4bdc9457SAndroid Build Coastguard Worker     }
106*4bdc9457SAndroid Build Coastguard Worker   }
107*4bdc9457SAndroid Build Coastguard Worker 
padding_bottom(uint32_t padding_bottom)108*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& padding_bottom(uint32_t padding_bottom) {
109*4bdc9457SAndroid Build Coastguard Worker     assert(!padding_tf_same());
110*4bdc9457SAndroid Build Coastguard Worker     this->padding_bottom_ = padding_bottom;
111*4bdc9457SAndroid Build Coastguard Worker     return *this;
112*4bdc9457SAndroid Build Coastguard Worker   }
113*4bdc9457SAndroid Build Coastguard Worker 
padding_bottom()114*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t padding_bottom() const {
115*4bdc9457SAndroid Build Coastguard Worker     if (padding_tf_same()) {
116*4bdc9457SAndroid Build Coastguard Worker       const uint32_t total_padding_height =
117*4bdc9457SAndroid Build Coastguard Worker         (output_height() - 1) * stride_height() + pooling_height() - input_height();
118*4bdc9457SAndroid Build Coastguard Worker       return total_padding_height - total_padding_height / 2;
119*4bdc9457SAndroid Build Coastguard Worker     } else {
120*4bdc9457SAndroid Build Coastguard Worker       return this->padding_bottom_;
121*4bdc9457SAndroid Build Coastguard Worker     }
122*4bdc9457SAndroid Build Coastguard Worker   }
123*4bdc9457SAndroid Build Coastguard Worker 
padding_right(uint32_t padding_right)124*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& padding_right(uint32_t padding_right) {
125*4bdc9457SAndroid Build Coastguard Worker     assert(!padding_tf_same());
126*4bdc9457SAndroid Build Coastguard Worker     this->padding_right_ = padding_right;
127*4bdc9457SAndroid Build Coastguard Worker     return *this;
128*4bdc9457SAndroid Build Coastguard Worker   }
129*4bdc9457SAndroid Build Coastguard Worker 
padding_right()130*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t padding_right() const {
131*4bdc9457SAndroid Build Coastguard Worker     if (padding_tf_same()) {
132*4bdc9457SAndroid Build Coastguard Worker       const uint32_t total_padding_width =
133*4bdc9457SAndroid Build Coastguard Worker         (output_width() - 1) * stride_width() + pooling_width() - input_width();
134*4bdc9457SAndroid Build Coastguard Worker       return total_padding_width - total_padding_width / 2;
135*4bdc9457SAndroid Build Coastguard Worker     } else {
136*4bdc9457SAndroid Build Coastguard Worker       return this->padding_right_;
137*4bdc9457SAndroid Build Coastguard Worker     }
138*4bdc9457SAndroid Build Coastguard Worker   }
139*4bdc9457SAndroid Build Coastguard Worker 
input_size(size_t input_height,size_t input_width)140*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& input_size(size_t input_height, size_t input_width) {
141*4bdc9457SAndroid Build Coastguard Worker     assert(input_height >= 1);
142*4bdc9457SAndroid Build Coastguard Worker     assert(input_width >= 1);
143*4bdc9457SAndroid Build Coastguard Worker     this->input_height_ = input_height;
144*4bdc9457SAndroid Build Coastguard Worker     this->input_width_ = input_width;
145*4bdc9457SAndroid Build Coastguard Worker     return *this;
146*4bdc9457SAndroid Build Coastguard Worker   }
147*4bdc9457SAndroid Build Coastguard Worker 
input_height(size_t input_height)148*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& input_height(size_t input_height) {
149*4bdc9457SAndroid Build Coastguard Worker     assert(input_height >= 1);
150*4bdc9457SAndroid Build Coastguard Worker     this->input_height_ = input_height;
151*4bdc9457SAndroid Build Coastguard Worker     return *this;
152*4bdc9457SAndroid Build Coastguard Worker   }
153*4bdc9457SAndroid Build Coastguard Worker 
input_height()154*4bdc9457SAndroid Build Coastguard Worker   inline size_t input_height() const {
155*4bdc9457SAndroid Build Coastguard Worker     return this->input_height_;
156*4bdc9457SAndroid Build Coastguard Worker   }
157*4bdc9457SAndroid Build Coastguard Worker 
input_width(size_t input_width)158*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& input_width(size_t input_width) {
159*4bdc9457SAndroid Build Coastguard Worker     assert(input_width >= 1);
160*4bdc9457SAndroid Build Coastguard Worker     this->input_width_ = input_width;
161*4bdc9457SAndroid Build Coastguard Worker     return *this;
162*4bdc9457SAndroid Build Coastguard Worker   }
163*4bdc9457SAndroid Build Coastguard Worker 
input_width()164*4bdc9457SAndroid Build Coastguard Worker   inline size_t input_width() const {
165*4bdc9457SAndroid Build Coastguard Worker     return this->input_width_;
166*4bdc9457SAndroid Build Coastguard Worker   }
167*4bdc9457SAndroid Build Coastguard Worker 
channels(size_t channels)168*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& channels(size_t channels) {
169*4bdc9457SAndroid Build Coastguard Worker     assert(channels != 0);
170*4bdc9457SAndroid Build Coastguard Worker     this->channels_ = channels;
171*4bdc9457SAndroid Build Coastguard Worker     return *this;
172*4bdc9457SAndroid Build Coastguard Worker   }
173*4bdc9457SAndroid Build Coastguard Worker 
channels()174*4bdc9457SAndroid Build Coastguard Worker   inline size_t channels() const {
175*4bdc9457SAndroid Build Coastguard Worker     return this->channels_;
176*4bdc9457SAndroid Build Coastguard Worker   }
177*4bdc9457SAndroid Build Coastguard Worker 
batch_size(size_t batch_size)178*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& batch_size(size_t batch_size) {
179*4bdc9457SAndroid Build Coastguard Worker     assert(batch_size != 0);
180*4bdc9457SAndroid Build Coastguard Worker     this->batch_size_ = batch_size;
181*4bdc9457SAndroid Build Coastguard Worker     return *this;
182*4bdc9457SAndroid Build Coastguard Worker   }
183*4bdc9457SAndroid Build Coastguard Worker 
batch_size()184*4bdc9457SAndroid Build Coastguard Worker   inline size_t batch_size() const {
185*4bdc9457SAndroid Build Coastguard Worker     return this->batch_size_;
186*4bdc9457SAndroid Build Coastguard Worker   }
187*4bdc9457SAndroid Build Coastguard Worker 
pooling_size(uint32_t pooling_size)188*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& pooling_size(uint32_t pooling_size) {
189*4bdc9457SAndroid Build Coastguard Worker     assert(pooling_size >= 1);
190*4bdc9457SAndroid Build Coastguard Worker     this->pooling_height_ = pooling_size;
191*4bdc9457SAndroid Build Coastguard Worker     this->pooling_width_ = pooling_size;
192*4bdc9457SAndroid Build Coastguard Worker     return *this;
193*4bdc9457SAndroid Build Coastguard Worker   }
194*4bdc9457SAndroid Build Coastguard Worker 
pooling_size(uint32_t pooling_height,uint32_t pooling_width)195*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& pooling_size(uint32_t pooling_height, uint32_t pooling_width) {
196*4bdc9457SAndroid Build Coastguard Worker     assert(pooling_height >= 1);
197*4bdc9457SAndroid Build Coastguard Worker     assert(pooling_width >= 1);
198*4bdc9457SAndroid Build Coastguard Worker     this->pooling_height_ = pooling_height;
199*4bdc9457SAndroid Build Coastguard Worker     this->pooling_width_ = pooling_width;
200*4bdc9457SAndroid Build Coastguard Worker     return *this;
201*4bdc9457SAndroid Build Coastguard Worker   }
202*4bdc9457SAndroid Build Coastguard Worker 
pooling_height(uint32_t pooling_height)203*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& pooling_height(uint32_t pooling_height) {
204*4bdc9457SAndroid Build Coastguard Worker     assert(pooling_height >= 1);
205*4bdc9457SAndroid Build Coastguard Worker     this->pooling_height_ = pooling_height;
206*4bdc9457SAndroid Build Coastguard Worker     return *this;
207*4bdc9457SAndroid Build Coastguard Worker   }
208*4bdc9457SAndroid Build Coastguard Worker 
pooling_height()209*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t pooling_height() const {
210*4bdc9457SAndroid Build Coastguard Worker     return this->pooling_height_;
211*4bdc9457SAndroid Build Coastguard Worker   }
212*4bdc9457SAndroid Build Coastguard Worker 
pooling_width(uint32_t pooling_width)213*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& pooling_width(uint32_t pooling_width) {
214*4bdc9457SAndroid Build Coastguard Worker     assert(pooling_width >= 1);
215*4bdc9457SAndroid Build Coastguard Worker     this->pooling_width_ = pooling_width;
216*4bdc9457SAndroid Build Coastguard Worker     return *this;
217*4bdc9457SAndroid Build Coastguard Worker   }
218*4bdc9457SAndroid Build Coastguard Worker 
pooling_width()219*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t pooling_width() const {
220*4bdc9457SAndroid Build Coastguard Worker     return this->pooling_width_;
221*4bdc9457SAndroid Build Coastguard Worker   }
222*4bdc9457SAndroid Build Coastguard Worker 
stride(uint32_t stride)223*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& stride(uint32_t stride) {
224*4bdc9457SAndroid Build Coastguard Worker     assert(stride >= 1);
225*4bdc9457SAndroid Build Coastguard Worker     this->stride_height_ = stride;
226*4bdc9457SAndroid Build Coastguard Worker     this->stride_width_ = stride;
227*4bdc9457SAndroid Build Coastguard Worker     return *this;
228*4bdc9457SAndroid Build Coastguard Worker   }
229*4bdc9457SAndroid Build Coastguard Worker 
stride(uint32_t stride_height,uint32_t stride_width)230*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& stride(uint32_t stride_height, uint32_t stride_width) {
231*4bdc9457SAndroid Build Coastguard Worker     assert(stride_height >= 1);
232*4bdc9457SAndroid Build Coastguard Worker     assert(stride_width >= 1);
233*4bdc9457SAndroid Build Coastguard Worker     this->stride_height_ = stride_height;
234*4bdc9457SAndroid Build Coastguard Worker     this->stride_width_ = stride_width;
235*4bdc9457SAndroid Build Coastguard Worker     return *this;
236*4bdc9457SAndroid Build Coastguard Worker   }
237*4bdc9457SAndroid Build Coastguard Worker 
stride_height(uint32_t stride_height)238*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& stride_height(uint32_t stride_height) {
239*4bdc9457SAndroid Build Coastguard Worker     assert(stride_height >= 1);
240*4bdc9457SAndroid Build Coastguard Worker     this->stride_height_ = stride_height;
241*4bdc9457SAndroid Build Coastguard Worker     return *this;
242*4bdc9457SAndroid Build Coastguard Worker   }
243*4bdc9457SAndroid Build Coastguard Worker 
stride_height()244*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t stride_height() const {
245*4bdc9457SAndroid Build Coastguard Worker     return this->stride_height_;
246*4bdc9457SAndroid Build Coastguard Worker   }
247*4bdc9457SAndroid Build Coastguard Worker 
stride_width(uint32_t stride_width)248*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& stride_width(uint32_t stride_width) {
249*4bdc9457SAndroid Build Coastguard Worker     assert(stride_width >= 1);
250*4bdc9457SAndroid Build Coastguard Worker     this->stride_width_ = stride_width;
251*4bdc9457SAndroid Build Coastguard Worker     return *this;
252*4bdc9457SAndroid Build Coastguard Worker   }
253*4bdc9457SAndroid Build Coastguard Worker 
stride_width()254*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t stride_width() const {
255*4bdc9457SAndroid Build Coastguard Worker     return this->stride_width_;
256*4bdc9457SAndroid Build Coastguard Worker   }
257*4bdc9457SAndroid Build Coastguard Worker 
output_height()258*4bdc9457SAndroid Build Coastguard Worker   inline size_t output_height() const {
259*4bdc9457SAndroid Build Coastguard Worker     if (padding_tf_same()) {
260*4bdc9457SAndroid Build Coastguard Worker       return (input_height() + stride_height() - 1) / stride_height();
261*4bdc9457SAndroid Build Coastguard Worker     } else {
262*4bdc9457SAndroid Build Coastguard Worker       const size_t padded_input_height = padding_top() + input_height() + padding_bottom();
263*4bdc9457SAndroid Build Coastguard Worker       if (padded_input_height <= pooling_height()) {
264*4bdc9457SAndroid Build Coastguard Worker         return 1;
265*4bdc9457SAndroid Build Coastguard Worker       } else {
266*4bdc9457SAndroid Build Coastguard Worker         return (padded_input_height - pooling_height()) / stride_height() + 1;
267*4bdc9457SAndroid Build Coastguard Worker       }
268*4bdc9457SAndroid Build Coastguard Worker     }
269*4bdc9457SAndroid Build Coastguard Worker   }
270*4bdc9457SAndroid Build Coastguard Worker 
output_width()271*4bdc9457SAndroid Build Coastguard Worker   inline size_t output_width() const {
272*4bdc9457SAndroid Build Coastguard Worker     if (padding_tf_same()) {
273*4bdc9457SAndroid Build Coastguard Worker       return (input_width() + stride_width() - 1) / stride_width();
274*4bdc9457SAndroid Build Coastguard Worker     } else {
275*4bdc9457SAndroid Build Coastguard Worker       const size_t padded_input_width = padding_left() + input_width() + padding_right();
276*4bdc9457SAndroid Build Coastguard Worker       if (padded_input_width <= pooling_width()) {
277*4bdc9457SAndroid Build Coastguard Worker         return 1;
278*4bdc9457SAndroid Build Coastguard Worker       } else {
279*4bdc9457SAndroid Build Coastguard Worker         return (padded_input_width - pooling_width()) / stride_width() + 1;
280*4bdc9457SAndroid Build Coastguard Worker       }
281*4bdc9457SAndroid Build Coastguard Worker     }
282*4bdc9457SAndroid Build Coastguard Worker   }
283*4bdc9457SAndroid Build Coastguard Worker 
input_pixel_stride(size_t input_pixel_stride)284*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& input_pixel_stride(size_t input_pixel_stride) {
285*4bdc9457SAndroid Build Coastguard Worker     assert(input_pixel_stride != 0);
286*4bdc9457SAndroid Build Coastguard Worker     this->input_pixel_stride_ = input_pixel_stride;
287*4bdc9457SAndroid Build Coastguard Worker     return *this;
288*4bdc9457SAndroid Build Coastguard Worker   }
289*4bdc9457SAndroid Build Coastguard Worker 
input_pixel_stride()290*4bdc9457SAndroid Build Coastguard Worker   inline size_t input_pixel_stride() const {
291*4bdc9457SAndroid Build Coastguard Worker     if (this->input_pixel_stride_ == 0) {
292*4bdc9457SAndroid Build Coastguard Worker       return channels();
293*4bdc9457SAndroid Build Coastguard Worker     } else {
294*4bdc9457SAndroid Build Coastguard Worker       assert(this->input_pixel_stride_ >= channels());
295*4bdc9457SAndroid Build Coastguard Worker       return this->input_pixel_stride_;
296*4bdc9457SAndroid Build Coastguard Worker     }
297*4bdc9457SAndroid Build Coastguard Worker   }
298*4bdc9457SAndroid Build Coastguard Worker 
output_pixel_stride(size_t output_pixel_stride)299*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& output_pixel_stride(size_t output_pixel_stride) {
300*4bdc9457SAndroid Build Coastguard Worker     assert(output_pixel_stride != 0);
301*4bdc9457SAndroid Build Coastguard Worker     this->output_pixel_stride_ = output_pixel_stride;
302*4bdc9457SAndroid Build Coastguard Worker     return *this;
303*4bdc9457SAndroid Build Coastguard Worker   }
304*4bdc9457SAndroid Build Coastguard Worker 
output_pixel_stride()305*4bdc9457SAndroid Build Coastguard Worker   inline size_t output_pixel_stride() const {
306*4bdc9457SAndroid Build Coastguard Worker     if (this->output_pixel_stride_ == 0) {
307*4bdc9457SAndroid Build Coastguard Worker       return channels();
308*4bdc9457SAndroid Build Coastguard Worker     } else {
309*4bdc9457SAndroid Build Coastguard Worker       assert(this->output_pixel_stride_ >= channels());
310*4bdc9457SAndroid Build Coastguard Worker       return this->output_pixel_stride_;
311*4bdc9457SAndroid Build Coastguard Worker     }
312*4bdc9457SAndroid Build Coastguard Worker   }
313*4bdc9457SAndroid Build Coastguard Worker 
next_input_size(uint32_t next_input_height,uint32_t next_input_width)314*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& next_input_size(uint32_t next_input_height, uint32_t next_input_width) {
315*4bdc9457SAndroid Build Coastguard Worker     assert(next_input_height >= 1);
316*4bdc9457SAndroid Build Coastguard Worker     assert(next_input_width >= 1);
317*4bdc9457SAndroid Build Coastguard Worker     this->next_input_height_ = next_input_height;
318*4bdc9457SAndroid Build Coastguard Worker     this->next_input_width_ = next_input_width;
319*4bdc9457SAndroid Build Coastguard Worker     return *this;
320*4bdc9457SAndroid Build Coastguard Worker   }
321*4bdc9457SAndroid Build Coastguard Worker 
next_input_height(uint32_t next_input_height)322*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& next_input_height(uint32_t next_input_height) {
323*4bdc9457SAndroid Build Coastguard Worker     assert(next_input_height >= 1);
324*4bdc9457SAndroid Build Coastguard Worker     this->next_input_height_ = next_input_height;
325*4bdc9457SAndroid Build Coastguard Worker     return *this;
326*4bdc9457SAndroid Build Coastguard Worker   }
327*4bdc9457SAndroid Build Coastguard Worker 
next_input_height()328*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t next_input_height() const {
329*4bdc9457SAndroid Build Coastguard Worker     if (this->next_input_height_ == 0) {
330*4bdc9457SAndroid Build Coastguard Worker       return input_height();
331*4bdc9457SAndroid Build Coastguard Worker     } else {
332*4bdc9457SAndroid Build Coastguard Worker       return this->next_input_height_;
333*4bdc9457SAndroid Build Coastguard Worker     }
334*4bdc9457SAndroid Build Coastguard Worker   }
335*4bdc9457SAndroid Build Coastguard Worker 
next_input_width(uint32_t next_input_width)336*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& next_input_width(uint32_t next_input_width) {
337*4bdc9457SAndroid Build Coastguard Worker     assert(next_input_width >= 1);
338*4bdc9457SAndroid Build Coastguard Worker     this->next_input_width_ = next_input_width;
339*4bdc9457SAndroid Build Coastguard Worker     return *this;
340*4bdc9457SAndroid Build Coastguard Worker   }
341*4bdc9457SAndroid Build Coastguard Worker 
next_input_width()342*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t next_input_width() const {
343*4bdc9457SAndroid Build Coastguard Worker     if (this->next_input_width_ == 0) {
344*4bdc9457SAndroid Build Coastguard Worker       return input_width();
345*4bdc9457SAndroid Build Coastguard Worker     } else {
346*4bdc9457SAndroid Build Coastguard Worker       return this->next_input_width_;
347*4bdc9457SAndroid Build Coastguard Worker     }
348*4bdc9457SAndroid Build Coastguard Worker   }
349*4bdc9457SAndroid Build Coastguard Worker 
next_output_height()350*4bdc9457SAndroid Build Coastguard Worker   inline size_t next_output_height() const {
351*4bdc9457SAndroid Build Coastguard Worker     const size_t padded_next_input_height = padding_top() + next_input_height() + padding_bottom();
352*4bdc9457SAndroid Build Coastguard Worker     if (padded_next_input_height <= pooling_height()) {
353*4bdc9457SAndroid Build Coastguard Worker       return 1;
354*4bdc9457SAndroid Build Coastguard Worker     } else {
355*4bdc9457SAndroid Build Coastguard Worker       return (padded_next_input_height - pooling_height()) / stride_height() + 1;
356*4bdc9457SAndroid Build Coastguard Worker     }
357*4bdc9457SAndroid Build Coastguard Worker   }
358*4bdc9457SAndroid Build Coastguard Worker 
next_output_width()359*4bdc9457SAndroid Build Coastguard Worker   inline size_t next_output_width() const {
360*4bdc9457SAndroid Build Coastguard Worker     const size_t padded_next_input_width = padding_left() + next_input_width() + padding_right();
361*4bdc9457SAndroid Build Coastguard Worker     if (padded_next_input_width <= pooling_width()) {
362*4bdc9457SAndroid Build Coastguard Worker       return 1;
363*4bdc9457SAndroid Build Coastguard Worker     } else {
364*4bdc9457SAndroid Build Coastguard Worker       return (padded_next_input_width - pooling_width()) / stride_width() + 1;
365*4bdc9457SAndroid Build Coastguard Worker     }
366*4bdc9457SAndroid Build Coastguard Worker   }
367*4bdc9457SAndroid Build Coastguard Worker 
next_batch_size(size_t next_batch_size)368*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& next_batch_size(size_t next_batch_size) {
369*4bdc9457SAndroid Build Coastguard Worker     assert(next_batch_size >= 1);
370*4bdc9457SAndroid Build Coastguard Worker     this->next_batch_size_ = next_batch_size;
371*4bdc9457SAndroid Build Coastguard Worker     return *this;
372*4bdc9457SAndroid Build Coastguard Worker   }
373*4bdc9457SAndroid Build Coastguard Worker 
next_batch_size()374*4bdc9457SAndroid Build Coastguard Worker   inline size_t next_batch_size() const {
375*4bdc9457SAndroid Build Coastguard Worker     if (this->next_batch_size_ == 0) {
376*4bdc9457SAndroid Build Coastguard Worker       return batch_size();
377*4bdc9457SAndroid Build Coastguard Worker     } else {
378*4bdc9457SAndroid Build Coastguard Worker       return this->next_batch_size_;
379*4bdc9457SAndroid Build Coastguard Worker     }
380*4bdc9457SAndroid Build Coastguard Worker   }
381*4bdc9457SAndroid Build Coastguard Worker 
input_scale(float input_scale)382*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& input_scale(float input_scale) {
383*4bdc9457SAndroid Build Coastguard Worker     assert(input_scale > 0.0f);
384*4bdc9457SAndroid Build Coastguard Worker     assert(std::isnormal(input_scale));
385*4bdc9457SAndroid Build Coastguard Worker     this->input_scale_ = input_scale;
386*4bdc9457SAndroid Build Coastguard Worker     return *this;
387*4bdc9457SAndroid Build Coastguard Worker   }
388*4bdc9457SAndroid Build Coastguard Worker 
input_scale()389*4bdc9457SAndroid Build Coastguard Worker   inline float input_scale() const {
390*4bdc9457SAndroid Build Coastguard Worker     return this->input_scale_;
391*4bdc9457SAndroid Build Coastguard Worker   }
392*4bdc9457SAndroid Build Coastguard Worker 
input_zero_point(uint8_t input_zero_point)393*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& input_zero_point(uint8_t input_zero_point) {
394*4bdc9457SAndroid Build Coastguard Worker     this->input_zero_point_ = input_zero_point;
395*4bdc9457SAndroid Build Coastguard Worker     return *this;
396*4bdc9457SAndroid Build Coastguard Worker   }
397*4bdc9457SAndroid Build Coastguard Worker 
input_zero_point()398*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t input_zero_point() const {
399*4bdc9457SAndroid Build Coastguard Worker     return this->input_zero_point_;
400*4bdc9457SAndroid Build Coastguard Worker   }
401*4bdc9457SAndroid Build Coastguard Worker 
output_scale(float output_scale)402*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& output_scale(float output_scale) {
403*4bdc9457SAndroid Build Coastguard Worker     assert(output_scale > 0.0f);
404*4bdc9457SAndroid Build Coastguard Worker     assert(std::isnormal(output_scale));
405*4bdc9457SAndroid Build Coastguard Worker     this->output_scale_ = output_scale;
406*4bdc9457SAndroid Build Coastguard Worker     return *this;
407*4bdc9457SAndroid Build Coastguard Worker   }
408*4bdc9457SAndroid Build Coastguard Worker 
output_scale()409*4bdc9457SAndroid Build Coastguard Worker   inline float output_scale() const {
410*4bdc9457SAndroid Build Coastguard Worker     return this->output_scale_;
411*4bdc9457SAndroid Build Coastguard Worker   }
412*4bdc9457SAndroid Build Coastguard Worker 
output_zero_point(uint8_t output_zero_point)413*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& output_zero_point(uint8_t output_zero_point) {
414*4bdc9457SAndroid Build Coastguard Worker     this->output_zero_point_ = output_zero_point;
415*4bdc9457SAndroid Build Coastguard Worker     return *this;
416*4bdc9457SAndroid Build Coastguard Worker   }
417*4bdc9457SAndroid Build Coastguard Worker 
output_zero_point()418*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t output_zero_point() const {
419*4bdc9457SAndroid Build Coastguard Worker     return this->output_zero_point_;
420*4bdc9457SAndroid Build Coastguard Worker   }
421*4bdc9457SAndroid Build Coastguard Worker 
qmin(uint8_t qmin)422*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& qmin(uint8_t qmin) {
423*4bdc9457SAndroid Build Coastguard Worker     this->qmin_ = qmin;
424*4bdc9457SAndroid Build Coastguard Worker     return *this;
425*4bdc9457SAndroid Build Coastguard Worker   }
426*4bdc9457SAndroid Build Coastguard Worker 
qmin()427*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t qmin() const {
428*4bdc9457SAndroid Build Coastguard Worker     return this->qmin_;
429*4bdc9457SAndroid Build Coastguard Worker   }
430*4bdc9457SAndroid Build Coastguard Worker 
qmax(uint8_t qmax)431*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& qmax(uint8_t qmax) {
432*4bdc9457SAndroid Build Coastguard Worker     this->qmax_ = qmax;
433*4bdc9457SAndroid Build Coastguard Worker     return *this;
434*4bdc9457SAndroid Build Coastguard Worker   }
435*4bdc9457SAndroid Build Coastguard Worker 
qmax()436*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t qmax() const {
437*4bdc9457SAndroid Build Coastguard Worker     return this->qmax_;
438*4bdc9457SAndroid Build Coastguard Worker   }
439*4bdc9457SAndroid Build Coastguard Worker 
iterations(size_t iterations)440*4bdc9457SAndroid Build Coastguard Worker   inline AveragePoolingOperatorTester& iterations(size_t iterations) {
441*4bdc9457SAndroid Build Coastguard Worker     this->iterations_ = iterations;
442*4bdc9457SAndroid Build Coastguard Worker     return *this;
443*4bdc9457SAndroid Build Coastguard Worker   }
444*4bdc9457SAndroid Build Coastguard Worker 
iterations()445*4bdc9457SAndroid Build Coastguard Worker   inline size_t iterations() const {
446*4bdc9457SAndroid Build Coastguard Worker     return this->iterations_;
447*4bdc9457SAndroid Build Coastguard Worker   }
448*4bdc9457SAndroid Build Coastguard Worker 
TestF16()449*4bdc9457SAndroid Build Coastguard Worker   void TestF16() const {
450*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
451*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
452*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
453*4bdc9457SAndroid Build Coastguard Worker 
454*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> input((batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
455*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels());
456*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels());
457*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
458*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
459*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
460*4bdc9457SAndroid Build Coastguard Worker 
461*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
462*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
463*4bdc9457SAndroid Build Coastguard Worker         for (size_t oy = 0; oy < output_height(); oy++) {
464*4bdc9457SAndroid Build Coastguard Worker           for (size_t ox = 0; ox < output_width(); ox++) {
465*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < channels(); c++) {
466*4bdc9457SAndroid Build Coastguard Worker               float acc = 0.0f;
467*4bdc9457SAndroid Build Coastguard Worker               int32_t n = 0;
468*4bdc9457SAndroid Build Coastguard Worker               for (size_t py = 0; py < pooling_height(); py++) {
469*4bdc9457SAndroid Build Coastguard Worker                 const size_t iy = oy * stride_height() + py - padding_top();
470*4bdc9457SAndroid Build Coastguard Worker                 for (size_t px = 0; px < pooling_width(); px++) {
471*4bdc9457SAndroid Build Coastguard Worker                   const size_t ix = ox * stride_width() + px - padding_left();
472*4bdc9457SAndroid Build Coastguard Worker                   if (ix < input_width() && iy < input_height()) {
473*4bdc9457SAndroid Build Coastguard Worker                     acc += fp16_ieee_to_fp32_value(input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c]);
474*4bdc9457SAndroid Build Coastguard Worker                     n += 1;
475*4bdc9457SAndroid Build Coastguard Worker                   }
476*4bdc9457SAndroid Build Coastguard Worker                 }
477*4bdc9457SAndroid Build Coastguard Worker               }
478*4bdc9457SAndroid Build Coastguard Worker               output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = acc / float(n);
479*4bdc9457SAndroid Build Coastguard Worker             }
480*4bdc9457SAndroid Build Coastguard Worker           }
481*4bdc9457SAndroid Build Coastguard Worker         }
482*4bdc9457SAndroid Build Coastguard Worker       }
483*4bdc9457SAndroid Build Coastguard Worker 
484*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
485*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
486*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
487*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
488*4bdc9457SAndroid Build Coastguard Worker       float output_min = accumulated_min + accumulated_range / 255.0f * float(qmin());
489*4bdc9457SAndroid Build Coastguard Worker       float output_max = accumulated_max - accumulated_range / 255.0f * float(255 - qmax());
490*4bdc9457SAndroid Build Coastguard Worker       output_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_min));
491*4bdc9457SAndroid Build Coastguard Worker       output_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_max));
492*4bdc9457SAndroid Build Coastguard Worker       if (accumulated_range == 0.0f) {
493*4bdc9457SAndroid Build Coastguard Worker         output_min = -std::numeric_limits<float>::infinity();
494*4bdc9457SAndroid Build Coastguard Worker         output_max = +std::numeric_limits<float>::infinity();
495*4bdc9457SAndroid Build Coastguard Worker       }
496*4bdc9457SAndroid Build Coastguard Worker       if (qmin() == std::numeric_limits<uint8_t>::min()) {
497*4bdc9457SAndroid Build Coastguard Worker         output_min = -std::numeric_limits<float>::infinity();
498*4bdc9457SAndroid Build Coastguard Worker       }
499*4bdc9457SAndroid Build Coastguard Worker       if (qmax() == std::numeric_limits<uint8_t>::max()) {
500*4bdc9457SAndroid Build Coastguard Worker         output_max = +std::numeric_limits<float>::infinity();
501*4bdc9457SAndroid Build Coastguard Worker       }
502*4bdc9457SAndroid Build Coastguard Worker 
503*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
504*4bdc9457SAndroid Build Coastguard Worker       for (float& value : output_ref) {
505*4bdc9457SAndroid Build Coastguard Worker         value = std::max(std::min(value, output_max), output_min);
506*4bdc9457SAndroid Build Coastguard Worker       }
507*4bdc9457SAndroid Build Coastguard Worker 
508*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Average Pooling operator.
509*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
510*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t average_pooling_op = nullptr;
511*4bdc9457SAndroid Build Coastguard Worker 
512*4bdc9457SAndroid Build Coastguard Worker       const xnn_status status = xnn_create_average_pooling2d_nhwc_f16(
513*4bdc9457SAndroid Build Coastguard Worker           padding_top(), padding_right(), padding_bottom(), padding_left(),
514*4bdc9457SAndroid Build Coastguard Worker           pooling_height(), pooling_width(),
515*4bdc9457SAndroid Build Coastguard Worker           stride_height(), stride_width(),
516*4bdc9457SAndroid Build Coastguard Worker           channels(), input_pixel_stride(), output_pixel_stride(),
517*4bdc9457SAndroid Build Coastguard Worker           output_min, output_max,
518*4bdc9457SAndroid Build Coastguard Worker           0, &average_pooling_op);
519*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
520*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
521*4bdc9457SAndroid Build Coastguard Worker       }
522*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
523*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, average_pooling_op);
524*4bdc9457SAndroid Build Coastguard Worker 
525*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete average_pooling_op.
526*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_average_pooling_op(average_pooling_op, xnn_delete_operator);
527*4bdc9457SAndroid Build Coastguard Worker 
528*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
529*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_average_pooling2d_nhwc_f16(
530*4bdc9457SAndroid Build Coastguard Worker           average_pooling_op,
531*4bdc9457SAndroid Build Coastguard Worker           batch_size(), input_height(), input_width(),
532*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
533*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
534*4bdc9457SAndroid Build Coastguard Worker 
535*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
536*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(average_pooling_op, nullptr /* thread pool */));
537*4bdc9457SAndroid Build Coastguard Worker 
538*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
539*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
540*4bdc9457SAndroid Build Coastguard Worker         for (size_t y = 0; y < output_height(); y++) {
541*4bdc9457SAndroid Build Coastguard Worker           for (size_t x = 0; x < output_width(); x++) {
542*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < channels(); c++) {
543*4bdc9457SAndroid Build Coastguard Worker               ASSERT_LE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), output_max);
544*4bdc9457SAndroid Build Coastguard Worker               ASSERT_GE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), output_min);
545*4bdc9457SAndroid Build Coastguard Worker               ASSERT_NEAR(
546*4bdc9457SAndroid Build Coastguard Worker                   fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]),
547*4bdc9457SAndroid Build Coastguard Worker                   output_ref[((i * output_height() + y) * output_width() + x) * channels() + c],
548*4bdc9457SAndroid Build Coastguard Worker                   std::max(1.0e-3f, std::abs(output_ref[((i * output_height() + y) * output_width() + x) * channels() + c]) * 1.0e-2f)) <<
549*4bdc9457SAndroid Build Coastguard Worker                 "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c;
550*4bdc9457SAndroid Build Coastguard Worker             }
551*4bdc9457SAndroid Build Coastguard Worker           }
552*4bdc9457SAndroid Build Coastguard Worker         }
553*4bdc9457SAndroid Build Coastguard Worker       }
554*4bdc9457SAndroid Build Coastguard Worker     }
555*4bdc9457SAndroid Build Coastguard Worker   }
556*4bdc9457SAndroid Build Coastguard Worker 
TestF32()557*4bdc9457SAndroid Build Coastguard Worker   void TestF32() const {
558*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
559*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
560*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
561*4bdc9457SAndroid Build Coastguard Worker 
562*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input((batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float));
563*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels());
564*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels());
565*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
566*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
567*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), std::nanf(""));
568*4bdc9457SAndroid Build Coastguard Worker 
569*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
570*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
571*4bdc9457SAndroid Build Coastguard Worker         for (size_t oy = 0; oy < output_height(); oy++) {
572*4bdc9457SAndroid Build Coastguard Worker           for (size_t ox = 0; ox < output_width(); ox++) {
573*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < channels(); c++) {
574*4bdc9457SAndroid Build Coastguard Worker               float acc = 0.0f;
575*4bdc9457SAndroid Build Coastguard Worker               int32_t n = 0;
576*4bdc9457SAndroid Build Coastguard Worker               for (size_t py = 0; py < pooling_height(); py++) {
577*4bdc9457SAndroid Build Coastguard Worker                 const size_t iy = oy * stride_height() + py - padding_top();
578*4bdc9457SAndroid Build Coastguard Worker                 for (size_t px = 0; px < pooling_width(); px++) {
579*4bdc9457SAndroid Build Coastguard Worker                   const size_t ix = ox * stride_width() + px - padding_left();
580*4bdc9457SAndroid Build Coastguard Worker                   if (ix < input_width() && iy < input_height()) {
581*4bdc9457SAndroid Build Coastguard Worker                     acc += input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c];
582*4bdc9457SAndroid Build Coastguard Worker                     n += 1;
583*4bdc9457SAndroid Build Coastguard Worker                   }
584*4bdc9457SAndroid Build Coastguard Worker                 }
585*4bdc9457SAndroid Build Coastguard Worker               }
586*4bdc9457SAndroid Build Coastguard Worker               output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = acc / float(n);
587*4bdc9457SAndroid Build Coastguard Worker             }
588*4bdc9457SAndroid Build Coastguard Worker           }
589*4bdc9457SAndroid Build Coastguard Worker         }
590*4bdc9457SAndroid Build Coastguard Worker       }
591*4bdc9457SAndroid Build Coastguard Worker 
592*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
593*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
594*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
595*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
596*4bdc9457SAndroid Build Coastguard Worker       const float output_min = accumulated_range == 0.0f ?
597*4bdc9457SAndroid Build Coastguard Worker         -std::numeric_limits<float>::infinity() :
598*4bdc9457SAndroid Build Coastguard Worker         accumulated_min + accumulated_range / 255.0f * float(qmin());
599*4bdc9457SAndroid Build Coastguard Worker       const float output_max = accumulated_range == 0.0f ?
600*4bdc9457SAndroid Build Coastguard Worker         +std::numeric_limits<float>::infinity() :
601*4bdc9457SAndroid Build Coastguard Worker         accumulated_max - accumulated_range / 255.0f * float(255 - qmax());
602*4bdc9457SAndroid Build Coastguard Worker 
603*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
604*4bdc9457SAndroid Build Coastguard Worker       for (float& value : output_ref) {
605*4bdc9457SAndroid Build Coastguard Worker         value = std::max(std::min(value, output_max), output_min);
606*4bdc9457SAndroid Build Coastguard Worker       }
607*4bdc9457SAndroid Build Coastguard Worker 
608*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Average Pooling operator.
609*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
610*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t average_pooling_op = nullptr;
611*4bdc9457SAndroid Build Coastguard Worker 
612*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
613*4bdc9457SAndroid Build Coastguard Worker         xnn_create_average_pooling2d_nhwc_f32(
614*4bdc9457SAndroid Build Coastguard Worker           padding_top(), padding_right(), padding_bottom(), padding_left(),
615*4bdc9457SAndroid Build Coastguard Worker           pooling_height(), pooling_width(),
616*4bdc9457SAndroid Build Coastguard Worker           stride_height(), stride_width(),
617*4bdc9457SAndroid Build Coastguard Worker           channels(), input_pixel_stride(), output_pixel_stride(),
618*4bdc9457SAndroid Build Coastguard Worker           output_min, output_max,
619*4bdc9457SAndroid Build Coastguard Worker           0, &average_pooling_op));
620*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, average_pooling_op);
621*4bdc9457SAndroid Build Coastguard Worker 
622*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete average_pooling_op.
623*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_average_pooling_op(average_pooling_op, xnn_delete_operator);
624*4bdc9457SAndroid Build Coastguard Worker 
625*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
626*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_average_pooling2d_nhwc_f32(
627*4bdc9457SAndroid Build Coastguard Worker           average_pooling_op,
628*4bdc9457SAndroid Build Coastguard Worker           batch_size(), input_height(), input_width(),
629*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
630*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
631*4bdc9457SAndroid Build Coastguard Worker 
632*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
633*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(average_pooling_op, nullptr /* thread pool */));
634*4bdc9457SAndroid Build Coastguard Worker 
635*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
636*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
637*4bdc9457SAndroid Build Coastguard Worker         for (size_t y = 0; y < output_height(); y++) {
638*4bdc9457SAndroid Build Coastguard Worker           for (size_t x = 0; x < output_width(); x++) {
639*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < channels(); c++) {
640*4bdc9457SAndroid Build Coastguard Worker               ASSERT_LE(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c], output_max);
641*4bdc9457SAndroid Build Coastguard Worker               ASSERT_GE(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c], output_min);
642*4bdc9457SAndroid Build Coastguard Worker               ASSERT_NEAR(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c],
643*4bdc9457SAndroid Build Coastguard Worker                   output_ref[((i * output_height() + y) * output_width() + x) * channels() + c],
644*4bdc9457SAndroid Build Coastguard Worker                   std::abs(output_ref[((i * output_height() + y) * output_width() + x) * channels() + c]) * 1.0e-6f) <<
645*4bdc9457SAndroid Build Coastguard Worker                 "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c;
646*4bdc9457SAndroid Build Coastguard Worker             }
647*4bdc9457SAndroid Build Coastguard Worker           }
648*4bdc9457SAndroid Build Coastguard Worker         }
649*4bdc9457SAndroid Build Coastguard Worker       }
650*4bdc9457SAndroid Build Coastguard Worker     }
651*4bdc9457SAndroid Build Coastguard Worker   }
652*4bdc9457SAndroid Build Coastguard Worker 
TestQU8()653*4bdc9457SAndroid Build Coastguard Worker   void TestQU8() const {
654*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
655*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
656*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> u8dist(
657*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
658*4bdc9457SAndroid Build Coastguard Worker 
659*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input((batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint8_t));
660*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels());
661*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels());
662*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
663*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
664*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT8_C(0xA5));
665*4bdc9457SAndroid Build Coastguard Worker 
666*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
667*4bdc9457SAndroid Build Coastguard Worker       const double scale = double(input_scale()) / (double(output_scale()) * double(pooling_height() * pooling_width()));
668*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
669*4bdc9457SAndroid Build Coastguard Worker         for (size_t oy = 0; oy < output_height(); oy++) {
670*4bdc9457SAndroid Build Coastguard Worker           for (size_t ox = 0; ox < output_width(); ox++) {
671*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < channels(); c++) {
672*4bdc9457SAndroid Build Coastguard Worker               double acc = 0.0f;
673*4bdc9457SAndroid Build Coastguard Worker               for (size_t py = 0; py < pooling_height(); py++) {
674*4bdc9457SAndroid Build Coastguard Worker                 const size_t iy = oy * stride_height() + py - padding_top();
675*4bdc9457SAndroid Build Coastguard Worker                 for (size_t px = 0; px < pooling_width(); px++) {
676*4bdc9457SAndroid Build Coastguard Worker                   const size_t ix = ox * stride_width() + px - padding_left();
677*4bdc9457SAndroid Build Coastguard Worker                   if (ix < input_width() && iy < input_height()) {
678*4bdc9457SAndroid Build Coastguard Worker                     acc += double(int32_t(input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c]) - int32_t(input_zero_point()));
679*4bdc9457SAndroid Build Coastguard Worker                   }
680*4bdc9457SAndroid Build Coastguard Worker                 }
681*4bdc9457SAndroid Build Coastguard Worker               }
682*4bdc9457SAndroid Build Coastguard Worker               output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = float(acc * scale + double(output_zero_point()));
683*4bdc9457SAndroid Build Coastguard Worker               output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] =
684*4bdc9457SAndroid Build Coastguard Worker                 std::min<float>(output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c], float(qmax()));
685*4bdc9457SAndroid Build Coastguard Worker               output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] =
686*4bdc9457SAndroid Build Coastguard Worker                 std::max<float>(output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c], float(qmin()));
687*4bdc9457SAndroid Build Coastguard Worker             }
688*4bdc9457SAndroid Build Coastguard Worker           }
689*4bdc9457SAndroid Build Coastguard Worker         }
690*4bdc9457SAndroid Build Coastguard Worker       }
691*4bdc9457SAndroid Build Coastguard Worker 
692*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Average Pooling operator.
693*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
694*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t average_pooling_op = nullptr;
695*4bdc9457SAndroid Build Coastguard Worker 
696*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
697*4bdc9457SAndroid Build Coastguard Worker         xnn_create_average_pooling2d_nhwc_qu8(
698*4bdc9457SAndroid Build Coastguard Worker           padding_top(), padding_right(), padding_bottom(), padding_left(),
699*4bdc9457SAndroid Build Coastguard Worker           pooling_height(), pooling_width(),
700*4bdc9457SAndroid Build Coastguard Worker           stride_height(), stride_width(),
701*4bdc9457SAndroid Build Coastguard Worker           channels(), input_pixel_stride(), output_pixel_stride(),
702*4bdc9457SAndroid Build Coastguard Worker           input_zero_point(), input_scale(),
703*4bdc9457SAndroid Build Coastguard Worker           output_zero_point(), output_scale(),
704*4bdc9457SAndroid Build Coastguard Worker           qmin(), qmax(),
705*4bdc9457SAndroid Build Coastguard Worker           0, &average_pooling_op));
706*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, average_pooling_op);
707*4bdc9457SAndroid Build Coastguard Worker 
708*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete average_pooling_op.
709*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_average_pooling_op(average_pooling_op, xnn_delete_operator);
710*4bdc9457SAndroid Build Coastguard Worker 
711*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
712*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_average_pooling2d_nhwc_qu8(
713*4bdc9457SAndroid Build Coastguard Worker           average_pooling_op,
714*4bdc9457SAndroid Build Coastguard Worker           batch_size(), input_height(), input_width(),
715*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
716*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
717*4bdc9457SAndroid Build Coastguard Worker 
718*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
719*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(average_pooling_op, nullptr /* thread pool */));
720*4bdc9457SAndroid Build Coastguard Worker 
721*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
722*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
723*4bdc9457SAndroid Build Coastguard Worker         for (size_t y = 0; y < output_height(); y++) {
724*4bdc9457SAndroid Build Coastguard Worker           for (size_t x = 0; x < output_width(); x++) {
725*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < channels(); c++) {
726*4bdc9457SAndroid Build Coastguard Worker               ASSERT_LE(uint32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), uint32_t(qmax()));
727*4bdc9457SAndroid Build Coastguard Worker               ASSERT_GE(uint32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), uint32_t(qmin()));
728*4bdc9457SAndroid Build Coastguard Worker               ASSERT_NEAR(float(int32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c])),
729*4bdc9457SAndroid Build Coastguard Worker                 output_ref[((i * output_height() + y) * output_width() + x) * channels() + c], 0.80f) <<
730*4bdc9457SAndroid Build Coastguard Worker                 "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c;
731*4bdc9457SAndroid Build Coastguard Worker             }
732*4bdc9457SAndroid Build Coastguard Worker           }
733*4bdc9457SAndroid Build Coastguard Worker         }
734*4bdc9457SAndroid Build Coastguard Worker       }
735*4bdc9457SAndroid Build Coastguard Worker     }
736*4bdc9457SAndroid Build Coastguard Worker   }
737*4bdc9457SAndroid Build Coastguard Worker 
TestSetupF16()738*4bdc9457SAndroid Build Coastguard Worker   void TestSetupF16() const {
739*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
740*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
741*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
742*4bdc9457SAndroid Build Coastguard Worker 
743*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + std::max<size_t>(
744*4bdc9457SAndroid Build Coastguard Worker       (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels(),
745*4bdc9457SAndroid Build Coastguard Worker       (next_batch_size() * next_input_height() * next_input_width() - 1) * input_pixel_stride() + channels()));
746*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> output(std::max<size_t>(
747*4bdc9457SAndroid Build Coastguard Worker       (batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels(),
748*4bdc9457SAndroid Build Coastguard Worker       (next_batch_size() * next_output_height() * next_output_width() - 1) * output_pixel_stride() + channels()));
749*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels());
750*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * channels());
751*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
752*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
753*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
754*4bdc9457SAndroid Build Coastguard Worker 
755*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
756*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
757*4bdc9457SAndroid Build Coastguard Worker         for (size_t oy = 0; oy < output_height(); oy++) {
758*4bdc9457SAndroid Build Coastguard Worker           for (size_t ox = 0; ox < output_width(); ox++) {
759*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < channels(); c++) {
760*4bdc9457SAndroid Build Coastguard Worker               float acc = 0.0f;
761*4bdc9457SAndroid Build Coastguard Worker               size_t n = 0;
762*4bdc9457SAndroid Build Coastguard Worker               for (size_t py = 0; py < pooling_height(); py++) {
763*4bdc9457SAndroid Build Coastguard Worker                 const size_t iy = oy * stride_height() + py - padding_top();
764*4bdc9457SAndroid Build Coastguard Worker                 for (size_t px = 0; px < pooling_width(); px++) {
765*4bdc9457SAndroid Build Coastguard Worker                   const size_t ix = ox * stride_width() + px - padding_left();
766*4bdc9457SAndroid Build Coastguard Worker                   if (ix < input_width() && iy < input_height()) {
767*4bdc9457SAndroid Build Coastguard Worker                     acc += fp16_ieee_to_fp32_value(input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c]);
768*4bdc9457SAndroid Build Coastguard Worker                     n += 1;
769*4bdc9457SAndroid Build Coastguard Worker                   }
770*4bdc9457SAndroid Build Coastguard Worker                 }
771*4bdc9457SAndroid Build Coastguard Worker               }
772*4bdc9457SAndroid Build Coastguard Worker               output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = acc / float(n);
773*4bdc9457SAndroid Build Coastguard Worker             }
774*4bdc9457SAndroid Build Coastguard Worker           }
775*4bdc9457SAndroid Build Coastguard Worker         }
776*4bdc9457SAndroid Build Coastguard Worker       }
777*4bdc9457SAndroid Build Coastguard Worker 
778*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
779*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
780*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
781*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
782*4bdc9457SAndroid Build Coastguard Worker       float output_min = accumulated_min + accumulated_range / 255.0f * float(qmin());
783*4bdc9457SAndroid Build Coastguard Worker       float output_max = accumulated_max - accumulated_range / 255.0f * float(255 - qmax());
784*4bdc9457SAndroid Build Coastguard Worker       output_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_min));
785*4bdc9457SAndroid Build Coastguard Worker       output_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_max));
786*4bdc9457SAndroid Build Coastguard Worker       if (accumulated_range == 0.0f) {
787*4bdc9457SAndroid Build Coastguard Worker         output_min = -std::numeric_limits<float>::infinity();
788*4bdc9457SAndroid Build Coastguard Worker         output_max = +std::numeric_limits<float>::infinity();
789*4bdc9457SAndroid Build Coastguard Worker       }
790*4bdc9457SAndroid Build Coastguard Worker       if (qmin() == std::numeric_limits<uint8_t>::min()) {
791*4bdc9457SAndroid Build Coastguard Worker         output_min = -std::numeric_limits<float>::infinity();
792*4bdc9457SAndroid Build Coastguard Worker       }
793*4bdc9457SAndroid Build Coastguard Worker       if (qmax() == std::numeric_limits<uint8_t>::max()) {
794*4bdc9457SAndroid Build Coastguard Worker         output_max = +std::numeric_limits<float>::infinity();
795*4bdc9457SAndroid Build Coastguard Worker       }
796*4bdc9457SAndroid Build Coastguard Worker 
797*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
798*4bdc9457SAndroid Build Coastguard Worker       for (float& value : output_ref) {
799*4bdc9457SAndroid Build Coastguard Worker         value = std::max(std::min(value, output_max), output_min);
800*4bdc9457SAndroid Build Coastguard Worker       }
801*4bdc9457SAndroid Build Coastguard Worker 
802*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, and run Average Pooling operator once.
803*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
804*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t average_pooling_op = nullptr;
805*4bdc9457SAndroid Build Coastguard Worker 
806*4bdc9457SAndroid Build Coastguard Worker       const xnn_status status = xnn_create_average_pooling2d_nhwc_f16(
807*4bdc9457SAndroid Build Coastguard Worker           padding_top(), padding_right(), padding_bottom(), padding_left(),
808*4bdc9457SAndroid Build Coastguard Worker           pooling_height(), pooling_width(),
809*4bdc9457SAndroid Build Coastguard Worker           stride_height(), stride_width(),
810*4bdc9457SAndroid Build Coastguard Worker           channels(), input_pixel_stride(), output_pixel_stride(),
811*4bdc9457SAndroid Build Coastguard Worker           output_min, output_max,
812*4bdc9457SAndroid Build Coastguard Worker           0, &average_pooling_op);
813*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
814*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
815*4bdc9457SAndroid Build Coastguard Worker       }
816*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
817*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, average_pooling_op);
818*4bdc9457SAndroid Build Coastguard Worker 
819*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
820*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_average_pooling2d_nhwc_f16(
821*4bdc9457SAndroid Build Coastguard Worker           average_pooling_op,
822*4bdc9457SAndroid Build Coastguard Worker           batch_size(), input_height(), input_width(),
823*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
824*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
825*4bdc9457SAndroid Build Coastguard Worker 
826*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
827*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(average_pooling_op, nullptr /* thread pool */));
828*4bdc9457SAndroid Build Coastguard Worker 
829*4bdc9457SAndroid Build Coastguard Worker       // Verify results of the first run.
830*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
831*4bdc9457SAndroid Build Coastguard Worker         for (size_t y = 0; y < output_height(); y++) {
832*4bdc9457SAndroid Build Coastguard Worker           for (size_t x = 0; x < output_width(); x++) {
833*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < channels(); c++) {
834*4bdc9457SAndroid Build Coastguard Worker               ASSERT_LE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), output_max);
835*4bdc9457SAndroid Build Coastguard Worker               ASSERT_GE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), output_min);
836*4bdc9457SAndroid Build Coastguard Worker               ASSERT_NEAR(
837*4bdc9457SAndroid Build Coastguard Worker                   fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]),
838*4bdc9457SAndroid Build Coastguard Worker                   output_ref[((i * output_height() + y) * output_width() + x) * channels() + c],
839*4bdc9457SAndroid Build Coastguard Worker                   std::max(1.0e-3f, std::abs(output_ref[((i * output_height() + y) * output_width() + x) * channels() + c]) * 1.0e-2f)) <<
840*4bdc9457SAndroid Build Coastguard Worker                 "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c;
841*4bdc9457SAndroid Build Coastguard Worker             }
842*4bdc9457SAndroid Build Coastguard Worker           }
843*4bdc9457SAndroid Build Coastguard Worker         }
844*4bdc9457SAndroid Build Coastguard Worker       }
845*4bdc9457SAndroid Build Coastguard Worker 
846*4bdc9457SAndroid Build Coastguard Worker       // Re-generate data for the second run.
847*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
848*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
849*4bdc9457SAndroid Build Coastguard Worker 
850*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results for the second run.
851*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < next_batch_size(); i++) {
852*4bdc9457SAndroid Build Coastguard Worker         for (size_t oy = 0; oy < next_output_height(); oy++) {
853*4bdc9457SAndroid Build Coastguard Worker           for (size_t ox = 0; ox < next_output_width(); ox++) {
854*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < channels(); c++) {
855*4bdc9457SAndroid Build Coastguard Worker               float acc = 0.0f;
856*4bdc9457SAndroid Build Coastguard Worker               int32_t n = 0;
857*4bdc9457SAndroid Build Coastguard Worker               for (size_t py = 0; py < pooling_height(); py++) {
858*4bdc9457SAndroid Build Coastguard Worker                 const size_t iy = oy * stride_height() + py - padding_top();
859*4bdc9457SAndroid Build Coastguard Worker                 for (size_t px = 0; px < pooling_width(); px++) {
860*4bdc9457SAndroid Build Coastguard Worker                   const size_t ix = ox * stride_width() + px - padding_left();
861*4bdc9457SAndroid Build Coastguard Worker                   if (ix < next_input_width() && iy < next_input_height()) {
862*4bdc9457SAndroid Build Coastguard Worker                     acc += fp16_ieee_to_fp32_value(input[((i * next_input_height() + iy) * next_input_width() + ix) * input_pixel_stride() + c]);
863*4bdc9457SAndroid Build Coastguard Worker                     n += 1;
864*4bdc9457SAndroid Build Coastguard Worker                   }
865*4bdc9457SAndroid Build Coastguard Worker                 }
866*4bdc9457SAndroid Build Coastguard Worker               }
867*4bdc9457SAndroid Build Coastguard Worker               next_output_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c] =
868*4bdc9457SAndroid Build Coastguard Worker                 std::max(std::min(acc / float(n), output_max), output_min);
869*4bdc9457SAndroid Build Coastguard Worker             }
870*4bdc9457SAndroid Build Coastguard Worker           }
871*4bdc9457SAndroid Build Coastguard Worker         }
872*4bdc9457SAndroid Build Coastguard Worker       }
873*4bdc9457SAndroid Build Coastguard Worker 
874*4bdc9457SAndroid Build Coastguard Worker       // Setup and run Average Pooling operator the second time, and destroy the operator.
875*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
876*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_average_pooling2d_nhwc_f16(
877*4bdc9457SAndroid Build Coastguard Worker           average_pooling_op,
878*4bdc9457SAndroid Build Coastguard Worker           next_batch_size(), next_input_height(), next_input_width(),
879*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
880*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
881*4bdc9457SAndroid Build Coastguard Worker 
882*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
883*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(average_pooling_op, nullptr /* thread pool */));
884*4bdc9457SAndroid Build Coastguard Worker 
885*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
886*4bdc9457SAndroid Build Coastguard Worker         xnn_delete_operator(average_pooling_op));
887*4bdc9457SAndroid Build Coastguard Worker       average_pooling_op = nullptr;
888*4bdc9457SAndroid Build Coastguard Worker 
889*4bdc9457SAndroid Build Coastguard Worker       // Verify results of the second run.
890*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < next_batch_size(); i++) {
891*4bdc9457SAndroid Build Coastguard Worker         for (size_t y = 0; y < next_output_height(); y++) {
892*4bdc9457SAndroid Build Coastguard Worker           for (size_t x = 0; x < next_output_width(); x++) {
893*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < channels(); c++) {
894*4bdc9457SAndroid Build Coastguard Worker               ASSERT_LE(fp16_ieee_to_fp32_value(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c]), output_max);
895*4bdc9457SAndroid Build Coastguard Worker               ASSERT_GE(fp16_ieee_to_fp32_value(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c]), output_min);
896*4bdc9457SAndroid Build Coastguard Worker               ASSERT_NEAR(
897*4bdc9457SAndroid Build Coastguard Worker                   fp16_ieee_to_fp32_value(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c]),
898*4bdc9457SAndroid Build Coastguard Worker                   next_output_ref[((i * next_output_height() + y) * next_output_width() + x) * channels() + c],
899*4bdc9457SAndroid Build Coastguard Worker                   std::max(1.0e-3f, std::abs(next_output_ref[((i * next_output_height() + y) * next_output_width() + x) * channels() + c]) * 1.0e-2f)) <<
900*4bdc9457SAndroid Build Coastguard Worker                 "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c;
901*4bdc9457SAndroid Build Coastguard Worker             }
902*4bdc9457SAndroid Build Coastguard Worker           }
903*4bdc9457SAndroid Build Coastguard Worker         }
904*4bdc9457SAndroid Build Coastguard Worker       }
905*4bdc9457SAndroid Build Coastguard Worker     }
906*4bdc9457SAndroid Build Coastguard Worker   }
907*4bdc9457SAndroid Build Coastguard Worker 
TestSetupF32()908*4bdc9457SAndroid Build Coastguard Worker   void TestSetupF32() const {
909*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
910*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
911*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
912*4bdc9457SAndroid Build Coastguard Worker 
913*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + std::max<size_t>(
914*4bdc9457SAndroid Build Coastguard Worker       (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels(),
915*4bdc9457SAndroid Build Coastguard Worker       (next_batch_size() * next_input_height() * next_input_width() - 1) * input_pixel_stride() + channels()));
916*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output(std::max<size_t>(
917*4bdc9457SAndroid Build Coastguard Worker       (batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels(),
918*4bdc9457SAndroid Build Coastguard Worker       (next_batch_size() * next_output_height() * next_output_width() - 1) * output_pixel_stride() + channels()));
919*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels());
920*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * channels());
921*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
922*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
923*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), std::nanf(""));
924*4bdc9457SAndroid Build Coastguard Worker 
925*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
926*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
927*4bdc9457SAndroid Build Coastguard Worker         for (size_t oy = 0; oy < output_height(); oy++) {
928*4bdc9457SAndroid Build Coastguard Worker           for (size_t ox = 0; ox < output_width(); ox++) {
929*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < channels(); c++) {
930*4bdc9457SAndroid Build Coastguard Worker               float acc = 0.0f;
931*4bdc9457SAndroid Build Coastguard Worker               size_t n = 0;
932*4bdc9457SAndroid Build Coastguard Worker               for (size_t py = 0; py < pooling_height(); py++) {
933*4bdc9457SAndroid Build Coastguard Worker                 const size_t iy = oy * stride_height() + py - padding_top();
934*4bdc9457SAndroid Build Coastguard Worker                 for (size_t px = 0; px < pooling_width(); px++) {
935*4bdc9457SAndroid Build Coastguard Worker                   const size_t ix = ox * stride_width() + px - padding_left();
936*4bdc9457SAndroid Build Coastguard Worker                   if (ix < input_width() && iy < input_height()) {
937*4bdc9457SAndroid Build Coastguard Worker                     acc += input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c];
938*4bdc9457SAndroid Build Coastguard Worker                     n += 1;
939*4bdc9457SAndroid Build Coastguard Worker                   }
940*4bdc9457SAndroid Build Coastguard Worker                 }
941*4bdc9457SAndroid Build Coastguard Worker               }
942*4bdc9457SAndroid Build Coastguard Worker               output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = acc / float(n);
943*4bdc9457SAndroid Build Coastguard Worker             }
944*4bdc9457SAndroid Build Coastguard Worker           }
945*4bdc9457SAndroid Build Coastguard Worker         }
946*4bdc9457SAndroid Build Coastguard Worker       }
947*4bdc9457SAndroid Build Coastguard Worker 
948*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
949*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
950*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
951*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
952*4bdc9457SAndroid Build Coastguard Worker       const float output_min = accumulated_range == 0.0f ?
953*4bdc9457SAndroid Build Coastguard Worker         -std::numeric_limits<float>::infinity() :
954*4bdc9457SAndroid Build Coastguard Worker         accumulated_min + accumulated_range / 255.0f * float(qmin());
955*4bdc9457SAndroid Build Coastguard Worker       const float output_max = accumulated_range == 0.0f ?
956*4bdc9457SAndroid Build Coastguard Worker         +std::numeric_limits<float>::infinity() :
957*4bdc9457SAndroid Build Coastguard Worker         accumulated_max - accumulated_range / 255.0f * float(255 - qmax());
958*4bdc9457SAndroid Build Coastguard Worker 
959*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
960*4bdc9457SAndroid Build Coastguard Worker       for (float& value : output_ref) {
961*4bdc9457SAndroid Build Coastguard Worker         value = std::max(std::min(value, output_max), output_min);
962*4bdc9457SAndroid Build Coastguard Worker       }
963*4bdc9457SAndroid Build Coastguard Worker 
964*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, and run Average Pooling operator once.
965*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
966*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t average_pooling_op = nullptr;
967*4bdc9457SAndroid Build Coastguard Worker 
968*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
969*4bdc9457SAndroid Build Coastguard Worker         xnn_create_average_pooling2d_nhwc_f32(
970*4bdc9457SAndroid Build Coastguard Worker           padding_top(), padding_right(), padding_bottom(), padding_left(),
971*4bdc9457SAndroid Build Coastguard Worker           pooling_height(), pooling_width(),
972*4bdc9457SAndroid Build Coastguard Worker           stride_height(), stride_width(),
973*4bdc9457SAndroid Build Coastguard Worker           channels(), input_pixel_stride(), output_pixel_stride(),
974*4bdc9457SAndroid Build Coastguard Worker           output_min, output_max,
975*4bdc9457SAndroid Build Coastguard Worker           0, &average_pooling_op));
976*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, average_pooling_op);
977*4bdc9457SAndroid Build Coastguard Worker 
978*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
979*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_average_pooling2d_nhwc_f32(
980*4bdc9457SAndroid Build Coastguard Worker           average_pooling_op,
981*4bdc9457SAndroid Build Coastguard Worker           batch_size(), input_height(), input_width(),
982*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
983*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
984*4bdc9457SAndroid Build Coastguard Worker 
985*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
986*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(average_pooling_op, nullptr /* thread pool */));
987*4bdc9457SAndroid Build Coastguard Worker 
988*4bdc9457SAndroid Build Coastguard Worker       // Verify results of the first run.
989*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
990*4bdc9457SAndroid Build Coastguard Worker         for (size_t y = 0; y < output_height(); y++) {
991*4bdc9457SAndroid Build Coastguard Worker           for (size_t x = 0; x < output_width(); x++) {
992*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < channels(); c++) {
993*4bdc9457SAndroid Build Coastguard Worker               ASSERT_LE(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c], output_max);
994*4bdc9457SAndroid Build Coastguard Worker               ASSERT_GE(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c], output_min);
995*4bdc9457SAndroid Build Coastguard Worker               ASSERT_NEAR(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c],
996*4bdc9457SAndroid Build Coastguard Worker                   output_ref[((i * output_height() + y) * output_width() + x) * channels() + c],
997*4bdc9457SAndroid Build Coastguard Worker                   std::abs(output_ref[((i * output_height() + y) * output_width() + x) * channels() + c]) * 1.0e-6f) <<
998*4bdc9457SAndroid Build Coastguard Worker                 "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c;
999*4bdc9457SAndroid Build Coastguard Worker             }
1000*4bdc9457SAndroid Build Coastguard Worker           }
1001*4bdc9457SAndroid Build Coastguard Worker         }
1002*4bdc9457SAndroid Build Coastguard Worker       }
1003*4bdc9457SAndroid Build Coastguard Worker 
1004*4bdc9457SAndroid Build Coastguard Worker       // Re-generate data for the second run.
1005*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
1006*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), std::nanf(""));
1007*4bdc9457SAndroid Build Coastguard Worker 
1008*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results for the second run.
1009*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < next_batch_size(); i++) {
1010*4bdc9457SAndroid Build Coastguard Worker         for (size_t oy = 0; oy < next_output_height(); oy++) {
1011*4bdc9457SAndroid Build Coastguard Worker           for (size_t ox = 0; ox < next_output_width(); ox++) {
1012*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < channels(); c++) {
1013*4bdc9457SAndroid Build Coastguard Worker               float acc = 0.0f;
1014*4bdc9457SAndroid Build Coastguard Worker               int32_t n = 0;
1015*4bdc9457SAndroid Build Coastguard Worker               for (size_t py = 0; py < pooling_height(); py++) {
1016*4bdc9457SAndroid Build Coastguard Worker                 const size_t iy = oy * stride_height() + py - padding_top();
1017*4bdc9457SAndroid Build Coastguard Worker                 for (size_t px = 0; px < pooling_width(); px++) {
1018*4bdc9457SAndroid Build Coastguard Worker                   const size_t ix = ox * stride_width() + px - padding_left();
1019*4bdc9457SAndroid Build Coastguard Worker                   if (ix < next_input_width() && iy < next_input_height()) {
1020*4bdc9457SAndroid Build Coastguard Worker                     acc += input[((i * next_input_height() + iy) * next_input_width() + ix) * input_pixel_stride() + c];
1021*4bdc9457SAndroid Build Coastguard Worker                     n += 1;
1022*4bdc9457SAndroid Build Coastguard Worker                   }
1023*4bdc9457SAndroid Build Coastguard Worker                 }
1024*4bdc9457SAndroid Build Coastguard Worker               }
1025*4bdc9457SAndroid Build Coastguard Worker               next_output_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c] =
1026*4bdc9457SAndroid Build Coastguard Worker                 std::max(std::min(acc / float(n), output_max), output_min);
1027*4bdc9457SAndroid Build Coastguard Worker             }
1028*4bdc9457SAndroid Build Coastguard Worker           }
1029*4bdc9457SAndroid Build Coastguard Worker         }
1030*4bdc9457SAndroid Build Coastguard Worker       }
1031*4bdc9457SAndroid Build Coastguard Worker 
1032*4bdc9457SAndroid Build Coastguard Worker       // Setup and run Average Pooling operator the second time, and destroy the operator.
1033*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
1034*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_average_pooling2d_nhwc_f32(
1035*4bdc9457SAndroid Build Coastguard Worker           average_pooling_op,
1036*4bdc9457SAndroid Build Coastguard Worker           next_batch_size(), next_input_height(), next_input_width(),
1037*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
1038*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
1039*4bdc9457SAndroid Build Coastguard Worker 
1040*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
1041*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(average_pooling_op, nullptr /* thread pool */));
1042*4bdc9457SAndroid Build Coastguard Worker 
1043*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
1044*4bdc9457SAndroid Build Coastguard Worker         xnn_delete_operator(average_pooling_op));
1045*4bdc9457SAndroid Build Coastguard Worker       average_pooling_op = nullptr;
1046*4bdc9457SAndroid Build Coastguard Worker 
1047*4bdc9457SAndroid Build Coastguard Worker       // Verify results of the second run.
1048*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < next_batch_size(); i++) {
1049*4bdc9457SAndroid Build Coastguard Worker         for (size_t y = 0; y < next_output_height(); y++) {
1050*4bdc9457SAndroid Build Coastguard Worker           for (size_t x = 0; x < next_output_width(); x++) {
1051*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < channels(); c++) {
1052*4bdc9457SAndroid Build Coastguard Worker               ASSERT_LE(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c], output_max);
1053*4bdc9457SAndroid Build Coastguard Worker               ASSERT_GE(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c], output_min);
1054*4bdc9457SAndroid Build Coastguard Worker               ASSERT_NEAR(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c],
1055*4bdc9457SAndroid Build Coastguard Worker                   next_output_ref[((i * next_output_height() + y) * next_output_width() + x) * channels() + c],
1056*4bdc9457SAndroid Build Coastguard Worker                   std::abs(next_output_ref[((i * next_output_height() + y) * next_output_width() + x) * channels() + c]) * 1.0e-6f) <<
1057*4bdc9457SAndroid Build Coastguard Worker                 "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c;
1058*4bdc9457SAndroid Build Coastguard Worker             }
1059*4bdc9457SAndroid Build Coastguard Worker           }
1060*4bdc9457SAndroid Build Coastguard Worker         }
1061*4bdc9457SAndroid Build Coastguard Worker       }
1062*4bdc9457SAndroid Build Coastguard Worker     }
1063*4bdc9457SAndroid Build Coastguard Worker   }
1064*4bdc9457SAndroid Build Coastguard Worker 
TestSetupQU8()1065*4bdc9457SAndroid Build Coastguard Worker   void TestSetupQU8() const {
1066*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
1067*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
1068*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> u8dist(
1069*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
1070*4bdc9457SAndroid Build Coastguard Worker 
1071*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + std::max<size_t>(
1072*4bdc9457SAndroid Build Coastguard Worker       (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels(),
1073*4bdc9457SAndroid Build Coastguard Worker       (next_batch_size() * next_input_height() * next_input_width() - 1) * input_pixel_stride() + channels()));
1074*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output(std::max<size_t>(
1075*4bdc9457SAndroid Build Coastguard Worker       (batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels(),
1076*4bdc9457SAndroid Build Coastguard Worker       (next_batch_size() * next_output_height() * next_output_width() - 1) * output_pixel_stride() + channels()));
1077*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels());
1078*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * channels());
1079*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
1080*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
1081*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), INT8_C(0xA5));
1082*4bdc9457SAndroid Build Coastguard Worker 
1083*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
1084*4bdc9457SAndroid Build Coastguard Worker       const double scale = double(input_scale()) / (double(output_scale()) * double(pooling_height() * pooling_width()));
1085*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
1086*4bdc9457SAndroid Build Coastguard Worker         for (size_t oy = 0; oy < output_height(); oy++) {
1087*4bdc9457SAndroid Build Coastguard Worker           for (size_t ox = 0; ox < output_width(); ox++) {
1088*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < channels(); c++) {
1089*4bdc9457SAndroid Build Coastguard Worker               double acc = 0.0f;
1090*4bdc9457SAndroid Build Coastguard Worker               for (size_t py = 0; py < pooling_height(); py++) {
1091*4bdc9457SAndroid Build Coastguard Worker                 const size_t iy = oy * stride_height() + py - padding_top();
1092*4bdc9457SAndroid Build Coastguard Worker                 for (size_t px = 0; px < pooling_width(); px++) {
1093*4bdc9457SAndroid Build Coastguard Worker                   const size_t ix = ox * stride_width() + px - padding_left();
1094*4bdc9457SAndroid Build Coastguard Worker                   if (ix < input_width() && iy < input_height()) {
1095*4bdc9457SAndroid Build Coastguard Worker                     acc += double(int32_t(input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c]) - int32_t(input_zero_point()));
1096*4bdc9457SAndroid Build Coastguard Worker                   }
1097*4bdc9457SAndroid Build Coastguard Worker                 }
1098*4bdc9457SAndroid Build Coastguard Worker               }
1099*4bdc9457SAndroid Build Coastguard Worker               output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = float(acc * scale + double(output_zero_point()));
1100*4bdc9457SAndroid Build Coastguard Worker               output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] =
1101*4bdc9457SAndroid Build Coastguard Worker                 std::min<float>(output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c], float(qmax()));
1102*4bdc9457SAndroid Build Coastguard Worker               output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] =
1103*4bdc9457SAndroid Build Coastguard Worker                 std::max<float>(output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c], float(qmin()));
1104*4bdc9457SAndroid Build Coastguard Worker             }
1105*4bdc9457SAndroid Build Coastguard Worker           }
1106*4bdc9457SAndroid Build Coastguard Worker         }
1107*4bdc9457SAndroid Build Coastguard Worker       }
1108*4bdc9457SAndroid Build Coastguard Worker 
1109*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, and run Average Pooling operator once.
1110*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
1111*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t average_pooling_op = nullptr;
1112*4bdc9457SAndroid Build Coastguard Worker 
1113*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
1114*4bdc9457SAndroid Build Coastguard Worker         xnn_create_average_pooling2d_nhwc_qu8(
1115*4bdc9457SAndroid Build Coastguard Worker           padding_top(), padding_right(), padding_bottom(), padding_left(),
1116*4bdc9457SAndroid Build Coastguard Worker           pooling_height(), pooling_width(),
1117*4bdc9457SAndroid Build Coastguard Worker           stride_height(), stride_width(),
1118*4bdc9457SAndroid Build Coastguard Worker           channels(), input_pixel_stride(), output_pixel_stride(),
1119*4bdc9457SAndroid Build Coastguard Worker           input_zero_point(), input_scale(),
1120*4bdc9457SAndroid Build Coastguard Worker           output_zero_point(), output_scale(),
1121*4bdc9457SAndroid Build Coastguard Worker           qmin(), qmax(),
1122*4bdc9457SAndroid Build Coastguard Worker           0, &average_pooling_op));
1123*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, average_pooling_op);
1124*4bdc9457SAndroid Build Coastguard Worker 
1125*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
1126*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_average_pooling2d_nhwc_qu8(
1127*4bdc9457SAndroid Build Coastguard Worker           average_pooling_op,
1128*4bdc9457SAndroid Build Coastguard Worker           batch_size(), input_height(), input_width(),
1129*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
1130*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
1131*4bdc9457SAndroid Build Coastguard Worker 
1132*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
1133*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(average_pooling_op, nullptr /* thread pool */));
1134*4bdc9457SAndroid Build Coastguard Worker 
1135*4bdc9457SAndroid Build Coastguard Worker       // Verify results of the first run.
1136*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
1137*4bdc9457SAndroid Build Coastguard Worker         for (size_t y = 0; y < output_height(); y++) {
1138*4bdc9457SAndroid Build Coastguard Worker           for (size_t x = 0; x < output_width(); x++) {
1139*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < channels(); c++) {
1140*4bdc9457SAndroid Build Coastguard Worker               ASSERT_LE(uint32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), uint32_t(qmax()));
1141*4bdc9457SAndroid Build Coastguard Worker               ASSERT_GE(uint32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), uint32_t(qmin()));
1142*4bdc9457SAndroid Build Coastguard Worker               ASSERT_NEAR(float(int32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c])),
1143*4bdc9457SAndroid Build Coastguard Worker                 output_ref[((i * output_height() + y) * output_width() + x) * channels() + c], 0.80f) <<
1144*4bdc9457SAndroid Build Coastguard Worker                 "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c;
1145*4bdc9457SAndroid Build Coastguard Worker             }
1146*4bdc9457SAndroid Build Coastguard Worker           }
1147*4bdc9457SAndroid Build Coastguard Worker         }
1148*4bdc9457SAndroid Build Coastguard Worker       }
1149*4bdc9457SAndroid Build Coastguard Worker 
1150*4bdc9457SAndroid Build Coastguard Worker       // Re-generate data for the second run.
1151*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
1152*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT8_C(0xA5));
1153*4bdc9457SAndroid Build Coastguard Worker 
1154*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results for the second run.
1155*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < next_batch_size(); i++) {
1156*4bdc9457SAndroid Build Coastguard Worker         for (size_t oy = 0; oy < next_output_height(); oy++) {
1157*4bdc9457SAndroid Build Coastguard Worker           for (size_t ox = 0; ox < next_output_width(); ox++) {
1158*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < channels(); c++) {
1159*4bdc9457SAndroid Build Coastguard Worker               double acc = 0.0f;
1160*4bdc9457SAndroid Build Coastguard Worker               for (size_t py = 0; py < pooling_height(); py++) {
1161*4bdc9457SAndroid Build Coastguard Worker                 const size_t iy = oy * stride_height() + py - padding_top();
1162*4bdc9457SAndroid Build Coastguard Worker                 for (size_t px = 0; px < pooling_width(); px++) {
1163*4bdc9457SAndroid Build Coastguard Worker                   const size_t ix = ox * stride_width() + px - padding_left();
1164*4bdc9457SAndroid Build Coastguard Worker                   if (ix < next_input_width() && iy < next_input_height()) {
1165*4bdc9457SAndroid Build Coastguard Worker                     acc += double(int32_t(input[((i * next_input_height() + iy) * next_input_width() + ix) * input_pixel_stride() + c]) - int32_t(input_zero_point()));
1166*4bdc9457SAndroid Build Coastguard Worker                   }
1167*4bdc9457SAndroid Build Coastguard Worker                 }
1168*4bdc9457SAndroid Build Coastguard Worker               }
1169*4bdc9457SAndroid Build Coastguard Worker               next_output_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c] = float(acc * scale + double(output_zero_point()));
1170*4bdc9457SAndroid Build Coastguard Worker               next_output_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c] =
1171*4bdc9457SAndroid Build Coastguard Worker                 std::min<float>(next_output_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c], float(qmax()));
1172*4bdc9457SAndroid Build Coastguard Worker               next_output_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c] =
1173*4bdc9457SAndroid Build Coastguard Worker                 std::max<float>(next_output_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c], float(qmin()));
1174*4bdc9457SAndroid Build Coastguard Worker             }
1175*4bdc9457SAndroid Build Coastguard Worker           }
1176*4bdc9457SAndroid Build Coastguard Worker         }
1177*4bdc9457SAndroid Build Coastguard Worker       }
1178*4bdc9457SAndroid Build Coastguard Worker 
1179*4bdc9457SAndroid Build Coastguard Worker       // Setup and run Average Pooling operator the second time, and destroy the operator.
1180*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
1181*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_average_pooling2d_nhwc_qu8(
1182*4bdc9457SAndroid Build Coastguard Worker           average_pooling_op,
1183*4bdc9457SAndroid Build Coastguard Worker           next_batch_size(), next_input_height(), next_input_width(),
1184*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
1185*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
1186*4bdc9457SAndroid Build Coastguard Worker 
1187*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
1188*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(average_pooling_op, nullptr /* thread pool */));
1189*4bdc9457SAndroid Build Coastguard Worker 
1190*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
1191*4bdc9457SAndroid Build Coastguard Worker         xnn_delete_operator(average_pooling_op));
1192*4bdc9457SAndroid Build Coastguard Worker       average_pooling_op = nullptr;
1193*4bdc9457SAndroid Build Coastguard Worker 
1194*4bdc9457SAndroid Build Coastguard Worker       // Verify results of the second run.
1195*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < next_batch_size(); i++) {
1196*4bdc9457SAndroid Build Coastguard Worker         for (size_t y = 0; y < next_output_height(); y++) {
1197*4bdc9457SAndroid Build Coastguard Worker           for (size_t x = 0; x < next_output_width(); x++) {
1198*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < channels(); c++) {
1199*4bdc9457SAndroid Build Coastguard Worker               ASSERT_LE(uint32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c]), uint32_t(qmax()));
1200*4bdc9457SAndroid Build Coastguard Worker               ASSERT_GE(uint32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c]), uint32_t(qmin()));
1201*4bdc9457SAndroid Build Coastguard Worker               ASSERT_NEAR(float(int32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c])),
1202*4bdc9457SAndroid Build Coastguard Worker                 next_output_ref[((i * next_output_height() + y) * next_output_width() + x) * channels() + c], 0.80f) <<
1203*4bdc9457SAndroid Build Coastguard Worker                 "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c;
1204*4bdc9457SAndroid Build Coastguard Worker             }
1205*4bdc9457SAndroid Build Coastguard Worker           }
1206*4bdc9457SAndroid Build Coastguard Worker         }
1207*4bdc9457SAndroid Build Coastguard Worker       }
1208*4bdc9457SAndroid Build Coastguard Worker     }
1209*4bdc9457SAndroid Build Coastguard Worker   }
1210*4bdc9457SAndroid Build Coastguard Worker 
1211*4bdc9457SAndroid Build Coastguard Worker  private:
1212*4bdc9457SAndroid Build Coastguard Worker   uint32_t padding_top_{0};
1213*4bdc9457SAndroid Build Coastguard Worker   uint32_t padding_right_{0};
1214*4bdc9457SAndroid Build Coastguard Worker   uint32_t padding_bottom_{0};
1215*4bdc9457SAndroid Build Coastguard Worker   uint32_t padding_left_{0};
1216*4bdc9457SAndroid Build Coastguard Worker   bool padding_tf_same_{false};
1217*4bdc9457SAndroid Build Coastguard Worker   size_t input_height_{1};
1218*4bdc9457SAndroid Build Coastguard Worker   size_t input_width_{1};
1219*4bdc9457SAndroid Build Coastguard Worker   size_t channels_{1};
1220*4bdc9457SAndroid Build Coastguard Worker   size_t batch_size_{1};
1221*4bdc9457SAndroid Build Coastguard Worker   size_t input_pixel_stride_{0};
1222*4bdc9457SAndroid Build Coastguard Worker   size_t output_pixel_stride_{0};
1223*4bdc9457SAndroid Build Coastguard Worker   uint32_t pooling_height_{1};
1224*4bdc9457SAndroid Build Coastguard Worker   uint32_t pooling_width_{1};
1225*4bdc9457SAndroid Build Coastguard Worker   uint32_t stride_height_{1};
1226*4bdc9457SAndroid Build Coastguard Worker   uint32_t stride_width_{1};
1227*4bdc9457SAndroid Build Coastguard Worker   size_t next_input_height_{0};
1228*4bdc9457SAndroid Build Coastguard Worker   size_t next_input_width_{0};
1229*4bdc9457SAndroid Build Coastguard Worker   size_t next_batch_size_{0};
1230*4bdc9457SAndroid Build Coastguard Worker   float input_scale_{1.0f};
1231*4bdc9457SAndroid Build Coastguard Worker   float output_scale_{1.0f};
1232*4bdc9457SAndroid Build Coastguard Worker   uint8_t input_zero_point_{121};
1233*4bdc9457SAndroid Build Coastguard Worker   uint8_t output_zero_point_{133};
1234*4bdc9457SAndroid Build Coastguard Worker   uint8_t qmin_{0};
1235*4bdc9457SAndroid Build Coastguard Worker   uint8_t qmax_{255};
1236*4bdc9457SAndroid Build Coastguard Worker   size_t iterations_{1};
1237*4bdc9457SAndroid Build Coastguard Worker };
1238