1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker // 3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker #pragma once 7*4bdc9457SAndroid Build Coastguard Worker 8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 9*4bdc9457SAndroid Build Coastguard Worker 10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 11*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 12*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 13*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 14*4bdc9457SAndroid Build Coastguard Worker #include <functional> 15*4bdc9457SAndroid Build Coastguard Worker #include <random> 16*4bdc9457SAndroid Build Coastguard Worker #include <vector> 17*4bdc9457SAndroid Build Coastguard Worker 18*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h> 19*4bdc9457SAndroid Build Coastguard Worker 20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 21*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h> 22*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microparams-init.h> 23*4bdc9457SAndroid Build Coastguard Worker 24*4bdc9457SAndroid Build Coastguard Worker 25*4bdc9457SAndroid Build Coastguard Worker class VBinaryCMicrokernelTester { 26*4bdc9457SAndroid Build Coastguard Worker public: 27*4bdc9457SAndroid Build Coastguard Worker enum class OpType { 28*4bdc9457SAndroid Build Coastguard Worker AddC, 29*4bdc9457SAndroid Build Coastguard Worker DivC, 30*4bdc9457SAndroid Build Coastguard Worker RDivC, 31*4bdc9457SAndroid Build Coastguard Worker MaxC, 32*4bdc9457SAndroid Build Coastguard Worker MinC, 33*4bdc9457SAndroid Build Coastguard Worker MulC, 34*4bdc9457SAndroid Build Coastguard Worker SqrDiffC, 35*4bdc9457SAndroid Build Coastguard Worker SubC, 36*4bdc9457SAndroid Build Coastguard Worker RSubC, 37*4bdc9457SAndroid Build Coastguard Worker }; 38*4bdc9457SAndroid Build Coastguard Worker batch_size(size_t batch_size)39*4bdc9457SAndroid Build Coastguard Worker inline VBinaryCMicrokernelTester& batch_size(size_t batch_size) { 40*4bdc9457SAndroid Build Coastguard Worker assert(batch_size != 0); 41*4bdc9457SAndroid Build Coastguard Worker this->batch_size_ = batch_size; 42*4bdc9457SAndroid Build Coastguard Worker return *this; 43*4bdc9457SAndroid Build Coastguard Worker } 44*4bdc9457SAndroid Build Coastguard Worker batch_size()45*4bdc9457SAndroid Build Coastguard Worker inline size_t batch_size() const { 46*4bdc9457SAndroid Build Coastguard Worker return this->batch_size_; 47*4bdc9457SAndroid Build Coastguard Worker } 48*4bdc9457SAndroid Build Coastguard Worker inplace(bool inplace)49*4bdc9457SAndroid Build Coastguard Worker inline VBinaryCMicrokernelTester& inplace(bool inplace) { 50*4bdc9457SAndroid Build Coastguard Worker this->inplace_ = inplace; 51*4bdc9457SAndroid Build Coastguard Worker return *this; 52*4bdc9457SAndroid Build Coastguard Worker } 53*4bdc9457SAndroid Build Coastguard Worker inplace()54*4bdc9457SAndroid Build Coastguard Worker inline bool inplace() const { 55*4bdc9457SAndroid Build Coastguard Worker return this->inplace_; 56*4bdc9457SAndroid Build Coastguard Worker } 57*4bdc9457SAndroid Build Coastguard Worker qmin(uint8_t qmin)58*4bdc9457SAndroid Build Coastguard Worker inline VBinaryCMicrokernelTester& qmin(uint8_t qmin) { 59*4bdc9457SAndroid Build Coastguard Worker this->qmin_ = qmin; 60*4bdc9457SAndroid Build Coastguard Worker return *this; 61*4bdc9457SAndroid Build Coastguard Worker } 62*4bdc9457SAndroid Build Coastguard Worker qmin()63*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmin() const { 64*4bdc9457SAndroid Build Coastguard Worker return this->qmin_; 65*4bdc9457SAndroid Build Coastguard Worker } 66*4bdc9457SAndroid Build Coastguard Worker qmax(uint8_t qmax)67*4bdc9457SAndroid Build Coastguard Worker inline VBinaryCMicrokernelTester& qmax(uint8_t qmax) { 68*4bdc9457SAndroid Build Coastguard Worker this->qmax_ = qmax; 69*4bdc9457SAndroid Build Coastguard Worker return *this; 70*4bdc9457SAndroid Build Coastguard Worker } 71*4bdc9457SAndroid Build Coastguard Worker qmax()72*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmax() const { 73*4bdc9457SAndroid Build Coastguard Worker return this->qmax_; 74*4bdc9457SAndroid Build Coastguard Worker } 75*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)76*4bdc9457SAndroid Build Coastguard Worker inline VBinaryCMicrokernelTester& iterations(size_t iterations) { 77*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 78*4bdc9457SAndroid Build Coastguard Worker return *this; 79*4bdc9457SAndroid Build Coastguard Worker } 80*4bdc9457SAndroid Build Coastguard Worker iterations()81*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 82*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 83*4bdc9457SAndroid Build Coastguard Worker } 84*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f16_vbinary_ukernel_function vbinaryc,OpType op_type)85*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f16_vbinary_ukernel_function vbinaryc, OpType op_type) const { 86*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 87*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 88*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.01f, 1.0f); 89*4bdc9457SAndroid Build Coastguard Worker 90*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> a(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 91*4bdc9457SAndroid Build Coastguard Worker const uint16_t b = fp16_ieee_from_fp32_value(f32dist(rng)); 92*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0)); 93*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 94*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 95*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 96*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 97*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 98*4bdc9457SAndroid Build Coastguard Worker } else { 99*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */); 100*4bdc9457SAndroid Build Coastguard Worker } 101*4bdc9457SAndroid Build Coastguard Worker const uint16_t* a_data = inplace() ? y.data() : a.data(); 102*4bdc9457SAndroid Build Coastguard Worker 103*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 104*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 105*4bdc9457SAndroid Build Coastguard Worker switch (op_type) { 106*4bdc9457SAndroid Build Coastguard Worker case OpType::AddC: 107*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_to_fp32_value(a_data[i]) + fp16_ieee_to_fp32_value(b); 108*4bdc9457SAndroid Build Coastguard Worker break; 109*4bdc9457SAndroid Build Coastguard Worker case OpType::DivC: 110*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_to_fp32_value(a_data[i]) / fp16_ieee_to_fp32_value(b); 111*4bdc9457SAndroid Build Coastguard Worker break; 112*4bdc9457SAndroid Build Coastguard Worker case OpType::RDivC: 113*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_to_fp32_value(b) / fp16_ieee_to_fp32_value(a_data[i]); 114*4bdc9457SAndroid Build Coastguard Worker break; 115*4bdc9457SAndroid Build Coastguard Worker case OpType::MaxC: 116*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::max<float>(fp16_ieee_to_fp32_value(a_data[i]), fp16_ieee_to_fp32_value(b)); 117*4bdc9457SAndroid Build Coastguard Worker break; 118*4bdc9457SAndroid Build Coastguard Worker case OpType::MinC: 119*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::min<float>(fp16_ieee_to_fp32_value(a_data[i]), fp16_ieee_to_fp32_value(b)); 120*4bdc9457SAndroid Build Coastguard Worker break; 121*4bdc9457SAndroid Build Coastguard Worker case OpType::MulC: 122*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_to_fp32_value(a_data[i]) * fp16_ieee_to_fp32_value(b); 123*4bdc9457SAndroid Build Coastguard Worker break; 124*4bdc9457SAndroid Build Coastguard Worker case OpType::SqrDiffC: 125*4bdc9457SAndroid Build Coastguard Worker { 126*4bdc9457SAndroid Build Coastguard Worker const float diff = fp16_ieee_to_fp32_value(a_data[i]) - fp16_ieee_to_fp32_value(b); 127*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = diff * diff; 128*4bdc9457SAndroid Build Coastguard Worker break; 129*4bdc9457SAndroid Build Coastguard Worker } 130*4bdc9457SAndroid Build Coastguard Worker case OpType::SubC: 131*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_to_fp32_value(a_data[i]) - fp16_ieee_to_fp32_value(b); 132*4bdc9457SAndroid Build Coastguard Worker break; 133*4bdc9457SAndroid Build Coastguard Worker case OpType::RSubC: 134*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_to_fp32_value(b) - fp16_ieee_to_fp32_value(a_data[i]); 135*4bdc9457SAndroid Build Coastguard Worker break; 136*4bdc9457SAndroid Build Coastguard Worker } 137*4bdc9457SAndroid Build Coastguard Worker } 138*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 139*4bdc9457SAndroid Build Coastguard Worker vbinaryc(batch_size() * sizeof(uint16_t), a_data, &b, y.data(), nullptr); 140*4bdc9457SAndroid Build Coastguard Worker 141*4bdc9457SAndroid Build Coastguard Worker // Verify results. 142*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 143*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(fp16_ieee_to_fp32_value(y[i]), y_ref[i], std::max(1.0e-4f, std::abs(y_ref[i]) * 1.0e-2f)) 144*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size(); 145*4bdc9457SAndroid Build Coastguard Worker } 146*4bdc9457SAndroid Build Coastguard Worker } 147*4bdc9457SAndroid Build Coastguard Worker } 148*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f16_vbinary_minmax_ukernel_function vbinaryc_minmax,OpType op_type,xnn_init_f16_minmax_params_fn init_params)149*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f16_vbinary_minmax_ukernel_function vbinaryc_minmax, OpType op_type, xnn_init_f16_minmax_params_fn init_params) const { 150*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 151*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 152*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.01f, 1.0f); 153*4bdc9457SAndroid Build Coastguard Worker 154*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> a(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 155*4bdc9457SAndroid Build Coastguard Worker const uint16_t b = fp16_ieee_from_fp32_value(f32dist(rng)); 156*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0)); 157*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 158*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 159*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 160*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 161*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 162*4bdc9457SAndroid Build Coastguard Worker } else { 163*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */); 164*4bdc9457SAndroid Build Coastguard Worker } 165*4bdc9457SAndroid Build Coastguard Worker const uint16_t* a_data = inplace() ? y.data() : a.data(); 166*4bdc9457SAndroid Build Coastguard Worker 167*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 168*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 169*4bdc9457SAndroid Build Coastguard Worker switch (op_type) { 170*4bdc9457SAndroid Build Coastguard Worker case OpType::AddC: 171*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_to_fp32_value(a_data[i]) + fp16_ieee_to_fp32_value(b); 172*4bdc9457SAndroid Build Coastguard Worker break; 173*4bdc9457SAndroid Build Coastguard Worker case OpType::DivC: 174*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_to_fp32_value(a_data[i]) / fp16_ieee_to_fp32_value(b); 175*4bdc9457SAndroid Build Coastguard Worker break; 176*4bdc9457SAndroid Build Coastguard Worker case OpType::RDivC: 177*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_to_fp32_value(b) / fp16_ieee_to_fp32_value(a_data[i]); 178*4bdc9457SAndroid Build Coastguard Worker break; 179*4bdc9457SAndroid Build Coastguard Worker case OpType::MaxC: 180*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::max<float>(fp16_ieee_to_fp32_value(a_data[i]), fp16_ieee_to_fp32_value(b)); 181*4bdc9457SAndroid Build Coastguard Worker break; 182*4bdc9457SAndroid Build Coastguard Worker case OpType::MinC: 183*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::min<float>(fp16_ieee_to_fp32_value(a_data[i]), fp16_ieee_to_fp32_value(b)); 184*4bdc9457SAndroid Build Coastguard Worker break; 185*4bdc9457SAndroid Build Coastguard Worker case OpType::MulC: 186*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_to_fp32_value(a_data[i]) * fp16_ieee_to_fp32_value(b); 187*4bdc9457SAndroid Build Coastguard Worker break; 188*4bdc9457SAndroid Build Coastguard Worker case OpType::SqrDiffC: 189*4bdc9457SAndroid Build Coastguard Worker { 190*4bdc9457SAndroid Build Coastguard Worker const float diff = fp16_ieee_to_fp32_value(a_data[i]) - fp16_ieee_to_fp32_value(b); 191*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = diff * diff; 192*4bdc9457SAndroid Build Coastguard Worker break; 193*4bdc9457SAndroid Build Coastguard Worker } 194*4bdc9457SAndroid Build Coastguard Worker case OpType::SubC: 195*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_to_fp32_value(a_data[i]) - fp16_ieee_to_fp32_value(b); 196*4bdc9457SAndroid Build Coastguard Worker break; 197*4bdc9457SAndroid Build Coastguard Worker case OpType::RSubC: 198*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_to_fp32_value(b) - fp16_ieee_to_fp32_value(a_data[i]); 199*4bdc9457SAndroid Build Coastguard Worker break; 200*4bdc9457SAndroid Build Coastguard Worker } 201*4bdc9457SAndroid Build Coastguard Worker } 202*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(y_ref.cbegin(), y_ref.cend()); 203*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(y_ref.cbegin(), y_ref.cend()); 204*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 205*4bdc9457SAndroid Build Coastguard Worker const float y_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_range > 0.0f ? 206*4bdc9457SAndroid Build Coastguard Worker (accumulated_max - accumulated_range / 255.0f * float(255 - qmax())) : 207*4bdc9457SAndroid Build Coastguard Worker +std::numeric_limits<float>::infinity())); 208*4bdc9457SAndroid Build Coastguard Worker const float y_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_range > 0.0f ? 209*4bdc9457SAndroid Build Coastguard Worker (accumulated_min + accumulated_range / 255.0f * float(qmin())) : 210*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<float>::infinity())); 211*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 212*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::max<float>(std::min<float>(y_ref[i], y_max), y_min); 213*4bdc9457SAndroid Build Coastguard Worker } 214*4bdc9457SAndroid Build Coastguard Worker 215*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 216*4bdc9457SAndroid Build Coastguard Worker xnn_f16_minmax_params params; 217*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, 218*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_from_fp32_value(y_min), fp16_ieee_from_fp32_value(y_max)); 219*4bdc9457SAndroid Build Coastguard Worker 220*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 221*4bdc9457SAndroid Build Coastguard Worker vbinaryc_minmax(batch_size() * sizeof(uint16_t), a_data, &b, y.data(), ¶ms); 222*4bdc9457SAndroid Build Coastguard Worker 223*4bdc9457SAndroid Build Coastguard Worker // Verify results. 224*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 225*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(fp16_ieee_to_fp32_value(y[i]), y_ref[i], std::max(1.0e-4f, std::abs(y_ref[i]) * 1.0e-2f)) 226*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size(); 227*4bdc9457SAndroid Build Coastguard Worker } 228*4bdc9457SAndroid Build Coastguard Worker } 229*4bdc9457SAndroid Build Coastguard Worker } 230*4bdc9457SAndroid Build Coastguard Worker 231*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_vbinary_ukernel_function vbinaryc, OpType op_type, xnn_init_f32_default_params_fn init_params = nullptr) const { 232*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 233*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 234*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.01f, 1.0f); 235*4bdc9457SAndroid Build Coastguard Worker 236*4bdc9457SAndroid Build Coastguard Worker std::vector<float> a(batch_size() + XNN_EXTRA_BYTES / sizeof(float)); 237*4bdc9457SAndroid Build Coastguard Worker const float b = f32dist(rng); 238*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0)); 239*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 240*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 241*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); }); 242*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 243*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); }); 244*4bdc9457SAndroid Build Coastguard Worker } else { 245*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), nanf("")); 246*4bdc9457SAndroid Build Coastguard Worker } 247*4bdc9457SAndroid Build Coastguard Worker const float* a_data = inplace() ? y.data() : a.data(); 248*4bdc9457SAndroid Build Coastguard Worker 249*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 250*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 251*4bdc9457SAndroid Build Coastguard Worker switch (op_type) { 252*4bdc9457SAndroid Build Coastguard Worker case OpType::AddC: 253*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] + b; 254*4bdc9457SAndroid Build Coastguard Worker break; 255*4bdc9457SAndroid Build Coastguard Worker case OpType::DivC: 256*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] / b; 257*4bdc9457SAndroid Build Coastguard Worker break; 258*4bdc9457SAndroid Build Coastguard Worker case OpType::RDivC: 259*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = b / a_data[i]; 260*4bdc9457SAndroid Build Coastguard Worker break; 261*4bdc9457SAndroid Build Coastguard Worker case OpType::MaxC: 262*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::max<float>(a_data[i], b); 263*4bdc9457SAndroid Build Coastguard Worker break; 264*4bdc9457SAndroid Build Coastguard Worker case OpType::MinC: 265*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::min<float>(a_data[i], b); 266*4bdc9457SAndroid Build Coastguard Worker break; 267*4bdc9457SAndroid Build Coastguard Worker case OpType::MulC: 268*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] * b; 269*4bdc9457SAndroid Build Coastguard Worker break; 270*4bdc9457SAndroid Build Coastguard Worker case OpType::SqrDiffC: 271*4bdc9457SAndroid Build Coastguard Worker { 272*4bdc9457SAndroid Build Coastguard Worker const float diff = a_data[i] - b; 273*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = diff * diff; 274*4bdc9457SAndroid Build Coastguard Worker break; 275*4bdc9457SAndroid Build Coastguard Worker } 276*4bdc9457SAndroid Build Coastguard Worker case OpType::SubC: 277*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] - b; 278*4bdc9457SAndroid Build Coastguard Worker break; 279*4bdc9457SAndroid Build Coastguard Worker case OpType::RSubC: 280*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = b - a_data[i]; 281*4bdc9457SAndroid Build Coastguard Worker break; 282*4bdc9457SAndroid Build Coastguard Worker } 283*4bdc9457SAndroid Build Coastguard Worker } 284*4bdc9457SAndroid Build Coastguard Worker 285*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 286*4bdc9457SAndroid Build Coastguard Worker xnn_f32_default_params params; 287*4bdc9457SAndroid Build Coastguard Worker if (init_params) { 288*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms); 289*4bdc9457SAndroid Build Coastguard Worker } 290*4bdc9457SAndroid Build Coastguard Worker 291*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 292*4bdc9457SAndroid Build Coastguard Worker vbinaryc(batch_size() * sizeof(float), a_data, &b, y.data(), init_params != nullptr ? ¶ms : nullptr); 293*4bdc9457SAndroid Build Coastguard Worker 294*4bdc9457SAndroid Build Coastguard Worker // Verify results. 295*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 296*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(y[i], y_ref[i], std::abs(y_ref[i]) * 1.0e-6f) 297*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size(); 298*4bdc9457SAndroid Build Coastguard Worker } 299*4bdc9457SAndroid Build Coastguard Worker } 300*4bdc9457SAndroid Build Coastguard Worker } 301*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f32_vbinary_relu_ukernel_function vbinaryc_relu,OpType op_type)302*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_vbinary_relu_ukernel_function vbinaryc_relu, OpType op_type) const { 303*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 304*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 305*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-1.0f, 1.0f); 306*4bdc9457SAndroid Build Coastguard Worker 307*4bdc9457SAndroid Build Coastguard Worker std::vector<float> a(batch_size() + XNN_EXTRA_BYTES / sizeof(float)); 308*4bdc9457SAndroid Build Coastguard Worker const float b = f32dist(rng); 309*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0)); 310*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 311*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 312*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); }); 313*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 314*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); }); 315*4bdc9457SAndroid Build Coastguard Worker } else { 316*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), nanf("")); 317*4bdc9457SAndroid Build Coastguard Worker } 318*4bdc9457SAndroid Build Coastguard Worker const float* a_data = inplace() ? y.data() : a.data(); 319*4bdc9457SAndroid Build Coastguard Worker 320*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 321*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 322*4bdc9457SAndroid Build Coastguard Worker switch (op_type) { 323*4bdc9457SAndroid Build Coastguard Worker case OpType::AddC: 324*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] + b; 325*4bdc9457SAndroid Build Coastguard Worker break; 326*4bdc9457SAndroid Build Coastguard Worker case OpType::DivC: 327*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] / b; 328*4bdc9457SAndroid Build Coastguard Worker break; 329*4bdc9457SAndroid Build Coastguard Worker case OpType::RDivC: 330*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = b / a_data[i]; 331*4bdc9457SAndroid Build Coastguard Worker break; 332*4bdc9457SAndroid Build Coastguard Worker case OpType::MaxC: 333*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::max<float>(a_data[i], b); 334*4bdc9457SAndroid Build Coastguard Worker break; 335*4bdc9457SAndroid Build Coastguard Worker case OpType::MinC: 336*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::min<float>(a_data[i], b); 337*4bdc9457SAndroid Build Coastguard Worker break; 338*4bdc9457SAndroid Build Coastguard Worker case OpType::MulC: 339*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] * b; 340*4bdc9457SAndroid Build Coastguard Worker break; 341*4bdc9457SAndroid Build Coastguard Worker case OpType::SqrDiffC: 342*4bdc9457SAndroid Build Coastguard Worker { 343*4bdc9457SAndroid Build Coastguard Worker const float diff = a_data[i] - b; 344*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = diff * diff; 345*4bdc9457SAndroid Build Coastguard Worker break; 346*4bdc9457SAndroid Build Coastguard Worker } 347*4bdc9457SAndroid Build Coastguard Worker case OpType::SubC: 348*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] - b; 349*4bdc9457SAndroid Build Coastguard Worker break; 350*4bdc9457SAndroid Build Coastguard Worker case OpType::RSubC: 351*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = b - a_data[i]; 352*4bdc9457SAndroid Build Coastguard Worker break; 353*4bdc9457SAndroid Build Coastguard Worker } 354*4bdc9457SAndroid Build Coastguard Worker } 355*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 356*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::max(y_ref[i], 0.0f); 357*4bdc9457SAndroid Build Coastguard Worker } 358*4bdc9457SAndroid Build Coastguard Worker 359*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 360*4bdc9457SAndroid Build Coastguard Worker vbinaryc_relu(batch_size() * sizeof(float), a_data, &b, y.data(), nullptr); 361*4bdc9457SAndroid Build Coastguard Worker 362*4bdc9457SAndroid Build Coastguard Worker // Verify results. 363*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 364*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(y[i], 0.0f) 365*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size(); 366*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(y[i], y_ref[i], std::abs(y_ref[i]) * 1.0e-6f) 367*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size(); 368*4bdc9457SAndroid Build Coastguard Worker } 369*4bdc9457SAndroid Build Coastguard Worker } 370*4bdc9457SAndroid Build Coastguard Worker } 371*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f32_vbinary_minmax_ukernel_function vbinaryc_minmax,OpType op_type,xnn_init_f32_minmax_params_fn init_params)372*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_vbinary_minmax_ukernel_function vbinaryc_minmax, OpType op_type, xnn_init_f32_minmax_params_fn init_params) const { 373*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 374*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 375*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist; 376*4bdc9457SAndroid Build Coastguard Worker 377*4bdc9457SAndroid Build Coastguard Worker std::vector<float> a(batch_size() + XNN_EXTRA_BYTES / sizeof(float)); 378*4bdc9457SAndroid Build Coastguard Worker const float b = f32dist(rng); 379*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(float) : 0)); 380*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 381*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 382*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); }); 383*4bdc9457SAndroid Build Coastguard Worker if (inplace()) { 384*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); }); 385*4bdc9457SAndroid Build Coastguard Worker } else { 386*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), nanf("")); 387*4bdc9457SAndroid Build Coastguard Worker } 388*4bdc9457SAndroid Build Coastguard Worker const float* a_data = inplace() ? y.data() : a.data(); 389*4bdc9457SAndroid Build Coastguard Worker 390*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 391*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 392*4bdc9457SAndroid Build Coastguard Worker switch (op_type) { 393*4bdc9457SAndroid Build Coastguard Worker case OpType::AddC: 394*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] + b; 395*4bdc9457SAndroid Build Coastguard Worker break; 396*4bdc9457SAndroid Build Coastguard Worker case OpType::DivC: 397*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] / b; 398*4bdc9457SAndroid Build Coastguard Worker break; 399*4bdc9457SAndroid Build Coastguard Worker case OpType::RDivC: 400*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = b / a_data[i]; 401*4bdc9457SAndroid Build Coastguard Worker break; 402*4bdc9457SAndroid Build Coastguard Worker case OpType::MaxC: 403*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::max<float>(a_data[i], b); 404*4bdc9457SAndroid Build Coastguard Worker break; 405*4bdc9457SAndroid Build Coastguard Worker case OpType::MinC: 406*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::min<float>(a_data[i], b); 407*4bdc9457SAndroid Build Coastguard Worker break; 408*4bdc9457SAndroid Build Coastguard Worker case OpType::MulC: 409*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] * b; 410*4bdc9457SAndroid Build Coastguard Worker break; 411*4bdc9457SAndroid Build Coastguard Worker case OpType::SqrDiffC: 412*4bdc9457SAndroid Build Coastguard Worker { 413*4bdc9457SAndroid Build Coastguard Worker const float diff = a_data[i] - b; 414*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = diff * diff; 415*4bdc9457SAndroid Build Coastguard Worker break; 416*4bdc9457SAndroid Build Coastguard Worker } 417*4bdc9457SAndroid Build Coastguard Worker case OpType::SubC: 418*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] - b; 419*4bdc9457SAndroid Build Coastguard Worker break; 420*4bdc9457SAndroid Build Coastguard Worker case OpType::RSubC: 421*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = b - a_data[i]; 422*4bdc9457SAndroid Build Coastguard Worker break; 423*4bdc9457SAndroid Build Coastguard Worker } 424*4bdc9457SAndroid Build Coastguard Worker } 425*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(y_ref.cbegin(), y_ref.cend()); 426*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(y_ref.cbegin(), y_ref.cend()); 427*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 428*4bdc9457SAndroid Build Coastguard Worker const float y_max = accumulated_range > 0.0f ? 429*4bdc9457SAndroid Build Coastguard Worker (accumulated_max - accumulated_range / 255.0f * float(255 - qmax())) : 430*4bdc9457SAndroid Build Coastguard Worker +std::numeric_limits<float>::infinity(); 431*4bdc9457SAndroid Build Coastguard Worker const float y_min = accumulated_range > 0.0f ? 432*4bdc9457SAndroid Build Coastguard Worker (accumulated_min + accumulated_range / 255.0f * float(qmin())) : 433*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<float>::infinity(); 434*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 435*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::max<float>(std::min<float>(y_ref[i], y_max), y_min); 436*4bdc9457SAndroid Build Coastguard Worker } 437*4bdc9457SAndroid Build Coastguard Worker 438*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 439*4bdc9457SAndroid Build Coastguard Worker xnn_f32_minmax_params params; 440*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, y_min, y_max); 441*4bdc9457SAndroid Build Coastguard Worker 442*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 443*4bdc9457SAndroid Build Coastguard Worker vbinaryc_minmax(batch_size() * sizeof(float), a_data, &b, y.data(), ¶ms); 444*4bdc9457SAndroid Build Coastguard Worker 445*4bdc9457SAndroid Build Coastguard Worker // Verify results. 446*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 447*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(y[i], y_ref[i], std::abs(y_ref[i]) * 1.0e-6f) 448*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size(); 449*4bdc9457SAndroid Build Coastguard Worker } 450*4bdc9457SAndroid Build Coastguard Worker } 451*4bdc9457SAndroid Build Coastguard Worker } 452*4bdc9457SAndroid Build Coastguard Worker 453*4bdc9457SAndroid Build Coastguard Worker private: 454*4bdc9457SAndroid Build Coastguard Worker size_t batch_size_{1}; 455*4bdc9457SAndroid Build Coastguard Worker bool inplace_{false}; 456*4bdc9457SAndroid Build Coastguard Worker uint8_t qmin_{0}; 457*4bdc9457SAndroid Build Coastguard Worker uint8_t qmax_{255}; 458*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{15}; 459*4bdc9457SAndroid Build Coastguard Worker }; 460