1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates. 2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved. 3*4bdc9457SAndroid Build Coastguard Worker // 4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC 5*4bdc9457SAndroid Build Coastguard Worker // 6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 8*4bdc9457SAndroid Build Coastguard Worker 9*4bdc9457SAndroid Build Coastguard Worker #pragma once 10*4bdc9457SAndroid Build Coastguard Worker 11*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 12*4bdc9457SAndroid Build Coastguard Worker 13*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h> 14*4bdc9457SAndroid Build Coastguard Worker 15*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 16*4bdc9457SAndroid Build Coastguard Worker #include <cmath> 17*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 18*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 19*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 20*4bdc9457SAndroid Build Coastguard Worker #include <limits> 21*4bdc9457SAndroid Build Coastguard Worker #include <random> 22*4bdc9457SAndroid Build Coastguard Worker #include <vector> 23*4bdc9457SAndroid Build Coastguard Worker 24*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 25*4bdc9457SAndroid Build Coastguard Worker 26*4bdc9457SAndroid Build Coastguard Worker 27*4bdc9457SAndroid Build Coastguard Worker class AveragePoolingOperatorTester { 28*4bdc9457SAndroid Build Coastguard Worker public: padding_tf_same(bool padding_same)29*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& padding_tf_same(bool padding_same) { 30*4bdc9457SAndroid Build Coastguard Worker if (padding_same) { 31*4bdc9457SAndroid Build Coastguard Worker assert(padding_top() == 0); 32*4bdc9457SAndroid Build Coastguard Worker assert(padding_left() == 0); 33*4bdc9457SAndroid Build Coastguard Worker assert(padding_bottom() == 0); 34*4bdc9457SAndroid Build Coastguard Worker assert(padding_right() == 0); 35*4bdc9457SAndroid Build Coastguard Worker } 36*4bdc9457SAndroid Build Coastguard Worker this->padding_tf_same_ = padding_same; 37*4bdc9457SAndroid Build Coastguard Worker return *this; 38*4bdc9457SAndroid Build Coastguard Worker } 39*4bdc9457SAndroid Build Coastguard Worker padding_tf_same()40*4bdc9457SAndroid Build Coastguard Worker inline bool padding_tf_same() const { 41*4bdc9457SAndroid Build Coastguard Worker return this->padding_tf_same_; 42*4bdc9457SAndroid Build Coastguard Worker } 43*4bdc9457SAndroid Build Coastguard Worker padding(uint32_t padding)44*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& padding(uint32_t padding) { 45*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 46*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding; 47*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding; 48*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding; 49*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding; 50*4bdc9457SAndroid Build Coastguard Worker return *this; 51*4bdc9457SAndroid Build Coastguard Worker } 52*4bdc9457SAndroid Build Coastguard Worker padding(uint32_t padding_height,uint32_t padding_width)53*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& padding(uint32_t padding_height, uint32_t padding_width) { 54*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 55*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding_height; 56*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding_width; 57*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding_height; 58*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding_width; 59*4bdc9457SAndroid Build Coastguard Worker return *this; 60*4bdc9457SAndroid Build Coastguard Worker } 61*4bdc9457SAndroid Build Coastguard Worker padding_height(uint32_t padding_height)62*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& padding_height(uint32_t padding_height) { 63*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 64*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding_height; 65*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding_height; 66*4bdc9457SAndroid Build Coastguard Worker return *this; 67*4bdc9457SAndroid Build Coastguard Worker } 68*4bdc9457SAndroid Build Coastguard Worker padding_width(uint32_t padding_width)69*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& padding_width(uint32_t padding_width) { 70*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 71*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding_width; 72*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding_width; 73*4bdc9457SAndroid Build Coastguard Worker return *this; 74*4bdc9457SAndroid Build Coastguard Worker } 75*4bdc9457SAndroid Build Coastguard Worker padding_top(uint32_t padding_top)76*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& padding_top(uint32_t padding_top) { 77*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 78*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding_top; 79*4bdc9457SAndroid Build Coastguard Worker return *this; 80*4bdc9457SAndroid Build Coastguard Worker } 81*4bdc9457SAndroid Build Coastguard Worker padding_top()82*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_top() const { 83*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 84*4bdc9457SAndroid Build Coastguard Worker const uint32_t total_padding_height = 85*4bdc9457SAndroid Build Coastguard Worker (output_height() - 1) * stride_height() + pooling_height() - input_height(); 86*4bdc9457SAndroid Build Coastguard Worker return total_padding_height / 2; 87*4bdc9457SAndroid Build Coastguard Worker } else { 88*4bdc9457SAndroid Build Coastguard Worker return this->padding_top_; 89*4bdc9457SAndroid Build Coastguard Worker } 90*4bdc9457SAndroid Build Coastguard Worker } 91*4bdc9457SAndroid Build Coastguard Worker padding_left(uint32_t padding_left)92*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& padding_left(uint32_t padding_left) { 93*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 94*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding_left; 95*4bdc9457SAndroid Build Coastguard Worker return *this; 96*4bdc9457SAndroid Build Coastguard Worker } 97*4bdc9457SAndroid Build Coastguard Worker padding_left()98*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_left() const { 99*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 100*4bdc9457SAndroid Build Coastguard Worker const uint32_t total_padding_width = 101*4bdc9457SAndroid Build Coastguard Worker (output_width() - 1) * stride_width() + pooling_width() - input_width(); 102*4bdc9457SAndroid Build Coastguard Worker return total_padding_width / 2; 103*4bdc9457SAndroid Build Coastguard Worker } else { 104*4bdc9457SAndroid Build Coastguard Worker return this->padding_left_; 105*4bdc9457SAndroid Build Coastguard Worker } 106*4bdc9457SAndroid Build Coastguard Worker } 107*4bdc9457SAndroid Build Coastguard Worker padding_bottom(uint32_t padding_bottom)108*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& padding_bottom(uint32_t padding_bottom) { 109*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 110*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding_bottom; 111*4bdc9457SAndroid Build Coastguard Worker return *this; 112*4bdc9457SAndroid Build Coastguard Worker } 113*4bdc9457SAndroid Build Coastguard Worker padding_bottom()114*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_bottom() const { 115*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 116*4bdc9457SAndroid Build Coastguard Worker const uint32_t total_padding_height = 117*4bdc9457SAndroid Build Coastguard Worker (output_height() - 1) * stride_height() + pooling_height() - input_height(); 118*4bdc9457SAndroid Build Coastguard Worker return total_padding_height - total_padding_height / 2; 119*4bdc9457SAndroid Build Coastguard Worker } else { 120*4bdc9457SAndroid Build Coastguard Worker return this->padding_bottom_; 121*4bdc9457SAndroid Build Coastguard Worker } 122*4bdc9457SAndroid Build Coastguard Worker } 123*4bdc9457SAndroid Build Coastguard Worker padding_right(uint32_t padding_right)124*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& padding_right(uint32_t padding_right) { 125*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 126*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding_right; 127*4bdc9457SAndroid Build Coastguard Worker return *this; 128*4bdc9457SAndroid Build Coastguard Worker } 129*4bdc9457SAndroid Build Coastguard Worker padding_right()130*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_right() const { 131*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 132*4bdc9457SAndroid Build Coastguard Worker const uint32_t total_padding_width = 133*4bdc9457SAndroid Build Coastguard Worker (output_width() - 1) * stride_width() + pooling_width() - input_width(); 134*4bdc9457SAndroid Build Coastguard Worker return total_padding_width - total_padding_width / 2; 135*4bdc9457SAndroid Build Coastguard Worker } else { 136*4bdc9457SAndroid Build Coastguard Worker return this->padding_right_; 137*4bdc9457SAndroid Build Coastguard Worker } 138*4bdc9457SAndroid Build Coastguard Worker } 139*4bdc9457SAndroid Build Coastguard Worker input_size(size_t input_height,size_t input_width)140*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& input_size(size_t input_height, size_t input_width) { 141*4bdc9457SAndroid Build Coastguard Worker assert(input_height >= 1); 142*4bdc9457SAndroid Build Coastguard Worker assert(input_width >= 1); 143*4bdc9457SAndroid Build Coastguard Worker this->input_height_ = input_height; 144*4bdc9457SAndroid Build Coastguard Worker this->input_width_ = input_width; 145*4bdc9457SAndroid Build Coastguard Worker return *this; 146*4bdc9457SAndroid Build Coastguard Worker } 147*4bdc9457SAndroid Build Coastguard Worker input_height(size_t input_height)148*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& input_height(size_t input_height) { 149*4bdc9457SAndroid Build Coastguard Worker assert(input_height >= 1); 150*4bdc9457SAndroid Build Coastguard Worker this->input_height_ = input_height; 151*4bdc9457SAndroid Build Coastguard Worker return *this; 152*4bdc9457SAndroid Build Coastguard Worker } 153*4bdc9457SAndroid Build Coastguard Worker input_height()154*4bdc9457SAndroid Build Coastguard Worker inline size_t input_height() const { 155*4bdc9457SAndroid Build Coastguard Worker return this->input_height_; 156*4bdc9457SAndroid Build Coastguard Worker } 157*4bdc9457SAndroid Build Coastguard Worker input_width(size_t input_width)158*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& input_width(size_t input_width) { 159*4bdc9457SAndroid Build Coastguard Worker assert(input_width >= 1); 160*4bdc9457SAndroid Build Coastguard Worker this->input_width_ = input_width; 161*4bdc9457SAndroid Build Coastguard Worker return *this; 162*4bdc9457SAndroid Build Coastguard Worker } 163*4bdc9457SAndroid Build Coastguard Worker input_width()164*4bdc9457SAndroid Build Coastguard Worker inline size_t input_width() const { 165*4bdc9457SAndroid Build Coastguard Worker return this->input_width_; 166*4bdc9457SAndroid Build Coastguard Worker } 167*4bdc9457SAndroid Build Coastguard Worker channels(size_t channels)168*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& channels(size_t channels) { 169*4bdc9457SAndroid Build Coastguard Worker assert(channels != 0); 170*4bdc9457SAndroid Build Coastguard Worker this->channels_ = channels; 171*4bdc9457SAndroid Build Coastguard Worker return *this; 172*4bdc9457SAndroid Build Coastguard Worker } 173*4bdc9457SAndroid Build Coastguard Worker channels()174*4bdc9457SAndroid Build Coastguard Worker inline size_t channels() const { 175*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 176*4bdc9457SAndroid Build Coastguard Worker } 177*4bdc9457SAndroid Build Coastguard Worker batch_size(size_t batch_size)178*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& batch_size(size_t batch_size) { 179*4bdc9457SAndroid Build Coastguard Worker assert(batch_size != 0); 180*4bdc9457SAndroid Build Coastguard Worker this->batch_size_ = batch_size; 181*4bdc9457SAndroid Build Coastguard Worker return *this; 182*4bdc9457SAndroid Build Coastguard Worker } 183*4bdc9457SAndroid Build Coastguard Worker batch_size()184*4bdc9457SAndroid Build Coastguard Worker inline size_t batch_size() const { 185*4bdc9457SAndroid Build Coastguard Worker return this->batch_size_; 186*4bdc9457SAndroid Build Coastguard Worker } 187*4bdc9457SAndroid Build Coastguard Worker pooling_size(uint32_t pooling_size)188*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& pooling_size(uint32_t pooling_size) { 189*4bdc9457SAndroid Build Coastguard Worker assert(pooling_size >= 1); 190*4bdc9457SAndroid Build Coastguard Worker this->pooling_height_ = pooling_size; 191*4bdc9457SAndroid Build Coastguard Worker this->pooling_width_ = pooling_size; 192*4bdc9457SAndroid Build Coastguard Worker return *this; 193*4bdc9457SAndroid Build Coastguard Worker } 194*4bdc9457SAndroid Build Coastguard Worker pooling_size(uint32_t pooling_height,uint32_t pooling_width)195*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& pooling_size(uint32_t pooling_height, uint32_t pooling_width) { 196*4bdc9457SAndroid Build Coastguard Worker assert(pooling_height >= 1); 197*4bdc9457SAndroid Build Coastguard Worker assert(pooling_width >= 1); 198*4bdc9457SAndroid Build Coastguard Worker this->pooling_height_ = pooling_height; 199*4bdc9457SAndroid Build Coastguard Worker this->pooling_width_ = pooling_width; 200*4bdc9457SAndroid Build Coastguard Worker return *this; 201*4bdc9457SAndroid Build Coastguard Worker } 202*4bdc9457SAndroid Build Coastguard Worker pooling_height(uint32_t pooling_height)203*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& pooling_height(uint32_t pooling_height) { 204*4bdc9457SAndroid Build Coastguard Worker assert(pooling_height >= 1); 205*4bdc9457SAndroid Build Coastguard Worker this->pooling_height_ = pooling_height; 206*4bdc9457SAndroid Build Coastguard Worker return *this; 207*4bdc9457SAndroid Build Coastguard Worker } 208*4bdc9457SAndroid Build Coastguard Worker pooling_height()209*4bdc9457SAndroid Build Coastguard Worker inline uint32_t pooling_height() const { 210*4bdc9457SAndroid Build Coastguard Worker return this->pooling_height_; 211*4bdc9457SAndroid Build Coastguard Worker } 212*4bdc9457SAndroid Build Coastguard Worker pooling_width(uint32_t pooling_width)213*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& pooling_width(uint32_t pooling_width) { 214*4bdc9457SAndroid Build Coastguard Worker assert(pooling_width >= 1); 215*4bdc9457SAndroid Build Coastguard Worker this->pooling_width_ = pooling_width; 216*4bdc9457SAndroid Build Coastguard Worker return *this; 217*4bdc9457SAndroid Build Coastguard Worker } 218*4bdc9457SAndroid Build Coastguard Worker pooling_width()219*4bdc9457SAndroid Build Coastguard Worker inline uint32_t pooling_width() const { 220*4bdc9457SAndroid Build Coastguard Worker return this->pooling_width_; 221*4bdc9457SAndroid Build Coastguard Worker } 222*4bdc9457SAndroid Build Coastguard Worker stride(uint32_t stride)223*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& stride(uint32_t stride) { 224*4bdc9457SAndroid Build Coastguard Worker assert(stride >= 1); 225*4bdc9457SAndroid Build Coastguard Worker this->stride_height_ = stride; 226*4bdc9457SAndroid Build Coastguard Worker this->stride_width_ = stride; 227*4bdc9457SAndroid Build Coastguard Worker return *this; 228*4bdc9457SAndroid Build Coastguard Worker } 229*4bdc9457SAndroid Build Coastguard Worker stride(uint32_t stride_height,uint32_t stride_width)230*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& stride(uint32_t stride_height, uint32_t stride_width) { 231*4bdc9457SAndroid Build Coastguard Worker assert(stride_height >= 1); 232*4bdc9457SAndroid Build Coastguard Worker assert(stride_width >= 1); 233*4bdc9457SAndroid Build Coastguard Worker this->stride_height_ = stride_height; 234*4bdc9457SAndroid Build Coastguard Worker this->stride_width_ = stride_width; 235*4bdc9457SAndroid Build Coastguard Worker return *this; 236*4bdc9457SAndroid Build Coastguard Worker } 237*4bdc9457SAndroid Build Coastguard Worker stride_height(uint32_t stride_height)238*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& stride_height(uint32_t stride_height) { 239*4bdc9457SAndroid Build Coastguard Worker assert(stride_height >= 1); 240*4bdc9457SAndroid Build Coastguard Worker this->stride_height_ = stride_height; 241*4bdc9457SAndroid Build Coastguard Worker return *this; 242*4bdc9457SAndroid Build Coastguard Worker } 243*4bdc9457SAndroid Build Coastguard Worker stride_height()244*4bdc9457SAndroid Build Coastguard Worker inline uint32_t stride_height() const { 245*4bdc9457SAndroid Build Coastguard Worker return this->stride_height_; 246*4bdc9457SAndroid Build Coastguard Worker } 247*4bdc9457SAndroid Build Coastguard Worker stride_width(uint32_t stride_width)248*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& stride_width(uint32_t stride_width) { 249*4bdc9457SAndroid Build Coastguard Worker assert(stride_width >= 1); 250*4bdc9457SAndroid Build Coastguard Worker this->stride_width_ = stride_width; 251*4bdc9457SAndroid Build Coastguard Worker return *this; 252*4bdc9457SAndroid Build Coastguard Worker } 253*4bdc9457SAndroid Build Coastguard Worker stride_width()254*4bdc9457SAndroid Build Coastguard Worker inline uint32_t stride_width() const { 255*4bdc9457SAndroid Build Coastguard Worker return this->stride_width_; 256*4bdc9457SAndroid Build Coastguard Worker } 257*4bdc9457SAndroid Build Coastguard Worker output_height()258*4bdc9457SAndroid Build Coastguard Worker inline size_t output_height() const { 259*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 260*4bdc9457SAndroid Build Coastguard Worker return (input_height() + stride_height() - 1) / stride_height(); 261*4bdc9457SAndroid Build Coastguard Worker } else { 262*4bdc9457SAndroid Build Coastguard Worker const size_t padded_input_height = padding_top() + input_height() + padding_bottom(); 263*4bdc9457SAndroid Build Coastguard Worker if (padded_input_height <= pooling_height()) { 264*4bdc9457SAndroid Build Coastguard Worker return 1; 265*4bdc9457SAndroid Build Coastguard Worker } else { 266*4bdc9457SAndroid Build Coastguard Worker return (padded_input_height - pooling_height()) / stride_height() + 1; 267*4bdc9457SAndroid Build Coastguard Worker } 268*4bdc9457SAndroid Build Coastguard Worker } 269*4bdc9457SAndroid Build Coastguard Worker } 270*4bdc9457SAndroid Build Coastguard Worker output_width()271*4bdc9457SAndroid Build Coastguard Worker inline size_t output_width() const { 272*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 273*4bdc9457SAndroid Build Coastguard Worker return (input_width() + stride_width() - 1) / stride_width(); 274*4bdc9457SAndroid Build Coastguard Worker } else { 275*4bdc9457SAndroid Build Coastguard Worker const size_t padded_input_width = padding_left() + input_width() + padding_right(); 276*4bdc9457SAndroid Build Coastguard Worker if (padded_input_width <= pooling_width()) { 277*4bdc9457SAndroid Build Coastguard Worker return 1; 278*4bdc9457SAndroid Build Coastguard Worker } else { 279*4bdc9457SAndroid Build Coastguard Worker return (padded_input_width - pooling_width()) / stride_width() + 1; 280*4bdc9457SAndroid Build Coastguard Worker } 281*4bdc9457SAndroid Build Coastguard Worker } 282*4bdc9457SAndroid Build Coastguard Worker } 283*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride(size_t input_pixel_stride)284*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& input_pixel_stride(size_t input_pixel_stride) { 285*4bdc9457SAndroid Build Coastguard Worker assert(input_pixel_stride != 0); 286*4bdc9457SAndroid Build Coastguard Worker this->input_pixel_stride_ = input_pixel_stride; 287*4bdc9457SAndroid Build Coastguard Worker return *this; 288*4bdc9457SAndroid Build Coastguard Worker } 289*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride()290*4bdc9457SAndroid Build Coastguard Worker inline size_t input_pixel_stride() const { 291*4bdc9457SAndroid Build Coastguard Worker if (this->input_pixel_stride_ == 0) { 292*4bdc9457SAndroid Build Coastguard Worker return channels(); 293*4bdc9457SAndroid Build Coastguard Worker } else { 294*4bdc9457SAndroid Build Coastguard Worker assert(this->input_pixel_stride_ >= channels()); 295*4bdc9457SAndroid Build Coastguard Worker return this->input_pixel_stride_; 296*4bdc9457SAndroid Build Coastguard Worker } 297*4bdc9457SAndroid Build Coastguard Worker } 298*4bdc9457SAndroid Build Coastguard Worker output_pixel_stride(size_t output_pixel_stride)299*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& output_pixel_stride(size_t output_pixel_stride) { 300*4bdc9457SAndroid Build Coastguard Worker assert(output_pixel_stride != 0); 301*4bdc9457SAndroid Build Coastguard Worker this->output_pixel_stride_ = output_pixel_stride; 302*4bdc9457SAndroid Build Coastguard Worker return *this; 303*4bdc9457SAndroid Build Coastguard Worker } 304*4bdc9457SAndroid Build Coastguard Worker output_pixel_stride()305*4bdc9457SAndroid Build Coastguard Worker inline size_t output_pixel_stride() const { 306*4bdc9457SAndroid Build Coastguard Worker if (this->output_pixel_stride_ == 0) { 307*4bdc9457SAndroid Build Coastguard Worker return channels(); 308*4bdc9457SAndroid Build Coastguard Worker } else { 309*4bdc9457SAndroid Build Coastguard Worker assert(this->output_pixel_stride_ >= channels()); 310*4bdc9457SAndroid Build Coastguard Worker return this->output_pixel_stride_; 311*4bdc9457SAndroid Build Coastguard Worker } 312*4bdc9457SAndroid Build Coastguard Worker } 313*4bdc9457SAndroid Build Coastguard Worker next_input_size(uint32_t next_input_height,uint32_t next_input_width)314*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& next_input_size(uint32_t next_input_height, uint32_t next_input_width) { 315*4bdc9457SAndroid Build Coastguard Worker assert(next_input_height >= 1); 316*4bdc9457SAndroid Build Coastguard Worker assert(next_input_width >= 1); 317*4bdc9457SAndroid Build Coastguard Worker this->next_input_height_ = next_input_height; 318*4bdc9457SAndroid Build Coastguard Worker this->next_input_width_ = next_input_width; 319*4bdc9457SAndroid Build Coastguard Worker return *this; 320*4bdc9457SAndroid Build Coastguard Worker } 321*4bdc9457SAndroid Build Coastguard Worker next_input_height(uint32_t next_input_height)322*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& next_input_height(uint32_t next_input_height) { 323*4bdc9457SAndroid Build Coastguard Worker assert(next_input_height >= 1); 324*4bdc9457SAndroid Build Coastguard Worker this->next_input_height_ = next_input_height; 325*4bdc9457SAndroid Build Coastguard Worker return *this; 326*4bdc9457SAndroid Build Coastguard Worker } 327*4bdc9457SAndroid Build Coastguard Worker next_input_height()328*4bdc9457SAndroid Build Coastguard Worker inline uint32_t next_input_height() const { 329*4bdc9457SAndroid Build Coastguard Worker if (this->next_input_height_ == 0) { 330*4bdc9457SAndroid Build Coastguard Worker return input_height(); 331*4bdc9457SAndroid Build Coastguard Worker } else { 332*4bdc9457SAndroid Build Coastguard Worker return this->next_input_height_; 333*4bdc9457SAndroid Build Coastguard Worker } 334*4bdc9457SAndroid Build Coastguard Worker } 335*4bdc9457SAndroid Build Coastguard Worker next_input_width(uint32_t next_input_width)336*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& next_input_width(uint32_t next_input_width) { 337*4bdc9457SAndroid Build Coastguard Worker assert(next_input_width >= 1); 338*4bdc9457SAndroid Build Coastguard Worker this->next_input_width_ = next_input_width; 339*4bdc9457SAndroid Build Coastguard Worker return *this; 340*4bdc9457SAndroid Build Coastguard Worker } 341*4bdc9457SAndroid Build Coastguard Worker next_input_width()342*4bdc9457SAndroid Build Coastguard Worker inline uint32_t next_input_width() const { 343*4bdc9457SAndroid Build Coastguard Worker if (this->next_input_width_ == 0) { 344*4bdc9457SAndroid Build Coastguard Worker return input_width(); 345*4bdc9457SAndroid Build Coastguard Worker } else { 346*4bdc9457SAndroid Build Coastguard Worker return this->next_input_width_; 347*4bdc9457SAndroid Build Coastguard Worker } 348*4bdc9457SAndroid Build Coastguard Worker } 349*4bdc9457SAndroid Build Coastguard Worker next_output_height()350*4bdc9457SAndroid Build Coastguard Worker inline size_t next_output_height() const { 351*4bdc9457SAndroid Build Coastguard Worker const size_t padded_next_input_height = padding_top() + next_input_height() + padding_bottom(); 352*4bdc9457SAndroid Build Coastguard Worker if (padded_next_input_height <= pooling_height()) { 353*4bdc9457SAndroid Build Coastguard Worker return 1; 354*4bdc9457SAndroid Build Coastguard Worker } else { 355*4bdc9457SAndroid Build Coastguard Worker return (padded_next_input_height - pooling_height()) / stride_height() + 1; 356*4bdc9457SAndroid Build Coastguard Worker } 357*4bdc9457SAndroid Build Coastguard Worker } 358*4bdc9457SAndroid Build Coastguard Worker next_output_width()359*4bdc9457SAndroid Build Coastguard Worker inline size_t next_output_width() const { 360*4bdc9457SAndroid Build Coastguard Worker const size_t padded_next_input_width = padding_left() + next_input_width() + padding_right(); 361*4bdc9457SAndroid Build Coastguard Worker if (padded_next_input_width <= pooling_width()) { 362*4bdc9457SAndroid Build Coastguard Worker return 1; 363*4bdc9457SAndroid Build Coastguard Worker } else { 364*4bdc9457SAndroid Build Coastguard Worker return (padded_next_input_width - pooling_width()) / stride_width() + 1; 365*4bdc9457SAndroid Build Coastguard Worker } 366*4bdc9457SAndroid Build Coastguard Worker } 367*4bdc9457SAndroid Build Coastguard Worker next_batch_size(size_t next_batch_size)368*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& next_batch_size(size_t next_batch_size) { 369*4bdc9457SAndroid Build Coastguard Worker assert(next_batch_size >= 1); 370*4bdc9457SAndroid Build Coastguard Worker this->next_batch_size_ = next_batch_size; 371*4bdc9457SAndroid Build Coastguard Worker return *this; 372*4bdc9457SAndroid Build Coastguard Worker } 373*4bdc9457SAndroid Build Coastguard Worker next_batch_size()374*4bdc9457SAndroid Build Coastguard Worker inline size_t next_batch_size() const { 375*4bdc9457SAndroid Build Coastguard Worker if (this->next_batch_size_ == 0) { 376*4bdc9457SAndroid Build Coastguard Worker return batch_size(); 377*4bdc9457SAndroid Build Coastguard Worker } else { 378*4bdc9457SAndroid Build Coastguard Worker return this->next_batch_size_; 379*4bdc9457SAndroid Build Coastguard Worker } 380*4bdc9457SAndroid Build Coastguard Worker } 381*4bdc9457SAndroid Build Coastguard Worker input_scale(float input_scale)382*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& input_scale(float input_scale) { 383*4bdc9457SAndroid Build Coastguard Worker assert(input_scale > 0.0f); 384*4bdc9457SAndroid Build Coastguard Worker assert(std::isnormal(input_scale)); 385*4bdc9457SAndroid Build Coastguard Worker this->input_scale_ = input_scale; 386*4bdc9457SAndroid Build Coastguard Worker return *this; 387*4bdc9457SAndroid Build Coastguard Worker } 388*4bdc9457SAndroid Build Coastguard Worker input_scale()389*4bdc9457SAndroid Build Coastguard Worker inline float input_scale() const { 390*4bdc9457SAndroid Build Coastguard Worker return this->input_scale_; 391*4bdc9457SAndroid Build Coastguard Worker } 392*4bdc9457SAndroid Build Coastguard Worker input_zero_point(uint8_t input_zero_point)393*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& input_zero_point(uint8_t input_zero_point) { 394*4bdc9457SAndroid Build Coastguard Worker this->input_zero_point_ = input_zero_point; 395*4bdc9457SAndroid Build Coastguard Worker return *this; 396*4bdc9457SAndroid Build Coastguard Worker } 397*4bdc9457SAndroid Build Coastguard Worker input_zero_point()398*4bdc9457SAndroid Build Coastguard Worker inline uint8_t input_zero_point() const { 399*4bdc9457SAndroid Build Coastguard Worker return this->input_zero_point_; 400*4bdc9457SAndroid Build Coastguard Worker } 401*4bdc9457SAndroid Build Coastguard Worker output_scale(float output_scale)402*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& output_scale(float output_scale) { 403*4bdc9457SAndroid Build Coastguard Worker assert(output_scale > 0.0f); 404*4bdc9457SAndroid Build Coastguard Worker assert(std::isnormal(output_scale)); 405*4bdc9457SAndroid Build Coastguard Worker this->output_scale_ = output_scale; 406*4bdc9457SAndroid Build Coastguard Worker return *this; 407*4bdc9457SAndroid Build Coastguard Worker } 408*4bdc9457SAndroid Build Coastguard Worker output_scale()409*4bdc9457SAndroid Build Coastguard Worker inline float output_scale() const { 410*4bdc9457SAndroid Build Coastguard Worker return this->output_scale_; 411*4bdc9457SAndroid Build Coastguard Worker } 412*4bdc9457SAndroid Build Coastguard Worker output_zero_point(uint8_t output_zero_point)413*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& output_zero_point(uint8_t output_zero_point) { 414*4bdc9457SAndroid Build Coastguard Worker this->output_zero_point_ = output_zero_point; 415*4bdc9457SAndroid Build Coastguard Worker return *this; 416*4bdc9457SAndroid Build Coastguard Worker } 417*4bdc9457SAndroid Build Coastguard Worker output_zero_point()418*4bdc9457SAndroid Build Coastguard Worker inline uint8_t output_zero_point() const { 419*4bdc9457SAndroid Build Coastguard Worker return this->output_zero_point_; 420*4bdc9457SAndroid Build Coastguard Worker } 421*4bdc9457SAndroid Build Coastguard Worker qmin(uint8_t qmin)422*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& qmin(uint8_t qmin) { 423*4bdc9457SAndroid Build Coastguard Worker this->qmin_ = qmin; 424*4bdc9457SAndroid Build Coastguard Worker return *this; 425*4bdc9457SAndroid Build Coastguard Worker } 426*4bdc9457SAndroid Build Coastguard Worker qmin()427*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmin() const { 428*4bdc9457SAndroid Build Coastguard Worker return this->qmin_; 429*4bdc9457SAndroid Build Coastguard Worker } 430*4bdc9457SAndroid Build Coastguard Worker qmax(uint8_t qmax)431*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& qmax(uint8_t qmax) { 432*4bdc9457SAndroid Build Coastguard Worker this->qmax_ = qmax; 433*4bdc9457SAndroid Build Coastguard Worker return *this; 434*4bdc9457SAndroid Build Coastguard Worker } 435*4bdc9457SAndroid Build Coastguard Worker qmax()436*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmax() const { 437*4bdc9457SAndroid Build Coastguard Worker return this->qmax_; 438*4bdc9457SAndroid Build Coastguard Worker } 439*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)440*4bdc9457SAndroid Build Coastguard Worker inline AveragePoolingOperatorTester& iterations(size_t iterations) { 441*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 442*4bdc9457SAndroid Build Coastguard Worker return *this; 443*4bdc9457SAndroid Build Coastguard Worker } 444*4bdc9457SAndroid Build Coastguard Worker iterations()445*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 446*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 447*4bdc9457SAndroid Build Coastguard Worker } 448*4bdc9457SAndroid Build Coastguard Worker TestF16()449*4bdc9457SAndroid Build Coastguard Worker void TestF16() const { 450*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 451*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 452*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist; 453*4bdc9457SAndroid Build Coastguard Worker 454*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> input((batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 455*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels()); 456*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels()); 457*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 458*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 459*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 460*4bdc9457SAndroid Build Coastguard Worker 461*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 462*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 463*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 464*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 465*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 466*4bdc9457SAndroid Build Coastguard Worker float acc = 0.0f; 467*4bdc9457SAndroid Build Coastguard Worker int32_t n = 0; 468*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 469*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * stride_height() + py - padding_top(); 470*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 471*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * stride_width() + px - padding_left(); 472*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width() && iy < input_height()) { 473*4bdc9457SAndroid Build Coastguard Worker acc += fp16_ieee_to_fp32_value(input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c]); 474*4bdc9457SAndroid Build Coastguard Worker n += 1; 475*4bdc9457SAndroid Build Coastguard Worker } 476*4bdc9457SAndroid Build Coastguard Worker } 477*4bdc9457SAndroid Build Coastguard Worker } 478*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = acc / float(n); 479*4bdc9457SAndroid Build Coastguard Worker } 480*4bdc9457SAndroid Build Coastguard Worker } 481*4bdc9457SAndroid Build Coastguard Worker } 482*4bdc9457SAndroid Build Coastguard Worker } 483*4bdc9457SAndroid Build Coastguard Worker 484*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 485*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 486*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 487*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 488*4bdc9457SAndroid Build Coastguard Worker float output_min = accumulated_min + accumulated_range / 255.0f * float(qmin()); 489*4bdc9457SAndroid Build Coastguard Worker float output_max = accumulated_max - accumulated_range / 255.0f * float(255 - qmax()); 490*4bdc9457SAndroid Build Coastguard Worker output_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_min)); 491*4bdc9457SAndroid Build Coastguard Worker output_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_max)); 492*4bdc9457SAndroid Build Coastguard Worker if (accumulated_range == 0.0f) { 493*4bdc9457SAndroid Build Coastguard Worker output_min = -std::numeric_limits<float>::infinity(); 494*4bdc9457SAndroid Build Coastguard Worker output_max = +std::numeric_limits<float>::infinity(); 495*4bdc9457SAndroid Build Coastguard Worker } 496*4bdc9457SAndroid Build Coastguard Worker if (qmin() == std::numeric_limits<uint8_t>::min()) { 497*4bdc9457SAndroid Build Coastguard Worker output_min = -std::numeric_limits<float>::infinity(); 498*4bdc9457SAndroid Build Coastguard Worker } 499*4bdc9457SAndroid Build Coastguard Worker if (qmax() == std::numeric_limits<uint8_t>::max()) { 500*4bdc9457SAndroid Build Coastguard Worker output_max = +std::numeric_limits<float>::infinity(); 501*4bdc9457SAndroid Build Coastguard Worker } 502*4bdc9457SAndroid Build Coastguard Worker 503*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 504*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) { 505*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min); 506*4bdc9457SAndroid Build Coastguard Worker } 507*4bdc9457SAndroid Build Coastguard Worker 508*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Average Pooling operator. 509*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 510*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t average_pooling_op = nullptr; 511*4bdc9457SAndroid Build Coastguard Worker 512*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_create_average_pooling2d_nhwc_f16( 513*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(), 514*4bdc9457SAndroid Build Coastguard Worker pooling_height(), pooling_width(), 515*4bdc9457SAndroid Build Coastguard Worker stride_height(), stride_width(), 516*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 517*4bdc9457SAndroid Build Coastguard Worker output_min, output_max, 518*4bdc9457SAndroid Build Coastguard Worker 0, &average_pooling_op); 519*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 520*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 521*4bdc9457SAndroid Build Coastguard Worker } 522*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 523*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, average_pooling_op); 524*4bdc9457SAndroid Build Coastguard Worker 525*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete average_pooling_op. 526*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_average_pooling_op(average_pooling_op, xnn_delete_operator); 527*4bdc9457SAndroid Build Coastguard Worker 528*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 529*4bdc9457SAndroid Build Coastguard Worker xnn_setup_average_pooling2d_nhwc_f16( 530*4bdc9457SAndroid Build Coastguard Worker average_pooling_op, 531*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 532*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 533*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 534*4bdc9457SAndroid Build Coastguard Worker 535*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 536*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(average_pooling_op, nullptr /* thread pool */)); 537*4bdc9457SAndroid Build Coastguard Worker 538*4bdc9457SAndroid Build Coastguard Worker // Verify results. 539*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 540*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 541*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 542*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 543*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), output_max); 544*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), output_min); 545*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 546*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), 547*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + y) * output_width() + x) * channels() + c], 548*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-3f, std::abs(output_ref[((i * output_height() + y) * output_width() + x) * channels() + c]) * 1.0e-2f)) << 549*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 550*4bdc9457SAndroid Build Coastguard Worker } 551*4bdc9457SAndroid Build Coastguard Worker } 552*4bdc9457SAndroid Build Coastguard Worker } 553*4bdc9457SAndroid Build Coastguard Worker } 554*4bdc9457SAndroid Build Coastguard Worker } 555*4bdc9457SAndroid Build Coastguard Worker } 556*4bdc9457SAndroid Build Coastguard Worker TestF32()557*4bdc9457SAndroid Build Coastguard Worker void TestF32() const { 558*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 559*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 560*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist; 561*4bdc9457SAndroid Build Coastguard Worker 562*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input((batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float)); 563*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels()); 564*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels()); 565*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 566*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 567*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), std::nanf("")); 568*4bdc9457SAndroid Build Coastguard Worker 569*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 570*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 571*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 572*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 573*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 574*4bdc9457SAndroid Build Coastguard Worker float acc = 0.0f; 575*4bdc9457SAndroid Build Coastguard Worker int32_t n = 0; 576*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 577*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * stride_height() + py - padding_top(); 578*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 579*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * stride_width() + px - padding_left(); 580*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width() && iy < input_height()) { 581*4bdc9457SAndroid Build Coastguard Worker acc += input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c]; 582*4bdc9457SAndroid Build Coastguard Worker n += 1; 583*4bdc9457SAndroid Build Coastguard Worker } 584*4bdc9457SAndroid Build Coastguard Worker } 585*4bdc9457SAndroid Build Coastguard Worker } 586*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = acc / float(n); 587*4bdc9457SAndroid Build Coastguard Worker } 588*4bdc9457SAndroid Build Coastguard Worker } 589*4bdc9457SAndroid Build Coastguard Worker } 590*4bdc9457SAndroid Build Coastguard Worker } 591*4bdc9457SAndroid Build Coastguard Worker 592*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 593*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 594*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 595*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 596*4bdc9457SAndroid Build Coastguard Worker const float output_min = accumulated_range == 0.0f ? 597*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<float>::infinity() : 598*4bdc9457SAndroid Build Coastguard Worker accumulated_min + accumulated_range / 255.0f * float(qmin()); 599*4bdc9457SAndroid Build Coastguard Worker const float output_max = accumulated_range == 0.0f ? 600*4bdc9457SAndroid Build Coastguard Worker +std::numeric_limits<float>::infinity() : 601*4bdc9457SAndroid Build Coastguard Worker accumulated_max - accumulated_range / 255.0f * float(255 - qmax()); 602*4bdc9457SAndroid Build Coastguard Worker 603*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 604*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) { 605*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min); 606*4bdc9457SAndroid Build Coastguard Worker } 607*4bdc9457SAndroid Build Coastguard Worker 608*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Average Pooling operator. 609*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 610*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t average_pooling_op = nullptr; 611*4bdc9457SAndroid Build Coastguard Worker 612*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 613*4bdc9457SAndroid Build Coastguard Worker xnn_create_average_pooling2d_nhwc_f32( 614*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(), 615*4bdc9457SAndroid Build Coastguard Worker pooling_height(), pooling_width(), 616*4bdc9457SAndroid Build Coastguard Worker stride_height(), stride_width(), 617*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 618*4bdc9457SAndroid Build Coastguard Worker output_min, output_max, 619*4bdc9457SAndroid Build Coastguard Worker 0, &average_pooling_op)); 620*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, average_pooling_op); 621*4bdc9457SAndroid Build Coastguard Worker 622*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete average_pooling_op. 623*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_average_pooling_op(average_pooling_op, xnn_delete_operator); 624*4bdc9457SAndroid Build Coastguard Worker 625*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 626*4bdc9457SAndroid Build Coastguard Worker xnn_setup_average_pooling2d_nhwc_f32( 627*4bdc9457SAndroid Build Coastguard Worker average_pooling_op, 628*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 629*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 630*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 631*4bdc9457SAndroid Build Coastguard Worker 632*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 633*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(average_pooling_op, nullptr /* thread pool */)); 634*4bdc9457SAndroid Build Coastguard Worker 635*4bdc9457SAndroid Build Coastguard Worker // Verify results. 636*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 637*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 638*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 639*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 640*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c], output_max); 641*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c], output_min); 642*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c], 643*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + y) * output_width() + x) * channels() + c], 644*4bdc9457SAndroid Build Coastguard Worker std::abs(output_ref[((i * output_height() + y) * output_width() + x) * channels() + c]) * 1.0e-6f) << 645*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 646*4bdc9457SAndroid Build Coastguard Worker } 647*4bdc9457SAndroid Build Coastguard Worker } 648*4bdc9457SAndroid Build Coastguard Worker } 649*4bdc9457SAndroid Build Coastguard Worker } 650*4bdc9457SAndroid Build Coastguard Worker } 651*4bdc9457SAndroid Build Coastguard Worker } 652*4bdc9457SAndroid Build Coastguard Worker TestQU8()653*4bdc9457SAndroid Build Coastguard Worker void TestQU8() const { 654*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 655*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 656*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> u8dist( 657*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max()); 658*4bdc9457SAndroid Build Coastguard Worker 659*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input((batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint8_t)); 660*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels()); 661*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels()); 662*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 663*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); 664*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT8_C(0xA5)); 665*4bdc9457SAndroid Build Coastguard Worker 666*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 667*4bdc9457SAndroid Build Coastguard Worker const double scale = double(input_scale()) / (double(output_scale()) * double(pooling_height() * pooling_width())); 668*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 669*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 670*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 671*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 672*4bdc9457SAndroid Build Coastguard Worker double acc = 0.0f; 673*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 674*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * stride_height() + py - padding_top(); 675*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 676*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * stride_width() + px - padding_left(); 677*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width() && iy < input_height()) { 678*4bdc9457SAndroid Build Coastguard Worker acc += double(int32_t(input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c]) - int32_t(input_zero_point())); 679*4bdc9457SAndroid Build Coastguard Worker } 680*4bdc9457SAndroid Build Coastguard Worker } 681*4bdc9457SAndroid Build Coastguard Worker } 682*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = float(acc * scale + double(output_zero_point())); 683*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = 684*4bdc9457SAndroid Build Coastguard Worker std::min<float>(output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c], float(qmax())); 685*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = 686*4bdc9457SAndroid Build Coastguard Worker std::max<float>(output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c], float(qmin())); 687*4bdc9457SAndroid Build Coastguard Worker } 688*4bdc9457SAndroid Build Coastguard Worker } 689*4bdc9457SAndroid Build Coastguard Worker } 690*4bdc9457SAndroid Build Coastguard Worker } 691*4bdc9457SAndroid Build Coastguard Worker 692*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Average Pooling operator. 693*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 694*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t average_pooling_op = nullptr; 695*4bdc9457SAndroid Build Coastguard Worker 696*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 697*4bdc9457SAndroid Build Coastguard Worker xnn_create_average_pooling2d_nhwc_qu8( 698*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(), 699*4bdc9457SAndroid Build Coastguard Worker pooling_height(), pooling_width(), 700*4bdc9457SAndroid Build Coastguard Worker stride_height(), stride_width(), 701*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 702*4bdc9457SAndroid Build Coastguard Worker input_zero_point(), input_scale(), 703*4bdc9457SAndroid Build Coastguard Worker output_zero_point(), output_scale(), 704*4bdc9457SAndroid Build Coastguard Worker qmin(), qmax(), 705*4bdc9457SAndroid Build Coastguard Worker 0, &average_pooling_op)); 706*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, average_pooling_op); 707*4bdc9457SAndroid Build Coastguard Worker 708*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete average_pooling_op. 709*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_average_pooling_op(average_pooling_op, xnn_delete_operator); 710*4bdc9457SAndroid Build Coastguard Worker 711*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 712*4bdc9457SAndroid Build Coastguard Worker xnn_setup_average_pooling2d_nhwc_qu8( 713*4bdc9457SAndroid Build Coastguard Worker average_pooling_op, 714*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 715*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 716*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 717*4bdc9457SAndroid Build Coastguard Worker 718*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 719*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(average_pooling_op, nullptr /* thread pool */)); 720*4bdc9457SAndroid Build Coastguard Worker 721*4bdc9457SAndroid Build Coastguard Worker // Verify results. 722*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 723*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 724*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 725*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 726*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(uint32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), uint32_t(qmax())); 727*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(uint32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), uint32_t(qmin())); 728*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(float(int32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c])), 729*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + y) * output_width() + x) * channels() + c], 0.80f) << 730*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 731*4bdc9457SAndroid Build Coastguard Worker } 732*4bdc9457SAndroid Build Coastguard Worker } 733*4bdc9457SAndroid Build Coastguard Worker } 734*4bdc9457SAndroid Build Coastguard Worker } 735*4bdc9457SAndroid Build Coastguard Worker } 736*4bdc9457SAndroid Build Coastguard Worker } 737*4bdc9457SAndroid Build Coastguard Worker TestSetupF16()738*4bdc9457SAndroid Build Coastguard Worker void TestSetupF16() const { 739*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 740*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 741*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist; 742*4bdc9457SAndroid Build Coastguard Worker 743*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + std::max<size_t>( 744*4bdc9457SAndroid Build Coastguard Worker (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels(), 745*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_input_height() * next_input_width() - 1) * input_pixel_stride() + channels())); 746*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output(std::max<size_t>( 747*4bdc9457SAndroid Build Coastguard Worker (batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels(), 748*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_output_height() * next_output_width() - 1) * output_pixel_stride() + channels())); 749*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels()); 750*4bdc9457SAndroid Build Coastguard Worker std::vector<float> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * channels()); 751*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 752*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 753*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 754*4bdc9457SAndroid Build Coastguard Worker 755*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 756*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 757*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 758*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 759*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 760*4bdc9457SAndroid Build Coastguard Worker float acc = 0.0f; 761*4bdc9457SAndroid Build Coastguard Worker size_t n = 0; 762*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 763*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * stride_height() + py - padding_top(); 764*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 765*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * stride_width() + px - padding_left(); 766*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width() && iy < input_height()) { 767*4bdc9457SAndroid Build Coastguard Worker acc += fp16_ieee_to_fp32_value(input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c]); 768*4bdc9457SAndroid Build Coastguard Worker n += 1; 769*4bdc9457SAndroid Build Coastguard Worker } 770*4bdc9457SAndroid Build Coastguard Worker } 771*4bdc9457SAndroid Build Coastguard Worker } 772*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = acc / float(n); 773*4bdc9457SAndroid Build Coastguard Worker } 774*4bdc9457SAndroid Build Coastguard Worker } 775*4bdc9457SAndroid Build Coastguard Worker } 776*4bdc9457SAndroid Build Coastguard Worker } 777*4bdc9457SAndroid Build Coastguard Worker 778*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 779*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 780*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 781*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 782*4bdc9457SAndroid Build Coastguard Worker float output_min = accumulated_min + accumulated_range / 255.0f * float(qmin()); 783*4bdc9457SAndroid Build Coastguard Worker float output_max = accumulated_max - accumulated_range / 255.0f * float(255 - qmax()); 784*4bdc9457SAndroid Build Coastguard Worker output_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_min)); 785*4bdc9457SAndroid Build Coastguard Worker output_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_max)); 786*4bdc9457SAndroid Build Coastguard Worker if (accumulated_range == 0.0f) { 787*4bdc9457SAndroid Build Coastguard Worker output_min = -std::numeric_limits<float>::infinity(); 788*4bdc9457SAndroid Build Coastguard Worker output_max = +std::numeric_limits<float>::infinity(); 789*4bdc9457SAndroid Build Coastguard Worker } 790*4bdc9457SAndroid Build Coastguard Worker if (qmin() == std::numeric_limits<uint8_t>::min()) { 791*4bdc9457SAndroid Build Coastguard Worker output_min = -std::numeric_limits<float>::infinity(); 792*4bdc9457SAndroid Build Coastguard Worker } 793*4bdc9457SAndroid Build Coastguard Worker if (qmax() == std::numeric_limits<uint8_t>::max()) { 794*4bdc9457SAndroid Build Coastguard Worker output_max = +std::numeric_limits<float>::infinity(); 795*4bdc9457SAndroid Build Coastguard Worker } 796*4bdc9457SAndroid Build Coastguard Worker 797*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 798*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) { 799*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min); 800*4bdc9457SAndroid Build Coastguard Worker } 801*4bdc9457SAndroid Build Coastguard Worker 802*4bdc9457SAndroid Build Coastguard Worker // Create, setup, and run Average Pooling operator once. 803*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 804*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t average_pooling_op = nullptr; 805*4bdc9457SAndroid Build Coastguard Worker 806*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_create_average_pooling2d_nhwc_f16( 807*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(), 808*4bdc9457SAndroid Build Coastguard Worker pooling_height(), pooling_width(), 809*4bdc9457SAndroid Build Coastguard Worker stride_height(), stride_width(), 810*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 811*4bdc9457SAndroid Build Coastguard Worker output_min, output_max, 812*4bdc9457SAndroid Build Coastguard Worker 0, &average_pooling_op); 813*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 814*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 815*4bdc9457SAndroid Build Coastguard Worker } 816*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 817*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, average_pooling_op); 818*4bdc9457SAndroid Build Coastguard Worker 819*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 820*4bdc9457SAndroid Build Coastguard Worker xnn_setup_average_pooling2d_nhwc_f16( 821*4bdc9457SAndroid Build Coastguard Worker average_pooling_op, 822*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 823*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 824*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 825*4bdc9457SAndroid Build Coastguard Worker 826*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 827*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(average_pooling_op, nullptr /* thread pool */)); 828*4bdc9457SAndroid Build Coastguard Worker 829*4bdc9457SAndroid Build Coastguard Worker // Verify results of the first run. 830*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 831*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 832*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 833*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 834*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), output_max); 835*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), output_min); 836*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 837*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), 838*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + y) * output_width() + x) * channels() + c], 839*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-3f, std::abs(output_ref[((i * output_height() + y) * output_width() + x) * channels() + c]) * 1.0e-2f)) << 840*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 841*4bdc9457SAndroid Build Coastguard Worker } 842*4bdc9457SAndroid Build Coastguard Worker } 843*4bdc9457SAndroid Build Coastguard Worker } 844*4bdc9457SAndroid Build Coastguard Worker } 845*4bdc9457SAndroid Build Coastguard Worker 846*4bdc9457SAndroid Build Coastguard Worker // Re-generate data for the second run. 847*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 848*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 849*4bdc9457SAndroid Build Coastguard Worker 850*4bdc9457SAndroid Build Coastguard Worker // Compute reference results for the second run. 851*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 852*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) { 853*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) { 854*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 855*4bdc9457SAndroid Build Coastguard Worker float acc = 0.0f; 856*4bdc9457SAndroid Build Coastguard Worker int32_t n = 0; 857*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 858*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * stride_height() + py - padding_top(); 859*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 860*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * stride_width() + px - padding_left(); 861*4bdc9457SAndroid Build Coastguard Worker if (ix < next_input_width() && iy < next_input_height()) { 862*4bdc9457SAndroid Build Coastguard Worker acc += fp16_ieee_to_fp32_value(input[((i * next_input_height() + iy) * next_input_width() + ix) * input_pixel_stride() + c]); 863*4bdc9457SAndroid Build Coastguard Worker n += 1; 864*4bdc9457SAndroid Build Coastguard Worker } 865*4bdc9457SAndroid Build Coastguard Worker } 866*4bdc9457SAndroid Build Coastguard Worker } 867*4bdc9457SAndroid Build Coastguard Worker next_output_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c] = 868*4bdc9457SAndroid Build Coastguard Worker std::max(std::min(acc / float(n), output_max), output_min); 869*4bdc9457SAndroid Build Coastguard Worker } 870*4bdc9457SAndroid Build Coastguard Worker } 871*4bdc9457SAndroid Build Coastguard Worker } 872*4bdc9457SAndroid Build Coastguard Worker } 873*4bdc9457SAndroid Build Coastguard Worker 874*4bdc9457SAndroid Build Coastguard Worker // Setup and run Average Pooling operator the second time, and destroy the operator. 875*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 876*4bdc9457SAndroid Build Coastguard Worker xnn_setup_average_pooling2d_nhwc_f16( 877*4bdc9457SAndroid Build Coastguard Worker average_pooling_op, 878*4bdc9457SAndroid Build Coastguard Worker next_batch_size(), next_input_height(), next_input_width(), 879*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 880*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 881*4bdc9457SAndroid Build Coastguard Worker 882*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 883*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(average_pooling_op, nullptr /* thread pool */)); 884*4bdc9457SAndroid Build Coastguard Worker 885*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 886*4bdc9457SAndroid Build Coastguard Worker xnn_delete_operator(average_pooling_op)); 887*4bdc9457SAndroid Build Coastguard Worker average_pooling_op = nullptr; 888*4bdc9457SAndroid Build Coastguard Worker 889*4bdc9457SAndroid Build Coastguard Worker // Verify results of the second run. 890*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 891*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < next_output_height(); y++) { 892*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < next_output_width(); x++) { 893*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 894*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(fp16_ieee_to_fp32_value(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c]), output_max); 895*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(fp16_ieee_to_fp32_value(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c]), output_min); 896*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 897*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c]), 898*4bdc9457SAndroid Build Coastguard Worker next_output_ref[((i * next_output_height() + y) * next_output_width() + x) * channels() + c], 899*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-3f, std::abs(next_output_ref[((i * next_output_height() + y) * next_output_width() + x) * channels() + c]) * 1.0e-2f)) << 900*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 901*4bdc9457SAndroid Build Coastguard Worker } 902*4bdc9457SAndroid Build Coastguard Worker } 903*4bdc9457SAndroid Build Coastguard Worker } 904*4bdc9457SAndroid Build Coastguard Worker } 905*4bdc9457SAndroid Build Coastguard Worker } 906*4bdc9457SAndroid Build Coastguard Worker } 907*4bdc9457SAndroid Build Coastguard Worker TestSetupF32()908*4bdc9457SAndroid Build Coastguard Worker void TestSetupF32() const { 909*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 910*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 911*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist; 912*4bdc9457SAndroid Build Coastguard Worker 913*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + std::max<size_t>( 914*4bdc9457SAndroid Build Coastguard Worker (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels(), 915*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_input_height() * next_input_width() - 1) * input_pixel_stride() + channels())); 916*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output(std::max<size_t>( 917*4bdc9457SAndroid Build Coastguard Worker (batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels(), 918*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_output_height() * next_output_width() - 1) * output_pixel_stride() + channels())); 919*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels()); 920*4bdc9457SAndroid Build Coastguard Worker std::vector<float> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * channels()); 921*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 922*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 923*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), std::nanf("")); 924*4bdc9457SAndroid Build Coastguard Worker 925*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 926*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 927*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 928*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 929*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 930*4bdc9457SAndroid Build Coastguard Worker float acc = 0.0f; 931*4bdc9457SAndroid Build Coastguard Worker size_t n = 0; 932*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 933*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * stride_height() + py - padding_top(); 934*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 935*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * stride_width() + px - padding_left(); 936*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width() && iy < input_height()) { 937*4bdc9457SAndroid Build Coastguard Worker acc += input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c]; 938*4bdc9457SAndroid Build Coastguard Worker n += 1; 939*4bdc9457SAndroid Build Coastguard Worker } 940*4bdc9457SAndroid Build Coastguard Worker } 941*4bdc9457SAndroid Build Coastguard Worker } 942*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = acc / float(n); 943*4bdc9457SAndroid Build Coastguard Worker } 944*4bdc9457SAndroid Build Coastguard Worker } 945*4bdc9457SAndroid Build Coastguard Worker } 946*4bdc9457SAndroid Build Coastguard Worker } 947*4bdc9457SAndroid Build Coastguard Worker 948*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 949*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 950*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 951*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 952*4bdc9457SAndroid Build Coastguard Worker const float output_min = accumulated_range == 0.0f ? 953*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<float>::infinity() : 954*4bdc9457SAndroid Build Coastguard Worker accumulated_min + accumulated_range / 255.0f * float(qmin()); 955*4bdc9457SAndroid Build Coastguard Worker const float output_max = accumulated_range == 0.0f ? 956*4bdc9457SAndroid Build Coastguard Worker +std::numeric_limits<float>::infinity() : 957*4bdc9457SAndroid Build Coastguard Worker accumulated_max - accumulated_range / 255.0f * float(255 - qmax()); 958*4bdc9457SAndroid Build Coastguard Worker 959*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 960*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) { 961*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min); 962*4bdc9457SAndroid Build Coastguard Worker } 963*4bdc9457SAndroid Build Coastguard Worker 964*4bdc9457SAndroid Build Coastguard Worker // Create, setup, and run Average Pooling operator once. 965*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 966*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t average_pooling_op = nullptr; 967*4bdc9457SAndroid Build Coastguard Worker 968*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 969*4bdc9457SAndroid Build Coastguard Worker xnn_create_average_pooling2d_nhwc_f32( 970*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(), 971*4bdc9457SAndroid Build Coastguard Worker pooling_height(), pooling_width(), 972*4bdc9457SAndroid Build Coastguard Worker stride_height(), stride_width(), 973*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 974*4bdc9457SAndroid Build Coastguard Worker output_min, output_max, 975*4bdc9457SAndroid Build Coastguard Worker 0, &average_pooling_op)); 976*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, average_pooling_op); 977*4bdc9457SAndroid Build Coastguard Worker 978*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 979*4bdc9457SAndroid Build Coastguard Worker xnn_setup_average_pooling2d_nhwc_f32( 980*4bdc9457SAndroid Build Coastguard Worker average_pooling_op, 981*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 982*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 983*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 984*4bdc9457SAndroid Build Coastguard Worker 985*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 986*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(average_pooling_op, nullptr /* thread pool */)); 987*4bdc9457SAndroid Build Coastguard Worker 988*4bdc9457SAndroid Build Coastguard Worker // Verify results of the first run. 989*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 990*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 991*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 992*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 993*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c], output_max); 994*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c], output_min); 995*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c], 996*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + y) * output_width() + x) * channels() + c], 997*4bdc9457SAndroid Build Coastguard Worker std::abs(output_ref[((i * output_height() + y) * output_width() + x) * channels() + c]) * 1.0e-6f) << 998*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 999*4bdc9457SAndroid Build Coastguard Worker } 1000*4bdc9457SAndroid Build Coastguard Worker } 1001*4bdc9457SAndroid Build Coastguard Worker } 1002*4bdc9457SAndroid Build Coastguard Worker } 1003*4bdc9457SAndroid Build Coastguard Worker 1004*4bdc9457SAndroid Build Coastguard Worker // Re-generate data for the second run. 1005*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 1006*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), std::nanf("")); 1007*4bdc9457SAndroid Build Coastguard Worker 1008*4bdc9457SAndroid Build Coastguard Worker // Compute reference results for the second run. 1009*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 1010*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) { 1011*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) { 1012*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 1013*4bdc9457SAndroid Build Coastguard Worker float acc = 0.0f; 1014*4bdc9457SAndroid Build Coastguard Worker int32_t n = 0; 1015*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 1016*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * stride_height() + py - padding_top(); 1017*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 1018*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * stride_width() + px - padding_left(); 1019*4bdc9457SAndroid Build Coastguard Worker if (ix < next_input_width() && iy < next_input_height()) { 1020*4bdc9457SAndroid Build Coastguard Worker acc += input[((i * next_input_height() + iy) * next_input_width() + ix) * input_pixel_stride() + c]; 1021*4bdc9457SAndroid Build Coastguard Worker n += 1; 1022*4bdc9457SAndroid Build Coastguard Worker } 1023*4bdc9457SAndroid Build Coastguard Worker } 1024*4bdc9457SAndroid Build Coastguard Worker } 1025*4bdc9457SAndroid Build Coastguard Worker next_output_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c] = 1026*4bdc9457SAndroid Build Coastguard Worker std::max(std::min(acc / float(n), output_max), output_min); 1027*4bdc9457SAndroid Build Coastguard Worker } 1028*4bdc9457SAndroid Build Coastguard Worker } 1029*4bdc9457SAndroid Build Coastguard Worker } 1030*4bdc9457SAndroid Build Coastguard Worker } 1031*4bdc9457SAndroid Build Coastguard Worker 1032*4bdc9457SAndroid Build Coastguard Worker // Setup and run Average Pooling operator the second time, and destroy the operator. 1033*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1034*4bdc9457SAndroid Build Coastguard Worker xnn_setup_average_pooling2d_nhwc_f32( 1035*4bdc9457SAndroid Build Coastguard Worker average_pooling_op, 1036*4bdc9457SAndroid Build Coastguard Worker next_batch_size(), next_input_height(), next_input_width(), 1037*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 1038*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 1039*4bdc9457SAndroid Build Coastguard Worker 1040*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1041*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(average_pooling_op, nullptr /* thread pool */)); 1042*4bdc9457SAndroid Build Coastguard Worker 1043*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1044*4bdc9457SAndroid Build Coastguard Worker xnn_delete_operator(average_pooling_op)); 1045*4bdc9457SAndroid Build Coastguard Worker average_pooling_op = nullptr; 1046*4bdc9457SAndroid Build Coastguard Worker 1047*4bdc9457SAndroid Build Coastguard Worker // Verify results of the second run. 1048*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 1049*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < next_output_height(); y++) { 1050*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < next_output_width(); x++) { 1051*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 1052*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c], output_max); 1053*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c], output_min); 1054*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c], 1055*4bdc9457SAndroid Build Coastguard Worker next_output_ref[((i * next_output_height() + y) * next_output_width() + x) * channels() + c], 1056*4bdc9457SAndroid Build Coastguard Worker std::abs(next_output_ref[((i * next_output_height() + y) * next_output_width() + x) * channels() + c]) * 1.0e-6f) << 1057*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 1058*4bdc9457SAndroid Build Coastguard Worker } 1059*4bdc9457SAndroid Build Coastguard Worker } 1060*4bdc9457SAndroid Build Coastguard Worker } 1061*4bdc9457SAndroid Build Coastguard Worker } 1062*4bdc9457SAndroid Build Coastguard Worker } 1063*4bdc9457SAndroid Build Coastguard Worker } 1064*4bdc9457SAndroid Build Coastguard Worker TestSetupQU8()1065*4bdc9457SAndroid Build Coastguard Worker void TestSetupQU8() const { 1066*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 1067*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 1068*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> u8dist( 1069*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max()); 1070*4bdc9457SAndroid Build Coastguard Worker 1071*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + std::max<size_t>( 1072*4bdc9457SAndroid Build Coastguard Worker (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels(), 1073*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_input_height() * next_input_width() - 1) * input_pixel_stride() + channels())); 1074*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output(std::max<size_t>( 1075*4bdc9457SAndroid Build Coastguard Worker (batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels(), 1076*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_output_height() * next_output_width() - 1) * output_pixel_stride() + channels())); 1077*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels()); 1078*4bdc9457SAndroid Build Coastguard Worker std::vector<float> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * channels()); 1079*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 1080*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); 1081*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), INT8_C(0xA5)); 1082*4bdc9457SAndroid Build Coastguard Worker 1083*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 1084*4bdc9457SAndroid Build Coastguard Worker const double scale = double(input_scale()) / (double(output_scale()) * double(pooling_height() * pooling_width())); 1085*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1086*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 1087*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 1088*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 1089*4bdc9457SAndroid Build Coastguard Worker double acc = 0.0f; 1090*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 1091*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * stride_height() + py - padding_top(); 1092*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 1093*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * stride_width() + px - padding_left(); 1094*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width() && iy < input_height()) { 1095*4bdc9457SAndroid Build Coastguard Worker acc += double(int32_t(input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c]) - int32_t(input_zero_point())); 1096*4bdc9457SAndroid Build Coastguard Worker } 1097*4bdc9457SAndroid Build Coastguard Worker } 1098*4bdc9457SAndroid Build Coastguard Worker } 1099*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = float(acc * scale + double(output_zero_point())); 1100*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = 1101*4bdc9457SAndroid Build Coastguard Worker std::min<float>(output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c], float(qmax())); 1102*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = 1103*4bdc9457SAndroid Build Coastguard Worker std::max<float>(output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c], float(qmin())); 1104*4bdc9457SAndroid Build Coastguard Worker } 1105*4bdc9457SAndroid Build Coastguard Worker } 1106*4bdc9457SAndroid Build Coastguard Worker } 1107*4bdc9457SAndroid Build Coastguard Worker } 1108*4bdc9457SAndroid Build Coastguard Worker 1109*4bdc9457SAndroid Build Coastguard Worker // Create, setup, and run Average Pooling operator once. 1110*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 1111*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t average_pooling_op = nullptr; 1112*4bdc9457SAndroid Build Coastguard Worker 1113*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1114*4bdc9457SAndroid Build Coastguard Worker xnn_create_average_pooling2d_nhwc_qu8( 1115*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(), 1116*4bdc9457SAndroid Build Coastguard Worker pooling_height(), pooling_width(), 1117*4bdc9457SAndroid Build Coastguard Worker stride_height(), stride_width(), 1118*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 1119*4bdc9457SAndroid Build Coastguard Worker input_zero_point(), input_scale(), 1120*4bdc9457SAndroid Build Coastguard Worker output_zero_point(), output_scale(), 1121*4bdc9457SAndroid Build Coastguard Worker qmin(), qmax(), 1122*4bdc9457SAndroid Build Coastguard Worker 0, &average_pooling_op)); 1123*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, average_pooling_op); 1124*4bdc9457SAndroid Build Coastguard Worker 1125*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1126*4bdc9457SAndroid Build Coastguard Worker xnn_setup_average_pooling2d_nhwc_qu8( 1127*4bdc9457SAndroid Build Coastguard Worker average_pooling_op, 1128*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 1129*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 1130*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 1131*4bdc9457SAndroid Build Coastguard Worker 1132*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1133*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(average_pooling_op, nullptr /* thread pool */)); 1134*4bdc9457SAndroid Build Coastguard Worker 1135*4bdc9457SAndroid Build Coastguard Worker // Verify results of the first run. 1136*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1137*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 1138*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 1139*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 1140*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(uint32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), uint32_t(qmax())); 1141*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(uint32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), uint32_t(qmin())); 1142*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(float(int32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c])), 1143*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + y) * output_width() + x) * channels() + c], 0.80f) << 1144*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 1145*4bdc9457SAndroid Build Coastguard Worker } 1146*4bdc9457SAndroid Build Coastguard Worker } 1147*4bdc9457SAndroid Build Coastguard Worker } 1148*4bdc9457SAndroid Build Coastguard Worker } 1149*4bdc9457SAndroid Build Coastguard Worker 1150*4bdc9457SAndroid Build Coastguard Worker // Re-generate data for the second run. 1151*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); 1152*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT8_C(0xA5)); 1153*4bdc9457SAndroid Build Coastguard Worker 1154*4bdc9457SAndroid Build Coastguard Worker // Compute reference results for the second run. 1155*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 1156*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) { 1157*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) { 1158*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 1159*4bdc9457SAndroid Build Coastguard Worker double acc = 0.0f; 1160*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 1161*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * stride_height() + py - padding_top(); 1162*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 1163*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * stride_width() + px - padding_left(); 1164*4bdc9457SAndroid Build Coastguard Worker if (ix < next_input_width() && iy < next_input_height()) { 1165*4bdc9457SAndroid Build Coastguard Worker acc += double(int32_t(input[((i * next_input_height() + iy) * next_input_width() + ix) * input_pixel_stride() + c]) - int32_t(input_zero_point())); 1166*4bdc9457SAndroid Build Coastguard Worker } 1167*4bdc9457SAndroid Build Coastguard Worker } 1168*4bdc9457SAndroid Build Coastguard Worker } 1169*4bdc9457SAndroid Build Coastguard Worker next_output_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c] = float(acc * scale + double(output_zero_point())); 1170*4bdc9457SAndroid Build Coastguard Worker next_output_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c] = 1171*4bdc9457SAndroid Build Coastguard Worker std::min<float>(next_output_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c], float(qmax())); 1172*4bdc9457SAndroid Build Coastguard Worker next_output_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c] = 1173*4bdc9457SAndroid Build Coastguard Worker std::max<float>(next_output_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c], float(qmin())); 1174*4bdc9457SAndroid Build Coastguard Worker } 1175*4bdc9457SAndroid Build Coastguard Worker } 1176*4bdc9457SAndroid Build Coastguard Worker } 1177*4bdc9457SAndroid Build Coastguard Worker } 1178*4bdc9457SAndroid Build Coastguard Worker 1179*4bdc9457SAndroid Build Coastguard Worker // Setup and run Average Pooling operator the second time, and destroy the operator. 1180*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1181*4bdc9457SAndroid Build Coastguard Worker xnn_setup_average_pooling2d_nhwc_qu8( 1182*4bdc9457SAndroid Build Coastguard Worker average_pooling_op, 1183*4bdc9457SAndroid Build Coastguard Worker next_batch_size(), next_input_height(), next_input_width(), 1184*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 1185*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 1186*4bdc9457SAndroid Build Coastguard Worker 1187*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1188*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(average_pooling_op, nullptr /* thread pool */)); 1189*4bdc9457SAndroid Build Coastguard Worker 1190*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1191*4bdc9457SAndroid Build Coastguard Worker xnn_delete_operator(average_pooling_op)); 1192*4bdc9457SAndroid Build Coastguard Worker average_pooling_op = nullptr; 1193*4bdc9457SAndroid Build Coastguard Worker 1194*4bdc9457SAndroid Build Coastguard Worker // Verify results of the second run. 1195*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 1196*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < next_output_height(); y++) { 1197*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < next_output_width(); x++) { 1198*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 1199*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(uint32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c]), uint32_t(qmax())); 1200*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(uint32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c]), uint32_t(qmin())); 1201*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(float(int32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c])), 1202*4bdc9457SAndroid Build Coastguard Worker next_output_ref[((i * next_output_height() + y) * next_output_width() + x) * channels() + c], 0.80f) << 1203*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 1204*4bdc9457SAndroid Build Coastguard Worker } 1205*4bdc9457SAndroid Build Coastguard Worker } 1206*4bdc9457SAndroid Build Coastguard Worker } 1207*4bdc9457SAndroid Build Coastguard Worker } 1208*4bdc9457SAndroid Build Coastguard Worker } 1209*4bdc9457SAndroid Build Coastguard Worker } 1210*4bdc9457SAndroid Build Coastguard Worker 1211*4bdc9457SAndroid Build Coastguard Worker private: 1212*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_top_{0}; 1213*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_right_{0}; 1214*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_bottom_{0}; 1215*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_left_{0}; 1216*4bdc9457SAndroid Build Coastguard Worker bool padding_tf_same_{false}; 1217*4bdc9457SAndroid Build Coastguard Worker size_t input_height_{1}; 1218*4bdc9457SAndroid Build Coastguard Worker size_t input_width_{1}; 1219*4bdc9457SAndroid Build Coastguard Worker size_t channels_{1}; 1220*4bdc9457SAndroid Build Coastguard Worker size_t batch_size_{1}; 1221*4bdc9457SAndroid Build Coastguard Worker size_t input_pixel_stride_{0}; 1222*4bdc9457SAndroid Build Coastguard Worker size_t output_pixel_stride_{0}; 1223*4bdc9457SAndroid Build Coastguard Worker uint32_t pooling_height_{1}; 1224*4bdc9457SAndroid Build Coastguard Worker uint32_t pooling_width_{1}; 1225*4bdc9457SAndroid Build Coastguard Worker uint32_t stride_height_{1}; 1226*4bdc9457SAndroid Build Coastguard Worker uint32_t stride_width_{1}; 1227*4bdc9457SAndroid Build Coastguard Worker size_t next_input_height_{0}; 1228*4bdc9457SAndroid Build Coastguard Worker size_t next_input_width_{0}; 1229*4bdc9457SAndroid Build Coastguard Worker size_t next_batch_size_{0}; 1230*4bdc9457SAndroid Build Coastguard Worker float input_scale_{1.0f}; 1231*4bdc9457SAndroid Build Coastguard Worker float output_scale_{1.0f}; 1232*4bdc9457SAndroid Build Coastguard Worker uint8_t input_zero_point_{121}; 1233*4bdc9457SAndroid Build Coastguard Worker uint8_t output_zero_point_{133}; 1234*4bdc9457SAndroid Build Coastguard Worker uint8_t qmin_{0}; 1235*4bdc9457SAndroid Build Coastguard Worker uint8_t qmax_{255}; 1236*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{1}; 1237*4bdc9457SAndroid Build Coastguard Worker }; 1238