xref: /aosp_15_r20/external/XNNPACK/test/gavgpool-microkernel-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates.
2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved.
3*4bdc9457SAndroid Build Coastguard Worker //
4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
5*4bdc9457SAndroid Build Coastguard Worker //
6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
8*4bdc9457SAndroid Build Coastguard Worker 
9*4bdc9457SAndroid Build Coastguard Worker #pragma once
10*4bdc9457SAndroid Build Coastguard Worker 
11*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
12*4bdc9457SAndroid Build Coastguard Worker 
13*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
14*4bdc9457SAndroid Build Coastguard Worker #include <cassert>
15*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
16*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
17*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
18*4bdc9457SAndroid Build Coastguard Worker #include <limits>
19*4bdc9457SAndroid Build Coastguard Worker #include <random>
20*4bdc9457SAndroid Build Coastguard Worker #include <vector>
21*4bdc9457SAndroid Build Coastguard Worker 
22*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
23*4bdc9457SAndroid Build Coastguard Worker 
24*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
25*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/aligned-allocator.h>
26*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h>
27*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microparams-init.h>
28*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/requantization.h>
29*4bdc9457SAndroid Build Coastguard Worker 
30*4bdc9457SAndroid Build Coastguard Worker 
31*4bdc9457SAndroid Build Coastguard Worker class GAvgPoolMicrokernelTester {
32*4bdc9457SAndroid Build Coastguard Worker  public:
rows(size_t rows)33*4bdc9457SAndroid Build Coastguard Worker   inline GAvgPoolMicrokernelTester& rows(size_t rows) {
34*4bdc9457SAndroid Build Coastguard Worker     assert(rows != 0);
35*4bdc9457SAndroid Build Coastguard Worker     this->rows_ = rows;
36*4bdc9457SAndroid Build Coastguard Worker     return *this;
37*4bdc9457SAndroid Build Coastguard Worker   }
38*4bdc9457SAndroid Build Coastguard Worker 
rows()39*4bdc9457SAndroid Build Coastguard Worker   inline size_t rows() const {
40*4bdc9457SAndroid Build Coastguard Worker     return this->rows_;
41*4bdc9457SAndroid Build Coastguard Worker   }
42*4bdc9457SAndroid Build Coastguard Worker 
channels(size_t channels)43*4bdc9457SAndroid Build Coastguard Worker   inline GAvgPoolMicrokernelTester& channels(size_t channels) {
44*4bdc9457SAndroid Build Coastguard Worker     assert(channels != 0);
45*4bdc9457SAndroid Build Coastguard Worker     this->channels_ = channels;
46*4bdc9457SAndroid Build Coastguard Worker     return *this;
47*4bdc9457SAndroid Build Coastguard Worker   }
48*4bdc9457SAndroid Build Coastguard Worker 
channels()49*4bdc9457SAndroid Build Coastguard Worker   inline size_t channels() const {
50*4bdc9457SAndroid Build Coastguard Worker     return this->channels_;
51*4bdc9457SAndroid Build Coastguard Worker   }
52*4bdc9457SAndroid Build Coastguard Worker 
channel_tile(size_t channel_tile)53*4bdc9457SAndroid Build Coastguard Worker   inline GAvgPoolMicrokernelTester& channel_tile(size_t channel_tile) {
54*4bdc9457SAndroid Build Coastguard Worker     assert(channel_tile != 0);
55*4bdc9457SAndroid Build Coastguard Worker     this->channel_tile_ = channel_tile;
56*4bdc9457SAndroid Build Coastguard Worker     return *this;
57*4bdc9457SAndroid Build Coastguard Worker   }
58*4bdc9457SAndroid Build Coastguard Worker 
channel_tile()59*4bdc9457SAndroid Build Coastguard Worker   inline size_t channel_tile() const {
60*4bdc9457SAndroid Build Coastguard Worker     return this->channel_tile_;
61*4bdc9457SAndroid Build Coastguard Worker   }
62*4bdc9457SAndroid Build Coastguard Worker 
input_stride(size_t input_stride)63*4bdc9457SAndroid Build Coastguard Worker   inline GAvgPoolMicrokernelTester& input_stride(size_t input_stride) {
64*4bdc9457SAndroid Build Coastguard Worker     assert(input_stride != 0);
65*4bdc9457SAndroid Build Coastguard Worker     this->input_stride_ = input_stride;
66*4bdc9457SAndroid Build Coastguard Worker     return *this;
67*4bdc9457SAndroid Build Coastguard Worker   }
68*4bdc9457SAndroid Build Coastguard Worker 
input_stride()69*4bdc9457SAndroid Build Coastguard Worker   inline size_t input_stride() const {
70*4bdc9457SAndroid Build Coastguard Worker     if (this->input_stride_ == 0) {
71*4bdc9457SAndroid Build Coastguard Worker       return channels();
72*4bdc9457SAndroid Build Coastguard Worker     } else {
73*4bdc9457SAndroid Build Coastguard Worker       assert(this->input_stride_ >= channels());
74*4bdc9457SAndroid Build Coastguard Worker       return this->input_stride_;
75*4bdc9457SAndroid Build Coastguard Worker     }
76*4bdc9457SAndroid Build Coastguard Worker   }
77*4bdc9457SAndroid Build Coastguard Worker 
input_scale(float input_scale)78*4bdc9457SAndroid Build Coastguard Worker   inline GAvgPoolMicrokernelTester& input_scale(float input_scale) {
79*4bdc9457SAndroid Build Coastguard Worker     assert(input_scale > 0.0f);
80*4bdc9457SAndroid Build Coastguard Worker     assert(std::isnormal(input_scale));
81*4bdc9457SAndroid Build Coastguard Worker     this->input_scale_ = input_scale;
82*4bdc9457SAndroid Build Coastguard Worker     return *this;
83*4bdc9457SAndroid Build Coastguard Worker   }
84*4bdc9457SAndroid Build Coastguard Worker 
input_scale()85*4bdc9457SAndroid Build Coastguard Worker   inline float input_scale() const {
86*4bdc9457SAndroid Build Coastguard Worker     return this->input_scale_;
87*4bdc9457SAndroid Build Coastguard Worker   }
88*4bdc9457SAndroid Build Coastguard Worker 
input_zero_point(uint8_t input_zero_point)89*4bdc9457SAndroid Build Coastguard Worker   inline GAvgPoolMicrokernelTester& input_zero_point(uint8_t input_zero_point) {
90*4bdc9457SAndroid Build Coastguard Worker     this->input_zero_point_ = input_zero_point;
91*4bdc9457SAndroid Build Coastguard Worker     return *this;
92*4bdc9457SAndroid Build Coastguard Worker   }
93*4bdc9457SAndroid Build Coastguard Worker 
input_zero_point()94*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t input_zero_point() const {
95*4bdc9457SAndroid Build Coastguard Worker     return this->input_zero_point_;
96*4bdc9457SAndroid Build Coastguard Worker   }
97*4bdc9457SAndroid Build Coastguard Worker 
output_scale(float output_scale)98*4bdc9457SAndroid Build Coastguard Worker   inline GAvgPoolMicrokernelTester& output_scale(float output_scale) {
99*4bdc9457SAndroid Build Coastguard Worker     assert(output_scale > 0.0f);
100*4bdc9457SAndroid Build Coastguard Worker     assert(std::isnormal(output_scale));
101*4bdc9457SAndroid Build Coastguard Worker     this->output_scale_ = output_scale;
102*4bdc9457SAndroid Build Coastguard Worker     return *this;
103*4bdc9457SAndroid Build Coastguard Worker   }
104*4bdc9457SAndroid Build Coastguard Worker 
output_scale()105*4bdc9457SAndroid Build Coastguard Worker   inline float output_scale() const {
106*4bdc9457SAndroid Build Coastguard Worker     return this->output_scale_;
107*4bdc9457SAndroid Build Coastguard Worker   }
108*4bdc9457SAndroid Build Coastguard Worker 
output_zero_point(uint8_t output_zero_point)109*4bdc9457SAndroid Build Coastguard Worker   inline GAvgPoolMicrokernelTester& output_zero_point(uint8_t output_zero_point) {
110*4bdc9457SAndroid Build Coastguard Worker     this->output_zero_point_ = output_zero_point;
111*4bdc9457SAndroid Build Coastguard Worker     return *this;
112*4bdc9457SAndroid Build Coastguard Worker   }
113*4bdc9457SAndroid Build Coastguard Worker 
output_zero_point()114*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t output_zero_point() const {
115*4bdc9457SAndroid Build Coastguard Worker     return this->output_zero_point_;
116*4bdc9457SAndroid Build Coastguard Worker   }
117*4bdc9457SAndroid Build Coastguard Worker 
qmin(uint8_t qmin)118*4bdc9457SAndroid Build Coastguard Worker   inline GAvgPoolMicrokernelTester& qmin(uint8_t qmin) {
119*4bdc9457SAndroid Build Coastguard Worker     this->qmin_ = qmin;
120*4bdc9457SAndroid Build Coastguard Worker     return *this;
121*4bdc9457SAndroid Build Coastguard Worker   }
122*4bdc9457SAndroid Build Coastguard Worker 
qmin()123*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t qmin() const {
124*4bdc9457SAndroid Build Coastguard Worker     return this->qmin_;
125*4bdc9457SAndroid Build Coastguard Worker   }
126*4bdc9457SAndroid Build Coastguard Worker 
qmax(uint8_t qmax)127*4bdc9457SAndroid Build Coastguard Worker   inline GAvgPoolMicrokernelTester& qmax(uint8_t qmax) {
128*4bdc9457SAndroid Build Coastguard Worker     this->qmax_ = qmax;
129*4bdc9457SAndroid Build Coastguard Worker     return *this;
130*4bdc9457SAndroid Build Coastguard Worker   }
131*4bdc9457SAndroid Build Coastguard Worker 
qmax()132*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t qmax() const {
133*4bdc9457SAndroid Build Coastguard Worker     return this->qmax_;
134*4bdc9457SAndroid Build Coastguard Worker   }
135*4bdc9457SAndroid Build Coastguard Worker 
iterations(size_t iterations)136*4bdc9457SAndroid Build Coastguard Worker   inline GAvgPoolMicrokernelTester& iterations(size_t iterations) {
137*4bdc9457SAndroid Build Coastguard Worker     this->iterations_ = iterations;
138*4bdc9457SAndroid Build Coastguard Worker     return *this;
139*4bdc9457SAndroid Build Coastguard Worker   }
140*4bdc9457SAndroid Build Coastguard Worker 
iterations()141*4bdc9457SAndroid Build Coastguard Worker   inline size_t iterations() const {
142*4bdc9457SAndroid Build Coastguard Worker     return this->iterations_;
143*4bdc9457SAndroid Build Coastguard Worker   }
144*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_qu8_gavgpool_minmax_unipass_ukernel_function gavgpool_minmax,xnn_init_qu8_avgpool_minmax_params_fn init_params,xnn_qu8_requantize_fn requantize)145*4bdc9457SAndroid Build Coastguard Worker   void Test(
146*4bdc9457SAndroid Build Coastguard Worker       xnn_qu8_gavgpool_minmax_unipass_ukernel_function gavgpool_minmax,
147*4bdc9457SAndroid Build Coastguard Worker       xnn_init_qu8_avgpool_minmax_params_fn init_params,
148*4bdc9457SAndroid Build Coastguard Worker       xnn_qu8_requantize_fn requantize) const
149*4bdc9457SAndroid Build Coastguard Worker   {
150*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
151*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
152*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> u8dist(
153*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
154*4bdc9457SAndroid Build Coastguard Worker 
155*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) +
156*4bdc9457SAndroid Build Coastguard Worker       (rows() - 1) * input_stride() + channels());
157*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> zero(channels() + XNN_EXTRA_BYTES / sizeof(uint8_t));
158*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output(channels());
159*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output_ref(channels());
160*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_fp(channels());
161*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> accumulators(channels());
162*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
163*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
164*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT8_C(0xA5));
165*4bdc9457SAndroid Build Coastguard Worker 
166*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
167*4bdc9457SAndroid Build Coastguard Worker       union xnn_qu8_avgpool_minmax_params params;
168*4bdc9457SAndroid Build Coastguard Worker       init_params(
169*4bdc9457SAndroid Build Coastguard Worker         &params,
170*4bdc9457SAndroid Build Coastguard Worker         -int32_t(input_zero_point()) * int32_t(rows()),
171*4bdc9457SAndroid Build Coastguard Worker         input_scale() / (output_scale() * float(rows())),
172*4bdc9457SAndroid Build Coastguard Worker         output_zero_point(), qmin(), qmax());
173*4bdc9457SAndroid Build Coastguard Worker 
174*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
175*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < channels(); c++) {
176*4bdc9457SAndroid Build Coastguard Worker         int32_t acc = 0;
177*4bdc9457SAndroid Build Coastguard Worker         for (size_t n = 0; n < rows(); n++) {
178*4bdc9457SAndroid Build Coastguard Worker           acc += int32_t(input[n * input_stride() + c]) - int32_t(input_zero_point());
179*4bdc9457SAndroid Build Coastguard Worker         }
180*4bdc9457SAndroid Build Coastguard Worker         accumulators[c] = acc;
181*4bdc9457SAndroid Build Coastguard Worker         output_ref[c] = requantize(
182*4bdc9457SAndroid Build Coastguard Worker           acc, input_scale() / (output_scale() * float(rows())), output_zero_point(), qmin(), qmax());
183*4bdc9457SAndroid Build Coastguard Worker         output_fp[c] = float(acc) * (input_scale() / (output_scale() * float(rows()))) + float(output_zero_point());
184*4bdc9457SAndroid Build Coastguard Worker         output_fp[c] = std::min<float>(output_fp[c], float(qmax()));
185*4bdc9457SAndroid Build Coastguard Worker         output_fp[c] = std::max<float>(output_fp[c], float(qmin()));
186*4bdc9457SAndroid Build Coastguard Worker       }
187*4bdc9457SAndroid Build Coastguard Worker 
188*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
189*4bdc9457SAndroid Build Coastguard Worker       gavgpool_minmax(rows(), channels(),
190*4bdc9457SAndroid Build Coastguard Worker         input.data(), input_stride() * sizeof(uint8_t),
191*4bdc9457SAndroid Build Coastguard Worker         zero.data(),
192*4bdc9457SAndroid Build Coastguard Worker         output.data(),
193*4bdc9457SAndroid Build Coastguard Worker         &params);
194*4bdc9457SAndroid Build Coastguard Worker 
195*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
196*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < channels(); c++) {
197*4bdc9457SAndroid Build Coastguard Worker         ASSERT_LE(uint32_t(output[c]), uint32_t(qmax()))
198*4bdc9457SAndroid Build Coastguard Worker           << "at position " << c << ", rows = " << rows() << ", channels = " << channels();
199*4bdc9457SAndroid Build Coastguard Worker         ASSERT_GE(uint32_t(output[c]), uint32_t(qmin()))
200*4bdc9457SAndroid Build Coastguard Worker           << "at position " << c << ", rows = " << rows() << ", channels = " << channels();
201*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(float(int32_t(output[c])), output_fp[c], 0.55f)
202*4bdc9457SAndroid Build Coastguard Worker           << "at position " << c << ", rows = " << rows() << ", channels = " << channels()
203*4bdc9457SAndroid Build Coastguard Worker           << ", acc = " << accumulators[c];
204*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(uint32_t(output_ref[c]), uint32_t(output[c]))
205*4bdc9457SAndroid Build Coastguard Worker           << "at position " << c << ", rows = " << rows() << ", channels = " << channels()
206*4bdc9457SAndroid Build Coastguard Worker           << ", acc = " << accumulators[c];
207*4bdc9457SAndroid Build Coastguard Worker       }
208*4bdc9457SAndroid Build Coastguard Worker     }
209*4bdc9457SAndroid Build Coastguard Worker   }
210*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_qu8_gavgpool_minmax_multipass_ukernel_function gavgpool_minmax,xnn_init_qu8_avgpool_minmax_params_fn init_params,xnn_qu8_requantize_fn requantize)211*4bdc9457SAndroid Build Coastguard Worker   void Test(
212*4bdc9457SAndroid Build Coastguard Worker       xnn_qu8_gavgpool_minmax_multipass_ukernel_function gavgpool_minmax,
213*4bdc9457SAndroid Build Coastguard Worker       xnn_init_qu8_avgpool_minmax_params_fn init_params,
214*4bdc9457SAndroid Build Coastguard Worker       xnn_qu8_requantize_fn requantize) const
215*4bdc9457SAndroid Build Coastguard Worker   {
216*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
217*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
218*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> u8dist(
219*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
220*4bdc9457SAndroid Build Coastguard Worker 
221*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) +
222*4bdc9457SAndroid Build Coastguard Worker       (rows() - 1) * input_stride() + channels());
223*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t, AlignedAllocator<int32_t, 64>> buffer(channels() + XNN_EXTRA_BYTES / sizeof(uint8_t));
224*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> zero(channels() + XNN_EXTRA_BYTES / sizeof(uint8_t));
225*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output(channels());
226*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output_ref(channels());
227*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_fp(channels());
228*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> accumulators(channels());
229*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
230*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
231*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT8_C(0xA5));
232*4bdc9457SAndroid Build Coastguard Worker 
233*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
234*4bdc9457SAndroid Build Coastguard Worker       union xnn_qu8_avgpool_minmax_params params;
235*4bdc9457SAndroid Build Coastguard Worker       init_params(
236*4bdc9457SAndroid Build Coastguard Worker         &params,
237*4bdc9457SAndroid Build Coastguard Worker         -int32_t(input_zero_point()) * int32_t(rows()),
238*4bdc9457SAndroid Build Coastguard Worker         input_scale() / (output_scale() * float(rows())),
239*4bdc9457SAndroid Build Coastguard Worker         output_zero_point(), qmin(), qmax());
240*4bdc9457SAndroid Build Coastguard Worker 
241*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
242*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < channels(); c++) {
243*4bdc9457SAndroid Build Coastguard Worker         int32_t acc = 0;
244*4bdc9457SAndroid Build Coastguard Worker         for (size_t n = 0; n < rows(); n++) {
245*4bdc9457SAndroid Build Coastguard Worker           acc += int32_t(input[n * input_stride() + c]) - int32_t(input_zero_point());
246*4bdc9457SAndroid Build Coastguard Worker         }
247*4bdc9457SAndroid Build Coastguard Worker 
248*4bdc9457SAndroid Build Coastguard Worker         accumulators[c] = acc;
249*4bdc9457SAndroid Build Coastguard Worker         output_ref[c] = requantize(
250*4bdc9457SAndroid Build Coastguard Worker           acc, input_scale() / (output_scale() * float(rows())), output_zero_point(), qmin(), qmax());
251*4bdc9457SAndroid Build Coastguard Worker         output_fp[c] = float(acc) * (input_scale() / (output_scale() * float(rows()))) + float(output_zero_point());
252*4bdc9457SAndroid Build Coastguard Worker         output_fp[c] = std::min<float>(output_fp[c], float(qmax()));
253*4bdc9457SAndroid Build Coastguard Worker         output_fp[c] = std::max<float>(output_fp[c], float(qmin()));
254*4bdc9457SAndroid Build Coastguard Worker       }
255*4bdc9457SAndroid Build Coastguard Worker 
256*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
257*4bdc9457SAndroid Build Coastguard Worker       gavgpool_minmax(rows(), channels(),
258*4bdc9457SAndroid Build Coastguard Worker         input.data(), input_stride() * sizeof(uint8_t),
259*4bdc9457SAndroid Build Coastguard Worker         zero.data(),
260*4bdc9457SAndroid Build Coastguard Worker         buffer.data(),
261*4bdc9457SAndroid Build Coastguard Worker         output.data(),
262*4bdc9457SAndroid Build Coastguard Worker         &params);
263*4bdc9457SAndroid Build Coastguard Worker 
264*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
265*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < channels(); c++) {
266*4bdc9457SAndroid Build Coastguard Worker         ASSERT_LE(uint32_t(output[c]), uint32_t(qmax()))
267*4bdc9457SAndroid Build Coastguard Worker           << "at position " << c << ", rows = " << rows() << ", channels = " << channels();
268*4bdc9457SAndroid Build Coastguard Worker         ASSERT_GE(uint32_t(output[c]), uint32_t(qmin()))
269*4bdc9457SAndroid Build Coastguard Worker           << "at position " << c << ", rows = " << rows() << ", channels = " << channels();
270*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(float(int32_t(output[c])), output_fp[c], 0.55f)
271*4bdc9457SAndroid Build Coastguard Worker           << "at position " << c << ", rows = " << rows() << ", channels = " << channels()
272*4bdc9457SAndroid Build Coastguard Worker           << ", acc = " << accumulators[c];
273*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(uint32_t(output_ref[c]), uint32_t(output[c]))
274*4bdc9457SAndroid Build Coastguard Worker           << "at position " << c << ", rows = " << rows() << ", channels = " << channels()
275*4bdc9457SAndroid Build Coastguard Worker           << ", acc = " << accumulators[c];
276*4bdc9457SAndroid Build Coastguard Worker       }
277*4bdc9457SAndroid Build Coastguard Worker     }
278*4bdc9457SAndroid Build Coastguard Worker   }
279*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_qs8_gavgpool_minmax_unipass_ukernel_function gavgpool_minmax,xnn_init_qs8_avgpool_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize)280*4bdc9457SAndroid Build Coastguard Worker   void Test(
281*4bdc9457SAndroid Build Coastguard Worker       xnn_qs8_gavgpool_minmax_unipass_ukernel_function gavgpool_minmax,
282*4bdc9457SAndroid Build Coastguard Worker       xnn_init_qs8_avgpool_minmax_params_fn init_params,
283*4bdc9457SAndroid Build Coastguard Worker       xnn_qs8_requantize_fn requantize) const
284*4bdc9457SAndroid Build Coastguard Worker   {
285*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
286*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
287*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i8dist(
288*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
289*4bdc9457SAndroid Build Coastguard Worker 
290*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) +
291*4bdc9457SAndroid Build Coastguard Worker       (rows() - 1) * input_stride() + channels());
292*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> zero(channels() + XNN_EXTRA_BYTES / sizeof(int8_t));
293*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> output(channels());
294*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> output_ref(channels());
295*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_fp(channels());
296*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> accumulators(channels());
297*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
298*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
299*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), INT8_C(0xA5));
300*4bdc9457SAndroid Build Coastguard Worker 
301*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
302*4bdc9457SAndroid Build Coastguard Worker       union xnn_qs8_avgpool_minmax_params params;
303*4bdc9457SAndroid Build Coastguard Worker       init_params(
304*4bdc9457SAndroid Build Coastguard Worker         &params,
305*4bdc9457SAndroid Build Coastguard Worker         -int32_t(input_zero_point() - 0x80) * int32_t(rows()),
306*4bdc9457SAndroid Build Coastguard Worker         input_scale() / (output_scale() * float(rows())),
307*4bdc9457SAndroid Build Coastguard Worker         int8_t(output_zero_point() - 0x80), int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
308*4bdc9457SAndroid Build Coastguard Worker 
309*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
310*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < channels(); c++) {
311*4bdc9457SAndroid Build Coastguard Worker         int32_t acc = 0;
312*4bdc9457SAndroid Build Coastguard Worker         for (size_t n = 0; n < rows(); n++) {
313*4bdc9457SAndroid Build Coastguard Worker           acc += int32_t(input[n * input_stride() + c]) - int32_t(input_zero_point() - 0x80);
314*4bdc9457SAndroid Build Coastguard Worker         }
315*4bdc9457SAndroid Build Coastguard Worker         accumulators[c] = acc;
316*4bdc9457SAndroid Build Coastguard Worker         output_ref[c] = requantize(
317*4bdc9457SAndroid Build Coastguard Worker           acc, input_scale() / (output_scale() * float(rows())), int8_t(output_zero_point() - 0x80), int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
318*4bdc9457SAndroid Build Coastguard Worker         output_fp[c] = float(acc) * (input_scale() / (output_scale() * float(rows()))) + float(output_zero_point() - 0x80);
319*4bdc9457SAndroid Build Coastguard Worker         output_fp[c] = std::min<float>(output_fp[c], float(qmax() - 0x80));
320*4bdc9457SAndroid Build Coastguard Worker         output_fp[c] = std::max<float>(output_fp[c], float(qmin() - 0x80));
321*4bdc9457SAndroid Build Coastguard Worker       }
322*4bdc9457SAndroid Build Coastguard Worker 
323*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
324*4bdc9457SAndroid Build Coastguard Worker       gavgpool_minmax(rows(), channels(),
325*4bdc9457SAndroid Build Coastguard Worker         input.data(), input_stride() * sizeof(int8_t),
326*4bdc9457SAndroid Build Coastguard Worker         zero.data(),
327*4bdc9457SAndroid Build Coastguard Worker         output.data(),
328*4bdc9457SAndroid Build Coastguard Worker         &params);
329*4bdc9457SAndroid Build Coastguard Worker 
330*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
331*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < channels(); c++) {
332*4bdc9457SAndroid Build Coastguard Worker         ASSERT_LE(int32_t(output[c]), int32_t(qmax() - 0x80))
333*4bdc9457SAndroid Build Coastguard Worker           << "at channel " << c << " / " << channels() << ", rows = " << rows();
334*4bdc9457SAndroid Build Coastguard Worker         ASSERT_GE(int32_t(output[c]), int32_t(qmin() - 0x80))
335*4bdc9457SAndroid Build Coastguard Worker           << "at channel " << c << " / " << channels() << ", rows = " << rows();
336*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(float(int32_t(output[c])), output_fp[c], 0.55f)
337*4bdc9457SAndroid Build Coastguard Worker           << "at channel " << c << " / " << channels() << ", rows = " << rows()
338*4bdc9457SAndroid Build Coastguard Worker           << ", accumulator = " << accumulators[c];
339*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(int32_t(output_ref[c]), int32_t(output[c]))
340*4bdc9457SAndroid Build Coastguard Worker           << "at channel " << c << " / " << channels() << ", rows = " << rows()
341*4bdc9457SAndroid Build Coastguard Worker           << ", accumulator = " << accumulators[c];
342*4bdc9457SAndroid Build Coastguard Worker       }
343*4bdc9457SAndroid Build Coastguard Worker     }
344*4bdc9457SAndroid Build Coastguard Worker   }
345*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_qs8_gavgpool_minmax_multipass_ukernel_function gavgpool_minmax,xnn_init_qs8_avgpool_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize)346*4bdc9457SAndroid Build Coastguard Worker   void Test(
347*4bdc9457SAndroid Build Coastguard Worker       xnn_qs8_gavgpool_minmax_multipass_ukernel_function gavgpool_minmax,
348*4bdc9457SAndroid Build Coastguard Worker       xnn_init_qs8_avgpool_minmax_params_fn init_params,
349*4bdc9457SAndroid Build Coastguard Worker       xnn_qs8_requantize_fn requantize) const
350*4bdc9457SAndroid Build Coastguard Worker   {
351*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
352*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
353*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i8dist(
354*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
355*4bdc9457SAndroid Build Coastguard Worker 
356*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) +
357*4bdc9457SAndroid Build Coastguard Worker       (rows() - 1) * input_stride() + channels());
358*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t, AlignedAllocator<int32_t, 64>> buffer(channels() + XNN_EXTRA_BYTES / sizeof(int8_t));
359*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> zero(channels() + XNN_EXTRA_BYTES / sizeof(int8_t));
360*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> output(channels());
361*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> output_ref(channels());
362*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_fp(channels());
363*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> accumulators(channels());
364*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
365*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
366*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), INT8_C(0xA5));
367*4bdc9457SAndroid Build Coastguard Worker 
368*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
369*4bdc9457SAndroid Build Coastguard Worker       union xnn_qs8_avgpool_minmax_params params;
370*4bdc9457SAndroid Build Coastguard Worker       init_params(
371*4bdc9457SAndroid Build Coastguard Worker         &params,
372*4bdc9457SAndroid Build Coastguard Worker         -int32_t(input_zero_point() - 0x80) * int32_t(rows()),
373*4bdc9457SAndroid Build Coastguard Worker         input_scale() / (output_scale() * float(rows())),
374*4bdc9457SAndroid Build Coastguard Worker         int8_t(output_zero_point() - 0x80), int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
375*4bdc9457SAndroid Build Coastguard Worker 
376*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
377*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < channels(); c++) {
378*4bdc9457SAndroid Build Coastguard Worker         int32_t acc = 0;
379*4bdc9457SAndroid Build Coastguard Worker         for (size_t n = 0; n < rows(); n++) {
380*4bdc9457SAndroid Build Coastguard Worker           acc += int32_t(input[n * input_stride() + c]) - int32_t(input_zero_point() - 0x80);
381*4bdc9457SAndroid Build Coastguard Worker         }
382*4bdc9457SAndroid Build Coastguard Worker         accumulators[c] = acc;
383*4bdc9457SAndroid Build Coastguard Worker         output_ref[c] = requantize(
384*4bdc9457SAndroid Build Coastguard Worker           acc, input_scale() / (output_scale() * float(rows())), int8_t(output_zero_point() - 0x80), int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
385*4bdc9457SAndroid Build Coastguard Worker         output_fp[c] = float(acc) * (input_scale() / (output_scale() * float(rows()))) + float(output_zero_point() - 0x80);
386*4bdc9457SAndroid Build Coastguard Worker         output_fp[c] = std::min<float>(output_fp[c], float(qmax() - 0x80));
387*4bdc9457SAndroid Build Coastguard Worker         output_fp[c] = std::max<float>(output_fp[c], float(qmin() - 0x80));
388*4bdc9457SAndroid Build Coastguard Worker       }
389*4bdc9457SAndroid Build Coastguard Worker 
390*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
391*4bdc9457SAndroid Build Coastguard Worker       gavgpool_minmax(rows(), channels(),
392*4bdc9457SAndroid Build Coastguard Worker         input.data(), input_stride() * sizeof(int8_t),
393*4bdc9457SAndroid Build Coastguard Worker         zero.data(),
394*4bdc9457SAndroid Build Coastguard Worker         buffer.data(),
395*4bdc9457SAndroid Build Coastguard Worker         output.data(),
396*4bdc9457SAndroid Build Coastguard Worker         &params);
397*4bdc9457SAndroid Build Coastguard Worker 
398*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
399*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < channels(); c++) {
400*4bdc9457SAndroid Build Coastguard Worker         ASSERT_LE(int32_t(output[c]), int32_t(qmax() - 0x80))
401*4bdc9457SAndroid Build Coastguard Worker           << "at channel " << c << " / " << channels() << ", rows = " << rows();
402*4bdc9457SAndroid Build Coastguard Worker         ASSERT_GE(int32_t(output[c]), int32_t(qmin() - 0x80))
403*4bdc9457SAndroid Build Coastguard Worker           << "at channel " << c << " / " << channels() << ", rows = " << rows();
404*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(float(int32_t(output[c])), output_fp[c], 0.55f)
405*4bdc9457SAndroid Build Coastguard Worker           << "at channel " << c << " / " << channels() << ", rows = " << rows()
406*4bdc9457SAndroid Build Coastguard Worker           << ", accumulator = " << accumulators[c];
407*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(int32_t(output_ref[c]), int32_t(output[c]))
408*4bdc9457SAndroid Build Coastguard Worker           << "at channel " << c << " / " << channels() << ", rows = " << rows()
409*4bdc9457SAndroid Build Coastguard Worker           << ", accumulator = " << accumulators[c];
410*4bdc9457SAndroid Build Coastguard Worker       }
411*4bdc9457SAndroid Build Coastguard Worker     }
412*4bdc9457SAndroid Build Coastguard Worker   }
413*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f16_gavgpool_minmax_unipass_ukernel_function gavgpool_minmax,xnn_init_f16_scaleminmax_params_fn init_params)414*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f16_gavgpool_minmax_unipass_ukernel_function gavgpool_minmax, xnn_init_f16_scaleminmax_params_fn init_params) const {
415*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
416*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
417*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
418*4bdc9457SAndroid Build Coastguard Worker 
419*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> input((rows() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
420*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> zero(channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
421*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> output(channels());
422*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(channels());
423*4bdc9457SAndroid Build Coastguard Worker 
424*4bdc9457SAndroid Build Coastguard Worker     std::fill(zero.begin(), zero.end(), 0);
425*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
426*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
427*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
428*4bdc9457SAndroid Build Coastguard Worker 
429*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
430*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < channels(); c++) {
431*4bdc9457SAndroid Build Coastguard Worker         float acc = 0.0f;
432*4bdc9457SAndroid Build Coastguard Worker         for (size_t n = 0; n < rows(); n++) {
433*4bdc9457SAndroid Build Coastguard Worker           acc += fp16_ieee_to_fp32_value(input[n * input_stride() + c]);
434*4bdc9457SAndroid Build Coastguard Worker         }
435*4bdc9457SAndroid Build Coastguard Worker         output_ref[c] = acc / float(rows());
436*4bdc9457SAndroid Build Coastguard Worker       }
437*4bdc9457SAndroid Build Coastguard Worker 
438*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
439*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
440*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
441*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
442*4bdc9457SAndroid Build Coastguard Worker       const float output_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + float(qmin()) / 255.0f * accumulated_range));
443*4bdc9457SAndroid Build Coastguard Worker       const float output_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - float(255 - qmax()) / 255.0f * accumulated_range));
444*4bdc9457SAndroid Build Coastguard Worker 
445*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
446*4bdc9457SAndroid Build Coastguard Worker       for (float& output_values : output_ref) {
447*4bdc9457SAndroid Build Coastguard Worker         output_values = std::max(std::min(output_values, output_max), output_min);
448*4bdc9457SAndroid Build Coastguard Worker       }
449*4bdc9457SAndroid Build Coastguard Worker 
450*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
451*4bdc9457SAndroid Build Coastguard Worker       xnn_f16_scaleminmax_params params;
452*4bdc9457SAndroid Build Coastguard Worker       init_params(&params,
453*4bdc9457SAndroid Build Coastguard Worker         fp16_ieee_from_fp32_value(1.0f / float(rows())),
454*4bdc9457SAndroid Build Coastguard Worker         fp16_ieee_from_fp32_value(output_min),
455*4bdc9457SAndroid Build Coastguard Worker         fp16_ieee_from_fp32_value(output_max));
456*4bdc9457SAndroid Build Coastguard Worker 
457*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
458*4bdc9457SAndroid Build Coastguard Worker       gavgpool_minmax(rows(), channels(),
459*4bdc9457SAndroid Build Coastguard Worker         input.data(), input_stride() * sizeof(uint16_t),
460*4bdc9457SAndroid Build Coastguard Worker         zero.data(),
461*4bdc9457SAndroid Build Coastguard Worker         output.data(),
462*4bdc9457SAndroid Build Coastguard Worker         &params);
463*4bdc9457SAndroid Build Coastguard Worker 
464*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
465*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < channels(); c++) {
466*4bdc9457SAndroid Build Coastguard Worker         ASSERT_LE(fp16_ieee_to_fp32_value(output[c]), output_max)
467*4bdc9457SAndroid Build Coastguard Worker           << "at position " << c << ", rows = " << rows() << ", channels = " << channels();
468*4bdc9457SAndroid Build Coastguard Worker         ASSERT_GE(fp16_ieee_to_fp32_value(output[c]), output_min)
469*4bdc9457SAndroid Build Coastguard Worker           << "at position " << c << ", rows = " << rows() << ", channels = " << channels();
470*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(fp16_ieee_to_fp32_value(output[c]), output_ref[c], std::max(1.0e-4f, std::abs(output_ref[c]) * 1.0e-2f))
471*4bdc9457SAndroid Build Coastguard Worker           << "at position " << c << ", rows = " << rows() << ", channels = " << channels();
472*4bdc9457SAndroid Build Coastguard Worker       }
473*4bdc9457SAndroid Build Coastguard Worker     }
474*4bdc9457SAndroid Build Coastguard Worker   }
475*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f16_gavgpool_minmax_multipass_ukernel_function gavgpool_minmax,xnn_init_f16_scaleminmax_params_fn init_params)476*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f16_gavgpool_minmax_multipass_ukernel_function gavgpool_minmax, xnn_init_f16_scaleminmax_params_fn init_params) const {
477*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
478*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
479*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
480*4bdc9457SAndroid Build Coastguard Worker 
481*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> input((rows() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
482*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> buffer(channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
483*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> zero(channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
484*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> output(channels());
485*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(channels());
486*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
487*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
488*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
489*4bdc9457SAndroid Build Coastguard Worker 
490*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
491*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < channels(); c++) {
492*4bdc9457SAndroid Build Coastguard Worker         float acc = 0.0f;
493*4bdc9457SAndroid Build Coastguard Worker         for (size_t n = 0; n < rows(); n++) {
494*4bdc9457SAndroid Build Coastguard Worker           acc += fp16_ieee_to_fp32_value(input[n * input_stride() + c]);
495*4bdc9457SAndroid Build Coastguard Worker         }
496*4bdc9457SAndroid Build Coastguard Worker         output_ref[c] = acc / float(rows());
497*4bdc9457SAndroid Build Coastguard Worker       }
498*4bdc9457SAndroid Build Coastguard Worker 
499*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
500*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
501*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
502*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
503*4bdc9457SAndroid Build Coastguard Worker       const float output_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + float(qmin()) / 255.0f * accumulated_range));
504*4bdc9457SAndroid Build Coastguard Worker       const float output_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - float(255 - qmax()) / 255.0f * accumulated_range));
505*4bdc9457SAndroid Build Coastguard Worker 
506*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
507*4bdc9457SAndroid Build Coastguard Worker       xnn_f16_scaleminmax_params params;
508*4bdc9457SAndroid Build Coastguard Worker       init_params(&params,
509*4bdc9457SAndroid Build Coastguard Worker         fp16_ieee_from_fp32_value(1.0f / float(rows())),
510*4bdc9457SAndroid Build Coastguard Worker         fp16_ieee_from_fp32_value(output_min),
511*4bdc9457SAndroid Build Coastguard Worker         fp16_ieee_from_fp32_value(output_max));
512*4bdc9457SAndroid Build Coastguard Worker 
513*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
514*4bdc9457SAndroid Build Coastguard Worker       for (float& output_values : output_ref) {
515*4bdc9457SAndroid Build Coastguard Worker         output_values = std::max(std::min(output_values, output_max), output_min);
516*4bdc9457SAndroid Build Coastguard Worker       }
517*4bdc9457SAndroid Build Coastguard Worker 
518*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
519*4bdc9457SAndroid Build Coastguard Worker       gavgpool_minmax(rows(), channels(),
520*4bdc9457SAndroid Build Coastguard Worker         input.data(), input_stride() * sizeof(uint16_t),
521*4bdc9457SAndroid Build Coastguard Worker         zero.data(),
522*4bdc9457SAndroid Build Coastguard Worker         buffer.data(),
523*4bdc9457SAndroid Build Coastguard Worker         output.data(),
524*4bdc9457SAndroid Build Coastguard Worker         &params);
525*4bdc9457SAndroid Build Coastguard Worker 
526*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
527*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < channels(); c++) {
528*4bdc9457SAndroid Build Coastguard Worker         ASSERT_LE(fp16_ieee_to_fp32_value(output[c]), output_max)
529*4bdc9457SAndroid Build Coastguard Worker           << "at position " << c << ", rows = " << rows() << ", channels = " << channels();
530*4bdc9457SAndroid Build Coastguard Worker         ASSERT_GE(fp16_ieee_to_fp32_value(output[c]), output_min)
531*4bdc9457SAndroid Build Coastguard Worker           << "at position " << c << ", rows = " << rows() << ", channels = " << channels();
532*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(fp16_ieee_to_fp32_value(output[c]), output_ref[c], std::abs(output_ref[c]) * 1.0e-0f)
533*4bdc9457SAndroid Build Coastguard Worker           << "at position " << c << ", rows = " << rows() << ", channels = " << channels();
534*4bdc9457SAndroid Build Coastguard Worker       }
535*4bdc9457SAndroid Build Coastguard Worker     }
536*4bdc9457SAndroid Build Coastguard Worker   }
537*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f32_gavgpool_minmax_unipass_ukernel_function gavgpool_minmax,xnn_init_f32_scaleminmax_params_fn init_params)538*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_gavgpool_minmax_unipass_ukernel_function gavgpool_minmax, xnn_init_f32_scaleminmax_params_fn init_params) const {
539*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
540*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
541*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
542*4bdc9457SAndroid Build Coastguard Worker 
543*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input((rows() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float));
544*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> zero(channels() + XNN_EXTRA_BYTES / sizeof(float));
545*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output(channels());
546*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(channels());
547*4bdc9457SAndroid Build Coastguard Worker 
548*4bdc9457SAndroid Build Coastguard Worker     std::fill(zero.begin(), zero.end(), 0.0f);
549*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
550*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
551*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), std::nanf(""));
552*4bdc9457SAndroid Build Coastguard Worker 
553*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
554*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < channels(); c++) {
555*4bdc9457SAndroid Build Coastguard Worker         float acc = 0.0f;
556*4bdc9457SAndroid Build Coastguard Worker         for (size_t n = 0; n < rows(); n++) {
557*4bdc9457SAndroid Build Coastguard Worker           acc += input[n * input_stride() + c];
558*4bdc9457SAndroid Build Coastguard Worker         }
559*4bdc9457SAndroid Build Coastguard Worker         output_ref[c] = acc / float(rows());
560*4bdc9457SAndroid Build Coastguard Worker       }
561*4bdc9457SAndroid Build Coastguard Worker 
562*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
563*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
564*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
565*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
566*4bdc9457SAndroid Build Coastguard Worker       const float output_min = accumulated_min + float(qmin()) / 255.0f * accumulated_range;
567*4bdc9457SAndroid Build Coastguard Worker       const float output_max = accumulated_max - float(255 - qmax()) / 255.0f * accumulated_range;
568*4bdc9457SAndroid Build Coastguard Worker 
569*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
570*4bdc9457SAndroid Build Coastguard Worker       for (float& output_values : output_ref) {
571*4bdc9457SAndroid Build Coastguard Worker         output_values = std::max(std::min(output_values, output_max), output_min);
572*4bdc9457SAndroid Build Coastguard Worker       }
573*4bdc9457SAndroid Build Coastguard Worker 
574*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
575*4bdc9457SAndroid Build Coastguard Worker       union xnn_f32_scaleminmax_params params;
576*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, 1.0f / float(rows()), output_min, output_max);
577*4bdc9457SAndroid Build Coastguard Worker 
578*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
579*4bdc9457SAndroid Build Coastguard Worker       gavgpool_minmax(rows(), channels(),
580*4bdc9457SAndroid Build Coastguard Worker         input.data(), input_stride() * sizeof(float),
581*4bdc9457SAndroid Build Coastguard Worker         zero.data(),
582*4bdc9457SAndroid Build Coastguard Worker         output.data(),
583*4bdc9457SAndroid Build Coastguard Worker         &params);
584*4bdc9457SAndroid Build Coastguard Worker 
585*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
586*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < channels(); c++) {
587*4bdc9457SAndroid Build Coastguard Worker         ASSERT_LE(output[c], output_max)
588*4bdc9457SAndroid Build Coastguard Worker           << "at position " << c << ", rows = " << rows() << ", channels = " << channels();
589*4bdc9457SAndroid Build Coastguard Worker         ASSERT_GE(output[c], output_min)
590*4bdc9457SAndroid Build Coastguard Worker           << "at position " << c << ", rows = " << rows() << ", channels = " << channels();
591*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(output[c], output_ref[c], std::abs(output_ref[c]) * 1.0e-6f)
592*4bdc9457SAndroid Build Coastguard Worker           << "at position " << c << ", rows = " << rows() << ", channels = " << channels();
593*4bdc9457SAndroid Build Coastguard Worker       }
594*4bdc9457SAndroid Build Coastguard Worker     }
595*4bdc9457SAndroid Build Coastguard Worker   }
596*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f32_gavgpool_minmax_multipass_ukernel_function gavgpool_minmax,xnn_init_f32_scaleminmax_params_fn init_params)597*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_gavgpool_minmax_multipass_ukernel_function gavgpool_minmax, xnn_init_f32_scaleminmax_params_fn init_params) const {
598*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
599*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
600*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
601*4bdc9457SAndroid Build Coastguard Worker 
602*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input((rows() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float));
603*4bdc9457SAndroid Build Coastguard Worker     std::vector<float, AlignedAllocator<float, 64>> buffer(channels() + XNN_EXTRA_BYTES / sizeof(float));
604*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> zero(channels() + XNN_EXTRA_BYTES / sizeof(float));
605*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output(channels());
606*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(channels());
607*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
608*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
609*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), std::nanf(""));
610*4bdc9457SAndroid Build Coastguard Worker 
611*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
612*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < channels(); c++) {
613*4bdc9457SAndroid Build Coastguard Worker         float acc = 0.0f;
614*4bdc9457SAndroid Build Coastguard Worker         for (size_t n = 0; n < rows(); n++) {
615*4bdc9457SAndroid Build Coastguard Worker           acc += input[n * input_stride() + c];
616*4bdc9457SAndroid Build Coastguard Worker         }
617*4bdc9457SAndroid Build Coastguard Worker         output_ref[c] = acc / float(rows());
618*4bdc9457SAndroid Build Coastguard Worker       }
619*4bdc9457SAndroid Build Coastguard Worker 
620*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
621*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
622*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
623*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
624*4bdc9457SAndroid Build Coastguard Worker       const float output_min = accumulated_min + float(qmin()) / 255.0f * accumulated_range;
625*4bdc9457SAndroid Build Coastguard Worker       const float output_max = accumulated_max - float(255 - qmax()) / 255.0f * accumulated_range;
626*4bdc9457SAndroid Build Coastguard Worker 
627*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
628*4bdc9457SAndroid Build Coastguard Worker       union xnn_f32_scaleminmax_params params;
629*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, 1.0f / float(rows()), output_min, output_max);
630*4bdc9457SAndroid Build Coastguard Worker 
631*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
632*4bdc9457SAndroid Build Coastguard Worker       for (float& output_values : output_ref) {
633*4bdc9457SAndroid Build Coastguard Worker         output_values = std::max(std::min(output_values, output_max), output_min);
634*4bdc9457SAndroid Build Coastguard Worker       }
635*4bdc9457SAndroid Build Coastguard Worker 
636*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
637*4bdc9457SAndroid Build Coastguard Worker       gavgpool_minmax(rows(), channels(),
638*4bdc9457SAndroid Build Coastguard Worker         input.data(), input_stride() * sizeof(float),
639*4bdc9457SAndroid Build Coastguard Worker         zero.data(),
640*4bdc9457SAndroid Build Coastguard Worker         buffer.data(),
641*4bdc9457SAndroid Build Coastguard Worker         output.data(),
642*4bdc9457SAndroid Build Coastguard Worker         &params);
643*4bdc9457SAndroid Build Coastguard Worker 
644*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
645*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < channels(); c++) {
646*4bdc9457SAndroid Build Coastguard Worker         ASSERT_LE(output[c], output_max)
647*4bdc9457SAndroid Build Coastguard Worker           << "at position " << c << ", rows = " << rows() << ", channels = " << channels();
648*4bdc9457SAndroid Build Coastguard Worker         ASSERT_GE(output[c], output_min)
649*4bdc9457SAndroid Build Coastguard Worker           << "at position " << c << ", rows = " << rows() << ", channels = " << channels();
650*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(output[c], output_ref[c], std::abs(output_ref[c]) * 1.0e-6f)
651*4bdc9457SAndroid Build Coastguard Worker           << "at position " << c << ", rows = " << rows() << ", channels = " << channels();
652*4bdc9457SAndroid Build Coastguard Worker       }
653*4bdc9457SAndroid Build Coastguard Worker     }
654*4bdc9457SAndroid Build Coastguard Worker   }
655*4bdc9457SAndroid Build Coastguard Worker 
656*4bdc9457SAndroid Build Coastguard Worker  private:
657*4bdc9457SAndroid Build Coastguard Worker   size_t rows_{1};
658*4bdc9457SAndroid Build Coastguard Worker   size_t channels_{1};
659*4bdc9457SAndroid Build Coastguard Worker   size_t channel_tile_{1};
660*4bdc9457SAndroid Build Coastguard Worker   size_t input_stride_{0};
661*4bdc9457SAndroid Build Coastguard Worker   float input_scale_{1.25f};
662*4bdc9457SAndroid Build Coastguard Worker   float output_scale_{0.75f};
663*4bdc9457SAndroid Build Coastguard Worker   uint8_t input_zero_point_{121};
664*4bdc9457SAndroid Build Coastguard Worker   uint8_t output_zero_point_{133};
665*4bdc9457SAndroid Build Coastguard Worker   uint8_t qmin_{0};
666*4bdc9457SAndroid Build Coastguard Worker   uint8_t qmax_{255};
667*4bdc9457SAndroid Build Coastguard Worker   size_t iterations_{15};
668*4bdc9457SAndroid Build Coastguard Worker };
669