xref: /aosp_15_r20/external/XNNPACK/test/fill-microkernel-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 
8 #include <gtest/gtest.h>
9 
10 #include <array>
11 #include <algorithm>
12 #include <cassert>
13 #include <cstddef>
14 #include <cstdlib>
15 #include <functional>
16 #include <random>
17 #include <vector>
18 
19 #include <xnnpack.h>
20 #include <xnnpack/microfnptr.h>
21 
22 
23 class FillMicrokernelTester {
24  public:
rows(size_t rows)25   inline FillMicrokernelTester& rows(size_t rows) {
26     assert(rows != 0);
27     this->rows_ = rows;
28     return *this;
29   }
30 
rows()31   inline size_t rows() const {
32     return this->rows_;
33   }
34 
channels(size_t channels)35   inline FillMicrokernelTester& channels(size_t channels) {
36     assert(channels != 0);
37     this->channels_ = channels;
38     return *this;
39   }
40 
channels()41   inline size_t channels() const {
42     return this->channels_;
43   }
44 
output_stride(size_t output_stride)45   inline FillMicrokernelTester& output_stride(size_t output_stride) {
46     assert(output_stride != 0);
47     this->output_stride_ = output_stride;
48     return *this;
49   }
50 
output_stride()51   inline size_t output_stride() const {
52     if (this->output_stride_ == 0) {
53       return channels();
54     } else {
55       return this->output_stride_;
56     }
57   }
58 
iterations(size_t iterations)59   inline FillMicrokernelTester& iterations(size_t iterations) {
60     this->iterations_ = iterations;
61     return *this;
62   }
63 
iterations()64   inline size_t iterations() const {
65     return this->iterations_;
66   }
67 
Test(xnn_fill_ukernel_function fill)68   void Test(xnn_fill_ukernel_function fill) const {
69     ASSERT_GE(output_stride(), channels());
70 
71     std::random_device random_device;
72     auto rng = std::mt19937(random_device());
73     auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng);
74 
75     std::vector<uint8_t> output((rows() - 1) * output_stride() + channels());
76     std::vector<uint8_t> output_copy(output.size());
77     for (size_t iteration = 0; iteration < iterations(); iteration++) {
78       std::generate(output.begin(), output.end(), std::ref(u8rng));
79       std::copy(output.cbegin(), output.cend(), output_copy.begin());
80       std::array<uint8_t, 4> fill_pattern;
81       std::generate(fill_pattern.begin(), fill_pattern.end(), std::ref(u8rng));
82       uint32_t fill_value = 0;
83       memcpy(&fill_value, fill_pattern.data(), sizeof(fill_value));
84 
85       // Call optimized micro-kernel.
86       fill(
87         rows(),
88         channels() * sizeof(uint8_t),
89         output.data(),
90         output_stride() * sizeof(uint8_t),
91         fill_value);
92 
93       // Verify results.
94       for (size_t i = 0; i < rows(); i++) {
95         for (size_t c = 0; c < channels(); c++) {
96           ASSERT_EQ(uint32_t(output[i * output_stride() + c]), uint32_t(fill_pattern[c % fill_pattern.size()]))
97             << "at row " << i << " / " << rows()
98             << ", channel " << c << " / " << channels()
99             << ", fill value 0x" << std::hex << std::setw(8) << std::setfill('0') << fill_value
100             << ", output value 0x" << std::hex << std::setw(8) << std::setfill('0') << output[i * output_stride() + c];
101         }
102       }
103       for (size_t i = 0; i + 1 < rows(); i++) {
104         for (size_t c = channels(); c < output_stride(); c++) {
105           ASSERT_EQ(uint32_t(output[i * output_stride() + c]), uint32_t(output_copy[i * output_stride() + c]))
106             << "at row " << i << " / " << rows()
107             << ", channel " << c << " / " << channels()
108             << ", original value 0x" << std::hex << std::setw(8) << std::setfill('0')
109             << output_copy[i * output_stride() + c]
110             << ", output value 0x" << std::hex << std::setw(8) << std::setfill('0') << output[i * output_stride() + c];
111         }
112       }
113     }
114   }
115 
116  private:
117   size_t rows_{1};
118   size_t channels_{1};
119   size_t output_stride_{0};
120   size_t iterations_{15};
121 };
122