1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates. 2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved. 3*4bdc9457SAndroid Build Coastguard Worker // 4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC 5*4bdc9457SAndroid Build Coastguard Worker // 6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 8*4bdc9457SAndroid Build Coastguard Worker 9*4bdc9457SAndroid Build Coastguard Worker #pragma once 10*4bdc9457SAndroid Build Coastguard Worker 11*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 12*4bdc9457SAndroid Build Coastguard Worker 13*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 14*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 15*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 16*4bdc9457SAndroid Build Coastguard Worker #include <cmath> 17*4bdc9457SAndroid Build Coastguard Worker #include <limits> 18*4bdc9457SAndroid Build Coastguard Worker #include <random> 19*4bdc9457SAndroid Build Coastguard Worker #include <vector> 20*4bdc9457SAndroid Build Coastguard Worker 21*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h> 22*4bdc9457SAndroid Build Coastguard Worker 23*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 24*4bdc9457SAndroid Build Coastguard Worker 25*4bdc9457SAndroid Build Coastguard Worker 26*4bdc9457SAndroid Build Coastguard Worker class GlobalAveragePoolingOperatorTester { 27*4bdc9457SAndroid Build Coastguard Worker public: channels(size_t channels)28*4bdc9457SAndroid Build Coastguard Worker inline GlobalAveragePoolingOperatorTester& channels(size_t channels) { 29*4bdc9457SAndroid Build Coastguard Worker assert(channels != 0); 30*4bdc9457SAndroid Build Coastguard Worker this->channels_ = channels; 31*4bdc9457SAndroid Build Coastguard Worker return *this; 32*4bdc9457SAndroid Build Coastguard Worker } 33*4bdc9457SAndroid Build Coastguard Worker channels()34*4bdc9457SAndroid Build Coastguard Worker inline size_t channels() const { 35*4bdc9457SAndroid Build Coastguard Worker return this->channels_; 36*4bdc9457SAndroid Build Coastguard Worker } 37*4bdc9457SAndroid Build Coastguard Worker width(size_t width)38*4bdc9457SAndroid Build Coastguard Worker inline GlobalAveragePoolingOperatorTester& width(size_t width) { 39*4bdc9457SAndroid Build Coastguard Worker assert(width != 0); 40*4bdc9457SAndroid Build Coastguard Worker this->width_ = width; 41*4bdc9457SAndroid Build Coastguard Worker return *this; 42*4bdc9457SAndroid Build Coastguard Worker } 43*4bdc9457SAndroid Build Coastguard Worker width()44*4bdc9457SAndroid Build Coastguard Worker inline size_t width() const { 45*4bdc9457SAndroid Build Coastguard Worker return this->width_; 46*4bdc9457SAndroid Build Coastguard Worker } 47*4bdc9457SAndroid Build Coastguard Worker input_stride(size_t input_stride)48*4bdc9457SAndroid Build Coastguard Worker inline GlobalAveragePoolingOperatorTester& input_stride(size_t input_stride) { 49*4bdc9457SAndroid Build Coastguard Worker assert(input_stride != 0); 50*4bdc9457SAndroid Build Coastguard Worker this->input_stride_ = input_stride; 51*4bdc9457SAndroid Build Coastguard Worker return *this; 52*4bdc9457SAndroid Build Coastguard Worker } 53*4bdc9457SAndroid Build Coastguard Worker input_stride()54*4bdc9457SAndroid Build Coastguard Worker inline size_t input_stride() const { 55*4bdc9457SAndroid Build Coastguard Worker if (this->input_stride_ == 0) { 56*4bdc9457SAndroid Build Coastguard Worker return channels(); 57*4bdc9457SAndroid Build Coastguard Worker } else { 58*4bdc9457SAndroid Build Coastguard Worker assert(this->input_stride_ >= channels()); 59*4bdc9457SAndroid Build Coastguard Worker return this->input_stride_; 60*4bdc9457SAndroid Build Coastguard Worker } 61*4bdc9457SAndroid Build Coastguard Worker } 62*4bdc9457SAndroid Build Coastguard Worker output_stride(size_t output_stride)63*4bdc9457SAndroid Build Coastguard Worker inline GlobalAveragePoolingOperatorTester& output_stride(size_t output_stride) { 64*4bdc9457SAndroid Build Coastguard Worker assert(output_stride != 0); 65*4bdc9457SAndroid Build Coastguard Worker this->output_stride_ = output_stride; 66*4bdc9457SAndroid Build Coastguard Worker return *this; 67*4bdc9457SAndroid Build Coastguard Worker } 68*4bdc9457SAndroid Build Coastguard Worker output_stride()69*4bdc9457SAndroid Build Coastguard Worker inline size_t output_stride() const { 70*4bdc9457SAndroid Build Coastguard Worker if (this->output_stride_ == 0) { 71*4bdc9457SAndroid Build Coastguard Worker return channels(); 72*4bdc9457SAndroid Build Coastguard Worker } else { 73*4bdc9457SAndroid Build Coastguard Worker assert(this->output_stride_ >= channels()); 74*4bdc9457SAndroid Build Coastguard Worker return this->output_stride_; 75*4bdc9457SAndroid Build Coastguard Worker } 76*4bdc9457SAndroid Build Coastguard Worker } 77*4bdc9457SAndroid Build Coastguard Worker batch_size(size_t batch_size)78*4bdc9457SAndroid Build Coastguard Worker inline GlobalAveragePoolingOperatorTester& batch_size(size_t batch_size) { 79*4bdc9457SAndroid Build Coastguard Worker assert(batch_size != 0); 80*4bdc9457SAndroid Build Coastguard Worker this->batch_size_ = batch_size; 81*4bdc9457SAndroid Build Coastguard Worker return *this; 82*4bdc9457SAndroid Build Coastguard Worker } 83*4bdc9457SAndroid Build Coastguard Worker batch_size()84*4bdc9457SAndroid Build Coastguard Worker inline size_t batch_size() const { 85*4bdc9457SAndroid Build Coastguard Worker return this->batch_size_; 86*4bdc9457SAndroid Build Coastguard Worker } 87*4bdc9457SAndroid Build Coastguard Worker input_scale(float input_scale)88*4bdc9457SAndroid Build Coastguard Worker inline GlobalAveragePoolingOperatorTester& input_scale(float input_scale) { 89*4bdc9457SAndroid Build Coastguard Worker assert(input_scale > 0.0f); 90*4bdc9457SAndroid Build Coastguard Worker assert(std::isnormal(input_scale)); 91*4bdc9457SAndroid Build Coastguard Worker this->input_scale_ = input_scale; 92*4bdc9457SAndroid Build Coastguard Worker return *this; 93*4bdc9457SAndroid Build Coastguard Worker } 94*4bdc9457SAndroid Build Coastguard Worker input_scale()95*4bdc9457SAndroid Build Coastguard Worker inline float input_scale() const { 96*4bdc9457SAndroid Build Coastguard Worker return this->input_scale_; 97*4bdc9457SAndroid Build Coastguard Worker } 98*4bdc9457SAndroid Build Coastguard Worker input_zero_point(uint8_t input_zero_point)99*4bdc9457SAndroid Build Coastguard Worker inline GlobalAveragePoolingOperatorTester& input_zero_point(uint8_t input_zero_point) { 100*4bdc9457SAndroid Build Coastguard Worker this->input_zero_point_ = input_zero_point; 101*4bdc9457SAndroid Build Coastguard Worker return *this; 102*4bdc9457SAndroid Build Coastguard Worker } 103*4bdc9457SAndroid Build Coastguard Worker input_zero_point()104*4bdc9457SAndroid Build Coastguard Worker inline uint8_t input_zero_point() const { 105*4bdc9457SAndroid Build Coastguard Worker return this->input_zero_point_; 106*4bdc9457SAndroid Build Coastguard Worker } 107*4bdc9457SAndroid Build Coastguard Worker output_scale(float output_scale)108*4bdc9457SAndroid Build Coastguard Worker inline GlobalAveragePoolingOperatorTester& output_scale(float output_scale) { 109*4bdc9457SAndroid Build Coastguard Worker assert(output_scale > 0.0f); 110*4bdc9457SAndroid Build Coastguard Worker assert(std::isnormal(output_scale)); 111*4bdc9457SAndroid Build Coastguard Worker this->output_scale_ = output_scale; 112*4bdc9457SAndroid Build Coastguard Worker return *this; 113*4bdc9457SAndroid Build Coastguard Worker } 114*4bdc9457SAndroid Build Coastguard Worker output_scale()115*4bdc9457SAndroid Build Coastguard Worker inline float output_scale() const { 116*4bdc9457SAndroid Build Coastguard Worker return this->output_scale_; 117*4bdc9457SAndroid Build Coastguard Worker } 118*4bdc9457SAndroid Build Coastguard Worker output_zero_point(uint8_t output_zero_point)119*4bdc9457SAndroid Build Coastguard Worker inline GlobalAveragePoolingOperatorTester& output_zero_point(uint8_t output_zero_point) { 120*4bdc9457SAndroid Build Coastguard Worker this->output_zero_point_ = output_zero_point; 121*4bdc9457SAndroid Build Coastguard Worker return *this; 122*4bdc9457SAndroid Build Coastguard Worker } 123*4bdc9457SAndroid Build Coastguard Worker output_zero_point()124*4bdc9457SAndroid Build Coastguard Worker inline uint8_t output_zero_point() const { 125*4bdc9457SAndroid Build Coastguard Worker return this->output_zero_point_; 126*4bdc9457SAndroid Build Coastguard Worker } 127*4bdc9457SAndroid Build Coastguard Worker qmin(uint8_t qmin)128*4bdc9457SAndroid Build Coastguard Worker inline GlobalAveragePoolingOperatorTester& qmin(uint8_t qmin) { 129*4bdc9457SAndroid Build Coastguard Worker this->qmin_ = qmin; 130*4bdc9457SAndroid Build Coastguard Worker return *this; 131*4bdc9457SAndroid Build Coastguard Worker } 132*4bdc9457SAndroid Build Coastguard Worker qmin()133*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmin() const { 134*4bdc9457SAndroid Build Coastguard Worker return this->qmin_; 135*4bdc9457SAndroid Build Coastguard Worker } 136*4bdc9457SAndroid Build Coastguard Worker qmax(uint8_t qmax)137*4bdc9457SAndroid Build Coastguard Worker inline GlobalAveragePoolingOperatorTester& qmax(uint8_t qmax) { 138*4bdc9457SAndroid Build Coastguard Worker this->qmax_ = qmax; 139*4bdc9457SAndroid Build Coastguard Worker return *this; 140*4bdc9457SAndroid Build Coastguard Worker } 141*4bdc9457SAndroid Build Coastguard Worker qmax()142*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmax() const { 143*4bdc9457SAndroid Build Coastguard Worker return this->qmax_; 144*4bdc9457SAndroid Build Coastguard Worker } 145*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)146*4bdc9457SAndroid Build Coastguard Worker inline GlobalAveragePoolingOperatorTester& iterations(size_t iterations) { 147*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 148*4bdc9457SAndroid Build Coastguard Worker return *this; 149*4bdc9457SAndroid Build Coastguard Worker } 150*4bdc9457SAndroid Build Coastguard Worker iterations()151*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 152*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 153*4bdc9457SAndroid Build Coastguard Worker } 154*4bdc9457SAndroid Build Coastguard Worker TestNWCxQU8()155*4bdc9457SAndroid Build Coastguard Worker void TestNWCxQU8() const { 156*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 157*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 158*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> u8dist( 159*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max()); 160*4bdc9457SAndroid Build Coastguard Worker 161*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input((batch_size() * width() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint8_t)); 162*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output(batch_size() * output_stride()); 163*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * channels()); 164*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 165*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); 166*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT8_C(0xA5)); 167*4bdc9457SAndroid Build Coastguard Worker 168*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 169*4bdc9457SAndroid Build Coastguard Worker const double scale = double(input_scale()) / (double(width()) * double(output_scale())); 170*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 171*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < channels(); j++) { 172*4bdc9457SAndroid Build Coastguard Worker double acc = 0.0f; 173*4bdc9457SAndroid Build Coastguard Worker for (size_t k = 0; k < width(); k++) { 174*4bdc9457SAndroid Build Coastguard Worker acc += double(int32_t(input[(i * width() + k) * input_stride() + j]) - int32_t(input_zero_point())); 175*4bdc9457SAndroid Build Coastguard Worker } 176*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + j] = float(acc * scale + double(output_zero_point())); 177*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + j] = std::min<float>(output_ref[i * channels() + j], float(qmax())); 178*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + j] = std::max<float>(output_ref[i * channels() + j], float(qmin())); 179*4bdc9457SAndroid Build Coastguard Worker } 180*4bdc9457SAndroid Build Coastguard Worker } 181*4bdc9457SAndroid Build Coastguard Worker 182*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Global Average Pooling operator. 183*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 184*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t global_average_pooling_op = nullptr; 185*4bdc9457SAndroid Build Coastguard Worker 186*4bdc9457SAndroid Build Coastguard Worker xnn_status status = xnn_create_global_average_pooling_nwc_qu8( 187*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 188*4bdc9457SAndroid Build Coastguard Worker input_zero_point(), input_scale(), 189*4bdc9457SAndroid Build Coastguard Worker output_zero_point(), output_scale(), 190*4bdc9457SAndroid Build Coastguard Worker qmin(), qmax(), 191*4bdc9457SAndroid Build Coastguard Worker 0, &global_average_pooling_op); 192*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 193*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 194*4bdc9457SAndroid Build Coastguard Worker } 195*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 196*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, global_average_pooling_op); 197*4bdc9457SAndroid Build Coastguard Worker 198*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete global_average_pooling_op. 199*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_global_average_pooling_op(global_average_pooling_op, xnn_delete_operator); 200*4bdc9457SAndroid Build Coastguard Worker 201*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 202*4bdc9457SAndroid Build Coastguard Worker xnn_setup_global_average_pooling_nwc_qu8( 203*4bdc9457SAndroid Build Coastguard Worker global_average_pooling_op, 204*4bdc9457SAndroid Build Coastguard Worker batch_size(), width(), 205*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 206*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 207*4bdc9457SAndroid Build Coastguard Worker 208*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 209*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(global_average_pooling_op, nullptr /* thread pool */)); 210*4bdc9457SAndroid Build Coastguard Worker 211*4bdc9457SAndroid Build Coastguard Worker // Verify results. 212*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 213*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 214*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(uint32_t(output[i * output_stride() + c]), uint32_t(qmax())); 215*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(uint32_t(output[i * output_stride() + c]), uint32_t(qmin())); 216*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.80f) 217*4bdc9457SAndroid Build Coastguard Worker << "at batch index " << i << " / " << batch_size() 218*4bdc9457SAndroid Build Coastguard Worker << ", channel " << c << " / " << channels(); 219*4bdc9457SAndroid Build Coastguard Worker } 220*4bdc9457SAndroid Build Coastguard Worker } 221*4bdc9457SAndroid Build Coastguard Worker } 222*4bdc9457SAndroid Build Coastguard Worker } 223*4bdc9457SAndroid Build Coastguard Worker TestNWCxQS8()224*4bdc9457SAndroid Build Coastguard Worker void TestNWCxQS8() const { 225*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 226*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 227*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i8dist( 228*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()); 229*4bdc9457SAndroid Build Coastguard Worker 230*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> input((batch_size() * width() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(int8_t)); 231*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output(batch_size() * output_stride()); 232*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * channels()); 233*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 234*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); }); 235*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), INT8_C(0xA5)); 236*4bdc9457SAndroid Build Coastguard Worker 237*4bdc9457SAndroid Build Coastguard Worker // Compute reference results. 238*4bdc9457SAndroid Build Coastguard Worker const double scale = double(input_scale()) / (double(width()) * double(output_scale())); 239*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 240*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < channels(); j++) { 241*4bdc9457SAndroid Build Coastguard Worker double acc = 0.0f; 242*4bdc9457SAndroid Build Coastguard Worker for (size_t k = 0; k < width(); k++) { 243*4bdc9457SAndroid Build Coastguard Worker acc += double(int32_t(input[(i * width() + k) * input_stride() + j]) - int32_t(input_zero_point() - 0x80)); 244*4bdc9457SAndroid Build Coastguard Worker } 245*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + j] = float(acc * scale + double(output_zero_point() - 0x80)); 246*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + j] = std::min<float>(output_ref[i * channels() + j], float(qmax() - 0x80)); 247*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + j] = std::max<float>(output_ref[i * channels() + j], float(qmin() - 0x80)); 248*4bdc9457SAndroid Build Coastguard Worker } 249*4bdc9457SAndroid Build Coastguard Worker } 250*4bdc9457SAndroid Build Coastguard Worker 251*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Global Average Pooling operator. 252*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 253*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t global_average_pooling_op = nullptr; 254*4bdc9457SAndroid Build Coastguard Worker 255*4bdc9457SAndroid Build Coastguard Worker xnn_status status = xnn_create_global_average_pooling_nwc_qs8( 256*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 257*4bdc9457SAndroid Build Coastguard Worker int8_t(input_zero_point() - 0x80), input_scale(), 258*4bdc9457SAndroid Build Coastguard Worker int8_t(output_zero_point() - 0x80), output_scale(), 259*4bdc9457SAndroid Build Coastguard Worker int8_t(qmin() - 0x80), int8_t(qmax() - 0x80), 260*4bdc9457SAndroid Build Coastguard Worker 0, &global_average_pooling_op); 261*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 262*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 263*4bdc9457SAndroid Build Coastguard Worker } 264*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 265*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, global_average_pooling_op); 266*4bdc9457SAndroid Build Coastguard Worker 267*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete global_average_pooling_op. 268*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_global_average_pooling_op(global_average_pooling_op, xnn_delete_operator); 269*4bdc9457SAndroid Build Coastguard Worker 270*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 271*4bdc9457SAndroid Build Coastguard Worker xnn_setup_global_average_pooling_nwc_qs8( 272*4bdc9457SAndroid Build Coastguard Worker global_average_pooling_op, 273*4bdc9457SAndroid Build Coastguard Worker batch_size(), width(), 274*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 275*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 276*4bdc9457SAndroid Build Coastguard Worker 277*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 278*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(global_average_pooling_op, nullptr /* thread pool */)); 279*4bdc9457SAndroid Build Coastguard Worker 280*4bdc9457SAndroid Build Coastguard Worker // Verify results. 281*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 282*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 283*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[i * output_stride() + c]), int32_t(qmax() - 0x80)); 284*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[i * output_stride() + c]), int32_t(qmin() - 0x80)); 285*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(float(int32_t(output[i * output_stride() + c])), output_ref[i * channels() + c], 0.80f) 286*4bdc9457SAndroid Build Coastguard Worker << "at batch index " << i << " / " << batch_size() 287*4bdc9457SAndroid Build Coastguard Worker << ", channel " << c << " / " << channels(); 288*4bdc9457SAndroid Build Coastguard Worker } 289*4bdc9457SAndroid Build Coastguard Worker } 290*4bdc9457SAndroid Build Coastguard Worker } 291*4bdc9457SAndroid Build Coastguard Worker } 292*4bdc9457SAndroid Build Coastguard Worker TestNWCxF16()293*4bdc9457SAndroid Build Coastguard Worker void TestNWCxF16() const { 294*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 295*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 296*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(1.0e-3f, 1.0f); 297*4bdc9457SAndroid Build Coastguard Worker 298*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> input((batch_size() * width() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(uint16_t)); 299*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output(batch_size() * output_stride()); 300*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * channels()); 301*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 302*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 303*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 304*4bdc9457SAndroid Build Coastguard Worker 305*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 306*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 307*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < channels(); j++) { 308*4bdc9457SAndroid Build Coastguard Worker float acc = 0.0f; 309*4bdc9457SAndroid Build Coastguard Worker for (size_t k = 0; k < width(); k++) { 310*4bdc9457SAndroid Build Coastguard Worker acc += fp16_ieee_to_fp32_value(input[(i * width() + k) * input_stride() + j]); 311*4bdc9457SAndroid Build Coastguard Worker } 312*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + j] = acc / float(width()); 313*4bdc9457SAndroid Build Coastguard Worker } 314*4bdc9457SAndroid Build Coastguard Worker } 315*4bdc9457SAndroid Build Coastguard Worker 316*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 317*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 318*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 319*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 320*4bdc9457SAndroid Build Coastguard Worker const float scaled_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + accumulated_range / 255.0f * float(qmin()))); 321*4bdc9457SAndroid Build Coastguard Worker const float scaled_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - accumulated_range / 255.0f * float(255 - qmax()))); 322*4bdc9457SAndroid Build Coastguard Worker const float output_min = scaled_min == scaled_max ? -std::numeric_limits<float>::infinity() : scaled_min; 323*4bdc9457SAndroid Build Coastguard Worker const float output_max = scaled_min == scaled_max ? +std::numeric_limits<float>::infinity() : scaled_max; 324*4bdc9457SAndroid Build Coastguard Worker 325*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 326*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) { 327*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min); 328*4bdc9457SAndroid Build Coastguard Worker } 329*4bdc9457SAndroid Build Coastguard Worker 330*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Global Average Pooling operator. 331*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 332*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t global_average_pooling_op = nullptr; 333*4bdc9457SAndroid Build Coastguard Worker 334*4bdc9457SAndroid Build Coastguard Worker xnn_status status = xnn_create_global_average_pooling_nwc_f16( 335*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 336*4bdc9457SAndroid Build Coastguard Worker output_min, output_max, 337*4bdc9457SAndroid Build Coastguard Worker 0, &global_average_pooling_op); 338*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 339*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 340*4bdc9457SAndroid Build Coastguard Worker } 341*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 342*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, global_average_pooling_op); 343*4bdc9457SAndroid Build Coastguard Worker 344*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete global_average_pooling_op. 345*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_global_average_pooling_op(global_average_pooling_op, xnn_delete_operator); 346*4bdc9457SAndroid Build Coastguard Worker 347*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 348*4bdc9457SAndroid Build Coastguard Worker xnn_setup_global_average_pooling_nwc_f16( 349*4bdc9457SAndroid Build Coastguard Worker global_average_pooling_op, 350*4bdc9457SAndroid Build Coastguard Worker batch_size(), width(), 351*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 352*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 353*4bdc9457SAndroid Build Coastguard Worker 354*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 355*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(global_average_pooling_op, nullptr /* thread pool */)); 356*4bdc9457SAndroid Build Coastguard Worker 357*4bdc9457SAndroid Build Coastguard Worker // Verify results. 358*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 359*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 360*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_max); 361*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_min); 362*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_ref[i * channels() + c], std::max(1.0e-4f, std::abs(output_ref[i * channels() + c]) * 1.0e-2f)) 363*4bdc9457SAndroid Build Coastguard Worker << "at batch index " << i << " / " << batch_size() 364*4bdc9457SAndroid Build Coastguard Worker << ", channel " << c << " / " << channels(); 365*4bdc9457SAndroid Build Coastguard Worker } 366*4bdc9457SAndroid Build Coastguard Worker } 367*4bdc9457SAndroid Build Coastguard Worker } 368*4bdc9457SAndroid Build Coastguard Worker } 369*4bdc9457SAndroid Build Coastguard Worker TestNWCxF32()370*4bdc9457SAndroid Build Coastguard Worker void TestNWCxF32() const { 371*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 372*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 373*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist; 374*4bdc9457SAndroid Build Coastguard Worker 375*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input((batch_size() * width() - 1) * input_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float)); 376*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output(batch_size() * output_stride()); 377*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * channels()); 378*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 379*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 380*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), std::nanf("")); 381*4bdc9457SAndroid Build Coastguard Worker 382*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 383*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 384*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < channels(); j++) { 385*4bdc9457SAndroid Build Coastguard Worker float acc = 0.0f; 386*4bdc9457SAndroid Build Coastguard Worker for (size_t k = 0; k < width(); k++) { 387*4bdc9457SAndroid Build Coastguard Worker acc += input[(i * width() + k) * input_stride() + j]; 388*4bdc9457SAndroid Build Coastguard Worker } 389*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + j] = acc / float(width()); 390*4bdc9457SAndroid Build Coastguard Worker } 391*4bdc9457SAndroid Build Coastguard Worker } 392*4bdc9457SAndroid Build Coastguard Worker 393*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 394*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 395*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 396*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 397*4bdc9457SAndroid Build Coastguard Worker const float output_min = accumulated_range == 0.0f ? 398*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<float>::infinity() : 399*4bdc9457SAndroid Build Coastguard Worker accumulated_min + accumulated_range / 255.0f * float(qmin()); 400*4bdc9457SAndroid Build Coastguard Worker const float output_max = accumulated_range == 0.0f ? 401*4bdc9457SAndroid Build Coastguard Worker +std::numeric_limits<float>::infinity() : 402*4bdc9457SAndroid Build Coastguard Worker accumulated_max - accumulated_range / 255.0f * float(255 - qmax()); 403*4bdc9457SAndroid Build Coastguard Worker 404*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 405*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) { 406*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min); 407*4bdc9457SAndroid Build Coastguard Worker } 408*4bdc9457SAndroid Build Coastguard Worker 409*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Global Average Pooling operator. 410*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 411*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t global_average_pooling_op = nullptr; 412*4bdc9457SAndroid Build Coastguard Worker 413*4bdc9457SAndroid Build Coastguard Worker xnn_status status = xnn_create_global_average_pooling_nwc_f32( 414*4bdc9457SAndroid Build Coastguard Worker channels(), input_stride(), output_stride(), 415*4bdc9457SAndroid Build Coastguard Worker output_min, output_max, 416*4bdc9457SAndroid Build Coastguard Worker 0, &global_average_pooling_op); 417*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 418*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 419*4bdc9457SAndroid Build Coastguard Worker } 420*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 421*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, global_average_pooling_op); 422*4bdc9457SAndroid Build Coastguard Worker 423*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete global_average_pooling_op. 424*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_global_average_pooling_op(global_average_pooling_op, xnn_delete_operator); 425*4bdc9457SAndroid Build Coastguard Worker 426*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 427*4bdc9457SAndroid Build Coastguard Worker xnn_setup_global_average_pooling_nwc_f32( 428*4bdc9457SAndroid Build Coastguard Worker global_average_pooling_op, 429*4bdc9457SAndroid Build Coastguard Worker batch_size(), width(), 430*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 431*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 432*4bdc9457SAndroid Build Coastguard Worker 433*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 434*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(global_average_pooling_op, nullptr /* thread pool */)); 435*4bdc9457SAndroid Build Coastguard Worker 436*4bdc9457SAndroid Build Coastguard Worker // Verify results. 437*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 438*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 439*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(output[i * output_stride() + c], output_max); 440*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output[i * output_stride() + c], output_min); 441*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(output[i * output_stride() + c], output_ref[i * channels() + c], std::abs(output_ref[i * channels() + c]) * 1.0e-6f) 442*4bdc9457SAndroid Build Coastguard Worker << "at batch index " << i << " / " << batch_size() 443*4bdc9457SAndroid Build Coastguard Worker << ", channel " << c << " / " << channels(); 444*4bdc9457SAndroid Build Coastguard Worker } 445*4bdc9457SAndroid Build Coastguard Worker } 446*4bdc9457SAndroid Build Coastguard Worker } 447*4bdc9457SAndroid Build Coastguard Worker } 448*4bdc9457SAndroid Build Coastguard Worker TestNCWxF32()449*4bdc9457SAndroid Build Coastguard Worker void TestNCWxF32() const { 450*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 451*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 452*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist; 453*4bdc9457SAndroid Build Coastguard Worker 454*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input(batch_size() * channels() * width() + XNN_EXTRA_BYTES / sizeof(float)); 455*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output(batch_size() * channels()); 456*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * channels()); 457*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 458*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 459*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), std::nanf("")); 460*4bdc9457SAndroid Build Coastguard Worker 461*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 462*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 463*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < channels(); j++) { 464*4bdc9457SAndroid Build Coastguard Worker float acc = 0.0f; 465*4bdc9457SAndroid Build Coastguard Worker for (size_t k = 0; k < width(); k++) { 466*4bdc9457SAndroid Build Coastguard Worker acc += input[(i * channels() + j) * width() + k]; 467*4bdc9457SAndroid Build Coastguard Worker } 468*4bdc9457SAndroid Build Coastguard Worker output_ref[i * channels() + j] = acc / float(width()); 469*4bdc9457SAndroid Build Coastguard Worker } 470*4bdc9457SAndroid Build Coastguard Worker } 471*4bdc9457SAndroid Build Coastguard Worker 472*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 473*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 474*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 475*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 476*4bdc9457SAndroid Build Coastguard Worker const float output_min = accumulated_range == 0.0f ? 477*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<float>::infinity() : 478*4bdc9457SAndroid Build Coastguard Worker accumulated_min + accumulated_range / 255.0f * float(qmin()); 479*4bdc9457SAndroid Build Coastguard Worker const float output_max = accumulated_range == 0.0f ? 480*4bdc9457SAndroid Build Coastguard Worker +std::numeric_limits<float>::infinity() : 481*4bdc9457SAndroid Build Coastguard Worker accumulated_max - accumulated_range / 255.0f * float(255 - qmax()); 482*4bdc9457SAndroid Build Coastguard Worker 483*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 484*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) { 485*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min); 486*4bdc9457SAndroid Build Coastguard Worker } 487*4bdc9457SAndroid Build Coastguard Worker 488*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Global Average Pooling operator. 489*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 490*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t global_average_pooling_op = nullptr; 491*4bdc9457SAndroid Build Coastguard Worker 492*4bdc9457SAndroid Build Coastguard Worker xnn_status status = xnn_create_global_average_pooling_ncw_f32( 493*4bdc9457SAndroid Build Coastguard Worker channels(), output_min, output_max, 494*4bdc9457SAndroid Build Coastguard Worker 0, &global_average_pooling_op); 495*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_parameter) { 496*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 497*4bdc9457SAndroid Build Coastguard Worker } 498*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 499*4bdc9457SAndroid Build Coastguard Worker 500*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete global_average_pooling_op. 501*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_global_average_pooling_op(global_average_pooling_op, xnn_delete_operator); 502*4bdc9457SAndroid Build Coastguard Worker 503*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 504*4bdc9457SAndroid Build Coastguard Worker xnn_setup_global_average_pooling_ncw_f32( 505*4bdc9457SAndroid Build Coastguard Worker global_average_pooling_op, 506*4bdc9457SAndroid Build Coastguard Worker batch_size(), width(), 507*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 508*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 509*4bdc9457SAndroid Build Coastguard Worker 510*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 511*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(global_average_pooling_op, nullptr /* thread pool */)); 512*4bdc9457SAndroid Build Coastguard Worker 513*4bdc9457SAndroid Build Coastguard Worker // Verify results. 514*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 515*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < channels(); c++) { 516*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(output[i * channels() + c], output_max); 517*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output[i * channels() + c], output_min); 518*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(output[i * channels() + c], output_ref[i * channels() + c], std::abs(output_ref[i * channels() + c]) * 1.0e-5f) 519*4bdc9457SAndroid Build Coastguard Worker << "at batch index " << i << " / " << batch_size() 520*4bdc9457SAndroid Build Coastguard Worker << ", channel " << c << " / " << channels(); 521*4bdc9457SAndroid Build Coastguard Worker } 522*4bdc9457SAndroid Build Coastguard Worker } 523*4bdc9457SAndroid Build Coastguard Worker } 524*4bdc9457SAndroid Build Coastguard Worker } 525*4bdc9457SAndroid Build Coastguard Worker 526*4bdc9457SAndroid Build Coastguard Worker private: 527*4bdc9457SAndroid Build Coastguard Worker size_t batch_size_{1}; 528*4bdc9457SAndroid Build Coastguard Worker size_t width_{1}; 529*4bdc9457SAndroid Build Coastguard Worker size_t channels_{1}; 530*4bdc9457SAndroid Build Coastguard Worker size_t input_stride_{0}; 531*4bdc9457SAndroid Build Coastguard Worker size_t output_stride_{0}; 532*4bdc9457SAndroid Build Coastguard Worker float input_scale_{1.0f}; 533*4bdc9457SAndroid Build Coastguard Worker float output_scale_{1.0f}; 534*4bdc9457SAndroid Build Coastguard Worker uint8_t input_zero_point_{121}; 535*4bdc9457SAndroid Build Coastguard Worker uint8_t output_zero_point_{133}; 536*4bdc9457SAndroid Build Coastguard Worker uint8_t qmin_{0}; 537*4bdc9457SAndroid Build Coastguard Worker uint8_t qmax_{255}; 538*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{1}; 539*4bdc9457SAndroid Build Coastguard Worker }; 540