1 // Copyright (c) Facebook, Inc. and its affiliates. 2 // All rights reserved. 3 // 4 // Copyright 2019 Google LLC 5 // 6 // This source code is licensed under the BSD-style license found in the 7 // LICENSE file in the root directory of this source tree. 8 9 #pragma once 10 11 #include <gtest/gtest.h> 12 13 #include <algorithm> 14 #include <array> 15 #include <cassert> 16 #include <cstddef> 17 #include <cstdlib> 18 #include <functional> 19 #include <limits> 20 #include <random> 21 #include <vector> 22 23 #include <xnnpack.h> 24 #include <xnnpack/microfnptr.h> 25 26 27 class LUTMicrokernelTester { 28 public: batch_size(size_t batch_size)29 inline LUTMicrokernelTester& batch_size(size_t batch_size) { 30 assert(batch_size != 0); 31 this->batch_size_ = batch_size; 32 return *this; 33 } 34 batch_size()35 inline size_t batch_size() const { 36 return this->batch_size_; 37 } 38 inplace(bool inplace)39 inline LUTMicrokernelTester& inplace(bool inplace) { 40 this->inplace_ = inplace; 41 return *this; 42 } 43 inplace()44 inline bool inplace() const { 45 return this->inplace_; 46 } 47 iterations(size_t iterations)48 inline LUTMicrokernelTester& iterations(size_t iterations) { 49 this->iterations_ = iterations; 50 return *this; 51 } 52 iterations()53 inline size_t iterations() const { 54 return this->iterations_; 55 } 56 Test(xnn_x8_lut_ukernel_function lut)57 void Test(xnn_x8_lut_ukernel_function lut) const { 58 std::random_device random_device; 59 auto rng = std::mt19937(random_device()); 60 auto u8rng = std::bind( 61 std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), std::ref(rng)); 62 63 std::vector<uint8_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint8_t)); 64 XNN_ALIGN(64) std::array<uint8_t, 256> t; 65 std::vector<uint8_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint8_t) : 0)); 66 std::vector<uint8_t> y_ref(batch_size()); 67 for (size_t iteration = 0; iteration < iterations(); iteration++) { 68 std::generate(x.begin(), x.end(), std::ref(u8rng)); 69 std::generate(t.begin(), t.end(), std::ref(u8rng)); 70 if (inplace()) { 71 std::generate(y.begin(), y.end(), std::ref(u8rng)); 72 } else { 73 std::fill(y.begin(), y.end(), 0xA5); 74 } 75 const uint8_t* x_data = inplace() ? y.data() : x.data(); 76 77 // Compute reference results. 78 for (size_t i = 0; i < batch_size(); i++) { 79 y_ref[i] = t[x_data[i]]; 80 } 81 82 // Call optimized micro-kernel. 83 lut(batch_size(), x_data, y.data(), t.data()); 84 85 // Verify results. 86 for (size_t i = 0; i < batch_size(); i++) { 87 ASSERT_EQ(uint32_t(y_ref[i]), uint32_t(y[i])) 88 << "at position " << i << " / " << batch_size(); 89 } 90 } 91 } 92 93 private: 94 size_t batch_size_{1}; 95 bool inplace_{false}; 96 size_t iterations_{15}; 97 }; 98