xref: /aosp_15_r20/external/XNNPACK/test/constant-pad-operator-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #pragma once
7*4bdc9457SAndroid Build Coastguard Worker 
8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
9*4bdc9457SAndroid Build Coastguard Worker 
10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
11*4bdc9457SAndroid Build Coastguard Worker #include <array>
12*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
13*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
14*4bdc9457SAndroid Build Coastguard Worker #include <initializer_list>
15*4bdc9457SAndroid Build Coastguard Worker #include <numeric>
16*4bdc9457SAndroid Build Coastguard Worker #include <random>
17*4bdc9457SAndroid Build Coastguard Worker #include <vector>
18*4bdc9457SAndroid Build Coastguard Worker 
19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
20*4bdc9457SAndroid Build Coastguard Worker 
21*4bdc9457SAndroid Build Coastguard Worker 
22*4bdc9457SAndroid Build Coastguard Worker class ConstantPadOperatorTester {
23*4bdc9457SAndroid Build Coastguard Worker  public:
input_shape(std::initializer_list<size_t> input_shape)24*4bdc9457SAndroid Build Coastguard Worker   inline ConstantPadOperatorTester& input_shape(std::initializer_list<size_t> input_shape) {
25*4bdc9457SAndroid Build Coastguard Worker     assert(input_shape.size() <= XNN_MAX_TENSOR_DIMS);
26*4bdc9457SAndroid Build Coastguard Worker     input_shape_ = std::vector<size_t>(input_shape);
27*4bdc9457SAndroid Build Coastguard Worker     return *this;
28*4bdc9457SAndroid Build Coastguard Worker   }
29*4bdc9457SAndroid Build Coastguard Worker 
input_shape()30*4bdc9457SAndroid Build Coastguard Worker   inline const std::vector<size_t>& input_shape() const {
31*4bdc9457SAndroid Build Coastguard Worker     return input_shape_;
32*4bdc9457SAndroid Build Coastguard Worker   }
33*4bdc9457SAndroid Build Coastguard Worker 
input_dim(size_t i)34*4bdc9457SAndroid Build Coastguard Worker   inline size_t input_dim(size_t i) const {
35*4bdc9457SAndroid Build Coastguard Worker     return i < input_shape_.size() ? input_shape_[i] : 1;
36*4bdc9457SAndroid Build Coastguard Worker   }
37*4bdc9457SAndroid Build Coastguard Worker 
num_dims()38*4bdc9457SAndroid Build Coastguard Worker   inline size_t num_dims() const {
39*4bdc9457SAndroid Build Coastguard Worker     return input_shape_.size();
40*4bdc9457SAndroid Build Coastguard Worker   }
41*4bdc9457SAndroid Build Coastguard Worker 
num_input_elements()42*4bdc9457SAndroid Build Coastguard Worker   inline size_t num_input_elements() const {
43*4bdc9457SAndroid Build Coastguard Worker     return std::accumulate(
44*4bdc9457SAndroid Build Coastguard Worker       input_shape_.cbegin(), input_shape_.cend(), size_t(1), std::multiplies<size_t>());
45*4bdc9457SAndroid Build Coastguard Worker   }
46*4bdc9457SAndroid Build Coastguard Worker 
pre_paddings(std::initializer_list<size_t> pre_paddings)47*4bdc9457SAndroid Build Coastguard Worker   inline ConstantPadOperatorTester& pre_paddings(std::initializer_list<size_t> pre_paddings) {
48*4bdc9457SAndroid Build Coastguard Worker     assert(pre_paddings.size() <= XNN_MAX_TENSOR_DIMS);
49*4bdc9457SAndroid Build Coastguard Worker     pre_paddings_ = std::vector<size_t>(pre_paddings);
50*4bdc9457SAndroid Build Coastguard Worker     return *this;
51*4bdc9457SAndroid Build Coastguard Worker   }
52*4bdc9457SAndroid Build Coastguard Worker 
pre_paddings()53*4bdc9457SAndroid Build Coastguard Worker   inline const std::vector<size_t>& pre_paddings() const {
54*4bdc9457SAndroid Build Coastguard Worker     return pre_paddings_;
55*4bdc9457SAndroid Build Coastguard Worker   }
56*4bdc9457SAndroid Build Coastguard Worker 
pre_padding(size_t i)57*4bdc9457SAndroid Build Coastguard Worker   inline size_t pre_padding(size_t i) const {
58*4bdc9457SAndroid Build Coastguard Worker     return i < pre_paddings_.size() ? pre_paddings_[i] : 0;
59*4bdc9457SAndroid Build Coastguard Worker   }
60*4bdc9457SAndroid Build Coastguard Worker 
num_pre_paddings()61*4bdc9457SAndroid Build Coastguard Worker   inline size_t num_pre_paddings() const {
62*4bdc9457SAndroid Build Coastguard Worker     return pre_paddings_.size();
63*4bdc9457SAndroid Build Coastguard Worker   }
64*4bdc9457SAndroid Build Coastguard Worker 
post_paddings(std::initializer_list<size_t> post_paddings)65*4bdc9457SAndroid Build Coastguard Worker   inline ConstantPadOperatorTester& post_paddings(std::initializer_list<size_t> post_paddings) {
66*4bdc9457SAndroid Build Coastguard Worker     assert(post_paddings.size() <= XNN_MAX_TENSOR_DIMS);
67*4bdc9457SAndroid Build Coastguard Worker     post_paddings_ = std::vector<size_t>(post_paddings);
68*4bdc9457SAndroid Build Coastguard Worker     return *this;
69*4bdc9457SAndroid Build Coastguard Worker   }
70*4bdc9457SAndroid Build Coastguard Worker 
post_paddings()71*4bdc9457SAndroid Build Coastguard Worker   inline const std::vector<size_t>& post_paddings() const {
72*4bdc9457SAndroid Build Coastguard Worker     return post_paddings_;
73*4bdc9457SAndroid Build Coastguard Worker   }
74*4bdc9457SAndroid Build Coastguard Worker 
post_padding(size_t i)75*4bdc9457SAndroid Build Coastguard Worker   inline size_t post_padding(size_t i) const {
76*4bdc9457SAndroid Build Coastguard Worker     return i < post_paddings_.size() ? post_paddings_[i] : 0;
77*4bdc9457SAndroid Build Coastguard Worker   }
78*4bdc9457SAndroid Build Coastguard Worker 
num_post_paddings()79*4bdc9457SAndroid Build Coastguard Worker   inline size_t num_post_paddings() const {
80*4bdc9457SAndroid Build Coastguard Worker     return post_paddings_.size();
81*4bdc9457SAndroid Build Coastguard Worker   }
82*4bdc9457SAndroid Build Coastguard Worker 
output_dim(size_t i)83*4bdc9457SAndroid Build Coastguard Worker   inline size_t output_dim(size_t i) const {
84*4bdc9457SAndroid Build Coastguard Worker     return pre_padding(i) + input_dim(i) + post_padding(i);
85*4bdc9457SAndroid Build Coastguard Worker   }
86*4bdc9457SAndroid Build Coastguard Worker 
num_output_elements()87*4bdc9457SAndroid Build Coastguard Worker   inline size_t num_output_elements() const {
88*4bdc9457SAndroid Build Coastguard Worker     size_t elements = 1;
89*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < num_dims(); i++) {
90*4bdc9457SAndroid Build Coastguard Worker       elements *= output_dim(i);
91*4bdc9457SAndroid Build Coastguard Worker     }
92*4bdc9457SAndroid Build Coastguard Worker     return elements;
93*4bdc9457SAndroid Build Coastguard Worker   }
94*4bdc9457SAndroid Build Coastguard Worker 
iterations(size_t iterations)95*4bdc9457SAndroid Build Coastguard Worker   inline ConstantPadOperatorTester& iterations(size_t iterations) {
96*4bdc9457SAndroid Build Coastguard Worker     this->iterations_ = iterations;
97*4bdc9457SAndroid Build Coastguard Worker     return *this;
98*4bdc9457SAndroid Build Coastguard Worker   }
99*4bdc9457SAndroid Build Coastguard Worker 
iterations()100*4bdc9457SAndroid Build Coastguard Worker   inline size_t iterations() const {
101*4bdc9457SAndroid Build Coastguard Worker     return this->iterations_;
102*4bdc9457SAndroid Build Coastguard Worker   }
103*4bdc9457SAndroid Build Coastguard Worker 
TestX8()104*4bdc9457SAndroid Build Coastguard Worker   void TestX8() const {
105*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(num_dims(), num_pre_paddings());
106*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(num_dims(), num_post_paddings());
107*4bdc9457SAndroid Build Coastguard Worker 
108*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
109*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
110*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> u8dist(
111*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
112*4bdc9457SAndroid Build Coastguard Worker 
113*4bdc9457SAndroid Build Coastguard Worker     // Compute generalized shapes.
114*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input_dims;
115*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input_pre_paddings;
116*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input_post_paddings;
117*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> output_dims;
118*4bdc9457SAndroid Build Coastguard Worker     std::fill(input_dims.begin(), input_dims.end(), 1);
119*4bdc9457SAndroid Build Coastguard Worker     std::fill(input_pre_paddings.begin(), input_pre_paddings.end(), 0);
120*4bdc9457SAndroid Build Coastguard Worker     std::fill(input_post_paddings.begin(), input_post_paddings.end(), 0);
121*4bdc9457SAndroid Build Coastguard Worker     std::fill(output_dims.begin(), output_dims.end(), 1);
122*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < num_dims(); i++) {
123*4bdc9457SAndroid Build Coastguard Worker       input_dims[XNN_MAX_TENSOR_DIMS - num_dims() + i] = input_dim(i);
124*4bdc9457SAndroid Build Coastguard Worker       input_pre_paddings[XNN_MAX_TENSOR_DIMS - num_dims() + i] = pre_padding(i);
125*4bdc9457SAndroid Build Coastguard Worker       input_post_paddings[XNN_MAX_TENSOR_DIMS - num_dims() + i] = post_padding(i);
126*4bdc9457SAndroid Build Coastguard Worker       output_dims[XNN_MAX_TENSOR_DIMS - num_dims() + i] = output_dim(i);
127*4bdc9457SAndroid Build Coastguard Worker     }
128*4bdc9457SAndroid Build Coastguard Worker 
129*4bdc9457SAndroid Build Coastguard Worker     // Compute generalized strides.
130*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input_strides;
131*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> output_strides;
132*4bdc9457SAndroid Build Coastguard Worker     size_t input_stride = 1, output_stride = 1;
133*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = XNN_MAX_TENSOR_DIMS; i != 0; i--) {
134*4bdc9457SAndroid Build Coastguard Worker       input_strides[i - 1] = input_stride;
135*4bdc9457SAndroid Build Coastguard Worker       output_strides[i - 1] = output_stride;
136*4bdc9457SAndroid Build Coastguard Worker       input_stride *= input_dims[i - 1];
137*4bdc9457SAndroid Build Coastguard Worker       output_stride *= output_dims[i - 1];
138*4bdc9457SAndroid Build Coastguard Worker     }
139*4bdc9457SAndroid Build Coastguard Worker 
140*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + num_input_elements());
141*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output(num_output_elements());
142*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output_ref(num_output_elements());
143*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
144*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
145*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT32_C(0xAA));
146*4bdc9457SAndroid Build Coastguard Worker       const uint8_t padding_value = u8dist(rng);
147*4bdc9457SAndroid Build Coastguard Worker 
148*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
149*4bdc9457SAndroid Build Coastguard Worker       std::fill(output_ref.begin(), output_ref.end(), padding_value);
150*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < input_dims[0]; i++) {
151*4bdc9457SAndroid Build Coastguard Worker         for (size_t j = 0; j < input_dims[1]; j++) {
152*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < input_dims[2]; k++) {
153*4bdc9457SAndroid Build Coastguard Worker             for (size_t l = 0; l < input_dims[3]; l++) {
154*4bdc9457SAndroid Build Coastguard Worker               for (size_t m = 0; m < input_dims[4]; m++) {
155*4bdc9457SAndroid Build Coastguard Worker                 for (size_t n = 0; n < input_dims[5]; n++) {
156*4bdc9457SAndroid Build Coastguard Worker                   const size_t output_index =
157*4bdc9457SAndroid Build Coastguard Worker                     (i + input_pre_paddings[0]) * output_strides[0] +
158*4bdc9457SAndroid Build Coastguard Worker                     (j + input_pre_paddings[1]) * output_strides[1] +
159*4bdc9457SAndroid Build Coastguard Worker                     (k + input_pre_paddings[2]) * output_strides[2] +
160*4bdc9457SAndroid Build Coastguard Worker                     (l + input_pre_paddings[3]) * output_strides[3] +
161*4bdc9457SAndroid Build Coastguard Worker                     (m + input_pre_paddings[4]) * output_strides[4] +
162*4bdc9457SAndroid Build Coastguard Worker                     (n + input_pre_paddings[5]) * output_strides[5];
163*4bdc9457SAndroid Build Coastguard Worker                   const size_t input_index =
164*4bdc9457SAndroid Build Coastguard Worker                     i * input_strides[0] + j * input_strides[1] + k * input_strides[2] +
165*4bdc9457SAndroid Build Coastguard Worker                     l * input_strides[3] + m * input_strides[4] + n * input_strides[5];
166*4bdc9457SAndroid Build Coastguard Worker                   output_ref[output_index] = input[input_index];
167*4bdc9457SAndroid Build Coastguard Worker                 }
168*4bdc9457SAndroid Build Coastguard Worker               }
169*4bdc9457SAndroid Build Coastguard Worker             }
170*4bdc9457SAndroid Build Coastguard Worker           }
171*4bdc9457SAndroid Build Coastguard Worker         }
172*4bdc9457SAndroid Build Coastguard Worker       }
173*4bdc9457SAndroid Build Coastguard Worker 
174*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy a binary elementwise operator.
175*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
176*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t pad_op = nullptr;
177*4bdc9457SAndroid Build Coastguard Worker 
178*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
179*4bdc9457SAndroid Build Coastguard Worker         xnn_create_constant_pad_nd_x8(
180*4bdc9457SAndroid Build Coastguard Worker           &padding_value, 0, &pad_op));
181*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, pad_op);
182*4bdc9457SAndroid Build Coastguard Worker 
183*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete pad_op.
184*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_pad_op(pad_op, xnn_delete_operator);
185*4bdc9457SAndroid Build Coastguard Worker 
186*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
187*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_constant_pad_nd_x8(
188*4bdc9457SAndroid Build Coastguard Worker           pad_op,
189*4bdc9457SAndroid Build Coastguard Worker           num_dims(),
190*4bdc9457SAndroid Build Coastguard Worker           input_shape().data(), pre_paddings().data(), post_paddings().data(),
191*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
192*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
193*4bdc9457SAndroid Build Coastguard Worker 
194*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
195*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(pad_op, nullptr /* thread pool */));
196*4bdc9457SAndroid Build Coastguard Worker 
197*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
198*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < output_dims[0]; i++) {
199*4bdc9457SAndroid Build Coastguard Worker         for (size_t j = 0; j < output_dims[1]; j++) {
200*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < output_dims[2]; k++) {
201*4bdc9457SAndroid Build Coastguard Worker             for (size_t l = 0; l < output_dims[3]; l++) {
202*4bdc9457SAndroid Build Coastguard Worker               for (size_t m = 0; m < output_dims[4]; m++) {
203*4bdc9457SAndroid Build Coastguard Worker                 for (size_t n = 0; n < output_dims[5]; n++) {
204*4bdc9457SAndroid Build Coastguard Worker                   const size_t index =
205*4bdc9457SAndroid Build Coastguard Worker                     i * output_strides[0] + j * output_strides[1] + k * output_strides[2] +
206*4bdc9457SAndroid Build Coastguard Worker                     l * output_strides[3] + m * output_strides[4] + n * output_strides[5];
207*4bdc9457SAndroid Build Coastguard Worker                   ASSERT_EQ(output[index], output_ref[index])
208*4bdc9457SAndroid Build Coastguard Worker                     << "(i, j, k, l, m, n) = ("
209*4bdc9457SAndroid Build Coastguard Worker                     << i << ", " << j << ", " << k << ", " << l << ", " << m << ", " << n << ")"
210*4bdc9457SAndroid Build Coastguard Worker                     << ", padding value = " << padding_value;
211*4bdc9457SAndroid Build Coastguard Worker                 }
212*4bdc9457SAndroid Build Coastguard Worker               }
213*4bdc9457SAndroid Build Coastguard Worker             }
214*4bdc9457SAndroid Build Coastguard Worker           }
215*4bdc9457SAndroid Build Coastguard Worker         }
216*4bdc9457SAndroid Build Coastguard Worker       }
217*4bdc9457SAndroid Build Coastguard Worker     }
218*4bdc9457SAndroid Build Coastguard Worker   }
219*4bdc9457SAndroid Build Coastguard Worker 
TestX16()220*4bdc9457SAndroid Build Coastguard Worker   void TestX16() const {
221*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(num_dims(), num_pre_paddings());
222*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(num_dims(), num_post_paddings());
223*4bdc9457SAndroid Build Coastguard Worker 
224*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
225*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
226*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<uint16_t> u16dist;
227*4bdc9457SAndroid Build Coastguard Worker 
228*4bdc9457SAndroid Build Coastguard Worker     // Compute generalized shapes.
229*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input_dims;
230*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input_pre_paddings;
231*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input_post_paddings;
232*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> output_dims;
233*4bdc9457SAndroid Build Coastguard Worker     std::fill(input_dims.begin(), input_dims.end(), 1);
234*4bdc9457SAndroid Build Coastguard Worker     std::fill(input_pre_paddings.begin(), input_pre_paddings.end(), 0);
235*4bdc9457SAndroid Build Coastguard Worker     std::fill(input_post_paddings.begin(), input_post_paddings.end(), 0);
236*4bdc9457SAndroid Build Coastguard Worker     std::fill(output_dims.begin(), output_dims.end(), 1);
237*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < num_dims(); i++) {
238*4bdc9457SAndroid Build Coastguard Worker       input_dims[XNN_MAX_TENSOR_DIMS - num_dims() + i] = input_dim(i);
239*4bdc9457SAndroid Build Coastguard Worker       input_pre_paddings[XNN_MAX_TENSOR_DIMS - num_dims() + i] = pre_padding(i);
240*4bdc9457SAndroid Build Coastguard Worker       input_post_paddings[XNN_MAX_TENSOR_DIMS - num_dims() + i] = post_padding(i);
241*4bdc9457SAndroid Build Coastguard Worker       output_dims[XNN_MAX_TENSOR_DIMS - num_dims() + i] = output_dim(i);
242*4bdc9457SAndroid Build Coastguard Worker     }
243*4bdc9457SAndroid Build Coastguard Worker 
244*4bdc9457SAndroid Build Coastguard Worker     // Compute generalized strides.
245*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input_strides;
246*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> output_strides;
247*4bdc9457SAndroid Build Coastguard Worker     size_t input_stride = 1, output_stride = 1;
248*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = XNN_MAX_TENSOR_DIMS; i != 0; i--) {
249*4bdc9457SAndroid Build Coastguard Worker       input_strides[i - 1] = input_stride;
250*4bdc9457SAndroid Build Coastguard Worker       output_strides[i - 1] = output_stride;
251*4bdc9457SAndroid Build Coastguard Worker       input_stride *= input_dims[i - 1];
252*4bdc9457SAndroid Build Coastguard Worker       output_stride *= output_dims[i - 1];
253*4bdc9457SAndroid Build Coastguard Worker     }
254*4bdc9457SAndroid Build Coastguard Worker 
255*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) + num_input_elements());
256*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> output(num_output_elements());
257*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> output_ref(num_output_elements());
258*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
259*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return u16dist(rng); });
260*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT16_C(0xDEAD));
261*4bdc9457SAndroid Build Coastguard Worker       const uint16_t padding_value = u16dist(rng);
262*4bdc9457SAndroid Build Coastguard Worker 
263*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
264*4bdc9457SAndroid Build Coastguard Worker       std::fill(output_ref.begin(), output_ref.end(), padding_value);
265*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < input_dims[0]; i++) {
266*4bdc9457SAndroid Build Coastguard Worker         for (size_t j = 0; j < input_dims[1]; j++) {
267*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < input_dims[2]; k++) {
268*4bdc9457SAndroid Build Coastguard Worker             for (size_t l = 0; l < input_dims[3]; l++) {
269*4bdc9457SAndroid Build Coastguard Worker               for (size_t m = 0; m < input_dims[4]; m++) {
270*4bdc9457SAndroid Build Coastguard Worker                 for (size_t n = 0; n < input_dims[5]; n++) {
271*4bdc9457SAndroid Build Coastguard Worker                   const size_t output_index =
272*4bdc9457SAndroid Build Coastguard Worker                     (i + input_pre_paddings[0]) * output_strides[0] +
273*4bdc9457SAndroid Build Coastguard Worker                     (j + input_pre_paddings[1]) * output_strides[1] +
274*4bdc9457SAndroid Build Coastguard Worker                     (k + input_pre_paddings[2]) * output_strides[2] +
275*4bdc9457SAndroid Build Coastguard Worker                     (l + input_pre_paddings[3]) * output_strides[3] +
276*4bdc9457SAndroid Build Coastguard Worker                     (m + input_pre_paddings[4]) * output_strides[4] +
277*4bdc9457SAndroid Build Coastguard Worker                     (n + input_pre_paddings[5]) * output_strides[5];
278*4bdc9457SAndroid Build Coastguard Worker                   const size_t input_index =
279*4bdc9457SAndroid Build Coastguard Worker                     i * input_strides[0] + j * input_strides[1] + k * input_strides[2] +
280*4bdc9457SAndroid Build Coastguard Worker                     l * input_strides[3] + m * input_strides[4] + n * input_strides[5];
281*4bdc9457SAndroid Build Coastguard Worker                   output_ref[output_index] = input[input_index];
282*4bdc9457SAndroid Build Coastguard Worker                 }
283*4bdc9457SAndroid Build Coastguard Worker               }
284*4bdc9457SAndroid Build Coastguard Worker             }
285*4bdc9457SAndroid Build Coastguard Worker           }
286*4bdc9457SAndroid Build Coastguard Worker         }
287*4bdc9457SAndroid Build Coastguard Worker       }
288*4bdc9457SAndroid Build Coastguard Worker 
289*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy a binary elementwise operator.
290*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
291*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t pad_op = nullptr;
292*4bdc9457SAndroid Build Coastguard Worker 
293*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
294*4bdc9457SAndroid Build Coastguard Worker         xnn_create_constant_pad_nd_x16(
295*4bdc9457SAndroid Build Coastguard Worker           &padding_value, 0, &pad_op));
296*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, pad_op);
297*4bdc9457SAndroid Build Coastguard Worker 
298*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete pad_op.
299*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_pad_op(pad_op, xnn_delete_operator);
300*4bdc9457SAndroid Build Coastguard Worker 
301*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
302*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_constant_pad_nd_x16(
303*4bdc9457SAndroid Build Coastguard Worker           pad_op,
304*4bdc9457SAndroid Build Coastguard Worker           num_dims(),
305*4bdc9457SAndroid Build Coastguard Worker           input_shape().data(), pre_paddings().data(), post_paddings().data(),
306*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
307*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
308*4bdc9457SAndroid Build Coastguard Worker 
309*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
310*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(pad_op, nullptr /* thread pool */));
311*4bdc9457SAndroid Build Coastguard Worker 
312*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
313*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < output_dims[0]; i++) {
314*4bdc9457SAndroid Build Coastguard Worker         for (size_t j = 0; j < output_dims[1]; j++) {
315*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < output_dims[2]; k++) {
316*4bdc9457SAndroid Build Coastguard Worker             for (size_t l = 0; l < output_dims[3]; l++) {
317*4bdc9457SAndroid Build Coastguard Worker               for (size_t m = 0; m < output_dims[4]; m++) {
318*4bdc9457SAndroid Build Coastguard Worker                 for (size_t n = 0; n < output_dims[5]; n++) {
319*4bdc9457SAndroid Build Coastguard Worker                   const size_t index =
320*4bdc9457SAndroid Build Coastguard Worker                     i * output_strides[0] + j * output_strides[1] + k * output_strides[2] +
321*4bdc9457SAndroid Build Coastguard Worker                     l * output_strides[3] + m * output_strides[4] + n * output_strides[5];
322*4bdc9457SAndroid Build Coastguard Worker                   ASSERT_EQ(output[index], output_ref[index])
323*4bdc9457SAndroid Build Coastguard Worker                     << "(i, j, k, l, m, n) = ("
324*4bdc9457SAndroid Build Coastguard Worker                     << i << ", " << j << ", " << k << ", " << l << ", " << m << ", " << n << ")"
325*4bdc9457SAndroid Build Coastguard Worker                     << ", padding value = " << padding_value;
326*4bdc9457SAndroid Build Coastguard Worker                 }
327*4bdc9457SAndroid Build Coastguard Worker               }
328*4bdc9457SAndroid Build Coastguard Worker             }
329*4bdc9457SAndroid Build Coastguard Worker           }
330*4bdc9457SAndroid Build Coastguard Worker         }
331*4bdc9457SAndroid Build Coastguard Worker       }
332*4bdc9457SAndroid Build Coastguard Worker     }
333*4bdc9457SAndroid Build Coastguard Worker   }
334*4bdc9457SAndroid Build Coastguard Worker 
TestX32()335*4bdc9457SAndroid Build Coastguard Worker   void TestX32() const {
336*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(num_dims(), num_pre_paddings());
337*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(num_dims(), num_post_paddings());
338*4bdc9457SAndroid Build Coastguard Worker 
339*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
340*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
341*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<uint32_t> u32dist;
342*4bdc9457SAndroid Build Coastguard Worker 
343*4bdc9457SAndroid Build Coastguard Worker     // Compute generalized shapes.
344*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input_dims;
345*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input_pre_paddings;
346*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input_post_paddings;
347*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> output_dims;
348*4bdc9457SAndroid Build Coastguard Worker     std::fill(input_dims.begin(), input_dims.end(), 1);
349*4bdc9457SAndroid Build Coastguard Worker     std::fill(input_pre_paddings.begin(), input_pre_paddings.end(), 0);
350*4bdc9457SAndroid Build Coastguard Worker     std::fill(input_post_paddings.begin(), input_post_paddings.end(), 0);
351*4bdc9457SAndroid Build Coastguard Worker     std::fill(output_dims.begin(), output_dims.end(), 1);
352*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < num_dims(); i++) {
353*4bdc9457SAndroid Build Coastguard Worker       input_dims[XNN_MAX_TENSOR_DIMS - num_dims() + i] = input_dim(i);
354*4bdc9457SAndroid Build Coastguard Worker       input_pre_paddings[XNN_MAX_TENSOR_DIMS - num_dims() + i] = pre_padding(i);
355*4bdc9457SAndroid Build Coastguard Worker       input_post_paddings[XNN_MAX_TENSOR_DIMS - num_dims() + i] = post_padding(i);
356*4bdc9457SAndroid Build Coastguard Worker       output_dims[XNN_MAX_TENSOR_DIMS - num_dims() + i] = output_dim(i);
357*4bdc9457SAndroid Build Coastguard Worker     }
358*4bdc9457SAndroid Build Coastguard Worker 
359*4bdc9457SAndroid Build Coastguard Worker     // Compute generalized strides.
360*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input_strides;
361*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> output_strides;
362*4bdc9457SAndroid Build Coastguard Worker     size_t input_stride = 1, output_stride = 1;
363*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = XNN_MAX_TENSOR_DIMS; i != 0; i--) {
364*4bdc9457SAndroid Build Coastguard Worker       input_strides[i - 1] = input_stride;
365*4bdc9457SAndroid Build Coastguard Worker       output_strides[i - 1] = output_stride;
366*4bdc9457SAndroid Build Coastguard Worker       input_stride *= input_dims[i - 1];
367*4bdc9457SAndroid Build Coastguard Worker       output_stride *= output_dims[i - 1];
368*4bdc9457SAndroid Build Coastguard Worker     }
369*4bdc9457SAndroid Build Coastguard Worker 
370*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint32_t> input(XNN_EXTRA_BYTES / sizeof(uint32_t) + num_input_elements());
371*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint32_t> output(num_output_elements());
372*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint32_t> output_ref(num_output_elements());
373*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
374*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return u32dist(rng); });
375*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT32_C(0xDEADBEEF));
376*4bdc9457SAndroid Build Coastguard Worker       const uint32_t padding_value = u32dist(rng);
377*4bdc9457SAndroid Build Coastguard Worker 
378*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
379*4bdc9457SAndroid Build Coastguard Worker       std::fill(output_ref.begin(), output_ref.end(), padding_value);
380*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < input_dims[0]; i++) {
381*4bdc9457SAndroid Build Coastguard Worker         for (size_t j = 0; j < input_dims[1]; j++) {
382*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < input_dims[2]; k++) {
383*4bdc9457SAndroid Build Coastguard Worker             for (size_t l = 0; l < input_dims[3]; l++) {
384*4bdc9457SAndroid Build Coastguard Worker               for (size_t m = 0; m < input_dims[4]; m++) {
385*4bdc9457SAndroid Build Coastguard Worker                 for (size_t n = 0; n < input_dims[5]; n++) {
386*4bdc9457SAndroid Build Coastguard Worker                   const size_t output_index =
387*4bdc9457SAndroid Build Coastguard Worker                     (i + input_pre_paddings[0]) * output_strides[0] +
388*4bdc9457SAndroid Build Coastguard Worker                     (j + input_pre_paddings[1]) * output_strides[1] +
389*4bdc9457SAndroid Build Coastguard Worker                     (k + input_pre_paddings[2]) * output_strides[2] +
390*4bdc9457SAndroid Build Coastguard Worker                     (l + input_pre_paddings[3]) * output_strides[3] +
391*4bdc9457SAndroid Build Coastguard Worker                     (m + input_pre_paddings[4]) * output_strides[4] +
392*4bdc9457SAndroid Build Coastguard Worker                     (n + input_pre_paddings[5]) * output_strides[5];
393*4bdc9457SAndroid Build Coastguard Worker                   const size_t input_index =
394*4bdc9457SAndroid Build Coastguard Worker                     i * input_strides[0] + j * input_strides[1] + k * input_strides[2] +
395*4bdc9457SAndroid Build Coastguard Worker                     l * input_strides[3] + m * input_strides[4] + n * input_strides[5];
396*4bdc9457SAndroid Build Coastguard Worker                   output_ref[output_index] = input[input_index];
397*4bdc9457SAndroid Build Coastguard Worker                 }
398*4bdc9457SAndroid Build Coastguard Worker               }
399*4bdc9457SAndroid Build Coastguard Worker             }
400*4bdc9457SAndroid Build Coastguard Worker           }
401*4bdc9457SAndroid Build Coastguard Worker         }
402*4bdc9457SAndroid Build Coastguard Worker       }
403*4bdc9457SAndroid Build Coastguard Worker 
404*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy a binary elementwise operator.
405*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
406*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t pad_op = nullptr;
407*4bdc9457SAndroid Build Coastguard Worker 
408*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
409*4bdc9457SAndroid Build Coastguard Worker         xnn_create_constant_pad_nd_x32(
410*4bdc9457SAndroid Build Coastguard Worker           &padding_value, 0, &pad_op));
411*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, pad_op);
412*4bdc9457SAndroid Build Coastguard Worker 
413*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete pad_op.
414*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_pad_op(pad_op, xnn_delete_operator);
415*4bdc9457SAndroid Build Coastguard Worker 
416*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
417*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_constant_pad_nd_x32(
418*4bdc9457SAndroid Build Coastguard Worker           pad_op,
419*4bdc9457SAndroid Build Coastguard Worker           num_dims(),
420*4bdc9457SAndroid Build Coastguard Worker           input_shape().data(), pre_paddings().data(), post_paddings().data(),
421*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
422*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
423*4bdc9457SAndroid Build Coastguard Worker 
424*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
425*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(pad_op, nullptr /* thread pool */));
426*4bdc9457SAndroid Build Coastguard Worker 
427*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
428*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < output_dims[0]; i++) {
429*4bdc9457SAndroid Build Coastguard Worker         for (size_t j = 0; j < output_dims[1]; j++) {
430*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < output_dims[2]; k++) {
431*4bdc9457SAndroid Build Coastguard Worker             for (size_t l = 0; l < output_dims[3]; l++) {
432*4bdc9457SAndroid Build Coastguard Worker               for (size_t m = 0; m < output_dims[4]; m++) {
433*4bdc9457SAndroid Build Coastguard Worker                 for (size_t n = 0; n < output_dims[5]; n++) {
434*4bdc9457SAndroid Build Coastguard Worker                   const size_t index =
435*4bdc9457SAndroid Build Coastguard Worker                     i * output_strides[0] + j * output_strides[1] + k * output_strides[2] +
436*4bdc9457SAndroid Build Coastguard Worker                     l * output_strides[3] + m * output_strides[4] + n * output_strides[5];
437*4bdc9457SAndroid Build Coastguard Worker                   ASSERT_EQ(output[index], output_ref[index])
438*4bdc9457SAndroid Build Coastguard Worker                     << "(i, j, k, l, m, n) = ("
439*4bdc9457SAndroid Build Coastguard Worker                     << i << ", " << j << ", " << k << ", " << l << ", " << m << ", " << n << ")"
440*4bdc9457SAndroid Build Coastguard Worker                     << ", padding value = " << padding_value;
441*4bdc9457SAndroid Build Coastguard Worker                 }
442*4bdc9457SAndroid Build Coastguard Worker               }
443*4bdc9457SAndroid Build Coastguard Worker             }
444*4bdc9457SAndroid Build Coastguard Worker           }
445*4bdc9457SAndroid Build Coastguard Worker         }
446*4bdc9457SAndroid Build Coastguard Worker       }
447*4bdc9457SAndroid Build Coastguard Worker     }
448*4bdc9457SAndroid Build Coastguard Worker   }
449*4bdc9457SAndroid Build Coastguard Worker 
450*4bdc9457SAndroid Build Coastguard Worker  private:
451*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> input_shape_;
452*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> pre_paddings_;
453*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> post_paddings_;
454*4bdc9457SAndroid Build Coastguard Worker   size_t iterations_{3};
455*4bdc9457SAndroid Build Coastguard Worker };
456