xref: /aosp_15_r20/external/XNNPACK/test/argmaxpool-microkernel-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #pragma once
7*4bdc9457SAndroid Build Coastguard Worker 
8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
9*4bdc9457SAndroid Build Coastguard Worker 
10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
11*4bdc9457SAndroid Build Coastguard Worker #include <cassert>
12*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
13*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
14*4bdc9457SAndroid Build Coastguard Worker #include <random>
15*4bdc9457SAndroid Build Coastguard Worker #include <vector>
16*4bdc9457SAndroid Build Coastguard Worker 
17*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
18*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/aligned-allocator.h>
19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h>
20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microparams-init.h>
21*4bdc9457SAndroid Build Coastguard Worker 
22*4bdc9457SAndroid Build Coastguard Worker 
23*4bdc9457SAndroid Build Coastguard Worker class ArgMaxPoolMicrokernelTester {
24*4bdc9457SAndroid Build Coastguard Worker  public:
25*4bdc9457SAndroid Build Coastguard Worker   enum class Variant {
26*4bdc9457SAndroid Build Coastguard Worker     Native,
27*4bdc9457SAndroid Build Coastguard Worker     Scalar,
28*4bdc9457SAndroid Build Coastguard Worker   };
29*4bdc9457SAndroid Build Coastguard Worker 
output_pixels(size_t output_pixels)30*4bdc9457SAndroid Build Coastguard Worker   inline ArgMaxPoolMicrokernelTester& output_pixels(size_t output_pixels) {
31*4bdc9457SAndroid Build Coastguard Worker     assert(output_pixels != 0);
32*4bdc9457SAndroid Build Coastguard Worker     this->output_pixels_ = output_pixels;
33*4bdc9457SAndroid Build Coastguard Worker     return *this;
34*4bdc9457SAndroid Build Coastguard Worker   }
35*4bdc9457SAndroid Build Coastguard Worker 
output_pixels()36*4bdc9457SAndroid Build Coastguard Worker   inline size_t output_pixels() const {
37*4bdc9457SAndroid Build Coastguard Worker     return this->output_pixels_;
38*4bdc9457SAndroid Build Coastguard Worker   }
39*4bdc9457SAndroid Build Coastguard Worker 
step(size_t step)40*4bdc9457SAndroid Build Coastguard Worker   inline ArgMaxPoolMicrokernelTester& step(size_t step) {
41*4bdc9457SAndroid Build Coastguard Worker     assert(step != 0);
42*4bdc9457SAndroid Build Coastguard Worker     this->step_ = step;
43*4bdc9457SAndroid Build Coastguard Worker     return *this;
44*4bdc9457SAndroid Build Coastguard Worker   }
45*4bdc9457SAndroid Build Coastguard Worker 
step()46*4bdc9457SAndroid Build Coastguard Worker   inline size_t step() const {
47*4bdc9457SAndroid Build Coastguard Worker     return this->step_;
48*4bdc9457SAndroid Build Coastguard Worker   }
49*4bdc9457SAndroid Build Coastguard Worker 
input_offset(size_t input_offset)50*4bdc9457SAndroid Build Coastguard Worker   inline ArgMaxPoolMicrokernelTester& input_offset(size_t input_offset) {
51*4bdc9457SAndroid Build Coastguard Worker     assert(input_offset != 0);
52*4bdc9457SAndroid Build Coastguard Worker     this->input_offset_ = input_offset;
53*4bdc9457SAndroid Build Coastguard Worker     return *this;
54*4bdc9457SAndroid Build Coastguard Worker   }
55*4bdc9457SAndroid Build Coastguard Worker 
input_offset()56*4bdc9457SAndroid Build Coastguard Worker   inline size_t input_offset() const {
57*4bdc9457SAndroid Build Coastguard Worker     return this->input_offset_;
58*4bdc9457SAndroid Build Coastguard Worker   }
59*4bdc9457SAndroid Build Coastguard Worker 
pooling_elements(size_t pooling_elements)60*4bdc9457SAndroid Build Coastguard Worker   inline ArgMaxPoolMicrokernelTester& pooling_elements(size_t pooling_elements) {
61*4bdc9457SAndroid Build Coastguard Worker     assert(pooling_elements != 0);
62*4bdc9457SAndroid Build Coastguard Worker     this->pooling_elements_ = pooling_elements;
63*4bdc9457SAndroid Build Coastguard Worker     return *this;
64*4bdc9457SAndroid Build Coastguard Worker   }
65*4bdc9457SAndroid Build Coastguard Worker 
pooling_elements()66*4bdc9457SAndroid Build Coastguard Worker   inline size_t pooling_elements() const {
67*4bdc9457SAndroid Build Coastguard Worker     return this->pooling_elements_;
68*4bdc9457SAndroid Build Coastguard Worker   }
69*4bdc9457SAndroid Build Coastguard Worker 
packed_pooling_elements()70*4bdc9457SAndroid Build Coastguard Worker   inline size_t packed_pooling_elements() const {
71*4bdc9457SAndroid Build Coastguard Worker     if (pooling_elements() <= primary_pooling_tile()) {
72*4bdc9457SAndroid Build Coastguard Worker       return primary_pooling_tile();
73*4bdc9457SAndroid Build Coastguard Worker     } else {
74*4bdc9457SAndroid Build Coastguard Worker       return (pooling_elements() - primary_pooling_tile()) % incremental_pooling_tile() == 0 ? pooling_elements() : ((pooling_elements() - primary_pooling_tile()) / incremental_pooling_tile() + 1) * incremental_pooling_tile() + primary_pooling_tile();
75*4bdc9457SAndroid Build Coastguard Worker     }
76*4bdc9457SAndroid Build Coastguard Worker   }
77*4bdc9457SAndroid Build Coastguard Worker 
pooling_tile(size_t primary_tile)78*4bdc9457SAndroid Build Coastguard Worker   inline ArgMaxPoolMicrokernelTester& pooling_tile(size_t primary_tile) {
79*4bdc9457SAndroid Build Coastguard Worker     assert(primary_tile != 0);
80*4bdc9457SAndroid Build Coastguard Worker     this->primary_pooling_tile_ = primary_tile;
81*4bdc9457SAndroid Build Coastguard Worker     this->incremental_pooling_tile_ = 0;
82*4bdc9457SAndroid Build Coastguard Worker     return *this;
83*4bdc9457SAndroid Build Coastguard Worker   }
84*4bdc9457SAndroid Build Coastguard Worker 
pooling_tile(size_t primary_tile,size_t incremental_tile)85*4bdc9457SAndroid Build Coastguard Worker   inline ArgMaxPoolMicrokernelTester& pooling_tile(size_t primary_tile, size_t incremental_tile) {
86*4bdc9457SAndroid Build Coastguard Worker     assert(primary_tile != 0);
87*4bdc9457SAndroid Build Coastguard Worker     this->primary_pooling_tile_ = primary_tile;
88*4bdc9457SAndroid Build Coastguard Worker     this->incremental_pooling_tile_ = incremental_tile;
89*4bdc9457SAndroid Build Coastguard Worker     return *this;
90*4bdc9457SAndroid Build Coastguard Worker   }
91*4bdc9457SAndroid Build Coastguard Worker 
primary_pooling_tile(size_t primary_pooling_tile)92*4bdc9457SAndroid Build Coastguard Worker   inline ArgMaxPoolMicrokernelTester& primary_pooling_tile(size_t primary_pooling_tile) {
93*4bdc9457SAndroid Build Coastguard Worker     assert(primary_pooling_tile != 0);
94*4bdc9457SAndroid Build Coastguard Worker     this->primary_pooling_tile_ = primary_pooling_tile;
95*4bdc9457SAndroid Build Coastguard Worker     return *this;
96*4bdc9457SAndroid Build Coastguard Worker   }
97*4bdc9457SAndroid Build Coastguard Worker 
primary_pooling_tile()98*4bdc9457SAndroid Build Coastguard Worker   inline size_t primary_pooling_tile() const {
99*4bdc9457SAndroid Build Coastguard Worker     return this->primary_pooling_tile_;
100*4bdc9457SAndroid Build Coastguard Worker   }
101*4bdc9457SAndroid Build Coastguard Worker 
incremental_pooling_tile(size_t incremental_pooling_tile)102*4bdc9457SAndroid Build Coastguard Worker   inline ArgMaxPoolMicrokernelTester& incremental_pooling_tile(size_t incremental_pooling_tile) {
103*4bdc9457SAndroid Build Coastguard Worker     assert(incremental_pooling_tile != 0);
104*4bdc9457SAndroid Build Coastguard Worker     this->incremental_pooling_tile_ = incremental_pooling_tile;
105*4bdc9457SAndroid Build Coastguard Worker     return *this;
106*4bdc9457SAndroid Build Coastguard Worker   }
107*4bdc9457SAndroid Build Coastguard Worker 
incremental_pooling_tile()108*4bdc9457SAndroid Build Coastguard Worker   inline size_t incremental_pooling_tile() const {
109*4bdc9457SAndroid Build Coastguard Worker     return this->incremental_pooling_tile_;
110*4bdc9457SAndroid Build Coastguard Worker   }
111*4bdc9457SAndroid Build Coastguard Worker 
channels(size_t channels)112*4bdc9457SAndroid Build Coastguard Worker   inline ArgMaxPoolMicrokernelTester& channels(size_t channels) {
113*4bdc9457SAndroid Build Coastguard Worker     assert(channels != 0);
114*4bdc9457SAndroid Build Coastguard Worker     this->channels_ = channels;
115*4bdc9457SAndroid Build Coastguard Worker     return *this;
116*4bdc9457SAndroid Build Coastguard Worker   }
117*4bdc9457SAndroid Build Coastguard Worker 
channels()118*4bdc9457SAndroid Build Coastguard Worker   inline size_t channels() const {
119*4bdc9457SAndroid Build Coastguard Worker     return this->channels_;
120*4bdc9457SAndroid Build Coastguard Worker   }
121*4bdc9457SAndroid Build Coastguard Worker 
output_stride(size_t output_stride)122*4bdc9457SAndroid Build Coastguard Worker   inline ArgMaxPoolMicrokernelTester& output_stride(size_t output_stride) {
123*4bdc9457SAndroid Build Coastguard Worker     assert(output_stride != 0);
124*4bdc9457SAndroid Build Coastguard Worker     this->output_stride_ = output_stride;
125*4bdc9457SAndroid Build Coastguard Worker     return *this;
126*4bdc9457SAndroid Build Coastguard Worker   }
127*4bdc9457SAndroid Build Coastguard Worker 
output_stride()128*4bdc9457SAndroid Build Coastguard Worker   inline size_t output_stride() const {
129*4bdc9457SAndroid Build Coastguard Worker     if (this->output_stride_ == 0) {
130*4bdc9457SAndroid Build Coastguard Worker       return channels();
131*4bdc9457SAndroid Build Coastguard Worker     } else {
132*4bdc9457SAndroid Build Coastguard Worker       assert(this->output_stride_ >= channels());
133*4bdc9457SAndroid Build Coastguard Worker       return this->output_stride_;
134*4bdc9457SAndroid Build Coastguard Worker     }
135*4bdc9457SAndroid Build Coastguard Worker   }
136*4bdc9457SAndroid Build Coastguard Worker 
iterations(size_t iterations)137*4bdc9457SAndroid Build Coastguard Worker   inline ArgMaxPoolMicrokernelTester& iterations(size_t iterations) {
138*4bdc9457SAndroid Build Coastguard Worker     this->iterations_ = iterations;
139*4bdc9457SAndroid Build Coastguard Worker     return *this;
140*4bdc9457SAndroid Build Coastguard Worker   }
141*4bdc9457SAndroid Build Coastguard Worker 
iterations()142*4bdc9457SAndroid Build Coastguard Worker   inline size_t iterations() const {
143*4bdc9457SAndroid Build Coastguard Worker     return this->iterations_;
144*4bdc9457SAndroid Build Coastguard Worker   }
145*4bdc9457SAndroid Build Coastguard Worker 
146*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_argmaxpool_unipass_ukernel_function argmaxpool, Variant variant = Variant::Native) const {
147*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
148*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
149*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
150*4bdc9457SAndroid Build Coastguard Worker 
151*4bdc9457SAndroid Build Coastguard Worker     std::vector<const float*> indirect_input((output_pixels() - 1) * step() + packed_pooling_elements());
152*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
153*4bdc9457SAndroid Build Coastguard Worker       ((output_pixels() - 1) * step() + pooling_elements()) * channels());
154*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output((output_pixels() - 1) * output_stride() + channels());
155*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint32_t> index(output_pixels() * channels());
156*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(output_pixels() * channels());
157*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint32_t> index_ref(output_pixels() * channels());
158*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
159*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
160*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), nanf(""));
161*4bdc9457SAndroid Build Coastguard Worker 
162*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < (output_pixels() - 1) * step() + pooling_elements(); i++) {
163*4bdc9457SAndroid Build Coastguard Worker         indirect_input[i] = input.data() + i * channels() - input_offset();
164*4bdc9457SAndroid Build Coastguard Worker       }
165*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirect_input.begin(),
166*4bdc9457SAndroid Build Coastguard Worker         indirect_input.begin() + (output_pixels() - 1) * step() + pooling_elements(), rng);
167*4bdc9457SAndroid Build Coastguard Worker 
168*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
169*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
170*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
171*4bdc9457SAndroid Build Coastguard Worker           float max_value = indirect_input[x * step()][c + input_offset()];
172*4bdc9457SAndroid Build Coastguard Worker           uint32_t max_index = 0;
173*4bdc9457SAndroid Build Coastguard Worker           for (size_t p = 0; p < pooling_elements(); p++) {
174*4bdc9457SAndroid Build Coastguard Worker             const float value = indirect_input[x * step() + p][c + input_offset()];
175*4bdc9457SAndroid Build Coastguard Worker             if (value > max_value) {
176*4bdc9457SAndroid Build Coastguard Worker               max_value = value;
177*4bdc9457SAndroid Build Coastguard Worker               max_index = p;
178*4bdc9457SAndroid Build Coastguard Worker             }
179*4bdc9457SAndroid Build Coastguard Worker           }
180*4bdc9457SAndroid Build Coastguard Worker           output_ref[x * channels() + c] = max_value;
181*4bdc9457SAndroid Build Coastguard Worker           index_ref[x * channels() + c] = max_index;
182*4bdc9457SAndroid Build Coastguard Worker         }
183*4bdc9457SAndroid Build Coastguard Worker       }
184*4bdc9457SAndroid Build Coastguard Worker 
185*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
186*4bdc9457SAndroid Build Coastguard Worker       argmaxpool(output_pixels(), pooling_elements(), channels(),
187*4bdc9457SAndroid Build Coastguard Worker         indirect_input.data(), input_offset() * sizeof(float), output.data(), index.data(),
188*4bdc9457SAndroid Build Coastguard Worker         step() * sizeof(void*),
189*4bdc9457SAndroid Build Coastguard Worker         (output_stride() - channels()) * sizeof(float));
190*4bdc9457SAndroid Build Coastguard Worker 
191*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
192*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
193*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
194*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(output_ref[x * channels() + c], output[x * output_stride() + c])
195*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
196*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
197*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
198*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(
199*4bdc9457SAndroid Build Coastguard Worker               indirect_input[x * step() + index_ref[x * channels() + c]][c + input_offset()],
200*4bdc9457SAndroid Build Coastguard Worker               indirect_input[x * step() + index[x * channels() + c]][c + input_offset()])
201*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
202*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
203*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
204*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(index_ref[x * channels() + c], index[x * channels() + c])
205*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
206*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
207*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
208*4bdc9457SAndroid Build Coastguard Worker         }
209*4bdc9457SAndroid Build Coastguard Worker       }
210*4bdc9457SAndroid Build Coastguard Worker     }
211*4bdc9457SAndroid Build Coastguard Worker   }
212*4bdc9457SAndroid Build Coastguard Worker 
213*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_argmaxpool_multipass_ukernel_function argmaxpool, Variant variant = Variant::Native) const {
214*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
215*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
216*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
217*4bdc9457SAndroid Build Coastguard Worker 
218*4bdc9457SAndroid Build Coastguard Worker     std::vector<const float*> indirect_input((output_pixels() - 1) * step() + packed_pooling_elements());
219*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
220*4bdc9457SAndroid Build Coastguard Worker       ((output_pixels() - 1) * step() + pooling_elements()) * channels());
221*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output((output_pixels() - 1) * output_stride() + channels());
222*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint32_t> index(output_pixels() * channels());
223*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint32_t, AlignedAllocator<uint32_t, 64>> index_buffer(
224*4bdc9457SAndroid Build Coastguard Worker       channels() + XNN_EXTRA_BYTES / sizeof(uint32_t));
225*4bdc9457SAndroid Build Coastguard Worker     std::vector<float, AlignedAllocator<float, 64>> output_buffer(
226*4bdc9457SAndroid Build Coastguard Worker       channels() + XNN_EXTRA_BYTES / sizeof(float));
227*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(output_pixels() * channels());
228*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint32_t> index_ref(output_pixels() * channels());
229*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
230*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
231*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), nanf(""));
232*4bdc9457SAndroid Build Coastguard Worker 
233*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < (output_pixels() - 1) * step() + pooling_elements(); i++) {
234*4bdc9457SAndroid Build Coastguard Worker         indirect_input[i] = input.data() + i * channels() - input_offset();
235*4bdc9457SAndroid Build Coastguard Worker       }
236*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirect_input.begin(),
237*4bdc9457SAndroid Build Coastguard Worker         indirect_input.begin() + (output_pixels() - 1) * step() + pooling_elements(), rng);
238*4bdc9457SAndroid Build Coastguard Worker 
239*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
240*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
241*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
242*4bdc9457SAndroid Build Coastguard Worker           float max_value = indirect_input[x * step()][c + input_offset()];
243*4bdc9457SAndroid Build Coastguard Worker           uint32_t max_index = 0;
244*4bdc9457SAndroid Build Coastguard Worker           for (size_t p = 0; p < pooling_elements(); p++) {
245*4bdc9457SAndroid Build Coastguard Worker             const float value = indirect_input[x * step() + p][c + input_offset()];
246*4bdc9457SAndroid Build Coastguard Worker             if (value > max_value) {
247*4bdc9457SAndroid Build Coastguard Worker               max_value = value;
248*4bdc9457SAndroid Build Coastguard Worker               max_index = p;
249*4bdc9457SAndroid Build Coastguard Worker             }
250*4bdc9457SAndroid Build Coastguard Worker           }
251*4bdc9457SAndroid Build Coastguard Worker           output_ref[x * channels() + c] = max_value;
252*4bdc9457SAndroid Build Coastguard Worker           index_ref[x * channels() + c] = max_index;
253*4bdc9457SAndroid Build Coastguard Worker         }
254*4bdc9457SAndroid Build Coastguard Worker       }
255*4bdc9457SAndroid Build Coastguard Worker 
256*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
257*4bdc9457SAndroid Build Coastguard Worker       argmaxpool(output_pixels(), pooling_elements(), channels(),
258*4bdc9457SAndroid Build Coastguard Worker         indirect_input.data(), input_offset() * sizeof(float),
259*4bdc9457SAndroid Build Coastguard Worker         output_buffer.data(), index_buffer.data(),
260*4bdc9457SAndroid Build Coastguard Worker         output.data(), index.data(),
261*4bdc9457SAndroid Build Coastguard Worker         (step() - (packed_pooling_elements() - incremental_pooling_tile())) * sizeof(void*),
262*4bdc9457SAndroid Build Coastguard Worker         (output_stride() - channels()) * sizeof(float));
263*4bdc9457SAndroid Build Coastguard Worker 
264*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
265*4bdc9457SAndroid Build Coastguard Worker       for (size_t x = 0; x < output_pixels(); x++) {
266*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
267*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(output_ref[x * channels() + c], output[x * output_stride() + c])
268*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
269*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
270*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
271*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(
272*4bdc9457SAndroid Build Coastguard Worker               indirect_input[x * step() + index_ref[x * channels() + c]][c + input_offset()],
273*4bdc9457SAndroid Build Coastguard Worker               indirect_input[x * step() + index[x * channels() + c]][c + input_offset()])
274*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
275*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
276*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
277*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(index_ref[x * channels() + c], index[x * channels() + c])
278*4bdc9457SAndroid Build Coastguard Worker             << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels()
279*4bdc9457SAndroid Build Coastguard Worker             << ", pooling elements = " << pooling_elements() << ", step = " << step()
280*4bdc9457SAndroid Build Coastguard Worker             << ", input offset = " << input_offset();
281*4bdc9457SAndroid Build Coastguard Worker         }
282*4bdc9457SAndroid Build Coastguard Worker       }
283*4bdc9457SAndroid Build Coastguard Worker     }
284*4bdc9457SAndroid Build Coastguard Worker   }
285*4bdc9457SAndroid Build Coastguard Worker 
286*4bdc9457SAndroid Build Coastguard Worker  private:
287*4bdc9457SAndroid Build Coastguard Worker   size_t output_pixels_{1};
288*4bdc9457SAndroid Build Coastguard Worker   size_t pooling_elements_{1};
289*4bdc9457SAndroid Build Coastguard Worker   size_t channels_{1};
290*4bdc9457SAndroid Build Coastguard Worker   size_t input_offset_{0};
291*4bdc9457SAndroid Build Coastguard Worker   size_t step_{1};
292*4bdc9457SAndroid Build Coastguard Worker   size_t primary_pooling_tile_{1};
293*4bdc9457SAndroid Build Coastguard Worker   size_t incremental_pooling_tile_{1};
294*4bdc9457SAndroid Build Coastguard Worker   size_t output_stride_{0};
295*4bdc9457SAndroid Build Coastguard Worker   size_t iterations_{3};
296*4bdc9457SAndroid Build Coastguard Worker };
297