1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker // 3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker #pragma once 7*4bdc9457SAndroid Build Coastguard Worker 8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 9*4bdc9457SAndroid Build Coastguard Worker 10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 11*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 12*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 13*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 14*4bdc9457SAndroid Build Coastguard Worker #include <random> 15*4bdc9457SAndroid Build Coastguard Worker #include <vector> 16*4bdc9457SAndroid Build Coastguard Worker 17*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h> 18*4bdc9457SAndroid Build Coastguard Worker 19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h> 21*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microparams-init.h> 22*4bdc9457SAndroid Build Coastguard Worker 23*4bdc9457SAndroid Build Coastguard Worker 24*4bdc9457SAndroid Build Coastguard Worker class VUnaryMicrokernelTester { 25*4bdc9457SAndroid Build Coastguard Worker public: 26*4bdc9457SAndroid Build Coastguard Worker enum class OpType { 27*4bdc9457SAndroid Build Coastguard Worker ReLU, 28*4bdc9457SAndroid Build Coastguard Worker RoundToNearestEven, 29*4bdc9457SAndroid Build Coastguard Worker RoundTowardsZero, 30*4bdc9457SAndroid Build Coastguard Worker RoundUp, 31*4bdc9457SAndroid Build Coastguard Worker RoundDown, 32*4bdc9457SAndroid Build Coastguard Worker }; 33*4bdc9457SAndroid Build Coastguard Worker 34*4bdc9457SAndroid Build Coastguard Worker enum class Variant { 35*4bdc9457SAndroid Build Coastguard Worker Native, 36*4bdc9457SAndroid Build Coastguard Worker Scalar, 37*4bdc9457SAndroid Build Coastguard Worker }; 38*4bdc9457SAndroid Build Coastguard Worker batch_size(size_t batch_size)39*4bdc9457SAndroid Build Coastguard Worker inline VUnaryMicrokernelTester& batch_size(size_t batch_size) { 40*4bdc9457SAndroid Build Coastguard Worker assert(batch_size != 0); 41*4bdc9457SAndroid Build Coastguard Worker this->batch_size_ = batch_size; 42*4bdc9457SAndroid Build Coastguard Worker return *this; 43*4bdc9457SAndroid Build Coastguard Worker } 44*4bdc9457SAndroid Build Coastguard Worker batch_size()45*4bdc9457SAndroid Build Coastguard Worker inline size_t batch_size() const { 46*4bdc9457SAndroid Build Coastguard Worker return this->batch_size_; 47*4bdc9457SAndroid Build Coastguard Worker } 48*4bdc9457SAndroid Build Coastguard Worker inplace(bool inplace)49*4bdc9457SAndroid Build Coastguard Worker inline VUnaryMicrokernelTester& inplace(bool inplace) { 50*4bdc9457SAndroid Build Coastguard Worker this->inplace_ = inplace; 51*4bdc9457SAndroid Build Coastguard Worker return *this; 52*4bdc9457SAndroid Build Coastguard Worker } 53*4bdc9457SAndroid Build Coastguard Worker inplace()54*4bdc9457SAndroid Build Coastguard Worker inline bool inplace() const { 55*4bdc9457SAndroid Build Coastguard Worker return this->inplace_; 56*4bdc9457SAndroid Build Coastguard Worker } 57*4bdc9457SAndroid Build Coastguard Worker slope(float slope)58*4bdc9457SAndroid Build Coastguard Worker inline VUnaryMicrokernelTester& slope(float slope) { 59*4bdc9457SAndroid Build Coastguard Worker this->slope_ = slope; 60*4bdc9457SAndroid Build Coastguard Worker return *this; 61*4bdc9457SAndroid Build Coastguard Worker } 62*4bdc9457SAndroid Build Coastguard Worker slope()63*4bdc9457SAndroid Build Coastguard Worker inline float slope() const { 64*4bdc9457SAndroid Build Coastguard Worker return this->slope_; 65*4bdc9457SAndroid Build Coastguard Worker } 66*4bdc9457SAndroid Build Coastguard Worker prescale(float prescale)67*4bdc9457SAndroid Build Coastguard Worker inline VUnaryMicrokernelTester& prescale(float prescale) { 68*4bdc9457SAndroid Build Coastguard Worker this->prescale_ = prescale; 69*4bdc9457SAndroid Build Coastguard Worker return *this; 70*4bdc9457SAndroid Build Coastguard Worker } 71*4bdc9457SAndroid Build Coastguard Worker prescale()72*4bdc9457SAndroid Build Coastguard Worker inline float prescale() const { 73*4bdc9457SAndroid Build Coastguard Worker return this->prescale_; 74*4bdc9457SAndroid Build Coastguard Worker } 75*4bdc9457SAndroid Build Coastguard Worker alpha(float alpha)76*4bdc9457SAndroid Build Coastguard Worker inline VUnaryMicrokernelTester& alpha(float alpha) { 77*4bdc9457SAndroid Build Coastguard Worker this->alpha_ = alpha; 78*4bdc9457SAndroid Build Coastguard Worker return *this; 79*4bdc9457SAndroid Build Coastguard Worker } 80*4bdc9457SAndroid Build Coastguard Worker alpha()81*4bdc9457SAndroid Build Coastguard Worker inline float alpha() const { 82*4bdc9457SAndroid Build Coastguard Worker return this->alpha_; 83*4bdc9457SAndroid Build Coastguard Worker } 84*4bdc9457SAndroid Build Coastguard Worker beta(float beta)85*4bdc9457SAndroid Build Coastguard Worker inline VUnaryMicrokernelTester& beta(float beta) { 86*4bdc9457SAndroid Build Coastguard Worker this->beta_ = beta; 87*4bdc9457SAndroid Build Coastguard Worker return *this; 88*4bdc9457SAndroid Build Coastguard Worker } 89*4bdc9457SAndroid Build Coastguard Worker beta()90*4bdc9457SAndroid Build Coastguard Worker inline float beta() const { 91*4bdc9457SAndroid Build Coastguard Worker return this->beta_; 92*4bdc9457SAndroid Build Coastguard Worker } 93*4bdc9457SAndroid Build Coastguard Worker shift(uint32_t shift)94*4bdc9457SAndroid Build Coastguard Worker inline VUnaryMicrokernelTester& shift(uint32_t shift) { 95*4bdc9457SAndroid Build Coastguard Worker this->shift_ = shift; 96*4bdc9457SAndroid Build Coastguard Worker return *this; 97*4bdc9457SAndroid Build Coastguard Worker } 98*4bdc9457SAndroid Build Coastguard Worker shift()99*4bdc9457SAndroid Build Coastguard Worker inline uint32_t shift() const { 100*4bdc9457SAndroid Build Coastguard Worker return this->shift_; 101*4bdc9457SAndroid Build Coastguard Worker } 102*4bdc9457SAndroid Build Coastguard Worker qmin(uint8_t qmin)103*4bdc9457SAndroid Build Coastguard Worker inline VUnaryMicrokernelTester& qmin(uint8_t qmin) { 104*4bdc9457SAndroid Build Coastguard Worker this->qmin_ = qmin; 105*4bdc9457SAndroid Build Coastguard Worker return *this; 106*4bdc9457SAndroid Build Coastguard Worker } 107*4bdc9457SAndroid Build Coastguard Worker qmin()108*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmin() const { 109*4bdc9457SAndroid Build Coastguard Worker return this->qmin_; 110*4bdc9457SAndroid Build Coastguard Worker } 111*4bdc9457SAndroid Build Coastguard Worker qmax(uint8_t qmax)112*4bdc9457SAndroid Build Coastguard Worker inline VUnaryMicrokernelTester& qmax(uint8_t qmax) { 113*4bdc9457SAndroid Build Coastguard Worker this->qmax_ = qmax; 114*4bdc9457SAndroid Build Coastguard Worker return *this; 115*4bdc9457SAndroid Build Coastguard Worker } 116*4bdc9457SAndroid Build Coastguard Worker qmax()117*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmax() const { 118*4bdc9457SAndroid Build Coastguard Worker return this->qmax_; 119*4bdc9457SAndroid Build Coastguard Worker } 120*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)121*4bdc9457SAndroid Build Coastguard Worker inline VUnaryMicrokernelTester& iterations(size_t iterations) { 122*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 123*4bdc9457SAndroid Build Coastguard Worker return *this; 124*4bdc9457SAndroid Build Coastguard Worker } 125*4bdc9457SAndroid Build Coastguard Worker iterations()126*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 127*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 128*4bdc9457SAndroid Build Coastguard Worker } 129*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f32_vrelu_ukernel_function vrelu)130*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_vrelu_ukernel_function vrelu) const { 131*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 132*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 133*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-1.0f, 1.0f); 134*4bdc9457SAndroid Build Coastguard Worker 135*4bdc9457SAndroid Build Coastguard Worker std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float)); 136*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0)); 137*4bdc9457SAndroid Build Coastguard Worker std::vector<double> y_ref(batch_size()); 138*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 139*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 140*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); }); 141*4bdc9457SAndroid Build Coastguard Worker } else { 142*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); }); 143*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), nanf("")); 144*4bdc9457SAndroid Build Coastguard Worker } 145*4bdc9457SAndroid Build Coastguard Worker const float* x_data = inplace() ? y.data() : x.data(); 146*4bdc9457SAndroid Build Coastguard Worker 147*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 148*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 149*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::max(x_data[i], 0.0f); 150*4bdc9457SAndroid Build Coastguard Worker } 151*4bdc9457SAndroid Build Coastguard Worker 152*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 153*4bdc9457SAndroid Build Coastguard Worker vrelu(batch_size() * sizeof(float), x_data, y.data(), nullptr); 154*4bdc9457SAndroid Build Coastguard Worker 155*4bdc9457SAndroid Build Coastguard Worker // Verify results. 156*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 157*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(y[i], y_ref[i]) 158*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i]; 159*4bdc9457SAndroid Build Coastguard Worker } 160*4bdc9457SAndroid Build Coastguard Worker } 161*4bdc9457SAndroid Build Coastguard Worker } 162*4bdc9457SAndroid Build Coastguard Worker 163*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f16_vabs_ukernel_function vabs, xnn_init_f16_abs_params_fn init_params = nullptr) const { 164*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 165*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 166*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-1.0f, 1.0f); 167*4bdc9457SAndroid Build Coastguard Worker 168*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 169*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0)); 170*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> y_ref(batch_size()); 171*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 172*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 173*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 174*4bdc9457SAndroid Build Coastguard Worker } else { 175*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 176*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */); 177*4bdc9457SAndroid Build Coastguard Worker } 178*4bdc9457SAndroid Build Coastguard Worker const uint16_t* x_data = inplace() ? y.data() : x.data(); 179*4bdc9457SAndroid Build Coastguard Worker 180*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 181*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 182*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = x_data[i] & UINT16_C(0x7FFF); 183*4bdc9457SAndroid Build Coastguard Worker } 184*4bdc9457SAndroid Build Coastguard Worker 185*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 186*4bdc9457SAndroid Build Coastguard Worker union xnn_f16_abs_params params; 187*4bdc9457SAndroid Build Coastguard Worker if (init_params != nullptr) { 188*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms); 189*4bdc9457SAndroid Build Coastguard Worker } 190*4bdc9457SAndroid Build Coastguard Worker 191*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 192*4bdc9457SAndroid Build Coastguard Worker vabs(batch_size() * sizeof(uint16_t), x_data, y.data(), ¶ms); 193*4bdc9457SAndroid Build Coastguard Worker 194*4bdc9457SAndroid Build Coastguard Worker // Verify results. 195*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 196*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(y[i], y_ref[i]) 197*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i]; 198*4bdc9457SAndroid Build Coastguard Worker } 199*4bdc9457SAndroid Build Coastguard Worker } 200*4bdc9457SAndroid Build Coastguard Worker } 201*4bdc9457SAndroid Build Coastguard Worker 202*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_vabs_ukernel_function vabs, xnn_init_f32_abs_params_fn init_params = nullptr) const { 203*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 204*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 205*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-1.0f, 1.0f); 206*4bdc9457SAndroid Build Coastguard Worker 207*4bdc9457SAndroid Build Coastguard Worker std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float)); 208*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0)); 209*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 210*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 211*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 212*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); }); 213*4bdc9457SAndroid Build Coastguard Worker } else { 214*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); }); 215*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), nanf("")); 216*4bdc9457SAndroid Build Coastguard Worker } 217*4bdc9457SAndroid Build Coastguard Worker const float* x_data = inplace() ? y.data() : x.data(); 218*4bdc9457SAndroid Build Coastguard Worker 219*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 220*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 221*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::abs(x_data[i]); 222*4bdc9457SAndroid Build Coastguard Worker } 223*4bdc9457SAndroid Build Coastguard Worker 224*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 225*4bdc9457SAndroid Build Coastguard Worker union xnn_f32_abs_params params; 226*4bdc9457SAndroid Build Coastguard Worker if (init_params != nullptr) { 227*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms); 228*4bdc9457SAndroid Build Coastguard Worker } 229*4bdc9457SAndroid Build Coastguard Worker 230*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 231*4bdc9457SAndroid Build Coastguard Worker vabs(batch_size() * sizeof(float), x_data, y.data(), ¶ms); 232*4bdc9457SAndroid Build Coastguard Worker 233*4bdc9457SAndroid Build Coastguard Worker // Verify results. 234*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 235*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(y[i], y_ref[i]) 236*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i]; 237*4bdc9457SAndroid Build Coastguard Worker } 238*4bdc9457SAndroid Build Coastguard Worker } 239*4bdc9457SAndroid Build Coastguard Worker } 240*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f32_vclamp_ukernel_function vclamp,xnn_init_f32_minmax_params_fn init_params)241*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_vclamp_ukernel_function vclamp, xnn_init_f32_minmax_params_fn init_params) const { 242*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 243*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 244*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.0f, 255.0f); 245*4bdc9457SAndroid Build Coastguard Worker 246*4bdc9457SAndroid Build Coastguard Worker std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float)); 247*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0)); 248*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 249*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 250*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 251*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); }); 252*4bdc9457SAndroid Build Coastguard Worker } else { 253*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); }); 254*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), nanf("")); 255*4bdc9457SAndroid Build Coastguard Worker } 256*4bdc9457SAndroid Build Coastguard Worker const float* x_data = inplace() ? y.data() : x.data(); 257*4bdc9457SAndroid Build Coastguard Worker 258*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 259*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 260*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::max(std::min(x_data[i], float(qmax())), float(qmin())); 261*4bdc9457SAndroid Build Coastguard Worker } 262*4bdc9457SAndroid Build Coastguard Worker 263*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 264*4bdc9457SAndroid Build Coastguard Worker union xnn_f32_minmax_params params; 265*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, float(qmin()), float(qmax())); 266*4bdc9457SAndroid Build Coastguard Worker 267*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 268*4bdc9457SAndroid Build Coastguard Worker vclamp(batch_size() * sizeof(float), x_data, y.data(), ¶ms); 269*4bdc9457SAndroid Build Coastguard Worker 270*4bdc9457SAndroid Build Coastguard Worker // Verify results. 271*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 272*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(y[i], y_ref[i]) 273*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i]; 274*4bdc9457SAndroid Build Coastguard Worker } 275*4bdc9457SAndroid Build Coastguard Worker } 276*4bdc9457SAndroid Build Coastguard Worker } 277*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f16_velu_ukernel_function velu,xnn_init_f16_elu_params_fn init_params)278*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f16_velu_ukernel_function velu, xnn_init_f16_elu_params_fn init_params) const { 279*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 280*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 281*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-9.0f, 9.0f); 282*4bdc9457SAndroid Build Coastguard Worker 283*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 284*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0)); 285*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 286*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 287*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 288*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 289*4bdc9457SAndroid Build Coastguard Worker } else { 290*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 291*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */); 292*4bdc9457SAndroid Build Coastguard Worker } 293*4bdc9457SAndroid Build Coastguard Worker const uint16_t* x_data = inplace() ? y.data() : x.data(); 294*4bdc9457SAndroid Build Coastguard Worker 295*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 296*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 297*4bdc9457SAndroid Build Coastguard Worker const float x_value = fp16_ieee_to_fp32_value(x_data[i]); 298*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::signbit(x_value) ? alpha() * std::expm1(x_value * prescale()) : x_value * beta(); 299*4bdc9457SAndroid Build Coastguard Worker } 300*4bdc9457SAndroid Build Coastguard Worker 301*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 302*4bdc9457SAndroid Build Coastguard Worker union xnn_f16_elu_params params; 303*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, fp16_ieee_from_fp32_value(prescale()), fp16_ieee_from_fp32_value(alpha()), fp16_ieee_from_fp32_value(beta())); 304*4bdc9457SAndroid Build Coastguard Worker 305*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 306*4bdc9457SAndroid Build Coastguard Worker velu(batch_size() * sizeof(uint16_t), x_data, y.data(), ¶ms); 307*4bdc9457SAndroid Build Coastguard Worker 308*4bdc9457SAndroid Build Coastguard Worker // Verify results. 309*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 310*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 311*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(y[i]), 312*4bdc9457SAndroid Build Coastguard Worker y_ref[i], 313*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-4f, std::abs(y_ref[i]) * 5.0e-3f)) 314*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << fp16_ieee_to_fp32_value(x[i]); 315*4bdc9457SAndroid Build Coastguard Worker } 316*4bdc9457SAndroid Build Coastguard Worker } 317*4bdc9457SAndroid Build Coastguard Worker } 318*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f32_velu_ukernel_function velu,xnn_init_f32_elu_params_fn init_params)319*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_velu_ukernel_function velu, xnn_init_f32_elu_params_fn init_params) const { 320*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 321*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 322*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-20.0f, 20.0f); 323*4bdc9457SAndroid Build Coastguard Worker 324*4bdc9457SAndroid Build Coastguard Worker std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float)); 325*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0)); 326*4bdc9457SAndroid Build Coastguard Worker std::vector<double> y_ref(batch_size()); 327*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 328*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 329*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); }); 330*4bdc9457SAndroid Build Coastguard Worker } else { 331*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); }); 332*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), nanf("")); 333*4bdc9457SAndroid Build Coastguard Worker } 334*4bdc9457SAndroid Build Coastguard Worker const float* x_data = inplace() ? y.data() : x.data(); 335*4bdc9457SAndroid Build Coastguard Worker 336*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 337*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 338*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::signbit(x_data[i]) ? alpha() * std::expm1(double(x_data[i]) * prescale()) : double(x_data[i]) * beta(); 339*4bdc9457SAndroid Build Coastguard Worker } 340*4bdc9457SAndroid Build Coastguard Worker 341*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 342*4bdc9457SAndroid Build Coastguard Worker union xnn_f32_elu_params params; 343*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, prescale(), alpha(), beta()); 344*4bdc9457SAndroid Build Coastguard Worker 345*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 346*4bdc9457SAndroid Build Coastguard Worker velu(batch_size() * sizeof(float), x_data, y.data(), ¶ms); 347*4bdc9457SAndroid Build Coastguard Worker 348*4bdc9457SAndroid Build Coastguard Worker // Verify results. 349*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 350*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(y[i], y_ref[i], std::max(5.0e-6, std::abs(y_ref[i]) * 1.0e-5)) 351*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i]; 352*4bdc9457SAndroid Build Coastguard Worker } 353*4bdc9457SAndroid Build Coastguard Worker } 354*4bdc9457SAndroid Build Coastguard Worker } 355*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f16_vhswish_ukernel_function vhswish,xnn_init_f16_hswish_params_fn init_params)356*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f16_vhswish_ukernel_function vhswish, xnn_init_f16_hswish_params_fn init_params) const { 357*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 358*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 359*4bdc9457SAndroid Build Coastguard Worker auto f32rng = std::bind(std::uniform_real_distribution<float>(-4.0f, 4.0f), std::ref(rng)); 360*4bdc9457SAndroid Build Coastguard Worker auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng); 361*4bdc9457SAndroid Build Coastguard Worker 362*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 363*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0)); 364*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 365*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 366*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), std::ref(f16rng)); 367*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 368*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), std::ref(f16rng)); 369*4bdc9457SAndroid Build Coastguard Worker } else { 370*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */); 371*4bdc9457SAndroid Build Coastguard Worker } 372*4bdc9457SAndroid Build Coastguard Worker const uint16_t* x_data = inplace() ? y.data() : x.data(); 373*4bdc9457SAndroid Build Coastguard Worker 374*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 375*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 376*4bdc9457SAndroid Build Coastguard Worker const float x_value = fp16_ieee_to_fp32_value(x_data[i]); 377*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = (x_value / 6.0f) * std::max(std::min(x_value + 3.0f, 6.0f), 0.0f); 378*4bdc9457SAndroid Build Coastguard Worker } 379*4bdc9457SAndroid Build Coastguard Worker 380*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 381*4bdc9457SAndroid Build Coastguard Worker union xnn_f16_hswish_params params; 382*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms); 383*4bdc9457SAndroid Build Coastguard Worker 384*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 385*4bdc9457SAndroid Build Coastguard Worker vhswish(batch_size() * sizeof(uint16_t), x_data, y.data(), ¶ms); 386*4bdc9457SAndroid Build Coastguard Worker 387*4bdc9457SAndroid Build Coastguard Worker // Verify results. 388*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 389*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(y_ref[i], fp16_ieee_to_fp32_value(y[i]), std::max(1.0e-3f, std::abs(y_ref[i]) * 1.0e-2f)) 390*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << fp16_ieee_to_fp32_value(x[i]); 391*4bdc9457SAndroid Build Coastguard Worker } 392*4bdc9457SAndroid Build Coastguard Worker } 393*4bdc9457SAndroid Build Coastguard Worker } 394*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f32_vhswish_ukernel_function vhswish,xnn_init_f32_hswish_params_fn init_params)395*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_vhswish_ukernel_function vhswish, xnn_init_f32_hswish_params_fn init_params) const { 396*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 397*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 398*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-4.0f, 4.0f); 399*4bdc9457SAndroid Build Coastguard Worker 400*4bdc9457SAndroid Build Coastguard Worker std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float)); 401*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0)); 402*4bdc9457SAndroid Build Coastguard Worker std::vector<double> y_ref(batch_size()); 403*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 404*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 405*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); }); 406*4bdc9457SAndroid Build Coastguard Worker } else { 407*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); }); 408*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), nanf("")); 409*4bdc9457SAndroid Build Coastguard Worker } 410*4bdc9457SAndroid Build Coastguard Worker const float* x_data = inplace() ? y.data() : x.data(); 411*4bdc9457SAndroid Build Coastguard Worker 412*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 413*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 414*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = (x_data[i] / 6.0f) * std::max(std::min(x_data[i] + 3.0f, 6.0f), 0.0f); 415*4bdc9457SAndroid Build Coastguard Worker } 416*4bdc9457SAndroid Build Coastguard Worker 417*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 418*4bdc9457SAndroid Build Coastguard Worker union xnn_f32_hswish_params params; 419*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms); 420*4bdc9457SAndroid Build Coastguard Worker 421*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 422*4bdc9457SAndroid Build Coastguard Worker vhswish(batch_size() * sizeof(float), x_data, y.data(), ¶ms); 423*4bdc9457SAndroid Build Coastguard Worker 424*4bdc9457SAndroid Build Coastguard Worker // Verify results. 425*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 426*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(y[i], y_ref[i], std::max(5.0e-6, std::abs(y_ref[i]) * 1.0e-5)) 427*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i]; 428*4bdc9457SAndroid Build Coastguard Worker } 429*4bdc9457SAndroid Build Coastguard Worker } 430*4bdc9457SAndroid Build Coastguard Worker } 431*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f16_vlrelu_ukernel_function vlrelu,xnn_init_f16_lrelu_params_fn init_params)432*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f16_vlrelu_ukernel_function vlrelu, xnn_init_f16_lrelu_params_fn init_params) const { 433*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 434*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 435*4bdc9457SAndroid Build Coastguard Worker auto f32rng = std::bind(std::uniform_real_distribution<float>(-125.0f, 125.0f), std::ref(rng)); 436*4bdc9457SAndroid Build Coastguard Worker auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng); 437*4bdc9457SAndroid Build Coastguard Worker 438*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 439*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0)); 440*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 441*4bdc9457SAndroid Build Coastguard Worker const uint16_t slope_as_half = fp16_ieee_from_fp32_value(slope()); 442*4bdc9457SAndroid Build Coastguard Worker const float slope_as_float = fp16_ieee_to_fp32_value(slope_as_half); 443*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 444*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 445*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), std::ref(f16rng)); 446*4bdc9457SAndroid Build Coastguard Worker } else { 447*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), std::ref(f16rng)); 448*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */); 449*4bdc9457SAndroid Build Coastguard Worker } 450*4bdc9457SAndroid Build Coastguard Worker const uint16_t* x_data = inplace() ? y.data() : x.data(); 451*4bdc9457SAndroid Build Coastguard Worker 452*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 453*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 454*4bdc9457SAndroid Build Coastguard Worker const float x_value = fp16_ieee_to_fp32_value(x_data[i]); 455*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::signbit(x_value) ? x_value * slope_as_float : x_value; 456*4bdc9457SAndroid Build Coastguard Worker } 457*4bdc9457SAndroid Build Coastguard Worker 458*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 459*4bdc9457SAndroid Build Coastguard Worker union xnn_f16_lrelu_params params; 460*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, slope_as_half); 461*4bdc9457SAndroid Build Coastguard Worker 462*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 463*4bdc9457SAndroid Build Coastguard Worker vlrelu(batch_size() * sizeof(uint16_t), x_data, y.data(), ¶ms); 464*4bdc9457SAndroid Build Coastguard Worker 465*4bdc9457SAndroid Build Coastguard Worker // Verify results. 466*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 467*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 468*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(y[i]), 469*4bdc9457SAndroid Build Coastguard Worker y_ref[i], 470*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-4f, std::abs(y_ref[i]) * 1.0e-3f)) 471*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << fp16_ieee_to_fp32_value(x[i]); 472*4bdc9457SAndroid Build Coastguard Worker } 473*4bdc9457SAndroid Build Coastguard Worker } 474*4bdc9457SAndroid Build Coastguard Worker } 475*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f32_vlrelu_ukernel_function vlrelu,xnn_init_f32_lrelu_params_fn init_params)476*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_vlrelu_ukernel_function vlrelu, xnn_init_f32_lrelu_params_fn init_params) const { 477*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 478*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 479*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-125.0f, 125.0f); 480*4bdc9457SAndroid Build Coastguard Worker 481*4bdc9457SAndroid Build Coastguard Worker std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float)); 482*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0)); 483*4bdc9457SAndroid Build Coastguard Worker std::vector<double> y_ref(batch_size()); 484*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 485*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 486*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); }); 487*4bdc9457SAndroid Build Coastguard Worker } else { 488*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); }); 489*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), nanf("")); 490*4bdc9457SAndroid Build Coastguard Worker } 491*4bdc9457SAndroid Build Coastguard Worker const float* x_data = inplace() ? y.data() : x.data(); 492*4bdc9457SAndroid Build Coastguard Worker 493*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 494*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 495*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::signbit(x_data[i]) ? x_data[i] * slope() : x_data[i]; 496*4bdc9457SAndroid Build Coastguard Worker } 497*4bdc9457SAndroid Build Coastguard Worker 498*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 499*4bdc9457SAndroid Build Coastguard Worker union xnn_f32_lrelu_params params; 500*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, slope()); 501*4bdc9457SAndroid Build Coastguard Worker 502*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 503*4bdc9457SAndroid Build Coastguard Worker vlrelu(batch_size() * sizeof(float), x_data, y.data(), ¶ms); 504*4bdc9457SAndroid Build Coastguard Worker 505*4bdc9457SAndroid Build Coastguard Worker // Verify results. 506*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 507*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(y[i], y_ref[i]) 508*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i]; 509*4bdc9457SAndroid Build Coastguard Worker } 510*4bdc9457SAndroid Build Coastguard Worker } 511*4bdc9457SAndroid Build Coastguard Worker } 512*4bdc9457SAndroid Build Coastguard Worker 513*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f16_vneg_ukernel_function vneg, xnn_init_f16_neg_params_fn init_params = nullptr) const { 514*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 515*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 516*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-1.0f, 1.0f); 517*4bdc9457SAndroid Build Coastguard Worker 518*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 519*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0)); 520*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> y_ref(batch_size()); 521*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 522*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 523*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 524*4bdc9457SAndroid Build Coastguard Worker } else { 525*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 526*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */); 527*4bdc9457SAndroid Build Coastguard Worker } 528*4bdc9457SAndroid Build Coastguard Worker const uint16_t* x_data = inplace() ? y.data() : x.data(); 529*4bdc9457SAndroid Build Coastguard Worker 530*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 531*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 532*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = x_data[i] ^ UINT16_C(0x8000); 533*4bdc9457SAndroid Build Coastguard Worker } 534*4bdc9457SAndroid Build Coastguard Worker 535*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 536*4bdc9457SAndroid Build Coastguard Worker union xnn_f16_neg_params params; 537*4bdc9457SAndroid Build Coastguard Worker if (init_params != nullptr) { 538*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms); 539*4bdc9457SAndroid Build Coastguard Worker } 540*4bdc9457SAndroid Build Coastguard Worker 541*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 542*4bdc9457SAndroid Build Coastguard Worker vneg(batch_size() * sizeof(uint16_t), x_data, y.data(), ¶ms); 543*4bdc9457SAndroid Build Coastguard Worker 544*4bdc9457SAndroid Build Coastguard Worker // Verify results. 545*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 546*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(y[i], y_ref[i]) 547*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i]; 548*4bdc9457SAndroid Build Coastguard Worker } 549*4bdc9457SAndroid Build Coastguard Worker } 550*4bdc9457SAndroid Build Coastguard Worker } 551*4bdc9457SAndroid Build Coastguard Worker 552*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_vneg_ukernel_function vneg, xnn_init_f32_neg_params_fn init_params = nullptr) const { 553*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 554*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 555*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-1.0f, 1.0f); 556*4bdc9457SAndroid Build Coastguard Worker 557*4bdc9457SAndroid Build Coastguard Worker std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float)); 558*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0)); 559*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 560*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 561*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 562*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); }); 563*4bdc9457SAndroid Build Coastguard Worker } else { 564*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); }); 565*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), nanf("")); 566*4bdc9457SAndroid Build Coastguard Worker } 567*4bdc9457SAndroid Build Coastguard Worker const float* x_data = inplace() ? y.data() : x.data(); 568*4bdc9457SAndroid Build Coastguard Worker 569*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 570*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 571*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = -x_data[i]; 572*4bdc9457SAndroid Build Coastguard Worker } 573*4bdc9457SAndroid Build Coastguard Worker 574*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 575*4bdc9457SAndroid Build Coastguard Worker union xnn_f32_neg_params params; 576*4bdc9457SAndroid Build Coastguard Worker if (init_params != nullptr) { 577*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms); 578*4bdc9457SAndroid Build Coastguard Worker } 579*4bdc9457SAndroid Build Coastguard Worker 580*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 581*4bdc9457SAndroid Build Coastguard Worker vneg(batch_size() * sizeof(float), x_data, y.data(), ¶ms); 582*4bdc9457SAndroid Build Coastguard Worker 583*4bdc9457SAndroid Build Coastguard Worker // Verify results. 584*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 585*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(y[i], y_ref[i]) 586*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i]; 587*4bdc9457SAndroid Build Coastguard Worker } 588*4bdc9457SAndroid Build Coastguard Worker } 589*4bdc9457SAndroid Build Coastguard Worker } 590*4bdc9457SAndroid Build Coastguard Worker 591*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f16_vround_ukernel_function vrnd, OpType op_type, xnn_init_f16_rnd_params_fn init_params = nullptr) const { 592*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 593*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 594*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-5.0f, 5.0f); 595*4bdc9457SAndroid Build Coastguard Worker 596*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 597*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0)); 598*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> y_ref(batch_size()); 599*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 600*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 601*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 602*4bdc9457SAndroid Build Coastguard Worker } else { 603*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 604*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */); 605*4bdc9457SAndroid Build Coastguard Worker } 606*4bdc9457SAndroid Build Coastguard Worker const uint16_t* x_data = inplace() ? y.data() : x.data(); 607*4bdc9457SAndroid Build Coastguard Worker 608*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 609*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 610*4bdc9457SAndroid Build Coastguard Worker switch (op_type) { 611*4bdc9457SAndroid Build Coastguard Worker case OpType::RoundToNearestEven: 612*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_from_fp32_value(std::nearbyint(fp16_ieee_to_fp32_value(x_data[i]))); 613*4bdc9457SAndroid Build Coastguard Worker break; 614*4bdc9457SAndroid Build Coastguard Worker case OpType::RoundTowardsZero: 615*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_from_fp32_value(std::trunc(fp16_ieee_to_fp32_value(x_data[i]))); 616*4bdc9457SAndroid Build Coastguard Worker break; 617*4bdc9457SAndroid Build Coastguard Worker case OpType::RoundUp: 618*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_from_fp32_value(std::ceil(fp16_ieee_to_fp32_value(x_data[i]))); 619*4bdc9457SAndroid Build Coastguard Worker break; 620*4bdc9457SAndroid Build Coastguard Worker case OpType::RoundDown: 621*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_from_fp32_value(std::floor(fp16_ieee_to_fp32_value(x_data[i]))); 622*4bdc9457SAndroid Build Coastguard Worker break; 623*4bdc9457SAndroid Build Coastguard Worker default: 624*4bdc9457SAndroid Build Coastguard Worker GTEST_FAIL() << "Unexpected operation type"; 625*4bdc9457SAndroid Build Coastguard Worker return; 626*4bdc9457SAndroid Build Coastguard Worker } 627*4bdc9457SAndroid Build Coastguard Worker } 628*4bdc9457SAndroid Build Coastguard Worker 629*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 630*4bdc9457SAndroid Build Coastguard Worker xnn_f16_rnd_params params; 631*4bdc9457SAndroid Build Coastguard Worker if (init_params != nullptr) { 632*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms); 633*4bdc9457SAndroid Build Coastguard Worker } 634*4bdc9457SAndroid Build Coastguard Worker 635*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 636*4bdc9457SAndroid Build Coastguard Worker vrnd(batch_size() * sizeof(uint16_t), x_data, y.data(), ¶ms); 637*4bdc9457SAndroid Build Coastguard Worker 638*4bdc9457SAndroid Build Coastguard Worker // Verify results. 639*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 640*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(y[i], y_ref[i]) 641*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i]; 642*4bdc9457SAndroid Build Coastguard Worker } 643*4bdc9457SAndroid Build Coastguard Worker } 644*4bdc9457SAndroid Build Coastguard Worker } 645*4bdc9457SAndroid Build Coastguard Worker 646*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_vround_ukernel_function vrnd, OpType op_type, xnn_init_f32_rnd_params_fn init_params = nullptr) const { 647*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 648*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 649*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-5.0f, 5.0f); 650*4bdc9457SAndroid Build Coastguard Worker 651*4bdc9457SAndroid Build Coastguard Worker std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float)); 652*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0)); 653*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 654*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 655*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 656*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); }); 657*4bdc9457SAndroid Build Coastguard Worker } else { 658*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); }); 659*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), nanf("")); 660*4bdc9457SAndroid Build Coastguard Worker } 661*4bdc9457SAndroid Build Coastguard Worker const float* x_data = inplace() ? y.data() : x.data(); 662*4bdc9457SAndroid Build Coastguard Worker 663*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 664*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 665*4bdc9457SAndroid Build Coastguard Worker switch (op_type) { 666*4bdc9457SAndroid Build Coastguard Worker case OpType::RoundToNearestEven: 667*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::nearbyint(x_data[i]); 668*4bdc9457SAndroid Build Coastguard Worker break; 669*4bdc9457SAndroid Build Coastguard Worker case OpType::RoundTowardsZero: 670*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::trunc(x_data[i]); 671*4bdc9457SAndroid Build Coastguard Worker break; 672*4bdc9457SAndroid Build Coastguard Worker case OpType::RoundUp: 673*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::ceil(x_data[i]); 674*4bdc9457SAndroid Build Coastguard Worker break; 675*4bdc9457SAndroid Build Coastguard Worker case OpType::RoundDown: 676*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::floor(x_data[i]); 677*4bdc9457SAndroid Build Coastguard Worker break; 678*4bdc9457SAndroid Build Coastguard Worker default: 679*4bdc9457SAndroid Build Coastguard Worker GTEST_FAIL() << "Unexpected operation type"; 680*4bdc9457SAndroid Build Coastguard Worker return; 681*4bdc9457SAndroid Build Coastguard Worker } 682*4bdc9457SAndroid Build Coastguard Worker } 683*4bdc9457SAndroid Build Coastguard Worker 684*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 685*4bdc9457SAndroid Build Coastguard Worker xnn_f32_rnd_params params; 686*4bdc9457SAndroid Build Coastguard Worker if (init_params != nullptr) { 687*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms); 688*4bdc9457SAndroid Build Coastguard Worker } 689*4bdc9457SAndroid Build Coastguard Worker 690*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 691*4bdc9457SAndroid Build Coastguard Worker vrnd(batch_size() * sizeof(float), x_data, y.data(), ¶ms); 692*4bdc9457SAndroid Build Coastguard Worker 693*4bdc9457SAndroid Build Coastguard Worker // Verify results. 694*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 695*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(y[i], y_ref[i]) 696*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i]; 697*4bdc9457SAndroid Build Coastguard Worker } 698*4bdc9457SAndroid Build Coastguard Worker } 699*4bdc9457SAndroid Build Coastguard Worker } 700*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f16_vsigmoid_ukernel_function vsigmoid,xnn_init_f16_sigmoid_params_fn init_params)701*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f16_vsigmoid_ukernel_function vsigmoid, xnn_init_f16_sigmoid_params_fn init_params) const { 702*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 703*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 704*4bdc9457SAndroid Build Coastguard Worker auto distribution = std::uniform_real_distribution<float>(-25.0f, 25.0f); 705*4bdc9457SAndroid Build Coastguard Worker auto f32rng = std::bind(distribution, std::ref(rng)); 706*4bdc9457SAndroid Build Coastguard Worker auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng); 707*4bdc9457SAndroid Build Coastguard Worker 708*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 709*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0)); 710*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 711*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 712*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 713*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), std::ref(f16rng)); 714*4bdc9457SAndroid Build Coastguard Worker } else { 715*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), std::ref(f16rng)); 716*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */); 717*4bdc9457SAndroid Build Coastguard Worker } 718*4bdc9457SAndroid Build Coastguard Worker const uint16_t* x_data = inplace() ? y.data() : x.data(); 719*4bdc9457SAndroid Build Coastguard Worker 720*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 721*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 722*4bdc9457SAndroid Build Coastguard Worker const float e = std::exp(fp16_ieee_to_fp32_value(x_data[i])); 723*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = e / (1.0f + e); 724*4bdc9457SAndroid Build Coastguard Worker } 725*4bdc9457SAndroid Build Coastguard Worker 726*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 727*4bdc9457SAndroid Build Coastguard Worker union xnn_f16_sigmoid_params params; 728*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms); 729*4bdc9457SAndroid Build Coastguard Worker 730*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 731*4bdc9457SAndroid Build Coastguard Worker vsigmoid(batch_size() * sizeof(uint16_t), x_data, y.data(), ¶ms); 732*4bdc9457SAndroid Build Coastguard Worker 733*4bdc9457SAndroid Build Coastguard Worker // Verify results. 734*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 735*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 736*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(y[i]), 737*4bdc9457SAndroid Build Coastguard Worker y_ref[i], 738*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-4f, std::abs(y_ref[i]) * 5.0e-3f)) 739*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << fp16_ieee_to_fp32_value(x[i]); 740*4bdc9457SAndroid Build Coastguard Worker } 741*4bdc9457SAndroid Build Coastguard Worker } 742*4bdc9457SAndroid Build Coastguard Worker } 743*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f32_vsigmoid_ukernel_function vsigmoid,xnn_init_f32_sigmoid_params_fn init_params)744*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_vsigmoid_ukernel_function vsigmoid, xnn_init_f32_sigmoid_params_fn init_params) const { 745*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 746*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 747*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-125.0f, 125.0f); 748*4bdc9457SAndroid Build Coastguard Worker 749*4bdc9457SAndroid Build Coastguard Worker std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float)); 750*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0)); 751*4bdc9457SAndroid Build Coastguard Worker std::vector<double> y_ref(batch_size()); 752*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 753*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 754*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); }); 755*4bdc9457SAndroid Build Coastguard Worker } else { 756*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); }); 757*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), nanf("")); 758*4bdc9457SAndroid Build Coastguard Worker } 759*4bdc9457SAndroid Build Coastguard Worker const float* x_data = inplace() ? y.data() : x.data(); 760*4bdc9457SAndroid Build Coastguard Worker 761*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 762*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 763*4bdc9457SAndroid Build Coastguard Worker const double e = std::exp(double(x_data[i])); 764*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = e / (1.0 + e); 765*4bdc9457SAndroid Build Coastguard Worker } 766*4bdc9457SAndroid Build Coastguard Worker 767*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 768*4bdc9457SAndroid Build Coastguard Worker union xnn_f32_sigmoid_params params; 769*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms); 770*4bdc9457SAndroid Build Coastguard Worker 771*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 772*4bdc9457SAndroid Build Coastguard Worker vsigmoid(batch_size() * sizeof(float), x_data, y.data(), ¶ms); 773*4bdc9457SAndroid Build Coastguard Worker 774*4bdc9457SAndroid Build Coastguard Worker // Verify results. 775*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 776*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(y[i], y_ref[i], std::max(5.0e-6, std::abs(y_ref[i]) * 1.0e-5)) 777*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i]; 778*4bdc9457SAndroid Build Coastguard Worker } 779*4bdc9457SAndroid Build Coastguard Worker } 780*4bdc9457SAndroid Build Coastguard Worker } 781*4bdc9457SAndroid Build Coastguard Worker 782*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f16_vsqr_ukernel_function vsqr, xnn_init_f16_default_params_fn init_params = nullptr) const { 783*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 784*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 785*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-10.0f, 10.0f); 786*4bdc9457SAndroid Build Coastguard Worker 787*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 788*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0)); 789*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 790*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 791*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 792*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 793*4bdc9457SAndroid Build Coastguard Worker } else { 794*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 795*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */); 796*4bdc9457SAndroid Build Coastguard Worker } 797*4bdc9457SAndroid Build Coastguard Worker const uint16_t* x_data = inplace() ? y.data() : x.data(); 798*4bdc9457SAndroid Build Coastguard Worker 799*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 800*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 801*4bdc9457SAndroid Build Coastguard Worker const float x_value = fp16_ieee_to_fp32_value(x_data[i]); 802*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = x_value * x_value; 803*4bdc9457SAndroid Build Coastguard Worker } 804*4bdc9457SAndroid Build Coastguard Worker 805*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 806*4bdc9457SAndroid Build Coastguard Worker union xnn_f16_default_params params; 807*4bdc9457SAndroid Build Coastguard Worker if (init_params != nullptr) { 808*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms); 809*4bdc9457SAndroid Build Coastguard Worker } 810*4bdc9457SAndroid Build Coastguard Worker 811*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 812*4bdc9457SAndroid Build Coastguard Worker vsqr(batch_size() * sizeof(uint16_t), x_data, y.data(), ¶ms); 813*4bdc9457SAndroid Build Coastguard Worker 814*4bdc9457SAndroid Build Coastguard Worker // Verify results. 815*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 816*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 817*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(y[i]), 818*4bdc9457SAndroid Build Coastguard Worker y_ref[i], 819*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-4f, std::abs(y_ref[i]) * 5.0e-3f)) 820*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << fp16_ieee_to_fp32_value(x[i]); 821*4bdc9457SAndroid Build Coastguard Worker } 822*4bdc9457SAndroid Build Coastguard Worker } 823*4bdc9457SAndroid Build Coastguard Worker } 824*4bdc9457SAndroid Build Coastguard Worker 825*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_vsqr_ukernel_function vsqr, xnn_init_f32_default_params_fn init_params = nullptr) const { 826*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 827*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 828*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-10.0f, 10.0f); 829*4bdc9457SAndroid Build Coastguard Worker 830*4bdc9457SAndroid Build Coastguard Worker std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float)); 831*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0)); 832*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 833*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 834*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 835*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); }); 836*4bdc9457SAndroid Build Coastguard Worker } else { 837*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); }); 838*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), nanf("")); 839*4bdc9457SAndroid Build Coastguard Worker } 840*4bdc9457SAndroid Build Coastguard Worker const float* x_data = inplace() ? y.data() : x.data(); 841*4bdc9457SAndroid Build Coastguard Worker 842*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 843*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 844*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = x_data[i] * x_data[i]; 845*4bdc9457SAndroid Build Coastguard Worker } 846*4bdc9457SAndroid Build Coastguard Worker 847*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 848*4bdc9457SAndroid Build Coastguard Worker union xnn_f32_default_params params; 849*4bdc9457SAndroid Build Coastguard Worker if (init_params != nullptr) { 850*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms); 851*4bdc9457SAndroid Build Coastguard Worker } 852*4bdc9457SAndroid Build Coastguard Worker 853*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 854*4bdc9457SAndroid Build Coastguard Worker vsqr(batch_size() * sizeof(float), x_data, y.data(), ¶ms); 855*4bdc9457SAndroid Build Coastguard Worker 856*4bdc9457SAndroid Build Coastguard Worker // Verify results. 857*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 858*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(y[i], y_ref[i]) 859*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i]; 860*4bdc9457SAndroid Build Coastguard Worker } 861*4bdc9457SAndroid Build Coastguard Worker } 862*4bdc9457SAndroid Build Coastguard Worker } 863*4bdc9457SAndroid Build Coastguard Worker 864*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f16_vsqrt_ukernel_function vsqrt, xnn_init_f16_sqrt_params_fn init_params = nullptr) const { 865*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 866*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 867*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.0f, 10.0f); 868*4bdc9457SAndroid Build Coastguard Worker 869*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 870*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0)); 871*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 872*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 873*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 874*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 875*4bdc9457SAndroid Build Coastguard Worker } else { 876*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 877*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */); 878*4bdc9457SAndroid Build Coastguard Worker } 879*4bdc9457SAndroid Build Coastguard Worker const uint16_t* x_data = inplace() ? y.data() : x.data(); 880*4bdc9457SAndroid Build Coastguard Worker 881*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 882*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 883*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::sqrt(fp16_ieee_to_fp32_value(x_data[i])); 884*4bdc9457SAndroid Build Coastguard Worker } 885*4bdc9457SAndroid Build Coastguard Worker 886*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 887*4bdc9457SAndroid Build Coastguard Worker union xnn_f16_sqrt_params params; 888*4bdc9457SAndroid Build Coastguard Worker if (init_params != nullptr) { 889*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms); 890*4bdc9457SAndroid Build Coastguard Worker } 891*4bdc9457SAndroid Build Coastguard Worker 892*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 893*4bdc9457SAndroid Build Coastguard Worker vsqrt(batch_size() * sizeof(uint16_t), x_data, y.data(), init_params != nullptr ? ¶ms : nullptr); 894*4bdc9457SAndroid Build Coastguard Worker 895*4bdc9457SAndroid Build Coastguard Worker // Verify results. 896*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 897*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 898*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(y[i]), 899*4bdc9457SAndroid Build Coastguard Worker y_ref[i], 900*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-4f, std::abs(y_ref[i]) * 5.0e-3f)) 901*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << fp16_ieee_to_fp32_value(x[i]); 902*4bdc9457SAndroid Build Coastguard Worker } 903*4bdc9457SAndroid Build Coastguard Worker } 904*4bdc9457SAndroid Build Coastguard Worker } 905*4bdc9457SAndroid Build Coastguard Worker 906*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_vsqrt_ukernel_function vsqrt, xnn_init_f32_sqrt_params_fn init_params = nullptr) const { 907*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 908*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 909*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.0f, 10.0f); 910*4bdc9457SAndroid Build Coastguard Worker 911*4bdc9457SAndroid Build Coastguard Worker std::vector<float> x(batch_size() + XNN_EXTRA_BYTES / sizeof(float)); 912*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0)); 913*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 914*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 915*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 916*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); }); 917*4bdc9457SAndroid Build Coastguard Worker } else { 918*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), [&]() { return f32dist(rng); }); 919*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), nanf("")); 920*4bdc9457SAndroid Build Coastguard Worker } 921*4bdc9457SAndroid Build Coastguard Worker const float* x_data = inplace() ? y.data() : x.data(); 922*4bdc9457SAndroid Build Coastguard Worker 923*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 924*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 925*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::sqrt(x_data[i]); 926*4bdc9457SAndroid Build Coastguard Worker } 927*4bdc9457SAndroid Build Coastguard Worker 928*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 929*4bdc9457SAndroid Build Coastguard Worker union xnn_f32_sqrt_params params; 930*4bdc9457SAndroid Build Coastguard Worker if (init_params != nullptr) { 931*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms); 932*4bdc9457SAndroid Build Coastguard Worker } 933*4bdc9457SAndroid Build Coastguard Worker 934*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 935*4bdc9457SAndroid Build Coastguard Worker vsqrt(batch_size() * sizeof(float), x_data, y.data(), init_params != nullptr ? ¶ms : nullptr); 936*4bdc9457SAndroid Build Coastguard Worker 937*4bdc9457SAndroid Build Coastguard Worker // Verify results. 938*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 939*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(y[i], y_ref[i]) 940*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << x[i]; 941*4bdc9457SAndroid Build Coastguard Worker } 942*4bdc9457SAndroid Build Coastguard Worker } 943*4bdc9457SAndroid Build Coastguard Worker } 944*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f16_vclamp_ukernel_function vclamp,xnn_init_f16_minmax_params_fn init_params)945*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f16_vclamp_ukernel_function vclamp, xnn_init_f16_minmax_params_fn init_params) const { 946*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 947*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 948*4bdc9457SAndroid Build Coastguard Worker auto f32rng = std::bind(std::uniform_real_distribution<float>(0.0f, 255.0f), std::ref(rng)); 949*4bdc9457SAndroid Build Coastguard Worker auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng); 950*4bdc9457SAndroid Build Coastguard Worker 951*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 952*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0)); 953*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 954*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 955*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), std::ref(f16rng)); 956*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 957*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), std::ref(f16rng)); 958*4bdc9457SAndroid Build Coastguard Worker } else { 959*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */); 960*4bdc9457SAndroid Build Coastguard Worker } 961*4bdc9457SAndroid Build Coastguard Worker const uint16_t* x_data = inplace() ? y.data() : x.data(); 962*4bdc9457SAndroid Build Coastguard Worker 963*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 964*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 965*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::max(std::min(fp16_ieee_to_fp32_value(x_data[i]), float(qmax())), float(qmin())); 966*4bdc9457SAndroid Build Coastguard Worker } 967*4bdc9457SAndroid Build Coastguard Worker 968*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 969*4bdc9457SAndroid Build Coastguard Worker union xnn_f16_minmax_params params; 970*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, fp16_ieee_from_fp32_value(float(qmin())), fp16_ieee_from_fp32_value(float(qmax()))); 971*4bdc9457SAndroid Build Coastguard Worker 972*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 973*4bdc9457SAndroid Build Coastguard Worker vclamp(batch_size() * sizeof(uint16_t), x_data, y.data(), ¶ms); 974*4bdc9457SAndroid Build Coastguard Worker 975*4bdc9457SAndroid Build Coastguard Worker // Verify results. 976*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 977*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(y_ref[i], fp16_ieee_to_fp32_value(y[i]), std::max(1.0e-3f, std::abs(y_ref[i]) * 1.0e-2f)) 978*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << fp16_ieee_to_fp32_value(x[i]); 979*4bdc9457SAndroid Build Coastguard Worker } 980*4bdc9457SAndroid Build Coastguard Worker } 981*4bdc9457SAndroid Build Coastguard Worker } 982*4bdc9457SAndroid Build Coastguard Worker Test(xnn_s8_vclamp_ukernel_function vclamp,xnn_init_s8_minmax_params_fn init_params)983*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_s8_vclamp_ukernel_function vclamp, xnn_init_s8_minmax_params_fn init_params) const { 984*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 985*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 986*4bdc9457SAndroid Build Coastguard Worker auto i8rng = std::bind( 987*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()), 988*4bdc9457SAndroid Build Coastguard Worker std::ref(rng)); 989*4bdc9457SAndroid Build Coastguard Worker 990*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(int8_t)); 991*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(int8_t) : 0)); 992*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> y_ref(batch_size()); 993*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 994*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), std::ref(i8rng)); 995*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 996*4bdc9457SAndroid Build Coastguard Worker std::copy(x.cbegin(), x.cend(), y.begin()); 997*4bdc9457SAndroid Build Coastguard Worker } else { 998*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), INT8_C(0xA5)); 999*4bdc9457SAndroid Build Coastguard Worker } 1000*4bdc9457SAndroid Build Coastguard Worker const int8_t* x_data = inplace() ? y.data() : x.data(); 1001*4bdc9457SAndroid Build Coastguard Worker 1002*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 1003*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1004*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::min(std::max(x_data[i], int8_t(qmin() - 0x80)), int8_t(qmax() - 0x80)); 1005*4bdc9457SAndroid Build Coastguard Worker } 1006*4bdc9457SAndroid Build Coastguard Worker 1007*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 1008*4bdc9457SAndroid Build Coastguard Worker union xnn_s8_minmax_params params; 1009*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80)); 1010*4bdc9457SAndroid Build Coastguard Worker 1011*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 1012*4bdc9457SAndroid Build Coastguard Worker vclamp(batch_size() * sizeof(int8_t), x_data, y.data(), ¶ms); 1013*4bdc9457SAndroid Build Coastguard Worker 1014*4bdc9457SAndroid Build Coastguard Worker // Verify results. 1015*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1016*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(int32_t(y_ref[i]), int32_t(y[i])) 1017*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << int32_t(x[i]); 1018*4bdc9457SAndroid Build Coastguard Worker } 1019*4bdc9457SAndroid Build Coastguard Worker } 1020*4bdc9457SAndroid Build Coastguard Worker } 1021*4bdc9457SAndroid Build Coastguard Worker Test(xnn_u8_vclamp_ukernel_function vclamp,xnn_init_u8_minmax_params_fn init_params)1022*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_u8_vclamp_ukernel_function vclamp, xnn_init_u8_minmax_params_fn init_params) const { 1023*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 1024*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 1025*4bdc9457SAndroid Build Coastguard Worker auto u8rng = std::bind( 1026*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(0, std::numeric_limits<uint8_t>::max()), std::ref(rng)); 1027*4bdc9457SAndroid Build Coastguard Worker 1028*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint8_t)); 1029*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint8_t) : 0)); 1030*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> y_ref(batch_size()); 1031*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 1032*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), std::ref(u8rng)); 1033*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 1034*4bdc9457SAndroid Build Coastguard Worker std::copy(x.cbegin(), x.cend(), y.begin()); 1035*4bdc9457SAndroid Build Coastguard Worker } else { 1036*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), UINT8_C(0xA5)); 1037*4bdc9457SAndroid Build Coastguard Worker } 1038*4bdc9457SAndroid Build Coastguard Worker const uint8_t* x_data = inplace() ? y.data() : x.data(); 1039*4bdc9457SAndroid Build Coastguard Worker 1040*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 1041*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1042*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::min(std::max(x_data[i], qmin()), qmax()); 1043*4bdc9457SAndroid Build Coastguard Worker } 1044*4bdc9457SAndroid Build Coastguard Worker 1045*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 1046*4bdc9457SAndroid Build Coastguard Worker union xnn_u8_minmax_params params; 1047*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, qmin(), qmax()); 1048*4bdc9457SAndroid Build Coastguard Worker 1049*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 1050*4bdc9457SAndroid Build Coastguard Worker vclamp(batch_size() * sizeof(uint8_t), x_data, y.data(), ¶ms); 1051*4bdc9457SAndroid Build Coastguard Worker 1052*4bdc9457SAndroid Build Coastguard Worker // Verify results. 1053*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1054*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(uint32_t(y_ref[i]), uint32_t(y[i])) 1055*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() << ", x[" << i << "] = " << uint32_t(x[i]); 1056*4bdc9457SAndroid Build Coastguard Worker } 1057*4bdc9457SAndroid Build Coastguard Worker } 1058*4bdc9457SAndroid Build Coastguard Worker } 1059*4bdc9457SAndroid Build Coastguard Worker Test(xnn_u64_u32_vsqrtshift_ukernel_function vsqrtshift)1060*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_u64_u32_vsqrtshift_ukernel_function vsqrtshift) const { 1061*4bdc9457SAndroid Build Coastguard Worker ASSERT_FALSE(inplace()); 1062*4bdc9457SAndroid Build Coastguard Worker 1063*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 1064*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 1065*4bdc9457SAndroid Build Coastguard Worker auto u64rng = std::bind( std::uniform_int_distribution<uint64_t>(), std::ref(rng)); 1066*4bdc9457SAndroid Build Coastguard Worker 1067*4bdc9457SAndroid Build Coastguard Worker std::vector<uint64_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint64_t)); 1068*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> y(batch_size()); 1069*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> y_ref(batch_size()); 1070*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 1071*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), std::ref(u64rng)); 1072*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), UINT32_C(0xDEADBEEF)); 1073*4bdc9457SAndroid Build Coastguard Worker 1074*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 1075*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1076*4bdc9457SAndroid Build Coastguard Worker const uint64_t x_value = x[i]; 1077*4bdc9457SAndroid Build Coastguard Worker uint32_t y_value = 0; 1078*4bdc9457SAndroid Build Coastguard Worker // Match TFLM semantics, including bugs 1079*4bdc9457SAndroid Build Coastguard Worker if (uint32_t(x_value) == x_value) { 1080*4bdc9457SAndroid Build Coastguard Worker y_value = (uint32_t) std::lrint(std::sqrt(double(int64_t(uint64_t(x_value))))); 1081*4bdc9457SAndroid Build Coastguard Worker y_value = std::min<uint32_t>(y_value, std::numeric_limits<uint16_t>::max()); 1082*4bdc9457SAndroid Build Coastguard Worker } else if (x_value != 0) { 1083*4bdc9457SAndroid Build Coastguard Worker uint64_t y0 = x_value >> 1; 1084*4bdc9457SAndroid Build Coastguard Worker uint64_t y1 = (y0 + x_value / y0) >> 1; 1085*4bdc9457SAndroid Build Coastguard Worker do { 1086*4bdc9457SAndroid Build Coastguard Worker y0 = y1; 1087*4bdc9457SAndroid Build Coastguard Worker y1 = (y0 + x_value / y0) >> 1; 1088*4bdc9457SAndroid Build Coastguard Worker } while (y1 < y0); 1089*4bdc9457SAndroid Build Coastguard Worker 1090*4bdc9457SAndroid Build Coastguard Worker // y0 is sqrt(x_value) rounded down, round up if needed 1091*4bdc9457SAndroid Build Coastguard Worker if (int64_t(y0 * y0 + y0 - x_value) < 0) { 1092*4bdc9457SAndroid Build Coastguard Worker y0 += 1; 1093*4bdc9457SAndroid Build Coastguard Worker } 1094*4bdc9457SAndroid Build Coastguard Worker y_value = static_cast<uint32_t>(std::min<uint64_t>(y0, std::numeric_limits<uint32_t>::max())); 1095*4bdc9457SAndroid Build Coastguard Worker } 1096*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = y_value >> shift(); 1097*4bdc9457SAndroid Build Coastguard Worker } 1098*4bdc9457SAndroid Build Coastguard Worker 1099*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 1100*4bdc9457SAndroid Build Coastguard Worker vsqrtshift(batch_size() * sizeof(uint64_t), x.data(), y.data(), shift()); 1101*4bdc9457SAndroid Build Coastguard Worker 1102*4bdc9457SAndroid Build Coastguard Worker // Verify results. 1103*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1104*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(y_ref[i], y[i]) 1105*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size() 1106*4bdc9457SAndroid Build Coastguard Worker << ", x[" << i << "]: " << x[i] 1107*4bdc9457SAndroid Build Coastguard Worker << ", shift: " << shift(); 1108*4bdc9457SAndroid Build Coastguard Worker } 1109*4bdc9457SAndroid Build Coastguard Worker } 1110*4bdc9457SAndroid Build Coastguard Worker } 1111*4bdc9457SAndroid Build Coastguard Worker 1112*4bdc9457SAndroid Build Coastguard Worker private: 1113*4bdc9457SAndroid Build Coastguard Worker size_t batch_size_ = 1; 1114*4bdc9457SAndroid Build Coastguard Worker bool inplace_ = false; 1115*4bdc9457SAndroid Build Coastguard Worker float slope_ = 0.5f; 1116*4bdc9457SAndroid Build Coastguard Worker float prescale_ = 1.0f; 1117*4bdc9457SAndroid Build Coastguard Worker float alpha_ = 1.0f; 1118*4bdc9457SAndroid Build Coastguard Worker float beta_ = 1.0f; 1119*4bdc9457SAndroid Build Coastguard Worker uint32_t shift_ = 1; 1120*4bdc9457SAndroid Build Coastguard Worker uint8_t qmin_ = 0; 1121*4bdc9457SAndroid Build Coastguard Worker uint8_t qmax_ = 255; 1122*4bdc9457SAndroid Build Coastguard Worker size_t iterations_ = 15; 1123*4bdc9457SAndroid Build Coastguard Worker }; 1124