1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker // 3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker #pragma once 7*4bdc9457SAndroid Build Coastguard Worker 8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 9*4bdc9457SAndroid Build Coastguard Worker 10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 11*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 12*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 13*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 14*4bdc9457SAndroid Build Coastguard Worker #include <limits> 15*4bdc9457SAndroid Build Coastguard Worker #include <random> 16*4bdc9457SAndroid Build Coastguard Worker #include <vector> 17*4bdc9457SAndroid Build Coastguard Worker 18*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 19*4bdc9457SAndroid Build Coastguard Worker 20*4bdc9457SAndroid Build Coastguard Worker 21*4bdc9457SAndroid Build Coastguard Worker class ArgmaxPoolingOperatorTester { 22*4bdc9457SAndroid Build Coastguard Worker public: padding_tf_same(bool padding_same)23*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& padding_tf_same(bool padding_same) { 24*4bdc9457SAndroid Build Coastguard Worker if (padding_same) { 25*4bdc9457SAndroid Build Coastguard Worker assert(padding_top() == 0); 26*4bdc9457SAndroid Build Coastguard Worker assert(padding_left() == 0); 27*4bdc9457SAndroid Build Coastguard Worker assert(padding_bottom() == 0); 28*4bdc9457SAndroid Build Coastguard Worker assert(padding_right() == 0); 29*4bdc9457SAndroid Build Coastguard Worker } 30*4bdc9457SAndroid Build Coastguard Worker this->padding_tf_same_ = padding_same; 31*4bdc9457SAndroid Build Coastguard Worker return *this; 32*4bdc9457SAndroid Build Coastguard Worker } 33*4bdc9457SAndroid Build Coastguard Worker padding_tf_same()34*4bdc9457SAndroid Build Coastguard Worker inline bool padding_tf_same() const { 35*4bdc9457SAndroid Build Coastguard Worker return this->padding_tf_same_; 36*4bdc9457SAndroid Build Coastguard Worker } 37*4bdc9457SAndroid Build Coastguard Worker padding(uint32_t padding)38*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& padding(uint32_t padding) { 39*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 40*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding; 41*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding; 42*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding; 43*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding; 44*4bdc9457SAndroid Build Coastguard Worker return *this; 45*4bdc9457SAndroid Build Coastguard Worker } 46*4bdc9457SAndroid Build Coastguard Worker padding(uint32_t padding_height,uint32_t padding_width)47*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& padding(uint32_t padding_height, uint32_t padding_width) { 48*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 49*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding_height; 50*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding_width; 51*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding_height; 52*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding_width; 53*4bdc9457SAndroid Build Coastguard Worker return *this; 54*4bdc9457SAndroid Build Coastguard Worker } 55*4bdc9457SAndroid Build Coastguard Worker padding_height(uint32_t padding_height)56*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& padding_height(uint32_t padding_height) { 57*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 58*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding_height; 59*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding_height; 60*4bdc9457SAndroid Build Coastguard Worker return *this; 61*4bdc9457SAndroid Build Coastguard Worker } 62*4bdc9457SAndroid Build Coastguard Worker padding_width(uint32_t padding_width)63*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& padding_width(uint32_t padding_width) { 64*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 65*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding_width; 66*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding_width; 67*4bdc9457SAndroid Build Coastguard Worker return *this; 68*4bdc9457SAndroid Build Coastguard Worker } 69*4bdc9457SAndroid Build Coastguard Worker padding_top(uint32_t padding_top)70*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& padding_top(uint32_t padding_top) { 71*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 72*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding_top; 73*4bdc9457SAndroid Build Coastguard Worker return *this; 74*4bdc9457SAndroid Build Coastguard Worker } 75*4bdc9457SAndroid Build Coastguard Worker padding_top()76*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_top() const { 77*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 78*4bdc9457SAndroid Build Coastguard Worker const uint32_t total_padding_height = output_height() * pooling_height() - input_height(); 79*4bdc9457SAndroid Build Coastguard Worker return total_padding_height / 2; 80*4bdc9457SAndroid Build Coastguard Worker } else { 81*4bdc9457SAndroid Build Coastguard Worker return this->padding_top_; 82*4bdc9457SAndroid Build Coastguard Worker } 83*4bdc9457SAndroid Build Coastguard Worker } 84*4bdc9457SAndroid Build Coastguard Worker padding_left(uint32_t padding_left)85*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& padding_left(uint32_t padding_left) { 86*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 87*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding_left; 88*4bdc9457SAndroid Build Coastguard Worker return *this; 89*4bdc9457SAndroid Build Coastguard Worker } 90*4bdc9457SAndroid Build Coastguard Worker padding_left()91*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_left() const { 92*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 93*4bdc9457SAndroid Build Coastguard Worker const uint32_t total_padding_width = output_width() * pooling_width() - input_width(); 94*4bdc9457SAndroid Build Coastguard Worker return total_padding_width / 2; 95*4bdc9457SAndroid Build Coastguard Worker } else { 96*4bdc9457SAndroid Build Coastguard Worker return this->padding_left_; 97*4bdc9457SAndroid Build Coastguard Worker } 98*4bdc9457SAndroid Build Coastguard Worker } 99*4bdc9457SAndroid Build Coastguard Worker padding_bottom(uint32_t padding_bottom)100*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& padding_bottom(uint32_t padding_bottom) { 101*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 102*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding_bottom; 103*4bdc9457SAndroid Build Coastguard Worker return *this; 104*4bdc9457SAndroid Build Coastguard Worker } 105*4bdc9457SAndroid Build Coastguard Worker padding_bottom()106*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_bottom() const { 107*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 108*4bdc9457SAndroid Build Coastguard Worker const uint32_t total_padding_height = output_height() * pooling_height() - input_height(); 109*4bdc9457SAndroid Build Coastguard Worker return total_padding_height - total_padding_height / 2; 110*4bdc9457SAndroid Build Coastguard Worker } else { 111*4bdc9457SAndroid Build Coastguard Worker return this->padding_bottom_; 112*4bdc9457SAndroid Build Coastguard Worker } 113*4bdc9457SAndroid Build Coastguard Worker } 114*4bdc9457SAndroid Build Coastguard Worker padding_right(uint32_t padding_right)115*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& padding_right(uint32_t padding_right) { 116*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 117*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding_right; 118*4bdc9457SAndroid Build Coastguard Worker return *this; 119*4bdc9457SAndroid Build Coastguard Worker } 120*4bdc9457SAndroid Build Coastguard Worker padding_right()121*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_right() const { 122*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 123*4bdc9457SAndroid Build Coastguard Worker const uint32_t total_padding_width = output_width() * pooling_width() - input_width(); 124*4bdc9457SAndroid Build Coastguard Worker return total_padding_width - total_padding_width / 2; 125*4bdc9457SAndroid Build Coastguard Worker } else { 126*4bdc9457SAndroid Build Coastguard Worker return this->padding_right_; 127*4bdc9457SAndroid Build Coastguard Worker } 128*4bdc9457SAndroid Build Coastguard Worker } 129*4bdc9457SAndroid Build Coastguard Worker input_size(size_t input_height,size_t input_width)130*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& input_size(size_t input_height, size_t input_width) { 131*4bdc9457SAndroid Build Coastguard Worker assert(input_height >= 1); 132*4bdc9457SAndroid Build Coastguard Worker assert(input_width >= 1); 133*4bdc9457SAndroid Build Coastguard Worker this->input_height_ = input_height; 134*4bdc9457SAndroid Build Coastguard Worker this->input_width_ = input_width; 135*4bdc9457SAndroid Build Coastguard Worker return *this; 136*4bdc9457SAndroid Build Coastguard Worker } 137*4bdc9457SAndroid Build Coastguard Worker input_height(size_t input_height)138*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& input_height(size_t input_height) { 139*4bdc9457SAndroid Build Coastguard Worker assert(input_height >= 1); 140*4bdc9457SAndroid Build Coastguard Worker this->input_height_ = input_height; 141*4bdc9457SAndroid Build Coastguard Worker return *this; 142*4bdc9457SAndroid Build Coastguard Worker } 143*4bdc9457SAndroid Build Coastguard Worker input_height()144*4bdc9457SAndroid Build Coastguard Worker inline size_t input_height() const { 145*4bdc9457SAndroid Build Coastguard Worker return this->input_height_; 146*4bdc9457SAndroid Build Coastguard Worker } 147*4bdc9457SAndroid Build Coastguard Worker input_width(size_t input_width)148*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& input_width(size_t input_width) { 149*4bdc9457SAndroid Build Coastguard Worker assert(input_width >= 1); 150*4bdc9457SAndroid Build Coastguard Worker this->input_width_ = input_width; 151*4bdc9457SAndroid Build Coastguard Worker return *this; 152*4bdc9457SAndroid Build Coastguard Worker } 153*4bdc9457SAndroid Build Coastguard Worker input_width()154*4bdc9457SAndroid Build Coastguard Worker inline size_t input_width() const { 155*4bdc9457SAndroid Build Coastguard Worker return this->input_width_; 156*4bdc9457SAndroid Build Coastguard Worker } 157*4bdc9457SAndroid Build Coastguard Worker channels(size_t channels)158*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& channels(size_t channels) { 159*4bdc9457SAndroid Build Coastguard Worker assert(channels != 0); 160*4bdc9457SAndroid Build Coastguard Worker this->channels_ = channels; 161*4bdc9457SAndroid Build Coastguard Worker return *this; 162*4bdc9457SAndroid Build Coastguard Worker } 163*4bdc9457SAndroid Build Coastguard Worker channels()164*4bdc9457SAndroid Build Coastguard Worker inline size_t channels() const { 165*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 166*4bdc9457SAndroid Build Coastguard Worker } 167*4bdc9457SAndroid Build Coastguard Worker batch_size(size_t batch_size)168*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& batch_size(size_t batch_size) { 169*4bdc9457SAndroid Build Coastguard Worker assert(batch_size != 0); 170*4bdc9457SAndroid Build Coastguard Worker this->batch_size_ = batch_size; 171*4bdc9457SAndroid Build Coastguard Worker return *this; 172*4bdc9457SAndroid Build Coastguard Worker } 173*4bdc9457SAndroid Build Coastguard Worker batch_size()174*4bdc9457SAndroid Build Coastguard Worker inline size_t batch_size() const { 175*4bdc9457SAndroid Build Coastguard Worker return this->batch_size_; 176*4bdc9457SAndroid Build Coastguard Worker } 177*4bdc9457SAndroid Build Coastguard Worker pooling_size(uint32_t pooling_size)178*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& pooling_size(uint32_t pooling_size) { 179*4bdc9457SAndroid Build Coastguard Worker assert(pooling_size >= 1); 180*4bdc9457SAndroid Build Coastguard Worker this->pooling_height_ = pooling_size; 181*4bdc9457SAndroid Build Coastguard Worker this->pooling_width_ = pooling_size; 182*4bdc9457SAndroid Build Coastguard Worker return *this; 183*4bdc9457SAndroid Build Coastguard Worker } 184*4bdc9457SAndroid Build Coastguard Worker pooling_size(uint32_t pooling_height,uint32_t pooling_width)185*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& pooling_size(uint32_t pooling_height, uint32_t pooling_width) { 186*4bdc9457SAndroid Build Coastguard Worker assert(pooling_height >= 1); 187*4bdc9457SAndroid Build Coastguard Worker assert(pooling_width >= 1); 188*4bdc9457SAndroid Build Coastguard Worker this->pooling_height_ = pooling_height; 189*4bdc9457SAndroid Build Coastguard Worker this->pooling_width_ = pooling_width; 190*4bdc9457SAndroid Build Coastguard Worker return *this; 191*4bdc9457SAndroid Build Coastguard Worker } 192*4bdc9457SAndroid Build Coastguard Worker pooling_height(uint32_t pooling_height)193*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& pooling_height(uint32_t pooling_height) { 194*4bdc9457SAndroid Build Coastguard Worker assert(pooling_height >= 1); 195*4bdc9457SAndroid Build Coastguard Worker this->pooling_height_ = pooling_height; 196*4bdc9457SAndroid Build Coastguard Worker return *this; 197*4bdc9457SAndroid Build Coastguard Worker } 198*4bdc9457SAndroid Build Coastguard Worker pooling_height()199*4bdc9457SAndroid Build Coastguard Worker inline uint32_t pooling_height() const { 200*4bdc9457SAndroid Build Coastguard Worker return this->pooling_height_; 201*4bdc9457SAndroid Build Coastguard Worker } 202*4bdc9457SAndroid Build Coastguard Worker pooling_width(uint32_t pooling_width)203*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& pooling_width(uint32_t pooling_width) { 204*4bdc9457SAndroid Build Coastguard Worker assert(pooling_width >= 1); 205*4bdc9457SAndroid Build Coastguard Worker this->pooling_width_ = pooling_width; 206*4bdc9457SAndroid Build Coastguard Worker return *this; 207*4bdc9457SAndroid Build Coastguard Worker } 208*4bdc9457SAndroid Build Coastguard Worker pooling_width()209*4bdc9457SAndroid Build Coastguard Worker inline uint32_t pooling_width() const { 210*4bdc9457SAndroid Build Coastguard Worker return this->pooling_width_; 211*4bdc9457SAndroid Build Coastguard Worker } 212*4bdc9457SAndroid Build Coastguard Worker output_height()213*4bdc9457SAndroid Build Coastguard Worker inline size_t output_height() const { 214*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 215*4bdc9457SAndroid Build Coastguard Worker return (input_height() + pooling_height() - 1) / pooling_height(); 216*4bdc9457SAndroid Build Coastguard Worker } else { 217*4bdc9457SAndroid Build Coastguard Worker const size_t padded_input_height = padding_top() + input_height() + padding_bottom(); 218*4bdc9457SAndroid Build Coastguard Worker return padded_input_height / pooling_height(); 219*4bdc9457SAndroid Build Coastguard Worker } 220*4bdc9457SAndroid Build Coastguard Worker } 221*4bdc9457SAndroid Build Coastguard Worker output_width()222*4bdc9457SAndroid Build Coastguard Worker inline size_t output_width() const { 223*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 224*4bdc9457SAndroid Build Coastguard Worker return (input_width() + pooling_width() - 1) / pooling_width(); 225*4bdc9457SAndroid Build Coastguard Worker } else { 226*4bdc9457SAndroid Build Coastguard Worker const size_t padded_input_width = padding_left() + input_width() + padding_right(); 227*4bdc9457SAndroid Build Coastguard Worker return padded_input_width / pooling_width(); 228*4bdc9457SAndroid Build Coastguard Worker } 229*4bdc9457SAndroid Build Coastguard Worker } 230*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride(size_t input_pixel_stride)231*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& input_pixel_stride(size_t input_pixel_stride) { 232*4bdc9457SAndroid Build Coastguard Worker assert(input_pixel_stride != 0); 233*4bdc9457SAndroid Build Coastguard Worker this->input_pixel_stride_ = input_pixel_stride; 234*4bdc9457SAndroid Build Coastguard Worker return *this; 235*4bdc9457SAndroid Build Coastguard Worker } 236*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride()237*4bdc9457SAndroid Build Coastguard Worker inline size_t input_pixel_stride() const { 238*4bdc9457SAndroid Build Coastguard Worker if (this->input_pixel_stride_ == 0) { 239*4bdc9457SAndroid Build Coastguard Worker return channels(); 240*4bdc9457SAndroid Build Coastguard Worker } else { 241*4bdc9457SAndroid Build Coastguard Worker assert(this->input_pixel_stride_ >= channels()); 242*4bdc9457SAndroid Build Coastguard Worker return this->input_pixel_stride_; 243*4bdc9457SAndroid Build Coastguard Worker } 244*4bdc9457SAndroid Build Coastguard Worker } 245*4bdc9457SAndroid Build Coastguard Worker output_pixel_stride(size_t output_pixel_stride)246*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& output_pixel_stride(size_t output_pixel_stride) { 247*4bdc9457SAndroid Build Coastguard Worker assert(output_pixel_stride != 0); 248*4bdc9457SAndroid Build Coastguard Worker this->output_pixel_stride_ = output_pixel_stride; 249*4bdc9457SAndroid Build Coastguard Worker return *this; 250*4bdc9457SAndroid Build Coastguard Worker } 251*4bdc9457SAndroid Build Coastguard Worker output_pixel_stride()252*4bdc9457SAndroid Build Coastguard Worker inline size_t output_pixel_stride() const { 253*4bdc9457SAndroid Build Coastguard Worker if (this->output_pixel_stride_ == 0) { 254*4bdc9457SAndroid Build Coastguard Worker return channels(); 255*4bdc9457SAndroid Build Coastguard Worker } else { 256*4bdc9457SAndroid Build Coastguard Worker assert(this->output_pixel_stride_ >= channels()); 257*4bdc9457SAndroid Build Coastguard Worker return this->output_pixel_stride_; 258*4bdc9457SAndroid Build Coastguard Worker } 259*4bdc9457SAndroid Build Coastguard Worker } 260*4bdc9457SAndroid Build Coastguard Worker next_input_size(uint32_t next_input_height,uint32_t next_input_width)261*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& next_input_size(uint32_t next_input_height, uint32_t next_input_width) { 262*4bdc9457SAndroid Build Coastguard Worker assert(next_input_height >= 1); 263*4bdc9457SAndroid Build Coastguard Worker assert(next_input_width >= 1); 264*4bdc9457SAndroid Build Coastguard Worker this->next_input_height_ = next_input_height; 265*4bdc9457SAndroid Build Coastguard Worker this->next_input_width_ = next_input_width; 266*4bdc9457SAndroid Build Coastguard Worker return *this; 267*4bdc9457SAndroid Build Coastguard Worker } 268*4bdc9457SAndroid Build Coastguard Worker next_input_height(uint32_t next_input_height)269*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& next_input_height(uint32_t next_input_height) { 270*4bdc9457SAndroid Build Coastguard Worker assert(next_input_height >= 1); 271*4bdc9457SAndroid Build Coastguard Worker this->next_input_height_ = next_input_height; 272*4bdc9457SAndroid Build Coastguard Worker return *this; 273*4bdc9457SAndroid Build Coastguard Worker } 274*4bdc9457SAndroid Build Coastguard Worker next_input_height()275*4bdc9457SAndroid Build Coastguard Worker inline uint32_t next_input_height() const { 276*4bdc9457SAndroid Build Coastguard Worker if (this->next_input_height_ == 0) { 277*4bdc9457SAndroid Build Coastguard Worker return input_height(); 278*4bdc9457SAndroid Build Coastguard Worker } else { 279*4bdc9457SAndroid Build Coastguard Worker return this->next_input_height_; 280*4bdc9457SAndroid Build Coastguard Worker } 281*4bdc9457SAndroid Build Coastguard Worker } 282*4bdc9457SAndroid Build Coastguard Worker next_input_width(uint32_t next_input_width)283*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& next_input_width(uint32_t next_input_width) { 284*4bdc9457SAndroid Build Coastguard Worker assert(next_input_width >= 1); 285*4bdc9457SAndroid Build Coastguard Worker this->next_input_width_ = next_input_width; 286*4bdc9457SAndroid Build Coastguard Worker return *this; 287*4bdc9457SAndroid Build Coastguard Worker } 288*4bdc9457SAndroid Build Coastguard Worker next_input_width()289*4bdc9457SAndroid Build Coastguard Worker inline uint32_t next_input_width() const { 290*4bdc9457SAndroid Build Coastguard Worker if (this->next_input_width_ == 0) { 291*4bdc9457SAndroid Build Coastguard Worker return input_width(); 292*4bdc9457SAndroid Build Coastguard Worker } else { 293*4bdc9457SAndroid Build Coastguard Worker return this->next_input_width_; 294*4bdc9457SAndroid Build Coastguard Worker } 295*4bdc9457SAndroid Build Coastguard Worker } 296*4bdc9457SAndroid Build Coastguard Worker next_output_height()297*4bdc9457SAndroid Build Coastguard Worker inline size_t next_output_height() const { 298*4bdc9457SAndroid Build Coastguard Worker const size_t padded_next_input_height = padding_top() + next_input_height() + padding_bottom(); 299*4bdc9457SAndroid Build Coastguard Worker return padded_next_input_height / pooling_height(); 300*4bdc9457SAndroid Build Coastguard Worker } 301*4bdc9457SAndroid Build Coastguard Worker next_output_width()302*4bdc9457SAndroid Build Coastguard Worker inline size_t next_output_width() const { 303*4bdc9457SAndroid Build Coastguard Worker const size_t padded_next_input_width = padding_left() + next_input_width() + padding_right(); 304*4bdc9457SAndroid Build Coastguard Worker return padded_next_input_width / pooling_width(); 305*4bdc9457SAndroid Build Coastguard Worker } 306*4bdc9457SAndroid Build Coastguard Worker next_batch_size(size_t next_batch_size)307*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& next_batch_size(size_t next_batch_size) { 308*4bdc9457SAndroid Build Coastguard Worker assert(next_batch_size >= 1); 309*4bdc9457SAndroid Build Coastguard Worker this->next_batch_size_ = next_batch_size; 310*4bdc9457SAndroid Build Coastguard Worker return *this; 311*4bdc9457SAndroid Build Coastguard Worker } 312*4bdc9457SAndroid Build Coastguard Worker next_batch_size()313*4bdc9457SAndroid Build Coastguard Worker inline size_t next_batch_size() const { 314*4bdc9457SAndroid Build Coastguard Worker if (this->next_batch_size_ == 0) { 315*4bdc9457SAndroid Build Coastguard Worker return batch_size(); 316*4bdc9457SAndroid Build Coastguard Worker } else { 317*4bdc9457SAndroid Build Coastguard Worker return this->next_batch_size_; 318*4bdc9457SAndroid Build Coastguard Worker } 319*4bdc9457SAndroid Build Coastguard Worker } 320*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)321*4bdc9457SAndroid Build Coastguard Worker inline ArgmaxPoolingOperatorTester& iterations(size_t iterations) { 322*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 323*4bdc9457SAndroid Build Coastguard Worker return *this; 324*4bdc9457SAndroid Build Coastguard Worker } 325*4bdc9457SAndroid Build Coastguard Worker iterations()326*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 327*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 328*4bdc9457SAndroid Build Coastguard Worker } 329*4bdc9457SAndroid Build Coastguard Worker TestF32()330*4bdc9457SAndroid Build Coastguard Worker void TestF32() const { 331*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 332*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 333*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist; 334*4bdc9457SAndroid Build Coastguard Worker 335*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input((batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float)); 336*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels()); 337*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels()); 338*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> index(batch_size() * output_height() * output_width() * channels()); 339*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> index_ref(batch_size() * output_height() * output_width() * channels()); 340*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 341*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 342*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), nanf("")); 343*4bdc9457SAndroid Build Coastguard Worker 344*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 345*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 346*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 347*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 348*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 349*4bdc9457SAndroid Build Coastguard Worker const size_t iy_top_left = std::max<size_t>(oy * pooling_height(), padding_top()) - padding_top(); 350*4bdc9457SAndroid Build Coastguard Worker const size_t ix_top_left = std::max<size_t>(ox * pooling_width(), padding_left()) - padding_left(); 351*4bdc9457SAndroid Build Coastguard Worker float max_value = 352*4bdc9457SAndroid Build Coastguard Worker input[((i * input_height() + iy_top_left) * input_width() + ix_top_left) * input_pixel_stride() + c]; 353*4bdc9457SAndroid Build Coastguard Worker uint32_t max_index = 0; 354*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 355*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * pooling_height() + py - padding_top(); 356*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 357*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * pooling_width() + px - padding_left(); 358*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width() && iy < input_height()) { 359*4bdc9457SAndroid Build Coastguard Worker const float value = input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c]; 360*4bdc9457SAndroid Build Coastguard Worker if (value > max_value) { 361*4bdc9457SAndroid Build Coastguard Worker max_value = value; 362*4bdc9457SAndroid Build Coastguard Worker max_index = uint32_t(px * pooling_height() + py); 363*4bdc9457SAndroid Build Coastguard Worker } 364*4bdc9457SAndroid Build Coastguard Worker } 365*4bdc9457SAndroid Build Coastguard Worker } 366*4bdc9457SAndroid Build Coastguard Worker } 367*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = max_value; 368*4bdc9457SAndroid Build Coastguard Worker index_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = max_index; 369*4bdc9457SAndroid Build Coastguard Worker } 370*4bdc9457SAndroid Build Coastguard Worker } 371*4bdc9457SAndroid Build Coastguard Worker } 372*4bdc9457SAndroid Build Coastguard Worker } 373*4bdc9457SAndroid Build Coastguard Worker 374*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Argmax Pooling operator. 375*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 376*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t argmax_pooling_op = nullptr; 377*4bdc9457SAndroid Build Coastguard Worker 378*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 379*4bdc9457SAndroid Build Coastguard Worker xnn_create_argmax_pooling2d_nhwc_f32( 380*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_top(), padding_tf_same() ? 0 : padding_right(), 381*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_bottom(), padding_tf_same() ? 0 : padding_left(), 382*4bdc9457SAndroid Build Coastguard Worker pooling_height(), pooling_width(), 383*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 384*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? XNN_FLAG_TENSORFLOW_SAME_PADDING : 0, 385*4bdc9457SAndroid Build Coastguard Worker &argmax_pooling_op)); 386*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, argmax_pooling_op); 387*4bdc9457SAndroid Build Coastguard Worker 388*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete argmax_pooling_op. 389*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_argmax_pooling_op(argmax_pooling_op, xnn_delete_operator); 390*4bdc9457SAndroid Build Coastguard Worker 391*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 392*4bdc9457SAndroid Build Coastguard Worker xnn_setup_argmax_pooling2d_nhwc_f32( 393*4bdc9457SAndroid Build Coastguard Worker argmax_pooling_op, 394*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 395*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), index.data(), 396*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 397*4bdc9457SAndroid Build Coastguard Worker 398*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 399*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(argmax_pooling_op, nullptr /* thread pool */)); 400*4bdc9457SAndroid Build Coastguard Worker 401*4bdc9457SAndroid Build Coastguard Worker // Verify results. 402*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 403*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 404*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 405*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 406*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(output_ref[((i * output_height() + y) * output_width() + x) * channels() + c], 407*4bdc9457SAndroid Build Coastguard Worker output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]) << 408*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 409*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(index_ref[((i * output_height() + y) * output_width() + x) * channels() + c], 410*4bdc9457SAndroid Build Coastguard Worker index[((i * output_height() + y) * output_width() + x) * channels() + c]) << 411*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 412*4bdc9457SAndroid Build Coastguard Worker } 413*4bdc9457SAndroid Build Coastguard Worker } 414*4bdc9457SAndroid Build Coastguard Worker } 415*4bdc9457SAndroid Build Coastguard Worker } 416*4bdc9457SAndroid Build Coastguard Worker } 417*4bdc9457SAndroid Build Coastguard Worker } 418*4bdc9457SAndroid Build Coastguard Worker TestSetupF32()419*4bdc9457SAndroid Build Coastguard Worker void TestSetupF32() const { 420*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 421*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 422*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist; 423*4bdc9457SAndroid Build Coastguard Worker 424*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + std::max<size_t>( 425*4bdc9457SAndroid Build Coastguard Worker (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels(), 426*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_input_height() * next_input_width() - 1) * input_pixel_stride() + channels())); 427*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output(std::max<size_t>( 428*4bdc9457SAndroid Build Coastguard Worker (batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels(), 429*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_output_height() * next_output_width() - 1) * output_pixel_stride() + channels())); 430*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> index(std::max<size_t>( 431*4bdc9457SAndroid Build Coastguard Worker batch_size() * output_height() * output_width() * channels(), 432*4bdc9457SAndroid Build Coastguard Worker next_batch_size() * next_output_height() * next_output_width() * channels())); 433*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels()); 434*4bdc9457SAndroid Build Coastguard Worker std::vector<float> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * channels()); 435*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> index_ref(batch_size() * output_height() * output_width() * channels()); 436*4bdc9457SAndroid Build Coastguard Worker std::vector<uint32_t> next_index_ref(next_batch_size() * next_output_height() * next_output_width() * channels()); 437*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 438*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 439*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), nanf("")); 440*4bdc9457SAndroid Build Coastguard Worker 441*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 442*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 443*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 444*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 445*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 446*4bdc9457SAndroid Build Coastguard Worker const size_t iy_top_left = std::max<size_t>(oy * pooling_height(), padding_top()) - padding_top(); 447*4bdc9457SAndroid Build Coastguard Worker const size_t ix_top_left = std::max<size_t>(ox * pooling_width(), padding_left()) - padding_left(); 448*4bdc9457SAndroid Build Coastguard Worker float max_value = 449*4bdc9457SAndroid Build Coastguard Worker input[((i * input_height() + iy_top_left) * input_width() + ix_top_left) * input_pixel_stride() + c]; 450*4bdc9457SAndroid Build Coastguard Worker uint32_t max_index = 0; 451*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 452*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * pooling_height() + py - padding_top(); 453*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 454*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * pooling_width() + px - padding_left(); 455*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width() && iy < input_height()) { 456*4bdc9457SAndroid Build Coastguard Worker const float value = input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c]; 457*4bdc9457SAndroid Build Coastguard Worker if (value > max_value) { 458*4bdc9457SAndroid Build Coastguard Worker max_value = value; 459*4bdc9457SAndroid Build Coastguard Worker max_index = uint32_t(px * pooling_height() + py); 460*4bdc9457SAndroid Build Coastguard Worker } 461*4bdc9457SAndroid Build Coastguard Worker } 462*4bdc9457SAndroid Build Coastguard Worker } 463*4bdc9457SAndroid Build Coastguard Worker } 464*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = max_value; 465*4bdc9457SAndroid Build Coastguard Worker index_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = max_index; 466*4bdc9457SAndroid Build Coastguard Worker } 467*4bdc9457SAndroid Build Coastguard Worker } 468*4bdc9457SAndroid Build Coastguard Worker } 469*4bdc9457SAndroid Build Coastguard Worker } 470*4bdc9457SAndroid Build Coastguard Worker 471*4bdc9457SAndroid Build Coastguard Worker // Create, setup, and run Argmax Pooling operator once. 472*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 473*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t argmax_pooling_op = nullptr; 474*4bdc9457SAndroid Build Coastguard Worker 475*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 476*4bdc9457SAndroid Build Coastguard Worker xnn_create_argmax_pooling2d_nhwc_f32( 477*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(), 478*4bdc9457SAndroid Build Coastguard Worker pooling_height(), pooling_width(), 479*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 480*4bdc9457SAndroid Build Coastguard Worker 0, &argmax_pooling_op)); 481*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, argmax_pooling_op); 482*4bdc9457SAndroid Build Coastguard Worker 483*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 484*4bdc9457SAndroid Build Coastguard Worker xnn_setup_argmax_pooling2d_nhwc_f32( 485*4bdc9457SAndroid Build Coastguard Worker argmax_pooling_op, 486*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 487*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), index.data(), 488*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 489*4bdc9457SAndroid Build Coastguard Worker 490*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 491*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(argmax_pooling_op, nullptr /* thread pool */)); 492*4bdc9457SAndroid Build Coastguard Worker 493*4bdc9457SAndroid Build Coastguard Worker // Verify results of the first run. 494*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 495*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 496*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 497*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 498*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ( 499*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + y) * output_width() + x) * channels() + c], 500*4bdc9457SAndroid Build Coastguard Worker output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]) 501*4bdc9457SAndroid Build Coastguard Worker << "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 502*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ( 503*4bdc9457SAndroid Build Coastguard Worker index_ref[((i * output_height() + y) * output_width() + x) * channels() + c], 504*4bdc9457SAndroid Build Coastguard Worker index[((i * output_height() + y) * output_width() + x) * channels() + c]) 505*4bdc9457SAndroid Build Coastguard Worker << "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 506*4bdc9457SAndroid Build Coastguard Worker } 507*4bdc9457SAndroid Build Coastguard Worker } 508*4bdc9457SAndroid Build Coastguard Worker } 509*4bdc9457SAndroid Build Coastguard Worker } 510*4bdc9457SAndroid Build Coastguard Worker 511*4bdc9457SAndroid Build Coastguard Worker // Re-generate data for the second run. 512*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 513*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), std::nanf("")); 514*4bdc9457SAndroid Build Coastguard Worker 515*4bdc9457SAndroid Build Coastguard Worker // Compute reference results for the second run, including clamping. 516*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 517*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) { 518*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) { 519*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 520*4bdc9457SAndroid Build Coastguard Worker const size_t iy_top_left = std::max<size_t>(oy * pooling_height(), padding_top()) - padding_top(); 521*4bdc9457SAndroid Build Coastguard Worker const size_t ix_top_left = std::max<size_t>(ox * pooling_width(), padding_left()) - padding_left(); 522*4bdc9457SAndroid Build Coastguard Worker float max_value = 523*4bdc9457SAndroid Build Coastguard Worker input[((i * next_input_height() + iy_top_left) * next_input_width() + ix_top_left) * input_pixel_stride() + c]; 524*4bdc9457SAndroid Build Coastguard Worker uint32_t max_index = 0; 525*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 526*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * pooling_height() + py - padding_top(); 527*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 528*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * pooling_width() + px - padding_left(); 529*4bdc9457SAndroid Build Coastguard Worker if (ix < next_input_width() && iy < next_input_height()) { 530*4bdc9457SAndroid Build Coastguard Worker const float value = input[((i * next_input_height() + iy) * next_input_width() + ix) * input_pixel_stride() + c]; 531*4bdc9457SAndroid Build Coastguard Worker if (value > max_value) { 532*4bdc9457SAndroid Build Coastguard Worker max_value = value; 533*4bdc9457SAndroid Build Coastguard Worker max_index = uint32_t(px * pooling_height() + py); 534*4bdc9457SAndroid Build Coastguard Worker } 535*4bdc9457SAndroid Build Coastguard Worker } 536*4bdc9457SAndroid Build Coastguard Worker } 537*4bdc9457SAndroid Build Coastguard Worker } 538*4bdc9457SAndroid Build Coastguard Worker next_output_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c] = max_value; 539*4bdc9457SAndroid Build Coastguard Worker next_index_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c] = max_index; 540*4bdc9457SAndroid Build Coastguard Worker } 541*4bdc9457SAndroid Build Coastguard Worker } 542*4bdc9457SAndroid Build Coastguard Worker } 543*4bdc9457SAndroid Build Coastguard Worker } 544*4bdc9457SAndroid Build Coastguard Worker 545*4bdc9457SAndroid Build Coastguard Worker // Setup and run Argmax Pooling operator the second time, and destroy the operator. 546*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 547*4bdc9457SAndroid Build Coastguard Worker xnn_setup_argmax_pooling2d_nhwc_f32( 548*4bdc9457SAndroid Build Coastguard Worker argmax_pooling_op, 549*4bdc9457SAndroid Build Coastguard Worker next_batch_size(), next_input_height(), next_input_width(), 550*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), index.data(), 551*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 552*4bdc9457SAndroid Build Coastguard Worker 553*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 554*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(argmax_pooling_op, nullptr /* thread pool */)); 555*4bdc9457SAndroid Build Coastguard Worker 556*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 557*4bdc9457SAndroid Build Coastguard Worker xnn_delete_operator(argmax_pooling_op)); 558*4bdc9457SAndroid Build Coastguard Worker argmax_pooling_op = nullptr; 559*4bdc9457SAndroid Build Coastguard Worker 560*4bdc9457SAndroid Build Coastguard Worker // Verify results of the second run. 561*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 562*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < next_output_height(); y++) { 563*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < next_output_width(); x++) { 564*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 565*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ( 566*4bdc9457SAndroid Build Coastguard Worker next_output_ref[((i * next_output_height() + y) * next_output_width() + x) * channels() + c], 567*4bdc9457SAndroid Build Coastguard Worker output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c]) 568*4bdc9457SAndroid Build Coastguard Worker << "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 569*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ( 570*4bdc9457SAndroid Build Coastguard Worker next_index_ref[((i * next_output_height() + y) * next_output_width() + x) * channels() + c], 571*4bdc9457SAndroid Build Coastguard Worker index[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c]) 572*4bdc9457SAndroid Build Coastguard Worker << "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 573*4bdc9457SAndroid Build Coastguard Worker } 574*4bdc9457SAndroid Build Coastguard Worker } 575*4bdc9457SAndroid Build Coastguard Worker } 576*4bdc9457SAndroid Build Coastguard Worker } 577*4bdc9457SAndroid Build Coastguard Worker } 578*4bdc9457SAndroid Build Coastguard Worker } 579*4bdc9457SAndroid Build Coastguard Worker 580*4bdc9457SAndroid Build Coastguard Worker private: 581*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_top_{0}; 582*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_right_{0}; 583*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_bottom_{0}; 584*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_left_{0}; 585*4bdc9457SAndroid Build Coastguard Worker bool padding_tf_same_{false}; 586*4bdc9457SAndroid Build Coastguard Worker size_t input_height_{1}; 587*4bdc9457SAndroid Build Coastguard Worker size_t input_width_{1}; 588*4bdc9457SAndroid Build Coastguard Worker size_t channels_{1}; 589*4bdc9457SAndroid Build Coastguard Worker size_t batch_size_{1}; 590*4bdc9457SAndroid Build Coastguard Worker size_t input_pixel_stride_{0}; 591*4bdc9457SAndroid Build Coastguard Worker size_t output_pixel_stride_{0}; 592*4bdc9457SAndroid Build Coastguard Worker uint32_t pooling_height_{1}; 593*4bdc9457SAndroid Build Coastguard Worker uint32_t pooling_width_{1}; 594*4bdc9457SAndroid Build Coastguard Worker size_t next_input_height_{0}; 595*4bdc9457SAndroid Build Coastguard Worker size_t next_input_width_{0}; 596*4bdc9457SAndroid Build Coastguard Worker size_t next_batch_size_{0}; 597*4bdc9457SAndroid Build Coastguard Worker uint8_t qmin_{0}; 598*4bdc9457SAndroid Build Coastguard Worker uint8_t qmax_{255}; 599*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{1}; 600*4bdc9457SAndroid Build Coastguard Worker }; 601