xref: /aosp_15_r20/external/XNNPACK/test/convolution-operator-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates.
2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved.
3*4bdc9457SAndroid Build Coastguard Worker //
4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
5*4bdc9457SAndroid Build Coastguard Worker //
6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
8*4bdc9457SAndroid Build Coastguard Worker 
9*4bdc9457SAndroid Build Coastguard Worker #pragma once
10*4bdc9457SAndroid Build Coastguard Worker 
11*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
12*4bdc9457SAndroid Build Coastguard Worker 
13*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
14*4bdc9457SAndroid Build Coastguard Worker #include <cassert>
15*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
16*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
17*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
18*4bdc9457SAndroid Build Coastguard Worker #include <limits>
19*4bdc9457SAndroid Build Coastguard Worker #include <random>
20*4bdc9457SAndroid Build Coastguard Worker #include <vector>
21*4bdc9457SAndroid Build Coastguard Worker 
22*4bdc9457SAndroid Build Coastguard Worker #include "convolution-test-helpers.h"
23*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
24*4bdc9457SAndroid Build Coastguard Worker 
25*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
26*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/cache.h>
27*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/allocator.h>
28*4bdc9457SAndroid Build Coastguard Worker 
29*4bdc9457SAndroid Build Coastguard Worker 
30*4bdc9457SAndroid Build Coastguard Worker class ConvolutionOperatorTester {
31*4bdc9457SAndroid Build Coastguard Worker  public:
32*4bdc9457SAndroid Build Coastguard Worker   enum class WeightsType {
33*4bdc9457SAndroid Build Coastguard Worker     Default,
34*4bdc9457SAndroid Build Coastguard Worker     FP32,
35*4bdc9457SAndroid Build Coastguard Worker   };
36*4bdc9457SAndroid Build Coastguard Worker 
padding_tf_same(bool padding_same)37*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& padding_tf_same(bool padding_same) {
38*4bdc9457SAndroid Build Coastguard Worker     if (padding_same) {
39*4bdc9457SAndroid Build Coastguard Worker       assert(padding_top() == 0);
40*4bdc9457SAndroid Build Coastguard Worker       assert(padding_left() == 0);
41*4bdc9457SAndroid Build Coastguard Worker       assert(padding_bottom() == 0);
42*4bdc9457SAndroid Build Coastguard Worker       assert(padding_right() == 0);
43*4bdc9457SAndroid Build Coastguard Worker     }
44*4bdc9457SAndroid Build Coastguard Worker     this->padding_tf_same_ = padding_same;
45*4bdc9457SAndroid Build Coastguard Worker     return *this;
46*4bdc9457SAndroid Build Coastguard Worker   }
47*4bdc9457SAndroid Build Coastguard Worker 
padding_tf_same()48*4bdc9457SAndroid Build Coastguard Worker   inline bool padding_tf_same() const {
49*4bdc9457SAndroid Build Coastguard Worker     return this->padding_tf_same_;
50*4bdc9457SAndroid Build Coastguard Worker   }
51*4bdc9457SAndroid Build Coastguard Worker 
padding(uint32_t padding)52*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& padding(uint32_t padding) {
53*4bdc9457SAndroid Build Coastguard Worker     assert(!padding_tf_same());
54*4bdc9457SAndroid Build Coastguard Worker     this->padding_top_ = padding;
55*4bdc9457SAndroid Build Coastguard Worker     this->padding_right_ = padding;
56*4bdc9457SAndroid Build Coastguard Worker     this->padding_bottom_ = padding;
57*4bdc9457SAndroid Build Coastguard Worker     this->padding_left_ = padding;
58*4bdc9457SAndroid Build Coastguard Worker     return *this;
59*4bdc9457SAndroid Build Coastguard Worker   }
60*4bdc9457SAndroid Build Coastguard Worker 
padding(uint32_t padding_height,uint32_t padding_width)61*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& padding(uint32_t padding_height, uint32_t padding_width) {
62*4bdc9457SAndroid Build Coastguard Worker     assert(!padding_tf_same());
63*4bdc9457SAndroid Build Coastguard Worker     this->padding_top_ = padding_height;
64*4bdc9457SAndroid Build Coastguard Worker     this->padding_right_ = padding_width;
65*4bdc9457SAndroid Build Coastguard Worker     this->padding_bottom_ = padding_height;
66*4bdc9457SAndroid Build Coastguard Worker     this->padding_left_ = padding_width;
67*4bdc9457SAndroid Build Coastguard Worker     return *this;
68*4bdc9457SAndroid Build Coastguard Worker   }
69*4bdc9457SAndroid Build Coastguard Worker 
padding_height(uint32_t padding_height)70*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& padding_height(uint32_t padding_height) {
71*4bdc9457SAndroid Build Coastguard Worker     assert(!padding_tf_same());
72*4bdc9457SAndroid Build Coastguard Worker     this->padding_top_ = padding_height;
73*4bdc9457SAndroid Build Coastguard Worker     this->padding_bottom_ = padding_height;
74*4bdc9457SAndroid Build Coastguard Worker     return *this;
75*4bdc9457SAndroid Build Coastguard Worker   }
76*4bdc9457SAndroid Build Coastguard Worker 
padding_width(uint32_t padding_width)77*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& padding_width(uint32_t padding_width) {
78*4bdc9457SAndroid Build Coastguard Worker     assert(!padding_tf_same());
79*4bdc9457SAndroid Build Coastguard Worker     this->padding_right_ = padding_width;
80*4bdc9457SAndroid Build Coastguard Worker     this->padding_left_ = padding_width;
81*4bdc9457SAndroid Build Coastguard Worker     return *this;
82*4bdc9457SAndroid Build Coastguard Worker   }
83*4bdc9457SAndroid Build Coastguard Worker 
padding_top(uint32_t padding_top)84*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& padding_top(uint32_t padding_top) {
85*4bdc9457SAndroid Build Coastguard Worker     assert(!padding_tf_same());
86*4bdc9457SAndroid Build Coastguard Worker     this->padding_top_ = padding_top;
87*4bdc9457SAndroid Build Coastguard Worker     return *this;
88*4bdc9457SAndroid Build Coastguard Worker   }
89*4bdc9457SAndroid Build Coastguard Worker 
padding_top()90*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t padding_top() const {
91*4bdc9457SAndroid Build Coastguard Worker     if (padding_tf_same()) {
92*4bdc9457SAndroid Build Coastguard Worker       const uint32_t total_padding_height =
93*4bdc9457SAndroid Build Coastguard Worker         (output_height() - 1) * subsampling_height() + dilated_kernel_height() - input_height();
94*4bdc9457SAndroid Build Coastguard Worker       return total_padding_height / 2;
95*4bdc9457SAndroid Build Coastguard Worker     } else {
96*4bdc9457SAndroid Build Coastguard Worker       return this->padding_top_;
97*4bdc9457SAndroid Build Coastguard Worker     }
98*4bdc9457SAndroid Build Coastguard Worker   }
99*4bdc9457SAndroid Build Coastguard Worker 
padding_left(uint32_t padding_left)100*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& padding_left(uint32_t padding_left) {
101*4bdc9457SAndroid Build Coastguard Worker     assert(!padding_tf_same());
102*4bdc9457SAndroid Build Coastguard Worker     this->padding_left_ = padding_left;
103*4bdc9457SAndroid Build Coastguard Worker     return *this;
104*4bdc9457SAndroid Build Coastguard Worker   }
105*4bdc9457SAndroid Build Coastguard Worker 
padding_left()106*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t padding_left() const {
107*4bdc9457SAndroid Build Coastguard Worker     if (padding_tf_same()) {
108*4bdc9457SAndroid Build Coastguard Worker       const uint32_t total_padding_width =
109*4bdc9457SAndroid Build Coastguard Worker         (output_width() - 1) * subsampling_width() + dilated_kernel_width() - input_width();
110*4bdc9457SAndroid Build Coastguard Worker       return total_padding_width / 2;
111*4bdc9457SAndroid Build Coastguard Worker     } else {
112*4bdc9457SAndroid Build Coastguard Worker       return this->padding_left_;
113*4bdc9457SAndroid Build Coastguard Worker     }
114*4bdc9457SAndroid Build Coastguard Worker   }
115*4bdc9457SAndroid Build Coastguard Worker 
padding_bottom(uint32_t padding_bottom)116*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& padding_bottom(uint32_t padding_bottom) {
117*4bdc9457SAndroid Build Coastguard Worker     assert(!padding_tf_same());
118*4bdc9457SAndroid Build Coastguard Worker     this->padding_bottom_ = padding_bottom;
119*4bdc9457SAndroid Build Coastguard Worker     return *this;
120*4bdc9457SAndroid Build Coastguard Worker   }
121*4bdc9457SAndroid Build Coastguard Worker 
padding_bottom()122*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t padding_bottom() const {
123*4bdc9457SAndroid Build Coastguard Worker     if (padding_tf_same()) {
124*4bdc9457SAndroid Build Coastguard Worker       const uint32_t total_padding_height =
125*4bdc9457SAndroid Build Coastguard Worker         (output_height() - 1) * subsampling_height() + dilated_kernel_height() - input_height();
126*4bdc9457SAndroid Build Coastguard Worker       return total_padding_height - total_padding_height / 2;
127*4bdc9457SAndroid Build Coastguard Worker     } else {
128*4bdc9457SAndroid Build Coastguard Worker       return this->padding_bottom_;
129*4bdc9457SAndroid Build Coastguard Worker     }
130*4bdc9457SAndroid Build Coastguard Worker   }
131*4bdc9457SAndroid Build Coastguard Worker 
padding_right(uint32_t padding_right)132*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& padding_right(uint32_t padding_right) {
133*4bdc9457SAndroid Build Coastguard Worker     assert(!padding_tf_same());
134*4bdc9457SAndroid Build Coastguard Worker     this->padding_right_ = padding_right;
135*4bdc9457SAndroid Build Coastguard Worker     return *this;
136*4bdc9457SAndroid Build Coastguard Worker   }
137*4bdc9457SAndroid Build Coastguard Worker 
padding_right()138*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t padding_right() const {
139*4bdc9457SAndroid Build Coastguard Worker     if (padding_tf_same()) {
140*4bdc9457SAndroid Build Coastguard Worker       const uint32_t total_padding_width =
141*4bdc9457SAndroid Build Coastguard Worker         (output_width() - 1) * subsampling_width() + dilated_kernel_width() - input_width();
142*4bdc9457SAndroid Build Coastguard Worker       return total_padding_width - total_padding_width / 2;
143*4bdc9457SAndroid Build Coastguard Worker     } else {
144*4bdc9457SAndroid Build Coastguard Worker       return this->padding_right_;
145*4bdc9457SAndroid Build Coastguard Worker     }
146*4bdc9457SAndroid Build Coastguard Worker   }
147*4bdc9457SAndroid Build Coastguard Worker 
input_size(uint32_t input_height,uint32_t input_width)148*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& input_size(uint32_t input_height, uint32_t input_width) {
149*4bdc9457SAndroid Build Coastguard Worker     assert(input_height >= 1);
150*4bdc9457SAndroid Build Coastguard Worker     assert(input_width >= 1);
151*4bdc9457SAndroid Build Coastguard Worker     this->input_height_ = input_height;
152*4bdc9457SAndroid Build Coastguard Worker     this->input_width_ = input_width;
153*4bdc9457SAndroid Build Coastguard Worker     return *this;
154*4bdc9457SAndroid Build Coastguard Worker   }
155*4bdc9457SAndroid Build Coastguard Worker 
input_height(uint32_t input_height)156*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& input_height(uint32_t input_height) {
157*4bdc9457SAndroid Build Coastguard Worker     assert(input_height >= 1);
158*4bdc9457SAndroid Build Coastguard Worker     this->input_height_ = input_height;
159*4bdc9457SAndroid Build Coastguard Worker     return *this;
160*4bdc9457SAndroid Build Coastguard Worker   }
161*4bdc9457SAndroid Build Coastguard Worker 
input_height()162*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t input_height() const {
163*4bdc9457SAndroid Build Coastguard Worker     return this->input_height_;
164*4bdc9457SAndroid Build Coastguard Worker   }
165*4bdc9457SAndroid Build Coastguard Worker 
input_width(uint32_t input_width)166*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& input_width(uint32_t input_width) {
167*4bdc9457SAndroid Build Coastguard Worker     assert(input_width >= 1);
168*4bdc9457SAndroid Build Coastguard Worker     this->input_width_ = input_width;
169*4bdc9457SAndroid Build Coastguard Worker     return *this;
170*4bdc9457SAndroid Build Coastguard Worker   }
171*4bdc9457SAndroid Build Coastguard Worker 
input_width()172*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t input_width() const {
173*4bdc9457SAndroid Build Coastguard Worker     return this->input_width_;
174*4bdc9457SAndroid Build Coastguard Worker   }
175*4bdc9457SAndroid Build Coastguard Worker 
groups(uint32_t groups)176*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& groups(uint32_t groups) {
177*4bdc9457SAndroid Build Coastguard Worker     assert(groups >= 1);
178*4bdc9457SAndroid Build Coastguard Worker     this->groups_ = groups;
179*4bdc9457SAndroid Build Coastguard Worker     return *this;
180*4bdc9457SAndroid Build Coastguard Worker   }
181*4bdc9457SAndroid Build Coastguard Worker 
groups()182*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t groups() const {
183*4bdc9457SAndroid Build Coastguard Worker     return this->groups_;
184*4bdc9457SAndroid Build Coastguard Worker   }
185*4bdc9457SAndroid Build Coastguard Worker 
group_input_channels(size_t group_input_channels)186*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& group_input_channels(size_t group_input_channels) {
187*4bdc9457SAndroid Build Coastguard Worker     assert(group_input_channels >= 1);
188*4bdc9457SAndroid Build Coastguard Worker     this->group_input_channels_ = group_input_channels;
189*4bdc9457SAndroid Build Coastguard Worker     return *this;
190*4bdc9457SAndroid Build Coastguard Worker   }
191*4bdc9457SAndroid Build Coastguard Worker 
group_input_channels()192*4bdc9457SAndroid Build Coastguard Worker   inline size_t group_input_channels() const {
193*4bdc9457SAndroid Build Coastguard Worker     return this->group_input_channels_;
194*4bdc9457SAndroid Build Coastguard Worker   }
195*4bdc9457SAndroid Build Coastguard Worker 
group_output_channels(size_t group_output_channels)196*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& group_output_channels(size_t group_output_channels) {
197*4bdc9457SAndroid Build Coastguard Worker     assert(group_output_channels >= 1);
198*4bdc9457SAndroid Build Coastguard Worker     this->group_output_channels_ = group_output_channels;
199*4bdc9457SAndroid Build Coastguard Worker     return *this;
200*4bdc9457SAndroid Build Coastguard Worker   }
201*4bdc9457SAndroid Build Coastguard Worker 
group_output_channels()202*4bdc9457SAndroid Build Coastguard Worker   inline size_t group_output_channels() const {
203*4bdc9457SAndroid Build Coastguard Worker     return this->group_output_channels_;
204*4bdc9457SAndroid Build Coastguard Worker   }
205*4bdc9457SAndroid Build Coastguard Worker 
batch_size(size_t batch_size)206*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& batch_size(size_t batch_size) {
207*4bdc9457SAndroid Build Coastguard Worker     assert(batch_size >= 1);
208*4bdc9457SAndroid Build Coastguard Worker     this->batch_size_ = batch_size;
209*4bdc9457SAndroid Build Coastguard Worker     return *this;
210*4bdc9457SAndroid Build Coastguard Worker   }
211*4bdc9457SAndroid Build Coastguard Worker 
batch_size()212*4bdc9457SAndroid Build Coastguard Worker   inline size_t batch_size() const {
213*4bdc9457SAndroid Build Coastguard Worker     return this->batch_size_;
214*4bdc9457SAndroid Build Coastguard Worker   }
215*4bdc9457SAndroid Build Coastguard Worker 
kernel_size(uint32_t kernel_size)216*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& kernel_size(uint32_t kernel_size) {
217*4bdc9457SAndroid Build Coastguard Worker     assert(kernel_size >= 1);
218*4bdc9457SAndroid Build Coastguard Worker     this->kernel_height_ = kernel_size;
219*4bdc9457SAndroid Build Coastguard Worker     this->kernel_width_ = kernel_size;
220*4bdc9457SAndroid Build Coastguard Worker     return *this;
221*4bdc9457SAndroid Build Coastguard Worker   }
222*4bdc9457SAndroid Build Coastguard Worker 
kernel_size(uint32_t kernel_height,uint32_t kernel_width)223*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& kernel_size(uint32_t kernel_height, uint32_t kernel_width) {
224*4bdc9457SAndroid Build Coastguard Worker     assert(kernel_height >= 1);
225*4bdc9457SAndroid Build Coastguard Worker     assert(kernel_width >= 1);
226*4bdc9457SAndroid Build Coastguard Worker     this->kernel_height_ = kernel_height;
227*4bdc9457SAndroid Build Coastguard Worker     this->kernel_width_ = kernel_width;
228*4bdc9457SAndroid Build Coastguard Worker     return *this;
229*4bdc9457SAndroid Build Coastguard Worker   }
230*4bdc9457SAndroid Build Coastguard Worker 
kernel_height(uint32_t kernel_height)231*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& kernel_height(uint32_t kernel_height) {
232*4bdc9457SAndroid Build Coastguard Worker     assert(kernel_height >= 1);
233*4bdc9457SAndroid Build Coastguard Worker     this->kernel_height_ = kernel_height;
234*4bdc9457SAndroid Build Coastguard Worker     return *this;
235*4bdc9457SAndroid Build Coastguard Worker   }
236*4bdc9457SAndroid Build Coastguard Worker 
kernel_height()237*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t kernel_height() const {
238*4bdc9457SAndroid Build Coastguard Worker     return this->kernel_height_;
239*4bdc9457SAndroid Build Coastguard Worker   }
240*4bdc9457SAndroid Build Coastguard Worker 
kernel_width(uint32_t kernel_width)241*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& kernel_width(uint32_t kernel_width) {
242*4bdc9457SAndroid Build Coastguard Worker     assert(kernel_width >= 1);
243*4bdc9457SAndroid Build Coastguard Worker     this->kernel_width_ = kernel_width;
244*4bdc9457SAndroid Build Coastguard Worker     return *this;
245*4bdc9457SAndroid Build Coastguard Worker   }
246*4bdc9457SAndroid Build Coastguard Worker 
kernel_width()247*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t kernel_width() const {
248*4bdc9457SAndroid Build Coastguard Worker     return this->kernel_width_;
249*4bdc9457SAndroid Build Coastguard Worker   }
250*4bdc9457SAndroid Build Coastguard Worker 
dilation(uint32_t dilation)251*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& dilation(uint32_t dilation) {
252*4bdc9457SAndroid Build Coastguard Worker     assert(dilation >= 1);
253*4bdc9457SAndroid Build Coastguard Worker     this->dilation_height_ = dilation;
254*4bdc9457SAndroid Build Coastguard Worker     this->dilation_width_ = dilation;
255*4bdc9457SAndroid Build Coastguard Worker     return *this;
256*4bdc9457SAndroid Build Coastguard Worker   }
257*4bdc9457SAndroid Build Coastguard Worker 
dilation(uint32_t dilation_height,uint32_t dilation_width)258*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& dilation(uint32_t dilation_height, uint32_t dilation_width) {
259*4bdc9457SAndroid Build Coastguard Worker     assert(dilation_height >= 1);
260*4bdc9457SAndroid Build Coastguard Worker     assert(dilation_width >= 1);
261*4bdc9457SAndroid Build Coastguard Worker     this->dilation_height_ = dilation_height;
262*4bdc9457SAndroid Build Coastguard Worker     this->dilation_width_ = dilation_width;
263*4bdc9457SAndroid Build Coastguard Worker     return *this;
264*4bdc9457SAndroid Build Coastguard Worker   }
265*4bdc9457SAndroid Build Coastguard Worker 
dilation_height(uint32_t dilation_height)266*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& dilation_height(uint32_t dilation_height) {
267*4bdc9457SAndroid Build Coastguard Worker     assert(dilation_height >= 1);
268*4bdc9457SAndroid Build Coastguard Worker     this->dilation_height_ = dilation_height;
269*4bdc9457SAndroid Build Coastguard Worker     return *this;
270*4bdc9457SAndroid Build Coastguard Worker   }
271*4bdc9457SAndroid Build Coastguard Worker 
dilation_height()272*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t dilation_height() const {
273*4bdc9457SAndroid Build Coastguard Worker     return this->dilation_height_;
274*4bdc9457SAndroid Build Coastguard Worker   }
275*4bdc9457SAndroid Build Coastguard Worker 
dilation_width(uint32_t dilation_width)276*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& dilation_width(uint32_t dilation_width) {
277*4bdc9457SAndroid Build Coastguard Worker     assert(dilation_width >= 1);
278*4bdc9457SAndroid Build Coastguard Worker     this->dilation_width_ = dilation_width;
279*4bdc9457SAndroid Build Coastguard Worker     return *this;
280*4bdc9457SAndroid Build Coastguard Worker   }
281*4bdc9457SAndroid Build Coastguard Worker 
dilation_width()282*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t dilation_width() const {
283*4bdc9457SAndroid Build Coastguard Worker     return this->dilation_width_;
284*4bdc9457SAndroid Build Coastguard Worker   }
285*4bdc9457SAndroid Build Coastguard Worker 
subsampling(uint32_t subsampling)286*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& subsampling(uint32_t subsampling) {
287*4bdc9457SAndroid Build Coastguard Worker     assert(subsampling >= 1);
288*4bdc9457SAndroid Build Coastguard Worker     this->subsampling_height_ = subsampling;
289*4bdc9457SAndroid Build Coastguard Worker     this->subsampling_width_ = subsampling;
290*4bdc9457SAndroid Build Coastguard Worker     return *this;
291*4bdc9457SAndroid Build Coastguard Worker   }
292*4bdc9457SAndroid Build Coastguard Worker 
subsampling(uint32_t subsampling_height,uint32_t subsampling_width)293*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& subsampling(uint32_t subsampling_height, uint32_t subsampling_width) {
294*4bdc9457SAndroid Build Coastguard Worker     assert(subsampling_height >= 1);
295*4bdc9457SAndroid Build Coastguard Worker     assert(subsampling_width >= 1);
296*4bdc9457SAndroid Build Coastguard Worker     this->subsampling_height_ = subsampling_height;
297*4bdc9457SAndroid Build Coastguard Worker     this->subsampling_width_ = subsampling_width;
298*4bdc9457SAndroid Build Coastguard Worker     return *this;
299*4bdc9457SAndroid Build Coastguard Worker   }
300*4bdc9457SAndroid Build Coastguard Worker 
subsampling_height(uint32_t subsampling_height)301*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& subsampling_height(uint32_t subsampling_height) {
302*4bdc9457SAndroid Build Coastguard Worker     assert(subsampling_height >= 1);
303*4bdc9457SAndroid Build Coastguard Worker     this->subsampling_height_ = subsampling_height;
304*4bdc9457SAndroid Build Coastguard Worker     return *this;
305*4bdc9457SAndroid Build Coastguard Worker   }
306*4bdc9457SAndroid Build Coastguard Worker 
subsampling_height()307*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t subsampling_height() const {
308*4bdc9457SAndroid Build Coastguard Worker     return this->subsampling_height_;
309*4bdc9457SAndroid Build Coastguard Worker   }
310*4bdc9457SAndroid Build Coastguard Worker 
subsampling_width(uint32_t subsampling_width)311*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& subsampling_width(uint32_t subsampling_width) {
312*4bdc9457SAndroid Build Coastguard Worker     assert(subsampling_width >= 1);
313*4bdc9457SAndroid Build Coastguard Worker     this->subsampling_width_ = subsampling_width;
314*4bdc9457SAndroid Build Coastguard Worker     return *this;
315*4bdc9457SAndroid Build Coastguard Worker   }
316*4bdc9457SAndroid Build Coastguard Worker 
subsampling_width()317*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t subsampling_width() const {
318*4bdc9457SAndroid Build Coastguard Worker     return this->subsampling_width_;
319*4bdc9457SAndroid Build Coastguard Worker   }
320*4bdc9457SAndroid Build Coastguard Worker 
input_channel_stride(size_t input_channel_stride)321*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& input_channel_stride(size_t input_channel_stride) {
322*4bdc9457SAndroid Build Coastguard Worker     assert(input_channel_stride >= 1);
323*4bdc9457SAndroid Build Coastguard Worker     this->input_channel_stride_ = input_channel_stride;
324*4bdc9457SAndroid Build Coastguard Worker     return *this;
325*4bdc9457SAndroid Build Coastguard Worker   }
326*4bdc9457SAndroid Build Coastguard Worker 
input_channel_stride()327*4bdc9457SAndroid Build Coastguard Worker   inline size_t input_channel_stride() const {
328*4bdc9457SAndroid Build Coastguard Worker     if (this->input_channel_stride_ == 0) {
329*4bdc9457SAndroid Build Coastguard Worker       return group_input_channels() * groups();
330*4bdc9457SAndroid Build Coastguard Worker     } else {
331*4bdc9457SAndroid Build Coastguard Worker       assert(this->input_channel_stride_ >= group_input_channels() * groups());
332*4bdc9457SAndroid Build Coastguard Worker       return this->input_channel_stride_;
333*4bdc9457SAndroid Build Coastguard Worker     }
334*4bdc9457SAndroid Build Coastguard Worker   }
335*4bdc9457SAndroid Build Coastguard Worker 
output_channel_stride(size_t output_channel_stride)336*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& output_channel_stride(size_t output_channel_stride) {
337*4bdc9457SAndroid Build Coastguard Worker     assert(output_channel_stride >= 1);
338*4bdc9457SAndroid Build Coastguard Worker     this->output_channel_stride_ = output_channel_stride;
339*4bdc9457SAndroid Build Coastguard Worker     return *this;
340*4bdc9457SAndroid Build Coastguard Worker   }
341*4bdc9457SAndroid Build Coastguard Worker 
output_channel_stride()342*4bdc9457SAndroid Build Coastguard Worker   inline size_t output_channel_stride() const {
343*4bdc9457SAndroid Build Coastguard Worker     if (this->output_channel_stride_ == 0) {
344*4bdc9457SAndroid Build Coastguard Worker       return group_output_channels() * groups();
345*4bdc9457SAndroid Build Coastguard Worker     } else {
346*4bdc9457SAndroid Build Coastguard Worker       assert(this->output_channel_stride_ >= group_output_channels() * groups());
347*4bdc9457SAndroid Build Coastguard Worker       return this->output_channel_stride_;
348*4bdc9457SAndroid Build Coastguard Worker     }
349*4bdc9457SAndroid Build Coastguard Worker   }
350*4bdc9457SAndroid Build Coastguard Worker 
dilated_kernel_height()351*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t dilated_kernel_height() const {
352*4bdc9457SAndroid Build Coastguard Worker     return (kernel_height() - 1) * dilation_height() + 1;
353*4bdc9457SAndroid Build Coastguard Worker   }
354*4bdc9457SAndroid Build Coastguard Worker 
dilated_kernel_width()355*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t dilated_kernel_width() const {
356*4bdc9457SAndroid Build Coastguard Worker     return (kernel_width() - 1) * dilation_width() + 1;
357*4bdc9457SAndroid Build Coastguard Worker   }
358*4bdc9457SAndroid Build Coastguard Worker 
output_height()359*4bdc9457SAndroid Build Coastguard Worker   inline size_t output_height() const {
360*4bdc9457SAndroid Build Coastguard Worker     if (padding_tf_same()) {
361*4bdc9457SAndroid Build Coastguard Worker       return (input_height() + subsampling_height() - 1) / subsampling_height();
362*4bdc9457SAndroid Build Coastguard Worker     } else {
363*4bdc9457SAndroid Build Coastguard Worker       const size_t padded_input_height = padding_top() + input_height() + padding_bottom();
364*4bdc9457SAndroid Build Coastguard Worker       if (padded_input_height <= dilated_kernel_height()) {
365*4bdc9457SAndroid Build Coastguard Worker         return 1;
366*4bdc9457SAndroid Build Coastguard Worker       } else {
367*4bdc9457SAndroid Build Coastguard Worker         return (padded_input_height - dilated_kernel_height()) / subsampling_height() + 1;
368*4bdc9457SAndroid Build Coastguard Worker       }
369*4bdc9457SAndroid Build Coastguard Worker     }
370*4bdc9457SAndroid Build Coastguard Worker   }
371*4bdc9457SAndroid Build Coastguard Worker 
output_width()372*4bdc9457SAndroid Build Coastguard Worker   inline size_t output_width() const {
373*4bdc9457SAndroid Build Coastguard Worker     if (padding_tf_same()) {
374*4bdc9457SAndroid Build Coastguard Worker       return (input_width() + subsampling_width() - 1) / subsampling_width();
375*4bdc9457SAndroid Build Coastguard Worker     } else {
376*4bdc9457SAndroid Build Coastguard Worker       const size_t padded_input_width = padding_left() + input_width() + padding_right();
377*4bdc9457SAndroid Build Coastguard Worker       if (padded_input_width <= dilated_kernel_width()) {
378*4bdc9457SAndroid Build Coastguard Worker         return 1;
379*4bdc9457SAndroid Build Coastguard Worker       } else {
380*4bdc9457SAndroid Build Coastguard Worker         return (padded_input_width - dilated_kernel_width()) / subsampling_width() + 1;
381*4bdc9457SAndroid Build Coastguard Worker       }
382*4bdc9457SAndroid Build Coastguard Worker     }
383*4bdc9457SAndroid Build Coastguard Worker   }
384*4bdc9457SAndroid Build Coastguard Worker 
next_input_size(uint32_t next_input_height,uint32_t next_input_width)385*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& next_input_size(uint32_t next_input_height, uint32_t next_input_width) {
386*4bdc9457SAndroid Build Coastguard Worker     assert(next_input_height >= 1);
387*4bdc9457SAndroid Build Coastguard Worker     assert(next_input_width >= 1);
388*4bdc9457SAndroid Build Coastguard Worker     this->next_input_height_ = next_input_height;
389*4bdc9457SAndroid Build Coastguard Worker     this->next_input_width_ = next_input_width;
390*4bdc9457SAndroid Build Coastguard Worker     return *this;
391*4bdc9457SAndroid Build Coastguard Worker   }
392*4bdc9457SAndroid Build Coastguard Worker 
next_input_height(uint32_t next_input_height)393*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& next_input_height(uint32_t next_input_height) {
394*4bdc9457SAndroid Build Coastguard Worker     assert(next_input_height >= 1);
395*4bdc9457SAndroid Build Coastguard Worker     this->next_input_height_ = next_input_height;
396*4bdc9457SAndroid Build Coastguard Worker     return *this;
397*4bdc9457SAndroid Build Coastguard Worker   }
398*4bdc9457SAndroid Build Coastguard Worker 
next_input_height()399*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t next_input_height() const {
400*4bdc9457SAndroid Build Coastguard Worker     if (this->next_input_height_ == 0) {
401*4bdc9457SAndroid Build Coastguard Worker       return input_height();
402*4bdc9457SAndroid Build Coastguard Worker     } else {
403*4bdc9457SAndroid Build Coastguard Worker       return this->next_input_height_;
404*4bdc9457SAndroid Build Coastguard Worker     }
405*4bdc9457SAndroid Build Coastguard Worker   }
406*4bdc9457SAndroid Build Coastguard Worker 
next_input_width(uint32_t next_input_width)407*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& next_input_width(uint32_t next_input_width) {
408*4bdc9457SAndroid Build Coastguard Worker     assert(next_input_width >= 1);
409*4bdc9457SAndroid Build Coastguard Worker     this->next_input_width_ = next_input_width;
410*4bdc9457SAndroid Build Coastguard Worker     return *this;
411*4bdc9457SAndroid Build Coastguard Worker   }
412*4bdc9457SAndroid Build Coastguard Worker 
next_input_width()413*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t next_input_width() const {
414*4bdc9457SAndroid Build Coastguard Worker     if (this->next_input_width_ == 0) {
415*4bdc9457SAndroid Build Coastguard Worker       return input_width();
416*4bdc9457SAndroid Build Coastguard Worker     } else {
417*4bdc9457SAndroid Build Coastguard Worker       return this->next_input_width_;
418*4bdc9457SAndroid Build Coastguard Worker     }
419*4bdc9457SAndroid Build Coastguard Worker   }
420*4bdc9457SAndroid Build Coastguard Worker 
next_output_height()421*4bdc9457SAndroid Build Coastguard Worker   inline size_t next_output_height() const {
422*4bdc9457SAndroid Build Coastguard Worker     const size_t padded_input_height = padding_top() + next_input_height() + padding_bottom();
423*4bdc9457SAndroid Build Coastguard Worker     if (padded_input_height <= dilated_kernel_height()) {
424*4bdc9457SAndroid Build Coastguard Worker       return 1;
425*4bdc9457SAndroid Build Coastguard Worker     } else {
426*4bdc9457SAndroid Build Coastguard Worker       return (padded_input_height - dilated_kernel_height()) / subsampling_height() + 1;
427*4bdc9457SAndroid Build Coastguard Worker     }
428*4bdc9457SAndroid Build Coastguard Worker   }
429*4bdc9457SAndroid Build Coastguard Worker 
next_output_width()430*4bdc9457SAndroid Build Coastguard Worker   inline size_t next_output_width() const {
431*4bdc9457SAndroid Build Coastguard Worker     const size_t padded_input_width = padding_left() + next_input_width() + padding_right();
432*4bdc9457SAndroid Build Coastguard Worker     if (padded_input_width <= dilated_kernel_width()) {
433*4bdc9457SAndroid Build Coastguard Worker       return 1;
434*4bdc9457SAndroid Build Coastguard Worker     } else {
435*4bdc9457SAndroid Build Coastguard Worker       return (padded_input_width - dilated_kernel_width()) / subsampling_width() + 1;
436*4bdc9457SAndroid Build Coastguard Worker     }
437*4bdc9457SAndroid Build Coastguard Worker   }
438*4bdc9457SAndroid Build Coastguard Worker 
next_batch_size(size_t next_batch_size)439*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& next_batch_size(size_t next_batch_size) {
440*4bdc9457SAndroid Build Coastguard Worker     assert(next_batch_size >= 1);
441*4bdc9457SAndroid Build Coastguard Worker     this->next_batch_size_ = next_batch_size;
442*4bdc9457SAndroid Build Coastguard Worker     return *this;
443*4bdc9457SAndroid Build Coastguard Worker   }
444*4bdc9457SAndroid Build Coastguard Worker 
next_batch_size()445*4bdc9457SAndroid Build Coastguard Worker   inline size_t next_batch_size() const {
446*4bdc9457SAndroid Build Coastguard Worker     if (this->next_batch_size_ == 0) {
447*4bdc9457SAndroid Build Coastguard Worker       return batch_size();
448*4bdc9457SAndroid Build Coastguard Worker     } else {
449*4bdc9457SAndroid Build Coastguard Worker       return this->next_batch_size_;
450*4bdc9457SAndroid Build Coastguard Worker     }
451*4bdc9457SAndroid Build Coastguard Worker   }
452*4bdc9457SAndroid Build Coastguard Worker 
sparsity(float sparsity)453*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& sparsity(float sparsity) {
454*4bdc9457SAndroid Build Coastguard Worker     this->sparsity_ = sparsity;
455*4bdc9457SAndroid Build Coastguard Worker     return *this;
456*4bdc9457SAndroid Build Coastguard Worker   }
457*4bdc9457SAndroid Build Coastguard Worker 
sparsity()458*4bdc9457SAndroid Build Coastguard Worker   inline float sparsity() const {
459*4bdc9457SAndroid Build Coastguard Worker     return this->sparsity_;
460*4bdc9457SAndroid Build Coastguard Worker   }
461*4bdc9457SAndroid Build Coastguard Worker 
qmin(uint8_t qmin)462*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& qmin(uint8_t qmin) {
463*4bdc9457SAndroid Build Coastguard Worker     this->qmin_ = qmin;
464*4bdc9457SAndroid Build Coastguard Worker     return *this;
465*4bdc9457SAndroid Build Coastguard Worker   }
466*4bdc9457SAndroid Build Coastguard Worker 
qmin()467*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t qmin() const {
468*4bdc9457SAndroid Build Coastguard Worker     return this->qmin_;
469*4bdc9457SAndroid Build Coastguard Worker   }
470*4bdc9457SAndroid Build Coastguard Worker 
qmax(uint8_t qmax)471*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& qmax(uint8_t qmax) {
472*4bdc9457SAndroid Build Coastguard Worker     this->qmax_ = qmax;
473*4bdc9457SAndroid Build Coastguard Worker     return *this;
474*4bdc9457SAndroid Build Coastguard Worker   }
475*4bdc9457SAndroid Build Coastguard Worker 
qmax()476*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t qmax() const {
477*4bdc9457SAndroid Build Coastguard Worker     return this->qmax_;
478*4bdc9457SAndroid Build Coastguard Worker   }
479*4bdc9457SAndroid Build Coastguard Worker 
force_nhwc_input(bool force_nhwc_input)480*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& force_nhwc_input(bool force_nhwc_input) {
481*4bdc9457SAndroid Build Coastguard Worker     this->force_nhwc_input_ = force_nhwc_input;
482*4bdc9457SAndroid Build Coastguard Worker     return *this;
483*4bdc9457SAndroid Build Coastguard Worker   }
484*4bdc9457SAndroid Build Coastguard Worker 
force_nhwc_input()485*4bdc9457SAndroid Build Coastguard Worker   inline bool force_nhwc_input() const {
486*4bdc9457SAndroid Build Coastguard Worker     return this->force_nhwc_input_;
487*4bdc9457SAndroid Build Coastguard Worker   }
488*4bdc9457SAndroid Build Coastguard Worker 
depthwise_layout(bool depthwise_layout)489*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& depthwise_layout(bool depthwise_layout) {
490*4bdc9457SAndroid Build Coastguard Worker     this->depthwise_layout_ = depthwise_layout;
491*4bdc9457SAndroid Build Coastguard Worker     return *this;
492*4bdc9457SAndroid Build Coastguard Worker   }
493*4bdc9457SAndroid Build Coastguard Worker 
depthwise_layout()494*4bdc9457SAndroid Build Coastguard Worker   inline bool depthwise_layout() const {
495*4bdc9457SAndroid Build Coastguard Worker     return this->depthwise_layout_;
496*4bdc9457SAndroid Build Coastguard Worker   }
497*4bdc9457SAndroid Build Coastguard Worker 
has_bias(bool has_bias)498*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& has_bias(bool has_bias) {
499*4bdc9457SAndroid Build Coastguard Worker     this->has_bias_ = has_bias;
500*4bdc9457SAndroid Build Coastguard Worker     return *this;
501*4bdc9457SAndroid Build Coastguard Worker   }
502*4bdc9457SAndroid Build Coastguard Worker 
has_bias()503*4bdc9457SAndroid Build Coastguard Worker   inline bool has_bias() const {
504*4bdc9457SAndroid Build Coastguard Worker     return this->has_bias_;
505*4bdc9457SAndroid Build Coastguard Worker   }
506*4bdc9457SAndroid Build Coastguard Worker 
weights_type(WeightsType weights_type)507*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& weights_type(WeightsType weights_type) {
508*4bdc9457SAndroid Build Coastguard Worker     this->weights_type_ = weights_type;
509*4bdc9457SAndroid Build Coastguard Worker     return *this;
510*4bdc9457SAndroid Build Coastguard Worker   }
511*4bdc9457SAndroid Build Coastguard Worker 
weights_type()512*4bdc9457SAndroid Build Coastguard Worker   inline WeightsType weights_type() const {
513*4bdc9457SAndroid Build Coastguard Worker     return this->weights_type_;
514*4bdc9457SAndroid Build Coastguard Worker   }
515*4bdc9457SAndroid Build Coastguard Worker 
iterations(size_t iterations)516*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& iterations(size_t iterations) {
517*4bdc9457SAndroid Build Coastguard Worker     this->iterations_ = iterations;
518*4bdc9457SAndroid Build Coastguard Worker     return *this;
519*4bdc9457SAndroid Build Coastguard Worker   }
520*4bdc9457SAndroid Build Coastguard Worker 
iterations()521*4bdc9457SAndroid Build Coastguard Worker   inline size_t iterations() const {
522*4bdc9457SAndroid Build Coastguard Worker     return this->iterations_;
523*4bdc9457SAndroid Build Coastguard Worker   }
524*4bdc9457SAndroid Build Coastguard Worker 
525*4bdc9457SAndroid Build Coastguard Worker #if XNN_PLATFORM_JIT
use_jit(bool use_jit)526*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& use_jit(bool use_jit) {
527*4bdc9457SAndroid Build Coastguard Worker     this->use_jit_ = use_jit;
528*4bdc9457SAndroid Build Coastguard Worker     return *this;
529*4bdc9457SAndroid Build Coastguard Worker   }
530*4bdc9457SAndroid Build Coastguard Worker 
use_jit()531*4bdc9457SAndroid Build Coastguard Worker   inline bool use_jit() const {
532*4bdc9457SAndroid Build Coastguard Worker     return this->use_jit_;
533*4bdc9457SAndroid Build Coastguard Worker   }
534*4bdc9457SAndroid Build Coastguard Worker #endif
535*4bdc9457SAndroid Build Coastguard Worker 
use_weights_cache(bool use_weights_cache)536*4bdc9457SAndroid Build Coastguard Worker   inline ConvolutionOperatorTester& use_weights_cache(bool use_weights_cache) {
537*4bdc9457SAndroid Build Coastguard Worker     this->use_weights_cache_ = use_weights_cache;
538*4bdc9457SAndroid Build Coastguard Worker     return *this;
539*4bdc9457SAndroid Build Coastguard Worker   }
540*4bdc9457SAndroid Build Coastguard Worker 
use_weights_cache()541*4bdc9457SAndroid Build Coastguard Worker   inline bool use_weights_cache() const {
542*4bdc9457SAndroid Build Coastguard Worker     return this->use_weights_cache_;
543*4bdc9457SAndroid Build Coastguard Worker   }
544*4bdc9457SAndroid Build Coastguard Worker 
TestNHWCxQC8()545*4bdc9457SAndroid Build Coastguard Worker   void TestNHWCxQC8() const {
546*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(weights_type(), WeightsType::Default);
547*4bdc9457SAndroid Build Coastguard Worker 
548*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
549*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
550*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i32dist(-10000, 10000);
551*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i8dist(
552*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
553*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> w8dist(
554*4bdc9457SAndroid Build Coastguard Worker       -std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max());
555*4bdc9457SAndroid Build Coastguard Worker 
556*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) +
557*4bdc9457SAndroid Build Coastguard Worker       batch_size() * ((input_height() * input_width() - 1) * input_channel_stride() + groups() * group_input_channels()));
558*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels());
559*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> bias(groups() * group_output_channels());
560*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> output(batch_size() * ((output_height() * output_width() - 1) * output_channel_stride() + groups() * group_output_channels()));
561*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> accumulators(batch_size() * output_height() * output_width() * groups() * group_output_channels());
562*4bdc9457SAndroid Build Coastguard Worker     std::vector<double> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels());
563*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> requantization_scales(groups() * group_output_channels());
564*4bdc9457SAndroid Build Coastguard Worker 
565*4bdc9457SAndroid Build Coastguard Worker     const int8_t input_zero_point = -1;
566*4bdc9457SAndroid Build Coastguard Worker     const int8_t output_zero_point = -1;
567*4bdc9457SAndroid Build Coastguard Worker 
568*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
569*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
570*4bdc9457SAndroid Build Coastguard Worker       std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); });
571*4bdc9457SAndroid Build Coastguard Worker       std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); });
572*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), INT8_C(0xA5));
573*4bdc9457SAndroid Build Coastguard Worker 
574*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without renormalization.
575*4bdc9457SAndroid Build Coastguard Worker       if (depthwise_layout()) {
576*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(group_input_channels(), 1);
577*4bdc9457SAndroid Build Coastguard Worker         xnnpack::compute_depthwise_convolution_qs8_reference_results(
578*4bdc9457SAndroid Build Coastguard Worker           batch_size(),
579*4bdc9457SAndroid Build Coastguard Worker           output_height(),
580*4bdc9457SAndroid Build Coastguard Worker           output_width(),
581*4bdc9457SAndroid Build Coastguard Worker           input_height(),
582*4bdc9457SAndroid Build Coastguard Worker           input_width(),
583*4bdc9457SAndroid Build Coastguard Worker           padding_top(),
584*4bdc9457SAndroid Build Coastguard Worker           padding_right(),
585*4bdc9457SAndroid Build Coastguard Worker           padding_bottom(),
586*4bdc9457SAndroid Build Coastguard Worker           padding_left(),
587*4bdc9457SAndroid Build Coastguard Worker           kernel_height(),
588*4bdc9457SAndroid Build Coastguard Worker           kernel_width(),
589*4bdc9457SAndroid Build Coastguard Worker           subsampling_height(),
590*4bdc9457SAndroid Build Coastguard Worker           subsampling_width(),
591*4bdc9457SAndroid Build Coastguard Worker           dilation_height(),
592*4bdc9457SAndroid Build Coastguard Worker           dilation_width(),
593*4bdc9457SAndroid Build Coastguard Worker           groups(),
594*4bdc9457SAndroid Build Coastguard Worker           group_output_channels(),
595*4bdc9457SAndroid Build Coastguard Worker           input_channel_stride(),
596*4bdc9457SAndroid Build Coastguard Worker           input_zero_point,
597*4bdc9457SAndroid Build Coastguard Worker           input,
598*4bdc9457SAndroid Build Coastguard Worker           kernel,
599*4bdc9457SAndroid Build Coastguard Worker           accumulators,
600*4bdc9457SAndroid Build Coastguard Worker           has_bias(),
601*4bdc9457SAndroid Build Coastguard Worker           bias);
602*4bdc9457SAndroid Build Coastguard Worker       } else {
603*4bdc9457SAndroid Build Coastguard Worker         xnnpack::compute_convolution_qs8_reference_results(
604*4bdc9457SAndroid Build Coastguard Worker           batch_size(),
605*4bdc9457SAndroid Build Coastguard Worker           output_height(),
606*4bdc9457SAndroid Build Coastguard Worker           output_width(),
607*4bdc9457SAndroid Build Coastguard Worker           input_height(),
608*4bdc9457SAndroid Build Coastguard Worker           input_width(),
609*4bdc9457SAndroid Build Coastguard Worker           padding_top(),
610*4bdc9457SAndroid Build Coastguard Worker           padding_right(),
611*4bdc9457SAndroid Build Coastguard Worker           padding_bottom(),
612*4bdc9457SAndroid Build Coastguard Worker           padding_left(),
613*4bdc9457SAndroid Build Coastguard Worker           kernel_height(),
614*4bdc9457SAndroid Build Coastguard Worker           kernel_width(),
615*4bdc9457SAndroid Build Coastguard Worker           subsampling_height(),
616*4bdc9457SAndroid Build Coastguard Worker           subsampling_width(),
617*4bdc9457SAndroid Build Coastguard Worker           dilation_height(),
618*4bdc9457SAndroid Build Coastguard Worker           dilation_width(),
619*4bdc9457SAndroid Build Coastguard Worker           groups(),
620*4bdc9457SAndroid Build Coastguard Worker           group_input_channels(),
621*4bdc9457SAndroid Build Coastguard Worker           group_output_channels(),
622*4bdc9457SAndroid Build Coastguard Worker           input_channel_stride(),
623*4bdc9457SAndroid Build Coastguard Worker           input_zero_point,
624*4bdc9457SAndroid Build Coastguard Worker           input,
625*4bdc9457SAndroid Build Coastguard Worker           kernel,
626*4bdc9457SAndroid Build Coastguard Worker           accumulators,
627*4bdc9457SAndroid Build Coastguard Worker           has_bias(),
628*4bdc9457SAndroid Build Coastguard Worker           bias);
629*4bdc9457SAndroid Build Coastguard Worker       }
630*4bdc9457SAndroid Build Coastguard Worker 
631*4bdc9457SAndroid Build Coastguard Worker       // Compute renormalization parameters.
632*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < groups() * group_output_channels(); c++) {
633*4bdc9457SAndroid Build Coastguard Worker         int32_t accumulated_min = accumulators[c];
634*4bdc9457SAndroid Build Coastguard Worker         int32_t accumulated_max = accumulators[c];
635*4bdc9457SAndroid Build Coastguard Worker         for (size_t px = 0; px < batch_size() * output_height() * output_width(); px++) {
636*4bdc9457SAndroid Build Coastguard Worker           accumulated_min = std::min(accumulated_min, accumulators[px * groups() * group_output_channels() + c]);
637*4bdc9457SAndroid Build Coastguard Worker           accumulated_max = std::max(accumulated_max, accumulators[px * groups() * group_output_channels() + c]);
638*4bdc9457SAndroid Build Coastguard Worker         }
639*4bdc9457SAndroid Build Coastguard Worker 
640*4bdc9457SAndroid Build Coastguard Worker         float requantization_scale = 0x1.0p-32f;
641*4bdc9457SAndroid Build Coastguard Worker         if (accumulated_max != 0) {
642*4bdc9457SAndroid Build Coastguard Worker           requantization_scale = std::max(requantization_scale,
643*4bdc9457SAndroid Build Coastguard Worker             float(int32_t(std::numeric_limits<int8_t>::max()) - int32_t(output_zero_point)) / float(accumulated_max));
644*4bdc9457SAndroid Build Coastguard Worker         }
645*4bdc9457SAndroid Build Coastguard Worker         if (accumulated_min != 0) {
646*4bdc9457SAndroid Build Coastguard Worker           requantization_scale = std::max(requantization_scale,
647*4bdc9457SAndroid Build Coastguard Worker             float(int32_t(std::numeric_limits<int8_t>::min()) - int32_t(output_zero_point)) / float(accumulated_min));
648*4bdc9457SAndroid Build Coastguard Worker         }
649*4bdc9457SAndroid Build Coastguard Worker         requantization_scale = std::min(requantization_scale, 0x1.FFFFFEp-1f);
650*4bdc9457SAndroid Build Coastguard Worker 
651*4bdc9457SAndroid Build Coastguard Worker         requantization_scales[c] = requantization_scale;
652*4bdc9457SAndroid Build Coastguard Worker       }
653*4bdc9457SAndroid Build Coastguard Worker 
654*4bdc9457SAndroid Build Coastguard Worker       // Renormalize reference results.
655*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < groups() * group_output_channels(); c++) {
656*4bdc9457SAndroid Build Coastguard Worker         for (size_t px = 0; px < batch_size() * output_height() * output_width(); px++) {
657*4bdc9457SAndroid Build Coastguard Worker           output_ref[px * groups() * group_output_channels() + c] = double(int32_t(output_zero_point)) +
658*4bdc9457SAndroid Build Coastguard Worker             double(accumulators[px * groups() * group_output_channels() + c]) * double(requantization_scales[c]);
659*4bdc9457SAndroid Build Coastguard Worker         }
660*4bdc9457SAndroid Build Coastguard Worker       }
661*4bdc9457SAndroid Build Coastguard Worker       std::transform(output_ref.cbegin(), output_ref.cend(), output_ref.begin(),
662*4bdc9457SAndroid Build Coastguard Worker         [this](double x) -> double {
663*4bdc9457SAndroid Build Coastguard Worker           return std::max<double>(std::min<double>(x, double(qmax() - 0x80)), double(qmin() - 0x80));
664*4bdc9457SAndroid Build Coastguard Worker         });
665*4bdc9457SAndroid Build Coastguard Worker 
666*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Convolution operator.
667*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
668*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t convolution_op = nullptr;
669*4bdc9457SAndroid Build Coastguard Worker       xnn_caches caches = {
670*4bdc9457SAndroid Build Coastguard Worker         .code_cache = NULL,
671*4bdc9457SAndroid Build Coastguard Worker         .weights_cache = NULL,
672*4bdc9457SAndroid Build Coastguard Worker       };
673*4bdc9457SAndroid Build Coastguard Worker       xnn_weights_cache weights_cache;
674*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
675*4bdc9457SAndroid Build Coastguard Worker         xnn_init_weights_cache(&weights_cache);
676*4bdc9457SAndroid Build Coastguard Worker         caches.weights_cache = &weights_cache;
677*4bdc9457SAndroid Build Coastguard Worker       }
678*4bdc9457SAndroid Build Coastguard Worker 
679*4bdc9457SAndroid Build Coastguard Worker       xnn_status status = xnn_create_convolution2d_nhwc_qc8(
680*4bdc9457SAndroid Build Coastguard Worker           padding_tf_same() ? 0 : padding_top(), padding_tf_same() ? 0 : padding_right(),
681*4bdc9457SAndroid Build Coastguard Worker           padding_tf_same() ? 0 : padding_bottom(), padding_tf_same() ? 0 : padding_left(),
682*4bdc9457SAndroid Build Coastguard Worker           kernel_height(), kernel_width(),
683*4bdc9457SAndroid Build Coastguard Worker           subsampling_height(), subsampling_width(),
684*4bdc9457SAndroid Build Coastguard Worker           dilation_height(), dilation_width(),
685*4bdc9457SAndroid Build Coastguard Worker           groups(), group_input_channels(), group_output_channels(),
686*4bdc9457SAndroid Build Coastguard Worker           input_channel_stride(), output_channel_stride(),
687*4bdc9457SAndroid Build Coastguard Worker           input_zero_point, 1.0f /* input scale */, requantization_scales.data(),
688*4bdc9457SAndroid Build Coastguard Worker           kernel.data(), has_bias() ? bias.data() : nullptr,
689*4bdc9457SAndroid Build Coastguard Worker           output_zero_point, 1.0f /* output scale */, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
690*4bdc9457SAndroid Build Coastguard Worker           (depthwise_layout() ? XNN_FLAG_DEPTHWISE_CONVOLUTION : 0) | (padding_tf_same() ? XNN_FLAG_TENSORFLOW_SAME_PADDING : 0),
691*4bdc9457SAndroid Build Coastguard Worker           &caches,
692*4bdc9457SAndroid Build Coastguard Worker           &convolution_op);
693*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
694*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
695*4bdc9457SAndroid Build Coastguard Worker       }
696*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
697*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, convolution_op);
698*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
699*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
700*4bdc9457SAndroid Build Coastguard Worker                   xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
701*4bdc9457SAndroid Build Coastguard Worker       }
702*4bdc9457SAndroid Build Coastguard Worker 
703*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete convolution_op.
704*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op, xnn_delete_operator);
705*4bdc9457SAndroid Build Coastguard Worker 
706*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
707*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_convolution2d_nhwc_qc8(
708*4bdc9457SAndroid Build Coastguard Worker           convolution_op,
709*4bdc9457SAndroid Build Coastguard Worker           batch_size(), input_height(), input_width(),
710*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
711*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
712*4bdc9457SAndroid Build Coastguard Worker 
713*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
714*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(convolution_op, nullptr /* thread pool */));
715*4bdc9457SAndroid Build Coastguard Worker 
716*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
717*4bdc9457SAndroid Build Coastguard Worker       VerifyNHWCxQC8(output, output_ref);
718*4bdc9457SAndroid Build Coastguard Worker 
719*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
720*4bdc9457SAndroid Build Coastguard Worker         xnn_operator_t convolution_op2 = nullptr;
721*4bdc9457SAndroid Build Coastguard Worker         size_t old_weights_cache_size = weights_cache.cache.weights.size;
722*4bdc9457SAndroid Build Coastguard Worker 
723*4bdc9457SAndroid Build Coastguard Worker         xnn_status status = xnn_create_convolution2d_nhwc_qc8(
724*4bdc9457SAndroid Build Coastguard Worker             padding_tf_same() ? 0 : padding_top(), padding_tf_same() ? 0 : padding_right(),
725*4bdc9457SAndroid Build Coastguard Worker             padding_tf_same() ? 0 : padding_bottom(), padding_tf_same() ? 0 : padding_left(),
726*4bdc9457SAndroid Build Coastguard Worker             kernel_height(), kernel_width(),
727*4bdc9457SAndroid Build Coastguard Worker             subsampling_height(), subsampling_width(),
728*4bdc9457SAndroid Build Coastguard Worker             dilation_height(), dilation_width(),
729*4bdc9457SAndroid Build Coastguard Worker             groups(), group_input_channels(), group_output_channels(),
730*4bdc9457SAndroid Build Coastguard Worker             input_channel_stride(), output_channel_stride(),
731*4bdc9457SAndroid Build Coastguard Worker             input_zero_point, 1.0f /* input scale */, requantization_scales.data(),
732*4bdc9457SAndroid Build Coastguard Worker             kernel.data(), has_bias() ? bias.data() : nullptr,
733*4bdc9457SAndroid Build Coastguard Worker             output_zero_point, 1.0f /* output scale */, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
734*4bdc9457SAndroid Build Coastguard Worker             (depthwise_layout() ? XNN_FLAG_DEPTHWISE_CONVOLUTION : 0) | (padding_tf_same() ? XNN_FLAG_TENSORFLOW_SAME_PADDING : 0),
735*4bdc9457SAndroid Build Coastguard Worker             &caches,
736*4bdc9457SAndroid Build Coastguard Worker             &convolution_op2);
737*4bdc9457SAndroid Build Coastguard Worker         (void) status;
738*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NE(nullptr, convolution_op2);
739*4bdc9457SAndroid Build Coastguard Worker 
740*4bdc9457SAndroid Build Coastguard Worker         // Smart pointer to automatically delete convolution_op2.
741*4bdc9457SAndroid Build Coastguard Worker         std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op2, xnn_delete_operator);
742*4bdc9457SAndroid Build Coastguard Worker         std::vector<int8_t> output2(output.size(), INT8_C(0xA5));
743*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
744*4bdc9457SAndroid Build Coastguard Worker                   xnn_setup_convolution2d_nhwc_qc8(
745*4bdc9457SAndroid Build Coastguard Worker                       convolution_op2,
746*4bdc9457SAndroid Build Coastguard Worker                       batch_size(), input_height(), input_width(),
747*4bdc9457SAndroid Build Coastguard Worker                       input.data(), output2.data(),
748*4bdc9457SAndroid Build Coastguard Worker                       nullptr /* thread pool */));
749*4bdc9457SAndroid Build Coastguard Worker 
750*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
751*4bdc9457SAndroid Build Coastguard Worker                   xnn_run_operator(convolution_op2, nullptr /* thread pool */));
752*4bdc9457SAndroid Build Coastguard Worker 
753*4bdc9457SAndroid Build Coastguard Worker         VerifyNHWCxQC8(output2, output_ref);
754*4bdc9457SAndroid Build Coastguard Worker         VerifyWeightsCache(weights_cache, old_weights_cache_size);
755*4bdc9457SAndroid Build Coastguard Worker         xnn_release_weights_cache(&weights_cache);
756*4bdc9457SAndroid Build Coastguard Worker       }
757*4bdc9457SAndroid Build Coastguard Worker     }
758*4bdc9457SAndroid Build Coastguard Worker   }
759*4bdc9457SAndroid Build Coastguard Worker 
VerifyNHWCxQC8(const std::vector<int8_t> & output,const std::vector<double> & output_ref)760*4bdc9457SAndroid Build Coastguard Worker   void VerifyNHWCxQC8(const std::vector<int8_t> &output,
761*4bdc9457SAndroid Build Coastguard Worker                       const std::vector<double> &output_ref) const {
762*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < batch_size(); i++) {
763*4bdc9457SAndroid Build Coastguard Worker       for (size_t y = 0; y < output_height(); y++) {
764*4bdc9457SAndroid Build Coastguard Worker         for (size_t x = 0; x < output_width(); x++) {
765*4bdc9457SAndroid Build Coastguard Worker           for (size_t g = 0; g < groups(); g++) {
766*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < group_output_channels(); c++) {
767*4bdc9457SAndroid Build Coastguard Worker               ASSERT_LE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmax() - 0x80))
768*4bdc9457SAndroid Build Coastguard Worker                 << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
769*4bdc9457SAndroid Build Coastguard Worker               ASSERT_GE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmin() - 0x80))
770*4bdc9457SAndroid Build Coastguard Worker                 << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
771*4bdc9457SAndroid Build Coastguard Worker               ASSERT_NEAR(
772*4bdc9457SAndroid Build Coastguard Worker                   output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c],
773*4bdc9457SAndroid Build Coastguard Worker                   double(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]),
774*4bdc9457SAndroid Build Coastguard Worker                   0.9)
775*4bdc9457SAndroid Build Coastguard Worker                 << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
776*4bdc9457SAndroid Build Coastguard Worker             }
777*4bdc9457SAndroid Build Coastguard Worker           }
778*4bdc9457SAndroid Build Coastguard Worker         }
779*4bdc9457SAndroid Build Coastguard Worker       }
780*4bdc9457SAndroid Build Coastguard Worker     }
781*4bdc9457SAndroid Build Coastguard Worker   }
782*4bdc9457SAndroid Build Coastguard Worker 
TestNHWCxQS8()783*4bdc9457SAndroid Build Coastguard Worker   void TestNHWCxQS8() const {
784*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(weights_type(), WeightsType::Default);
785*4bdc9457SAndroid Build Coastguard Worker 
786*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
787*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
788*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i32dist(-10000, 10000);
789*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i8dist(
790*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
791*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> w8dist(
792*4bdc9457SAndroid Build Coastguard Worker       -std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max());
793*4bdc9457SAndroid Build Coastguard Worker 
794*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) +
795*4bdc9457SAndroid Build Coastguard Worker       batch_size() * ((input_height() * input_width() - 1) * input_channel_stride() + groups() * group_input_channels()));
796*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels());
797*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> bias(groups() * group_output_channels());
798*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> output(batch_size() * ((output_height() * output_width() - 1) * output_channel_stride() + groups() * group_output_channels()));
799*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> accumulators(batch_size() * output_height() * output_width() * groups() * group_output_channels());
800*4bdc9457SAndroid Build Coastguard Worker     std::vector<double> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels());
801*4bdc9457SAndroid Build Coastguard Worker 
802*4bdc9457SAndroid Build Coastguard Worker     const int8_t input_zero_point = -1;
803*4bdc9457SAndroid Build Coastguard Worker 
804*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
805*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
806*4bdc9457SAndroid Build Coastguard Worker       std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); });
807*4bdc9457SAndroid Build Coastguard Worker       std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); });
808*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), INT8_C(0xA5));
809*4bdc9457SAndroid Build Coastguard Worker 
810*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without renormalization.
811*4bdc9457SAndroid Build Coastguard Worker       if (depthwise_layout()) {
812*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(group_input_channels(), 1);
813*4bdc9457SAndroid Build Coastguard Worker         xnnpack::compute_depthwise_convolution_qs8_reference_results(
814*4bdc9457SAndroid Build Coastguard Worker           batch_size(),
815*4bdc9457SAndroid Build Coastguard Worker           output_height(),
816*4bdc9457SAndroid Build Coastguard Worker           output_width(),
817*4bdc9457SAndroid Build Coastguard Worker           input_height(),
818*4bdc9457SAndroid Build Coastguard Worker           input_width(),
819*4bdc9457SAndroid Build Coastguard Worker           padding_top(),
820*4bdc9457SAndroid Build Coastguard Worker           padding_right(),
821*4bdc9457SAndroid Build Coastguard Worker           padding_bottom(),
822*4bdc9457SAndroid Build Coastguard Worker           padding_left(),
823*4bdc9457SAndroid Build Coastguard Worker           kernel_height(),
824*4bdc9457SAndroid Build Coastguard Worker           kernel_width(),
825*4bdc9457SAndroid Build Coastguard Worker           subsampling_height(),
826*4bdc9457SAndroid Build Coastguard Worker           subsampling_width(),
827*4bdc9457SAndroid Build Coastguard Worker           dilation_height(),
828*4bdc9457SAndroid Build Coastguard Worker           dilation_width(),
829*4bdc9457SAndroid Build Coastguard Worker           groups(),
830*4bdc9457SAndroid Build Coastguard Worker           group_output_channels(),
831*4bdc9457SAndroid Build Coastguard Worker           input_channel_stride(),
832*4bdc9457SAndroid Build Coastguard Worker           input_zero_point,
833*4bdc9457SAndroid Build Coastguard Worker           input,
834*4bdc9457SAndroid Build Coastguard Worker           kernel,
835*4bdc9457SAndroid Build Coastguard Worker           accumulators,
836*4bdc9457SAndroid Build Coastguard Worker           has_bias(),
837*4bdc9457SAndroid Build Coastguard Worker           bias);
838*4bdc9457SAndroid Build Coastguard Worker       } else {
839*4bdc9457SAndroid Build Coastguard Worker         xnnpack::compute_convolution_qs8_reference_results(
840*4bdc9457SAndroid Build Coastguard Worker           batch_size(),
841*4bdc9457SAndroid Build Coastguard Worker           output_height(),
842*4bdc9457SAndroid Build Coastguard Worker           output_width(),
843*4bdc9457SAndroid Build Coastguard Worker           input_height(),
844*4bdc9457SAndroid Build Coastguard Worker           input_width(),
845*4bdc9457SAndroid Build Coastguard Worker           padding_top(),
846*4bdc9457SAndroid Build Coastguard Worker           padding_right(),
847*4bdc9457SAndroid Build Coastguard Worker           padding_bottom(),
848*4bdc9457SAndroid Build Coastguard Worker           padding_left(),
849*4bdc9457SAndroid Build Coastguard Worker           kernel_height(),
850*4bdc9457SAndroid Build Coastguard Worker           kernel_width(),
851*4bdc9457SAndroid Build Coastguard Worker           subsampling_height(),
852*4bdc9457SAndroid Build Coastguard Worker           subsampling_width(),
853*4bdc9457SAndroid Build Coastguard Worker           dilation_height(),
854*4bdc9457SAndroid Build Coastguard Worker           dilation_width(),
855*4bdc9457SAndroid Build Coastguard Worker           groups(),
856*4bdc9457SAndroid Build Coastguard Worker           group_input_channels(),
857*4bdc9457SAndroid Build Coastguard Worker           group_output_channels(),
858*4bdc9457SAndroid Build Coastguard Worker           input_channel_stride(),
859*4bdc9457SAndroid Build Coastguard Worker           input_zero_point,
860*4bdc9457SAndroid Build Coastguard Worker           input,
861*4bdc9457SAndroid Build Coastguard Worker           kernel,
862*4bdc9457SAndroid Build Coastguard Worker           accumulators,
863*4bdc9457SAndroid Build Coastguard Worker           has_bias(),
864*4bdc9457SAndroid Build Coastguard Worker           bias);
865*4bdc9457SAndroid Build Coastguard Worker       }
866*4bdc9457SAndroid Build Coastguard Worker 
867*4bdc9457SAndroid Build Coastguard Worker       // Compute renormalization parameters.
868*4bdc9457SAndroid Build Coastguard Worker       const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
869*4bdc9457SAndroid Build Coastguard Worker       const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
870*4bdc9457SAndroid Build Coastguard Worker 
871*4bdc9457SAndroid Build Coastguard Worker       const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0;
872*4bdc9457SAndroid Build Coastguard Worker       const int8_t output_zero_point = int8_t(std::max(std::min(
873*4bdc9457SAndroid Build Coastguard Worker         lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
874*4bdc9457SAndroid Build Coastguard Worker         long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
875*4bdc9457SAndroid Build Coastguard Worker 
876*4bdc9457SAndroid Build Coastguard Worker       // Renormalize reference results.
877*4bdc9457SAndroid Build Coastguard Worker       std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(),
878*4bdc9457SAndroid Build Coastguard Worker         [this, output_scale, output_zero_point](int32_t x) -> double {
879*4bdc9457SAndroid Build Coastguard Worker           return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax() - 0x80) - output_zero_point), double(qmin() - 0x80) - output_zero_point);
880*4bdc9457SAndroid Build Coastguard Worker         });
881*4bdc9457SAndroid Build Coastguard Worker 
882*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Convolution operator.
883*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
884*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t convolution_op = nullptr;
885*4bdc9457SAndroid Build Coastguard Worker       xnn_caches caches = {
886*4bdc9457SAndroid Build Coastguard Worker         .code_cache = NULL,
887*4bdc9457SAndroid Build Coastguard Worker         .weights_cache = NULL,
888*4bdc9457SAndroid Build Coastguard Worker       };
889*4bdc9457SAndroid Build Coastguard Worker       xnn_weights_cache weights_cache;
890*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
891*4bdc9457SAndroid Build Coastguard Worker         xnn_init_weights_cache(&weights_cache);
892*4bdc9457SAndroid Build Coastguard Worker         caches.weights_cache = &weights_cache;
893*4bdc9457SAndroid Build Coastguard Worker       }
894*4bdc9457SAndroid Build Coastguard Worker 
895*4bdc9457SAndroid Build Coastguard Worker       xnn_status status = xnn_create_convolution2d_nhwc_qs8(
896*4bdc9457SAndroid Build Coastguard Worker           padding_tf_same() ? 0 : padding_top(), padding_tf_same() ? 0 : padding_right(),
897*4bdc9457SAndroid Build Coastguard Worker           padding_tf_same() ? 0 : padding_bottom(), padding_tf_same() ? 0 : padding_left(),
898*4bdc9457SAndroid Build Coastguard Worker           kernel_height(), kernel_width(),
899*4bdc9457SAndroid Build Coastguard Worker           subsampling_height(), subsampling_width(),
900*4bdc9457SAndroid Build Coastguard Worker           dilation_height(), dilation_width(),
901*4bdc9457SAndroid Build Coastguard Worker           groups(), group_input_channels(), group_output_channels(),
902*4bdc9457SAndroid Build Coastguard Worker           input_channel_stride(), output_channel_stride(),
903*4bdc9457SAndroid Build Coastguard Worker           input_zero_point, 1.0f /* input scale */, 1.0f /* kernel scale */,
904*4bdc9457SAndroid Build Coastguard Worker           kernel.data(), has_bias() ? bias.data() : nullptr,
905*4bdc9457SAndroid Build Coastguard Worker           output_zero_point, output_scale, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
906*4bdc9457SAndroid Build Coastguard Worker           (depthwise_layout() ? XNN_FLAG_DEPTHWISE_CONVOLUTION : 0) | (padding_tf_same() ? XNN_FLAG_TENSORFLOW_SAME_PADDING : 0),
907*4bdc9457SAndroid Build Coastguard Worker           &caches,
908*4bdc9457SAndroid Build Coastguard Worker           &convolution_op);
909*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
910*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
911*4bdc9457SAndroid Build Coastguard Worker       }
912*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
913*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, convolution_op);
914*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
915*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
916*4bdc9457SAndroid Build Coastguard Worker                   xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
917*4bdc9457SAndroid Build Coastguard Worker       }
918*4bdc9457SAndroid Build Coastguard Worker 
919*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete convolution_op.
920*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op, xnn_delete_operator);
921*4bdc9457SAndroid Build Coastguard Worker 
922*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
923*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_convolution2d_nhwc_qs8(
924*4bdc9457SAndroid Build Coastguard Worker           convolution_op,
925*4bdc9457SAndroid Build Coastguard Worker           batch_size(), input_height(), input_width(),
926*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
927*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
928*4bdc9457SAndroid Build Coastguard Worker 
929*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
930*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(convolution_op, nullptr /* thread pool */));
931*4bdc9457SAndroid Build Coastguard Worker 
932*4bdc9457SAndroid Build Coastguard Worker       VerifyNHWCxQS8(output, output_ref, output_zero_point);
933*4bdc9457SAndroid Build Coastguard Worker 
934*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
935*4bdc9457SAndroid Build Coastguard Worker         xnn_operator_t convolution_op2 = nullptr;
936*4bdc9457SAndroid Build Coastguard Worker         size_t old_weights_cache_size = weights_cache.cache.weights.size;
937*4bdc9457SAndroid Build Coastguard Worker 
938*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(
939*4bdc9457SAndroid Build Coastguard Worker             xnn_status_success,
940*4bdc9457SAndroid Build Coastguard Worker             xnn_create_convolution2d_nhwc_qs8(
941*4bdc9457SAndroid Build Coastguard Worker                 padding_tf_same() ? 0 : padding_top(),
942*4bdc9457SAndroid Build Coastguard Worker                 padding_tf_same() ? 0 : padding_right(),
943*4bdc9457SAndroid Build Coastguard Worker                 padding_tf_same() ? 0 : padding_bottom(),
944*4bdc9457SAndroid Build Coastguard Worker                 padding_tf_same() ? 0 : padding_left(), kernel_height(),
945*4bdc9457SAndroid Build Coastguard Worker                 kernel_width(), subsampling_height(), subsampling_width(),
946*4bdc9457SAndroid Build Coastguard Worker                 dilation_height(), dilation_width(), groups(),
947*4bdc9457SAndroid Build Coastguard Worker                 group_input_channels(), group_output_channels(),
948*4bdc9457SAndroid Build Coastguard Worker                 input_channel_stride(), output_channel_stride(),
949*4bdc9457SAndroid Build Coastguard Worker                 input_zero_point, 1.0f /* input scale */,
950*4bdc9457SAndroid Build Coastguard Worker                 1.0f /* kernel scale */, kernel.data(),
951*4bdc9457SAndroid Build Coastguard Worker                 has_bias() ? bias.data() : nullptr, output_zero_point,
952*4bdc9457SAndroid Build Coastguard Worker                 output_scale, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
953*4bdc9457SAndroid Build Coastguard Worker                 (depthwise_layout() ? XNN_FLAG_DEPTHWISE_CONVOLUTION : 0) |
954*4bdc9457SAndroid Build Coastguard Worker                     (padding_tf_same() ? XNN_FLAG_TENSORFLOW_SAME_PADDING : 0),
955*4bdc9457SAndroid Build Coastguard Worker                 &caches, &convolution_op2));
956*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NE(nullptr, convolution_op2);
957*4bdc9457SAndroid Build Coastguard Worker 
958*4bdc9457SAndroid Build Coastguard Worker         // Smart pointer to automatically delete convolution_op.
959*4bdc9457SAndroid Build Coastguard Worker         std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)>
960*4bdc9457SAndroid Build Coastguard Worker             auto_convolution_op(convolution_op2, xnn_delete_operator);
961*4bdc9457SAndroid Build Coastguard Worker 
962*4bdc9457SAndroid Build Coastguard Worker         std::vector<int8_t> output2(output.size(), INT8_C(0xA5));
963*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
964*4bdc9457SAndroid Build Coastguard Worker                   xnn_setup_convolution2d_nhwc_qs8(
965*4bdc9457SAndroid Build Coastguard Worker                       convolution_op2, batch_size(), input_height(),
966*4bdc9457SAndroid Build Coastguard Worker                       input_width(), input.data(), output2.data(),
967*4bdc9457SAndroid Build Coastguard Worker                       nullptr /* thread pool */));
968*4bdc9457SAndroid Build Coastguard Worker 
969*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
970*4bdc9457SAndroid Build Coastguard Worker                   xnn_run_operator(convolution_op2, nullptr /* thread pool */));
971*4bdc9457SAndroid Build Coastguard Worker 
972*4bdc9457SAndroid Build Coastguard Worker         VerifyNHWCxQS8(output2, output_ref, output_zero_point);
973*4bdc9457SAndroid Build Coastguard Worker         VerifyWeightsCache(weights_cache, old_weights_cache_size);
974*4bdc9457SAndroid Build Coastguard Worker         xnn_release_weights_cache(&weights_cache);
975*4bdc9457SAndroid Build Coastguard Worker       }
976*4bdc9457SAndroid Build Coastguard Worker     }
977*4bdc9457SAndroid Build Coastguard Worker   }
978*4bdc9457SAndroid Build Coastguard Worker 
VerifyNHWCxQS8(const std::vector<int8_t> & output,const std::vector<double> & output_ref,const int8_t output_zero_point)979*4bdc9457SAndroid Build Coastguard Worker   void VerifyNHWCxQS8(const std::vector<int8_t> &output,
980*4bdc9457SAndroid Build Coastguard Worker                       const std::vector<double> &output_ref,
981*4bdc9457SAndroid Build Coastguard Worker                       const int8_t output_zero_point) const {
982*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < batch_size(); i++) {
983*4bdc9457SAndroid Build Coastguard Worker       for (size_t y = 0; y < output_height(); y++) {
984*4bdc9457SAndroid Build Coastguard Worker         for (size_t x = 0; x < output_width(); x++) {
985*4bdc9457SAndroid Build Coastguard Worker           for (size_t g = 0; g < groups(); g++) {
986*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < group_output_channels(); c++) {
987*4bdc9457SAndroid Build Coastguard Worker               ASSERT_LE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmax() - 0x80))
988*4bdc9457SAndroid Build Coastguard Worker                 << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
989*4bdc9457SAndroid Build Coastguard Worker               ASSERT_GE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmin() - 0x80))
990*4bdc9457SAndroid Build Coastguard Worker                 << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
991*4bdc9457SAndroid Build Coastguard Worker               ASSERT_NEAR(
992*4bdc9457SAndroid Build Coastguard Worker                   output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c],
993*4bdc9457SAndroid Build Coastguard Worker                   double(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]) - double(output_zero_point),
994*4bdc9457SAndroid Build Coastguard Worker                   0.9)
995*4bdc9457SAndroid Build Coastguard Worker                 << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
996*4bdc9457SAndroid Build Coastguard Worker             }
997*4bdc9457SAndroid Build Coastguard Worker           }
998*4bdc9457SAndroid Build Coastguard Worker         }
999*4bdc9457SAndroid Build Coastguard Worker       }
1000*4bdc9457SAndroid Build Coastguard Worker     }
1001*4bdc9457SAndroid Build Coastguard Worker   }
1002*4bdc9457SAndroid Build Coastguard Worker 
TestNHWCxQU8()1003*4bdc9457SAndroid Build Coastguard Worker   void TestNHWCxQU8() const {
1004*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(weights_type(), WeightsType::Default);
1005*4bdc9457SAndroid Build Coastguard Worker 
1006*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
1007*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
1008*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i32dist(-10000, 10000);
1009*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> u8dist(
1010*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
1011*4bdc9457SAndroid Build Coastguard Worker 
1012*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) +
1013*4bdc9457SAndroid Build Coastguard Worker       batch_size() * ((input_height() * input_width() - 1) * input_channel_stride() + groups() * group_input_channels()));
1014*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels());
1015*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> bias(groups() * group_output_channels());
1016*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output(batch_size() * ((output_height() * output_width() - 1) * output_channel_stride() + groups() * group_output_channels()));
1017*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> accumulators(batch_size() * output_height() * output_width() * groups() * group_output_channels());
1018*4bdc9457SAndroid Build Coastguard Worker     std::vector<double> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels());
1019*4bdc9457SAndroid Build Coastguard Worker 
1020*4bdc9457SAndroid Build Coastguard Worker     const uint8_t input_zero_point = 127;
1021*4bdc9457SAndroid Build Coastguard Worker     const uint8_t kernel_zero_point = 127;
1022*4bdc9457SAndroid Build Coastguard Worker 
1023*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
1024*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
1025*4bdc9457SAndroid Build Coastguard Worker       std::generate(kernel.begin(), kernel.end(), [&]() { return u8dist(rng); });
1026*4bdc9457SAndroid Build Coastguard Worker       std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); });
1027*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT8_C(0xA5));
1028*4bdc9457SAndroid Build Coastguard Worker 
1029*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without renormalization.
1030*4bdc9457SAndroid Build Coastguard Worker       if (has_bias()) {
1031*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
1032*4bdc9457SAndroid Build Coastguard Worker           for (size_t oy = 0; oy < output_height(); oy++) {
1033*4bdc9457SAndroid Build Coastguard Worker             for (size_t ox = 0; ox < output_width(); ox++) {
1034*4bdc9457SAndroid Build Coastguard Worker               for (size_t g = 0; g < groups(); g++) {
1035*4bdc9457SAndroid Build Coastguard Worker                 for (size_t oc = 0; oc < group_output_channels(); oc++) {
1036*4bdc9457SAndroid Build Coastguard Worker                   accumulators[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] =
1037*4bdc9457SAndroid Build Coastguard Worker                     bias[g * group_output_channels() + oc];
1038*4bdc9457SAndroid Build Coastguard Worker                 }
1039*4bdc9457SAndroid Build Coastguard Worker               }
1040*4bdc9457SAndroid Build Coastguard Worker             }
1041*4bdc9457SAndroid Build Coastguard Worker           }
1042*4bdc9457SAndroid Build Coastguard Worker         }
1043*4bdc9457SAndroid Build Coastguard Worker       } else {
1044*4bdc9457SAndroid Build Coastguard Worker         std::fill(accumulators.begin(), accumulators.end(), 0);
1045*4bdc9457SAndroid Build Coastguard Worker       }
1046*4bdc9457SAndroid Build Coastguard Worker       if (depthwise_layout()) {
1047*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(group_input_channels(), 1);
1048*4bdc9457SAndroid Build Coastguard Worker         xnnpack::compute_depthwise_convolution_qu8_reference_results(
1049*4bdc9457SAndroid Build Coastguard Worker             batch_size(),
1050*4bdc9457SAndroid Build Coastguard Worker             output_height(),
1051*4bdc9457SAndroid Build Coastguard Worker             output_width(),
1052*4bdc9457SAndroid Build Coastguard Worker             input_height(),
1053*4bdc9457SAndroid Build Coastguard Worker             input_width(),
1054*4bdc9457SAndroid Build Coastguard Worker             padding_top(),
1055*4bdc9457SAndroid Build Coastguard Worker             padding_right(),
1056*4bdc9457SAndroid Build Coastguard Worker             padding_bottom(),
1057*4bdc9457SAndroid Build Coastguard Worker             padding_left(),
1058*4bdc9457SAndroid Build Coastguard Worker             kernel_height(),
1059*4bdc9457SAndroid Build Coastguard Worker             kernel_width(),
1060*4bdc9457SAndroid Build Coastguard Worker             subsampling_height(),
1061*4bdc9457SAndroid Build Coastguard Worker             subsampling_width(),
1062*4bdc9457SAndroid Build Coastguard Worker             dilation_height(),
1063*4bdc9457SAndroid Build Coastguard Worker             dilation_width(),
1064*4bdc9457SAndroid Build Coastguard Worker             groups(),
1065*4bdc9457SAndroid Build Coastguard Worker             group_output_channels(),
1066*4bdc9457SAndroid Build Coastguard Worker             input_channel_stride(),
1067*4bdc9457SAndroid Build Coastguard Worker             input_zero_point,
1068*4bdc9457SAndroid Build Coastguard Worker             kernel_zero_point,
1069*4bdc9457SAndroid Build Coastguard Worker             input,
1070*4bdc9457SAndroid Build Coastguard Worker             kernel,
1071*4bdc9457SAndroid Build Coastguard Worker             accumulators,
1072*4bdc9457SAndroid Build Coastguard Worker             has_bias(),
1073*4bdc9457SAndroid Build Coastguard Worker             bias);
1074*4bdc9457SAndroid Build Coastguard Worker       } else {
1075*4bdc9457SAndroid Build Coastguard Worker         xnnpack::compute_convolution_qu8_reference_results(
1076*4bdc9457SAndroid Build Coastguard Worker             batch_size(),
1077*4bdc9457SAndroid Build Coastguard Worker             output_height(),
1078*4bdc9457SAndroid Build Coastguard Worker             output_width(),
1079*4bdc9457SAndroid Build Coastguard Worker             input_height(),
1080*4bdc9457SAndroid Build Coastguard Worker             input_width(),
1081*4bdc9457SAndroid Build Coastguard Worker             padding_top(),
1082*4bdc9457SAndroid Build Coastguard Worker             padding_right(),
1083*4bdc9457SAndroid Build Coastguard Worker             padding_bottom(),
1084*4bdc9457SAndroid Build Coastguard Worker             padding_left(),
1085*4bdc9457SAndroid Build Coastguard Worker             kernel_height(),
1086*4bdc9457SAndroid Build Coastguard Worker             kernel_width(),
1087*4bdc9457SAndroid Build Coastguard Worker             subsampling_height(),
1088*4bdc9457SAndroid Build Coastguard Worker             subsampling_width(),
1089*4bdc9457SAndroid Build Coastguard Worker             dilation_height(),
1090*4bdc9457SAndroid Build Coastguard Worker             dilation_width(),
1091*4bdc9457SAndroid Build Coastguard Worker             groups(),
1092*4bdc9457SAndroid Build Coastguard Worker             group_input_channels(),
1093*4bdc9457SAndroid Build Coastguard Worker             group_output_channels(),
1094*4bdc9457SAndroid Build Coastguard Worker             input_channel_stride(),
1095*4bdc9457SAndroid Build Coastguard Worker             input_zero_point,
1096*4bdc9457SAndroid Build Coastguard Worker             kernel_zero_point,
1097*4bdc9457SAndroid Build Coastguard Worker             input,
1098*4bdc9457SAndroid Build Coastguard Worker             kernel,
1099*4bdc9457SAndroid Build Coastguard Worker             accumulators,
1100*4bdc9457SAndroid Build Coastguard Worker             has_bias(),
1101*4bdc9457SAndroid Build Coastguard Worker             bias);
1102*4bdc9457SAndroid Build Coastguard Worker       }
1103*4bdc9457SAndroid Build Coastguard Worker 
1104*4bdc9457SAndroid Build Coastguard Worker       // Compute renormalization parameters.
1105*4bdc9457SAndroid Build Coastguard Worker       const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
1106*4bdc9457SAndroid Build Coastguard Worker       const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
1107*4bdc9457SAndroid Build Coastguard Worker 
1108*4bdc9457SAndroid Build Coastguard Worker       const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0;
1109*4bdc9457SAndroid Build Coastguard Worker       const uint8_t output_zero_point = uint8_t(std::max(std::min(
1110*4bdc9457SAndroid Build Coastguard Worker         lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
1111*4bdc9457SAndroid Build Coastguard Worker         long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
1112*4bdc9457SAndroid Build Coastguard Worker 
1113*4bdc9457SAndroid Build Coastguard Worker       // Renormalize reference results.
1114*4bdc9457SAndroid Build Coastguard Worker       std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(),
1115*4bdc9457SAndroid Build Coastguard Worker         [this, output_scale, output_zero_point](int32_t x) -> double {
1116*4bdc9457SAndroid Build Coastguard Worker           return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax()) - output_zero_point), double(qmin()) - output_zero_point);
1117*4bdc9457SAndroid Build Coastguard Worker         });
1118*4bdc9457SAndroid Build Coastguard Worker 
1119*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Convolution operator.
1120*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
1121*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t convolution_op = nullptr;
1122*4bdc9457SAndroid Build Coastguard Worker 
1123*4bdc9457SAndroid Build Coastguard Worker       xnn_caches caches = {
1124*4bdc9457SAndroid Build Coastguard Worker         .code_cache = NULL,
1125*4bdc9457SAndroid Build Coastguard Worker         .weights_cache = NULL,
1126*4bdc9457SAndroid Build Coastguard Worker       };
1127*4bdc9457SAndroid Build Coastguard Worker       xnn_weights_cache weights_cache;
1128*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
1129*4bdc9457SAndroid Build Coastguard Worker         xnn_init_weights_cache(&weights_cache);
1130*4bdc9457SAndroid Build Coastguard Worker         caches.weights_cache = &weights_cache;
1131*4bdc9457SAndroid Build Coastguard Worker       }
1132*4bdc9457SAndroid Build Coastguard Worker 
1133*4bdc9457SAndroid Build Coastguard Worker       xnn_status status = xnn_create_convolution2d_nhwc_qu8(
1134*4bdc9457SAndroid Build Coastguard Worker           padding_tf_same() ? 0 : padding_top(), padding_tf_same() ? 0 : padding_right(),
1135*4bdc9457SAndroid Build Coastguard Worker           padding_tf_same() ? 0 : padding_bottom(), padding_tf_same() ? 0 : padding_left(),
1136*4bdc9457SAndroid Build Coastguard Worker           kernel_height(), kernel_width(),
1137*4bdc9457SAndroid Build Coastguard Worker           subsampling_height(), subsampling_width(),
1138*4bdc9457SAndroid Build Coastguard Worker           dilation_height(), dilation_width(),
1139*4bdc9457SAndroid Build Coastguard Worker           groups(), group_input_channels(), group_output_channels(),
1140*4bdc9457SAndroid Build Coastguard Worker           input_channel_stride(), output_channel_stride(),
1141*4bdc9457SAndroid Build Coastguard Worker           input_zero_point, 1.0f /* input scale */,
1142*4bdc9457SAndroid Build Coastguard Worker           kernel_zero_point, 1.0f /* kernel scale */,
1143*4bdc9457SAndroid Build Coastguard Worker           kernel.data(), has_bias() ? bias.data() : nullptr,
1144*4bdc9457SAndroid Build Coastguard Worker           output_zero_point, output_scale, qmin(), qmax(),
1145*4bdc9457SAndroid Build Coastguard Worker           (depthwise_layout() ? XNN_FLAG_DEPTHWISE_CONVOLUTION : 0) | (padding_tf_same() ? XNN_FLAG_TENSORFLOW_SAME_PADDING : 0),
1146*4bdc9457SAndroid Build Coastguard Worker           &caches,
1147*4bdc9457SAndroid Build Coastguard Worker           &convolution_op);
1148*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
1149*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
1150*4bdc9457SAndroid Build Coastguard Worker       }
1151*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
1152*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, convolution_op);
1153*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
1154*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
1155*4bdc9457SAndroid Build Coastguard Worker                   xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
1156*4bdc9457SAndroid Build Coastguard Worker       }
1157*4bdc9457SAndroid Build Coastguard Worker 
1158*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete convolution_op.
1159*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op, xnn_delete_operator);
1160*4bdc9457SAndroid Build Coastguard Worker 
1161*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
1162*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_convolution2d_nhwc_qu8(
1163*4bdc9457SAndroid Build Coastguard Worker           convolution_op,
1164*4bdc9457SAndroid Build Coastguard Worker           batch_size(), input_height(), input_width(),
1165*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
1166*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
1167*4bdc9457SAndroid Build Coastguard Worker 
1168*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
1169*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(convolution_op, nullptr /* thread pool */));
1170*4bdc9457SAndroid Build Coastguard Worker 
1171*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
1172*4bdc9457SAndroid Build Coastguard Worker       VerifyNHWCxQU8(output, output_ref, output_zero_point);
1173*4bdc9457SAndroid Build Coastguard Worker 
1174*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
1175*4bdc9457SAndroid Build Coastguard Worker         xnn_operator_t convolution_op2 = nullptr;
1176*4bdc9457SAndroid Build Coastguard Worker         size_t old_weights_cache_size = weights_cache.cache.weights.size;
1177*4bdc9457SAndroid Build Coastguard Worker 
1178*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(
1179*4bdc9457SAndroid Build Coastguard Worker             xnn_status_success,
1180*4bdc9457SAndroid Build Coastguard Worker             xnn_create_convolution2d_nhwc_qu8(
1181*4bdc9457SAndroid Build Coastguard Worker                 padding_tf_same() ? 0 : padding_top(),
1182*4bdc9457SAndroid Build Coastguard Worker                 padding_tf_same() ? 0 : padding_right(),
1183*4bdc9457SAndroid Build Coastguard Worker                 padding_tf_same() ? 0 : padding_bottom(),
1184*4bdc9457SAndroid Build Coastguard Worker                 padding_tf_same() ? 0 : padding_left(), kernel_height(),
1185*4bdc9457SAndroid Build Coastguard Worker                 kernel_width(), subsampling_height(), subsampling_width(),
1186*4bdc9457SAndroid Build Coastguard Worker                 dilation_height(), dilation_width(), groups(),
1187*4bdc9457SAndroid Build Coastguard Worker                 group_input_channels(), group_output_channels(),
1188*4bdc9457SAndroid Build Coastguard Worker                 input_channel_stride(), output_channel_stride(),
1189*4bdc9457SAndroid Build Coastguard Worker                 input_zero_point, 1.0f /* input scale */, kernel_zero_point,
1190*4bdc9457SAndroid Build Coastguard Worker                 1.0f /* kernel scale */, kernel.data(),
1191*4bdc9457SAndroid Build Coastguard Worker                 has_bias() ? bias.data() : nullptr, output_zero_point,
1192*4bdc9457SAndroid Build Coastguard Worker                 output_scale, qmin(), qmax(),
1193*4bdc9457SAndroid Build Coastguard Worker                 (depthwise_layout() ? XNN_FLAG_DEPTHWISE_CONVOLUTION : 0) |
1194*4bdc9457SAndroid Build Coastguard Worker                     (padding_tf_same() ? XNN_FLAG_TENSORFLOW_SAME_PADDING : 0),
1195*4bdc9457SAndroid Build Coastguard Worker                 &caches, &convolution_op2));
1196*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NE(nullptr, convolution_op2);
1197*4bdc9457SAndroid Build Coastguard Worker 
1198*4bdc9457SAndroid Build Coastguard Worker         // Smart pointer to automatically delete convolution_op2.
1199*4bdc9457SAndroid Build Coastguard Worker         std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)>
1200*4bdc9457SAndroid Build Coastguard Worker             auto_convolution_op2(convolution_op2, xnn_delete_operator);
1201*4bdc9457SAndroid Build Coastguard Worker         std::vector<uint8_t> output2(output.size(), UINT8_C(0xA5));
1202*4bdc9457SAndroid Build Coastguard Worker 
1203*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
1204*4bdc9457SAndroid Build Coastguard Worker                   xnn_setup_convolution2d_nhwc_qu8(
1205*4bdc9457SAndroid Build Coastguard Worker                       convolution_op2, batch_size(), input_height(),
1206*4bdc9457SAndroid Build Coastguard Worker                       input_width(), input.data(), output2.data(),
1207*4bdc9457SAndroid Build Coastguard Worker                       nullptr /* thread pool */));
1208*4bdc9457SAndroid Build Coastguard Worker 
1209*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
1210*4bdc9457SAndroid Build Coastguard Worker                   xnn_run_operator(convolution_op2, nullptr /* thread pool */));
1211*4bdc9457SAndroid Build Coastguard Worker 
1212*4bdc9457SAndroid Build Coastguard Worker         // Verify results.
1213*4bdc9457SAndroid Build Coastguard Worker         VerifyNHWCxQU8(output2, output_ref, output_zero_point);
1214*4bdc9457SAndroid Build Coastguard Worker         VerifyWeightsCache(weights_cache, old_weights_cache_size);
1215*4bdc9457SAndroid Build Coastguard Worker         xnn_release_weights_cache(&weights_cache);
1216*4bdc9457SAndroid Build Coastguard Worker       }
1217*4bdc9457SAndroid Build Coastguard Worker     }
1218*4bdc9457SAndroid Build Coastguard Worker   }
1219*4bdc9457SAndroid Build Coastguard Worker 
VerifyNHWCxQU8(const std::vector<uint8_t> & output,const std::vector<double> & output_ref,const uint8_t output_zero_point)1220*4bdc9457SAndroid Build Coastguard Worker   void VerifyNHWCxQU8(const std::vector<uint8_t> &output,
1221*4bdc9457SAndroid Build Coastguard Worker                       const std::vector<double> &output_ref,
1222*4bdc9457SAndroid Build Coastguard Worker                       const uint8_t output_zero_point) const {
1223*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < batch_size(); i++) {
1224*4bdc9457SAndroid Build Coastguard Worker       for (size_t y = 0; y < output_height(); y++) {
1225*4bdc9457SAndroid Build Coastguard Worker         for (size_t x = 0; x < output_width(); x++) {
1226*4bdc9457SAndroid Build Coastguard Worker           for (size_t g = 0; g < groups(); g++) {
1227*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < group_output_channels(); c++) {
1228*4bdc9457SAndroid Build Coastguard Worker               ASSERT_LE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmax()))
1229*4bdc9457SAndroid Build Coastguard Worker                 << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1230*4bdc9457SAndroid Build Coastguard Worker               ASSERT_GE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmin()))
1231*4bdc9457SAndroid Build Coastguard Worker                 << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1232*4bdc9457SAndroid Build Coastguard Worker               ASSERT_NEAR(
1233*4bdc9457SAndroid Build Coastguard Worker                   output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c],
1234*4bdc9457SAndroid Build Coastguard Worker                   double(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]) - double(output_zero_point),
1235*4bdc9457SAndroid Build Coastguard Worker                   0.9)
1236*4bdc9457SAndroid Build Coastguard Worker                 << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1237*4bdc9457SAndroid Build Coastguard Worker             }
1238*4bdc9457SAndroid Build Coastguard Worker           }
1239*4bdc9457SAndroid Build Coastguard Worker         }
1240*4bdc9457SAndroid Build Coastguard Worker       }
1241*4bdc9457SAndroid Build Coastguard Worker     }
1242*4bdc9457SAndroid Build Coastguard Worker   }
1243*4bdc9457SAndroid Build Coastguard Worker 
TestNHWCxF32()1244*4bdc9457SAndroid Build Coastguard Worker   void TestNHWCxF32() const {
1245*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(weights_type(), WeightsType::Default);
1246*4bdc9457SAndroid Build Coastguard Worker 
1247*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
1248*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
1249*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(0.1f, 1.0f);
1250*4bdc9457SAndroid Build Coastguard Worker 
1251*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
1252*4bdc9457SAndroid Build Coastguard Worker       batch_size() * ((input_height() * input_width() - 1) * input_channel_stride() + groups() * group_input_channels()));
1253*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels());
1254*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> bias(groups() * group_output_channels());
1255*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output(batch_size() * ((output_height() * output_width() - 1) * output_channel_stride() + groups() * group_output_channels()));
1256*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels());
1257*4bdc9457SAndroid Build Coastguard Worker 
1258*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
1259*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
1260*4bdc9457SAndroid Build Coastguard Worker       std::generate(kernel.begin(), kernel.end(), [&]() { return f32dist(rng); });
1261*4bdc9457SAndroid Build Coastguard Worker       std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
1262*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), nanf(""));
1263*4bdc9457SAndroid Build Coastguard Worker 
1264*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
1265*4bdc9457SAndroid Build Coastguard Worker       if (has_bias()) {
1266*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
1267*4bdc9457SAndroid Build Coastguard Worker           for (size_t oy = 0; oy < output_height(); oy++) {
1268*4bdc9457SAndroid Build Coastguard Worker             for (size_t ox = 0; ox < output_width(); ox++) {
1269*4bdc9457SAndroid Build Coastguard Worker               for (size_t g = 0; g < groups(); g++) {
1270*4bdc9457SAndroid Build Coastguard Worker                 for (size_t oc = 0; oc < group_output_channels(); oc++) {
1271*4bdc9457SAndroid Build Coastguard Worker                   output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] =
1272*4bdc9457SAndroid Build Coastguard Worker                     bias[g * group_output_channels() + oc];
1273*4bdc9457SAndroid Build Coastguard Worker                 }
1274*4bdc9457SAndroid Build Coastguard Worker               }
1275*4bdc9457SAndroid Build Coastguard Worker             }
1276*4bdc9457SAndroid Build Coastguard Worker           }
1277*4bdc9457SAndroid Build Coastguard Worker         }
1278*4bdc9457SAndroid Build Coastguard Worker       } else {
1279*4bdc9457SAndroid Build Coastguard Worker         std::fill(output_ref.begin(), output_ref.end(), 0.0f);
1280*4bdc9457SAndroid Build Coastguard Worker       }
1281*4bdc9457SAndroid Build Coastguard Worker       if (depthwise_layout()) {
1282*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(group_input_channels(), 1);
1283*4bdc9457SAndroid Build Coastguard Worker 
1284*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
1285*4bdc9457SAndroid Build Coastguard Worker           for (size_t oy = 0; oy < output_height(); oy++) {
1286*4bdc9457SAndroid Build Coastguard Worker             for (size_t ox = 0; ox < output_width(); ox++) {
1287*4bdc9457SAndroid Build Coastguard Worker               for (size_t ky = 0; ky < kernel_height(); ky++) {
1288*4bdc9457SAndroid Build Coastguard Worker                 const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top();
1289*4bdc9457SAndroid Build Coastguard Worker                 if (iy < input_height()) {
1290*4bdc9457SAndroid Build Coastguard Worker                   for (size_t kx = 0; kx < kernel_width(); kx++) {
1291*4bdc9457SAndroid Build Coastguard Worker                     const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left();
1292*4bdc9457SAndroid Build Coastguard Worker                     if (ix < input_width()) {
1293*4bdc9457SAndroid Build Coastguard Worker                       for (size_t g = 0; g < groups(); g++) {
1294*4bdc9457SAndroid Build Coastguard Worker                         for (size_t oc = 0; oc < group_output_channels(); oc++) {
1295*4bdc9457SAndroid Build Coastguard Worker                           output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
1296*4bdc9457SAndroid Build Coastguard Worker                             input[((i * input_height() + iy) * input_width() + ix) * input_channel_stride() + g] *
1297*4bdc9457SAndroid Build Coastguard Worker                             kernel[((ky * kernel_width() + kx) * groups() + g) * group_output_channels() + oc];
1298*4bdc9457SAndroid Build Coastguard Worker                         }
1299*4bdc9457SAndroid Build Coastguard Worker                       }
1300*4bdc9457SAndroid Build Coastguard Worker                     }
1301*4bdc9457SAndroid Build Coastguard Worker                   }
1302*4bdc9457SAndroid Build Coastguard Worker                 }
1303*4bdc9457SAndroid Build Coastguard Worker               }
1304*4bdc9457SAndroid Build Coastguard Worker             }
1305*4bdc9457SAndroid Build Coastguard Worker           }
1306*4bdc9457SAndroid Build Coastguard Worker         }
1307*4bdc9457SAndroid Build Coastguard Worker       } else {
1308*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
1309*4bdc9457SAndroid Build Coastguard Worker           for (size_t oy = 0; oy < output_height(); oy++) {
1310*4bdc9457SAndroid Build Coastguard Worker             for (size_t ox = 0; ox < output_width(); ox++) {
1311*4bdc9457SAndroid Build Coastguard Worker               for (size_t ky = 0; ky < kernel_height(); ky++) {
1312*4bdc9457SAndroid Build Coastguard Worker                 const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top();
1313*4bdc9457SAndroid Build Coastguard Worker                 if (iy < input_height()) {
1314*4bdc9457SAndroid Build Coastguard Worker                   for (size_t kx = 0; kx < kernel_width(); kx++) {
1315*4bdc9457SAndroid Build Coastguard Worker                     const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left();
1316*4bdc9457SAndroid Build Coastguard Worker                     if (ix < input_width()) {
1317*4bdc9457SAndroid Build Coastguard Worker                       for (size_t g = 0; g < groups(); g++) {
1318*4bdc9457SAndroid Build Coastguard Worker                         for (size_t oc = 0; oc < group_output_channels(); oc++) {
1319*4bdc9457SAndroid Build Coastguard Worker                           for (size_t ic = 0; ic < group_input_channels(); ic++) {
1320*4bdc9457SAndroid Build Coastguard Worker                             output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
1321*4bdc9457SAndroid Build Coastguard Worker                               input[((i * input_height() + iy) * input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic] *
1322*4bdc9457SAndroid Build Coastguard Worker                               kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic];
1323*4bdc9457SAndroid Build Coastguard Worker                           }
1324*4bdc9457SAndroid Build Coastguard Worker                         }
1325*4bdc9457SAndroid Build Coastguard Worker                       }
1326*4bdc9457SAndroid Build Coastguard Worker                     }
1327*4bdc9457SAndroid Build Coastguard Worker                   }
1328*4bdc9457SAndroid Build Coastguard Worker                 }
1329*4bdc9457SAndroid Build Coastguard Worker               }
1330*4bdc9457SAndroid Build Coastguard Worker             }
1331*4bdc9457SAndroid Build Coastguard Worker           }
1332*4bdc9457SAndroid Build Coastguard Worker         }
1333*4bdc9457SAndroid Build Coastguard Worker       }
1334*4bdc9457SAndroid Build Coastguard Worker 
1335*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
1336*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
1337*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
1338*4bdc9457SAndroid Build Coastguard Worker 
1339*4bdc9457SAndroid Build Coastguard Worker       const float output_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1340*4bdc9457SAndroid Build Coastguard Worker       const float output_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1341*4bdc9457SAndroid Build Coastguard Worker 
1342*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
1343*4bdc9457SAndroid Build Coastguard Worker       for (float& value : output_ref) {
1344*4bdc9457SAndroid Build Coastguard Worker         value = std::max(std::min(value, output_max), output_min);
1345*4bdc9457SAndroid Build Coastguard Worker       }
1346*4bdc9457SAndroid Build Coastguard Worker 
1347*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Convolution operator.
1348*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
1349*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t convolution_op = nullptr;
1350*4bdc9457SAndroid Build Coastguard Worker 
1351*4bdc9457SAndroid Build Coastguard Worker       xnn_caches caches = {
1352*4bdc9457SAndroid Build Coastguard Worker         .code_cache = NULL,
1353*4bdc9457SAndroid Build Coastguard Worker         .weights_cache = NULL,
1354*4bdc9457SAndroid Build Coastguard Worker       };
1355*4bdc9457SAndroid Build Coastguard Worker       #if XNN_PLATFORM_JIT
1356*4bdc9457SAndroid Build Coastguard Worker         xnn_code_cache code_cache;
1357*4bdc9457SAndroid Build Coastguard Worker         if (use_jit()) {
1358*4bdc9457SAndroid Build Coastguard Worker           xnn_init_code_cache(&code_cache);
1359*4bdc9457SAndroid Build Coastguard Worker           caches.code_cache = &code_cache;
1360*4bdc9457SAndroid Build Coastguard Worker         }
1361*4bdc9457SAndroid Build Coastguard Worker       #endif
1362*4bdc9457SAndroid Build Coastguard Worker       xnn_weights_cache weights_cache;
1363*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
1364*4bdc9457SAndroid Build Coastguard Worker         xnn_init_weights_cache(&weights_cache);
1365*4bdc9457SAndroid Build Coastguard Worker         caches.weights_cache = &weights_cache;
1366*4bdc9457SAndroid Build Coastguard Worker       }
1367*4bdc9457SAndroid Build Coastguard Worker 
1368*4bdc9457SAndroid Build Coastguard Worker       xnn_status status = xnn_create_convolution2d_nhwc_f32(
1369*4bdc9457SAndroid Build Coastguard Worker           padding_tf_same() ? 0 : padding_top(), padding_tf_same() ? 0 : padding_right(),
1370*4bdc9457SAndroid Build Coastguard Worker           padding_tf_same() ? 0 : padding_bottom(), padding_tf_same() ? 0 : padding_left(),
1371*4bdc9457SAndroid Build Coastguard Worker           kernel_height(), kernel_width(),
1372*4bdc9457SAndroid Build Coastguard Worker           subsampling_height(), subsampling_width(),
1373*4bdc9457SAndroid Build Coastguard Worker           dilation_height(), dilation_width(),
1374*4bdc9457SAndroid Build Coastguard Worker           groups(), group_input_channels(), group_output_channels(),
1375*4bdc9457SAndroid Build Coastguard Worker           input_channel_stride(), output_channel_stride(),
1376*4bdc9457SAndroid Build Coastguard Worker           kernel.data(), has_bias() ? bias.data() : nullptr,
1377*4bdc9457SAndroid Build Coastguard Worker           output_min, output_max,
1378*4bdc9457SAndroid Build Coastguard Worker           (depthwise_layout() ? XNN_FLAG_DEPTHWISE_CONVOLUTION : 0) | (padding_tf_same() ? XNN_FLAG_TENSORFLOW_SAME_PADDING : 0),
1379*4bdc9457SAndroid Build Coastguard Worker           &caches,
1380*4bdc9457SAndroid Build Coastguard Worker           &convolution_op);
1381*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
1382*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
1383*4bdc9457SAndroid Build Coastguard Worker       }
1384*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
1385*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, convolution_op);
1386*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
1387*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
1388*4bdc9457SAndroid Build Coastguard Worker                   xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
1389*4bdc9457SAndroid Build Coastguard Worker       }
1390*4bdc9457SAndroid Build Coastguard Worker 
1391*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete convolution_op.
1392*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op, xnn_delete_operator);
1393*4bdc9457SAndroid Build Coastguard Worker 
1394*4bdc9457SAndroid Build Coastguard Worker       #if XNN_PLATFORM_JIT
1395*4bdc9457SAndroid Build Coastguard Worker         if (use_jit()) {
1396*4bdc9457SAndroid Build Coastguard Worker           // Check that we actually generated code.
1397*4bdc9457SAndroid Build Coastguard Worker           ASSERT_GT(code_cache.cache.code.size, 0);
1398*4bdc9457SAndroid Build Coastguard Worker           xnn_finalize_code_memory(&code_cache.cache.code);
1399*4bdc9457SAndroid Build Coastguard Worker         }
1400*4bdc9457SAndroid Build Coastguard Worker       #endif
1401*4bdc9457SAndroid Build Coastguard Worker 
1402*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
1403*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_convolution2d_nhwc_f32(
1404*4bdc9457SAndroid Build Coastguard Worker           convolution_op,
1405*4bdc9457SAndroid Build Coastguard Worker           batch_size(), input_height(), input_width(),
1406*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
1407*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
1408*4bdc9457SAndroid Build Coastguard Worker 
1409*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
1410*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(convolution_op, nullptr /* thread pool */));
1411*4bdc9457SAndroid Build Coastguard Worker 
1412*4bdc9457SAndroid Build Coastguard Worker       VerifyNHWCxF32(output, output_ref, output_min, output_max);
1413*4bdc9457SAndroid Build Coastguard Worker 
1414*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
1415*4bdc9457SAndroid Build Coastguard Worker         // We already finalized the code cache, so create a new code cache if we are testing JIT.
1416*4bdc9457SAndroid Build Coastguard Worker         #if XNN_PLATFORM_JIT
1417*4bdc9457SAndroid Build Coastguard Worker           xnn_code_cache inner_code_cache;
1418*4bdc9457SAndroid Build Coastguard Worker           if (use_jit()) {
1419*4bdc9457SAndroid Build Coastguard Worker             xnn_init_code_cache(&inner_code_cache);
1420*4bdc9457SAndroid Build Coastguard Worker             caches.code_cache = &inner_code_cache;
1421*4bdc9457SAndroid Build Coastguard Worker           }
1422*4bdc9457SAndroid Build Coastguard Worker         #endif
1423*4bdc9457SAndroid Build Coastguard Worker         // To test weights cache, we create the operator with the same parameters, and setup with a different output.
1424*4bdc9457SAndroid Build Coastguard Worker         xnn_operator_t convolution_op2 = nullptr;
1425*4bdc9457SAndroid Build Coastguard Worker         size_t old_weights_cache_size = weights_cache.cache.weights.size;
1426*4bdc9457SAndroid Build Coastguard Worker 
1427*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success, xnn_create_convolution2d_nhwc_f32(
1428*4bdc9457SAndroid Build Coastguard Worker             padding_tf_same() ? 0 : padding_top(), padding_tf_same() ? 0 : padding_right(),
1429*4bdc9457SAndroid Build Coastguard Worker             padding_tf_same() ? 0 : padding_bottom(), padding_tf_same() ? 0 : padding_left(),
1430*4bdc9457SAndroid Build Coastguard Worker             kernel_height(), kernel_width(),
1431*4bdc9457SAndroid Build Coastguard Worker             subsampling_height(), subsampling_width(),
1432*4bdc9457SAndroid Build Coastguard Worker             dilation_height(), dilation_width(),
1433*4bdc9457SAndroid Build Coastguard Worker             groups(), group_input_channels(), group_output_channels(),
1434*4bdc9457SAndroid Build Coastguard Worker             input_channel_stride(), output_channel_stride(),
1435*4bdc9457SAndroid Build Coastguard Worker             kernel.data(), has_bias() ? bias.data() : nullptr,
1436*4bdc9457SAndroid Build Coastguard Worker             output_min, output_max,
1437*4bdc9457SAndroid Build Coastguard Worker             (depthwise_layout() ? XNN_FLAG_DEPTHWISE_CONVOLUTION : 0) | (padding_tf_same() ? XNN_FLAG_TENSORFLOW_SAME_PADDING : 0),
1438*4bdc9457SAndroid Build Coastguard Worker             &caches,
1439*4bdc9457SAndroid Build Coastguard Worker             &convolution_op2));
1440*4bdc9457SAndroid Build Coastguard Worker 
1441*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NE(nullptr, convolution_op2);
1442*4bdc9457SAndroid Build Coastguard Worker 
1443*4bdc9457SAndroid Build Coastguard Worker         #if XNN_PLATFORM_JIT
1444*4bdc9457SAndroid Build Coastguard Worker           if (use_jit()) {
1445*4bdc9457SAndroid Build Coastguard Worker             // Check that we actually generated code.
1446*4bdc9457SAndroid Build Coastguard Worker             ASSERT_GT(inner_code_cache.cache.code.size, 0);
1447*4bdc9457SAndroid Build Coastguard Worker             xnn_finalize_code_memory(&inner_code_cache.cache.code);
1448*4bdc9457SAndroid Build Coastguard Worker           }
1449*4bdc9457SAndroid Build Coastguard Worker         #endif
1450*4bdc9457SAndroid Build Coastguard Worker 
1451*4bdc9457SAndroid Build Coastguard Worker         std::vector<float> output2(output.size(), nanf(""));
1452*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
1453*4bdc9457SAndroid Build Coastguard Worker                   xnn_setup_convolution2d_nhwc_f32(
1454*4bdc9457SAndroid Build Coastguard Worker                       convolution_op2,
1455*4bdc9457SAndroid Build Coastguard Worker                       batch_size(), input_height(), input_width(),
1456*4bdc9457SAndroid Build Coastguard Worker                       input.data(), output2.data(),
1457*4bdc9457SAndroid Build Coastguard Worker                       nullptr /* thread pool */));
1458*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
1459*4bdc9457SAndroid Build Coastguard Worker                   xnn_run_operator(convolution_op2, nullptr /* thread pool */));
1460*4bdc9457SAndroid Build Coastguard Worker 
1461*4bdc9457SAndroid Build Coastguard Worker         std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op2(convolution_op2, xnn_delete_operator);
1462*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(weights_cache.cache.hits, 1);
1463*4bdc9457SAndroid Build Coastguard Worker         // Ensure that we did not write more weights to the cache because it was a cache hit.
1464*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(old_weights_cache_size, weights_cache.cache.weights.size);
1465*4bdc9457SAndroid Build Coastguard Worker 
1466*4bdc9457SAndroid Build Coastguard Worker         VerifyNHWCxF32(output2, output_ref, output_min, output_max);
1467*4bdc9457SAndroid Build Coastguard Worker         #if XNN_PLATFORM_JIT
1468*4bdc9457SAndroid Build Coastguard Worker           if (use_jit()) {
1469*4bdc9457SAndroid Build Coastguard Worker             xnn_release_code_cache(&inner_code_cache);
1470*4bdc9457SAndroid Build Coastguard Worker           }
1471*4bdc9457SAndroid Build Coastguard Worker         #endif
1472*4bdc9457SAndroid Build Coastguard Worker       }
1473*4bdc9457SAndroid Build Coastguard Worker 
1474*4bdc9457SAndroid Build Coastguard Worker       #if XNN_PLATFORM_JIT
1475*4bdc9457SAndroid Build Coastguard Worker         if (use_jit()) {
1476*4bdc9457SAndroid Build Coastguard Worker           xnn_release_code_cache(&code_cache);
1477*4bdc9457SAndroid Build Coastguard Worker         }
1478*4bdc9457SAndroid Build Coastguard Worker       #endif
1479*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
1480*4bdc9457SAndroid Build Coastguard Worker         xnn_release_weights_cache(&weights_cache);
1481*4bdc9457SAndroid Build Coastguard Worker       }
1482*4bdc9457SAndroid Build Coastguard Worker     }
1483*4bdc9457SAndroid Build Coastguard Worker   }
1484*4bdc9457SAndroid Build Coastguard Worker 
VerifyNHWCxF32(const std::vector<float> & output,const std::vector<float> & output_ref,const float output_min,const float output_max)1485*4bdc9457SAndroid Build Coastguard Worker   void VerifyNHWCxF32(const std::vector<float>& output, const std::vector<float>& output_ref, const float output_min, const float output_max) const {
1486*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < batch_size(); i++) {
1487*4bdc9457SAndroid Build Coastguard Worker       for (size_t y = 0; y < output_height(); y++) {
1488*4bdc9457SAndroid Build Coastguard Worker         for (size_t x = 0; x < output_width(); x++) {
1489*4bdc9457SAndroid Build Coastguard Worker           for (size_t g = 0; g < groups(); g++) {
1490*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < group_output_channels(); c++) {
1491*4bdc9457SAndroid Build Coastguard Worker               ASSERT_GE(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c], output_min)
1492*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1493*4bdc9457SAndroid Build Coastguard Worker               ASSERT_LE(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c], output_max)
1494*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1495*4bdc9457SAndroid Build Coastguard Worker               ASSERT_NEAR(
1496*4bdc9457SAndroid Build Coastguard Worker                   output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c],
1497*4bdc9457SAndroid Build Coastguard Worker                   output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c],
1498*4bdc9457SAndroid Build Coastguard Worker                   1.0e-4 * std::abs(output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c]))
1499*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1500*4bdc9457SAndroid Build Coastguard Worker             }
1501*4bdc9457SAndroid Build Coastguard Worker           }
1502*4bdc9457SAndroid Build Coastguard Worker         }
1503*4bdc9457SAndroid Build Coastguard Worker       }
1504*4bdc9457SAndroid Build Coastguard Worker     }
1505*4bdc9457SAndroid Build Coastguard Worker   }
1506*4bdc9457SAndroid Build Coastguard Worker 
TestNHWCxF16()1507*4bdc9457SAndroid Build Coastguard Worker   void TestNHWCxF16() const {
1508*4bdc9457SAndroid Build Coastguard Worker     switch (weights_type()) {
1509*4bdc9457SAndroid Build Coastguard Worker       case WeightsType::Default:
1510*4bdc9457SAndroid Build Coastguard Worker         break;
1511*4bdc9457SAndroid Build Coastguard Worker       case WeightsType::FP32:
1512*4bdc9457SAndroid Build Coastguard Worker         break;
1513*4bdc9457SAndroid Build Coastguard Worker       default:
1514*4bdc9457SAndroid Build Coastguard Worker         GTEST_FAIL() << "unexpected weights type";
1515*4bdc9457SAndroid Build Coastguard Worker     }
1516*4bdc9457SAndroid Build Coastguard Worker 
1517*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
1518*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
1519*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(0.1f, 1.0f);
1520*4bdc9457SAndroid Build Coastguard Worker 
1521*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) +
1522*4bdc9457SAndroid Build Coastguard Worker       batch_size() * ((input_height() * input_width() - 1) * input_channel_stride() + groups() * group_input_channels()));
1523*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels());
1524*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> kernel_as_float(kernel.size());
1525*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> bias(groups() * group_output_channels());
1526*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> bias_as_float(bias.size());
1527*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> output(batch_size() * ((output_height() * output_width() - 1) * output_channel_stride() + groups() * group_output_channels()));
1528*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels());
1529*4bdc9457SAndroid Build Coastguard Worker 
1530*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
1531*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
1532*4bdc9457SAndroid Build Coastguard Worker       std::generate(kernel.begin(), kernel.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
1533*4bdc9457SAndroid Build Coastguard Worker       std::transform(kernel.cbegin(), kernel.cend(), kernel_as_float.begin(), fp16_ieee_to_fp32_value);
1534*4bdc9457SAndroid Build Coastguard Worker       std::generate(bias.begin(), bias.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
1535*4bdc9457SAndroid Build Coastguard Worker       std::transform(bias.cbegin(), bias.cend(), bias_as_float.begin(), fp16_ieee_to_fp32_value);
1536*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
1537*4bdc9457SAndroid Build Coastguard Worker 
1538*4bdc9457SAndroid Build Coastguard Worker 
1539*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
1540*4bdc9457SAndroid Build Coastguard Worker       if (has_bias()) {
1541*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
1542*4bdc9457SAndroid Build Coastguard Worker           for (size_t oy = 0; oy < output_height(); oy++) {
1543*4bdc9457SAndroid Build Coastguard Worker             for (size_t ox = 0; ox < output_width(); ox++) {
1544*4bdc9457SAndroid Build Coastguard Worker               for (size_t g = 0; g < groups(); g++) {
1545*4bdc9457SAndroid Build Coastguard Worker                 for (size_t oc = 0; oc < group_output_channels(); oc++) {
1546*4bdc9457SAndroid Build Coastguard Worker                   output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] =
1547*4bdc9457SAndroid Build Coastguard Worker                     fp16_ieee_to_fp32_value(bias[g * group_output_channels() + oc]);
1548*4bdc9457SAndroid Build Coastguard Worker                 }
1549*4bdc9457SAndroid Build Coastguard Worker               }
1550*4bdc9457SAndroid Build Coastguard Worker             }
1551*4bdc9457SAndroid Build Coastguard Worker           }
1552*4bdc9457SAndroid Build Coastguard Worker         }
1553*4bdc9457SAndroid Build Coastguard Worker       } else {
1554*4bdc9457SAndroid Build Coastguard Worker         std::fill(output_ref.begin(), output_ref.end(), 0.0f);
1555*4bdc9457SAndroid Build Coastguard Worker       }
1556*4bdc9457SAndroid Build Coastguard Worker       if (depthwise_layout()) {
1557*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(group_input_channels(), 1);
1558*4bdc9457SAndroid Build Coastguard Worker 
1559*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
1560*4bdc9457SAndroid Build Coastguard Worker           for (size_t oy = 0; oy < output_height(); oy++) {
1561*4bdc9457SAndroid Build Coastguard Worker             for (size_t ox = 0; ox < output_width(); ox++) {
1562*4bdc9457SAndroid Build Coastguard Worker               for (size_t ky = 0; ky < kernel_height(); ky++) {
1563*4bdc9457SAndroid Build Coastguard Worker                 const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top();
1564*4bdc9457SAndroid Build Coastguard Worker                 if (iy < input_height()) {
1565*4bdc9457SAndroid Build Coastguard Worker                   for (size_t kx = 0; kx < kernel_width(); kx++) {
1566*4bdc9457SAndroid Build Coastguard Worker                     const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left();
1567*4bdc9457SAndroid Build Coastguard Worker                     if (ix < input_width()) {
1568*4bdc9457SAndroid Build Coastguard Worker                       for (size_t g = 0; g < groups(); g++) {
1569*4bdc9457SAndroid Build Coastguard Worker                         for (size_t oc = 0; oc < group_output_channels(); oc++) {
1570*4bdc9457SAndroid Build Coastguard Worker                           output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
1571*4bdc9457SAndroid Build Coastguard Worker                             fp16_ieee_to_fp32_value(input[((i * input_height() + iy) * input_width() + ix) * input_channel_stride() + g]) *
1572*4bdc9457SAndroid Build Coastguard Worker                             fp16_ieee_to_fp32_value(kernel[((ky * kernel_width() + kx) * groups() + g) * group_output_channels() + oc]);
1573*4bdc9457SAndroid Build Coastguard Worker                         }
1574*4bdc9457SAndroid Build Coastguard Worker                       }
1575*4bdc9457SAndroid Build Coastguard Worker                     }
1576*4bdc9457SAndroid Build Coastguard Worker                   }
1577*4bdc9457SAndroid Build Coastguard Worker                 }
1578*4bdc9457SAndroid Build Coastguard Worker               }
1579*4bdc9457SAndroid Build Coastguard Worker             }
1580*4bdc9457SAndroid Build Coastguard Worker           }
1581*4bdc9457SAndroid Build Coastguard Worker         }
1582*4bdc9457SAndroid Build Coastguard Worker       } else {
1583*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
1584*4bdc9457SAndroid Build Coastguard Worker           for (size_t oy = 0; oy < output_height(); oy++) {
1585*4bdc9457SAndroid Build Coastguard Worker             for (size_t ox = 0; ox < output_width(); ox++) {
1586*4bdc9457SAndroid Build Coastguard Worker               for (size_t ky = 0; ky < kernel_height(); ky++) {
1587*4bdc9457SAndroid Build Coastguard Worker                 const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top();
1588*4bdc9457SAndroid Build Coastguard Worker                 if (iy < input_height()) {
1589*4bdc9457SAndroid Build Coastguard Worker                   for (size_t kx = 0; kx < kernel_width(); kx++) {
1590*4bdc9457SAndroid Build Coastguard Worker                     const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left();
1591*4bdc9457SAndroid Build Coastguard Worker                     if (ix < input_width()) {
1592*4bdc9457SAndroid Build Coastguard Worker                       for (size_t g = 0; g < groups(); g++) {
1593*4bdc9457SAndroid Build Coastguard Worker                         for (size_t oc = 0; oc < group_output_channels(); oc++) {
1594*4bdc9457SAndroid Build Coastguard Worker                           for (size_t ic = 0; ic < group_input_channels(); ic++) {
1595*4bdc9457SAndroid Build Coastguard Worker                             output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
1596*4bdc9457SAndroid Build Coastguard Worker                               fp16_ieee_to_fp32_value(input[((i * input_height() + iy) * input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic]) *
1597*4bdc9457SAndroid Build Coastguard Worker                               fp16_ieee_to_fp32_value(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]);
1598*4bdc9457SAndroid Build Coastguard Worker                           }
1599*4bdc9457SAndroid Build Coastguard Worker                         }
1600*4bdc9457SAndroid Build Coastguard Worker                       }
1601*4bdc9457SAndroid Build Coastguard Worker                     }
1602*4bdc9457SAndroid Build Coastguard Worker                   }
1603*4bdc9457SAndroid Build Coastguard Worker                 }
1604*4bdc9457SAndroid Build Coastguard Worker               }
1605*4bdc9457SAndroid Build Coastguard Worker             }
1606*4bdc9457SAndroid Build Coastguard Worker           }
1607*4bdc9457SAndroid Build Coastguard Worker         }
1608*4bdc9457SAndroid Build Coastguard Worker       }
1609*4bdc9457SAndroid Build Coastguard Worker 
1610*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
1611*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
1612*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
1613*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
1614*4bdc9457SAndroid Build Coastguard Worker       const float scaled_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + accumulated_range / 255.0f * float(qmin())));
1615*4bdc9457SAndroid Build Coastguard Worker       const float scaled_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - accumulated_range / 255.0f * float(255 - qmax())));
1616*4bdc9457SAndroid Build Coastguard Worker       const float output_min = scaled_min == scaled_max ? -std::numeric_limits<float>::infinity() : scaled_min;
1617*4bdc9457SAndroid Build Coastguard Worker       const float output_max = scaled_min == scaled_max ? +std::numeric_limits<float>::infinity() : scaled_max;
1618*4bdc9457SAndroid Build Coastguard Worker 
1619*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
1620*4bdc9457SAndroid Build Coastguard Worker       for (float& value : output_ref) {
1621*4bdc9457SAndroid Build Coastguard Worker         value = std::max(std::min(value, output_max), output_min);
1622*4bdc9457SAndroid Build Coastguard Worker       }
1623*4bdc9457SAndroid Build Coastguard Worker 
1624*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Convolution operator.
1625*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
1626*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t convolution_op = nullptr;
1627*4bdc9457SAndroid Build Coastguard Worker       xnn_caches caches = {
1628*4bdc9457SAndroid Build Coastguard Worker         .code_cache = NULL,
1629*4bdc9457SAndroid Build Coastguard Worker         .weights_cache = NULL,
1630*4bdc9457SAndroid Build Coastguard Worker       };
1631*4bdc9457SAndroid Build Coastguard Worker       xnn_weights_cache weights_cache;
1632*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
1633*4bdc9457SAndroid Build Coastguard Worker         xnn_init_weights_cache(&weights_cache);
1634*4bdc9457SAndroid Build Coastguard Worker         caches.weights_cache = &weights_cache;
1635*4bdc9457SAndroid Build Coastguard Worker       }
1636*4bdc9457SAndroid Build Coastguard Worker 
1637*4bdc9457SAndroid Build Coastguard Worker       const void* kernel_data = kernel.data();
1638*4bdc9457SAndroid Build Coastguard Worker       const void* bias_data = bias.data();
1639*4bdc9457SAndroid Build Coastguard Worker       if (weights_type() == WeightsType::FP32) {
1640*4bdc9457SAndroid Build Coastguard Worker         kernel_data = kernel_as_float.data();
1641*4bdc9457SAndroid Build Coastguard Worker         bias_data = bias_as_float.data();
1642*4bdc9457SAndroid Build Coastguard Worker       }
1643*4bdc9457SAndroid Build Coastguard Worker       uint32_t flags = 0;
1644*4bdc9457SAndroid Build Coastguard Worker       if (depthwise_layout()) {
1645*4bdc9457SAndroid Build Coastguard Worker         flags |= XNN_FLAG_DEPTHWISE_CONVOLUTION;
1646*4bdc9457SAndroid Build Coastguard Worker       }
1647*4bdc9457SAndroid Build Coastguard Worker       if (padding_tf_same()) {
1648*4bdc9457SAndroid Build Coastguard Worker         flags |= XNN_FLAG_TENSORFLOW_SAME_PADDING;
1649*4bdc9457SAndroid Build Coastguard Worker       }
1650*4bdc9457SAndroid Build Coastguard Worker       if (weights_type() == WeightsType::FP32) {
1651*4bdc9457SAndroid Build Coastguard Worker         flags |= XNN_FLAG_FP32_STATIC_WEIGHTS;
1652*4bdc9457SAndroid Build Coastguard Worker       }
1653*4bdc9457SAndroid Build Coastguard Worker       xnn_status status = xnn_create_convolution2d_nhwc_f16(
1654*4bdc9457SAndroid Build Coastguard Worker           padding_tf_same() ? 0 : padding_top(), padding_tf_same() ? 0 : padding_right(),
1655*4bdc9457SAndroid Build Coastguard Worker           padding_tf_same() ? 0 : padding_bottom(), padding_tf_same() ? 0 : padding_left(),
1656*4bdc9457SAndroid Build Coastguard Worker           kernel_height(), kernel_width(),
1657*4bdc9457SAndroid Build Coastguard Worker           subsampling_height(), subsampling_width(),
1658*4bdc9457SAndroid Build Coastguard Worker           dilation_height(), dilation_width(),
1659*4bdc9457SAndroid Build Coastguard Worker           groups(), group_input_channels(), group_output_channels(),
1660*4bdc9457SAndroid Build Coastguard Worker           input_channel_stride(), output_channel_stride(),
1661*4bdc9457SAndroid Build Coastguard Worker           kernel_data, has_bias() ? bias_data : nullptr,
1662*4bdc9457SAndroid Build Coastguard Worker           output_min, output_max,
1663*4bdc9457SAndroid Build Coastguard Worker           flags,
1664*4bdc9457SAndroid Build Coastguard Worker           &caches,
1665*4bdc9457SAndroid Build Coastguard Worker           &convolution_op);
1666*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
1667*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
1668*4bdc9457SAndroid Build Coastguard Worker       }
1669*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
1670*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, convolution_op);
1671*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
1672*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
1673*4bdc9457SAndroid Build Coastguard Worker                   xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
1674*4bdc9457SAndroid Build Coastguard Worker       }
1675*4bdc9457SAndroid Build Coastguard Worker 
1676*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete convolution_op.
1677*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op, xnn_delete_operator);
1678*4bdc9457SAndroid Build Coastguard Worker 
1679*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
1680*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_convolution2d_nhwc_f16(
1681*4bdc9457SAndroid Build Coastguard Worker           convolution_op,
1682*4bdc9457SAndroid Build Coastguard Worker           batch_size(), input_height(), input_width(),
1683*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
1684*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
1685*4bdc9457SAndroid Build Coastguard Worker 
1686*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
1687*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(convolution_op, nullptr /* thread pool */));
1688*4bdc9457SAndroid Build Coastguard Worker 
1689*4bdc9457SAndroid Build Coastguard Worker       VerifyNHWCxF16(output, output_ref, output_min, output_max);
1690*4bdc9457SAndroid Build Coastguard Worker 
1691*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
1692*4bdc9457SAndroid Build Coastguard Worker         xnn_operator_t convolution_op2 = nullptr;
1693*4bdc9457SAndroid Build Coastguard Worker         size_t old_weights_cache_size = weights_cache.cache.weights.size;
1694*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success, xnn_create_convolution2d_nhwc_f16(
1695*4bdc9457SAndroid Build Coastguard Worker             padding_tf_same() ? 0 : padding_top(), padding_tf_same() ? 0 : padding_right(),
1696*4bdc9457SAndroid Build Coastguard Worker             padding_tf_same() ? 0 : padding_bottom(), padding_tf_same() ? 0 : padding_left(),
1697*4bdc9457SAndroid Build Coastguard Worker             kernel_height(), kernel_width(),
1698*4bdc9457SAndroid Build Coastguard Worker             subsampling_height(), subsampling_width(),
1699*4bdc9457SAndroid Build Coastguard Worker             dilation_height(), dilation_width(),
1700*4bdc9457SAndroid Build Coastguard Worker             groups(), group_input_channels(), group_output_channels(),
1701*4bdc9457SAndroid Build Coastguard Worker             input_channel_stride(), output_channel_stride(),
1702*4bdc9457SAndroid Build Coastguard Worker             kernel_data, has_bias() ? bias_data : nullptr,
1703*4bdc9457SAndroid Build Coastguard Worker             output_min, output_max,
1704*4bdc9457SAndroid Build Coastguard Worker             flags,
1705*4bdc9457SAndroid Build Coastguard Worker             &caches,
1706*4bdc9457SAndroid Build Coastguard Worker             &convolution_op2));
1707*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NE(nullptr, convolution_op2);
1708*4bdc9457SAndroid Build Coastguard Worker 
1709*4bdc9457SAndroid Build Coastguard Worker         // Smart pointer to automatically delete convolution_op.
1710*4bdc9457SAndroid Build Coastguard Worker         std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op2, xnn_delete_operator);
1711*4bdc9457SAndroid Build Coastguard Worker 
1712*4bdc9457SAndroid Build Coastguard Worker         std::vector<uint16_t> output2(output.size(), UINT16_C(0x7E00) /* NaN */);
1713*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
1714*4bdc9457SAndroid Build Coastguard Worker                   xnn_setup_convolution2d_nhwc_f16(
1715*4bdc9457SAndroid Build Coastguard Worker                       convolution_op2,
1716*4bdc9457SAndroid Build Coastguard Worker                       batch_size(), input_height(), input_width(),
1717*4bdc9457SAndroid Build Coastguard Worker                       input.data(), output2.data(),
1718*4bdc9457SAndroid Build Coastguard Worker                       nullptr /* thread pool */));
1719*4bdc9457SAndroid Build Coastguard Worker 
1720*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
1721*4bdc9457SAndroid Build Coastguard Worker                   xnn_run_operator(convolution_op2, nullptr /* thread pool */));
1722*4bdc9457SAndroid Build Coastguard Worker 
1723*4bdc9457SAndroid Build Coastguard Worker         VerifyNHWCxF16(output2, output_ref, output_min, output_max);
1724*4bdc9457SAndroid Build Coastguard Worker         VerifyWeightsCache(weights_cache, old_weights_cache_size);
1725*4bdc9457SAndroid Build Coastguard Worker         xnn_release_weights_cache(&weights_cache);
1726*4bdc9457SAndroid Build Coastguard Worker       }
1727*4bdc9457SAndroid Build Coastguard Worker     }
1728*4bdc9457SAndroid Build Coastguard Worker   }
1729*4bdc9457SAndroid Build Coastguard Worker 
VerifyNHWCxF16(const std::vector<uint16_t> & output,const std::vector<float> & output_ref,const float output_min,const float output_max)1730*4bdc9457SAndroid Build Coastguard Worker   void VerifyNHWCxF16(const std::vector<uint16_t> &output,
1731*4bdc9457SAndroid Build Coastguard Worker                       const std::vector<float> &output_ref,
1732*4bdc9457SAndroid Build Coastguard Worker                       const float output_min, const float output_max) const {
1733*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < batch_size(); i++) {
1734*4bdc9457SAndroid Build Coastguard Worker       for (size_t y = 0; y < output_height(); y++) {
1735*4bdc9457SAndroid Build Coastguard Worker         for (size_t x = 0; x < output_width(); x++) {
1736*4bdc9457SAndroid Build Coastguard Worker           for (size_t g = 0; g < groups(); g++) {
1737*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < group_output_channels(); c++) {
1738*4bdc9457SAndroid Build Coastguard Worker              ASSERT_GE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), output_min)
1739*4bdc9457SAndroid Build Coastguard Worker                << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1740*4bdc9457SAndroid Build Coastguard Worker              ASSERT_LE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), output_max)
1741*4bdc9457SAndroid Build Coastguard Worker                << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1742*4bdc9457SAndroid Build Coastguard Worker               ASSERT_NEAR(output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c], fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), std::max(1.0e-4f, std::abs(output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c]) * 1.0e-2f))
1743*4bdc9457SAndroid Build Coastguard Worker                 << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1744*4bdc9457SAndroid Build Coastguard Worker             }
1745*4bdc9457SAndroid Build Coastguard Worker           }
1746*4bdc9457SAndroid Build Coastguard Worker         }
1747*4bdc9457SAndroid Build Coastguard Worker       }
1748*4bdc9457SAndroid Build Coastguard Worker     }
1749*4bdc9457SAndroid Build Coastguard Worker   }
1750*4bdc9457SAndroid Build Coastguard Worker 
TestNCHWxF32()1751*4bdc9457SAndroid Build Coastguard Worker   void TestNCHWxF32() {
1752*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(weights_type(), WeightsType::Default);
1753*4bdc9457SAndroid Build Coastguard Worker 
1754*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
1755*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
1756*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(0.1f, 1.0f);
1757*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> pdist;
1758*4bdc9457SAndroid Build Coastguard Worker 
1759*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input(2 * XNN_EXTRA_BYTES / sizeof(float) +
1760*4bdc9457SAndroid Build Coastguard Worker       ((batch_size() - 1) * input_channel_stride() + groups() * group_input_channels()) * input_height() * input_width());
1761*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> kernel(
1762*4bdc9457SAndroid Build Coastguard Worker       groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels());
1763*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> bias(groups() * group_output_channels());
1764*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output(
1765*4bdc9457SAndroid Build Coastguard Worker       ((batch_size() - 1) * output_channel_stride() + groups() * group_output_channels()) * output_height() * output_width());
1766*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(batch_size() * groups() * group_output_channels() * output_height() * output_width());
1767*4bdc9457SAndroid Build Coastguard Worker 
1768*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
1769*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
1770*4bdc9457SAndroid Build Coastguard Worker       std::generate(kernel.begin(), kernel.end(), [&]() { return f32dist(rng); });
1771*4bdc9457SAndroid Build Coastguard Worker       for (float& k : kernel) {
1772*4bdc9457SAndroid Build Coastguard Worker         if (pdist(rng) <= sparsity()) {
1773*4bdc9457SAndroid Build Coastguard Worker           k = 0.0f;
1774*4bdc9457SAndroid Build Coastguard Worker         }
1775*4bdc9457SAndroid Build Coastguard Worker       }
1776*4bdc9457SAndroid Build Coastguard Worker       std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
1777*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), nanf(""));
1778*4bdc9457SAndroid Build Coastguard Worker 
1779*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
1780*4bdc9457SAndroid Build Coastguard Worker       if (has_bias()) {
1781*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
1782*4bdc9457SAndroid Build Coastguard Worker           for (size_t oy = 0; oy < output_height(); oy++) {
1783*4bdc9457SAndroid Build Coastguard Worker             for (size_t ox = 0; ox < output_width(); ox++) {
1784*4bdc9457SAndroid Build Coastguard Worker               for (size_t g = 0; g < groups(); g++) {
1785*4bdc9457SAndroid Build Coastguard Worker                 for (size_t oc = 0; oc < group_output_channels(); oc++) {
1786*4bdc9457SAndroid Build Coastguard Worker                   output_ref[(((i * groups() + g) * group_output_channels() + oc) * output_height() + oy) * output_width() + ox] =
1787*4bdc9457SAndroid Build Coastguard Worker                     bias[g * group_output_channels() + oc];
1788*4bdc9457SAndroid Build Coastguard Worker                 }
1789*4bdc9457SAndroid Build Coastguard Worker               }
1790*4bdc9457SAndroid Build Coastguard Worker             }
1791*4bdc9457SAndroid Build Coastguard Worker           }
1792*4bdc9457SAndroid Build Coastguard Worker         }
1793*4bdc9457SAndroid Build Coastguard Worker       } else {
1794*4bdc9457SAndroid Build Coastguard Worker         std::fill(output_ref.begin(), output_ref.end(), 0.0f);
1795*4bdc9457SAndroid Build Coastguard Worker       }
1796*4bdc9457SAndroid Build Coastguard Worker       if (force_nhwc_input()) {
1797*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
1798*4bdc9457SAndroid Build Coastguard Worker           for (size_t oy = 0; oy < output_height(); oy++) {
1799*4bdc9457SAndroid Build Coastguard Worker             for (size_t ox = 0; ox < output_width(); ox++) {
1800*4bdc9457SAndroid Build Coastguard Worker               for (size_t ky = 0; ky < kernel_height(); ky++) {
1801*4bdc9457SAndroid Build Coastguard Worker                 const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top();
1802*4bdc9457SAndroid Build Coastguard Worker                 if (iy < input_height()) {
1803*4bdc9457SAndroid Build Coastguard Worker                   for (size_t kx = 0; kx < kernel_width(); kx++) {
1804*4bdc9457SAndroid Build Coastguard Worker                     const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left();
1805*4bdc9457SAndroid Build Coastguard Worker                     if (ix < input_width()) {
1806*4bdc9457SAndroid Build Coastguard Worker                       for (size_t g = 0; g < groups(); g++) {
1807*4bdc9457SAndroid Build Coastguard Worker                         for (size_t oc = 0; oc < group_output_channels(); oc++) {
1808*4bdc9457SAndroid Build Coastguard Worker                           for (size_t ic = 0; ic < group_input_channels(); ic++) {
1809*4bdc9457SAndroid Build Coastguard Worker                             output_ref[(((i * groups() + g) * group_output_channels() + oc) * output_height() + oy) * output_width() + ox] +=
1810*4bdc9457SAndroid Build Coastguard Worker                               input[((((i * input_height() + iy) * input_width() + ix) * groups() + g) * group_input_channels() + ic)] *
1811*4bdc9457SAndroid Build Coastguard Worker                               kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic];
1812*4bdc9457SAndroid Build Coastguard Worker                           }
1813*4bdc9457SAndroid Build Coastguard Worker                         }
1814*4bdc9457SAndroid Build Coastguard Worker                       }
1815*4bdc9457SAndroid Build Coastguard Worker                     }
1816*4bdc9457SAndroid Build Coastguard Worker                   }
1817*4bdc9457SAndroid Build Coastguard Worker                 }
1818*4bdc9457SAndroid Build Coastguard Worker               }
1819*4bdc9457SAndroid Build Coastguard Worker             }
1820*4bdc9457SAndroid Build Coastguard Worker           }
1821*4bdc9457SAndroid Build Coastguard Worker         }
1822*4bdc9457SAndroid Build Coastguard Worker       } else if (depthwise_layout()) {
1823*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(group_input_channels(), 1);
1824*4bdc9457SAndroid Build Coastguard Worker 
1825*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
1826*4bdc9457SAndroid Build Coastguard Worker           for (size_t oy = 0; oy < output_height(); oy++) {
1827*4bdc9457SAndroid Build Coastguard Worker             for (size_t ox = 0; ox < output_width(); ox++) {
1828*4bdc9457SAndroid Build Coastguard Worker               for (size_t ky = 0; ky < kernel_height(); ky++) {
1829*4bdc9457SAndroid Build Coastguard Worker                 const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top();
1830*4bdc9457SAndroid Build Coastguard Worker                 if (iy < input_height()) {
1831*4bdc9457SAndroid Build Coastguard Worker                   for (size_t kx = 0; kx < kernel_width(); kx++) {
1832*4bdc9457SAndroid Build Coastguard Worker                     const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left();
1833*4bdc9457SAndroid Build Coastguard Worker                     if (ix < input_width()) {
1834*4bdc9457SAndroid Build Coastguard Worker                       for (size_t g = 0; g < groups(); g++) {
1835*4bdc9457SAndroid Build Coastguard Worker                         for (size_t oc = 0; oc < group_output_channels(); oc++) {
1836*4bdc9457SAndroid Build Coastguard Worker                           output_ref[(((i * groups() + g) * group_output_channels() + oc) * output_height() + oy) * output_width() + ox] +=
1837*4bdc9457SAndroid Build Coastguard Worker                             input[((i * input_channel_stride() + g) * input_height() + iy) * input_width() + ix] *
1838*4bdc9457SAndroid Build Coastguard Worker                             kernel[((ky * kernel_width() + kx) * groups() + g) * group_output_channels() + oc];
1839*4bdc9457SAndroid Build Coastguard Worker                         }
1840*4bdc9457SAndroid Build Coastguard Worker                       }
1841*4bdc9457SAndroid Build Coastguard Worker                     }
1842*4bdc9457SAndroid Build Coastguard Worker                   }
1843*4bdc9457SAndroid Build Coastguard Worker                 }
1844*4bdc9457SAndroid Build Coastguard Worker               }
1845*4bdc9457SAndroid Build Coastguard Worker             }
1846*4bdc9457SAndroid Build Coastguard Worker           }
1847*4bdc9457SAndroid Build Coastguard Worker         }
1848*4bdc9457SAndroid Build Coastguard Worker       } else {
1849*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
1850*4bdc9457SAndroid Build Coastguard Worker           for (size_t oy = 0; oy < output_height(); oy++) {
1851*4bdc9457SAndroid Build Coastguard Worker             for (size_t ox = 0; ox < output_width(); ox++) {
1852*4bdc9457SAndroid Build Coastguard Worker               for (size_t ky = 0; ky < kernel_height(); ky++) {
1853*4bdc9457SAndroid Build Coastguard Worker                 const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top();
1854*4bdc9457SAndroid Build Coastguard Worker                 if (iy < input_height()) {
1855*4bdc9457SAndroid Build Coastguard Worker                   for (size_t kx = 0; kx < kernel_width(); kx++) {
1856*4bdc9457SAndroid Build Coastguard Worker                     const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left();
1857*4bdc9457SAndroid Build Coastguard Worker                     if (ix < input_width()) {
1858*4bdc9457SAndroid Build Coastguard Worker                       for (size_t g = 0; g < groups(); g++) {
1859*4bdc9457SAndroid Build Coastguard Worker                         for (size_t oc = 0; oc < group_output_channels(); oc++) {
1860*4bdc9457SAndroid Build Coastguard Worker                           for (size_t ic = 0; ic < group_input_channels(); ic++) {
1861*4bdc9457SAndroid Build Coastguard Worker                             output_ref[(((i * groups() + g) * group_output_channels() + oc) * output_height() + oy) * output_width() + ox] +=
1862*4bdc9457SAndroid Build Coastguard Worker                               input[((i * input_channel_stride() + g * group_input_channels() + ic) * input_height() + iy) * input_width() + ix] *
1863*4bdc9457SAndroid Build Coastguard Worker                               kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic];
1864*4bdc9457SAndroid Build Coastguard Worker                           }
1865*4bdc9457SAndroid Build Coastguard Worker                         }
1866*4bdc9457SAndroid Build Coastguard Worker                       }
1867*4bdc9457SAndroid Build Coastguard Worker                     }
1868*4bdc9457SAndroid Build Coastguard Worker                   }
1869*4bdc9457SAndroid Build Coastguard Worker                 }
1870*4bdc9457SAndroid Build Coastguard Worker               }
1871*4bdc9457SAndroid Build Coastguard Worker             }
1872*4bdc9457SAndroid Build Coastguard Worker           }
1873*4bdc9457SAndroid Build Coastguard Worker         }
1874*4bdc9457SAndroid Build Coastguard Worker       }
1875*4bdc9457SAndroid Build Coastguard Worker 
1876*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
1877*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
1878*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
1879*4bdc9457SAndroid Build Coastguard Worker 
1880*4bdc9457SAndroid Build Coastguard Worker       const float output_min = qmin() == 0 ? -std::numeric_limits<float>::infinity() :
1881*4bdc9457SAndroid Build Coastguard Worker         accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1882*4bdc9457SAndroid Build Coastguard Worker       const float output_max = qmax() == 255 ? std::numeric_limits<float>::infinity() :
1883*4bdc9457SAndroid Build Coastguard Worker         accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1884*4bdc9457SAndroid Build Coastguard Worker 
1885*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
1886*4bdc9457SAndroid Build Coastguard Worker       for (float& value : output_ref) {
1887*4bdc9457SAndroid Build Coastguard Worker         value = std::max(std::min(value, output_max), output_min);
1888*4bdc9457SAndroid Build Coastguard Worker       }
1889*4bdc9457SAndroid Build Coastguard Worker 
1890*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Convolution operator.
1891*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
1892*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t convolution_op = nullptr;
1893*4bdc9457SAndroid Build Coastguard Worker       xnn_caches caches = {
1894*4bdc9457SAndroid Build Coastguard Worker         .code_cache = NULL,
1895*4bdc9457SAndroid Build Coastguard Worker         .weights_cache = NULL,
1896*4bdc9457SAndroid Build Coastguard Worker       };
1897*4bdc9457SAndroid Build Coastguard Worker       xnn_weights_cache weights_cache;
1898*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
1899*4bdc9457SAndroid Build Coastguard Worker         xnn_init_weights_cache(&weights_cache);
1900*4bdc9457SAndroid Build Coastguard Worker         caches.weights_cache = &weights_cache;
1901*4bdc9457SAndroid Build Coastguard Worker       }
1902*4bdc9457SAndroid Build Coastguard Worker 
1903*4bdc9457SAndroid Build Coastguard Worker       xnn_status status = xnn_create_convolution2d_nchw_f32(
1904*4bdc9457SAndroid Build Coastguard Worker           padding_top(), padding_right(), padding_bottom(), padding_left(),
1905*4bdc9457SAndroid Build Coastguard Worker           kernel_height(), kernel_width(),
1906*4bdc9457SAndroid Build Coastguard Worker           subsampling_height(), subsampling_width(),
1907*4bdc9457SAndroid Build Coastguard Worker           dilation_height(), dilation_width(),
1908*4bdc9457SAndroid Build Coastguard Worker           groups(), group_input_channels(), group_output_channels(),
1909*4bdc9457SAndroid Build Coastguard Worker           input_channel_stride(), output_channel_stride(),
1910*4bdc9457SAndroid Build Coastguard Worker           kernel.data(), has_bias() ? bias.data() : nullptr,
1911*4bdc9457SAndroid Build Coastguard Worker           output_min, output_max,
1912*4bdc9457SAndroid Build Coastguard Worker           (depthwise_layout() ? XNN_FLAG_DEPTHWISE_CONVOLUTION : 0) | (force_nhwc_input() ? XNN_FLAG_INPUT_NHWC : 0),
1913*4bdc9457SAndroid Build Coastguard Worker           &caches,
1914*4bdc9457SAndroid Build Coastguard Worker           &convolution_op);
1915*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_parameter) {
1916*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
1917*4bdc9457SAndroid Build Coastguard Worker       }
1918*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
1919*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, convolution_op);
1920*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
1921*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
1922*4bdc9457SAndroid Build Coastguard Worker                   xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
1923*4bdc9457SAndroid Build Coastguard Worker       }
1924*4bdc9457SAndroid Build Coastguard Worker 
1925*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete convolution_op.
1926*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op, xnn_delete_operator);
1927*4bdc9457SAndroid Build Coastguard Worker 
1928*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
1929*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_convolution2d_nchw_f32(
1930*4bdc9457SAndroid Build Coastguard Worker           convolution_op,
1931*4bdc9457SAndroid Build Coastguard Worker           batch_size(), input_height(), input_width(),
1932*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
1933*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
1934*4bdc9457SAndroid Build Coastguard Worker 
1935*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
1936*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(convolution_op, nullptr /* thread pool */));
1937*4bdc9457SAndroid Build Coastguard Worker 
1938*4bdc9457SAndroid Build Coastguard Worker       VerifyNCHWxF32(output, output_ref, output_min, output_max);
1939*4bdc9457SAndroid Build Coastguard Worker 
1940*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
1941*4bdc9457SAndroid Build Coastguard Worker         xnn_operator_t convolution_op2 = nullptr;
1942*4bdc9457SAndroid Build Coastguard Worker         size_t old_weights_cache_size = weights_cache.cache.weights.size;
1943*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(
1944*4bdc9457SAndroid Build Coastguard Worker             xnn_status_success,
1945*4bdc9457SAndroid Build Coastguard Worker             xnn_create_convolution2d_nchw_f32(
1946*4bdc9457SAndroid Build Coastguard Worker                 padding_top(), padding_right(), padding_bottom(),
1947*4bdc9457SAndroid Build Coastguard Worker                 padding_left(), kernel_height(), kernel_width(),
1948*4bdc9457SAndroid Build Coastguard Worker                 subsampling_height(), subsampling_width(), dilation_height(),
1949*4bdc9457SAndroid Build Coastguard Worker                 dilation_width(), groups(), group_input_channels(),
1950*4bdc9457SAndroid Build Coastguard Worker                 group_output_channels(), input_channel_stride(),
1951*4bdc9457SAndroid Build Coastguard Worker                 output_channel_stride(), kernel.data(),
1952*4bdc9457SAndroid Build Coastguard Worker                 has_bias() ? bias.data() : nullptr, output_min, output_max,
1953*4bdc9457SAndroid Build Coastguard Worker                 (depthwise_layout() ? XNN_FLAG_DEPTHWISE_CONVOLUTION : 0) |
1954*4bdc9457SAndroid Build Coastguard Worker                     (force_nhwc_input() ? XNN_FLAG_INPUT_NHWC : 0),
1955*4bdc9457SAndroid Build Coastguard Worker                 &caches, &convolution_op2));
1956*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NE(nullptr, convolution_op2);
1957*4bdc9457SAndroid Build Coastguard Worker 
1958*4bdc9457SAndroid Build Coastguard Worker         // Smart pointer to automatically delete convolution_op2.
1959*4bdc9457SAndroid Build Coastguard Worker         std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op2, xnn_delete_operator);
1960*4bdc9457SAndroid Build Coastguard Worker         std::vector<float> output2(output.size(), nanf(""));
1961*4bdc9457SAndroid Build Coastguard Worker 
1962*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
1963*4bdc9457SAndroid Build Coastguard Worker                   xnn_setup_convolution2d_nchw_f32(
1964*4bdc9457SAndroid Build Coastguard Worker                       convolution_op2,
1965*4bdc9457SAndroid Build Coastguard Worker                       batch_size(), input_height(), input_width(),
1966*4bdc9457SAndroid Build Coastguard Worker                       input.data(), output2.data(),
1967*4bdc9457SAndroid Build Coastguard Worker                       nullptr /* thread pool */));
1968*4bdc9457SAndroid Build Coastguard Worker 
1969*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
1970*4bdc9457SAndroid Build Coastguard Worker                   xnn_run_operator(convolution_op2, nullptr /* thread pool */));
1971*4bdc9457SAndroid Build Coastguard Worker 
1972*4bdc9457SAndroid Build Coastguard Worker         VerifyNCHWxF32(output2, output_ref, output_min, output_max);
1973*4bdc9457SAndroid Build Coastguard Worker         if (IsSpmm()) {
1974*4bdc9457SAndroid Build Coastguard Worker           VerifyWeightsCacheUnused(weights_cache);
1975*4bdc9457SAndroid Build Coastguard Worker         } else {
1976*4bdc9457SAndroid Build Coastguard Worker           VerifyWeightsCache(weights_cache, old_weights_cache_size);
1977*4bdc9457SAndroid Build Coastguard Worker         }
1978*4bdc9457SAndroid Build Coastguard Worker         xnn_release_weights_cache(&weights_cache);
1979*4bdc9457SAndroid Build Coastguard Worker       }
1980*4bdc9457SAndroid Build Coastguard Worker     }
1981*4bdc9457SAndroid Build Coastguard Worker   }
1982*4bdc9457SAndroid Build Coastguard Worker 
VerifyNCHWxF32(const std::vector<float> & output,const std::vector<float> & output_ref,const float output_min,const float output_max)1983*4bdc9457SAndroid Build Coastguard Worker   void VerifyNCHWxF32(const std::vector<float> &output,
1984*4bdc9457SAndroid Build Coastguard Worker                       const std::vector<float> &output_ref,
1985*4bdc9457SAndroid Build Coastguard Worker                       const float output_min, const float output_max) const {
1986*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < batch_size(); i++) {
1987*4bdc9457SAndroid Build Coastguard Worker       for (size_t y = 0; y < output_height(); y++) {
1988*4bdc9457SAndroid Build Coastguard Worker         for (size_t x = 0; x < output_width(); x++) {
1989*4bdc9457SAndroid Build Coastguard Worker           for (size_t g = 0; g < groups(); g++) {
1990*4bdc9457SAndroid Build Coastguard Worker             for (size_t c = 0; c < group_output_channels(); c++) {
1991*4bdc9457SAndroid Build Coastguard Worker               ASSERT_GE(output[((i * output_channel_stride() + g * group_output_channels() + c) * output_height() + y) * output_width() + x], output_min)
1992*4bdc9457SAndroid Build Coastguard Worker                 << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c << ", image = " << i;
1993*4bdc9457SAndroid Build Coastguard Worker               ASSERT_LE(output[((i * output_channel_stride() + g * group_output_channels() + c) * output_height() + y) * output_width() + x], output_max)
1994*4bdc9457SAndroid Build Coastguard Worker                 << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c << ", image = " << i;
1995*4bdc9457SAndroid Build Coastguard Worker               ASSERT_NEAR(
1996*4bdc9457SAndroid Build Coastguard Worker                   output_ref[(((i * groups() + g) * group_output_channels() + c) * output_height() + y) * output_width() + x],
1997*4bdc9457SAndroid Build Coastguard Worker                   output[((i * output_channel_stride() + g * group_output_channels() + c) * output_height() + y) * output_width() + x],
1998*4bdc9457SAndroid Build Coastguard Worker                   1.0e-4 * std::abs(output_ref[(((i * groups() + g) * group_output_channels() + c) * output_height() + y) * output_width() + x]))
1999*4bdc9457SAndroid Build Coastguard Worker                 << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c << ", image = " << i;
2000*4bdc9457SAndroid Build Coastguard Worker             }
2001*4bdc9457SAndroid Build Coastguard Worker           }
2002*4bdc9457SAndroid Build Coastguard Worker         }
2003*4bdc9457SAndroid Build Coastguard Worker       }
2004*4bdc9457SAndroid Build Coastguard Worker     }
2005*4bdc9457SAndroid Build Coastguard Worker   }
2006*4bdc9457SAndroid Build Coastguard Worker 
TestSetupNHWCxQC8()2007*4bdc9457SAndroid Build Coastguard Worker   void TestSetupNHWCxQC8() const {
2008*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(weights_type(), WeightsType::Default);
2009*4bdc9457SAndroid Build Coastguard Worker 
2010*4bdc9457SAndroid Build Coastguard Worker     ASSERT_FALSE(depthwise_layout());
2011*4bdc9457SAndroid Build Coastguard Worker 
2012*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
2013*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
2014*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i32dist(-10000, 10000);
2015*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i8dist(
2016*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
2017*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> w8dist(
2018*4bdc9457SAndroid Build Coastguard Worker       -std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max());
2019*4bdc9457SAndroid Build Coastguard Worker 
2020*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) + std::max(
2021*4bdc9457SAndroid Build Coastguard Worker       batch_size() * ((input_height() * input_width() - 1) * input_channel_stride() + groups() * group_input_channels()),
2022*4bdc9457SAndroid Build Coastguard Worker       next_batch_size() * ((next_input_height() * next_input_width() - 1) * input_channel_stride() + groups() * group_input_channels())));
2023*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels());
2024*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> bias(groups() * group_output_channels());
2025*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> output(std::max(
2026*4bdc9457SAndroid Build Coastguard Worker       batch_size() * ((output_height() * output_width() - 1) * output_channel_stride() + groups() * group_output_channels()),
2027*4bdc9457SAndroid Build Coastguard Worker       next_batch_size() * ((next_output_height() * next_output_width() - 1) * output_channel_stride() + groups() * group_output_channels())));
2028*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> accumulators(batch_size() * output_height() * output_width() * groups() * group_output_channels());
2029*4bdc9457SAndroid Build Coastguard Worker     std::vector<double> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels());
2030*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> requantization_scales(groups() * group_output_channels());
2031*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> next_accumulators(next_batch_size() * next_output_height() * next_output_width() * groups() * group_output_channels());
2032*4bdc9457SAndroid Build Coastguard Worker     std::vector<double> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * groups() * group_output_channels());
2033*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> next_requantization_scales(groups() * group_output_channels());
2034*4bdc9457SAndroid Build Coastguard Worker 
2035*4bdc9457SAndroid Build Coastguard Worker     const int8_t input_zero_point = -1;
2036*4bdc9457SAndroid Build Coastguard Worker     const int8_t output_zero_point = -1;
2037*4bdc9457SAndroid Build Coastguard Worker 
2038*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
2039*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
2040*4bdc9457SAndroid Build Coastguard Worker       std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); });
2041*4bdc9457SAndroid Build Coastguard Worker       std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); });
2042*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), INT8_C(0xA5));
2043*4bdc9457SAndroid Build Coastguard Worker 
2044*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without renormalization.
2045*4bdc9457SAndroid Build Coastguard Worker       if (has_bias()) {
2046*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
2047*4bdc9457SAndroid Build Coastguard Worker           for (size_t oy = 0; oy < output_height(); oy++) {
2048*4bdc9457SAndroid Build Coastguard Worker             for (size_t ox = 0; ox < output_width(); ox++) {
2049*4bdc9457SAndroid Build Coastguard Worker               for (size_t g = 0; g < groups(); g++) {
2050*4bdc9457SAndroid Build Coastguard Worker                 for (size_t oc = 0; oc < group_output_channels(); oc++) {
2051*4bdc9457SAndroid Build Coastguard Worker                   accumulators[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] =
2052*4bdc9457SAndroid Build Coastguard Worker                     bias[g * group_output_channels() + oc];
2053*4bdc9457SAndroid Build Coastguard Worker                 }
2054*4bdc9457SAndroid Build Coastguard Worker               }
2055*4bdc9457SAndroid Build Coastguard Worker             }
2056*4bdc9457SAndroid Build Coastguard Worker           }
2057*4bdc9457SAndroid Build Coastguard Worker         }
2058*4bdc9457SAndroid Build Coastguard Worker       } else {
2059*4bdc9457SAndroid Build Coastguard Worker         std::fill(accumulators.begin(), accumulators.end(), 0);
2060*4bdc9457SAndroid Build Coastguard Worker       }
2061*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
2062*4bdc9457SAndroid Build Coastguard Worker         for (size_t oy = 0; oy < output_height(); oy++) {
2063*4bdc9457SAndroid Build Coastguard Worker           for (size_t ox = 0; ox < output_width(); ox++) {
2064*4bdc9457SAndroid Build Coastguard Worker             for (size_t ky = 0; ky < kernel_height(); ky++) {
2065*4bdc9457SAndroid Build Coastguard Worker               const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top();
2066*4bdc9457SAndroid Build Coastguard Worker               if (iy < input_height()) {
2067*4bdc9457SAndroid Build Coastguard Worker                 for (size_t kx = 0; kx < kernel_width(); kx++) {
2068*4bdc9457SAndroid Build Coastguard Worker                   const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left();
2069*4bdc9457SAndroid Build Coastguard Worker                   if (ix < input_width()) {
2070*4bdc9457SAndroid Build Coastguard Worker                     for (size_t g = 0; g < groups(); g++) {
2071*4bdc9457SAndroid Build Coastguard Worker                       for (size_t oc = 0; oc < group_output_channels(); oc++) {
2072*4bdc9457SAndroid Build Coastguard Worker                         for (size_t ic = 0; ic < group_input_channels(); ic++) {
2073*4bdc9457SAndroid Build Coastguard Worker                           accumulators[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
2074*4bdc9457SAndroid Build Coastguard Worker                             (int32_t(input[((i * input_height() + iy) * input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic]) - int32_t(input_zero_point)) *
2075*4bdc9457SAndroid Build Coastguard Worker                             int32_t(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]);
2076*4bdc9457SAndroid Build Coastguard Worker                         }
2077*4bdc9457SAndroid Build Coastguard Worker                       }
2078*4bdc9457SAndroid Build Coastguard Worker                     }
2079*4bdc9457SAndroid Build Coastguard Worker                   }
2080*4bdc9457SAndroid Build Coastguard Worker                 }
2081*4bdc9457SAndroid Build Coastguard Worker               }
2082*4bdc9457SAndroid Build Coastguard Worker             }
2083*4bdc9457SAndroid Build Coastguard Worker           }
2084*4bdc9457SAndroid Build Coastguard Worker         }
2085*4bdc9457SAndroid Build Coastguard Worker       }
2086*4bdc9457SAndroid Build Coastguard Worker 
2087*4bdc9457SAndroid Build Coastguard Worker       // Compute renormalization parameters.
2088*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < groups() * group_output_channels(); c++) {
2089*4bdc9457SAndroid Build Coastguard Worker         int32_t accumulated_min = accumulators[c];
2090*4bdc9457SAndroid Build Coastguard Worker         int32_t accumulated_max = accumulators[c];
2091*4bdc9457SAndroid Build Coastguard Worker         for (size_t px = 0; px < batch_size() * output_height() * output_width(); px++) {
2092*4bdc9457SAndroid Build Coastguard Worker           accumulated_min = std::min(accumulated_min, accumulators[px * groups() * group_output_channels() + c]);
2093*4bdc9457SAndroid Build Coastguard Worker           accumulated_max = std::max(accumulated_max, accumulators[px * groups() * group_output_channels() + c]);
2094*4bdc9457SAndroid Build Coastguard Worker         }
2095*4bdc9457SAndroid Build Coastguard Worker 
2096*4bdc9457SAndroid Build Coastguard Worker         float requantization_scale = 0x1.0p-32f;
2097*4bdc9457SAndroid Build Coastguard Worker         if (accumulated_max != 0) {
2098*4bdc9457SAndroid Build Coastguard Worker           requantization_scale = std::max(requantization_scale,
2099*4bdc9457SAndroid Build Coastguard Worker             float(int32_t(std::numeric_limits<int8_t>::max()) - int32_t(output_zero_point)) / float(accumulated_max));
2100*4bdc9457SAndroid Build Coastguard Worker         }
2101*4bdc9457SAndroid Build Coastguard Worker         if (accumulated_min != 0) {
2102*4bdc9457SAndroid Build Coastguard Worker           requantization_scale = std::max(requantization_scale,
2103*4bdc9457SAndroid Build Coastguard Worker             float(int32_t(std::numeric_limits<int8_t>::min()) - int32_t(output_zero_point)) / float(accumulated_min));
2104*4bdc9457SAndroid Build Coastguard Worker         }
2105*4bdc9457SAndroid Build Coastguard Worker         requantization_scale = std::min(requantization_scale, 0x1.FFFFFEp-1f);
2106*4bdc9457SAndroid Build Coastguard Worker 
2107*4bdc9457SAndroid Build Coastguard Worker         requantization_scales[c] = requantization_scale;
2108*4bdc9457SAndroid Build Coastguard Worker       }
2109*4bdc9457SAndroid Build Coastguard Worker 
2110*4bdc9457SAndroid Build Coastguard Worker       // Renormalize reference results.
2111*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < groups() * group_output_channels(); c++) {
2112*4bdc9457SAndroid Build Coastguard Worker         for (size_t px = 0; px < batch_size() * output_height() * output_width(); px++) {
2113*4bdc9457SAndroid Build Coastguard Worker           output_ref[px * groups() * group_output_channels() + c] = double(int32_t(output_zero_point)) +
2114*4bdc9457SAndroid Build Coastguard Worker             double(accumulators[px * groups() * group_output_channels() + c]) * double(requantization_scales[c]);
2115*4bdc9457SAndroid Build Coastguard Worker         }
2116*4bdc9457SAndroid Build Coastguard Worker       }
2117*4bdc9457SAndroid Build Coastguard Worker       std::transform(output_ref.cbegin(), output_ref.cend(), output_ref.begin(),
2118*4bdc9457SAndroid Build Coastguard Worker         [this](double x) -> double {
2119*4bdc9457SAndroid Build Coastguard Worker           return std::max<double>(std::min<double>(x, double(qmax() - 0x80)), double(qmin() - 0x80));
2120*4bdc9457SAndroid Build Coastguard Worker         });
2121*4bdc9457SAndroid Build Coastguard Worker 
2122*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, and run Convolution operator once.
2123*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
2124*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t convolution_op = nullptr;
2125*4bdc9457SAndroid Build Coastguard Worker 
2126*4bdc9457SAndroid Build Coastguard Worker       xnn_status status = xnn_create_convolution2d_nhwc_qc8(
2127*4bdc9457SAndroid Build Coastguard Worker           padding_top(), padding_right(), padding_bottom(), padding_left(),
2128*4bdc9457SAndroid Build Coastguard Worker           kernel_height(), kernel_width(),
2129*4bdc9457SAndroid Build Coastguard Worker           subsampling_height(), subsampling_width(),
2130*4bdc9457SAndroid Build Coastguard Worker           dilation_height(), dilation_width(),
2131*4bdc9457SAndroid Build Coastguard Worker           groups(), group_input_channels(), group_output_channels(),
2132*4bdc9457SAndroid Build Coastguard Worker           input_channel_stride(), output_channel_stride(),
2133*4bdc9457SAndroid Build Coastguard Worker           input_zero_point, 1.0f /* input scale */, requantization_scales.data(),
2134*4bdc9457SAndroid Build Coastguard Worker           kernel.data(), has_bias() ? bias.data() : nullptr,
2135*4bdc9457SAndroid Build Coastguard Worker           output_zero_point, 1.0f /* output scale */, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
2136*4bdc9457SAndroid Build Coastguard Worker           0, NULL, &convolution_op);
2137*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
2138*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
2139*4bdc9457SAndroid Build Coastguard Worker       }
2140*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
2141*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, convolution_op);
2142*4bdc9457SAndroid Build Coastguard Worker 
2143*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete convolution_op.
2144*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op, xnn_delete_operator);
2145*4bdc9457SAndroid Build Coastguard Worker 
2146*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
2147*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_convolution2d_nhwc_qc8(
2148*4bdc9457SAndroid Build Coastguard Worker           convolution_op,
2149*4bdc9457SAndroid Build Coastguard Worker           batch_size(), input_height(), input_width(),
2150*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
2151*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
2152*4bdc9457SAndroid Build Coastguard Worker 
2153*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
2154*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(convolution_op, nullptr /* thread pool */));
2155*4bdc9457SAndroid Build Coastguard Worker 
2156*4bdc9457SAndroid Build Coastguard Worker       // Verify results of the first run.
2157*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
2158*4bdc9457SAndroid Build Coastguard Worker         for (size_t y = 0; y < output_height(); y++) {
2159*4bdc9457SAndroid Build Coastguard Worker           for (size_t x = 0; x < output_width(); x++) {
2160*4bdc9457SAndroid Build Coastguard Worker             for (size_t g = 0; g < groups(); g++) {
2161*4bdc9457SAndroid Build Coastguard Worker               for (size_t c = 0; c < group_output_channels(); c++) {
2162*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_LE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmax() - 0x80))
2163*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2164*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_GE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmin() - 0x80))
2165*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2166*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_NEAR(
2167*4bdc9457SAndroid Build Coastguard Worker                     output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c],
2168*4bdc9457SAndroid Build Coastguard Worker                     double(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]),
2169*4bdc9457SAndroid Build Coastguard Worker                     0.9)
2170*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2171*4bdc9457SAndroid Build Coastguard Worker               }
2172*4bdc9457SAndroid Build Coastguard Worker             }
2173*4bdc9457SAndroid Build Coastguard Worker           }
2174*4bdc9457SAndroid Build Coastguard Worker         }
2175*4bdc9457SAndroid Build Coastguard Worker       }
2176*4bdc9457SAndroid Build Coastguard Worker 
2177*4bdc9457SAndroid Build Coastguard Worker       // Re-generate data for the second run.
2178*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
2179*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), INT8_C(0xA5));
2180*4bdc9457SAndroid Build Coastguard Worker 
2181*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results for the second run, including renormalization.
2182*4bdc9457SAndroid Build Coastguard Worker       if (has_bias()) {
2183*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < next_batch_size(); i++) {
2184*4bdc9457SAndroid Build Coastguard Worker           for (size_t oy = 0; oy < next_output_height(); oy++) {
2185*4bdc9457SAndroid Build Coastguard Worker             for (size_t ox = 0; ox < next_output_width(); ox++) {
2186*4bdc9457SAndroid Build Coastguard Worker               for (size_t g = 0; g < groups(); g++) {
2187*4bdc9457SAndroid Build Coastguard Worker                 for (size_t oc = 0; oc < group_output_channels(); oc++) {
2188*4bdc9457SAndroid Build Coastguard Worker                   next_accumulators[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] =
2189*4bdc9457SAndroid Build Coastguard Worker                     bias[g * group_output_channels() + oc];
2190*4bdc9457SAndroid Build Coastguard Worker                 }
2191*4bdc9457SAndroid Build Coastguard Worker               }
2192*4bdc9457SAndroid Build Coastguard Worker             }
2193*4bdc9457SAndroid Build Coastguard Worker           }
2194*4bdc9457SAndroid Build Coastguard Worker         }
2195*4bdc9457SAndroid Build Coastguard Worker       } else {
2196*4bdc9457SAndroid Build Coastguard Worker         std::fill(next_accumulators.begin(), next_accumulators.end(), 0);
2197*4bdc9457SAndroid Build Coastguard Worker       }
2198*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < next_batch_size(); i++) {
2199*4bdc9457SAndroid Build Coastguard Worker         for (size_t oy = 0; oy < next_output_height(); oy++) {
2200*4bdc9457SAndroid Build Coastguard Worker           for (size_t ox = 0; ox < next_output_width(); ox++) {
2201*4bdc9457SAndroid Build Coastguard Worker             for (size_t ky = 0; ky < kernel_height(); ky++) {
2202*4bdc9457SAndroid Build Coastguard Worker               const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top();
2203*4bdc9457SAndroid Build Coastguard Worker               if (iy < next_input_height()) {
2204*4bdc9457SAndroid Build Coastguard Worker                 for (size_t kx = 0; kx < kernel_width(); kx++) {
2205*4bdc9457SAndroid Build Coastguard Worker                   const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left();
2206*4bdc9457SAndroid Build Coastguard Worker                   if (ix < next_input_width()) {
2207*4bdc9457SAndroid Build Coastguard Worker                     for (size_t g = 0; g < groups(); g++) {
2208*4bdc9457SAndroid Build Coastguard Worker                       for (size_t oc = 0; oc < group_output_channels(); oc++) {
2209*4bdc9457SAndroid Build Coastguard Worker                         for (size_t ic = 0; ic < group_input_channels(); ic++) {
2210*4bdc9457SAndroid Build Coastguard Worker                           next_accumulators[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
2211*4bdc9457SAndroid Build Coastguard Worker                             (int32_t(input[((i * next_input_height() + iy) * next_input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic]) - int32_t(input_zero_point)) *
2212*4bdc9457SAndroid Build Coastguard Worker                             int32_t(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]);
2213*4bdc9457SAndroid Build Coastguard Worker                         }
2214*4bdc9457SAndroid Build Coastguard Worker                       }
2215*4bdc9457SAndroid Build Coastguard Worker                     }
2216*4bdc9457SAndroid Build Coastguard Worker                   }
2217*4bdc9457SAndroid Build Coastguard Worker                 }
2218*4bdc9457SAndroid Build Coastguard Worker               }
2219*4bdc9457SAndroid Build Coastguard Worker             }
2220*4bdc9457SAndroid Build Coastguard Worker           }
2221*4bdc9457SAndroid Build Coastguard Worker         }
2222*4bdc9457SAndroid Build Coastguard Worker       }
2223*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < groups() * group_output_channels(); c++) {
2224*4bdc9457SAndroid Build Coastguard Worker         for (size_t px = 0; px < next_batch_size() * next_output_height() * next_output_width(); px++) {
2225*4bdc9457SAndroid Build Coastguard Worker           next_output_ref[px * groups() * group_output_channels() + c] = double(int32_t(output_zero_point)) +
2226*4bdc9457SAndroid Build Coastguard Worker             double(next_accumulators[px * groups() * group_output_channels() + c]) * double(requantization_scales[c]);
2227*4bdc9457SAndroid Build Coastguard Worker         }
2228*4bdc9457SAndroid Build Coastguard Worker       }
2229*4bdc9457SAndroid Build Coastguard Worker       std::transform(next_output_ref.cbegin(), next_output_ref.cend(), next_output_ref.begin(),
2230*4bdc9457SAndroid Build Coastguard Worker         [this](double x) -> double {
2231*4bdc9457SAndroid Build Coastguard Worker           return std::max<double>(std::min<double>(x, double(qmax() - 0x80)), double(qmin() - 0x80));
2232*4bdc9457SAndroid Build Coastguard Worker         });
2233*4bdc9457SAndroid Build Coastguard Worker 
2234*4bdc9457SAndroid Build Coastguard Worker       // Setup and run Convolution operator the second time, and destroy the operator.
2235*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
2236*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_convolution2d_nhwc_qc8(
2237*4bdc9457SAndroid Build Coastguard Worker           convolution_op,
2238*4bdc9457SAndroid Build Coastguard Worker           next_batch_size(), next_input_height(), next_input_width(),
2239*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
2240*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
2241*4bdc9457SAndroid Build Coastguard Worker 
2242*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
2243*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(convolution_op, nullptr /* thread pool */));
2244*4bdc9457SAndroid Build Coastguard Worker 
2245*4bdc9457SAndroid Build Coastguard Worker       // Verify results of the second run.
2246*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < next_batch_size(); i++) {
2247*4bdc9457SAndroid Build Coastguard Worker         for (size_t y = 0; y < next_output_height(); y++) {
2248*4bdc9457SAndroid Build Coastguard Worker           for (size_t x = 0; x < next_output_width(); x++) {
2249*4bdc9457SAndroid Build Coastguard Worker             for (size_t g = 0; g < groups(); g++) {
2250*4bdc9457SAndroid Build Coastguard Worker               for (size_t c = 0; c < group_output_channels(); c++) {
2251*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_LE(int32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmax() - 0x80))
2252*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2253*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_GE(int32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmin() - 0x80))
2254*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2255*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_NEAR(
2256*4bdc9457SAndroid Build Coastguard Worker                     next_output_ref[(((i * next_output_height() + y) * next_output_width() + x) * groups() + g) * group_output_channels() + c],
2257*4bdc9457SAndroid Build Coastguard Worker                     double(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]),
2258*4bdc9457SAndroid Build Coastguard Worker                     0.9)
2259*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2260*4bdc9457SAndroid Build Coastguard Worker               }
2261*4bdc9457SAndroid Build Coastguard Worker             }
2262*4bdc9457SAndroid Build Coastguard Worker           }
2263*4bdc9457SAndroid Build Coastguard Worker         }
2264*4bdc9457SAndroid Build Coastguard Worker       }
2265*4bdc9457SAndroid Build Coastguard Worker     }
2266*4bdc9457SAndroid Build Coastguard Worker   }
2267*4bdc9457SAndroid Build Coastguard Worker 
TestSetupNHWCxQS8()2268*4bdc9457SAndroid Build Coastguard Worker   void TestSetupNHWCxQS8() const {
2269*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(weights_type(), WeightsType::Default);
2270*4bdc9457SAndroid Build Coastguard Worker 
2271*4bdc9457SAndroid Build Coastguard Worker     ASSERT_FALSE(depthwise_layout());
2272*4bdc9457SAndroid Build Coastguard Worker 
2273*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
2274*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
2275*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i32dist(-10000, 10000);
2276*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i8dist(
2277*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
2278*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> w8dist(
2279*4bdc9457SAndroid Build Coastguard Worker       -std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max());
2280*4bdc9457SAndroid Build Coastguard Worker 
2281*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) + std::max(
2282*4bdc9457SAndroid Build Coastguard Worker       batch_size() * ((input_height() * input_width() - 1) * input_channel_stride() + groups() * group_input_channels()),
2283*4bdc9457SAndroid Build Coastguard Worker       next_batch_size() * ((next_input_height() * next_input_width() - 1) * input_channel_stride() + groups() * group_input_channels())));
2284*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels());
2285*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> bias(groups() * group_output_channels());
2286*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> output(std::max(
2287*4bdc9457SAndroid Build Coastguard Worker       batch_size() * ((output_height() * output_width() - 1) * output_channel_stride() + groups() * group_output_channels()),
2288*4bdc9457SAndroid Build Coastguard Worker       next_batch_size() * ((next_output_height() * next_output_width() - 1) * output_channel_stride() + groups() * group_output_channels())));
2289*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> accumulators(batch_size() * output_height() * output_width() * groups() * group_output_channels());
2290*4bdc9457SAndroid Build Coastguard Worker     std::vector<double> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels());
2291*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> next_accumulators(next_batch_size() * next_output_height() * next_output_width() * groups() * group_output_channels());
2292*4bdc9457SAndroid Build Coastguard Worker     std::vector<double> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * groups() * group_output_channels());
2293*4bdc9457SAndroid Build Coastguard Worker 
2294*4bdc9457SAndroid Build Coastguard Worker     const int8_t input_zero_point = -1;
2295*4bdc9457SAndroid Build Coastguard Worker 
2296*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
2297*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
2298*4bdc9457SAndroid Build Coastguard Worker       std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); });
2299*4bdc9457SAndroid Build Coastguard Worker       std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); });
2300*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), INT8_C(0xA5));
2301*4bdc9457SAndroid Build Coastguard Worker 
2302*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without renormalization.
2303*4bdc9457SAndroid Build Coastguard Worker       if (has_bias()) {
2304*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
2305*4bdc9457SAndroid Build Coastguard Worker           for (size_t oy = 0; oy < output_height(); oy++) {
2306*4bdc9457SAndroid Build Coastguard Worker             for (size_t ox = 0; ox < output_width(); ox++) {
2307*4bdc9457SAndroid Build Coastguard Worker               for (size_t g = 0; g < groups(); g++) {
2308*4bdc9457SAndroid Build Coastguard Worker                 for (size_t oc = 0; oc < group_output_channels(); oc++) {
2309*4bdc9457SAndroid Build Coastguard Worker                   accumulators[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] =
2310*4bdc9457SAndroid Build Coastguard Worker                     bias[g * group_output_channels() + oc];
2311*4bdc9457SAndroid Build Coastguard Worker                 }
2312*4bdc9457SAndroid Build Coastguard Worker               }
2313*4bdc9457SAndroid Build Coastguard Worker             }
2314*4bdc9457SAndroid Build Coastguard Worker           }
2315*4bdc9457SAndroid Build Coastguard Worker         }
2316*4bdc9457SAndroid Build Coastguard Worker       } else {
2317*4bdc9457SAndroid Build Coastguard Worker         std::fill(accumulators.begin(), accumulators.end(), 0);
2318*4bdc9457SAndroid Build Coastguard Worker       }
2319*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
2320*4bdc9457SAndroid Build Coastguard Worker         for (size_t oy = 0; oy < output_height(); oy++) {
2321*4bdc9457SAndroid Build Coastguard Worker           for (size_t ox = 0; ox < output_width(); ox++) {
2322*4bdc9457SAndroid Build Coastguard Worker             for (size_t ky = 0; ky < kernel_height(); ky++) {
2323*4bdc9457SAndroid Build Coastguard Worker               const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top();
2324*4bdc9457SAndroid Build Coastguard Worker               if (iy < input_height()) {
2325*4bdc9457SAndroid Build Coastguard Worker                 for (size_t kx = 0; kx < kernel_width(); kx++) {
2326*4bdc9457SAndroid Build Coastguard Worker                   const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left();
2327*4bdc9457SAndroid Build Coastguard Worker                   if (ix < input_width()) {
2328*4bdc9457SAndroid Build Coastguard Worker                     for (size_t g = 0; g < groups(); g++) {
2329*4bdc9457SAndroid Build Coastguard Worker                       for (size_t oc = 0; oc < group_output_channels(); oc++) {
2330*4bdc9457SAndroid Build Coastguard Worker                         for (size_t ic = 0; ic < group_input_channels(); ic++) {
2331*4bdc9457SAndroid Build Coastguard Worker                           accumulators[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
2332*4bdc9457SAndroid Build Coastguard Worker                             (int32_t(input[((i * input_height() + iy) * input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic]) - int32_t(input_zero_point)) *
2333*4bdc9457SAndroid Build Coastguard Worker                             int32_t(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]);
2334*4bdc9457SAndroid Build Coastguard Worker                         }
2335*4bdc9457SAndroid Build Coastguard Worker                       }
2336*4bdc9457SAndroid Build Coastguard Worker                     }
2337*4bdc9457SAndroid Build Coastguard Worker                   }
2338*4bdc9457SAndroid Build Coastguard Worker                 }
2339*4bdc9457SAndroid Build Coastguard Worker               }
2340*4bdc9457SAndroid Build Coastguard Worker             }
2341*4bdc9457SAndroid Build Coastguard Worker           }
2342*4bdc9457SAndroid Build Coastguard Worker         }
2343*4bdc9457SAndroid Build Coastguard Worker       }
2344*4bdc9457SAndroid Build Coastguard Worker 
2345*4bdc9457SAndroid Build Coastguard Worker       // Compute renormalization parameters.
2346*4bdc9457SAndroid Build Coastguard Worker       const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
2347*4bdc9457SAndroid Build Coastguard Worker       const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
2348*4bdc9457SAndroid Build Coastguard Worker 
2349*4bdc9457SAndroid Build Coastguard Worker       const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0;
2350*4bdc9457SAndroid Build Coastguard Worker       const int8_t output_zero_point = int8_t(std::max(std::min(
2351*4bdc9457SAndroid Build Coastguard Worker         lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
2352*4bdc9457SAndroid Build Coastguard Worker         long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
2353*4bdc9457SAndroid Build Coastguard Worker 
2354*4bdc9457SAndroid Build Coastguard Worker       // Renormalize reference results.
2355*4bdc9457SAndroid Build Coastguard Worker       std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(),
2356*4bdc9457SAndroid Build Coastguard Worker         [this, output_scale, output_zero_point](int32_t x) -> double {
2357*4bdc9457SAndroid Build Coastguard Worker           return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax() - 0x80) - output_zero_point), double(qmin() - 0x80) - output_zero_point);
2358*4bdc9457SAndroid Build Coastguard Worker         });
2359*4bdc9457SAndroid Build Coastguard Worker 
2360*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, and run Convolution operator once.
2361*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
2362*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t convolution_op = nullptr;
2363*4bdc9457SAndroid Build Coastguard Worker 
2364*4bdc9457SAndroid Build Coastguard Worker       xnn_status status = xnn_create_convolution2d_nhwc_qs8(
2365*4bdc9457SAndroid Build Coastguard Worker           padding_top(), padding_right(), padding_bottom(), padding_left(),
2366*4bdc9457SAndroid Build Coastguard Worker           kernel_height(), kernel_width(),
2367*4bdc9457SAndroid Build Coastguard Worker           subsampling_height(), subsampling_width(),
2368*4bdc9457SAndroid Build Coastguard Worker           dilation_height(), dilation_width(),
2369*4bdc9457SAndroid Build Coastguard Worker           groups(), group_input_channels(), group_output_channels(),
2370*4bdc9457SAndroid Build Coastguard Worker           input_channel_stride(), output_channel_stride(),
2371*4bdc9457SAndroid Build Coastguard Worker           input_zero_point, 1.0f /* input scale */, 1.0f /* kernel scale */,
2372*4bdc9457SAndroid Build Coastguard Worker           kernel.data(), has_bias() ? bias.data() : nullptr,
2373*4bdc9457SAndroid Build Coastguard Worker           output_zero_point, output_scale, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
2374*4bdc9457SAndroid Build Coastguard Worker           0, NULL, &convolution_op);
2375*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
2376*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
2377*4bdc9457SAndroid Build Coastguard Worker       }
2378*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
2379*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, convolution_op);
2380*4bdc9457SAndroid Build Coastguard Worker 
2381*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete convolution_op.
2382*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op, xnn_delete_operator);
2383*4bdc9457SAndroid Build Coastguard Worker 
2384*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
2385*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_convolution2d_nhwc_qs8(
2386*4bdc9457SAndroid Build Coastguard Worker           convolution_op,
2387*4bdc9457SAndroid Build Coastguard Worker           batch_size(), input_height(), input_width(),
2388*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
2389*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
2390*4bdc9457SAndroid Build Coastguard Worker 
2391*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
2392*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(convolution_op, nullptr /* thread pool */));
2393*4bdc9457SAndroid Build Coastguard Worker 
2394*4bdc9457SAndroid Build Coastguard Worker       // Verify results of the first run.
2395*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
2396*4bdc9457SAndroid Build Coastguard Worker         for (size_t y = 0; y < output_height(); y++) {
2397*4bdc9457SAndroid Build Coastguard Worker           for (size_t x = 0; x < output_width(); x++) {
2398*4bdc9457SAndroid Build Coastguard Worker             for (size_t g = 0; g < groups(); g++) {
2399*4bdc9457SAndroid Build Coastguard Worker               for (size_t c = 0; c < group_output_channels(); c++) {
2400*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_LE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmax() - 0x80))
2401*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2402*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_GE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmin() - 0x80))
2403*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2404*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_NEAR(
2405*4bdc9457SAndroid Build Coastguard Worker                     output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c],
2406*4bdc9457SAndroid Build Coastguard Worker                     double(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]) - double(output_zero_point),
2407*4bdc9457SAndroid Build Coastguard Worker                     0.9)
2408*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2409*4bdc9457SAndroid Build Coastguard Worker               }
2410*4bdc9457SAndroid Build Coastguard Worker             }
2411*4bdc9457SAndroid Build Coastguard Worker           }
2412*4bdc9457SAndroid Build Coastguard Worker         }
2413*4bdc9457SAndroid Build Coastguard Worker       }
2414*4bdc9457SAndroid Build Coastguard Worker 
2415*4bdc9457SAndroid Build Coastguard Worker       // Re-generate data for the second run.
2416*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
2417*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), INT8_C(0xA5));
2418*4bdc9457SAndroid Build Coastguard Worker 
2419*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results for the second run, including renormalization.
2420*4bdc9457SAndroid Build Coastguard Worker       if (has_bias()) {
2421*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < next_batch_size(); i++) {
2422*4bdc9457SAndroid Build Coastguard Worker           for (size_t oy = 0; oy < next_output_height(); oy++) {
2423*4bdc9457SAndroid Build Coastguard Worker             for (size_t ox = 0; ox < next_output_width(); ox++) {
2424*4bdc9457SAndroid Build Coastguard Worker               for (size_t g = 0; g < groups(); g++) {
2425*4bdc9457SAndroid Build Coastguard Worker                 for (size_t oc = 0; oc < group_output_channels(); oc++) {
2426*4bdc9457SAndroid Build Coastguard Worker                   next_accumulators[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] =
2427*4bdc9457SAndroid Build Coastguard Worker                     bias[g * group_output_channels() + oc];
2428*4bdc9457SAndroid Build Coastguard Worker                 }
2429*4bdc9457SAndroid Build Coastguard Worker               }
2430*4bdc9457SAndroid Build Coastguard Worker             }
2431*4bdc9457SAndroid Build Coastguard Worker           }
2432*4bdc9457SAndroid Build Coastguard Worker         }
2433*4bdc9457SAndroid Build Coastguard Worker       } else {
2434*4bdc9457SAndroid Build Coastguard Worker         std::fill(next_accumulators.begin(), next_accumulators.end(), 0);
2435*4bdc9457SAndroid Build Coastguard Worker       }
2436*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < next_batch_size(); i++) {
2437*4bdc9457SAndroid Build Coastguard Worker         for (size_t oy = 0; oy < next_output_height(); oy++) {
2438*4bdc9457SAndroid Build Coastguard Worker           for (size_t ox = 0; ox < next_output_width(); ox++) {
2439*4bdc9457SAndroid Build Coastguard Worker             for (size_t ky = 0; ky < kernel_height(); ky++) {
2440*4bdc9457SAndroid Build Coastguard Worker               const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top();
2441*4bdc9457SAndroid Build Coastguard Worker               if (iy < next_input_height()) {
2442*4bdc9457SAndroid Build Coastguard Worker                 for (size_t kx = 0; kx < kernel_width(); kx++) {
2443*4bdc9457SAndroid Build Coastguard Worker                   const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left();
2444*4bdc9457SAndroid Build Coastguard Worker                   if (ix < next_input_width()) {
2445*4bdc9457SAndroid Build Coastguard Worker                     for (size_t g = 0; g < groups(); g++) {
2446*4bdc9457SAndroid Build Coastguard Worker                       for (size_t oc = 0; oc < group_output_channels(); oc++) {
2447*4bdc9457SAndroid Build Coastguard Worker                         for (size_t ic = 0; ic < group_input_channels(); ic++) {
2448*4bdc9457SAndroid Build Coastguard Worker                           next_accumulators[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
2449*4bdc9457SAndroid Build Coastguard Worker                             (int32_t(input[((i * next_input_height() + iy) * next_input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic]) - int32_t(input_zero_point)) *
2450*4bdc9457SAndroid Build Coastguard Worker                             int32_t(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]);
2451*4bdc9457SAndroid Build Coastguard Worker                         }
2452*4bdc9457SAndroid Build Coastguard Worker                       }
2453*4bdc9457SAndroid Build Coastguard Worker                     }
2454*4bdc9457SAndroid Build Coastguard Worker                   }
2455*4bdc9457SAndroid Build Coastguard Worker                 }
2456*4bdc9457SAndroid Build Coastguard Worker               }
2457*4bdc9457SAndroid Build Coastguard Worker             }
2458*4bdc9457SAndroid Build Coastguard Worker           }
2459*4bdc9457SAndroid Build Coastguard Worker         }
2460*4bdc9457SAndroid Build Coastguard Worker       }
2461*4bdc9457SAndroid Build Coastguard Worker       std::transform(next_accumulators.cbegin(), next_accumulators.cend(), next_output_ref.begin(),
2462*4bdc9457SAndroid Build Coastguard Worker         [this, output_scale, output_zero_point](int32_t x) -> double {
2463*4bdc9457SAndroid Build Coastguard Worker           return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax() - 0x80) - output_zero_point), double(qmin() - 0x80) - output_zero_point);
2464*4bdc9457SAndroid Build Coastguard Worker         });
2465*4bdc9457SAndroid Build Coastguard Worker 
2466*4bdc9457SAndroid Build Coastguard Worker       // Setup and run Convolution operator the second time, and destroy the operator.
2467*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
2468*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_convolution2d_nhwc_qs8(
2469*4bdc9457SAndroid Build Coastguard Worker           convolution_op,
2470*4bdc9457SAndroid Build Coastguard Worker           next_batch_size(), next_input_height(), next_input_width(),
2471*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
2472*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
2473*4bdc9457SAndroid Build Coastguard Worker 
2474*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
2475*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(convolution_op, nullptr /* thread pool */));
2476*4bdc9457SAndroid Build Coastguard Worker 
2477*4bdc9457SAndroid Build Coastguard Worker       // Verify results of the second run.
2478*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < next_batch_size(); i++) {
2479*4bdc9457SAndroid Build Coastguard Worker         for (size_t y = 0; y < next_output_height(); y++) {
2480*4bdc9457SAndroid Build Coastguard Worker           for (size_t x = 0; x < next_output_width(); x++) {
2481*4bdc9457SAndroid Build Coastguard Worker             for (size_t g = 0; g < groups(); g++) {
2482*4bdc9457SAndroid Build Coastguard Worker               for (size_t c = 0; c < group_output_channels(); c++) {
2483*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_LE(int32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmax() - 0x80))
2484*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2485*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_GE(int32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmin() - 0x80))
2486*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2487*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_NEAR(
2488*4bdc9457SAndroid Build Coastguard Worker                     next_output_ref[(((i * next_output_height() + y) * next_output_width() + x) * groups() + g) * group_output_channels() + c],
2489*4bdc9457SAndroid Build Coastguard Worker                     double(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]) - double(output_zero_point),
2490*4bdc9457SAndroid Build Coastguard Worker                     0.9)
2491*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2492*4bdc9457SAndroid Build Coastguard Worker               }
2493*4bdc9457SAndroid Build Coastguard Worker             }
2494*4bdc9457SAndroid Build Coastguard Worker           }
2495*4bdc9457SAndroid Build Coastguard Worker         }
2496*4bdc9457SAndroid Build Coastguard Worker       }
2497*4bdc9457SAndroid Build Coastguard Worker     }
2498*4bdc9457SAndroid Build Coastguard Worker   }
2499*4bdc9457SAndroid Build Coastguard Worker 
TestSetupNHWCxQU8()2500*4bdc9457SAndroid Build Coastguard Worker   void TestSetupNHWCxQU8() const {
2501*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(weights_type(), WeightsType::Default);
2502*4bdc9457SAndroid Build Coastguard Worker 
2503*4bdc9457SAndroid Build Coastguard Worker     ASSERT_FALSE(depthwise_layout());
2504*4bdc9457SAndroid Build Coastguard Worker 
2505*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
2506*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
2507*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i32dist(-10000, 10000);
2508*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> u8dist(
2509*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
2510*4bdc9457SAndroid Build Coastguard Worker 
2511*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + std::max(
2512*4bdc9457SAndroid Build Coastguard Worker       batch_size() * ((input_height() * input_width() - 1) * input_channel_stride() + groups() * group_input_channels()),
2513*4bdc9457SAndroid Build Coastguard Worker       next_batch_size() * ((next_input_height() * next_input_width() - 1) * input_channel_stride() + groups() * group_input_channels())));
2514*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels());
2515*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> bias(groups() * group_output_channels());
2516*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output(std::max(
2517*4bdc9457SAndroid Build Coastguard Worker       batch_size() * ((output_height() * output_width() - 1) * output_channel_stride() + groups() * group_output_channels()),
2518*4bdc9457SAndroid Build Coastguard Worker       next_batch_size() * ((next_output_height() * next_output_width() - 1) * output_channel_stride() + groups() * group_output_channels())));
2519*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> accumulators(batch_size() * output_height() * output_width() * groups() * group_output_channels());
2520*4bdc9457SAndroid Build Coastguard Worker     std::vector<double> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels());
2521*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> next_accumulators(next_batch_size() * next_output_height() * next_output_width() * groups() * group_output_channels());
2522*4bdc9457SAndroid Build Coastguard Worker     std::vector<double> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * groups() * group_output_channels());
2523*4bdc9457SAndroid Build Coastguard Worker 
2524*4bdc9457SAndroid Build Coastguard Worker     const uint8_t input_zero_point = 127;
2525*4bdc9457SAndroid Build Coastguard Worker     const uint8_t kernel_zero_point = 127;
2526*4bdc9457SAndroid Build Coastguard Worker 
2527*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
2528*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
2529*4bdc9457SAndroid Build Coastguard Worker       std::generate(kernel.begin(), kernel.end(), [&]() { return u8dist(rng); });
2530*4bdc9457SAndroid Build Coastguard Worker       std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); });
2531*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT8_C(0xA5));
2532*4bdc9457SAndroid Build Coastguard Worker 
2533*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without renormalization.
2534*4bdc9457SAndroid Build Coastguard Worker       if (has_bias()) {
2535*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
2536*4bdc9457SAndroid Build Coastguard Worker           for (size_t oy = 0; oy < output_height(); oy++) {
2537*4bdc9457SAndroid Build Coastguard Worker             for (size_t ox = 0; ox < output_width(); ox++) {
2538*4bdc9457SAndroid Build Coastguard Worker               for (size_t g = 0; g < groups(); g++) {
2539*4bdc9457SAndroid Build Coastguard Worker                 for (size_t oc = 0; oc < group_output_channels(); oc++) {
2540*4bdc9457SAndroid Build Coastguard Worker                   accumulators[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] =
2541*4bdc9457SAndroid Build Coastguard Worker                     bias[g * group_output_channels() + oc];
2542*4bdc9457SAndroid Build Coastguard Worker                 }
2543*4bdc9457SAndroid Build Coastguard Worker               }
2544*4bdc9457SAndroid Build Coastguard Worker             }
2545*4bdc9457SAndroid Build Coastguard Worker           }
2546*4bdc9457SAndroid Build Coastguard Worker         }
2547*4bdc9457SAndroid Build Coastguard Worker       } else {
2548*4bdc9457SAndroid Build Coastguard Worker         std::fill(accumulators.begin(), accumulators.end(), 0);
2549*4bdc9457SAndroid Build Coastguard Worker       }
2550*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
2551*4bdc9457SAndroid Build Coastguard Worker         for (size_t oy = 0; oy < output_height(); oy++) {
2552*4bdc9457SAndroid Build Coastguard Worker           for (size_t ox = 0; ox < output_width(); ox++) {
2553*4bdc9457SAndroid Build Coastguard Worker             for (size_t ky = 0; ky < kernel_height(); ky++) {
2554*4bdc9457SAndroid Build Coastguard Worker               const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top();
2555*4bdc9457SAndroid Build Coastguard Worker               if (iy < input_height()) {
2556*4bdc9457SAndroid Build Coastguard Worker                 for (size_t kx = 0; kx < kernel_width(); kx++) {
2557*4bdc9457SAndroid Build Coastguard Worker                   const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left();
2558*4bdc9457SAndroid Build Coastguard Worker                   if (ix < input_width()) {
2559*4bdc9457SAndroid Build Coastguard Worker                     for (size_t g = 0; g < groups(); g++) {
2560*4bdc9457SAndroid Build Coastguard Worker                       for (size_t oc = 0; oc < group_output_channels(); oc++) {
2561*4bdc9457SAndroid Build Coastguard Worker                         for (size_t ic = 0; ic < group_input_channels(); ic++) {
2562*4bdc9457SAndroid Build Coastguard Worker                           accumulators[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
2563*4bdc9457SAndroid Build Coastguard Worker                             (int32_t(input[((i * input_height() + iy) * input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic]) - int32_t(input_zero_point)) *
2564*4bdc9457SAndroid Build Coastguard Worker                             (int32_t(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]) - int32_t(kernel_zero_point));
2565*4bdc9457SAndroid Build Coastguard Worker                         }
2566*4bdc9457SAndroid Build Coastguard Worker                       }
2567*4bdc9457SAndroid Build Coastguard Worker                     }
2568*4bdc9457SAndroid Build Coastguard Worker                   }
2569*4bdc9457SAndroid Build Coastguard Worker                 }
2570*4bdc9457SAndroid Build Coastguard Worker               }
2571*4bdc9457SAndroid Build Coastguard Worker             }
2572*4bdc9457SAndroid Build Coastguard Worker           }
2573*4bdc9457SAndroid Build Coastguard Worker         }
2574*4bdc9457SAndroid Build Coastguard Worker       }
2575*4bdc9457SAndroid Build Coastguard Worker 
2576*4bdc9457SAndroid Build Coastguard Worker       // Compute renormalization parameters.
2577*4bdc9457SAndroid Build Coastguard Worker       const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
2578*4bdc9457SAndroid Build Coastguard Worker       const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
2579*4bdc9457SAndroid Build Coastguard Worker 
2580*4bdc9457SAndroid Build Coastguard Worker       const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0;
2581*4bdc9457SAndroid Build Coastguard Worker       const uint8_t output_zero_point = uint8_t(std::max(std::min(
2582*4bdc9457SAndroid Build Coastguard Worker         lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
2583*4bdc9457SAndroid Build Coastguard Worker         long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
2584*4bdc9457SAndroid Build Coastguard Worker 
2585*4bdc9457SAndroid Build Coastguard Worker       // Renormalize reference results.
2586*4bdc9457SAndroid Build Coastguard Worker       std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(),
2587*4bdc9457SAndroid Build Coastguard Worker         [this, output_scale, output_zero_point](int32_t x) -> double {
2588*4bdc9457SAndroid Build Coastguard Worker           return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax()) - output_zero_point), double(qmin()) - output_zero_point);
2589*4bdc9457SAndroid Build Coastguard Worker         });
2590*4bdc9457SAndroid Build Coastguard Worker 
2591*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, and run Convolution operator once.
2592*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
2593*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t convolution_op = nullptr;
2594*4bdc9457SAndroid Build Coastguard Worker 
2595*4bdc9457SAndroid Build Coastguard Worker       xnn_status status = xnn_create_convolution2d_nhwc_qu8(
2596*4bdc9457SAndroid Build Coastguard Worker           padding_top(), padding_right(), padding_bottom(), padding_left(),
2597*4bdc9457SAndroid Build Coastguard Worker           kernel_height(), kernel_width(),
2598*4bdc9457SAndroid Build Coastguard Worker           subsampling_height(), subsampling_width(),
2599*4bdc9457SAndroid Build Coastguard Worker           dilation_height(), dilation_width(),
2600*4bdc9457SAndroid Build Coastguard Worker           groups(), group_input_channels(), group_output_channels(),
2601*4bdc9457SAndroid Build Coastguard Worker           input_channel_stride(), output_channel_stride(),
2602*4bdc9457SAndroid Build Coastguard Worker           input_zero_point, 1.0f /* input scale */,
2603*4bdc9457SAndroid Build Coastguard Worker           kernel_zero_point, 1.0f /* kernel scale */,
2604*4bdc9457SAndroid Build Coastguard Worker           kernel.data(), has_bias() ? bias.data() : nullptr,
2605*4bdc9457SAndroid Build Coastguard Worker           output_zero_point, output_scale, qmin(), qmax(),
2606*4bdc9457SAndroid Build Coastguard Worker           0, NULL, &convolution_op);
2607*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
2608*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
2609*4bdc9457SAndroid Build Coastguard Worker       }
2610*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
2611*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, convolution_op);
2612*4bdc9457SAndroid Build Coastguard Worker 
2613*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete convolution_op.
2614*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op, xnn_delete_operator);
2615*4bdc9457SAndroid Build Coastguard Worker 
2616*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
2617*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_convolution2d_nhwc_qu8(
2618*4bdc9457SAndroid Build Coastguard Worker           convolution_op,
2619*4bdc9457SAndroid Build Coastguard Worker           batch_size(), input_height(), input_width(),
2620*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
2621*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
2622*4bdc9457SAndroid Build Coastguard Worker 
2623*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
2624*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(convolution_op, nullptr /* thread pool */));
2625*4bdc9457SAndroid Build Coastguard Worker 
2626*4bdc9457SAndroid Build Coastguard Worker       // Verify results of the first run.
2627*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
2628*4bdc9457SAndroid Build Coastguard Worker         for (size_t y = 0; y < output_height(); y++) {
2629*4bdc9457SAndroid Build Coastguard Worker           for (size_t x = 0; x < output_width(); x++) {
2630*4bdc9457SAndroid Build Coastguard Worker             for (size_t g = 0; g < groups(); g++) {
2631*4bdc9457SAndroid Build Coastguard Worker               for (size_t c = 0; c < group_output_channels(); c++) {
2632*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_LE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmax()))
2633*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2634*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_GE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmin()))
2635*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2636*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_NEAR(
2637*4bdc9457SAndroid Build Coastguard Worker                     output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c],
2638*4bdc9457SAndroid Build Coastguard Worker                     double(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]) - double(output_zero_point),
2639*4bdc9457SAndroid Build Coastguard Worker                     0.9)
2640*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2641*4bdc9457SAndroid Build Coastguard Worker               }
2642*4bdc9457SAndroid Build Coastguard Worker             }
2643*4bdc9457SAndroid Build Coastguard Worker           }
2644*4bdc9457SAndroid Build Coastguard Worker         }
2645*4bdc9457SAndroid Build Coastguard Worker       }
2646*4bdc9457SAndroid Build Coastguard Worker 
2647*4bdc9457SAndroid Build Coastguard Worker       // Re-generate data for the second run.
2648*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
2649*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), 0xA5);
2650*4bdc9457SAndroid Build Coastguard Worker 
2651*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results for the second run, including renormalization.
2652*4bdc9457SAndroid Build Coastguard Worker       if (has_bias()) {
2653*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < next_batch_size(); i++) {
2654*4bdc9457SAndroid Build Coastguard Worker           for (size_t oy = 0; oy < next_output_height(); oy++) {
2655*4bdc9457SAndroid Build Coastguard Worker             for (size_t ox = 0; ox < next_output_width(); ox++) {
2656*4bdc9457SAndroid Build Coastguard Worker               for (size_t g = 0; g < groups(); g++) {
2657*4bdc9457SAndroid Build Coastguard Worker                 for (size_t oc = 0; oc < group_output_channels(); oc++) {
2658*4bdc9457SAndroid Build Coastguard Worker                   next_accumulators[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] =
2659*4bdc9457SAndroid Build Coastguard Worker                     bias[g * group_output_channels() + oc];
2660*4bdc9457SAndroid Build Coastguard Worker                 }
2661*4bdc9457SAndroid Build Coastguard Worker               }
2662*4bdc9457SAndroid Build Coastguard Worker             }
2663*4bdc9457SAndroid Build Coastguard Worker           }
2664*4bdc9457SAndroid Build Coastguard Worker         }
2665*4bdc9457SAndroid Build Coastguard Worker       } else {
2666*4bdc9457SAndroid Build Coastguard Worker         std::fill(next_accumulators.begin(), next_accumulators.end(), 0);
2667*4bdc9457SAndroid Build Coastguard Worker       }
2668*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < next_batch_size(); i++) {
2669*4bdc9457SAndroid Build Coastguard Worker         for (size_t oy = 0; oy < next_output_height(); oy++) {
2670*4bdc9457SAndroid Build Coastguard Worker           for (size_t ox = 0; ox < next_output_width(); ox++) {
2671*4bdc9457SAndroid Build Coastguard Worker             for (size_t ky = 0; ky < kernel_height(); ky++) {
2672*4bdc9457SAndroid Build Coastguard Worker               const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top();
2673*4bdc9457SAndroid Build Coastguard Worker               if (iy < next_input_height()) {
2674*4bdc9457SAndroid Build Coastguard Worker                 for (size_t kx = 0; kx < kernel_width(); kx++) {
2675*4bdc9457SAndroid Build Coastguard Worker                   const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left();
2676*4bdc9457SAndroid Build Coastguard Worker                   if (ix < next_input_width()) {
2677*4bdc9457SAndroid Build Coastguard Worker                     for (size_t g = 0; g < groups(); g++) {
2678*4bdc9457SAndroid Build Coastguard Worker                       for (size_t oc = 0; oc < group_output_channels(); oc++) {
2679*4bdc9457SAndroid Build Coastguard Worker                         for (size_t ic = 0; ic < group_input_channels(); ic++) {
2680*4bdc9457SAndroid Build Coastguard Worker                           next_accumulators[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
2681*4bdc9457SAndroid Build Coastguard Worker                             (int32_t(input[((i * next_input_height() + iy) * next_input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic]) - int32_t(input_zero_point)) *
2682*4bdc9457SAndroid Build Coastguard Worker                             (int32_t(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]) - int32_t(kernel_zero_point));
2683*4bdc9457SAndroid Build Coastguard Worker                         }
2684*4bdc9457SAndroid Build Coastguard Worker                       }
2685*4bdc9457SAndroid Build Coastguard Worker                     }
2686*4bdc9457SAndroid Build Coastguard Worker                   }
2687*4bdc9457SAndroid Build Coastguard Worker                 }
2688*4bdc9457SAndroid Build Coastguard Worker               }
2689*4bdc9457SAndroid Build Coastguard Worker             }
2690*4bdc9457SAndroid Build Coastguard Worker           }
2691*4bdc9457SAndroid Build Coastguard Worker         }
2692*4bdc9457SAndroid Build Coastguard Worker       }
2693*4bdc9457SAndroid Build Coastguard Worker       std::transform(next_accumulators.cbegin(), next_accumulators.cend(), next_output_ref.begin(),
2694*4bdc9457SAndroid Build Coastguard Worker         [this, output_scale, output_zero_point](int32_t x) -> double {
2695*4bdc9457SAndroid Build Coastguard Worker           return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax()) - output_zero_point), double(qmin()) - output_zero_point);
2696*4bdc9457SAndroid Build Coastguard Worker         });
2697*4bdc9457SAndroid Build Coastguard Worker 
2698*4bdc9457SAndroid Build Coastguard Worker       // Setup and run Convolution operator the second time, and destroy the operator.
2699*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
2700*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_convolution2d_nhwc_qu8(
2701*4bdc9457SAndroid Build Coastguard Worker           convolution_op,
2702*4bdc9457SAndroid Build Coastguard Worker           next_batch_size(), next_input_height(), next_input_width(),
2703*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
2704*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
2705*4bdc9457SAndroid Build Coastguard Worker 
2706*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
2707*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(convolution_op, nullptr /* thread pool */));
2708*4bdc9457SAndroid Build Coastguard Worker 
2709*4bdc9457SAndroid Build Coastguard Worker       // Verify results of the second run.
2710*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < next_batch_size(); i++) {
2711*4bdc9457SAndroid Build Coastguard Worker         for (size_t y = 0; y < next_output_height(); y++) {
2712*4bdc9457SAndroid Build Coastguard Worker           for (size_t x = 0; x < next_output_width(); x++) {
2713*4bdc9457SAndroid Build Coastguard Worker             for (size_t g = 0; g < groups(); g++) {
2714*4bdc9457SAndroid Build Coastguard Worker               for (size_t c = 0; c < group_output_channels(); c++) {
2715*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_LE(int32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmax()))
2716*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2717*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_GE(int32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), int32_t(qmin()))
2718*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2719*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_NEAR(
2720*4bdc9457SAndroid Build Coastguard Worker                     next_output_ref[(((i * next_output_height() + y) * next_output_width() + x) * groups() + g) * group_output_channels() + c],
2721*4bdc9457SAndroid Build Coastguard Worker                     double(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]) - double(output_zero_point),
2722*4bdc9457SAndroid Build Coastguard Worker                     0.9)
2723*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2724*4bdc9457SAndroid Build Coastguard Worker               }
2725*4bdc9457SAndroid Build Coastguard Worker             }
2726*4bdc9457SAndroid Build Coastguard Worker           }
2727*4bdc9457SAndroid Build Coastguard Worker         }
2728*4bdc9457SAndroid Build Coastguard Worker       }
2729*4bdc9457SAndroid Build Coastguard Worker     }
2730*4bdc9457SAndroid Build Coastguard Worker   }
2731*4bdc9457SAndroid Build Coastguard Worker 
TestSetupNHWCxF16()2732*4bdc9457SAndroid Build Coastguard Worker   void TestSetupNHWCxF16() const {
2733*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(weights_type(), WeightsType::Default);
2734*4bdc9457SAndroid Build Coastguard Worker 
2735*4bdc9457SAndroid Build Coastguard Worker     ASSERT_FALSE(depthwise_layout());
2736*4bdc9457SAndroid Build Coastguard Worker 
2737*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
2738*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
2739*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(0.1f, 1.0f);
2740*4bdc9457SAndroid Build Coastguard Worker 
2741*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + std::max(
2742*4bdc9457SAndroid Build Coastguard Worker       batch_size() * ((input_height() * input_width() - 1) * input_channel_stride() + groups() * group_input_channels()),
2743*4bdc9457SAndroid Build Coastguard Worker       next_batch_size() * ((next_input_height() * next_input_width() - 1) * input_channel_stride() + groups() * group_input_channels())));
2744*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels());
2745*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> bias(groups() * group_output_channels());
2746*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> output(std::max(
2747*4bdc9457SAndroid Build Coastguard Worker       batch_size() * ((output_height() * output_width() - 1) * output_channel_stride() + groups() * group_output_channels()),
2748*4bdc9457SAndroid Build Coastguard Worker       next_batch_size() * ((next_output_height() * next_output_width() - 1) * output_channel_stride() + groups() * group_output_channels())));
2749*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels());
2750*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * groups() * group_output_channels());
2751*4bdc9457SAndroid Build Coastguard Worker 
2752*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
2753*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
2754*4bdc9457SAndroid Build Coastguard Worker       std::generate(kernel.begin(), kernel.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
2755*4bdc9457SAndroid Build Coastguard Worker       std::generate(bias.begin(), bias.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
2756*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
2757*4bdc9457SAndroid Build Coastguard Worker 
2758*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
2759*4bdc9457SAndroid Build Coastguard Worker       if (has_bias()) {
2760*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
2761*4bdc9457SAndroid Build Coastguard Worker           for (size_t oy = 0; oy < output_height(); oy++) {
2762*4bdc9457SAndroid Build Coastguard Worker             for (size_t ox = 0; ox < output_width(); ox++) {
2763*4bdc9457SAndroid Build Coastguard Worker               for (size_t g = 0; g < groups(); g++) {
2764*4bdc9457SAndroid Build Coastguard Worker                 for (size_t oc = 0; oc < group_output_channels(); oc++) {
2765*4bdc9457SAndroid Build Coastguard Worker                   output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] =
2766*4bdc9457SAndroid Build Coastguard Worker                     fp16_ieee_to_fp32_value(bias[g * group_output_channels() + oc]);
2767*4bdc9457SAndroid Build Coastguard Worker                 }
2768*4bdc9457SAndroid Build Coastguard Worker               }
2769*4bdc9457SAndroid Build Coastguard Worker             }
2770*4bdc9457SAndroid Build Coastguard Worker           }
2771*4bdc9457SAndroid Build Coastguard Worker         }
2772*4bdc9457SAndroid Build Coastguard Worker       } else {
2773*4bdc9457SAndroid Build Coastguard Worker         std::fill(output_ref.begin(), output_ref.end(), 0.0f);
2774*4bdc9457SAndroid Build Coastguard Worker       }
2775*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
2776*4bdc9457SAndroid Build Coastguard Worker         for (size_t oy = 0; oy < output_height(); oy++) {
2777*4bdc9457SAndroid Build Coastguard Worker           for (size_t ox = 0; ox < output_width(); ox++) {
2778*4bdc9457SAndroid Build Coastguard Worker             for (size_t ky = 0; ky < kernel_height(); ky++) {
2779*4bdc9457SAndroid Build Coastguard Worker               const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top();
2780*4bdc9457SAndroid Build Coastguard Worker               if (iy < input_height()) {
2781*4bdc9457SAndroid Build Coastguard Worker                 for (size_t kx = 0; kx < kernel_width(); kx++) {
2782*4bdc9457SAndroid Build Coastguard Worker                   const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left();
2783*4bdc9457SAndroid Build Coastguard Worker                   if (ix < input_width()) {
2784*4bdc9457SAndroid Build Coastguard Worker                     for (size_t g = 0; g < groups(); g++) {
2785*4bdc9457SAndroid Build Coastguard Worker                       for (size_t oc = 0; oc < group_output_channels(); oc++) {
2786*4bdc9457SAndroid Build Coastguard Worker                         for (size_t ic = 0; ic < group_input_channels(); ic++) {
2787*4bdc9457SAndroid Build Coastguard Worker                           output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
2788*4bdc9457SAndroid Build Coastguard Worker                             fp16_ieee_to_fp32_value(input[((i * input_height() + iy) * input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic]) *
2789*4bdc9457SAndroid Build Coastguard Worker                             fp16_ieee_to_fp32_value(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]);
2790*4bdc9457SAndroid Build Coastguard Worker                         }
2791*4bdc9457SAndroid Build Coastguard Worker                       }
2792*4bdc9457SAndroid Build Coastguard Worker                     }
2793*4bdc9457SAndroid Build Coastguard Worker                   }
2794*4bdc9457SAndroid Build Coastguard Worker                 }
2795*4bdc9457SAndroid Build Coastguard Worker               }
2796*4bdc9457SAndroid Build Coastguard Worker             }
2797*4bdc9457SAndroid Build Coastguard Worker           }
2798*4bdc9457SAndroid Build Coastguard Worker         }
2799*4bdc9457SAndroid Build Coastguard Worker       }
2800*4bdc9457SAndroid Build Coastguard Worker 
2801*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
2802*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
2803*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
2804*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
2805*4bdc9457SAndroid Build Coastguard Worker       const float scaled_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + accumulated_range / 255.0f * float(qmin())));
2806*4bdc9457SAndroid Build Coastguard Worker       const float scaled_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - accumulated_range / 255.0f * float(255 - qmax())));
2807*4bdc9457SAndroid Build Coastguard Worker       const float output_min = scaled_min == scaled_max ? -std::numeric_limits<float>::infinity() : scaled_min;
2808*4bdc9457SAndroid Build Coastguard Worker       const float output_max = scaled_min == scaled_max ? +std::numeric_limits<float>::infinity() : scaled_max;
2809*4bdc9457SAndroid Build Coastguard Worker 
2810*4bdc9457SAndroid Build Coastguard Worker       for (float& output_value : output_ref) {
2811*4bdc9457SAndroid Build Coastguard Worker         output_value = std::min(std::max(output_value, output_min), output_max);
2812*4bdc9457SAndroid Build Coastguard Worker       }
2813*4bdc9457SAndroid Build Coastguard Worker 
2814*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, and run Convolution operator once.
2815*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
2816*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t convolution_op = nullptr;
2817*4bdc9457SAndroid Build Coastguard Worker 
2818*4bdc9457SAndroid Build Coastguard Worker       xnn_status status = xnn_create_convolution2d_nhwc_f16(
2819*4bdc9457SAndroid Build Coastguard Worker           padding_top(), padding_right(), padding_bottom(), padding_left(),
2820*4bdc9457SAndroid Build Coastguard Worker           kernel_height(), kernel_width(),
2821*4bdc9457SAndroid Build Coastguard Worker           subsampling_height(), subsampling_width(),
2822*4bdc9457SAndroid Build Coastguard Worker           dilation_height(), dilation_width(),
2823*4bdc9457SAndroid Build Coastguard Worker           groups(), group_input_channels(), group_output_channels(),
2824*4bdc9457SAndroid Build Coastguard Worker           input_channel_stride(), output_channel_stride(),
2825*4bdc9457SAndroid Build Coastguard Worker           kernel.data(), has_bias() ? bias.data() : nullptr,
2826*4bdc9457SAndroid Build Coastguard Worker           output_min, output_max,
2827*4bdc9457SAndroid Build Coastguard Worker           0, NULL, &convolution_op);
2828*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
2829*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
2830*4bdc9457SAndroid Build Coastguard Worker       }
2831*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
2832*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, convolution_op);
2833*4bdc9457SAndroid Build Coastguard Worker 
2834*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete convolution_op.
2835*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op, xnn_delete_operator);
2836*4bdc9457SAndroid Build Coastguard Worker 
2837*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
2838*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_convolution2d_nhwc_f16(
2839*4bdc9457SAndroid Build Coastguard Worker           convolution_op,
2840*4bdc9457SAndroid Build Coastguard Worker           batch_size(), input_height(), input_width(),
2841*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
2842*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
2843*4bdc9457SAndroid Build Coastguard Worker 
2844*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
2845*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(convolution_op, nullptr /* thread pool */));
2846*4bdc9457SAndroid Build Coastguard Worker 
2847*4bdc9457SAndroid Build Coastguard Worker       // Verify results of the first run.
2848*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
2849*4bdc9457SAndroid Build Coastguard Worker         for (size_t y = 0; y < output_height(); y++) {
2850*4bdc9457SAndroid Build Coastguard Worker           for (size_t x = 0; x < output_width(); x++) {
2851*4bdc9457SAndroid Build Coastguard Worker             for (size_t g = 0; g < groups(); g++) {
2852*4bdc9457SAndroid Build Coastguard Worker               for (size_t c = 0; c < group_output_channels(); c++) {
2853*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_GE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), output_min)
2854*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2855*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_LE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), output_max)
2856*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2857*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_NEAR(output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c], fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), std::max(1.0e-4f, std::abs(output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c]) * 1.0e-2f))
2858*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2859*4bdc9457SAndroid Build Coastguard Worker               }
2860*4bdc9457SAndroid Build Coastguard Worker             }
2861*4bdc9457SAndroid Build Coastguard Worker           }
2862*4bdc9457SAndroid Build Coastguard Worker         }
2863*4bdc9457SAndroid Build Coastguard Worker       }
2864*4bdc9457SAndroid Build Coastguard Worker 
2865*4bdc9457SAndroid Build Coastguard Worker       // Re-generate data for the second run.
2866*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
2867*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
2868*4bdc9457SAndroid Build Coastguard Worker 
2869*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results for the second run, including clamping.
2870*4bdc9457SAndroid Build Coastguard Worker       if (has_bias()) {
2871*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < next_batch_size(); i++) {
2872*4bdc9457SAndroid Build Coastguard Worker           for (size_t oy = 0; oy < next_output_height(); oy++) {
2873*4bdc9457SAndroid Build Coastguard Worker             for (size_t ox = 0; ox < next_output_width(); ox++) {
2874*4bdc9457SAndroid Build Coastguard Worker               for (size_t g = 0; g < groups(); g++) {
2875*4bdc9457SAndroid Build Coastguard Worker                 for (size_t oc = 0; oc < group_output_channels(); oc++) {
2876*4bdc9457SAndroid Build Coastguard Worker                   next_output_ref[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] =
2877*4bdc9457SAndroid Build Coastguard Worker                     fp16_ieee_to_fp32_value(bias[g * group_output_channels() + oc]);
2878*4bdc9457SAndroid Build Coastguard Worker                 }
2879*4bdc9457SAndroid Build Coastguard Worker               }
2880*4bdc9457SAndroid Build Coastguard Worker             }
2881*4bdc9457SAndroid Build Coastguard Worker           }
2882*4bdc9457SAndroid Build Coastguard Worker         }
2883*4bdc9457SAndroid Build Coastguard Worker       } else {
2884*4bdc9457SAndroid Build Coastguard Worker         std::fill(next_output_ref.begin(), next_output_ref.end(), 0.0f);
2885*4bdc9457SAndroid Build Coastguard Worker       }
2886*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < next_batch_size(); i++) {
2887*4bdc9457SAndroid Build Coastguard Worker         for (size_t oy = 0; oy < next_output_height(); oy++) {
2888*4bdc9457SAndroid Build Coastguard Worker           for (size_t ox = 0; ox < next_output_width(); ox++) {
2889*4bdc9457SAndroid Build Coastguard Worker             for (size_t ky = 0; ky < kernel_height(); ky++) {
2890*4bdc9457SAndroid Build Coastguard Worker               const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top();
2891*4bdc9457SAndroid Build Coastguard Worker               if (iy < next_input_height()) {
2892*4bdc9457SAndroid Build Coastguard Worker                 for (size_t kx = 0; kx < kernel_width(); kx++) {
2893*4bdc9457SAndroid Build Coastguard Worker                   const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left();
2894*4bdc9457SAndroid Build Coastguard Worker                   if (ix < next_input_width()) {
2895*4bdc9457SAndroid Build Coastguard Worker                     for (size_t g = 0; g < groups(); g++) {
2896*4bdc9457SAndroid Build Coastguard Worker                       for (size_t oc = 0; oc < group_output_channels(); oc++) {
2897*4bdc9457SAndroid Build Coastguard Worker                         for (size_t ic = 0; ic < group_input_channels(); ic++) {
2898*4bdc9457SAndroid Build Coastguard Worker                           next_output_ref[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
2899*4bdc9457SAndroid Build Coastguard Worker                             fp16_ieee_to_fp32_value(input[((i * next_input_height() + iy) * next_input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic]) *
2900*4bdc9457SAndroid Build Coastguard Worker                             fp16_ieee_to_fp32_value(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]);
2901*4bdc9457SAndroid Build Coastguard Worker                         }
2902*4bdc9457SAndroid Build Coastguard Worker                       }
2903*4bdc9457SAndroid Build Coastguard Worker                     }
2904*4bdc9457SAndroid Build Coastguard Worker                   }
2905*4bdc9457SAndroid Build Coastguard Worker                 }
2906*4bdc9457SAndroid Build Coastguard Worker               }
2907*4bdc9457SAndroid Build Coastguard Worker             }
2908*4bdc9457SAndroid Build Coastguard Worker           }
2909*4bdc9457SAndroid Build Coastguard Worker         }
2910*4bdc9457SAndroid Build Coastguard Worker       }
2911*4bdc9457SAndroid Build Coastguard Worker       for (float& value : next_output_ref) {
2912*4bdc9457SAndroid Build Coastguard Worker         value = std::max(std::min(value, output_max), output_min);
2913*4bdc9457SAndroid Build Coastguard Worker       }
2914*4bdc9457SAndroid Build Coastguard Worker 
2915*4bdc9457SAndroid Build Coastguard Worker       // Setup and run Convolution operator the second time, and destroy the operator.
2916*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
2917*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_convolution2d_nhwc_f16(
2918*4bdc9457SAndroid Build Coastguard Worker           convolution_op,
2919*4bdc9457SAndroid Build Coastguard Worker           next_batch_size(), next_input_height(), next_input_width(),
2920*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
2921*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
2922*4bdc9457SAndroid Build Coastguard Worker 
2923*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
2924*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(convolution_op, nullptr /* thread pool */));
2925*4bdc9457SAndroid Build Coastguard Worker 
2926*4bdc9457SAndroid Build Coastguard Worker       // Verify results of the second run.
2927*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < next_batch_size(); i++) {
2928*4bdc9457SAndroid Build Coastguard Worker         for (size_t y = 0; y < next_output_height(); y++) {
2929*4bdc9457SAndroid Build Coastguard Worker           for (size_t x = 0; x < next_output_width(); x++) {
2930*4bdc9457SAndroid Build Coastguard Worker             for (size_t g = 0; g < groups(); g++) {
2931*4bdc9457SAndroid Build Coastguard Worker               for (size_t c = 0; c < group_output_channels(); c++) {
2932*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_GE(fp16_ieee_to_fp32_value(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), output_min)
2933*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2934*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_LE(fp16_ieee_to_fp32_value(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), output_max)
2935*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2936*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_NEAR(next_output_ref[(((i * next_output_height() + y) * next_output_width() + x) * groups() + g) * group_output_channels() + c], fp16_ieee_to_fp32_value(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c]), std::max(1.0e-4f, std::abs(next_output_ref[(((i * next_output_height() + y) * next_output_width() + x) * groups() + g) * group_output_channels() + c]) * 1.0e-2f))
2937*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2938*4bdc9457SAndroid Build Coastguard Worker               }
2939*4bdc9457SAndroid Build Coastguard Worker             }
2940*4bdc9457SAndroid Build Coastguard Worker           }
2941*4bdc9457SAndroid Build Coastguard Worker         }
2942*4bdc9457SAndroid Build Coastguard Worker       }
2943*4bdc9457SAndroid Build Coastguard Worker     }
2944*4bdc9457SAndroid Build Coastguard Worker   }
2945*4bdc9457SAndroid Build Coastguard Worker 
TestSetupNHWCxF32()2946*4bdc9457SAndroid Build Coastguard Worker   void TestSetupNHWCxF32() const {
2947*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(weights_type(), WeightsType::Default);
2948*4bdc9457SAndroid Build Coastguard Worker 
2949*4bdc9457SAndroid Build Coastguard Worker     ASSERT_FALSE(depthwise_layout());
2950*4bdc9457SAndroid Build Coastguard Worker 
2951*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
2952*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
2953*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(0.1f, 1.0f);
2954*4bdc9457SAndroid Build Coastguard Worker 
2955*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + std::max(
2956*4bdc9457SAndroid Build Coastguard Worker       batch_size() * ((input_height() * input_width() - 1) * input_channel_stride() + groups() * group_input_channels()),
2957*4bdc9457SAndroid Build Coastguard Worker       next_batch_size() * ((next_input_height() * next_input_width() - 1) * input_channel_stride() + groups() * group_input_channels())));
2958*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels());
2959*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> bias(groups() * group_output_channels());
2960*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output(std::max(
2961*4bdc9457SAndroid Build Coastguard Worker       batch_size() * ((output_height() * output_width() - 1) * output_channel_stride() + groups() * group_output_channels()),
2962*4bdc9457SAndroid Build Coastguard Worker       next_batch_size() * ((next_output_height() * next_output_width() - 1) * output_channel_stride() + groups() * group_output_channels())));
2963*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels());
2964*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * groups() * group_output_channels());
2965*4bdc9457SAndroid Build Coastguard Worker 
2966*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
2967*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
2968*4bdc9457SAndroid Build Coastguard Worker       std::generate(kernel.begin(), kernel.end(), [&]() { return f32dist(rng); });
2969*4bdc9457SAndroid Build Coastguard Worker       std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
2970*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), nanf(""));
2971*4bdc9457SAndroid Build Coastguard Worker 
2972*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without clamping.
2973*4bdc9457SAndroid Build Coastguard Worker       if (has_bias()) {
2974*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
2975*4bdc9457SAndroid Build Coastguard Worker           for (size_t oy = 0; oy < output_height(); oy++) {
2976*4bdc9457SAndroid Build Coastguard Worker             for (size_t ox = 0; ox < output_width(); ox++) {
2977*4bdc9457SAndroid Build Coastguard Worker               for (size_t g = 0; g < groups(); g++) {
2978*4bdc9457SAndroid Build Coastguard Worker                 for (size_t oc = 0; oc < group_output_channels(); oc++) {
2979*4bdc9457SAndroid Build Coastguard Worker                   output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] =
2980*4bdc9457SAndroid Build Coastguard Worker                     bias[g * group_output_channels() + oc];
2981*4bdc9457SAndroid Build Coastguard Worker                 }
2982*4bdc9457SAndroid Build Coastguard Worker               }
2983*4bdc9457SAndroid Build Coastguard Worker             }
2984*4bdc9457SAndroid Build Coastguard Worker           }
2985*4bdc9457SAndroid Build Coastguard Worker         }
2986*4bdc9457SAndroid Build Coastguard Worker       } else {
2987*4bdc9457SAndroid Build Coastguard Worker         std::fill(output_ref.begin(), output_ref.end(), 0.0f);
2988*4bdc9457SAndroid Build Coastguard Worker       }
2989*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
2990*4bdc9457SAndroid Build Coastguard Worker         for (size_t oy = 0; oy < output_height(); oy++) {
2991*4bdc9457SAndroid Build Coastguard Worker           for (size_t ox = 0; ox < output_width(); ox++) {
2992*4bdc9457SAndroid Build Coastguard Worker             for (size_t ky = 0; ky < kernel_height(); ky++) {
2993*4bdc9457SAndroid Build Coastguard Worker               const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top();
2994*4bdc9457SAndroid Build Coastguard Worker               if (iy < input_height()) {
2995*4bdc9457SAndroid Build Coastguard Worker                 for (size_t kx = 0; kx < kernel_width(); kx++) {
2996*4bdc9457SAndroid Build Coastguard Worker                   const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left();
2997*4bdc9457SAndroid Build Coastguard Worker                   if (ix < input_width()) {
2998*4bdc9457SAndroid Build Coastguard Worker                     for (size_t g = 0; g < groups(); g++) {
2999*4bdc9457SAndroid Build Coastguard Worker                       for (size_t oc = 0; oc < group_output_channels(); oc++) {
3000*4bdc9457SAndroid Build Coastguard Worker                         for (size_t ic = 0; ic < group_input_channels(); ic++) {
3001*4bdc9457SAndroid Build Coastguard Worker                           output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
3002*4bdc9457SAndroid Build Coastguard Worker                             input[((i * input_height() + iy) * input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic] *
3003*4bdc9457SAndroid Build Coastguard Worker                             kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic];
3004*4bdc9457SAndroid Build Coastguard Worker                         }
3005*4bdc9457SAndroid Build Coastguard Worker                       }
3006*4bdc9457SAndroid Build Coastguard Worker                     }
3007*4bdc9457SAndroid Build Coastguard Worker                   }
3008*4bdc9457SAndroid Build Coastguard Worker                 }
3009*4bdc9457SAndroid Build Coastguard Worker               }
3010*4bdc9457SAndroid Build Coastguard Worker             }
3011*4bdc9457SAndroid Build Coastguard Worker           }
3012*4bdc9457SAndroid Build Coastguard Worker         }
3013*4bdc9457SAndroid Build Coastguard Worker       }
3014*4bdc9457SAndroid Build Coastguard Worker 
3015*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
3016*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
3017*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
3018*4bdc9457SAndroid Build Coastguard Worker 
3019*4bdc9457SAndroid Build Coastguard Worker       const float output_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
3020*4bdc9457SAndroid Build Coastguard Worker       const float output_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
3021*4bdc9457SAndroid Build Coastguard Worker 
3022*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
3023*4bdc9457SAndroid Build Coastguard Worker       for (float& value : output_ref) {
3024*4bdc9457SAndroid Build Coastguard Worker         value = std::max(std::min(value, output_max), output_min);
3025*4bdc9457SAndroid Build Coastguard Worker       }
3026*4bdc9457SAndroid Build Coastguard Worker 
3027*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, and run Convolution operator once.
3028*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
3029*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t convolution_op = nullptr;
3030*4bdc9457SAndroid Build Coastguard Worker 
3031*4bdc9457SAndroid Build Coastguard Worker       xnn_status status = xnn_create_convolution2d_nhwc_f32(
3032*4bdc9457SAndroid Build Coastguard Worker           padding_top(), padding_right(), padding_bottom(), padding_left(),
3033*4bdc9457SAndroid Build Coastguard Worker           kernel_height(), kernel_width(),
3034*4bdc9457SAndroid Build Coastguard Worker           subsampling_height(), subsampling_width(),
3035*4bdc9457SAndroid Build Coastguard Worker           dilation_height(), dilation_width(),
3036*4bdc9457SAndroid Build Coastguard Worker           groups(), group_input_channels(), group_output_channels(),
3037*4bdc9457SAndroid Build Coastguard Worker           input_channel_stride(), output_channel_stride(),
3038*4bdc9457SAndroid Build Coastguard Worker           kernel.data(), has_bias() ? bias.data() : nullptr,
3039*4bdc9457SAndroid Build Coastguard Worker           output_min, output_max,
3040*4bdc9457SAndroid Build Coastguard Worker           0, NULL, &convolution_op);
3041*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
3042*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
3043*4bdc9457SAndroid Build Coastguard Worker       }
3044*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
3045*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, convolution_op);
3046*4bdc9457SAndroid Build Coastguard Worker 
3047*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete convolution_op.
3048*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_convolution_op(convolution_op, xnn_delete_operator);
3049*4bdc9457SAndroid Build Coastguard Worker 
3050*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
3051*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_convolution2d_nhwc_f32(
3052*4bdc9457SAndroid Build Coastguard Worker           convolution_op,
3053*4bdc9457SAndroid Build Coastguard Worker           batch_size(), input_height(), input_width(),
3054*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
3055*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
3056*4bdc9457SAndroid Build Coastguard Worker 
3057*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
3058*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(convolution_op, nullptr /* thread pool */));
3059*4bdc9457SAndroid Build Coastguard Worker 
3060*4bdc9457SAndroid Build Coastguard Worker       // Verify results of the first run.
3061*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < batch_size(); i++) {
3062*4bdc9457SAndroid Build Coastguard Worker         for (size_t y = 0; y < output_height(); y++) {
3063*4bdc9457SAndroid Build Coastguard Worker           for (size_t x = 0; x < output_width(); x++) {
3064*4bdc9457SAndroid Build Coastguard Worker             for (size_t g = 0; g < groups(); g++) {
3065*4bdc9457SAndroid Build Coastguard Worker               for (size_t c = 0; c < group_output_channels(); c++) {
3066*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_GE(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c], output_min)
3067*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
3068*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_LE(output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c], output_max)
3069*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
3070*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_NEAR(
3071*4bdc9457SAndroid Build Coastguard Worker                     output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c],
3072*4bdc9457SAndroid Build Coastguard Worker                     output[((i * output_height() + y) * output_width() + x) * output_channel_stride() + g * group_output_channels() + c],
3073*4bdc9457SAndroid Build Coastguard Worker                     1.0e-4 * std::abs(output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c]))
3074*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
3075*4bdc9457SAndroid Build Coastguard Worker               }
3076*4bdc9457SAndroid Build Coastguard Worker             }
3077*4bdc9457SAndroid Build Coastguard Worker           }
3078*4bdc9457SAndroid Build Coastguard Worker         }
3079*4bdc9457SAndroid Build Coastguard Worker       }
3080*4bdc9457SAndroid Build Coastguard Worker 
3081*4bdc9457SAndroid Build Coastguard Worker       // Re-generate data for the second run.
3082*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
3083*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), nanf(""));
3084*4bdc9457SAndroid Build Coastguard Worker 
3085*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results for the second run, including clamping.
3086*4bdc9457SAndroid Build Coastguard Worker       if (has_bias()) {
3087*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < next_batch_size(); i++) {
3088*4bdc9457SAndroid Build Coastguard Worker           for (size_t oy = 0; oy < next_output_height(); oy++) {
3089*4bdc9457SAndroid Build Coastguard Worker             for (size_t ox = 0; ox < next_output_width(); ox++) {
3090*4bdc9457SAndroid Build Coastguard Worker               for (size_t g = 0; g < groups(); g++) {
3091*4bdc9457SAndroid Build Coastguard Worker                 for (size_t oc = 0; oc < group_output_channels(); oc++) {
3092*4bdc9457SAndroid Build Coastguard Worker                   next_output_ref[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] =
3093*4bdc9457SAndroid Build Coastguard Worker                     bias[g * group_output_channels() + oc];
3094*4bdc9457SAndroid Build Coastguard Worker                 }
3095*4bdc9457SAndroid Build Coastguard Worker               }
3096*4bdc9457SAndroid Build Coastguard Worker             }
3097*4bdc9457SAndroid Build Coastguard Worker           }
3098*4bdc9457SAndroid Build Coastguard Worker         }
3099*4bdc9457SAndroid Build Coastguard Worker       } else {
3100*4bdc9457SAndroid Build Coastguard Worker         std::fill(next_output_ref.begin(), next_output_ref.end(), 0.0f);
3101*4bdc9457SAndroid Build Coastguard Worker       }
3102*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < next_batch_size(); i++) {
3103*4bdc9457SAndroid Build Coastguard Worker         for (size_t oy = 0; oy < next_output_height(); oy++) {
3104*4bdc9457SAndroid Build Coastguard Worker           for (size_t ox = 0; ox < next_output_width(); ox++) {
3105*4bdc9457SAndroid Build Coastguard Worker             for (size_t ky = 0; ky < kernel_height(); ky++) {
3106*4bdc9457SAndroid Build Coastguard Worker               const size_t iy = oy * subsampling_height() + ky * dilation_height() - padding_top();
3107*4bdc9457SAndroid Build Coastguard Worker               if (iy < next_input_height()) {
3108*4bdc9457SAndroid Build Coastguard Worker                 for (size_t kx = 0; kx < kernel_width(); kx++) {
3109*4bdc9457SAndroid Build Coastguard Worker                   const size_t ix = ox * subsampling_width() + kx * dilation_width() - padding_left();
3110*4bdc9457SAndroid Build Coastguard Worker                   if (ix < next_input_width()) {
3111*4bdc9457SAndroid Build Coastguard Worker                     for (size_t g = 0; g < groups(); g++) {
3112*4bdc9457SAndroid Build Coastguard Worker                       for (size_t oc = 0; oc < group_output_channels(); oc++) {
3113*4bdc9457SAndroid Build Coastguard Worker                         for (size_t ic = 0; ic < group_input_channels(); ic++) {
3114*4bdc9457SAndroid Build Coastguard Worker                           next_output_ref[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
3115*4bdc9457SAndroid Build Coastguard Worker                             input[((i * next_input_height() + iy) * next_input_width() + ix) * input_channel_stride() + g * group_input_channels() + ic] *
3116*4bdc9457SAndroid Build Coastguard Worker                             kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic];
3117*4bdc9457SAndroid Build Coastguard Worker                         }
3118*4bdc9457SAndroid Build Coastguard Worker                       }
3119*4bdc9457SAndroid Build Coastguard Worker                     }
3120*4bdc9457SAndroid Build Coastguard Worker                   }
3121*4bdc9457SAndroid Build Coastguard Worker                 }
3122*4bdc9457SAndroid Build Coastguard Worker               }
3123*4bdc9457SAndroid Build Coastguard Worker             }
3124*4bdc9457SAndroid Build Coastguard Worker           }
3125*4bdc9457SAndroid Build Coastguard Worker         }
3126*4bdc9457SAndroid Build Coastguard Worker       }
3127*4bdc9457SAndroid Build Coastguard Worker       for (float& value : next_output_ref) {
3128*4bdc9457SAndroid Build Coastguard Worker         value = std::max(std::min(value, output_max), output_min);
3129*4bdc9457SAndroid Build Coastguard Worker       }
3130*4bdc9457SAndroid Build Coastguard Worker 
3131*4bdc9457SAndroid Build Coastguard Worker       // Setup and run Convolution operator the second time, and destroy the operator.
3132*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
3133*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_convolution2d_nhwc_f32(
3134*4bdc9457SAndroid Build Coastguard Worker           convolution_op,
3135*4bdc9457SAndroid Build Coastguard Worker           next_batch_size(), next_input_height(), next_input_width(),
3136*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
3137*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
3138*4bdc9457SAndroid Build Coastguard Worker 
3139*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
3140*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(convolution_op, nullptr /* thread pool */));
3141*4bdc9457SAndroid Build Coastguard Worker 
3142*4bdc9457SAndroid Build Coastguard Worker       // Verify results of the second run.
3143*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < next_batch_size(); i++) {
3144*4bdc9457SAndroid Build Coastguard Worker         for (size_t y = 0; y < next_output_height(); y++) {
3145*4bdc9457SAndroid Build Coastguard Worker           for (size_t x = 0; x < next_output_width(); x++) {
3146*4bdc9457SAndroid Build Coastguard Worker             for (size_t g = 0; g < groups(); g++) {
3147*4bdc9457SAndroid Build Coastguard Worker               for (size_t c = 0; c < group_output_channels(); c++) {
3148*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_GE(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c], output_min)
3149*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
3150*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_LE(output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c], output_max)
3151*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
3152*4bdc9457SAndroid Build Coastguard Worker                 ASSERT_NEAR(
3153*4bdc9457SAndroid Build Coastguard Worker                     next_output_ref[(((i * next_output_height() + y) * next_output_width() + x) * groups() + g) * group_output_channels() + c],
3154*4bdc9457SAndroid Build Coastguard Worker                     output[((i * next_output_height() + y) * next_output_width() + x) * output_channel_stride() + g * group_output_channels() + c],
3155*4bdc9457SAndroid Build Coastguard Worker                     1.0e-4 * std::abs(next_output_ref[(((i * next_output_height() + y) * next_output_width() + x) * groups() + g) * group_output_channels() + c]))
3156*4bdc9457SAndroid Build Coastguard Worker                   << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
3157*4bdc9457SAndroid Build Coastguard Worker               }
3158*4bdc9457SAndroid Build Coastguard Worker             }
3159*4bdc9457SAndroid Build Coastguard Worker           }
3160*4bdc9457SAndroid Build Coastguard Worker         }
3161*4bdc9457SAndroid Build Coastguard Worker       }
3162*4bdc9457SAndroid Build Coastguard Worker     }
3163*4bdc9457SAndroid Build Coastguard Worker   }
3164*4bdc9457SAndroid Build Coastguard Worker 
VerifyWeightsCache(const xnn_weights_cache & weights_cache,size_t old_size)3165*4bdc9457SAndroid Build Coastguard Worker   void VerifyWeightsCache(const xnn_weights_cache &weights_cache, size_t old_size) const {
3166*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(weights_cache.cache.hits, 1);
3167*4bdc9457SAndroid Build Coastguard Worker     // Ensure that we did not write more weights to the cache because it was a
3168*4bdc9457SAndroid Build Coastguard Worker     // cache hit.
3169*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(old_size, weights_cache.cache.weights.size);
3170*4bdc9457SAndroid Build Coastguard Worker   };
3171*4bdc9457SAndroid Build Coastguard Worker 
VerifyWeightsCacheUnused(const xnn_weights_cache & weights_cache)3172*4bdc9457SAndroid Build Coastguard Worker   void VerifyWeightsCacheUnused(const xnn_weights_cache &weights_cache) const {
3173*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(weights_cache.cache.hits, 0);
3174*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(0, weights_cache.cache.weights.size);
3175*4bdc9457SAndroid Build Coastguard Worker   }
3176*4bdc9457SAndroid Build Coastguard Worker 
IsSpmm()3177*4bdc9457SAndroid Build Coastguard Worker   bool IsSpmm() const {
3178*4bdc9457SAndroid Build Coastguard Worker     const bool is_1x1 = kernel_width() == 1 && kernel_height() == 1 &&
3179*4bdc9457SAndroid Build Coastguard Worker         subsampling_height() == 1 && subsampling_width() == 1;
3180*4bdc9457SAndroid Build Coastguard Worker     const bool any_padding = (padding_left() | padding_top() | padding_right() | padding_bottom()) != 0;
3181*4bdc9457SAndroid Build Coastguard Worker     return is_1x1 && !any_padding && !force_nhwc_input() && groups() == 1;
3182*4bdc9457SAndroid Build Coastguard Worker   }
3183*4bdc9457SAndroid Build Coastguard Worker 
3184*4bdc9457SAndroid Build Coastguard Worker  private:
3185*4bdc9457SAndroid Build Coastguard Worker   uint32_t padding_top_{0};
3186*4bdc9457SAndroid Build Coastguard Worker   uint32_t padding_right_{0};
3187*4bdc9457SAndroid Build Coastguard Worker   uint32_t padding_bottom_{0};
3188*4bdc9457SAndroid Build Coastguard Worker   uint32_t padding_left_{0};
3189*4bdc9457SAndroid Build Coastguard Worker   bool padding_tf_same_{false};
3190*4bdc9457SAndroid Build Coastguard Worker   size_t input_height_{1};
3191*4bdc9457SAndroid Build Coastguard Worker   size_t input_width_{1};
3192*4bdc9457SAndroid Build Coastguard Worker   uint32_t groups_{1};
3193*4bdc9457SAndroid Build Coastguard Worker   size_t group_input_channels_{1};
3194*4bdc9457SAndroid Build Coastguard Worker   size_t input_channel_stride_{0};
3195*4bdc9457SAndroid Build Coastguard Worker   size_t group_output_channels_{1};
3196*4bdc9457SAndroid Build Coastguard Worker   size_t output_channel_stride_{0};
3197*4bdc9457SAndroid Build Coastguard Worker   size_t batch_size_{1};
3198*4bdc9457SAndroid Build Coastguard Worker   uint32_t kernel_height_{1};
3199*4bdc9457SAndroid Build Coastguard Worker   uint32_t kernel_width_{1};
3200*4bdc9457SAndroid Build Coastguard Worker   uint32_t dilation_height_{1};
3201*4bdc9457SAndroid Build Coastguard Worker   uint32_t dilation_width_{1};
3202*4bdc9457SAndroid Build Coastguard Worker   uint32_t subsampling_height_{1};
3203*4bdc9457SAndroid Build Coastguard Worker   uint32_t subsampling_width_{1};
3204*4bdc9457SAndroid Build Coastguard Worker   size_t next_input_height_{0};
3205*4bdc9457SAndroid Build Coastguard Worker   size_t next_input_width_{0};
3206*4bdc9457SAndroid Build Coastguard Worker   size_t next_batch_size_{0};
3207*4bdc9457SAndroid Build Coastguard Worker   float sparsity_{0.0f};
3208*4bdc9457SAndroid Build Coastguard Worker   uint8_t qmin_{0};
3209*4bdc9457SAndroid Build Coastguard Worker   uint8_t qmax_{255};
3210*4bdc9457SAndroid Build Coastguard Worker   bool depthwise_layout_{false};
3211*4bdc9457SAndroid Build Coastguard Worker   bool force_nhwc_input_{false};
3212*4bdc9457SAndroid Build Coastguard Worker   bool has_bias_{true};
3213*4bdc9457SAndroid Build Coastguard Worker   WeightsType weights_type_{WeightsType::Default};
3214*4bdc9457SAndroid Build Coastguard Worker   size_t iterations_{1};
3215*4bdc9457SAndroid Build Coastguard Worker #if XNN_PLATFORM_JIT
3216*4bdc9457SAndroid Build Coastguard Worker   bool use_jit_{false};
3217*4bdc9457SAndroid Build Coastguard Worker #endif
3218*4bdc9457SAndroid Build Coastguard Worker   bool use_weights_cache_{false};
3219*4bdc9457SAndroid Build Coastguard Worker };
3220