1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker // 3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker #pragma once 7*4bdc9457SAndroid Build Coastguard Worker 8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 9*4bdc9457SAndroid Build Coastguard Worker 10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 11*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 12*4bdc9457SAndroid Build Coastguard Worker #include <cmath> 13*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 14*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 15*4bdc9457SAndroid Build Coastguard Worker #include <random> 16*4bdc9457SAndroid Build Coastguard Worker #include <vector> 17*4bdc9457SAndroid Build Coastguard Worker 18*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h> 20*4bdc9457SAndroid Build Coastguard Worker 21*4bdc9457SAndroid Build Coastguard Worker 22*4bdc9457SAndroid Build Coastguard Worker class VLShiftMicrokernelTester { 23*4bdc9457SAndroid Build Coastguard Worker public: batch(size_t batch)24*4bdc9457SAndroid Build Coastguard Worker inline VLShiftMicrokernelTester& batch(size_t batch) { 25*4bdc9457SAndroid Build Coastguard Worker assert(batch != 0); 26*4bdc9457SAndroid Build Coastguard Worker this->batch_ = batch; 27*4bdc9457SAndroid Build Coastguard Worker return *this; 28*4bdc9457SAndroid Build Coastguard Worker } 29*4bdc9457SAndroid Build Coastguard Worker batch()30*4bdc9457SAndroid Build Coastguard Worker inline size_t batch() const { 31*4bdc9457SAndroid Build Coastguard Worker return this->batch_; 32*4bdc9457SAndroid Build Coastguard Worker } 33*4bdc9457SAndroid Build Coastguard Worker shift(uint32_t shift)34*4bdc9457SAndroid Build Coastguard Worker inline VLShiftMicrokernelTester& shift(uint32_t shift) { 35*4bdc9457SAndroid Build Coastguard Worker assert(shift < 32); 36*4bdc9457SAndroid Build Coastguard Worker this->shift_ = shift; 37*4bdc9457SAndroid Build Coastguard Worker return *this; 38*4bdc9457SAndroid Build Coastguard Worker } 39*4bdc9457SAndroid Build Coastguard Worker shift()40*4bdc9457SAndroid Build Coastguard Worker inline uint32_t shift() const { 41*4bdc9457SAndroid Build Coastguard Worker return this->shift_; 42*4bdc9457SAndroid Build Coastguard Worker } 43*4bdc9457SAndroid Build Coastguard Worker inplace(bool inplace)44*4bdc9457SAndroid Build Coastguard Worker inline VLShiftMicrokernelTester& inplace(bool inplace) { 45*4bdc9457SAndroid Build Coastguard Worker this->inplace_ = inplace; 46*4bdc9457SAndroid Build Coastguard Worker return *this; 47*4bdc9457SAndroid Build Coastguard Worker } 48*4bdc9457SAndroid Build Coastguard Worker inplace()49*4bdc9457SAndroid Build Coastguard Worker inline bool inplace() const { 50*4bdc9457SAndroid Build Coastguard Worker return this->inplace_; 51*4bdc9457SAndroid Build Coastguard Worker } 52*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)53*4bdc9457SAndroid Build Coastguard Worker inline VLShiftMicrokernelTester& iterations(size_t iterations) { 54*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 55*4bdc9457SAndroid Build Coastguard Worker return *this; 56*4bdc9457SAndroid Build Coastguard Worker } 57*4bdc9457SAndroid Build Coastguard Worker iterations()58*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 59*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 60*4bdc9457SAndroid Build Coastguard Worker } 61*4bdc9457SAndroid Build Coastguard Worker Test(xnn_s16_vlshift_ukernel_function vlshift)62*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_s16_vlshift_ukernel_function vlshift) const { 63*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 64*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 65*4bdc9457SAndroid Build Coastguard Worker auto i16rng = std::bind(std::uniform_int_distribution<int16_t>(), std::ref(rng)); 66*4bdc9457SAndroid Build Coastguard Worker 67*4bdc9457SAndroid Build Coastguard Worker std::vector<int16_t> x(batch() + XNN_EXTRA_BYTES / sizeof(int16_t)); 68*4bdc9457SAndroid Build Coastguard Worker std::vector<int16_t> y(batch() + (inplace() ? XNN_EXTRA_BYTES / sizeof(int16_t) : 0)); 69*4bdc9457SAndroid Build Coastguard Worker std::vector<int16_t> y_ref(batch()); 70*4bdc9457SAndroid Build Coastguard Worker const int16_t* x_data = inplace() ? y.data() : x.data(); 71*4bdc9457SAndroid Build Coastguard Worker 72*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 73*4bdc9457SAndroid Build Coastguard Worker std::generate(x.begin(), x.end(), std::ref(i16rng)); 74*4bdc9457SAndroid Build Coastguard Worker std::generate(y.begin(), y.end(), std::ref(i16rng)); 75*4bdc9457SAndroid Build Coastguard Worker std::generate(y_ref.begin(), y_ref.end(), std::ref(i16rng)); 76*4bdc9457SAndroid Build Coastguard Worker 77*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 78*4bdc9457SAndroid Build Coastguard Worker for (size_t n = 0; n < batch(); n++) { 79*4bdc9457SAndroid Build Coastguard Worker const uint16_t i = static_cast<uint16_t>(x_data[n]); 80*4bdc9457SAndroid Build Coastguard Worker uint16_t value = i << shift(); 81*4bdc9457SAndroid Build Coastguard Worker y_ref[n] = reinterpret_cast<uint16_t>(value); 82*4bdc9457SAndroid Build Coastguard Worker } 83*4bdc9457SAndroid Build Coastguard Worker 84*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 85*4bdc9457SAndroid Build Coastguard Worker vlshift(batch(), x_data, y.data(), shift()); 86*4bdc9457SAndroid Build Coastguard Worker 87*4bdc9457SAndroid Build Coastguard Worker // Verify results. 88*4bdc9457SAndroid Build Coastguard Worker for (size_t n = 0; n < batch(); n++) { 89*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(y[n], y_ref[n]) 90*4bdc9457SAndroid Build Coastguard Worker << ", shift " << shift() 91*4bdc9457SAndroid Build Coastguard Worker << ", batch " << n << " / " << batch(); 92*4bdc9457SAndroid Build Coastguard Worker } 93*4bdc9457SAndroid Build Coastguard Worker } 94*4bdc9457SAndroid Build Coastguard Worker } 95*4bdc9457SAndroid Build Coastguard Worker 96*4bdc9457SAndroid Build Coastguard Worker private: 97*4bdc9457SAndroid Build Coastguard Worker size_t batch_{1}; 98*4bdc9457SAndroid Build Coastguard Worker uint32_t shift_{12}; 99*4bdc9457SAndroid Build Coastguard Worker bool inplace_{false}; 100*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{15}; 101*4bdc9457SAndroid Build Coastguard Worker }; 102