xref: /aosp_15_r20/external/XNNPACK/test/global-average-pooling-operator-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates.
2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved.
3*4bdc9457SAndroid Build Coastguard Worker //
4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
5*4bdc9457SAndroid Build Coastguard Worker //
6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
8*4bdc9457SAndroid Build Coastguard Worker 
9*4bdc9457SAndroid Build Coastguard Worker #pragma once
10*4bdc9457SAndroid Build Coastguard Worker 
11*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
12*4bdc9457SAndroid Build Coastguard Worker 
13*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
14*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
15*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
16*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
17*4bdc9457SAndroid Build Coastguard Worker #include <limits>
18*4bdc9457SAndroid Build Coastguard Worker #include <random>
19*4bdc9457SAndroid Build Coastguard Worker #include <vector>
20*4bdc9457SAndroid Build Coastguard Worker 
21*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
22*4bdc9457SAndroid Build Coastguard Worker 
23*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
24*4bdc9457SAndroid Build Coastguard Worker 
25*4bdc9457SAndroid Build Coastguard Worker 
26*4bdc9457SAndroid Build Coastguard Worker class GlobalAveragePoolingOperatorTester {
27*4bdc9457SAndroid Build Coastguard Worker  public:
channels(size_t channels)28*4bdc9457SAndroid Build Coastguard Worker   inline GlobalAveragePoolingOperatorTester& channels(size_t channels) {
29*4bdc9457SAndroid Build Coastguard Worker     assert(channels != 0);
30*4bdc9457SAndroid Build Coastguard Worker     this->channels_ = channels;
31*4bdc9457SAndroid Build Coastguard Worker     return *this;
32*4bdc9457SAndroid Build Coastguard Worker   }
33*4bdc9457SAndroid Build Coastguard Worker 
channels()34*4bdc9457SAndroid Build Coastguard Worker   inline size_t channels() const {
35*4bdc9457SAndroid Build Coastguard Worker     return this->channels_;
36*4bdc9457SAndroid Build Coastguard Worker   }
37*4bdc9457SAndroid Build Coastguard Worker 
width(size_t width)38*4bdc9457SAndroid Build Coastguard Worker   inline GlobalAveragePoolingOperatorTester& width(size_t width) {
39*4bdc9457SAndroid Build Coastguard Worker     assert(width != 0);
40*4bdc9457SAndroid Build Coastguard Worker     this->width_ = width;
41*4bdc9457SAndroid Build Coastguard Worker     return *this;
42*4bdc9457SAndroid Build Coastguard Worker   }
43*4bdc9457SAndroid Build Coastguard Worker 
width()44*4bdc9457SAndroid Build Coastguard Worker   inline size_t width() const {
45*4bdc9457SAndroid Build Coastguard Worker     return this->width_;
46*4bdc9457SAndroid Build Coastguard Worker   }
47*4bdc9457SAndroid Build Coastguard Worker 
input_stride(size_t input_stride)48*4bdc9457SAndroid Build Coastguard Worker   inline GlobalAveragePoolingOperatorTester& input_stride(size_t input_stride) {
49*4bdc9457SAndroid Build Coastguard Worker     assert(input_stride != 0);
50*4bdc9457SAndroid Build Coastguard Worker     this->input_stride_ = input_stride;
51*4bdc9457SAndroid Build Coastguard Worker     return *this;
52*4bdc9457SAndroid Build Coastguard Worker   }
53*4bdc9457SAndroid Build Coastguard Worker 
input_stride()54*4bdc9457SAndroid Build Coastguard Worker   inline size_t input_stride() const {
55*4bdc9457SAndroid Build Coastguard Worker     if (this->input_stride_ == 0) {
56*4bdc9457SAndroid Build Coastguard Worker       return channels();
57*4bdc9457SAndroid Build Coastguard Worker     } else {
58*4bdc9457SAndroid Build Coastguard Worker       assert(this->input_stride_ >= channels());
59*4bdc9457SAndroid Build Coastguard Worker       return this->input_stride_;
60*4bdc9457SAndroid Build Coastguard Worker     }
61*4bdc9457SAndroid Build Coastguard Worker   }
62*4bdc9457SAndroid Build Coastguard Worker 
output_stride(size_t output_stride)63*4bdc9457SAndroid Build Coastguard Worker   inline GlobalAveragePoolingOperatorTester& output_stride(size_t output_stride) {
64*4bdc9457SAndroid Build Coastguard Worker     assert(output_stride != 0);
65*4bdc9457SAndroid Build Coastguard Worker     this->output_stride_ = output_stride;
66*4bdc9457SAndroid Build Coastguard Worker     return *this;
67*4bdc9457SAndroid Build Coastguard Worker   }
68*4bdc9457SAndroid Build Coastguard Worker 
output_stride()69*4bdc9457SAndroid Build Coastguard Worker   inline size_t output_stride() const {
70*4bdc9457SAndroid Build Coastguard Worker     if (this->output_stride_ == 0) {
71*4bdc9457SAndroid Build Coastguard Worker       return channels();
72*4bdc9457SAndroid Build Coastguard Worker     } else {
73*4bdc9457SAndroid Build Coastguard Worker       assert(this->output_stride_ >= channels());
74*4bdc9457SAndroid Build Coastguard Worker       return this->output_stride_;
75*4bdc9457SAndroid Build Coastguard Worker     }
76*4bdc9457SAndroid Build Coastguard Worker   }
77*4bdc9457SAndroid Build Coastguard Worker 
batch_size(size_t batch_size)78*4bdc9457SAndroid Build Coastguard Worker   inline GlobalAveragePoolingOperatorTester& batch_size(size_t batch_size) {
79*4bdc9457SAndroid Build Coastguard Worker     assert(batch_size != 0);
80*4bdc9457SAndroid Build Coastguard Worker     this->batch_size_ = batch_size;
81*4bdc9457SAndroid Build Coastguard Worker     return *this;
82*4bdc9457SAndroid Build Coastguard Worker   }
83*4bdc9457SAndroid Build Coastguard Worker 
batch_size()84*4bdc9457SAndroid Build Coastguard Worker   inline size_t batch_size() const {
85*4bdc9457SAndroid Build Coastguard Worker     return this->batch_size_;
86*4bdc9457SAndroid Build Coastguard Worker   }
87*4bdc9457SAndroid Build Coastguard Worker 
input_scale(float input_scale)88*4bdc9457SAndroid Build Coastguard Worker   inline GlobalAveragePoolingOperatorTester& input_scale(float input_scale) {
89*4bdc9457SAndroid Build Coastguard Worker     assert(input_scale > 0.0f);
90*4bdc9457SAndroid Build Coastguard Worker     assert(std::isnormal(input_scale));
91*4bdc9457SAndroid Build Coastguard Worker     this->input_scale_ = input_scale;
92*4bdc9457SAndroid Build Coastguard Worker     return *this;
93*4bdc9457SAndroid Build Coastguard Worker   }
94*4bdc9457SAndroid Build Coastguard Worker 
input_scale()95*4bdc9457SAndroid Build Coastguard Worker   inline float input_scale() const {
96*4bdc9457SAndroid Build Coastguard Worker     return this->input_scale_;
97*4bdc9457SAndroid Build Coastguard Worker   }
98*4bdc9457SAndroid Build Coastguard Worker 
input_zero_point(uint8_t input_zero_point)99*4bdc9457SAndroid Build Coastguard Worker   inline GlobalAveragePoolingOperatorTester& input_zero_point(uint8_t input_zero_point) {
100*4bdc9457SAndroid Build Coastguard Worker     this->input_zero_point_ = input_zero_point;
101*4bdc9457SAndroid Build Coastguard Worker     return *this;
102*4bdc9457SAndroid Build Coastguard Worker   }
103*4bdc9457SAndroid Build Coastguard Worker 
input_zero_point()104*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t input_zero_point() const {
105*4bdc9457SAndroid Build Coastguard Worker     return this->input_zero_point_;
106*4bdc9457SAndroid Build Coastguard Worker   }
107*4bdc9457SAndroid Build Coastguard Worker 
output_scale(float output_scale)108*4bdc9457SAndroid Build Coastguard Worker   inline GlobalAveragePoolingOperatorTester& output_scale(float output_scale) {
109*4bdc9457SAndroid Build Coastguard Worker     assert(output_scale > 0.0f);
110*4bdc9457SAndroid Build Coastguard Worker     assert(std::isnormal(output_scale));
111*4bdc9457SAndroid Build Coastguard Worker     this->output_scale_ = output_scale;
112*4bdc9457SAndroid Build Coastguard Worker     return *this;
113*4bdc9457SAndroid Build Coastguard Worker   }
114*4bdc9457SAndroid Build Coastguard Worker 
output_scale()115*4bdc9457SAndroid Build Coastguard Worker   inline float output_scale() const {
116*4bdc9457SAndroid Build Coastguard Worker     return this->output_scale_;
117*4bdc9457SAndroid Build Coastguard Worker   }
118*4bdc9457SAndroid Build Coastguard Worker 
output_zero_point(uint8_t output_zero_point)119*4bdc9457SAndroid Build Coastguard Worker   inline GlobalAveragePoolingOperatorTester& output_zero_point(uint8_t output_zero_point) {
120*4bdc9457SAndroid Build Coastguard Worker     this->output_zero_point_ = output_zero_point;
121*4bdc9457SAndroid Build Coastguard Worker     return *this;
122*4bdc9457SAndroid Build Coastguard Worker   }
123*4bdc9457SAndroid Build Coastguard Worker 
output_zero_point()124*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t output_zero_point() const {
125*4bdc9457SAndroid Build Coastguard Worker     return this->output_zero_point_;
126*4bdc9457SAndroid Build Coastguard Worker   }
127*4bdc9457SAndroid Build Coastguard Worker 
qmin(uint8_t qmin)128*4bdc9457SAndroid Build Coastguard Worker   inline GlobalAveragePoolingOperatorTester& qmin(uint8_t qmin) {
129*4bdc9457SAndroid Build Coastguard Worker     this->qmin_ = qmin;
130*4bdc9457SAndroid Build Coastguard Worker     return *this;
131*4bdc9457SAndroid Build Coastguard Worker   }
132*4bdc9457SAndroid Build Coastguard Worker 
qmin()133*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t qmin() const {
134*4bdc9457SAndroid Build Coastguard Worker     return this->qmin_;
135*4bdc9457SAndroid Build Coastguard Worker   }
136*4bdc9457SAndroid Build Coastguard Worker 
qmax(uint8_t qmax)137*4bdc9457SAndroid Build Coastguard Worker   inline GlobalAveragePoolingOperatorTester& qmax(uint8_t qmax) {
138*4bdc9457SAndroid Build Coastguard Worker     this->qmax_ = qmax;
139*4bdc9457SAndroid Build Coastguard Worker     return *this;
140*4bdc9457SAndroid Build Coastguard Worker   }
141*4bdc9457SAndroid Build Coastguard Worker 
qmax()142*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t qmax() const {
143*4bdc9457SAndroid Build Coastguard Worker     return this->qmax_;
144*4bdc9457SAndroid Build Coastguard Worker   }
145*4bdc9457SAndroid Build Coastguard Worker 
iterations(size_t iterations)146*4bdc9457SAndroid Build Coastguard Worker   inline GlobalAveragePoolingOperatorTester& iterations(size_t iterations) {
147*4bdc9457SAndroid Build Coastguard Worker     this->iterations_ = iterations;
148*4bdc9457SAndroid Build Coastguard Worker     return *this;
149*4bdc9457SAndroid Build Coastguard Worker   }
150*4bdc9457SAndroid Build Coastguard Worker 
iterations()151*4bdc9457SAndroid Build Coastguard Worker   inline size_t iterations() const {
152*4bdc9457SAndroid Build Coastguard Worker     return this->iterations_;
153*4bdc9457SAndroid Build Coastguard Worker   }
154*4bdc9457SAndroid Build Coastguard Worker 
TestNWCxQU8()155*4bdc9457SAndroid Build Coastguard Worker   void TestNWCxQU8() const {
156*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
157*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
158*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> u8dist(
159*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
160*4bdc9457SAndroid Build Coastguard Worker 
161*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input((batch_size() * width() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint8_t));
162*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output(batch_size() * output_stride());
163*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(batch_size() * channels());
164*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
165*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
166*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT8_C(0xA5));
167*4bdc9457SAndroid Build Coastguard Worker 
168*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
169*4bdc9457SAndroid Build Coastguard Worker       const double scale = double(input_scale()) / (double(width()) * double(output_scale()));
170*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
171*4bdc9457SAndroid Build Coastguard Worker         for (size_t j = 0; j < channels(); j++) {
172*4bdc9457SAndroid Build Coastguard Worker           double acc = 0.0f;
173*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < width(); k++) {
174*4bdc9457SAndroid Build Coastguard Worker             acc += double(int32_t(input[(i * width() + k) * input_stride() + j]) - int32_t(input_zero_point()));
175*4bdc9457SAndroid Build Coastguard Worker           }
176*4bdc9457SAndroid Build Coastguard Worker           output_ref[i * channels() + j] = float(acc * scale + double(output_zero_point()));
177*4bdc9457SAndroid Build Coastguard Worker           output_ref[i * channels() + j] = std::min<float>(output_ref[i * channels() + j], float(qmax()));
178*4bdc9457SAndroid Build Coastguard Worker           output_ref[i * channels() + j] = std::max<float>(output_ref[i * channels() + j], float(qmin()));
179*4bdc9457SAndroid Build Coastguard Worker         }
180*4bdc9457SAndroid Build Coastguard Worker       }
181*4bdc9457SAndroid Build Coastguard Worker 
182*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Global Average Pooling operator.
183*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
184*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t global_average_pooling_op = nullptr;
185*4bdc9457SAndroid Build Coastguard Worker 
186*4bdc9457SAndroid Build Coastguard Worker       xnn_status status = xnn_create_global_average_pooling_nwc_qu8(
187*4bdc9457SAndroid Build Coastguard Worker           channels(), input_stride(), output_stride(),
188*4bdc9457SAndroid Build Coastguard Worker           input_zero_point(), input_scale(),
189*4bdc9457SAndroid Build Coastguard Worker           output_zero_point(), output_scale(),
190*4bdc9457SAndroid Build Coastguard Worker           qmin(), qmax(),
191*4bdc9457SAndroid Build Coastguard Worker           0, &global_average_pooling_op);
192*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
193*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
194*4bdc9457SAndroid Build Coastguard Worker       }
195*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
196*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, global_average_pooling_op);
197*4bdc9457SAndroid Build Coastguard Worker 
198*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete global_average_pooling_op.
199*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_global_average_pooling_op(global_average_pooling_op, xnn_delete_operator);
200*4bdc9457SAndroid Build Coastguard Worker 
201*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
202*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_global_average_pooling_nwc_qu8(
203*4bdc9457SAndroid Build Coastguard Worker           global_average_pooling_op,
204*4bdc9457SAndroid Build Coastguard Worker           batch_size(), width(),
205*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
206*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
207*4bdc9457SAndroid Build Coastguard Worker 
208*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
209*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(global_average_pooling_op, nullptr /* thread pool */));
210*4bdc9457SAndroid Build Coastguard Worker 
211*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
212*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
213*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
214*4bdc9457SAndroid Build Coastguard Worker           ASSERT_LE(uint32_t(output[i * output_stride() + c]), uint32_t(qmax()));
215*4bdc9457SAndroid Build Coastguard Worker           ASSERT_GE(uint32_t(output[i * output_stride() + c]), uint32_t(qmin()));
216*4bdc9457SAndroid Build Coastguard Worker           ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.80f)
217*4bdc9457SAndroid Build Coastguard Worker             << "at batch index " << i << " / " << batch_size()
218*4bdc9457SAndroid Build Coastguard Worker             << ", channel " << c << " / " << channels();
219*4bdc9457SAndroid Build Coastguard Worker         }
220*4bdc9457SAndroid Build Coastguard Worker       }
221*4bdc9457SAndroid Build Coastguard Worker     }
222*4bdc9457SAndroid Build Coastguard Worker   }
223*4bdc9457SAndroid Build Coastguard Worker 
TestNWCxQS8()224*4bdc9457SAndroid Build Coastguard Worker   void TestNWCxQS8() const {
225*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
226*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
227*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i8dist(
228*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
229*4bdc9457SAndroid Build Coastguard Worker 
230*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> input((batch_size() * width() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(int8_t));
231*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> output(batch_size() * output_stride());
232*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(batch_size() * channels());
233*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
234*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
235*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), INT8_C(0xA5));
236*4bdc9457SAndroid Build Coastguard Worker 
237*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
238*4bdc9457SAndroid Build Coastguard Worker       const double scale = double(input_scale()) / (double(width()) * double(output_scale()));
239*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
240*4bdc9457SAndroid Build Coastguard Worker         for (size_t j = 0; j < channels(); j++) {
241*4bdc9457SAndroid Build Coastguard Worker           double acc = 0.0f;
242*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < width(); k++) {
243*4bdc9457SAndroid Build Coastguard Worker             acc += double(int32_t(input[(i * width() + k) * input_stride() + j]) - int32_t(input_zero_point() - 0x80));
244*4bdc9457SAndroid Build Coastguard Worker           }
245*4bdc9457SAndroid Build Coastguard Worker           output_ref[i * channels() + j] = float(acc * scale + double(output_zero_point() - 0x80));
246*4bdc9457SAndroid Build Coastguard Worker           output_ref[i * channels() + j] = std::min<float>(output_ref[i * channels() + j], float(qmax() - 0x80));
247*4bdc9457SAndroid Build Coastguard Worker           output_ref[i * channels() + j] = std::max<float>(output_ref[i * channels() + j], float(qmin() - 0x80));
248*4bdc9457SAndroid Build Coastguard Worker         }
249*4bdc9457SAndroid Build Coastguard Worker       }
250*4bdc9457SAndroid Build Coastguard Worker 
251*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Global Average Pooling operator.
252*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
253*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t global_average_pooling_op = nullptr;
254*4bdc9457SAndroid Build Coastguard Worker 
255*4bdc9457SAndroid Build Coastguard Worker       xnn_status status = xnn_create_global_average_pooling_nwc_qs8(
256*4bdc9457SAndroid Build Coastguard Worker           channels(), input_stride(), output_stride(),
257*4bdc9457SAndroid Build Coastguard Worker           int8_t(input_zero_point() - 0x80), input_scale(),
258*4bdc9457SAndroid Build Coastguard Worker           int8_t(output_zero_point() - 0x80), output_scale(),
259*4bdc9457SAndroid Build Coastguard Worker           int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
260*4bdc9457SAndroid Build Coastguard Worker           0, &global_average_pooling_op);
261*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
262*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
263*4bdc9457SAndroid Build Coastguard Worker       }
264*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
265*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, global_average_pooling_op);
266*4bdc9457SAndroid Build Coastguard Worker 
267*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete global_average_pooling_op.
268*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_global_average_pooling_op(global_average_pooling_op, xnn_delete_operator);
269*4bdc9457SAndroid Build Coastguard Worker 
270*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
271*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_global_average_pooling_nwc_qs8(
272*4bdc9457SAndroid Build Coastguard Worker           global_average_pooling_op,
273*4bdc9457SAndroid Build Coastguard Worker           batch_size(), width(),
274*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
275*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
276*4bdc9457SAndroid Build Coastguard Worker 
277*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
278*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(global_average_pooling_op, nullptr /* thread pool */));
279*4bdc9457SAndroid Build Coastguard Worker 
280*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
281*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
282*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
283*4bdc9457SAndroid Build Coastguard Worker           ASSERT_LE(int32_t(output[i * output_stride() + c]), int32_t(qmax() - 0x80));
284*4bdc9457SAndroid Build Coastguard Worker           ASSERT_GE(int32_t(output[i * output_stride() + c]), int32_t(qmin() - 0x80));
285*4bdc9457SAndroid Build Coastguard Worker           ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.80f)
286*4bdc9457SAndroid Build Coastguard Worker             << "at batch index " << i << " / " << batch_size()
287*4bdc9457SAndroid Build Coastguard Worker             << ", channel " << c << " / " << channels();
288*4bdc9457SAndroid Build Coastguard Worker         }
289*4bdc9457SAndroid Build Coastguard Worker       }
290*4bdc9457SAndroid Build Coastguard Worker     }
291*4bdc9457SAndroid Build Coastguard Worker   }
292*4bdc9457SAndroid Build Coastguard Worker 
TestNWCxF16()293*4bdc9457SAndroid Build Coastguard Worker   void TestNWCxF16() const {
294*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
295*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
296*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(1.0e-3f, 1.0f);
297*4bdc9457SAndroid Build Coastguard Worker 
298*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> input((batch_size() * width() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
299*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> output(batch_size() * output_stride());
300*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(batch_size() * channels());
301*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
302*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
303*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
304*4bdc9457SAndroid Build Coastguard Worker 
305*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
306*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
307*4bdc9457SAndroid Build Coastguard Worker         for (size_t j = 0; j < channels(); j++) {
308*4bdc9457SAndroid Build Coastguard Worker           float acc = 0.0f;
309*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < width(); k++) {
310*4bdc9457SAndroid Build Coastguard Worker             acc += fp16_ieee_to_fp32_value(input[(i * width() + k) * input_stride() + j]);
311*4bdc9457SAndroid Build Coastguard Worker           }
312*4bdc9457SAndroid Build Coastguard Worker           output_ref[i * channels() + j] = acc / float(width());
313*4bdc9457SAndroid Build Coastguard Worker         }
314*4bdc9457SAndroid Build Coastguard Worker       }
315*4bdc9457SAndroid Build Coastguard Worker 
316*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
317*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
318*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
319*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
320*4bdc9457SAndroid Build Coastguard Worker       const float scaled_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + accumulated_range / 255.0f * float(qmin())));
321*4bdc9457SAndroid Build Coastguard Worker       const float scaled_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - accumulated_range / 255.0f * float(255 - qmax())));
322*4bdc9457SAndroid Build Coastguard Worker       const float output_min = scaled_min == scaled_max ? -std::numeric_limits<float>::infinity() : scaled_min;
323*4bdc9457SAndroid Build Coastguard Worker       const float output_max = scaled_min == scaled_max ? +std::numeric_limits<float>::infinity() : scaled_max;
324*4bdc9457SAndroid Build Coastguard Worker 
325*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
326*4bdc9457SAndroid Build Coastguard Worker       for (float& value : output_ref) {
327*4bdc9457SAndroid Build Coastguard Worker         value = std::max(std::min(value, output_max), output_min);
328*4bdc9457SAndroid Build Coastguard Worker       }
329*4bdc9457SAndroid Build Coastguard Worker 
330*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Global Average Pooling operator.
331*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
332*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t global_average_pooling_op = nullptr;
333*4bdc9457SAndroid Build Coastguard Worker 
334*4bdc9457SAndroid Build Coastguard Worker       xnn_status status = xnn_create_global_average_pooling_nwc_f16(
335*4bdc9457SAndroid Build Coastguard Worker           channels(), input_stride(), output_stride(),
336*4bdc9457SAndroid Build Coastguard Worker           output_min, output_max,
337*4bdc9457SAndroid Build Coastguard Worker           0, &global_average_pooling_op);
338*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
339*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
340*4bdc9457SAndroid Build Coastguard Worker       }
341*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
342*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, global_average_pooling_op);
343*4bdc9457SAndroid Build Coastguard Worker 
344*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete global_average_pooling_op.
345*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_global_average_pooling_op(global_average_pooling_op, xnn_delete_operator);
346*4bdc9457SAndroid Build Coastguard Worker 
347*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
348*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_global_average_pooling_nwc_f16(
349*4bdc9457SAndroid Build Coastguard Worker           global_average_pooling_op,
350*4bdc9457SAndroid Build Coastguard Worker           batch_size(), width(),
351*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
352*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
353*4bdc9457SAndroid Build Coastguard Worker 
354*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
355*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(global_average_pooling_op, nullptr /* thread pool */));
356*4bdc9457SAndroid Build Coastguard Worker 
357*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
358*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
359*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
360*4bdc9457SAndroid Build Coastguard Worker           ASSERT_LE(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_max);
361*4bdc9457SAndroid Build Coastguard Worker           ASSERT_GE(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_min);
362*4bdc9457SAndroid Build Coastguard Worker           ASSERT_NEAR(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_ref[i * channels() + c], std::max(1.0e-4f, std::abs(output_ref[i * channels() + c]) * 1.0e-2f))
363*4bdc9457SAndroid Build Coastguard Worker             << "at batch index " << i << " / " << batch_size()
364*4bdc9457SAndroid Build Coastguard Worker             << ", channel " << c << " / " << channels();
365*4bdc9457SAndroid Build Coastguard Worker         }
366*4bdc9457SAndroid Build Coastguard Worker       }
367*4bdc9457SAndroid Build Coastguard Worker     }
368*4bdc9457SAndroid Build Coastguard Worker   }
369*4bdc9457SAndroid Build Coastguard Worker 
TestNWCxF32()370*4bdc9457SAndroid Build Coastguard Worker   void TestNWCxF32() const {
371*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
372*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
373*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
374*4bdc9457SAndroid Build Coastguard Worker 
375*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input((batch_size() * width() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float));
376*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output(batch_size() * output_stride());
377*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(batch_size() * channels());
378*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
379*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
380*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), std::nanf(""));
381*4bdc9457SAndroid Build Coastguard Worker 
382*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
383*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
384*4bdc9457SAndroid Build Coastguard Worker         for (size_t j = 0; j < channels(); j++) {
385*4bdc9457SAndroid Build Coastguard Worker           float acc = 0.0f;
386*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < width(); k++) {
387*4bdc9457SAndroid Build Coastguard Worker             acc += input[(i * width() + k) * input_stride() + j];
388*4bdc9457SAndroid Build Coastguard Worker           }
389*4bdc9457SAndroid Build Coastguard Worker           output_ref[i * channels() + j] = acc / float(width());
390*4bdc9457SAndroid Build Coastguard Worker         }
391*4bdc9457SAndroid Build Coastguard Worker       }
392*4bdc9457SAndroid Build Coastguard Worker 
393*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
394*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
395*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
396*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
397*4bdc9457SAndroid Build Coastguard Worker       const float output_min = accumulated_range == 0.0f ?
398*4bdc9457SAndroid Build Coastguard Worker         -std::numeric_limits<float>::infinity() :
399*4bdc9457SAndroid Build Coastguard Worker         accumulated_min + accumulated_range / 255.0f * float(qmin());
400*4bdc9457SAndroid Build Coastguard Worker       const float output_max = accumulated_range == 0.0f ?
401*4bdc9457SAndroid Build Coastguard Worker         +std::numeric_limits<float>::infinity() :
402*4bdc9457SAndroid Build Coastguard Worker         accumulated_max - accumulated_range / 255.0f * float(255 - qmax());
403*4bdc9457SAndroid Build Coastguard Worker 
404*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
405*4bdc9457SAndroid Build Coastguard Worker       for (float& value : output_ref) {
406*4bdc9457SAndroid Build Coastguard Worker         value = std::max(std::min(value, output_max), output_min);
407*4bdc9457SAndroid Build Coastguard Worker       }
408*4bdc9457SAndroid Build Coastguard Worker 
409*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Global Average Pooling operator.
410*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
411*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t global_average_pooling_op = nullptr;
412*4bdc9457SAndroid Build Coastguard Worker 
413*4bdc9457SAndroid Build Coastguard Worker       xnn_status status = xnn_create_global_average_pooling_nwc_f32(
414*4bdc9457SAndroid Build Coastguard Worker           channels(), input_stride(), output_stride(),
415*4bdc9457SAndroid Build Coastguard Worker           output_min, output_max,
416*4bdc9457SAndroid Build Coastguard Worker           0, &global_average_pooling_op);
417*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
418*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
419*4bdc9457SAndroid Build Coastguard Worker       }
420*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
421*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, global_average_pooling_op);
422*4bdc9457SAndroid Build Coastguard Worker 
423*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete global_average_pooling_op.
424*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_global_average_pooling_op(global_average_pooling_op, xnn_delete_operator);
425*4bdc9457SAndroid Build Coastguard Worker 
426*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
427*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_global_average_pooling_nwc_f32(
428*4bdc9457SAndroid Build Coastguard Worker           global_average_pooling_op,
429*4bdc9457SAndroid Build Coastguard Worker           batch_size(), width(),
430*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
431*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
432*4bdc9457SAndroid Build Coastguard Worker 
433*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
434*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(global_average_pooling_op, nullptr /* thread pool */));
435*4bdc9457SAndroid Build Coastguard Worker 
436*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
437*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
438*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
439*4bdc9457SAndroid Build Coastguard Worker           ASSERT_LE(output[i * output_stride() + c], output_max);
440*4bdc9457SAndroid Build Coastguard Worker           ASSERT_GE(output[i * output_stride() + c], output_min);
441*4bdc9457SAndroid Build Coastguard Worker           ASSERT_NEAR(output[i * output_stride() + c], output_ref[i * channels() + c], std::abs(output_ref[i * channels() + c]) * 1.0e-6f)
442*4bdc9457SAndroid Build Coastguard Worker             << "at batch index " << i << " / " << batch_size()
443*4bdc9457SAndroid Build Coastguard Worker             << ", channel " << c << " / " << channels();
444*4bdc9457SAndroid Build Coastguard Worker         }
445*4bdc9457SAndroid Build Coastguard Worker       }
446*4bdc9457SAndroid Build Coastguard Worker     }
447*4bdc9457SAndroid Build Coastguard Worker   }
448*4bdc9457SAndroid Build Coastguard Worker 
TestNCWxF32()449*4bdc9457SAndroid Build Coastguard Worker   void TestNCWxF32() const {
450*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
451*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
452*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
453*4bdc9457SAndroid Build Coastguard Worker 
454*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input(batch_size() * channels() * width() + XNN_EXTRA_BYTES / sizeof(float));
455*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output(batch_size() * channels());
456*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(batch_size() * channels());
457*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
458*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
459*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), std::nanf(""));
460*4bdc9457SAndroid Build Coastguard Worker 
461*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
462*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
463*4bdc9457SAndroid Build Coastguard Worker         for (size_t j = 0; j < channels(); j++) {
464*4bdc9457SAndroid Build Coastguard Worker           float acc = 0.0f;
465*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < width(); k++) {
466*4bdc9457SAndroid Build Coastguard Worker             acc += input[(i * channels() + j) * width() + k];
467*4bdc9457SAndroid Build Coastguard Worker           }
468*4bdc9457SAndroid Build Coastguard Worker           output_ref[i * channels() + j] = acc / float(width());
469*4bdc9457SAndroid Build Coastguard Worker         }
470*4bdc9457SAndroid Build Coastguard Worker       }
471*4bdc9457SAndroid Build Coastguard Worker 
472*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
473*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
474*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
475*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
476*4bdc9457SAndroid Build Coastguard Worker       const float output_min = accumulated_range == 0.0f ?
477*4bdc9457SAndroid Build Coastguard Worker         -std::numeric_limits<float>::infinity() :
478*4bdc9457SAndroid Build Coastguard Worker         accumulated_min + accumulated_range / 255.0f * float(qmin());
479*4bdc9457SAndroid Build Coastguard Worker       const float output_max = accumulated_range == 0.0f ?
480*4bdc9457SAndroid Build Coastguard Worker         +std::numeric_limits<float>::infinity() :
481*4bdc9457SAndroid Build Coastguard Worker         accumulated_max - accumulated_range / 255.0f * float(255 - qmax());
482*4bdc9457SAndroid Build Coastguard Worker 
483*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
484*4bdc9457SAndroid Build Coastguard Worker       for (float& value : output_ref) {
485*4bdc9457SAndroid Build Coastguard Worker         value = std::max(std::min(value, output_max), output_min);
486*4bdc9457SAndroid Build Coastguard Worker       }
487*4bdc9457SAndroid Build Coastguard Worker 
488*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Global Average Pooling operator.
489*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
490*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t global_average_pooling_op = nullptr;
491*4bdc9457SAndroid Build Coastguard Worker 
492*4bdc9457SAndroid Build Coastguard Worker       xnn_status status = xnn_create_global_average_pooling_ncw_f32(
493*4bdc9457SAndroid Build Coastguard Worker         channels(), output_min, output_max,
494*4bdc9457SAndroid Build Coastguard Worker         0, &global_average_pooling_op);
495*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_parameter) {
496*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
497*4bdc9457SAndroid Build Coastguard Worker       }
498*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
499*4bdc9457SAndroid Build Coastguard Worker 
500*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete global_average_pooling_op.
501*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_global_average_pooling_op(global_average_pooling_op, xnn_delete_operator);
502*4bdc9457SAndroid Build Coastguard Worker 
503*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
504*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_global_average_pooling_ncw_f32(
505*4bdc9457SAndroid Build Coastguard Worker           global_average_pooling_op,
506*4bdc9457SAndroid Build Coastguard Worker           batch_size(), width(),
507*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
508*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
509*4bdc9457SAndroid Build Coastguard Worker 
510*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
511*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(global_average_pooling_op, nullptr /* thread pool */));
512*4bdc9457SAndroid Build Coastguard Worker 
513*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
514*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
515*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
516*4bdc9457SAndroid Build Coastguard Worker           ASSERT_LE(output[i * channels() + c], output_max);
517*4bdc9457SAndroid Build Coastguard Worker           ASSERT_GE(output[i * channels() + c], output_min);
518*4bdc9457SAndroid Build Coastguard Worker           ASSERT_NEAR(output[i * channels() + c], output_ref[i * channels() + c], std::abs(output_ref[i * channels() + c]) * 1.0e-5f)
519*4bdc9457SAndroid Build Coastguard Worker             << "at batch index " << i << " / " << batch_size()
520*4bdc9457SAndroid Build Coastguard Worker             << ", channel " << c << " / " << channels();
521*4bdc9457SAndroid Build Coastguard Worker         }
522*4bdc9457SAndroid Build Coastguard Worker       }
523*4bdc9457SAndroid Build Coastguard Worker     }
524*4bdc9457SAndroid Build Coastguard Worker   }
525*4bdc9457SAndroid Build Coastguard Worker 
526*4bdc9457SAndroid Build Coastguard Worker  private:
527*4bdc9457SAndroid Build Coastguard Worker   size_t batch_size_{1};
528*4bdc9457SAndroid Build Coastguard Worker   size_t width_{1};
529*4bdc9457SAndroid Build Coastguard Worker   size_t channels_{1};
530*4bdc9457SAndroid Build Coastguard Worker   size_t input_stride_{0};
531*4bdc9457SAndroid Build Coastguard Worker   size_t output_stride_{0};
532*4bdc9457SAndroid Build Coastguard Worker   float input_scale_{1.0f};
533*4bdc9457SAndroid Build Coastguard Worker   float output_scale_{1.0f};
534*4bdc9457SAndroid Build Coastguard Worker   uint8_t input_zero_point_{121};
535*4bdc9457SAndroid Build Coastguard Worker   uint8_t output_zero_point_{133};
536*4bdc9457SAndroid Build Coastguard Worker   uint8_t qmin_{0};
537*4bdc9457SAndroid Build Coastguard Worker   uint8_t qmax_{255};
538*4bdc9457SAndroid Build Coastguard Worker   size_t iterations_{1};
539*4bdc9457SAndroid Build Coastguard Worker };
540