1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates. 2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved. 3*4bdc9457SAndroid Build Coastguard Worker // 4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC 5*4bdc9457SAndroid Build Coastguard Worker // 6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 8*4bdc9457SAndroid Build Coastguard Worker 9*4bdc9457SAndroid Build Coastguard Worker #pragma once 10*4bdc9457SAndroid Build Coastguard Worker 11*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 12*4bdc9457SAndroid Build Coastguard Worker 13*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h> 14*4bdc9457SAndroid Build Coastguard Worker 15*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 16*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 17*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 18*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 19*4bdc9457SAndroid Build Coastguard Worker #include <limits> 20*4bdc9457SAndroid Build Coastguard Worker #include <random> 21*4bdc9457SAndroid Build Coastguard Worker #include <vector> 22*4bdc9457SAndroid Build Coastguard Worker 23*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 24*4bdc9457SAndroid Build Coastguard Worker 25*4bdc9457SAndroid Build Coastguard Worker 26*4bdc9457SAndroid Build Coastguard Worker class MaxPoolingOperatorTester { 27*4bdc9457SAndroid Build Coastguard Worker public: padding_tf_same(bool padding_same)28*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& padding_tf_same(bool padding_same) { 29*4bdc9457SAndroid Build Coastguard Worker if (padding_same) { 30*4bdc9457SAndroid Build Coastguard Worker assert(padding_top() == 0); 31*4bdc9457SAndroid Build Coastguard Worker assert(padding_left() == 0); 32*4bdc9457SAndroid Build Coastguard Worker assert(padding_bottom() == 0); 33*4bdc9457SAndroid Build Coastguard Worker assert(padding_right() == 0); 34*4bdc9457SAndroid Build Coastguard Worker } 35*4bdc9457SAndroid Build Coastguard Worker this->padding_tf_same_ = padding_same; 36*4bdc9457SAndroid Build Coastguard Worker return *this; 37*4bdc9457SAndroid Build Coastguard Worker } 38*4bdc9457SAndroid Build Coastguard Worker padding_tf_same()39*4bdc9457SAndroid Build Coastguard Worker inline bool padding_tf_same() const { 40*4bdc9457SAndroid Build Coastguard Worker return this->padding_tf_same_; 41*4bdc9457SAndroid Build Coastguard Worker } 42*4bdc9457SAndroid Build Coastguard Worker padding(uint32_t padding)43*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& padding(uint32_t padding) { 44*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 45*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding; 46*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding; 47*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding; 48*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding; 49*4bdc9457SAndroid Build Coastguard Worker return *this; 50*4bdc9457SAndroid Build Coastguard Worker } 51*4bdc9457SAndroid Build Coastguard Worker padding(uint32_t padding_height,uint32_t padding_width)52*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& padding(uint32_t padding_height, uint32_t padding_width) { 53*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 54*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding_height; 55*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding_width; 56*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding_height; 57*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding_width; 58*4bdc9457SAndroid Build Coastguard Worker return *this; 59*4bdc9457SAndroid Build Coastguard Worker } 60*4bdc9457SAndroid Build Coastguard Worker padding_height(uint32_t padding_height)61*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& padding_height(uint32_t padding_height) { 62*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 63*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding_height; 64*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding_height; 65*4bdc9457SAndroid Build Coastguard Worker return *this; 66*4bdc9457SAndroid Build Coastguard Worker } 67*4bdc9457SAndroid Build Coastguard Worker padding_width(uint32_t padding_width)68*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& padding_width(uint32_t padding_width) { 69*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 70*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding_width; 71*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding_width; 72*4bdc9457SAndroid Build Coastguard Worker return *this; 73*4bdc9457SAndroid Build Coastguard Worker } 74*4bdc9457SAndroid Build Coastguard Worker padding_top(uint32_t padding_top)75*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& padding_top(uint32_t padding_top) { 76*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 77*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding_top; 78*4bdc9457SAndroid Build Coastguard Worker return *this; 79*4bdc9457SAndroid Build Coastguard Worker } 80*4bdc9457SAndroid Build Coastguard Worker padding_top()81*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_top() const { 82*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 83*4bdc9457SAndroid Build Coastguard Worker const uint32_t total_padding_height = 84*4bdc9457SAndroid Build Coastguard Worker (output_height() - 1) * stride_height() + dilated_pooling_height() - input_height(); 85*4bdc9457SAndroid Build Coastguard Worker return total_padding_height / 2; 86*4bdc9457SAndroid Build Coastguard Worker } else { 87*4bdc9457SAndroid Build Coastguard Worker return this->padding_top_; 88*4bdc9457SAndroid Build Coastguard Worker } 89*4bdc9457SAndroid Build Coastguard Worker } 90*4bdc9457SAndroid Build Coastguard Worker padding_left(uint32_t padding_left)91*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& padding_left(uint32_t padding_left) { 92*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 93*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding_left; 94*4bdc9457SAndroid Build Coastguard Worker return *this; 95*4bdc9457SAndroid Build Coastguard Worker } 96*4bdc9457SAndroid Build Coastguard Worker padding_left()97*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_left() const { 98*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 99*4bdc9457SAndroid Build Coastguard Worker const uint32_t total_padding_width = 100*4bdc9457SAndroid Build Coastguard Worker (output_width() - 1) * stride_width() + dilated_pooling_width() - input_width(); 101*4bdc9457SAndroid Build Coastguard Worker return total_padding_width / 2; 102*4bdc9457SAndroid Build Coastguard Worker } else { 103*4bdc9457SAndroid Build Coastguard Worker return this->padding_left_; 104*4bdc9457SAndroid Build Coastguard Worker } 105*4bdc9457SAndroid Build Coastguard Worker } 106*4bdc9457SAndroid Build Coastguard Worker padding_bottom(uint32_t padding_bottom)107*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& padding_bottom(uint32_t padding_bottom) { 108*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 109*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding_bottom; 110*4bdc9457SAndroid Build Coastguard Worker return *this; 111*4bdc9457SAndroid Build Coastguard Worker } 112*4bdc9457SAndroid Build Coastguard Worker padding_bottom()113*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_bottom() const { 114*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 115*4bdc9457SAndroid Build Coastguard Worker const uint32_t total_padding_height = 116*4bdc9457SAndroid Build Coastguard Worker (output_height() - 1) * stride_height() + dilated_pooling_height() - input_height(); 117*4bdc9457SAndroid Build Coastguard Worker return total_padding_height - total_padding_height / 2; 118*4bdc9457SAndroid Build Coastguard Worker } else { 119*4bdc9457SAndroid Build Coastguard Worker return this->padding_bottom_; 120*4bdc9457SAndroid Build Coastguard Worker } 121*4bdc9457SAndroid Build Coastguard Worker } 122*4bdc9457SAndroid Build Coastguard Worker padding_right(uint32_t padding_right)123*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& padding_right(uint32_t padding_right) { 124*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 125*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding_right; 126*4bdc9457SAndroid Build Coastguard Worker return *this; 127*4bdc9457SAndroid Build Coastguard Worker } 128*4bdc9457SAndroid Build Coastguard Worker padding_right()129*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_right() const { 130*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 131*4bdc9457SAndroid Build Coastguard Worker const uint32_t total_padding_width = 132*4bdc9457SAndroid Build Coastguard Worker (output_width() - 1) * stride_width() + dilated_pooling_width() - input_width(); 133*4bdc9457SAndroid Build Coastguard Worker return total_padding_width - total_padding_width / 2; 134*4bdc9457SAndroid Build Coastguard Worker } else { 135*4bdc9457SAndroid Build Coastguard Worker return this->padding_right_; 136*4bdc9457SAndroid Build Coastguard Worker } 137*4bdc9457SAndroid Build Coastguard Worker } 138*4bdc9457SAndroid Build Coastguard Worker input_size(size_t input_height,size_t input_width)139*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& input_size(size_t input_height, size_t input_width) { 140*4bdc9457SAndroid Build Coastguard Worker assert(input_height >= 1); 141*4bdc9457SAndroid Build Coastguard Worker assert(input_width >= 1); 142*4bdc9457SAndroid Build Coastguard Worker this->input_height_ = input_height; 143*4bdc9457SAndroid Build Coastguard Worker this->input_width_ = input_width; 144*4bdc9457SAndroid Build Coastguard Worker return *this; 145*4bdc9457SAndroid Build Coastguard Worker } 146*4bdc9457SAndroid Build Coastguard Worker input_height(size_t input_height)147*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& input_height(size_t input_height) { 148*4bdc9457SAndroid Build Coastguard Worker assert(input_height >= 1); 149*4bdc9457SAndroid Build Coastguard Worker this->input_height_ = input_height; 150*4bdc9457SAndroid Build Coastguard Worker return *this; 151*4bdc9457SAndroid Build Coastguard Worker } 152*4bdc9457SAndroid Build Coastguard Worker input_height()153*4bdc9457SAndroid Build Coastguard Worker inline size_t input_height() const { 154*4bdc9457SAndroid Build Coastguard Worker return this->input_height_; 155*4bdc9457SAndroid Build Coastguard Worker } 156*4bdc9457SAndroid Build Coastguard Worker input_width(size_t input_width)157*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& input_width(size_t input_width) { 158*4bdc9457SAndroid Build Coastguard Worker assert(input_width >= 1); 159*4bdc9457SAndroid Build Coastguard Worker this->input_width_ = input_width; 160*4bdc9457SAndroid Build Coastguard Worker return *this; 161*4bdc9457SAndroid Build Coastguard Worker } 162*4bdc9457SAndroid Build Coastguard Worker input_width()163*4bdc9457SAndroid Build Coastguard Worker inline size_t input_width() const { 164*4bdc9457SAndroid Build Coastguard Worker return this->input_width_; 165*4bdc9457SAndroid Build Coastguard Worker } 166*4bdc9457SAndroid Build Coastguard Worker channels(size_t channels)167*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& channels(size_t channels) { 168*4bdc9457SAndroid Build Coastguard Worker assert(channels != 0); 169*4bdc9457SAndroid Build Coastguard Worker this->channels_ = channels; 170*4bdc9457SAndroid Build Coastguard Worker return *this; 171*4bdc9457SAndroid Build Coastguard Worker } 172*4bdc9457SAndroid Build Coastguard Worker channels()173*4bdc9457SAndroid Build Coastguard Worker inline size_t channels() const { 174*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 175*4bdc9457SAndroid Build Coastguard Worker } 176*4bdc9457SAndroid Build Coastguard Worker batch_size(size_t batch_size)177*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& batch_size(size_t batch_size) { 178*4bdc9457SAndroid Build Coastguard Worker assert(batch_size != 0); 179*4bdc9457SAndroid Build Coastguard Worker this->batch_size_ = batch_size; 180*4bdc9457SAndroid Build Coastguard Worker return *this; 181*4bdc9457SAndroid Build Coastguard Worker } 182*4bdc9457SAndroid Build Coastguard Worker batch_size()183*4bdc9457SAndroid Build Coastguard Worker inline size_t batch_size() const { 184*4bdc9457SAndroid Build Coastguard Worker return this->batch_size_; 185*4bdc9457SAndroid Build Coastguard Worker } 186*4bdc9457SAndroid Build Coastguard Worker pooling_size(uint32_t pooling_size)187*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& pooling_size(uint32_t pooling_size) { 188*4bdc9457SAndroid Build Coastguard Worker assert(pooling_size >= 1); 189*4bdc9457SAndroid Build Coastguard Worker this->pooling_height_ = pooling_size; 190*4bdc9457SAndroid Build Coastguard Worker this->pooling_width_ = pooling_size; 191*4bdc9457SAndroid Build Coastguard Worker return *this; 192*4bdc9457SAndroid Build Coastguard Worker } 193*4bdc9457SAndroid Build Coastguard Worker pooling_size(uint32_t pooling_height,uint32_t pooling_width)194*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& pooling_size(uint32_t pooling_height, uint32_t pooling_width) { 195*4bdc9457SAndroid Build Coastguard Worker assert(pooling_height >= 1); 196*4bdc9457SAndroid Build Coastguard Worker assert(pooling_width >= 1); 197*4bdc9457SAndroid Build Coastguard Worker this->pooling_height_ = pooling_height; 198*4bdc9457SAndroid Build Coastguard Worker this->pooling_width_ = pooling_width; 199*4bdc9457SAndroid Build Coastguard Worker return *this; 200*4bdc9457SAndroid Build Coastguard Worker } 201*4bdc9457SAndroid Build Coastguard Worker pooling_height(uint32_t pooling_height)202*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& pooling_height(uint32_t pooling_height) { 203*4bdc9457SAndroid Build Coastguard Worker assert(pooling_height >= 1); 204*4bdc9457SAndroid Build Coastguard Worker this->pooling_height_ = pooling_height; 205*4bdc9457SAndroid Build Coastguard Worker return *this; 206*4bdc9457SAndroid Build Coastguard Worker } 207*4bdc9457SAndroid Build Coastguard Worker pooling_height()208*4bdc9457SAndroid Build Coastguard Worker inline uint32_t pooling_height() const { 209*4bdc9457SAndroid Build Coastguard Worker return this->pooling_height_; 210*4bdc9457SAndroid Build Coastguard Worker } 211*4bdc9457SAndroid Build Coastguard Worker pooling_width(uint32_t pooling_width)212*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& pooling_width(uint32_t pooling_width) { 213*4bdc9457SAndroid Build Coastguard Worker assert(pooling_width >= 1); 214*4bdc9457SAndroid Build Coastguard Worker this->pooling_width_ = pooling_width; 215*4bdc9457SAndroid Build Coastguard Worker return *this; 216*4bdc9457SAndroid Build Coastguard Worker } 217*4bdc9457SAndroid Build Coastguard Worker pooling_width()218*4bdc9457SAndroid Build Coastguard Worker inline uint32_t pooling_width() const { 219*4bdc9457SAndroid Build Coastguard Worker return this->pooling_width_; 220*4bdc9457SAndroid Build Coastguard Worker } 221*4bdc9457SAndroid Build Coastguard Worker stride(uint32_t stride)222*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& stride(uint32_t stride) { 223*4bdc9457SAndroid Build Coastguard Worker assert(stride >= 1); 224*4bdc9457SAndroid Build Coastguard Worker this->stride_height_ = stride; 225*4bdc9457SAndroid Build Coastguard Worker this->stride_width_ = stride; 226*4bdc9457SAndroid Build Coastguard Worker return *this; 227*4bdc9457SAndroid Build Coastguard Worker } 228*4bdc9457SAndroid Build Coastguard Worker stride(uint32_t stride_height,uint32_t stride_width)229*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& stride(uint32_t stride_height, uint32_t stride_width) { 230*4bdc9457SAndroid Build Coastguard Worker assert(stride_height >= 1); 231*4bdc9457SAndroid Build Coastguard Worker assert(stride_width >= 1); 232*4bdc9457SAndroid Build Coastguard Worker this->stride_height_ = stride_height; 233*4bdc9457SAndroid Build Coastguard Worker this->stride_width_ = stride_width; 234*4bdc9457SAndroid Build Coastguard Worker return *this; 235*4bdc9457SAndroid Build Coastguard Worker } 236*4bdc9457SAndroid Build Coastguard Worker stride_height(uint32_t stride_height)237*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& stride_height(uint32_t stride_height) { 238*4bdc9457SAndroid Build Coastguard Worker assert(stride_height >= 1); 239*4bdc9457SAndroid Build Coastguard Worker this->stride_height_ = stride_height; 240*4bdc9457SAndroid Build Coastguard Worker return *this; 241*4bdc9457SAndroid Build Coastguard Worker } 242*4bdc9457SAndroid Build Coastguard Worker stride_height()243*4bdc9457SAndroid Build Coastguard Worker inline uint32_t stride_height() const { 244*4bdc9457SAndroid Build Coastguard Worker return this->stride_height_; 245*4bdc9457SAndroid Build Coastguard Worker } 246*4bdc9457SAndroid Build Coastguard Worker stride_width(uint32_t stride_width)247*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& stride_width(uint32_t stride_width) { 248*4bdc9457SAndroid Build Coastguard Worker assert(stride_width >= 1); 249*4bdc9457SAndroid Build Coastguard Worker this->stride_width_ = stride_width; 250*4bdc9457SAndroid Build Coastguard Worker return *this; 251*4bdc9457SAndroid Build Coastguard Worker } 252*4bdc9457SAndroid Build Coastguard Worker stride_width()253*4bdc9457SAndroid Build Coastguard Worker inline uint32_t stride_width() const { 254*4bdc9457SAndroid Build Coastguard Worker return this->stride_width_; 255*4bdc9457SAndroid Build Coastguard Worker } 256*4bdc9457SAndroid Build Coastguard Worker dilation(uint32_t dilation)257*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& dilation(uint32_t dilation) { 258*4bdc9457SAndroid Build Coastguard Worker assert(dilation >= 1); 259*4bdc9457SAndroid Build Coastguard Worker this->dilation_height_ = dilation; 260*4bdc9457SAndroid Build Coastguard Worker this->dilation_width_ = dilation; 261*4bdc9457SAndroid Build Coastguard Worker return *this; 262*4bdc9457SAndroid Build Coastguard Worker } 263*4bdc9457SAndroid Build Coastguard Worker dilation(uint32_t dilation_height,uint32_t dilation_width)264*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& dilation(uint32_t dilation_height, uint32_t dilation_width) { 265*4bdc9457SAndroid Build Coastguard Worker assert(dilation_height >= 1); 266*4bdc9457SAndroid Build Coastguard Worker assert(dilation_width >= 1); 267*4bdc9457SAndroid Build Coastguard Worker this->dilation_height_ = dilation_height; 268*4bdc9457SAndroid Build Coastguard Worker this->dilation_width_ = dilation_width; 269*4bdc9457SAndroid Build Coastguard Worker return *this; 270*4bdc9457SAndroid Build Coastguard Worker } 271*4bdc9457SAndroid Build Coastguard Worker dilation_height(uint32_t dilation_height)272*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& dilation_height(uint32_t dilation_height) { 273*4bdc9457SAndroid Build Coastguard Worker assert(dilation_height >= 1); 274*4bdc9457SAndroid Build Coastguard Worker this->dilation_height_ = dilation_height; 275*4bdc9457SAndroid Build Coastguard Worker return *this; 276*4bdc9457SAndroid Build Coastguard Worker } 277*4bdc9457SAndroid Build Coastguard Worker dilation_height()278*4bdc9457SAndroid Build Coastguard Worker inline uint32_t dilation_height() const { 279*4bdc9457SAndroid Build Coastguard Worker return this->dilation_height_; 280*4bdc9457SAndroid Build Coastguard Worker } 281*4bdc9457SAndroid Build Coastguard Worker dilation_width(uint32_t dilation_width)282*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& dilation_width(uint32_t dilation_width) { 283*4bdc9457SAndroid Build Coastguard Worker assert(dilation_width >= 1); 284*4bdc9457SAndroid Build Coastguard Worker this->dilation_width_ = dilation_width; 285*4bdc9457SAndroid Build Coastguard Worker return *this; 286*4bdc9457SAndroid Build Coastguard Worker } 287*4bdc9457SAndroid Build Coastguard Worker dilation_width()288*4bdc9457SAndroid Build Coastguard Worker inline uint32_t dilation_width() const { 289*4bdc9457SAndroid Build Coastguard Worker return this->dilation_width_; 290*4bdc9457SAndroid Build Coastguard Worker } 291*4bdc9457SAndroid Build Coastguard Worker dilated_pooling_height()292*4bdc9457SAndroid Build Coastguard Worker inline uint32_t dilated_pooling_height() const { 293*4bdc9457SAndroid Build Coastguard Worker return (pooling_height() - 1) * dilation_height() + 1; 294*4bdc9457SAndroid Build Coastguard Worker } 295*4bdc9457SAndroid Build Coastguard Worker dilated_pooling_width()296*4bdc9457SAndroid Build Coastguard Worker inline uint32_t dilated_pooling_width() const { 297*4bdc9457SAndroid Build Coastguard Worker return (pooling_width() - 1) * dilation_width() + 1; 298*4bdc9457SAndroid Build Coastguard Worker } 299*4bdc9457SAndroid Build Coastguard Worker output_height()300*4bdc9457SAndroid Build Coastguard Worker inline size_t output_height() const { 301*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 302*4bdc9457SAndroid Build Coastguard Worker return (input_height() + stride_height() - 1) / stride_height(); 303*4bdc9457SAndroid Build Coastguard Worker } else { 304*4bdc9457SAndroid Build Coastguard Worker const size_t padded_input_height = padding_top() + input_height() + padding_bottom(); 305*4bdc9457SAndroid Build Coastguard Worker if (padded_input_height <= dilated_pooling_height()) { 306*4bdc9457SAndroid Build Coastguard Worker return 1; 307*4bdc9457SAndroid Build Coastguard Worker } else { 308*4bdc9457SAndroid Build Coastguard Worker return (padded_input_height - dilated_pooling_height()) / stride_height() + 1; 309*4bdc9457SAndroid Build Coastguard Worker } 310*4bdc9457SAndroid Build Coastguard Worker } 311*4bdc9457SAndroid Build Coastguard Worker } 312*4bdc9457SAndroid Build Coastguard Worker output_width()313*4bdc9457SAndroid Build Coastguard Worker inline size_t output_width() const { 314*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 315*4bdc9457SAndroid Build Coastguard Worker return (input_width() + stride_width() - 1) / stride_width(); 316*4bdc9457SAndroid Build Coastguard Worker } else { 317*4bdc9457SAndroid Build Coastguard Worker const size_t padded_input_width = padding_left() + input_width() + padding_right(); 318*4bdc9457SAndroid Build Coastguard Worker if (padded_input_width <= dilated_pooling_width()) { 319*4bdc9457SAndroid Build Coastguard Worker return 1; 320*4bdc9457SAndroid Build Coastguard Worker } else { 321*4bdc9457SAndroid Build Coastguard Worker return (padded_input_width - dilated_pooling_width()) / stride_width() + 1; 322*4bdc9457SAndroid Build Coastguard Worker } 323*4bdc9457SAndroid Build Coastguard Worker } 324*4bdc9457SAndroid Build Coastguard Worker } 325*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride(size_t input_pixel_stride)326*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& input_pixel_stride(size_t input_pixel_stride) { 327*4bdc9457SAndroid Build Coastguard Worker assert(input_pixel_stride != 0); 328*4bdc9457SAndroid Build Coastguard Worker this->input_pixel_stride_ = input_pixel_stride; 329*4bdc9457SAndroid Build Coastguard Worker return *this; 330*4bdc9457SAndroid Build Coastguard Worker } 331*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride()332*4bdc9457SAndroid Build Coastguard Worker inline size_t input_pixel_stride() const { 333*4bdc9457SAndroid Build Coastguard Worker if (this->input_pixel_stride_ == 0) { 334*4bdc9457SAndroid Build Coastguard Worker return channels(); 335*4bdc9457SAndroid Build Coastguard Worker } else { 336*4bdc9457SAndroid Build Coastguard Worker assert(this->input_pixel_stride_ >= channels()); 337*4bdc9457SAndroid Build Coastguard Worker return this->input_pixel_stride_; 338*4bdc9457SAndroid Build Coastguard Worker } 339*4bdc9457SAndroid Build Coastguard Worker } 340*4bdc9457SAndroid Build Coastguard Worker output_pixel_stride(size_t output_pixel_stride)341*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& output_pixel_stride(size_t output_pixel_stride) { 342*4bdc9457SAndroid Build Coastguard Worker assert(output_pixel_stride != 0); 343*4bdc9457SAndroid Build Coastguard Worker this->output_pixel_stride_ = output_pixel_stride; 344*4bdc9457SAndroid Build Coastguard Worker return *this; 345*4bdc9457SAndroid Build Coastguard Worker } 346*4bdc9457SAndroid Build Coastguard Worker output_pixel_stride()347*4bdc9457SAndroid Build Coastguard Worker inline size_t output_pixel_stride() const { 348*4bdc9457SAndroid Build Coastguard Worker if (this->output_pixel_stride_ == 0) { 349*4bdc9457SAndroid Build Coastguard Worker return channels(); 350*4bdc9457SAndroid Build Coastguard Worker } else { 351*4bdc9457SAndroid Build Coastguard Worker assert(this->output_pixel_stride_ >= channels()); 352*4bdc9457SAndroid Build Coastguard Worker return this->output_pixel_stride_; 353*4bdc9457SAndroid Build Coastguard Worker } 354*4bdc9457SAndroid Build Coastguard Worker } 355*4bdc9457SAndroid Build Coastguard Worker next_input_size(uint32_t next_input_height,uint32_t next_input_width)356*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& next_input_size(uint32_t next_input_height, uint32_t next_input_width) { 357*4bdc9457SAndroid Build Coastguard Worker assert(next_input_height >= 1); 358*4bdc9457SAndroid Build Coastguard Worker assert(next_input_width >= 1); 359*4bdc9457SAndroid Build Coastguard Worker this->next_input_height_ = next_input_height; 360*4bdc9457SAndroid Build Coastguard Worker this->next_input_width_ = next_input_width; 361*4bdc9457SAndroid Build Coastguard Worker return *this; 362*4bdc9457SAndroid Build Coastguard Worker } 363*4bdc9457SAndroid Build Coastguard Worker next_input_height(uint32_t next_input_height)364*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& next_input_height(uint32_t next_input_height) { 365*4bdc9457SAndroid Build Coastguard Worker assert(next_input_height >= 1); 366*4bdc9457SAndroid Build Coastguard Worker this->next_input_height_ = next_input_height; 367*4bdc9457SAndroid Build Coastguard Worker return *this; 368*4bdc9457SAndroid Build Coastguard Worker } 369*4bdc9457SAndroid Build Coastguard Worker next_input_height()370*4bdc9457SAndroid Build Coastguard Worker inline uint32_t next_input_height() const { 371*4bdc9457SAndroid Build Coastguard Worker if (this->next_input_height_ == 0) { 372*4bdc9457SAndroid Build Coastguard Worker return input_height(); 373*4bdc9457SAndroid Build Coastguard Worker } else { 374*4bdc9457SAndroid Build Coastguard Worker return this->next_input_height_; 375*4bdc9457SAndroid Build Coastguard Worker } 376*4bdc9457SAndroid Build Coastguard Worker } 377*4bdc9457SAndroid Build Coastguard Worker next_input_width(uint32_t next_input_width)378*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& next_input_width(uint32_t next_input_width) { 379*4bdc9457SAndroid Build Coastguard Worker assert(next_input_width >= 1); 380*4bdc9457SAndroid Build Coastguard Worker this->next_input_width_ = next_input_width; 381*4bdc9457SAndroid Build Coastguard Worker return *this; 382*4bdc9457SAndroid Build Coastguard Worker } 383*4bdc9457SAndroid Build Coastguard Worker next_input_width()384*4bdc9457SAndroid Build Coastguard Worker inline uint32_t next_input_width() const { 385*4bdc9457SAndroid Build Coastguard Worker if (this->next_input_width_ == 0) { 386*4bdc9457SAndroid Build Coastguard Worker return input_width(); 387*4bdc9457SAndroid Build Coastguard Worker } else { 388*4bdc9457SAndroid Build Coastguard Worker return this->next_input_width_; 389*4bdc9457SAndroid Build Coastguard Worker } 390*4bdc9457SAndroid Build Coastguard Worker } 391*4bdc9457SAndroid Build Coastguard Worker next_output_height()392*4bdc9457SAndroid Build Coastguard Worker inline size_t next_output_height() const { 393*4bdc9457SAndroid Build Coastguard Worker const size_t padded_next_input_height = padding_top() + next_input_height() + padding_bottom(); 394*4bdc9457SAndroid Build Coastguard Worker if (padded_next_input_height <= dilated_pooling_height()) { 395*4bdc9457SAndroid Build Coastguard Worker return 1; 396*4bdc9457SAndroid Build Coastguard Worker } else { 397*4bdc9457SAndroid Build Coastguard Worker return (padded_next_input_height - dilated_pooling_height()) / stride_height() + 1; 398*4bdc9457SAndroid Build Coastguard Worker } 399*4bdc9457SAndroid Build Coastguard Worker } 400*4bdc9457SAndroid Build Coastguard Worker next_output_width()401*4bdc9457SAndroid Build Coastguard Worker inline size_t next_output_width() const { 402*4bdc9457SAndroid Build Coastguard Worker const size_t padded_next_input_width = padding_left() + next_input_width() + padding_right(); 403*4bdc9457SAndroid Build Coastguard Worker if (padded_next_input_width <= dilated_pooling_width()) { 404*4bdc9457SAndroid Build Coastguard Worker return 1; 405*4bdc9457SAndroid Build Coastguard Worker } else { 406*4bdc9457SAndroid Build Coastguard Worker return (padded_next_input_width - dilated_pooling_width()) / stride_width() + 1; 407*4bdc9457SAndroid Build Coastguard Worker } 408*4bdc9457SAndroid Build Coastguard Worker } 409*4bdc9457SAndroid Build Coastguard Worker next_batch_size(size_t next_batch_size)410*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& next_batch_size(size_t next_batch_size) { 411*4bdc9457SAndroid Build Coastguard Worker assert(next_batch_size >= 1); 412*4bdc9457SAndroid Build Coastguard Worker this->next_batch_size_ = next_batch_size; 413*4bdc9457SAndroid Build Coastguard Worker return *this; 414*4bdc9457SAndroid Build Coastguard Worker } 415*4bdc9457SAndroid Build Coastguard Worker next_batch_size()416*4bdc9457SAndroid Build Coastguard Worker inline size_t next_batch_size() const { 417*4bdc9457SAndroid Build Coastguard Worker if (this->next_batch_size_ == 0) { 418*4bdc9457SAndroid Build Coastguard Worker return batch_size(); 419*4bdc9457SAndroid Build Coastguard Worker } else { 420*4bdc9457SAndroid Build Coastguard Worker return this->next_batch_size_; 421*4bdc9457SAndroid Build Coastguard Worker } 422*4bdc9457SAndroid Build Coastguard Worker } 423*4bdc9457SAndroid Build Coastguard Worker qmin(uint8_t qmin)424*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& qmin(uint8_t qmin) { 425*4bdc9457SAndroid Build Coastguard Worker this->qmin_ = qmin; 426*4bdc9457SAndroid Build Coastguard Worker return *this; 427*4bdc9457SAndroid Build Coastguard Worker } 428*4bdc9457SAndroid Build Coastguard Worker qmin()429*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmin() const { 430*4bdc9457SAndroid Build Coastguard Worker return this->qmin_; 431*4bdc9457SAndroid Build Coastguard Worker } 432*4bdc9457SAndroid Build Coastguard Worker qmax(uint8_t qmax)433*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& qmax(uint8_t qmax) { 434*4bdc9457SAndroid Build Coastguard Worker this->qmax_ = qmax; 435*4bdc9457SAndroid Build Coastguard Worker return *this; 436*4bdc9457SAndroid Build Coastguard Worker } 437*4bdc9457SAndroid Build Coastguard Worker qmax()438*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmax() const { 439*4bdc9457SAndroid Build Coastguard Worker return this->qmax_; 440*4bdc9457SAndroid Build Coastguard Worker } 441*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)442*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolingOperatorTester& iterations(size_t iterations) { 443*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 444*4bdc9457SAndroid Build Coastguard Worker return *this; 445*4bdc9457SAndroid Build Coastguard Worker } 446*4bdc9457SAndroid Build Coastguard Worker iterations()447*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 448*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 449*4bdc9457SAndroid Build Coastguard Worker } 450*4bdc9457SAndroid Build Coastguard Worker TestS8()451*4bdc9457SAndroid Build Coastguard Worker void TestS8() const { 452*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 453*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 454*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i8dist( 455*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()); 456*4bdc9457SAndroid Build Coastguard Worker 457*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> input((batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels() + XNN_EXTRA_BYTES / sizeof(int8_t)); 458*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels() + XNN_EXTRA_BYTES / sizeof(int8_t)); 459*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output_ref(batch_size() * output_height() * output_width() * channels()); 460*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 461*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); }); 462*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), 0xA5); 463*4bdc9457SAndroid Build Coastguard Worker 464*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 465*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 466*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 467*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 468*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 469*4bdc9457SAndroid Build Coastguard Worker int8_t max_value = std::numeric_limits<int8_t>::min(); 470*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 471*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * stride_height() + py * dilation_height() - padding_top(); 472*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 473*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * stride_width() + px * dilation_width() - padding_left(); 474*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width() && iy < input_height()) { 475*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, 476*4bdc9457SAndroid Build Coastguard Worker input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c]); 477*4bdc9457SAndroid Build Coastguard Worker } 478*4bdc9457SAndroid Build Coastguard Worker } 479*4bdc9457SAndroid Build Coastguard Worker } 480*4bdc9457SAndroid Build Coastguard Worker max_value = std::min(max_value, int8_t(qmax() - 0x80)); 481*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, int8_t(qmin() - 0x80)); 482*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = max_value; 483*4bdc9457SAndroid Build Coastguard Worker } 484*4bdc9457SAndroid Build Coastguard Worker } 485*4bdc9457SAndroid Build Coastguard Worker } 486*4bdc9457SAndroid Build Coastguard Worker } 487*4bdc9457SAndroid Build Coastguard Worker 488*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Max Pooling operator. 489*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 490*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t max_pooling_op = nullptr; 491*4bdc9457SAndroid Build Coastguard Worker 492*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 493*4bdc9457SAndroid Build Coastguard Worker xnn_create_max_pooling2d_nhwc_s8( 494*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_top(), padding_tf_same() ? 0 : padding_right(), 495*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_bottom(), padding_tf_same() ? 0 : padding_left(), 496*4bdc9457SAndroid Build Coastguard Worker pooling_height(), pooling_width(), 497*4bdc9457SAndroid Build Coastguard Worker stride_height(), stride_width(), 498*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), 499*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 500*4bdc9457SAndroid Build Coastguard Worker int8_t(qmin() - 0x80), int8_t(qmax() - 0x80), 501*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? XNN_FLAG_TENSORFLOW_SAME_PADDING : 0, 502*4bdc9457SAndroid Build Coastguard Worker &max_pooling_op)); 503*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, max_pooling_op); 504*4bdc9457SAndroid Build Coastguard Worker 505*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete max_pooling_op. 506*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_max_pooling_op(max_pooling_op, xnn_delete_operator); 507*4bdc9457SAndroid Build Coastguard Worker 508*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 509*4bdc9457SAndroid Build Coastguard Worker xnn_setup_max_pooling2d_nhwc_s8( 510*4bdc9457SAndroid Build Coastguard Worker max_pooling_op, 511*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 512*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 513*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 514*4bdc9457SAndroid Build Coastguard Worker 515*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 516*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(max_pooling_op, nullptr /* thread pool */)); 517*4bdc9457SAndroid Build Coastguard Worker 518*4bdc9457SAndroid Build Coastguard Worker // Verify results. 519*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 520*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 521*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 522*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 523*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), int32_t(qmax() - 0x80)); 524*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), int32_t(qmin() - 0x80)); 525*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(int32_t(output_ref[((i * output_height() + y) * output_width() + x) * channels() + c]), 526*4bdc9457SAndroid Build Coastguard Worker int32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c])) << 527*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 528*4bdc9457SAndroid Build Coastguard Worker } 529*4bdc9457SAndroid Build Coastguard Worker } 530*4bdc9457SAndroid Build Coastguard Worker } 531*4bdc9457SAndroid Build Coastguard Worker } 532*4bdc9457SAndroid Build Coastguard Worker } 533*4bdc9457SAndroid Build Coastguard Worker } 534*4bdc9457SAndroid Build Coastguard Worker TestU8()535*4bdc9457SAndroid Build Coastguard Worker void TestU8() const { 536*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 537*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 538*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> u8dist( 539*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max()); 540*4bdc9457SAndroid Build Coastguard Worker 541*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input((batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint8_t)); 542*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint8_t)); 543*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output_ref(batch_size() * output_height() * output_width() * channels()); 544*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 545*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); 546*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), 0xA5); 547*4bdc9457SAndroid Build Coastguard Worker 548*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 549*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 550*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 551*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 552*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 553*4bdc9457SAndroid Build Coastguard Worker uint8_t max_value = 0; 554*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 555*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * stride_height() + py * dilation_height() - padding_top(); 556*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 557*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * stride_width() + px * dilation_width() - padding_left(); 558*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width() && iy < input_height()) { 559*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, 560*4bdc9457SAndroid Build Coastguard Worker input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c]); 561*4bdc9457SAndroid Build Coastguard Worker } 562*4bdc9457SAndroid Build Coastguard Worker } 563*4bdc9457SAndroid Build Coastguard Worker } 564*4bdc9457SAndroid Build Coastguard Worker max_value = std::min(max_value, qmax()); 565*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, qmin()); 566*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = max_value; 567*4bdc9457SAndroid Build Coastguard Worker } 568*4bdc9457SAndroid Build Coastguard Worker } 569*4bdc9457SAndroid Build Coastguard Worker } 570*4bdc9457SAndroid Build Coastguard Worker } 571*4bdc9457SAndroid Build Coastguard Worker 572*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Max Pooling operator. 573*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 574*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t max_pooling_op = nullptr; 575*4bdc9457SAndroid Build Coastguard Worker 576*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 577*4bdc9457SAndroid Build Coastguard Worker xnn_create_max_pooling2d_nhwc_u8( 578*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_top(), padding_tf_same() ? 0 : padding_right(), 579*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_bottom(), padding_tf_same() ? 0 : padding_left(), 580*4bdc9457SAndroid Build Coastguard Worker pooling_height(), pooling_width(), 581*4bdc9457SAndroid Build Coastguard Worker stride_height(), stride_width(), 582*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), 583*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 584*4bdc9457SAndroid Build Coastguard Worker qmin(), qmax(), 585*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? XNN_FLAG_TENSORFLOW_SAME_PADDING : 0, 586*4bdc9457SAndroid Build Coastguard Worker &max_pooling_op)); 587*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, max_pooling_op); 588*4bdc9457SAndroid Build Coastguard Worker 589*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete max_pooling_op. 590*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_max_pooling_op(max_pooling_op, xnn_delete_operator); 591*4bdc9457SAndroid Build Coastguard Worker 592*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 593*4bdc9457SAndroid Build Coastguard Worker xnn_setup_max_pooling2d_nhwc_u8( 594*4bdc9457SAndroid Build Coastguard Worker max_pooling_op, 595*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 596*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 597*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 598*4bdc9457SAndroid Build Coastguard Worker 599*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 600*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(max_pooling_op, nullptr /* thread pool */)); 601*4bdc9457SAndroid Build Coastguard Worker 602*4bdc9457SAndroid Build Coastguard Worker // Verify results. 603*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 604*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 605*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 606*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 607*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(uint32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), uint32_t(qmax())); 608*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(uint32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), uint32_t(qmin())); 609*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(uint32_t(output_ref[((i * output_height() + y) * output_width() + x) * channels() + c]), 610*4bdc9457SAndroid Build Coastguard Worker uint32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c])) << 611*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 612*4bdc9457SAndroid Build Coastguard Worker } 613*4bdc9457SAndroid Build Coastguard Worker } 614*4bdc9457SAndroid Build Coastguard Worker } 615*4bdc9457SAndroid Build Coastguard Worker } 616*4bdc9457SAndroid Build Coastguard Worker } 617*4bdc9457SAndroid Build Coastguard Worker } 618*4bdc9457SAndroid Build Coastguard Worker TestF16()619*4bdc9457SAndroid Build Coastguard Worker void TestF16() const { 620*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 621*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 622*4bdc9457SAndroid Build Coastguard Worker // Note: we need to avoid FP16 denormals in the generated tensor because they might be processed differently in 623*4bdc9457SAndroid Build Coastguard Worker // native vs emulated arithmetics, and we use exact comparison to verify the results against reference. 624*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.001f, 1.0f); 625*4bdc9457SAndroid Build Coastguard Worker 626*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> input((batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 627*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 628*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels()); 629*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 630*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 631*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 632*4bdc9457SAndroid Build Coastguard Worker 633*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 634*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 635*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 636*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 637*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 638*4bdc9457SAndroid Build Coastguard Worker float max_value = -std::numeric_limits<float>::infinity(); 639*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 640*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * stride_height() + py * dilation_height() - padding_top(); 641*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 642*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * stride_width() + px * dilation_width() - padding_left(); 643*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width() && iy < input_height()) { 644*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, 645*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c])); 646*4bdc9457SAndroid Build Coastguard Worker } 647*4bdc9457SAndroid Build Coastguard Worker } 648*4bdc9457SAndroid Build Coastguard Worker } 649*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = max_value; 650*4bdc9457SAndroid Build Coastguard Worker } 651*4bdc9457SAndroid Build Coastguard Worker } 652*4bdc9457SAndroid Build Coastguard Worker } 653*4bdc9457SAndroid Build Coastguard Worker } 654*4bdc9457SAndroid Build Coastguard Worker 655*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 656*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 657*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 658*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 659*4bdc9457SAndroid Build Coastguard Worker float output_min = accumulated_min + accumulated_range / 255.0f * float(qmin()); 660*4bdc9457SAndroid Build Coastguard Worker float output_max = accumulated_max - accumulated_range / 255.0f * float(255 - qmax()); 661*4bdc9457SAndroid Build Coastguard Worker output_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_min)); 662*4bdc9457SAndroid Build Coastguard Worker output_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_max)); 663*4bdc9457SAndroid Build Coastguard Worker if (accumulated_range == 0.0f) { 664*4bdc9457SAndroid Build Coastguard Worker output_min = -std::numeric_limits<float>::infinity(); 665*4bdc9457SAndroid Build Coastguard Worker output_max = +std::numeric_limits<float>::infinity(); 666*4bdc9457SAndroid Build Coastguard Worker } 667*4bdc9457SAndroid Build Coastguard Worker if (qmin() == std::numeric_limits<uint8_t>::min()) { 668*4bdc9457SAndroid Build Coastguard Worker output_min = -std::numeric_limits<float>::infinity(); 669*4bdc9457SAndroid Build Coastguard Worker } 670*4bdc9457SAndroid Build Coastguard Worker if (qmax() == std::numeric_limits<uint8_t>::max()) { 671*4bdc9457SAndroid Build Coastguard Worker output_max = +std::numeric_limits<float>::infinity(); 672*4bdc9457SAndroid Build Coastguard Worker } 673*4bdc9457SAndroid Build Coastguard Worker 674*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 675*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) { 676*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min); 677*4bdc9457SAndroid Build Coastguard Worker } 678*4bdc9457SAndroid Build Coastguard Worker 679*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Max Pooling operator. 680*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 681*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t max_pooling_op = nullptr; 682*4bdc9457SAndroid Build Coastguard Worker 683*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_create_max_pooling2d_nhwc_f16( 684*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_top(), padding_tf_same() ? 0 : padding_right(), 685*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_bottom(), padding_tf_same() ? 0 : padding_left(), 686*4bdc9457SAndroid Build Coastguard Worker pooling_height(), pooling_width(), 687*4bdc9457SAndroid Build Coastguard Worker stride_height(), stride_width(), 688*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), 689*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 690*4bdc9457SAndroid Build Coastguard Worker output_min, output_max, 691*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? XNN_FLAG_TENSORFLOW_SAME_PADDING : 0, 692*4bdc9457SAndroid Build Coastguard Worker &max_pooling_op); 693*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 694*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 695*4bdc9457SAndroid Build Coastguard Worker } 696*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 697*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, max_pooling_op); 698*4bdc9457SAndroid Build Coastguard Worker 699*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete max_pooling_op. 700*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_max_pooling_op(max_pooling_op, xnn_delete_operator); 701*4bdc9457SAndroid Build Coastguard Worker 702*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 703*4bdc9457SAndroid Build Coastguard Worker xnn_setup_max_pooling2d_nhwc_f16( 704*4bdc9457SAndroid Build Coastguard Worker max_pooling_op, 705*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 706*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 707*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 708*4bdc9457SAndroid Build Coastguard Worker 709*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 710*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(max_pooling_op, nullptr /* thread pool */)); 711*4bdc9457SAndroid Build Coastguard Worker 712*4bdc9457SAndroid Build Coastguard Worker // Verify results. 713*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 714*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 715*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 716*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 717*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), output_max); 718*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), output_min); 719*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ( 720*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), 721*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + y) * output_width() + x) * channels() + c]) << 722*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c 723*4bdc9457SAndroid Build Coastguard Worker << ", min = " << output_min << ", max = " << output_max; 724*4bdc9457SAndroid Build Coastguard Worker } 725*4bdc9457SAndroid Build Coastguard Worker } 726*4bdc9457SAndroid Build Coastguard Worker } 727*4bdc9457SAndroid Build Coastguard Worker } 728*4bdc9457SAndroid Build Coastguard Worker } 729*4bdc9457SAndroid Build Coastguard Worker } 730*4bdc9457SAndroid Build Coastguard Worker TestF32()731*4bdc9457SAndroid Build Coastguard Worker void TestF32() const { 732*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 733*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 734*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist; 735*4bdc9457SAndroid Build Coastguard Worker 736*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input((batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float)); 737*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float)); 738*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels()); 739*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 740*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 741*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), nanf("")); 742*4bdc9457SAndroid Build Coastguard Worker 743*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 744*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 745*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 746*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 747*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 748*4bdc9457SAndroid Build Coastguard Worker float max_value = -std::numeric_limits<float>::infinity(); 749*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 750*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * stride_height() + py * dilation_height() - padding_top(); 751*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 752*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * stride_width() + px * dilation_width() - padding_left(); 753*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width() && iy < input_height()) { 754*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, 755*4bdc9457SAndroid Build Coastguard Worker input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c]); 756*4bdc9457SAndroid Build Coastguard Worker } 757*4bdc9457SAndroid Build Coastguard Worker } 758*4bdc9457SAndroid Build Coastguard Worker } 759*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = max_value; 760*4bdc9457SAndroid Build Coastguard Worker } 761*4bdc9457SAndroid Build Coastguard Worker } 762*4bdc9457SAndroid Build Coastguard Worker } 763*4bdc9457SAndroid Build Coastguard Worker } 764*4bdc9457SAndroid Build Coastguard Worker 765*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 766*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 767*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 768*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 769*4bdc9457SAndroid Build Coastguard Worker const float output_min = accumulated_range == 0.0f ? 770*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<float>::infinity() : 771*4bdc9457SAndroid Build Coastguard Worker accumulated_min + accumulated_range / 255.0f * float(qmin()); 772*4bdc9457SAndroid Build Coastguard Worker const float output_max = accumulated_range == 0.0f ? 773*4bdc9457SAndroid Build Coastguard Worker +std::numeric_limits<float>::infinity() : 774*4bdc9457SAndroid Build Coastguard Worker accumulated_max - accumulated_range / 255.0f * float(255 - qmax()); 775*4bdc9457SAndroid Build Coastguard Worker 776*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 777*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) { 778*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min); 779*4bdc9457SAndroid Build Coastguard Worker } 780*4bdc9457SAndroid Build Coastguard Worker 781*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Max Pooling operator. 782*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 783*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t max_pooling_op = nullptr; 784*4bdc9457SAndroid Build Coastguard Worker 785*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 786*4bdc9457SAndroid Build Coastguard Worker xnn_create_max_pooling2d_nhwc_f32( 787*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_top(), padding_tf_same() ? 0 : padding_right(), 788*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_bottom(), padding_tf_same() ? 0 : padding_left(), 789*4bdc9457SAndroid Build Coastguard Worker pooling_height(), pooling_width(), 790*4bdc9457SAndroid Build Coastguard Worker stride_height(), stride_width(), 791*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), 792*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 793*4bdc9457SAndroid Build Coastguard Worker output_min, output_max, 794*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? XNN_FLAG_TENSORFLOW_SAME_PADDING : 0, 795*4bdc9457SAndroid Build Coastguard Worker &max_pooling_op)); 796*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, max_pooling_op); 797*4bdc9457SAndroid Build Coastguard Worker 798*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete max_pooling_op. 799*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_max_pooling_op(max_pooling_op, xnn_delete_operator); 800*4bdc9457SAndroid Build Coastguard Worker 801*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 802*4bdc9457SAndroid Build Coastguard Worker xnn_setup_max_pooling2d_nhwc_f32( 803*4bdc9457SAndroid Build Coastguard Worker max_pooling_op, 804*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 805*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 806*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 807*4bdc9457SAndroid Build Coastguard Worker 808*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 809*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(max_pooling_op, nullptr /* thread pool */)); 810*4bdc9457SAndroid Build Coastguard Worker 811*4bdc9457SAndroid Build Coastguard Worker // Verify results. 812*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 813*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 814*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 815*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 816*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c], output_max); 817*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c], output_min); 818*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(output_ref[((i * output_height() + y) * output_width() + x) * channels() + c], 819*4bdc9457SAndroid Build Coastguard Worker output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]) << 820*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c 821*4bdc9457SAndroid Build Coastguard Worker << ", min = " << output_min << ", max = " << output_max; 822*4bdc9457SAndroid Build Coastguard Worker } 823*4bdc9457SAndroid Build Coastguard Worker } 824*4bdc9457SAndroid Build Coastguard Worker } 825*4bdc9457SAndroid Build Coastguard Worker } 826*4bdc9457SAndroid Build Coastguard Worker } 827*4bdc9457SAndroid Build Coastguard Worker } 828*4bdc9457SAndroid Build Coastguard Worker TestSetupS8()829*4bdc9457SAndroid Build Coastguard Worker void TestSetupS8() const { 830*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 831*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 832*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i8dist( 833*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()); 834*4bdc9457SAndroid Build Coastguard Worker 835*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) + std::max<size_t>( 836*4bdc9457SAndroid Build Coastguard Worker (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels(), 837*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_input_height() * next_input_width() - 1) * input_pixel_stride() + channels())); 838*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output(XNN_EXTRA_BYTES / sizeof(int8_t) + std::max<size_t>( 839*4bdc9457SAndroid Build Coastguard Worker (batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels(), 840*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_output_height() * next_output_width() - 1) * output_pixel_stride() + channels())); 841*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels()); 842*4bdc9457SAndroid Build Coastguard Worker std::vector<float> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * channels()); 843*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 844*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); }); 845*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), INT8_C(0xA5)); 846*4bdc9457SAndroid Build Coastguard Worker 847*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 848*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 849*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 850*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 851*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 852*4bdc9457SAndroid Build Coastguard Worker int8_t max_value = std::numeric_limits<int8_t>::min(); 853*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 854*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * stride_height() + py * dilation_height() - padding_top(); 855*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 856*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * stride_width() + px * dilation_width() - padding_left(); 857*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width() && iy < input_height()) { 858*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, 859*4bdc9457SAndroid Build Coastguard Worker input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c]); 860*4bdc9457SAndroid Build Coastguard Worker } 861*4bdc9457SAndroid Build Coastguard Worker } 862*4bdc9457SAndroid Build Coastguard Worker } 863*4bdc9457SAndroid Build Coastguard Worker max_value = std::min(max_value, int8_t(qmax() - 0x80)); 864*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, int8_t(qmin() - 0x80)); 865*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = max_value; 866*4bdc9457SAndroid Build Coastguard Worker } 867*4bdc9457SAndroid Build Coastguard Worker } 868*4bdc9457SAndroid Build Coastguard Worker } 869*4bdc9457SAndroid Build Coastguard Worker } 870*4bdc9457SAndroid Build Coastguard Worker 871*4bdc9457SAndroid Build Coastguard Worker // Create, setup, and run Max Pooling operator once. 872*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 873*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t max_pooling_op = nullptr; 874*4bdc9457SAndroid Build Coastguard Worker 875*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 876*4bdc9457SAndroid Build Coastguard Worker xnn_create_max_pooling2d_nhwc_s8( 877*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(), 878*4bdc9457SAndroid Build Coastguard Worker pooling_height(), pooling_width(), 879*4bdc9457SAndroid Build Coastguard Worker stride_height(), stride_width(), 880*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), 881*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 882*4bdc9457SAndroid Build Coastguard Worker int8_t(qmin() - 0x80), int8_t(qmax() - 0x80), 883*4bdc9457SAndroid Build Coastguard Worker 0, &max_pooling_op)); 884*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, max_pooling_op); 885*4bdc9457SAndroid Build Coastguard Worker 886*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete max_pooling_op. 887*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_max_pooling_op(max_pooling_op, xnn_delete_operator); 888*4bdc9457SAndroid Build Coastguard Worker 889*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 890*4bdc9457SAndroid Build Coastguard Worker xnn_setup_max_pooling2d_nhwc_s8( 891*4bdc9457SAndroid Build Coastguard Worker max_pooling_op, 892*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 893*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 894*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 895*4bdc9457SAndroid Build Coastguard Worker 896*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 897*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(max_pooling_op, nullptr /* thread pool */)); 898*4bdc9457SAndroid Build Coastguard Worker 899*4bdc9457SAndroid Build Coastguard Worker // Verify results of the first run. 900*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 901*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 902*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 903*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 904*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), int32_t(qmax() - 0x80)); 905*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), int32_t(qmin() - 0x80)); 906*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(int32_t(output_ref[((i * output_height() + y) * output_width() + x) * channels() + c]), 907*4bdc9457SAndroid Build Coastguard Worker int32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c])) << 908*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 909*4bdc9457SAndroid Build Coastguard Worker } 910*4bdc9457SAndroid Build Coastguard Worker } 911*4bdc9457SAndroid Build Coastguard Worker } 912*4bdc9457SAndroid Build Coastguard Worker } 913*4bdc9457SAndroid Build Coastguard Worker 914*4bdc9457SAndroid Build Coastguard Worker // Re-generate data for the second run. 915*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); }); 916*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), 0xA5); 917*4bdc9457SAndroid Build Coastguard Worker 918*4bdc9457SAndroid Build Coastguard Worker // Compute reference results for the second run. 919*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 920*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) { 921*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) { 922*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 923*4bdc9457SAndroid Build Coastguard Worker int8_t max_value = std::numeric_limits<int8_t>::min(); 924*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 925*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * stride_height() + py * dilation_height() - padding_top(); 926*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 927*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * stride_width() + px * dilation_width() - padding_left(); 928*4bdc9457SAndroid Build Coastguard Worker if (ix < next_input_width() && iy < next_input_height()) { 929*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, 930*4bdc9457SAndroid Build Coastguard Worker input[((i * next_input_height() + iy) * next_input_width() + ix) * input_pixel_stride() + c]); 931*4bdc9457SAndroid Build Coastguard Worker } 932*4bdc9457SAndroid Build Coastguard Worker } 933*4bdc9457SAndroid Build Coastguard Worker } 934*4bdc9457SAndroid Build Coastguard Worker max_value = std::min(max_value, int8_t(qmax() - 0x80)); 935*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, int8_t(qmin() - 0x80)); 936*4bdc9457SAndroid Build Coastguard Worker next_output_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c] = max_value; 937*4bdc9457SAndroid Build Coastguard Worker } 938*4bdc9457SAndroid Build Coastguard Worker } 939*4bdc9457SAndroid Build Coastguard Worker } 940*4bdc9457SAndroid Build Coastguard Worker } 941*4bdc9457SAndroid Build Coastguard Worker 942*4bdc9457SAndroid Build Coastguard Worker // Setup and run Max Pooling operator the second time, and destroy the operator. 943*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 944*4bdc9457SAndroid Build Coastguard Worker xnn_setup_max_pooling2d_nhwc_s8( 945*4bdc9457SAndroid Build Coastguard Worker max_pooling_op, 946*4bdc9457SAndroid Build Coastguard Worker next_batch_size(), next_input_height(), next_input_width(), 947*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 948*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 949*4bdc9457SAndroid Build Coastguard Worker 950*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 951*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(max_pooling_op, nullptr /* thread pool */)); 952*4bdc9457SAndroid Build Coastguard Worker 953*4bdc9457SAndroid Build Coastguard Worker // Verify results of the second run. 954*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 955*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < next_output_height(); y++) { 956*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < next_output_width(); x++) { 957*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 958*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c]), int32_t(qmax() - 0x80)); 959*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c]), int32_t(qmin() - 0x80)); 960*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(int32_t(next_output_ref[((i * next_output_height() + y) * next_output_width() + x) * channels() + c]), 961*4bdc9457SAndroid Build Coastguard Worker int32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c])) << 962*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 963*4bdc9457SAndroid Build Coastguard Worker } 964*4bdc9457SAndroid Build Coastguard Worker } 965*4bdc9457SAndroid Build Coastguard Worker } 966*4bdc9457SAndroid Build Coastguard Worker } 967*4bdc9457SAndroid Build Coastguard Worker } 968*4bdc9457SAndroid Build Coastguard Worker } 969*4bdc9457SAndroid Build Coastguard Worker TestSetupU8()970*4bdc9457SAndroid Build Coastguard Worker void TestSetupU8() const { 971*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 972*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 973*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> u8dist( 974*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max()); 975*4bdc9457SAndroid Build Coastguard Worker 976*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + std::max<size_t>( 977*4bdc9457SAndroid Build Coastguard Worker (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels(), 978*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_input_height() * next_input_width() - 1) * input_pixel_stride() + channels())); 979*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output(XNN_EXTRA_BYTES / sizeof(uint8_t) + std::max<size_t>( 980*4bdc9457SAndroid Build Coastguard Worker (batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels(), 981*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_output_height() * next_output_width() - 1) * output_pixel_stride() + channels())); 982*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels()); 983*4bdc9457SAndroid Build Coastguard Worker std::vector<float> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * channels()); 984*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 985*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); 986*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), 0xA5); 987*4bdc9457SAndroid Build Coastguard Worker 988*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 989*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 990*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 991*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 992*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 993*4bdc9457SAndroid Build Coastguard Worker uint8_t max_value = 0; 994*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 995*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * stride_height() + py * dilation_height() - padding_top(); 996*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 997*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * stride_width() + px * dilation_width() - padding_left(); 998*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width() && iy < input_height()) { 999*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, 1000*4bdc9457SAndroid Build Coastguard Worker input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c]); 1001*4bdc9457SAndroid Build Coastguard Worker } 1002*4bdc9457SAndroid Build Coastguard Worker } 1003*4bdc9457SAndroid Build Coastguard Worker } 1004*4bdc9457SAndroid Build Coastguard Worker max_value = std::min(max_value, qmax()); 1005*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, qmin()); 1006*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = max_value; 1007*4bdc9457SAndroid Build Coastguard Worker } 1008*4bdc9457SAndroid Build Coastguard Worker } 1009*4bdc9457SAndroid Build Coastguard Worker } 1010*4bdc9457SAndroid Build Coastguard Worker } 1011*4bdc9457SAndroid Build Coastguard Worker 1012*4bdc9457SAndroid Build Coastguard Worker // Create, setup, and run Max Pooling operator once. 1013*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 1014*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t max_pooling_op = nullptr; 1015*4bdc9457SAndroid Build Coastguard Worker 1016*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1017*4bdc9457SAndroid Build Coastguard Worker xnn_create_max_pooling2d_nhwc_u8( 1018*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(), 1019*4bdc9457SAndroid Build Coastguard Worker pooling_height(), pooling_width(), 1020*4bdc9457SAndroid Build Coastguard Worker stride_height(), stride_width(), 1021*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), 1022*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 1023*4bdc9457SAndroid Build Coastguard Worker qmin(), qmax(), 1024*4bdc9457SAndroid Build Coastguard Worker 0, &max_pooling_op)); 1025*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, max_pooling_op); 1026*4bdc9457SAndroid Build Coastguard Worker 1027*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete max_pooling_op. 1028*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_max_pooling_op(max_pooling_op, xnn_delete_operator); 1029*4bdc9457SAndroid Build Coastguard Worker 1030*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1031*4bdc9457SAndroid Build Coastguard Worker xnn_setup_max_pooling2d_nhwc_u8( 1032*4bdc9457SAndroid Build Coastguard Worker max_pooling_op, 1033*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 1034*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 1035*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 1036*4bdc9457SAndroid Build Coastguard Worker 1037*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1038*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(max_pooling_op, nullptr /* thread pool */)); 1039*4bdc9457SAndroid Build Coastguard Worker 1040*4bdc9457SAndroid Build Coastguard Worker // Verify results of the first run. 1041*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1042*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 1043*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 1044*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 1045*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(uint32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), uint32_t(qmax())); 1046*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(uint32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), uint32_t(qmin())); 1047*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(uint32_t(output_ref[((i * output_height() + y) * output_width() + x) * channels() + c]), 1048*4bdc9457SAndroid Build Coastguard Worker uint32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c])) << 1049*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 1050*4bdc9457SAndroid Build Coastguard Worker } 1051*4bdc9457SAndroid Build Coastguard Worker } 1052*4bdc9457SAndroid Build Coastguard Worker } 1053*4bdc9457SAndroid Build Coastguard Worker } 1054*4bdc9457SAndroid Build Coastguard Worker 1055*4bdc9457SAndroid Build Coastguard Worker // Re-generate data for the second run. 1056*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); 1057*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), 0xA5); 1058*4bdc9457SAndroid Build Coastguard Worker 1059*4bdc9457SAndroid Build Coastguard Worker // Compute reference results for the second run. 1060*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 1061*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) { 1062*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) { 1063*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 1064*4bdc9457SAndroid Build Coastguard Worker uint8_t max_value = 0; 1065*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 1066*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * stride_height() + py * dilation_height() - padding_top(); 1067*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 1068*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * stride_width() + px * dilation_width() - padding_left(); 1069*4bdc9457SAndroid Build Coastguard Worker if (ix < next_input_width() && iy < next_input_height()) { 1070*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, 1071*4bdc9457SAndroid Build Coastguard Worker input[((i * next_input_height() + iy) * next_input_width() + ix) * input_pixel_stride() + c]); 1072*4bdc9457SAndroid Build Coastguard Worker } 1073*4bdc9457SAndroid Build Coastguard Worker } 1074*4bdc9457SAndroid Build Coastguard Worker } 1075*4bdc9457SAndroid Build Coastguard Worker max_value = std::min(max_value, qmax()); 1076*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, qmin()); 1077*4bdc9457SAndroid Build Coastguard Worker next_output_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c] = max_value; 1078*4bdc9457SAndroid Build Coastguard Worker } 1079*4bdc9457SAndroid Build Coastguard Worker } 1080*4bdc9457SAndroid Build Coastguard Worker } 1081*4bdc9457SAndroid Build Coastguard Worker } 1082*4bdc9457SAndroid Build Coastguard Worker 1083*4bdc9457SAndroid Build Coastguard Worker // Setup and run Max Pooling operator the second time, and destroy the operator. 1084*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1085*4bdc9457SAndroid Build Coastguard Worker xnn_setup_max_pooling2d_nhwc_u8( 1086*4bdc9457SAndroid Build Coastguard Worker max_pooling_op, 1087*4bdc9457SAndroid Build Coastguard Worker next_batch_size(), next_input_height(), next_input_width(), 1088*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 1089*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 1090*4bdc9457SAndroid Build Coastguard Worker 1091*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1092*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(max_pooling_op, nullptr /* thread pool */)); 1093*4bdc9457SAndroid Build Coastguard Worker 1094*4bdc9457SAndroid Build Coastguard Worker // Verify results of the second run. 1095*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 1096*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < next_output_height(); y++) { 1097*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < next_output_width(); x++) { 1098*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 1099*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(uint32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c]), uint32_t(qmax())); 1100*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(uint32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c]), uint32_t(qmin())); 1101*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(uint32_t(next_output_ref[((i * next_output_height() + y) * next_output_width() + x) * channels() + c]), 1102*4bdc9457SAndroid Build Coastguard Worker uint32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c])) << 1103*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 1104*4bdc9457SAndroid Build Coastguard Worker } 1105*4bdc9457SAndroid Build Coastguard Worker } 1106*4bdc9457SAndroid Build Coastguard Worker } 1107*4bdc9457SAndroid Build Coastguard Worker } 1108*4bdc9457SAndroid Build Coastguard Worker } 1109*4bdc9457SAndroid Build Coastguard Worker } 1110*4bdc9457SAndroid Build Coastguard Worker TestSetupF16()1111*4bdc9457SAndroid Build Coastguard Worker void TestSetupF16() const { 1112*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 1113*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 1114*4bdc9457SAndroid Build Coastguard Worker // Note: we need to avoid FP16 denormals in the generated tensor because they might be processed differently in 1115*4bdc9457SAndroid Build Coastguard Worker // native vs emulated arithmetics, and we use exact comparison to verify the results against reference. 1116*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.001f, 1.0f); 1117*4bdc9457SAndroid Build Coastguard Worker 1118*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + std::max<size_t>( 1119*4bdc9457SAndroid Build Coastguard Worker (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels(), 1120*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_input_height() * next_input_width() - 1) * input_pixel_stride() + channels())); 1121*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output(XNN_EXTRA_BYTES / sizeof(uint16_t) + std::max<size_t>( 1122*4bdc9457SAndroid Build Coastguard Worker (batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels(), 1123*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_output_height() * next_output_width() - 1) * output_pixel_stride() + channels())); 1124*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels()); 1125*4bdc9457SAndroid Build Coastguard Worker std::vector<float> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * channels()); 1126*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 1127*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 1128*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 1129*4bdc9457SAndroid Build Coastguard Worker 1130*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 1131*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1132*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 1133*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 1134*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 1135*4bdc9457SAndroid Build Coastguard Worker float max_value = -std::numeric_limits<float>::infinity(); 1136*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 1137*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * stride_height() + py * dilation_height() - padding_top(); 1138*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 1139*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * stride_width() + px * dilation_width() - padding_left(); 1140*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width() && iy < input_height()) { 1141*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, 1142*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c])); 1143*4bdc9457SAndroid Build Coastguard Worker } 1144*4bdc9457SAndroid Build Coastguard Worker } 1145*4bdc9457SAndroid Build Coastguard Worker } 1146*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = max_value; 1147*4bdc9457SAndroid Build Coastguard Worker } 1148*4bdc9457SAndroid Build Coastguard Worker } 1149*4bdc9457SAndroid Build Coastguard Worker } 1150*4bdc9457SAndroid Build Coastguard Worker } 1151*4bdc9457SAndroid Build Coastguard Worker 1152*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 1153*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 1154*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 1155*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 1156*4bdc9457SAndroid Build Coastguard Worker float output_min = accumulated_min + accumulated_range / 255.0f * float(qmin()); 1157*4bdc9457SAndroid Build Coastguard Worker float output_max = accumulated_max - accumulated_range / 255.0f * float(255 - qmax()); 1158*4bdc9457SAndroid Build Coastguard Worker output_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_min)); 1159*4bdc9457SAndroid Build Coastguard Worker output_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_max)); 1160*4bdc9457SAndroid Build Coastguard Worker if (accumulated_range == 0.0f) { 1161*4bdc9457SAndroid Build Coastguard Worker output_min = -std::numeric_limits<float>::infinity(); 1162*4bdc9457SAndroid Build Coastguard Worker output_max = +std::numeric_limits<float>::infinity(); 1163*4bdc9457SAndroid Build Coastguard Worker } 1164*4bdc9457SAndroid Build Coastguard Worker if (qmin() == std::numeric_limits<uint8_t>::min()) { 1165*4bdc9457SAndroid Build Coastguard Worker output_min = -std::numeric_limits<float>::infinity(); 1166*4bdc9457SAndroid Build Coastguard Worker } 1167*4bdc9457SAndroid Build Coastguard Worker if (qmax() == std::numeric_limits<uint8_t>::max()) { 1168*4bdc9457SAndroid Build Coastguard Worker output_max = +std::numeric_limits<float>::infinity(); 1169*4bdc9457SAndroid Build Coastguard Worker } 1170*4bdc9457SAndroid Build Coastguard Worker 1171*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 1172*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) { 1173*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min); 1174*4bdc9457SAndroid Build Coastguard Worker } 1175*4bdc9457SAndroid Build Coastguard Worker 1176*4bdc9457SAndroid Build Coastguard Worker // Create, setup, and run Max Pooling operator once. 1177*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 1178*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t max_pooling_op = nullptr; 1179*4bdc9457SAndroid Build Coastguard Worker 1180*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_create_max_pooling2d_nhwc_f16( 1181*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(), 1182*4bdc9457SAndroid Build Coastguard Worker pooling_height(), pooling_width(), 1183*4bdc9457SAndroid Build Coastguard Worker stride_height(), stride_width(), 1184*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), 1185*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 1186*4bdc9457SAndroid Build Coastguard Worker output_min, output_max, 1187*4bdc9457SAndroid Build Coastguard Worker 0, &max_pooling_op); 1188*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 1189*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 1190*4bdc9457SAndroid Build Coastguard Worker } 1191*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 1192*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, max_pooling_op); 1193*4bdc9457SAndroid Build Coastguard Worker 1194*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete max_pooling_op. 1195*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_max_pooling_op(max_pooling_op, xnn_delete_operator); 1196*4bdc9457SAndroid Build Coastguard Worker 1197*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1198*4bdc9457SAndroid Build Coastguard Worker xnn_setup_max_pooling2d_nhwc_f16( 1199*4bdc9457SAndroid Build Coastguard Worker max_pooling_op, 1200*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 1201*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 1202*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 1203*4bdc9457SAndroid Build Coastguard Worker 1204*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1205*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(max_pooling_op, nullptr /* thread pool */)); 1206*4bdc9457SAndroid Build Coastguard Worker 1207*4bdc9457SAndroid Build Coastguard Worker // Verify results of the first run. 1208*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1209*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 1210*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 1211*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 1212*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), output_max); 1213*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), output_min); 1214*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ( 1215*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]), 1216*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + y) * output_width() + x) * channels() + c]) << 1217*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c 1218*4bdc9457SAndroid Build Coastguard Worker << ", min = " << output_min << ", max = " << output_max; 1219*4bdc9457SAndroid Build Coastguard Worker } 1220*4bdc9457SAndroid Build Coastguard Worker } 1221*4bdc9457SAndroid Build Coastguard Worker } 1222*4bdc9457SAndroid Build Coastguard Worker } 1223*4bdc9457SAndroid Build Coastguard Worker 1224*4bdc9457SAndroid Build Coastguard Worker // Re-generate data for the second run. 1225*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 1226*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 1227*4bdc9457SAndroid Build Coastguard Worker 1228*4bdc9457SAndroid Build Coastguard Worker // Compute reference results for the second run, including clamping. 1229*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 1230*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) { 1231*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) { 1232*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 1233*4bdc9457SAndroid Build Coastguard Worker float max_value = -std::numeric_limits<float>::infinity(); 1234*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 1235*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * stride_height() + py * dilation_height() - padding_top(); 1236*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 1237*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * stride_width() + px * dilation_width() - padding_left(); 1238*4bdc9457SAndroid Build Coastguard Worker if (ix < next_input_width() && iy < next_input_height()) { 1239*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, 1240*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(input[((i * next_input_height() + iy) * next_input_width() + ix) * input_pixel_stride() + c])); 1241*4bdc9457SAndroid Build Coastguard Worker } 1242*4bdc9457SAndroid Build Coastguard Worker } 1243*4bdc9457SAndroid Build Coastguard Worker } 1244*4bdc9457SAndroid Build Coastguard Worker max_value = std::min(max_value, output_max); 1245*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, output_min); 1246*4bdc9457SAndroid Build Coastguard Worker next_output_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c] = max_value; 1247*4bdc9457SAndroid Build Coastguard Worker } 1248*4bdc9457SAndroid Build Coastguard Worker } 1249*4bdc9457SAndroid Build Coastguard Worker } 1250*4bdc9457SAndroid Build Coastguard Worker } 1251*4bdc9457SAndroid Build Coastguard Worker 1252*4bdc9457SAndroid Build Coastguard Worker // Setup and run Max Pooling operator the second time, and destroy the operator. 1253*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1254*4bdc9457SAndroid Build Coastguard Worker xnn_setup_max_pooling2d_nhwc_f16( 1255*4bdc9457SAndroid Build Coastguard Worker max_pooling_op, 1256*4bdc9457SAndroid Build Coastguard Worker next_batch_size(), next_input_height(), next_input_width(), 1257*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 1258*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 1259*4bdc9457SAndroid Build Coastguard Worker 1260*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1261*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(max_pooling_op, nullptr /* thread pool */)); 1262*4bdc9457SAndroid Build Coastguard Worker 1263*4bdc9457SAndroid Build Coastguard Worker // Verify results of the second run. 1264*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 1265*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < next_output_height(); y++) { 1266*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < next_output_width(); x++) { 1267*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 1268*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(fp16_ieee_to_fp32_value(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c]), output_max); 1269*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(fp16_ieee_to_fp32_value(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c]), output_min); 1270*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ( 1271*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c]), 1272*4bdc9457SAndroid Build Coastguard Worker next_output_ref[((i * next_output_height() + y) * next_output_width() + x) * channels() + c]) << 1273*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c 1274*4bdc9457SAndroid Build Coastguard Worker << ", min = " << output_min << ", max = " << output_max; 1275*4bdc9457SAndroid Build Coastguard Worker } 1276*4bdc9457SAndroid Build Coastguard Worker } 1277*4bdc9457SAndroid Build Coastguard Worker } 1278*4bdc9457SAndroid Build Coastguard Worker } 1279*4bdc9457SAndroid Build Coastguard Worker } 1280*4bdc9457SAndroid Build Coastguard Worker } 1281*4bdc9457SAndroid Build Coastguard Worker TestSetupF32()1282*4bdc9457SAndroid Build Coastguard Worker void TestSetupF32() const { 1283*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 1284*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 1285*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist; 1286*4bdc9457SAndroid Build Coastguard Worker 1287*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + std::max<size_t>( 1288*4bdc9457SAndroid Build Coastguard Worker (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + channels(), 1289*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_input_height() * next_input_width() - 1) * input_pixel_stride() + channels())); 1290*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output(XNN_EXTRA_BYTES / sizeof(float) + std::max<size_t>( 1291*4bdc9457SAndroid Build Coastguard Worker (batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + channels(), 1292*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_output_height() * next_output_width() - 1) * output_pixel_stride() + channels())); 1293*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * channels()); 1294*4bdc9457SAndroid Build Coastguard Worker std::vector<float> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * channels()); 1295*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 1296*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 1297*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), nanf("")); 1298*4bdc9457SAndroid Build Coastguard Worker 1299*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 1300*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1301*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 1302*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 1303*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 1304*4bdc9457SAndroid Build Coastguard Worker float max_value = -std::numeric_limits<float>::infinity(); 1305*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 1306*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * stride_height() + py * dilation_height() - padding_top(); 1307*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 1308*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * stride_width() + px * dilation_width() - padding_left(); 1309*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width() && iy < input_height()) { 1310*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, 1311*4bdc9457SAndroid Build Coastguard Worker input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + c]); 1312*4bdc9457SAndroid Build Coastguard Worker } 1313*4bdc9457SAndroid Build Coastguard Worker } 1314*4bdc9457SAndroid Build Coastguard Worker } 1315*4bdc9457SAndroid Build Coastguard Worker output_ref[((i * output_height() + oy) * output_width() + ox) * channels() + c] = max_value; 1316*4bdc9457SAndroid Build Coastguard Worker } 1317*4bdc9457SAndroid Build Coastguard Worker } 1318*4bdc9457SAndroid Build Coastguard Worker } 1319*4bdc9457SAndroid Build Coastguard Worker } 1320*4bdc9457SAndroid Build Coastguard Worker 1321*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 1322*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 1323*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 1324*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 1325*4bdc9457SAndroid Build Coastguard Worker const float output_min = accumulated_range == 0.0f ? 1326*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<float>::infinity() : 1327*4bdc9457SAndroid Build Coastguard Worker accumulated_min + accumulated_range / 255.0f * float(qmin()); 1328*4bdc9457SAndroid Build Coastguard Worker const float output_max = accumulated_range == 0.0f ? 1329*4bdc9457SAndroid Build Coastguard Worker +std::numeric_limits<float>::infinity() : 1330*4bdc9457SAndroid Build Coastguard Worker accumulated_max - accumulated_range / 255.0f * float(255 - qmax()); 1331*4bdc9457SAndroid Build Coastguard Worker 1332*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 1333*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) { 1334*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min); 1335*4bdc9457SAndroid Build Coastguard Worker } 1336*4bdc9457SAndroid Build Coastguard Worker 1337*4bdc9457SAndroid Build Coastguard Worker // Create, setup, and run Max Pooling operator once. 1338*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 1339*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t max_pooling_op = nullptr; 1340*4bdc9457SAndroid Build Coastguard Worker 1341*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1342*4bdc9457SAndroid Build Coastguard Worker xnn_create_max_pooling2d_nhwc_f32( 1343*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(), 1344*4bdc9457SAndroid Build Coastguard Worker pooling_height(), pooling_width(), 1345*4bdc9457SAndroid Build Coastguard Worker stride_height(), stride_width(), 1346*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), 1347*4bdc9457SAndroid Build Coastguard Worker channels(), input_pixel_stride(), output_pixel_stride(), 1348*4bdc9457SAndroid Build Coastguard Worker output_min, output_max, 1349*4bdc9457SAndroid Build Coastguard Worker 0, &max_pooling_op)); 1350*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, max_pooling_op); 1351*4bdc9457SAndroid Build Coastguard Worker 1352*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete max_pooling_op. 1353*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_max_pooling_op(max_pooling_op, xnn_delete_operator); 1354*4bdc9457SAndroid Build Coastguard Worker 1355*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1356*4bdc9457SAndroid Build Coastguard Worker xnn_setup_max_pooling2d_nhwc_f32( 1357*4bdc9457SAndroid Build Coastguard Worker max_pooling_op, 1358*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 1359*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 1360*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 1361*4bdc9457SAndroid Build Coastguard Worker 1362*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1363*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(max_pooling_op, nullptr /* thread pool */)); 1364*4bdc9457SAndroid Build Coastguard Worker 1365*4bdc9457SAndroid Build Coastguard Worker // Verify results of the first run. 1366*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1367*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 1368*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 1369*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 1370*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c], output_max); 1371*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c], output_min); 1372*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(output_ref[((i * output_height() + y) * output_width() + x) * channels() + c], 1373*4bdc9457SAndroid Build Coastguard Worker output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + c]) << 1374*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 1375*4bdc9457SAndroid Build Coastguard Worker } 1376*4bdc9457SAndroid Build Coastguard Worker } 1377*4bdc9457SAndroid Build Coastguard Worker } 1378*4bdc9457SAndroid Build Coastguard Worker } 1379*4bdc9457SAndroid Build Coastguard Worker 1380*4bdc9457SAndroid Build Coastguard Worker // Re-generate data for the second run. 1381*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 1382*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), std::nanf("")); 1383*4bdc9457SAndroid Build Coastguard Worker 1384*4bdc9457SAndroid Build Coastguard Worker // Compute reference results for the second run, including clamping. 1385*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 1386*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) { 1387*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) { 1388*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 1389*4bdc9457SAndroid Build Coastguard Worker float max_value = -std::numeric_limits<float>::infinity(); 1390*4bdc9457SAndroid Build Coastguard Worker for (size_t py = 0; py < pooling_height(); py++) { 1391*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * stride_height() + py * dilation_height() - padding_top(); 1392*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < pooling_width(); px++) { 1393*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * stride_width() + px * dilation_width() - padding_left(); 1394*4bdc9457SAndroid Build Coastguard Worker if (ix < next_input_width() && iy < next_input_height()) { 1395*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, 1396*4bdc9457SAndroid Build Coastguard Worker input[((i * next_input_height() + iy) * next_input_width() + ix) * input_pixel_stride() + c]); 1397*4bdc9457SAndroid Build Coastguard Worker } 1398*4bdc9457SAndroid Build Coastguard Worker } 1399*4bdc9457SAndroid Build Coastguard Worker } 1400*4bdc9457SAndroid Build Coastguard Worker max_value = std::min(max_value, output_max); 1401*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, output_min); 1402*4bdc9457SAndroid Build Coastguard Worker next_output_ref[((i * next_output_height() + oy) * next_output_width() + ox) * channels() + c] = max_value; 1403*4bdc9457SAndroid Build Coastguard Worker } 1404*4bdc9457SAndroid Build Coastguard Worker } 1405*4bdc9457SAndroid Build Coastguard Worker } 1406*4bdc9457SAndroid Build Coastguard Worker } 1407*4bdc9457SAndroid Build Coastguard Worker 1408*4bdc9457SAndroid Build Coastguard Worker // Setup and run Max Pooling operator the second time, and destroy the operator. 1409*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1410*4bdc9457SAndroid Build Coastguard Worker xnn_setup_max_pooling2d_nhwc_f32( 1411*4bdc9457SAndroid Build Coastguard Worker max_pooling_op, 1412*4bdc9457SAndroid Build Coastguard Worker next_batch_size(), next_input_height(), next_input_width(), 1413*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 1414*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 1415*4bdc9457SAndroid Build Coastguard Worker 1416*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1417*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(max_pooling_op, nullptr /* thread pool */)); 1418*4bdc9457SAndroid Build Coastguard Worker 1419*4bdc9457SAndroid Build Coastguard Worker // Verify results of the second run. 1420*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 1421*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < next_output_height(); y++) { 1422*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < next_output_width(); x++) { 1423*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 1424*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c], output_max); 1425*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c], output_min); 1426*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(next_output_ref[((i * next_output_height() + y) * next_output_width() + x) * channels() + c], 1427*4bdc9457SAndroid Build Coastguard Worker output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + c]) << 1428*4bdc9457SAndroid Build Coastguard Worker "in batch index " << i << ", pixel (" << y << ", " << x << "), channel " << c; 1429*4bdc9457SAndroid Build Coastguard Worker } 1430*4bdc9457SAndroid Build Coastguard Worker } 1431*4bdc9457SAndroid Build Coastguard Worker } 1432*4bdc9457SAndroid Build Coastguard Worker } 1433*4bdc9457SAndroid Build Coastguard Worker } 1434*4bdc9457SAndroid Build Coastguard Worker } 1435*4bdc9457SAndroid Build Coastguard Worker 1436*4bdc9457SAndroid Build Coastguard Worker private: 1437*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_top_{0}; 1438*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_right_{0}; 1439*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_bottom_{0}; 1440*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_left_{0}; 1441*4bdc9457SAndroid Build Coastguard Worker bool padding_tf_same_{false}; 1442*4bdc9457SAndroid Build Coastguard Worker size_t input_height_{1}; 1443*4bdc9457SAndroid Build Coastguard Worker size_t input_width_{1}; 1444*4bdc9457SAndroid Build Coastguard Worker size_t channels_{1}; 1445*4bdc9457SAndroid Build Coastguard Worker size_t batch_size_{1}; 1446*4bdc9457SAndroid Build Coastguard Worker size_t input_pixel_stride_{0}; 1447*4bdc9457SAndroid Build Coastguard Worker size_t output_pixel_stride_{0}; 1448*4bdc9457SAndroid Build Coastguard Worker uint32_t pooling_height_{1}; 1449*4bdc9457SAndroid Build Coastguard Worker uint32_t pooling_width_{1}; 1450*4bdc9457SAndroid Build Coastguard Worker uint32_t stride_height_{1}; 1451*4bdc9457SAndroid Build Coastguard Worker uint32_t stride_width_{1}; 1452*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_height_{1}; 1453*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_width_{1}; 1454*4bdc9457SAndroid Build Coastguard Worker size_t next_input_height_{0}; 1455*4bdc9457SAndroid Build Coastguard Worker size_t next_input_width_{0}; 1456*4bdc9457SAndroid Build Coastguard Worker size_t next_batch_size_{0}; 1457*4bdc9457SAndroid Build Coastguard Worker uint8_t qmin_{0}; 1458*4bdc9457SAndroid Build Coastguard Worker uint8_t qmax_{255}; 1459*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{1}; 1460*4bdc9457SAndroid Build Coastguard Worker }; 1461