xref: /aosp_15_r20/external/XNNPACK/test/ibilinear-microkernel-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #pragma once
7*4bdc9457SAndroid Build Coastguard Worker 
8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
9*4bdc9457SAndroid Build Coastguard Worker 
10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
11*4bdc9457SAndroid Build Coastguard Worker #include <cassert>
12*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
13*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
14*4bdc9457SAndroid Build Coastguard Worker #include <cstdint>
15*4bdc9457SAndroid Build Coastguard Worker #include <random>
16*4bdc9457SAndroid Build Coastguard Worker #include <vector>
17*4bdc9457SAndroid Build Coastguard Worker 
18*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
19*4bdc9457SAndroid Build Coastguard Worker 
20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
21*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/aligned-allocator.h>
22*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h>
23*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/math.h>
24*4bdc9457SAndroid Build Coastguard Worker 
25*4bdc9457SAndroid Build Coastguard Worker 
26*4bdc9457SAndroid Build Coastguard Worker class IBilinearMicrokernelTester {
27*4bdc9457SAndroid Build Coastguard Worker  public:
pixels(uint32_t pixels)28*4bdc9457SAndroid Build Coastguard Worker   inline IBilinearMicrokernelTester& pixels(uint32_t pixels) {
29*4bdc9457SAndroid Build Coastguard Worker     assert(pixels >= 1);
30*4bdc9457SAndroid Build Coastguard Worker     this->pixels_ = pixels;
31*4bdc9457SAndroid Build Coastguard Worker     return *this;
32*4bdc9457SAndroid Build Coastguard Worker   }
33*4bdc9457SAndroid Build Coastguard Worker 
pixels()34*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t pixels() const {
35*4bdc9457SAndroid Build Coastguard Worker     return this->pixels_;
36*4bdc9457SAndroid Build Coastguard Worker   }
37*4bdc9457SAndroid Build Coastguard Worker 
channels(uint32_t channels)38*4bdc9457SAndroid Build Coastguard Worker   inline IBilinearMicrokernelTester& channels(uint32_t channels) {
39*4bdc9457SAndroid Build Coastguard Worker     assert(channels >= 1);
40*4bdc9457SAndroid Build Coastguard Worker     this->channels_ = channels;
41*4bdc9457SAndroid Build Coastguard Worker     return *this;
42*4bdc9457SAndroid Build Coastguard Worker   }
43*4bdc9457SAndroid Build Coastguard Worker 
channels()44*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t channels() const {
45*4bdc9457SAndroid Build Coastguard Worker     return this->channels_;
46*4bdc9457SAndroid Build Coastguard Worker   }
47*4bdc9457SAndroid Build Coastguard Worker 
input_offset(uint32_t input_offset)48*4bdc9457SAndroid Build Coastguard Worker   inline IBilinearMicrokernelTester& input_offset(uint32_t input_offset) {
49*4bdc9457SAndroid Build Coastguard Worker     this->input_offset_ = input_offset;
50*4bdc9457SAndroid Build Coastguard Worker     return *this;
51*4bdc9457SAndroid Build Coastguard Worker   }
52*4bdc9457SAndroid Build Coastguard Worker 
input_offset()53*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t input_offset() const {
54*4bdc9457SAndroid Build Coastguard Worker     return this->input_offset_;
55*4bdc9457SAndroid Build Coastguard Worker   }
56*4bdc9457SAndroid Build Coastguard Worker 
output_stride(uint32_t output_stride)57*4bdc9457SAndroid Build Coastguard Worker   inline IBilinearMicrokernelTester& output_stride(uint32_t output_stride) {
58*4bdc9457SAndroid Build Coastguard Worker     assert(output_stride != 0);
59*4bdc9457SAndroid Build Coastguard Worker     this->output_stride_ = output_stride;
60*4bdc9457SAndroid Build Coastguard Worker     return *this;
61*4bdc9457SAndroid Build Coastguard Worker   }
62*4bdc9457SAndroid Build Coastguard Worker 
output_stride()63*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t output_stride() const {
64*4bdc9457SAndroid Build Coastguard Worker     if (this->output_stride_ == 0) {
65*4bdc9457SAndroid Build Coastguard Worker       return channels();
66*4bdc9457SAndroid Build Coastguard Worker     } else {
67*4bdc9457SAndroid Build Coastguard Worker       assert(this->output_stride_ >= channels());
68*4bdc9457SAndroid Build Coastguard Worker       return this->output_stride_;
69*4bdc9457SAndroid Build Coastguard Worker     }
70*4bdc9457SAndroid Build Coastguard Worker   }
71*4bdc9457SAndroid Build Coastguard Worker 
iterations(size_t iterations)72*4bdc9457SAndroid Build Coastguard Worker   inline IBilinearMicrokernelTester& iterations(size_t iterations) {
73*4bdc9457SAndroid Build Coastguard Worker     this->iterations_ = iterations;
74*4bdc9457SAndroid Build Coastguard Worker     return *this;
75*4bdc9457SAndroid Build Coastguard Worker   }
76*4bdc9457SAndroid Build Coastguard Worker 
iterations()77*4bdc9457SAndroid Build Coastguard Worker   inline size_t iterations() const {
78*4bdc9457SAndroid Build Coastguard Worker     return this->iterations_;
79*4bdc9457SAndroid Build Coastguard Worker   }
80*4bdc9457SAndroid Build Coastguard Worker 
input_stride(uint32_t input_stride)81*4bdc9457SAndroid Build Coastguard Worker   inline IBilinearMicrokernelTester& input_stride(uint32_t input_stride) {
82*4bdc9457SAndroid Build Coastguard Worker     assert(input_stride != 0);
83*4bdc9457SAndroid Build Coastguard Worker     this->input_stride_ = input_stride;
84*4bdc9457SAndroid Build Coastguard Worker     return *this;
85*4bdc9457SAndroid Build Coastguard Worker   }
86*4bdc9457SAndroid Build Coastguard Worker 
input_stride()87*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t input_stride() const {
88*4bdc9457SAndroid Build Coastguard Worker     if (this->input_stride_ == 0) {
89*4bdc9457SAndroid Build Coastguard Worker       return 4 * pixels();
90*4bdc9457SAndroid Build Coastguard Worker     } else {
91*4bdc9457SAndroid Build Coastguard Worker       assert(this->input_stride_ >= 4 * pixels());
92*4bdc9457SAndroid Build Coastguard Worker       return this->input_stride_;
93*4bdc9457SAndroid Build Coastguard Worker     }
94*4bdc9457SAndroid Build Coastguard Worker   }
95*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f16_ibilinear_ukernel_function ibilinear)96*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f16_ibilinear_ukernel_function ibilinear) const {
97*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
98*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
99*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(0.1f, 1.0f);
100*4bdc9457SAndroid Build Coastguard Worker 
101*4bdc9457SAndroid Build Coastguard Worker     std::vector<const uint16_t*> indirection(pixels() * 4);
102*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + indirection.size() * channels());
103*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> packed_weights(pixels() * 2);
104*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> output((pixels() - 1) * output_stride() + channels());
105*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(pixels() * channels());
106*4bdc9457SAndroid Build Coastguard Worker 
107*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
108*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
109*4bdc9457SAndroid Build Coastguard Worker       std::generate(packed_weights.begin(), packed_weights.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
110*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
111*4bdc9457SAndroid Build Coastguard Worker 
112*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < indirection.size(); i++) {
113*4bdc9457SAndroid Build Coastguard Worker         indirection[i] = input.data() + i * channels() - input_offset();
114*4bdc9457SAndroid Build Coastguard Worker       }
115*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirection.begin(), indirection.end(), rng);
116*4bdc9457SAndroid Build Coastguard Worker 
117*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
118*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < pixels(); i++) {
119*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
120*4bdc9457SAndroid Build Coastguard Worker           const float alpha_h = fp16_ieee_to_fp32_value(packed_weights[i * 2 + 0]);
121*4bdc9457SAndroid Build Coastguard Worker           const float alpha_v = fp16_ieee_to_fp32_value(packed_weights[i * 2 + 1]);
122*4bdc9457SAndroid Build Coastguard Worker           output_ref[i * channels() + c] =
123*4bdc9457SAndroid Build Coastguard Worker             fp16_ieee_to_fp32_value(indirection[i * 4 + 0][c + input_offset()]) * (1.0f - alpha_h) * (1.0f - alpha_v) +
124*4bdc9457SAndroid Build Coastguard Worker             fp16_ieee_to_fp32_value(indirection[i * 4 + 1][c + input_offset()]) * alpha_h * (1.0f - alpha_v) +
125*4bdc9457SAndroid Build Coastguard Worker             fp16_ieee_to_fp32_value(indirection[i * 4 + 2][c + input_offset()]) * (1.0f - alpha_h) * alpha_v +
126*4bdc9457SAndroid Build Coastguard Worker             fp16_ieee_to_fp32_value(indirection[i * 4 + 3][c + input_offset()]) * alpha_h * alpha_v;
127*4bdc9457SAndroid Build Coastguard Worker         }
128*4bdc9457SAndroid Build Coastguard Worker       }
129*4bdc9457SAndroid Build Coastguard Worker 
130*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
131*4bdc9457SAndroid Build Coastguard Worker       ibilinear(
132*4bdc9457SAndroid Build Coastguard Worker         pixels(), channels() * sizeof(uint16_t),
133*4bdc9457SAndroid Build Coastguard Worker         reinterpret_cast<const void**>(indirection.data()), input_offset() * sizeof(uint16_t),
134*4bdc9457SAndroid Build Coastguard Worker         packed_weights.data(), output.data(),
135*4bdc9457SAndroid Build Coastguard Worker         (output_stride() - channels()) * sizeof(uint16_t));
136*4bdc9457SAndroid Build Coastguard Worker 
137*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
138*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < pixels(); i++) {
139*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
140*4bdc9457SAndroid Build Coastguard Worker           ASSERT_NEAR(
141*4bdc9457SAndroid Build Coastguard Worker               fp16_ieee_to_fp32_value(output[i * output_stride() + c]),
142*4bdc9457SAndroid Build Coastguard Worker               output_ref[i * channels() + c],
143*4bdc9457SAndroid Build Coastguard Worker               std::abs(output_ref[i * channels() + c]) * 1.0e-2f)
144*4bdc9457SAndroid Build Coastguard Worker             << "pixel " << i << " / " << pixels() << ", channel " << c << " / " << channels();
145*4bdc9457SAndroid Build Coastguard Worker         }
146*4bdc9457SAndroid Build Coastguard Worker       }
147*4bdc9457SAndroid Build Coastguard Worker     }
148*4bdc9457SAndroid Build Coastguard Worker   }
149*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f32_ibilinear_ukernel_function ibilinear)150*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_ibilinear_ukernel_function ibilinear) const {
151*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
152*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
153*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
154*4bdc9457SAndroid Build Coastguard Worker 
155*4bdc9457SAndroid Build Coastguard Worker     std::vector<const float*> indirection(pixels() * 4);
156*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + indirection.size() * channels());
157*4bdc9457SAndroid Build Coastguard Worker     std::vector<float, AlignedAllocator<float, 64>> packed_weights(pixels() * 2);
158*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output((pixels() - 1) * output_stride() + channels());
159*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(pixels() * channels());
160*4bdc9457SAndroid Build Coastguard Worker 
161*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
162*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
163*4bdc9457SAndroid Build Coastguard Worker       std::generate(packed_weights.begin(), packed_weights.end(), [&]() { return f32dist(rng); });
164*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), nanf(""));
165*4bdc9457SAndroid Build Coastguard Worker 
166*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < indirection.size(); i++) {
167*4bdc9457SAndroid Build Coastguard Worker         indirection[i] = input.data() + i * channels() - input_offset();
168*4bdc9457SAndroid Build Coastguard Worker       }
169*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirection.begin(), indirection.end(), rng);
170*4bdc9457SAndroid Build Coastguard Worker 
171*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
172*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < pixels(); i++) {
173*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
174*4bdc9457SAndroid Build Coastguard Worker           const float alpha_h = packed_weights[i * 2 + 0];
175*4bdc9457SAndroid Build Coastguard Worker           const float alpha_v = packed_weights[i * 2 + 1];
176*4bdc9457SAndroid Build Coastguard Worker           output_ref[i * channels() + c] =
177*4bdc9457SAndroid Build Coastguard Worker             indirection[i * 4 + 0][c + input_offset()] * (1.0f - alpha_h) * (1.0f - alpha_v) +
178*4bdc9457SAndroid Build Coastguard Worker             indirection[i * 4 + 1][c + input_offset()] * alpha_h * (1.0f - alpha_v) +
179*4bdc9457SAndroid Build Coastguard Worker             indirection[i * 4 + 2][c + input_offset()] * (1.0f - alpha_h) * alpha_v +
180*4bdc9457SAndroid Build Coastguard Worker             indirection[i * 4 + 3][c + input_offset()] * alpha_h * alpha_v;
181*4bdc9457SAndroid Build Coastguard Worker         }
182*4bdc9457SAndroid Build Coastguard Worker       }
183*4bdc9457SAndroid Build Coastguard Worker 
184*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
185*4bdc9457SAndroid Build Coastguard Worker       ibilinear(
186*4bdc9457SAndroid Build Coastguard Worker         pixels(), channels() * sizeof(float),
187*4bdc9457SAndroid Build Coastguard Worker         indirection.data(), input_offset() * sizeof(float),
188*4bdc9457SAndroid Build Coastguard Worker         packed_weights.data(), output.data(),
189*4bdc9457SAndroid Build Coastguard Worker         (output_stride() - channels()) * sizeof(float));
190*4bdc9457SAndroid Build Coastguard Worker 
191*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
192*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < pixels(); i++) {
193*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
194*4bdc9457SAndroid Build Coastguard Worker           ASSERT_NEAR(
195*4bdc9457SAndroid Build Coastguard Worker               output_ref[i * channels() + c],
196*4bdc9457SAndroid Build Coastguard Worker               output[i * output_stride() + c],
197*4bdc9457SAndroid Build Coastguard Worker               std::abs(output_ref[i * channels() + c]) * 1.0e-4)
198*4bdc9457SAndroid Build Coastguard Worker             << "pixel " << i << " / " << pixels() << ", channel " << c << " / " << channels();
199*4bdc9457SAndroid Build Coastguard Worker         }
200*4bdc9457SAndroid Build Coastguard Worker       }
201*4bdc9457SAndroid Build Coastguard Worker     }
202*4bdc9457SAndroid Build Coastguard Worker   }
203*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_s8_ibilinear_ukernel_function ibilinear)204*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_s8_ibilinear_ukernel_function ibilinear) const {
205*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
206*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
207*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i8dist(
208*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
209*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int16_t> w11dist(0, 2047);
210*4bdc9457SAndroid Build Coastguard Worker 
211*4bdc9457SAndroid Build Coastguard Worker     std::vector<const int8_t*> indirection(pixels() * 4);
212*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) + indirection.size() * channels());
213*4bdc9457SAndroid Build Coastguard Worker     std::vector<int16_t, AlignedAllocator<int16_t, 64>> packed_weights(pixels() * 2);
214*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> output((pixels() - 1) * output_stride() + channels());
215*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> output_ref(pixels() * channels());
216*4bdc9457SAndroid Build Coastguard Worker 
217*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
218*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
219*4bdc9457SAndroid Build Coastguard Worker       std::generate(packed_weights.begin(), packed_weights.end(), [&]() { return w11dist(rng); });
220*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), INT8_C(0xFA));
221*4bdc9457SAndroid Build Coastguard Worker 
222*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < indirection.size(); i++) {
223*4bdc9457SAndroid Build Coastguard Worker         indirection[i] = input.data() + i * channels() - input_offset();
224*4bdc9457SAndroid Build Coastguard Worker       }
225*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirection.begin(), indirection.end(), rng);
226*4bdc9457SAndroid Build Coastguard Worker 
227*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
228*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < pixels(); i++) {
229*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
230*4bdc9457SAndroid Build Coastguard Worker           const int32_t alpha_h = packed_weights[i * 2 + 0];
231*4bdc9457SAndroid Build Coastguard Worker           const int32_t alpha_v = packed_weights[i * 2 + 1];
232*4bdc9457SAndroid Build Coastguard Worker           const int32_t acc = math_asr_s32(
233*4bdc9457SAndroid Build Coastguard Worker             int32_t(indirection[i * 4 + 0][c + input_offset()]) * (2048 - alpha_h) * (2048 - alpha_v) +
234*4bdc9457SAndroid Build Coastguard Worker             int32_t(indirection[i * 4 + 1][c + input_offset()]) * alpha_h * (2048 - alpha_v) +
235*4bdc9457SAndroid Build Coastguard Worker             int32_t(indirection[i * 4 + 2][c + input_offset()]) * (2048 - alpha_h) * alpha_v +
236*4bdc9457SAndroid Build Coastguard Worker             int32_t(indirection[i * 4 + 3][c + input_offset()]) * alpha_h * alpha_v +
237*4bdc9457SAndroid Build Coastguard Worker             2097152, 22);
238*4bdc9457SAndroid Build Coastguard Worker           ASSERT_GE(acc, std::numeric_limits<int8_t>::min());
239*4bdc9457SAndroid Build Coastguard Worker           ASSERT_LE(acc, std::numeric_limits<int8_t>::max());
240*4bdc9457SAndroid Build Coastguard Worker           output_ref[i * channels() + c] = (int8_t) acc;
241*4bdc9457SAndroid Build Coastguard Worker         }
242*4bdc9457SAndroid Build Coastguard Worker       }
243*4bdc9457SAndroid Build Coastguard Worker 
244*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
245*4bdc9457SAndroid Build Coastguard Worker       ibilinear(
246*4bdc9457SAndroid Build Coastguard Worker         pixels(), channels() * sizeof(int8_t),
247*4bdc9457SAndroid Build Coastguard Worker         indirection.data(), input_offset() * sizeof(int8_t),
248*4bdc9457SAndroid Build Coastguard Worker         packed_weights.data(), output.data(),
249*4bdc9457SAndroid Build Coastguard Worker         (output_stride() - channels()) * sizeof(int8_t));
250*4bdc9457SAndroid Build Coastguard Worker 
251*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
252*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < pixels(); i++) {
253*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
254*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(int32_t(output_ref[i * channels() + c]), int32_t(output[i * output_stride() + c]))
255*4bdc9457SAndroid Build Coastguard Worker             << "pixel " << i << " / " << pixels() << ", channel " << c << " / " << channels();
256*4bdc9457SAndroid Build Coastguard Worker         }
257*4bdc9457SAndroid Build Coastguard Worker       }
258*4bdc9457SAndroid Build Coastguard Worker     }
259*4bdc9457SAndroid Build Coastguard Worker   }
260*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_u8_ibilinear_ukernel_function ibilinear)261*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_u8_ibilinear_ukernel_function ibilinear) const {
262*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
263*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
264*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> u8dist(
265*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
266*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int16_t> w11dist(0, 2047);
267*4bdc9457SAndroid Build Coastguard Worker 
268*4bdc9457SAndroid Build Coastguard Worker     std::vector<const uint8_t*> indirection(pixels() * 4);
269*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + indirection.size() * channels());
270*4bdc9457SAndroid Build Coastguard Worker     std::vector<int16_t, AlignedAllocator<int16_t, 64>> packed_weights(pixels() * 2);
271*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output((pixels() - 1) * output_stride() + channels());
272*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output_ref(pixels() * channels());
273*4bdc9457SAndroid Build Coastguard Worker 
274*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
275*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
276*4bdc9457SAndroid Build Coastguard Worker       std::generate(packed_weights.begin(), packed_weights.end(), [&]() { return w11dist(rng); });
277*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT8_C(0xFA));
278*4bdc9457SAndroid Build Coastguard Worker 
279*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < indirection.size(); i++) {
280*4bdc9457SAndroid Build Coastguard Worker         indirection[i] = input.data() + i * channels() - input_offset();
281*4bdc9457SAndroid Build Coastguard Worker       }
282*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirection.begin(), indirection.end(), rng);
283*4bdc9457SAndroid Build Coastguard Worker 
284*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
285*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < pixels(); i++) {
286*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
287*4bdc9457SAndroid Build Coastguard Worker           const uint32_t alpha_h = uint32_t(int32_t(packed_weights[i * 2 + 0]));
288*4bdc9457SAndroid Build Coastguard Worker           const uint32_t alpha_v = uint32_t(int32_t(packed_weights[i * 2 + 1]));
289*4bdc9457SAndroid Build Coastguard Worker           const uint32_t acc = (2097152 +
290*4bdc9457SAndroid Build Coastguard Worker             int32_t(indirection[i * 4 + 0][c + input_offset()]) * (2048 - alpha_h) * (2048 - alpha_v) +
291*4bdc9457SAndroid Build Coastguard Worker             int32_t(indirection[i * 4 + 1][c + input_offset()]) * alpha_h * (2048 - alpha_v) +
292*4bdc9457SAndroid Build Coastguard Worker             int32_t(indirection[i * 4 + 2][c + input_offset()]) * (2048 - alpha_h) * alpha_v +
293*4bdc9457SAndroid Build Coastguard Worker             int32_t(indirection[i * 4 + 3][c + input_offset()]) * alpha_h * alpha_v) >> 22;
294*4bdc9457SAndroid Build Coastguard Worker           ASSERT_LE(acc, std::numeric_limits<uint8_t>::max());
295*4bdc9457SAndroid Build Coastguard Worker           output_ref[i * channels() + c] = (uint8_t) acc;
296*4bdc9457SAndroid Build Coastguard Worker         }
297*4bdc9457SAndroid Build Coastguard Worker       }
298*4bdc9457SAndroid Build Coastguard Worker 
299*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
300*4bdc9457SAndroid Build Coastguard Worker       ibilinear(
301*4bdc9457SAndroid Build Coastguard Worker         pixels(), channels() * sizeof(uint8_t),
302*4bdc9457SAndroid Build Coastguard Worker         indirection.data(), input_offset() * sizeof(uint8_t),
303*4bdc9457SAndroid Build Coastguard Worker         packed_weights.data(), output.data(),
304*4bdc9457SAndroid Build Coastguard Worker         (output_stride() - channels()) * sizeof(uint8_t));
305*4bdc9457SAndroid Build Coastguard Worker 
306*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
307*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < pixels(); i++) {
308*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
309*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(uint32_t(output_ref[i * channels() + c]), uint32_t(output[i * output_stride() + c]))
310*4bdc9457SAndroid Build Coastguard Worker             << "pixel " << i << " / " << pixels() << ", channel " << c << " / " << channels();
311*4bdc9457SAndroid Build Coastguard Worker         }
312*4bdc9457SAndroid Build Coastguard Worker       }
313*4bdc9457SAndroid Build Coastguard Worker     }
314*4bdc9457SAndroid Build Coastguard Worker   }
315*4bdc9457SAndroid Build Coastguard Worker 
TestCHW(xnn_f16_ibilinear_chw_ukernel_function ibilinear)316*4bdc9457SAndroid Build Coastguard Worker   void TestCHW(xnn_f16_ibilinear_chw_ukernel_function ibilinear) const {
317*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
318*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
319*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(0.1f, 1.0f);
320*4bdc9457SAndroid Build Coastguard Worker 
321*4bdc9457SAndroid Build Coastguard Worker     std::vector<const uint16_t*> indirection(pixels() * 2);
322*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + (channels() - 1) * input_stride() + 4 * pixels());
323*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> packed_weights(pixels() * 2);
324*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> output(pixels() * channels());
325*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(pixels() * channels());
326*4bdc9457SAndroid Build Coastguard Worker 
327*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
328*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
329*4bdc9457SAndroid Build Coastguard Worker       std::generate(packed_weights.begin(), packed_weights.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
330*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
331*4bdc9457SAndroid Build Coastguard Worker 
332*4bdc9457SAndroid Build Coastguard Worker       // Indirection will point to the even ("left") pixels of the input.
333*4bdc9457SAndroid Build Coastguard Worker       // The kernels will expect "right" pixels to be placed right next to them.
334*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < indirection.size(); i++) {
335*4bdc9457SAndroid Build Coastguard Worker         const uint16_t* left_corner = input.data() + 2 * i - input_offset();
336*4bdc9457SAndroid Build Coastguard Worker         indirection[i] = left_corner;
337*4bdc9457SAndroid Build Coastguard Worker       }
338*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirection.begin(), indirection.end(), rng);
339*4bdc9457SAndroid Build Coastguard Worker 
340*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
341*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < pixels(); i++) {
342*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
343*4bdc9457SAndroid Build Coastguard Worker           const float alpha_h = fp16_ieee_to_fp32_value(packed_weights[i * 2 + 0]);
344*4bdc9457SAndroid Build Coastguard Worker           const float alpha_v = fp16_ieee_to_fp32_value(packed_weights[i * 2 + 1]);
345*4bdc9457SAndroid Build Coastguard Worker           // `c * pixels() + i` because the output is NCHW.
346*4bdc9457SAndroid Build Coastguard Worker           output_ref[c * pixels() + i] =
347*4bdc9457SAndroid Build Coastguard Worker             // `c * indirection.size()` because the input is NCHW.
348*4bdc9457SAndroid Build Coastguard Worker             fp16_ieee_to_fp32_value((indirection[i * 2 + 0] + 0)[c * input_stride() + input_offset()]) * (1.0f - alpha_h) * (1.0f - alpha_v) +
349*4bdc9457SAndroid Build Coastguard Worker             fp16_ieee_to_fp32_value((indirection[i * 2 + 0] + 1)[c * input_stride() + input_offset()]) * alpha_h * (1.0f - alpha_v) +
350*4bdc9457SAndroid Build Coastguard Worker             fp16_ieee_to_fp32_value((indirection[i * 2 + 1] + 0)[c * input_stride() + input_offset()]) * (1.0f - alpha_h) * alpha_v +
351*4bdc9457SAndroid Build Coastguard Worker             fp16_ieee_to_fp32_value((indirection[i * 2 + 1] + 1)[c * input_stride() + input_offset()]) * alpha_h * alpha_v;
352*4bdc9457SAndroid Build Coastguard Worker         }
353*4bdc9457SAndroid Build Coastguard Worker       }
354*4bdc9457SAndroid Build Coastguard Worker 
355*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
356*4bdc9457SAndroid Build Coastguard Worker       ibilinear(
357*4bdc9457SAndroid Build Coastguard Worker         pixels(), channels(),
358*4bdc9457SAndroid Build Coastguard Worker         reinterpret_cast<const void**>(indirection.data()), input_offset() * sizeof(uint16_t),
359*4bdc9457SAndroid Build Coastguard Worker         packed_weights.data(), output.data(), input_stride() * sizeof(uint16_t));
360*4bdc9457SAndroid Build Coastguard Worker 
361*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
362*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < channels(); c++) {
363*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < pixels(); i++) {
364*4bdc9457SAndroid Build Coastguard Worker           ASSERT_NEAR(
365*4bdc9457SAndroid Build Coastguard Worker               fp16_ieee_to_fp32_value(output[c * pixels() + i]),
366*4bdc9457SAndroid Build Coastguard Worker               output_ref[c * pixels() + i],
367*4bdc9457SAndroid Build Coastguard Worker               std::abs(output_ref[c * pixels() + i]) * 1.0e-2f)
368*4bdc9457SAndroid Build Coastguard Worker             << "i = " << i << ", channel = " << c;
369*4bdc9457SAndroid Build Coastguard Worker         }
370*4bdc9457SAndroid Build Coastguard Worker       }
371*4bdc9457SAndroid Build Coastguard Worker     }
372*4bdc9457SAndroid Build Coastguard Worker   }
373*4bdc9457SAndroid Build Coastguard Worker 
TestCHW(xnn_f32_ibilinear_chw_ukernel_function ibilinear)374*4bdc9457SAndroid Build Coastguard Worker   void TestCHW(xnn_f32_ibilinear_chw_ukernel_function ibilinear) const {
375*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
376*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
377*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist;
378*4bdc9457SAndroid Build Coastguard Worker 
379*4bdc9457SAndroid Build Coastguard Worker     std::vector<const float*> indirection(pixels() * 2);
380*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + (channels() - 1) * input_stride() + 4 * pixels());
381*4bdc9457SAndroid Build Coastguard Worker     std::vector<float, AlignedAllocator<float, 64>> packed_weights(pixels() * 2);
382*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output(pixels() * channels());
383*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(pixels() * channels());
384*4bdc9457SAndroid Build Coastguard Worker 
385*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
386*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
387*4bdc9457SAndroid Build Coastguard Worker       std::generate(packed_weights.begin(), packed_weights.end(), [&]() { return f32dist(rng); });
388*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), nanf(""));
389*4bdc9457SAndroid Build Coastguard Worker 
390*4bdc9457SAndroid Build Coastguard Worker       // Indirection will point to the even ("left") pixels of the input.
391*4bdc9457SAndroid Build Coastguard Worker       // The kernels will expect "right" pixels to be placed right next to them.
392*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < indirection.size(); i++) {
393*4bdc9457SAndroid Build Coastguard Worker         const float* left_corner = input.data() + 2 * i - input_offset();
394*4bdc9457SAndroid Build Coastguard Worker         indirection[i] = left_corner;
395*4bdc9457SAndroid Build Coastguard Worker       }
396*4bdc9457SAndroid Build Coastguard Worker       std::shuffle(indirection.begin(), indirection.end(), rng);
397*4bdc9457SAndroid Build Coastguard Worker 
398*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
399*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < pixels(); i++) {
400*4bdc9457SAndroid Build Coastguard Worker         for (size_t c = 0; c < channels(); c++) {
401*4bdc9457SAndroid Build Coastguard Worker           const float alpha_h = packed_weights[i * 2 + 0];
402*4bdc9457SAndroid Build Coastguard Worker           const float alpha_v = packed_weights[i * 2 + 1];
403*4bdc9457SAndroid Build Coastguard Worker           // `c * pixels() + i` because the output is NCHW.
404*4bdc9457SAndroid Build Coastguard Worker           output_ref[c * pixels() + i] =
405*4bdc9457SAndroid Build Coastguard Worker             // `c * indirection.size()` because the input is NCHW.
406*4bdc9457SAndroid Build Coastguard Worker             (indirection[i * 2 + 0] + 0)[c * input_stride() + input_offset()] * (1.0f - alpha_h) * (1.0f - alpha_v) +
407*4bdc9457SAndroid Build Coastguard Worker             (indirection[i * 2 + 0] + 1)[c * input_stride() + input_offset()] * alpha_h * (1.0f - alpha_v) +
408*4bdc9457SAndroid Build Coastguard Worker             (indirection[i * 2 + 1] + 0)[c * input_stride() + input_offset()] * (1.0f - alpha_h) * alpha_v +
409*4bdc9457SAndroid Build Coastguard Worker             (indirection[i * 2 + 1] + 1)[c * input_stride() + input_offset()] * alpha_h * alpha_v;
410*4bdc9457SAndroid Build Coastguard Worker         }
411*4bdc9457SAndroid Build Coastguard Worker       }
412*4bdc9457SAndroid Build Coastguard Worker 
413*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
414*4bdc9457SAndroid Build Coastguard Worker       ibilinear(
415*4bdc9457SAndroid Build Coastguard Worker         pixels(), channels(),
416*4bdc9457SAndroid Build Coastguard Worker         indirection.data(), input_offset() * sizeof(float),
417*4bdc9457SAndroid Build Coastguard Worker         packed_weights.data(), output.data(), input_stride() * sizeof(float));
418*4bdc9457SAndroid Build Coastguard Worker 
419*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
420*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < channels(); c++) {
421*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < pixels(); i++) {
422*4bdc9457SAndroid Build Coastguard Worker           ASSERT_NEAR(
423*4bdc9457SAndroid Build Coastguard Worker               output_ref[c * pixels() + i],
424*4bdc9457SAndroid Build Coastguard Worker               output[c * pixels() + i],
425*4bdc9457SAndroid Build Coastguard Worker               std::abs(output_ref[c * pixels() + i]) * 1.0e-4)
426*4bdc9457SAndroid Build Coastguard Worker             << "i = " << i << ", channel = " << c;
427*4bdc9457SAndroid Build Coastguard Worker         }
428*4bdc9457SAndroid Build Coastguard Worker       }
429*4bdc9457SAndroid Build Coastguard Worker     }
430*4bdc9457SAndroid Build Coastguard Worker   }
431*4bdc9457SAndroid Build Coastguard Worker 
432*4bdc9457SAndroid Build Coastguard Worker  private:
433*4bdc9457SAndroid Build Coastguard Worker   uint32_t channels_{1};
434*4bdc9457SAndroid Build Coastguard Worker   uint32_t pixels_{1};
435*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_stride_{0};
436*4bdc9457SAndroid Build Coastguard Worker   uint32_t input_stride_{0};
437*4bdc9457SAndroid Build Coastguard Worker   uint32_t input_offset_{0};
438*4bdc9457SAndroid Build Coastguard Worker   size_t iterations_{3};
439*4bdc9457SAndroid Build Coastguard Worker };
440