1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2021 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker // 3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker #pragma once 7*4bdc9457SAndroid Build Coastguard Worker 8*4bdc9457SAndroid Build Coastguard Worker #include <numeric> 9*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 10*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 11*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 12*4bdc9457SAndroid Build Coastguard Worker #include <cstring> 13*4bdc9457SAndroid Build Coastguard Worker #include <vector> 14*4bdc9457SAndroid Build Coastguard Worker 15*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 16*4bdc9457SAndroid Build Coastguard Worker 17*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 18*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h> 19*4bdc9457SAndroid Build Coastguard Worker 20*4bdc9457SAndroid Build Coastguard Worker 21*4bdc9457SAndroid Build Coastguard Worker class TransposeMicrokernelTester { 22*4bdc9457SAndroid Build Coastguard Worker public: element_size(size_t element_size)23*4bdc9457SAndroid Build Coastguard Worker inline TransposeMicrokernelTester& element_size(size_t element_size) { 24*4bdc9457SAndroid Build Coastguard Worker assert(element_size != 0); 25*4bdc9457SAndroid Build Coastguard Worker this->element_size_ = element_size; 26*4bdc9457SAndroid Build Coastguard Worker return *this; 27*4bdc9457SAndroid Build Coastguard Worker } 28*4bdc9457SAndroid Build Coastguard Worker element_size()29*4bdc9457SAndroid Build Coastguard Worker inline size_t element_size() const { return this->element_size_; } 30*4bdc9457SAndroid Build Coastguard Worker block_height(size_t block_height)31*4bdc9457SAndroid Build Coastguard Worker inline TransposeMicrokernelTester& block_height(size_t block_height) { 32*4bdc9457SAndroid Build Coastguard Worker assert(block_height != 0); 33*4bdc9457SAndroid Build Coastguard Worker this->block_height_ = block_height; 34*4bdc9457SAndroid Build Coastguard Worker return *this; 35*4bdc9457SAndroid Build Coastguard Worker } 36*4bdc9457SAndroid Build Coastguard Worker block_height()37*4bdc9457SAndroid Build Coastguard Worker inline size_t block_height() const { return this->block_height_; } 38*4bdc9457SAndroid Build Coastguard Worker block_width(size_t block_width)39*4bdc9457SAndroid Build Coastguard Worker inline TransposeMicrokernelTester& block_width(size_t block_width) { 40*4bdc9457SAndroid Build Coastguard Worker assert(block_width != 0); 41*4bdc9457SAndroid Build Coastguard Worker this->block_width_ = block_width; 42*4bdc9457SAndroid Build Coastguard Worker return *this; 43*4bdc9457SAndroid Build Coastguard Worker } 44*4bdc9457SAndroid Build Coastguard Worker block_width()45*4bdc9457SAndroid Build Coastguard Worker inline size_t block_width() const { return this->block_width_; } 46*4bdc9457SAndroid Build Coastguard Worker input_stride(size_t input_stride)47*4bdc9457SAndroid Build Coastguard Worker inline TransposeMicrokernelTester& input_stride(size_t input_stride) { 48*4bdc9457SAndroid Build Coastguard Worker this->input_stride_ = input_stride; 49*4bdc9457SAndroid Build Coastguard Worker return *this; 50*4bdc9457SAndroid Build Coastguard Worker } 51*4bdc9457SAndroid Build Coastguard Worker input_stride()52*4bdc9457SAndroid Build Coastguard Worker inline size_t input_stride() const { return this->input_stride_; } 53*4bdc9457SAndroid Build Coastguard Worker output_stride(size_t output_stride)54*4bdc9457SAndroid Build Coastguard Worker inline TransposeMicrokernelTester& output_stride(size_t output_stride) { 55*4bdc9457SAndroid Build Coastguard Worker this->output_stride_ = output_stride; 56*4bdc9457SAndroid Build Coastguard Worker return *this; 57*4bdc9457SAndroid Build Coastguard Worker } 58*4bdc9457SAndroid Build Coastguard Worker output_stride()59*4bdc9457SAndroid Build Coastguard Worker inline size_t output_stride() const { return this->output_stride_; } 60*4bdc9457SAndroid Build Coastguard Worker input_element_stride(size_t input_element_stride)61*4bdc9457SAndroid Build Coastguard Worker inline TransposeMicrokernelTester& input_element_stride(size_t input_element_stride) { 62*4bdc9457SAndroid Build Coastguard Worker assert(input_element_stride >= element_size_); 63*4bdc9457SAndroid Build Coastguard Worker this->input_element_stride_ = input_element_stride; 64*4bdc9457SAndroid Build Coastguard Worker return *this; 65*4bdc9457SAndroid Build Coastguard Worker } 66*4bdc9457SAndroid Build Coastguard Worker input_element_stride()67*4bdc9457SAndroid Build Coastguard Worker inline size_t input_element_stride() const { 68*4bdc9457SAndroid Build Coastguard Worker if (input_element_stride_ == 0) { 69*4bdc9457SAndroid Build Coastguard Worker return element_size_; 70*4bdc9457SAndroid Build Coastguard Worker } else { 71*4bdc9457SAndroid Build Coastguard Worker return input_element_stride_; 72*4bdc9457SAndroid Build Coastguard Worker } 73*4bdc9457SAndroid Build Coastguard Worker } 74*4bdc9457SAndroid Build Coastguard Worker output_element_stride(size_t output_element_stride)75*4bdc9457SAndroid Build Coastguard Worker inline TransposeMicrokernelTester& output_element_stride(size_t output_element_stride) { 76*4bdc9457SAndroid Build Coastguard Worker assert(output_element_stride >= element_size_); 77*4bdc9457SAndroid Build Coastguard Worker this->output_element_stride_ = output_element_stride; 78*4bdc9457SAndroid Build Coastguard Worker return *this; 79*4bdc9457SAndroid Build Coastguard Worker } 80*4bdc9457SAndroid Build Coastguard Worker output_element_stride()81*4bdc9457SAndroid Build Coastguard Worker inline size_t output_element_stride() const { 82*4bdc9457SAndroid Build Coastguard Worker if (output_element_stride_ == 0) { 83*4bdc9457SAndroid Build Coastguard Worker return element_size_; 84*4bdc9457SAndroid Build Coastguard Worker } else { 85*4bdc9457SAndroid Build Coastguard Worker return output_element_stride_; 86*4bdc9457SAndroid Build Coastguard Worker } 87*4bdc9457SAndroid Build Coastguard Worker } 88*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)89*4bdc9457SAndroid Build Coastguard Worker inline TransposeMicrokernelTester& iterations(size_t iterations) { 90*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 91*4bdc9457SAndroid Build Coastguard Worker return *this; 92*4bdc9457SAndroid Build Coastguard Worker } 93*4bdc9457SAndroid Build Coastguard Worker iterations()94*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { return this->iterations_; } 95*4bdc9457SAndroid Build Coastguard Worker Test(xnn_transposev_ukernel_function transpose)96*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_transposev_ukernel_function transpose) const { 97*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input(input_stride() * block_height() * input_element_stride() + XNN_EXTRA_BYTES); 98*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output(output_stride() * block_width() * output_element_stride()); 99*4bdc9457SAndroid Build Coastguard Worker std::iota(input.begin(), input.end(), 0); 100*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT8_C(0xA5)); 101*4bdc9457SAndroid Build Coastguard Worker 102*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 103*4bdc9457SAndroid Build Coastguard Worker transpose(input.data(), 104*4bdc9457SAndroid Build Coastguard Worker output.data(), 105*4bdc9457SAndroid Build Coastguard Worker input_stride() * input_element_stride(), 106*4bdc9457SAndroid Build Coastguard Worker output_stride() * output_element_stride(), 107*4bdc9457SAndroid Build Coastguard Worker input_element_stride(), 108*4bdc9457SAndroid Build Coastguard Worker output_element_stride(), 109*4bdc9457SAndroid Build Coastguard Worker element_size(), 110*4bdc9457SAndroid Build Coastguard Worker block_width(), 111*4bdc9457SAndroid Build Coastguard Worker block_height()); 112*4bdc9457SAndroid Build Coastguard Worker 113*4bdc9457SAndroid Build Coastguard Worker // Verify results. 114*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < block_width(); c++) { 115*4bdc9457SAndroid Build Coastguard Worker for (size_t r = 0; r < block_height(); r++) { 116*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(std::memcmp(&input[input_element_stride() * (c+ r * input_stride())], 117*4bdc9457SAndroid Build Coastguard Worker &output[output_element_stride() * (r + c * output_stride())], 118*4bdc9457SAndroid Build Coastguard Worker element_size()), 0) 119*4bdc9457SAndroid Build Coastguard Worker << "at row " << r << " / " << block_height() 120*4bdc9457SAndroid Build Coastguard Worker << ", at column " << c << " / " << block_width(); 121*4bdc9457SAndroid Build Coastguard Worker } 122*4bdc9457SAndroid Build Coastguard Worker } 123*4bdc9457SAndroid Build Coastguard Worker } 124*4bdc9457SAndroid Build Coastguard Worker Test(xnn_x64_transposec_ukernel_function transpose)125*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_x64_transposec_ukernel_function transpose) const { 126*4bdc9457SAndroid Build Coastguard Worker std::vector<uint64_t> input(input_stride() * output_stride() + XNN_EXTRA_BYTES / sizeof(uint64_t)); 127*4bdc9457SAndroid Build Coastguard Worker std::vector<uint64_t> output(input_stride() * output_stride()); 128*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 129*4bdc9457SAndroid Build Coastguard Worker std::iota(input.begin(), input.end(), 0); 130*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT64_C(0xBADC0FFEE0DDF00D)); 131*4bdc9457SAndroid Build Coastguard Worker 132*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 133*4bdc9457SAndroid Build Coastguard Worker transpose(input.data(), 134*4bdc9457SAndroid Build Coastguard Worker output.data(), 135*4bdc9457SAndroid Build Coastguard Worker input_stride() * sizeof(uint64_t), 136*4bdc9457SAndroid Build Coastguard Worker output_stride() * sizeof(uint64_t), 137*4bdc9457SAndroid Build Coastguard Worker block_width(), 138*4bdc9457SAndroid Build Coastguard Worker block_height()); 139*4bdc9457SAndroid Build Coastguard Worker 140*4bdc9457SAndroid Build Coastguard Worker // Verify results. 141*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < block_width(); c++) { 142*4bdc9457SAndroid Build Coastguard Worker for (size_t r = 0; r < block_height(); r++) { 143*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(input[c + r * input_stride()], output[r + c * output_stride()]) 144*4bdc9457SAndroid Build Coastguard Worker << "at row " << r << " / " << block_height() 145*4bdc9457SAndroid Build Coastguard Worker << ", at column " << c << " / " << block_width(); 146*4bdc9457SAndroid Build Coastguard Worker } 147*4bdc9457SAndroid Build Coastguard Worker } 148*4bdc9457SAndroid Build Coastguard Worker } 149*4bdc9457SAndroid Build Coastguard Worker } 150*4bdc9457SAndroid Build Coastguard Worker Test(xnn_x32_transposec_ukernel_function transpose)151*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_x32_transposec_ukernel_function transpose) const { 152*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> input(input_stride() * output_stride() + XNN_EXTRA_BYTES / sizeof(uint32_t)); 153*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> output(input_stride() * output_stride()); 154*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 155*4bdc9457SAndroid Build Coastguard Worker std::iota(input.begin(), input.end(), 0); 156*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT32_C(0xDEADBEEF)); 157*4bdc9457SAndroid Build Coastguard Worker 158*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 159*4bdc9457SAndroid Build Coastguard Worker transpose(input.data(), 160*4bdc9457SAndroid Build Coastguard Worker output.data(), 161*4bdc9457SAndroid Build Coastguard Worker input_stride() * sizeof(uint32_t), 162*4bdc9457SAndroid Build Coastguard Worker output_stride() * sizeof(uint32_t), 163*4bdc9457SAndroid Build Coastguard Worker block_width(), 164*4bdc9457SAndroid Build Coastguard Worker block_height()); 165*4bdc9457SAndroid Build Coastguard Worker 166*4bdc9457SAndroid Build Coastguard Worker // Verify results. 167*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < block_width(); c++) { 168*4bdc9457SAndroid Build Coastguard Worker for (size_t r = 0; r < block_height(); r++) { 169*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(input[c + r * input_stride()], output[r + c * output_stride()]) 170*4bdc9457SAndroid Build Coastguard Worker << "at row " << r << " / " << block_height() 171*4bdc9457SAndroid Build Coastguard Worker << ", at column " << c << " / " << block_width(); 172*4bdc9457SAndroid Build Coastguard Worker } 173*4bdc9457SAndroid Build Coastguard Worker } 174*4bdc9457SAndroid Build Coastguard Worker } 175*4bdc9457SAndroid Build Coastguard Worker } 176*4bdc9457SAndroid Build Coastguard Worker Test(xnn_x24_transposec_ukernel_function transpose)177*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_x24_transposec_ukernel_function transpose) const { 178*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input(input_stride() * output_stride() * element_size() + XNN_EXTRA_BYTES); 179*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output(input_stride() * output_stride() * element_size()); 180*4bdc9457SAndroid Build Coastguard Worker std::iota(input.begin(), input.end(), 0); 181*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT8_C(0xA5)); 182*4bdc9457SAndroid Build Coastguard Worker 183*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 184*4bdc9457SAndroid Build Coastguard Worker transpose(input.data(), 185*4bdc9457SAndroid Build Coastguard Worker output.data(), 186*4bdc9457SAndroid Build Coastguard Worker input_stride() * element_size(), 187*4bdc9457SAndroid Build Coastguard Worker output_stride() * element_size(), 188*4bdc9457SAndroid Build Coastguard Worker block_width(), 189*4bdc9457SAndroid Build Coastguard Worker block_height()); 190*4bdc9457SAndroid Build Coastguard Worker 191*4bdc9457SAndroid Build Coastguard Worker // Verify results. 192*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < block_width(); c++) { 193*4bdc9457SAndroid Build Coastguard Worker for (size_t r = 0; r < block_height(); r++) { 194*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(std::memcmp(&input[element_size() * (c+ r * input_stride())], 195*4bdc9457SAndroid Build Coastguard Worker &output[element_size() * (r + c * output_stride())], 196*4bdc9457SAndroid Build Coastguard Worker element_size()), 0) 197*4bdc9457SAndroid Build Coastguard Worker << "at row " << r << " / " << block_height() 198*4bdc9457SAndroid Build Coastguard Worker << ", at column " << c << " / " << block_width(); 199*4bdc9457SAndroid Build Coastguard Worker } 200*4bdc9457SAndroid Build Coastguard Worker } 201*4bdc9457SAndroid Build Coastguard Worker } 202*4bdc9457SAndroid Build Coastguard Worker Test(xnn_x16_transposec_ukernel_function transpose)203*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_x16_transposec_ukernel_function transpose) const { 204*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> input(input_stride() * output_stride() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 205*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output(input_stride() * output_stride()); 206*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 207*4bdc9457SAndroid Build Coastguard Worker std::iota(input.begin(), input.end(), 0); 208*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0xDEAD)); 209*4bdc9457SAndroid Build Coastguard Worker 210*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 211*4bdc9457SAndroid Build Coastguard Worker transpose(input.data(), 212*4bdc9457SAndroid Build Coastguard Worker output.data(), 213*4bdc9457SAndroid Build Coastguard Worker input_stride() * sizeof(uint16_t), 214*4bdc9457SAndroid Build Coastguard Worker output_stride() * sizeof(uint16_t), 215*4bdc9457SAndroid Build Coastguard Worker block_width(), 216*4bdc9457SAndroid Build Coastguard Worker block_height()); 217*4bdc9457SAndroid Build Coastguard Worker 218*4bdc9457SAndroid Build Coastguard Worker // Verify results. 219*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < block_width(); c++) { 220*4bdc9457SAndroid Build Coastguard Worker for (size_t r = 0; r < block_height(); r++) { 221*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(input[c + r * input_stride()], output[r + c * output_stride()]) 222*4bdc9457SAndroid Build Coastguard Worker << "at row " << r << " / " << block_height() 223*4bdc9457SAndroid Build Coastguard Worker << ", at column " << c << " / " << block_width(); 224*4bdc9457SAndroid Build Coastguard Worker } 225*4bdc9457SAndroid Build Coastguard Worker } 226*4bdc9457SAndroid Build Coastguard Worker } 227*4bdc9457SAndroid Build Coastguard Worker } 228*4bdc9457SAndroid Build Coastguard Worker Test(xnn_x8_transposec_ukernel_function transpose)229*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_x8_transposec_ukernel_function transpose) const { 230*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input(input_stride() * output_stride() + XNN_EXTRA_BYTES); 231*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output(input_stride() * output_stride()); 232*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 233*4bdc9457SAndroid Build Coastguard Worker std::iota(input.begin(), input.end(), 0); 234*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT8_C(0xA5)); 235*4bdc9457SAndroid Build Coastguard Worker 236*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 237*4bdc9457SAndroid Build Coastguard Worker transpose(input.data(), 238*4bdc9457SAndroid Build Coastguard Worker output.data(), 239*4bdc9457SAndroid Build Coastguard Worker input_stride() * sizeof(uint8_t), 240*4bdc9457SAndroid Build Coastguard Worker output_stride() * sizeof(uint8_t), 241*4bdc9457SAndroid Build Coastguard Worker block_width(), 242*4bdc9457SAndroid Build Coastguard Worker block_height()); 243*4bdc9457SAndroid Build Coastguard Worker 244*4bdc9457SAndroid Build Coastguard Worker // Verify results. 245*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < block_width(); c++) { 246*4bdc9457SAndroid Build Coastguard Worker for (size_t r = 0; r < block_height(); r++) { 247*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ((int)input[c + r * input_stride()], (int)output[r + c * output_stride()]) 248*4bdc9457SAndroid Build Coastguard Worker << "at row " << r << " / " << block_height() 249*4bdc9457SAndroid Build Coastguard Worker << ", at column " << c << " / " << block_width(); 250*4bdc9457SAndroid Build Coastguard Worker } 251*4bdc9457SAndroid Build Coastguard Worker } 252*4bdc9457SAndroid Build Coastguard Worker } 253*4bdc9457SAndroid Build Coastguard Worker } 254*4bdc9457SAndroid Build Coastguard Worker 255*4bdc9457SAndroid Build Coastguard Worker private: 256*4bdc9457SAndroid Build Coastguard Worker size_t element_size_ = 1; 257*4bdc9457SAndroid Build Coastguard Worker size_t input_stride_ = 1; 258*4bdc9457SAndroid Build Coastguard Worker size_t output_stride_ = 1; 259*4bdc9457SAndroid Build Coastguard Worker size_t input_element_stride_ = 0; 260*4bdc9457SAndroid Build Coastguard Worker size_t output_element_stride_ = 0; 261*4bdc9457SAndroid Build Coastguard Worker size_t block_height_ = 1; 262*4bdc9457SAndroid Build Coastguard Worker size_t block_width_ = 1; 263*4bdc9457SAndroid Build Coastguard Worker size_t iterations_ = 15; 264*4bdc9457SAndroid Build Coastguard Worker }; 265