1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker // 3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker #pragma once 7*4bdc9457SAndroid Build Coastguard Worker 8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 9*4bdc9457SAndroid Build Coastguard Worker 10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 11*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 12*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 13*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 14*4bdc9457SAndroid Build Coastguard Worker #include <functional> 15*4bdc9457SAndroid Build Coastguard Worker #include <random> 16*4bdc9457SAndroid Build Coastguard Worker #include <vector> 17*4bdc9457SAndroid Build Coastguard Worker 18*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h> 19*4bdc9457SAndroid Build Coastguard Worker 20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 21*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h> 22*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microparams-init.h> 23*4bdc9457SAndroid Build Coastguard Worker 24*4bdc9457SAndroid Build Coastguard Worker 25*4bdc9457SAndroid Build Coastguard Worker class VBinaryMicrokernelTester { 26*4bdc9457SAndroid Build Coastguard Worker public: 27*4bdc9457SAndroid Build Coastguard Worker enum class OpType { 28*4bdc9457SAndroid Build Coastguard Worker Add, 29*4bdc9457SAndroid Build Coastguard Worker Div, 30*4bdc9457SAndroid Build Coastguard Worker Max, 31*4bdc9457SAndroid Build Coastguard Worker Min, 32*4bdc9457SAndroid Build Coastguard Worker Mul, 33*4bdc9457SAndroid Build Coastguard Worker Sub, 34*4bdc9457SAndroid Build Coastguard Worker SqrDiff, 35*4bdc9457SAndroid Build Coastguard Worker }; 36*4bdc9457SAndroid Build Coastguard Worker batch_size(size_t batch_size)37*4bdc9457SAndroid Build Coastguard Worker inline VBinaryMicrokernelTester& batch_size(size_t batch_size) { 38*4bdc9457SAndroid Build Coastguard Worker assert(batch_size != 0); 39*4bdc9457SAndroid Build Coastguard Worker this->batch_size_ = batch_size; 40*4bdc9457SAndroid Build Coastguard Worker return *this; 41*4bdc9457SAndroid Build Coastguard Worker } 42*4bdc9457SAndroid Build Coastguard Worker batch_size()43*4bdc9457SAndroid Build Coastguard Worker inline size_t batch_size() const { 44*4bdc9457SAndroid Build Coastguard Worker return this->batch_size_; 45*4bdc9457SAndroid Build Coastguard Worker } 46*4bdc9457SAndroid Build Coastguard Worker inplace_a(bool inplace_a)47*4bdc9457SAndroid Build Coastguard Worker inline VBinaryMicrokernelTester& inplace_a(bool inplace_a) { 48*4bdc9457SAndroid Build Coastguard Worker this->inplace_a_ = inplace_a; 49*4bdc9457SAndroid Build Coastguard Worker return *this; 50*4bdc9457SAndroid Build Coastguard Worker } 51*4bdc9457SAndroid Build Coastguard Worker inplace_a()52*4bdc9457SAndroid Build Coastguard Worker inline bool inplace_a() const { 53*4bdc9457SAndroid Build Coastguard Worker return this->inplace_a_; 54*4bdc9457SAndroid Build Coastguard Worker } 55*4bdc9457SAndroid Build Coastguard Worker inplace_b(bool inplace_b)56*4bdc9457SAndroid Build Coastguard Worker inline VBinaryMicrokernelTester& inplace_b(bool inplace_b) { 57*4bdc9457SAndroid Build Coastguard Worker this->inplace_b_ = inplace_b; 58*4bdc9457SAndroid Build Coastguard Worker return *this; 59*4bdc9457SAndroid Build Coastguard Worker } 60*4bdc9457SAndroid Build Coastguard Worker inplace_b()61*4bdc9457SAndroid Build Coastguard Worker inline bool inplace_b() const { 62*4bdc9457SAndroid Build Coastguard Worker return this->inplace_b_; 63*4bdc9457SAndroid Build Coastguard Worker } 64*4bdc9457SAndroid Build Coastguard Worker qmin(uint8_t qmin)65*4bdc9457SAndroid Build Coastguard Worker inline VBinaryMicrokernelTester& qmin(uint8_t qmin) { 66*4bdc9457SAndroid Build Coastguard Worker this->qmin_ = qmin; 67*4bdc9457SAndroid Build Coastguard Worker return *this; 68*4bdc9457SAndroid Build Coastguard Worker } 69*4bdc9457SAndroid Build Coastguard Worker qmin()70*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmin() const { 71*4bdc9457SAndroid Build Coastguard Worker return this->qmin_; 72*4bdc9457SAndroid Build Coastguard Worker } 73*4bdc9457SAndroid Build Coastguard Worker qmax(uint8_t qmax)74*4bdc9457SAndroid Build Coastguard Worker inline VBinaryMicrokernelTester& qmax(uint8_t qmax) { 75*4bdc9457SAndroid Build Coastguard Worker this->qmax_ = qmax; 76*4bdc9457SAndroid Build Coastguard Worker return *this; 77*4bdc9457SAndroid Build Coastguard Worker } 78*4bdc9457SAndroid Build Coastguard Worker qmax()79*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmax() const { 80*4bdc9457SAndroid Build Coastguard Worker return this->qmax_; 81*4bdc9457SAndroid Build Coastguard Worker } 82*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)83*4bdc9457SAndroid Build Coastguard Worker inline VBinaryMicrokernelTester& iterations(size_t iterations) { 84*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 85*4bdc9457SAndroid Build Coastguard Worker return *this; 86*4bdc9457SAndroid Build Coastguard Worker } 87*4bdc9457SAndroid Build Coastguard Worker iterations()88*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 89*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 90*4bdc9457SAndroid Build Coastguard Worker } 91*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f16_vbinary_ukernel_function vbinary,OpType op_type)92*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f16_vbinary_ukernel_function vbinary, OpType op_type) const { 93*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 94*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 95*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.01f, 1.0f); 96*4bdc9457SAndroid Build Coastguard Worker 97*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> a(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 98*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> b(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 99*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> y(batch_size() + (inplace_a() || inplace_b() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0)); 100*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 101*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 102*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 103*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 104*4bdc9457SAndroid Build Coastguard Worker if (inplace_a() || inplace_b()) { 105*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 106*4bdc9457SAndroid Build Coastguard Worker } else { 107*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */); 108*4bdc9457SAndroid Build Coastguard Worker } 109*4bdc9457SAndroid Build Coastguard Worker const uint16_t* a_data = inplace_a() ? y.data() : a.data(); 110*4bdc9457SAndroid Build Coastguard Worker const uint16_t* b_data = inplace_b() ? y.data() : b.data(); 111*4bdc9457SAndroid Build Coastguard Worker 112*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 113*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 114*4bdc9457SAndroid Build Coastguard Worker switch (op_type) { 115*4bdc9457SAndroid Build Coastguard Worker case OpType::Add: 116*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_to_fp32_value(a_data[i]) + fp16_ieee_to_fp32_value(b_data[i]); 117*4bdc9457SAndroid Build Coastguard Worker break; 118*4bdc9457SAndroid Build Coastguard Worker case OpType::Div: 119*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_to_fp32_value(a_data[i]) / fp16_ieee_to_fp32_value(b_data[i]); 120*4bdc9457SAndroid Build Coastguard Worker break; 121*4bdc9457SAndroid Build Coastguard Worker case OpType::Max: 122*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::max<float>(fp16_ieee_to_fp32_value(a_data[i]), fp16_ieee_to_fp32_value(b_data[i])); 123*4bdc9457SAndroid Build Coastguard Worker break; 124*4bdc9457SAndroid Build Coastguard Worker case OpType::Min: 125*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::min<float>(fp16_ieee_to_fp32_value(a_data[i]), fp16_ieee_to_fp32_value(b_data[i])); 126*4bdc9457SAndroid Build Coastguard Worker break; 127*4bdc9457SAndroid Build Coastguard Worker case OpType::Mul: 128*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_to_fp32_value(a_data[i]) * fp16_ieee_to_fp32_value(b_data[i]); 129*4bdc9457SAndroid Build Coastguard Worker break; 130*4bdc9457SAndroid Build Coastguard Worker case OpType::SqrDiff: 131*4bdc9457SAndroid Build Coastguard Worker { 132*4bdc9457SAndroid Build Coastguard Worker const float diff = fp16_ieee_to_fp32_value(a_data[i]) - fp16_ieee_to_fp32_value(b_data[i]); 133*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = diff * diff; 134*4bdc9457SAndroid Build Coastguard Worker break; 135*4bdc9457SAndroid Build Coastguard Worker } 136*4bdc9457SAndroid Build Coastguard Worker case OpType::Sub: 137*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_to_fp32_value(a_data[i]) - fp16_ieee_to_fp32_value(b_data[i]); 138*4bdc9457SAndroid Build Coastguard Worker break; 139*4bdc9457SAndroid Build Coastguard Worker } 140*4bdc9457SAndroid Build Coastguard Worker } 141*4bdc9457SAndroid Build Coastguard Worker 142*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 143*4bdc9457SAndroid Build Coastguard Worker vbinary(batch_size() * sizeof(uint16_t), a_data, b_data, y.data(), nullptr); 144*4bdc9457SAndroid Build Coastguard Worker 145*4bdc9457SAndroid Build Coastguard Worker // Verify results. 146*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 147*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(fp16_ieee_to_fp32_value(y[i]), y_ref[i], std::max(1.0e-4f, std::abs(y_ref[i]) * 1.0e-2f)) 148*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size(); 149*4bdc9457SAndroid Build Coastguard Worker } 150*4bdc9457SAndroid Build Coastguard Worker } 151*4bdc9457SAndroid Build Coastguard Worker } 152*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f16_vbinary_minmax_ukernel_function vbinary_minmax,OpType op_type,xnn_init_f16_minmax_params_fn init_params)153*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f16_vbinary_minmax_ukernel_function vbinary_minmax, OpType op_type, xnn_init_f16_minmax_params_fn init_params) const { 154*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 155*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 156*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.01f, 1.0f); 157*4bdc9457SAndroid Build Coastguard Worker 158*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> a(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 159*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> b(batch_size() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 160*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> y(batch_size() + (inplace_a() || inplace_b() ? XNN_EXTRA_BYTES / sizeof(uint16_t) : 0)); 161*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 162*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 163*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 164*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 165*4bdc9457SAndroid Build Coastguard Worker if (inplace_a() || inplace_b()) { 166*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 167*4bdc9457SAndroid Build Coastguard Worker } else { 168*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */); 169*4bdc9457SAndroid Build Coastguard Worker } 170*4bdc9457SAndroid Build Coastguard Worker const uint16_t* a_data = inplace_a() ? y.data() : a.data(); 171*4bdc9457SAndroid Build Coastguard Worker const uint16_t* b_data = inplace_b() ? y.data() : b.data(); 172*4bdc9457SAndroid Build Coastguard Worker 173*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 174*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 175*4bdc9457SAndroid Build Coastguard Worker switch (op_type) { 176*4bdc9457SAndroid Build Coastguard Worker case OpType::Add: 177*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_to_fp32_value(a_data[i]) + fp16_ieee_to_fp32_value(b_data[i]); 178*4bdc9457SAndroid Build Coastguard Worker break; 179*4bdc9457SAndroid Build Coastguard Worker case OpType::Div: 180*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_to_fp32_value(a_data[i]) / fp16_ieee_to_fp32_value(b_data[i]); 181*4bdc9457SAndroid Build Coastguard Worker break; 182*4bdc9457SAndroid Build Coastguard Worker case OpType::Max: 183*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::max<float>(fp16_ieee_to_fp32_value(a_data[i]), fp16_ieee_to_fp32_value(b_data[i])); 184*4bdc9457SAndroid Build Coastguard Worker break; 185*4bdc9457SAndroid Build Coastguard Worker case OpType::Min: 186*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::min<float>(fp16_ieee_to_fp32_value(a_data[i]), fp16_ieee_to_fp32_value(b_data[i])); 187*4bdc9457SAndroid Build Coastguard Worker break; 188*4bdc9457SAndroid Build Coastguard Worker case OpType::Mul: 189*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_to_fp32_value(a_data[i]) * fp16_ieee_to_fp32_value(b_data[i]); 190*4bdc9457SAndroid Build Coastguard Worker break; 191*4bdc9457SAndroid Build Coastguard Worker case OpType::SqrDiff: 192*4bdc9457SAndroid Build Coastguard Worker { 193*4bdc9457SAndroid Build Coastguard Worker const float diff = fp16_ieee_to_fp32_value(a_data[i]) - fp16_ieee_to_fp32_value(b_data[i]); 194*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = diff * diff; 195*4bdc9457SAndroid Build Coastguard Worker break; 196*4bdc9457SAndroid Build Coastguard Worker } 197*4bdc9457SAndroid Build Coastguard Worker case OpType::Sub: 198*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = fp16_ieee_to_fp32_value(a_data[i]) - fp16_ieee_to_fp32_value(b_data[i]); 199*4bdc9457SAndroid Build Coastguard Worker break; 200*4bdc9457SAndroid Build Coastguard Worker } 201*4bdc9457SAndroid Build Coastguard Worker } 202*4bdc9457SAndroid Build Coastguard Worker 203*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(y_ref.cbegin(), y_ref.cend()); 204*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(y_ref.cbegin(), y_ref.cend()); 205*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 206*4bdc9457SAndroid Build Coastguard Worker const float y_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_range > 0.0f ? 207*4bdc9457SAndroid Build Coastguard Worker (accumulated_max - accumulated_range / 255.0f * float(255 - qmax())) : 208*4bdc9457SAndroid Build Coastguard Worker +std::numeric_limits<float>::infinity())); 209*4bdc9457SAndroid Build Coastguard Worker const float y_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_range > 0.0f ? 210*4bdc9457SAndroid Build Coastguard Worker (accumulated_min + accumulated_range / 255.0f * float(qmin())) : 211*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<float>::infinity())); 212*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 213*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::max<float>(std::min<float>(y_ref[i], y_max), y_min); 214*4bdc9457SAndroid Build Coastguard Worker } 215*4bdc9457SAndroid Build Coastguard Worker 216*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 217*4bdc9457SAndroid Build Coastguard Worker xnn_f16_minmax_params params; 218*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, 219*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_from_fp32_value(y_min), fp16_ieee_from_fp32_value(y_max)); 220*4bdc9457SAndroid Build Coastguard Worker 221*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 222*4bdc9457SAndroid Build Coastguard Worker vbinary_minmax(batch_size() * sizeof(uint16_t), a_data, b_data, y.data(), ¶ms); 223*4bdc9457SAndroid Build Coastguard Worker 224*4bdc9457SAndroid Build Coastguard Worker // Verify results. 225*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 226*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(fp16_ieee_to_fp32_value(y[i]), y_ref[i], std::max(1.0e-4f, std::abs(y_ref[i]) * 1.0e-2f)) 227*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size(); 228*4bdc9457SAndroid Build Coastguard Worker } 229*4bdc9457SAndroid Build Coastguard Worker } 230*4bdc9457SAndroid Build Coastguard Worker } 231*4bdc9457SAndroid Build Coastguard Worker 232*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_vbinary_ukernel_function vbinary, OpType op_type, xnn_init_f32_default_params_fn init_params = nullptr) const { 233*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 234*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 235*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.01f, 1.0f); 236*4bdc9457SAndroid Build Coastguard Worker 237*4bdc9457SAndroid Build Coastguard Worker std::vector<float> a(batch_size() + XNN_EXTRA_BYTES / sizeof(float)); 238*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(batch_size() + XNN_EXTRA_BYTES / sizeof(float)); 239*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y(batch_size() + (inplace_a() || inplace_b() ? XNN_EXTRA_BYTES / sizeof(float) : 0)); 240*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 241*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 242*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); }); 243*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); }); 244*4bdc9457SAndroid Build Coastguard Worker if (inplace_a() || inplace_b()) { 245*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); }); 246*4bdc9457SAndroid Build Coastguard Worker } else { 247*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), nanf("")); 248*4bdc9457SAndroid Build Coastguard Worker } 249*4bdc9457SAndroid Build Coastguard Worker const float* a_data = inplace_a() ? y.data() : a.data(); 250*4bdc9457SAndroid Build Coastguard Worker const float* b_data = inplace_b() ? y.data() : b.data(); 251*4bdc9457SAndroid Build Coastguard Worker 252*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 253*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 254*4bdc9457SAndroid Build Coastguard Worker switch (op_type) { 255*4bdc9457SAndroid Build Coastguard Worker case OpType::Add: 256*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] + b_data[i]; 257*4bdc9457SAndroid Build Coastguard Worker break; 258*4bdc9457SAndroid Build Coastguard Worker case OpType::Div: 259*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] / b_data[i]; 260*4bdc9457SAndroid Build Coastguard Worker break; 261*4bdc9457SAndroid Build Coastguard Worker case OpType::Max: 262*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::max<float>(a_data[i], b_data[i]); 263*4bdc9457SAndroid Build Coastguard Worker break; 264*4bdc9457SAndroid Build Coastguard Worker case OpType::Min: 265*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::min<float>(a_data[i], b_data[i]); 266*4bdc9457SAndroid Build Coastguard Worker break; 267*4bdc9457SAndroid Build Coastguard Worker case OpType::Mul: 268*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] * b_data[i]; 269*4bdc9457SAndroid Build Coastguard Worker break; 270*4bdc9457SAndroid Build Coastguard Worker case OpType::SqrDiff: 271*4bdc9457SAndroid Build Coastguard Worker { 272*4bdc9457SAndroid Build Coastguard Worker const float diff = a_data[i] - b_data[i]; 273*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = diff * diff; 274*4bdc9457SAndroid Build Coastguard Worker break; 275*4bdc9457SAndroid Build Coastguard Worker } 276*4bdc9457SAndroid Build Coastguard Worker case OpType::Sub: 277*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] - b_data[i]; 278*4bdc9457SAndroid Build Coastguard Worker break; 279*4bdc9457SAndroid Build Coastguard Worker } 280*4bdc9457SAndroid Build Coastguard Worker } 281*4bdc9457SAndroid Build Coastguard Worker 282*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 283*4bdc9457SAndroid Build Coastguard Worker xnn_f32_default_params params; 284*4bdc9457SAndroid Build Coastguard Worker if (init_params) { 285*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms); 286*4bdc9457SAndroid Build Coastguard Worker } 287*4bdc9457SAndroid Build Coastguard Worker 288*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 289*4bdc9457SAndroid Build Coastguard Worker vbinary(batch_size() * sizeof(float), a_data, b_data, y.data(), init_params != nullptr ? ¶ms : nullptr); 290*4bdc9457SAndroid Build Coastguard Worker 291*4bdc9457SAndroid Build Coastguard Worker // Verify results. 292*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 293*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(y[i], y_ref[i], std::abs(y_ref[i]) * 1.0e-6f) 294*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size(); 295*4bdc9457SAndroid Build Coastguard Worker } 296*4bdc9457SAndroid Build Coastguard Worker } 297*4bdc9457SAndroid Build Coastguard Worker } 298*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f32_vbinary_relu_ukernel_function vbinary_relu,OpType op_type)299*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_vbinary_relu_ukernel_function vbinary_relu, OpType op_type) const { 300*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 301*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 302*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-1.0f, 1.0f); 303*4bdc9457SAndroid Build Coastguard Worker 304*4bdc9457SAndroid Build Coastguard Worker std::vector<float> a(batch_size() + XNN_EXTRA_BYTES / sizeof(float)); 305*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(batch_size() + XNN_EXTRA_BYTES / sizeof(float)); 306*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y(batch_size() + (inplace_a() || inplace_b() ? XNN_EXTRA_BYTES / sizeof(float) : 0)); 307*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 308*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 309*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); }); 310*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); }); 311*4bdc9457SAndroid Build Coastguard Worker if (inplace_a() || inplace_b()) { 312*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); }); 313*4bdc9457SAndroid Build Coastguard Worker } else { 314*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), nanf("")); 315*4bdc9457SAndroid Build Coastguard Worker } 316*4bdc9457SAndroid Build Coastguard Worker const float* a_data = inplace_a() ? y.data() : a.data(); 317*4bdc9457SAndroid Build Coastguard Worker const float* b_data = inplace_b() ? y.data() : b.data(); 318*4bdc9457SAndroid Build Coastguard Worker 319*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 320*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 321*4bdc9457SAndroid Build Coastguard Worker switch (op_type) { 322*4bdc9457SAndroid Build Coastguard Worker case OpType::Add: 323*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] + b_data[i]; 324*4bdc9457SAndroid Build Coastguard Worker break; 325*4bdc9457SAndroid Build Coastguard Worker case OpType::Div: 326*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] / b_data[i]; 327*4bdc9457SAndroid Build Coastguard Worker break; 328*4bdc9457SAndroid Build Coastguard Worker case OpType::Max: 329*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::max<float>(a_data[i], b_data[i]); 330*4bdc9457SAndroid Build Coastguard Worker break; 331*4bdc9457SAndroid Build Coastguard Worker case OpType::Min: 332*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::min<float>(a_data[i], b_data[i]); 333*4bdc9457SAndroid Build Coastguard Worker break; 334*4bdc9457SAndroid Build Coastguard Worker case OpType::Mul: 335*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] * b_data[i]; 336*4bdc9457SAndroid Build Coastguard Worker break; 337*4bdc9457SAndroid Build Coastguard Worker case OpType::SqrDiff: 338*4bdc9457SAndroid Build Coastguard Worker { 339*4bdc9457SAndroid Build Coastguard Worker const float diff = a_data[i] - b_data[i]; 340*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = diff * diff; 341*4bdc9457SAndroid Build Coastguard Worker break; 342*4bdc9457SAndroid Build Coastguard Worker } 343*4bdc9457SAndroid Build Coastguard Worker case OpType::Sub: 344*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] - b_data[i]; 345*4bdc9457SAndroid Build Coastguard Worker break; 346*4bdc9457SAndroid Build Coastguard Worker } 347*4bdc9457SAndroid Build Coastguard Worker } 348*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 349*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::max(y_ref[i], 0.0f); 350*4bdc9457SAndroid Build Coastguard Worker } 351*4bdc9457SAndroid Build Coastguard Worker 352*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 353*4bdc9457SAndroid Build Coastguard Worker vbinary_relu(batch_size() * sizeof(float), a_data, b_data, y.data(), nullptr); 354*4bdc9457SAndroid Build Coastguard Worker 355*4bdc9457SAndroid Build Coastguard Worker // Verify results. 356*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 357*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(y[i], 0.0f) 358*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size(); 359*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(y[i], y_ref[i], std::abs(y_ref[i]) * 1.0e-6f) 360*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size(); 361*4bdc9457SAndroid Build Coastguard Worker } 362*4bdc9457SAndroid Build Coastguard Worker } 363*4bdc9457SAndroid Build Coastguard Worker } 364*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f32_vbinary_minmax_ukernel_function vbinary_minmax,OpType op_type,xnn_init_f32_minmax_params_fn init_params)365*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_vbinary_minmax_ukernel_function vbinary_minmax, OpType op_type, xnn_init_f32_minmax_params_fn init_params) const { 366*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 367*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 368*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.01f, 1.0f); 369*4bdc9457SAndroid Build Coastguard Worker 370*4bdc9457SAndroid Build Coastguard Worker std::vector<float> a(batch_size() + XNN_EXTRA_BYTES / sizeof(float)); 371*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(batch_size() + XNN_EXTRA_BYTES / sizeof(float)); 372*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y(batch_size() + (inplace_a() || inplace_b() ? XNN_EXTRA_BYTES / sizeof(float) : 0)); 373*4bdc9457SAndroid Build Coastguard Worker std::vector<float> y_ref(batch_size()); 374*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 375*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); }); 376*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); }); 377*4bdc9457SAndroid Build Coastguard Worker if (inplace_a() || inplace_b()) { 378*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), [&]() { return f32dist(rng); }); 379*4bdc9457SAndroid Build Coastguard Worker } else { 380*4bdc9457SAndroid Build Coastguard Worker std::fill(y.begin(), y.end(), nanf("")); 381*4bdc9457SAndroid Build Coastguard Worker } 382*4bdc9457SAndroid Build Coastguard Worker const float* a_data = inplace_a() ? y.data() : a.data(); 383*4bdc9457SAndroid Build Coastguard Worker const float* b_data = inplace_b() ? y.data() : b.data(); 384*4bdc9457SAndroid Build Coastguard Worker 385*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 386*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 387*4bdc9457SAndroid Build Coastguard Worker switch (op_type) { 388*4bdc9457SAndroid Build Coastguard Worker case OpType::Add: 389*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] + b_data[i]; 390*4bdc9457SAndroid Build Coastguard Worker break; 391*4bdc9457SAndroid Build Coastguard Worker case OpType::Div: 392*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] / b_data[i]; 393*4bdc9457SAndroid Build Coastguard Worker break; 394*4bdc9457SAndroid Build Coastguard Worker case OpType::Max: 395*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::max<float>(a_data[i], b_data[i]); 396*4bdc9457SAndroid Build Coastguard Worker break; 397*4bdc9457SAndroid Build Coastguard Worker case OpType::Min: 398*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::min<float>(a_data[i], b_data[i]); 399*4bdc9457SAndroid Build Coastguard Worker break; 400*4bdc9457SAndroid Build Coastguard Worker case OpType::Mul: 401*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] * b_data[i]; 402*4bdc9457SAndroid Build Coastguard Worker break; 403*4bdc9457SAndroid Build Coastguard Worker case OpType::SqrDiff: 404*4bdc9457SAndroid Build Coastguard Worker { 405*4bdc9457SAndroid Build Coastguard Worker const float diff = a_data[i] - b_data[i]; 406*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = diff * diff; 407*4bdc9457SAndroid Build Coastguard Worker break; 408*4bdc9457SAndroid Build Coastguard Worker } 409*4bdc9457SAndroid Build Coastguard Worker case OpType::Sub: 410*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = a_data[i] - b_data[i]; 411*4bdc9457SAndroid Build Coastguard Worker break; 412*4bdc9457SAndroid Build Coastguard Worker } 413*4bdc9457SAndroid Build Coastguard Worker } 414*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(y_ref.cbegin(), y_ref.cend()); 415*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(y_ref.cbegin(), y_ref.cend()); 416*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 417*4bdc9457SAndroid Build Coastguard Worker const float y_max = accumulated_range > 0.0f ? 418*4bdc9457SAndroid Build Coastguard Worker (accumulated_max - accumulated_range / 255.0f * float(255 - qmax())) : 419*4bdc9457SAndroid Build Coastguard Worker +std::numeric_limits<float>::infinity(); 420*4bdc9457SAndroid Build Coastguard Worker const float y_min = accumulated_range > 0.0f ? 421*4bdc9457SAndroid Build Coastguard Worker (accumulated_min + accumulated_range / 255.0f * float(qmin())) : 422*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<float>::infinity(); 423*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 424*4bdc9457SAndroid Build Coastguard Worker y_ref[i] = std::max<float>(std::min<float>(y_ref[i], y_max), y_min); 425*4bdc9457SAndroid Build Coastguard Worker } 426*4bdc9457SAndroid Build Coastguard Worker 427*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 428*4bdc9457SAndroid Build Coastguard Worker xnn_f32_minmax_params params; 429*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, y_min, y_max); 430*4bdc9457SAndroid Build Coastguard Worker 431*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 432*4bdc9457SAndroid Build Coastguard Worker vbinary_minmax(batch_size() * sizeof(float), a_data, b_data, y.data(), ¶ms); 433*4bdc9457SAndroid Build Coastguard Worker 434*4bdc9457SAndroid Build Coastguard Worker // Verify results. 435*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 436*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(y[i], y_ref[i], std::abs(y_ref[i]) * 1.0e-6f) 437*4bdc9457SAndroid Build Coastguard Worker << "at " << i << " / " << batch_size(); 438*4bdc9457SAndroid Build Coastguard Worker } 439*4bdc9457SAndroid Build Coastguard Worker } 440*4bdc9457SAndroid Build Coastguard Worker } 441*4bdc9457SAndroid Build Coastguard Worker 442*4bdc9457SAndroid Build Coastguard Worker private: 443*4bdc9457SAndroid Build Coastguard Worker size_t batch_size_{1}; 444*4bdc9457SAndroid Build Coastguard Worker bool inplace_a_{false}; 445*4bdc9457SAndroid Build Coastguard Worker bool inplace_b_{false}; 446*4bdc9457SAndroid Build Coastguard Worker uint8_t qmin_{0}; 447*4bdc9457SAndroid Build Coastguard Worker uint8_t qmax_{255}; 448*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{15}; 449*4bdc9457SAndroid Build Coastguard Worker }; 450