1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates. 2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved. 3*4bdc9457SAndroid Build Coastguard Worker // 4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC 5*4bdc9457SAndroid Build Coastguard Worker // 6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 8*4bdc9457SAndroid Build Coastguard Worker 9*4bdc9457SAndroid Build Coastguard Worker #pragma once 10*4bdc9457SAndroid Build Coastguard Worker 11*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 12*4bdc9457SAndroid Build Coastguard Worker 13*4bdc9457SAndroid Build Coastguard Worker #include <algorithm> 14*4bdc9457SAndroid Build Coastguard Worker #include <cassert> 15*4bdc9457SAndroid Build Coastguard Worker #include <cmath> 16*4bdc9457SAndroid Build Coastguard Worker #include <cstddef> 17*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib> 18*4bdc9457SAndroid Build Coastguard Worker #include <limits> 19*4bdc9457SAndroid Build Coastguard Worker #include <random> 20*4bdc9457SAndroid Build Coastguard Worker #include <vector> 21*4bdc9457SAndroid Build Coastguard Worker 22*4bdc9457SAndroid Build Coastguard Worker #include "convolution-test-helpers.h" 23*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h> 24*4bdc9457SAndroid Build Coastguard Worker 25*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 26*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/cache.h> 27*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/allocator.h> 28*4bdc9457SAndroid Build Coastguard Worker 29*4bdc9457SAndroid Build Coastguard Worker 30*4bdc9457SAndroid Build Coastguard Worker class ConvolutionOperatorTester { 31*4bdc9457SAndroid Build Coastguard Worker public: 32*4bdc9457SAndroid Build Coastguard Worker enum class WeightsType { 33*4bdc9457SAndroid Build Coastguard Worker Default, 34*4bdc9457SAndroid Build Coastguard Worker FP32, 35*4bdc9457SAndroid Build Coastguard Worker }; 36*4bdc9457SAndroid Build Coastguard Worker padding_tf_same(bool padding_same)37*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& padding_tf_same(bool padding_same) { 38*4bdc9457SAndroid Build Coastguard Worker if (padding_same) { 39*4bdc9457SAndroid Build Coastguard Worker assert(padding_top() == 0); 40*4bdc9457SAndroid Build Coastguard Worker assert(padding_left() == 0); 41*4bdc9457SAndroid Build Coastguard Worker assert(padding_bottom() == 0); 42*4bdc9457SAndroid Build Coastguard Worker assert(padding_right() == 0); 43*4bdc9457SAndroid Build Coastguard Worker } 44*4bdc9457SAndroid Build Coastguard Worker this->padding_tf_same_ = padding_same; 45*4bdc9457SAndroid Build Coastguard Worker return *this; 46*4bdc9457SAndroid Build Coastguard Worker } 47*4bdc9457SAndroid Build Coastguard Worker padding_tf_same()48*4bdc9457SAndroid Build Coastguard Worker inline bool padding_tf_same() const { 49*4bdc9457SAndroid Build Coastguard Worker return this->padding_tf_same_; 50*4bdc9457SAndroid Build Coastguard Worker } 51*4bdc9457SAndroid Build Coastguard Worker padding(uint32_t padding)52*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& padding(uint32_t padding) { 53*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 54*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding; 55*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding; 56*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding; 57*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding; 58*4bdc9457SAndroid Build Coastguard Worker return *this; 59*4bdc9457SAndroid Build Coastguard Worker } 60*4bdc9457SAndroid Build Coastguard Worker padding(uint32_t padding_height,uint32_t padding_width)61*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& padding(uint32_t padding_height, uint32_t padding_width) { 62*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 63*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding_height; 64*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding_width; 65*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding_height; 66*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding_width; 67*4bdc9457SAndroid Build Coastguard Worker return *this; 68*4bdc9457SAndroid Build Coastguard Worker } 69*4bdc9457SAndroid Build Coastguard Worker padding_height(uint32_t padding_height)70*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& padding_height(uint32_t padding_height) { 71*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 72*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding_height; 73*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding_height; 74*4bdc9457SAndroid Build Coastguard Worker return *this; 75*4bdc9457SAndroid Build Coastguard Worker } 76*4bdc9457SAndroid Build Coastguard Worker padding_width(uint32_t padding_width)77*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& padding_width(uint32_t padding_width) { 78*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 79*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding_width; 80*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding_width; 81*4bdc9457SAndroid Build Coastguard Worker return *this; 82*4bdc9457SAndroid Build Coastguard Worker } 83*4bdc9457SAndroid Build Coastguard Worker padding_top(uint32_t padding_top)84*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& padding_top(uint32_t padding_top) { 85*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 86*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding_top; 87*4bdc9457SAndroid Build Coastguard Worker return *this; 88*4bdc9457SAndroid Build Coastguard Worker } 89*4bdc9457SAndroid Build Coastguard Worker padding_top()90*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_top() const { 91*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 92*4bdc9457SAndroid Build Coastguard Worker const uint32_t total_padding_height = 93*4bdc9457SAndroid Build Coastguard Worker (output_height() - 1) * subsampling_height() + dilated_kernel_height() - input_height(); 94*4bdc9457SAndroid Build Coastguard Worker return total_padding_height / 2; 95*4bdc9457SAndroid Build Coastguard Worker } else { 96*4bdc9457SAndroid Build Coastguard Worker return this->padding_top_; 97*4bdc9457SAndroid Build Coastguard Worker } 98*4bdc9457SAndroid Build Coastguard Worker } 99*4bdc9457SAndroid Build Coastguard Worker padding_left(uint32_t padding_left)100*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& padding_left(uint32_t padding_left) { 101*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 102*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding_left; 103*4bdc9457SAndroid Build Coastguard Worker return *this; 104*4bdc9457SAndroid Build Coastguard Worker } 105*4bdc9457SAndroid Build Coastguard Worker padding_left()106*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_left() const { 107*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 108*4bdc9457SAndroid Build Coastguard Worker const uint32_t total_padding_width = 109*4bdc9457SAndroid Build Coastguard Worker (output_width() - 1) * subsampling_width() + dilated_kernel_width() - input_width(); 110*4bdc9457SAndroid Build Coastguard Worker return total_padding_width / 2; 111*4bdc9457SAndroid Build Coastguard Worker } else { 112*4bdc9457SAndroid Build Coastguard Worker return this->padding_left_; 113*4bdc9457SAndroid Build Coastguard Worker } 114*4bdc9457SAndroid Build Coastguard Worker } 115*4bdc9457SAndroid Build Coastguard Worker padding_bottom(uint32_t padding_bottom)116*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& padding_bottom(uint32_t padding_bottom) { 117*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 118*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding_bottom; 119*4bdc9457SAndroid Build Coastguard Worker return *this; 120*4bdc9457SAndroid Build Coastguard Worker } 121*4bdc9457SAndroid Build Coastguard Worker padding_bottom()122*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_bottom() const { 123*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 124*4bdc9457SAndroid Build Coastguard Worker const uint32_t total_padding_height = 125*4bdc9457SAndroid Build Coastguard Worker (output_height() - 1) * subsampling_height() + dilated_kernel_height() - input_height(); 126*4bdc9457SAndroid Build Coastguard Worker return total_padding_height - total_padding_height / 2; 127*4bdc9457SAndroid Build Coastguard Worker } else { 128*4bdc9457SAndroid Build Coastguard Worker return this->padding_bottom_; 129*4bdc9457SAndroid Build Coastguard Worker } 130*4bdc9457SAndroid Build Coastguard Worker } 131*4bdc9457SAndroid Build Coastguard Worker padding_right(uint32_t padding_right)132*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& padding_right(uint32_t padding_right) { 133*4bdc9457SAndroid Build Coastguard Worker assert(!padding_tf_same()); 134*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding_right; 135*4bdc9457SAndroid Build Coastguard Worker return *this; 136*4bdc9457SAndroid Build Coastguard Worker } 137*4bdc9457SAndroid Build Coastguard Worker padding_right()138*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_right() const { 139*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 140*4bdc9457SAndroid Build Coastguard Worker const uint32_t total_padding_width = 141*4bdc9457SAndroid Build Coastguard Worker (output_width() - 1) * subsampling_width() + dilated_kernel_width() - input_width(); 142*4bdc9457SAndroid Build Coastguard Worker return total_padding_width - total_padding_width / 2; 143*4bdc9457SAndroid Build Coastguard Worker } else { 144*4bdc9457SAndroid Build Coastguard Worker return this->padding_right_; 145*4bdc9457SAndroid Build Coastguard Worker } 146*4bdc9457SAndroid Build Coastguard Worker } 147*4bdc9457SAndroid Build Coastguard Worker input_size(uint32_t input_height,uint32_t input_width)148*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& input_size(uint32_t input_height, uint32_t input_width) { 149*4bdc9457SAndroid Build Coastguard Worker assert(input_height >= 1); 150*4bdc9457SAndroid Build Coastguard Worker assert(input_width >= 1); 151*4bdc9457SAndroid Build Coastguard Worker this->input_height_ = input_height; 152*4bdc9457SAndroid Build Coastguard Worker this->input_width_ = input_width; 153*4bdc9457SAndroid Build Coastguard Worker return *this; 154*4bdc9457SAndroid Build Coastguard Worker } 155*4bdc9457SAndroid Build Coastguard Worker input_height(uint32_t input_height)156*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& input_height(uint32_t input_height) { 157*4bdc9457SAndroid Build Coastguard Worker assert(input_height >= 1); 158*4bdc9457SAndroid Build Coastguard Worker this->input_height_ = input_height; 159*4bdc9457SAndroid Build Coastguard Worker return *this; 160*4bdc9457SAndroid Build Coastguard Worker } 161*4bdc9457SAndroid Build Coastguard Worker input_height()162*4bdc9457SAndroid Build Coastguard Worker inline uint32_t input_height() const { 163*4bdc9457SAndroid Build Coastguard Worker return this->input_height_; 164*4bdc9457SAndroid Build Coastguard Worker } 165*4bdc9457SAndroid Build Coastguard Worker input_width(uint32_t input_width)166*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& input_width(uint32_t input_width) { 167*4bdc9457SAndroid Build Coastguard Worker assert(input_width >= 1); 168*4bdc9457SAndroid Build Coastguard Worker this->input_width_ = input_width; 169*4bdc9457SAndroid Build Coastguard Worker return *this; 170*4bdc9457SAndroid Build Coastguard Worker } 171*4bdc9457SAndroid Build Coastguard Worker input_width()172*4bdc9457SAndroid Build Coastguard Worker inline uint32_t input_width() const { 173*4bdc9457SAndroid Build Coastguard Worker return this->input_width_; 174*4bdc9457SAndroid Build Coastguard Worker } 175*4bdc9457SAndroid Build Coastguard Worker groups(uint32_t groups)176*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& groups(uint32_t groups) { 177*4bdc9457SAndroid Build Coastguard Worker assert(groups >= 1); 178*4bdc9457SAndroid Build Coastguard Worker this->groups_ = groups; 179*4bdc9457SAndroid Build Coastguard Worker return *this; 180*4bdc9457SAndroid Build Coastguard Worker } 181*4bdc9457SAndroid Build Coastguard Worker groups()182*4bdc9457SAndroid Build Coastguard Worker inline uint32_t groups() const { 183*4bdc9457SAndroid Build Coastguard Worker return this->groups_; 184*4bdc9457SAndroid Build Coastguard Worker } 185*4bdc9457SAndroid Build Coastguard Worker group_input_channels(size_t group_input_channels)186*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& group_input_channels(size_t group_input_channels) { 187*4bdc9457SAndroid Build Coastguard Worker assert(group_input_channels >= 1); 188*4bdc9457SAndroid Build Coastguard Worker this->group_input_channels_ = group_input_channels; 189*4bdc9457SAndroid Build Coastguard Worker return *this; 190*4bdc9457SAndroid Build Coastguard Worker } 191*4bdc9457SAndroid Build Coastguard Worker group_input_channels()192*4bdc9457SAndroid Build Coastguard Worker inline size_t group_input_channels() const { 193*4bdc9457SAndroid Build Coastguard Worker return this->group_input_channels_; 194*4bdc9457SAndroid Build Coastguard Worker } 195*4bdc9457SAndroid Build Coastguard Worker group_output_channels(size_t group_output_channels)196*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& group_output_channels(size_t group_output_channels) { 197*4bdc9457SAndroid Build Coastguard Worker assert(group_output_channels >= 1); 198*4bdc9457SAndroid Build Coastguard Worker this->group_output_channels_ = group_output_channels; 199*4bdc9457SAndroid Build Coastguard Worker return *this; 200*4bdc9457SAndroid Build Coastguard Worker } 201*4bdc9457SAndroid Build Coastguard Worker group_output_channels()202*4bdc9457SAndroid Build Coastguard Worker inline size_t group_output_channels() const { 203*4bdc9457SAndroid Build Coastguard Worker return this->group_output_channels_; 204*4bdc9457SAndroid Build Coastguard Worker } 205*4bdc9457SAndroid Build Coastguard Worker batch_size(size_t batch_size)206*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& batch_size(size_t batch_size) { 207*4bdc9457SAndroid Build Coastguard Worker assert(batch_size >= 1); 208*4bdc9457SAndroid Build Coastguard Worker this->batch_size_ = batch_size; 209*4bdc9457SAndroid Build Coastguard Worker return *this; 210*4bdc9457SAndroid Build Coastguard Worker } 211*4bdc9457SAndroid Build Coastguard Worker batch_size()212*4bdc9457SAndroid Build Coastguard Worker inline size_t batch_size() const { 213*4bdc9457SAndroid Build Coastguard Worker return this->batch_size_; 214*4bdc9457SAndroid Build Coastguard Worker } 215*4bdc9457SAndroid Build Coastguard Worker kernel_size(uint32_t kernel_size)216*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& kernel_size(uint32_t kernel_size) { 217*4bdc9457SAndroid Build Coastguard Worker assert(kernel_size >= 1); 218*4bdc9457SAndroid Build Coastguard Worker this->kernel_height_ = kernel_size; 219*4bdc9457SAndroid Build Coastguard Worker this->kernel_width_ = kernel_size; 220*4bdc9457SAndroid Build Coastguard Worker return *this; 221*4bdc9457SAndroid Build Coastguard Worker } 222*4bdc9457SAndroid Build Coastguard Worker kernel_size(uint32_t kernel_height,uint32_t kernel_width)223*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& kernel_size(uint32_t kernel_height, uint32_t kernel_width) { 224*4bdc9457SAndroid Build Coastguard Worker assert(kernel_height >= 1); 225*4bdc9457SAndroid Build Coastguard Worker assert(kernel_width >= 1); 226*4bdc9457SAndroid Build Coastguard Worker this->kernel_height_ = kernel_height; 227*4bdc9457SAndroid Build Coastguard Worker this->kernel_width_ = kernel_width; 228*4bdc9457SAndroid Build Coastguard Worker return *this; 229*4bdc9457SAndroid Build Coastguard Worker } 230*4bdc9457SAndroid Build Coastguard Worker kernel_height(uint32_t kernel_height)231*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& kernel_height(uint32_t kernel_height) { 232*4bdc9457SAndroid Build Coastguard Worker assert(kernel_height >= 1); 233*4bdc9457SAndroid Build Coastguard Worker this->kernel_height_ = kernel_height; 234*4bdc9457SAndroid Build Coastguard Worker return *this; 235*4bdc9457SAndroid Build Coastguard Worker } 236*4bdc9457SAndroid Build Coastguard Worker kernel_height()237*4bdc9457SAndroid Build Coastguard Worker inline uint32_t kernel_height() const { 238*4bdc9457SAndroid Build Coastguard Worker return this->kernel_height_; 239*4bdc9457SAndroid Build Coastguard Worker } 240*4bdc9457SAndroid Build Coastguard Worker kernel_width(uint32_t kernel_width)241*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& kernel_width(uint32_t kernel_width) { 242*4bdc9457SAndroid Build Coastguard Worker assert(kernel_width >= 1); 243*4bdc9457SAndroid Build Coastguard Worker this->kernel_width_ = kernel_width; 244*4bdc9457SAndroid Build Coastguard Worker return *this; 245*4bdc9457SAndroid Build Coastguard Worker } 246*4bdc9457SAndroid Build Coastguard Worker kernel_width()247*4bdc9457SAndroid Build Coastguard Worker inline uint32_t kernel_width() const { 248*4bdc9457SAndroid Build Coastguard Worker return this->kernel_width_; 249*4bdc9457SAndroid Build Coastguard Worker } 250*4bdc9457SAndroid Build Coastguard Worker dilation(uint32_t dilation)251*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& dilation(uint32_t dilation) { 252*4bdc9457SAndroid Build Coastguard Worker assert(dilation >= 1); 253*4bdc9457SAndroid Build Coastguard Worker this->dilation_height_ = dilation; 254*4bdc9457SAndroid Build Coastguard Worker this->dilation_width_ = dilation; 255*4bdc9457SAndroid Build Coastguard Worker return *this; 256*4bdc9457SAndroid Build Coastguard Worker } 257*4bdc9457SAndroid Build Coastguard Worker dilation(uint32_t dilation_height,uint32_t dilation_width)258*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& dilation(uint32_t dilation_height, uint32_t dilation_width) { 259*4bdc9457SAndroid Build Coastguard Worker assert(dilation_height >= 1); 260*4bdc9457SAndroid Build Coastguard Worker assert(dilation_width >= 1); 261*4bdc9457SAndroid Build Coastguard Worker this->dilation_height_ = dilation_height; 262*4bdc9457SAndroid Build Coastguard Worker this->dilation_width_ = dilation_width; 263*4bdc9457SAndroid Build Coastguard Worker return *this; 264*4bdc9457SAndroid Build Coastguard Worker } 265*4bdc9457SAndroid Build Coastguard Worker dilation_height(uint32_t dilation_height)266*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& dilation_height(uint32_t dilation_height) { 267*4bdc9457SAndroid Build Coastguard Worker assert(dilation_height >= 1); 268*4bdc9457SAndroid Build Coastguard Worker this->dilation_height_ = dilation_height; 269*4bdc9457SAndroid Build Coastguard Worker return *this; 270*4bdc9457SAndroid Build Coastguard Worker } 271*4bdc9457SAndroid Build Coastguard Worker dilation_height()272*4bdc9457SAndroid Build Coastguard Worker inline uint32_t dilation_height() const { 273*4bdc9457SAndroid Build Coastguard Worker return this->dilation_height_; 274*4bdc9457SAndroid Build Coastguard Worker } 275*4bdc9457SAndroid Build Coastguard Worker dilation_width(uint32_t dilation_width)276*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& dilation_width(uint32_t dilation_width) { 277*4bdc9457SAndroid Build Coastguard Worker assert(dilation_width >= 1); 278*4bdc9457SAndroid Build Coastguard Worker this->dilation_width_ = dilation_width; 279*4bdc9457SAndroid Build Coastguard Worker return *this; 280*4bdc9457SAndroid Build Coastguard Worker } 281*4bdc9457SAndroid Build Coastguard Worker dilation_width()282*4bdc9457SAndroid Build Coastguard Worker inline uint32_t dilation_width() const { 283*4bdc9457SAndroid Build Coastguard Worker return this->dilation_width_; 284*4bdc9457SAndroid Build Coastguard Worker } 285*4bdc9457SAndroid Build Coastguard Worker subsampling(uint32_t subsampling)286*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& subsampling(uint32_t subsampling) { 287*4bdc9457SAndroid Build Coastguard Worker assert(subsampling >= 1); 288*4bdc9457SAndroid Build Coastguard Worker this->subsampling_height_ = subsampling; 289*4bdc9457SAndroid Build Coastguard Worker this->subsampling_width_ = subsampling; 290*4bdc9457SAndroid Build Coastguard Worker return *this; 291*4bdc9457SAndroid Build Coastguard Worker } 292*4bdc9457SAndroid Build Coastguard Worker subsampling(uint32_t subsampling_height,uint32_t subsampling_width)293*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& subsampling(uint32_t subsampling_height, uint32_t subsampling_width) { 294*4bdc9457SAndroid Build Coastguard Worker assert(subsampling_height >= 1); 295*4bdc9457SAndroid Build Coastguard Worker assert(subsampling_width >= 1); 296*4bdc9457SAndroid Build Coastguard Worker this->subsampling_height_ = subsampling_height; 297*4bdc9457SAndroid Build Coastguard Worker this->subsampling_width_ = subsampling_width; 298*4bdc9457SAndroid Build Coastguard Worker return *this; 299*4bdc9457SAndroid Build Coastguard Worker } 300*4bdc9457SAndroid Build Coastguard Worker subsampling_height(uint32_t subsampling_height)301*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& subsampling_height(uint32_t subsampling_height) { 302*4bdc9457SAndroid Build Coastguard Worker assert(subsampling_height >= 1); 303*4bdc9457SAndroid Build Coastguard Worker this->subsampling_height_ = subsampling_height; 304*4bdc9457SAndroid Build Coastguard Worker return *this; 305*4bdc9457SAndroid Build Coastguard Worker } 306*4bdc9457SAndroid Build Coastguard Worker subsampling_height()307*4bdc9457SAndroid Build Coastguard Worker inline uint32_t subsampling_height() const { 308*4bdc9457SAndroid Build Coastguard Worker return this->subsampling_height_; 309*4bdc9457SAndroid Build Coastguard Worker } 310*4bdc9457SAndroid Build Coastguard Worker subsampling_width(uint32_t subsampling_width)311*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& subsampling_width(uint32_t subsampling_width) { 312*4bdc9457SAndroid Build Coastguard Worker assert(subsampling_width >= 1); 313*4bdc9457SAndroid Build Coastguard Worker this->subsampling_width_ = subsampling_width; 314*4bdc9457SAndroid Build Coastguard Worker return *this; 315*4bdc9457SAndroid Build Coastguard Worker } 316*4bdc9457SAndroid Build Coastguard Worker subsampling_width()317*4bdc9457SAndroid Build Coastguard Worker inline uint32_t subsampling_width() const { 318*4bdc9457SAndroid Build Coastguard Worker return this->subsampling_width_; 319*4bdc9457SAndroid Build Coastguard Worker } 320*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(size_t input_channel_stride)321*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& input_channel_stride(size_t input_channel_stride) { 322*4bdc9457SAndroid Build Coastguard Worker assert(input_channel_stride >= 1); 323*4bdc9457SAndroid Build Coastguard Worker this->input_channel_stride_ = input_channel_stride; 324*4bdc9457SAndroid Build Coastguard Worker return *this; 325*4bdc9457SAndroid Build Coastguard Worker } 326*4bdc9457SAndroid Build Coastguard Worker input_channel_stride()327*4bdc9457SAndroid Build Coastguard Worker inline size_t input_channel_stride() const { 328*4bdc9457SAndroid Build Coastguard Worker if (this->input_channel_stride_ == 0) { 329*4bdc9457SAndroid Build Coastguard Worker return group_input_channels() * groups(); 330*4bdc9457SAndroid Build Coastguard Worker } else { 331*4bdc9457SAndroid Build Coastguard Worker assert(this->input_channel_stride_ >= group_input_channels() * groups()); 332*4bdc9457SAndroid Build Coastguard Worker return this->input_channel_stride_; 333*4bdc9457SAndroid Build Coastguard Worker } 334*4bdc9457SAndroid Build Coastguard Worker } 335*4bdc9457SAndroid Build Coastguard Worker output_channel_stride(size_t output_channel_stride)336*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& output_channel_stride(size_t output_channel_stride) { 337*4bdc9457SAndroid Build Coastguard Worker assert(output_channel_stride >= 1); 338*4bdc9457SAndroid Build Coastguard Worker this->output_channel_stride_ = output_channel_stride; 339*4bdc9457SAndroid Build Coastguard Worker return *this; 340*4bdc9457SAndroid Build Coastguard Worker } 341*4bdc9457SAndroid Build Coastguard Worker output_channel_stride()342*4bdc9457SAndroid Build Coastguard Worker inline size_t output_channel_stride() const { 343*4bdc9457SAndroid Build Coastguard Worker if (this->output_channel_stride_ == 0) { 344*4bdc9457SAndroid Build Coastguard Worker return group_output_channels() * groups(); 345*4bdc9457SAndroid Build Coastguard Worker } else { 346*4bdc9457SAndroid Build Coastguard Worker assert(this->output_channel_stride_ >= group_output_channels() * groups()); 347*4bdc9457SAndroid Build Coastguard Worker return this->output_channel_stride_; 348*4bdc9457SAndroid Build Coastguard Worker } 349*4bdc9457SAndroid Build Coastguard Worker } 350*4bdc9457SAndroid Build Coastguard Worker dilated_kernel_height()351*4bdc9457SAndroid Build Coastguard Worker inline uint32_t dilated_kernel_height() const { 352*4bdc9457SAndroid Build Coastguard Worker return (kernel_height() - 1) * dilation_height() + 1; 353*4bdc9457SAndroid Build Coastguard Worker } 354*4bdc9457SAndroid Build Coastguard Worker dilated_kernel_width()355*4bdc9457SAndroid Build Coastguard Worker inline uint32_t dilated_kernel_width() const { 356*4bdc9457SAndroid Build Coastguard Worker return (kernel_width() - 1) * dilation_width() + 1; 357*4bdc9457SAndroid Build Coastguard Worker } 358*4bdc9457SAndroid Build Coastguard Worker output_height()359*4bdc9457SAndroid Build Coastguard Worker inline size_t output_height() const { 360*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 361*4bdc9457SAndroid Build Coastguard Worker return (input_height() + subsampling_height() - 1) / subsampling_height(); 362*4bdc9457SAndroid Build Coastguard Worker } else { 363*4bdc9457SAndroid Build Coastguard Worker const size_t padded_input_height = padding_top() + input_height() + padding_bottom(); 364*4bdc9457SAndroid Build Coastguard Worker if (padded_input_height <= dilated_kernel_height()) { 365*4bdc9457SAndroid Build Coastguard Worker return 1; 366*4bdc9457SAndroid Build Coastguard Worker } else { 367*4bdc9457SAndroid Build Coastguard Worker return (padded_input_height - dilated_kernel_height()) / subsampling_height() + 1; 368*4bdc9457SAndroid Build Coastguard Worker } 369*4bdc9457SAndroid Build Coastguard Worker } 370*4bdc9457SAndroid Build Coastguard Worker } 371*4bdc9457SAndroid Build Coastguard Worker output_width()372*4bdc9457SAndroid Build Coastguard Worker inline size_t output_width() const { 373*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 374*4bdc9457SAndroid Build Coastguard Worker return (input_width() + subsampling_width() - 1) / subsampling_width(); 375*4bdc9457SAndroid Build Coastguard Worker } else { 376*4bdc9457SAndroid Build Coastguard Worker const size_t padded_input_width = padding_left() + input_width() + padding_right(); 377*4bdc9457SAndroid Build Coastguard Worker if (padded_input_width <= dilated_kernel_width()) { 378*4bdc9457SAndroid Build Coastguard Worker return 1; 379*4bdc9457SAndroid Build Coastguard Worker } else { 380*4bdc9457SAndroid Build Coastguard Worker return (padded_input_width - dilated_kernel_width()) / subsampling_width() + 1; 381*4bdc9457SAndroid Build Coastguard Worker } 382*4bdc9457SAndroid Build Coastguard Worker } 383*4bdc9457SAndroid Build Coastguard Worker } 384*4bdc9457SAndroid Build Coastguard Worker next_input_size(uint32_t next_input_height,uint32_t next_input_width)385*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& next_input_size(uint32_t next_input_height, uint32_t next_input_width) { 386*4bdc9457SAndroid Build Coastguard Worker assert(next_input_height >= 1); 387*4bdc9457SAndroid Build Coastguard Worker assert(next_input_width >= 1); 388*4bdc9457SAndroid Build Coastguard Worker this->next_input_height_ = next_input_height; 389*4bdc9457SAndroid Build Coastguard Worker this->next_input_width_ = next_input_width; 390*4bdc9457SAndroid Build Coastguard Worker return *this; 391*4bdc9457SAndroid Build Coastguard Worker } 392*4bdc9457SAndroid Build Coastguard Worker next_input_height(uint32_t next_input_height)393*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& next_input_height(uint32_t next_input_height) { 394*4bdc9457SAndroid Build Coastguard Worker assert(next_input_height >= 1); 395*4bdc9457SAndroid Build Coastguard Worker this->next_input_height_ = next_input_height; 396*4bdc9457SAndroid Build Coastguard Worker return *this; 397*4bdc9457SAndroid Build Coastguard Worker } 398*4bdc9457SAndroid Build Coastguard Worker next_input_height()399*4bdc9457SAndroid Build Coastguard Worker inline uint32_t next_input_height() const { 400*4bdc9457SAndroid Build Coastguard Worker if (this->next_input_height_ == 0) { 401*4bdc9457SAndroid Build Coastguard Worker return input_height(); 402*4bdc9457SAndroid Build Coastguard Worker } else { 403*4bdc9457SAndroid Build Coastguard Worker return this->next_input_height_; 404*4bdc9457SAndroid Build Coastguard Worker } 405*4bdc9457SAndroid Build Coastguard Worker } 406*4bdc9457SAndroid Build Coastguard Worker next_input_width(uint32_t next_input_width)407*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& next_input_width(uint32_t next_input_width) { 408*4bdc9457SAndroid Build Coastguard Worker assert(next_input_width >= 1); 409*4bdc9457SAndroid Build Coastguard Worker this->next_input_width_ = next_input_width; 410*4bdc9457SAndroid Build Coastguard Worker return *this; 411*4bdc9457SAndroid Build Coastguard Worker } 412*4bdc9457SAndroid Build Coastguard Worker next_input_width()413*4bdc9457SAndroid Build Coastguard Worker inline uint32_t next_input_width() const { 414*4bdc9457SAndroid Build Coastguard Worker if (this->next_input_width_ == 0) { 415*4bdc9457SAndroid Build Coastguard Worker return input_width(); 416*4bdc9457SAndroid Build Coastguard Worker } else { 417*4bdc9457SAndroid Build Coastguard Worker return this->next_input_width_; 418*4bdc9457SAndroid Build Coastguard Worker } 419*4bdc9457SAndroid Build Coastguard Worker } 420*4bdc9457SAndroid Build Coastguard Worker next_output_height()421*4bdc9457SAndroid Build Coastguard Worker inline size_t next_output_height() const { 422*4bdc9457SAndroid Build Coastguard Worker const size_t padded_input_height = padding_top() + next_input_height() + padding_bottom(); 423*4bdc9457SAndroid Build Coastguard Worker if (padded_input_height <= dilated_kernel_height()) { 424*4bdc9457SAndroid Build Coastguard Worker return 1; 425*4bdc9457SAndroid Build Coastguard Worker } else { 426*4bdc9457SAndroid Build Coastguard Worker return (padded_input_height - dilated_kernel_height()) / subsampling_height() + 1; 427*4bdc9457SAndroid Build Coastguard Worker } 428*4bdc9457SAndroid Build Coastguard Worker } 429*4bdc9457SAndroid Build Coastguard Worker next_output_width()430*4bdc9457SAndroid Build Coastguard Worker inline size_t next_output_width() const { 431*4bdc9457SAndroid Build Coastguard Worker const size_t padded_input_width = padding_left() + next_input_width() + padding_right(); 432*4bdc9457SAndroid Build Coastguard Worker if (padded_input_width <= dilated_kernel_width()) { 433*4bdc9457SAndroid Build Coastguard Worker return 1; 434*4bdc9457SAndroid Build Coastguard Worker } else { 435*4bdc9457SAndroid Build Coastguard Worker return (padded_input_width - dilated_kernel_width()) / subsampling_width() + 1; 436*4bdc9457SAndroid Build Coastguard Worker } 437*4bdc9457SAndroid Build Coastguard Worker } 438*4bdc9457SAndroid Build Coastguard Worker next_batch_size(size_t next_batch_size)439*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& next_batch_size(size_t next_batch_size) { 440*4bdc9457SAndroid Build Coastguard Worker assert(next_batch_size >= 1); 441*4bdc9457SAndroid Build Coastguard Worker this->next_batch_size_ = next_batch_size; 442*4bdc9457SAndroid Build Coastguard Worker return *this; 443*4bdc9457SAndroid Build Coastguard Worker } 444*4bdc9457SAndroid Build Coastguard Worker next_batch_size()445*4bdc9457SAndroid Build Coastguard Worker inline size_t next_batch_size() const { 446*4bdc9457SAndroid Build Coastguard Worker if (this->next_batch_size_ == 0) { 447*4bdc9457SAndroid Build Coastguard Worker return batch_size(); 448*4bdc9457SAndroid Build Coastguard Worker } else { 449*4bdc9457SAndroid Build Coastguard Worker return this->next_batch_size_; 450*4bdc9457SAndroid Build Coastguard Worker } 451*4bdc9457SAndroid Build Coastguard Worker } 452*4bdc9457SAndroid Build Coastguard Worker sparsity(float sparsity)453*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& sparsity(float sparsity) { 454*4bdc9457SAndroid Build Coastguard Worker this->sparsity_ = sparsity; 455*4bdc9457SAndroid Build Coastguard Worker return *this; 456*4bdc9457SAndroid Build Coastguard Worker } 457*4bdc9457SAndroid Build Coastguard Worker sparsity()458*4bdc9457SAndroid Build Coastguard Worker inline float sparsity() const { 459*4bdc9457SAndroid Build Coastguard Worker return this->sparsity_; 460*4bdc9457SAndroid Build Coastguard Worker } 461*4bdc9457SAndroid Build Coastguard Worker qmin(uint8_t qmin)462*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& qmin(uint8_t qmin) { 463*4bdc9457SAndroid Build Coastguard Worker this->qmin_ = qmin; 464*4bdc9457SAndroid Build Coastguard Worker return *this; 465*4bdc9457SAndroid Build Coastguard Worker } 466*4bdc9457SAndroid Build Coastguard Worker qmin()467*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmin() const { 468*4bdc9457SAndroid Build Coastguard Worker return this->qmin_; 469*4bdc9457SAndroid Build Coastguard Worker } 470*4bdc9457SAndroid Build Coastguard Worker qmax(uint8_t qmax)471*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& qmax(uint8_t qmax) { 472*4bdc9457SAndroid Build Coastguard Worker this->qmax_ = qmax; 473*4bdc9457SAndroid Build Coastguard Worker return *this; 474*4bdc9457SAndroid Build Coastguard Worker } 475*4bdc9457SAndroid Build Coastguard Worker qmax()476*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmax() const { 477*4bdc9457SAndroid Build Coastguard Worker return this->qmax_; 478*4bdc9457SAndroid Build Coastguard Worker } 479*4bdc9457SAndroid Build Coastguard Worker force_nhwc_input(bool force_nhwc_input)480*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& force_nhwc_input(bool force_nhwc_input) { 481*4bdc9457SAndroid Build Coastguard Worker this->force_nhwc_input_ = force_nhwc_input; 482*4bdc9457SAndroid Build Coastguard Worker return *this; 483*4bdc9457SAndroid Build Coastguard Worker } 484*4bdc9457SAndroid Build Coastguard Worker force_nhwc_input()485*4bdc9457SAndroid Build Coastguard Worker inline bool force_nhwc_input() const { 486*4bdc9457SAndroid Build Coastguard Worker return this->force_nhwc_input_; 487*4bdc9457SAndroid Build Coastguard Worker } 488*4bdc9457SAndroid Build Coastguard Worker depthwise_layout(bool depthwise_layout)489*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& depthwise_layout(bool depthwise_layout) { 490*4bdc9457SAndroid Build Coastguard Worker this->depthwise_layout_ = depthwise_layout; 491*4bdc9457SAndroid Build Coastguard Worker return *this; 492*4bdc9457SAndroid Build Coastguard Worker } 493*4bdc9457SAndroid Build Coastguard Worker depthwise_layout()494*4bdc9457SAndroid Build Coastguard Worker inline bool depthwise_layout() const { 495*4bdc9457SAndroid Build Coastguard Worker return this->depthwise_layout_; 496*4bdc9457SAndroid Build Coastguard Worker } 497*4bdc9457SAndroid Build Coastguard Worker has_bias(bool has_bias)498*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& has_bias(bool has_bias) { 499*4bdc9457SAndroid Build Coastguard Worker this->has_bias_ = has_bias; 500*4bdc9457SAndroid Build Coastguard Worker return *this; 501*4bdc9457SAndroid Build Coastguard Worker } 502*4bdc9457SAndroid Build Coastguard Worker has_bias()503*4bdc9457SAndroid Build Coastguard Worker inline bool has_bias() const { 504*4bdc9457SAndroid Build Coastguard Worker return this->has_bias_; 505*4bdc9457SAndroid Build Coastguard Worker } 506*4bdc9457SAndroid Build Coastguard Worker weights_type(WeightsType weights_type)507*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& weights_type(WeightsType weights_type) { 508*4bdc9457SAndroid Build Coastguard Worker this->weights_type_ = weights_type; 509*4bdc9457SAndroid Build Coastguard Worker return *this; 510*4bdc9457SAndroid Build Coastguard Worker } 511*4bdc9457SAndroid Build Coastguard Worker weights_type()512*4bdc9457SAndroid Build Coastguard Worker inline WeightsType weights_type() const { 513*4bdc9457SAndroid Build Coastguard Worker return this->weights_type_; 514*4bdc9457SAndroid Build Coastguard Worker } 515*4bdc9457SAndroid Build Coastguard Worker iterations(size_t iterations)516*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& iterations(size_t iterations) { 517*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations; 518*4bdc9457SAndroid Build Coastguard Worker return *this; 519*4bdc9457SAndroid Build Coastguard Worker } 520*4bdc9457SAndroid Build Coastguard Worker iterations()521*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const { 522*4bdc9457SAndroid Build Coastguard Worker return this->iterations_; 523*4bdc9457SAndroid Build Coastguard Worker } 524*4bdc9457SAndroid Build Coastguard Worker 525*4bdc9457SAndroid Build Coastguard Worker #if XNN_PLATFORM_JIT use_jit(bool use_jit)526*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& use_jit(bool use_jit) { 527*4bdc9457SAndroid Build Coastguard Worker this->use_jit_ = use_jit; 528*4bdc9457SAndroid Build Coastguard Worker return *this; 529*4bdc9457SAndroid Build Coastguard Worker } 530*4bdc9457SAndroid Build Coastguard Worker use_jit()531*4bdc9457SAndroid Build Coastguard Worker inline bool use_jit() const { 532*4bdc9457SAndroid Build Coastguard Worker return this->use_jit_; 533*4bdc9457SAndroid Build Coastguard Worker } 534*4bdc9457SAndroid Build Coastguard Worker #endif 535*4bdc9457SAndroid Build Coastguard Worker use_weights_cache(bool use_weights_cache)536*4bdc9457SAndroid Build Coastguard Worker inline ConvolutionOperatorTester& use_weights_cache(bool use_weights_cache) { 537*4bdc9457SAndroid Build Coastguard Worker this->use_weights_cache_ = use_weights_cache; 538*4bdc9457SAndroid Build Coastguard Worker return *this; 539*4bdc9457SAndroid Build Coastguard Worker } 540*4bdc9457SAndroid Build Coastguard Worker use_weights_cache()541*4bdc9457SAndroid Build Coastguard Worker inline bool use_weights_cache() const { 542*4bdc9457SAndroid Build Coastguard Worker return this->use_weights_cache_; 543*4bdc9457SAndroid Build Coastguard Worker } 544*4bdc9457SAndroid Build Coastguard Worker TestNHWCxQC8()545*4bdc9457SAndroid Build Coastguard Worker void TestNHWCxQC8() const { 546*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_type(), WeightsType::Default); 547*4bdc9457SAndroid Build Coastguard Worker 548*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 549*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 550*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i32dist(-10000, 10000); 551*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i8dist( 552*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()); 553*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> w8dist( 554*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()); 555*4bdc9457SAndroid Build Coastguard Worker 556*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) + 557*4bdc9457SAndroid Build Coastguard Worker batch_size() * ((input_height() * input_width() - 1) * input_channel_stride() + groups() * group_input_channels())); 558*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels()); 559*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> bias(groups() * group_output_channels()); 560*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output(batch_size() * ((output_height() * output_width() - 1) * output_channel_stride() + groups() * group_output_channels())); 561*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> accumulators(batch_size() * output_height() * output_width() * groups() * group_output_channels()); 562*4bdc9457SAndroid Build Coastguard Worker std::vector<double> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels()); 563*4bdc9457SAndroid Build Coastguard Worker std::vector<float> requantization_scales(groups() * group_output_channels()); 564*4bdc9457SAndroid Build Coastguard Worker 565*4bdc9457SAndroid Build Coastguard Worker const int8_t input_zero_point = -1; 566*4bdc9457SAndroid Build Coastguard Worker const int8_t output_zero_point = -1; 567*4bdc9457SAndroid Build Coastguard Worker 568*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 569*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); }); 570*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); }); 571*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); }); 572*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), INT8_C(0xA5)); 573*4bdc9457SAndroid Build Coastguard Worker 574*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without renormalization. 575*4bdc9457SAndroid Build Coastguard Worker if (depthwise_layout()) { 576*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(group_input_channels(), 1); 577*4bdc9457SAndroid Build Coastguard Worker xnnpack::compute_depthwise_convolution_qs8_reference_results( 578*4bdc9457SAndroid Build Coastguard Worker batch_size(), 579*4bdc9457SAndroid Build Coastguard Worker output_height(), 580*4bdc9457SAndroid Build Coastguard Worker output_width(), 581*4bdc9457SAndroid Build Coastguard Worker input_height(), 582*4bdc9457SAndroid Build Coastguard Worker input_width(), 583*4bdc9457SAndroid Build Coastguard Worker padding_top(), 584*4bdc9457SAndroid Build Coastguard Worker padding_right(), 585*4bdc9457SAndroid Build Coastguard Worker padding_bottom(), 586*4bdc9457SAndroid Build Coastguard Worker padding_left(), 587*4bdc9457SAndroid Build Coastguard Worker kernel_height(), 588*4bdc9457SAndroid Build Coastguard Worker kernel_width(), 589*4bdc9457SAndroid Build Coastguard Worker subsampling_height(), 590*4bdc9457SAndroid Build Coastguard Worker subsampling_width(), 591*4bdc9457SAndroid Build Coastguard Worker dilation_height(), 592*4bdc9457SAndroid Build Coastguard Worker dilation_width(), 593*4bdc9457SAndroid Build Coastguard Worker groups(), 594*4bdc9457SAndroid Build Coastguard Worker group_output_channels(), 595*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(), 596*4bdc9457SAndroid Build Coastguard Worker input_zero_point, 597*4bdc9457SAndroid Build Coastguard Worker input, 598*4bdc9457SAndroid Build Coastguard Worker kernel, 599*4bdc9457SAndroid Build Coastguard Worker accumulators, 600*4bdc9457SAndroid Build Coastguard Worker has_bias(), 601*4bdc9457SAndroid Build Coastguard Worker bias); 602*4bdc9457SAndroid Build Coastguard Worker } else { 603*4bdc9457SAndroid Build Coastguard Worker xnnpack::compute_convolution_qs8_reference_results( 604*4bdc9457SAndroid Build Coastguard Worker batch_size(), 605*4bdc9457SAndroid Build Coastguard Worker output_height(), 606*4bdc9457SAndroid Build Coastguard Worker output_width(), 607*4bdc9457SAndroid Build Coastguard Worker input_height(), 608*4bdc9457SAndroid Build Coastguard Worker input_width(), 609*4bdc9457SAndroid Build Coastguard Worker padding_top(), 610*4bdc9457SAndroid Build Coastguard Worker padding_right(), 611*4bdc9457SAndroid Build Coastguard Worker padding_bottom(), 612*4bdc9457SAndroid Build Coastguard Worker padding_left(), 613*4bdc9457SAndroid Build Coastguard Worker kernel_height(), 614*4bdc9457SAndroid Build Coastguard Worker kernel_width(), 615*4bdc9457SAndroid Build Coastguard Worker subsampling_height(), 616*4bdc9457SAndroid Build Coastguard Worker subsampling_width(), 617*4bdc9457SAndroid Build Coastguard Worker dilation_height(), 618*4bdc9457SAndroid Build Coastguard Worker dilation_width(), 619*4bdc9457SAndroid Build Coastguard Worker groups(), 620*4bdc9457SAndroid Build Coastguard Worker group_input_channels(), 621*4bdc9457SAndroid Build Coastguard Worker group_output_channels(), 622*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(), 623*4bdc9457SAndroid Build Coastguard Worker input_zero_point, 624*4bdc9457SAndroid Build Coastguard Worker input, 625*4bdc9457SAndroid Build Coastguard Worker kernel, 626*4bdc9457SAndroid Build Coastguard Worker accumulators, 627*4bdc9457SAndroid Build Coastguard Worker has_bias(), 628*4bdc9457SAndroid Build Coastguard Worker bias); 629*4bdc9457SAndroid Build Coastguard Worker } 630*4bdc9457SAndroid Build Coastguard Worker 631*4bdc9457SAndroid Build Coastguard Worker // Compute renormalization parameters. 632*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < groups() * group_output_channels(); c++) { 633*4bdc9457SAndroid Build Coastguard Worker int32_t accumulated_min = accumulators[c]; 634*4bdc9457SAndroid Build Coastguard Worker int32_t accumulated_max = accumulators[c]; 635*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < batch_size() * output_height() * output_width(); px++) { 636*4bdc9457SAndroid Build Coastguard Worker accumulated_min = std::min(accumulated_min, accumulators[px * groups() * group_output_channels() + c]); 637*4bdc9457SAndroid Build Coastguard Worker accumulated_max = std::max(accumulated_max, accumulators[px * groups() * group_output_channels() + c]); 638*4bdc9457SAndroid Build Coastguard Worker } 639*4bdc9457SAndroid Build Coastguard Worker 640*4bdc9457SAndroid Build Coastguard Worker float requantization_scale = 0x1.0p-32f; 641*4bdc9457SAndroid Build Coastguard Worker if (accumulated_max != 0) { 642*4bdc9457SAndroid Build Coastguard Worker requantization_scale = std::max(requantization_scale, 643*4bdc9457SAndroid Build Coastguard Worker float(int32_t(std::numeric_limits<int8_t>::max()) - int32_t(output_zero_point)) / float(accumulated_max)); 644*4bdc9457SAndroid Build Coastguard Worker } 645*4bdc9457SAndroid Build Coastguard Worker if (accumulated_min != 0) { 646*4bdc9457SAndroid Build Coastguard Worker requantization_scale = std::max(requantization_scale, 647*4bdc9457SAndroid Build Coastguard Worker float(int32_t(std::numeric_limits<int8_t>::min()) - int32_t(output_zero_point)) / float(accumulated_min)); 648*4bdc9457SAndroid Build Coastguard Worker } 649*4bdc9457SAndroid Build Coastguard Worker requantization_scale = std::min(requantization_scale, 0x1.FFFFFEp-1f); 650*4bdc9457SAndroid Build Coastguard Worker 651*4bdc9457SAndroid Build Coastguard Worker requantization_scales[c] = requantization_scale; 652*4bdc9457SAndroid Build Coastguard Worker } 653*4bdc9457SAndroid Build Coastguard Worker 654*4bdc9457SAndroid Build Coastguard Worker // Renormalize reference results. 655*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < groups() * group_output_channels(); c++) { 656*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < batch_size() * output_height() * output_width(); px++) { 657*4bdc9457SAndroid Build Coastguard Worker output_ref[px * groups() * group_output_channels() + c] = double(int32_t(output_zero_point)) + 658*4bdc9457SAndroid Build Coastguard Worker double(accumulators[px * groups() * group_output_channels() + c]) * double(requantization_scales[c]); 659*4bdc9457SAndroid Build Coastguard Worker } 660*4bdc9457SAndroid Build Coastguard Worker } 661*4bdc9457SAndroid Build Coastguard Worker std::transform(output_ref.cbegin(), output_ref.cend(), output_ref.begin(), 662*4bdc9457SAndroid Build Coastguard Worker [this](double x) -> double { 663*4bdc9457SAndroid Build Coastguard Worker return std::max<double>(std::min<double>(x, double(qmax() - 0x80)), double(qmin() - 0x80)); 664*4bdc9457SAndroid Build Coastguard Worker }); 665*4bdc9457SAndroid Build Coastguard Worker 666*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Convolution operator. 667*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 668*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op = nullptr; 669*4bdc9457SAndroid Build Coastguard Worker xnn_caches caches = { 670*4bdc9457SAndroid Build Coastguard Worker .code_cache = NULL, 671*4bdc9457SAndroid Build Coastguard Worker .weights_cache = NULL, 672*4bdc9457SAndroid Build Coastguard Worker }; 673*4bdc9457SAndroid Build Coastguard Worker xnn_weights_cache weights_cache; 674*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 675*4bdc9457SAndroid Build Coastguard Worker xnn_init_weights_cache(&weights_cache); 676*4bdc9457SAndroid Build Coastguard Worker caches.weights_cache = &weights_cache; 677*4bdc9457SAndroid Build Coastguard Worker } 678*4bdc9457SAndroid Build Coastguard Worker 679*4bdc9457SAndroid Build Coastguard Worker xnn_status status = xnn_create_convolution2d_nhwc_qc8( 680*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_top(), padding_tf_same() ? 0 : padding_right(), 681*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_bottom(), padding_tf_same() ? 0 : padding_left(), 682*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), 683*4bdc9457SAndroid Build Coastguard Worker subsampling_height(), subsampling_width(), 684*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), 685*4bdc9457SAndroid Build Coastguard Worker groups(), group_input_channels(), group_output_channels(), 686*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(), output_channel_stride(), 687*4bdc9457SAndroid Build Coastguard Worker input_zero_point, 1.0f /* input scale */, requantization_scales.data(), 688*4bdc9457SAndroid Build Coastguard Worker kernel.data(), has_bias() ? bias.data() : nullptr, 689*4bdc9457SAndroid Build Coastguard Worker output_zero_point, 1.0f /* output scale */, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80), 690*4bdc9457SAndroid Build Coastguard Worker (depthwise_layout() ? XNN_FLAG_DEPTHWISE_CONVOLUTION : 0) | (padding_tf_same() ? XNN_FLAG_TENSORFLOW_SAME_PADDING : 0), 691*4bdc9457SAndroid Build Coastguard Worker &caches, 692*4bdc9457SAndroid Build Coastguard Worker &convolution_op); 693*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 694*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 695*4bdc9457SAndroid Build Coastguard Worker } 696*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 697*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convolution_op); 698*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 699*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 700*4bdc9457SAndroid Build Coastguard Worker xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft)); 701*4bdc9457SAndroid Build Coastguard Worker } 702*4bdc9457SAndroid Build Coastguard Worker 703*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete convolution_op. 704*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op, xnn_delete_operator); 705*4bdc9457SAndroid Build Coastguard Worker 706*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 707*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convolution2d_nhwc_qc8( 708*4bdc9457SAndroid Build Coastguard Worker convolution_op, 709*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 710*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 711*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 712*4bdc9457SAndroid Build Coastguard Worker 713*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 714*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convolution_op, nullptr /* thread pool */)); 715*4bdc9457SAndroid Build Coastguard Worker 716*4bdc9457SAndroid Build Coastguard Worker // Verify results. 717*4bdc9457SAndroid Build Coastguard Worker VerifyNHWCxQC8(output, output_ref); 718*4bdc9457SAndroid Build Coastguard Worker 719*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 720*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op2 = nullptr; 721*4bdc9457SAndroid Build Coastguard Worker size_t old_weights_cache_size = weights_cache.cache.weights.size; 722*4bdc9457SAndroid Build Coastguard Worker 723*4bdc9457SAndroid Build Coastguard Worker xnn_status status = xnn_create_convolution2d_nhwc_qc8( 724*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_top(), padding_tf_same() ? 0 : padding_right(), 725*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_bottom(), padding_tf_same() ? 0 : padding_left(), 726*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), 727*4bdc9457SAndroid Build Coastguard Worker subsampling_height(), subsampling_width(), 728*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), 729*4bdc9457SAndroid Build Coastguard Worker groups(), group_input_channels(), group_output_channels(), 730*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(), output_channel_stride(), 731*4bdc9457SAndroid Build Coastguard Worker input_zero_point, 1.0f /* input scale */, requantization_scales.data(), 732*4bdc9457SAndroid Build Coastguard Worker kernel.data(), has_bias() ? bias.data() : nullptr, 733*4bdc9457SAndroid Build Coastguard Worker output_zero_point, 1.0f /* output scale */, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80), 734*4bdc9457SAndroid Build Coastguard Worker (depthwise_layout() ? XNN_FLAG_DEPTHWISE_CONVOLUTION : 0) | (padding_tf_same() ? XNN_FLAG_TENSORFLOW_SAME_PADDING : 0), 735*4bdc9457SAndroid Build Coastguard Worker &caches, 736*4bdc9457SAndroid Build Coastguard Worker &convolution_op2); 737*4bdc9457SAndroid Build Coastguard Worker (void) status; 738*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convolution_op2); 739*4bdc9457SAndroid Build Coastguard Worker 740*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete convolution_op2. 741*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op2, xnn_delete_operator); 742*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output2(output.size(), INT8_C(0xA5)); 743*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 744*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convolution2d_nhwc_qc8( 745*4bdc9457SAndroid Build Coastguard Worker convolution_op2, 746*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 747*4bdc9457SAndroid Build Coastguard Worker input.data(), output2.data(), 748*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 749*4bdc9457SAndroid Build Coastguard Worker 750*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 751*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convolution_op2, nullptr /* thread pool */)); 752*4bdc9457SAndroid Build Coastguard Worker 753*4bdc9457SAndroid Build Coastguard Worker VerifyNHWCxQC8(output2, output_ref); 754*4bdc9457SAndroid Build Coastguard Worker VerifyWeightsCache(weights_cache, old_weights_cache_size); 755*4bdc9457SAndroid Build Coastguard Worker xnn_release_weights_cache(&weights_cache); 756*4bdc9457SAndroid Build Coastguard Worker } 757*4bdc9457SAndroid Build Coastguard Worker } 758*4bdc9457SAndroid Build Coastguard Worker } 759*4bdc9457SAndroid Build Coastguard Worker VerifyNHWCxQC8(const std::vector<int8_t> & output,const std::vector<double> & output_ref)760*4bdc9457SAndroid Build Coastguard Worker void VerifyNHWCxQC8(const std::vector<int8_t> &output, 761*4bdc9457SAndroid Build Coastguard Worker const std::vector<double> &output_ref) const { 762*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 763*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 764*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 765*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 766*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) { 767*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmax() - 0x80)) 768*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 769*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmin() - 0x80)) 770*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 771*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 772*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c], 773*4bdc9457SAndroid Build Coastguard Worker double(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), 774*4bdc9457SAndroid Build Coastguard Worker 0.9) 775*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 776*4bdc9457SAndroid Build Coastguard Worker } 777*4bdc9457SAndroid Build Coastguard Worker } 778*4bdc9457SAndroid Build Coastguard Worker } 779*4bdc9457SAndroid Build Coastguard Worker } 780*4bdc9457SAndroid Build Coastguard Worker } 781*4bdc9457SAndroid Build Coastguard Worker } 782*4bdc9457SAndroid Build Coastguard Worker TestNHWCxQS8()783*4bdc9457SAndroid Build Coastguard Worker void TestNHWCxQS8() const { 784*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_type(), WeightsType::Default); 785*4bdc9457SAndroid Build Coastguard Worker 786*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 787*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 788*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i32dist(-10000, 10000); 789*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i8dist( 790*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()); 791*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> w8dist( 792*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()); 793*4bdc9457SAndroid Build Coastguard Worker 794*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) + 795*4bdc9457SAndroid Build Coastguard Worker batch_size() * ((input_height() * input_width() - 1) * input_channel_stride() + groups() * group_input_channels())); 796*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels()); 797*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> bias(groups() * group_output_channels()); 798*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output(batch_size() * ((output_height() * output_width() - 1) * output_channel_stride() + groups() * group_output_channels())); 799*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> accumulators(batch_size() * output_height() * output_width() * groups() * group_output_channels()); 800*4bdc9457SAndroid Build Coastguard Worker std::vector<double> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels()); 801*4bdc9457SAndroid Build Coastguard Worker 802*4bdc9457SAndroid Build Coastguard Worker const int8_t input_zero_point = -1; 803*4bdc9457SAndroid Build Coastguard Worker 804*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 805*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); }); 806*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); }); 807*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); }); 808*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), INT8_C(0xA5)); 809*4bdc9457SAndroid Build Coastguard Worker 810*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without renormalization. 811*4bdc9457SAndroid Build Coastguard Worker if (depthwise_layout()) { 812*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(group_input_channels(), 1); 813*4bdc9457SAndroid Build Coastguard Worker xnnpack::compute_depthwise_convolution_qs8_reference_results( 814*4bdc9457SAndroid Build Coastguard Worker batch_size(), 815*4bdc9457SAndroid Build Coastguard Worker output_height(), 816*4bdc9457SAndroid Build Coastguard Worker output_width(), 817*4bdc9457SAndroid Build Coastguard Worker input_height(), 818*4bdc9457SAndroid Build Coastguard Worker input_width(), 819*4bdc9457SAndroid Build Coastguard Worker padding_top(), 820*4bdc9457SAndroid Build Coastguard Worker padding_right(), 821*4bdc9457SAndroid Build Coastguard Worker padding_bottom(), 822*4bdc9457SAndroid Build Coastguard Worker padding_left(), 823*4bdc9457SAndroid Build Coastguard Worker kernel_height(), 824*4bdc9457SAndroid Build Coastguard Worker kernel_width(), 825*4bdc9457SAndroid Build Coastguard Worker subsampling_height(), 826*4bdc9457SAndroid Build Coastguard Worker subsampling_width(), 827*4bdc9457SAndroid Build Coastguard Worker dilation_height(), 828*4bdc9457SAndroid Build Coastguard Worker dilation_width(), 829*4bdc9457SAndroid Build Coastguard Worker groups(), 830*4bdc9457SAndroid Build Coastguard Worker group_output_channels(), 831*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(), 832*4bdc9457SAndroid Build Coastguard Worker input_zero_point, 833*4bdc9457SAndroid Build Coastguard Worker input, 834*4bdc9457SAndroid Build Coastguard Worker kernel, 835*4bdc9457SAndroid Build Coastguard Worker accumulators, 836*4bdc9457SAndroid Build Coastguard Worker has_bias(), 837*4bdc9457SAndroid Build Coastguard Worker bias); 838*4bdc9457SAndroid Build Coastguard Worker } else { 839*4bdc9457SAndroid Build Coastguard Worker xnnpack::compute_convolution_qs8_reference_results( 840*4bdc9457SAndroid Build Coastguard Worker batch_size(), 841*4bdc9457SAndroid Build Coastguard Worker output_height(), 842*4bdc9457SAndroid Build Coastguard Worker output_width(), 843*4bdc9457SAndroid Build Coastguard Worker input_height(), 844*4bdc9457SAndroid Build Coastguard Worker input_width(), 845*4bdc9457SAndroid Build Coastguard Worker padding_top(), 846*4bdc9457SAndroid Build Coastguard Worker padding_right(), 847*4bdc9457SAndroid Build Coastguard Worker padding_bottom(), 848*4bdc9457SAndroid Build Coastguard Worker padding_left(), 849*4bdc9457SAndroid Build Coastguard Worker kernel_height(), 850*4bdc9457SAndroid Build Coastguard Worker kernel_width(), 851*4bdc9457SAndroid Build Coastguard Worker subsampling_height(), 852*4bdc9457SAndroid Build Coastguard Worker subsampling_width(), 853*4bdc9457SAndroid Build Coastguard Worker dilation_height(), 854*4bdc9457SAndroid Build Coastguard Worker dilation_width(), 855*4bdc9457SAndroid Build Coastguard Worker groups(), 856*4bdc9457SAndroid Build Coastguard Worker group_input_channels(), 857*4bdc9457SAndroid Build Coastguard Worker group_output_channels(), 858*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(), 859*4bdc9457SAndroid Build Coastguard Worker input_zero_point, 860*4bdc9457SAndroid Build Coastguard Worker input, 861*4bdc9457SAndroid Build Coastguard Worker kernel, 862*4bdc9457SAndroid Build Coastguard Worker accumulators, 863*4bdc9457SAndroid Build Coastguard Worker has_bias(), 864*4bdc9457SAndroid Build Coastguard Worker bias); 865*4bdc9457SAndroid Build Coastguard Worker } 866*4bdc9457SAndroid Build Coastguard Worker 867*4bdc9457SAndroid Build Coastguard Worker // Compute renormalization parameters. 868*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend()); 869*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend()); 870*4bdc9457SAndroid Build Coastguard Worker 871*4bdc9457SAndroid Build Coastguard Worker const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0; 872*4bdc9457SAndroid Build Coastguard Worker const int8_t output_zero_point = int8_t(std::max(std::min( 873*4bdc9457SAndroid Build Coastguard Worker lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale), 874*4bdc9457SAndroid Build Coastguard Worker long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min()))); 875*4bdc9457SAndroid Build Coastguard Worker 876*4bdc9457SAndroid Build Coastguard Worker // Renormalize reference results. 877*4bdc9457SAndroid Build Coastguard Worker std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(), 878*4bdc9457SAndroid Build Coastguard Worker [this, output_scale, output_zero_point](int32_t x) -> double { 879*4bdc9457SAndroid Build Coastguard Worker return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax() - 0x80) - output_zero_point), double(qmin() - 0x80) - output_zero_point); 880*4bdc9457SAndroid Build Coastguard Worker }); 881*4bdc9457SAndroid Build Coastguard Worker 882*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Convolution operator. 883*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 884*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op = nullptr; 885*4bdc9457SAndroid Build Coastguard Worker xnn_caches caches = { 886*4bdc9457SAndroid Build Coastguard Worker .code_cache = NULL, 887*4bdc9457SAndroid Build Coastguard Worker .weights_cache = NULL, 888*4bdc9457SAndroid Build Coastguard Worker }; 889*4bdc9457SAndroid Build Coastguard Worker xnn_weights_cache weights_cache; 890*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 891*4bdc9457SAndroid Build Coastguard Worker xnn_init_weights_cache(&weights_cache); 892*4bdc9457SAndroid Build Coastguard Worker caches.weights_cache = &weights_cache; 893*4bdc9457SAndroid Build Coastguard Worker } 894*4bdc9457SAndroid Build Coastguard Worker 895*4bdc9457SAndroid Build Coastguard Worker xnn_status status = xnn_create_convolution2d_nhwc_qs8( 896*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_top(), padding_tf_same() ? 0 : padding_right(), 897*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_bottom(), padding_tf_same() ? 0 : padding_left(), 898*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), 899*4bdc9457SAndroid Build Coastguard Worker subsampling_height(), subsampling_width(), 900*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), 901*4bdc9457SAndroid Build Coastguard Worker groups(), group_input_channels(), group_output_channels(), 902*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(), output_channel_stride(), 903*4bdc9457SAndroid Build Coastguard Worker input_zero_point, 1.0f /* input scale */, 1.0f /* kernel scale */, 904*4bdc9457SAndroid Build Coastguard Worker kernel.data(), has_bias() ? bias.data() : nullptr, 905*4bdc9457SAndroid Build Coastguard Worker output_zero_point, output_scale, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80), 906*4bdc9457SAndroid Build Coastguard Worker (depthwise_layout() ? XNN_FLAG_DEPTHWISE_CONVOLUTION : 0) | (padding_tf_same() ? XNN_FLAG_TENSORFLOW_SAME_PADDING : 0), 907*4bdc9457SAndroid Build Coastguard Worker &caches, 908*4bdc9457SAndroid Build Coastguard Worker &convolution_op); 909*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 910*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 911*4bdc9457SAndroid Build Coastguard Worker } 912*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 913*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convolution_op); 914*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 915*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 916*4bdc9457SAndroid Build Coastguard Worker xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft)); 917*4bdc9457SAndroid Build Coastguard Worker } 918*4bdc9457SAndroid Build Coastguard Worker 919*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete convolution_op. 920*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op, xnn_delete_operator); 921*4bdc9457SAndroid Build Coastguard Worker 922*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 923*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convolution2d_nhwc_qs8( 924*4bdc9457SAndroid Build Coastguard Worker convolution_op, 925*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 926*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 927*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 928*4bdc9457SAndroid Build Coastguard Worker 929*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 930*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convolution_op, nullptr /* thread pool */)); 931*4bdc9457SAndroid Build Coastguard Worker 932*4bdc9457SAndroid Build Coastguard Worker VerifyNHWCxQS8(output, output_ref, output_zero_point); 933*4bdc9457SAndroid Build Coastguard Worker 934*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 935*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op2 = nullptr; 936*4bdc9457SAndroid Build Coastguard Worker size_t old_weights_cache_size = weights_cache.cache.weights.size; 937*4bdc9457SAndroid Build Coastguard Worker 938*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ( 939*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, 940*4bdc9457SAndroid Build Coastguard Worker xnn_create_convolution2d_nhwc_qs8( 941*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_top(), 942*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_right(), 943*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_bottom(), 944*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_left(), kernel_height(), 945*4bdc9457SAndroid Build Coastguard Worker kernel_width(), subsampling_height(), subsampling_width(), 946*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), groups(), 947*4bdc9457SAndroid Build Coastguard Worker group_input_channels(), group_output_channels(), 948*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(), output_channel_stride(), 949*4bdc9457SAndroid Build Coastguard Worker input_zero_point, 1.0f /* input scale */, 950*4bdc9457SAndroid Build Coastguard Worker 1.0f /* kernel scale */, kernel.data(), 951*4bdc9457SAndroid Build Coastguard Worker has_bias() ? bias.data() : nullptr, output_zero_point, 952*4bdc9457SAndroid Build Coastguard Worker output_scale, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80), 953*4bdc9457SAndroid Build Coastguard Worker (depthwise_layout() ? XNN_FLAG_DEPTHWISE_CONVOLUTION : 0) | 954*4bdc9457SAndroid Build Coastguard Worker (padding_tf_same() ? XNN_FLAG_TENSORFLOW_SAME_PADDING : 0), 955*4bdc9457SAndroid Build Coastguard Worker &caches, &convolution_op2)); 956*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convolution_op2); 957*4bdc9457SAndroid Build Coastguard Worker 958*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete convolution_op. 959*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> 960*4bdc9457SAndroid Build Coastguard Worker auto_convolution_op(convolution_op2, xnn_delete_operator); 961*4bdc9457SAndroid Build Coastguard Worker 962*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output2(output.size(), INT8_C(0xA5)); 963*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 964*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convolution2d_nhwc_qs8( 965*4bdc9457SAndroid Build Coastguard Worker convolution_op2, batch_size(), input_height(), 966*4bdc9457SAndroid Build Coastguard Worker input_width(), input.data(), output2.data(), 967*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 968*4bdc9457SAndroid Build Coastguard Worker 969*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 970*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convolution_op2, nullptr /* thread pool */)); 971*4bdc9457SAndroid Build Coastguard Worker 972*4bdc9457SAndroid Build Coastguard Worker VerifyNHWCxQS8(output2, output_ref, output_zero_point); 973*4bdc9457SAndroid Build Coastguard Worker VerifyWeightsCache(weights_cache, old_weights_cache_size); 974*4bdc9457SAndroid Build Coastguard Worker xnn_release_weights_cache(&weights_cache); 975*4bdc9457SAndroid Build Coastguard Worker } 976*4bdc9457SAndroid Build Coastguard Worker } 977*4bdc9457SAndroid Build Coastguard Worker } 978*4bdc9457SAndroid Build Coastguard Worker VerifyNHWCxQS8(const std::vector<int8_t> & output,const std::vector<double> & output_ref,const int8_t output_zero_point)979*4bdc9457SAndroid Build Coastguard Worker void VerifyNHWCxQS8(const std::vector<int8_t> &output, 980*4bdc9457SAndroid Build Coastguard Worker const std::vector<double> &output_ref, 981*4bdc9457SAndroid Build Coastguard Worker const int8_t output_zero_point) const { 982*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 983*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 984*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 985*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 986*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) { 987*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmax() - 0x80)) 988*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 989*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmin() - 0x80)) 990*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 991*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 992*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c], 993*4bdc9457SAndroid Build Coastguard Worker double(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]) - double(output_zero_point), 994*4bdc9457SAndroid Build Coastguard Worker 0.9) 995*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 996*4bdc9457SAndroid Build Coastguard Worker } 997*4bdc9457SAndroid Build Coastguard Worker } 998*4bdc9457SAndroid Build Coastguard Worker } 999*4bdc9457SAndroid Build Coastguard Worker } 1000*4bdc9457SAndroid Build Coastguard Worker } 1001*4bdc9457SAndroid Build Coastguard Worker } 1002*4bdc9457SAndroid Build Coastguard Worker TestNHWCxQU8()1003*4bdc9457SAndroid Build Coastguard Worker void TestNHWCxQU8() const { 1004*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_type(), WeightsType::Default); 1005*4bdc9457SAndroid Build Coastguard Worker 1006*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 1007*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 1008*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i32dist(-10000, 10000); 1009*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> u8dist( 1010*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max()); 1011*4bdc9457SAndroid Build Coastguard Worker 1012*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + 1013*4bdc9457SAndroid Build Coastguard Worker batch_size() * ((input_height() * input_width() - 1) * input_channel_stride() + groups() * group_input_channels())); 1014*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels()); 1015*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> bias(groups() * group_output_channels()); 1016*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output(batch_size() * ((output_height() * output_width() - 1) * output_channel_stride() + groups() * group_output_channels())); 1017*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> accumulators(batch_size() * output_height() * output_width() * groups() * group_output_channels()); 1018*4bdc9457SAndroid Build Coastguard Worker std::vector<double> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels()); 1019*4bdc9457SAndroid Build Coastguard Worker 1020*4bdc9457SAndroid Build Coastguard Worker const uint8_t input_zero_point = 127; 1021*4bdc9457SAndroid Build Coastguard Worker const uint8_t kernel_zero_point = 127; 1022*4bdc9457SAndroid Build Coastguard Worker 1023*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 1024*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); 1025*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return u8dist(rng); }); 1026*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); }); 1027*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT8_C(0xA5)); 1028*4bdc9457SAndroid Build Coastguard Worker 1029*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without renormalization. 1030*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) { 1031*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1032*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 1033*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 1034*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 1035*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 1036*4bdc9457SAndroid Build Coastguard Worker accumulators[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] = 1037*4bdc9457SAndroid Build Coastguard Worker bias[g * group_output_channels() + oc]; 1038*4bdc9457SAndroid Build Coastguard Worker } 1039*4bdc9457SAndroid Build Coastguard Worker } 1040*4bdc9457SAndroid Build Coastguard Worker } 1041*4bdc9457SAndroid Build Coastguard Worker } 1042*4bdc9457SAndroid Build Coastguard Worker } 1043*4bdc9457SAndroid Build Coastguard Worker } else { 1044*4bdc9457SAndroid Build Coastguard Worker std::fill(accumulators.begin(), accumulators.end(), 0); 1045*4bdc9457SAndroid Build Coastguard Worker } 1046*4bdc9457SAndroid Build Coastguard Worker if (depthwise_layout()) { 1047*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(group_input_channels(), 1); 1048*4bdc9457SAndroid Build Coastguard Worker xnnpack::compute_depthwise_convolution_qu8_reference_results( 1049*4bdc9457SAndroid Build Coastguard Worker batch_size(), 1050*4bdc9457SAndroid Build Coastguard Worker output_height(), 1051*4bdc9457SAndroid Build Coastguard Worker output_width(), 1052*4bdc9457SAndroid Build Coastguard Worker input_height(), 1053*4bdc9457SAndroid Build Coastguard Worker input_width(), 1054*4bdc9457SAndroid Build Coastguard Worker padding_top(), 1055*4bdc9457SAndroid Build Coastguard Worker padding_right(), 1056*4bdc9457SAndroid Build Coastguard Worker padding_bottom(), 1057*4bdc9457SAndroid Build Coastguard Worker padding_left(), 1058*4bdc9457SAndroid Build Coastguard Worker kernel_height(), 1059*4bdc9457SAndroid Build Coastguard Worker kernel_width(), 1060*4bdc9457SAndroid Build Coastguard Worker subsampling_height(), 1061*4bdc9457SAndroid Build Coastguard Worker subsampling_width(), 1062*4bdc9457SAndroid Build Coastguard Worker dilation_height(), 1063*4bdc9457SAndroid Build Coastguard Worker dilation_width(), 1064*4bdc9457SAndroid Build Coastguard Worker groups(), 1065*4bdc9457SAndroid Build Coastguard Worker group_output_channels(), 1066*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(), 1067*4bdc9457SAndroid Build Coastguard Worker input_zero_point, 1068*4bdc9457SAndroid Build Coastguard Worker kernel_zero_point, 1069*4bdc9457SAndroid Build Coastguard Worker input, 1070*4bdc9457SAndroid Build Coastguard Worker kernel, 1071*4bdc9457SAndroid Build Coastguard Worker accumulators, 1072*4bdc9457SAndroid Build Coastguard Worker has_bias(), 1073*4bdc9457SAndroid Build Coastguard Worker bias); 1074*4bdc9457SAndroid Build Coastguard Worker } else { 1075*4bdc9457SAndroid Build Coastguard Worker xnnpack::compute_convolution_qu8_reference_results( 1076*4bdc9457SAndroid Build Coastguard Worker batch_size(), 1077*4bdc9457SAndroid Build Coastguard Worker output_height(), 1078*4bdc9457SAndroid Build Coastguard Worker output_width(), 1079*4bdc9457SAndroid Build Coastguard Worker input_height(), 1080*4bdc9457SAndroid Build Coastguard Worker input_width(), 1081*4bdc9457SAndroid Build Coastguard Worker padding_top(), 1082*4bdc9457SAndroid Build Coastguard Worker padding_right(), 1083*4bdc9457SAndroid Build Coastguard Worker padding_bottom(), 1084*4bdc9457SAndroid Build Coastguard Worker padding_left(), 1085*4bdc9457SAndroid Build Coastguard Worker kernel_height(), 1086*4bdc9457SAndroid Build Coastguard Worker kernel_width(), 1087*4bdc9457SAndroid Build Coastguard Worker subsampling_height(), 1088*4bdc9457SAndroid Build Coastguard Worker subsampling_width(), 1089*4bdc9457SAndroid Build Coastguard Worker dilation_height(), 1090*4bdc9457SAndroid Build Coastguard Worker dilation_width(), 1091*4bdc9457SAndroid Build Coastguard Worker groups(), 1092*4bdc9457SAndroid Build Coastguard Worker group_input_channels(), 1093*4bdc9457SAndroid Build Coastguard Worker group_output_channels(), 1094*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(), 1095*4bdc9457SAndroid Build Coastguard Worker input_zero_point, 1096*4bdc9457SAndroid Build Coastguard Worker kernel_zero_point, 1097*4bdc9457SAndroid Build Coastguard Worker input, 1098*4bdc9457SAndroid Build Coastguard Worker kernel, 1099*4bdc9457SAndroid Build Coastguard Worker accumulators, 1100*4bdc9457SAndroid Build Coastguard Worker has_bias(), 1101*4bdc9457SAndroid Build Coastguard Worker bias); 1102*4bdc9457SAndroid Build Coastguard Worker } 1103*4bdc9457SAndroid Build Coastguard Worker 1104*4bdc9457SAndroid Build Coastguard Worker // Compute renormalization parameters. 1105*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend()); 1106*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend()); 1107*4bdc9457SAndroid Build Coastguard Worker 1108*4bdc9457SAndroid Build Coastguard Worker const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0; 1109*4bdc9457SAndroid Build Coastguard Worker const uint8_t output_zero_point = uint8_t(std::max(std::min( 1110*4bdc9457SAndroid Build Coastguard Worker lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale), 1111*4bdc9457SAndroid Build Coastguard Worker long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min()))); 1112*4bdc9457SAndroid Build Coastguard Worker 1113*4bdc9457SAndroid Build Coastguard Worker // Renormalize reference results. 1114*4bdc9457SAndroid Build Coastguard Worker std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(), 1115*4bdc9457SAndroid Build Coastguard Worker [this, output_scale, output_zero_point](int32_t x) -> double { 1116*4bdc9457SAndroid Build Coastguard Worker return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax()) - output_zero_point), double(qmin()) - output_zero_point); 1117*4bdc9457SAndroid Build Coastguard Worker }); 1118*4bdc9457SAndroid Build Coastguard Worker 1119*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Convolution operator. 1120*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 1121*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op = nullptr; 1122*4bdc9457SAndroid Build Coastguard Worker 1123*4bdc9457SAndroid Build Coastguard Worker xnn_caches caches = { 1124*4bdc9457SAndroid Build Coastguard Worker .code_cache = NULL, 1125*4bdc9457SAndroid Build Coastguard Worker .weights_cache = NULL, 1126*4bdc9457SAndroid Build Coastguard Worker }; 1127*4bdc9457SAndroid Build Coastguard Worker xnn_weights_cache weights_cache; 1128*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 1129*4bdc9457SAndroid Build Coastguard Worker xnn_init_weights_cache(&weights_cache); 1130*4bdc9457SAndroid Build Coastguard Worker caches.weights_cache = &weights_cache; 1131*4bdc9457SAndroid Build Coastguard Worker } 1132*4bdc9457SAndroid Build Coastguard Worker 1133*4bdc9457SAndroid Build Coastguard Worker xnn_status status = xnn_create_convolution2d_nhwc_qu8( 1134*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_top(), padding_tf_same() ? 0 : padding_right(), 1135*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_bottom(), padding_tf_same() ? 0 : padding_left(), 1136*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), 1137*4bdc9457SAndroid Build Coastguard Worker subsampling_height(), subsampling_width(), 1138*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), 1139*4bdc9457SAndroid Build Coastguard Worker groups(), group_input_channels(), group_output_channels(), 1140*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(), output_channel_stride(), 1141*4bdc9457SAndroid Build Coastguard Worker input_zero_point, 1.0f /* input scale */, 1142*4bdc9457SAndroid Build Coastguard Worker kernel_zero_point, 1.0f /* kernel scale */, 1143*4bdc9457SAndroid Build Coastguard Worker kernel.data(), has_bias() ? bias.data() : nullptr, 1144*4bdc9457SAndroid Build Coastguard Worker output_zero_point, output_scale, qmin(), qmax(), 1145*4bdc9457SAndroid Build Coastguard Worker (depthwise_layout() ? XNN_FLAG_DEPTHWISE_CONVOLUTION : 0) | (padding_tf_same() ? XNN_FLAG_TENSORFLOW_SAME_PADDING : 0), 1146*4bdc9457SAndroid Build Coastguard Worker &caches, 1147*4bdc9457SAndroid Build Coastguard Worker &convolution_op); 1148*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 1149*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 1150*4bdc9457SAndroid Build Coastguard Worker } 1151*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 1152*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convolution_op); 1153*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 1154*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1155*4bdc9457SAndroid Build Coastguard Worker xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft)); 1156*4bdc9457SAndroid Build Coastguard Worker } 1157*4bdc9457SAndroid Build Coastguard Worker 1158*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete convolution_op. 1159*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op, xnn_delete_operator); 1160*4bdc9457SAndroid Build Coastguard Worker 1161*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1162*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convolution2d_nhwc_qu8( 1163*4bdc9457SAndroid Build Coastguard Worker convolution_op, 1164*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 1165*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 1166*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 1167*4bdc9457SAndroid Build Coastguard Worker 1168*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1169*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convolution_op, nullptr /* thread pool */)); 1170*4bdc9457SAndroid Build Coastguard Worker 1171*4bdc9457SAndroid Build Coastguard Worker // Verify results. 1172*4bdc9457SAndroid Build Coastguard Worker VerifyNHWCxQU8(output, output_ref, output_zero_point); 1173*4bdc9457SAndroid Build Coastguard Worker 1174*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 1175*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op2 = nullptr; 1176*4bdc9457SAndroid Build Coastguard Worker size_t old_weights_cache_size = weights_cache.cache.weights.size; 1177*4bdc9457SAndroid Build Coastguard Worker 1178*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ( 1179*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, 1180*4bdc9457SAndroid Build Coastguard Worker xnn_create_convolution2d_nhwc_qu8( 1181*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_top(), 1182*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_right(), 1183*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_bottom(), 1184*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_left(), kernel_height(), 1185*4bdc9457SAndroid Build Coastguard Worker kernel_width(), subsampling_height(), subsampling_width(), 1186*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), groups(), 1187*4bdc9457SAndroid Build Coastguard Worker group_input_channels(), group_output_channels(), 1188*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(), output_channel_stride(), 1189*4bdc9457SAndroid Build Coastguard Worker input_zero_point, 1.0f /* input scale */, kernel_zero_point, 1190*4bdc9457SAndroid Build Coastguard Worker 1.0f /* kernel scale */, kernel.data(), 1191*4bdc9457SAndroid Build Coastguard Worker has_bias() ? bias.data() : nullptr, output_zero_point, 1192*4bdc9457SAndroid Build Coastguard Worker output_scale, qmin(), qmax(), 1193*4bdc9457SAndroid Build Coastguard Worker (depthwise_layout() ? XNN_FLAG_DEPTHWISE_CONVOLUTION : 0) | 1194*4bdc9457SAndroid Build Coastguard Worker (padding_tf_same() ? XNN_FLAG_TENSORFLOW_SAME_PADDING : 0), 1195*4bdc9457SAndroid Build Coastguard Worker &caches, &convolution_op2)); 1196*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convolution_op2); 1197*4bdc9457SAndroid Build Coastguard Worker 1198*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete convolution_op2. 1199*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> 1200*4bdc9457SAndroid Build Coastguard Worker auto_convolution_op2(convolution_op2, xnn_delete_operator); 1201*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output2(output.size(), UINT8_C(0xA5)); 1202*4bdc9457SAndroid Build Coastguard Worker 1203*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1204*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convolution2d_nhwc_qu8( 1205*4bdc9457SAndroid Build Coastguard Worker convolution_op2, batch_size(), input_height(), 1206*4bdc9457SAndroid Build Coastguard Worker input_width(), input.data(), output2.data(), 1207*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 1208*4bdc9457SAndroid Build Coastguard Worker 1209*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1210*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convolution_op2, nullptr /* thread pool */)); 1211*4bdc9457SAndroid Build Coastguard Worker 1212*4bdc9457SAndroid Build Coastguard Worker // Verify results. 1213*4bdc9457SAndroid Build Coastguard Worker VerifyNHWCxQU8(output2, output_ref, output_zero_point); 1214*4bdc9457SAndroid Build Coastguard Worker VerifyWeightsCache(weights_cache, old_weights_cache_size); 1215*4bdc9457SAndroid Build Coastguard Worker xnn_release_weights_cache(&weights_cache); 1216*4bdc9457SAndroid Build Coastguard Worker } 1217*4bdc9457SAndroid Build Coastguard Worker } 1218*4bdc9457SAndroid Build Coastguard Worker } 1219*4bdc9457SAndroid Build Coastguard Worker VerifyNHWCxQU8(const std::vector<uint8_t> & output,const std::vector<double> & output_ref,const uint8_t output_zero_point)1220*4bdc9457SAndroid Build Coastguard Worker void VerifyNHWCxQU8(const std::vector<uint8_t> &output, 1221*4bdc9457SAndroid Build Coastguard Worker const std::vector<double> &output_ref, 1222*4bdc9457SAndroid Build Coastguard Worker const uint8_t output_zero_point) const { 1223*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1224*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 1225*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 1226*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 1227*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) { 1228*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmax())) 1229*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 1230*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmin())) 1231*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 1232*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 1233*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c], 1234*4bdc9457SAndroid Build Coastguard Worker double(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]) - double(output_zero_point), 1235*4bdc9457SAndroid Build Coastguard Worker 0.9) 1236*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 1237*4bdc9457SAndroid Build Coastguard Worker } 1238*4bdc9457SAndroid Build Coastguard Worker } 1239*4bdc9457SAndroid Build Coastguard Worker } 1240*4bdc9457SAndroid Build Coastguard Worker } 1241*4bdc9457SAndroid Build Coastguard Worker } 1242*4bdc9457SAndroid Build Coastguard Worker } 1243*4bdc9457SAndroid Build Coastguard Worker TestNHWCxF32()1244*4bdc9457SAndroid Build Coastguard Worker void TestNHWCxF32() const { 1245*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_type(), WeightsType::Default); 1246*4bdc9457SAndroid Build Coastguard Worker 1247*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 1248*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 1249*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.1f, 1.0f); 1250*4bdc9457SAndroid Build Coastguard Worker 1251*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + 1252*4bdc9457SAndroid Build Coastguard Worker batch_size() * ((input_height() * input_width() - 1) * input_channel_stride() + groups() * group_input_channels())); 1253*4bdc9457SAndroid Build Coastguard Worker std::vector<float> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels()); 1254*4bdc9457SAndroid Build Coastguard Worker std::vector<float> bias(groups() * group_output_channels()); 1255*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output(batch_size() * ((output_height() * output_width() - 1) * output_channel_stride() + groups() * group_output_channels())); 1256*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels()); 1257*4bdc9457SAndroid Build Coastguard Worker 1258*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 1259*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 1260*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return f32dist(rng); }); 1261*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); 1262*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), nanf("")); 1263*4bdc9457SAndroid Build Coastguard Worker 1264*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 1265*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) { 1266*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1267*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 1268*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 1269*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 1270*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 1271*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] = 1272*4bdc9457SAndroid Build Coastguard Worker bias[g * group_output_channels() + oc]; 1273*4bdc9457SAndroid Build Coastguard Worker } 1274*4bdc9457SAndroid Build Coastguard Worker } 1275*4bdc9457SAndroid Build Coastguard Worker } 1276*4bdc9457SAndroid Build Coastguard Worker } 1277*4bdc9457SAndroid Build Coastguard Worker } 1278*4bdc9457SAndroid Build Coastguard Worker } else { 1279*4bdc9457SAndroid Build Coastguard Worker std::fill(output_ref.begin(), output_ref.end(), 0.0f); 1280*4bdc9457SAndroid Build Coastguard Worker } 1281*4bdc9457SAndroid Build Coastguard Worker if (depthwise_layout()) { 1282*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(group_input_channels(), 1); 1283*4bdc9457SAndroid Build Coastguard Worker 1284*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1285*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 1286*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 1287*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) { 1288*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top(); 1289*4bdc9457SAndroid Build Coastguard Worker if (iy < input_height()) { 1290*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) { 1291*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left(); 1292*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width()) { 1293*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 1294*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 1295*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] += 1296*4bdc9457SAndroid Build Coastguard Worker input[((i * input_height() + iy) * input_width() + ix) * input_channel_stride() + g] * 1297*4bdc9457SAndroid Build Coastguard Worker kernel[((ky * kernel_width() + kx) * groups() + g) * group_output_channels() + oc]; 1298*4bdc9457SAndroid Build Coastguard Worker } 1299*4bdc9457SAndroid Build Coastguard Worker } 1300*4bdc9457SAndroid Build Coastguard Worker } 1301*4bdc9457SAndroid Build Coastguard Worker } 1302*4bdc9457SAndroid Build Coastguard Worker } 1303*4bdc9457SAndroid Build Coastguard Worker } 1304*4bdc9457SAndroid Build Coastguard Worker } 1305*4bdc9457SAndroid Build Coastguard Worker } 1306*4bdc9457SAndroid Build Coastguard Worker } 1307*4bdc9457SAndroid Build Coastguard Worker } else { 1308*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1309*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 1310*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 1311*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) { 1312*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top(); 1313*4bdc9457SAndroid Build Coastguard Worker if (iy < input_height()) { 1314*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) { 1315*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left(); 1316*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width()) { 1317*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 1318*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 1319*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) { 1320*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] += 1321*4bdc9457SAndroid Build Coastguard Worker input[((i * input_height() + iy) * input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic] * 1322*4bdc9457SAndroid Build Coastguard Worker kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]; 1323*4bdc9457SAndroid Build Coastguard Worker } 1324*4bdc9457SAndroid Build Coastguard Worker } 1325*4bdc9457SAndroid Build Coastguard Worker } 1326*4bdc9457SAndroid Build Coastguard Worker } 1327*4bdc9457SAndroid Build Coastguard Worker } 1328*4bdc9457SAndroid Build Coastguard Worker } 1329*4bdc9457SAndroid Build Coastguard Worker } 1330*4bdc9457SAndroid Build Coastguard Worker } 1331*4bdc9457SAndroid Build Coastguard Worker } 1332*4bdc9457SAndroid Build Coastguard Worker } 1333*4bdc9457SAndroid Build Coastguard Worker } 1334*4bdc9457SAndroid Build Coastguard Worker 1335*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 1336*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 1337*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 1338*4bdc9457SAndroid Build Coastguard Worker 1339*4bdc9457SAndroid Build Coastguard Worker const float output_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin()); 1340*4bdc9457SAndroid Build Coastguard Worker const float output_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax()); 1341*4bdc9457SAndroid Build Coastguard Worker 1342*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 1343*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) { 1344*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min); 1345*4bdc9457SAndroid Build Coastguard Worker } 1346*4bdc9457SAndroid Build Coastguard Worker 1347*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Convolution operator. 1348*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 1349*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op = nullptr; 1350*4bdc9457SAndroid Build Coastguard Worker 1351*4bdc9457SAndroid Build Coastguard Worker xnn_caches caches = { 1352*4bdc9457SAndroid Build Coastguard Worker .code_cache = NULL, 1353*4bdc9457SAndroid Build Coastguard Worker .weights_cache = NULL, 1354*4bdc9457SAndroid Build Coastguard Worker }; 1355*4bdc9457SAndroid Build Coastguard Worker #if XNN_PLATFORM_JIT 1356*4bdc9457SAndroid Build Coastguard Worker xnn_code_cache code_cache; 1357*4bdc9457SAndroid Build Coastguard Worker if (use_jit()) { 1358*4bdc9457SAndroid Build Coastguard Worker xnn_init_code_cache(&code_cache); 1359*4bdc9457SAndroid Build Coastguard Worker caches.code_cache = &code_cache; 1360*4bdc9457SAndroid Build Coastguard Worker } 1361*4bdc9457SAndroid Build Coastguard Worker #endif 1362*4bdc9457SAndroid Build Coastguard Worker xnn_weights_cache weights_cache; 1363*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 1364*4bdc9457SAndroid Build Coastguard Worker xnn_init_weights_cache(&weights_cache); 1365*4bdc9457SAndroid Build Coastguard Worker caches.weights_cache = &weights_cache; 1366*4bdc9457SAndroid Build Coastguard Worker } 1367*4bdc9457SAndroid Build Coastguard Worker 1368*4bdc9457SAndroid Build Coastguard Worker xnn_status status = xnn_create_convolution2d_nhwc_f32( 1369*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_top(), padding_tf_same() ? 0 : padding_right(), 1370*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_bottom(), padding_tf_same() ? 0 : padding_left(), 1371*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), 1372*4bdc9457SAndroid Build Coastguard Worker subsampling_height(), subsampling_width(), 1373*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), 1374*4bdc9457SAndroid Build Coastguard Worker groups(), group_input_channels(), group_output_channels(), 1375*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(), output_channel_stride(), 1376*4bdc9457SAndroid Build Coastguard Worker kernel.data(), has_bias() ? bias.data() : nullptr, 1377*4bdc9457SAndroid Build Coastguard Worker output_min, output_max, 1378*4bdc9457SAndroid Build Coastguard Worker (depthwise_layout() ? XNN_FLAG_DEPTHWISE_CONVOLUTION : 0) | (padding_tf_same() ? XNN_FLAG_TENSORFLOW_SAME_PADDING : 0), 1379*4bdc9457SAndroid Build Coastguard Worker &caches, 1380*4bdc9457SAndroid Build Coastguard Worker &convolution_op); 1381*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 1382*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 1383*4bdc9457SAndroid Build Coastguard Worker } 1384*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 1385*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convolution_op); 1386*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 1387*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1388*4bdc9457SAndroid Build Coastguard Worker xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft)); 1389*4bdc9457SAndroid Build Coastguard Worker } 1390*4bdc9457SAndroid Build Coastguard Worker 1391*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete convolution_op. 1392*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op, xnn_delete_operator); 1393*4bdc9457SAndroid Build Coastguard Worker 1394*4bdc9457SAndroid Build Coastguard Worker #if XNN_PLATFORM_JIT 1395*4bdc9457SAndroid Build Coastguard Worker if (use_jit()) { 1396*4bdc9457SAndroid Build Coastguard Worker // Check that we actually generated code. 1397*4bdc9457SAndroid Build Coastguard Worker ASSERT_GT(code_cache.cache.code.size, 0); 1398*4bdc9457SAndroid Build Coastguard Worker xnn_finalize_code_memory(&code_cache.cache.code); 1399*4bdc9457SAndroid Build Coastguard Worker } 1400*4bdc9457SAndroid Build Coastguard Worker #endif 1401*4bdc9457SAndroid Build Coastguard Worker 1402*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1403*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convolution2d_nhwc_f32( 1404*4bdc9457SAndroid Build Coastguard Worker convolution_op, 1405*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 1406*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 1407*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 1408*4bdc9457SAndroid Build Coastguard Worker 1409*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1410*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convolution_op, nullptr /* thread pool */)); 1411*4bdc9457SAndroid Build Coastguard Worker 1412*4bdc9457SAndroid Build Coastguard Worker VerifyNHWCxF32(output, output_ref, output_min, output_max); 1413*4bdc9457SAndroid Build Coastguard Worker 1414*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 1415*4bdc9457SAndroid Build Coastguard Worker // We already finalized the code cache, so create a new code cache if we are testing JIT. 1416*4bdc9457SAndroid Build Coastguard Worker #if XNN_PLATFORM_JIT 1417*4bdc9457SAndroid Build Coastguard Worker xnn_code_cache inner_code_cache; 1418*4bdc9457SAndroid Build Coastguard Worker if (use_jit()) { 1419*4bdc9457SAndroid Build Coastguard Worker xnn_init_code_cache(&inner_code_cache); 1420*4bdc9457SAndroid Build Coastguard Worker caches.code_cache = &inner_code_cache; 1421*4bdc9457SAndroid Build Coastguard Worker } 1422*4bdc9457SAndroid Build Coastguard Worker #endif 1423*4bdc9457SAndroid Build Coastguard Worker // To test weights cache, we create the operator with the same parameters, and setup with a different output. 1424*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op2 = nullptr; 1425*4bdc9457SAndroid Build Coastguard Worker size_t old_weights_cache_size = weights_cache.cache.weights.size; 1426*4bdc9457SAndroid Build Coastguard Worker 1427*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_create_convolution2d_nhwc_f32( 1428*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_top(), padding_tf_same() ? 0 : padding_right(), 1429*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_bottom(), padding_tf_same() ? 0 : padding_left(), 1430*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), 1431*4bdc9457SAndroid Build Coastguard Worker subsampling_height(), subsampling_width(), 1432*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), 1433*4bdc9457SAndroid Build Coastguard Worker groups(), group_input_channels(), group_output_channels(), 1434*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(), output_channel_stride(), 1435*4bdc9457SAndroid Build Coastguard Worker kernel.data(), has_bias() ? bias.data() : nullptr, 1436*4bdc9457SAndroid Build Coastguard Worker output_min, output_max, 1437*4bdc9457SAndroid Build Coastguard Worker (depthwise_layout() ? XNN_FLAG_DEPTHWISE_CONVOLUTION : 0) | (padding_tf_same() ? XNN_FLAG_TENSORFLOW_SAME_PADDING : 0), 1438*4bdc9457SAndroid Build Coastguard Worker &caches, 1439*4bdc9457SAndroid Build Coastguard Worker &convolution_op2)); 1440*4bdc9457SAndroid Build Coastguard Worker 1441*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convolution_op2); 1442*4bdc9457SAndroid Build Coastguard Worker 1443*4bdc9457SAndroid Build Coastguard Worker #if XNN_PLATFORM_JIT 1444*4bdc9457SAndroid Build Coastguard Worker if (use_jit()) { 1445*4bdc9457SAndroid Build Coastguard Worker // Check that we actually generated code. 1446*4bdc9457SAndroid Build Coastguard Worker ASSERT_GT(inner_code_cache.cache.code.size, 0); 1447*4bdc9457SAndroid Build Coastguard Worker xnn_finalize_code_memory(&inner_code_cache.cache.code); 1448*4bdc9457SAndroid Build Coastguard Worker } 1449*4bdc9457SAndroid Build Coastguard Worker #endif 1450*4bdc9457SAndroid Build Coastguard Worker 1451*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output2(output.size(), nanf("")); 1452*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1453*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convolution2d_nhwc_f32( 1454*4bdc9457SAndroid Build Coastguard Worker convolution_op2, 1455*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 1456*4bdc9457SAndroid Build Coastguard Worker input.data(), output2.data(), 1457*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 1458*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1459*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convolution_op2, nullptr /* thread pool */)); 1460*4bdc9457SAndroid Build Coastguard Worker 1461*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op2(convolution_op2, xnn_delete_operator); 1462*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_cache.cache.hits, 1); 1463*4bdc9457SAndroid Build Coastguard Worker // Ensure that we did not write more weights to the cache because it was a cache hit. 1464*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(old_weights_cache_size, weights_cache.cache.weights.size); 1465*4bdc9457SAndroid Build Coastguard Worker 1466*4bdc9457SAndroid Build Coastguard Worker VerifyNHWCxF32(output2, output_ref, output_min, output_max); 1467*4bdc9457SAndroid Build Coastguard Worker #if XNN_PLATFORM_JIT 1468*4bdc9457SAndroid Build Coastguard Worker if (use_jit()) { 1469*4bdc9457SAndroid Build Coastguard Worker xnn_release_code_cache(&inner_code_cache); 1470*4bdc9457SAndroid Build Coastguard Worker } 1471*4bdc9457SAndroid Build Coastguard Worker #endif 1472*4bdc9457SAndroid Build Coastguard Worker } 1473*4bdc9457SAndroid Build Coastguard Worker 1474*4bdc9457SAndroid Build Coastguard Worker #if XNN_PLATFORM_JIT 1475*4bdc9457SAndroid Build Coastguard Worker if (use_jit()) { 1476*4bdc9457SAndroid Build Coastguard Worker xnn_release_code_cache(&code_cache); 1477*4bdc9457SAndroid Build Coastguard Worker } 1478*4bdc9457SAndroid Build Coastguard Worker #endif 1479*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 1480*4bdc9457SAndroid Build Coastguard Worker xnn_release_weights_cache(&weights_cache); 1481*4bdc9457SAndroid Build Coastguard Worker } 1482*4bdc9457SAndroid Build Coastguard Worker } 1483*4bdc9457SAndroid Build Coastguard Worker } 1484*4bdc9457SAndroid Build Coastguard Worker VerifyNHWCxF32(const std::vector<float> & output,const std::vector<float> & output_ref,const float output_min,const float output_max)1485*4bdc9457SAndroid Build Coastguard Worker void VerifyNHWCxF32(const std::vector<float>& output, const std::vector<float>& output_ref, const float output_min, const float output_max) const { 1486*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1487*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 1488*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 1489*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 1490*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) { 1491*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c], output_min) 1492*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 1493*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c], output_max) 1494*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 1495*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 1496*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c], 1497*4bdc9457SAndroid Build Coastguard Worker output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c], 1498*4bdc9457SAndroid Build Coastguard Worker 1.0e-4 * std::abs(output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c])) 1499*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 1500*4bdc9457SAndroid Build Coastguard Worker } 1501*4bdc9457SAndroid Build Coastguard Worker } 1502*4bdc9457SAndroid Build Coastguard Worker } 1503*4bdc9457SAndroid Build Coastguard Worker } 1504*4bdc9457SAndroid Build Coastguard Worker } 1505*4bdc9457SAndroid Build Coastguard Worker } 1506*4bdc9457SAndroid Build Coastguard Worker TestNHWCxF16()1507*4bdc9457SAndroid Build Coastguard Worker void TestNHWCxF16() const { 1508*4bdc9457SAndroid Build Coastguard Worker switch (weights_type()) { 1509*4bdc9457SAndroid Build Coastguard Worker case WeightsType::Default: 1510*4bdc9457SAndroid Build Coastguard Worker break; 1511*4bdc9457SAndroid Build Coastguard Worker case WeightsType::FP32: 1512*4bdc9457SAndroid Build Coastguard Worker break; 1513*4bdc9457SAndroid Build Coastguard Worker default: 1514*4bdc9457SAndroid Build Coastguard Worker GTEST_FAIL() << "unexpected weights type"; 1515*4bdc9457SAndroid Build Coastguard Worker } 1516*4bdc9457SAndroid Build Coastguard Worker 1517*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 1518*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 1519*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.1f, 1.0f); 1520*4bdc9457SAndroid Build Coastguard Worker 1521*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + 1522*4bdc9457SAndroid Build Coastguard Worker batch_size() * ((input_height() * input_width() - 1) * input_channel_stride() + groups() * group_input_channels())); 1523*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels()); 1524*4bdc9457SAndroid Build Coastguard Worker std::vector<float> kernel_as_float(kernel.size()); 1525*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> bias(groups() * group_output_channels()); 1526*4bdc9457SAndroid Build Coastguard Worker std::vector<float> bias_as_float(bias.size()); 1527*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output(batch_size() * ((output_height() * output_width() - 1) * output_channel_stride() + groups() * group_output_channels())); 1528*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels()); 1529*4bdc9457SAndroid Build Coastguard Worker 1530*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 1531*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 1532*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 1533*4bdc9457SAndroid Build Coastguard Worker std::transform(kernel.cbegin(), kernel.cend(), kernel_as_float.begin(), fp16_ieee_to_fp32_value); 1534*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 1535*4bdc9457SAndroid Build Coastguard Worker std::transform(bias.cbegin(), bias.cend(), bias_as_float.begin(), fp16_ieee_to_fp32_value); 1536*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 1537*4bdc9457SAndroid Build Coastguard Worker 1538*4bdc9457SAndroid Build Coastguard Worker 1539*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 1540*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) { 1541*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1542*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 1543*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 1544*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 1545*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 1546*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] = 1547*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(bias[g * group_output_channels() + oc]); 1548*4bdc9457SAndroid Build Coastguard Worker } 1549*4bdc9457SAndroid Build Coastguard Worker } 1550*4bdc9457SAndroid Build Coastguard Worker } 1551*4bdc9457SAndroid Build Coastguard Worker } 1552*4bdc9457SAndroid Build Coastguard Worker } 1553*4bdc9457SAndroid Build Coastguard Worker } else { 1554*4bdc9457SAndroid Build Coastguard Worker std::fill(output_ref.begin(), output_ref.end(), 0.0f); 1555*4bdc9457SAndroid Build Coastguard Worker } 1556*4bdc9457SAndroid Build Coastguard Worker if (depthwise_layout()) { 1557*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(group_input_channels(), 1); 1558*4bdc9457SAndroid Build Coastguard Worker 1559*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1560*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 1561*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 1562*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) { 1563*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top(); 1564*4bdc9457SAndroid Build Coastguard Worker if (iy < input_height()) { 1565*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) { 1566*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left(); 1567*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width()) { 1568*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 1569*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 1570*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] += 1571*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(input[((i * input_height() + iy) * input_width() + ix) * input_channel_stride() + g]) * 1572*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(kernel[((ky * kernel_width() + kx) * groups() + g) * group_output_channels() + oc]); 1573*4bdc9457SAndroid Build Coastguard Worker } 1574*4bdc9457SAndroid Build Coastguard Worker } 1575*4bdc9457SAndroid Build Coastguard Worker } 1576*4bdc9457SAndroid Build Coastguard Worker } 1577*4bdc9457SAndroid Build Coastguard Worker } 1578*4bdc9457SAndroid Build Coastguard Worker } 1579*4bdc9457SAndroid Build Coastguard Worker } 1580*4bdc9457SAndroid Build Coastguard Worker } 1581*4bdc9457SAndroid Build Coastguard Worker } 1582*4bdc9457SAndroid Build Coastguard Worker } else { 1583*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1584*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 1585*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 1586*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) { 1587*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top(); 1588*4bdc9457SAndroid Build Coastguard Worker if (iy < input_height()) { 1589*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) { 1590*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left(); 1591*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width()) { 1592*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 1593*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 1594*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) { 1595*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] += 1596*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(input[((i * input_height() + iy) * input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic]) * 1597*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]); 1598*4bdc9457SAndroid Build Coastguard Worker } 1599*4bdc9457SAndroid Build Coastguard Worker } 1600*4bdc9457SAndroid Build Coastguard Worker } 1601*4bdc9457SAndroid Build Coastguard Worker } 1602*4bdc9457SAndroid Build Coastguard Worker } 1603*4bdc9457SAndroid Build Coastguard Worker } 1604*4bdc9457SAndroid Build Coastguard Worker } 1605*4bdc9457SAndroid Build Coastguard Worker } 1606*4bdc9457SAndroid Build Coastguard Worker } 1607*4bdc9457SAndroid Build Coastguard Worker } 1608*4bdc9457SAndroid Build Coastguard Worker } 1609*4bdc9457SAndroid Build Coastguard Worker 1610*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 1611*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 1612*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 1613*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 1614*4bdc9457SAndroid Build Coastguard Worker const float scaled_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + accumulated_range / 255.0f * float(qmin()))); 1615*4bdc9457SAndroid Build Coastguard Worker const float scaled_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - accumulated_range / 255.0f * float(255 - qmax()))); 1616*4bdc9457SAndroid Build Coastguard Worker const float output_min = scaled_min == scaled_max ? -std::numeric_limits<float>::infinity() : scaled_min; 1617*4bdc9457SAndroid Build Coastguard Worker const float output_max = scaled_min == scaled_max ? +std::numeric_limits<float>::infinity() : scaled_max; 1618*4bdc9457SAndroid Build Coastguard Worker 1619*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 1620*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) { 1621*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min); 1622*4bdc9457SAndroid Build Coastguard Worker } 1623*4bdc9457SAndroid Build Coastguard Worker 1624*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Convolution operator. 1625*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 1626*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op = nullptr; 1627*4bdc9457SAndroid Build Coastguard Worker xnn_caches caches = { 1628*4bdc9457SAndroid Build Coastguard Worker .code_cache = NULL, 1629*4bdc9457SAndroid Build Coastguard Worker .weights_cache = NULL, 1630*4bdc9457SAndroid Build Coastguard Worker }; 1631*4bdc9457SAndroid Build Coastguard Worker xnn_weights_cache weights_cache; 1632*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 1633*4bdc9457SAndroid Build Coastguard Worker xnn_init_weights_cache(&weights_cache); 1634*4bdc9457SAndroid Build Coastguard Worker caches.weights_cache = &weights_cache; 1635*4bdc9457SAndroid Build Coastguard Worker } 1636*4bdc9457SAndroid Build Coastguard Worker 1637*4bdc9457SAndroid Build Coastguard Worker const void* kernel_data = kernel.data(); 1638*4bdc9457SAndroid Build Coastguard Worker const void* bias_data = bias.data(); 1639*4bdc9457SAndroid Build Coastguard Worker if (weights_type() == WeightsType::FP32) { 1640*4bdc9457SAndroid Build Coastguard Worker kernel_data = kernel_as_float.data(); 1641*4bdc9457SAndroid Build Coastguard Worker bias_data = bias_as_float.data(); 1642*4bdc9457SAndroid Build Coastguard Worker } 1643*4bdc9457SAndroid Build Coastguard Worker uint32_t flags = 0; 1644*4bdc9457SAndroid Build Coastguard Worker if (depthwise_layout()) { 1645*4bdc9457SAndroid Build Coastguard Worker flags |= XNN_FLAG_DEPTHWISE_CONVOLUTION; 1646*4bdc9457SAndroid Build Coastguard Worker } 1647*4bdc9457SAndroid Build Coastguard Worker if (padding_tf_same()) { 1648*4bdc9457SAndroid Build Coastguard Worker flags |= XNN_FLAG_TENSORFLOW_SAME_PADDING; 1649*4bdc9457SAndroid Build Coastguard Worker } 1650*4bdc9457SAndroid Build Coastguard Worker if (weights_type() == WeightsType::FP32) { 1651*4bdc9457SAndroid Build Coastguard Worker flags |= XNN_FLAG_FP32_STATIC_WEIGHTS; 1652*4bdc9457SAndroid Build Coastguard Worker } 1653*4bdc9457SAndroid Build Coastguard Worker xnn_status status = xnn_create_convolution2d_nhwc_f16( 1654*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_top(), padding_tf_same() ? 0 : padding_right(), 1655*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_bottom(), padding_tf_same() ? 0 : padding_left(), 1656*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), 1657*4bdc9457SAndroid Build Coastguard Worker subsampling_height(), subsampling_width(), 1658*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), 1659*4bdc9457SAndroid Build Coastguard Worker groups(), group_input_channels(), group_output_channels(), 1660*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(), output_channel_stride(), 1661*4bdc9457SAndroid Build Coastguard Worker kernel_data, has_bias() ? bias_data : nullptr, 1662*4bdc9457SAndroid Build Coastguard Worker output_min, output_max, 1663*4bdc9457SAndroid Build Coastguard Worker flags, 1664*4bdc9457SAndroid Build Coastguard Worker &caches, 1665*4bdc9457SAndroid Build Coastguard Worker &convolution_op); 1666*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 1667*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 1668*4bdc9457SAndroid Build Coastguard Worker } 1669*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 1670*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convolution_op); 1671*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 1672*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1673*4bdc9457SAndroid Build Coastguard Worker xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft)); 1674*4bdc9457SAndroid Build Coastguard Worker } 1675*4bdc9457SAndroid Build Coastguard Worker 1676*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete convolution_op. 1677*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op, xnn_delete_operator); 1678*4bdc9457SAndroid Build Coastguard Worker 1679*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1680*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convolution2d_nhwc_f16( 1681*4bdc9457SAndroid Build Coastguard Worker convolution_op, 1682*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 1683*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 1684*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 1685*4bdc9457SAndroid Build Coastguard Worker 1686*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1687*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convolution_op, nullptr /* thread pool */)); 1688*4bdc9457SAndroid Build Coastguard Worker 1689*4bdc9457SAndroid Build Coastguard Worker VerifyNHWCxF16(output, output_ref, output_min, output_max); 1690*4bdc9457SAndroid Build Coastguard Worker 1691*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 1692*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op2 = nullptr; 1693*4bdc9457SAndroid Build Coastguard Worker size_t old_weights_cache_size = weights_cache.cache.weights.size; 1694*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_create_convolution2d_nhwc_f16( 1695*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_top(), padding_tf_same() ? 0 : padding_right(), 1696*4bdc9457SAndroid Build Coastguard Worker padding_tf_same() ? 0 : padding_bottom(), padding_tf_same() ? 0 : padding_left(), 1697*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), 1698*4bdc9457SAndroid Build Coastguard Worker subsampling_height(), subsampling_width(), 1699*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), 1700*4bdc9457SAndroid Build Coastguard Worker groups(), group_input_channels(), group_output_channels(), 1701*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(), output_channel_stride(), 1702*4bdc9457SAndroid Build Coastguard Worker kernel_data, has_bias() ? bias_data : nullptr, 1703*4bdc9457SAndroid Build Coastguard Worker output_min, output_max, 1704*4bdc9457SAndroid Build Coastguard Worker flags, 1705*4bdc9457SAndroid Build Coastguard Worker &caches, 1706*4bdc9457SAndroid Build Coastguard Worker &convolution_op2)); 1707*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convolution_op2); 1708*4bdc9457SAndroid Build Coastguard Worker 1709*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete convolution_op. 1710*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op2, xnn_delete_operator); 1711*4bdc9457SAndroid Build Coastguard Worker 1712*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output2(output.size(), UINT16_C(0x7E00) /* NaN */); 1713*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1714*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convolution2d_nhwc_f16( 1715*4bdc9457SAndroid Build Coastguard Worker convolution_op2, 1716*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 1717*4bdc9457SAndroid Build Coastguard Worker input.data(), output2.data(), 1718*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 1719*4bdc9457SAndroid Build Coastguard Worker 1720*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1721*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convolution_op2, nullptr /* thread pool */)); 1722*4bdc9457SAndroid Build Coastguard Worker 1723*4bdc9457SAndroid Build Coastguard Worker VerifyNHWCxF16(output2, output_ref, output_min, output_max); 1724*4bdc9457SAndroid Build Coastguard Worker VerifyWeightsCache(weights_cache, old_weights_cache_size); 1725*4bdc9457SAndroid Build Coastguard Worker xnn_release_weights_cache(&weights_cache); 1726*4bdc9457SAndroid Build Coastguard Worker } 1727*4bdc9457SAndroid Build Coastguard Worker } 1728*4bdc9457SAndroid Build Coastguard Worker } 1729*4bdc9457SAndroid Build Coastguard Worker VerifyNHWCxF16(const std::vector<uint16_t> & output,const std::vector<float> & output_ref,const float output_min,const float output_max)1730*4bdc9457SAndroid Build Coastguard Worker void VerifyNHWCxF16(const std::vector<uint16_t> &output, 1731*4bdc9457SAndroid Build Coastguard Worker const std::vector<float> &output_ref, 1732*4bdc9457SAndroid Build Coastguard Worker const float output_min, const float output_max) const { 1733*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1734*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 1735*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 1736*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 1737*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) { 1738*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), output_min) 1739*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 1740*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), output_max) 1741*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 1742*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c], fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), std::max(1.0e-4f, std::abs(output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c]) * 1.0e-2f)) 1743*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 1744*4bdc9457SAndroid Build Coastguard Worker } 1745*4bdc9457SAndroid Build Coastguard Worker } 1746*4bdc9457SAndroid Build Coastguard Worker } 1747*4bdc9457SAndroid Build Coastguard Worker } 1748*4bdc9457SAndroid Build Coastguard Worker } 1749*4bdc9457SAndroid Build Coastguard Worker } 1750*4bdc9457SAndroid Build Coastguard Worker TestNCHWxF32()1751*4bdc9457SAndroid Build Coastguard Worker void TestNCHWxF32() { 1752*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_type(), WeightsType::Default); 1753*4bdc9457SAndroid Build Coastguard Worker 1754*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 1755*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 1756*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.1f, 1.0f); 1757*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> pdist; 1758*4bdc9457SAndroid Build Coastguard Worker 1759*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input(2 * XNN_EXTRA_BYTES / sizeof(float) + 1760*4bdc9457SAndroid Build Coastguard Worker ((batch_size() - 1) * input_channel_stride() + groups() * group_input_channels()) * input_height() * input_width()); 1761*4bdc9457SAndroid Build Coastguard Worker std::vector<float> kernel( 1762*4bdc9457SAndroid Build Coastguard Worker groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels()); 1763*4bdc9457SAndroid Build Coastguard Worker std::vector<float> bias(groups() * group_output_channels()); 1764*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output( 1765*4bdc9457SAndroid Build Coastguard Worker ((batch_size() - 1) * output_channel_stride() + groups() * group_output_channels()) * output_height() * output_width()); 1766*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * groups() * group_output_channels() * output_height() * output_width()); 1767*4bdc9457SAndroid Build Coastguard Worker 1768*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 1769*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 1770*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return f32dist(rng); }); 1771*4bdc9457SAndroid Build Coastguard Worker for (float& k : kernel) { 1772*4bdc9457SAndroid Build Coastguard Worker if (pdist(rng) <= sparsity()) { 1773*4bdc9457SAndroid Build Coastguard Worker k = 0.0f; 1774*4bdc9457SAndroid Build Coastguard Worker } 1775*4bdc9457SAndroid Build Coastguard Worker } 1776*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); 1777*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), nanf("")); 1778*4bdc9457SAndroid Build Coastguard Worker 1779*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 1780*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) { 1781*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1782*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 1783*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 1784*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 1785*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 1786*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * groups() + g) * group_output_channels() + oc) * output_height() + oy) * output_width() + ox] = 1787*4bdc9457SAndroid Build Coastguard Worker bias[g * group_output_channels() + oc]; 1788*4bdc9457SAndroid Build Coastguard Worker } 1789*4bdc9457SAndroid Build Coastguard Worker } 1790*4bdc9457SAndroid Build Coastguard Worker } 1791*4bdc9457SAndroid Build Coastguard Worker } 1792*4bdc9457SAndroid Build Coastguard Worker } 1793*4bdc9457SAndroid Build Coastguard Worker } else { 1794*4bdc9457SAndroid Build Coastguard Worker std::fill(output_ref.begin(), output_ref.end(), 0.0f); 1795*4bdc9457SAndroid Build Coastguard Worker } 1796*4bdc9457SAndroid Build Coastguard Worker if (force_nhwc_input()) { 1797*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1798*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 1799*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 1800*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) { 1801*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top(); 1802*4bdc9457SAndroid Build Coastguard Worker if (iy < input_height()) { 1803*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) { 1804*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left(); 1805*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width()) { 1806*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 1807*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 1808*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) { 1809*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * groups() + g) * group_output_channels() + oc) * output_height() + oy) * output_width() + ox] += 1810*4bdc9457SAndroid Build Coastguard Worker input[((((i * input_height() + iy) * input_width() + ix) * groups() + g) * group_input_channels() + ic)] * 1811*4bdc9457SAndroid Build Coastguard Worker kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]; 1812*4bdc9457SAndroid Build Coastguard Worker } 1813*4bdc9457SAndroid Build Coastguard Worker } 1814*4bdc9457SAndroid Build Coastguard Worker } 1815*4bdc9457SAndroid Build Coastguard Worker } 1816*4bdc9457SAndroid Build Coastguard Worker } 1817*4bdc9457SAndroid Build Coastguard Worker } 1818*4bdc9457SAndroid Build Coastguard Worker } 1819*4bdc9457SAndroid Build Coastguard Worker } 1820*4bdc9457SAndroid Build Coastguard Worker } 1821*4bdc9457SAndroid Build Coastguard Worker } 1822*4bdc9457SAndroid Build Coastguard Worker } else if (depthwise_layout()) { 1823*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(group_input_channels(), 1); 1824*4bdc9457SAndroid Build Coastguard Worker 1825*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1826*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 1827*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 1828*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) { 1829*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top(); 1830*4bdc9457SAndroid Build Coastguard Worker if (iy < input_height()) { 1831*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) { 1832*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left(); 1833*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width()) { 1834*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 1835*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 1836*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * groups() + g) * group_output_channels() + oc) * output_height() + oy) * output_width() + ox] += 1837*4bdc9457SAndroid Build Coastguard Worker input[((i * input_channel_stride() + g) * input_height() + iy) * input_width() + ix] * 1838*4bdc9457SAndroid Build Coastguard Worker kernel[((ky * kernel_width() + kx) * groups() + g) * group_output_channels() + oc]; 1839*4bdc9457SAndroid Build Coastguard Worker } 1840*4bdc9457SAndroid Build Coastguard Worker } 1841*4bdc9457SAndroid Build Coastguard Worker } 1842*4bdc9457SAndroid Build Coastguard Worker } 1843*4bdc9457SAndroid Build Coastguard Worker } 1844*4bdc9457SAndroid Build Coastguard Worker } 1845*4bdc9457SAndroid Build Coastguard Worker } 1846*4bdc9457SAndroid Build Coastguard Worker } 1847*4bdc9457SAndroid Build Coastguard Worker } 1848*4bdc9457SAndroid Build Coastguard Worker } else { 1849*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1850*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 1851*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 1852*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) { 1853*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top(); 1854*4bdc9457SAndroid Build Coastguard Worker if (iy < input_height()) { 1855*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) { 1856*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left(); 1857*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width()) { 1858*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 1859*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 1860*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) { 1861*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * groups() + g) * group_output_channels() + oc) * output_height() + oy) * output_width() + ox] += 1862*4bdc9457SAndroid Build Coastguard Worker input[((i * input_channel_stride() + g * group_input_channels() + ic) * input_height() + iy) * input_width() + ix] * 1863*4bdc9457SAndroid Build Coastguard Worker kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]; 1864*4bdc9457SAndroid Build Coastguard Worker } 1865*4bdc9457SAndroid Build Coastguard Worker } 1866*4bdc9457SAndroid Build Coastguard Worker } 1867*4bdc9457SAndroid Build Coastguard Worker } 1868*4bdc9457SAndroid Build Coastguard Worker } 1869*4bdc9457SAndroid Build Coastguard Worker } 1870*4bdc9457SAndroid Build Coastguard Worker } 1871*4bdc9457SAndroid Build Coastguard Worker } 1872*4bdc9457SAndroid Build Coastguard Worker } 1873*4bdc9457SAndroid Build Coastguard Worker } 1874*4bdc9457SAndroid Build Coastguard Worker } 1875*4bdc9457SAndroid Build Coastguard Worker 1876*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 1877*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 1878*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 1879*4bdc9457SAndroid Build Coastguard Worker 1880*4bdc9457SAndroid Build Coastguard Worker const float output_min = qmin() == 0 ? -std::numeric_limits<float>::infinity() : 1881*4bdc9457SAndroid Build Coastguard Worker accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin()); 1882*4bdc9457SAndroid Build Coastguard Worker const float output_max = qmax() == 255 ? std::numeric_limits<float>::infinity() : 1883*4bdc9457SAndroid Build Coastguard Worker accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax()); 1884*4bdc9457SAndroid Build Coastguard Worker 1885*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 1886*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) { 1887*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min); 1888*4bdc9457SAndroid Build Coastguard Worker } 1889*4bdc9457SAndroid Build Coastguard Worker 1890*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Convolution operator. 1891*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 1892*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op = nullptr; 1893*4bdc9457SAndroid Build Coastguard Worker xnn_caches caches = { 1894*4bdc9457SAndroid Build Coastguard Worker .code_cache = NULL, 1895*4bdc9457SAndroid Build Coastguard Worker .weights_cache = NULL, 1896*4bdc9457SAndroid Build Coastguard Worker }; 1897*4bdc9457SAndroid Build Coastguard Worker xnn_weights_cache weights_cache; 1898*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 1899*4bdc9457SAndroid Build Coastguard Worker xnn_init_weights_cache(&weights_cache); 1900*4bdc9457SAndroid Build Coastguard Worker caches.weights_cache = &weights_cache; 1901*4bdc9457SAndroid Build Coastguard Worker } 1902*4bdc9457SAndroid Build Coastguard Worker 1903*4bdc9457SAndroid Build Coastguard Worker xnn_status status = xnn_create_convolution2d_nchw_f32( 1904*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(), 1905*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), 1906*4bdc9457SAndroid Build Coastguard Worker subsampling_height(), subsampling_width(), 1907*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), 1908*4bdc9457SAndroid Build Coastguard Worker groups(), group_input_channels(), group_output_channels(), 1909*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(), output_channel_stride(), 1910*4bdc9457SAndroid Build Coastguard Worker kernel.data(), has_bias() ? bias.data() : nullptr, 1911*4bdc9457SAndroid Build Coastguard Worker output_min, output_max, 1912*4bdc9457SAndroid Build Coastguard Worker (depthwise_layout() ? XNN_FLAG_DEPTHWISE_CONVOLUTION : 0) | (force_nhwc_input() ? XNN_FLAG_INPUT_NHWC : 0), 1913*4bdc9457SAndroid Build Coastguard Worker &caches, 1914*4bdc9457SAndroid Build Coastguard Worker &convolution_op); 1915*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_parameter) { 1916*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 1917*4bdc9457SAndroid Build Coastguard Worker } 1918*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 1919*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convolution_op); 1920*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 1921*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1922*4bdc9457SAndroid Build Coastguard Worker xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft)); 1923*4bdc9457SAndroid Build Coastguard Worker } 1924*4bdc9457SAndroid Build Coastguard Worker 1925*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete convolution_op. 1926*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op, xnn_delete_operator); 1927*4bdc9457SAndroid Build Coastguard Worker 1928*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1929*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convolution2d_nchw_f32( 1930*4bdc9457SAndroid Build Coastguard Worker convolution_op, 1931*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 1932*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 1933*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 1934*4bdc9457SAndroid Build Coastguard Worker 1935*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1936*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convolution_op, nullptr /* thread pool */)); 1937*4bdc9457SAndroid Build Coastguard Worker 1938*4bdc9457SAndroid Build Coastguard Worker VerifyNCHWxF32(output, output_ref, output_min, output_max); 1939*4bdc9457SAndroid Build Coastguard Worker 1940*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) { 1941*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op2 = nullptr; 1942*4bdc9457SAndroid Build Coastguard Worker size_t old_weights_cache_size = weights_cache.cache.weights.size; 1943*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ( 1944*4bdc9457SAndroid Build Coastguard Worker xnn_status_success, 1945*4bdc9457SAndroid Build Coastguard Worker xnn_create_convolution2d_nchw_f32( 1946*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), 1947*4bdc9457SAndroid Build Coastguard Worker padding_left(), kernel_height(), kernel_width(), 1948*4bdc9457SAndroid Build Coastguard Worker subsampling_height(), subsampling_width(), dilation_height(), 1949*4bdc9457SAndroid Build Coastguard Worker dilation_width(), groups(), group_input_channels(), 1950*4bdc9457SAndroid Build Coastguard Worker group_output_channels(), input_channel_stride(), 1951*4bdc9457SAndroid Build Coastguard Worker output_channel_stride(), kernel.data(), 1952*4bdc9457SAndroid Build Coastguard Worker has_bias() ? bias.data() : nullptr, output_min, output_max, 1953*4bdc9457SAndroid Build Coastguard Worker (depthwise_layout() ? XNN_FLAG_DEPTHWISE_CONVOLUTION : 0) | 1954*4bdc9457SAndroid Build Coastguard Worker (force_nhwc_input() ? XNN_FLAG_INPUT_NHWC : 0), 1955*4bdc9457SAndroid Build Coastguard Worker &caches, &convolution_op2)); 1956*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convolution_op2); 1957*4bdc9457SAndroid Build Coastguard Worker 1958*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete convolution_op2. 1959*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op2, xnn_delete_operator); 1960*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output2(output.size(), nanf("")); 1961*4bdc9457SAndroid Build Coastguard Worker 1962*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1963*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convolution2d_nchw_f32( 1964*4bdc9457SAndroid Build Coastguard Worker convolution_op2, 1965*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 1966*4bdc9457SAndroid Build Coastguard Worker input.data(), output2.data(), 1967*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 1968*4bdc9457SAndroid Build Coastguard Worker 1969*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 1970*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convolution_op2, nullptr /* thread pool */)); 1971*4bdc9457SAndroid Build Coastguard Worker 1972*4bdc9457SAndroid Build Coastguard Worker VerifyNCHWxF32(output2, output_ref, output_min, output_max); 1973*4bdc9457SAndroid Build Coastguard Worker if (IsSpmm()) { 1974*4bdc9457SAndroid Build Coastguard Worker VerifyWeightsCacheUnused(weights_cache); 1975*4bdc9457SAndroid Build Coastguard Worker } else { 1976*4bdc9457SAndroid Build Coastguard Worker VerifyWeightsCache(weights_cache, old_weights_cache_size); 1977*4bdc9457SAndroid Build Coastguard Worker } 1978*4bdc9457SAndroid Build Coastguard Worker xnn_release_weights_cache(&weights_cache); 1979*4bdc9457SAndroid Build Coastguard Worker } 1980*4bdc9457SAndroid Build Coastguard Worker } 1981*4bdc9457SAndroid Build Coastguard Worker } 1982*4bdc9457SAndroid Build Coastguard Worker VerifyNCHWxF32(const std::vector<float> & output,const std::vector<float> & output_ref,const float output_min,const float output_max)1983*4bdc9457SAndroid Build Coastguard Worker void VerifyNCHWxF32(const std::vector<float> &output, 1984*4bdc9457SAndroid Build Coastguard Worker const std::vector<float> &output_ref, 1985*4bdc9457SAndroid Build Coastguard Worker const float output_min, const float output_max) const { 1986*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 1987*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 1988*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 1989*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 1990*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) { 1991*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output[((i * output_channel_stride() + g * group_output_channels() + c) * output_height() + y) * output_width() + x], output_min) 1992*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c << ", image = " << i; 1993*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(output[((i * output_channel_stride() + g * group_output_channels() + c) * output_height() + y) * output_width() + x], output_max) 1994*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c << ", image = " << i; 1995*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 1996*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * groups() + g) * group_output_channels() + c) * output_height() + y) * output_width() + x], 1997*4bdc9457SAndroid Build Coastguard Worker output[((i * output_channel_stride() + g * group_output_channels() + c) * output_height() + y) * output_width() + x], 1998*4bdc9457SAndroid Build Coastguard Worker 1.0e-4 * std::abs(output_ref[(((i * groups() + g) * group_output_channels() + c) * output_height() + y) * output_width() + x])) 1999*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c << ", image = " << i; 2000*4bdc9457SAndroid Build Coastguard Worker } 2001*4bdc9457SAndroid Build Coastguard Worker } 2002*4bdc9457SAndroid Build Coastguard Worker } 2003*4bdc9457SAndroid Build Coastguard Worker } 2004*4bdc9457SAndroid Build Coastguard Worker } 2005*4bdc9457SAndroid Build Coastguard Worker } 2006*4bdc9457SAndroid Build Coastguard Worker TestSetupNHWCxQC8()2007*4bdc9457SAndroid Build Coastguard Worker void TestSetupNHWCxQC8() const { 2008*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_type(), WeightsType::Default); 2009*4bdc9457SAndroid Build Coastguard Worker 2010*4bdc9457SAndroid Build Coastguard Worker ASSERT_FALSE(depthwise_layout()); 2011*4bdc9457SAndroid Build Coastguard Worker 2012*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 2013*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 2014*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i32dist(-10000, 10000); 2015*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i8dist( 2016*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()); 2017*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> w8dist( 2018*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()); 2019*4bdc9457SAndroid Build Coastguard Worker 2020*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) + std::max( 2021*4bdc9457SAndroid Build Coastguard Worker batch_size() * ((input_height() * input_width() - 1) * input_channel_stride() + groups() * group_input_channels()), 2022*4bdc9457SAndroid Build Coastguard Worker next_batch_size() * ((next_input_height() * next_input_width() - 1) * input_channel_stride() + groups() * group_input_channels()))); 2023*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels()); 2024*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> bias(groups() * group_output_channels()); 2025*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output(std::max( 2026*4bdc9457SAndroid Build Coastguard Worker batch_size() * ((output_height() * output_width() - 1) * output_channel_stride() + groups() * group_output_channels()), 2027*4bdc9457SAndroid Build Coastguard Worker next_batch_size() * ((next_output_height() * next_output_width() - 1) * output_channel_stride() + groups() * group_output_channels()))); 2028*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> accumulators(batch_size() * output_height() * output_width() * groups() * group_output_channels()); 2029*4bdc9457SAndroid Build Coastguard Worker std::vector<double> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels()); 2030*4bdc9457SAndroid Build Coastguard Worker std::vector<float> requantization_scales(groups() * group_output_channels()); 2031*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> next_accumulators(next_batch_size() * next_output_height() * next_output_width() * groups() * group_output_channels()); 2032*4bdc9457SAndroid Build Coastguard Worker std::vector<double> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * groups() * group_output_channels()); 2033*4bdc9457SAndroid Build Coastguard Worker std::vector<float> next_requantization_scales(groups() * group_output_channels()); 2034*4bdc9457SAndroid Build Coastguard Worker 2035*4bdc9457SAndroid Build Coastguard Worker const int8_t input_zero_point = -1; 2036*4bdc9457SAndroid Build Coastguard Worker const int8_t output_zero_point = -1; 2037*4bdc9457SAndroid Build Coastguard Worker 2038*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 2039*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); }); 2040*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); }); 2041*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); }); 2042*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), INT8_C(0xA5)); 2043*4bdc9457SAndroid Build Coastguard Worker 2044*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without renormalization. 2045*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) { 2046*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 2047*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 2048*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 2049*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2050*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 2051*4bdc9457SAndroid Build Coastguard Worker accumulators[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] = 2052*4bdc9457SAndroid Build Coastguard Worker bias[g * group_output_channels() + oc]; 2053*4bdc9457SAndroid Build Coastguard Worker } 2054*4bdc9457SAndroid Build Coastguard Worker } 2055*4bdc9457SAndroid Build Coastguard Worker } 2056*4bdc9457SAndroid Build Coastguard Worker } 2057*4bdc9457SAndroid Build Coastguard Worker } 2058*4bdc9457SAndroid Build Coastguard Worker } else { 2059*4bdc9457SAndroid Build Coastguard Worker std::fill(accumulators.begin(), accumulators.end(), 0); 2060*4bdc9457SAndroid Build Coastguard Worker } 2061*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 2062*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 2063*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 2064*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) { 2065*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top(); 2066*4bdc9457SAndroid Build Coastguard Worker if (iy < input_height()) { 2067*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) { 2068*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left(); 2069*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width()) { 2070*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2071*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 2072*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) { 2073*4bdc9457SAndroid Build Coastguard Worker accumulators[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] += 2074*4bdc9457SAndroid Build Coastguard Worker (int32_t(input[((i * input_height() + iy) * input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic]) - int32_t(input_zero_point)) * 2075*4bdc9457SAndroid Build Coastguard Worker int32_t(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]); 2076*4bdc9457SAndroid Build Coastguard Worker } 2077*4bdc9457SAndroid Build Coastguard Worker } 2078*4bdc9457SAndroid Build Coastguard Worker } 2079*4bdc9457SAndroid Build Coastguard Worker } 2080*4bdc9457SAndroid Build Coastguard Worker } 2081*4bdc9457SAndroid Build Coastguard Worker } 2082*4bdc9457SAndroid Build Coastguard Worker } 2083*4bdc9457SAndroid Build Coastguard Worker } 2084*4bdc9457SAndroid Build Coastguard Worker } 2085*4bdc9457SAndroid Build Coastguard Worker } 2086*4bdc9457SAndroid Build Coastguard Worker 2087*4bdc9457SAndroid Build Coastguard Worker // Compute renormalization parameters. 2088*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < groups() * group_output_channels(); c++) { 2089*4bdc9457SAndroid Build Coastguard Worker int32_t accumulated_min = accumulators[c]; 2090*4bdc9457SAndroid Build Coastguard Worker int32_t accumulated_max = accumulators[c]; 2091*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < batch_size() * output_height() * output_width(); px++) { 2092*4bdc9457SAndroid Build Coastguard Worker accumulated_min = std::min(accumulated_min, accumulators[px * groups() * group_output_channels() + c]); 2093*4bdc9457SAndroid Build Coastguard Worker accumulated_max = std::max(accumulated_max, accumulators[px * groups() * group_output_channels() + c]); 2094*4bdc9457SAndroid Build Coastguard Worker } 2095*4bdc9457SAndroid Build Coastguard Worker 2096*4bdc9457SAndroid Build Coastguard Worker float requantization_scale = 0x1.0p-32f; 2097*4bdc9457SAndroid Build Coastguard Worker if (accumulated_max != 0) { 2098*4bdc9457SAndroid Build Coastguard Worker requantization_scale = std::max(requantization_scale, 2099*4bdc9457SAndroid Build Coastguard Worker float(int32_t(std::numeric_limits<int8_t>::max()) - int32_t(output_zero_point)) / float(accumulated_max)); 2100*4bdc9457SAndroid Build Coastguard Worker } 2101*4bdc9457SAndroid Build Coastguard Worker if (accumulated_min != 0) { 2102*4bdc9457SAndroid Build Coastguard Worker requantization_scale = std::max(requantization_scale, 2103*4bdc9457SAndroid Build Coastguard Worker float(int32_t(std::numeric_limits<int8_t>::min()) - int32_t(output_zero_point)) / float(accumulated_min)); 2104*4bdc9457SAndroid Build Coastguard Worker } 2105*4bdc9457SAndroid Build Coastguard Worker requantization_scale = std::min(requantization_scale, 0x1.FFFFFEp-1f); 2106*4bdc9457SAndroid Build Coastguard Worker 2107*4bdc9457SAndroid Build Coastguard Worker requantization_scales[c] = requantization_scale; 2108*4bdc9457SAndroid Build Coastguard Worker } 2109*4bdc9457SAndroid Build Coastguard Worker 2110*4bdc9457SAndroid Build Coastguard Worker // Renormalize reference results. 2111*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < groups() * group_output_channels(); c++) { 2112*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < batch_size() * output_height() * output_width(); px++) { 2113*4bdc9457SAndroid Build Coastguard Worker output_ref[px * groups() * group_output_channels() + c] = double(int32_t(output_zero_point)) + 2114*4bdc9457SAndroid Build Coastguard Worker double(accumulators[px * groups() * group_output_channels() + c]) * double(requantization_scales[c]); 2115*4bdc9457SAndroid Build Coastguard Worker } 2116*4bdc9457SAndroid Build Coastguard Worker } 2117*4bdc9457SAndroid Build Coastguard Worker std::transform(output_ref.cbegin(), output_ref.cend(), output_ref.begin(), 2118*4bdc9457SAndroid Build Coastguard Worker [this](double x) -> double { 2119*4bdc9457SAndroid Build Coastguard Worker return std::max<double>(std::min<double>(x, double(qmax() - 0x80)), double(qmin() - 0x80)); 2120*4bdc9457SAndroid Build Coastguard Worker }); 2121*4bdc9457SAndroid Build Coastguard Worker 2122*4bdc9457SAndroid Build Coastguard Worker // Create, setup, and run Convolution operator once. 2123*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 2124*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op = nullptr; 2125*4bdc9457SAndroid Build Coastguard Worker 2126*4bdc9457SAndroid Build Coastguard Worker xnn_status status = xnn_create_convolution2d_nhwc_qc8( 2127*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(), 2128*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), 2129*4bdc9457SAndroid Build Coastguard Worker subsampling_height(), subsampling_width(), 2130*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), 2131*4bdc9457SAndroid Build Coastguard Worker groups(), group_input_channels(), group_output_channels(), 2132*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(), output_channel_stride(), 2133*4bdc9457SAndroid Build Coastguard Worker input_zero_point, 1.0f /* input scale */, requantization_scales.data(), 2134*4bdc9457SAndroid Build Coastguard Worker kernel.data(), has_bias() ? bias.data() : nullptr, 2135*4bdc9457SAndroid Build Coastguard Worker output_zero_point, 1.0f /* output scale */, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80), 2136*4bdc9457SAndroid Build Coastguard Worker 0, NULL, &convolution_op); 2137*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 2138*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 2139*4bdc9457SAndroid Build Coastguard Worker } 2140*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 2141*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convolution_op); 2142*4bdc9457SAndroid Build Coastguard Worker 2143*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete convolution_op. 2144*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op, xnn_delete_operator); 2145*4bdc9457SAndroid Build Coastguard Worker 2146*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 2147*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convolution2d_nhwc_qc8( 2148*4bdc9457SAndroid Build Coastguard Worker convolution_op, 2149*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 2150*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 2151*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 2152*4bdc9457SAndroid Build Coastguard Worker 2153*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 2154*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convolution_op, nullptr /* thread pool */)); 2155*4bdc9457SAndroid Build Coastguard Worker 2156*4bdc9457SAndroid Build Coastguard Worker // Verify results of the first run. 2157*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 2158*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 2159*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 2160*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2161*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) { 2162*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmax() - 0x80)) 2163*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2164*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmin() - 0x80)) 2165*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2166*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 2167*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c], 2168*4bdc9457SAndroid Build Coastguard Worker double(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), 2169*4bdc9457SAndroid Build Coastguard Worker 0.9) 2170*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2171*4bdc9457SAndroid Build Coastguard Worker } 2172*4bdc9457SAndroid Build Coastguard Worker } 2173*4bdc9457SAndroid Build Coastguard Worker } 2174*4bdc9457SAndroid Build Coastguard Worker } 2175*4bdc9457SAndroid Build Coastguard Worker } 2176*4bdc9457SAndroid Build Coastguard Worker 2177*4bdc9457SAndroid Build Coastguard Worker // Re-generate data for the second run. 2178*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); }); 2179*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), INT8_C(0xA5)); 2180*4bdc9457SAndroid Build Coastguard Worker 2181*4bdc9457SAndroid Build Coastguard Worker // Compute reference results for the second run, including renormalization. 2182*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) { 2183*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 2184*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) { 2185*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) { 2186*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2187*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 2188*4bdc9457SAndroid Build Coastguard Worker next_accumulators[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] = 2189*4bdc9457SAndroid Build Coastguard Worker bias[g * group_output_channels() + oc]; 2190*4bdc9457SAndroid Build Coastguard Worker } 2191*4bdc9457SAndroid Build Coastguard Worker } 2192*4bdc9457SAndroid Build Coastguard Worker } 2193*4bdc9457SAndroid Build Coastguard Worker } 2194*4bdc9457SAndroid Build Coastguard Worker } 2195*4bdc9457SAndroid Build Coastguard Worker } else { 2196*4bdc9457SAndroid Build Coastguard Worker std::fill(next_accumulators.begin(), next_accumulators.end(), 0); 2197*4bdc9457SAndroid Build Coastguard Worker } 2198*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 2199*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) { 2200*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) { 2201*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) { 2202*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top(); 2203*4bdc9457SAndroid Build Coastguard Worker if (iy < next_input_height()) { 2204*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) { 2205*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left(); 2206*4bdc9457SAndroid Build Coastguard Worker if (ix < next_input_width()) { 2207*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2208*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 2209*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) { 2210*4bdc9457SAndroid Build Coastguard Worker next_accumulators[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] += 2211*4bdc9457SAndroid Build Coastguard Worker (int32_t(input[((i * next_input_height() + iy) * next_input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic]) - int32_t(input_zero_point)) * 2212*4bdc9457SAndroid Build Coastguard Worker int32_t(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]); 2213*4bdc9457SAndroid Build Coastguard Worker } 2214*4bdc9457SAndroid Build Coastguard Worker } 2215*4bdc9457SAndroid Build Coastguard Worker } 2216*4bdc9457SAndroid Build Coastguard Worker } 2217*4bdc9457SAndroid Build Coastguard Worker } 2218*4bdc9457SAndroid Build Coastguard Worker } 2219*4bdc9457SAndroid Build Coastguard Worker } 2220*4bdc9457SAndroid Build Coastguard Worker } 2221*4bdc9457SAndroid Build Coastguard Worker } 2222*4bdc9457SAndroid Build Coastguard Worker } 2223*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < groups() * group_output_channels(); c++) { 2224*4bdc9457SAndroid Build Coastguard Worker for (size_t px = 0; px < next_batch_size() * next_output_height() * next_output_width(); px++) { 2225*4bdc9457SAndroid Build Coastguard Worker next_output_ref[px * groups() * group_output_channels() + c] = double(int32_t(output_zero_point)) + 2226*4bdc9457SAndroid Build Coastguard Worker double(next_accumulators[px * groups() * group_output_channels() + c]) * double(requantization_scales[c]); 2227*4bdc9457SAndroid Build Coastguard Worker } 2228*4bdc9457SAndroid Build Coastguard Worker } 2229*4bdc9457SAndroid Build Coastguard Worker std::transform(next_output_ref.cbegin(), next_output_ref.cend(), next_output_ref.begin(), 2230*4bdc9457SAndroid Build Coastguard Worker [this](double x) -> double { 2231*4bdc9457SAndroid Build Coastguard Worker return std::max<double>(std::min<double>(x, double(qmax() - 0x80)), double(qmin() - 0x80)); 2232*4bdc9457SAndroid Build Coastguard Worker }); 2233*4bdc9457SAndroid Build Coastguard Worker 2234*4bdc9457SAndroid Build Coastguard Worker // Setup and run Convolution operator the second time, and destroy the operator. 2235*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 2236*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convolution2d_nhwc_qc8( 2237*4bdc9457SAndroid Build Coastguard Worker convolution_op, 2238*4bdc9457SAndroid Build Coastguard Worker next_batch_size(), next_input_height(), next_input_width(), 2239*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 2240*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 2241*4bdc9457SAndroid Build Coastguard Worker 2242*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 2243*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convolution_op, nullptr /* thread pool */)); 2244*4bdc9457SAndroid Build Coastguard Worker 2245*4bdc9457SAndroid Build Coastguard Worker // Verify results of the second run. 2246*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 2247*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < next_output_height(); y++) { 2248*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < next_output_width(); x++) { 2249*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2250*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) { 2251*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmax() - 0x80)) 2252*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2253*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmin() - 0x80)) 2254*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2255*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 2256*4bdc9457SAndroid Build Coastguard Worker next_output_ref[(((i * next_output_height() + y) * next_output_width() + x) * groups() + g) * group_output_channels() + c], 2257*4bdc9457SAndroid Build Coastguard Worker double(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), 2258*4bdc9457SAndroid Build Coastguard Worker 0.9) 2259*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2260*4bdc9457SAndroid Build Coastguard Worker } 2261*4bdc9457SAndroid Build Coastguard Worker } 2262*4bdc9457SAndroid Build Coastguard Worker } 2263*4bdc9457SAndroid Build Coastguard Worker } 2264*4bdc9457SAndroid Build Coastguard Worker } 2265*4bdc9457SAndroid Build Coastguard Worker } 2266*4bdc9457SAndroid Build Coastguard Worker } 2267*4bdc9457SAndroid Build Coastguard Worker TestSetupNHWCxQS8()2268*4bdc9457SAndroid Build Coastguard Worker void TestSetupNHWCxQS8() const { 2269*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_type(), WeightsType::Default); 2270*4bdc9457SAndroid Build Coastguard Worker 2271*4bdc9457SAndroid Build Coastguard Worker ASSERT_FALSE(depthwise_layout()); 2272*4bdc9457SAndroid Build Coastguard Worker 2273*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 2274*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 2275*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i32dist(-10000, 10000); 2276*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i8dist( 2277*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()); 2278*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> w8dist( 2279*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()); 2280*4bdc9457SAndroid Build Coastguard Worker 2281*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) + std::max( 2282*4bdc9457SAndroid Build Coastguard Worker batch_size() * ((input_height() * input_width() - 1) * input_channel_stride() + groups() * group_input_channels()), 2283*4bdc9457SAndroid Build Coastguard Worker next_batch_size() * ((next_input_height() * next_input_width() - 1) * input_channel_stride() + groups() * group_input_channels()))); 2284*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels()); 2285*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> bias(groups() * group_output_channels()); 2286*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output(std::max( 2287*4bdc9457SAndroid Build Coastguard Worker batch_size() * ((output_height() * output_width() - 1) * output_channel_stride() + groups() * group_output_channels()), 2288*4bdc9457SAndroid Build Coastguard Worker next_batch_size() * ((next_output_height() * next_output_width() - 1) * output_channel_stride() + groups() * group_output_channels()))); 2289*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> accumulators(batch_size() * output_height() * output_width() * groups() * group_output_channels()); 2290*4bdc9457SAndroid Build Coastguard Worker std::vector<double> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels()); 2291*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> next_accumulators(next_batch_size() * next_output_height() * next_output_width() * groups() * group_output_channels()); 2292*4bdc9457SAndroid Build Coastguard Worker std::vector<double> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * groups() * group_output_channels()); 2293*4bdc9457SAndroid Build Coastguard Worker 2294*4bdc9457SAndroid Build Coastguard Worker const int8_t input_zero_point = -1; 2295*4bdc9457SAndroid Build Coastguard Worker 2296*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 2297*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); }); 2298*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); }); 2299*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); }); 2300*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), INT8_C(0xA5)); 2301*4bdc9457SAndroid Build Coastguard Worker 2302*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without renormalization. 2303*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) { 2304*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 2305*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 2306*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 2307*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2308*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 2309*4bdc9457SAndroid Build Coastguard Worker accumulators[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] = 2310*4bdc9457SAndroid Build Coastguard Worker bias[g * group_output_channels() + oc]; 2311*4bdc9457SAndroid Build Coastguard Worker } 2312*4bdc9457SAndroid Build Coastguard Worker } 2313*4bdc9457SAndroid Build Coastguard Worker } 2314*4bdc9457SAndroid Build Coastguard Worker } 2315*4bdc9457SAndroid Build Coastguard Worker } 2316*4bdc9457SAndroid Build Coastguard Worker } else { 2317*4bdc9457SAndroid Build Coastguard Worker std::fill(accumulators.begin(), accumulators.end(), 0); 2318*4bdc9457SAndroid Build Coastguard Worker } 2319*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 2320*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 2321*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 2322*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) { 2323*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top(); 2324*4bdc9457SAndroid Build Coastguard Worker if (iy < input_height()) { 2325*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) { 2326*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left(); 2327*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width()) { 2328*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2329*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 2330*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) { 2331*4bdc9457SAndroid Build Coastguard Worker accumulators[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] += 2332*4bdc9457SAndroid Build Coastguard Worker (int32_t(input[((i * input_height() + iy) * input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic]) - int32_t(input_zero_point)) * 2333*4bdc9457SAndroid Build Coastguard Worker int32_t(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]); 2334*4bdc9457SAndroid Build Coastguard Worker } 2335*4bdc9457SAndroid Build Coastguard Worker } 2336*4bdc9457SAndroid Build Coastguard Worker } 2337*4bdc9457SAndroid Build Coastguard Worker } 2338*4bdc9457SAndroid Build Coastguard Worker } 2339*4bdc9457SAndroid Build Coastguard Worker } 2340*4bdc9457SAndroid Build Coastguard Worker } 2341*4bdc9457SAndroid Build Coastguard Worker } 2342*4bdc9457SAndroid Build Coastguard Worker } 2343*4bdc9457SAndroid Build Coastguard Worker } 2344*4bdc9457SAndroid Build Coastguard Worker 2345*4bdc9457SAndroid Build Coastguard Worker // Compute renormalization parameters. 2346*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend()); 2347*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend()); 2348*4bdc9457SAndroid Build Coastguard Worker 2349*4bdc9457SAndroid Build Coastguard Worker const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0; 2350*4bdc9457SAndroid Build Coastguard Worker const int8_t output_zero_point = int8_t(std::max(std::min( 2351*4bdc9457SAndroid Build Coastguard Worker lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale), 2352*4bdc9457SAndroid Build Coastguard Worker long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min()))); 2353*4bdc9457SAndroid Build Coastguard Worker 2354*4bdc9457SAndroid Build Coastguard Worker // Renormalize reference results. 2355*4bdc9457SAndroid Build Coastguard Worker std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(), 2356*4bdc9457SAndroid Build Coastguard Worker [this, output_scale, output_zero_point](int32_t x) -> double { 2357*4bdc9457SAndroid Build Coastguard Worker return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax() - 0x80) - output_zero_point), double(qmin() - 0x80) - output_zero_point); 2358*4bdc9457SAndroid Build Coastguard Worker }); 2359*4bdc9457SAndroid Build Coastguard Worker 2360*4bdc9457SAndroid Build Coastguard Worker // Create, setup, and run Convolution operator once. 2361*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 2362*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op = nullptr; 2363*4bdc9457SAndroid Build Coastguard Worker 2364*4bdc9457SAndroid Build Coastguard Worker xnn_status status = xnn_create_convolution2d_nhwc_qs8( 2365*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(), 2366*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), 2367*4bdc9457SAndroid Build Coastguard Worker subsampling_height(), subsampling_width(), 2368*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), 2369*4bdc9457SAndroid Build Coastguard Worker groups(), group_input_channels(), group_output_channels(), 2370*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(), output_channel_stride(), 2371*4bdc9457SAndroid Build Coastguard Worker input_zero_point, 1.0f /* input scale */, 1.0f /* kernel scale */, 2372*4bdc9457SAndroid Build Coastguard Worker kernel.data(), has_bias() ? bias.data() : nullptr, 2373*4bdc9457SAndroid Build Coastguard Worker output_zero_point, output_scale, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80), 2374*4bdc9457SAndroid Build Coastguard Worker 0, NULL, &convolution_op); 2375*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 2376*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 2377*4bdc9457SAndroid Build Coastguard Worker } 2378*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 2379*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convolution_op); 2380*4bdc9457SAndroid Build Coastguard Worker 2381*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete convolution_op. 2382*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op, xnn_delete_operator); 2383*4bdc9457SAndroid Build Coastguard Worker 2384*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 2385*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convolution2d_nhwc_qs8( 2386*4bdc9457SAndroid Build Coastguard Worker convolution_op, 2387*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 2388*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 2389*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 2390*4bdc9457SAndroid Build Coastguard Worker 2391*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 2392*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convolution_op, nullptr /* thread pool */)); 2393*4bdc9457SAndroid Build Coastguard Worker 2394*4bdc9457SAndroid Build Coastguard Worker // Verify results of the first run. 2395*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 2396*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 2397*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 2398*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2399*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) { 2400*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmax() - 0x80)) 2401*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2402*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmin() - 0x80)) 2403*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2404*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 2405*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c], 2406*4bdc9457SAndroid Build Coastguard Worker double(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]) - double(output_zero_point), 2407*4bdc9457SAndroid Build Coastguard Worker 0.9) 2408*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2409*4bdc9457SAndroid Build Coastguard Worker } 2410*4bdc9457SAndroid Build Coastguard Worker } 2411*4bdc9457SAndroid Build Coastguard Worker } 2412*4bdc9457SAndroid Build Coastguard Worker } 2413*4bdc9457SAndroid Build Coastguard Worker } 2414*4bdc9457SAndroid Build Coastguard Worker 2415*4bdc9457SAndroid Build Coastguard Worker // Re-generate data for the second run. 2416*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); }); 2417*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), INT8_C(0xA5)); 2418*4bdc9457SAndroid Build Coastguard Worker 2419*4bdc9457SAndroid Build Coastguard Worker // Compute reference results for the second run, including renormalization. 2420*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) { 2421*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 2422*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) { 2423*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) { 2424*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2425*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 2426*4bdc9457SAndroid Build Coastguard Worker next_accumulators[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] = 2427*4bdc9457SAndroid Build Coastguard Worker bias[g * group_output_channels() + oc]; 2428*4bdc9457SAndroid Build Coastguard Worker } 2429*4bdc9457SAndroid Build Coastguard Worker } 2430*4bdc9457SAndroid Build Coastguard Worker } 2431*4bdc9457SAndroid Build Coastguard Worker } 2432*4bdc9457SAndroid Build Coastguard Worker } 2433*4bdc9457SAndroid Build Coastguard Worker } else { 2434*4bdc9457SAndroid Build Coastguard Worker std::fill(next_accumulators.begin(), next_accumulators.end(), 0); 2435*4bdc9457SAndroid Build Coastguard Worker } 2436*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 2437*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) { 2438*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) { 2439*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) { 2440*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top(); 2441*4bdc9457SAndroid Build Coastguard Worker if (iy < next_input_height()) { 2442*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) { 2443*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left(); 2444*4bdc9457SAndroid Build Coastguard Worker if (ix < next_input_width()) { 2445*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2446*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 2447*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) { 2448*4bdc9457SAndroid Build Coastguard Worker next_accumulators[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] += 2449*4bdc9457SAndroid Build Coastguard Worker (int32_t(input[((i * next_input_height() + iy) * next_input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic]) - int32_t(input_zero_point)) * 2450*4bdc9457SAndroid Build Coastguard Worker int32_t(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]); 2451*4bdc9457SAndroid Build Coastguard Worker } 2452*4bdc9457SAndroid Build Coastguard Worker } 2453*4bdc9457SAndroid Build Coastguard Worker } 2454*4bdc9457SAndroid Build Coastguard Worker } 2455*4bdc9457SAndroid Build Coastguard Worker } 2456*4bdc9457SAndroid Build Coastguard Worker } 2457*4bdc9457SAndroid Build Coastguard Worker } 2458*4bdc9457SAndroid Build Coastguard Worker } 2459*4bdc9457SAndroid Build Coastguard Worker } 2460*4bdc9457SAndroid Build Coastguard Worker } 2461*4bdc9457SAndroid Build Coastguard Worker std::transform(next_accumulators.cbegin(), next_accumulators.cend(), next_output_ref.begin(), 2462*4bdc9457SAndroid Build Coastguard Worker [this, output_scale, output_zero_point](int32_t x) -> double { 2463*4bdc9457SAndroid Build Coastguard Worker return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax() - 0x80) - output_zero_point), double(qmin() - 0x80) - output_zero_point); 2464*4bdc9457SAndroid Build Coastguard Worker }); 2465*4bdc9457SAndroid Build Coastguard Worker 2466*4bdc9457SAndroid Build Coastguard Worker // Setup and run Convolution operator the second time, and destroy the operator. 2467*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 2468*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convolution2d_nhwc_qs8( 2469*4bdc9457SAndroid Build Coastguard Worker convolution_op, 2470*4bdc9457SAndroid Build Coastguard Worker next_batch_size(), next_input_height(), next_input_width(), 2471*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 2472*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 2473*4bdc9457SAndroid Build Coastguard Worker 2474*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 2475*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convolution_op, nullptr /* thread pool */)); 2476*4bdc9457SAndroid Build Coastguard Worker 2477*4bdc9457SAndroid Build Coastguard Worker // Verify results of the second run. 2478*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 2479*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < next_output_height(); y++) { 2480*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < next_output_width(); x++) { 2481*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2482*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) { 2483*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmax() - 0x80)) 2484*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2485*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmin() - 0x80)) 2486*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2487*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 2488*4bdc9457SAndroid Build Coastguard Worker next_output_ref[(((i * next_output_height() + y) * next_output_width() + x) * groups() + g) * group_output_channels() + c], 2489*4bdc9457SAndroid Build Coastguard Worker double(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]) - double(output_zero_point), 2490*4bdc9457SAndroid Build Coastguard Worker 0.9) 2491*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2492*4bdc9457SAndroid Build Coastguard Worker } 2493*4bdc9457SAndroid Build Coastguard Worker } 2494*4bdc9457SAndroid Build Coastguard Worker } 2495*4bdc9457SAndroid Build Coastguard Worker } 2496*4bdc9457SAndroid Build Coastguard Worker } 2497*4bdc9457SAndroid Build Coastguard Worker } 2498*4bdc9457SAndroid Build Coastguard Worker } 2499*4bdc9457SAndroid Build Coastguard Worker TestSetupNHWCxQU8()2500*4bdc9457SAndroid Build Coastguard Worker void TestSetupNHWCxQU8() const { 2501*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_type(), WeightsType::Default); 2502*4bdc9457SAndroid Build Coastguard Worker 2503*4bdc9457SAndroid Build Coastguard Worker ASSERT_FALSE(depthwise_layout()); 2504*4bdc9457SAndroid Build Coastguard Worker 2505*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 2506*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 2507*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i32dist(-10000, 10000); 2508*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> u8dist( 2509*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max()); 2510*4bdc9457SAndroid Build Coastguard Worker 2511*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + std::max( 2512*4bdc9457SAndroid Build Coastguard Worker batch_size() * ((input_height() * input_width() - 1) * input_channel_stride() + groups() * group_input_channels()), 2513*4bdc9457SAndroid Build Coastguard Worker next_batch_size() * ((next_input_height() * next_input_width() - 1) * input_channel_stride() + groups() * group_input_channels()))); 2514*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels()); 2515*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> bias(groups() * group_output_channels()); 2516*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output(std::max( 2517*4bdc9457SAndroid Build Coastguard Worker batch_size() * ((output_height() * output_width() - 1) * output_channel_stride() + groups() * group_output_channels()), 2518*4bdc9457SAndroid Build Coastguard Worker next_batch_size() * ((next_output_height() * next_output_width() - 1) * output_channel_stride() + groups() * group_output_channels()))); 2519*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> accumulators(batch_size() * output_height() * output_width() * groups() * group_output_channels()); 2520*4bdc9457SAndroid Build Coastguard Worker std::vector<double> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels()); 2521*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> next_accumulators(next_batch_size() * next_output_height() * next_output_width() * groups() * group_output_channels()); 2522*4bdc9457SAndroid Build Coastguard Worker std::vector<double> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * groups() * group_output_channels()); 2523*4bdc9457SAndroid Build Coastguard Worker 2524*4bdc9457SAndroid Build Coastguard Worker const uint8_t input_zero_point = 127; 2525*4bdc9457SAndroid Build Coastguard Worker const uint8_t kernel_zero_point = 127; 2526*4bdc9457SAndroid Build Coastguard Worker 2527*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 2528*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); 2529*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return u8dist(rng); }); 2530*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); }); 2531*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT8_C(0xA5)); 2532*4bdc9457SAndroid Build Coastguard Worker 2533*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without renormalization. 2534*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) { 2535*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 2536*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 2537*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 2538*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2539*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 2540*4bdc9457SAndroid Build Coastguard Worker accumulators[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] = 2541*4bdc9457SAndroid Build Coastguard Worker bias[g * group_output_channels() + oc]; 2542*4bdc9457SAndroid Build Coastguard Worker } 2543*4bdc9457SAndroid Build Coastguard Worker } 2544*4bdc9457SAndroid Build Coastguard Worker } 2545*4bdc9457SAndroid Build Coastguard Worker } 2546*4bdc9457SAndroid Build Coastguard Worker } 2547*4bdc9457SAndroid Build Coastguard Worker } else { 2548*4bdc9457SAndroid Build Coastguard Worker std::fill(accumulators.begin(), accumulators.end(), 0); 2549*4bdc9457SAndroid Build Coastguard Worker } 2550*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 2551*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 2552*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 2553*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) { 2554*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top(); 2555*4bdc9457SAndroid Build Coastguard Worker if (iy < input_height()) { 2556*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) { 2557*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left(); 2558*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width()) { 2559*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2560*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 2561*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) { 2562*4bdc9457SAndroid Build Coastguard Worker accumulators[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] += 2563*4bdc9457SAndroid Build Coastguard Worker (int32_t(input[((i * input_height() + iy) * input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic]) - int32_t(input_zero_point)) * 2564*4bdc9457SAndroid Build Coastguard Worker (int32_t(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]) - int32_t(kernel_zero_point)); 2565*4bdc9457SAndroid Build Coastguard Worker } 2566*4bdc9457SAndroid Build Coastguard Worker } 2567*4bdc9457SAndroid Build Coastguard Worker } 2568*4bdc9457SAndroid Build Coastguard Worker } 2569*4bdc9457SAndroid Build Coastguard Worker } 2570*4bdc9457SAndroid Build Coastguard Worker } 2571*4bdc9457SAndroid Build Coastguard Worker } 2572*4bdc9457SAndroid Build Coastguard Worker } 2573*4bdc9457SAndroid Build Coastguard Worker } 2574*4bdc9457SAndroid Build Coastguard Worker } 2575*4bdc9457SAndroid Build Coastguard Worker 2576*4bdc9457SAndroid Build Coastguard Worker // Compute renormalization parameters. 2577*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend()); 2578*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend()); 2579*4bdc9457SAndroid Build Coastguard Worker 2580*4bdc9457SAndroid Build Coastguard Worker const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0; 2581*4bdc9457SAndroid Build Coastguard Worker const uint8_t output_zero_point = uint8_t(std::max(std::min( 2582*4bdc9457SAndroid Build Coastguard Worker lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale), 2583*4bdc9457SAndroid Build Coastguard Worker long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min()))); 2584*4bdc9457SAndroid Build Coastguard Worker 2585*4bdc9457SAndroid Build Coastguard Worker // Renormalize reference results. 2586*4bdc9457SAndroid Build Coastguard Worker std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(), 2587*4bdc9457SAndroid Build Coastguard Worker [this, output_scale, output_zero_point](int32_t x) -> double { 2588*4bdc9457SAndroid Build Coastguard Worker return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax()) - output_zero_point), double(qmin()) - output_zero_point); 2589*4bdc9457SAndroid Build Coastguard Worker }); 2590*4bdc9457SAndroid Build Coastguard Worker 2591*4bdc9457SAndroid Build Coastguard Worker // Create, setup, and run Convolution operator once. 2592*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 2593*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op = nullptr; 2594*4bdc9457SAndroid Build Coastguard Worker 2595*4bdc9457SAndroid Build Coastguard Worker xnn_status status = xnn_create_convolution2d_nhwc_qu8( 2596*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(), 2597*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), 2598*4bdc9457SAndroid Build Coastguard Worker subsampling_height(), subsampling_width(), 2599*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), 2600*4bdc9457SAndroid Build Coastguard Worker groups(), group_input_channels(), group_output_channels(), 2601*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(), output_channel_stride(), 2602*4bdc9457SAndroid Build Coastguard Worker input_zero_point, 1.0f /* input scale */, 2603*4bdc9457SAndroid Build Coastguard Worker kernel_zero_point, 1.0f /* kernel scale */, 2604*4bdc9457SAndroid Build Coastguard Worker kernel.data(), has_bias() ? bias.data() : nullptr, 2605*4bdc9457SAndroid Build Coastguard Worker output_zero_point, output_scale, qmin(), qmax(), 2606*4bdc9457SAndroid Build Coastguard Worker 0, NULL, &convolution_op); 2607*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 2608*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 2609*4bdc9457SAndroid Build Coastguard Worker } 2610*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 2611*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convolution_op); 2612*4bdc9457SAndroid Build Coastguard Worker 2613*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete convolution_op. 2614*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op, xnn_delete_operator); 2615*4bdc9457SAndroid Build Coastguard Worker 2616*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 2617*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convolution2d_nhwc_qu8( 2618*4bdc9457SAndroid Build Coastguard Worker convolution_op, 2619*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 2620*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 2621*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 2622*4bdc9457SAndroid Build Coastguard Worker 2623*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 2624*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convolution_op, nullptr /* thread pool */)); 2625*4bdc9457SAndroid Build Coastguard Worker 2626*4bdc9457SAndroid Build Coastguard Worker // Verify results of the first run. 2627*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 2628*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 2629*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 2630*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2631*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) { 2632*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmax())) 2633*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2634*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmin())) 2635*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2636*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 2637*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c], 2638*4bdc9457SAndroid Build Coastguard Worker double(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]) - double(output_zero_point), 2639*4bdc9457SAndroid Build Coastguard Worker 0.9) 2640*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2641*4bdc9457SAndroid Build Coastguard Worker } 2642*4bdc9457SAndroid Build Coastguard Worker } 2643*4bdc9457SAndroid Build Coastguard Worker } 2644*4bdc9457SAndroid Build Coastguard Worker } 2645*4bdc9457SAndroid Build Coastguard Worker } 2646*4bdc9457SAndroid Build Coastguard Worker 2647*4bdc9457SAndroid Build Coastguard Worker // Re-generate data for the second run. 2648*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); }); 2649*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), 0xA5); 2650*4bdc9457SAndroid Build Coastguard Worker 2651*4bdc9457SAndroid Build Coastguard Worker // Compute reference results for the second run, including renormalization. 2652*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) { 2653*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 2654*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) { 2655*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) { 2656*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2657*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 2658*4bdc9457SAndroid Build Coastguard Worker next_accumulators[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] = 2659*4bdc9457SAndroid Build Coastguard Worker bias[g * group_output_channels() + oc]; 2660*4bdc9457SAndroid Build Coastguard Worker } 2661*4bdc9457SAndroid Build Coastguard Worker } 2662*4bdc9457SAndroid Build Coastguard Worker } 2663*4bdc9457SAndroid Build Coastguard Worker } 2664*4bdc9457SAndroid Build Coastguard Worker } 2665*4bdc9457SAndroid Build Coastguard Worker } else { 2666*4bdc9457SAndroid Build Coastguard Worker std::fill(next_accumulators.begin(), next_accumulators.end(), 0); 2667*4bdc9457SAndroid Build Coastguard Worker } 2668*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 2669*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) { 2670*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) { 2671*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) { 2672*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top(); 2673*4bdc9457SAndroid Build Coastguard Worker if (iy < next_input_height()) { 2674*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) { 2675*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left(); 2676*4bdc9457SAndroid Build Coastguard Worker if (ix < next_input_width()) { 2677*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2678*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 2679*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) { 2680*4bdc9457SAndroid Build Coastguard Worker next_accumulators[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] += 2681*4bdc9457SAndroid Build Coastguard Worker (int32_t(input[((i * next_input_height() + iy) * next_input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic]) - int32_t(input_zero_point)) * 2682*4bdc9457SAndroid Build Coastguard Worker (int32_t(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]) - int32_t(kernel_zero_point)); 2683*4bdc9457SAndroid Build Coastguard Worker } 2684*4bdc9457SAndroid Build Coastguard Worker } 2685*4bdc9457SAndroid Build Coastguard Worker } 2686*4bdc9457SAndroid Build Coastguard Worker } 2687*4bdc9457SAndroid Build Coastguard Worker } 2688*4bdc9457SAndroid Build Coastguard Worker } 2689*4bdc9457SAndroid Build Coastguard Worker } 2690*4bdc9457SAndroid Build Coastguard Worker } 2691*4bdc9457SAndroid Build Coastguard Worker } 2692*4bdc9457SAndroid Build Coastguard Worker } 2693*4bdc9457SAndroid Build Coastguard Worker std::transform(next_accumulators.cbegin(), next_accumulators.cend(), next_output_ref.begin(), 2694*4bdc9457SAndroid Build Coastguard Worker [this, output_scale, output_zero_point](int32_t x) -> double { 2695*4bdc9457SAndroid Build Coastguard Worker return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax()) - output_zero_point), double(qmin()) - output_zero_point); 2696*4bdc9457SAndroid Build Coastguard Worker }); 2697*4bdc9457SAndroid Build Coastguard Worker 2698*4bdc9457SAndroid Build Coastguard Worker // Setup and run Convolution operator the second time, and destroy the operator. 2699*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 2700*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convolution2d_nhwc_qu8( 2701*4bdc9457SAndroid Build Coastguard Worker convolution_op, 2702*4bdc9457SAndroid Build Coastguard Worker next_batch_size(), next_input_height(), next_input_width(), 2703*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 2704*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 2705*4bdc9457SAndroid Build Coastguard Worker 2706*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 2707*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convolution_op, nullptr /* thread pool */)); 2708*4bdc9457SAndroid Build Coastguard Worker 2709*4bdc9457SAndroid Build Coastguard Worker // Verify results of the second run. 2710*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 2711*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < next_output_height(); y++) { 2712*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < next_output_width(); x++) { 2713*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2714*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) { 2715*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmax())) 2716*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2717*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmin())) 2718*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2719*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 2720*4bdc9457SAndroid Build Coastguard Worker next_output_ref[(((i * next_output_height() + y) * next_output_width() + x) * groups() + g) * group_output_channels() + c], 2721*4bdc9457SAndroid Build Coastguard Worker double(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]) - double(output_zero_point), 2722*4bdc9457SAndroid Build Coastguard Worker 0.9) 2723*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2724*4bdc9457SAndroid Build Coastguard Worker } 2725*4bdc9457SAndroid Build Coastguard Worker } 2726*4bdc9457SAndroid Build Coastguard Worker } 2727*4bdc9457SAndroid Build Coastguard Worker } 2728*4bdc9457SAndroid Build Coastguard Worker } 2729*4bdc9457SAndroid Build Coastguard Worker } 2730*4bdc9457SAndroid Build Coastguard Worker } 2731*4bdc9457SAndroid Build Coastguard Worker TestSetupNHWCxF16()2732*4bdc9457SAndroid Build Coastguard Worker void TestSetupNHWCxF16() const { 2733*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_type(), WeightsType::Default); 2734*4bdc9457SAndroid Build Coastguard Worker 2735*4bdc9457SAndroid Build Coastguard Worker ASSERT_FALSE(depthwise_layout()); 2736*4bdc9457SAndroid Build Coastguard Worker 2737*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 2738*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 2739*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.1f, 1.0f); 2740*4bdc9457SAndroid Build Coastguard Worker 2741*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + std::max( 2742*4bdc9457SAndroid Build Coastguard Worker batch_size() * ((input_height() * input_width() - 1) * input_channel_stride() + groups() * group_input_channels()), 2743*4bdc9457SAndroid Build Coastguard Worker next_batch_size() * ((next_input_height() * next_input_width() - 1) * input_channel_stride() + groups() * group_input_channels()))); 2744*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels()); 2745*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> bias(groups() * group_output_channels()); 2746*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output(std::max( 2747*4bdc9457SAndroid Build Coastguard Worker batch_size() * ((output_height() * output_width() - 1) * output_channel_stride() + groups() * group_output_channels()), 2748*4bdc9457SAndroid Build Coastguard Worker next_batch_size() * ((next_output_height() * next_output_width() - 1) * output_channel_stride() + groups() * group_output_channels()))); 2749*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels()); 2750*4bdc9457SAndroid Build Coastguard Worker std::vector<float> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * groups() * group_output_channels()); 2751*4bdc9457SAndroid Build Coastguard Worker 2752*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 2753*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 2754*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 2755*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 2756*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 2757*4bdc9457SAndroid Build Coastguard Worker 2758*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 2759*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) { 2760*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 2761*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 2762*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 2763*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2764*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 2765*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] = 2766*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(bias[g * group_output_channels() + oc]); 2767*4bdc9457SAndroid Build Coastguard Worker } 2768*4bdc9457SAndroid Build Coastguard Worker } 2769*4bdc9457SAndroid Build Coastguard Worker } 2770*4bdc9457SAndroid Build Coastguard Worker } 2771*4bdc9457SAndroid Build Coastguard Worker } 2772*4bdc9457SAndroid Build Coastguard Worker } else { 2773*4bdc9457SAndroid Build Coastguard Worker std::fill(output_ref.begin(), output_ref.end(), 0.0f); 2774*4bdc9457SAndroid Build Coastguard Worker } 2775*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 2776*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 2777*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 2778*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) { 2779*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top(); 2780*4bdc9457SAndroid Build Coastguard Worker if (iy < input_height()) { 2781*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) { 2782*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left(); 2783*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width()) { 2784*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2785*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 2786*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) { 2787*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] += 2788*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(input[((i * input_height() + iy) * input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic]) * 2789*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]); 2790*4bdc9457SAndroid Build Coastguard Worker } 2791*4bdc9457SAndroid Build Coastguard Worker } 2792*4bdc9457SAndroid Build Coastguard Worker } 2793*4bdc9457SAndroid Build Coastguard Worker } 2794*4bdc9457SAndroid Build Coastguard Worker } 2795*4bdc9457SAndroid Build Coastguard Worker } 2796*4bdc9457SAndroid Build Coastguard Worker } 2797*4bdc9457SAndroid Build Coastguard Worker } 2798*4bdc9457SAndroid Build Coastguard Worker } 2799*4bdc9457SAndroid Build Coastguard Worker } 2800*4bdc9457SAndroid Build Coastguard Worker 2801*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 2802*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 2803*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 2804*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min; 2805*4bdc9457SAndroid Build Coastguard Worker const float scaled_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + accumulated_range / 255.0f * float(qmin()))); 2806*4bdc9457SAndroid Build Coastguard Worker const float scaled_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - accumulated_range / 255.0f * float(255 - qmax()))); 2807*4bdc9457SAndroid Build Coastguard Worker const float output_min = scaled_min == scaled_max ? -std::numeric_limits<float>::infinity() : scaled_min; 2808*4bdc9457SAndroid Build Coastguard Worker const float output_max = scaled_min == scaled_max ? +std::numeric_limits<float>::infinity() : scaled_max; 2809*4bdc9457SAndroid Build Coastguard Worker 2810*4bdc9457SAndroid Build Coastguard Worker for (float& output_value : output_ref) { 2811*4bdc9457SAndroid Build Coastguard Worker output_value = std::min(std::max(output_value, output_min), output_max); 2812*4bdc9457SAndroid Build Coastguard Worker } 2813*4bdc9457SAndroid Build Coastguard Worker 2814*4bdc9457SAndroid Build Coastguard Worker // Create, setup, and run Convolution operator once. 2815*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 2816*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op = nullptr; 2817*4bdc9457SAndroid Build Coastguard Worker 2818*4bdc9457SAndroid Build Coastguard Worker xnn_status status = xnn_create_convolution2d_nhwc_f16( 2819*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(), 2820*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), 2821*4bdc9457SAndroid Build Coastguard Worker subsampling_height(), subsampling_width(), 2822*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), 2823*4bdc9457SAndroid Build Coastguard Worker groups(), group_input_channels(), group_output_channels(), 2824*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(), output_channel_stride(), 2825*4bdc9457SAndroid Build Coastguard Worker kernel.data(), has_bias() ? bias.data() : nullptr, 2826*4bdc9457SAndroid Build Coastguard Worker output_min, output_max, 2827*4bdc9457SAndroid Build Coastguard Worker 0, NULL, &convolution_op); 2828*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 2829*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 2830*4bdc9457SAndroid Build Coastguard Worker } 2831*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 2832*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convolution_op); 2833*4bdc9457SAndroid Build Coastguard Worker 2834*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete convolution_op. 2835*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op, xnn_delete_operator); 2836*4bdc9457SAndroid Build Coastguard Worker 2837*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 2838*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convolution2d_nhwc_f16( 2839*4bdc9457SAndroid Build Coastguard Worker convolution_op, 2840*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 2841*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 2842*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 2843*4bdc9457SAndroid Build Coastguard Worker 2844*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 2845*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convolution_op, nullptr /* thread pool */)); 2846*4bdc9457SAndroid Build Coastguard Worker 2847*4bdc9457SAndroid Build Coastguard Worker // Verify results of the first run. 2848*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 2849*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 2850*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 2851*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2852*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) { 2853*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), output_min) 2854*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2855*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), output_max) 2856*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2857*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c], fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), std::max(1.0e-4f, std::abs(output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c]) * 1.0e-2f)) 2858*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2859*4bdc9457SAndroid Build Coastguard Worker } 2860*4bdc9457SAndroid Build Coastguard Worker } 2861*4bdc9457SAndroid Build Coastguard Worker } 2862*4bdc9457SAndroid Build Coastguard Worker } 2863*4bdc9457SAndroid Build Coastguard Worker } 2864*4bdc9457SAndroid Build Coastguard Worker 2865*4bdc9457SAndroid Build Coastguard Worker // Re-generate data for the second run. 2866*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); }); 2867*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */); 2868*4bdc9457SAndroid Build Coastguard Worker 2869*4bdc9457SAndroid Build Coastguard Worker // Compute reference results for the second run, including clamping. 2870*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) { 2871*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 2872*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) { 2873*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) { 2874*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2875*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 2876*4bdc9457SAndroid Build Coastguard Worker next_output_ref[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] = 2877*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(bias[g * group_output_channels() + oc]); 2878*4bdc9457SAndroid Build Coastguard Worker } 2879*4bdc9457SAndroid Build Coastguard Worker } 2880*4bdc9457SAndroid Build Coastguard Worker } 2881*4bdc9457SAndroid Build Coastguard Worker } 2882*4bdc9457SAndroid Build Coastguard Worker } 2883*4bdc9457SAndroid Build Coastguard Worker } else { 2884*4bdc9457SAndroid Build Coastguard Worker std::fill(next_output_ref.begin(), next_output_ref.end(), 0.0f); 2885*4bdc9457SAndroid Build Coastguard Worker } 2886*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 2887*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) { 2888*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) { 2889*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) { 2890*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top(); 2891*4bdc9457SAndroid Build Coastguard Worker if (iy < next_input_height()) { 2892*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) { 2893*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left(); 2894*4bdc9457SAndroid Build Coastguard Worker if (ix < next_input_width()) { 2895*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2896*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 2897*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) { 2898*4bdc9457SAndroid Build Coastguard Worker next_output_ref[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] += 2899*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(input[((i * next_input_height() + iy) * next_input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic]) * 2900*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]); 2901*4bdc9457SAndroid Build Coastguard Worker } 2902*4bdc9457SAndroid Build Coastguard Worker } 2903*4bdc9457SAndroid Build Coastguard Worker } 2904*4bdc9457SAndroid Build Coastguard Worker } 2905*4bdc9457SAndroid Build Coastguard Worker } 2906*4bdc9457SAndroid Build Coastguard Worker } 2907*4bdc9457SAndroid Build Coastguard Worker } 2908*4bdc9457SAndroid Build Coastguard Worker } 2909*4bdc9457SAndroid Build Coastguard Worker } 2910*4bdc9457SAndroid Build Coastguard Worker } 2911*4bdc9457SAndroid Build Coastguard Worker for (float& value : next_output_ref) { 2912*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min); 2913*4bdc9457SAndroid Build Coastguard Worker } 2914*4bdc9457SAndroid Build Coastguard Worker 2915*4bdc9457SAndroid Build Coastguard Worker // Setup and run Convolution operator the second time, and destroy the operator. 2916*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 2917*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convolution2d_nhwc_f16( 2918*4bdc9457SAndroid Build Coastguard Worker convolution_op, 2919*4bdc9457SAndroid Build Coastguard Worker next_batch_size(), next_input_height(), next_input_width(), 2920*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 2921*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 2922*4bdc9457SAndroid Build Coastguard Worker 2923*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 2924*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convolution_op, nullptr /* thread pool */)); 2925*4bdc9457SAndroid Build Coastguard Worker 2926*4bdc9457SAndroid Build Coastguard Worker // Verify results of the second run. 2927*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 2928*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < next_output_height(); y++) { 2929*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < next_output_width(); x++) { 2930*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2931*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) { 2932*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(fp16_ieee_to_fp32_value(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), output_min) 2933*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2934*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(fp16_ieee_to_fp32_value(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), output_max) 2935*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2936*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(next_output_ref[(((i * next_output_height() + y) * next_output_width() + x) * groups() + g) * group_output_channels() + c], fp16_ieee_to_fp32_value(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), std::max(1.0e-4f, std::abs(next_output_ref[(((i * next_output_height() + y) * next_output_width() + x) * groups() + g) * group_output_channels() + c]) * 1.0e-2f)) 2937*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 2938*4bdc9457SAndroid Build Coastguard Worker } 2939*4bdc9457SAndroid Build Coastguard Worker } 2940*4bdc9457SAndroid Build Coastguard Worker } 2941*4bdc9457SAndroid Build Coastguard Worker } 2942*4bdc9457SAndroid Build Coastguard Worker } 2943*4bdc9457SAndroid Build Coastguard Worker } 2944*4bdc9457SAndroid Build Coastguard Worker } 2945*4bdc9457SAndroid Build Coastguard Worker TestSetupNHWCxF32()2946*4bdc9457SAndroid Build Coastguard Worker void TestSetupNHWCxF32() const { 2947*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_type(), WeightsType::Default); 2948*4bdc9457SAndroid Build Coastguard Worker 2949*4bdc9457SAndroid Build Coastguard Worker ASSERT_FALSE(depthwise_layout()); 2950*4bdc9457SAndroid Build Coastguard Worker 2951*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device; 2952*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device()); 2953*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.1f, 1.0f); 2954*4bdc9457SAndroid Build Coastguard Worker 2955*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + std::max( 2956*4bdc9457SAndroid Build Coastguard Worker batch_size() * ((input_height() * input_width() - 1) * input_channel_stride() + groups() * group_input_channels()), 2957*4bdc9457SAndroid Build Coastguard Worker next_batch_size() * ((next_input_height() * next_input_width() - 1) * input_channel_stride() + groups() * group_input_channels()))); 2958*4bdc9457SAndroid Build Coastguard Worker std::vector<float> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels()); 2959*4bdc9457SAndroid Build Coastguard Worker std::vector<float> bias(groups() * group_output_channels()); 2960*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output(std::max( 2961*4bdc9457SAndroid Build Coastguard Worker batch_size() * ((output_height() * output_width() - 1) * output_channel_stride() + groups() * group_output_channels()), 2962*4bdc9457SAndroid Build Coastguard Worker next_batch_size() * ((next_output_height() * next_output_width() - 1) * output_channel_stride() + groups() * group_output_channels()))); 2963*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels()); 2964*4bdc9457SAndroid Build Coastguard Worker std::vector<float> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * groups() * group_output_channels()); 2965*4bdc9457SAndroid Build Coastguard Worker 2966*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) { 2967*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 2968*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return f32dist(rng); }); 2969*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); 2970*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), nanf("")); 2971*4bdc9457SAndroid Build Coastguard Worker 2972*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping. 2973*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) { 2974*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 2975*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 2976*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 2977*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2978*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 2979*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] = 2980*4bdc9457SAndroid Build Coastguard Worker bias[g * group_output_channels() + oc]; 2981*4bdc9457SAndroid Build Coastguard Worker } 2982*4bdc9457SAndroid Build Coastguard Worker } 2983*4bdc9457SAndroid Build Coastguard Worker } 2984*4bdc9457SAndroid Build Coastguard Worker } 2985*4bdc9457SAndroid Build Coastguard Worker } 2986*4bdc9457SAndroid Build Coastguard Worker } else { 2987*4bdc9457SAndroid Build Coastguard Worker std::fill(output_ref.begin(), output_ref.end(), 0.0f); 2988*4bdc9457SAndroid Build Coastguard Worker } 2989*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 2990*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) { 2991*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) { 2992*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) { 2993*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top(); 2994*4bdc9457SAndroid Build Coastguard Worker if (iy < input_height()) { 2995*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) { 2996*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left(); 2997*4bdc9457SAndroid Build Coastguard Worker if (ix < input_width()) { 2998*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 2999*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 3000*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) { 3001*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] += 3002*4bdc9457SAndroid Build Coastguard Worker input[((i * input_height() + iy) * input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic] * 3003*4bdc9457SAndroid Build Coastguard Worker kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]; 3004*4bdc9457SAndroid Build Coastguard Worker } 3005*4bdc9457SAndroid Build Coastguard Worker } 3006*4bdc9457SAndroid Build Coastguard Worker } 3007*4bdc9457SAndroid Build Coastguard Worker } 3008*4bdc9457SAndroid Build Coastguard Worker } 3009*4bdc9457SAndroid Build Coastguard Worker } 3010*4bdc9457SAndroid Build Coastguard Worker } 3011*4bdc9457SAndroid Build Coastguard Worker } 3012*4bdc9457SAndroid Build Coastguard Worker } 3013*4bdc9457SAndroid Build Coastguard Worker } 3014*4bdc9457SAndroid Build Coastguard Worker 3015*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters. 3016*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend()); 3017*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend()); 3018*4bdc9457SAndroid Build Coastguard Worker 3019*4bdc9457SAndroid Build Coastguard Worker const float output_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin()); 3020*4bdc9457SAndroid Build Coastguard Worker const float output_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax()); 3021*4bdc9457SAndroid Build Coastguard Worker 3022*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results. 3023*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) { 3024*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min); 3025*4bdc9457SAndroid Build Coastguard Worker } 3026*4bdc9457SAndroid Build Coastguard Worker 3027*4bdc9457SAndroid Build Coastguard Worker // Create, setup, and run Convolution operator once. 3028*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */)); 3029*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t convolution_op = nullptr; 3030*4bdc9457SAndroid Build Coastguard Worker 3031*4bdc9457SAndroid Build Coastguard Worker xnn_status status = xnn_create_convolution2d_nhwc_f32( 3032*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(), 3033*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), 3034*4bdc9457SAndroid Build Coastguard Worker subsampling_height(), subsampling_width(), 3035*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), 3036*4bdc9457SAndroid Build Coastguard Worker groups(), group_input_channels(), group_output_channels(), 3037*4bdc9457SAndroid Build Coastguard Worker input_channel_stride(), output_channel_stride(), 3038*4bdc9457SAndroid Build Coastguard Worker kernel.data(), has_bias() ? bias.data() : nullptr, 3039*4bdc9457SAndroid Build Coastguard Worker output_min, output_max, 3040*4bdc9457SAndroid Build Coastguard Worker 0, NULL, &convolution_op); 3041*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) { 3042*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP(); 3043*4bdc9457SAndroid Build Coastguard Worker } 3044*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status); 3045*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, convolution_op); 3046*4bdc9457SAndroid Build Coastguard Worker 3047*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete convolution_op. 3048*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op, xnn_delete_operator); 3049*4bdc9457SAndroid Build Coastguard Worker 3050*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 3051*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convolution2d_nhwc_f32( 3052*4bdc9457SAndroid Build Coastguard Worker convolution_op, 3053*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(), 3054*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 3055*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 3056*4bdc9457SAndroid Build Coastguard Worker 3057*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 3058*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convolution_op, nullptr /* thread pool */)); 3059*4bdc9457SAndroid Build Coastguard Worker 3060*4bdc9457SAndroid Build Coastguard Worker // Verify results of the first run. 3061*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) { 3062*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) { 3063*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) { 3064*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 3065*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) { 3066*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c], output_min) 3067*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 3068*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c], output_max) 3069*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 3070*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 3071*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c], 3072*4bdc9457SAndroid Build Coastguard Worker output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c], 3073*4bdc9457SAndroid Build Coastguard Worker 1.0e-4 * std::abs(output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c])) 3074*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 3075*4bdc9457SAndroid Build Coastguard Worker } 3076*4bdc9457SAndroid Build Coastguard Worker } 3077*4bdc9457SAndroid Build Coastguard Worker } 3078*4bdc9457SAndroid Build Coastguard Worker } 3079*4bdc9457SAndroid Build Coastguard Worker } 3080*4bdc9457SAndroid Build Coastguard Worker 3081*4bdc9457SAndroid Build Coastguard Worker // Re-generate data for the second run. 3082*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); 3083*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), nanf("")); 3084*4bdc9457SAndroid Build Coastguard Worker 3085*4bdc9457SAndroid Build Coastguard Worker // Compute reference results for the second run, including clamping. 3086*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) { 3087*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 3088*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) { 3089*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) { 3090*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 3091*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 3092*4bdc9457SAndroid Build Coastguard Worker next_output_ref[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] = 3093*4bdc9457SAndroid Build Coastguard Worker bias[g * group_output_channels() + oc]; 3094*4bdc9457SAndroid Build Coastguard Worker } 3095*4bdc9457SAndroid Build Coastguard Worker } 3096*4bdc9457SAndroid Build Coastguard Worker } 3097*4bdc9457SAndroid Build Coastguard Worker } 3098*4bdc9457SAndroid Build Coastguard Worker } 3099*4bdc9457SAndroid Build Coastguard Worker } else { 3100*4bdc9457SAndroid Build Coastguard Worker std::fill(next_output_ref.begin(), next_output_ref.end(), 0.0f); 3101*4bdc9457SAndroid Build Coastguard Worker } 3102*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 3103*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) { 3104*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) { 3105*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) { 3106*4bdc9457SAndroid Build Coastguard Worker const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top(); 3107*4bdc9457SAndroid Build Coastguard Worker if (iy < next_input_height()) { 3108*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) { 3109*4bdc9457SAndroid Build Coastguard Worker const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left(); 3110*4bdc9457SAndroid Build Coastguard Worker if (ix < next_input_width()) { 3111*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 3112*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) { 3113*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) { 3114*4bdc9457SAndroid Build Coastguard Worker next_output_ref[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] += 3115*4bdc9457SAndroid Build Coastguard Worker input[((i * next_input_height() + iy) * next_input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic] * 3116*4bdc9457SAndroid Build Coastguard Worker kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]; 3117*4bdc9457SAndroid Build Coastguard Worker } 3118*4bdc9457SAndroid Build Coastguard Worker } 3119*4bdc9457SAndroid Build Coastguard Worker } 3120*4bdc9457SAndroid Build Coastguard Worker } 3121*4bdc9457SAndroid Build Coastguard Worker } 3122*4bdc9457SAndroid Build Coastguard Worker } 3123*4bdc9457SAndroid Build Coastguard Worker } 3124*4bdc9457SAndroid Build Coastguard Worker } 3125*4bdc9457SAndroid Build Coastguard Worker } 3126*4bdc9457SAndroid Build Coastguard Worker } 3127*4bdc9457SAndroid Build Coastguard Worker for (float& value : next_output_ref) { 3128*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min); 3129*4bdc9457SAndroid Build Coastguard Worker } 3130*4bdc9457SAndroid Build Coastguard Worker 3131*4bdc9457SAndroid Build Coastguard Worker // Setup and run Convolution operator the second time, and destroy the operator. 3132*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 3133*4bdc9457SAndroid Build Coastguard Worker xnn_setup_convolution2d_nhwc_f32( 3134*4bdc9457SAndroid Build Coastguard Worker convolution_op, 3135*4bdc9457SAndroid Build Coastguard Worker next_batch_size(), next_input_height(), next_input_width(), 3136*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(), 3137*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */)); 3138*4bdc9457SAndroid Build Coastguard Worker 3139*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, 3140*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(convolution_op, nullptr /* thread pool */)); 3141*4bdc9457SAndroid Build Coastguard Worker 3142*4bdc9457SAndroid Build Coastguard Worker // Verify results of the second run. 3143*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) { 3144*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < next_output_height(); y++) { 3145*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < next_output_width(); x++) { 3146*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) { 3147*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) { 3148*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c], output_min) 3149*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 3150*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c], output_max) 3151*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 3152*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR( 3153*4bdc9457SAndroid Build Coastguard Worker next_output_ref[(((i * next_output_height() + y) * next_output_width() + x) * groups() + g) * group_output_channels() + c], 3154*4bdc9457SAndroid Build Coastguard Worker output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c], 3155*4bdc9457SAndroid Build Coastguard Worker 1.0e-4 * std::abs(next_output_ref[(((i * next_output_height() + y) * next_output_width() + x) * groups() + g) * group_output_channels() + c])) 3156*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c; 3157*4bdc9457SAndroid Build Coastguard Worker } 3158*4bdc9457SAndroid Build Coastguard Worker } 3159*4bdc9457SAndroid Build Coastguard Worker } 3160*4bdc9457SAndroid Build Coastguard Worker } 3161*4bdc9457SAndroid Build Coastguard Worker } 3162*4bdc9457SAndroid Build Coastguard Worker } 3163*4bdc9457SAndroid Build Coastguard Worker } 3164*4bdc9457SAndroid Build Coastguard Worker VerifyWeightsCache(const xnn_weights_cache & weights_cache,size_t old_size)3165*4bdc9457SAndroid Build Coastguard Worker void VerifyWeightsCache(const xnn_weights_cache &weights_cache, size_t old_size) const { 3166*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_cache.cache.hits, 1); 3167*4bdc9457SAndroid Build Coastguard Worker // Ensure that we did not write more weights to the cache because it was a 3168*4bdc9457SAndroid Build Coastguard Worker // cache hit. 3169*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(old_size, weights_cache.cache.weights.size); 3170*4bdc9457SAndroid Build Coastguard Worker }; 3171*4bdc9457SAndroid Build Coastguard Worker VerifyWeightsCacheUnused(const xnn_weights_cache & weights_cache)3172*4bdc9457SAndroid Build Coastguard Worker void VerifyWeightsCacheUnused(const xnn_weights_cache &weights_cache) const { 3173*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_cache.cache.hits, 0); 3174*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(0, weights_cache.cache.weights.size); 3175*4bdc9457SAndroid Build Coastguard Worker } 3176*4bdc9457SAndroid Build Coastguard Worker IsSpmm()3177*4bdc9457SAndroid Build Coastguard Worker bool IsSpmm() const { 3178*4bdc9457SAndroid Build Coastguard Worker const bool is_1x1 = kernel_width() == 1 && kernel_height() == 1 && 3179*4bdc9457SAndroid Build Coastguard Worker subsampling_height() == 1 && subsampling_width() == 1; 3180*4bdc9457SAndroid Build Coastguard Worker const bool any_padding = (padding_left() | padding_top() | padding_right() | padding_bottom()) != 0; 3181*4bdc9457SAndroid Build Coastguard Worker return is_1x1 && !any_padding && !force_nhwc_input() && groups() == 1; 3182*4bdc9457SAndroid Build Coastguard Worker } 3183*4bdc9457SAndroid Build Coastguard Worker 3184*4bdc9457SAndroid Build Coastguard Worker private: 3185*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_top_{0}; 3186*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_right_{0}; 3187*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_bottom_{0}; 3188*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_left_{0}; 3189*4bdc9457SAndroid Build Coastguard Worker bool padding_tf_same_{false}; 3190*4bdc9457SAndroid Build Coastguard Worker size_t input_height_{1}; 3191*4bdc9457SAndroid Build Coastguard Worker size_t input_width_{1}; 3192*4bdc9457SAndroid Build Coastguard Worker uint32_t groups_{1}; 3193*4bdc9457SAndroid Build Coastguard Worker size_t group_input_channels_{1}; 3194*4bdc9457SAndroid Build Coastguard Worker size_t input_channel_stride_{0}; 3195*4bdc9457SAndroid Build Coastguard Worker size_t group_output_channels_{1}; 3196*4bdc9457SAndroid Build Coastguard Worker size_t output_channel_stride_{0}; 3197*4bdc9457SAndroid Build Coastguard Worker size_t batch_size_{1}; 3198*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_height_{1}; 3199*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_width_{1}; 3200*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_height_{1}; 3201*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_width_{1}; 3202*4bdc9457SAndroid Build Coastguard Worker uint32_t subsampling_height_{1}; 3203*4bdc9457SAndroid Build Coastguard Worker uint32_t subsampling_width_{1}; 3204*4bdc9457SAndroid Build Coastguard Worker size_t next_input_height_{0}; 3205*4bdc9457SAndroid Build Coastguard Worker size_t next_input_width_{0}; 3206*4bdc9457SAndroid Build Coastguard Worker size_t next_batch_size_{0}; 3207*4bdc9457SAndroid Build Coastguard Worker float sparsity_{0.0f}; 3208*4bdc9457SAndroid Build Coastguard Worker uint8_t qmin_{0}; 3209*4bdc9457SAndroid Build Coastguard Worker uint8_t qmax_{255}; 3210*4bdc9457SAndroid Build Coastguard Worker bool depthwise_layout_{false}; 3211*4bdc9457SAndroid Build Coastguard Worker bool force_nhwc_input_{false}; 3212*4bdc9457SAndroid Build Coastguard Worker bool has_bias_{true}; 3213*4bdc9457SAndroid Build Coastguard Worker WeightsType weights_type_{WeightsType::Default}; 3214*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{1}; 3215*4bdc9457SAndroid Build Coastguard Worker #if XNN_PLATFORM_JIT 3216*4bdc9457SAndroid Build Coastguard Worker bool use_jit_{false}; 3217*4bdc9457SAndroid Build Coastguard Worker #endif 3218*4bdc9457SAndroid Build Coastguard Worker bool use_weights_cache_{false}; 3219*4bdc9457SAndroid Build Coastguard Worker }; 3220