1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates. 2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved. 3*4bdc9457SAndroid Build Coastguard Worker // 4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC 5*4bdc9457SAndroid Build Coastguard Worker // 6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 8*4bdc9457SAndroid Build Coastguard Worker 9*4bdc9457SAndroid Build Coastguard Worker #pragma once 10*4bdc9457SAndroid Build Coastguard Worker 11*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 12*4bdc9457SAndroid Build Coastguard Worker 13*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 14*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 15*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 16*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 17*4bdc9457SAndroid Build Coastguard Worker #include <limits> 18*4bdc9457SAndroid Build Coastguard Worker #include <random> 19*4bdc9457SAndroid Build Coastguard Worker #include <vector> 20*4bdc9457SAndroid Build Coastguard Worker 21*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h> 22*4bdc9457SAndroid Build Coastguard Worker 23*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 24*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microparams-init.h> 25*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h> 26*4bdc9457SAndroid Build Coastguard Worker 27*4bdc9457SAndroid Build Coastguard Worker 28*4bdc9457SAndroid Build Coastguard Worker class MaxPoolMicrokernelTester { 29*4bdc9457SAndroid Build Coastguard Worker public: output_pixels(size_t output_pixels)30*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolMicrokernelTester& output_pixels(size_t output_pixels) { 31*4bdc9457SAndroid Build Coastguard Worker assert(output_pixels != 0); 32*4bdc9457SAndroid Build Coastguard Worker this->output_pixels_ = output_pixels; 33*4bdc9457SAndroid Build Coastguard Worker return *this; 34*4bdc9457SAndroid Build Coastguard Worker } 35*4bdc9457SAndroid Build Coastguard Worker output_pixels()36*4bdc9457SAndroid Build Coastguard Worker inline size_t output_pixels() const { 37*4bdc9457SAndroid Build Coastguard Worker return this->output_pixels_; 38*4bdc9457SAndroid Build Coastguard Worker } 39*4bdc9457SAndroid Build Coastguard Worker step(size_t step)40*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolMicrokernelTester& step(size_t step) { 41*4bdc9457SAndroid Build Coastguard Worker assert(step != 0); 42*4bdc9457SAndroid Build Coastguard Worker this->step_ = step; 43*4bdc9457SAndroid Build Coastguard Worker return *this; 44*4bdc9457SAndroid Build Coastguard Worker } 45*4bdc9457SAndroid Build Coastguard Worker step()46*4bdc9457SAndroid Build Coastguard Worker inline size_t step() const { 47*4bdc9457SAndroid Build Coastguard Worker return this->step_; 48*4bdc9457SAndroid Build Coastguard Worker } 49*4bdc9457SAndroid Build Coastguard Worker input_offset(size_t input_offset)50*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolMicrokernelTester& input_offset(size_t input_offset) { 51*4bdc9457SAndroid Build Coastguard Worker assert(input_offset != 0); 52*4bdc9457SAndroid Build Coastguard Worker this->input_offset_ = input_offset; 53*4bdc9457SAndroid Build Coastguard Worker return *this; 54*4bdc9457SAndroid Build Coastguard Worker } 55*4bdc9457SAndroid Build Coastguard Worker input_offset()56*4bdc9457SAndroid Build Coastguard Worker inline size_t input_offset() const { 57*4bdc9457SAndroid Build Coastguard Worker return this->input_offset_; 58*4bdc9457SAndroid Build Coastguard Worker } 59*4bdc9457SAndroid Build Coastguard Worker pooling_elements(size_t pooling_elements)60*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolMicrokernelTester& pooling_elements(size_t pooling_elements) { 61*4bdc9457SAndroid Build Coastguard Worker assert(pooling_elements != 0); 62*4bdc9457SAndroid Build Coastguard Worker this->pooling_elements_ = pooling_elements; 63*4bdc9457SAndroid Build Coastguard Worker return *this; 64*4bdc9457SAndroid Build Coastguard Worker } 65*4bdc9457SAndroid Build Coastguard Worker pooling_elements()66*4bdc9457SAndroid Build Coastguard Worker inline size_t pooling_elements() const { 67*4bdc9457SAndroid Build Coastguard Worker return this->pooling_elements_; 68*4bdc9457SAndroid Build Coastguard Worker } 69*4bdc9457SAndroid Build Coastguard Worker packed_pooling_elements()70*4bdc9457SAndroid Build Coastguard Worker inline size_t packed_pooling_elements() const { 71*4bdc9457SAndroid Build Coastguard Worker if (pooling_elements() <= primary_pooling_tile()) { 72*4bdc9457SAndroid Build Coastguard Worker return primary_pooling_tile(); 73*4bdc9457SAndroid Build Coastguard Worker } else { 74*4bdc9457SAndroid Build Coastguard Worker return (pooling_elements() - primary_pooling_tile()) % incremental_pooling_tile() == 0 ? pooling_elements() : ((pooling_elements() - primary_pooling_tile()) / incremental_pooling_tile() + 1) * incremental_pooling_tile() + primary_pooling_tile(); 75*4bdc9457SAndroid Build Coastguard Worker } 76*4bdc9457SAndroid Build Coastguard Worker } 77*4bdc9457SAndroid Build Coastguard Worker pooling_tile(size_t primary_tile,size_t incremental_tile)78*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolMicrokernelTester& pooling_tile(size_t primary_tile, size_t incremental_tile) { 79*4bdc9457SAndroid Build Coastguard Worker assert(primary_tile != 0); 80*4bdc9457SAndroid Build Coastguard Worker this->primary_pooling_tile_ = primary_tile; 81*4bdc9457SAndroid Build Coastguard Worker this->incremental_pooling_tile_ = incremental_tile; 82*4bdc9457SAndroid Build Coastguard Worker return *this; 83*4bdc9457SAndroid Build Coastguard Worker } 84*4bdc9457SAndroid Build Coastguard Worker primary_pooling_tile(size_t primary_pooling_tile)85*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolMicrokernelTester& primary_pooling_tile(size_t primary_pooling_tile) { 86*4bdc9457SAndroid Build Coastguard Worker assert(primary_pooling_tile != 0); 87*4bdc9457SAndroid Build Coastguard Worker this->primary_pooling_tile_ = primary_pooling_tile; 88*4bdc9457SAndroid Build Coastguard Worker return *this; 89*4bdc9457SAndroid Build Coastguard Worker } 90*4bdc9457SAndroid Build Coastguard Worker primary_pooling_tile()91*4bdc9457SAndroid Build Coastguard Worker inline size_t primary_pooling_tile() const { 92*4bdc9457SAndroid Build Coastguard Worker return this->primary_pooling_tile_; 93*4bdc9457SAndroid Build Coastguard Worker } 94*4bdc9457SAndroid Build Coastguard Worker incremental_pooling_tile(size_t incremental_pooling_tile)95*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolMicrokernelTester& incremental_pooling_tile(size_t incremental_pooling_tile) { 96*4bdc9457SAndroid Build Coastguard Worker assert(incremental_pooling_tile != 0); 97*4bdc9457SAndroid Build Coastguard Worker this->incremental_pooling_tile_ = incremental_pooling_tile; 98*4bdc9457SAndroid Build Coastguard Worker return *this; 99*4bdc9457SAndroid Build Coastguard Worker } 100*4bdc9457SAndroid Build Coastguard Worker incremental_pooling_tile()101*4bdc9457SAndroid Build Coastguard Worker inline size_t incremental_pooling_tile() const { 102*4bdc9457SAndroid Build Coastguard Worker return this->incremental_pooling_tile_; 103*4bdc9457SAndroid Build Coastguard Worker } 104*4bdc9457SAndroid Build Coastguard Worker channels(size_t channels)105*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolMicrokernelTester& channels(size_t channels) { 106*4bdc9457SAndroid Build Coastguard Worker assert(channels != 0); 107*4bdc9457SAndroid Build Coastguard Worker this->channels_ = channels; 108*4bdc9457SAndroid Build Coastguard Worker return *this; 109*4bdc9457SAndroid Build Coastguard Worker } 110*4bdc9457SAndroid Build Coastguard Worker channels()111*4bdc9457SAndroid Build Coastguard Worker inline size_t channels() const { 112*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 113*4bdc9457SAndroid Build Coastguard Worker } 114*4bdc9457SAndroid Build Coastguard Worker output_stride(size_t output_stride)115*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolMicrokernelTester& output_stride(size_t output_stride) { 116*4bdc9457SAndroid Build Coastguard Worker assert(output_stride != 0); 117*4bdc9457SAndroid Build Coastguard Worker this->output_stride_ = output_stride; 118*4bdc9457SAndroid Build Coastguard Worker return *this; 119*4bdc9457SAndroid Build Coastguard Worker } 120*4bdc9457SAndroid Build Coastguard Worker output_stride()121*4bdc9457SAndroid Build Coastguard Worker inline size_t output_stride() const { 122*4bdc9457SAndroid Build Coastguard Worker if (this->output_stride_ == 0) { 123*4bdc9457SAndroid Build Coastguard Worker return channels(); 124*4bdc9457SAndroid Build Coastguard Worker } else { 125*4bdc9457SAndroid Build Coastguard Worker assert(this->output_stride_ >= channels()); 126*4bdc9457SAndroid Build Coastguard Worker return this->output_stride_; 127*4bdc9457SAndroid Build Coastguard Worker } 128*4bdc9457SAndroid Build Coastguard Worker } 129*4bdc9457SAndroid Build Coastguard Worker qmin(int16_t qmin)130*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolMicrokernelTester& qmin(int16_t qmin) { 131*4bdc9457SAndroid Build Coastguard Worker this->qmin_ = qmin; 132*4bdc9457SAndroid Build Coastguard Worker return *this; 133*4bdc9457SAndroid Build Coastguard Worker } 134*4bdc9457SAndroid Build Coastguard Worker qmin()135*4bdc9457SAndroid Build Coastguard Worker inline int16_t qmin() const { 136*4bdc9457SAndroid Build Coastguard Worker return this->qmin_; 137*4bdc9457SAndroid Build Coastguard Worker } 138*4bdc9457SAndroid Build Coastguard Worker qmax(int16_t qmax)139*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolMicrokernelTester& qmax(int16_t qmax) { 140*4bdc9457SAndroid Build Coastguard Worker this->qmax_ = qmax; 141*4bdc9457SAndroid Build Coastguard Worker return *this; 142*4bdc9457SAndroid Build Coastguard Worker } 143*4bdc9457SAndroid Build Coastguard Worker qmax()144*4bdc9457SAndroid Build Coastguard Worker inline int16_t qmax() const { 145*4bdc9457SAndroid Build Coastguard Worker return this->qmax_; 146*4bdc9457SAndroid Build Coastguard Worker } 147*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)148*4bdc9457SAndroid Build Coastguard Worker inline MaxPoolMicrokernelTester& iterations(size_t iterations) { 149*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 150*4bdc9457SAndroid Build Coastguard Worker return *this; 151*4bdc9457SAndroid Build Coastguard Worker } 152*4bdc9457SAndroid Build Coastguard Worker iterations()153*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 154*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 155*4bdc9457SAndroid Build Coastguard Worker } 156*4bdc9457SAndroid Build Coastguard Worker Test(xnn_s8_maxpool_ukernel_function maxpool,xnn_init_s8_minmax_params_fn init_params)157*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_s8_maxpool_ukernel_function maxpool, xnn_init_s8_minmax_params_fn init_params) const { 158*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(qmin(), std::numeric_limits<int8_t>::min()); 159*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(qmax(), std::numeric_limits<int8_t>::max()); 160*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(qmin(), qmax()); 161*4bdc9457SAndroid Build Coastguard Worker 162*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 163*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 164*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i8dist( 165*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()); 166*4bdc9457SAndroid Build Coastguard Worker 167*4bdc9457SAndroid Build Coastguard Worker std::vector<const int8_t*> indirect_input((output_pixels() - 1) * step() + packed_pooling_elements()); 168*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) + 169*4bdc9457SAndroid Build Coastguard Worker indirect_input.size() * channels()); 170*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output(XNN_EXTRA_BYTES / sizeof(int8_t) + 171*4bdc9457SAndroid Build Coastguard Worker (output_pixels() - 1) * output_stride() + channels()); 172*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output_ref(output_pixels() * channels()); 173*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 174*4bdc9457SAndroid Build Coastguard Worker do { 175*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); }); 176*4bdc9457SAndroid Build Coastguard Worker } while (input.size() > 1 && *std::max_element(input.cbegin(), input.cend()) == *std::min_element(input.cbegin(), input.cend())); 177*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), INT8_C(0xA5)); 178*4bdc9457SAndroid Build Coastguard Worker 179*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < (output_pixels() - 1) * step() + pooling_elements(); i++) { 180*4bdc9457SAndroid Build Coastguard Worker indirect_input[i] = input.data() + i * channels() - input_offset(); 181*4bdc9457SAndroid Build Coastguard Worker } 182*4bdc9457SAndroid Build Coastguard Worker std::shuffle(indirect_input.begin(), 183*4bdc9457SAndroid Build Coastguard Worker indirect_input.begin() + (output_pixels() - 1) * step() + pooling_elements(), rng); 184*4bdc9457SAndroid Build Coastguard Worker 185*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 186*4bdc9457SAndroid Build Coastguard Worker xnn_s8_minmax_params params; 187*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, static_cast<int8_t>(qmin()), static_cast<int8_t>(qmax())); 188*4bdc9457SAndroid Build Coastguard Worker 189*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 190*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_pixels(); x++) { 191*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 192*4bdc9457SAndroid Build Coastguard Worker int8_t max_value = std::numeric_limits<int8_t>::min(); 193*4bdc9457SAndroid Build Coastguard Worker for (size_t p = 0; p < pooling_elements(); p++) { 194*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, indirect_input[x * step() + p][c + input_offset()]); 195*4bdc9457SAndroid Build Coastguard Worker } 196*4bdc9457SAndroid Build Coastguard Worker max_value = std::min(max_value, static_cast<int8_t>(qmax())); 197*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, static_cast<int8_t>(qmin())); 198*4bdc9457SAndroid Build Coastguard Worker output_ref[x * channels() + c] = max_value; 199*4bdc9457SAndroid Build Coastguard Worker } 200*4bdc9457SAndroid Build Coastguard Worker } 201*4bdc9457SAndroid Build Coastguard Worker 202*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 203*4bdc9457SAndroid Build Coastguard Worker maxpool(output_pixels(), pooling_elements(), channels(), 204*4bdc9457SAndroid Build Coastguard Worker indirect_input.data(), input_offset() * sizeof(int8_t), output.data(), 205*4bdc9457SAndroid Build Coastguard Worker (step() - packed_pooling_elements()) * sizeof(void*), 206*4bdc9457SAndroid Build Coastguard Worker (output_stride() - channels()) * sizeof(int8_t), 207*4bdc9457SAndroid Build Coastguard Worker ¶ms); 208*4bdc9457SAndroid Build Coastguard Worker 209*4bdc9457SAndroid Build Coastguard Worker // Verify results. 210*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_pixels(); x++) { 211*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 212*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int16_t(output[x * output_stride() + c]), qmin()) 213*4bdc9457SAndroid Build Coastguard Worker << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels() 214*4bdc9457SAndroid Build Coastguard Worker << ", pooling elements = " << pooling_elements() << ", step = " << step() 215*4bdc9457SAndroid Build Coastguard Worker << ", input offset = " << input_offset(); 216*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int16_t(output[x * output_stride() + c]), qmax()) 217*4bdc9457SAndroid Build Coastguard Worker << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels() 218*4bdc9457SAndroid Build Coastguard Worker << ", pooling elements = " << pooling_elements() << ", step = " << step() 219*4bdc9457SAndroid Build Coastguard Worker << ", input offset = " << input_offset(); 220*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(int32_t(output_ref[x * channels() + c]), int32_t(output[x * output_stride() + c])) 221*4bdc9457SAndroid Build Coastguard Worker << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels() 222*4bdc9457SAndroid Build Coastguard Worker << ", pooling elements = " << pooling_elements() << ", step = " << step() 223*4bdc9457SAndroid Build Coastguard Worker << ", input offset = " << input_offset(); 224*4bdc9457SAndroid Build Coastguard Worker } 225*4bdc9457SAndroid Build Coastguard Worker } 226*4bdc9457SAndroid Build Coastguard Worker } 227*4bdc9457SAndroid Build Coastguard Worker } 228*4bdc9457SAndroid Build Coastguard Worker Test(xnn_u8_maxpool_ukernel_function maxpool,xnn_init_u8_minmax_params_fn init_params)229*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_u8_maxpool_ukernel_function maxpool, xnn_init_u8_minmax_params_fn init_params) const { 230*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(qmin(), std::numeric_limits<uint8_t>::min()); 231*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(qmax(), std::numeric_limits<uint8_t>::max()); 232*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(qmin(), qmax()); 233*4bdc9457SAndroid Build Coastguard Worker 234*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 235*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 236*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> u8dist( 237*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max()); 238*4bdc9457SAndroid Build Coastguard Worker 239*4bdc9457SAndroid Build Coastguard Worker std::vector<const uint8_t*> indirect_input((output_pixels() - 1) * step() + packed_pooling_elements()); 240*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + 241*4bdc9457SAndroid Build Coastguard Worker indirect_input.size() * channels()); 242*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output(XNN_EXTRA_BYTES / sizeof(uint8_t) + 243*4bdc9457SAndroid Build Coastguard Worker (output_pixels() - 1) * output_stride() + channels()); 244*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output_ref(output_pixels() * channels()); 245*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 246*4bdc9457SAndroid Build Coastguard Worker do { 247*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); 248*4bdc9457SAndroid Build Coastguard Worker } while (input.size() > 1 && *std::max_element(input.cbegin(), input.cend()) == *std::min_element(input.cbegin(), input.cend())); 249*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT8_C(0xA5)); 250*4bdc9457SAndroid Build Coastguard Worker 251*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < (output_pixels() - 1) * step() + pooling_elements(); i++) { 252*4bdc9457SAndroid Build Coastguard Worker indirect_input[i] = input.data() + i * channels() - input_offset(); 253*4bdc9457SAndroid Build Coastguard Worker } 254*4bdc9457SAndroid Build Coastguard Worker std::shuffle(indirect_input.begin(), 255*4bdc9457SAndroid Build Coastguard Worker indirect_input.begin() + (output_pixels() - 1) * step() + pooling_elements(), rng); 256*4bdc9457SAndroid Build Coastguard Worker 257*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 258*4bdc9457SAndroid Build Coastguard Worker xnn_u8_minmax_params params; 259*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, static_cast<uint8_t>(qmin()), static_cast<uint8_t>(qmax())); 260*4bdc9457SAndroid Build Coastguard Worker 261*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 262*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_pixels(); x++) { 263*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 264*4bdc9457SAndroid Build Coastguard Worker uint8_t max_value = 0; 265*4bdc9457SAndroid Build Coastguard Worker for (size_t p = 0; p < pooling_elements(); p++) { 266*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, indirect_input[x * step() + p][c + input_offset()]); 267*4bdc9457SAndroid Build Coastguard Worker } 268*4bdc9457SAndroid Build Coastguard Worker max_value = std::min(max_value, static_cast<uint8_t>(qmax())); 269*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, static_cast<uint8_t>(qmin())); 270*4bdc9457SAndroid Build Coastguard Worker output_ref[x * channels() + c] = max_value; 271*4bdc9457SAndroid Build Coastguard Worker } 272*4bdc9457SAndroid Build Coastguard Worker } 273*4bdc9457SAndroid Build Coastguard Worker 274*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 275*4bdc9457SAndroid Build Coastguard Worker maxpool(output_pixels(), pooling_elements(), channels(), 276*4bdc9457SAndroid Build Coastguard Worker indirect_input.data(), input_offset() * sizeof(uint8_t), output.data(), 277*4bdc9457SAndroid Build Coastguard Worker (step() - packed_pooling_elements()) * sizeof(void*), 278*4bdc9457SAndroid Build Coastguard Worker (output_stride() - channels()) * sizeof(uint8_t), 279*4bdc9457SAndroid Build Coastguard Worker ¶ms); 280*4bdc9457SAndroid Build Coastguard Worker 281*4bdc9457SAndroid Build Coastguard Worker // Verify results. 282*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_pixels(); x++) { 283*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 284*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int16_t(output[x * output_stride() + c]), qmin()) 285*4bdc9457SAndroid Build Coastguard Worker << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels() 286*4bdc9457SAndroid Build Coastguard Worker << ", pooling elements = " << pooling_elements() << ", step = " << step() 287*4bdc9457SAndroid Build Coastguard Worker << ", input offset = " << input_offset(); 288*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int16_t(output[x * output_stride() + c]), qmax()) 289*4bdc9457SAndroid Build Coastguard Worker << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels() 290*4bdc9457SAndroid Build Coastguard Worker << ", pooling elements = " << pooling_elements() << ", step = " << step() 291*4bdc9457SAndroid Build Coastguard Worker << ", input offset = " << input_offset(); 292*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(int32_t(output_ref[x * channels() + c]), int32_t(output[x * output_stride() + c])) 293*4bdc9457SAndroid Build Coastguard Worker << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels() 294*4bdc9457SAndroid Build Coastguard Worker << ", pooling elements = " << pooling_elements() << ", step = " << step() 295*4bdc9457SAndroid Build Coastguard Worker << ", input offset = " << input_offset(); 296*4bdc9457SAndroid Build Coastguard Worker } 297*4bdc9457SAndroid Build Coastguard Worker } 298*4bdc9457SAndroid Build Coastguard Worker } 299*4bdc9457SAndroid Build Coastguard Worker } 300*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f16_maxpool_ukernel_function maxpool,xnn_init_f16_minmax_params_fn init_params)301*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f16_maxpool_ukernel_function maxpool, xnn_init_f16_minmax_params_fn init_params) const { 302*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(qmin(), qmax()); 303*4bdc9457SAndroid Build Coastguard Worker 304*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 305*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 306*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-1.0f, 1.0f); 307*4bdc9457SAndroid Build Coastguard Worker 308*4bdc9457SAndroid Build Coastguard Worker std::vector<const uint16_t*> indirect_input((output_pixels() - 1) * step() + packed_pooling_elements()); 309*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + 310*4bdc9457SAndroid Build Coastguard Worker ((output_pixels() - 1) * step() + pooling_elements()) * channels()); 311*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output(XNN_EXTRA_BYTES / sizeof(uint16_t) + 312*4bdc9457SAndroid Build Coastguard Worker (output_pixels() - 1) * output_stride() + channels()); 313*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(output_pixels() * channels()); 314*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 315*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 316*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 317*4bdc9457SAndroid Build Coastguard Worker 318*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < (output_pixels() - 1) * step() + pooling_elements(); i++) { 319*4bdc9457SAndroid Build Coastguard Worker indirect_input[i] = input.data() + i * channels() - input_offset(); 320*4bdc9457SAndroid Build Coastguard Worker } 321*4bdc9457SAndroid Build Coastguard Worker std::shuffle(indirect_input.begin(), 322*4bdc9457SAndroid Build Coastguard Worker indirect_input.begin() + (output_pixels() - 1) * step() + pooling_elements(), rng); 323*4bdc9457SAndroid Build Coastguard Worker 324*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 325*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_pixels(); x++) { 326*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 327*4bdc9457SAndroid Build Coastguard Worker float max_value = -std::numeric_limits<float>::infinity(); 328*4bdc9457SAndroid Build Coastguard Worker for (size_t p = 0; p < pooling_elements(); p++) { 329*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, fp16_ieee_to_fp32_value(indirect_input[x * step() + p][c + input_offset()])); 330*4bdc9457SAndroid Build Coastguard Worker } 331*4bdc9457SAndroid Build Coastguard Worker output_ref[x * channels() + c] = max_value; 332*4bdc9457SAndroid Build Coastguard Worker } 333*4bdc9457SAndroid Build Coastguard Worker } 334*4bdc9457SAndroid Build Coastguard Worker 335*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 336*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 337*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 338*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 339*4bdc9457SAndroid Build Coastguard Worker float output_min = accumulated_min + accumulated_range * 340*4bdc9457SAndroid Build Coastguard Worker (static_cast<float>(qmin() - std::numeric_limits<int16_t>::min()) / 341*4bdc9457SAndroid Build Coastguard Worker static_cast<float>(std::numeric_limits<int16_t>::max() - std::numeric_limits<int16_t>::min())); 342*4bdc9457SAndroid Build Coastguard Worker if (qmin() == std::numeric_limits<int16_t>::min()) { 343*4bdc9457SAndroid Build Coastguard Worker output_min = -std::numeric_limits<float>::infinity(); 344*4bdc9457SAndroid Build Coastguard Worker } 345*4bdc9457SAndroid Build Coastguard Worker float output_max = accumulated_max - accumulated_range * 346*4bdc9457SAndroid Build Coastguard Worker (static_cast<float>(std::numeric_limits<int16_t>::max() - qmax()) / 347*4bdc9457SAndroid Build Coastguard Worker static_cast<float>(std::numeric_limits<int16_t>::max() - std::numeric_limits<int16_t>::min())); 348*4bdc9457SAndroid Build Coastguard Worker if (qmax() == std::numeric_limits<int16_t>::max()) { 349*4bdc9457SAndroid Build Coastguard Worker output_max = +std::numeric_limits<float>::infinity(); 350*4bdc9457SAndroid Build Coastguard Worker } 351*4bdc9457SAndroid Build Coastguard Worker output_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_min)); 352*4bdc9457SAndroid Build Coastguard Worker output_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_max)); 353*4bdc9457SAndroid Build Coastguard Worker 354*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 355*4bdc9457SAndroid Build Coastguard Worker xnn_f16_minmax_params params; 356*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, fp16_ieee_from_fp32_value(output_min), fp16_ieee_from_fp32_value(output_max)); 357*4bdc9457SAndroid Build Coastguard Worker 358*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 359*4bdc9457SAndroid Build Coastguard Worker for (float& output_value : output_ref) { 360*4bdc9457SAndroid Build Coastguard Worker output_value = std::max(std::min(output_value, output_max), output_min); 361*4bdc9457SAndroid Build Coastguard Worker } 362*4bdc9457SAndroid Build Coastguard Worker 363*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 364*4bdc9457SAndroid Build Coastguard Worker maxpool(output_pixels(), pooling_elements(), channels(), 365*4bdc9457SAndroid Build Coastguard Worker reinterpret_cast<const void**>(indirect_input.data()), input_offset() * sizeof(uint16_t), output.data(), 366*4bdc9457SAndroid Build Coastguard Worker (step() - packed_pooling_elements()) * sizeof(void*), 367*4bdc9457SAndroid Build Coastguard Worker (output_stride() - channels()) * sizeof(uint16_t), 368*4bdc9457SAndroid Build Coastguard Worker ¶ms); 369*4bdc9457SAndroid Build Coastguard Worker 370*4bdc9457SAndroid Build Coastguard Worker // Verify results. 371*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_pixels(); x++) { 372*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 373*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(fp16_ieee_to_fp32_value(output[x * output_stride() + c]), output_min) 374*4bdc9457SAndroid Build Coastguard Worker << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels() 375*4bdc9457SAndroid Build Coastguard Worker << ", pooling elements = " << pooling_elements() << ", step = " << step() 376*4bdc9457SAndroid Build Coastguard Worker << ", input offset = " << input_offset(); 377*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(fp16_ieee_to_fp32_value(output[x * output_stride() + c]), output_max) 378*4bdc9457SAndroid Build Coastguard Worker << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels() 379*4bdc9457SAndroid Build Coastguard Worker << ", pooling elements = " << pooling_elements() << ", step = " << step() 380*4bdc9457SAndroid Build Coastguard Worker << ", input offset = " << input_offset(); 381*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(fp16_ieee_to_fp32_value(output[x * output_stride() + c]), output_ref[x * channels() + c]) 382*4bdc9457SAndroid Build Coastguard Worker << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels() 383*4bdc9457SAndroid Build Coastguard Worker << ", pooling elements = " << pooling_elements() << ", step = " << step() 384*4bdc9457SAndroid Build Coastguard Worker << ", input offset = " << input_offset(); 385*4bdc9457SAndroid Build Coastguard Worker } 386*4bdc9457SAndroid Build Coastguard Worker } 387*4bdc9457SAndroid Build Coastguard Worker } 388*4bdc9457SAndroid Build Coastguard Worker } 389*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f32_maxpool_ukernel_function maxpool,xnn_init_f32_minmax_params_fn init_params)390*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_maxpool_ukernel_function maxpool, xnn_init_f32_minmax_params_fn init_params) const { 391*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(qmin(), qmax()); 392*4bdc9457SAndroid Build Coastguard Worker 393*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 394*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 395*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(-1.0f, 1.0f); 396*4bdc9457SAndroid Build Coastguard Worker 397*4bdc9457SAndroid Build Coastguard Worker std::vector<const float*> indirect_input((output_pixels() - 1) * step() + packed_pooling_elements()); 398*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + 399*4bdc9457SAndroid Build Coastguard Worker ((output_pixels() - 1) * step() + pooling_elements()) * channels()); 400*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output(XNN_EXTRA_BYTES / sizeof(float) + 401*4bdc9457SAndroid Build Coastguard Worker (output_pixels() - 1) * output_stride() + channels()); 402*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(output_pixels() * channels()); 403*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 404*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 405*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), nanf("")); 406*4bdc9457SAndroid Build Coastguard Worker 407*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < (output_pixels() - 1) * step() + pooling_elements(); i++) { 408*4bdc9457SAndroid Build Coastguard Worker indirect_input[i] = input.data() + i * channels() - input_offset(); 409*4bdc9457SAndroid Build Coastguard Worker } 410*4bdc9457SAndroid Build Coastguard Worker std::shuffle(indirect_input.begin(), 411*4bdc9457SAndroid Build Coastguard Worker indirect_input.begin() + (output_pixels() - 1) * step() + pooling_elements(), rng); 412*4bdc9457SAndroid Build Coastguard Worker 413*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 414*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_pixels(); x++) { 415*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 416*4bdc9457SAndroid Build Coastguard Worker float max_value = -std::numeric_limits<float>::infinity(); 417*4bdc9457SAndroid Build Coastguard Worker for (size_t p = 0; p < pooling_elements(); p++) { 418*4bdc9457SAndroid Build Coastguard Worker max_value = std::max(max_value, indirect_input[x * step() + p][c + input_offset()]); 419*4bdc9457SAndroid Build Coastguard Worker } 420*4bdc9457SAndroid Build Coastguard Worker output_ref[x * channels() + c] = max_value; 421*4bdc9457SAndroid Build Coastguard Worker } 422*4bdc9457SAndroid Build Coastguard Worker } 423*4bdc9457SAndroid Build Coastguard Worker 424*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 425*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 426*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 427*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 428*4bdc9457SAndroid Build Coastguard Worker float output_min = accumulated_min + accumulated_range * 429*4bdc9457SAndroid Build Coastguard Worker (static_cast<float>(qmin() - std::numeric_limits<int16_t>::min()) / 430*4bdc9457SAndroid Build Coastguard Worker static_cast<float>(std::numeric_limits<int16_t>::max() - std::numeric_limits<int16_t>::min())); 431*4bdc9457SAndroid Build Coastguard Worker if (qmin() == std::numeric_limits<int16_t>::min()) { 432*4bdc9457SAndroid Build Coastguard Worker output_min = -std::numeric_limits<float>::infinity(); 433*4bdc9457SAndroid Build Coastguard Worker } 434*4bdc9457SAndroid Build Coastguard Worker float output_max = accumulated_max - accumulated_range * 435*4bdc9457SAndroid Build Coastguard Worker (static_cast<float>(std::numeric_limits<int16_t>::max() - qmax()) / 436*4bdc9457SAndroid Build Coastguard Worker static_cast<float>(std::numeric_limits<int16_t>::max() - std::numeric_limits<int16_t>::min())); 437*4bdc9457SAndroid Build Coastguard Worker if (qmax() == std::numeric_limits<int16_t>::max()) { 438*4bdc9457SAndroid Build Coastguard Worker output_max = +std::numeric_limits<float>::infinity(); 439*4bdc9457SAndroid Build Coastguard Worker } 440*4bdc9457SAndroid Build Coastguard Worker 441*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 442*4bdc9457SAndroid Build Coastguard Worker xnn_f32_minmax_params params; 443*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, output_min, output_max); 444*4bdc9457SAndroid Build Coastguard Worker 445*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 446*4bdc9457SAndroid Build Coastguard Worker for (float& output_value : output_ref) { 447*4bdc9457SAndroid Build Coastguard Worker output_value = std::max(std::min(output_value, output_max), output_min); 448*4bdc9457SAndroid Build Coastguard Worker } 449*4bdc9457SAndroid Build Coastguard Worker 450*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 451*4bdc9457SAndroid Build Coastguard Worker maxpool(output_pixels(), pooling_elements(), channels(), 452*4bdc9457SAndroid Build Coastguard Worker indirect_input.data(), input_offset() * sizeof(float), output.data(), 453*4bdc9457SAndroid Build Coastguard Worker (step() - packed_pooling_elements()) * sizeof(void*), 454*4bdc9457SAndroid Build Coastguard Worker (output_stride() - channels()) * sizeof(float), 455*4bdc9457SAndroid Build Coastguard Worker ¶ms); 456*4bdc9457SAndroid Build Coastguard Worker 457*4bdc9457SAndroid Build Coastguard Worker // Verify results. 458*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_pixels(); x++) { 459*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 460*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output[x * output_stride() + c], output_min) 461*4bdc9457SAndroid Build Coastguard Worker << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels() 462*4bdc9457SAndroid Build Coastguard Worker << ", pooling elements = " << pooling_elements() << ", step = " << step() 463*4bdc9457SAndroid Build Coastguard Worker << ", input offset = " << input_offset(); 464*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(output[x * output_stride() + c], output_max) 465*4bdc9457SAndroid Build Coastguard Worker << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels() 466*4bdc9457SAndroid Build Coastguard Worker << ", pooling elements = " << pooling_elements() << ", step = " << step() 467*4bdc9457SAndroid Build Coastguard Worker << ", input offset = " << input_offset(); 468*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(output_ref[x * channels() + c], output[x * output_stride() + c]) 469*4bdc9457SAndroid Build Coastguard Worker << "at pixel " << x << " / " << output_pixels() << ", channel " << c << " / " << channels() 470*4bdc9457SAndroid Build Coastguard Worker << ", pooling elements = " << pooling_elements() << ", step = " << step() 471*4bdc9457SAndroid Build Coastguard Worker << ", input offset = " << input_offset(); 472*4bdc9457SAndroid Build Coastguard Worker } 473*4bdc9457SAndroid Build Coastguard Worker } 474*4bdc9457SAndroid Build Coastguard Worker } 475*4bdc9457SAndroid Build Coastguard Worker } 476*4bdc9457SAndroid Build Coastguard Worker 477*4bdc9457SAndroid Build Coastguard Worker private: 478*4bdc9457SAndroid Build Coastguard Worker size_t output_pixels_{1}; 479*4bdc9457SAndroid Build Coastguard Worker size_t pooling_elements_{1}; 480*4bdc9457SAndroid Build Coastguard Worker size_t channels_{1}; 481*4bdc9457SAndroid Build Coastguard Worker size_t input_offset_{0}; 482*4bdc9457SAndroid Build Coastguard Worker size_t step_{1}; 483*4bdc9457SAndroid Build Coastguard Worker size_t primary_pooling_tile_{1}; 484*4bdc9457SAndroid Build Coastguard Worker size_t incremental_pooling_tile_{1}; 485*4bdc9457SAndroid Build Coastguard Worker size_t output_stride_{0}; 486*4bdc9457SAndroid Build Coastguard Worker int16_t qmin_{std::numeric_limits<int16_t>::min()}; 487*4bdc9457SAndroid Build Coastguard Worker int16_t qmax_{std::numeric_limits<int16_t>::max()}; 488*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{3}; 489*4bdc9457SAndroid Build Coastguard Worker }; 490