1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates. 2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved. 3*4bdc9457SAndroid Build Coastguard Worker // 4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC 5*4bdc9457SAndroid Build Coastguard Worker // 6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 8*4bdc9457SAndroid Build Coastguard Worker 9*4bdc9457SAndroid Build Coastguard Worker #pragma once 10*4bdc9457SAndroid Build Coastguard Worker 11*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 12*4bdc9457SAndroid Build Coastguard Worker 13*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 14*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 15*4bdc9457SAndroid Build Coastguard Worker #include <cmath> 16*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 17*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 18*4bdc9457SAndroid Build Coastguard Worker #include <limits> 19*4bdc9457SAndroid Build Coastguard Worker #include <random> 20*4bdc9457SAndroid Build Coastguard Worker #include <vector> 21*4bdc9457SAndroid Build Coastguard Worker 22*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h> 23*4bdc9457SAndroid Build Coastguard Worker 24*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 25*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/aligned-allocator.h> 26*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h> 27*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microparams-init.h> 28*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/requantization.h> 29*4bdc9457SAndroid Build Coastguard Worker 30*4bdc9457SAndroid Build Coastguard Worker 31*4bdc9457SAndroid Build Coastguard Worker class GAvgPoolMicrokernelTester { 32*4bdc9457SAndroid Build Coastguard Worker public: rows(size_t rows)33*4bdc9457SAndroid Build Coastguard Worker inline GAvgPoolMicrokernelTester& rows(size_t rows) { 34*4bdc9457SAndroid Build Coastguard Worker assert(rows != 0); 35*4bdc9457SAndroid Build Coastguard Worker this->rows_ = rows; 36*4bdc9457SAndroid Build Coastguard Worker return *this; 37*4bdc9457SAndroid Build Coastguard Worker } 38*4bdc9457SAndroid Build Coastguard Worker rows()39*4bdc9457SAndroid Build Coastguard Worker inline size_t rows() const { 40*4bdc9457SAndroid Build Coastguard Worker return this->rows_; 41*4bdc9457SAndroid Build Coastguard Worker } 42*4bdc9457SAndroid Build Coastguard Worker channels(size_t channels)43*4bdc9457SAndroid Build Coastguard Worker inline GAvgPoolMicrokernelTester& channels(size_t channels) { 44*4bdc9457SAndroid Build Coastguard Worker assert(channels != 0); 45*4bdc9457SAndroid Build Coastguard Worker this->channels_ = channels; 46*4bdc9457SAndroid Build Coastguard Worker return *this; 47*4bdc9457SAndroid Build Coastguard Worker } 48*4bdc9457SAndroid Build Coastguard Worker channels()49*4bdc9457SAndroid Build Coastguard Worker inline size_t channels() const { 50*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 51*4bdc9457SAndroid Build Coastguard Worker } 52*4bdc9457SAndroid Build Coastguard Worker channel_tile(size_t channel_tile)53*4bdc9457SAndroid Build Coastguard Worker inline GAvgPoolMicrokernelTester& channel_tile(size_t channel_tile) { 54*4bdc9457SAndroid Build Coastguard Worker assert(channel_tile != 0); 55*4bdc9457SAndroid Build Coastguard Worker this->channel_tile_ = channel_tile; 56*4bdc9457SAndroid Build Coastguard Worker return *this; 57*4bdc9457SAndroid Build Coastguard Worker } 58*4bdc9457SAndroid Build Coastguard Worker channel_tile()59*4bdc9457SAndroid Build Coastguard Worker inline size_t channel_tile() const { 60*4bdc9457SAndroid Build Coastguard Worker return this->channel_tile_; 61*4bdc9457SAndroid Build Coastguard Worker } 62*4bdc9457SAndroid Build Coastguard Worker input_stride(size_t input_stride)63*4bdc9457SAndroid Build Coastguard Worker inline GAvgPoolMicrokernelTester& input_stride(size_t input_stride) { 64*4bdc9457SAndroid Build Coastguard Worker assert(input_stride != 0); 65*4bdc9457SAndroid Build Coastguard Worker this->input_stride_ = input_stride; 66*4bdc9457SAndroid Build Coastguard Worker return *this; 67*4bdc9457SAndroid Build Coastguard Worker } 68*4bdc9457SAndroid Build Coastguard Worker input_stride()69*4bdc9457SAndroid Build Coastguard Worker inline size_t input_stride() const { 70*4bdc9457SAndroid Build Coastguard Worker if (this->input_stride_ == 0) { 71*4bdc9457SAndroid Build Coastguard Worker return channels(); 72*4bdc9457SAndroid Build Coastguard Worker } else { 73*4bdc9457SAndroid Build Coastguard Worker assert(this->input_stride_ >= channels()); 74*4bdc9457SAndroid Build Coastguard Worker return this->input_stride_; 75*4bdc9457SAndroid Build Coastguard Worker } 76*4bdc9457SAndroid Build Coastguard Worker } 77*4bdc9457SAndroid Build Coastguard Worker input_scale(float input_scale)78*4bdc9457SAndroid Build Coastguard Worker inline GAvgPoolMicrokernelTester& input_scale(float input_scale) { 79*4bdc9457SAndroid Build Coastguard Worker assert(input_scale > 0.0f); 80*4bdc9457SAndroid Build Coastguard Worker assert(std::isnormal(input_scale)); 81*4bdc9457SAndroid Build Coastguard Worker this->input_scale_ = input_scale; 82*4bdc9457SAndroid Build Coastguard Worker return *this; 83*4bdc9457SAndroid Build Coastguard Worker } 84*4bdc9457SAndroid Build Coastguard Worker input_scale()85*4bdc9457SAndroid Build Coastguard Worker inline float input_scale() const { 86*4bdc9457SAndroid Build Coastguard Worker return this->input_scale_; 87*4bdc9457SAndroid Build Coastguard Worker } 88*4bdc9457SAndroid Build Coastguard Worker input_zero_point(uint8_t input_zero_point)89*4bdc9457SAndroid Build Coastguard Worker inline GAvgPoolMicrokernelTester& input_zero_point(uint8_t input_zero_point) { 90*4bdc9457SAndroid Build Coastguard Worker this->input_zero_point_ = input_zero_point; 91*4bdc9457SAndroid Build Coastguard Worker return *this; 92*4bdc9457SAndroid Build Coastguard Worker } 93*4bdc9457SAndroid Build Coastguard Worker input_zero_point()94*4bdc9457SAndroid Build Coastguard Worker inline uint8_t input_zero_point() const { 95*4bdc9457SAndroid Build Coastguard Worker return this->input_zero_point_; 96*4bdc9457SAndroid Build Coastguard Worker } 97*4bdc9457SAndroid Build Coastguard Worker output_scale(float output_scale)98*4bdc9457SAndroid Build Coastguard Worker inline GAvgPoolMicrokernelTester& output_scale(float output_scale) { 99*4bdc9457SAndroid Build Coastguard Worker assert(output_scale > 0.0f); 100*4bdc9457SAndroid Build Coastguard Worker assert(std::isnormal(output_scale)); 101*4bdc9457SAndroid Build Coastguard Worker this->output_scale_ = output_scale; 102*4bdc9457SAndroid Build Coastguard Worker return *this; 103*4bdc9457SAndroid Build Coastguard Worker } 104*4bdc9457SAndroid Build Coastguard Worker output_scale()105*4bdc9457SAndroid Build Coastguard Worker inline float output_scale() const { 106*4bdc9457SAndroid Build Coastguard Worker return this->output_scale_; 107*4bdc9457SAndroid Build Coastguard Worker } 108*4bdc9457SAndroid Build Coastguard Worker output_zero_point(uint8_t output_zero_point)109*4bdc9457SAndroid Build Coastguard Worker inline GAvgPoolMicrokernelTester& output_zero_point(uint8_t output_zero_point) { 110*4bdc9457SAndroid Build Coastguard Worker this->output_zero_point_ = output_zero_point; 111*4bdc9457SAndroid Build Coastguard Worker return *this; 112*4bdc9457SAndroid Build Coastguard Worker } 113*4bdc9457SAndroid Build Coastguard Worker output_zero_point()114*4bdc9457SAndroid Build Coastguard Worker inline uint8_t output_zero_point() const { 115*4bdc9457SAndroid Build Coastguard Worker return this->output_zero_point_; 116*4bdc9457SAndroid Build Coastguard Worker } 117*4bdc9457SAndroid Build Coastguard Worker qmin(uint8_t qmin)118*4bdc9457SAndroid Build Coastguard Worker inline GAvgPoolMicrokernelTester& qmin(uint8_t qmin) { 119*4bdc9457SAndroid Build Coastguard Worker this->qmin_ = qmin; 120*4bdc9457SAndroid Build Coastguard Worker return *this; 121*4bdc9457SAndroid Build Coastguard Worker } 122*4bdc9457SAndroid Build Coastguard Worker qmin()123*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmin() const { 124*4bdc9457SAndroid Build Coastguard Worker return this->qmin_; 125*4bdc9457SAndroid Build Coastguard Worker } 126*4bdc9457SAndroid Build Coastguard Worker qmax(uint8_t qmax)127*4bdc9457SAndroid Build Coastguard Worker inline GAvgPoolMicrokernelTester& qmax(uint8_t qmax) { 128*4bdc9457SAndroid Build Coastguard Worker this->qmax_ = qmax; 129*4bdc9457SAndroid Build Coastguard Worker return *this; 130*4bdc9457SAndroid Build Coastguard Worker } 131*4bdc9457SAndroid Build Coastguard Worker qmax()132*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmax() const { 133*4bdc9457SAndroid Build Coastguard Worker return this->qmax_; 134*4bdc9457SAndroid Build Coastguard Worker } 135*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)136*4bdc9457SAndroid Build Coastguard Worker inline GAvgPoolMicrokernelTester& iterations(size_t iterations) { 137*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 138*4bdc9457SAndroid Build Coastguard Worker return *this; 139*4bdc9457SAndroid Build Coastguard Worker } 140*4bdc9457SAndroid Build Coastguard Worker iterations()141*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 142*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 143*4bdc9457SAndroid Build Coastguard Worker } 144*4bdc9457SAndroid Build Coastguard Worker Test(xnn_qu8_gavgpool_minmax_unipass_ukernel_function gavgpool_minmax,xnn_init_qu8_avgpool_minmax_params_fn init_params,xnn_qu8_requantize_fn requantize)145*4bdc9457SAndroid Build Coastguard Worker void Test( 146*4bdc9457SAndroid Build Coastguard Worker xnn_qu8_gavgpool_minmax_unipass_ukernel_function gavgpool_minmax, 147*4bdc9457SAndroid Build Coastguard Worker xnn_init_qu8_avgpool_minmax_params_fn init_params, 148*4bdc9457SAndroid Build Coastguard Worker xnn_qu8_requantize_fn requantize) const 149*4bdc9457SAndroid Build Coastguard Worker { 150*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 151*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 152*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> u8dist( 153*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max()); 154*4bdc9457SAndroid Build Coastguard Worker 155*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + 156*4bdc9457SAndroid Build Coastguard Worker (rows() - 1) * input_stride() + channels()); 157*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> zero(channels() + XNN_EXTRA_BYTES / sizeof(uint8_t)); 158*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output(channels()); 159*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output_ref(channels()); 160*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_fp(channels()); 161*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> accumulators(channels()); 162*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 163*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); 164*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT8_C(0xA5)); 165*4bdc9457SAndroid Build Coastguard Worker 166*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 167*4bdc9457SAndroid Build Coastguard Worker union xnn_qu8_avgpool_minmax_params params; 168*4bdc9457SAndroid Build Coastguard Worker init_params( 169*4bdc9457SAndroid Build Coastguard Worker ¶ms, 170*4bdc9457SAndroid Build Coastguard Worker -int32_t(input_zero_point()) * int32_t(rows()), 171*4bdc9457SAndroid Build Coastguard Worker input_scale() / (output_scale() * float(rows())), 172*4bdc9457SAndroid Build Coastguard Worker output_zero_point(), qmin(), qmax()); 173*4bdc9457SAndroid Build Coastguard Worker 174*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 175*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 176*4bdc9457SAndroid Build Coastguard Worker int32_t acc = 0; 177*4bdc9457SAndroid Build Coastguard Worker for (size_t n = 0; n < rows(); n++) { 178*4bdc9457SAndroid Build Coastguard Worker acc += int32_t(input[n * input_stride() + c]) - int32_t(input_zero_point()); 179*4bdc9457SAndroid Build Coastguard Worker } 180*4bdc9457SAndroid Build Coastguard Worker accumulators[c] = acc; 181*4bdc9457SAndroid Build Coastguard Worker output_ref[c] = requantize( 182*4bdc9457SAndroid Build Coastguard Worker acc, input_scale() / (output_scale() * float(rows())), output_zero_point(), qmin(), qmax()); 183*4bdc9457SAndroid Build Coastguard Worker output_fp[c] = float(acc) * (input_scale() / (output_scale() * float(rows()))) + float(output_zero_point()); 184*4bdc9457SAndroid Build Coastguard Worker output_fp[c] = std::min<float>(output_fp[c], float(qmax())); 185*4bdc9457SAndroid Build Coastguard Worker output_fp[c] = std::max<float>(output_fp[c], float(qmin())); 186*4bdc9457SAndroid Build Coastguard Worker } 187*4bdc9457SAndroid Build Coastguard Worker 188*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 189*4bdc9457SAndroid Build Coastguard Worker gavgpool_minmax(rows(), channels(), 190*4bdc9457SAndroid Build Coastguard Worker input.data(), input_stride() * sizeof(uint8_t), 191*4bdc9457SAndroid Build Coastguard Worker zero.data(), 192*4bdc9457SAndroid Build Coastguard Worker output.data(), 193*4bdc9457SAndroid Build Coastguard Worker ¶ms); 194*4bdc9457SAndroid Build Coastguard Worker 195*4bdc9457SAndroid Build Coastguard Worker // Verify results. 196*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 197*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(uint32_t(output[c]), uint32_t(qmax())) 198*4bdc9457SAndroid Build Coastguard Worker << "at position " << c << ", rows = " << rows() << ", channels = " << channels(); 199*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(uint32_t(output[c]), uint32_t(qmin())) 200*4bdc9457SAndroid Build Coastguard Worker << "at position " << c << ", rows = " << rows() << ", channels = " << channels(); 201*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(float(int32_t(output[c])), output_fp[c], 0.55f) 202*4bdc9457SAndroid Build Coastguard Worker << "at position " << c << ", rows = " << rows() << ", channels = " << channels() 203*4bdc9457SAndroid Build Coastguard Worker << ", acc = " << accumulators[c]; 204*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(uint32_t(output_ref[c]), uint32_t(output[c])) 205*4bdc9457SAndroid Build Coastguard Worker << "at position " << c << ", rows = " << rows() << ", channels = " << channels() 206*4bdc9457SAndroid Build Coastguard Worker << ", acc = " << accumulators[c]; 207*4bdc9457SAndroid Build Coastguard Worker } 208*4bdc9457SAndroid Build Coastguard Worker } 209*4bdc9457SAndroid Build Coastguard Worker } 210*4bdc9457SAndroid Build Coastguard Worker Test(xnn_qu8_gavgpool_minmax_multipass_ukernel_function gavgpool_minmax,xnn_init_qu8_avgpool_minmax_params_fn init_params,xnn_qu8_requantize_fn requantize)211*4bdc9457SAndroid Build Coastguard Worker void Test( 212*4bdc9457SAndroid Build Coastguard Worker xnn_qu8_gavgpool_minmax_multipass_ukernel_function gavgpool_minmax, 213*4bdc9457SAndroid Build Coastguard Worker xnn_init_qu8_avgpool_minmax_params_fn init_params, 214*4bdc9457SAndroid Build Coastguard Worker xnn_qu8_requantize_fn requantize) const 215*4bdc9457SAndroid Build Coastguard Worker { 216*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 217*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 218*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> u8dist( 219*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max()); 220*4bdc9457SAndroid Build Coastguard Worker 221*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + 222*4bdc9457SAndroid Build Coastguard Worker (rows() - 1) * input_stride() + channels()); 223*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t, AlignedAllocator<int32_t, 64>> buffer(channels() + XNN_EXTRA_BYTES / sizeof(uint8_t)); 224*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> zero(channels() + XNN_EXTRA_BYTES / sizeof(uint8_t)); 225*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output(channels()); 226*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output_ref(channels()); 227*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_fp(channels()); 228*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> accumulators(channels()); 229*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 230*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); 231*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT8_C(0xA5)); 232*4bdc9457SAndroid Build Coastguard Worker 233*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 234*4bdc9457SAndroid Build Coastguard Worker union xnn_qu8_avgpool_minmax_params params; 235*4bdc9457SAndroid Build Coastguard Worker init_params( 236*4bdc9457SAndroid Build Coastguard Worker ¶ms, 237*4bdc9457SAndroid Build Coastguard Worker -int32_t(input_zero_point()) * int32_t(rows()), 238*4bdc9457SAndroid Build Coastguard Worker input_scale() / (output_scale() * float(rows())), 239*4bdc9457SAndroid Build Coastguard Worker output_zero_point(), qmin(), qmax()); 240*4bdc9457SAndroid Build Coastguard Worker 241*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 242*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 243*4bdc9457SAndroid Build Coastguard Worker int32_t acc = 0; 244*4bdc9457SAndroid Build Coastguard Worker for (size_t n = 0; n < rows(); n++) { 245*4bdc9457SAndroid Build Coastguard Worker acc += int32_t(input[n * input_stride() + c]) - int32_t(input_zero_point()); 246*4bdc9457SAndroid Build Coastguard Worker } 247*4bdc9457SAndroid Build Coastguard Worker 248*4bdc9457SAndroid Build Coastguard Worker accumulators[c] = acc; 249*4bdc9457SAndroid Build Coastguard Worker output_ref[c] = requantize( 250*4bdc9457SAndroid Build Coastguard Worker acc, input_scale() / (output_scale() * float(rows())), output_zero_point(), qmin(), qmax()); 251*4bdc9457SAndroid Build Coastguard Worker output_fp[c] = float(acc) * (input_scale() / (output_scale() * float(rows()))) + float(output_zero_point()); 252*4bdc9457SAndroid Build Coastguard Worker output_fp[c] = std::min<float>(output_fp[c], float(qmax())); 253*4bdc9457SAndroid Build Coastguard Worker output_fp[c] = std::max<float>(output_fp[c], float(qmin())); 254*4bdc9457SAndroid Build Coastguard Worker } 255*4bdc9457SAndroid Build Coastguard Worker 256*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 257*4bdc9457SAndroid Build Coastguard Worker gavgpool_minmax(rows(), channels(), 258*4bdc9457SAndroid Build Coastguard Worker input.data(), input_stride() * sizeof(uint8_t), 259*4bdc9457SAndroid Build Coastguard Worker zero.data(), 260*4bdc9457SAndroid Build Coastguard Worker buffer.data(), 261*4bdc9457SAndroid Build Coastguard Worker output.data(), 262*4bdc9457SAndroid Build Coastguard Worker ¶ms); 263*4bdc9457SAndroid Build Coastguard Worker 264*4bdc9457SAndroid Build Coastguard Worker // Verify results. 265*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 266*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(uint32_t(output[c]), uint32_t(qmax())) 267*4bdc9457SAndroid Build Coastguard Worker << "at position " << c << ", rows = " << rows() << ", channels = " << channels(); 268*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(uint32_t(output[c]), uint32_t(qmin())) 269*4bdc9457SAndroid Build Coastguard Worker << "at position " << c << ", rows = " << rows() << ", channels = " << channels(); 270*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(float(int32_t(output[c])), output_fp[c], 0.55f) 271*4bdc9457SAndroid Build Coastguard Worker << "at position " << c << ", rows = " << rows() << ", channels = " << channels() 272*4bdc9457SAndroid Build Coastguard Worker << ", acc = " << accumulators[c]; 273*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(uint32_t(output_ref[c]), uint32_t(output[c])) 274*4bdc9457SAndroid Build Coastguard Worker << "at position " << c << ", rows = " << rows() << ", channels = " << channels() 275*4bdc9457SAndroid Build Coastguard Worker << ", acc = " << accumulators[c]; 276*4bdc9457SAndroid Build Coastguard Worker } 277*4bdc9457SAndroid Build Coastguard Worker } 278*4bdc9457SAndroid Build Coastguard Worker } 279*4bdc9457SAndroid Build Coastguard Worker Test(xnn_qs8_gavgpool_minmax_unipass_ukernel_function gavgpool_minmax,xnn_init_qs8_avgpool_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize)280*4bdc9457SAndroid Build Coastguard Worker void Test( 281*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_gavgpool_minmax_unipass_ukernel_function gavgpool_minmax, 282*4bdc9457SAndroid Build Coastguard Worker xnn_init_qs8_avgpool_minmax_params_fn init_params, 283*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_requantize_fn requantize) const 284*4bdc9457SAndroid Build Coastguard Worker { 285*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 286*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 287*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i8dist( 288*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()); 289*4bdc9457SAndroid Build Coastguard Worker 290*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) + 291*4bdc9457SAndroid Build Coastguard Worker (rows() - 1) * input_stride() + channels()); 292*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> zero(channels() + XNN_EXTRA_BYTES / sizeof(int8_t)); 293*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output(channels()); 294*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output_ref(channels()); 295*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_fp(channels()); 296*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> accumulators(channels()); 297*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 298*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); }); 299*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), INT8_C(0xA5)); 300*4bdc9457SAndroid Build Coastguard Worker 301*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 302*4bdc9457SAndroid Build Coastguard Worker union xnn_qs8_avgpool_minmax_params params; 303*4bdc9457SAndroid Build Coastguard Worker init_params( 304*4bdc9457SAndroid Build Coastguard Worker ¶ms, 305*4bdc9457SAndroid Build Coastguard Worker -int32_t(input_zero_point() - 0x80) * int32_t(rows()), 306*4bdc9457SAndroid Build Coastguard Worker input_scale() / (output_scale() * float(rows())), 307*4bdc9457SAndroid Build Coastguard Worker int8_t(output_zero_point() - 0x80), int8_t(qmin() - 0x80), int8_t(qmax() - 0x80)); 308*4bdc9457SAndroid Build Coastguard Worker 309*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 310*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 311*4bdc9457SAndroid Build Coastguard Worker int32_t acc = 0; 312*4bdc9457SAndroid Build Coastguard Worker for (size_t n = 0; n < rows(); n++) { 313*4bdc9457SAndroid Build Coastguard Worker acc += int32_t(input[n * input_stride() + c]) - int32_t(input_zero_point() - 0x80); 314*4bdc9457SAndroid Build Coastguard Worker } 315*4bdc9457SAndroid Build Coastguard Worker accumulators[c] = acc; 316*4bdc9457SAndroid Build Coastguard Worker output_ref[c] = requantize( 317*4bdc9457SAndroid Build Coastguard Worker acc, input_scale() / (output_scale() * float(rows())), int8_t(output_zero_point() - 0x80), int8_t(qmin() - 0x80), int8_t(qmax() - 0x80)); 318*4bdc9457SAndroid Build Coastguard Worker output_fp[c] = float(acc) * (input_scale() / (output_scale() * float(rows()))) + float(output_zero_point() - 0x80); 319*4bdc9457SAndroid Build Coastguard Worker output_fp[c] = std::min<float>(output_fp[c], float(qmax() - 0x80)); 320*4bdc9457SAndroid Build Coastguard Worker output_fp[c] = std::max<float>(output_fp[c], float(qmin() - 0x80)); 321*4bdc9457SAndroid Build Coastguard Worker } 322*4bdc9457SAndroid Build Coastguard Worker 323*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 324*4bdc9457SAndroid Build Coastguard Worker gavgpool_minmax(rows(), channels(), 325*4bdc9457SAndroid Build Coastguard Worker input.data(), input_stride() * sizeof(int8_t), 326*4bdc9457SAndroid Build Coastguard Worker zero.data(), 327*4bdc9457SAndroid Build Coastguard Worker output.data(), 328*4bdc9457SAndroid Build Coastguard Worker ¶ms); 329*4bdc9457SAndroid Build Coastguard Worker 330*4bdc9457SAndroid Build Coastguard Worker // Verify results. 331*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 332*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[c]), int32_t(qmax() - 0x80)) 333*4bdc9457SAndroid Build Coastguard Worker << "at channel " << c << " / " << channels() << ", rows = " << rows(); 334*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[c]), int32_t(qmin() - 0x80)) 335*4bdc9457SAndroid Build Coastguard Worker << "at channel " << c << " / " << channels() << ", rows = " << rows(); 336*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(float(int32_t(output[c])), output_fp[c], 0.55f) 337*4bdc9457SAndroid Build Coastguard Worker << "at channel " << c << " / " << channels() << ", rows = " << rows() 338*4bdc9457SAndroid Build Coastguard Worker << ", accumulator = " << accumulators[c]; 339*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(int32_t(output_ref[c]), int32_t(output[c])) 340*4bdc9457SAndroid Build Coastguard Worker << "at channel " << c << " / " << channels() << ", rows = " << rows() 341*4bdc9457SAndroid Build Coastguard Worker << ", accumulator = " << accumulators[c]; 342*4bdc9457SAndroid Build Coastguard Worker } 343*4bdc9457SAndroid Build Coastguard Worker } 344*4bdc9457SAndroid Build Coastguard Worker } 345*4bdc9457SAndroid Build Coastguard Worker Test(xnn_qs8_gavgpool_minmax_multipass_ukernel_function gavgpool_minmax,xnn_init_qs8_avgpool_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize)346*4bdc9457SAndroid Build Coastguard Worker void Test( 347*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_gavgpool_minmax_multipass_ukernel_function gavgpool_minmax, 348*4bdc9457SAndroid Build Coastguard Worker xnn_init_qs8_avgpool_minmax_params_fn init_params, 349*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_requantize_fn requantize) const 350*4bdc9457SAndroid Build Coastguard Worker { 351*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 352*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 353*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i8dist( 354*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()); 355*4bdc9457SAndroid Build Coastguard Worker 356*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) + 357*4bdc9457SAndroid Build Coastguard Worker (rows() - 1) * input_stride() + channels()); 358*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t, AlignedAllocator<int32_t, 64>> buffer(channels() + XNN_EXTRA_BYTES / sizeof(int8_t)); 359*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> zero(channels() + XNN_EXTRA_BYTES / sizeof(int8_t)); 360*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output(channels()); 361*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output_ref(channels()); 362*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_fp(channels()); 363*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> accumulators(channels()); 364*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 365*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); }); 366*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), INT8_C(0xA5)); 367*4bdc9457SAndroid Build Coastguard Worker 368*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 369*4bdc9457SAndroid Build Coastguard Worker union xnn_qs8_avgpool_minmax_params params; 370*4bdc9457SAndroid Build Coastguard Worker init_params( 371*4bdc9457SAndroid Build Coastguard Worker ¶ms, 372*4bdc9457SAndroid Build Coastguard Worker -int32_t(input_zero_point() - 0x80) * int32_t(rows()), 373*4bdc9457SAndroid Build Coastguard Worker input_scale() / (output_scale() * float(rows())), 374*4bdc9457SAndroid Build Coastguard Worker int8_t(output_zero_point() - 0x80), int8_t(qmin() - 0x80), int8_t(qmax() - 0x80)); 375*4bdc9457SAndroid Build Coastguard Worker 376*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 377*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 378*4bdc9457SAndroid Build Coastguard Worker int32_t acc = 0; 379*4bdc9457SAndroid Build Coastguard Worker for (size_t n = 0; n < rows(); n++) { 380*4bdc9457SAndroid Build Coastguard Worker acc += int32_t(input[n * input_stride() + c]) - int32_t(input_zero_point() - 0x80); 381*4bdc9457SAndroid Build Coastguard Worker } 382*4bdc9457SAndroid Build Coastguard Worker accumulators[c] = acc; 383*4bdc9457SAndroid Build Coastguard Worker output_ref[c] = requantize( 384*4bdc9457SAndroid Build Coastguard Worker acc, input_scale() / (output_scale() * float(rows())), int8_t(output_zero_point() - 0x80), int8_t(qmin() - 0x80), int8_t(qmax() - 0x80)); 385*4bdc9457SAndroid Build Coastguard Worker output_fp[c] = float(acc) * (input_scale() / (output_scale() * float(rows()))) + float(output_zero_point() - 0x80); 386*4bdc9457SAndroid Build Coastguard Worker output_fp[c] = std::min<float>(output_fp[c], float(qmax() - 0x80)); 387*4bdc9457SAndroid Build Coastguard Worker output_fp[c] = std::max<float>(output_fp[c], float(qmin() - 0x80)); 388*4bdc9457SAndroid Build Coastguard Worker } 389*4bdc9457SAndroid Build Coastguard Worker 390*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 391*4bdc9457SAndroid Build Coastguard Worker gavgpool_minmax(rows(), channels(), 392*4bdc9457SAndroid Build Coastguard Worker input.data(), input_stride() * sizeof(int8_t), 393*4bdc9457SAndroid Build Coastguard Worker zero.data(), 394*4bdc9457SAndroid Build Coastguard Worker buffer.data(), 395*4bdc9457SAndroid Build Coastguard Worker output.data(), 396*4bdc9457SAndroid Build Coastguard Worker ¶ms); 397*4bdc9457SAndroid Build Coastguard Worker 398*4bdc9457SAndroid Build Coastguard Worker // Verify results. 399*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 400*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[c]), int32_t(qmax() - 0x80)) 401*4bdc9457SAndroid Build Coastguard Worker << "at channel " << c << " / " << channels() << ", rows = " << rows(); 402*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[c]), int32_t(qmin() - 0x80)) 403*4bdc9457SAndroid Build Coastguard Worker << "at channel " << c << " / " << channels() << ", rows = " << rows(); 404*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(float(int32_t(output[c])), output_fp[c], 0.55f) 405*4bdc9457SAndroid Build Coastguard Worker << "at channel " << c << " / " << channels() << ", rows = " << rows() 406*4bdc9457SAndroid Build Coastguard Worker << ", accumulator = " << accumulators[c]; 407*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(int32_t(output_ref[c]), int32_t(output[c])) 408*4bdc9457SAndroid Build Coastguard Worker << "at channel " << c << " / " << channels() << ", rows = " << rows() 409*4bdc9457SAndroid Build Coastguard Worker << ", accumulator = " << accumulators[c]; 410*4bdc9457SAndroid Build Coastguard Worker } 411*4bdc9457SAndroid Build Coastguard Worker } 412*4bdc9457SAndroid Build Coastguard Worker } 413*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f16_gavgpool_minmax_unipass_ukernel_function gavgpool_minmax,xnn_init_f16_scaleminmax_params_fn init_params)414*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f16_gavgpool_minmax_unipass_ukernel_function gavgpool_minmax, xnn_init_f16_scaleminmax_params_fn init_params) const { 415*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 416*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 417*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist; 418*4bdc9457SAndroid Build Coastguard Worker 419*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> input((rows() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 420*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> zero(channels() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 421*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output(channels()); 422*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(channels()); 423*4bdc9457SAndroid Build Coastguard Worker 424*4bdc9457SAndroid Build Coastguard Worker std::fill(zero.begin(), zero.end(), 0); 425*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 426*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 427*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 428*4bdc9457SAndroid Build Coastguard Worker 429*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 430*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 431*4bdc9457SAndroid Build Coastguard Worker float acc = 0.0f; 432*4bdc9457SAndroid Build Coastguard Worker for (size_t n = 0; n < rows(); n++) { 433*4bdc9457SAndroid Build Coastguard Worker acc += fp16_ieee_to_fp32_value(input[n * input_stride() + c]); 434*4bdc9457SAndroid Build Coastguard Worker } 435*4bdc9457SAndroid Build Coastguard Worker output_ref[c] = acc / float(rows()); 436*4bdc9457SAndroid Build Coastguard Worker } 437*4bdc9457SAndroid Build Coastguard Worker 438*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 439*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 440*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 441*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 442*4bdc9457SAndroid Build Coastguard Worker const float output_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + float(qmin()) / 255.0f * accumulated_range)); 443*4bdc9457SAndroid Build Coastguard Worker const float output_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - float(255 - qmax()) / 255.0f * accumulated_range)); 444*4bdc9457SAndroid Build Coastguard Worker 445*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 446*4bdc9457SAndroid Build Coastguard Worker for (float& output_values : output_ref) { 447*4bdc9457SAndroid Build Coastguard Worker output_values = std::max(std::min(output_values, output_max), output_min); 448*4bdc9457SAndroid Build Coastguard Worker } 449*4bdc9457SAndroid Build Coastguard Worker 450*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 451*4bdc9457SAndroid Build Coastguard Worker xnn_f16_scaleminmax_params params; 452*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, 453*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_from_fp32_value(1.0f / float(rows())), 454*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_from_fp32_value(output_min), 455*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_from_fp32_value(output_max)); 456*4bdc9457SAndroid Build Coastguard Worker 457*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 458*4bdc9457SAndroid Build Coastguard Worker gavgpool_minmax(rows(), channels(), 459*4bdc9457SAndroid Build Coastguard Worker input.data(), input_stride() * sizeof(uint16_t), 460*4bdc9457SAndroid Build Coastguard Worker zero.data(), 461*4bdc9457SAndroid Build Coastguard Worker output.data(), 462*4bdc9457SAndroid Build Coastguard Worker ¶ms); 463*4bdc9457SAndroid Build Coastguard Worker 464*4bdc9457SAndroid Build Coastguard Worker // Verify results. 465*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 466*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(fp16_ieee_to_fp32_value(output[c]), output_max) 467*4bdc9457SAndroid Build Coastguard Worker << "at position " << c << ", rows = " << rows() << ", channels = " << channels(); 468*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(fp16_ieee_to_fp32_value(output[c]), output_min) 469*4bdc9457SAndroid Build Coastguard Worker << "at position " << c << ", rows = " << rows() << ", channels = " << channels(); 470*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(fp16_ieee_to_fp32_value(output[c]), output_ref[c], std::max(1.0e-4f, std::abs(output_ref[c]) * 1.0e-2f)) 471*4bdc9457SAndroid Build Coastguard Worker << "at position " << c << ", rows = " << rows() << ", channels = " << channels(); 472*4bdc9457SAndroid Build Coastguard Worker } 473*4bdc9457SAndroid Build Coastguard Worker } 474*4bdc9457SAndroid Build Coastguard Worker } 475*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f16_gavgpool_minmax_multipass_ukernel_function gavgpool_minmax,xnn_init_f16_scaleminmax_params_fn init_params)476*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f16_gavgpool_minmax_multipass_ukernel_function gavgpool_minmax, xnn_init_f16_scaleminmax_params_fn init_params) const { 477*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 478*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 479*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist; 480*4bdc9457SAndroid Build Coastguard Worker 481*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> input((rows() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 482*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> buffer(channels() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 483*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> zero(channels() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 484*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output(channels()); 485*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(channels()); 486*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 487*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 488*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 489*4bdc9457SAndroid Build Coastguard Worker 490*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 491*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 492*4bdc9457SAndroid Build Coastguard Worker float acc = 0.0f; 493*4bdc9457SAndroid Build Coastguard Worker for (size_t n = 0; n < rows(); n++) { 494*4bdc9457SAndroid Build Coastguard Worker acc += fp16_ieee_to_fp32_value(input[n * input_stride() + c]); 495*4bdc9457SAndroid Build Coastguard Worker } 496*4bdc9457SAndroid Build Coastguard Worker output_ref[c] = acc / float(rows()); 497*4bdc9457SAndroid Build Coastguard Worker } 498*4bdc9457SAndroid Build Coastguard Worker 499*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 500*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 501*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 502*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 503*4bdc9457SAndroid Build Coastguard Worker const float output_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + float(qmin()) / 255.0f * accumulated_range)); 504*4bdc9457SAndroid Build Coastguard Worker const float output_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - float(255 - qmax()) / 255.0f * accumulated_range)); 505*4bdc9457SAndroid Build Coastguard Worker 506*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 507*4bdc9457SAndroid Build Coastguard Worker xnn_f16_scaleminmax_params params; 508*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, 509*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_from_fp32_value(1.0f / float(rows())), 510*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_from_fp32_value(output_min), 511*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_from_fp32_value(output_max)); 512*4bdc9457SAndroid Build Coastguard Worker 513*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 514*4bdc9457SAndroid Build Coastguard Worker for (float& output_values : output_ref) { 515*4bdc9457SAndroid Build Coastguard Worker output_values = std::max(std::min(output_values, output_max), output_min); 516*4bdc9457SAndroid Build Coastguard Worker } 517*4bdc9457SAndroid Build Coastguard Worker 518*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 519*4bdc9457SAndroid Build Coastguard Worker gavgpool_minmax(rows(), channels(), 520*4bdc9457SAndroid Build Coastguard Worker input.data(), input_stride() * sizeof(uint16_t), 521*4bdc9457SAndroid Build Coastguard Worker zero.data(), 522*4bdc9457SAndroid Build Coastguard Worker buffer.data(), 523*4bdc9457SAndroid Build Coastguard Worker output.data(), 524*4bdc9457SAndroid Build Coastguard Worker ¶ms); 525*4bdc9457SAndroid Build Coastguard Worker 526*4bdc9457SAndroid Build Coastguard Worker // Verify results. 527*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 528*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(fp16_ieee_to_fp32_value(output[c]), output_max) 529*4bdc9457SAndroid Build Coastguard Worker << "at position " << c << ", rows = " << rows() << ", channels = " << channels(); 530*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(fp16_ieee_to_fp32_value(output[c]), output_min) 531*4bdc9457SAndroid Build Coastguard Worker << "at position " << c << ", rows = " << rows() << ", channels = " << channels(); 532*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(fp16_ieee_to_fp32_value(output[c]), output_ref[c], std::abs(output_ref[c]) * 1.0e-0f) 533*4bdc9457SAndroid Build Coastguard Worker << "at position " << c << ", rows = " << rows() << ", channels = " << channels(); 534*4bdc9457SAndroid Build Coastguard Worker } 535*4bdc9457SAndroid Build Coastguard Worker } 536*4bdc9457SAndroid Build Coastguard Worker } 537*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f32_gavgpool_minmax_unipass_ukernel_function gavgpool_minmax,xnn_init_f32_scaleminmax_params_fn init_params)538*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_gavgpool_minmax_unipass_ukernel_function gavgpool_minmax, xnn_init_f32_scaleminmax_params_fn init_params) const { 539*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 540*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 541*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist; 542*4bdc9457SAndroid Build Coastguard Worker 543*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input((rows() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float)); 544*4bdc9457SAndroid Build Coastguard Worker std::vector<float> zero(channels() + XNN_EXTRA_BYTES / sizeof(float)); 545*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output(channels()); 546*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(channels()); 547*4bdc9457SAndroid Build Coastguard Worker 548*4bdc9457SAndroid Build Coastguard Worker std::fill(zero.begin(), zero.end(), 0.0f); 549*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 550*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 551*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), std::nanf("")); 552*4bdc9457SAndroid Build Coastguard Worker 553*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 554*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 555*4bdc9457SAndroid Build Coastguard Worker float acc = 0.0f; 556*4bdc9457SAndroid Build Coastguard Worker for (size_t n = 0; n < rows(); n++) { 557*4bdc9457SAndroid Build Coastguard Worker acc += input[n * input_stride() + c]; 558*4bdc9457SAndroid Build Coastguard Worker } 559*4bdc9457SAndroid Build Coastguard Worker output_ref[c] = acc / float(rows()); 560*4bdc9457SAndroid Build Coastguard Worker } 561*4bdc9457SAndroid Build Coastguard Worker 562*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 563*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 564*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 565*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 566*4bdc9457SAndroid Build Coastguard Worker const float output_min = accumulated_min + float(qmin()) / 255.0f * accumulated_range; 567*4bdc9457SAndroid Build Coastguard Worker const float output_max = accumulated_max - float(255 - qmax()) / 255.0f * accumulated_range; 568*4bdc9457SAndroid Build Coastguard Worker 569*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 570*4bdc9457SAndroid Build Coastguard Worker for (float& output_values : output_ref) { 571*4bdc9457SAndroid Build Coastguard Worker output_values = std::max(std::min(output_values, output_max), output_min); 572*4bdc9457SAndroid Build Coastguard Worker } 573*4bdc9457SAndroid Build Coastguard Worker 574*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 575*4bdc9457SAndroid Build Coastguard Worker union xnn_f32_scaleminmax_params params; 576*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, 1.0f / float(rows()), output_min, output_max); 577*4bdc9457SAndroid Build Coastguard Worker 578*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 579*4bdc9457SAndroid Build Coastguard Worker gavgpool_minmax(rows(), channels(), 580*4bdc9457SAndroid Build Coastguard Worker input.data(), input_stride() * sizeof(float), 581*4bdc9457SAndroid Build Coastguard Worker zero.data(), 582*4bdc9457SAndroid Build Coastguard Worker output.data(), 583*4bdc9457SAndroid Build Coastguard Worker ¶ms); 584*4bdc9457SAndroid Build Coastguard Worker 585*4bdc9457SAndroid Build Coastguard Worker // Verify results. 586*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 587*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(output[c], output_max) 588*4bdc9457SAndroid Build Coastguard Worker << "at position " << c << ", rows = " << rows() << ", channels = " << channels(); 589*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output[c], output_min) 590*4bdc9457SAndroid Build Coastguard Worker << "at position " << c << ", rows = " << rows() << ", channels = " << channels(); 591*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(output[c], output_ref[c], std::abs(output_ref[c]) * 1.0e-6f) 592*4bdc9457SAndroid Build Coastguard Worker << "at position " << c << ", rows = " << rows() << ", channels = " << channels(); 593*4bdc9457SAndroid Build Coastguard Worker } 594*4bdc9457SAndroid Build Coastguard Worker } 595*4bdc9457SAndroid Build Coastguard Worker } 596*4bdc9457SAndroid Build Coastguard Worker Test(xnn_f32_gavgpool_minmax_multipass_ukernel_function gavgpool_minmax,xnn_init_f32_scaleminmax_params_fn init_params)597*4bdc9457SAndroid Build Coastguard Worker void Test(xnn_f32_gavgpool_minmax_multipass_ukernel_function gavgpool_minmax, xnn_init_f32_scaleminmax_params_fn init_params) const { 598*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 599*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 600*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist; 601*4bdc9457SAndroid Build Coastguard Worker 602*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input((rows() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float)); 603*4bdc9457SAndroid Build Coastguard Worker std::vector<float, AlignedAllocator<float, 64>> buffer(channels() + XNN_EXTRA_BYTES / sizeof(float)); 604*4bdc9457SAndroid Build Coastguard Worker std::vector<float> zero(channels() + XNN_EXTRA_BYTES / sizeof(float)); 605*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output(channels()); 606*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(channels()); 607*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 608*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 609*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), std::nanf("")); 610*4bdc9457SAndroid Build Coastguard Worker 611*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 612*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 613*4bdc9457SAndroid Build Coastguard Worker float acc = 0.0f; 614*4bdc9457SAndroid Build Coastguard Worker for (size_t n = 0; n < rows(); n++) { 615*4bdc9457SAndroid Build Coastguard Worker acc += input[n * input_stride() + c]; 616*4bdc9457SAndroid Build Coastguard Worker } 617*4bdc9457SAndroid Build Coastguard Worker output_ref[c] = acc / float(rows()); 618*4bdc9457SAndroid Build Coastguard Worker } 619*4bdc9457SAndroid Build Coastguard Worker 620*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 621*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 622*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 623*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 624*4bdc9457SAndroid Build Coastguard Worker const float output_min = accumulated_min + float(qmin()) / 255.0f * accumulated_range; 625*4bdc9457SAndroid Build Coastguard Worker const float output_max = accumulated_max - float(255 - qmax()) / 255.0f * accumulated_range; 626*4bdc9457SAndroid Build Coastguard Worker 627*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters. 628*4bdc9457SAndroid Build Coastguard Worker union xnn_f32_scaleminmax_params params; 629*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, 1.0f / float(rows()), output_min, output_max); 630*4bdc9457SAndroid Build Coastguard Worker 631*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 632*4bdc9457SAndroid Build Coastguard Worker for (float& output_values : output_ref) { 633*4bdc9457SAndroid Build Coastguard Worker output_values = std::max(std::min(output_values, output_max), output_min); 634*4bdc9457SAndroid Build Coastguard Worker } 635*4bdc9457SAndroid Build Coastguard Worker 636*4bdc9457SAndroid Build Coastguard Worker // Call optimized micro-kernel. 637*4bdc9457SAndroid Build Coastguard Worker gavgpool_minmax(rows(), channels(), 638*4bdc9457SAndroid Build Coastguard Worker input.data(), input_stride() * sizeof(float), 639*4bdc9457SAndroid Build Coastguard Worker zero.data(), 640*4bdc9457SAndroid Build Coastguard Worker buffer.data(), 641*4bdc9457SAndroid Build Coastguard Worker output.data(), 642*4bdc9457SAndroid Build Coastguard Worker ¶ms); 643*4bdc9457SAndroid Build Coastguard Worker 644*4bdc9457SAndroid Build Coastguard Worker // Verify results. 645*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 646*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(output[c], output_max) 647*4bdc9457SAndroid Build Coastguard Worker << "at position " << c << ", rows = " << rows() << ", channels = " << channels(); 648*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output[c], output_min) 649*4bdc9457SAndroid Build Coastguard Worker << "at position " << c << ", rows = " << rows() << ", channels = " << channels(); 650*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(output[c], output_ref[c], std::abs(output_ref[c]) * 1.0e-6f) 651*4bdc9457SAndroid Build Coastguard Worker << "at position " << c << ", rows = " << rows() << ", channels = " << channels(); 652*4bdc9457SAndroid Build Coastguard Worker } 653*4bdc9457SAndroid Build Coastguard Worker } 654*4bdc9457SAndroid Build Coastguard Worker } 655*4bdc9457SAndroid Build Coastguard Worker 656*4bdc9457SAndroid Build Coastguard Worker private: 657*4bdc9457SAndroid Build Coastguard Worker size_t rows_{1}; 658*4bdc9457SAndroid Build Coastguard Worker size_t channels_{1}; 659*4bdc9457SAndroid Build Coastguard Worker size_t channel_tile_{1}; 660*4bdc9457SAndroid Build Coastguard Worker size_t input_stride_{0}; 661*4bdc9457SAndroid Build Coastguard Worker float input_scale_{1.25f}; 662*4bdc9457SAndroid Build Coastguard Worker float output_scale_{0.75f}; 663*4bdc9457SAndroid Build Coastguard Worker uint8_t input_zero_point_{121}; 664*4bdc9457SAndroid Build Coastguard Worker uint8_t output_zero_point_{133}; 665*4bdc9457SAndroid Build Coastguard Worker uint8_t qmin_{0}; 666*4bdc9457SAndroid Build Coastguard Worker uint8_t qmax_{255}; 667*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{15}; 668*4bdc9457SAndroid Build Coastguard Worker }; 669