xref: /aosp_15_r20/external/XNNPACK/test/lut-microkernel-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <gtest/gtest.h>
12 
13 #include <algorithm>
14 #include <array>
15 #include <cassert>
16 #include <cstddef>
17 #include <cstdlib>
18 #include <functional>
19 #include <limits>
20 #include <random>
21 #include <vector>
22 
23 #include <xnnpack.h>
24 #include <xnnpack/microfnptr.h>
25 
26 
27 class LUTMicrokernelTester {
28  public:
batch_size(size_t batch_size)29   inline LUTMicrokernelTester& batch_size(size_t batch_size) {
30     assert(batch_size != 0);
31     this->batch_size_ = batch_size;
32     return *this;
33   }
34 
batch_size()35   inline size_t batch_size() const {
36     return this->batch_size_;
37   }
38 
inplace(bool inplace)39   inline LUTMicrokernelTester& inplace(bool inplace) {
40     this->inplace_ = inplace;
41     return *this;
42   }
43 
inplace()44   inline bool inplace() const {
45     return this->inplace_;
46   }
47 
iterations(size_t iterations)48   inline LUTMicrokernelTester& iterations(size_t iterations) {
49     this->iterations_ = iterations;
50     return *this;
51   }
52 
iterations()53   inline size_t iterations() const {
54     return this->iterations_;
55   }
56 
Test(xnn_x8_lut_ukernel_function lut)57   void Test(xnn_x8_lut_ukernel_function lut) const {
58     std::random_device random_device;
59     auto rng = std::mt19937(random_device());
60     auto u8rng = std::bind(
61       std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), std::ref(rng));
62 
63     std::vector<uint8_t> x(batch_size() + XNN_EXTRA_BYTES / sizeof(uint8_t));
64     XNN_ALIGN(64) std::array<uint8_t, 256> t;
65     std::vector<uint8_t> y(batch_size() + (inplace() ? XNN_EXTRA_BYTES / sizeof(uint8_t) : 0));
66     std::vector<uint8_t> y_ref(batch_size());
67     for (size_t iteration = 0; iteration < iterations(); iteration++) {
68       std::generate(x.begin(), x.end(), std::ref(u8rng));
69       std::generate(t.begin(), t.end(), std::ref(u8rng));
70       if (inplace()) {
71         std::generate(y.begin(), y.end(), std::ref(u8rng));
72       } else {
73         std::fill(y.begin(), y.end(), 0xA5);
74       }
75       const uint8_t* x_data = inplace() ? y.data() : x.data();
76 
77       // Compute reference results.
78       for (size_t i = 0; i < batch_size(); i++) {
79         y_ref[i] = t[x_data[i]];
80       }
81 
82       // Call optimized micro-kernel.
83       lut(batch_size(), x_data, y.data(), t.data());
84 
85       // Verify results.
86       for (size_t i = 0; i < batch_size(); i++) {
87         ASSERT_EQ(uint32_t(y_ref[i]), uint32_t(y[i]))
88           << "at position " << i << " / " << batch_size();
89       }
90     }
91   }
92 
93  private:
94   size_t batch_size_{1};
95   bool inplace_{false};
96   size_t iterations_{15};
97 };
98