1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker // 3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker #pragma once 7*4bdc9457SAndroid Build Coastguard Worker 8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 9*4bdc9457SAndroid Build Coastguard Worker 10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 11*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 12*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 13*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 14*4bdc9457SAndroid Build Coastguard Worker #include <functional> 15*4bdc9457SAndroid Build Coastguard Worker #include <limits> 16*4bdc9457SAndroid Build Coastguard Worker #include <random> 17*4bdc9457SAndroid Build Coastguard Worker #include <vector> 18*4bdc9457SAndroid Build Coastguard Worker 19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 20*4bdc9457SAndroid Build Coastguard Worker 21*4bdc9457SAndroid Build Coastguard Worker 22*4bdc9457SAndroid Build Coastguard Worker class UnpoolingOperatorTester { 23*4bdc9457SAndroid Build Coastguard Worker public: padding(uint32_t padding)24*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& padding(uint32_t padding) { 25*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding; 26*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding; 27*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding; 28*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding; 29*4bdc9457SAndroid Build Coastguard Worker return *this; 30*4bdc9457SAndroid Build Coastguard Worker } 31*4bdc9457SAndroid Build Coastguard Worker padding(uint32_t padding_height,uint32_t padding_width)32*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& padding(uint32_t padding_height, uint32_t padding_width) { 33*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding_height; 34*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding_width; 35*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding_height; 36*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding_width; 37*4bdc9457SAndroid Build Coastguard Worker return *this; 38*4bdc9457SAndroid Build Coastguard Worker } 39*4bdc9457SAndroid Build Coastguard Worker padding_height(uint32_t padding_height)40*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& padding_height(uint32_t padding_height) { 41*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding_height; 42*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding_height; 43*4bdc9457SAndroid Build Coastguard Worker return *this; 44*4bdc9457SAndroid Build Coastguard Worker } 45*4bdc9457SAndroid Build Coastguard Worker padding_width(uint32_t padding_width)46*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& padding_width(uint32_t padding_width) { 47*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding_width; 48*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding_width; 49*4bdc9457SAndroid Build Coastguard Worker return *this; 50*4bdc9457SAndroid Build Coastguard Worker } 51*4bdc9457SAndroid Build Coastguard Worker padding_top(uint32_t padding_top)52*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& padding_top(uint32_t padding_top) { 53*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding_top; 54*4bdc9457SAndroid Build Coastguard Worker return *this; 55*4bdc9457SAndroid Build Coastguard Worker } 56*4bdc9457SAndroid Build Coastguard Worker padding_top()57*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_top() const { 58*4bdc9457SAndroid Build Coastguard Worker return this->padding_top_; 59*4bdc9457SAndroid Build Coastguard Worker } 60*4bdc9457SAndroid Build Coastguard Worker padding_right(uint32_t padding_right)61*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& padding_right(uint32_t padding_right) { 62*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding_right; 63*4bdc9457SAndroid Build Coastguard Worker return *this; 64*4bdc9457SAndroid Build Coastguard Worker } 65*4bdc9457SAndroid Build Coastguard Worker padding_right()66*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_right() const { 67*4bdc9457SAndroid Build Coastguard Worker return this->padding_right_; 68*4bdc9457SAndroid Build Coastguard Worker } 69*4bdc9457SAndroid Build Coastguard Worker padding_bottom(uint32_t padding_bottom)70*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& padding_bottom(uint32_t padding_bottom) { 71*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding_bottom; 72*4bdc9457SAndroid Build Coastguard Worker return *this; 73*4bdc9457SAndroid Build Coastguard Worker } 74*4bdc9457SAndroid Build Coastguard Worker padding_bottom()75*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_bottom() const { 76*4bdc9457SAndroid Build Coastguard Worker return this->padding_bottom_; 77*4bdc9457SAndroid Build Coastguard Worker } 78*4bdc9457SAndroid Build Coastguard Worker padding_left(uint32_t padding_left)79*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& padding_left(uint32_t padding_left) { 80*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding_left; 81*4bdc9457SAndroid Build Coastguard Worker return *this; 82*4bdc9457SAndroid Build Coastguard Worker } 83*4bdc9457SAndroid Build Coastguard Worker padding_left()84*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_left() const { 85*4bdc9457SAndroid Build Coastguard Worker return this->padding_left_; 86*4bdc9457SAndroid Build Coastguard Worker } 87*4bdc9457SAndroid Build Coastguard Worker input_size(size_t input_height,size_t input_width)88*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& input_size(size_t input_height, size_t input_width) { 89*4bdc9457SAndroid Build Coastguard Worker assert(input_height >= 1); 90*4bdc9457SAndroid Build Coastguard Worker assert(input_width >= 1); 91*4bdc9457SAndroid Build Coastguard Worker this->input_height_ = input_height; 92*4bdc9457SAndroid Build Coastguard Worker this->input_width_ = input_width; 93*4bdc9457SAndroid Build Coastguard Worker return *this; 94*4bdc9457SAndroid Build Coastguard Worker } 95*4bdc9457SAndroid Build Coastguard Worker input_height(size_t input_height)96*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& input_height(size_t input_height) { 97*4bdc9457SAndroid Build Coastguard Worker assert(input_height >= 1); 98*4bdc9457SAndroid Build Coastguard Worker this->input_height_ = input_height; 99*4bdc9457SAndroid Build Coastguard Worker return *this; 100*4bdc9457SAndroid Build Coastguard Worker } 101*4bdc9457SAndroid Build Coastguard Worker input_height()102*4bdc9457SAndroid Build Coastguard Worker inline size_t input_height() const { 103*4bdc9457SAndroid Build Coastguard Worker return this->input_height_; 104*4bdc9457SAndroid Build Coastguard Worker } 105*4bdc9457SAndroid Build Coastguard Worker input_width(size_t input_width)106*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& input_width(size_t input_width) { 107*4bdc9457SAndroid Build Coastguard Worker assert(input_width >= 1); 108*4bdc9457SAndroid Build Coastguard Worker this->input_width_ = input_width; 109*4bdc9457SAndroid Build Coastguard Worker return *this; 110*4bdc9457SAndroid Build Coastguard Worker } 111*4bdc9457SAndroid Build Coastguard Worker input_width()112*4bdc9457SAndroid Build Coastguard Worker inline size_t input_width() const { 113*4bdc9457SAndroid Build Coastguard Worker return this->input_width_; 114*4bdc9457SAndroid Build Coastguard Worker } 115*4bdc9457SAndroid Build Coastguard Worker channels(size_t channels)116*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& channels(size_t channels) { 117*4bdc9457SAndroid Build Coastguard Worker assert(channels != 0); 118*4bdc9457SAndroid Build Coastguard Worker this->channels_ = channels; 119*4bdc9457SAndroid Build Coastguard Worker return *this; 120*4bdc9457SAndroid Build Coastguard Worker } 121*4bdc9457SAndroid Build Coastguard Worker channels()122*4bdc9457SAndroid Build Coastguard Worker inline size_t channels() const { 123*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 124*4bdc9457SAndroid Build Coastguard Worker } 125*4bdc9457SAndroid Build Coastguard Worker batch_size(size_t batch_size)126*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& batch_size(size_t batch_size) { 127*4bdc9457SAndroid Build Coastguard Worker assert(batch_size != 0); 128*4bdc9457SAndroid Build Coastguard Worker this->batch_size_ = batch_size; 129*4bdc9457SAndroid Build Coastguard Worker return *this; 130*4bdc9457SAndroid Build Coastguard Worker } 131*4bdc9457SAndroid Build Coastguard Worker batch_size()132*4bdc9457SAndroid Build Coastguard Worker inline size_t batch_size() const { 133*4bdc9457SAndroid Build Coastguard Worker return this->batch_size_; 134*4bdc9457SAndroid Build Coastguard Worker } 135*4bdc9457SAndroid Build Coastguard Worker pooling_size(uint32_t pooling_size)136*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& pooling_size(uint32_t pooling_size) { 137*4bdc9457SAndroid Build Coastguard Worker assert(pooling_size >= 1); 138*4bdc9457SAndroid Build Coastguard Worker this->pooling_height_ = pooling_size; 139*4bdc9457SAndroid Build Coastguard Worker this->pooling_width_ = pooling_size; 140*4bdc9457SAndroid Build Coastguard Worker return *this; 141*4bdc9457SAndroid Build Coastguard Worker } 142*4bdc9457SAndroid Build Coastguard Worker pooling_size(uint32_t pooling_height,uint32_t pooling_width)143*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& pooling_size(uint32_t pooling_height, uint32_t pooling_width) { 144*4bdc9457SAndroid Build Coastguard Worker assert(pooling_height >= 1); 145*4bdc9457SAndroid Build Coastguard Worker assert(pooling_width >= 1); 146*4bdc9457SAndroid Build Coastguard Worker this->pooling_height_ = pooling_height; 147*4bdc9457SAndroid Build Coastguard Worker this->pooling_width_ = pooling_width; 148*4bdc9457SAndroid Build Coastguard Worker return *this; 149*4bdc9457SAndroid Build Coastguard Worker } 150*4bdc9457SAndroid Build Coastguard Worker pooling_height(uint32_t pooling_height)151*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& pooling_height(uint32_t pooling_height) { 152*4bdc9457SAndroid Build Coastguard Worker assert(pooling_height >= 1); 153*4bdc9457SAndroid Build Coastguard Worker this->pooling_height_ = pooling_height; 154*4bdc9457SAndroid Build Coastguard Worker return *this; 155*4bdc9457SAndroid Build Coastguard Worker } 156*4bdc9457SAndroid Build Coastguard Worker pooling_height()157*4bdc9457SAndroid Build Coastguard Worker inline uint32_t pooling_height() const { 158*4bdc9457SAndroid Build Coastguard Worker return this->pooling_height_; 159*4bdc9457SAndroid Build Coastguard Worker } 160*4bdc9457SAndroid Build Coastguard Worker pooling_width(uint32_t pooling_width)161*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& pooling_width(uint32_t pooling_width) { 162*4bdc9457SAndroid Build Coastguard Worker assert(pooling_width >= 1); 163*4bdc9457SAndroid Build Coastguard Worker this->pooling_width_ = pooling_width; 164*4bdc9457SAndroid Build Coastguard Worker return *this; 165*4bdc9457SAndroid Build Coastguard Worker } 166*4bdc9457SAndroid Build Coastguard Worker pooling_width()167*4bdc9457SAndroid Build Coastguard Worker inline uint32_t pooling_width() const { 168*4bdc9457SAndroid Build Coastguard Worker return this->pooling_width_; 169*4bdc9457SAndroid Build Coastguard Worker } 170*4bdc9457SAndroid Build Coastguard Worker output_height()171*4bdc9457SAndroid Build Coastguard Worker inline size_t output_height() const { 172*4bdc9457SAndroid Build Coastguard Worker const size_t padding_height = padding_top() + padding_bottom(); 173*4bdc9457SAndroid Build Coastguard Worker return std::max<size_t>(input_height() * pooling_height(), padding_height) - padding_height; 174*4bdc9457SAndroid Build Coastguard Worker } 175*4bdc9457SAndroid Build Coastguard Worker output_width()176*4bdc9457SAndroid Build Coastguard Worker inline size_t output_width() const { 177*4bdc9457SAndroid Build Coastguard Worker const size_t padding_width = padding_left() + padding_right(); 178*4bdc9457SAndroid Build Coastguard Worker return std::max<size_t>(input_width() * pooling_width(), padding_width) - padding_width; 179*4bdc9457SAndroid Build Coastguard Worker } 180*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride(size_t input_pixel_stride)181*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& input_pixel_stride(size_t input_pixel_stride) { 182*4bdc9457SAndroid Build Coastguard Worker assert(input_pixel_stride != 0); 183*4bdc9457SAndroid Build Coastguard Worker this->input_pixel_stride_ = input_pixel_stride; 184*4bdc9457SAndroid Build Coastguard Worker return *this; 185*4bdc9457SAndroid Build Coastguard Worker } 186*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride()187*4bdc9457SAndroid Build Coastguard Worker inline size_t input_pixel_stride() const { 188*4bdc9457SAndroid Build Coastguard Worker if (this->input_pixel_stride_ == 0) { 189*4bdc9457SAndroid Build Coastguard Worker return channels(); 190*4bdc9457SAndroid Build Coastguard Worker } else { 191*4bdc9457SAndroid Build Coastguard Worker assert(this->input_pixel_stride_ >= channels()); 192*4bdc9457SAndroid Build Coastguard Worker return this->input_pixel_stride_; 193*4bdc9457SAndroid Build Coastguard Worker } 194*4bdc9457SAndroid Build Coastguard Worker } 195*4bdc9457SAndroid Build Coastguard Worker output_pixel_stride(size_t output_pixel_stride)196*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& output_pixel_stride(size_t output_pixel_stride) { 197*4bdc9457SAndroid Build Coastguard Worker assert(output_pixel_stride != 0); 198*4bdc9457SAndroid Build Coastguard Worker this->output_pixel_stride_ = output_pixel_stride; 199*4bdc9457SAndroid Build Coastguard Worker return *this; 200*4bdc9457SAndroid Build Coastguard Worker } 201*4bdc9457SAndroid Build Coastguard Worker output_pixel_stride()202*4bdc9457SAndroid Build Coastguard Worker inline size_t output_pixel_stride() const { 203*4bdc9457SAndroid Build Coastguard Worker if (this->output_pixel_stride_ == 0) { 204*4bdc9457SAndroid Build Coastguard Worker return channels(); 205*4bdc9457SAndroid Build Coastguard Worker } else { 206*4bdc9457SAndroid Build Coastguard Worker assert(this->output_pixel_stride_ >= channels()); 207*4bdc9457SAndroid Build Coastguard Worker return this->output_pixel_stride_; 208*4bdc9457SAndroid Build Coastguard Worker } 209*4bdc9457SAndroid Build Coastguard Worker } 210*4bdc9457SAndroid Build Coastguard Worker next_input_size(uint32_t next_input_height,uint32_t next_input_width)211*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& next_input_size(uint32_t next_input_height, uint32_t next_input_width) { 212*4bdc9457SAndroid Build Coastguard Worker assert(next_input_height >= 1); 213*4bdc9457SAndroid Build Coastguard Worker assert(next_input_width >= 1); 214*4bdc9457SAndroid Build Coastguard Worker this->next_input_height_ = next_input_height; 215*4bdc9457SAndroid Build Coastguard Worker this->next_input_width_ = next_input_width; 216*4bdc9457SAndroid Build Coastguard Worker return *this; 217*4bdc9457SAndroid Build Coastguard Worker } 218*4bdc9457SAndroid Build Coastguard Worker next_input_height(uint32_t next_input_height)219*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& next_input_height(uint32_t next_input_height) { 220*4bdc9457SAndroid Build Coastguard Worker assert(next_input_height >= 1); 221*4bdc9457SAndroid Build Coastguard Worker this->next_input_height_ = next_input_height; 222*4bdc9457SAndroid Build Coastguard Worker return *this; 223*4bdc9457SAndroid Build Coastguard Worker } 224*4bdc9457SAndroid Build Coastguard Worker next_input_height()225*4bdc9457SAndroid Build Coastguard Worker inline uint32_t next_input_height() const { 226*4bdc9457SAndroid Build Coastguard Worker if (this->next_input_height_ == 0) { 227*4bdc9457SAndroid Build Coastguard Worker return input_height(); 228*4bdc9457SAndroid Build Coastguard Worker } else { 229*4bdc9457SAndroid Build Coastguard Worker return this->next_input_height_; 230*4bdc9457SAndroid Build Coastguard Worker } 231*4bdc9457SAndroid Build Coastguard Worker } 232*4bdc9457SAndroid Build Coastguard Worker next_input_width(uint32_t next_input_width)233*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& next_input_width(uint32_t next_input_width) { 234*4bdc9457SAndroid Build Coastguard Worker assert(next_input_width >= 1); 235*4bdc9457SAndroid Build Coastguard Worker this->next_input_width_ = next_input_width; 236*4bdc9457SAndroid Build Coastguard Worker return *this; 237*4bdc9457SAndroid Build Coastguard Worker } 238*4bdc9457SAndroid Build Coastguard Worker next_input_width()239*4bdc9457SAndroid Build Coastguard Worker inline uint32_t next_input_width() const { 240*4bdc9457SAndroid Build Coastguard Worker if (this->next_input_width_ == 0) { 241*4bdc9457SAndroid Build Coastguard Worker return input_width(); 242*4bdc9457SAndroid Build Coastguard Worker } else { 243*4bdc9457SAndroid Build Coastguard Worker return this->next_input_width_; 244*4bdc9457SAndroid Build Coastguard Worker } 245*4bdc9457SAndroid Build Coastguard Worker } 246*4bdc9457SAndroid Build Coastguard Worker next_output_height()247*4bdc9457SAndroid Build Coastguard Worker inline size_t next_output_height() const { 248*4bdc9457SAndroid Build Coastguard Worker const size_t padding_height = padding_top() + padding_bottom(); 249*4bdc9457SAndroid Build Coastguard Worker return std::max<size_t>(next_input_height() * pooling_height(), padding_height) - padding_height; 250*4bdc9457SAndroid Build Coastguard Worker } 251*4bdc9457SAndroid Build Coastguard Worker next_output_width()252*4bdc9457SAndroid Build Coastguard Worker inline size_t next_output_width() const { 253*4bdc9457SAndroid Build Coastguard Worker const size_t padding_width = padding_left() + padding_right(); 254*4bdc9457SAndroid Build Coastguard Worker return std::max<size_t>(next_input_width() * pooling_width(), padding_width) - padding_width; 255*4bdc9457SAndroid Build Coastguard Worker } 256*4bdc9457SAndroid Build Coastguard Worker next_batch_size(size_t next_batch_size)257*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& next_batch_size(size_t next_batch_size) { 258*4bdc9457SAndroid Build Coastguard Worker assert(next_batch_size >= 1); 259*4bdc9457SAndroid Build Coastguard Worker this->next_batch_size_ = next_batch_size; 260*4bdc9457SAndroid Build Coastguard Worker return *this; 261*4bdc9457SAndroid Build Coastguard Worker } 262*4bdc9457SAndroid Build Coastguard Worker next_batch_size()263*4bdc9457SAndroid Build Coastguard Worker inline size_t next_batch_size() const { 264*4bdc9457SAndroid Build Coastguard Worker if (this->next_batch_size_ == 0) { 265*4bdc9457SAndroid Build Coastguard Worker return batch_size(); 266*4bdc9457SAndroid Build Coastguard Worker } else { 267*4bdc9457SAndroid Build Coastguard Worker return this->next_batch_size_; 268*4bdc9457SAndroid Build Coastguard Worker } 269*4bdc9457SAndroid Build Coastguard Worker } 270*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)271*4bdc9457SAndroid Build Coastguard Worker inline UnpoolingOperatorTester& iterations(size_t iterations) { 272*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 273*4bdc9457SAndroid Build Coastguard Worker return *this; 274*4bdc9457SAndroid Build Coastguard Worker } 275*4bdc9457SAndroid Build Coastguard Worker iterations()276*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 277*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 278*4bdc9457SAndroid Build Coastguard Worker } 279*4bdc9457SAndroid Build Coastguard Worker TestX32()280*4bdc9457SAndroid Build Coastguard Worker void TestX32() const { 281*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 282*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 283*4bdc9457SAndroid Build Coastguard Worker auto u32rng = std::bind(std::uniform_int_distribution<uint32_t>(), std::ref(rng)); 284*4bdc9457SAndroid Build Coastguard Worker auto idx_rng = std::bind(std::uniform_int_distribution<uint32_t>(0, pooling_height() * pooling_width() - 1), std::ref(rng)); 285*4bdc9457SAndroid Build Coastguard Worker 286*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> input((batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels()); 287*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> index(batch_size() * input_height() * input_width() * channels()); 288*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels()); 289*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> output_ref(batch_size() * output_height() * output_width() * channels()); 290*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 291*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), std::ref(u32rng)); 292*4bdc9457SAndroid Build Coastguard Worker std::generate(index.begin(), index.end(), std::ref(idx_rng)); 293*4bdc9457SAndroid Build Coastguard Worker std::generate(output.begin(), output.end(), std::ref(u32rng)); 294*4bdc9457SAndroid Build Coastguard Worker 295*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 296*4bdc9457SAndroid Build Coastguard Worker std::fill(output_ref.begin(), output_ref.end(), 0); 297*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 298*4bdc9457SAndroid Build Coastguard Worker for (size_t iy = 0; iy < input_height(); iy++) { 299*4bdc9457SAndroid Build Coastguard Worker for (size_t ix = 0; ix < input_width(); ix++) { 300*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 301*4bdc9457SAndroid Build Coastguard Worker const uint32_t pooling_index = index[((i * input_height() + iy) * input_width() + ix) * channels() + c]; 302*4bdc9457SAndroid Build Coastguard Worker const uint32_t py = pooling_index % pooling_height(); 303*4bdc9457SAndroid Build Coastguard Worker const uint32_t px = pooling_index / pooling_height(); 304*4bdc9457SAndroid Build Coastguard Worker const size_t oy = std::min<size_t>(std::max<size_t>(iy * pooling_height() + py, padding_top()) - padding_top(), output_height() - 1); 305*4bdc9457SAndroid Build Coastguard Worker const size_t ox = std::min<size_t>(std::max<size_t>(ix * pooling_width() + px, padding_left()) - padding_left(), output_width() - 1); 306*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = 307*4bdc9457SAndroid Build Coastguard Worker input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c]; 308*4bdc9457SAndroid Build Coastguard Worker } 309*4bdc9457SAndroid Build Coastguard Worker } 310*4bdc9457SAndroid Build Coastguard Worker } 311*4bdc9457SAndroid Build Coastguard Worker } 312*4bdc9457SAndroid Build Coastguard Worker 313*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Unpooling operator. 314*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 315*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t unpooling_op = nullptr; 316*4bdc9457SAndroid Build Coastguard Worker 317*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 318*4bdc9457SAndroid Build Coastguard Worker xnn_create_unpooling2d_nhwc_x32( 319*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(), 320*4bdc9457SAndroid Build Coastguard Worker pooling_height(), pooling_width(), 321*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 322*4bdc9457SAndroid Build Coastguard Worker 0, &unpooling_op)); 323*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, unpooling_op); 324*4bdc9457SAndroid Build Coastguard Worker 325*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete unpooling_op. 326*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_unpooling_op(unpooling_op, xnn_delete_operator); 327*4bdc9457SAndroid Build Coastguard Worker 328*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 329*4bdc9457SAndroid Build Coastguard Worker xnn_setup_unpooling2d_nhwc_x32( 330*4bdc9457SAndroid Build Coastguard Worker unpooling_op, 331*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 332*4bdc9457SAndroid Build Coastguard Worker input.data(), index.data(), output.data(), 333*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 334*4bdc9457SAndroid Build Coastguard Worker 335*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 336*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(unpooling_op, nullptr /* thread pool */)); 337*4bdc9457SAndroid Build Coastguard Worker 338*4bdc9457SAndroid Build Coastguard Worker // Verify results. 339*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 340*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 341*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 342*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 343*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(output_ref[((i * output_height() + y) * output_width() + x) * channels() + c], 344*4bdc9457SAndroid Build Coastguard Worker output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]) << 345*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 346*4bdc9457SAndroid Build Coastguard Worker } 347*4bdc9457SAndroid Build Coastguard Worker } 348*4bdc9457SAndroid Build Coastguard Worker } 349*4bdc9457SAndroid Build Coastguard Worker } 350*4bdc9457SAndroid Build Coastguard Worker } 351*4bdc9457SAndroid Build Coastguard Worker } 352*4bdc9457SAndroid Build Coastguard Worker TestSetupX32()353*4bdc9457SAndroid Build Coastguard Worker void TestSetupX32() const { 354*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 355*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 356*4bdc9457SAndroid Build Coastguard Worker auto u32rng = std::bind(std::uniform_int_distribution<uint32_t>(), std::ref(rng)); 357*4bdc9457SAndroid Build Coastguard Worker auto idx_rng = std::bind(std::uniform_int_distribution<uint32_t>(0, pooling_height() * pooling_width() - 1), std::ref(rng)); 358*4bdc9457SAndroid Build Coastguard Worker 359*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> input(std::max<size_t>( 360*4bdc9457SAndroid Build Coastguard Worker (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels(), 361*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_input_height() * next_input_width() - 1) * input_pixel_stride() + channels())); 362*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> index(std::max<size_t>( 363*4bdc9457SAndroid Build Coastguard Worker batch_size() * input_height() * input_width() * channels(), 364*4bdc9457SAndroid Build Coastguard Worker next_batch_size() * next_input_height() * next_input_width() * channels())); 365*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> output(std::max<size_t>( 366*4bdc9457SAndroid Build Coastguard Worker (batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels(), 367*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_output_height() * next_output_width() - 1) * output_pixel_stride() * channels())); 368*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> output_ref(batch_size() * output_height() * output_width() * channels()); 369*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * channels()); 370*4bdc9457SAndroid Build Coastguard Worker 371*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 372*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), std::ref(u32rng)); 373*4bdc9457SAndroid Build Coastguard Worker std::generate(index.begin(), index.end(), std::ref(idx_rng)); 374*4bdc9457SAndroid Build Coastguard Worker std::generate(output.begin(), output.end(), std::ref(u32rng)); 375*4bdc9457SAndroid Build Coastguard Worker 376*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 377*4bdc9457SAndroid Build Coastguard Worker std::fill(output_ref.begin(), output_ref.end(), 0); 378*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 379*4bdc9457SAndroid Build Coastguard Worker for (size_t iy = 0; iy < input_height(); iy++) { 380*4bdc9457SAndroid Build Coastguard Worker for (size_t ix = 0; ix < input_width(); ix++) { 381*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 382*4bdc9457SAndroid Build Coastguard Worker const uint32_t pooling_index = index[((i * input_height() + iy) * input_width() + ix) * channels() + c]; 383*4bdc9457SAndroid Build Coastguard Worker const uint32_t py = pooling_index % pooling_height(); 384*4bdc9457SAndroid Build Coastguard Worker const uint32_t px = pooling_index / pooling_height(); 385*4bdc9457SAndroid Build Coastguard Worker const size_t oy = std::min<size_t>(std::max<size_t>(iy * pooling_height() + py, padding_top()) - padding_top(), output_height() - 1); 386*4bdc9457SAndroid Build Coastguard Worker const size_t ox = std::min<size_t>(std::max<size_t>(ix * pooling_width() + px, padding_left()) - padding_left(), output_width() - 1); 387*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = 388*4bdc9457SAndroid Build Coastguard Worker input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c]; 389*4bdc9457SAndroid Build Coastguard Worker } 390*4bdc9457SAndroid Build Coastguard Worker } 391*4bdc9457SAndroid Build Coastguard Worker } 392*4bdc9457SAndroid Build Coastguard Worker } 393*4bdc9457SAndroid Build Coastguard Worker 394*4bdc9457SAndroid Build Coastguard Worker // Create, setup, and run Unpooling operator once. 395*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 396*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t unpooling_op = nullptr; 397*4bdc9457SAndroid Build Coastguard Worker 398*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 399*4bdc9457SAndroid Build Coastguard Worker xnn_create_unpooling2d_nhwc_x32( 400*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(), 401*4bdc9457SAndroid Build Coastguard Worker pooling_height(), pooling_width(), 402*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 403*4bdc9457SAndroid Build Coastguard Worker 0, &unpooling_op)); 404*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, unpooling_op); 405*4bdc9457SAndroid Build Coastguard Worker 406*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete unpooling_op. 407*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_unpooling_op(unpooling_op, xnn_delete_operator); 408*4bdc9457SAndroid Build Coastguard Worker 409*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 410*4bdc9457SAndroid Build Coastguard Worker xnn_setup_unpooling2d_nhwc_x32( 411*4bdc9457SAndroid Build Coastguard Worker unpooling_op, 412*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 413*4bdc9457SAndroid Build Coastguard Worker input.data(), index.data(), output.data(), 414*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 415*4bdc9457SAndroid Build Coastguard Worker 416*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 417*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(unpooling_op, nullptr /* thread pool */)); 418*4bdc9457SAndroid Build Coastguard Worker 419*4bdc9457SAndroid Build Coastguard Worker // Verify results of the first run. 420*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 421*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 422*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 423*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 424*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(output_ref[((i * output_height() + y) * output_width() + x) * channels() + c], 425*4bdc9457SAndroid Build Coastguard Worker output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]) << 426*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 427*4bdc9457SAndroid Build Coastguard Worker } 428*4bdc9457SAndroid Build Coastguard Worker } 429*4bdc9457SAndroid Build Coastguard Worker } 430*4bdc9457SAndroid Build Coastguard Worker } 431*4bdc9457SAndroid Build Coastguard Worker 432*4bdc9457SAndroid Build Coastguard Worker // Re-generate data for the second run. 433*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), std::ref(u32rng)); 434*4bdc9457SAndroid Build Coastguard Worker std::generate(index.begin(), index.end(), std::ref(idx_rng)); 435*4bdc9457SAndroid Build Coastguard Worker std::generate(output.begin(), output.end(), std::ref(u32rng)); 436*4bdc9457SAndroid Build Coastguard Worker 437*4bdc9457SAndroid Build Coastguard Worker // Compute reference results for the second run, including clamping. 438*4bdc9457SAndroid Build Coastguard Worker std::fill(next_output_ref.begin(), next_output_ref.end(), 0); 439*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 440*4bdc9457SAndroid Build Coastguard Worker for (size_t iy = 0; iy < next_input_height(); iy++) { 441*4bdc9457SAndroid Build Coastguard Worker for (size_t ix = 0; ix < next_input_width(); ix++) { 442*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 443*4bdc9457SAndroid Build Coastguard Worker const uint32_t pooling_index = index[((i * next_input_height() + iy) * next_input_width() + ix) * channels() + c]; 444*4bdc9457SAndroid Build Coastguard Worker const uint32_t py = pooling_index % pooling_height(); 445*4bdc9457SAndroid Build Coastguard Worker const uint32_t px = pooling_index / pooling_height(); 446*4bdc9457SAndroid Build Coastguard Worker const size_t oy = std::min<size_t>(std::max<size_t>(iy * pooling_height() + py, padding_top()) - padding_top(), next_output_height() - 1); 447*4bdc9457SAndroid Build Coastguard Worker const size_t ox = std::min<size_t>(std::max<size_t>(ix * pooling_width() + px, padding_left()) - padding_left(), next_output_width() - 1); 448*4bdc9457SAndroid Build Coastguard Worker next_output_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c] = 449*4bdc9457SAndroid Build Coastguard Worker input[((i * next_input_height() + iy) * next_input_width() + ix) * input_pixel_stride() + c]; 450*4bdc9457SAndroid Build Coastguard Worker } 451*4bdc9457SAndroid Build Coastguard Worker } 452*4bdc9457SAndroid Build Coastguard Worker } 453*4bdc9457SAndroid Build Coastguard Worker } 454*4bdc9457SAndroid Build Coastguard Worker 455*4bdc9457SAndroid Build Coastguard Worker // Setup and run Max Pooling operator the second time, and destroy the operator. 456*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 457*4bdc9457SAndroid Build Coastguard Worker xnn_setup_unpooling2d_nhwc_x32( 458*4bdc9457SAndroid Build Coastguard Worker unpooling_op, 459*4bdc9457SAndroid Build Coastguard Worker next_batch_size(), next_input_height(), next_input_width(), 460*4bdc9457SAndroid Build Coastguard Worker input.data(), index.data(), output.data(), 461*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 462*4bdc9457SAndroid Build Coastguard Worker 463*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 464*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(unpooling_op, nullptr /* thread pool */)); 465*4bdc9457SAndroid Build Coastguard Worker 466*4bdc9457SAndroid Build Coastguard Worker // Verify results of the second run. 467*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 468*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 469*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < next_output_height(); y++) { 470*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < next_output_width(); x++) { 471*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(next_output_ref[((i * next_output_height() + y) * next_output_width() + x) * channels() + c], 472*4bdc9457SAndroid Build Coastguard Worker output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c]) << 473*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 474*4bdc9457SAndroid Build Coastguard Worker } 475*4bdc9457SAndroid Build Coastguard Worker } 476*4bdc9457SAndroid Build Coastguard Worker } 477*4bdc9457SAndroid Build Coastguard Worker } 478*4bdc9457SAndroid Build Coastguard Worker } 479*4bdc9457SAndroid Build Coastguard Worker } 480*4bdc9457SAndroid Build Coastguard Worker 481*4bdc9457SAndroid Build Coastguard Worker private: 482*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_top_{0}; 483*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_right_{0}; 484*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_bottom_{0}; 485*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_left_{0}; 486*4bdc9457SAndroid Build Coastguard Worker size_t input_height_{1}; 487*4bdc9457SAndroid Build Coastguard Worker size_t input_width_{1}; 488*4bdc9457SAndroid Build Coastguard Worker size_t channels_{1}; 489*4bdc9457SAndroid Build Coastguard Worker size_t batch_size_{1}; 490*4bdc9457SAndroid Build Coastguard Worker size_t input_pixel_stride_{0}; 491*4bdc9457SAndroid Build Coastguard Worker size_t output_pixel_stride_{0}; 492*4bdc9457SAndroid Build Coastguard Worker uint32_t pooling_height_{1}; 493*4bdc9457SAndroid Build Coastguard Worker uint32_t pooling_width_{1}; 494*4bdc9457SAndroid Build Coastguard Worker size_t next_input_height_{0}; 495*4bdc9457SAndroid Build Coastguard Worker size_t next_input_width_{0}; 496*4bdc9457SAndroid Build Coastguard Worker size_t next_batch_size_{0}; 497*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{1}; 498*4bdc9457SAndroid Build Coastguard Worker }; 499