xref: /aosp_15_r20/external/XNNPACK/test/vunary-microkernel-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #pragma once
7*4bdc9457SAndroid Build Coastguard Worker 
8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
9*4bdc9457SAndroid Build Coastguard Worker 
10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
11*4bdc9457SAndroid Build Coastguard Worker #include <cassert>
12*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
13*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
14*4bdc9457SAndroid Build Coastguard Worker #include <random>
15*4bdc9457SAndroid Build Coastguard Worker #include <vector>
16*4bdc9457SAndroid Build Coastguard Worker 
17*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
18*4bdc9457SAndroid Build Coastguard Worker 
19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h>
21*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microparams-init.h>
22*4bdc9457SAndroid Build Coastguard Worker 
23*4bdc9457SAndroid Build Coastguard Worker 
24*4bdc9457SAndroid Build Coastguard Worker class VUnaryMicrokernelTester {
25*4bdc9457SAndroid Build Coastguard Worker  public:
26*4bdc9457SAndroid Build Coastguard Worker   enum class OpType {
27*4bdc9457SAndroid Build Coastguard Worker     ReLU,
28*4bdc9457SAndroid Build Coastguard Worker     RoundToNearestEven,
29*4bdc9457SAndroid Build Coastguard Worker     RoundTowardsZero,
30*4bdc9457SAndroid Build Coastguard Worker     RoundUp,
31*4bdc9457SAndroid Build Coastguard Worker     RoundDown,
32*4bdc9457SAndroid Build Coastguard Worker   };
33*4bdc9457SAndroid Build Coastguard Worker 
34*4bdc9457SAndroid Build Coastguard Worker   enum class Variant {
35*4bdc9457SAndroid Build Coastguard Worker     Native,
36*4bdc9457SAndroid Build Coastguard Worker     Scalar,
37*4bdc9457SAndroid Build Coastguard Worker   };
38*4bdc9457SAndroid Build Coastguard Worker 
batch_size(size_t batch_size)39*4bdc9457SAndroid Build Coastguard Worker   inline VUnaryMicrokernelTester& batch_size(size_t batch_size) {
40*4bdc9457SAndroid Build Coastguard Worker     assert(batch_size != 0);
41*4bdc9457SAndroid Build Coastguard Worker     this->batch_size_ = batch_size;
42*4bdc9457SAndroid Build Coastguard Worker     return *this;
43*4bdc9457SAndroid Build Coastguard Worker   }
44*4bdc9457SAndroid Build Coastguard Worker 
batch_size()45*4bdc9457SAndroid Build Coastguard Worker   inline size_t batch_size() const {
46*4bdc9457SAndroid Build Coastguard Worker     return this->batch_size_;
47*4bdc9457SAndroid Build Coastguard Worker   }
48*4bdc9457SAndroid Build Coastguard Worker 
inplace(bool inplace)49*4bdc9457SAndroid Build Coastguard Worker   inline VUnaryMicrokernelTester& inplace(bool inplace) {
50*4bdc9457SAndroid Build Coastguard Worker     this->inplace_ = inplace;
51*4bdc9457SAndroid Build Coastguard Worker     return *this;
52*4bdc9457SAndroid Build Coastguard Worker   }
53*4bdc9457SAndroid Build Coastguard Worker 
inplace()54*4bdc9457SAndroid Build Coastguard Worker   inline bool inplace() const {
55*4bdc9457SAndroid Build Coastguard Worker     return this->inplace_;
56*4bdc9457SAndroid Build Coastguard Worker   }
57*4bdc9457SAndroid Build Coastguard Worker 
slope(float slope)58*4bdc9457SAndroid Build Coastguard Worker   inline VUnaryMicrokernelTester& slope(float slope) {
59*4bdc9457SAndroid Build Coastguard Worker     this->slope_ = slope;
60*4bdc9457SAndroid Build Coastguard Worker     return *this;
61*4bdc9457SAndroid Build Coastguard Worker   }
62*4bdc9457SAndroid Build Coastguard Worker 
slope()63*4bdc9457SAndroid Build Coastguard Worker   inline float slope() const {
64*4bdc9457SAndroid Build Coastguard Worker     return this->slope_;
65*4bdc9457SAndroid Build Coastguard Worker   }
66*4bdc9457SAndroid Build Coastguard Worker 
prescale(float prescale)67*4bdc9457SAndroid Build Coastguard Worker   inline VUnaryMicrokernelTester& prescale(float prescale) {
68*4bdc9457SAndroid Build Coastguard Worker     this->prescale_ = prescale;
69*4bdc9457SAndroid Build Coastguard Worker     return *this;
70*4bdc9457SAndroid Build Coastguard Worker   }
71*4bdc9457SAndroid Build Coastguard Worker 
prescale()72*4bdc9457SAndroid Build Coastguard Worker   inline float prescale() const {
73*4bdc9457SAndroid Build Coastguard Worker     return this->prescale_;
74*4bdc9457SAndroid Build Coastguard Worker   }
75*4bdc9457SAndroid Build Coastguard Worker 
alpha(float alpha)76*4bdc9457SAndroid Build Coastguard Worker   inline VUnaryMicrokernelTester& alpha(float alpha) {
77*4bdc9457SAndroid Build Coastguard Worker     this->alpha_ = alpha;
78*4bdc9457SAndroid Build Coastguard Worker     return *this;
79*4bdc9457SAndroid Build Coastguard Worker   }
80*4bdc9457SAndroid Build Coastguard Worker 
alpha()81*4bdc9457SAndroid Build Coastguard Worker   inline float alpha() const {
82*4bdc9457SAndroid Build Coastguard Worker     return this->alpha_;
83*4bdc9457SAndroid Build Coastguard Worker   }
84*4bdc9457SAndroid Build Coastguard Worker 
beta(float beta)85*4bdc9457SAndroid Build Coastguard Worker   inline VUnaryMicrokernelTester& beta(float beta) {
86*4bdc9457SAndroid Build Coastguard Worker     this->beta_ = beta;
87*4bdc9457SAndroid Build Coastguard Worker     return *this;
88*4bdc9457SAndroid Build Coastguard Worker   }
89*4bdc9457SAndroid Build Coastguard Worker 
beta()90*4bdc9457SAndroid Build Coastguard Worker   inline float beta() const {
91*4bdc9457SAndroid Build Coastguard Worker     return this->beta_;
92*4bdc9457SAndroid Build Coastguard Worker   }
93*4bdc9457SAndroid Build Coastguard Worker 
shift(uint32_t shift)94*4bdc9457SAndroid Build Coastguard Worker   inline VUnaryMicrokernelTester& shift(uint32_t shift) {
95*4bdc9457SAndroid Build Coastguard Worker     this->shift_ = shift;
96*4bdc9457SAndroid Build Coastguard Worker     return *this;
97*4bdc9457SAndroid Build Coastguard Worker   }
98*4bdc9457SAndroid Build Coastguard Worker 
shift()99*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t shift() const {
100*4bdc9457SAndroid Build Coastguard Worker     return this->shift_;
101*4bdc9457SAndroid Build Coastguard Worker   }
102*4bdc9457SAndroid Build Coastguard Worker 
qmin(uint8_t qmin)103*4bdc9457SAndroid Build Coastguard Worker   inline VUnaryMicrokernelTester& qmin(uint8_t qmin) {
104*4bdc9457SAndroid Build Coastguard Worker     this->qmin_ = qmin;
105*4bdc9457SAndroid Build Coastguard Worker     return *this;
106*4bdc9457SAndroid Build Coastguard Worker   }
107*4bdc9457SAndroid Build Coastguard Worker 
qmin()108*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t qmin() const {
109*4bdc9457SAndroid Build Coastguard Worker     return this->qmin_;
110*4bdc9457SAndroid Build Coastguard Worker   }
111*4bdc9457SAndroid Build Coastguard Worker 
qmax(uint8_t qmax)112*4bdc9457SAndroid Build Coastguard Worker   inline VUnaryMicrokernelTester& qmax(uint8_t qmax) {
113*4bdc9457SAndroid Build Coastguard Worker     this->qmax_ = qmax;
114*4bdc9457SAndroid Build Coastguard Worker     return *this;
115*4bdc9457SAndroid Build Coastguard Worker   }
116*4bdc9457SAndroid Build Coastguard Worker 
qmax()117*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t qmax() const {
118*4bdc9457SAndroid Build Coastguard Worker     return this->qmax_;
119*4bdc9457SAndroid Build Coastguard Worker   }
120*4bdc9457SAndroid Build Coastguard Worker 
iterations(size_t iterations)121*4bdc9457SAndroid Build Coastguard Worker   inline VUnaryMicrokernelTester& iterations(size_t iterations) {
122*4bdc9457SAndroid Build Coastguard Worker     this->iterations_ = iterations;
123*4bdc9457SAndroid Build Coastguard Worker     return *this;
124*4bdc9457SAndroid Build Coastguard Worker   }
125*4bdc9457SAndroid Build Coastguard Worker 
iterations()126*4bdc9457SAndroid Build Coastguard Worker   inline size_t iterations() const {
127*4bdc9457SAndroid Build Coastguard Worker     return this->iterations_;
128*4bdc9457SAndroid Build Coastguard Worker   }
129*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f32_vrelu_ukernel_function vrelu)130*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_vrelu_ukernel_function vrelu) const {
131*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
132*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
133*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(-1.0f, 1.0f);
134*4bdc9457SAndroid Build Coastguard Worker 
135*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float));
136*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0));
137*4bdc9457SAndroid Build Coastguard Worker     std::vector<double> y_ref(batch_size());
138*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
139*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
140*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); });
141*4bdc9457SAndroid Build Coastguard Worker       } else {
142*4bdc9457SAndroid Build Coastguard Worker         std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); });
143*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), nanf(""));
144*4bdc9457SAndroid Build Coastguard Worker       }
145*4bdc9457SAndroid Build Coastguard Worker       const float* x_data = inplace() ? y.data() : x.data();
146*4bdc9457SAndroid Build Coastguard Worker 
147*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
148*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
149*4bdc9457SAndroid Build Coastguard Worker         y_ref[i] = std::max(x_data[i], 0.0f);
150*4bdc9457SAndroid Build Coastguard Worker       }
151*4bdc9457SAndroid Build Coastguard Worker 
152*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
153*4bdc9457SAndroid Build Coastguard Worker       vrelu(batch_size() * sizeof(float), x_data, y.data(), nullptr);
154*4bdc9457SAndroid Build Coastguard Worker 
155*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
156*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
157*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(y[i], y_ref[i])
158*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i];
159*4bdc9457SAndroid Build Coastguard Worker       }
160*4bdc9457SAndroid Build Coastguard Worker     }
161*4bdc9457SAndroid Build Coastguard Worker   }
162*4bdc9457SAndroid Build Coastguard Worker 
163*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f16_vabs_ukernel_function vabs, xnn_init_f16_abs_params_fn init_params = nullptr) const {
164*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
165*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
166*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(-1.0f, 1.0f);
167*4bdc9457SAndroid Build Coastguard Worker 
168*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t));
169*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0));
170*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> y_ref(batch_size());
171*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
172*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
173*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
174*4bdc9457SAndroid Build Coastguard Worker       } else {
175*4bdc9457SAndroid Build Coastguard Worker         std::generate(x.begin(), x.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
176*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */);
177*4bdc9457SAndroid Build Coastguard Worker       }
178*4bdc9457SAndroid Build Coastguard Worker       const uint16_t* x_data = inplace() ? y.data() : x.data();
179*4bdc9457SAndroid Build Coastguard Worker 
180*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
181*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
182*4bdc9457SAndroid Build Coastguard Worker         y_ref[i] = x_data[i] & UINT16_C(0x7FFF);
183*4bdc9457SAndroid Build Coastguard Worker       }
184*4bdc9457SAndroid Build Coastguard Worker 
185*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
186*4bdc9457SAndroid Build Coastguard Worker       union xnn_f16_abs_params params;
187*4bdc9457SAndroid Build Coastguard Worker       if (init_params != nullptr) {
188*4bdc9457SAndroid Build Coastguard Worker         init_params(&params);
189*4bdc9457SAndroid Build Coastguard Worker       }
190*4bdc9457SAndroid Build Coastguard Worker 
191*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
192*4bdc9457SAndroid Build Coastguard Worker       vabs(batch_size() * sizeof(uint16_t), x_data, y.data(), &params);
193*4bdc9457SAndroid Build Coastguard Worker 
194*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
195*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
196*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(y[i], y_ref[i])
197*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i];
198*4bdc9457SAndroid Build Coastguard Worker       }
199*4bdc9457SAndroid Build Coastguard Worker     }
200*4bdc9457SAndroid Build Coastguard Worker   }
201*4bdc9457SAndroid Build Coastguard Worker 
202*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_vabs_ukernel_function vabs, xnn_init_f32_abs_params_fn init_params = nullptr) const {
203*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
204*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
205*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(-1.0f, 1.0f);
206*4bdc9457SAndroid Build Coastguard Worker 
207*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float));
208*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0));
209*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y_ref(batch_size());
210*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
211*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
212*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); });
213*4bdc9457SAndroid Build Coastguard Worker       } else {
214*4bdc9457SAndroid Build Coastguard Worker         std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); });
215*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), nanf(""));
216*4bdc9457SAndroid Build Coastguard Worker       }
217*4bdc9457SAndroid Build Coastguard Worker       const float* x_data = inplace() ? y.data() : x.data();
218*4bdc9457SAndroid Build Coastguard Worker 
219*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
220*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
221*4bdc9457SAndroid Build Coastguard Worker         y_ref[i] = std::abs(x_data[i]);
222*4bdc9457SAndroid Build Coastguard Worker       }
223*4bdc9457SAndroid Build Coastguard Worker 
224*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
225*4bdc9457SAndroid Build Coastguard Worker       union xnn_f32_abs_params params;
226*4bdc9457SAndroid Build Coastguard Worker       if (init_params != nullptr) {
227*4bdc9457SAndroid Build Coastguard Worker         init_params(&params);
228*4bdc9457SAndroid Build Coastguard Worker       }
229*4bdc9457SAndroid Build Coastguard Worker 
230*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
231*4bdc9457SAndroid Build Coastguard Worker       vabs(batch_size() * sizeof(float), x_data, y.data(), &params);
232*4bdc9457SAndroid Build Coastguard Worker 
233*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
234*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
235*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(y[i], y_ref[i])
236*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i];
237*4bdc9457SAndroid Build Coastguard Worker       }
238*4bdc9457SAndroid Build Coastguard Worker     }
239*4bdc9457SAndroid Build Coastguard Worker   }
240*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f32_vclamp_ukernel_function vclamp,xnn_init_f32_minmax_params_fn init_params)241*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_vclamp_ukernel_function vclamp, xnn_init_f32_minmax_params_fn init_params) const {
242*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
243*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
244*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(0.0f, 255.0f);
245*4bdc9457SAndroid Build Coastguard Worker 
246*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float));
247*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0));
248*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y_ref(batch_size());
249*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
250*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
251*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); });
252*4bdc9457SAndroid Build Coastguard Worker       } else {
253*4bdc9457SAndroid Build Coastguard Worker         std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); });
254*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), nanf(""));
255*4bdc9457SAndroid Build Coastguard Worker       }
256*4bdc9457SAndroid Build Coastguard Worker       const float* x_data = inplace() ? y.data() : x.data();
257*4bdc9457SAndroid Build Coastguard Worker 
258*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
259*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
260*4bdc9457SAndroid Build Coastguard Worker         y_ref[i] = std::max(std::min(x_data[i], float(qmax())), float(qmin()));
261*4bdc9457SAndroid Build Coastguard Worker       }
262*4bdc9457SAndroid Build Coastguard Worker 
263*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
264*4bdc9457SAndroid Build Coastguard Worker       union xnn_f32_minmax_params params;
265*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, float(qmin()), float(qmax()));
266*4bdc9457SAndroid Build Coastguard Worker 
267*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
268*4bdc9457SAndroid Build Coastguard Worker       vclamp(batch_size() * sizeof(float), x_data, y.data(), &params);
269*4bdc9457SAndroid Build Coastguard Worker 
270*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
271*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
272*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(y[i], y_ref[i])
273*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i];
274*4bdc9457SAndroid Build Coastguard Worker       }
275*4bdc9457SAndroid Build Coastguard Worker     }
276*4bdc9457SAndroid Build Coastguard Worker   }
277*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f16_velu_ukernel_function velu,xnn_init_f16_elu_params_fn init_params)278*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f16_velu_ukernel_function velu, xnn_init_f16_elu_params_fn init_params) const {
279*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
280*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
281*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(-9.0f, 9.0f);
282*4bdc9457SAndroid Build Coastguard Worker 
283*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t));
284*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0));
285*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y_ref(batch_size());
286*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
287*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
288*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
289*4bdc9457SAndroid Build Coastguard Worker       } else {
290*4bdc9457SAndroid Build Coastguard Worker         std::generate(x.begin(), x.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
291*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */);
292*4bdc9457SAndroid Build Coastguard Worker       }
293*4bdc9457SAndroid Build Coastguard Worker       const uint16_t* x_data = inplace() ? y.data() : x.data();
294*4bdc9457SAndroid Build Coastguard Worker 
295*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
296*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
297*4bdc9457SAndroid Build Coastguard Worker         const float x_value = fp16_ieee_to_fp32_value(x_data[i]);
298*4bdc9457SAndroid Build Coastguard Worker         y_ref[i] = std::signbit(x_value) ? alpha() * std::expm1(x_value * prescale()) : x_value * beta();
299*4bdc9457SAndroid Build Coastguard Worker       }
300*4bdc9457SAndroid Build Coastguard Worker 
301*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
302*4bdc9457SAndroid Build Coastguard Worker       union xnn_f16_elu_params params;
303*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, fp16_ieee_from_fp32_value(prescale()), fp16_ieee_from_fp32_value(alpha()), fp16_ieee_from_fp32_value(beta()));
304*4bdc9457SAndroid Build Coastguard Worker 
305*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
306*4bdc9457SAndroid Build Coastguard Worker       velu(batch_size() * sizeof(uint16_t), x_data, y.data(), &params);
307*4bdc9457SAndroid Build Coastguard Worker 
308*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
309*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
310*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(
311*4bdc9457SAndroid Build Coastguard Worker             fp16_ieee_to_fp32_value(y[i]),
312*4bdc9457SAndroid Build Coastguard Worker             y_ref[i],
313*4bdc9457SAndroid Build Coastguard Worker             std::max(1.0e-4f, std::abs(y_ref[i]) * 5.0e-3f))
314*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << fp16_ieee_to_fp32_value(x[i]);
315*4bdc9457SAndroid Build Coastguard Worker       }
316*4bdc9457SAndroid Build Coastguard Worker     }
317*4bdc9457SAndroid Build Coastguard Worker   }
318*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f32_velu_ukernel_function velu,xnn_init_f32_elu_params_fn init_params)319*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_velu_ukernel_function velu, xnn_init_f32_elu_params_fn init_params) const {
320*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
321*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
322*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(-20.0f, 20.0f);
323*4bdc9457SAndroid Build Coastguard Worker 
324*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float));
325*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0));
326*4bdc9457SAndroid Build Coastguard Worker     std::vector<double> y_ref(batch_size());
327*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
328*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
329*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); });
330*4bdc9457SAndroid Build Coastguard Worker       } else {
331*4bdc9457SAndroid Build Coastguard Worker         std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); });
332*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), nanf(""));
333*4bdc9457SAndroid Build Coastguard Worker       }
334*4bdc9457SAndroid Build Coastguard Worker       const float* x_data = inplace() ? y.data() : x.data();
335*4bdc9457SAndroid Build Coastguard Worker 
336*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
337*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
338*4bdc9457SAndroid Build Coastguard Worker         y_ref[i] = std::signbit(x_data[i]) ? alpha() * std::expm1(double(x_data[i]) * prescale()) : double(x_data[i]) * beta();
339*4bdc9457SAndroid Build Coastguard Worker       }
340*4bdc9457SAndroid Build Coastguard Worker 
341*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
342*4bdc9457SAndroid Build Coastguard Worker       union xnn_f32_elu_params params;
343*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, prescale(), alpha(), beta());
344*4bdc9457SAndroid Build Coastguard Worker 
345*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
346*4bdc9457SAndroid Build Coastguard Worker       velu(batch_size() * sizeof(float), x_data, y.data(), &params);
347*4bdc9457SAndroid Build Coastguard Worker 
348*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
349*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
350*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(y[i], y_ref[i], std::max(5.0e-6, std::abs(y_ref[i]) * 1.0e-5))
351*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i];
352*4bdc9457SAndroid Build Coastguard Worker       }
353*4bdc9457SAndroid Build Coastguard Worker     }
354*4bdc9457SAndroid Build Coastguard Worker   }
355*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f16_vhswish_ukernel_function vhswish,xnn_init_f16_hswish_params_fn init_params)356*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f16_vhswish_ukernel_function vhswish, xnn_init_f16_hswish_params_fn init_params) const {
357*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
358*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
359*4bdc9457SAndroid Build Coastguard Worker     auto f32rng = std::bind(std::uniform_real_distribution<float>(-4.0f, 4.0f), std::ref(rng));
360*4bdc9457SAndroid Build Coastguard Worker     auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
361*4bdc9457SAndroid Build Coastguard Worker 
362*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t));
363*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0));
364*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y_ref(batch_size());
365*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
366*4bdc9457SAndroid Build Coastguard Worker       std::generate(x.begin(), x.end(), std::ref(f16rng));
367*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
368*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), std::ref(f16rng));
369*4bdc9457SAndroid Build Coastguard Worker       } else {
370*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */);
371*4bdc9457SAndroid Build Coastguard Worker       }
372*4bdc9457SAndroid Build Coastguard Worker       const uint16_t* x_data = inplace() ? y.data() : x.data();
373*4bdc9457SAndroid Build Coastguard Worker 
374*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
375*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
376*4bdc9457SAndroid Build Coastguard Worker         const float x_value = fp16_ieee_to_fp32_value(x_data[i]);
377*4bdc9457SAndroid Build Coastguard Worker         y_ref[i] = (x_value / 6.0f) * std::max(std::min(x_value + 3.0f, 6.0f), 0.0f);
378*4bdc9457SAndroid Build Coastguard Worker       }
379*4bdc9457SAndroid Build Coastguard Worker 
380*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
381*4bdc9457SAndroid Build Coastguard Worker       union xnn_f16_hswish_params params;
382*4bdc9457SAndroid Build Coastguard Worker       init_params(&params);
383*4bdc9457SAndroid Build Coastguard Worker 
384*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
385*4bdc9457SAndroid Build Coastguard Worker       vhswish(batch_size() * sizeof(uint16_t), x_data, y.data(), &params);
386*4bdc9457SAndroid Build Coastguard Worker 
387*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
388*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
389*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(y_ref[i], fp16_ieee_to_fp32_value(y[i]), std::max(1.0e-3f, std::abs(y_ref[i]) * 1.0e-2f))
390*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << fp16_ieee_to_fp32_value(x[i]);
391*4bdc9457SAndroid Build Coastguard Worker       }
392*4bdc9457SAndroid Build Coastguard Worker     }
393*4bdc9457SAndroid Build Coastguard Worker   }
394*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f32_vhswish_ukernel_function vhswish,xnn_init_f32_hswish_params_fn init_params)395*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_vhswish_ukernel_function vhswish, xnn_init_f32_hswish_params_fn init_params) const {
396*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
397*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
398*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(-4.0f, 4.0f);
399*4bdc9457SAndroid Build Coastguard Worker 
400*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float));
401*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0));
402*4bdc9457SAndroid Build Coastguard Worker     std::vector<double> y_ref(batch_size());
403*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
404*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
405*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); });
406*4bdc9457SAndroid Build Coastguard Worker       } else {
407*4bdc9457SAndroid Build Coastguard Worker         std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); });
408*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), nanf(""));
409*4bdc9457SAndroid Build Coastguard Worker       }
410*4bdc9457SAndroid Build Coastguard Worker       const float* x_data = inplace() ? y.data() : x.data();
411*4bdc9457SAndroid Build Coastguard Worker 
412*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
413*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
414*4bdc9457SAndroid Build Coastguard Worker         y_ref[i] = (x_data[i] / 6.0f) * std::max(std::min(x_data[i] + 3.0f, 6.0f), 0.0f);
415*4bdc9457SAndroid Build Coastguard Worker       }
416*4bdc9457SAndroid Build Coastguard Worker 
417*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
418*4bdc9457SAndroid Build Coastguard Worker       union xnn_f32_hswish_params params;
419*4bdc9457SAndroid Build Coastguard Worker       init_params(&params);
420*4bdc9457SAndroid Build Coastguard Worker 
421*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
422*4bdc9457SAndroid Build Coastguard Worker       vhswish(batch_size() * sizeof(float), x_data, y.data(), &params);
423*4bdc9457SAndroid Build Coastguard Worker 
424*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
425*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
426*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(y[i], y_ref[i], std::max(5.0e-6, std::abs(y_ref[i]) * 1.0e-5))
427*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i];
428*4bdc9457SAndroid Build Coastguard Worker       }
429*4bdc9457SAndroid Build Coastguard Worker     }
430*4bdc9457SAndroid Build Coastguard Worker   }
431*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f16_vlrelu_ukernel_function vlrelu,xnn_init_f16_lrelu_params_fn init_params)432*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f16_vlrelu_ukernel_function vlrelu, xnn_init_f16_lrelu_params_fn init_params) const {
433*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
434*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
435*4bdc9457SAndroid Build Coastguard Worker     auto f32rng = std::bind(std::uniform_real_distribution<float>(-125.0f, 125.0f), std::ref(rng));
436*4bdc9457SAndroid Build Coastguard Worker     auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
437*4bdc9457SAndroid Build Coastguard Worker 
438*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t));
439*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0));
440*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y_ref(batch_size());
441*4bdc9457SAndroid Build Coastguard Worker     const uint16_t slope_as_half = fp16_ieee_from_fp32_value(slope());
442*4bdc9457SAndroid Build Coastguard Worker     const float slope_as_float = fp16_ieee_to_fp32_value(slope_as_half);
443*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
444*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
445*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), std::ref(f16rng));
446*4bdc9457SAndroid Build Coastguard Worker       } else {
447*4bdc9457SAndroid Build Coastguard Worker         std::generate(x.begin(), x.end(), std::ref(f16rng));
448*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */);
449*4bdc9457SAndroid Build Coastguard Worker       }
450*4bdc9457SAndroid Build Coastguard Worker       const uint16_t* x_data = inplace() ? y.data() : x.data();
451*4bdc9457SAndroid Build Coastguard Worker 
452*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
453*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
454*4bdc9457SAndroid Build Coastguard Worker         const float x_value = fp16_ieee_to_fp32_value(x_data[i]);
455*4bdc9457SAndroid Build Coastguard Worker         y_ref[i] = std::signbit(x_value) ? x_value * slope_as_float : x_value;
456*4bdc9457SAndroid Build Coastguard Worker       }
457*4bdc9457SAndroid Build Coastguard Worker 
458*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
459*4bdc9457SAndroid Build Coastguard Worker       union xnn_f16_lrelu_params params;
460*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, slope_as_half);
461*4bdc9457SAndroid Build Coastguard Worker 
462*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
463*4bdc9457SAndroid Build Coastguard Worker       vlrelu(batch_size() * sizeof(uint16_t), x_data, y.data(), &params);
464*4bdc9457SAndroid Build Coastguard Worker 
465*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
466*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
467*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(
468*4bdc9457SAndroid Build Coastguard Worker             fp16_ieee_to_fp32_value(y[i]),
469*4bdc9457SAndroid Build Coastguard Worker             y_ref[i],
470*4bdc9457SAndroid Build Coastguard Worker             std::max(1.0e-4f, std::abs(y_ref[i]) * 1.0e-3f))
471*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << fp16_ieee_to_fp32_value(x[i]);
472*4bdc9457SAndroid Build Coastguard Worker       }
473*4bdc9457SAndroid Build Coastguard Worker     }
474*4bdc9457SAndroid Build Coastguard Worker   }
475*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f32_vlrelu_ukernel_function vlrelu,xnn_init_f32_lrelu_params_fn init_params)476*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_vlrelu_ukernel_function vlrelu, xnn_init_f32_lrelu_params_fn init_params) const {
477*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
478*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
479*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(-125.0f, 125.0f);
480*4bdc9457SAndroid Build Coastguard Worker 
481*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float));
482*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0));
483*4bdc9457SAndroid Build Coastguard Worker     std::vector<double> y_ref(batch_size());
484*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
485*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
486*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); });
487*4bdc9457SAndroid Build Coastguard Worker       } else {
488*4bdc9457SAndroid Build Coastguard Worker         std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); });
489*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), nanf(""));
490*4bdc9457SAndroid Build Coastguard Worker       }
491*4bdc9457SAndroid Build Coastguard Worker       const float* x_data = inplace() ? y.data() : x.data();
492*4bdc9457SAndroid Build Coastguard Worker 
493*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
494*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
495*4bdc9457SAndroid Build Coastguard Worker         y_ref[i] = std::signbit(x_data[i]) ? x_data[i] * slope() : x_data[i];
496*4bdc9457SAndroid Build Coastguard Worker       }
497*4bdc9457SAndroid Build Coastguard Worker 
498*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
499*4bdc9457SAndroid Build Coastguard Worker       union xnn_f32_lrelu_params params;
500*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, slope());
501*4bdc9457SAndroid Build Coastguard Worker 
502*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
503*4bdc9457SAndroid Build Coastguard Worker       vlrelu(batch_size() * sizeof(float), x_data, y.data(), &params);
504*4bdc9457SAndroid Build Coastguard Worker 
505*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
506*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
507*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(y[i], y_ref[i])
508*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i];
509*4bdc9457SAndroid Build Coastguard Worker       }
510*4bdc9457SAndroid Build Coastguard Worker     }
511*4bdc9457SAndroid Build Coastguard Worker   }
512*4bdc9457SAndroid Build Coastguard Worker 
513*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f16_vneg_ukernel_function vneg, xnn_init_f16_neg_params_fn init_params = nullptr) const {
514*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
515*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
516*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(-1.0f, 1.0f);
517*4bdc9457SAndroid Build Coastguard Worker 
518*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t));
519*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0));
520*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> y_ref(batch_size());
521*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
522*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
523*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
524*4bdc9457SAndroid Build Coastguard Worker       } else {
525*4bdc9457SAndroid Build Coastguard Worker         std::generate(x.begin(), x.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
526*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */);
527*4bdc9457SAndroid Build Coastguard Worker       }
528*4bdc9457SAndroid Build Coastguard Worker       const uint16_t* x_data = inplace() ? y.data() : x.data();
529*4bdc9457SAndroid Build Coastguard Worker 
530*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
531*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
532*4bdc9457SAndroid Build Coastguard Worker         y_ref[i] = x_data[i] ^ UINT16_C(0x8000);
533*4bdc9457SAndroid Build Coastguard Worker       }
534*4bdc9457SAndroid Build Coastguard Worker 
535*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
536*4bdc9457SAndroid Build Coastguard Worker       union xnn_f16_neg_params params;
537*4bdc9457SAndroid Build Coastguard Worker       if (init_params != nullptr) {
538*4bdc9457SAndroid Build Coastguard Worker         init_params(&params);
539*4bdc9457SAndroid Build Coastguard Worker       }
540*4bdc9457SAndroid Build Coastguard Worker 
541*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
542*4bdc9457SAndroid Build Coastguard Worker       vneg(batch_size() * sizeof(uint16_t), x_data, y.data(), &params);
543*4bdc9457SAndroid Build Coastguard Worker 
544*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
545*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
546*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(y[i], y_ref[i])
547*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i];
548*4bdc9457SAndroid Build Coastguard Worker       }
549*4bdc9457SAndroid Build Coastguard Worker     }
550*4bdc9457SAndroid Build Coastguard Worker   }
551*4bdc9457SAndroid Build Coastguard Worker 
552*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_vneg_ukernel_function vneg, xnn_init_f32_neg_params_fn init_params = nullptr) const {
553*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
554*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
555*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(-1.0f, 1.0f);
556*4bdc9457SAndroid Build Coastguard Worker 
557*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float));
558*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0));
559*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y_ref(batch_size());
560*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
561*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
562*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); });
563*4bdc9457SAndroid Build Coastguard Worker       } else {
564*4bdc9457SAndroid Build Coastguard Worker         std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); });
565*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), nanf(""));
566*4bdc9457SAndroid Build Coastguard Worker       }
567*4bdc9457SAndroid Build Coastguard Worker       const float* x_data = inplace() ? y.data() : x.data();
568*4bdc9457SAndroid Build Coastguard Worker 
569*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
570*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
571*4bdc9457SAndroid Build Coastguard Worker         y_ref[i] = -x_data[i];
572*4bdc9457SAndroid Build Coastguard Worker       }
573*4bdc9457SAndroid Build Coastguard Worker 
574*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
575*4bdc9457SAndroid Build Coastguard Worker       union xnn_f32_neg_params params;
576*4bdc9457SAndroid Build Coastguard Worker       if (init_params != nullptr) {
577*4bdc9457SAndroid Build Coastguard Worker         init_params(&params);
578*4bdc9457SAndroid Build Coastguard Worker       }
579*4bdc9457SAndroid Build Coastguard Worker 
580*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
581*4bdc9457SAndroid Build Coastguard Worker       vneg(batch_size() * sizeof(float), x_data, y.data(), &params);
582*4bdc9457SAndroid Build Coastguard Worker 
583*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
584*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
585*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(y[i], y_ref[i])
586*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i];
587*4bdc9457SAndroid Build Coastguard Worker       }
588*4bdc9457SAndroid Build Coastguard Worker     }
589*4bdc9457SAndroid Build Coastguard Worker   }
590*4bdc9457SAndroid Build Coastguard Worker 
591*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f16_vround_ukernel_function vrnd, OpType op_type, xnn_init_f16_rnd_params_fn init_params = nullptr) const {
592*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
593*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
594*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(-5.0f, 5.0f);
595*4bdc9457SAndroid Build Coastguard Worker 
596*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t));
597*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0));
598*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> y_ref(batch_size());
599*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
600*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
601*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
602*4bdc9457SAndroid Build Coastguard Worker       } else {
603*4bdc9457SAndroid Build Coastguard Worker         std::generate(x.begin(), x.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
604*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */);
605*4bdc9457SAndroid Build Coastguard Worker       }
606*4bdc9457SAndroid Build Coastguard Worker       const uint16_t* x_data = inplace() ? y.data() : x.data();
607*4bdc9457SAndroid Build Coastguard Worker 
608*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
609*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
610*4bdc9457SAndroid Build Coastguard Worker         switch (op_type) {
611*4bdc9457SAndroid Build Coastguard Worker           case OpType::RoundToNearestEven:
612*4bdc9457SAndroid Build Coastguard Worker             y_ref[i] = fp16_ieee_from_fp32_value(std::nearbyint(fp16_ieee_to_fp32_value(x_data[i])));
613*4bdc9457SAndroid Build Coastguard Worker             break;
614*4bdc9457SAndroid Build Coastguard Worker           case OpType::RoundTowardsZero:
615*4bdc9457SAndroid Build Coastguard Worker             y_ref[i] = fp16_ieee_from_fp32_value(std::trunc(fp16_ieee_to_fp32_value(x_data[i])));
616*4bdc9457SAndroid Build Coastguard Worker             break;
617*4bdc9457SAndroid Build Coastguard Worker           case OpType::RoundUp:
618*4bdc9457SAndroid Build Coastguard Worker             y_ref[i] = fp16_ieee_from_fp32_value(std::ceil(fp16_ieee_to_fp32_value(x_data[i])));
619*4bdc9457SAndroid Build Coastguard Worker             break;
620*4bdc9457SAndroid Build Coastguard Worker           case OpType::RoundDown:
621*4bdc9457SAndroid Build Coastguard Worker             y_ref[i] = fp16_ieee_from_fp32_value(std::floor(fp16_ieee_to_fp32_value(x_data[i])));
622*4bdc9457SAndroid Build Coastguard Worker             break;
623*4bdc9457SAndroid Build Coastguard Worker           default:
624*4bdc9457SAndroid Build Coastguard Worker             GTEST_FAIL() << "Unexpected operation type";
625*4bdc9457SAndroid Build Coastguard Worker             return;
626*4bdc9457SAndroid Build Coastguard Worker         }
627*4bdc9457SAndroid Build Coastguard Worker       }
628*4bdc9457SAndroid Build Coastguard Worker 
629*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
630*4bdc9457SAndroid Build Coastguard Worker       xnn_f16_rnd_params params;
631*4bdc9457SAndroid Build Coastguard Worker       if (init_params != nullptr) {
632*4bdc9457SAndroid Build Coastguard Worker         init_params(&params);
633*4bdc9457SAndroid Build Coastguard Worker       }
634*4bdc9457SAndroid Build Coastguard Worker 
635*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
636*4bdc9457SAndroid Build Coastguard Worker       vrnd(batch_size() * sizeof(uint16_t), x_data, y.data(), &params);
637*4bdc9457SAndroid Build Coastguard Worker 
638*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
639*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
640*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(y[i], y_ref[i])
641*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i];
642*4bdc9457SAndroid Build Coastguard Worker       }
643*4bdc9457SAndroid Build Coastguard Worker     }
644*4bdc9457SAndroid Build Coastguard Worker   }
645*4bdc9457SAndroid Build Coastguard Worker 
646*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_vround_ukernel_function vrnd, OpType op_type, xnn_init_f32_rnd_params_fn init_params = nullptr) const {
647*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
648*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
649*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(-5.0f, 5.0f);
650*4bdc9457SAndroid Build Coastguard Worker 
651*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float));
652*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0));
653*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y_ref(batch_size());
654*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
655*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
656*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); });
657*4bdc9457SAndroid Build Coastguard Worker       } else {
658*4bdc9457SAndroid Build Coastguard Worker         std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); });
659*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), nanf(""));
660*4bdc9457SAndroid Build Coastguard Worker       }
661*4bdc9457SAndroid Build Coastguard Worker       const float* x_data = inplace() ? y.data() : x.data();
662*4bdc9457SAndroid Build Coastguard Worker 
663*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
664*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
665*4bdc9457SAndroid Build Coastguard Worker         switch (op_type) {
666*4bdc9457SAndroid Build Coastguard Worker           case OpType::RoundToNearestEven:
667*4bdc9457SAndroid Build Coastguard Worker             y_ref[i] = std::nearbyint(x_data[i]);
668*4bdc9457SAndroid Build Coastguard Worker             break;
669*4bdc9457SAndroid Build Coastguard Worker           case OpType::RoundTowardsZero:
670*4bdc9457SAndroid Build Coastguard Worker             y_ref[i] = std::trunc(x_data[i]);
671*4bdc9457SAndroid Build Coastguard Worker             break;
672*4bdc9457SAndroid Build Coastguard Worker           case OpType::RoundUp:
673*4bdc9457SAndroid Build Coastguard Worker             y_ref[i] = std::ceil(x_data[i]);
674*4bdc9457SAndroid Build Coastguard Worker             break;
675*4bdc9457SAndroid Build Coastguard Worker           case OpType::RoundDown:
676*4bdc9457SAndroid Build Coastguard Worker             y_ref[i] = std::floor(x_data[i]);
677*4bdc9457SAndroid Build Coastguard Worker             break;
678*4bdc9457SAndroid Build Coastguard Worker           default:
679*4bdc9457SAndroid Build Coastguard Worker             GTEST_FAIL() << "Unexpected operation type";
680*4bdc9457SAndroid Build Coastguard Worker             return;
681*4bdc9457SAndroid Build Coastguard Worker         }
682*4bdc9457SAndroid Build Coastguard Worker       }
683*4bdc9457SAndroid Build Coastguard Worker 
684*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
685*4bdc9457SAndroid Build Coastguard Worker       xnn_f32_rnd_params params;
686*4bdc9457SAndroid Build Coastguard Worker       if (init_params != nullptr) {
687*4bdc9457SAndroid Build Coastguard Worker         init_params(&params);
688*4bdc9457SAndroid Build Coastguard Worker       }
689*4bdc9457SAndroid Build Coastguard Worker 
690*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
691*4bdc9457SAndroid Build Coastguard Worker       vrnd(batch_size() * sizeof(float), x_data, y.data(), &params);
692*4bdc9457SAndroid Build Coastguard Worker 
693*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
694*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
695*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(y[i], y_ref[i])
696*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i];
697*4bdc9457SAndroid Build Coastguard Worker       }
698*4bdc9457SAndroid Build Coastguard Worker     }
699*4bdc9457SAndroid Build Coastguard Worker   }
700*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f16_vsigmoid_ukernel_function vsigmoid,xnn_init_f16_sigmoid_params_fn init_params)701*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f16_vsigmoid_ukernel_function vsigmoid, xnn_init_f16_sigmoid_params_fn init_params) const {
702*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
703*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
704*4bdc9457SAndroid Build Coastguard Worker     auto distribution = std::uniform_real_distribution<float>(-25.0f, 25.0f);
705*4bdc9457SAndroid Build Coastguard Worker     auto f32rng = std::bind(distribution, std::ref(rng));
706*4bdc9457SAndroid Build Coastguard Worker     auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
707*4bdc9457SAndroid Build Coastguard Worker 
708*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t));
709*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0));
710*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y_ref(batch_size());
711*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
712*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
713*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), std::ref(f16rng));
714*4bdc9457SAndroid Build Coastguard Worker       } else {
715*4bdc9457SAndroid Build Coastguard Worker         std::generate(x.begin(), x.end(), std::ref(f16rng));
716*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */);
717*4bdc9457SAndroid Build Coastguard Worker       }
718*4bdc9457SAndroid Build Coastguard Worker       const uint16_t* x_data = inplace() ? y.data() : x.data();
719*4bdc9457SAndroid Build Coastguard Worker 
720*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
721*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
722*4bdc9457SAndroid Build Coastguard Worker         const float e = std::exp(fp16_ieee_to_fp32_value(x_data[i]));
723*4bdc9457SAndroid Build Coastguard Worker         y_ref[i] = e / (1.0f + e);
724*4bdc9457SAndroid Build Coastguard Worker       }
725*4bdc9457SAndroid Build Coastguard Worker 
726*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
727*4bdc9457SAndroid Build Coastguard Worker       union xnn_f16_sigmoid_params params;
728*4bdc9457SAndroid Build Coastguard Worker       init_params(&params);
729*4bdc9457SAndroid Build Coastguard Worker 
730*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
731*4bdc9457SAndroid Build Coastguard Worker       vsigmoid(batch_size() * sizeof(uint16_t), x_data, y.data(), &params);
732*4bdc9457SAndroid Build Coastguard Worker 
733*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
734*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
735*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(
736*4bdc9457SAndroid Build Coastguard Worker             fp16_ieee_to_fp32_value(y[i]),
737*4bdc9457SAndroid Build Coastguard Worker             y_ref[i],
738*4bdc9457SAndroid Build Coastguard Worker             std::max(1.0e-4f, std::abs(y_ref[i]) * 5.0e-3f))
739*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << fp16_ieee_to_fp32_value(x[i]);
740*4bdc9457SAndroid Build Coastguard Worker       }
741*4bdc9457SAndroid Build Coastguard Worker     }
742*4bdc9457SAndroid Build Coastguard Worker   }
743*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f32_vsigmoid_ukernel_function vsigmoid,xnn_init_f32_sigmoid_params_fn init_params)744*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_vsigmoid_ukernel_function vsigmoid, xnn_init_f32_sigmoid_params_fn init_params) const {
745*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
746*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
747*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(-125.0f, 125.0f);
748*4bdc9457SAndroid Build Coastguard Worker 
749*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float));
750*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0));
751*4bdc9457SAndroid Build Coastguard Worker     std::vector<double> y_ref(batch_size());
752*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
753*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
754*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); });
755*4bdc9457SAndroid Build Coastguard Worker       } else {
756*4bdc9457SAndroid Build Coastguard Worker         std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); });
757*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), nanf(""));
758*4bdc9457SAndroid Build Coastguard Worker       }
759*4bdc9457SAndroid Build Coastguard Worker       const float* x_data = inplace() ? y.data() : x.data();
760*4bdc9457SAndroid Build Coastguard Worker 
761*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
762*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
763*4bdc9457SAndroid Build Coastguard Worker         const double e = std::exp(double(x_data[i]));
764*4bdc9457SAndroid Build Coastguard Worker         y_ref[i] = e / (1.0 + e);
765*4bdc9457SAndroid Build Coastguard Worker       }
766*4bdc9457SAndroid Build Coastguard Worker 
767*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
768*4bdc9457SAndroid Build Coastguard Worker       union xnn_f32_sigmoid_params params;
769*4bdc9457SAndroid Build Coastguard Worker       init_params(&params);
770*4bdc9457SAndroid Build Coastguard Worker 
771*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
772*4bdc9457SAndroid Build Coastguard Worker       vsigmoid(batch_size() * sizeof(float), x_data, y.data(), &params);
773*4bdc9457SAndroid Build Coastguard Worker 
774*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
775*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
776*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(y[i], y_ref[i], std::max(5.0e-6, std::abs(y_ref[i]) * 1.0e-5))
777*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i];
778*4bdc9457SAndroid Build Coastguard Worker       }
779*4bdc9457SAndroid Build Coastguard Worker     }
780*4bdc9457SAndroid Build Coastguard Worker   }
781*4bdc9457SAndroid Build Coastguard Worker 
782*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f16_vsqr_ukernel_function vsqr, xnn_init_f16_default_params_fn init_params = nullptr) const {
783*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
784*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
785*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(-10.0f, 10.0f);
786*4bdc9457SAndroid Build Coastguard Worker 
787*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t));
788*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0));
789*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y_ref(batch_size());
790*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
791*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
792*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
793*4bdc9457SAndroid Build Coastguard Worker       } else {
794*4bdc9457SAndroid Build Coastguard Worker         std::generate(x.begin(), x.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
795*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */);
796*4bdc9457SAndroid Build Coastguard Worker       }
797*4bdc9457SAndroid Build Coastguard Worker       const uint16_t* x_data = inplace() ? y.data() : x.data();
798*4bdc9457SAndroid Build Coastguard Worker 
799*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
800*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
801*4bdc9457SAndroid Build Coastguard Worker         const float x_value = fp16_ieee_to_fp32_value(x_data[i]);
802*4bdc9457SAndroid Build Coastguard Worker         y_ref[i] = x_value * x_value;
803*4bdc9457SAndroid Build Coastguard Worker       }
804*4bdc9457SAndroid Build Coastguard Worker 
805*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
806*4bdc9457SAndroid Build Coastguard Worker       union xnn_f16_default_params params;
807*4bdc9457SAndroid Build Coastguard Worker       if (init_params != nullptr) {
808*4bdc9457SAndroid Build Coastguard Worker         init_params(&params);
809*4bdc9457SAndroid Build Coastguard Worker       }
810*4bdc9457SAndroid Build Coastguard Worker 
811*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
812*4bdc9457SAndroid Build Coastguard Worker       vsqr(batch_size() * sizeof(uint16_t), x_data, y.data(), &params);
813*4bdc9457SAndroid Build Coastguard Worker 
814*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
815*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
816*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(
817*4bdc9457SAndroid Build Coastguard Worker             fp16_ieee_to_fp32_value(y[i]),
818*4bdc9457SAndroid Build Coastguard Worker             y_ref[i],
819*4bdc9457SAndroid Build Coastguard Worker             std::max(1.0e-4f, std::abs(y_ref[i]) * 5.0e-3f))
820*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << fp16_ieee_to_fp32_value(x[i]);
821*4bdc9457SAndroid Build Coastguard Worker       }
822*4bdc9457SAndroid Build Coastguard Worker     }
823*4bdc9457SAndroid Build Coastguard Worker   }
824*4bdc9457SAndroid Build Coastguard Worker 
825*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_vsqr_ukernel_function vsqr, xnn_init_f32_default_params_fn init_params = nullptr) const {
826*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
827*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
828*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(-10.0f, 10.0f);
829*4bdc9457SAndroid Build Coastguard Worker 
830*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float));
831*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0));
832*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y_ref(batch_size());
833*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
834*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
835*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); });
836*4bdc9457SAndroid Build Coastguard Worker       } else {
837*4bdc9457SAndroid Build Coastguard Worker         std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); });
838*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), nanf(""));
839*4bdc9457SAndroid Build Coastguard Worker       }
840*4bdc9457SAndroid Build Coastguard Worker       const float* x_data = inplace() ? y.data() : x.data();
841*4bdc9457SAndroid Build Coastguard Worker 
842*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
843*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
844*4bdc9457SAndroid Build Coastguard Worker         y_ref[i] = x_data[i] * x_data[i];
845*4bdc9457SAndroid Build Coastguard Worker       }
846*4bdc9457SAndroid Build Coastguard Worker 
847*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
848*4bdc9457SAndroid Build Coastguard Worker       union xnn_f32_default_params params;
849*4bdc9457SAndroid Build Coastguard Worker       if (init_params != nullptr) {
850*4bdc9457SAndroid Build Coastguard Worker         init_params(&params);
851*4bdc9457SAndroid Build Coastguard Worker       }
852*4bdc9457SAndroid Build Coastguard Worker 
853*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
854*4bdc9457SAndroid Build Coastguard Worker       vsqr(batch_size() * sizeof(float), x_data, y.data(), &params);
855*4bdc9457SAndroid Build Coastguard Worker 
856*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
857*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
858*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(y[i], y_ref[i])
859*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i];
860*4bdc9457SAndroid Build Coastguard Worker       }
861*4bdc9457SAndroid Build Coastguard Worker     }
862*4bdc9457SAndroid Build Coastguard Worker   }
863*4bdc9457SAndroid Build Coastguard Worker 
864*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f16_vsqrt_ukernel_function vsqrt, xnn_init_f16_sqrt_params_fn init_params = nullptr) const {
865*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
866*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
867*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(0.0f, 10.0f);
868*4bdc9457SAndroid Build Coastguard Worker 
869*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t));
870*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0));
871*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y_ref(batch_size());
872*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
873*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
874*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
875*4bdc9457SAndroid Build Coastguard Worker       } else {
876*4bdc9457SAndroid Build Coastguard Worker         std::generate(x.begin(), x.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
877*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */);
878*4bdc9457SAndroid Build Coastguard Worker       }
879*4bdc9457SAndroid Build Coastguard Worker       const uint16_t* x_data = inplace() ? y.data() : x.data();
880*4bdc9457SAndroid Build Coastguard Worker 
881*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
882*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
883*4bdc9457SAndroid Build Coastguard Worker         y_ref[i] = std::sqrt(fp16_ieee_to_fp32_value(x_data[i]));
884*4bdc9457SAndroid Build Coastguard Worker       }
885*4bdc9457SAndroid Build Coastguard Worker 
886*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
887*4bdc9457SAndroid Build Coastguard Worker       union xnn_f16_sqrt_params params;
888*4bdc9457SAndroid Build Coastguard Worker       if (init_params != nullptr) {
889*4bdc9457SAndroid Build Coastguard Worker         init_params(&params);
890*4bdc9457SAndroid Build Coastguard Worker       }
891*4bdc9457SAndroid Build Coastguard Worker 
892*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
893*4bdc9457SAndroid Build Coastguard Worker       vsqrt(batch_size() * sizeof(uint16_t), x_data, y.data(), init_params != nullptr ? &params : nullptr);
894*4bdc9457SAndroid Build Coastguard Worker 
895*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
896*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
897*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(
898*4bdc9457SAndroid Build Coastguard Worker             fp16_ieee_to_fp32_value(y[i]),
899*4bdc9457SAndroid Build Coastguard Worker             y_ref[i],
900*4bdc9457SAndroid Build Coastguard Worker             std::max(1.0e-4f, std::abs(y_ref[i]) * 5.0e-3f))
901*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << fp16_ieee_to_fp32_value(x[i]);
902*4bdc9457SAndroid Build Coastguard Worker       }
903*4bdc9457SAndroid Build Coastguard Worker     }
904*4bdc9457SAndroid Build Coastguard Worker   }
905*4bdc9457SAndroid Build Coastguard Worker 
906*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f32_vsqrt_ukernel_function vsqrt, xnn_init_f32_sqrt_params_fn init_params = nullptr) const {
907*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
908*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
909*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(0.0f, 10.0f);
910*4bdc9457SAndroid Build Coastguard Worker 
911*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float));
912*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0));
913*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y_ref(batch_size());
914*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
915*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
916*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); });
917*4bdc9457SAndroid Build Coastguard Worker       } else {
918*4bdc9457SAndroid Build Coastguard Worker         std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); });
919*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), nanf(""));
920*4bdc9457SAndroid Build Coastguard Worker       }
921*4bdc9457SAndroid Build Coastguard Worker       const float* x_data = inplace() ? y.data() : x.data();
922*4bdc9457SAndroid Build Coastguard Worker 
923*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
924*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
925*4bdc9457SAndroid Build Coastguard Worker         y_ref[i] = std::sqrt(x_data[i]);
926*4bdc9457SAndroid Build Coastguard Worker       }
927*4bdc9457SAndroid Build Coastguard Worker 
928*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
929*4bdc9457SAndroid Build Coastguard Worker       union xnn_f32_sqrt_params params;
930*4bdc9457SAndroid Build Coastguard Worker       if (init_params != nullptr) {
931*4bdc9457SAndroid Build Coastguard Worker         init_params(&params);
932*4bdc9457SAndroid Build Coastguard Worker       }
933*4bdc9457SAndroid Build Coastguard Worker 
934*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
935*4bdc9457SAndroid Build Coastguard Worker       vsqrt(batch_size() * sizeof(float), x_data, y.data(), init_params != nullptr ? &params : nullptr);
936*4bdc9457SAndroid Build Coastguard Worker 
937*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
938*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
939*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(y[i], y_ref[i])
940*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i];
941*4bdc9457SAndroid Build Coastguard Worker       }
942*4bdc9457SAndroid Build Coastguard Worker     }
943*4bdc9457SAndroid Build Coastguard Worker   }
944*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_f16_vclamp_ukernel_function vclamp,xnn_init_f16_minmax_params_fn init_params)945*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_f16_vclamp_ukernel_function vclamp, xnn_init_f16_minmax_params_fn init_params) const {
946*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
947*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
948*4bdc9457SAndroid Build Coastguard Worker     auto f32rng = std::bind(std::uniform_real_distribution<float>(0.0f, 255.0f), std::ref(rng));
949*4bdc9457SAndroid Build Coastguard Worker     auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
950*4bdc9457SAndroid Build Coastguard Worker 
951*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t));
952*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0));
953*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> y_ref(batch_size());
954*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
955*4bdc9457SAndroid Build Coastguard Worker       std::generate(x.begin(), x.end(), std::ref(f16rng));
956*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
957*4bdc9457SAndroid Build Coastguard Worker         std::generate(y.begin(), y.end(), std::ref(f16rng));
958*4bdc9457SAndroid Build Coastguard Worker       } else {
959*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */);
960*4bdc9457SAndroid Build Coastguard Worker       }
961*4bdc9457SAndroid Build Coastguard Worker       const uint16_t* x_data = inplace() ? y.data() : x.data();
962*4bdc9457SAndroid Build Coastguard Worker 
963*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
964*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
965*4bdc9457SAndroid Build Coastguard Worker         y_ref[i] = std::max(std::min(fp16_ieee_to_fp32_value(x_data[i]), float(qmax())), float(qmin()));
966*4bdc9457SAndroid Build Coastguard Worker       }
967*4bdc9457SAndroid Build Coastguard Worker 
968*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
969*4bdc9457SAndroid Build Coastguard Worker       union xnn_f16_minmax_params params;
970*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, fp16_ieee_from_fp32_value(float(qmin())), fp16_ieee_from_fp32_value(float(qmax())));
971*4bdc9457SAndroid Build Coastguard Worker 
972*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
973*4bdc9457SAndroid Build Coastguard Worker       vclamp(batch_size() * sizeof(uint16_t), x_data, y.data(), &params);
974*4bdc9457SAndroid Build Coastguard Worker 
975*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
976*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
977*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(y_ref[i], fp16_ieee_to_fp32_value(y[i]), std::max(1.0e-3f, std::abs(y_ref[i]) * 1.0e-2f))
978*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << fp16_ieee_to_fp32_value(x[i]);
979*4bdc9457SAndroid Build Coastguard Worker       }
980*4bdc9457SAndroid Build Coastguard Worker     }
981*4bdc9457SAndroid Build Coastguard Worker   }
982*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_s8_vclamp_ukernel_function vclamp,xnn_init_s8_minmax_params_fn init_params)983*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_s8_vclamp_ukernel_function vclamp, xnn_init_s8_minmax_params_fn init_params) const {
984*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
985*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
986*4bdc9457SAndroid Build Coastguard Worker     auto i8rng = std::bind(
987*4bdc9457SAndroid Build Coastguard Worker       std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
988*4bdc9457SAndroid Build Coastguard Worker       std::ref(rng));
989*4bdc9457SAndroid Build Coastguard Worker 
990*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(int8_t));
991*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(int8_t) : 0));
992*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> y_ref(batch_size());
993*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
994*4bdc9457SAndroid Build Coastguard Worker       std::generate(x.begin(), x.end(), std::ref(i8rng));
995*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
996*4bdc9457SAndroid Build Coastguard Worker         std::copy(x.cbegin(), x.cend(), y.begin());
997*4bdc9457SAndroid Build Coastguard Worker       } else {
998*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), INT8_C(0xA5));
999*4bdc9457SAndroid Build Coastguard Worker       }
1000*4bdc9457SAndroid Build Coastguard Worker       const int8_t* x_data = inplace() ? y.data() : x.data();
1001*4bdc9457SAndroid Build Coastguard Worker 
1002*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
1003*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
1004*4bdc9457SAndroid Build Coastguard Worker         y_ref[i] = std::min(std::max(x_data[i], int8_t(qmin() - 0x80)), int8_t(qmax() - 0x80));
1005*4bdc9457SAndroid Build Coastguard Worker       }
1006*4bdc9457SAndroid Build Coastguard Worker 
1007*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
1008*4bdc9457SAndroid Build Coastguard Worker       union xnn_s8_minmax_params params;
1009*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
1010*4bdc9457SAndroid Build Coastguard Worker 
1011*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
1012*4bdc9457SAndroid Build Coastguard Worker       vclamp(batch_size() * sizeof(int8_t), x_data, y.data(), &params);
1013*4bdc9457SAndroid Build Coastguard Worker 
1014*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
1015*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
1016*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(int32_t(y_ref[i]), int32_t(y[i]))
1017*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << int32_t(x[i]);
1018*4bdc9457SAndroid Build Coastguard Worker       }
1019*4bdc9457SAndroid Build Coastguard Worker     }
1020*4bdc9457SAndroid Build Coastguard Worker   }
1021*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_u8_vclamp_ukernel_function vclamp,xnn_init_u8_minmax_params_fn init_params)1022*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_u8_vclamp_ukernel_function vclamp, xnn_init_u8_minmax_params_fn init_params) const {
1023*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
1024*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
1025*4bdc9457SAndroid Build Coastguard Worker     auto u8rng = std::bind(
1026*4bdc9457SAndroid Build Coastguard Worker       std::uniform_int_distribution<int32_t>(0, std::numeric_limits<uint8_t>::max()), std::ref(rng));
1027*4bdc9457SAndroid Build Coastguard Worker 
1028*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint8_t));
1029*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint8_t) : 0));
1030*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> y_ref(batch_size());
1031*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
1032*4bdc9457SAndroid Build Coastguard Worker       std::generate(x.begin(), x.end(), std::ref(u8rng));
1033*4bdc9457SAndroid Build Coastguard Worker       if (inplace()) {
1034*4bdc9457SAndroid Build Coastguard Worker         std::copy(x.cbegin(), x.cend(), y.begin());
1035*4bdc9457SAndroid Build Coastguard Worker       } else {
1036*4bdc9457SAndroid Build Coastguard Worker         std::fill(y.begin(), y.end(), UINT8_C(0xA5));
1037*4bdc9457SAndroid Build Coastguard Worker       }
1038*4bdc9457SAndroid Build Coastguard Worker       const uint8_t* x_data = inplace() ? y.data() : x.data();
1039*4bdc9457SAndroid Build Coastguard Worker 
1040*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
1041*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
1042*4bdc9457SAndroid Build Coastguard Worker         y_ref[i] = std::min(std::max(x_data[i], qmin()), qmax());
1043*4bdc9457SAndroid Build Coastguard Worker       }
1044*4bdc9457SAndroid Build Coastguard Worker 
1045*4bdc9457SAndroid Build Coastguard Worker       // Prepare parameters.
1046*4bdc9457SAndroid Build Coastguard Worker       union xnn_u8_minmax_params params;
1047*4bdc9457SAndroid Build Coastguard Worker       init_params(&params, qmin(), qmax());
1048*4bdc9457SAndroid Build Coastguard Worker 
1049*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
1050*4bdc9457SAndroid Build Coastguard Worker       vclamp(batch_size() * sizeof(uint8_t), x_data, y.data(), &params);
1051*4bdc9457SAndroid Build Coastguard Worker 
1052*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
1053*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
1054*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(uint32_t(y_ref[i]), uint32_t(y[i]))
1055*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << uint32_t(x[i]);
1056*4bdc9457SAndroid Build Coastguard Worker       }
1057*4bdc9457SAndroid Build Coastguard Worker     }
1058*4bdc9457SAndroid Build Coastguard Worker   }
1059*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_u64_u32_vsqrtshift_ukernel_function vsqrtshift)1060*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_u64_u32_vsqrtshift_ukernel_function vsqrtshift) const {
1061*4bdc9457SAndroid Build Coastguard Worker     ASSERT_FALSE(inplace());
1062*4bdc9457SAndroid Build Coastguard Worker 
1063*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
1064*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
1065*4bdc9457SAndroid Build Coastguard Worker     auto u64rng = std::bind( std::uniform_int_distribution<uint64_t>(), std::ref(rng));
1066*4bdc9457SAndroid Build Coastguard Worker 
1067*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint64_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint64_t));
1068*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint32_t> y(batch_size());
1069*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint32_t> y_ref(batch_size());
1070*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
1071*4bdc9457SAndroid Build Coastguard Worker       std::generate(x.begin(), x.end(), std::ref(u64rng));
1072*4bdc9457SAndroid Build Coastguard Worker       std::fill(y.begin(), y.end(), UINT32_C(0xDEADBEEF));
1073*4bdc9457SAndroid Build Coastguard Worker 
1074*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
1075*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
1076*4bdc9457SAndroid Build Coastguard Worker         const uint64_t x_value = x[i];
1077*4bdc9457SAndroid Build Coastguard Worker         uint32_t y_value = 0;
1078*4bdc9457SAndroid Build Coastguard Worker         // Match TFLM semantics, including bugs
1079*4bdc9457SAndroid Build Coastguard Worker         if (uint32_t(x_value) == x_value) {
1080*4bdc9457SAndroid Build Coastguard Worker           y_value = (uint32_t) std::lrint(std::sqrt(double(int64_t(uint64_t(x_value)))));
1081*4bdc9457SAndroid Build Coastguard Worker           y_value = std::min<uint32_t>(y_value, std::numeric_limits<uint16_t>::max());
1082*4bdc9457SAndroid Build Coastguard Worker         } else if (x_value != 0) {
1083*4bdc9457SAndroid Build Coastguard Worker           uint64_t y0 = x_value >> 1;
1084*4bdc9457SAndroid Build Coastguard Worker           uint64_t y1 = (y0 + x_value / y0) >> 1;
1085*4bdc9457SAndroid Build Coastguard Worker           do {
1086*4bdc9457SAndroid Build Coastguard Worker             y0 = y1;
1087*4bdc9457SAndroid Build Coastguard Worker             y1 = (y0 + x_value / y0) >> 1;
1088*4bdc9457SAndroid Build Coastguard Worker           } while (y1 < y0);
1089*4bdc9457SAndroid Build Coastguard Worker 
1090*4bdc9457SAndroid Build Coastguard Worker           // y0 is sqrt(x_value) rounded down, round up if needed
1091*4bdc9457SAndroid Build Coastguard Worker           if (int64_t(y0 * y0 + y0 - x_value) < 0) {
1092*4bdc9457SAndroid Build Coastguard Worker             y0 += 1;
1093*4bdc9457SAndroid Build Coastguard Worker           }
1094*4bdc9457SAndroid Build Coastguard Worker           y_value = static_cast<uint32_t>(std::min<uint64_t>(y0, std::numeric_limits<uint32_t>::max()));
1095*4bdc9457SAndroid Build Coastguard Worker         }
1096*4bdc9457SAndroid Build Coastguard Worker         y_ref[i] = y_value >> shift();
1097*4bdc9457SAndroid Build Coastguard Worker       }
1098*4bdc9457SAndroid Build Coastguard Worker 
1099*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
1100*4bdc9457SAndroid Build Coastguard Worker       vsqrtshift(batch_size() * sizeof(uint64_t), x.data(), y.data(), shift());
1101*4bdc9457SAndroid Build Coastguard Worker 
1102*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
1103*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
1104*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(y_ref[i], y[i])
1105*4bdc9457SAndroid Build Coastguard Worker           << "at " << i << " / " << batch_size()
1106*4bdc9457SAndroid Build Coastguard Worker           << ", x[" << i << "]: " << x[i]
1107*4bdc9457SAndroid Build Coastguard Worker           << ", shift: " << shift();
1108*4bdc9457SAndroid Build Coastguard Worker       }
1109*4bdc9457SAndroid Build Coastguard Worker     }
1110*4bdc9457SAndroid Build Coastguard Worker   }
1111*4bdc9457SAndroid Build Coastguard Worker 
1112*4bdc9457SAndroid Build Coastguard Worker  private:
1113*4bdc9457SAndroid Build Coastguard Worker   size_t batch_size_ = 1;
1114*4bdc9457SAndroid Build Coastguard Worker   bool inplace_ = false;
1115*4bdc9457SAndroid Build Coastguard Worker   float slope_ = 0.5f;
1116*4bdc9457SAndroid Build Coastguard Worker   float prescale_ = 1.0f;
1117*4bdc9457SAndroid Build Coastguard Worker   float alpha_ = 1.0f;
1118*4bdc9457SAndroid Build Coastguard Worker   float beta_ = 1.0f;
1119*4bdc9457SAndroid Build Coastguard Worker   uint32_t shift_ = 1;
1120*4bdc9457SAndroid Build Coastguard Worker   uint8_t qmin_ = 0;
1121*4bdc9457SAndroid Build Coastguard Worker   uint8_t qmax_ = 255;
1122*4bdc9457SAndroid Build Coastguard Worker   size_t iterations_ = 15;
1123*4bdc9457SAndroid Build Coastguard Worker };
1124