1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates.
2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved.
3*4bdc9457SAndroid Build Coastguard Worker //
4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
5*4bdc9457SAndroid Build Coastguard Worker //
6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
8*4bdc9457SAndroid Build Coastguard Worker
9*4bdc9457SAndroid Build Coastguard Worker #pragma once
10*4bdc9457SAndroid Build Coastguard Worker
11*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
12*4bdc9457SAndroid Build Coastguard Worker
13*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
14*4bdc9457SAndroid Build Coastguard Worker #include <cassert>
15*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
16*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
17*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
18*4bdc9457SAndroid Build Coastguard Worker #include <limits>
19*4bdc9457SAndroid Build Coastguard Worker #include <random>
20*4bdc9457SAndroid Build Coastguard Worker #include <vector>
21*4bdc9457SAndroid Build Coastguard Worker
22*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
23*4bdc9457SAndroid Build Coastguard Worker
24*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
25*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/cache.h>
26*4bdc9457SAndroid Build Coastguard Worker
27*4bdc9457SAndroid Build Coastguard Worker namespace {
28*4bdc9457SAndroid Build Coastguard Worker
29*4bdc9457SAndroid Build Coastguard Worker template<class T>
doz(T a,T b)30*4bdc9457SAndroid Build Coastguard Worker inline T doz(T a, T b) {
31*4bdc9457SAndroid Build Coastguard Worker return a > b ? a - b : T(0);
32*4bdc9457SAndroid Build Coastguard Worker }
33*4bdc9457SAndroid Build Coastguard Worker
34*4bdc9457SAndroid Build Coastguard Worker } // namespace
35*4bdc9457SAndroid Build Coastguard Worker
36*4bdc9457SAndroid Build Coastguard Worker class DeconvolutionOperatorTester {
37*4bdc9457SAndroid Build Coastguard Worker public:
38*4bdc9457SAndroid Build Coastguard Worker enum class WeightsType {
39*4bdc9457SAndroid Build Coastguard Worker Default,
40*4bdc9457SAndroid Build Coastguard Worker FP32,
41*4bdc9457SAndroid Build Coastguard Worker };
42*4bdc9457SAndroid Build Coastguard Worker
padding(uint32_t padding)43*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& padding(uint32_t padding) {
44*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding;
45*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding;
46*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding;
47*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding;
48*4bdc9457SAndroid Build Coastguard Worker return *this;
49*4bdc9457SAndroid Build Coastguard Worker }
50*4bdc9457SAndroid Build Coastguard Worker
padding_height(uint32_t padding_height)51*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& padding_height(uint32_t padding_height) {
52*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding_height;
53*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding_height;
54*4bdc9457SAndroid Build Coastguard Worker return *this;
55*4bdc9457SAndroid Build Coastguard Worker }
56*4bdc9457SAndroid Build Coastguard Worker
padding_height()57*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_height() const {
58*4bdc9457SAndroid Build Coastguard Worker return this->padding_top_ + this->padding_bottom_;
59*4bdc9457SAndroid Build Coastguard Worker }
60*4bdc9457SAndroid Build Coastguard Worker
padding_width(uint32_t padding_width)61*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& padding_width(uint32_t padding_width) {
62*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding_width;
63*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding_width;
64*4bdc9457SAndroid Build Coastguard Worker return *this;
65*4bdc9457SAndroid Build Coastguard Worker }
66*4bdc9457SAndroid Build Coastguard Worker
padding_width()67*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_width() const {
68*4bdc9457SAndroid Build Coastguard Worker return this->padding_left_ + this->padding_right_;
69*4bdc9457SAndroid Build Coastguard Worker }
70*4bdc9457SAndroid Build Coastguard Worker
padding_top(uint32_t padding_top)71*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& padding_top(uint32_t padding_top) {
72*4bdc9457SAndroid Build Coastguard Worker this->padding_top_ = padding_top;
73*4bdc9457SAndroid Build Coastguard Worker return *this;
74*4bdc9457SAndroid Build Coastguard Worker }
75*4bdc9457SAndroid Build Coastguard Worker
padding_top()76*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_top() const { return this->padding_top_; }
77*4bdc9457SAndroid Build Coastguard Worker
padding_right(uint32_t padding_right)78*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& padding_right(uint32_t padding_right) {
79*4bdc9457SAndroid Build Coastguard Worker this->padding_right_ = padding_right;
80*4bdc9457SAndroid Build Coastguard Worker return *this;
81*4bdc9457SAndroid Build Coastguard Worker }
82*4bdc9457SAndroid Build Coastguard Worker
padding_right()83*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_right() const { return this->padding_right_; }
84*4bdc9457SAndroid Build Coastguard Worker
padding_bottom(uint32_t padding_bottom)85*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& padding_bottom(uint32_t padding_bottom) {
86*4bdc9457SAndroid Build Coastguard Worker this->padding_bottom_ = padding_bottom;
87*4bdc9457SAndroid Build Coastguard Worker return *this;
88*4bdc9457SAndroid Build Coastguard Worker }
89*4bdc9457SAndroid Build Coastguard Worker
padding_bottom()90*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_bottom() const { return this->padding_bottom_; }
91*4bdc9457SAndroid Build Coastguard Worker
padding_left(uint32_t padding_left)92*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& padding_left(uint32_t padding_left) {
93*4bdc9457SAndroid Build Coastguard Worker this->padding_left_ = padding_left;
94*4bdc9457SAndroid Build Coastguard Worker return *this;
95*4bdc9457SAndroid Build Coastguard Worker }
96*4bdc9457SAndroid Build Coastguard Worker
padding_left()97*4bdc9457SAndroid Build Coastguard Worker inline uint32_t padding_left() const { return this->padding_left_; }
98*4bdc9457SAndroid Build Coastguard Worker
adjustment_height(uint32_t adjustment_height)99*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& adjustment_height(uint32_t adjustment_height) {
100*4bdc9457SAndroid Build Coastguard Worker this->adjustment_height_ = adjustment_height;
101*4bdc9457SAndroid Build Coastguard Worker return *this;
102*4bdc9457SAndroid Build Coastguard Worker }
103*4bdc9457SAndroid Build Coastguard Worker
adjustment_height()104*4bdc9457SAndroid Build Coastguard Worker inline uint32_t adjustment_height() const {
105*4bdc9457SAndroid Build Coastguard Worker return this->adjustment_height_;
106*4bdc9457SAndroid Build Coastguard Worker }
107*4bdc9457SAndroid Build Coastguard Worker
adjustment_width(uint32_t adjustment_width)108*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& adjustment_width(uint32_t adjustment_width) {
109*4bdc9457SAndroid Build Coastguard Worker this->adjustment_width_ = adjustment_width;
110*4bdc9457SAndroid Build Coastguard Worker return *this;
111*4bdc9457SAndroid Build Coastguard Worker }
112*4bdc9457SAndroid Build Coastguard Worker
adjustment_width()113*4bdc9457SAndroid Build Coastguard Worker inline uint32_t adjustment_width() const {
114*4bdc9457SAndroid Build Coastguard Worker return this->adjustment_width_;
115*4bdc9457SAndroid Build Coastguard Worker }
116*4bdc9457SAndroid Build Coastguard Worker
input_size(uint32_t input_height,uint32_t input_width)117*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& input_size(uint32_t input_height, uint32_t input_width) {
118*4bdc9457SAndroid Build Coastguard Worker assert(input_height >= 1);
119*4bdc9457SAndroid Build Coastguard Worker assert(input_width >= 1);
120*4bdc9457SAndroid Build Coastguard Worker this->input_height_ = input_height;
121*4bdc9457SAndroid Build Coastguard Worker this->input_width_ = input_width;
122*4bdc9457SAndroid Build Coastguard Worker return *this;
123*4bdc9457SAndroid Build Coastguard Worker }
124*4bdc9457SAndroid Build Coastguard Worker
input_height(uint32_t input_height)125*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& input_height(uint32_t input_height) {
126*4bdc9457SAndroid Build Coastguard Worker assert(input_height >= 1);
127*4bdc9457SAndroid Build Coastguard Worker this->input_height_ = input_height;
128*4bdc9457SAndroid Build Coastguard Worker return *this;
129*4bdc9457SAndroid Build Coastguard Worker }
130*4bdc9457SAndroid Build Coastguard Worker
input_height()131*4bdc9457SAndroid Build Coastguard Worker inline uint32_t input_height() const {
132*4bdc9457SAndroid Build Coastguard Worker return this->input_height_;
133*4bdc9457SAndroid Build Coastguard Worker }
134*4bdc9457SAndroid Build Coastguard Worker
input_width(uint32_t input_width)135*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& input_width(uint32_t input_width) {
136*4bdc9457SAndroid Build Coastguard Worker assert(input_width >= 1);
137*4bdc9457SAndroid Build Coastguard Worker this->input_width_ = input_width;
138*4bdc9457SAndroid Build Coastguard Worker return *this;
139*4bdc9457SAndroid Build Coastguard Worker }
140*4bdc9457SAndroid Build Coastguard Worker
input_width()141*4bdc9457SAndroid Build Coastguard Worker inline uint32_t input_width() const {
142*4bdc9457SAndroid Build Coastguard Worker return this->input_width_;
143*4bdc9457SAndroid Build Coastguard Worker }
144*4bdc9457SAndroid Build Coastguard Worker
groups(uint32_t groups)145*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& groups(uint32_t groups) {
146*4bdc9457SAndroid Build Coastguard Worker assert(groups >= 1);
147*4bdc9457SAndroid Build Coastguard Worker this->groups_ = groups;
148*4bdc9457SAndroid Build Coastguard Worker return *this;
149*4bdc9457SAndroid Build Coastguard Worker }
150*4bdc9457SAndroid Build Coastguard Worker
groups()151*4bdc9457SAndroid Build Coastguard Worker inline uint32_t groups() const {
152*4bdc9457SAndroid Build Coastguard Worker return this->groups_;
153*4bdc9457SAndroid Build Coastguard Worker }
154*4bdc9457SAndroid Build Coastguard Worker
group_input_channels(size_t group_input_channels)155*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& group_input_channels(size_t group_input_channels) {
156*4bdc9457SAndroid Build Coastguard Worker assert(group_input_channels >= 1);
157*4bdc9457SAndroid Build Coastguard Worker this->group_input_channels_ = group_input_channels;
158*4bdc9457SAndroid Build Coastguard Worker return *this;
159*4bdc9457SAndroid Build Coastguard Worker }
160*4bdc9457SAndroid Build Coastguard Worker
group_input_channels()161*4bdc9457SAndroid Build Coastguard Worker inline size_t group_input_channels() const {
162*4bdc9457SAndroid Build Coastguard Worker return this->group_input_channels_;
163*4bdc9457SAndroid Build Coastguard Worker }
164*4bdc9457SAndroid Build Coastguard Worker
group_output_channels(size_t group_output_channels)165*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& group_output_channels(size_t group_output_channels) {
166*4bdc9457SAndroid Build Coastguard Worker assert(group_output_channels >= 1);
167*4bdc9457SAndroid Build Coastguard Worker this->group_output_channels_ = group_output_channels;
168*4bdc9457SAndroid Build Coastguard Worker return *this;
169*4bdc9457SAndroid Build Coastguard Worker }
170*4bdc9457SAndroid Build Coastguard Worker
group_output_channels()171*4bdc9457SAndroid Build Coastguard Worker inline size_t group_output_channels() const {
172*4bdc9457SAndroid Build Coastguard Worker return this->group_output_channels_;
173*4bdc9457SAndroid Build Coastguard Worker }
174*4bdc9457SAndroid Build Coastguard Worker
batch_size(size_t batch_size)175*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& batch_size(size_t batch_size) {
176*4bdc9457SAndroid Build Coastguard Worker assert(batch_size >= 1);
177*4bdc9457SAndroid Build Coastguard Worker this->batch_size_ = batch_size;
178*4bdc9457SAndroid Build Coastguard Worker return *this;
179*4bdc9457SAndroid Build Coastguard Worker }
180*4bdc9457SAndroid Build Coastguard Worker
batch_size()181*4bdc9457SAndroid Build Coastguard Worker inline size_t batch_size() const {
182*4bdc9457SAndroid Build Coastguard Worker return this->batch_size_;
183*4bdc9457SAndroid Build Coastguard Worker }
184*4bdc9457SAndroid Build Coastguard Worker
kernel_size(uint32_t kernel_size)185*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& kernel_size(uint32_t kernel_size) {
186*4bdc9457SAndroid Build Coastguard Worker assert(kernel_size >= 1);
187*4bdc9457SAndroid Build Coastguard Worker this->kernel_height_ = kernel_size;
188*4bdc9457SAndroid Build Coastguard Worker this->kernel_width_ = kernel_size;
189*4bdc9457SAndroid Build Coastguard Worker return *this;
190*4bdc9457SAndroid Build Coastguard Worker }
191*4bdc9457SAndroid Build Coastguard Worker
kernel_size(uint32_t kernel_height,uint32_t kernel_width)192*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& kernel_size(uint32_t kernel_height, uint32_t kernel_width) {
193*4bdc9457SAndroid Build Coastguard Worker assert(kernel_height >= 1);
194*4bdc9457SAndroid Build Coastguard Worker assert(kernel_width >= 1);
195*4bdc9457SAndroid Build Coastguard Worker this->kernel_height_ = kernel_height;
196*4bdc9457SAndroid Build Coastguard Worker this->kernel_width_ = kernel_width;
197*4bdc9457SAndroid Build Coastguard Worker return *this;
198*4bdc9457SAndroid Build Coastguard Worker }
199*4bdc9457SAndroid Build Coastguard Worker
kernel_height(uint32_t kernel_height)200*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& kernel_height(uint32_t kernel_height) {
201*4bdc9457SAndroid Build Coastguard Worker assert(kernel_height >= 1);
202*4bdc9457SAndroid Build Coastguard Worker this->kernel_height_ = kernel_height;
203*4bdc9457SAndroid Build Coastguard Worker return *this;
204*4bdc9457SAndroid Build Coastguard Worker }
205*4bdc9457SAndroid Build Coastguard Worker
kernel_height()206*4bdc9457SAndroid Build Coastguard Worker inline uint32_t kernel_height() const {
207*4bdc9457SAndroid Build Coastguard Worker return this->kernel_height_;
208*4bdc9457SAndroid Build Coastguard Worker }
209*4bdc9457SAndroid Build Coastguard Worker
kernel_width(uint32_t kernel_width)210*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& kernel_width(uint32_t kernel_width) {
211*4bdc9457SAndroid Build Coastguard Worker assert(kernel_width >= 1);
212*4bdc9457SAndroid Build Coastguard Worker this->kernel_width_ = kernel_width;
213*4bdc9457SAndroid Build Coastguard Worker return *this;
214*4bdc9457SAndroid Build Coastguard Worker }
215*4bdc9457SAndroid Build Coastguard Worker
kernel_width()216*4bdc9457SAndroid Build Coastguard Worker inline uint32_t kernel_width() const {
217*4bdc9457SAndroid Build Coastguard Worker return this->kernel_width_;
218*4bdc9457SAndroid Build Coastguard Worker }
219*4bdc9457SAndroid Build Coastguard Worker
dilation(uint32_t dilation)220*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& dilation(uint32_t dilation) {
221*4bdc9457SAndroid Build Coastguard Worker assert(dilation >= 1);
222*4bdc9457SAndroid Build Coastguard Worker this->dilation_height_ = dilation;
223*4bdc9457SAndroid Build Coastguard Worker this->dilation_width_ = dilation;
224*4bdc9457SAndroid Build Coastguard Worker return *this;
225*4bdc9457SAndroid Build Coastguard Worker }
226*4bdc9457SAndroid Build Coastguard Worker
dilation(uint32_t dilation_height,uint32_t dilation_width)227*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& dilation(uint32_t dilation_height, uint32_t dilation_width) {
228*4bdc9457SAndroid Build Coastguard Worker assert(dilation_height >= 1);
229*4bdc9457SAndroid Build Coastguard Worker assert(dilation_width >= 1);
230*4bdc9457SAndroid Build Coastguard Worker this->dilation_height_ = dilation_height;
231*4bdc9457SAndroid Build Coastguard Worker this->dilation_width_ = dilation_width;
232*4bdc9457SAndroid Build Coastguard Worker return *this;
233*4bdc9457SAndroid Build Coastguard Worker }
234*4bdc9457SAndroid Build Coastguard Worker
dilation_height(uint32_t dilation_height)235*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& dilation_height(uint32_t dilation_height) {
236*4bdc9457SAndroid Build Coastguard Worker assert(dilation_height >= 1);
237*4bdc9457SAndroid Build Coastguard Worker this->dilation_height_ = dilation_height;
238*4bdc9457SAndroid Build Coastguard Worker return *this;
239*4bdc9457SAndroid Build Coastguard Worker }
240*4bdc9457SAndroid Build Coastguard Worker
dilation_height()241*4bdc9457SAndroid Build Coastguard Worker inline uint32_t dilation_height() const {
242*4bdc9457SAndroid Build Coastguard Worker return this->dilation_height_;
243*4bdc9457SAndroid Build Coastguard Worker }
244*4bdc9457SAndroid Build Coastguard Worker
dilation_width(uint32_t dilation_width)245*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& dilation_width(uint32_t dilation_width) {
246*4bdc9457SAndroid Build Coastguard Worker assert(dilation_width >= 1);
247*4bdc9457SAndroid Build Coastguard Worker this->dilation_width_ = dilation_width;
248*4bdc9457SAndroid Build Coastguard Worker return *this;
249*4bdc9457SAndroid Build Coastguard Worker }
250*4bdc9457SAndroid Build Coastguard Worker
dilation_width()251*4bdc9457SAndroid Build Coastguard Worker inline uint32_t dilation_width() const {
252*4bdc9457SAndroid Build Coastguard Worker return this->dilation_width_;
253*4bdc9457SAndroid Build Coastguard Worker }
254*4bdc9457SAndroid Build Coastguard Worker
stride(uint32_t stride)255*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& stride(uint32_t stride) {
256*4bdc9457SAndroid Build Coastguard Worker assert(stride >= 1);
257*4bdc9457SAndroid Build Coastguard Worker this->stride_height_ = stride;
258*4bdc9457SAndroid Build Coastguard Worker this->stride_width_ = stride;
259*4bdc9457SAndroid Build Coastguard Worker return *this;
260*4bdc9457SAndroid Build Coastguard Worker }
261*4bdc9457SAndroid Build Coastguard Worker
stride(uint32_t stride_height,uint32_t stride_width)262*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& stride(uint32_t stride_height, uint32_t stride_width) {
263*4bdc9457SAndroid Build Coastguard Worker assert(stride_height >= 1);
264*4bdc9457SAndroid Build Coastguard Worker assert(stride_width >= 1);
265*4bdc9457SAndroid Build Coastguard Worker this->stride_height_ = stride_height;
266*4bdc9457SAndroid Build Coastguard Worker this->stride_width_ = stride_width;
267*4bdc9457SAndroid Build Coastguard Worker return *this;
268*4bdc9457SAndroid Build Coastguard Worker }
269*4bdc9457SAndroid Build Coastguard Worker
stride_height(uint32_t stride_height)270*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& stride_height(uint32_t stride_height) {
271*4bdc9457SAndroid Build Coastguard Worker assert(stride_height >= 1);
272*4bdc9457SAndroid Build Coastguard Worker this->stride_height_ = stride_height;
273*4bdc9457SAndroid Build Coastguard Worker return *this;
274*4bdc9457SAndroid Build Coastguard Worker }
275*4bdc9457SAndroid Build Coastguard Worker
stride_height()276*4bdc9457SAndroid Build Coastguard Worker inline uint32_t stride_height() const {
277*4bdc9457SAndroid Build Coastguard Worker return this->stride_height_;
278*4bdc9457SAndroid Build Coastguard Worker }
279*4bdc9457SAndroid Build Coastguard Worker
stride_width(uint32_t stride_width)280*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& stride_width(uint32_t stride_width) {
281*4bdc9457SAndroid Build Coastguard Worker assert(stride_width >= 1);
282*4bdc9457SAndroid Build Coastguard Worker this->stride_width_ = stride_width;
283*4bdc9457SAndroid Build Coastguard Worker return *this;
284*4bdc9457SAndroid Build Coastguard Worker }
285*4bdc9457SAndroid Build Coastguard Worker
stride_width()286*4bdc9457SAndroid Build Coastguard Worker inline uint32_t stride_width() const {
287*4bdc9457SAndroid Build Coastguard Worker return this->stride_width_;
288*4bdc9457SAndroid Build Coastguard Worker }
289*4bdc9457SAndroid Build Coastguard Worker
input_pixel_stride(size_t input_pixel_stride)290*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& input_pixel_stride(size_t input_pixel_stride) {
291*4bdc9457SAndroid Build Coastguard Worker assert(input_pixel_stride >= 1);
292*4bdc9457SAndroid Build Coastguard Worker this->input_pixel_stride_ = input_pixel_stride;
293*4bdc9457SAndroid Build Coastguard Worker return *this;
294*4bdc9457SAndroid Build Coastguard Worker }
295*4bdc9457SAndroid Build Coastguard Worker
input_pixel_stride()296*4bdc9457SAndroid Build Coastguard Worker inline size_t input_pixel_stride() const {
297*4bdc9457SAndroid Build Coastguard Worker if (this->input_pixel_stride_ == 0) {
298*4bdc9457SAndroid Build Coastguard Worker return group_input_channels() * groups();
299*4bdc9457SAndroid Build Coastguard Worker } else {
300*4bdc9457SAndroid Build Coastguard Worker assert(this->input_pixel_stride_ >= group_input_channels() * groups());
301*4bdc9457SAndroid Build Coastguard Worker return this->input_pixel_stride_;
302*4bdc9457SAndroid Build Coastguard Worker }
303*4bdc9457SAndroid Build Coastguard Worker }
304*4bdc9457SAndroid Build Coastguard Worker
output_pixel_stride(size_t output_pixel_stride)305*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& output_pixel_stride(size_t output_pixel_stride) {
306*4bdc9457SAndroid Build Coastguard Worker assert(output_pixel_stride >= 1);
307*4bdc9457SAndroid Build Coastguard Worker this->output_pixel_stride_ = output_pixel_stride;
308*4bdc9457SAndroid Build Coastguard Worker return *this;
309*4bdc9457SAndroid Build Coastguard Worker }
310*4bdc9457SAndroid Build Coastguard Worker
output_pixel_stride()311*4bdc9457SAndroid Build Coastguard Worker inline size_t output_pixel_stride() const {
312*4bdc9457SAndroid Build Coastguard Worker if (this->output_pixel_stride_ == 0) {
313*4bdc9457SAndroid Build Coastguard Worker return group_output_channels() * groups();
314*4bdc9457SAndroid Build Coastguard Worker } else {
315*4bdc9457SAndroid Build Coastguard Worker assert(this->output_pixel_stride_ >= group_output_channels() * groups());
316*4bdc9457SAndroid Build Coastguard Worker return this->output_pixel_stride_;
317*4bdc9457SAndroid Build Coastguard Worker }
318*4bdc9457SAndroid Build Coastguard Worker }
319*4bdc9457SAndroid Build Coastguard Worker
dilated_kernel_height()320*4bdc9457SAndroid Build Coastguard Worker inline uint32_t dilated_kernel_height() const {
321*4bdc9457SAndroid Build Coastguard Worker return (kernel_height() - 1) * dilation_height() + 1;
322*4bdc9457SAndroid Build Coastguard Worker }
323*4bdc9457SAndroid Build Coastguard Worker
dilated_kernel_width()324*4bdc9457SAndroid Build Coastguard Worker inline uint32_t dilated_kernel_width() const {
325*4bdc9457SAndroid Build Coastguard Worker return (kernel_width() - 1) * dilation_width() + 1;
326*4bdc9457SAndroid Build Coastguard Worker }
327*4bdc9457SAndroid Build Coastguard Worker
output_height()328*4bdc9457SAndroid Build Coastguard Worker inline size_t output_height() const {
329*4bdc9457SAndroid Build Coastguard Worker return stride_height() * (input_height() - 1) + adjustment_height() + dilated_kernel_height() - padding_height();
330*4bdc9457SAndroid Build Coastguard Worker }
331*4bdc9457SAndroid Build Coastguard Worker
output_width()332*4bdc9457SAndroid Build Coastguard Worker inline size_t output_width() const {
333*4bdc9457SAndroid Build Coastguard Worker return stride_width() * (input_width() - 1) + adjustment_width() + dilated_kernel_width() - padding_width();
334*4bdc9457SAndroid Build Coastguard Worker }
335*4bdc9457SAndroid Build Coastguard Worker
next_input_size(uint32_t next_input_height,uint32_t next_input_width)336*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& next_input_size(uint32_t next_input_height, uint32_t next_input_width) {
337*4bdc9457SAndroid Build Coastguard Worker assert(next_input_height >= 1);
338*4bdc9457SAndroid Build Coastguard Worker assert(next_input_width >= 1);
339*4bdc9457SAndroid Build Coastguard Worker this->next_input_height_ = next_input_height;
340*4bdc9457SAndroid Build Coastguard Worker this->next_input_width_ = next_input_width;
341*4bdc9457SAndroid Build Coastguard Worker return *this;
342*4bdc9457SAndroid Build Coastguard Worker }
343*4bdc9457SAndroid Build Coastguard Worker
next_input_height(uint32_t next_input_height)344*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& next_input_height(uint32_t next_input_height) {
345*4bdc9457SAndroid Build Coastguard Worker assert(next_input_height >= 1);
346*4bdc9457SAndroid Build Coastguard Worker this->next_input_height_ = next_input_height;
347*4bdc9457SAndroid Build Coastguard Worker return *this;
348*4bdc9457SAndroid Build Coastguard Worker }
349*4bdc9457SAndroid Build Coastguard Worker
next_input_height()350*4bdc9457SAndroid Build Coastguard Worker inline uint32_t next_input_height() const {
351*4bdc9457SAndroid Build Coastguard Worker if (this->next_input_height_ == 0) {
352*4bdc9457SAndroid Build Coastguard Worker return input_height();
353*4bdc9457SAndroid Build Coastguard Worker } else {
354*4bdc9457SAndroid Build Coastguard Worker return this->next_input_height_;
355*4bdc9457SAndroid Build Coastguard Worker }
356*4bdc9457SAndroid Build Coastguard Worker }
357*4bdc9457SAndroid Build Coastguard Worker
next_input_width(uint32_t next_input_width)358*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& next_input_width(uint32_t next_input_width) {
359*4bdc9457SAndroid Build Coastguard Worker assert(next_input_width >= 1);
360*4bdc9457SAndroid Build Coastguard Worker this->next_input_width_ = next_input_width;
361*4bdc9457SAndroid Build Coastguard Worker return *this;
362*4bdc9457SAndroid Build Coastguard Worker }
363*4bdc9457SAndroid Build Coastguard Worker
next_input_width()364*4bdc9457SAndroid Build Coastguard Worker inline uint32_t next_input_width() const {
365*4bdc9457SAndroid Build Coastguard Worker if (this->next_input_width_ == 0) {
366*4bdc9457SAndroid Build Coastguard Worker return input_width();
367*4bdc9457SAndroid Build Coastguard Worker } else {
368*4bdc9457SAndroid Build Coastguard Worker return this->next_input_width_;
369*4bdc9457SAndroid Build Coastguard Worker }
370*4bdc9457SAndroid Build Coastguard Worker }
371*4bdc9457SAndroid Build Coastguard Worker
next_output_height()372*4bdc9457SAndroid Build Coastguard Worker inline size_t next_output_height() const {
373*4bdc9457SAndroid Build Coastguard Worker return stride_height() * (next_input_height() - 1) + adjustment_height() + dilated_kernel_height() - padding_height();
374*4bdc9457SAndroid Build Coastguard Worker }
375*4bdc9457SAndroid Build Coastguard Worker
next_output_width()376*4bdc9457SAndroid Build Coastguard Worker inline size_t next_output_width() const {
377*4bdc9457SAndroid Build Coastguard Worker return stride_width() * (next_input_width() - 1) + adjustment_width() + dilated_kernel_width() - padding_width();
378*4bdc9457SAndroid Build Coastguard Worker }
379*4bdc9457SAndroid Build Coastguard Worker
next_batch_size(size_t next_batch_size)380*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& next_batch_size(size_t next_batch_size) {
381*4bdc9457SAndroid Build Coastguard Worker assert(next_batch_size >= 1);
382*4bdc9457SAndroid Build Coastguard Worker this->next_batch_size_ = next_batch_size;
383*4bdc9457SAndroid Build Coastguard Worker return *this;
384*4bdc9457SAndroid Build Coastguard Worker }
385*4bdc9457SAndroid Build Coastguard Worker
next_batch_size()386*4bdc9457SAndroid Build Coastguard Worker inline size_t next_batch_size() const {
387*4bdc9457SAndroid Build Coastguard Worker if (this->next_batch_size_ == 0) {
388*4bdc9457SAndroid Build Coastguard Worker return batch_size();
389*4bdc9457SAndroid Build Coastguard Worker } else {
390*4bdc9457SAndroid Build Coastguard Worker return this->next_batch_size_;
391*4bdc9457SAndroid Build Coastguard Worker }
392*4bdc9457SAndroid Build Coastguard Worker }
393*4bdc9457SAndroid Build Coastguard Worker
qmin(uint8_t qmin)394*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& qmin(uint8_t qmin) {
395*4bdc9457SAndroid Build Coastguard Worker this->qmin_ = qmin;
396*4bdc9457SAndroid Build Coastguard Worker return *this;
397*4bdc9457SAndroid Build Coastguard Worker }
398*4bdc9457SAndroid Build Coastguard Worker
qmin()399*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmin() const {
400*4bdc9457SAndroid Build Coastguard Worker return this->qmin_;
401*4bdc9457SAndroid Build Coastguard Worker }
402*4bdc9457SAndroid Build Coastguard Worker
qmax(uint8_t qmax)403*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& qmax(uint8_t qmax) {
404*4bdc9457SAndroid Build Coastguard Worker this->qmax_ = qmax;
405*4bdc9457SAndroid Build Coastguard Worker return *this;
406*4bdc9457SAndroid Build Coastguard Worker }
407*4bdc9457SAndroid Build Coastguard Worker
qmax()408*4bdc9457SAndroid Build Coastguard Worker inline uint8_t qmax() const {
409*4bdc9457SAndroid Build Coastguard Worker return this->qmax_;
410*4bdc9457SAndroid Build Coastguard Worker }
411*4bdc9457SAndroid Build Coastguard Worker
has_bias(bool has_bias)412*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& has_bias(bool has_bias) {
413*4bdc9457SAndroid Build Coastguard Worker this->has_bias_ = has_bias;
414*4bdc9457SAndroid Build Coastguard Worker return *this;
415*4bdc9457SAndroid Build Coastguard Worker }
416*4bdc9457SAndroid Build Coastguard Worker
has_bias()417*4bdc9457SAndroid Build Coastguard Worker inline bool has_bias() const {
418*4bdc9457SAndroid Build Coastguard Worker return this->has_bias_;
419*4bdc9457SAndroid Build Coastguard Worker }
420*4bdc9457SAndroid Build Coastguard Worker
weights_type(WeightsType weights_type)421*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& weights_type(WeightsType weights_type) {
422*4bdc9457SAndroid Build Coastguard Worker this->weights_type_ = weights_type;
423*4bdc9457SAndroid Build Coastguard Worker return *this;
424*4bdc9457SAndroid Build Coastguard Worker }
425*4bdc9457SAndroid Build Coastguard Worker
weights_type()426*4bdc9457SAndroid Build Coastguard Worker inline WeightsType weights_type() const {
427*4bdc9457SAndroid Build Coastguard Worker return this->weights_type_;
428*4bdc9457SAndroid Build Coastguard Worker }
429*4bdc9457SAndroid Build Coastguard Worker
use_weights_cache(bool use_weights_cache)430*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& use_weights_cache(bool use_weights_cache) {
431*4bdc9457SAndroid Build Coastguard Worker this->use_weights_cache_ = use_weights_cache;
432*4bdc9457SAndroid Build Coastguard Worker return *this;
433*4bdc9457SAndroid Build Coastguard Worker }
434*4bdc9457SAndroid Build Coastguard Worker
use_weights_cache()435*4bdc9457SAndroid Build Coastguard Worker inline bool use_weights_cache() const {
436*4bdc9457SAndroid Build Coastguard Worker return this->use_weights_cache_;
437*4bdc9457SAndroid Build Coastguard Worker }
438*4bdc9457SAndroid Build Coastguard Worker
iterations(size_t iterations)439*4bdc9457SAndroid Build Coastguard Worker inline DeconvolutionOperatorTester& iterations(size_t iterations) {
440*4bdc9457SAndroid Build Coastguard Worker this->iterations_ = iterations;
441*4bdc9457SAndroid Build Coastguard Worker return *this;
442*4bdc9457SAndroid Build Coastguard Worker }
443*4bdc9457SAndroid Build Coastguard Worker
iterations()444*4bdc9457SAndroid Build Coastguard Worker inline size_t iterations() const {
445*4bdc9457SAndroid Build Coastguard Worker return this->iterations_;
446*4bdc9457SAndroid Build Coastguard Worker }
447*4bdc9457SAndroid Build Coastguard Worker
TestQS8()448*4bdc9457SAndroid Build Coastguard Worker void TestQS8() const {
449*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_type(), WeightsType::Default);
450*4bdc9457SAndroid Build Coastguard Worker
451*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
452*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
453*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i32dist(-10000, 10000);
454*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i8dist(
455*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
456*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> w8dist(
457*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max());
458*4bdc9457SAndroid Build Coastguard Worker
459*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) +
460*4bdc9457SAndroid Build Coastguard Worker (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + groups() * group_input_channels());
461*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels());
462*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> bias(groups() * group_output_channels());
463*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + groups() * group_output_channels());
464*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> accumulators(batch_size() * output_height() * output_width() * groups() * group_output_channels());
465*4bdc9457SAndroid Build Coastguard Worker std::vector<double> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels());
466*4bdc9457SAndroid Build Coastguard Worker
467*4bdc9457SAndroid Build Coastguard Worker const int8_t input_zero_point = 1;
468*4bdc9457SAndroid Build Coastguard Worker
469*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
470*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
471*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); });
472*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); });
473*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), INT8_C(0xA5));
474*4bdc9457SAndroid Build Coastguard Worker
475*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without renormalization.
476*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) {
477*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
478*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) {
479*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) {
480*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
481*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
482*4bdc9457SAndroid Build Coastguard Worker accumulators[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] =
483*4bdc9457SAndroid Build Coastguard Worker bias[g * group_output_channels() + oc];
484*4bdc9457SAndroid Build Coastguard Worker }
485*4bdc9457SAndroid Build Coastguard Worker }
486*4bdc9457SAndroid Build Coastguard Worker }
487*4bdc9457SAndroid Build Coastguard Worker }
488*4bdc9457SAndroid Build Coastguard Worker }
489*4bdc9457SAndroid Build Coastguard Worker } else {
490*4bdc9457SAndroid Build Coastguard Worker std::fill(accumulators.begin(), accumulators.end(), 0);
491*4bdc9457SAndroid Build Coastguard Worker }
492*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
493*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) {
494*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) {
495*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) {
496*4bdc9457SAndroid Build Coastguard Worker const size_t y = oy + padding_top() - ky * dilation_height();
497*4bdc9457SAndroid Build Coastguard Worker const size_t iy = y / stride_height();
498*4bdc9457SAndroid Build Coastguard Worker if (iy * stride_height() == y && iy < input_height()) {
499*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) {
500*4bdc9457SAndroid Build Coastguard Worker const size_t x = ox + padding_left() - kx * dilation_width();
501*4bdc9457SAndroid Build Coastguard Worker const size_t ix = x / stride_width();
502*4bdc9457SAndroid Build Coastguard Worker if (ix * stride_width() == x && ix < input_width()) {
503*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
504*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
505*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) {
506*4bdc9457SAndroid Build Coastguard Worker accumulators[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
507*4bdc9457SAndroid Build Coastguard Worker (int32_t(input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + g * group_input_channels() + ic]) - int32_t(input_zero_point)) *
508*4bdc9457SAndroid Build Coastguard Worker int32_t(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]);
509*4bdc9457SAndroid Build Coastguard Worker }
510*4bdc9457SAndroid Build Coastguard Worker }
511*4bdc9457SAndroid Build Coastguard Worker }
512*4bdc9457SAndroid Build Coastguard Worker }
513*4bdc9457SAndroid Build Coastguard Worker }
514*4bdc9457SAndroid Build Coastguard Worker }
515*4bdc9457SAndroid Build Coastguard Worker }
516*4bdc9457SAndroid Build Coastguard Worker }
517*4bdc9457SAndroid Build Coastguard Worker }
518*4bdc9457SAndroid Build Coastguard Worker }
519*4bdc9457SAndroid Build Coastguard Worker
520*4bdc9457SAndroid Build Coastguard Worker // Compute renormalization parameters.
521*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
522*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
523*4bdc9457SAndroid Build Coastguard Worker
524*4bdc9457SAndroid Build Coastguard Worker const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0;
525*4bdc9457SAndroid Build Coastguard Worker const int8_t output_zero_point = int8_t(std::max(std::min(
526*4bdc9457SAndroid Build Coastguard Worker lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
527*4bdc9457SAndroid Build Coastguard Worker long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
528*4bdc9457SAndroid Build Coastguard Worker
529*4bdc9457SAndroid Build Coastguard Worker // Renormalize reference results.
530*4bdc9457SAndroid Build Coastguard Worker std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(),
531*4bdc9457SAndroid Build Coastguard Worker [this, output_scale, output_zero_point](int32_t x) -> double {
532*4bdc9457SAndroid Build Coastguard Worker return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax() - 0x80) - output_zero_point), double(qmin() - 0x80) - output_zero_point);
533*4bdc9457SAndroid Build Coastguard Worker });
534*4bdc9457SAndroid Build Coastguard Worker
535*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Deconvolution operator.
536*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
537*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t deconvolution_op = nullptr;
538*4bdc9457SAndroid Build Coastguard Worker
539*4bdc9457SAndroid Build Coastguard Worker xnn_caches caches = {
540*4bdc9457SAndroid Build Coastguard Worker .code_cache = NULL,
541*4bdc9457SAndroid Build Coastguard Worker .weights_cache = NULL,
542*4bdc9457SAndroid Build Coastguard Worker };
543*4bdc9457SAndroid Build Coastguard Worker xnn_weights_cache weights_cache;
544*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) {
545*4bdc9457SAndroid Build Coastguard Worker xnn_init_weights_cache(&weights_cache);
546*4bdc9457SAndroid Build Coastguard Worker caches.weights_cache = &weights_cache;
547*4bdc9457SAndroid Build Coastguard Worker }
548*4bdc9457SAndroid Build Coastguard Worker
549*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
550*4bdc9457SAndroid Build Coastguard Worker xnn_status_success,
551*4bdc9457SAndroid Build Coastguard Worker xnn_create_deconvolution2d_nhwc_qs8(
552*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(),
553*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), stride_height(), stride_width(),
554*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), groups(),
555*4bdc9457SAndroid Build Coastguard Worker group_input_channels(), group_output_channels(),
556*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride(), output_pixel_stride(), input_zero_point,
557*4bdc9457SAndroid Build Coastguard Worker 1.0f /* input scale */, 1.0f /* kernel scale */, kernel.data(),
558*4bdc9457SAndroid Build Coastguard Worker has_bias() ? bias.data() : nullptr, output_zero_point,
559*4bdc9457SAndroid Build Coastguard Worker output_scale, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
560*4bdc9457SAndroid Build Coastguard Worker /*flags=*/0, &caches, &deconvolution_op));
561*4bdc9457SAndroid Build Coastguard Worker
562*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) {
563*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
564*4bdc9457SAndroid Build Coastguard Worker xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
565*4bdc9457SAndroid Build Coastguard Worker }
566*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete deconvolution_op.
567*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_deconvolution_op(deconvolution_op, xnn_delete_operator);
568*4bdc9457SAndroid Build Coastguard Worker
569*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
570*4bdc9457SAndroid Build Coastguard Worker xnn_setup_deconvolution2d_nhwc_qs8(
571*4bdc9457SAndroid Build Coastguard Worker deconvolution_op,
572*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(),
573*4bdc9457SAndroid Build Coastguard Worker adjustment_height(), adjustment_width(),
574*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(),
575*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */));
576*4bdc9457SAndroid Build Coastguard Worker
577*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
578*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(deconvolution_op, nullptr /* thread pool */));
579*4bdc9457SAndroid Build Coastguard Worker
580*4bdc9457SAndroid Build Coastguard Worker VerifyQS8(output, output_ref, output_zero_point);
581*4bdc9457SAndroid Build Coastguard Worker
582*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) {
583*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t deconvolution_op2 = nullptr;
584*4bdc9457SAndroid Build Coastguard Worker size_t old_weights_cache_size = weights_cache.cache.weights.size;
585*4bdc9457SAndroid Build Coastguard Worker
586*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
587*4bdc9457SAndroid Build Coastguard Worker xnn_status_success,
588*4bdc9457SAndroid Build Coastguard Worker xnn_create_deconvolution2d_nhwc_qs8(
589*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(),
590*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), stride_height(), stride_width(),
591*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), groups(),
592*4bdc9457SAndroid Build Coastguard Worker group_input_channels(), group_output_channels(),
593*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride(), output_pixel_stride(), input_zero_point,
594*4bdc9457SAndroid Build Coastguard Worker 1.0f /* input scale */, 1.0f /* kernel scale */, kernel.data(),
595*4bdc9457SAndroid Build Coastguard Worker has_bias() ? bias.data() : nullptr, output_zero_point,
596*4bdc9457SAndroid Build Coastguard Worker output_scale, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
597*4bdc9457SAndroid Build Coastguard Worker /*flags=*/0, &caches, &deconvolution_op2));
598*4bdc9457SAndroid Build Coastguard Worker
599*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete deconvolution_op2.
600*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_deconvolution_op(deconvolution_op2, xnn_delete_operator);
601*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output2(output.size(), INT8_C(0xA5));
602*4bdc9457SAndroid Build Coastguard Worker
603*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
604*4bdc9457SAndroid Build Coastguard Worker xnn_setup_deconvolution2d_nhwc_qs8(
605*4bdc9457SAndroid Build Coastguard Worker deconvolution_op2,
606*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(),
607*4bdc9457SAndroid Build Coastguard Worker adjustment_height(), adjustment_width(),
608*4bdc9457SAndroid Build Coastguard Worker input.data(), output2.data(),
609*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */));
610*4bdc9457SAndroid Build Coastguard Worker
611*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
612*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(deconvolution_op2, nullptr /* thread pool */));
613*4bdc9457SAndroid Build Coastguard Worker
614*4bdc9457SAndroid Build Coastguard Worker VerifyWeightsCache(&weights_cache, old_weights_cache_size);
615*4bdc9457SAndroid Build Coastguard Worker VerifyQS8(output2, output_ref, output_zero_point);
616*4bdc9457SAndroid Build Coastguard Worker xnn_release_weights_cache(&weights_cache);
617*4bdc9457SAndroid Build Coastguard Worker }
618*4bdc9457SAndroid Build Coastguard Worker
619*4bdc9457SAndroid Build Coastguard Worker }
620*4bdc9457SAndroid Build Coastguard Worker }
621*4bdc9457SAndroid Build Coastguard Worker
VerifyQS8(const std::vector<int8_t> & output,const std::vector<double> & output_ref,int8_t output_zero_point)622*4bdc9457SAndroid Build Coastguard Worker void VerifyQS8(const std::vector<int8_t> &output,
623*4bdc9457SAndroid Build Coastguard Worker const std::vector<double> &output_ref,
624*4bdc9457SAndroid Build Coastguard Worker int8_t output_zero_point) const {
625*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
626*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) {
627*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) {
628*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
629*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) {
630*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]), int32_t(qmax() - 0x80))
631*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
632*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]), int32_t(qmin() - 0x80))
633*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
634*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
635*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c],
636*4bdc9457SAndroid Build Coastguard Worker double(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]) - double(output_zero_point),
637*4bdc9457SAndroid Build Coastguard Worker 0.9)
638*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
639*4bdc9457SAndroid Build Coastguard Worker }
640*4bdc9457SAndroid Build Coastguard Worker }
641*4bdc9457SAndroid Build Coastguard Worker }
642*4bdc9457SAndroid Build Coastguard Worker }
643*4bdc9457SAndroid Build Coastguard Worker }
644*4bdc9457SAndroid Build Coastguard Worker }
645*4bdc9457SAndroid Build Coastguard Worker
VerifyWeightsCache(xnn_weights_cache * weights_cache,size_t old_size)646*4bdc9457SAndroid Build Coastguard Worker void VerifyWeightsCache(xnn_weights_cache* weights_cache, size_t old_size) const {
647*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_cache->cache.hits, 1);
648*4bdc9457SAndroid Build Coastguard Worker // Ensure that we did not write more weights to the cache because it was a cache hit.
649*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(old_size, weights_cache->cache.weights.size);
650*4bdc9457SAndroid Build Coastguard Worker };
651*4bdc9457SAndroid Build Coastguard Worker
TestQU8()652*4bdc9457SAndroid Build Coastguard Worker void TestQU8() const {
653*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_type(), WeightsType::Default);
654*4bdc9457SAndroid Build Coastguard Worker
655*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
656*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
657*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i32dist(-10000, 10000);
658*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> u8dist(
659*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
660*4bdc9457SAndroid Build Coastguard Worker
661*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) +
662*4bdc9457SAndroid Build Coastguard Worker (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + groups() * group_input_channels());
663*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels());
664*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> bias(groups() * group_output_channels());
665*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + groups() * group_output_channels());
666*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> accumulators(batch_size() * output_height() * output_width() * groups() * group_output_channels());
667*4bdc9457SAndroid Build Coastguard Worker std::vector<double> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels());
668*4bdc9457SAndroid Build Coastguard Worker
669*4bdc9457SAndroid Build Coastguard Worker const uint8_t input_zero_point = 127;
670*4bdc9457SAndroid Build Coastguard Worker const uint8_t kernel_zero_point = 127;
671*4bdc9457SAndroid Build Coastguard Worker
672*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
673*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
674*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return u8dist(rng); });
675*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); });
676*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT8_C(0xA5));
677*4bdc9457SAndroid Build Coastguard Worker
678*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without renormalization.
679*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) {
680*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
681*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) {
682*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) {
683*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
684*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
685*4bdc9457SAndroid Build Coastguard Worker accumulators[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] =
686*4bdc9457SAndroid Build Coastguard Worker bias[g * group_output_channels() + oc];
687*4bdc9457SAndroid Build Coastguard Worker }
688*4bdc9457SAndroid Build Coastguard Worker }
689*4bdc9457SAndroid Build Coastguard Worker }
690*4bdc9457SAndroid Build Coastguard Worker }
691*4bdc9457SAndroid Build Coastguard Worker }
692*4bdc9457SAndroid Build Coastguard Worker } else {
693*4bdc9457SAndroid Build Coastguard Worker std::fill(accumulators.begin(), accumulators.end(), 0);
694*4bdc9457SAndroid Build Coastguard Worker }
695*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
696*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) {
697*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) {
698*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) {
699*4bdc9457SAndroid Build Coastguard Worker const size_t y = oy + padding_top() - ky * dilation_height();
700*4bdc9457SAndroid Build Coastguard Worker const size_t iy = y / stride_height();
701*4bdc9457SAndroid Build Coastguard Worker if (iy * stride_height() == y && iy < input_height()) {
702*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) {
703*4bdc9457SAndroid Build Coastguard Worker const size_t x = ox + padding_left() - kx * dilation_width();
704*4bdc9457SAndroid Build Coastguard Worker const size_t ix = x / stride_width();
705*4bdc9457SAndroid Build Coastguard Worker if (ix * stride_width() == x && ix < input_width()) {
706*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
707*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
708*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) {
709*4bdc9457SAndroid Build Coastguard Worker accumulators[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
710*4bdc9457SAndroid Build Coastguard Worker (int32_t(input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + g * group_input_channels() + ic]) - int32_t(input_zero_point)) *
711*4bdc9457SAndroid Build Coastguard Worker (int32_t(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]) - int32_t(kernel_zero_point));
712*4bdc9457SAndroid Build Coastguard Worker }
713*4bdc9457SAndroid Build Coastguard Worker }
714*4bdc9457SAndroid Build Coastguard Worker }
715*4bdc9457SAndroid Build Coastguard Worker }
716*4bdc9457SAndroid Build Coastguard Worker }
717*4bdc9457SAndroid Build Coastguard Worker }
718*4bdc9457SAndroid Build Coastguard Worker }
719*4bdc9457SAndroid Build Coastguard Worker }
720*4bdc9457SAndroid Build Coastguard Worker }
721*4bdc9457SAndroid Build Coastguard Worker }
722*4bdc9457SAndroid Build Coastguard Worker
723*4bdc9457SAndroid Build Coastguard Worker // Compute renormalization parameters.
724*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
725*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
726*4bdc9457SAndroid Build Coastguard Worker
727*4bdc9457SAndroid Build Coastguard Worker const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0;
728*4bdc9457SAndroid Build Coastguard Worker const uint8_t output_zero_point = uint8_t(std::max(std::min(
729*4bdc9457SAndroid Build Coastguard Worker lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
730*4bdc9457SAndroid Build Coastguard Worker long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
731*4bdc9457SAndroid Build Coastguard Worker
732*4bdc9457SAndroid Build Coastguard Worker // Renormalize reference results.
733*4bdc9457SAndroid Build Coastguard Worker std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(),
734*4bdc9457SAndroid Build Coastguard Worker [this, output_scale, output_zero_point](int32_t x) -> double {
735*4bdc9457SAndroid Build Coastguard Worker return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax()) - output_zero_point), double(qmin()) - output_zero_point);
736*4bdc9457SAndroid Build Coastguard Worker });
737*4bdc9457SAndroid Build Coastguard Worker
738*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Deconvolution operator.
739*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
740*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t deconvolution_op = nullptr;
741*4bdc9457SAndroid Build Coastguard Worker
742*4bdc9457SAndroid Build Coastguard Worker xnn_caches caches = {
743*4bdc9457SAndroid Build Coastguard Worker .code_cache = NULL,
744*4bdc9457SAndroid Build Coastguard Worker .weights_cache = NULL,
745*4bdc9457SAndroid Build Coastguard Worker };
746*4bdc9457SAndroid Build Coastguard Worker xnn_weights_cache weights_cache;
747*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) {
748*4bdc9457SAndroid Build Coastguard Worker xnn_init_weights_cache(&weights_cache);
749*4bdc9457SAndroid Build Coastguard Worker caches.weights_cache = &weights_cache;
750*4bdc9457SAndroid Build Coastguard Worker }
751*4bdc9457SAndroid Build Coastguard Worker
752*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
753*4bdc9457SAndroid Build Coastguard Worker xnn_status_success,
754*4bdc9457SAndroid Build Coastguard Worker xnn_create_deconvolution2d_nhwc_qu8(
755*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(),
756*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), stride_height(), stride_width(),
757*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), groups(),
758*4bdc9457SAndroid Build Coastguard Worker group_input_channels(), group_output_channels(),
759*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride(), output_pixel_stride(), input_zero_point,
760*4bdc9457SAndroid Build Coastguard Worker 1.0f /* input scale */, kernel_zero_point,
761*4bdc9457SAndroid Build Coastguard Worker 1.0f /* kernel scale */, kernel.data(),
762*4bdc9457SAndroid Build Coastguard Worker has_bias() ? bias.data() : nullptr, output_zero_point,
763*4bdc9457SAndroid Build Coastguard Worker output_scale, qmin(), qmax(),
764*4bdc9457SAndroid Build Coastguard Worker /*flags=*/0, &caches, &deconvolution_op));
765*4bdc9457SAndroid Build Coastguard Worker
766*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) {
767*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
768*4bdc9457SAndroid Build Coastguard Worker xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
769*4bdc9457SAndroid Build Coastguard Worker }
770*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete deconvolution_op.
771*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_deconvolution_op(deconvolution_op, xnn_delete_operator);
772*4bdc9457SAndroid Build Coastguard Worker
773*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
774*4bdc9457SAndroid Build Coastguard Worker xnn_setup_deconvolution2d_nhwc_qu8(
775*4bdc9457SAndroid Build Coastguard Worker deconvolution_op,
776*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(),
777*4bdc9457SAndroid Build Coastguard Worker adjustment_height(), adjustment_width(),
778*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(),
779*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */));
780*4bdc9457SAndroid Build Coastguard Worker
781*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
782*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(deconvolution_op, nullptr /* thread pool */));
783*4bdc9457SAndroid Build Coastguard Worker
784*4bdc9457SAndroid Build Coastguard Worker // Verify results.
785*4bdc9457SAndroid Build Coastguard Worker VerifyQU8(output, output_ref, output_zero_point);
786*4bdc9457SAndroid Build Coastguard Worker
787*4bdc9457SAndroid Build Coastguard Worker
788*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) {
789*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t deconvolution_op2 = nullptr;
790*4bdc9457SAndroid Build Coastguard Worker size_t old_weights_cache_size = weights_cache.cache.weights.size;
791*4bdc9457SAndroid Build Coastguard Worker
792*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
793*4bdc9457SAndroid Build Coastguard Worker xnn_status_success,
794*4bdc9457SAndroid Build Coastguard Worker xnn_create_deconvolution2d_nhwc_qu8(
795*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(),
796*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), stride_height(), stride_width(),
797*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), groups(),
798*4bdc9457SAndroid Build Coastguard Worker group_input_channels(), group_output_channels(),
799*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride(), output_pixel_stride(), input_zero_point,
800*4bdc9457SAndroid Build Coastguard Worker 1.0f /* input scale */, kernel_zero_point,
801*4bdc9457SAndroid Build Coastguard Worker 1.0f /* kernel scale */, kernel.data(),
802*4bdc9457SAndroid Build Coastguard Worker has_bias() ? bias.data() : nullptr, output_zero_point,
803*4bdc9457SAndroid Build Coastguard Worker output_scale, qmin(), qmax(),
804*4bdc9457SAndroid Build Coastguard Worker /*flags=*/0, &caches, &deconvolution_op2));
805*4bdc9457SAndroid Build Coastguard Worker
806*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete deconvolution_op2.
807*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_deconvolution_op(deconvolution_op2, xnn_delete_operator);
808*4bdc9457SAndroid Build Coastguard Worker
809*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
810*4bdc9457SAndroid Build Coastguard Worker xnn_setup_deconvolution2d_nhwc_qu8(
811*4bdc9457SAndroid Build Coastguard Worker deconvolution_op2,
812*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(),
813*4bdc9457SAndroid Build Coastguard Worker adjustment_height(), adjustment_width(),
814*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(),
815*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */));
816*4bdc9457SAndroid Build Coastguard Worker
817*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
818*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(deconvolution_op2, nullptr /* thread pool */));
819*4bdc9457SAndroid Build Coastguard Worker
820*4bdc9457SAndroid Build Coastguard Worker VerifyWeightsCache(&weights_cache, old_weights_cache_size);
821*4bdc9457SAndroid Build Coastguard Worker VerifyQU8(output, output_ref, output_zero_point);
822*4bdc9457SAndroid Build Coastguard Worker xnn_release_weights_cache(&weights_cache);
823*4bdc9457SAndroid Build Coastguard Worker }
824*4bdc9457SAndroid Build Coastguard Worker }
825*4bdc9457SAndroid Build Coastguard Worker }
826*4bdc9457SAndroid Build Coastguard Worker
VerifyQU8(const std::vector<uint8_t> & output,const std::vector<double> & output_ref,uint8_t output_zero_point)827*4bdc9457SAndroid Build Coastguard Worker void VerifyQU8(const std::vector<uint8_t> &output,
828*4bdc9457SAndroid Build Coastguard Worker const std::vector<double> &output_ref,
829*4bdc9457SAndroid Build Coastguard Worker uint8_t output_zero_point) const {
830*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
831*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) {
832*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) {
833*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
834*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) {
835*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]), int32_t(qmax()))
836*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
837*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]), int32_t(qmin()))
838*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
839*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
840*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c],
841*4bdc9457SAndroid Build Coastguard Worker double(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]) - double(output_zero_point),
842*4bdc9457SAndroid Build Coastguard Worker 0.9)
843*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
844*4bdc9457SAndroid Build Coastguard Worker }
845*4bdc9457SAndroid Build Coastguard Worker }
846*4bdc9457SAndroid Build Coastguard Worker }
847*4bdc9457SAndroid Build Coastguard Worker }
848*4bdc9457SAndroid Build Coastguard Worker }
849*4bdc9457SAndroid Build Coastguard Worker }
850*4bdc9457SAndroid Build Coastguard Worker
TestF16()851*4bdc9457SAndroid Build Coastguard Worker void TestF16() const {
852*4bdc9457SAndroid Build Coastguard Worker switch (weights_type()) {
853*4bdc9457SAndroid Build Coastguard Worker case WeightsType::Default:
854*4bdc9457SAndroid Build Coastguard Worker break;
855*4bdc9457SAndroid Build Coastguard Worker case WeightsType::FP32:
856*4bdc9457SAndroid Build Coastguard Worker break;
857*4bdc9457SAndroid Build Coastguard Worker default:
858*4bdc9457SAndroid Build Coastguard Worker GTEST_FAIL() << "unexpected weights type";
859*4bdc9457SAndroid Build Coastguard Worker }
860*4bdc9457SAndroid Build Coastguard Worker
861*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
862*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
863*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.1f, 1.0f);
864*4bdc9457SAndroid Build Coastguard Worker
865*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) +
866*4bdc9457SAndroid Build Coastguard Worker (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + groups() * group_input_channels());
867*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels());
868*4bdc9457SAndroid Build Coastguard Worker std::vector<float> kernel_as_float(kernel.size());
869*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> bias(groups() * group_output_channels());
870*4bdc9457SAndroid Build Coastguard Worker std::vector<float> bias_as_float(bias.size());
871*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + groups() * group_output_channels());
872*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels());
873*4bdc9457SAndroid Build Coastguard Worker
874*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
875*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
876*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
877*4bdc9457SAndroid Build Coastguard Worker std::transform(kernel.cbegin(), kernel.cend(), kernel_as_float.begin(), fp16_ieee_to_fp32_value);
878*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
879*4bdc9457SAndroid Build Coastguard Worker std::transform(bias.cbegin(), bias.cend(), bias_as_float.begin(), fp16_ieee_to_fp32_value);
880*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
881*4bdc9457SAndroid Build Coastguard Worker
882*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping.
883*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) {
884*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
885*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) {
886*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) {
887*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
888*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
889*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] =
890*4bdc9457SAndroid Build Coastguard Worker bias_as_float[g * group_output_channels() + oc];
891*4bdc9457SAndroid Build Coastguard Worker }
892*4bdc9457SAndroid Build Coastguard Worker }
893*4bdc9457SAndroid Build Coastguard Worker }
894*4bdc9457SAndroid Build Coastguard Worker }
895*4bdc9457SAndroid Build Coastguard Worker }
896*4bdc9457SAndroid Build Coastguard Worker } else {
897*4bdc9457SAndroid Build Coastguard Worker std::fill(output_ref.begin(), output_ref.end(), 0.0f);
898*4bdc9457SAndroid Build Coastguard Worker }
899*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
900*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) {
901*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) {
902*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) {
903*4bdc9457SAndroid Build Coastguard Worker const size_t y = oy + padding_top() - ky * dilation_height();
904*4bdc9457SAndroid Build Coastguard Worker const size_t iy = y / stride_height();
905*4bdc9457SAndroid Build Coastguard Worker if (iy * stride_height() == y && iy < input_height()) {
906*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) {
907*4bdc9457SAndroid Build Coastguard Worker const size_t x = ox + padding_left() - kx * dilation_width();
908*4bdc9457SAndroid Build Coastguard Worker const size_t ix = x / stride_width();
909*4bdc9457SAndroid Build Coastguard Worker if (ix * stride_width() == x && ix < input_width()) {
910*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
911*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
912*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) {
913*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
914*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + g * group_input_channels() + ic]) *
915*4bdc9457SAndroid Build Coastguard Worker kernel_as_float[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic];
916*4bdc9457SAndroid Build Coastguard Worker }
917*4bdc9457SAndroid Build Coastguard Worker }
918*4bdc9457SAndroid Build Coastguard Worker }
919*4bdc9457SAndroid Build Coastguard Worker }
920*4bdc9457SAndroid Build Coastguard Worker }
921*4bdc9457SAndroid Build Coastguard Worker }
922*4bdc9457SAndroid Build Coastguard Worker }
923*4bdc9457SAndroid Build Coastguard Worker }
924*4bdc9457SAndroid Build Coastguard Worker }
925*4bdc9457SAndroid Build Coastguard Worker }
926*4bdc9457SAndroid Build Coastguard Worker
927*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters.
928*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
929*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
930*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min;
931*4bdc9457SAndroid Build Coastguard Worker float output_min = accumulated_min + accumulated_range / 255.0f * float(qmin());
932*4bdc9457SAndroid Build Coastguard Worker float output_max = accumulated_max - accumulated_range / 255.0f * float(255 - qmax());
933*4bdc9457SAndroid Build Coastguard Worker output_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_min));
934*4bdc9457SAndroid Build Coastguard Worker output_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_max));
935*4bdc9457SAndroid Build Coastguard Worker if (accumulated_range == 0.0f) {
936*4bdc9457SAndroid Build Coastguard Worker output_min = -std::numeric_limits<float>::infinity();
937*4bdc9457SAndroid Build Coastguard Worker output_max = +std::numeric_limits<float>::infinity();
938*4bdc9457SAndroid Build Coastguard Worker }
939*4bdc9457SAndroid Build Coastguard Worker if (qmin() == std::numeric_limits<uint8_t>::min()) {
940*4bdc9457SAndroid Build Coastguard Worker output_min = -std::numeric_limits<float>::infinity();
941*4bdc9457SAndroid Build Coastguard Worker }
942*4bdc9457SAndroid Build Coastguard Worker if (qmax() == std::numeric_limits<uint8_t>::max()) {
943*4bdc9457SAndroid Build Coastguard Worker output_max = +std::numeric_limits<float>::infinity();
944*4bdc9457SAndroid Build Coastguard Worker }
945*4bdc9457SAndroid Build Coastguard Worker
946*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results.
947*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) {
948*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min);
949*4bdc9457SAndroid Build Coastguard Worker }
950*4bdc9457SAndroid Build Coastguard Worker
951*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Deconvolution operator.
952*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
953*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t deconvolution_op = nullptr;
954*4bdc9457SAndroid Build Coastguard Worker
955*4bdc9457SAndroid Build Coastguard Worker xnn_caches caches = {
956*4bdc9457SAndroid Build Coastguard Worker .code_cache = NULL,
957*4bdc9457SAndroid Build Coastguard Worker .weights_cache = NULL,
958*4bdc9457SAndroid Build Coastguard Worker };
959*4bdc9457SAndroid Build Coastguard Worker xnn_weights_cache weights_cache;
960*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) {
961*4bdc9457SAndroid Build Coastguard Worker xnn_init_weights_cache(&weights_cache);
962*4bdc9457SAndroid Build Coastguard Worker caches.weights_cache = &weights_cache;
963*4bdc9457SAndroid Build Coastguard Worker }
964*4bdc9457SAndroid Build Coastguard Worker
965*4bdc9457SAndroid Build Coastguard Worker const void* kernel_data = kernel.data();
966*4bdc9457SAndroid Build Coastguard Worker const void* bias_data = bias.data();
967*4bdc9457SAndroid Build Coastguard Worker if (weights_type() == WeightsType::FP32) {
968*4bdc9457SAndroid Build Coastguard Worker kernel_data = kernel_as_float.data();
969*4bdc9457SAndroid Build Coastguard Worker bias_data = bias_as_float.data();
970*4bdc9457SAndroid Build Coastguard Worker }
971*4bdc9457SAndroid Build Coastguard Worker uint32_t flags = 0;
972*4bdc9457SAndroid Build Coastguard Worker if (weights_type() == WeightsType::FP32) {
973*4bdc9457SAndroid Build Coastguard Worker flags |= XNN_FLAG_FP32_STATIC_WEIGHTS;
974*4bdc9457SAndroid Build Coastguard Worker }
975*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_create_deconvolution2d_nhwc_f16(
976*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(),
977*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), stride_height(), stride_width(),
978*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), groups(),
979*4bdc9457SAndroid Build Coastguard Worker group_input_channels(), group_output_channels(),
980*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride(), output_pixel_stride(),
981*4bdc9457SAndroid Build Coastguard Worker kernel_data, has_bias() ? bias_data : nullptr,
982*4bdc9457SAndroid Build Coastguard Worker output_min, output_max,
983*4bdc9457SAndroid Build Coastguard Worker flags, &caches, &deconvolution_op);
984*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) {
985*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP();
986*4bdc9457SAndroid Build Coastguard Worker }
987*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status);
988*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, deconvolution_op);
989*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) {
990*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
991*4bdc9457SAndroid Build Coastguard Worker xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
992*4bdc9457SAndroid Build Coastguard Worker }
993*4bdc9457SAndroid Build Coastguard Worker
994*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete deconvolution_op.
995*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_deconvolution_op(deconvolution_op, xnn_delete_operator);
996*4bdc9457SAndroid Build Coastguard Worker
997*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
998*4bdc9457SAndroid Build Coastguard Worker xnn_setup_deconvolution2d_nhwc_f16(
999*4bdc9457SAndroid Build Coastguard Worker deconvolution_op,
1000*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(),
1001*4bdc9457SAndroid Build Coastguard Worker adjustment_height(), adjustment_width(),
1002*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(),
1003*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */));
1004*4bdc9457SAndroid Build Coastguard Worker
1005*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1006*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(deconvolution_op, nullptr /* thread pool */));
1007*4bdc9457SAndroid Build Coastguard Worker
1008*4bdc9457SAndroid Build Coastguard Worker VerifyF16(output, output_ref, output_max, output_min);
1009*4bdc9457SAndroid Build Coastguard Worker
1010*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) {
1011*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t deconvolution_op2 = nullptr;
1012*4bdc9457SAndroid Build Coastguard Worker size_t old_weights_cache_size = weights_cache.cache.weights.size;
1013*4bdc9457SAndroid Build Coastguard Worker
1014*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1015*4bdc9457SAndroid Build Coastguard Worker xnn_create_deconvolution2d_nhwc_f16(
1016*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(),
1017*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), stride_height(), stride_width(),
1018*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), groups(),
1019*4bdc9457SAndroid Build Coastguard Worker group_input_channels(), group_output_channels(),
1020*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride(), output_pixel_stride(),
1021*4bdc9457SAndroid Build Coastguard Worker kernel_data, has_bias() ? bias_data : nullptr,
1022*4bdc9457SAndroid Build Coastguard Worker output_min, output_max,
1023*4bdc9457SAndroid Build Coastguard Worker flags, &caches, &deconvolution_op2));
1024*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, deconvolution_op2);
1025*4bdc9457SAndroid Build Coastguard Worker
1026*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete deconvolution_op2.
1027*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_deconvolution_op(deconvolution_op2, xnn_delete_operator);
1028*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output2(output.size(), UINT16_C(0x7E00) /* NaN */);
1029*4bdc9457SAndroid Build Coastguard Worker
1030*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1031*4bdc9457SAndroid Build Coastguard Worker xnn_setup_deconvolution2d_nhwc_f16(
1032*4bdc9457SAndroid Build Coastguard Worker deconvolution_op2,
1033*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(),
1034*4bdc9457SAndroid Build Coastguard Worker adjustment_height(), adjustment_width(),
1035*4bdc9457SAndroid Build Coastguard Worker input.data(), output2.data(),
1036*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */));
1037*4bdc9457SAndroid Build Coastguard Worker
1038*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1039*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(deconvolution_op2, nullptr /* thread pool */));
1040*4bdc9457SAndroid Build Coastguard Worker
1041*4bdc9457SAndroid Build Coastguard Worker VerifyWeightsCache(&weights_cache, old_weights_cache_size);
1042*4bdc9457SAndroid Build Coastguard Worker VerifyF16(output2, output_ref, output_max, output_min);
1043*4bdc9457SAndroid Build Coastguard Worker xnn_release_weights_cache(&weights_cache);
1044*4bdc9457SAndroid Build Coastguard Worker }
1045*4bdc9457SAndroid Build Coastguard Worker }
1046*4bdc9457SAndroid Build Coastguard Worker }
1047*4bdc9457SAndroid Build Coastguard Worker
VerifyF16(const std::vector<uint16_t> & output,const std::vector<float> & output_ref,float output_max,float output_min)1048*4bdc9457SAndroid Build Coastguard Worker void VerifyF16(const std::vector<uint16_t> &output,
1049*4bdc9457SAndroid Build Coastguard Worker const std::vector<float> &output_ref,
1050*4bdc9457SAndroid Build Coastguard Worker float output_max,
1051*4bdc9457SAndroid Build Coastguard Worker float output_min) const {
1052*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
1053*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) {
1054*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) {
1055*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
1056*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) {
1057*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]), output_min)
1058*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1059*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]), output_max)
1060*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1061*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
1062*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]),
1063*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c],
1064*4bdc9457SAndroid Build Coastguard Worker 1.0e-2f * std::abs(output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c]))
1065*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1066*4bdc9457SAndroid Build Coastguard Worker }
1067*4bdc9457SAndroid Build Coastguard Worker }
1068*4bdc9457SAndroid Build Coastguard Worker }
1069*4bdc9457SAndroid Build Coastguard Worker }
1070*4bdc9457SAndroid Build Coastguard Worker }
1071*4bdc9457SAndroid Build Coastguard Worker }
1072*4bdc9457SAndroid Build Coastguard Worker
TestF32()1073*4bdc9457SAndroid Build Coastguard Worker void TestF32() const {
1074*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_type(), WeightsType::Default);
1075*4bdc9457SAndroid Build Coastguard Worker
1076*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
1077*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
1078*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.1f, 1.0f);
1079*4bdc9457SAndroid Build Coastguard Worker
1080*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
1081*4bdc9457SAndroid Build Coastguard Worker (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + groups() * group_input_channels());
1082*4bdc9457SAndroid Build Coastguard Worker std::vector<float> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels());
1083*4bdc9457SAndroid Build Coastguard Worker std::vector<float> bias(groups() * group_output_channels());
1084*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + groups() * group_output_channels());
1085*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels());
1086*4bdc9457SAndroid Build Coastguard Worker
1087*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
1088*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
1089*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return f32dist(rng); });
1090*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
1091*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), nanf(""));
1092*4bdc9457SAndroid Build Coastguard Worker
1093*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping.
1094*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) {
1095*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
1096*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) {
1097*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) {
1098*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
1099*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
1100*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] =
1101*4bdc9457SAndroid Build Coastguard Worker bias[g * group_output_channels() + oc];
1102*4bdc9457SAndroid Build Coastguard Worker }
1103*4bdc9457SAndroid Build Coastguard Worker }
1104*4bdc9457SAndroid Build Coastguard Worker }
1105*4bdc9457SAndroid Build Coastguard Worker }
1106*4bdc9457SAndroid Build Coastguard Worker }
1107*4bdc9457SAndroid Build Coastguard Worker } else {
1108*4bdc9457SAndroid Build Coastguard Worker std::fill(output_ref.begin(), output_ref.end(), 0.0f);
1109*4bdc9457SAndroid Build Coastguard Worker }
1110*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
1111*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) {
1112*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) {
1113*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) {
1114*4bdc9457SAndroid Build Coastguard Worker const size_t y = oy + padding_top() - ky * dilation_height();
1115*4bdc9457SAndroid Build Coastguard Worker const size_t iy = y / stride_height();
1116*4bdc9457SAndroid Build Coastguard Worker if (iy * stride_height() == y && iy < input_height()) {
1117*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) {
1118*4bdc9457SAndroid Build Coastguard Worker const size_t x = ox + padding_left() - kx * dilation_width();
1119*4bdc9457SAndroid Build Coastguard Worker const size_t ix = x / stride_width();
1120*4bdc9457SAndroid Build Coastguard Worker if (ix * stride_width() == x && ix < input_width()) {
1121*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
1122*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
1123*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) {
1124*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
1125*4bdc9457SAndroid Build Coastguard Worker input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + g * group_input_channels() + ic] *
1126*4bdc9457SAndroid Build Coastguard Worker kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic];
1127*4bdc9457SAndroid Build Coastguard Worker }
1128*4bdc9457SAndroid Build Coastguard Worker }
1129*4bdc9457SAndroid Build Coastguard Worker }
1130*4bdc9457SAndroid Build Coastguard Worker }
1131*4bdc9457SAndroid Build Coastguard Worker }
1132*4bdc9457SAndroid Build Coastguard Worker }
1133*4bdc9457SAndroid Build Coastguard Worker }
1134*4bdc9457SAndroid Build Coastguard Worker }
1135*4bdc9457SAndroid Build Coastguard Worker }
1136*4bdc9457SAndroid Build Coastguard Worker }
1137*4bdc9457SAndroid Build Coastguard Worker
1138*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters.
1139*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
1140*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
1141*4bdc9457SAndroid Build Coastguard Worker
1142*4bdc9457SAndroid Build Coastguard Worker const float output_min = qmin() == 0 ? -std::numeric_limits<float>::infinity() :
1143*4bdc9457SAndroid Build Coastguard Worker accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1144*4bdc9457SAndroid Build Coastguard Worker const float output_max = qmax() == 255 ? std::numeric_limits<float>::infinity() :
1145*4bdc9457SAndroid Build Coastguard Worker accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1146*4bdc9457SAndroid Build Coastguard Worker
1147*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results.
1148*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) {
1149*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min);
1150*4bdc9457SAndroid Build Coastguard Worker }
1151*4bdc9457SAndroid Build Coastguard Worker
1152*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Deconvolution operator.
1153*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
1154*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t deconvolution_op = nullptr;
1155*4bdc9457SAndroid Build Coastguard Worker
1156*4bdc9457SAndroid Build Coastguard Worker xnn_caches caches = {
1157*4bdc9457SAndroid Build Coastguard Worker .code_cache = NULL,
1158*4bdc9457SAndroid Build Coastguard Worker .weights_cache = NULL,
1159*4bdc9457SAndroid Build Coastguard Worker };
1160*4bdc9457SAndroid Build Coastguard Worker xnn_weights_cache weights_cache;
1161*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) {
1162*4bdc9457SAndroid Build Coastguard Worker xnn_init_weights_cache(&weights_cache);
1163*4bdc9457SAndroid Build Coastguard Worker caches.weights_cache = &weights_cache;
1164*4bdc9457SAndroid Build Coastguard Worker }
1165*4bdc9457SAndroid Build Coastguard Worker
1166*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
1167*4bdc9457SAndroid Build Coastguard Worker xnn_status_success,
1168*4bdc9457SAndroid Build Coastguard Worker xnn_create_deconvolution2d_nhwc_f32(
1169*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(),
1170*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), stride_height(), stride_width(),
1171*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), groups(),
1172*4bdc9457SAndroid Build Coastguard Worker group_input_channels(), group_output_channels(),
1173*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride(), output_pixel_stride(), kernel.data(),
1174*4bdc9457SAndroid Build Coastguard Worker has_bias() ? bias.data() : nullptr, output_min, output_max,
1175*4bdc9457SAndroid Build Coastguard Worker /*flags=*/0, &caches, &deconvolution_op));
1176*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) {
1177*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1178*4bdc9457SAndroid Build Coastguard Worker xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
1179*4bdc9457SAndroid Build Coastguard Worker }
1180*4bdc9457SAndroid Build Coastguard Worker
1181*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete deconvolution_op.
1182*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_deconvolution_op(deconvolution_op, xnn_delete_operator);
1183*4bdc9457SAndroid Build Coastguard Worker
1184*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1185*4bdc9457SAndroid Build Coastguard Worker xnn_setup_deconvolution2d_nhwc_f32(
1186*4bdc9457SAndroid Build Coastguard Worker deconvolution_op,
1187*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(),
1188*4bdc9457SAndroid Build Coastguard Worker adjustment_height(), adjustment_width(),
1189*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(),
1190*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */));
1191*4bdc9457SAndroid Build Coastguard Worker
1192*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1193*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(deconvolution_op, nullptr /* thread pool */));
1194*4bdc9457SAndroid Build Coastguard Worker
1195*4bdc9457SAndroid Build Coastguard Worker VerifyF32(output, output_ref, output_max, output_min);
1196*4bdc9457SAndroid Build Coastguard Worker
1197*4bdc9457SAndroid Build Coastguard Worker if (use_weights_cache()) {
1198*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t deconvolution_op2 = nullptr;
1199*4bdc9457SAndroid Build Coastguard Worker size_t old_weights_cache_size = weights_cache.cache.weights.size;
1200*4bdc9457SAndroid Build Coastguard Worker
1201*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
1202*4bdc9457SAndroid Build Coastguard Worker xnn_status_success,
1203*4bdc9457SAndroid Build Coastguard Worker xnn_create_deconvolution2d_nhwc_f32(
1204*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(),
1205*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), stride_height(), stride_width(),
1206*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), groups(),
1207*4bdc9457SAndroid Build Coastguard Worker group_input_channels(), group_output_channels(),
1208*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride(), output_pixel_stride(), kernel.data(),
1209*4bdc9457SAndroid Build Coastguard Worker has_bias() ? bias.data() : nullptr, output_min, output_max,
1210*4bdc9457SAndroid Build Coastguard Worker /*flags=*/0, &caches, &deconvolution_op2));
1211*4bdc9457SAndroid Build Coastguard Worker
1212*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete deconvolution_op2.
1213*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_deconvolution_op(deconvolution_op2, xnn_delete_operator);
1214*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output2(output.size(), nanf(""));
1215*4bdc9457SAndroid Build Coastguard Worker
1216*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1217*4bdc9457SAndroid Build Coastguard Worker xnn_setup_deconvolution2d_nhwc_f32(
1218*4bdc9457SAndroid Build Coastguard Worker deconvolution_op2,
1219*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(),
1220*4bdc9457SAndroid Build Coastguard Worker adjustment_height(), adjustment_width(),
1221*4bdc9457SAndroid Build Coastguard Worker input.data(), output2.data(),
1222*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */));
1223*4bdc9457SAndroid Build Coastguard Worker
1224*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1225*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(deconvolution_op2, nullptr /* thread pool */));
1226*4bdc9457SAndroid Build Coastguard Worker
1227*4bdc9457SAndroid Build Coastguard Worker VerifyWeightsCache(&weights_cache, old_weights_cache_size);
1228*4bdc9457SAndroid Build Coastguard Worker VerifyF32(output2, output_ref, output_max, output_min);
1229*4bdc9457SAndroid Build Coastguard Worker xnn_release_weights_cache(&weights_cache);
1230*4bdc9457SAndroid Build Coastguard Worker }
1231*4bdc9457SAndroid Build Coastguard Worker }
1232*4bdc9457SAndroid Build Coastguard Worker }
1233*4bdc9457SAndroid Build Coastguard Worker
1234*4bdc9457SAndroid Build Coastguard Worker // A variation of TestF32 that stresses the weights cache. All the operator creation needs to happen before
1235*4bdc9457SAndroid Build Coastguard Worker // finalization and setup.
StressWeightsCacheTestF32()1236*4bdc9457SAndroid Build Coastguard Worker void StressWeightsCacheTestF32() const {
1237*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_type(), WeightsType::Default);
1238*4bdc9457SAndroid Build Coastguard Worker
1239*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
1240*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
1241*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.1f, 1.0f);
1242*4bdc9457SAndroid Build Coastguard Worker
1243*4bdc9457SAndroid Build Coastguard Worker xnn_caches caches = {
1244*4bdc9457SAndroid Build Coastguard Worker .code_cache = NULL,
1245*4bdc9457SAndroid Build Coastguard Worker .weights_cache = NULL,
1246*4bdc9457SAndroid Build Coastguard Worker };
1247*4bdc9457SAndroid Build Coastguard Worker xnn_weights_cache weights_cache;
1248*4bdc9457SAndroid Build Coastguard Worker xnn_init_weights_cache(&weights_cache);
1249*4bdc9457SAndroid Build Coastguard Worker caches.weights_cache = &weights_cache;
1250*4bdc9457SAndroid Build Coastguard Worker void* old_weights_cache_start = weights_cache.cache.weights.start;
1251*4bdc9457SAndroid Build Coastguard Worker size_t old_weights_cache_size = weights_cache.cache.weights.size;
1252*4bdc9457SAndroid Build Coastguard Worker
1253*4bdc9457SAndroid Build Coastguard Worker std::vector<xnn_operator_t> operators;
1254*4bdc9457SAndroid Build Coastguard Worker operators.reserve(iterations());
1255*4bdc9457SAndroid Build Coastguard Worker std::vector<std::vector<float>> inputs;
1256*4bdc9457SAndroid Build Coastguard Worker inputs.reserve(iterations());
1257*4bdc9457SAndroid Build Coastguard Worker std::vector<std::vector<float>> outputs;
1258*4bdc9457SAndroid Build Coastguard Worker outputs.reserve(iterations());
1259*4bdc9457SAndroid Build Coastguard Worker std::vector<std::vector<float>> output_refs;
1260*4bdc9457SAndroid Build Coastguard Worker output_refs.reserve(iterations());
1261*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_mins;
1262*4bdc9457SAndroid Build Coastguard Worker output_mins.reserve(iterations());
1263*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_maxs;
1264*4bdc9457SAndroid Build Coastguard Worker output_maxs.reserve(iterations());
1265*4bdc9457SAndroid Build Coastguard Worker
1266*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
1267*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
1268*4bdc9457SAndroid Build Coastguard Worker (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + groups() * group_input_channels());
1269*4bdc9457SAndroid Build Coastguard Worker std::vector<float> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels());
1270*4bdc9457SAndroid Build Coastguard Worker std::vector<float> bias(groups() * group_output_channels());
1271*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output((batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + groups() * group_output_channels());
1272*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels());
1273*4bdc9457SAndroid Build Coastguard Worker
1274*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
1275*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return f32dist(rng); });
1276*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
1277*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), nanf(""));
1278*4bdc9457SAndroid Build Coastguard Worker
1279*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping.
1280*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) {
1281*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
1282*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) {
1283*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) {
1284*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
1285*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
1286*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] =
1287*4bdc9457SAndroid Build Coastguard Worker bias[g * group_output_channels() + oc];
1288*4bdc9457SAndroid Build Coastguard Worker }
1289*4bdc9457SAndroid Build Coastguard Worker }
1290*4bdc9457SAndroid Build Coastguard Worker }
1291*4bdc9457SAndroid Build Coastguard Worker }
1292*4bdc9457SAndroid Build Coastguard Worker }
1293*4bdc9457SAndroid Build Coastguard Worker } else {
1294*4bdc9457SAndroid Build Coastguard Worker std::fill(output_ref.begin(), output_ref.end(), 0.0f);
1295*4bdc9457SAndroid Build Coastguard Worker }
1296*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
1297*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) {
1298*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) {
1299*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) {
1300*4bdc9457SAndroid Build Coastguard Worker const size_t y = oy + padding_top() - ky * dilation_height();
1301*4bdc9457SAndroid Build Coastguard Worker const size_t iy = y / stride_height();
1302*4bdc9457SAndroid Build Coastguard Worker if (iy * stride_height() == y && iy < input_height()) {
1303*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) {
1304*4bdc9457SAndroid Build Coastguard Worker const size_t x = ox + padding_left() - kx * dilation_width();
1305*4bdc9457SAndroid Build Coastguard Worker const size_t ix = x / stride_width();
1306*4bdc9457SAndroid Build Coastguard Worker if (ix * stride_width() == x && ix < input_width()) {
1307*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
1308*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
1309*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) {
1310*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
1311*4bdc9457SAndroid Build Coastguard Worker input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + g * group_input_channels() + ic] *
1312*4bdc9457SAndroid Build Coastguard Worker kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic];
1313*4bdc9457SAndroid Build Coastguard Worker }
1314*4bdc9457SAndroid Build Coastguard Worker }
1315*4bdc9457SAndroid Build Coastguard Worker }
1316*4bdc9457SAndroid Build Coastguard Worker }
1317*4bdc9457SAndroid Build Coastguard Worker }
1318*4bdc9457SAndroid Build Coastguard Worker }
1319*4bdc9457SAndroid Build Coastguard Worker }
1320*4bdc9457SAndroid Build Coastguard Worker }
1321*4bdc9457SAndroid Build Coastguard Worker }
1322*4bdc9457SAndroid Build Coastguard Worker }
1323*4bdc9457SAndroid Build Coastguard Worker
1324*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters.
1325*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
1326*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
1327*4bdc9457SAndroid Build Coastguard Worker
1328*4bdc9457SAndroid Build Coastguard Worker const float output_min = qmin() == 0 ? -std::numeric_limits<float>::infinity() :
1329*4bdc9457SAndroid Build Coastguard Worker accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1330*4bdc9457SAndroid Build Coastguard Worker const float output_max = qmax() == 255 ? std::numeric_limits<float>::infinity() :
1331*4bdc9457SAndroid Build Coastguard Worker accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1332*4bdc9457SAndroid Build Coastguard Worker output_mins.push_back(output_min);
1333*4bdc9457SAndroid Build Coastguard Worker output_maxs.push_back(output_max);
1334*4bdc9457SAndroid Build Coastguard Worker
1335*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results.
1336*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) {
1337*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min);
1338*4bdc9457SAndroid Build Coastguard Worker }
1339*4bdc9457SAndroid Build Coastguard Worker
1340*4bdc9457SAndroid Build Coastguard Worker // Create, setup, run, and destroy Deconvolution operator.
1341*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
1342*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t deconvolution_op = nullptr;
1343*4bdc9457SAndroid Build Coastguard Worker
1344*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(
1345*4bdc9457SAndroid Build Coastguard Worker xnn_status_success,
1346*4bdc9457SAndroid Build Coastguard Worker xnn_create_deconvolution2d_nhwc_f32(
1347*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(),
1348*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(), stride_height(), stride_width(),
1349*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(), groups(),
1350*4bdc9457SAndroid Build Coastguard Worker group_input_channels(), group_output_channels(),
1351*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride(), output_pixel_stride(), kernel.data(),
1352*4bdc9457SAndroid Build Coastguard Worker has_bias() ? bias.data() : nullptr, output_min, output_max,
1353*4bdc9457SAndroid Build Coastguard Worker /*flags=*/0, &caches, &deconvolution_op));
1354*4bdc9457SAndroid Build Coastguard Worker
1355*4bdc9457SAndroid Build Coastguard Worker operators.push_back(std::move(deconvolution_op));
1356*4bdc9457SAndroid Build Coastguard Worker inputs.push_back(std::move(input));
1357*4bdc9457SAndroid Build Coastguard Worker outputs.push_back(std::move(output));
1358*4bdc9457SAndroid Build Coastguard Worker output_refs.push_back(std::move(output_ref));
1359*4bdc9457SAndroid Build Coastguard Worker }
1360*4bdc9457SAndroid Build Coastguard Worker
1361*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1362*4bdc9457SAndroid Build Coastguard Worker xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
1363*4bdc9457SAndroid Build Coastguard Worker
1364*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
1365*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t deconvolution_op = operators[iteration];
1366*4bdc9457SAndroid Build Coastguard Worker
1367*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1368*4bdc9457SAndroid Build Coastguard Worker xnn_setup_deconvolution2d_nhwc_f32(
1369*4bdc9457SAndroid Build Coastguard Worker deconvolution_op,
1370*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(),
1371*4bdc9457SAndroid Build Coastguard Worker adjustment_height(), adjustment_width(),
1372*4bdc9457SAndroid Build Coastguard Worker inputs[iteration].data(), outputs[iteration].data(),
1373*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */));
1374*4bdc9457SAndroid Build Coastguard Worker
1375*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1376*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(deconvolution_op, nullptr /* thread pool */));
1377*4bdc9457SAndroid Build Coastguard Worker
1378*4bdc9457SAndroid Build Coastguard Worker VerifyF32(outputs[iteration],
1379*4bdc9457SAndroid Build Coastguard Worker output_refs[iteration],
1380*4bdc9457SAndroid Build Coastguard Worker output_maxs[iteration],
1381*4bdc9457SAndroid Build Coastguard Worker output_mins[iteration]);
1382*4bdc9457SAndroid Build Coastguard Worker xnn_delete_operator(deconvolution_op);
1383*4bdc9457SAndroid Build Coastguard Worker }
1384*4bdc9457SAndroid Build Coastguard Worker
1385*4bdc9457SAndroid Build Coastguard Worker // Check that the weights cache grew and moved. If these assertion fails,
1386*4bdc9457SAndroid Build Coastguard Worker // might have to increase the number of test iterations.
1387*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(old_weights_cache_start, weights_cache.cache.weights.start);
1388*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(old_weights_cache_size, weights_cache.cache.weights.size);
1389*4bdc9457SAndroid Build Coastguard Worker // Since the weights are randomized, it is very unlikely to have any hits.
1390*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(iterations(), weights_cache.cache.misses);
1391*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(0, weights_cache.cache.hits);
1392*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(iterations(), weights_cache.cache.num_entries);
1393*4bdc9457SAndroid Build Coastguard Worker xnn_release_weights_cache(&weights_cache);
1394*4bdc9457SAndroid Build Coastguard Worker }
1395*4bdc9457SAndroid Build Coastguard Worker
VerifyF32(const std::vector<float> & output,const std::vector<float> & output_ref,float output_max,float output_min)1396*4bdc9457SAndroid Build Coastguard Worker void VerifyF32(const std::vector<float> &output,
1397*4bdc9457SAndroid Build Coastguard Worker const std::vector<float> &output_ref,
1398*4bdc9457SAndroid Build Coastguard Worker float output_max,
1399*4bdc9457SAndroid Build Coastguard Worker float output_min) const {
1400*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
1401*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) {
1402*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) {
1403*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
1404*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) {
1405*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c], output_min)
1406*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1407*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c], output_max)
1408*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1409*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
1410*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c],
1411*4bdc9457SAndroid Build Coastguard Worker output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c],
1412*4bdc9457SAndroid Build Coastguard Worker 1.0e-4 * std::abs(output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c]))
1413*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1414*4bdc9457SAndroid Build Coastguard Worker }
1415*4bdc9457SAndroid Build Coastguard Worker }
1416*4bdc9457SAndroid Build Coastguard Worker }
1417*4bdc9457SAndroid Build Coastguard Worker }
1418*4bdc9457SAndroid Build Coastguard Worker }
1419*4bdc9457SAndroid Build Coastguard Worker }
1420*4bdc9457SAndroid Build Coastguard Worker
TestSetupQS8()1421*4bdc9457SAndroid Build Coastguard Worker void TestSetupQS8() const {
1422*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_type(), WeightsType::Default);
1423*4bdc9457SAndroid Build Coastguard Worker
1424*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
1425*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
1426*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i32dist(-10000, 10000);
1427*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i8dist(
1428*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
1429*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> w8dist(
1430*4bdc9457SAndroid Build Coastguard Worker -std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max());
1431*4bdc9457SAndroid Build Coastguard Worker
1432*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) + std::max(
1433*4bdc9457SAndroid Build Coastguard Worker (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + groups() * group_input_channels(),
1434*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_input_height() * next_input_width() - 1) * input_pixel_stride() + groups() * group_input_channels()));
1435*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels());
1436*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> bias(groups() * group_output_channels());
1437*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> output(std::max(
1438*4bdc9457SAndroid Build Coastguard Worker (batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + groups() * group_output_channels(),
1439*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_output_height() * next_output_width() - 1) * output_pixel_stride() + groups() * group_output_channels()));
1440*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> accumulators(batch_size() * output_height() * output_width() * groups() * group_output_channels());
1441*4bdc9457SAndroid Build Coastguard Worker std::vector<double> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels());
1442*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> next_accumulators(next_batch_size() * next_output_height() * next_output_width() * groups() * group_output_channels());
1443*4bdc9457SAndroid Build Coastguard Worker std::vector<double> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * groups() * group_output_channels());
1444*4bdc9457SAndroid Build Coastguard Worker
1445*4bdc9457SAndroid Build Coastguard Worker const int8_t input_zero_point = 127;
1446*4bdc9457SAndroid Build Coastguard Worker
1447*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
1448*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
1449*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); });
1450*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); });
1451*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), INT8_C(0xA5));
1452*4bdc9457SAndroid Build Coastguard Worker
1453*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without renormalization.
1454*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) {
1455*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
1456*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) {
1457*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) {
1458*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
1459*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
1460*4bdc9457SAndroid Build Coastguard Worker accumulators[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] =
1461*4bdc9457SAndroid Build Coastguard Worker bias[g * group_output_channels() + oc];
1462*4bdc9457SAndroid Build Coastguard Worker }
1463*4bdc9457SAndroid Build Coastguard Worker }
1464*4bdc9457SAndroid Build Coastguard Worker }
1465*4bdc9457SAndroid Build Coastguard Worker }
1466*4bdc9457SAndroid Build Coastguard Worker }
1467*4bdc9457SAndroid Build Coastguard Worker } else {
1468*4bdc9457SAndroid Build Coastguard Worker std::fill(accumulators.begin(), accumulators.end(), 0);
1469*4bdc9457SAndroid Build Coastguard Worker }
1470*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
1471*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) {
1472*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) {
1473*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) {
1474*4bdc9457SAndroid Build Coastguard Worker const size_t y = oy + padding_top() - ky * dilation_height();
1475*4bdc9457SAndroid Build Coastguard Worker const size_t iy = y / stride_height();
1476*4bdc9457SAndroid Build Coastguard Worker if (iy * stride_height() == y && iy < input_height()) {
1477*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) {
1478*4bdc9457SAndroid Build Coastguard Worker const size_t x = ox + padding_left() - kx * dilation_width();
1479*4bdc9457SAndroid Build Coastguard Worker const size_t ix = x / stride_width();
1480*4bdc9457SAndroid Build Coastguard Worker if (ix * stride_width() == x && ix < input_width()) {
1481*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
1482*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
1483*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) {
1484*4bdc9457SAndroid Build Coastguard Worker accumulators[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
1485*4bdc9457SAndroid Build Coastguard Worker (int32_t(input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + g * group_input_channels() + ic]) - int32_t(input_zero_point)) *
1486*4bdc9457SAndroid Build Coastguard Worker int32_t(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]);
1487*4bdc9457SAndroid Build Coastguard Worker }
1488*4bdc9457SAndroid Build Coastguard Worker }
1489*4bdc9457SAndroid Build Coastguard Worker }
1490*4bdc9457SAndroid Build Coastguard Worker }
1491*4bdc9457SAndroid Build Coastguard Worker }
1492*4bdc9457SAndroid Build Coastguard Worker }
1493*4bdc9457SAndroid Build Coastguard Worker }
1494*4bdc9457SAndroid Build Coastguard Worker }
1495*4bdc9457SAndroid Build Coastguard Worker }
1496*4bdc9457SAndroid Build Coastguard Worker }
1497*4bdc9457SAndroid Build Coastguard Worker
1498*4bdc9457SAndroid Build Coastguard Worker // Compute renormalization parameters.
1499*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
1500*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
1501*4bdc9457SAndroid Build Coastguard Worker
1502*4bdc9457SAndroid Build Coastguard Worker const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0;
1503*4bdc9457SAndroid Build Coastguard Worker const int8_t output_zero_point = int8_t(std::max(std::min(
1504*4bdc9457SAndroid Build Coastguard Worker lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
1505*4bdc9457SAndroid Build Coastguard Worker long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
1506*4bdc9457SAndroid Build Coastguard Worker
1507*4bdc9457SAndroid Build Coastguard Worker // Renormalize reference results.
1508*4bdc9457SAndroid Build Coastguard Worker std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(),
1509*4bdc9457SAndroid Build Coastguard Worker [this, output_scale, output_zero_point](int32_t x) -> double {
1510*4bdc9457SAndroid Build Coastguard Worker return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax() - 0x80) - output_zero_point), double(qmin() - 0x80) - output_zero_point);
1511*4bdc9457SAndroid Build Coastguard Worker });
1512*4bdc9457SAndroid Build Coastguard Worker
1513*4bdc9457SAndroid Build Coastguard Worker // Create, setup, and run Deconvolution operator once.
1514*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
1515*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t deconvolution_op = nullptr;
1516*4bdc9457SAndroid Build Coastguard Worker
1517*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1518*4bdc9457SAndroid Build Coastguard Worker xnn_create_deconvolution2d_nhwc_qs8(
1519*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(),
1520*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(),
1521*4bdc9457SAndroid Build Coastguard Worker stride_height(), stride_width(),
1522*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(),
1523*4bdc9457SAndroid Build Coastguard Worker groups(), group_input_channels(), group_output_channels(),
1524*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride(), output_pixel_stride(),
1525*4bdc9457SAndroid Build Coastguard Worker input_zero_point, 1.0f /* input scale */,
1526*4bdc9457SAndroid Build Coastguard Worker 1.0f /* kernel scale */,
1527*4bdc9457SAndroid Build Coastguard Worker kernel.data(), has_bias() ? bias.data() : nullptr,
1528*4bdc9457SAndroid Build Coastguard Worker output_zero_point, output_scale, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
1529*4bdc9457SAndroid Build Coastguard Worker 0, NULL, &deconvolution_op));
1530*4bdc9457SAndroid Build Coastguard Worker
1531*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete deconvolution_op.
1532*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_deconvolution_op(deconvolution_op, xnn_delete_operator);
1533*4bdc9457SAndroid Build Coastguard Worker
1534*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1535*4bdc9457SAndroid Build Coastguard Worker xnn_setup_deconvolution2d_nhwc_qs8(
1536*4bdc9457SAndroid Build Coastguard Worker deconvolution_op,
1537*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(),
1538*4bdc9457SAndroid Build Coastguard Worker adjustment_height(), adjustment_width(),
1539*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(),
1540*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */));
1541*4bdc9457SAndroid Build Coastguard Worker
1542*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1543*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(deconvolution_op, nullptr /* thread pool */));
1544*4bdc9457SAndroid Build Coastguard Worker
1545*4bdc9457SAndroid Build Coastguard Worker // Verify results of the first run.
1546*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
1547*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) {
1548*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) {
1549*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
1550*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) {
1551*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]), int32_t(qmax() - 0x80))
1552*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1553*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]), int32_t(qmin() - 0x80))
1554*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1555*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
1556*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c],
1557*4bdc9457SAndroid Build Coastguard Worker double(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]) - double(output_zero_point),
1558*4bdc9457SAndroid Build Coastguard Worker 0.9)
1559*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1560*4bdc9457SAndroid Build Coastguard Worker }
1561*4bdc9457SAndroid Build Coastguard Worker }
1562*4bdc9457SAndroid Build Coastguard Worker }
1563*4bdc9457SAndroid Build Coastguard Worker }
1564*4bdc9457SAndroid Build Coastguard Worker }
1565*4bdc9457SAndroid Build Coastguard Worker
1566*4bdc9457SAndroid Build Coastguard Worker // Re-generate data for the second run.
1567*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
1568*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), INT8_C(0xA5));
1569*4bdc9457SAndroid Build Coastguard Worker
1570*4bdc9457SAndroid Build Coastguard Worker // Compute reference results for the second run, including renormalization.
1571*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) {
1572*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) {
1573*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) {
1574*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) {
1575*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
1576*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
1577*4bdc9457SAndroid Build Coastguard Worker next_accumulators[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] =
1578*4bdc9457SAndroid Build Coastguard Worker bias[g * group_output_channels() + oc];
1579*4bdc9457SAndroid Build Coastguard Worker }
1580*4bdc9457SAndroid Build Coastguard Worker }
1581*4bdc9457SAndroid Build Coastguard Worker }
1582*4bdc9457SAndroid Build Coastguard Worker }
1583*4bdc9457SAndroid Build Coastguard Worker }
1584*4bdc9457SAndroid Build Coastguard Worker } else {
1585*4bdc9457SAndroid Build Coastguard Worker std::fill(next_accumulators.begin(), next_accumulators.end(), 0);
1586*4bdc9457SAndroid Build Coastguard Worker }
1587*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) {
1588*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) {
1589*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) {
1590*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) {
1591*4bdc9457SAndroid Build Coastguard Worker const size_t y = oy + padding_top() - ky * dilation_height();
1592*4bdc9457SAndroid Build Coastguard Worker const size_t iy = y / stride_height();
1593*4bdc9457SAndroid Build Coastguard Worker if (iy * stride_height() == y && iy < next_input_height()) {
1594*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) {
1595*4bdc9457SAndroid Build Coastguard Worker const size_t x = ox + padding_left() - kx * dilation_width();
1596*4bdc9457SAndroid Build Coastguard Worker const size_t ix = x / stride_width();
1597*4bdc9457SAndroid Build Coastguard Worker if (ix * stride_width() == x && ix < next_input_width()) {
1598*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
1599*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
1600*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) {
1601*4bdc9457SAndroid Build Coastguard Worker next_accumulators[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
1602*4bdc9457SAndroid Build Coastguard Worker (int32_t(input[((i * next_input_height() + iy) * next_input_width() + ix) * input_pixel_stride() + g * group_input_channels() + ic]) - int32_t(input_zero_point)) *
1603*4bdc9457SAndroid Build Coastguard Worker int32_t(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]);
1604*4bdc9457SAndroid Build Coastguard Worker }
1605*4bdc9457SAndroid Build Coastguard Worker }
1606*4bdc9457SAndroid Build Coastguard Worker }
1607*4bdc9457SAndroid Build Coastguard Worker }
1608*4bdc9457SAndroid Build Coastguard Worker }
1609*4bdc9457SAndroid Build Coastguard Worker }
1610*4bdc9457SAndroid Build Coastguard Worker }
1611*4bdc9457SAndroid Build Coastguard Worker }
1612*4bdc9457SAndroid Build Coastguard Worker }
1613*4bdc9457SAndroid Build Coastguard Worker }
1614*4bdc9457SAndroid Build Coastguard Worker std::transform(next_accumulators.cbegin(), next_accumulators.cend(), next_output_ref.begin(),
1615*4bdc9457SAndroid Build Coastguard Worker [this, output_scale, output_zero_point](int32_t x) -> double {
1616*4bdc9457SAndroid Build Coastguard Worker return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax() - 0x80) - output_zero_point), double(qmin() - 0x80) - output_zero_point);
1617*4bdc9457SAndroid Build Coastguard Worker });
1618*4bdc9457SAndroid Build Coastguard Worker
1619*4bdc9457SAndroid Build Coastguard Worker // Setup and run Deconvolution operator the second time, and destroy the operator.
1620*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1621*4bdc9457SAndroid Build Coastguard Worker xnn_setup_deconvolution2d_nhwc_qs8(
1622*4bdc9457SAndroid Build Coastguard Worker deconvolution_op,
1623*4bdc9457SAndroid Build Coastguard Worker next_batch_size(), next_input_height(), next_input_width(),
1624*4bdc9457SAndroid Build Coastguard Worker adjustment_height(), adjustment_width(),
1625*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(),
1626*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */));
1627*4bdc9457SAndroid Build Coastguard Worker
1628*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1629*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(deconvolution_op, nullptr /* thread pool */));
1630*4bdc9457SAndroid Build Coastguard Worker
1631*4bdc9457SAndroid Build Coastguard Worker // Verify results of the second run.
1632*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) {
1633*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < next_output_height(); y++) {
1634*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < next_output_width(); x++) {
1635*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
1636*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) {
1637*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]), int32_t(qmax() - 0x80))
1638*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1639*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]), int32_t(qmin() - 0x80))
1640*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1641*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
1642*4bdc9457SAndroid Build Coastguard Worker next_output_ref[(((i * next_output_height() + y) * next_output_width() + x) * groups() + g) * group_output_channels() + c],
1643*4bdc9457SAndroid Build Coastguard Worker double(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]) - double(output_zero_point),
1644*4bdc9457SAndroid Build Coastguard Worker 0.9)
1645*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1646*4bdc9457SAndroid Build Coastguard Worker }
1647*4bdc9457SAndroid Build Coastguard Worker }
1648*4bdc9457SAndroid Build Coastguard Worker }
1649*4bdc9457SAndroid Build Coastguard Worker }
1650*4bdc9457SAndroid Build Coastguard Worker }
1651*4bdc9457SAndroid Build Coastguard Worker }
1652*4bdc9457SAndroid Build Coastguard Worker }
1653*4bdc9457SAndroid Build Coastguard Worker
TestSetupQU8()1654*4bdc9457SAndroid Build Coastguard Worker void TestSetupQU8() const {
1655*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_type(), WeightsType::Default);
1656*4bdc9457SAndroid Build Coastguard Worker
1657*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
1658*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
1659*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> i32dist(-10000, 10000);
1660*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t> u8dist(
1661*4bdc9457SAndroid Build Coastguard Worker std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
1662*4bdc9457SAndroid Build Coastguard Worker
1663*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + std::max(
1664*4bdc9457SAndroid Build Coastguard Worker (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + groups() * group_input_channels(),
1665*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_input_height() * next_input_width() - 1) * input_pixel_stride() + groups() * group_input_channels()));
1666*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels());
1667*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> bias(groups() * group_output_channels());
1668*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> output(std::max(
1669*4bdc9457SAndroid Build Coastguard Worker (batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + groups() * group_output_channels(),
1670*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_output_height() * next_output_width() - 1) * output_pixel_stride() + groups() * group_output_channels()));
1671*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> accumulators(batch_size() * output_height() * output_width() * groups() * group_output_channels());
1672*4bdc9457SAndroid Build Coastguard Worker std::vector<double> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels());
1673*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> next_accumulators(next_batch_size() * next_output_height() * next_output_width() * groups() * group_output_channels());
1674*4bdc9457SAndroid Build Coastguard Worker std::vector<double> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * groups() * group_output_channels());
1675*4bdc9457SAndroid Build Coastguard Worker
1676*4bdc9457SAndroid Build Coastguard Worker const uint8_t input_zero_point = 127;
1677*4bdc9457SAndroid Build Coastguard Worker const uint8_t kernel_zero_point = 127;
1678*4bdc9457SAndroid Build Coastguard Worker
1679*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
1680*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
1681*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return u8dist(rng); });
1682*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); });
1683*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT8_C(0xA5));
1684*4bdc9457SAndroid Build Coastguard Worker
1685*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without renormalization.
1686*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) {
1687*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
1688*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) {
1689*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) {
1690*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
1691*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
1692*4bdc9457SAndroid Build Coastguard Worker accumulators[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] =
1693*4bdc9457SAndroid Build Coastguard Worker bias[g * group_output_channels() + oc];
1694*4bdc9457SAndroid Build Coastguard Worker }
1695*4bdc9457SAndroid Build Coastguard Worker }
1696*4bdc9457SAndroid Build Coastguard Worker }
1697*4bdc9457SAndroid Build Coastguard Worker }
1698*4bdc9457SAndroid Build Coastguard Worker }
1699*4bdc9457SAndroid Build Coastguard Worker } else {
1700*4bdc9457SAndroid Build Coastguard Worker std::fill(accumulators.begin(), accumulators.end(), 0);
1701*4bdc9457SAndroid Build Coastguard Worker }
1702*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
1703*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) {
1704*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) {
1705*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) {
1706*4bdc9457SAndroid Build Coastguard Worker const size_t y = oy + padding_top() - ky * dilation_height();
1707*4bdc9457SAndroid Build Coastguard Worker const size_t iy = y / stride_height();
1708*4bdc9457SAndroid Build Coastguard Worker if (iy * stride_height() == y && iy < input_height()) {
1709*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) {
1710*4bdc9457SAndroid Build Coastguard Worker const size_t x = ox + padding_left() - kx * dilation_width();
1711*4bdc9457SAndroid Build Coastguard Worker const size_t ix = x / stride_width();
1712*4bdc9457SAndroid Build Coastguard Worker if (ix * stride_width() == x && ix < input_width()) {
1713*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
1714*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
1715*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) {
1716*4bdc9457SAndroid Build Coastguard Worker accumulators[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
1717*4bdc9457SAndroid Build Coastguard Worker (int32_t(input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + g * group_input_channels() + ic]) - int32_t(input_zero_point)) *
1718*4bdc9457SAndroid Build Coastguard Worker (int32_t(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]) - int32_t(kernel_zero_point));
1719*4bdc9457SAndroid Build Coastguard Worker }
1720*4bdc9457SAndroid Build Coastguard Worker }
1721*4bdc9457SAndroid Build Coastguard Worker }
1722*4bdc9457SAndroid Build Coastguard Worker }
1723*4bdc9457SAndroid Build Coastguard Worker }
1724*4bdc9457SAndroid Build Coastguard Worker }
1725*4bdc9457SAndroid Build Coastguard Worker }
1726*4bdc9457SAndroid Build Coastguard Worker }
1727*4bdc9457SAndroid Build Coastguard Worker }
1728*4bdc9457SAndroid Build Coastguard Worker }
1729*4bdc9457SAndroid Build Coastguard Worker
1730*4bdc9457SAndroid Build Coastguard Worker // Compute renormalization parameters.
1731*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
1732*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
1733*4bdc9457SAndroid Build Coastguard Worker
1734*4bdc9457SAndroid Build Coastguard Worker const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0;
1735*4bdc9457SAndroid Build Coastguard Worker const uint8_t output_zero_point = uint8_t(std::max(std::min(
1736*4bdc9457SAndroid Build Coastguard Worker lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
1737*4bdc9457SAndroid Build Coastguard Worker long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
1738*4bdc9457SAndroid Build Coastguard Worker
1739*4bdc9457SAndroid Build Coastguard Worker // Renormalize reference results.
1740*4bdc9457SAndroid Build Coastguard Worker std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(),
1741*4bdc9457SAndroid Build Coastguard Worker [this, output_scale, output_zero_point](int32_t x) -> double {
1742*4bdc9457SAndroid Build Coastguard Worker return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax()) - output_zero_point), double(qmin()) - output_zero_point);
1743*4bdc9457SAndroid Build Coastguard Worker });
1744*4bdc9457SAndroid Build Coastguard Worker
1745*4bdc9457SAndroid Build Coastguard Worker // Create, setup, and run Deconvolution operator once.
1746*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
1747*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t deconvolution_op = nullptr;
1748*4bdc9457SAndroid Build Coastguard Worker
1749*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1750*4bdc9457SAndroid Build Coastguard Worker xnn_create_deconvolution2d_nhwc_qu8(
1751*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(),
1752*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(),
1753*4bdc9457SAndroid Build Coastguard Worker stride_height(), stride_width(),
1754*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(),
1755*4bdc9457SAndroid Build Coastguard Worker groups(), group_input_channels(), group_output_channels(),
1756*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride(), output_pixel_stride(),
1757*4bdc9457SAndroid Build Coastguard Worker input_zero_point, 1.0f /* input scale */,
1758*4bdc9457SAndroid Build Coastguard Worker kernel_zero_point, 1.0f /* kernel scale */,
1759*4bdc9457SAndroid Build Coastguard Worker kernel.data(), has_bias() ? bias.data() : nullptr,
1760*4bdc9457SAndroid Build Coastguard Worker output_zero_point, output_scale, qmin(), qmax(),
1761*4bdc9457SAndroid Build Coastguard Worker 0, NULL, &deconvolution_op));
1762*4bdc9457SAndroid Build Coastguard Worker
1763*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete deconvolution_op.
1764*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_deconvolution_op(deconvolution_op, xnn_delete_operator);
1765*4bdc9457SAndroid Build Coastguard Worker
1766*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1767*4bdc9457SAndroid Build Coastguard Worker xnn_setup_deconvolution2d_nhwc_qu8(
1768*4bdc9457SAndroid Build Coastguard Worker deconvolution_op,
1769*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(),
1770*4bdc9457SAndroid Build Coastguard Worker adjustment_height(), adjustment_width(),
1771*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(),
1772*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */));
1773*4bdc9457SAndroid Build Coastguard Worker
1774*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1775*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(deconvolution_op, nullptr /* thread pool */));
1776*4bdc9457SAndroid Build Coastguard Worker
1777*4bdc9457SAndroid Build Coastguard Worker // Verify results of the first run.
1778*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
1779*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) {
1780*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) {
1781*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
1782*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) {
1783*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]), int32_t(qmax()))
1784*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1785*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]), int32_t(qmin()))
1786*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1787*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
1788*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c],
1789*4bdc9457SAndroid Build Coastguard Worker double(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]) - double(output_zero_point),
1790*4bdc9457SAndroid Build Coastguard Worker 0.9)
1791*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1792*4bdc9457SAndroid Build Coastguard Worker }
1793*4bdc9457SAndroid Build Coastguard Worker }
1794*4bdc9457SAndroid Build Coastguard Worker }
1795*4bdc9457SAndroid Build Coastguard Worker }
1796*4bdc9457SAndroid Build Coastguard Worker }
1797*4bdc9457SAndroid Build Coastguard Worker
1798*4bdc9457SAndroid Build Coastguard Worker // Re-generate data for the second run.
1799*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
1800*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), 0xA5);
1801*4bdc9457SAndroid Build Coastguard Worker
1802*4bdc9457SAndroid Build Coastguard Worker // Compute reference results for the second run, including renormalization.
1803*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) {
1804*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) {
1805*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) {
1806*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) {
1807*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
1808*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
1809*4bdc9457SAndroid Build Coastguard Worker next_accumulators[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] =
1810*4bdc9457SAndroid Build Coastguard Worker bias[g * group_output_channels() + oc];
1811*4bdc9457SAndroid Build Coastguard Worker }
1812*4bdc9457SAndroid Build Coastguard Worker }
1813*4bdc9457SAndroid Build Coastguard Worker }
1814*4bdc9457SAndroid Build Coastguard Worker }
1815*4bdc9457SAndroid Build Coastguard Worker }
1816*4bdc9457SAndroid Build Coastguard Worker } else {
1817*4bdc9457SAndroid Build Coastguard Worker std::fill(next_accumulators.begin(), next_accumulators.end(), 0);
1818*4bdc9457SAndroid Build Coastguard Worker }
1819*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) {
1820*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) {
1821*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) {
1822*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) {
1823*4bdc9457SAndroid Build Coastguard Worker const size_t y = oy + padding_top() - ky * dilation_height();
1824*4bdc9457SAndroid Build Coastguard Worker const size_t iy = y / stride_height();
1825*4bdc9457SAndroid Build Coastguard Worker if (iy * stride_height() == y && iy < next_input_height()) {
1826*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) {
1827*4bdc9457SAndroid Build Coastguard Worker const size_t x = ox + padding_left() - kx * dilation_width();
1828*4bdc9457SAndroid Build Coastguard Worker const size_t ix = x / stride_width();
1829*4bdc9457SAndroid Build Coastguard Worker if (ix * stride_width() == x && ix < next_input_width()) {
1830*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
1831*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
1832*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) {
1833*4bdc9457SAndroid Build Coastguard Worker next_accumulators[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
1834*4bdc9457SAndroid Build Coastguard Worker (int32_t(input[((i * next_input_height() + iy) * next_input_width() + ix) * input_pixel_stride() + g * group_input_channels() + ic]) - int32_t(input_zero_point)) *
1835*4bdc9457SAndroid Build Coastguard Worker (int32_t(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]) - int32_t(kernel_zero_point));
1836*4bdc9457SAndroid Build Coastguard Worker }
1837*4bdc9457SAndroid Build Coastguard Worker }
1838*4bdc9457SAndroid Build Coastguard Worker }
1839*4bdc9457SAndroid Build Coastguard Worker }
1840*4bdc9457SAndroid Build Coastguard Worker }
1841*4bdc9457SAndroid Build Coastguard Worker }
1842*4bdc9457SAndroid Build Coastguard Worker }
1843*4bdc9457SAndroid Build Coastguard Worker }
1844*4bdc9457SAndroid Build Coastguard Worker }
1845*4bdc9457SAndroid Build Coastguard Worker }
1846*4bdc9457SAndroid Build Coastguard Worker std::transform(next_accumulators.cbegin(), next_accumulators.cend(), next_output_ref.begin(),
1847*4bdc9457SAndroid Build Coastguard Worker [this, output_scale, output_zero_point](int32_t x) -> double {
1848*4bdc9457SAndroid Build Coastguard Worker return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax()) - output_zero_point), double(qmin()) - output_zero_point);
1849*4bdc9457SAndroid Build Coastguard Worker });
1850*4bdc9457SAndroid Build Coastguard Worker
1851*4bdc9457SAndroid Build Coastguard Worker // Setup and run Deconvolution operator the second time, and destroy the operator.
1852*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1853*4bdc9457SAndroid Build Coastguard Worker xnn_setup_deconvolution2d_nhwc_qu8(
1854*4bdc9457SAndroid Build Coastguard Worker deconvolution_op,
1855*4bdc9457SAndroid Build Coastguard Worker next_batch_size(), next_input_height(), next_input_width(),
1856*4bdc9457SAndroid Build Coastguard Worker adjustment_height(), adjustment_width(),
1857*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(),
1858*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */));
1859*4bdc9457SAndroid Build Coastguard Worker
1860*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1861*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(deconvolution_op, nullptr /* thread pool */));
1862*4bdc9457SAndroid Build Coastguard Worker
1863*4bdc9457SAndroid Build Coastguard Worker // Verify results of the second run.
1864*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) {
1865*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < next_output_height(); y++) {
1866*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < next_output_width(); x++) {
1867*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
1868*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) {
1869*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]), int32_t(qmax()))
1870*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1871*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]), int32_t(qmin()))
1872*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1873*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
1874*4bdc9457SAndroid Build Coastguard Worker next_output_ref[(((i * next_output_height() + y) * next_output_width() + x) * groups() + g) * group_output_channels() + c],
1875*4bdc9457SAndroid Build Coastguard Worker double(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]) - double(output_zero_point),
1876*4bdc9457SAndroid Build Coastguard Worker 0.9)
1877*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
1878*4bdc9457SAndroid Build Coastguard Worker }
1879*4bdc9457SAndroid Build Coastguard Worker }
1880*4bdc9457SAndroid Build Coastguard Worker }
1881*4bdc9457SAndroid Build Coastguard Worker }
1882*4bdc9457SAndroid Build Coastguard Worker }
1883*4bdc9457SAndroid Build Coastguard Worker }
1884*4bdc9457SAndroid Build Coastguard Worker }
1885*4bdc9457SAndroid Build Coastguard Worker
TestSetupF16()1886*4bdc9457SAndroid Build Coastguard Worker void TestSetupF16() const {
1887*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_type(), WeightsType::Default);
1888*4bdc9457SAndroid Build Coastguard Worker
1889*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
1890*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
1891*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.1f, 1.0f);
1892*4bdc9457SAndroid Build Coastguard Worker
1893*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + std::max(
1894*4bdc9457SAndroid Build Coastguard Worker (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + groups() * group_input_channels(),
1895*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_input_height() * next_input_width() - 1) * input_pixel_stride() + groups() * group_input_channels()));
1896*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels());
1897*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> bias(groups() * group_output_channels());
1898*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> output(std::max(
1899*4bdc9457SAndroid Build Coastguard Worker (batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + groups() * group_output_channels(),
1900*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_output_height() * next_output_width() - 1) * output_pixel_stride() + groups() * group_output_channels()));
1901*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels());
1902*4bdc9457SAndroid Build Coastguard Worker std::vector<float> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * groups() * group_output_channels());
1903*4bdc9457SAndroid Build Coastguard Worker
1904*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
1905*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
1906*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
1907*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
1908*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
1909*4bdc9457SAndroid Build Coastguard Worker
1910*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping.
1911*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) {
1912*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
1913*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) {
1914*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) {
1915*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
1916*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
1917*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] =
1918*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(bias[g * group_output_channels() + oc]);
1919*4bdc9457SAndroid Build Coastguard Worker }
1920*4bdc9457SAndroid Build Coastguard Worker }
1921*4bdc9457SAndroid Build Coastguard Worker }
1922*4bdc9457SAndroid Build Coastguard Worker }
1923*4bdc9457SAndroid Build Coastguard Worker }
1924*4bdc9457SAndroid Build Coastguard Worker } else {
1925*4bdc9457SAndroid Build Coastguard Worker std::fill(output_ref.begin(), output_ref.end(), 0);
1926*4bdc9457SAndroid Build Coastguard Worker }
1927*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
1928*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) {
1929*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) {
1930*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) {
1931*4bdc9457SAndroid Build Coastguard Worker const size_t y = oy + padding_top() - ky * dilation_height();
1932*4bdc9457SAndroid Build Coastguard Worker const size_t iy = y / stride_height();
1933*4bdc9457SAndroid Build Coastguard Worker if (iy * stride_height() == y && iy < input_height()) {
1934*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) {
1935*4bdc9457SAndroid Build Coastguard Worker const size_t x = ox + padding_left() - kx * dilation_width();
1936*4bdc9457SAndroid Build Coastguard Worker const size_t ix = x / stride_width();
1937*4bdc9457SAndroid Build Coastguard Worker if (ix * stride_width() == x && ix < input_width()) {
1938*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
1939*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
1940*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) {
1941*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
1942*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + g * group_input_channels() + ic]) *
1943*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]);
1944*4bdc9457SAndroid Build Coastguard Worker }
1945*4bdc9457SAndroid Build Coastguard Worker }
1946*4bdc9457SAndroid Build Coastguard Worker }
1947*4bdc9457SAndroid Build Coastguard Worker }
1948*4bdc9457SAndroid Build Coastguard Worker }
1949*4bdc9457SAndroid Build Coastguard Worker }
1950*4bdc9457SAndroid Build Coastguard Worker }
1951*4bdc9457SAndroid Build Coastguard Worker }
1952*4bdc9457SAndroid Build Coastguard Worker }
1953*4bdc9457SAndroid Build Coastguard Worker }
1954*4bdc9457SAndroid Build Coastguard Worker
1955*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters.
1956*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
1957*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
1958*4bdc9457SAndroid Build Coastguard Worker const float accumulated_range = accumulated_max - accumulated_min;
1959*4bdc9457SAndroid Build Coastguard Worker float output_min = accumulated_min + accumulated_range / 255.0f * float(qmin());
1960*4bdc9457SAndroid Build Coastguard Worker float output_max = accumulated_max - accumulated_range / 255.0f * float(255 - qmax());
1961*4bdc9457SAndroid Build Coastguard Worker output_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_min));
1962*4bdc9457SAndroid Build Coastguard Worker output_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(output_max));
1963*4bdc9457SAndroid Build Coastguard Worker if (accumulated_range == 0.0f) {
1964*4bdc9457SAndroid Build Coastguard Worker output_min = -std::numeric_limits<float>::infinity();
1965*4bdc9457SAndroid Build Coastguard Worker output_max = +std::numeric_limits<float>::infinity();
1966*4bdc9457SAndroid Build Coastguard Worker }
1967*4bdc9457SAndroid Build Coastguard Worker if (qmin() == std::numeric_limits<uint8_t>::min()) {
1968*4bdc9457SAndroid Build Coastguard Worker output_min = -std::numeric_limits<float>::infinity();
1969*4bdc9457SAndroid Build Coastguard Worker }
1970*4bdc9457SAndroid Build Coastguard Worker if (qmax() == std::numeric_limits<uint8_t>::max()) {
1971*4bdc9457SAndroid Build Coastguard Worker output_max = +std::numeric_limits<float>::infinity();
1972*4bdc9457SAndroid Build Coastguard Worker }
1973*4bdc9457SAndroid Build Coastguard Worker
1974*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results.
1975*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) {
1976*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min);
1977*4bdc9457SAndroid Build Coastguard Worker }
1978*4bdc9457SAndroid Build Coastguard Worker
1979*4bdc9457SAndroid Build Coastguard Worker // Create, setup, and run Deconvolution operator once.
1980*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
1981*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t deconvolution_op = nullptr;
1982*4bdc9457SAndroid Build Coastguard Worker
1983*4bdc9457SAndroid Build Coastguard Worker const xnn_status status = xnn_create_deconvolution2d_nhwc_f16(
1984*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(),
1985*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(),
1986*4bdc9457SAndroid Build Coastguard Worker stride_height(), stride_width(),
1987*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(),
1988*4bdc9457SAndroid Build Coastguard Worker groups(), group_input_channels(), group_output_channels(),
1989*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride(), output_pixel_stride(),
1990*4bdc9457SAndroid Build Coastguard Worker kernel.data(), has_bias() ? bias.data() : nullptr,
1991*4bdc9457SAndroid Build Coastguard Worker output_min, output_max,
1992*4bdc9457SAndroid Build Coastguard Worker 0, NULL, &deconvolution_op);
1993*4bdc9457SAndroid Build Coastguard Worker if (status == xnn_status_unsupported_hardware) {
1994*4bdc9457SAndroid Build Coastguard Worker GTEST_SKIP();
1995*4bdc9457SAndroid Build Coastguard Worker }
1996*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, status);
1997*4bdc9457SAndroid Build Coastguard Worker ASSERT_NE(nullptr, deconvolution_op);
1998*4bdc9457SAndroid Build Coastguard Worker
1999*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete deconvolution_op.
2000*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_deconvolution_op(deconvolution_op, xnn_delete_operator);
2001*4bdc9457SAndroid Build Coastguard Worker
2002*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
2003*4bdc9457SAndroid Build Coastguard Worker xnn_setup_deconvolution2d_nhwc_f16(
2004*4bdc9457SAndroid Build Coastguard Worker deconvolution_op,
2005*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(),
2006*4bdc9457SAndroid Build Coastguard Worker adjustment_height(), adjustment_width(),
2007*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(),
2008*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */));
2009*4bdc9457SAndroid Build Coastguard Worker
2010*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
2011*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(deconvolution_op, nullptr /* thread pool */));
2012*4bdc9457SAndroid Build Coastguard Worker
2013*4bdc9457SAndroid Build Coastguard Worker // Verify results of the first run.
2014*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
2015*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) {
2016*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) {
2017*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
2018*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) {
2019*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]), output_min)
2020*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2021*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]), output_max)
2022*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2023*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
2024*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]),
2025*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c],
2026*4bdc9457SAndroid Build Coastguard Worker 1.0e-2f * std::abs(output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c]))
2027*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2028*4bdc9457SAndroid Build Coastguard Worker }
2029*4bdc9457SAndroid Build Coastguard Worker }
2030*4bdc9457SAndroid Build Coastguard Worker }
2031*4bdc9457SAndroid Build Coastguard Worker }
2032*4bdc9457SAndroid Build Coastguard Worker }
2033*4bdc9457SAndroid Build Coastguard Worker
2034*4bdc9457SAndroid Build Coastguard Worker // Re-generate data for the second run.
2035*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
2036*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
2037*4bdc9457SAndroid Build Coastguard Worker
2038*4bdc9457SAndroid Build Coastguard Worker // Compute reference results for the second run, including clamping.
2039*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) {
2040*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) {
2041*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) {
2042*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) {
2043*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
2044*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
2045*4bdc9457SAndroid Build Coastguard Worker next_output_ref[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] =
2046*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(bias[g * group_output_channels() + oc]);
2047*4bdc9457SAndroid Build Coastguard Worker }
2048*4bdc9457SAndroid Build Coastguard Worker }
2049*4bdc9457SAndroid Build Coastguard Worker }
2050*4bdc9457SAndroid Build Coastguard Worker }
2051*4bdc9457SAndroid Build Coastguard Worker }
2052*4bdc9457SAndroid Build Coastguard Worker } else {
2053*4bdc9457SAndroid Build Coastguard Worker std::fill(next_output_ref.begin(), next_output_ref.end(), 0);
2054*4bdc9457SAndroid Build Coastguard Worker }
2055*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) {
2056*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) {
2057*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) {
2058*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) {
2059*4bdc9457SAndroid Build Coastguard Worker const size_t y = oy + padding_top() - ky * dilation_height();
2060*4bdc9457SAndroid Build Coastguard Worker const size_t iy = y / stride_height();
2061*4bdc9457SAndroid Build Coastguard Worker if (iy * stride_height() == y && iy < next_input_height()) {
2062*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) {
2063*4bdc9457SAndroid Build Coastguard Worker const size_t x = ox + padding_left() - kx * dilation_width();
2064*4bdc9457SAndroid Build Coastguard Worker const size_t ix = x / stride_width();
2065*4bdc9457SAndroid Build Coastguard Worker if (ix * stride_width() == x && ix < next_input_width()) {
2066*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
2067*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
2068*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) {
2069*4bdc9457SAndroid Build Coastguard Worker next_output_ref[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
2070*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(input[((i * next_input_height() + iy) * next_input_width() + ix) * input_pixel_stride() + g * group_input_channels() + ic]) *
2071*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic]);
2072*4bdc9457SAndroid Build Coastguard Worker }
2073*4bdc9457SAndroid Build Coastguard Worker }
2074*4bdc9457SAndroid Build Coastguard Worker }
2075*4bdc9457SAndroid Build Coastguard Worker }
2076*4bdc9457SAndroid Build Coastguard Worker }
2077*4bdc9457SAndroid Build Coastguard Worker }
2078*4bdc9457SAndroid Build Coastguard Worker }
2079*4bdc9457SAndroid Build Coastguard Worker }
2080*4bdc9457SAndroid Build Coastguard Worker }
2081*4bdc9457SAndroid Build Coastguard Worker }
2082*4bdc9457SAndroid Build Coastguard Worker for (float& value : next_output_ref) {
2083*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min);
2084*4bdc9457SAndroid Build Coastguard Worker }
2085*4bdc9457SAndroid Build Coastguard Worker
2086*4bdc9457SAndroid Build Coastguard Worker // Setup and run Deconvolution operator the second time, and destroy the operator.
2087*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
2088*4bdc9457SAndroid Build Coastguard Worker xnn_setup_deconvolution2d_nhwc_f16(
2089*4bdc9457SAndroid Build Coastguard Worker deconvolution_op,
2090*4bdc9457SAndroid Build Coastguard Worker next_batch_size(), next_input_height(), next_input_width(),
2091*4bdc9457SAndroid Build Coastguard Worker adjustment_height(), adjustment_width(),
2092*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(),
2093*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */));
2094*4bdc9457SAndroid Build Coastguard Worker
2095*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
2096*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(deconvolution_op, nullptr /* thread pool */));
2097*4bdc9457SAndroid Build Coastguard Worker
2098*4bdc9457SAndroid Build Coastguard Worker // Verify results of the second run.
2099*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) {
2100*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < next_output_height(); y++) {
2101*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < next_output_width(); x++) {
2102*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
2103*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) {
2104*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(fp16_ieee_to_fp32_value(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]), output_min)
2105*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2106*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(fp16_ieee_to_fp32_value(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]), output_max)
2107*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2108*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
2109*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + g * group_output_channels() + c]),
2110*4bdc9457SAndroid Build Coastguard Worker next_output_ref[(((i * next_output_height() + y) * next_output_width() + x) * groups() + g) * group_output_channels() + c],
2111*4bdc9457SAndroid Build Coastguard Worker 1.0e-2f * std::abs(next_output_ref[(((i * next_output_height() + y) * next_output_width() + x) * groups() + g) * group_output_channels() + c]))
2112*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2113*4bdc9457SAndroid Build Coastguard Worker }
2114*4bdc9457SAndroid Build Coastguard Worker }
2115*4bdc9457SAndroid Build Coastguard Worker }
2116*4bdc9457SAndroid Build Coastguard Worker }
2117*4bdc9457SAndroid Build Coastguard Worker }
2118*4bdc9457SAndroid Build Coastguard Worker }
2119*4bdc9457SAndroid Build Coastguard Worker }
2120*4bdc9457SAndroid Build Coastguard Worker
TestSetupF32()2121*4bdc9457SAndroid Build Coastguard Worker void TestSetupF32() const {
2122*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(weights_type(), WeightsType::Default);
2123*4bdc9457SAndroid Build Coastguard Worker
2124*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
2125*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
2126*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist(0.1f, 1.0f);
2127*4bdc9457SAndroid Build Coastguard Worker
2128*4bdc9457SAndroid Build Coastguard Worker std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + std::max(
2129*4bdc9457SAndroid Build Coastguard Worker (batch_size() * input_height() * input_width() - 1) * input_pixel_stride() + groups() * group_input_channels(),
2130*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_input_height() * next_input_width() - 1) * input_pixel_stride() + groups() * group_input_channels()));
2131*4bdc9457SAndroid Build Coastguard Worker std::vector<float> kernel(groups() * group_output_channels() * kernel_height() * kernel_width() * group_input_channels());
2132*4bdc9457SAndroid Build Coastguard Worker std::vector<float> bias(groups() * group_output_channels());
2133*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output(std::max(
2134*4bdc9457SAndroid Build Coastguard Worker (batch_size() * output_height() * output_width() - 1) * output_pixel_stride() + groups() * group_output_channels(),
2135*4bdc9457SAndroid Build Coastguard Worker (next_batch_size() * next_output_height() * next_output_width() - 1) * output_pixel_stride() + groups() * group_output_channels()));
2136*4bdc9457SAndroid Build Coastguard Worker std::vector<float> output_ref(batch_size() * output_height() * output_width() * groups() * group_output_channels());
2137*4bdc9457SAndroid Build Coastguard Worker std::vector<float> next_output_ref(next_batch_size() * next_output_height() * next_output_width() * groups() * group_output_channels());
2138*4bdc9457SAndroid Build Coastguard Worker
2139*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
2140*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
2141*4bdc9457SAndroid Build Coastguard Worker std::generate(kernel.begin(), kernel.end(), [&]() { return f32dist(rng); });
2142*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
2143*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), nanf(""));
2144*4bdc9457SAndroid Build Coastguard Worker
2145*4bdc9457SAndroid Build Coastguard Worker // Compute reference results, without clamping.
2146*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) {
2147*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
2148*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) {
2149*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) {
2150*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
2151*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
2152*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] =
2153*4bdc9457SAndroid Build Coastguard Worker bias[g * group_output_channels() + oc];
2154*4bdc9457SAndroid Build Coastguard Worker }
2155*4bdc9457SAndroid Build Coastguard Worker }
2156*4bdc9457SAndroid Build Coastguard Worker }
2157*4bdc9457SAndroid Build Coastguard Worker }
2158*4bdc9457SAndroid Build Coastguard Worker }
2159*4bdc9457SAndroid Build Coastguard Worker } else {
2160*4bdc9457SAndroid Build Coastguard Worker std::fill(output_ref.begin(), output_ref.end(), 0.0f);
2161*4bdc9457SAndroid Build Coastguard Worker }
2162*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
2163*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < output_height(); oy++) {
2164*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < output_width(); ox++) {
2165*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) {
2166*4bdc9457SAndroid Build Coastguard Worker const size_t y = oy + padding_top() - ky * dilation_height();
2167*4bdc9457SAndroid Build Coastguard Worker const size_t iy = y / stride_height();
2168*4bdc9457SAndroid Build Coastguard Worker if (iy * stride_height() == y && iy < input_height()) {
2169*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) {
2170*4bdc9457SAndroid Build Coastguard Worker const size_t x = ox + padding_left() - kx * dilation_width();
2171*4bdc9457SAndroid Build Coastguard Worker const size_t ix = x / stride_width();
2172*4bdc9457SAndroid Build Coastguard Worker if (ix * stride_width() == x && ix < input_width()) {
2173*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
2174*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
2175*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) {
2176*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + oy) * output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
2177*4bdc9457SAndroid Build Coastguard Worker input[((i * input_height() + iy) * input_width() + ix) * input_pixel_stride() + g * group_input_channels() + ic] *
2178*4bdc9457SAndroid Build Coastguard Worker kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic];
2179*4bdc9457SAndroid Build Coastguard Worker }
2180*4bdc9457SAndroid Build Coastguard Worker }
2181*4bdc9457SAndroid Build Coastguard Worker }
2182*4bdc9457SAndroid Build Coastguard Worker }
2183*4bdc9457SAndroid Build Coastguard Worker }
2184*4bdc9457SAndroid Build Coastguard Worker }
2185*4bdc9457SAndroid Build Coastguard Worker }
2186*4bdc9457SAndroid Build Coastguard Worker }
2187*4bdc9457SAndroid Build Coastguard Worker }
2188*4bdc9457SAndroid Build Coastguard Worker }
2189*4bdc9457SAndroid Build Coastguard Worker
2190*4bdc9457SAndroid Build Coastguard Worker // Compute clamping parameters.
2191*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
2192*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
2193*4bdc9457SAndroid Build Coastguard Worker
2194*4bdc9457SAndroid Build Coastguard Worker const float output_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
2195*4bdc9457SAndroid Build Coastguard Worker const float output_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
2196*4bdc9457SAndroid Build Coastguard Worker
2197*4bdc9457SAndroid Build Coastguard Worker // Clamp reference results.
2198*4bdc9457SAndroid Build Coastguard Worker for (float& value : output_ref) {
2199*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min);
2200*4bdc9457SAndroid Build Coastguard Worker }
2201*4bdc9457SAndroid Build Coastguard Worker
2202*4bdc9457SAndroid Build Coastguard Worker // Create, setup, and run Deconvolution operator once.
2203*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
2204*4bdc9457SAndroid Build Coastguard Worker xnn_operator_t deconvolution_op = nullptr;
2205*4bdc9457SAndroid Build Coastguard Worker
2206*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
2207*4bdc9457SAndroid Build Coastguard Worker xnn_create_deconvolution2d_nhwc_f32(
2208*4bdc9457SAndroid Build Coastguard Worker padding_top(), padding_right(), padding_bottom(), padding_left(),
2209*4bdc9457SAndroid Build Coastguard Worker kernel_height(), kernel_width(),
2210*4bdc9457SAndroid Build Coastguard Worker stride_height(), stride_width(),
2211*4bdc9457SAndroid Build Coastguard Worker dilation_height(), dilation_width(),
2212*4bdc9457SAndroid Build Coastguard Worker groups(), group_input_channels(), group_output_channels(),
2213*4bdc9457SAndroid Build Coastguard Worker input_pixel_stride(), output_pixel_stride(),
2214*4bdc9457SAndroid Build Coastguard Worker kernel.data(), has_bias() ? bias.data() : nullptr,
2215*4bdc9457SAndroid Build Coastguard Worker output_min, output_max,
2216*4bdc9457SAndroid Build Coastguard Worker 0, NULL, &deconvolution_op));
2217*4bdc9457SAndroid Build Coastguard Worker
2218*4bdc9457SAndroid Build Coastguard Worker // Smart pointer to automatically delete deconvolution_op.
2219*4bdc9457SAndroid Build Coastguard Worker std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_deconvolution_op(deconvolution_op, xnn_delete_operator);
2220*4bdc9457SAndroid Build Coastguard Worker
2221*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
2222*4bdc9457SAndroid Build Coastguard Worker xnn_setup_deconvolution2d_nhwc_f32(
2223*4bdc9457SAndroid Build Coastguard Worker deconvolution_op,
2224*4bdc9457SAndroid Build Coastguard Worker batch_size(), input_height(), input_width(),
2225*4bdc9457SAndroid Build Coastguard Worker adjustment_height(), adjustment_width(),
2226*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(),
2227*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */));
2228*4bdc9457SAndroid Build Coastguard Worker
2229*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
2230*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(deconvolution_op, nullptr /* thread pool */));
2231*4bdc9457SAndroid Build Coastguard Worker
2232*4bdc9457SAndroid Build Coastguard Worker // Verify results of the first run.
2233*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < batch_size(); i++) {
2234*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < output_height(); y++) {
2235*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < output_width(); x++) {
2236*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
2237*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) {
2238*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c], output_min)
2239*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2240*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c], output_max)
2241*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2242*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
2243*4bdc9457SAndroid Build Coastguard Worker output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c],
2244*4bdc9457SAndroid Build Coastguard Worker output[((i * output_height() + y) * output_width() + x) * output_pixel_stride() + g * group_output_channels() + c],
2245*4bdc9457SAndroid Build Coastguard Worker 1.0e-4 * std::abs(output_ref[(((i * output_height() + y) * output_width() + x) * groups() + g) * group_output_channels() + c]))
2246*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2247*4bdc9457SAndroid Build Coastguard Worker }
2248*4bdc9457SAndroid Build Coastguard Worker }
2249*4bdc9457SAndroid Build Coastguard Worker }
2250*4bdc9457SAndroid Build Coastguard Worker }
2251*4bdc9457SAndroid Build Coastguard Worker }
2252*4bdc9457SAndroid Build Coastguard Worker
2253*4bdc9457SAndroid Build Coastguard Worker // Re-generate data for the second run.
2254*4bdc9457SAndroid Build Coastguard Worker std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
2255*4bdc9457SAndroid Build Coastguard Worker std::fill(output.begin(), output.end(), nanf(""));
2256*4bdc9457SAndroid Build Coastguard Worker
2257*4bdc9457SAndroid Build Coastguard Worker // Compute reference results for the second run, including clamping.
2258*4bdc9457SAndroid Build Coastguard Worker if (has_bias()) {
2259*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) {
2260*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) {
2261*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) {
2262*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
2263*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
2264*4bdc9457SAndroid Build Coastguard Worker next_output_ref[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] =
2265*4bdc9457SAndroid Build Coastguard Worker bias[g * group_output_channels() + oc];
2266*4bdc9457SAndroid Build Coastguard Worker }
2267*4bdc9457SAndroid Build Coastguard Worker }
2268*4bdc9457SAndroid Build Coastguard Worker }
2269*4bdc9457SAndroid Build Coastguard Worker }
2270*4bdc9457SAndroid Build Coastguard Worker }
2271*4bdc9457SAndroid Build Coastguard Worker } else {
2272*4bdc9457SAndroid Build Coastguard Worker std::fill(next_output_ref.begin(), next_output_ref.end(), 0.0f);
2273*4bdc9457SAndroid Build Coastguard Worker }
2274*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) {
2275*4bdc9457SAndroid Build Coastguard Worker for (size_t oy = 0; oy < next_output_height(); oy++) {
2276*4bdc9457SAndroid Build Coastguard Worker for (size_t ox = 0; ox < next_output_width(); ox++) {
2277*4bdc9457SAndroid Build Coastguard Worker for (size_t ky = 0; ky < kernel_height(); ky++) {
2278*4bdc9457SAndroid Build Coastguard Worker const size_t y = oy + padding_top() - ky * dilation_height();
2279*4bdc9457SAndroid Build Coastguard Worker const size_t iy = y / stride_height();
2280*4bdc9457SAndroid Build Coastguard Worker if (iy * stride_height() == y && iy < next_input_height()) {
2281*4bdc9457SAndroid Build Coastguard Worker for (size_t kx = 0; kx < kernel_width(); kx++) {
2282*4bdc9457SAndroid Build Coastguard Worker const size_t x = ox + padding_left() - kx * dilation_width();
2283*4bdc9457SAndroid Build Coastguard Worker const size_t ix = x / stride_width();
2284*4bdc9457SAndroid Build Coastguard Worker if (ix * stride_width() == x && ix < next_input_width()) {
2285*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
2286*4bdc9457SAndroid Build Coastguard Worker for (size_t oc = 0; oc < group_output_channels(); oc++) {
2287*4bdc9457SAndroid Build Coastguard Worker for (size_t ic = 0; ic < group_input_channels(); ic++) {
2288*4bdc9457SAndroid Build Coastguard Worker next_output_ref[(((i * next_output_height() + oy) * next_output_width() + ox) * groups() + g) * group_output_channels() + oc] +=
2289*4bdc9457SAndroid Build Coastguard Worker input[((i * next_input_height() + iy) * next_input_width() + ix) * input_pixel_stride() + g * group_input_channels() + ic] *
2290*4bdc9457SAndroid Build Coastguard Worker kernel[(((g * group_output_channels() + oc) * kernel_height() + ky) * kernel_width() + kx) * group_input_channels() + ic];
2291*4bdc9457SAndroid Build Coastguard Worker }
2292*4bdc9457SAndroid Build Coastguard Worker }
2293*4bdc9457SAndroid Build Coastguard Worker }
2294*4bdc9457SAndroid Build Coastguard Worker }
2295*4bdc9457SAndroid Build Coastguard Worker }
2296*4bdc9457SAndroid Build Coastguard Worker }
2297*4bdc9457SAndroid Build Coastguard Worker }
2298*4bdc9457SAndroid Build Coastguard Worker }
2299*4bdc9457SAndroid Build Coastguard Worker }
2300*4bdc9457SAndroid Build Coastguard Worker }
2301*4bdc9457SAndroid Build Coastguard Worker for (float& value : next_output_ref) {
2302*4bdc9457SAndroid Build Coastguard Worker value = std::max(std::min(value, output_max), output_min);
2303*4bdc9457SAndroid Build Coastguard Worker }
2304*4bdc9457SAndroid Build Coastguard Worker
2305*4bdc9457SAndroid Build Coastguard Worker // Setup and run Deconvolution operator the second time, and destroy the operator.
2306*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
2307*4bdc9457SAndroid Build Coastguard Worker xnn_setup_deconvolution2d_nhwc_f32(
2308*4bdc9457SAndroid Build Coastguard Worker deconvolution_op,
2309*4bdc9457SAndroid Build Coastguard Worker next_batch_size(), next_input_height(), next_input_width(),
2310*4bdc9457SAndroid Build Coastguard Worker adjustment_height(), adjustment_width(),
2311*4bdc9457SAndroid Build Coastguard Worker input.data(), output.data(),
2312*4bdc9457SAndroid Build Coastguard Worker nullptr /* thread pool */));
2313*4bdc9457SAndroid Build Coastguard Worker
2314*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
2315*4bdc9457SAndroid Build Coastguard Worker xnn_run_operator(deconvolution_op, nullptr /* thread pool */));
2316*4bdc9457SAndroid Build Coastguard Worker
2317*4bdc9457SAndroid Build Coastguard Worker // Verify results of the second run.
2318*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < next_batch_size(); i++) {
2319*4bdc9457SAndroid Build Coastguard Worker for (size_t y = 0; y < next_output_height(); y++) {
2320*4bdc9457SAndroid Build Coastguard Worker for (size_t x = 0; x < next_output_width(); x++) {
2321*4bdc9457SAndroid Build Coastguard Worker for (size_t g = 0; g < groups(); g++) {
2322*4bdc9457SAndroid Build Coastguard Worker for (size_t c = 0; c < group_output_channels(); c++) {
2323*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + g * group_output_channels() + c], output_min)
2324*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2325*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + g * group_output_channels() + c], output_max)
2326*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2327*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
2328*4bdc9457SAndroid Build Coastguard Worker next_output_ref[(((i * next_output_height() + y) * next_output_width() + x) * groups() + g) * group_output_channels() + c],
2329*4bdc9457SAndroid Build Coastguard Worker output[((i * next_output_height() + y) * next_output_width() + x) * output_pixel_stride() + g * group_output_channels() + c],
2330*4bdc9457SAndroid Build Coastguard Worker 1.0e-4 * std::abs(next_output_ref[(((i * next_output_height() + y) * next_output_width() + x) * groups() + g) * group_output_channels() + c]))
2331*4bdc9457SAndroid Build Coastguard Worker << "(x, y) = (" << x << ", " << y << "), group = " << g << ", channel = " << c;
2332*4bdc9457SAndroid Build Coastguard Worker }
2333*4bdc9457SAndroid Build Coastguard Worker }
2334*4bdc9457SAndroid Build Coastguard Worker }
2335*4bdc9457SAndroid Build Coastguard Worker }
2336*4bdc9457SAndroid Build Coastguard Worker }
2337*4bdc9457SAndroid Build Coastguard Worker }
2338*4bdc9457SAndroid Build Coastguard Worker }
2339*4bdc9457SAndroid Build Coastguard Worker
2340*4bdc9457SAndroid Build Coastguard Worker private:
2341*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_top_{0};
2342*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_right_{0};
2343*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_bottom_{0};
2344*4bdc9457SAndroid Build Coastguard Worker uint32_t padding_left_{0};
2345*4bdc9457SAndroid Build Coastguard Worker size_t input_height_{1};
2346*4bdc9457SAndroid Build Coastguard Worker size_t input_width_{1};
2347*4bdc9457SAndroid Build Coastguard Worker uint32_t groups_{1};
2348*4bdc9457SAndroid Build Coastguard Worker size_t group_input_channels_{1};
2349*4bdc9457SAndroid Build Coastguard Worker size_t input_pixel_stride_{0};
2350*4bdc9457SAndroid Build Coastguard Worker size_t group_output_channels_{1};
2351*4bdc9457SAndroid Build Coastguard Worker size_t output_pixel_stride_{0};
2352*4bdc9457SAndroid Build Coastguard Worker size_t batch_size_{1};
2353*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_height_{1};
2354*4bdc9457SAndroid Build Coastguard Worker uint32_t kernel_width_{1};
2355*4bdc9457SAndroid Build Coastguard Worker uint32_t adjustment_height_{0};
2356*4bdc9457SAndroid Build Coastguard Worker uint32_t adjustment_width_{0};
2357*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_height_{1};
2358*4bdc9457SAndroid Build Coastguard Worker uint32_t dilation_width_{1};
2359*4bdc9457SAndroid Build Coastguard Worker uint32_t stride_height_{1};
2360*4bdc9457SAndroid Build Coastguard Worker uint32_t stride_width_{1};
2361*4bdc9457SAndroid Build Coastguard Worker size_t next_input_height_{0};
2362*4bdc9457SAndroid Build Coastguard Worker size_t next_input_width_{0};
2363*4bdc9457SAndroid Build Coastguard Worker size_t next_batch_size_{0};
2364*4bdc9457SAndroid Build Coastguard Worker uint8_t qmin_{0};
2365*4bdc9457SAndroid Build Coastguard Worker uint8_t qmax_{255};
2366*4bdc9457SAndroid Build Coastguard Worker bool has_bias_{true};
2367*4bdc9457SAndroid Build Coastguard Worker WeightsType weights_type_{WeightsType::Default};
2368*4bdc9457SAndroid Build Coastguard Worker bool use_weights_cache_{false};
2369*4bdc9457SAndroid Build Coastguard Worker bool stress_weights_cache_{false};
2370*4bdc9457SAndroid Build Coastguard Worker size_t iterations_{1};
2371*4bdc9457SAndroid Build Coastguard Worker };
2372