1 // Copyright 2019 Google LLC 2 // 3 // This source code is licensed under the BSD-style license found in the 4 // LICENSE file in the root directory of this source tree. 5 6 #pragma once 7 8 #include <gtest/gtest.h> 9 10 #include <array> 11 #include <algorithm> 12 #include <cassert> 13 #include <cstddef> 14 #include <cstdlib> 15 #include <functional> 16 #include <random> 17 #include <vector> 18 19 #include <xnnpack.h> 20 #include <xnnpack/microfnptr.h> 21 22 23 class PadMicrokernelTester { 24 public: rows(size_t rows)25 inline PadMicrokernelTester& rows(size_t rows) { 26 assert(rows != 0); 27 this->rows_ = rows; 28 return *this; 29 } 30 rows()31 inline size_t rows() const { 32 return this->rows_; 33 } 34 input_channels(size_t input_channels)35 inline PadMicrokernelTester& input_channels(size_t input_channels) { 36 assert(input_channels != 0); 37 this->input_channels_ = input_channels; 38 return *this; 39 } 40 input_channels()41 inline size_t input_channels() const { 42 return this->input_channels_; 43 } 44 pre_padding(size_t pre_padding)45 inline PadMicrokernelTester& pre_padding(size_t pre_padding) { 46 this->pre_padding_ = pre_padding; 47 return *this; 48 } 49 pre_padding()50 inline size_t pre_padding() const { 51 return this->pre_padding_; 52 } 53 post_padding(size_t post_padding)54 inline PadMicrokernelTester& post_padding(size_t post_padding) { 55 this->post_padding_ = post_padding; 56 return *this; 57 } 58 post_padding()59 inline size_t post_padding() const { 60 return this->post_padding_; 61 } 62 output_channels()63 inline size_t output_channels() const { 64 return pre_padding() + input_channels() + post_padding(); 65 } 66 input_stride(size_t input_stride)67 inline PadMicrokernelTester& input_stride(size_t input_stride) { 68 assert(input_stride != 0); 69 this->input_stride_ = input_stride; 70 return *this; 71 } 72 input_stride()73 inline size_t input_stride() const { 74 if (this->input_stride_ == 0) { 75 return input_channels(); 76 } else { 77 assert(this->input_stride_ >= input_channels()); 78 return this->input_stride_; 79 } 80 } 81 output_stride(size_t output_stride)82 inline PadMicrokernelTester& output_stride(size_t output_stride) { 83 assert(output_stride != 0); 84 this->output_stride_ = output_stride; 85 return *this; 86 } 87 output_stride()88 inline size_t output_stride() const { 89 if (this->output_stride_ == 0) { 90 return pre_padding() + input_channels() + post_padding(); 91 } else { 92 assert(this->output_stride_ >= pre_padding() + input_channels() + post_padding()); 93 return this->output_stride_; 94 } 95 } 96 iterations(size_t iterations)97 inline PadMicrokernelTester& iterations(size_t iterations) { 98 this->iterations_ = iterations; 99 return *this; 100 } 101 iterations()102 inline size_t iterations() const { 103 return this->iterations_; 104 } 105 Test(xnn_pad_ukernel_function pad)106 void Test(xnn_pad_ukernel_function pad) const { 107 std::random_device random_device; 108 auto rng = std::mt19937(random_device()); 109 auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng); 110 111 std::vector<uint8_t> input(input_channels() + (rows() - 1) * input_stride() + XNN_EXTRA_BYTES / sizeof(uint8_t)); 112 std::vector<uint8_t> output((pre_padding() + input_channels() + post_padding()) + (rows() - 1) * output_stride()); 113 for (size_t iteration = 0; iteration < iterations(); iteration++) { 114 std::generate(input.begin(), input.end(), std::ref(u8rng)); 115 std::generate(output.begin(), output.end(), std::ref(u8rng)); 116 std::array<uint8_t, 4> fill_pattern; 117 std::generate(fill_pattern.begin(), fill_pattern.end(), std::ref(u8rng)); 118 uint32_t fill_value = 0; 119 memcpy(&fill_value, fill_pattern.data(), sizeof(fill_value)); 120 121 // Call optimized micro-kernel. 122 pad( 123 rows(), 124 input_channels() * sizeof(uint8_t), 125 pre_padding() * sizeof(uint8_t), 126 post_padding() * sizeof(uint8_t), 127 input.data(), input_stride() * sizeof(uint8_t), 128 output.data(), output_stride() * sizeof(uint8_t), 129 fill_value); 130 131 // Verify results. 132 for (size_t i = 0; i < rows(); i++) { 133 for (size_t l = 0; l < pre_padding(); l++) { 134 ASSERT_EQ( 135 uint32_t(output[i * output_stride() + l]), 136 uint32_t(fill_pattern[l % fill_pattern.size()])) 137 << "at row " << i << " / " << rows() << ", channel " << i << " / " << output_channels() 138 << " (" << pre_padding() << " + " << input_channels() << " + " << post_padding() << ")" 139 << ", fill value 0x" << std::hex << std::setw(8) << std::setfill('0') << fill_value 140 << ", output value 0x" << std::hex << std::setw(2) << std::setfill('0') 141 << uint32_t(output[i * output_stride() + l]); 142 } 143 for (size_t c = 0; c < input_channels(); c++) { 144 ASSERT_EQ( 145 uint32_t(output[i * output_stride() + pre_padding() + c]), 146 uint32_t(input[i * input_stride() + c])) 147 << "at row " << i << " / " << rows() << ", channel " << i << " / " << output_channels() 148 << " (" << pre_padding() << " + " << input_channels() << " + " << post_padding() << ")" 149 << ", fill value 0x" << std::hex << std::setw(8) << std::setfill('0') << fill_value 150 << ", output value 0x" << std::hex << std::setw(2) << std::setfill('0') 151 << uint32_t(output[i * output_stride() + pre_padding() + c]); 152 } 153 for (size_t r = 0; r < post_padding(); r++) { 154 ASSERT_EQ( 155 uint32_t(output[i * output_stride() + pre_padding() + input_channels() + r]), 156 uint32_t(fill_pattern[r % fill_pattern.size()])) 157 << "at row " << i << " / " << rows() << ", channel " << i << " / " << output_channels() 158 << " (" << pre_padding() << " + " << input_channels() << " + " << post_padding() << ")" 159 << ", fill value 0x" << std::hex << std::setw(8) << std::setfill('0') << fill_value 160 << ", output value 0x" << std::hex << std::setw(2) << std::setfill('0') 161 << uint32_t(output[i * output_stride() + pre_padding() + input_channels() + r]); 162 } 163 } 164 } 165 } 166 167 private: 168 size_t rows_{1}; 169 size_t input_channels_{1}; 170 size_t pre_padding_{0}; 171 size_t post_padding_{0}; 172 size_t input_stride_{0}; 173 size_t output_stride_{0}; 174 size_t iterations_{15}; 175 }; 176