xref: /aosp_15_r20/external/XNNPACK/test/pad-microkernel-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #pragma once
7 
8 #include <gtest/gtest.h>
9 
10 #include <array>
11 #include <algorithm>
12 #include <cassert>
13 #include <cstddef>
14 #include <cstdlib>
15 #include <functional>
16 #include <random>
17 #include <vector>
18 
19 #include <xnnpack.h>
20 #include <xnnpack/microfnptr.h>
21 
22 
23 class PadMicrokernelTester {
24  public:
rows(size_t rows)25   inline PadMicrokernelTester& rows(size_t rows) {
26     assert(rows != 0);
27     this->rows_ = rows;
28     return *this;
29   }
30 
rows()31   inline size_t rows() const {
32     return this->rows_;
33   }
34 
input_channels(size_t input_channels)35   inline PadMicrokernelTester& input_channels(size_t input_channels) {
36     assert(input_channels != 0);
37     this->input_channels_ = input_channels;
38     return *this;
39   }
40 
input_channels()41   inline size_t input_channels() const {
42     return this->input_channels_;
43   }
44 
pre_padding(size_t pre_padding)45   inline PadMicrokernelTester& pre_padding(size_t pre_padding) {
46     this->pre_padding_ = pre_padding;
47     return *this;
48   }
49 
pre_padding()50   inline size_t pre_padding() const {
51     return this->pre_padding_;
52   }
53 
post_padding(size_t post_padding)54   inline PadMicrokernelTester& post_padding(size_t post_padding) {
55     this->post_padding_ = post_padding;
56     return *this;
57   }
58 
post_padding()59   inline size_t post_padding() const {
60     return this->post_padding_;
61   }
62 
output_channels()63   inline size_t output_channels() const {
64     return pre_padding() + input_channels() + post_padding();
65   }
66 
input_stride(size_t input_stride)67   inline PadMicrokernelTester& input_stride(size_t input_stride) {
68     assert(input_stride != 0);
69     this->input_stride_ = input_stride;
70     return *this;
71   }
72 
input_stride()73   inline size_t input_stride() const {
74     if (this->input_stride_ == 0) {
75       return input_channels();
76     } else {
77       assert(this->input_stride_ >= input_channels());
78       return this->input_stride_;
79     }
80   }
81 
output_stride(size_t output_stride)82   inline PadMicrokernelTester& output_stride(size_t output_stride) {
83     assert(output_stride != 0);
84     this->output_stride_ = output_stride;
85     return *this;
86   }
87 
output_stride()88   inline size_t output_stride() const {
89     if (this->output_stride_ == 0) {
90       return pre_padding() + input_channels() + post_padding();
91     } else {
92       assert(this->output_stride_ >= pre_padding() + input_channels() + post_padding());
93       return this->output_stride_;
94     }
95   }
96 
iterations(size_t iterations)97   inline PadMicrokernelTester& iterations(size_t iterations) {
98     this->iterations_ = iterations;
99     return *this;
100   }
101 
iterations()102   inline size_t iterations() const {
103     return this->iterations_;
104   }
105 
Test(xnn_pad_ukernel_function pad)106   void Test(xnn_pad_ukernel_function pad) const {
107     std::random_device random_device;
108     auto rng = std::mt19937(random_device());
109     auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng);
110 
111     std::vector<uint8_t> input(input_channels() + (rows() - 1) * input_stride() + XNN_EXTRA_BYTES / sizeof(uint8_t));
112     std::vector<uint8_t> output((pre_padding() + input_channels() + post_padding()) + (rows() - 1) * output_stride());
113     for (size_t iteration = 0; iteration < iterations(); iteration++) {
114       std::generate(input.begin(), input.end(), std::ref(u8rng));
115       std::generate(output.begin(), output.end(), std::ref(u8rng));
116       std::array<uint8_t, 4> fill_pattern;
117       std::generate(fill_pattern.begin(), fill_pattern.end(), std::ref(u8rng));
118       uint32_t fill_value = 0;
119       memcpy(&fill_value, fill_pattern.data(), sizeof(fill_value));
120 
121       // Call optimized micro-kernel.
122       pad(
123         rows(),
124         input_channels() * sizeof(uint8_t),
125         pre_padding() * sizeof(uint8_t),
126         post_padding() * sizeof(uint8_t),
127         input.data(), input_stride() * sizeof(uint8_t),
128         output.data(), output_stride() * sizeof(uint8_t),
129         fill_value);
130 
131       // Verify results.
132       for (size_t i = 0; i < rows(); i++) {
133         for (size_t l = 0; l < pre_padding(); l++) {
134           ASSERT_EQ(
135               uint32_t(output[i * output_stride() + l]),
136               uint32_t(fill_pattern[l % fill_pattern.size()]))
137             << "at row " << i << " / " << rows() << ", channel " << i << " / " << output_channels()
138             << " (" << pre_padding() << " + " << input_channels() << " + " << post_padding() << ")"
139             << ", fill value 0x" << std::hex << std::setw(8) << std::setfill('0') << fill_value
140             << ", output value 0x" << std::hex << std::setw(2) << std::setfill('0')
141             << uint32_t(output[i * output_stride() + l]);
142         }
143         for (size_t c = 0; c < input_channels(); c++) {
144           ASSERT_EQ(
145               uint32_t(output[i * output_stride() + pre_padding() + c]),
146               uint32_t(input[i * input_stride() + c]))
147             << "at row " << i << " / " << rows() << ", channel " << i << " / " << output_channels()
148             << " (" << pre_padding() << " + " << input_channels() << " + " << post_padding() << ")"
149             << ", fill value 0x" << std::hex << std::setw(8) << std::setfill('0') << fill_value
150             << ", output value 0x" << std::hex << std::setw(2) << std::setfill('0')
151             << uint32_t(output[i * output_stride() + pre_padding() + c]);
152         }
153         for (size_t r = 0; r < post_padding(); r++) {
154           ASSERT_EQ(
155               uint32_t(output[i * output_stride() + pre_padding() + input_channels() + r]),
156               uint32_t(fill_pattern[r % fill_pattern.size()]))
157             << "at row " << i << " / " << rows() << ", channel " << i << " / " << output_channels()
158             << " (" << pre_padding() << " + " << input_channels() << " + " << post_padding() << ")"
159             << ", fill value 0x" << std::hex << std::setw(8) << std::setfill('0') << fill_value
160             << ", output value 0x" << std::hex << std::setw(2) << std::setfill('0')
161             << uint32_t(output[i * output_stride() + pre_padding() + input_channels() + r]);
162         }
163       }
164     }
165   }
166 
167  private:
168   size_t rows_{1};
169   size_t input_channels_{1};
170   size_t pre_padding_{0};
171   size_t post_padding_{0};
172   size_t input_stride_{0};
173   size_t output_stride_{0};
174   size_t iterations_{15};
175 };
176