xref: /aosp_15_r20/external/XNNPACK/test/binary-elementwise-operator-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #pragma once
7*4bdc9457SAndroid Build Coastguard Worker 
8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
9*4bdc9457SAndroid Build Coastguard Worker 
10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
11*4bdc9457SAndroid Build Coastguard Worker #include <array>
12*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
13*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
14*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
15*4bdc9457SAndroid Build Coastguard Worker #include <initializer_list>
16*4bdc9457SAndroid Build Coastguard Worker #include <limits>
17*4bdc9457SAndroid Build Coastguard Worker #include <numeric>
18*4bdc9457SAndroid Build Coastguard Worker #include <random>
19*4bdc9457SAndroid Build Coastguard Worker #include <vector>
20*4bdc9457SAndroid Build Coastguard Worker 
21*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
22*4bdc9457SAndroid Build Coastguard Worker 
23*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
24*4bdc9457SAndroid Build Coastguard Worker 
25*4bdc9457SAndroid Build Coastguard Worker 
26*4bdc9457SAndroid Build Coastguard Worker class BinaryElementwiseOperatorTester {
27*4bdc9457SAndroid Build Coastguard Worker  public:
28*4bdc9457SAndroid Build Coastguard Worker   enum class OperationType {
29*4bdc9457SAndroid Build Coastguard Worker     Unknown,
30*4bdc9457SAndroid Build Coastguard Worker     Add,
31*4bdc9457SAndroid Build Coastguard Worker     Divide,
32*4bdc9457SAndroid Build Coastguard Worker     Maximum,
33*4bdc9457SAndroid Build Coastguard Worker     Minimum,
34*4bdc9457SAndroid Build Coastguard Worker     Multiply,
35*4bdc9457SAndroid Build Coastguard Worker     Subtract,
36*4bdc9457SAndroid Build Coastguard Worker     SquaredDifference,
37*4bdc9457SAndroid Build Coastguard Worker   };
38*4bdc9457SAndroid Build Coastguard Worker 
input1_shape(std::initializer_list<size_t> input1_shape)39*4bdc9457SAndroid Build Coastguard Worker   inline BinaryElementwiseOperatorTester& input1_shape(std::initializer_list<size_t> input1_shape) {
40*4bdc9457SAndroid Build Coastguard Worker     assert(input1_shape.size() <= XNN_MAX_TENSOR_DIMS);
41*4bdc9457SAndroid Build Coastguard Worker     this->input1_shape_ = std::vector<size_t>(input1_shape);
42*4bdc9457SAndroid Build Coastguard Worker     return *this;
43*4bdc9457SAndroid Build Coastguard Worker   }
44*4bdc9457SAndroid Build Coastguard Worker 
input1_shape()45*4bdc9457SAndroid Build Coastguard Worker   inline const std::vector<size_t>& input1_shape() const {
46*4bdc9457SAndroid Build Coastguard Worker     return this->input1_shape_;
47*4bdc9457SAndroid Build Coastguard Worker   }
48*4bdc9457SAndroid Build Coastguard Worker 
input1_dim(size_t i)49*4bdc9457SAndroid Build Coastguard Worker   inline size_t input1_dim(size_t i) const {
50*4bdc9457SAndroid Build Coastguard Worker     return i < num_input1_dims() ? this->input1_shape_[i] : 1;
51*4bdc9457SAndroid Build Coastguard Worker   }
52*4bdc9457SAndroid Build Coastguard Worker 
num_input1_dims()53*4bdc9457SAndroid Build Coastguard Worker   inline size_t num_input1_dims() const {
54*4bdc9457SAndroid Build Coastguard Worker     return this->input1_shape_.size();
55*4bdc9457SAndroid Build Coastguard Worker   }
56*4bdc9457SAndroid Build Coastguard Worker 
num_input1_elements()57*4bdc9457SAndroid Build Coastguard Worker   inline size_t num_input1_elements() const {
58*4bdc9457SAndroid Build Coastguard Worker     return std::accumulate(
59*4bdc9457SAndroid Build Coastguard Worker       this->input1_shape_.begin(), this->input1_shape_.end(), size_t(1), std::multiplies<size_t>());
60*4bdc9457SAndroid Build Coastguard Worker   }
61*4bdc9457SAndroid Build Coastguard Worker 
input1_zero_point(int16_t input1_zero_point)62*4bdc9457SAndroid Build Coastguard Worker   inline BinaryElementwiseOperatorTester& input1_zero_point(int16_t input1_zero_point) {
63*4bdc9457SAndroid Build Coastguard Worker     this->input1_zero_point_ = input1_zero_point;
64*4bdc9457SAndroid Build Coastguard Worker     return *this;
65*4bdc9457SAndroid Build Coastguard Worker   }
66*4bdc9457SAndroid Build Coastguard Worker 
input1_zero_point()67*4bdc9457SAndroid Build Coastguard Worker   inline int16_t input1_zero_point() const {
68*4bdc9457SAndroid Build Coastguard Worker     return this->input1_zero_point_;
69*4bdc9457SAndroid Build Coastguard Worker   }
70*4bdc9457SAndroid Build Coastguard Worker 
input1_scale(float input1_scale)71*4bdc9457SAndroid Build Coastguard Worker   inline BinaryElementwiseOperatorTester& input1_scale(float input1_scale) {
72*4bdc9457SAndroid Build Coastguard Worker     assert(std::isfinite(input1_scale));
73*4bdc9457SAndroid Build Coastguard Worker     this->input1_scale_ = input1_scale;
74*4bdc9457SAndroid Build Coastguard Worker     return *this;
75*4bdc9457SAndroid Build Coastguard Worker   }
76*4bdc9457SAndroid Build Coastguard Worker 
input1_scale()77*4bdc9457SAndroid Build Coastguard Worker   inline float input1_scale() const {
78*4bdc9457SAndroid Build Coastguard Worker     return this->input1_scale_;
79*4bdc9457SAndroid Build Coastguard Worker   }
80*4bdc9457SAndroid Build Coastguard Worker 
input2_shape(std::initializer_list<size_t> input2_shape)81*4bdc9457SAndroid Build Coastguard Worker   inline BinaryElementwiseOperatorTester& input2_shape(std::initializer_list<size_t> input2_shape) {
82*4bdc9457SAndroid Build Coastguard Worker     assert(input2_shape.size() <= XNN_MAX_TENSOR_DIMS);
83*4bdc9457SAndroid Build Coastguard Worker     this->input2_shape_ = std::vector<size_t>(input2_shape);
84*4bdc9457SAndroid Build Coastguard Worker     return *this;
85*4bdc9457SAndroid Build Coastguard Worker   }
86*4bdc9457SAndroid Build Coastguard Worker 
input2_shape()87*4bdc9457SAndroid Build Coastguard Worker   inline const std::vector<size_t>& input2_shape() const {
88*4bdc9457SAndroid Build Coastguard Worker     return this->input2_shape_;
89*4bdc9457SAndroid Build Coastguard Worker   }
90*4bdc9457SAndroid Build Coastguard Worker 
input2_dim(size_t i)91*4bdc9457SAndroid Build Coastguard Worker   inline size_t input2_dim(size_t i) const {
92*4bdc9457SAndroid Build Coastguard Worker     return i < num_input2_dims() ? this->input2_shape_[i] : 1;
93*4bdc9457SAndroid Build Coastguard Worker   }
94*4bdc9457SAndroid Build Coastguard Worker 
num_input2_dims()95*4bdc9457SAndroid Build Coastguard Worker   inline size_t num_input2_dims() const {
96*4bdc9457SAndroid Build Coastguard Worker     return this->input2_shape_.size();
97*4bdc9457SAndroid Build Coastguard Worker   }
98*4bdc9457SAndroid Build Coastguard Worker 
num_input2_elements()99*4bdc9457SAndroid Build Coastguard Worker   inline size_t num_input2_elements() const {
100*4bdc9457SAndroid Build Coastguard Worker     return std::accumulate(
101*4bdc9457SAndroid Build Coastguard Worker       this->input2_shape_.begin(), this->input2_shape_.end(), size_t(1), std::multiplies<size_t>());
102*4bdc9457SAndroid Build Coastguard Worker   }
103*4bdc9457SAndroid Build Coastguard Worker 
input2_zero_point(int16_t input2_zero_point)104*4bdc9457SAndroid Build Coastguard Worker   inline BinaryElementwiseOperatorTester& input2_zero_point(int16_t input2_zero_point) {
105*4bdc9457SAndroid Build Coastguard Worker     this->input2_zero_point_ = input2_zero_point;
106*4bdc9457SAndroid Build Coastguard Worker     return *this;
107*4bdc9457SAndroid Build Coastguard Worker   }
108*4bdc9457SAndroid Build Coastguard Worker 
input2_zero_point()109*4bdc9457SAndroid Build Coastguard Worker   inline int16_t input2_zero_point() const {
110*4bdc9457SAndroid Build Coastguard Worker     return this->input2_zero_point_;
111*4bdc9457SAndroid Build Coastguard Worker   }
112*4bdc9457SAndroid Build Coastguard Worker 
input2_scale(float input2_scale)113*4bdc9457SAndroid Build Coastguard Worker   inline BinaryElementwiseOperatorTester& input2_scale(float input2_scale) {
114*4bdc9457SAndroid Build Coastguard Worker     assert(std::isfinite(input2_scale));
115*4bdc9457SAndroid Build Coastguard Worker     this->input2_scale_ = input2_scale;
116*4bdc9457SAndroid Build Coastguard Worker     return *this;
117*4bdc9457SAndroid Build Coastguard Worker   }
118*4bdc9457SAndroid Build Coastguard Worker 
input2_scale()119*4bdc9457SAndroid Build Coastguard Worker   inline float input2_scale() const {
120*4bdc9457SAndroid Build Coastguard Worker     return this->input2_scale_;
121*4bdc9457SAndroid Build Coastguard Worker   }
122*4bdc9457SAndroid Build Coastguard Worker 
output_zero_point(int16_t output_zero_point)123*4bdc9457SAndroid Build Coastguard Worker   inline BinaryElementwiseOperatorTester& output_zero_point(int16_t output_zero_point) {
124*4bdc9457SAndroid Build Coastguard Worker     this->output_zero_point_ = output_zero_point;
125*4bdc9457SAndroid Build Coastguard Worker     return *this;
126*4bdc9457SAndroid Build Coastguard Worker   }
127*4bdc9457SAndroid Build Coastguard Worker 
output_zero_point()128*4bdc9457SAndroid Build Coastguard Worker   inline int16_t output_zero_point() const {
129*4bdc9457SAndroid Build Coastguard Worker     return this->output_zero_point_;
130*4bdc9457SAndroid Build Coastguard Worker   }
131*4bdc9457SAndroid Build Coastguard Worker 
output_scale(float output_scale)132*4bdc9457SAndroid Build Coastguard Worker   inline BinaryElementwiseOperatorTester& output_scale(float output_scale) {
133*4bdc9457SAndroid Build Coastguard Worker     assert(std::isfinite(output_scale));
134*4bdc9457SAndroid Build Coastguard Worker     this->output_scale_ = output_scale;
135*4bdc9457SAndroid Build Coastguard Worker     return *this;
136*4bdc9457SAndroid Build Coastguard Worker   }
137*4bdc9457SAndroid Build Coastguard Worker 
output_scale()138*4bdc9457SAndroid Build Coastguard Worker   inline float output_scale() const {
139*4bdc9457SAndroid Build Coastguard Worker     return this->output_scale_;
140*4bdc9457SAndroid Build Coastguard Worker   }
141*4bdc9457SAndroid Build Coastguard Worker 
qmin(uint8_t qmin)142*4bdc9457SAndroid Build Coastguard Worker   inline BinaryElementwiseOperatorTester& qmin(uint8_t qmin) {
143*4bdc9457SAndroid Build Coastguard Worker     this->qmin_ = qmin;
144*4bdc9457SAndroid Build Coastguard Worker     return *this;
145*4bdc9457SAndroid Build Coastguard Worker   }
146*4bdc9457SAndroid Build Coastguard Worker 
qmin()147*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t qmin() const {
148*4bdc9457SAndroid Build Coastguard Worker     return this->qmin_;
149*4bdc9457SAndroid Build Coastguard Worker   }
150*4bdc9457SAndroid Build Coastguard Worker 
qmax(uint8_t qmax)151*4bdc9457SAndroid Build Coastguard Worker   inline BinaryElementwiseOperatorTester& qmax(uint8_t qmax) {
152*4bdc9457SAndroid Build Coastguard Worker     this->qmax_ = qmax;
153*4bdc9457SAndroid Build Coastguard Worker     return *this;
154*4bdc9457SAndroid Build Coastguard Worker   }
155*4bdc9457SAndroid Build Coastguard Worker 
qmax()156*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t qmax() const {
157*4bdc9457SAndroid Build Coastguard Worker     return this->qmax_;
158*4bdc9457SAndroid Build Coastguard Worker   }
159*4bdc9457SAndroid Build Coastguard Worker 
operation_type(OperationType operation_type)160*4bdc9457SAndroid Build Coastguard Worker   inline BinaryElementwiseOperatorTester& operation_type(OperationType operation_type) {
161*4bdc9457SAndroid Build Coastguard Worker     this->operation_type_ = operation_type;
162*4bdc9457SAndroid Build Coastguard Worker     return *this;
163*4bdc9457SAndroid Build Coastguard Worker   }
164*4bdc9457SAndroid Build Coastguard Worker 
operation_type()165*4bdc9457SAndroid Build Coastguard Worker   inline OperationType operation_type() const {
166*4bdc9457SAndroid Build Coastguard Worker     return this->operation_type_;
167*4bdc9457SAndroid Build Coastguard Worker   }
168*4bdc9457SAndroid Build Coastguard Worker 
iterations(size_t iterations)169*4bdc9457SAndroid Build Coastguard Worker   inline BinaryElementwiseOperatorTester& iterations(size_t iterations) {
170*4bdc9457SAndroid Build Coastguard Worker     this->iterations_ = iterations;
171*4bdc9457SAndroid Build Coastguard Worker     return *this;
172*4bdc9457SAndroid Build Coastguard Worker   }
173*4bdc9457SAndroid Build Coastguard Worker 
iterations()174*4bdc9457SAndroid Build Coastguard Worker   inline size_t iterations() const {
175*4bdc9457SAndroid Build Coastguard Worker     return this->iterations_;
176*4bdc9457SAndroid Build Coastguard Worker   }
177*4bdc9457SAndroid Build Coastguard Worker 
Compute(float a,float b)178*4bdc9457SAndroid Build Coastguard Worker   float Compute(float a, float b) const {
179*4bdc9457SAndroid Build Coastguard Worker     switch (operation_type()) {
180*4bdc9457SAndroid Build Coastguard Worker       case OperationType::Add:
181*4bdc9457SAndroid Build Coastguard Worker         return a + b;
182*4bdc9457SAndroid Build Coastguard Worker       case OperationType::Divide:
183*4bdc9457SAndroid Build Coastguard Worker         return a / b;
184*4bdc9457SAndroid Build Coastguard Worker       case OperationType::Maximum:
185*4bdc9457SAndroid Build Coastguard Worker         return std::max<float>(a, b);
186*4bdc9457SAndroid Build Coastguard Worker       case OperationType::Minimum:
187*4bdc9457SAndroid Build Coastguard Worker         return std::min<float>(a, b);
188*4bdc9457SAndroid Build Coastguard Worker       case OperationType::Multiply:
189*4bdc9457SAndroid Build Coastguard Worker         return a * b;
190*4bdc9457SAndroid Build Coastguard Worker       case OperationType::Subtract:
191*4bdc9457SAndroid Build Coastguard Worker         return a - b;
192*4bdc9457SAndroid Build Coastguard Worker       case OperationType::SquaredDifference:
193*4bdc9457SAndroid Build Coastguard Worker         return (a - b) * (a - b);
194*4bdc9457SAndroid Build Coastguard Worker       default:
195*4bdc9457SAndroid Build Coastguard Worker         return std::nanf("");
196*4bdc9457SAndroid Build Coastguard Worker     }
197*4bdc9457SAndroid Build Coastguard Worker   }
198*4bdc9457SAndroid Build Coastguard Worker 
TestQS8()199*4bdc9457SAndroid Build Coastguard Worker   void TestQS8() const {
200*4bdc9457SAndroid Build Coastguard Worker     ASSERT_NE(operation_type(), OperationType::Unknown);
201*4bdc9457SAndroid Build Coastguard Worker     ASSERT_GE(input1_zero_point(), std::numeric_limits<int8_t>::min());
202*4bdc9457SAndroid Build Coastguard Worker     ASSERT_LE(input1_zero_point(), std::numeric_limits<int8_t>::max());
203*4bdc9457SAndroid Build Coastguard Worker     ASSERT_GE(input2_zero_point(), std::numeric_limits<int8_t>::min());
204*4bdc9457SAndroid Build Coastguard Worker     ASSERT_LE(input2_zero_point(), std::numeric_limits<int8_t>::max());
205*4bdc9457SAndroid Build Coastguard Worker     ASSERT_GE(output_zero_point(), std::numeric_limits<int8_t>::min());
206*4bdc9457SAndroid Build Coastguard Worker     ASSERT_LE(output_zero_point(), std::numeric_limits<int8_t>::max());
207*4bdc9457SAndroid Build Coastguard Worker 
208*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
209*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
210*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i8dist(
211*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
212*4bdc9457SAndroid Build Coastguard Worker 
213*4bdc9457SAndroid Build Coastguard Worker     // Compute generalized shapes.
214*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_dims;
215*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_dims;
216*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> output_dims;
217*4bdc9457SAndroid Build Coastguard Worker     std::fill(input1_dims.begin(), input1_dims.end(), 1);
218*4bdc9457SAndroid Build Coastguard Worker     std::fill(input2_dims.begin(), input2_dims.end(), 1);
219*4bdc9457SAndroid Build Coastguard Worker     std::fill(output_dims.begin(), output_dims.end(), 1);
220*4bdc9457SAndroid Build Coastguard Worker     std::copy(input1_shape().cbegin(), input1_shape().cend(), input1_dims.end() - num_input1_dims());
221*4bdc9457SAndroid Build Coastguard Worker     std::copy(input2_shape().cbegin(), input2_shape().cend(), input2_dims.end() - num_input2_dims());
222*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
223*4bdc9457SAndroid Build Coastguard Worker       if (input1_dims[i] != 1 && input2_dims[i] != 1) {
224*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(input1_dims[i], input2_dims[i]);
225*4bdc9457SAndroid Build Coastguard Worker       }
226*4bdc9457SAndroid Build Coastguard Worker       output_dims[i] = std::max(input1_dims[i], input2_dims[i]);
227*4bdc9457SAndroid Build Coastguard Worker     }
228*4bdc9457SAndroid Build Coastguard Worker     const size_t num_output_elements =
229*4bdc9457SAndroid Build Coastguard Worker       std::accumulate(output_dims.begin(), output_dims.end(), size_t(1), std::multiplies<size_t>());
230*4bdc9457SAndroid Build Coastguard Worker 
231*4bdc9457SAndroid Build Coastguard Worker     // Compute generalized strides.
232*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_strides;
233*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_strides;
234*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> output_strides;
235*4bdc9457SAndroid Build Coastguard Worker     size_t input1_stride = 1, input2_stride = 1, output_stride = 1;
236*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = XNN_MAX_TENSOR_DIMS; i != 0; i--) {
237*4bdc9457SAndroid Build Coastguard Worker       input1_strides[i - 1] = input1_dims[i - 1] == 1 ? 0 : input1_stride;
238*4bdc9457SAndroid Build Coastguard Worker       input2_strides[i - 1] = input2_dims[i - 1] == 1 ? 0 : input2_stride;
239*4bdc9457SAndroid Build Coastguard Worker       output_strides[i - 1] = output_stride;
240*4bdc9457SAndroid Build Coastguard Worker       input1_stride *= input1_dims[i - 1];
241*4bdc9457SAndroid Build Coastguard Worker       input2_stride *= input2_dims[i - 1];
242*4bdc9457SAndroid Build Coastguard Worker       output_stride *= output_dims[i - 1];
243*4bdc9457SAndroid Build Coastguard Worker     }
244*4bdc9457SAndroid Build Coastguard Worker 
245*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> input1(XNN_EXTRA_BYTES / sizeof(uint16_t) + num_input1_elements());
246*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> input2(XNN_EXTRA_BYTES / sizeof(uint16_t) + num_input2_elements());
247*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> output(num_output_elements);
248*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(num_output_elements);
249*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
250*4bdc9457SAndroid Build Coastguard Worker       std::generate(input1.begin(), input1.end(), [&]() { return i8dist(rng); });
251*4bdc9457SAndroid Build Coastguard Worker       std::generate(input2.begin(), input2.end(), [&]() { return i8dist(rng); });
252*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), 0xAA);
253*4bdc9457SAndroid Build Coastguard Worker 
254*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
255*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < output_dims[0]; i++) {
256*4bdc9457SAndroid Build Coastguard Worker         for (size_t j = 0; j < output_dims[1]; j++) {
257*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < output_dims[2]; k++) {
258*4bdc9457SAndroid Build Coastguard Worker             for (size_t l = 0; l < output_dims[3]; l++) {
259*4bdc9457SAndroid Build Coastguard Worker               for (size_t m = 0; m < output_dims[4]; m++) {
260*4bdc9457SAndroid Build Coastguard Worker                 for (size_t n = 0; n < output_dims[5]; n++) {
261*4bdc9457SAndroid Build Coastguard Worker                   output_ref[i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5]] = Compute(
262*4bdc9457SAndroid Build Coastguard Worker                     input1_scale() * (int32_t(input1[i * input1_strides[0] + j * input1_strides[1] + k * input1_strides[2] + l * input1_strides[3] + m * input1_strides[4] + n * input1_strides[5]]) - input1_zero_point()),
263*4bdc9457SAndroid Build Coastguard Worker                     input2_scale() * (int32_t(input2[i * input2_strides[0] + j * input2_strides[1] + k * input2_strides[2] + l * input2_strides[3] + m * input2_strides[4] + n * input2_strides[5]]) - input2_zero_point())) /
264*4bdc9457SAndroid Build Coastguard Worker                       output_scale() + float(output_zero_point());
265*4bdc9457SAndroid Build Coastguard Worker                 }
266*4bdc9457SAndroid Build Coastguard Worker               }
267*4bdc9457SAndroid Build Coastguard Worker             }
268*4bdc9457SAndroid Build Coastguard Worker           }
269*4bdc9457SAndroid Build Coastguard Worker         }
270*4bdc9457SAndroid Build Coastguard Worker       }
271*4bdc9457SAndroid Build Coastguard Worker 
272*4bdc9457SAndroid Build Coastguard Worker       for (float& output_value : output_ref) {
273*4bdc9457SAndroid Build Coastguard Worker         output_value = std::min(std::max(output_value, float(int8_t(qmin() - 0x80))), float(int8_t(qmax() - 0x80)));
274*4bdc9457SAndroid Build Coastguard Worker       }
275*4bdc9457SAndroid Build Coastguard Worker 
276*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy a binary elementwise operator.
277*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
278*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t binary_elementwise_op = nullptr;
279*4bdc9457SAndroid Build Coastguard Worker       xnn_status status = xnn_status_unsupported_parameter;
280*4bdc9457SAndroid Build Coastguard Worker       switch (operation_type()) {
281*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Add:
282*4bdc9457SAndroid Build Coastguard Worker           status = xnn_create_add_nd_qs8(
283*4bdc9457SAndroid Build Coastguard Worker             input1_zero_point(), input1_scale(),
284*4bdc9457SAndroid Build Coastguard Worker             input2_zero_point(), input2_scale(),
285*4bdc9457SAndroid Build Coastguard Worker             output_zero_point(), output_scale(),
286*4bdc9457SAndroid Build Coastguard Worker             int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
287*4bdc9457SAndroid Build Coastguard Worker             0, &binary_elementwise_op);
288*4bdc9457SAndroid Build Coastguard Worker           break;
289*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Multiply:
290*4bdc9457SAndroid Build Coastguard Worker           status = xnn_create_multiply_nd_qs8(
291*4bdc9457SAndroid Build Coastguard Worker             input1_zero_point(), input1_scale(),
292*4bdc9457SAndroid Build Coastguard Worker             input2_zero_point(), input2_scale(),
293*4bdc9457SAndroid Build Coastguard Worker             output_zero_point(), output_scale(),
294*4bdc9457SAndroid Build Coastguard Worker             int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
295*4bdc9457SAndroid Build Coastguard Worker             0, &binary_elementwise_op);
296*4bdc9457SAndroid Build Coastguard Worker           break;
297*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Subtract:
298*4bdc9457SAndroid Build Coastguard Worker           status = xnn_create_subtract_nd_qs8(
299*4bdc9457SAndroid Build Coastguard Worker             input1_zero_point(), input1_scale(),
300*4bdc9457SAndroid Build Coastguard Worker             input2_zero_point(), input2_scale(),
301*4bdc9457SAndroid Build Coastguard Worker             output_zero_point(), output_scale(),
302*4bdc9457SAndroid Build Coastguard Worker             int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
303*4bdc9457SAndroid Build Coastguard Worker             0, &binary_elementwise_op);
304*4bdc9457SAndroid Build Coastguard Worker           break;
305*4bdc9457SAndroid Build Coastguard Worker         default:
306*4bdc9457SAndroid Build Coastguard Worker           FAIL() << "Unsupported operation type";
307*4bdc9457SAndroid Build Coastguard Worker       }
308*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
309*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
310*4bdc9457SAndroid Build Coastguard Worker       }
311*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
312*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, binary_elementwise_op);
313*4bdc9457SAndroid Build Coastguard Worker 
314*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete binary_elementwise_op.
315*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_binary_elementwise_op(binary_elementwise_op, xnn_delete_operator);
316*4bdc9457SAndroid Build Coastguard Worker 
317*4bdc9457SAndroid Build Coastguard Worker       switch (operation_type()) {
318*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Add:
319*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
320*4bdc9457SAndroid Build Coastguard Worker             xnn_setup_add_nd_qs8(
321*4bdc9457SAndroid Build Coastguard Worker               binary_elementwise_op,
322*4bdc9457SAndroid Build Coastguard Worker               num_input1_dims(),
323*4bdc9457SAndroid Build Coastguard Worker               input1_shape().data(),
324*4bdc9457SAndroid Build Coastguard Worker               num_input2_dims(),
325*4bdc9457SAndroid Build Coastguard Worker               input2_shape().data(),
326*4bdc9457SAndroid Build Coastguard Worker               input1.data(), input2.data(), output.data(),
327*4bdc9457SAndroid Build Coastguard Worker               nullptr /* thread pool */));
328*4bdc9457SAndroid Build Coastguard Worker           break;
329*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Multiply:
330*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
331*4bdc9457SAndroid Build Coastguard Worker             xnn_setup_multiply_nd_qs8(
332*4bdc9457SAndroid Build Coastguard Worker               binary_elementwise_op,
333*4bdc9457SAndroid Build Coastguard Worker               num_input1_dims(),
334*4bdc9457SAndroid Build Coastguard Worker               input1_shape().data(),
335*4bdc9457SAndroid Build Coastguard Worker               num_input2_dims(),
336*4bdc9457SAndroid Build Coastguard Worker               input2_shape().data(),
337*4bdc9457SAndroid Build Coastguard Worker               input1.data(), input2.data(), output.data(),
338*4bdc9457SAndroid Build Coastguard Worker               nullptr /* thread pool */));
339*4bdc9457SAndroid Build Coastguard Worker           break;
340*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Subtract:
341*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
342*4bdc9457SAndroid Build Coastguard Worker             xnn_setup_subtract_nd_qs8(
343*4bdc9457SAndroid Build Coastguard Worker               binary_elementwise_op,
344*4bdc9457SAndroid Build Coastguard Worker               num_input1_dims(),
345*4bdc9457SAndroid Build Coastguard Worker               input1_shape().data(),
346*4bdc9457SAndroid Build Coastguard Worker               num_input2_dims(),
347*4bdc9457SAndroid Build Coastguard Worker               input2_shape().data(),
348*4bdc9457SAndroid Build Coastguard Worker               input1.data(), input2.data(), output.data(),
349*4bdc9457SAndroid Build Coastguard Worker               nullptr /* thread pool */));
350*4bdc9457SAndroid Build Coastguard Worker           break;
351*4bdc9457SAndroid Build Coastguard Worker         default:
352*4bdc9457SAndroid Build Coastguard Worker           FAIL() << "Unsupported operation type";
353*4bdc9457SAndroid Build Coastguard Worker       }
354*4bdc9457SAndroid Build Coastguard Worker 
355*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
356*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(binary_elementwise_op, nullptr /* thread pool */));
357*4bdc9457SAndroid Build Coastguard Worker 
358*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
359*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < output_dims[0]; i++) {
360*4bdc9457SAndroid Build Coastguard Worker         for (size_t j = 0; j < output_dims[1]; j++) {
361*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < output_dims[2]; k++) {
362*4bdc9457SAndroid Build Coastguard Worker             for (size_t l = 0; l < output_dims[3]; l++) {
363*4bdc9457SAndroid Build Coastguard Worker               for (size_t m = 0; m < output_dims[4]; m++) {
364*4bdc9457SAndroid Build Coastguard Worker                 for (size_t n = 0; n < output_dims[5]; n++) {
365*4bdc9457SAndroid Build Coastguard Worker                   const size_t index =
366*4bdc9457SAndroid Build Coastguard Worker                     i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5];
367*4bdc9457SAndroid Build Coastguard Worker                   ASSERT_NEAR(float(output[index]), output_ref[index], 0.6f)
368*4bdc9457SAndroid Build Coastguard Worker                     << "(i, j, k, l, m, n) = (" << i << ", " << j << ", " << k << ", " << l << ", " << m << ", " << n << ")"
369*4bdc9457SAndroid Build Coastguard Worker                     << ", input1 zero point = " << input1_zero_point() << ", input1 scale = " << input1_scale()
370*4bdc9457SAndroid Build Coastguard Worker                     << ", input2 zero point = " << input2_zero_point() << ", input2 scale = " << input2_scale()
371*4bdc9457SAndroid Build Coastguard Worker                     << ", output zero point = " << output_zero_point() << ", output scale = " << output_scale();
372*4bdc9457SAndroid Build Coastguard Worker                 }
373*4bdc9457SAndroid Build Coastguard Worker               }
374*4bdc9457SAndroid Build Coastguard Worker             }
375*4bdc9457SAndroid Build Coastguard Worker           }
376*4bdc9457SAndroid Build Coastguard Worker         }
377*4bdc9457SAndroid Build Coastguard Worker       }
378*4bdc9457SAndroid Build Coastguard Worker     }
379*4bdc9457SAndroid Build Coastguard Worker   }
380*4bdc9457SAndroid Build Coastguard Worker 
TestQU8()381*4bdc9457SAndroid Build Coastguard Worker   void TestQU8() const {
382*4bdc9457SAndroid Build Coastguard Worker     ASSERT_NE(operation_type(), OperationType::Unknown);
383*4bdc9457SAndroid Build Coastguard Worker     ASSERT_GE(input1_zero_point(), std::numeric_limits<uint8_t>::min());
384*4bdc9457SAndroid Build Coastguard Worker     ASSERT_LE(input1_zero_point(), std::numeric_limits<uint8_t>::max());
385*4bdc9457SAndroid Build Coastguard Worker     ASSERT_GE(input2_zero_point(), std::numeric_limits<uint8_t>::min());
386*4bdc9457SAndroid Build Coastguard Worker     ASSERT_LE(input2_zero_point(), std::numeric_limits<uint8_t>::max());
387*4bdc9457SAndroid Build Coastguard Worker     ASSERT_GE(output_zero_point(), std::numeric_limits<uint8_t>::min());
388*4bdc9457SAndroid Build Coastguard Worker     ASSERT_LE(output_zero_point(), std::numeric_limits<uint8_t>::max());
389*4bdc9457SAndroid Build Coastguard Worker 
390*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
391*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
392*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> u8dist(
393*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
394*4bdc9457SAndroid Build Coastguard Worker 
395*4bdc9457SAndroid Build Coastguard Worker     // Compute generalized shapes.
396*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_dims;
397*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_dims;
398*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> output_dims;
399*4bdc9457SAndroid Build Coastguard Worker     std::fill(input1_dims.begin(), input1_dims.end(), 1);
400*4bdc9457SAndroid Build Coastguard Worker     std::fill(input2_dims.begin(), input2_dims.end(), 1);
401*4bdc9457SAndroid Build Coastguard Worker     std::fill(output_dims.begin(), output_dims.end(), 1);
402*4bdc9457SAndroid Build Coastguard Worker     std::copy(input1_shape().cbegin(), input1_shape().cend(), input1_dims.end() - num_input1_dims());
403*4bdc9457SAndroid Build Coastguard Worker     std::copy(input2_shape().cbegin(), input2_shape().cend(), input2_dims.end() - num_input2_dims());
404*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
405*4bdc9457SAndroid Build Coastguard Worker       if (input1_dims[i] != 1 && input2_dims[i] != 1) {
406*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(input1_dims[i], input2_dims[i]);
407*4bdc9457SAndroid Build Coastguard Worker       }
408*4bdc9457SAndroid Build Coastguard Worker       output_dims[i] = std::max(input1_dims[i], input2_dims[i]);
409*4bdc9457SAndroid Build Coastguard Worker     }
410*4bdc9457SAndroid Build Coastguard Worker     const size_t num_output_elements =
411*4bdc9457SAndroid Build Coastguard Worker       std::accumulate(output_dims.begin(), output_dims.end(), size_t(1), std::multiplies<size_t>());
412*4bdc9457SAndroid Build Coastguard Worker 
413*4bdc9457SAndroid Build Coastguard Worker     // Compute generalized strides.
414*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_strides;
415*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_strides;
416*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> output_strides;
417*4bdc9457SAndroid Build Coastguard Worker     size_t input1_stride = 1, input2_stride = 1, output_stride = 1;
418*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = XNN_MAX_TENSOR_DIMS; i != 0; i--) {
419*4bdc9457SAndroid Build Coastguard Worker       input1_strides[i - 1] = input1_dims[i - 1] == 1 ? 0 : input1_stride;
420*4bdc9457SAndroid Build Coastguard Worker       input2_strides[i - 1] = input2_dims[i - 1] == 1 ? 0 : input2_stride;
421*4bdc9457SAndroid Build Coastguard Worker       output_strides[i - 1] = output_stride;
422*4bdc9457SAndroid Build Coastguard Worker       input1_stride *= input1_dims[i - 1];
423*4bdc9457SAndroid Build Coastguard Worker       input2_stride *= input2_dims[i - 1];
424*4bdc9457SAndroid Build Coastguard Worker       output_stride *= output_dims[i - 1];
425*4bdc9457SAndroid Build Coastguard Worker     }
426*4bdc9457SAndroid Build Coastguard Worker 
427*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input1(XNN_EXTRA_BYTES / sizeof(uint16_t) + num_input1_elements());
428*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input2(XNN_EXTRA_BYTES / sizeof(uint16_t) + num_input2_elements());
429*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output(num_output_elements);
430*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(num_output_elements);
431*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
432*4bdc9457SAndroid Build Coastguard Worker       std::generate(input1.begin(), input1.end(), [&]() { return u8dist(rng); });
433*4bdc9457SAndroid Build Coastguard Worker       std::generate(input2.begin(), input2.end(), [&]() { return u8dist(rng); });
434*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), 0xAA);
435*4bdc9457SAndroid Build Coastguard Worker 
436*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
437*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < output_dims[0]; i++) {
438*4bdc9457SAndroid Build Coastguard Worker         for (size_t j = 0; j < output_dims[1]; j++) {
439*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < output_dims[2]; k++) {
440*4bdc9457SAndroid Build Coastguard Worker             for (size_t l = 0; l < output_dims[3]; l++) {
441*4bdc9457SAndroid Build Coastguard Worker               for (size_t m = 0; m < output_dims[4]; m++) {
442*4bdc9457SAndroid Build Coastguard Worker                 for (size_t n = 0; n < output_dims[5]; n++) {
443*4bdc9457SAndroid Build Coastguard Worker                   output_ref[i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5]] = Compute(
444*4bdc9457SAndroid Build Coastguard Worker                     input1_scale() * (int32_t(input1[i * input1_strides[0] + j * input1_strides[1] + k * input1_strides[2] + l * input1_strides[3] + m * input1_strides[4] + n * input1_strides[5]]) - input1_zero_point()),
445*4bdc9457SAndroid Build Coastguard Worker                     input2_scale() * (int32_t(input2[i * input2_strides[0] + j * input2_strides[1] + k * input2_strides[2] + l * input2_strides[3] + m * input2_strides[4] + n * input2_strides[5]]) - input2_zero_point())) /
446*4bdc9457SAndroid Build Coastguard Worker                       output_scale() + float(output_zero_point());
447*4bdc9457SAndroid Build Coastguard Worker                 }
448*4bdc9457SAndroid Build Coastguard Worker               }
449*4bdc9457SAndroid Build Coastguard Worker             }
450*4bdc9457SAndroid Build Coastguard Worker           }
451*4bdc9457SAndroid Build Coastguard Worker         }
452*4bdc9457SAndroid Build Coastguard Worker       }
453*4bdc9457SAndroid Build Coastguard Worker 
454*4bdc9457SAndroid Build Coastguard Worker       for (float& output_value : output_ref) {
455*4bdc9457SAndroid Build Coastguard Worker         output_value = std::min(std::max(output_value, float(int32_t(qmin()))), float(int32_t(qmax())));
456*4bdc9457SAndroid Build Coastguard Worker       }
457*4bdc9457SAndroid Build Coastguard Worker 
458*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy a binary elementwise operator.
459*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
460*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t binary_elementwise_op = nullptr;
461*4bdc9457SAndroid Build Coastguard Worker       xnn_status status = xnn_status_unsupported_parameter;
462*4bdc9457SAndroid Build Coastguard Worker       switch (operation_type()) {
463*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Add:
464*4bdc9457SAndroid Build Coastguard Worker           status = xnn_create_add_nd_qu8(
465*4bdc9457SAndroid Build Coastguard Worker             input1_zero_point(), input1_scale(),
466*4bdc9457SAndroid Build Coastguard Worker             input2_zero_point(), input2_scale(),
467*4bdc9457SAndroid Build Coastguard Worker             output_zero_point(), output_scale(),
468*4bdc9457SAndroid Build Coastguard Worker             qmin(), qmax(),
469*4bdc9457SAndroid Build Coastguard Worker             0, &binary_elementwise_op);
470*4bdc9457SAndroid Build Coastguard Worker           break;
471*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Multiply:
472*4bdc9457SAndroid Build Coastguard Worker           status = xnn_create_multiply_nd_qu8(
473*4bdc9457SAndroid Build Coastguard Worker             input1_zero_point(), input1_scale(),
474*4bdc9457SAndroid Build Coastguard Worker             input2_zero_point(), input2_scale(),
475*4bdc9457SAndroid Build Coastguard Worker             output_zero_point(), output_scale(),
476*4bdc9457SAndroid Build Coastguard Worker             qmin(), qmax(),
477*4bdc9457SAndroid Build Coastguard Worker             0, &binary_elementwise_op);
478*4bdc9457SAndroid Build Coastguard Worker           break;
479*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Subtract:
480*4bdc9457SAndroid Build Coastguard Worker           status = xnn_create_subtract_nd_qu8(
481*4bdc9457SAndroid Build Coastguard Worker             input1_zero_point(), input1_scale(),
482*4bdc9457SAndroid Build Coastguard Worker             input2_zero_point(), input2_scale(),
483*4bdc9457SAndroid Build Coastguard Worker             output_zero_point(), output_scale(),
484*4bdc9457SAndroid Build Coastguard Worker             qmin(), qmax(),
485*4bdc9457SAndroid Build Coastguard Worker             0, &binary_elementwise_op);
486*4bdc9457SAndroid Build Coastguard Worker           break;
487*4bdc9457SAndroid Build Coastguard Worker         default:
488*4bdc9457SAndroid Build Coastguard Worker           FAIL() << "Unsupported operation type";
489*4bdc9457SAndroid Build Coastguard Worker       }
490*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
491*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
492*4bdc9457SAndroid Build Coastguard Worker       }
493*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
494*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, binary_elementwise_op);
495*4bdc9457SAndroid Build Coastguard Worker 
496*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete binary_elementwise_op.
497*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_binary_elementwise_op(binary_elementwise_op, xnn_delete_operator);
498*4bdc9457SAndroid Build Coastguard Worker 
499*4bdc9457SAndroid Build Coastguard Worker       switch (operation_type()) {
500*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Add:
501*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
502*4bdc9457SAndroid Build Coastguard Worker             xnn_setup_add_nd_qu8(
503*4bdc9457SAndroid Build Coastguard Worker               binary_elementwise_op,
504*4bdc9457SAndroid Build Coastguard Worker               num_input1_dims(),
505*4bdc9457SAndroid Build Coastguard Worker               input1_shape().data(),
506*4bdc9457SAndroid Build Coastguard Worker               num_input2_dims(),
507*4bdc9457SAndroid Build Coastguard Worker               input2_shape().data(),
508*4bdc9457SAndroid Build Coastguard Worker               input1.data(), input2.data(), output.data(),
509*4bdc9457SAndroid Build Coastguard Worker               nullptr /* thread pool */));
510*4bdc9457SAndroid Build Coastguard Worker           break;
511*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Multiply:
512*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
513*4bdc9457SAndroid Build Coastguard Worker             xnn_setup_multiply_nd_qu8(
514*4bdc9457SAndroid Build Coastguard Worker               binary_elementwise_op,
515*4bdc9457SAndroid Build Coastguard Worker               num_input1_dims(),
516*4bdc9457SAndroid Build Coastguard Worker               input1_shape().data(),
517*4bdc9457SAndroid Build Coastguard Worker               num_input2_dims(),
518*4bdc9457SAndroid Build Coastguard Worker               input2_shape().data(),
519*4bdc9457SAndroid Build Coastguard Worker               input1.data(), input2.data(), output.data(),
520*4bdc9457SAndroid Build Coastguard Worker               nullptr /* thread pool */));
521*4bdc9457SAndroid Build Coastguard Worker           break;
522*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Subtract:
523*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
524*4bdc9457SAndroid Build Coastguard Worker             xnn_setup_subtract_nd_qu8(
525*4bdc9457SAndroid Build Coastguard Worker               binary_elementwise_op,
526*4bdc9457SAndroid Build Coastguard Worker               num_input1_dims(),
527*4bdc9457SAndroid Build Coastguard Worker               input1_shape().data(),
528*4bdc9457SAndroid Build Coastguard Worker               num_input2_dims(),
529*4bdc9457SAndroid Build Coastguard Worker               input2_shape().data(),
530*4bdc9457SAndroid Build Coastguard Worker               input1.data(), input2.data(), output.data(),
531*4bdc9457SAndroid Build Coastguard Worker               nullptr /* thread pool */));
532*4bdc9457SAndroid Build Coastguard Worker           break;
533*4bdc9457SAndroid Build Coastguard Worker         default:
534*4bdc9457SAndroid Build Coastguard Worker           FAIL() << "Unsupported operation type";
535*4bdc9457SAndroid Build Coastguard Worker       }
536*4bdc9457SAndroid Build Coastguard Worker 
537*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
538*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(binary_elementwise_op, nullptr /* thread pool */));
539*4bdc9457SAndroid Build Coastguard Worker 
540*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
541*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < output_dims[0]; i++) {
542*4bdc9457SAndroid Build Coastguard Worker         for (size_t j = 0; j < output_dims[1]; j++) {
543*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < output_dims[2]; k++) {
544*4bdc9457SAndroid Build Coastguard Worker             for (size_t l = 0; l < output_dims[3]; l++) {
545*4bdc9457SAndroid Build Coastguard Worker               for (size_t m = 0; m < output_dims[4]; m++) {
546*4bdc9457SAndroid Build Coastguard Worker                 for (size_t n = 0; n < output_dims[5]; n++) {
547*4bdc9457SAndroid Build Coastguard Worker                   const size_t index =
548*4bdc9457SAndroid Build Coastguard Worker                     i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5];
549*4bdc9457SAndroid Build Coastguard Worker                   ASSERT_NEAR(float(int32_t(output[index])), output_ref[index], 0.6f)
550*4bdc9457SAndroid Build Coastguard Worker                     << "(i, j, k, l, m, n) = (" << i << ", " << j << ", " << k << ", " << l << ", " << m << ", " << n << ")"
551*4bdc9457SAndroid Build Coastguard Worker                     << ", input1 zero point = " << input1_zero_point() << ", input1 scale = " << input1_scale()
552*4bdc9457SAndroid Build Coastguard Worker                     << ", input2 zero point = " << input2_zero_point() << ", input2 scale = " << input2_scale()
553*4bdc9457SAndroid Build Coastguard Worker                     << ", output zero point = " << output_zero_point() << ", output scale = " << output_scale();
554*4bdc9457SAndroid Build Coastguard Worker                 }
555*4bdc9457SAndroid Build Coastguard Worker               }
556*4bdc9457SAndroid Build Coastguard Worker             }
557*4bdc9457SAndroid Build Coastguard Worker           }
558*4bdc9457SAndroid Build Coastguard Worker         }
559*4bdc9457SAndroid Build Coastguard Worker       }
560*4bdc9457SAndroid Build Coastguard Worker     }
561*4bdc9457SAndroid Build Coastguard Worker   }
562*4bdc9457SAndroid Build Coastguard Worker 
TestF16()563*4bdc9457SAndroid Build Coastguard Worker   void TestF16() const {
564*4bdc9457SAndroid Build Coastguard Worker     ASSERT_NE(operation_type(), OperationType::Unknown);
565*4bdc9457SAndroid Build Coastguard Worker 
566*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
567*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
568*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(0.01f, 1.0f);
569*4bdc9457SAndroid Build Coastguard Worker 
570*4bdc9457SAndroid Build Coastguard Worker     // Compute generalized shapes.
571*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_dims;
572*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_dims;
573*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> output_dims;
574*4bdc9457SAndroid Build Coastguard Worker     std::fill(input1_dims.begin(), input1_dims.end(), 1);
575*4bdc9457SAndroid Build Coastguard Worker     std::fill(input2_dims.begin(), input2_dims.end(), 1);
576*4bdc9457SAndroid Build Coastguard Worker     std::fill(output_dims.begin(), output_dims.end(), 1);
577*4bdc9457SAndroid Build Coastguard Worker     std::copy(input1_shape().cbegin(), input1_shape().cend(), input1_dims.end() - num_input1_dims());
578*4bdc9457SAndroid Build Coastguard Worker     std::copy(input2_shape().cbegin(), input2_shape().cend(), input2_dims.end() - num_input2_dims());
579*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
580*4bdc9457SAndroid Build Coastguard Worker       if (input1_dims[i] != 1 && input2_dims[i] != 1) {
581*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(input1_dims[i], input2_dims[i]);
582*4bdc9457SAndroid Build Coastguard Worker       }
583*4bdc9457SAndroid Build Coastguard Worker       output_dims[i] = std::max(input1_dims[i], input2_dims[i]);
584*4bdc9457SAndroid Build Coastguard Worker     }
585*4bdc9457SAndroid Build Coastguard Worker     const size_t num_output_elements =
586*4bdc9457SAndroid Build Coastguard Worker       std::accumulate(output_dims.begin(), output_dims.end(), size_t(1), std::multiplies<size_t>());
587*4bdc9457SAndroid Build Coastguard Worker 
588*4bdc9457SAndroid Build Coastguard Worker     // Compute generalized strides.
589*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_strides;
590*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_strides;
591*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> output_strides;
592*4bdc9457SAndroid Build Coastguard Worker     size_t input1_stride = 1, input2_stride = 1, output_stride = 1;
593*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = XNN_MAX_TENSOR_DIMS; i != 0; i--) {
594*4bdc9457SAndroid Build Coastguard Worker       input1_strides[i - 1] = input1_dims[i - 1] == 1 ? 0 : input1_stride;
595*4bdc9457SAndroid Build Coastguard Worker       input2_strides[i - 1] = input2_dims[i - 1] == 1 ? 0 : input2_stride;
596*4bdc9457SAndroid Build Coastguard Worker       output_strides[i - 1] = output_stride;
597*4bdc9457SAndroid Build Coastguard Worker       input1_stride *= input1_dims[i - 1];
598*4bdc9457SAndroid Build Coastguard Worker       input2_stride *= input2_dims[i - 1];
599*4bdc9457SAndroid Build Coastguard Worker       output_stride *= output_dims[i - 1];
600*4bdc9457SAndroid Build Coastguard Worker     }
601*4bdc9457SAndroid Build Coastguard Worker 
602*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> input1(XNN_EXTRA_BYTES / sizeof(uint16_t) + num_input1_elements());
603*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> input2(XNN_EXTRA_BYTES / sizeof(uint16_t) + num_input2_elements());
604*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> output(num_output_elements);
605*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(num_output_elements);
606*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
607*4bdc9457SAndroid Build Coastguard Worker       std::generate(input1.begin(), input1.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
608*4bdc9457SAndroid Build Coastguard Worker       std::generate(input2.begin(), input2.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
609*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
610*4bdc9457SAndroid Build Coastguard Worker 
611*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
612*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < output_dims[0]; i++) {
613*4bdc9457SAndroid Build Coastguard Worker         for (size_t j = 0; j < output_dims[1]; j++) {
614*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < output_dims[2]; k++) {
615*4bdc9457SAndroid Build Coastguard Worker             for (size_t l = 0; l < output_dims[3]; l++) {
616*4bdc9457SAndroid Build Coastguard Worker               for (size_t m = 0; m < output_dims[4]; m++) {
617*4bdc9457SAndroid Build Coastguard Worker                 for (size_t n = 0; n < output_dims[5]; n++) {
618*4bdc9457SAndroid Build Coastguard Worker                   output_ref[i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5]] = Compute(
619*4bdc9457SAndroid Build Coastguard Worker                     fp16_ieee_to_fp32_value(input1[i * input1_strides[0] + j * input1_strides[1] + k * input1_strides[2] + l * input1_strides[3] + m * input1_strides[4] + n * input1_strides[5]]),
620*4bdc9457SAndroid Build Coastguard Worker                     fp16_ieee_to_fp32_value(input2[i * input2_strides[0] + j * input2_strides[1] + k * input2_strides[2] + l * input2_strides[3] + m * input2_strides[4] + n * input2_strides[5]]));
621*4bdc9457SAndroid Build Coastguard Worker                 }
622*4bdc9457SAndroid Build Coastguard Worker               }
623*4bdc9457SAndroid Build Coastguard Worker             }
624*4bdc9457SAndroid Build Coastguard Worker           }
625*4bdc9457SAndroid Build Coastguard Worker         }
626*4bdc9457SAndroid Build Coastguard Worker       }
627*4bdc9457SAndroid Build Coastguard Worker 
628*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
629*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
630*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
631*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
632*4bdc9457SAndroid Build Coastguard Worker       const float scaled_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + accumulated_range / 255.0f * float(qmin())));
633*4bdc9457SAndroid Build Coastguard Worker       const float scaled_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - accumulated_range / 255.0f * float(255 - qmax())));
634*4bdc9457SAndroid Build Coastguard Worker       const float output_min = scaled_min == scaled_max ? -std::numeric_limits<float>::infinity() : scaled_min;
635*4bdc9457SAndroid Build Coastguard Worker       const float output_max = scaled_min == scaled_max ? +std::numeric_limits<float>::infinity() : scaled_max;
636*4bdc9457SAndroid Build Coastguard Worker 
637*4bdc9457SAndroid Build Coastguard Worker       for (float& output_value : output_ref) {
638*4bdc9457SAndroid Build Coastguard Worker         output_value = std::min(std::max(output_value, output_min), output_max);
639*4bdc9457SAndroid Build Coastguard Worker       }
640*4bdc9457SAndroid Build Coastguard Worker 
641*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy a binary elementwise operator.
642*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
643*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t binary_elementwise_op = nullptr;
644*4bdc9457SAndroid Build Coastguard Worker       xnn_status status = xnn_status_unsupported_parameter;
645*4bdc9457SAndroid Build Coastguard Worker       switch (operation_type()) {
646*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Add:
647*4bdc9457SAndroid Build Coastguard Worker           status = xnn_create_add_nd_f16(output_min, output_max, 0, &binary_elementwise_op);
648*4bdc9457SAndroid Build Coastguard Worker           break;
649*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Divide:
650*4bdc9457SAndroid Build Coastguard Worker           status = xnn_create_divide_nd_f16(output_min, output_max, 0, &binary_elementwise_op);
651*4bdc9457SAndroid Build Coastguard Worker           break;
652*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Maximum:
653*4bdc9457SAndroid Build Coastguard Worker           status = xnn_create_maximum_nd_f16(0, &binary_elementwise_op);
654*4bdc9457SAndroid Build Coastguard Worker           break;
655*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Minimum:
656*4bdc9457SAndroid Build Coastguard Worker           status = xnn_create_minimum_nd_f16(0, &binary_elementwise_op);
657*4bdc9457SAndroid Build Coastguard Worker           break;
658*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Multiply:
659*4bdc9457SAndroid Build Coastguard Worker           status = xnn_create_multiply_nd_f16(output_min, output_max, 0, &binary_elementwise_op);
660*4bdc9457SAndroid Build Coastguard Worker           break;
661*4bdc9457SAndroid Build Coastguard Worker         case OperationType::SquaredDifference:
662*4bdc9457SAndroid Build Coastguard Worker           status = xnn_create_squared_difference_nd_f16(0, &binary_elementwise_op);
663*4bdc9457SAndroid Build Coastguard Worker           break;
664*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Subtract:
665*4bdc9457SAndroid Build Coastguard Worker           status = xnn_create_subtract_nd_f16(output_min, output_max, 0, &binary_elementwise_op);
666*4bdc9457SAndroid Build Coastguard Worker           break;
667*4bdc9457SAndroid Build Coastguard Worker         default:
668*4bdc9457SAndroid Build Coastguard Worker           FAIL() << "Unsupported operation type";
669*4bdc9457SAndroid Build Coastguard Worker       }
670*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
671*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
672*4bdc9457SAndroid Build Coastguard Worker       }
673*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
674*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, binary_elementwise_op);
675*4bdc9457SAndroid Build Coastguard Worker 
676*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete binary_elementwise_op.
677*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_binary_elementwise_op(binary_elementwise_op, xnn_delete_operator);
678*4bdc9457SAndroid Build Coastguard Worker 
679*4bdc9457SAndroid Build Coastguard Worker       switch (operation_type()) {
680*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Add:
681*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
682*4bdc9457SAndroid Build Coastguard Worker             xnn_setup_add_nd_f16(
683*4bdc9457SAndroid Build Coastguard Worker               binary_elementwise_op,
684*4bdc9457SAndroid Build Coastguard Worker               num_input1_dims(),
685*4bdc9457SAndroid Build Coastguard Worker               input1_shape().data(),
686*4bdc9457SAndroid Build Coastguard Worker               num_input2_dims(),
687*4bdc9457SAndroid Build Coastguard Worker               input2_shape().data(),
688*4bdc9457SAndroid Build Coastguard Worker               input1.data(), input2.data(), output.data(),
689*4bdc9457SAndroid Build Coastguard Worker               nullptr /* thread pool */));
690*4bdc9457SAndroid Build Coastguard Worker           break;
691*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Divide:
692*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
693*4bdc9457SAndroid Build Coastguard Worker             xnn_setup_divide_nd_f16(
694*4bdc9457SAndroid Build Coastguard Worker               binary_elementwise_op,
695*4bdc9457SAndroid Build Coastguard Worker               num_input1_dims(),
696*4bdc9457SAndroid Build Coastguard Worker               input1_shape().data(),
697*4bdc9457SAndroid Build Coastguard Worker               num_input2_dims(),
698*4bdc9457SAndroid Build Coastguard Worker               input2_shape().data(),
699*4bdc9457SAndroid Build Coastguard Worker               input1.data(), input2.data(), output.data(),
700*4bdc9457SAndroid Build Coastguard Worker               nullptr /* thread pool */));
701*4bdc9457SAndroid Build Coastguard Worker           break;
702*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Maximum:
703*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
704*4bdc9457SAndroid Build Coastguard Worker             xnn_setup_maximum_nd_f16(
705*4bdc9457SAndroid Build Coastguard Worker               binary_elementwise_op,
706*4bdc9457SAndroid Build Coastguard Worker               num_input1_dims(),
707*4bdc9457SAndroid Build Coastguard Worker               input1_shape().data(),
708*4bdc9457SAndroid Build Coastguard Worker               num_input2_dims(),
709*4bdc9457SAndroid Build Coastguard Worker               input2_shape().data(),
710*4bdc9457SAndroid Build Coastguard Worker               input1.data(), input2.data(), output.data(),
711*4bdc9457SAndroid Build Coastguard Worker               nullptr /* thread pool */));
712*4bdc9457SAndroid Build Coastguard Worker           break;
713*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Minimum:
714*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
715*4bdc9457SAndroid Build Coastguard Worker             xnn_setup_minimum_nd_f16(
716*4bdc9457SAndroid Build Coastguard Worker               binary_elementwise_op,
717*4bdc9457SAndroid Build Coastguard Worker               num_input1_dims(),
718*4bdc9457SAndroid Build Coastguard Worker               input1_shape().data(),
719*4bdc9457SAndroid Build Coastguard Worker               num_input2_dims(),
720*4bdc9457SAndroid Build Coastguard Worker               input2_shape().data(),
721*4bdc9457SAndroid Build Coastguard Worker               input1.data(), input2.data(), output.data(),
722*4bdc9457SAndroid Build Coastguard Worker               nullptr /* thread pool */));
723*4bdc9457SAndroid Build Coastguard Worker           break;
724*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Multiply:
725*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
726*4bdc9457SAndroid Build Coastguard Worker             xnn_setup_multiply_nd_f16(
727*4bdc9457SAndroid Build Coastguard Worker               binary_elementwise_op,
728*4bdc9457SAndroid Build Coastguard Worker               num_input1_dims(),
729*4bdc9457SAndroid Build Coastguard Worker               input1_shape().data(),
730*4bdc9457SAndroid Build Coastguard Worker               num_input2_dims(),
731*4bdc9457SAndroid Build Coastguard Worker               input2_shape().data(),
732*4bdc9457SAndroid Build Coastguard Worker               input1.data(), input2.data(), output.data(),
733*4bdc9457SAndroid Build Coastguard Worker               nullptr /* thread pool */));
734*4bdc9457SAndroid Build Coastguard Worker           break;
735*4bdc9457SAndroid Build Coastguard Worker         case OperationType::SquaredDifference:
736*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
737*4bdc9457SAndroid Build Coastguard Worker             xnn_setup_squared_difference_nd_f16(
738*4bdc9457SAndroid Build Coastguard Worker               binary_elementwise_op,
739*4bdc9457SAndroid Build Coastguard Worker               num_input1_dims(),
740*4bdc9457SAndroid Build Coastguard Worker               input1_shape().data(),
741*4bdc9457SAndroid Build Coastguard Worker               num_input2_dims(),
742*4bdc9457SAndroid Build Coastguard Worker               input2_shape().data(),
743*4bdc9457SAndroid Build Coastguard Worker               input1.data(), input2.data(), output.data(),
744*4bdc9457SAndroid Build Coastguard Worker               nullptr /* thread pool */));
745*4bdc9457SAndroid Build Coastguard Worker           break;
746*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Subtract:
747*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
748*4bdc9457SAndroid Build Coastguard Worker             xnn_setup_subtract_nd_f16(
749*4bdc9457SAndroid Build Coastguard Worker               binary_elementwise_op,
750*4bdc9457SAndroid Build Coastguard Worker               num_input1_dims(),
751*4bdc9457SAndroid Build Coastguard Worker               input1_shape().data(),
752*4bdc9457SAndroid Build Coastguard Worker               num_input2_dims(),
753*4bdc9457SAndroid Build Coastguard Worker               input2_shape().data(),
754*4bdc9457SAndroid Build Coastguard Worker               input1.data(), input2.data(), output.data(),
755*4bdc9457SAndroid Build Coastguard Worker               nullptr /* thread pool */));
756*4bdc9457SAndroid Build Coastguard Worker           break;
757*4bdc9457SAndroid Build Coastguard Worker         default:
758*4bdc9457SAndroid Build Coastguard Worker           FAIL() << "Unsupported operation type";
759*4bdc9457SAndroid Build Coastguard Worker       }
760*4bdc9457SAndroid Build Coastguard Worker 
761*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
762*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(binary_elementwise_op, nullptr /* thread pool */));
763*4bdc9457SAndroid Build Coastguard Worker 
764*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
765*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < output_dims[0]; i++) {
766*4bdc9457SAndroid Build Coastguard Worker         for (size_t j = 0; j < output_dims[1]; j++) {
767*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < output_dims[2]; k++) {
768*4bdc9457SAndroid Build Coastguard Worker             for (size_t l = 0; l < output_dims[3]; l++) {
769*4bdc9457SAndroid Build Coastguard Worker               for (size_t m = 0; m < output_dims[4]; m++) {
770*4bdc9457SAndroid Build Coastguard Worker                 for (size_t n = 0; n < output_dims[5]; n++) {
771*4bdc9457SAndroid Build Coastguard Worker                   const size_t index =
772*4bdc9457SAndroid Build Coastguard Worker                     i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5];
773*4bdc9457SAndroid Build Coastguard Worker                   ASSERT_NEAR(fp16_ieee_to_fp32_value(output[index]), output_ref[index], std::max(1.0e-4f, std::abs(output_ref[index]) * 1.0e-2f))
774*4bdc9457SAndroid Build Coastguard Worker                     << "(i, j, k, l, m, n) = (" << i << ", " << j << ", " << k << ", " << l << ", " << m << ", " << n << ")";
775*4bdc9457SAndroid Build Coastguard Worker                 }
776*4bdc9457SAndroid Build Coastguard Worker               }
777*4bdc9457SAndroid Build Coastguard Worker             }
778*4bdc9457SAndroid Build Coastguard Worker           }
779*4bdc9457SAndroid Build Coastguard Worker         }
780*4bdc9457SAndroid Build Coastguard Worker       }
781*4bdc9457SAndroid Build Coastguard Worker     }
782*4bdc9457SAndroid Build Coastguard Worker   }
783*4bdc9457SAndroid Build Coastguard Worker 
TestF32()784*4bdc9457SAndroid Build Coastguard Worker   void TestF32() const {
785*4bdc9457SAndroid Build Coastguard Worker     ASSERT_NE(operation_type(), OperationType::Unknown);
786*4bdc9457SAndroid Build Coastguard Worker 
787*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
788*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
789*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(0.01f, 1.0f);
790*4bdc9457SAndroid Build Coastguard Worker 
791*4bdc9457SAndroid Build Coastguard Worker     // Compute generalized shapes.
792*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_dims;
793*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_dims;
794*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> output_dims;
795*4bdc9457SAndroid Build Coastguard Worker     std::fill(input1_dims.begin(), input1_dims.end(), 1);
796*4bdc9457SAndroid Build Coastguard Worker     std::fill(input2_dims.begin(), input2_dims.end(), 1);
797*4bdc9457SAndroid Build Coastguard Worker     std::fill(output_dims.begin(), output_dims.end(), 1);
798*4bdc9457SAndroid Build Coastguard Worker     std::copy(input1_shape().cbegin(), input1_shape().cend(), input1_dims.end() - num_input1_dims());
799*4bdc9457SAndroid Build Coastguard Worker     std::copy(input2_shape().cbegin(), input2_shape().cend(), input2_dims.end() - num_input2_dims());
800*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < XNN_MAX_TENSOR_DIMS; i++) {
801*4bdc9457SAndroid Build Coastguard Worker       if (input1_dims[i] != 1 && input2_dims[i] != 1) {
802*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(input1_dims[i], input2_dims[i]);
803*4bdc9457SAndroid Build Coastguard Worker       }
804*4bdc9457SAndroid Build Coastguard Worker       output_dims[i] = std::max(input1_dims[i], input2_dims[i]);
805*4bdc9457SAndroid Build Coastguard Worker     }
806*4bdc9457SAndroid Build Coastguard Worker     const size_t num_output_elements =
807*4bdc9457SAndroid Build Coastguard Worker       std::accumulate(output_dims.begin(), output_dims.end(), size_t(1), std::multiplies<size_t>());
808*4bdc9457SAndroid Build Coastguard Worker 
809*4bdc9457SAndroid Build Coastguard Worker     // Compute generalized strides.
810*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input1_strides;
811*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> input2_strides;
812*4bdc9457SAndroid Build Coastguard Worker     std::array<size_t, XNN_MAX_TENSOR_DIMS> output_strides;
813*4bdc9457SAndroid Build Coastguard Worker     size_t input1_stride = 1, input2_stride = 1, output_stride = 1;
814*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = XNN_MAX_TENSOR_DIMS; i != 0; i--) {
815*4bdc9457SAndroid Build Coastguard Worker       input1_strides[i - 1] = input1_dims[i - 1] == 1 ? 0 : input1_stride;
816*4bdc9457SAndroid Build Coastguard Worker       input2_strides[i - 1] = input2_dims[i - 1] == 1 ? 0 : input2_stride;
817*4bdc9457SAndroid Build Coastguard Worker       output_strides[i - 1] = output_stride;
818*4bdc9457SAndroid Build Coastguard Worker       input1_stride *= input1_dims[i - 1];
819*4bdc9457SAndroid Build Coastguard Worker       input2_stride *= input2_dims[i - 1];
820*4bdc9457SAndroid Build Coastguard Worker       output_stride *= output_dims[i - 1];
821*4bdc9457SAndroid Build Coastguard Worker     }
822*4bdc9457SAndroid Build Coastguard Worker 
823*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input1(XNN_EXTRA_BYTES / sizeof(float) + num_input1_elements());
824*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input2(XNN_EXTRA_BYTES / sizeof(float) + num_input2_elements());
825*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output(num_output_elements);
826*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(num_output_elements);
827*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
828*4bdc9457SAndroid Build Coastguard Worker       std::generate(input1.begin(), input1.end(), [&]() { return f32dist(rng); });
829*4bdc9457SAndroid Build Coastguard Worker       std::generate(input2.begin(), input2.end(), [&]() { return f32dist(rng); });
830*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), nanf(""));
831*4bdc9457SAndroid Build Coastguard Worker 
832*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
833*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < output_dims[0]; i++) {
834*4bdc9457SAndroid Build Coastguard Worker         for (size_t j = 0; j < output_dims[1]; j++) {
835*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < output_dims[2]; k++) {
836*4bdc9457SAndroid Build Coastguard Worker             for (size_t l = 0; l < output_dims[3]; l++) {
837*4bdc9457SAndroid Build Coastguard Worker               for (size_t m = 0; m < output_dims[4]; m++) {
838*4bdc9457SAndroid Build Coastguard Worker                 for (size_t n = 0; n < output_dims[5]; n++) {
839*4bdc9457SAndroid Build Coastguard Worker                   output_ref[i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5]] = Compute(
840*4bdc9457SAndroid Build Coastguard Worker                     input1[i * input1_strides[0] + j * input1_strides[1] + k * input1_strides[2] + l * input1_strides[3] + m * input1_strides[4] + n * input1_strides[5]],
841*4bdc9457SAndroid Build Coastguard Worker                     input2[i * input2_strides[0] + j * input2_strides[1] + k * input2_strides[2] + l * input2_strides[3] + m * input2_strides[4] + n * input2_strides[5]]);
842*4bdc9457SAndroid Build Coastguard Worker                 }
843*4bdc9457SAndroid Build Coastguard Worker               }
844*4bdc9457SAndroid Build Coastguard Worker             }
845*4bdc9457SAndroid Build Coastguard Worker           }
846*4bdc9457SAndroid Build Coastguard Worker         }
847*4bdc9457SAndroid Build Coastguard Worker       }
848*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
849*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
850*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
851*4bdc9457SAndroid Build Coastguard Worker       const float output_min = num_output_elements == 1 ?
852*4bdc9457SAndroid Build Coastguard Worker         -std::numeric_limits<float>::infinity() : accumulated_min + accumulated_range / 255.0f * float(qmin());
853*4bdc9457SAndroid Build Coastguard Worker       const float output_max = num_output_elements == 1 ?
854*4bdc9457SAndroid Build Coastguard Worker         +std::numeric_limits<float>::infinity() : accumulated_max - accumulated_range / 255.0f * float(255 - qmax());
855*4bdc9457SAndroid Build Coastguard Worker       for (float& output_value : output_ref) {
856*4bdc9457SAndroid Build Coastguard Worker         output_value = std::min(std::max(output_value, output_min), output_max);
857*4bdc9457SAndroid Build Coastguard Worker       }
858*4bdc9457SAndroid Build Coastguard Worker 
859*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy a binary elementwise operator.
860*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
861*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t binary_elementwise_op = nullptr;
862*4bdc9457SAndroid Build Coastguard Worker 
863*4bdc9457SAndroid Build Coastguard Worker       switch (operation_type()) {
864*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Add:
865*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
866*4bdc9457SAndroid Build Coastguard Worker             xnn_create_add_nd_f32(
867*4bdc9457SAndroid Build Coastguard Worker               output_min, output_max,
868*4bdc9457SAndroid Build Coastguard Worker               0, &binary_elementwise_op));
869*4bdc9457SAndroid Build Coastguard Worker           break;
870*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Divide:
871*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
872*4bdc9457SAndroid Build Coastguard Worker             xnn_create_divide_nd_f32(
873*4bdc9457SAndroid Build Coastguard Worker               output_min, output_max,
874*4bdc9457SAndroid Build Coastguard Worker               0, &binary_elementwise_op));
875*4bdc9457SAndroid Build Coastguard Worker           break;
876*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Maximum:
877*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
878*4bdc9457SAndroid Build Coastguard Worker             xnn_create_maximum_nd_f32(
879*4bdc9457SAndroid Build Coastguard Worker               0, &binary_elementwise_op));
880*4bdc9457SAndroid Build Coastguard Worker           break;
881*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Minimum:
882*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
883*4bdc9457SAndroid Build Coastguard Worker             xnn_create_minimum_nd_f32(
884*4bdc9457SAndroid Build Coastguard Worker               0, &binary_elementwise_op));
885*4bdc9457SAndroid Build Coastguard Worker           break;
886*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Multiply:
887*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
888*4bdc9457SAndroid Build Coastguard Worker             xnn_create_multiply_nd_f32(
889*4bdc9457SAndroid Build Coastguard Worker               output_min, output_max,
890*4bdc9457SAndroid Build Coastguard Worker               0, &binary_elementwise_op));
891*4bdc9457SAndroid Build Coastguard Worker           break;
892*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Subtract:
893*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
894*4bdc9457SAndroid Build Coastguard Worker             xnn_create_subtract_nd_f32(
895*4bdc9457SAndroid Build Coastguard Worker               output_min, output_max,
896*4bdc9457SAndroid Build Coastguard Worker               0, &binary_elementwise_op));
897*4bdc9457SAndroid Build Coastguard Worker           break;
898*4bdc9457SAndroid Build Coastguard Worker         case OperationType::SquaredDifference:
899*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
900*4bdc9457SAndroid Build Coastguard Worker             xnn_create_squared_difference_nd_f32(
901*4bdc9457SAndroid Build Coastguard Worker               0, &binary_elementwise_op));
902*4bdc9457SAndroid Build Coastguard Worker           break;
903*4bdc9457SAndroid Build Coastguard Worker         default:
904*4bdc9457SAndroid Build Coastguard Worker           FAIL() << "Unsupported operation type";
905*4bdc9457SAndroid Build Coastguard Worker       }
906*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, binary_elementwise_op);
907*4bdc9457SAndroid Build Coastguard Worker 
908*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete binary_elementwise_op.
909*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_binary_elementwise_op(binary_elementwise_op, xnn_delete_operator);
910*4bdc9457SAndroid Build Coastguard Worker 
911*4bdc9457SAndroid Build Coastguard Worker       switch (operation_type()) {
912*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Add:
913*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
914*4bdc9457SAndroid Build Coastguard Worker             xnn_setup_add_nd_f32(
915*4bdc9457SAndroid Build Coastguard Worker               binary_elementwise_op,
916*4bdc9457SAndroid Build Coastguard Worker               num_input1_dims(),
917*4bdc9457SAndroid Build Coastguard Worker               input1_shape().data(),
918*4bdc9457SAndroid Build Coastguard Worker               num_input2_dims(),
919*4bdc9457SAndroid Build Coastguard Worker               input2_shape().data(),
920*4bdc9457SAndroid Build Coastguard Worker               input1.data(), input2.data(), output.data(),
921*4bdc9457SAndroid Build Coastguard Worker               nullptr /* thread pool */));
922*4bdc9457SAndroid Build Coastguard Worker           break;
923*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Divide:
924*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
925*4bdc9457SAndroid Build Coastguard Worker             xnn_setup_divide_nd_f32(
926*4bdc9457SAndroid Build Coastguard Worker               binary_elementwise_op,
927*4bdc9457SAndroid Build Coastguard Worker               num_input1_dims(),
928*4bdc9457SAndroid Build Coastguard Worker               input1_shape().data(),
929*4bdc9457SAndroid Build Coastguard Worker               num_input2_dims(),
930*4bdc9457SAndroid Build Coastguard Worker               input2_shape().data(),
931*4bdc9457SAndroid Build Coastguard Worker               input1.data(), input2.data(), output.data(),
932*4bdc9457SAndroid Build Coastguard Worker               nullptr /* thread pool */));
933*4bdc9457SAndroid Build Coastguard Worker           break;
934*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Maximum:
935*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
936*4bdc9457SAndroid Build Coastguard Worker             xnn_setup_maximum_nd_f32(
937*4bdc9457SAndroid Build Coastguard Worker               binary_elementwise_op,
938*4bdc9457SAndroid Build Coastguard Worker               num_input1_dims(),
939*4bdc9457SAndroid Build Coastguard Worker               input1_shape().data(),
940*4bdc9457SAndroid Build Coastguard Worker               num_input2_dims(),
941*4bdc9457SAndroid Build Coastguard Worker               input2_shape().data(),
942*4bdc9457SAndroid Build Coastguard Worker               input1.data(), input2.data(), output.data(),
943*4bdc9457SAndroid Build Coastguard Worker               nullptr /* thread pool */));
944*4bdc9457SAndroid Build Coastguard Worker           break;
945*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Minimum:
946*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
947*4bdc9457SAndroid Build Coastguard Worker             xnn_setup_minimum_nd_f32(
948*4bdc9457SAndroid Build Coastguard Worker               binary_elementwise_op,
949*4bdc9457SAndroid Build Coastguard Worker               num_input1_dims(),
950*4bdc9457SAndroid Build Coastguard Worker               input1_shape().data(),
951*4bdc9457SAndroid Build Coastguard Worker               num_input2_dims(),
952*4bdc9457SAndroid Build Coastguard Worker               input2_shape().data(),
953*4bdc9457SAndroid Build Coastguard Worker               input1.data(), input2.data(), output.data(),
954*4bdc9457SAndroid Build Coastguard Worker               nullptr /* thread pool */));
955*4bdc9457SAndroid Build Coastguard Worker           break;
956*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Multiply:
957*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
958*4bdc9457SAndroid Build Coastguard Worker             xnn_setup_multiply_nd_f32(
959*4bdc9457SAndroid Build Coastguard Worker               binary_elementwise_op,
960*4bdc9457SAndroid Build Coastguard Worker               num_input1_dims(),
961*4bdc9457SAndroid Build Coastguard Worker               input1_shape().data(),
962*4bdc9457SAndroid Build Coastguard Worker               num_input2_dims(),
963*4bdc9457SAndroid Build Coastguard Worker               input2_shape().data(),
964*4bdc9457SAndroid Build Coastguard Worker               input1.data(), input2.data(), output.data(),
965*4bdc9457SAndroid Build Coastguard Worker               nullptr /* thread pool */));
966*4bdc9457SAndroid Build Coastguard Worker           break;
967*4bdc9457SAndroid Build Coastguard Worker         case OperationType::Subtract:
968*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
969*4bdc9457SAndroid Build Coastguard Worker             xnn_setup_subtract_nd_f32(
970*4bdc9457SAndroid Build Coastguard Worker               binary_elementwise_op,
971*4bdc9457SAndroid Build Coastguard Worker               num_input1_dims(),
972*4bdc9457SAndroid Build Coastguard Worker               input1_shape().data(),
973*4bdc9457SAndroid Build Coastguard Worker               num_input2_dims(),
974*4bdc9457SAndroid Build Coastguard Worker               input2_shape().data(),
975*4bdc9457SAndroid Build Coastguard Worker               input1.data(), input2.data(), output.data(),
976*4bdc9457SAndroid Build Coastguard Worker               nullptr /* thread pool */));
977*4bdc9457SAndroid Build Coastguard Worker           break;
978*4bdc9457SAndroid Build Coastguard Worker         case OperationType::SquaredDifference:
979*4bdc9457SAndroid Build Coastguard Worker           ASSERT_EQ(xnn_status_success,
980*4bdc9457SAndroid Build Coastguard Worker             xnn_setup_squared_difference_nd_f32(
981*4bdc9457SAndroid Build Coastguard Worker               binary_elementwise_op,
982*4bdc9457SAndroid Build Coastguard Worker               num_input1_dims(),
983*4bdc9457SAndroid Build Coastguard Worker               input1_shape().data(),
984*4bdc9457SAndroid Build Coastguard Worker               num_input2_dims(),
985*4bdc9457SAndroid Build Coastguard Worker               input2_shape().data(),
986*4bdc9457SAndroid Build Coastguard Worker               input1.data(), input2.data(), output.data(),
987*4bdc9457SAndroid Build Coastguard Worker               nullptr /* thread pool */));
988*4bdc9457SAndroid Build Coastguard Worker           break;
989*4bdc9457SAndroid Build Coastguard Worker         default:
990*4bdc9457SAndroid Build Coastguard Worker           FAIL() << "Unsupported operation type";
991*4bdc9457SAndroid Build Coastguard Worker       }
992*4bdc9457SAndroid Build Coastguard Worker 
993*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
994*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(binary_elementwise_op, nullptr /* thread pool */));
995*4bdc9457SAndroid Build Coastguard Worker 
996*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
997*4bdc9457SAndroid Build Coastguard Worker       for (size_t i = 0; i < output_dims[0]; i++) {
998*4bdc9457SAndroid Build Coastguard Worker         for (size_t j = 0; j < output_dims[1]; j++) {
999*4bdc9457SAndroid Build Coastguard Worker           for (size_t k = 0; k < output_dims[2]; k++) {
1000*4bdc9457SAndroid Build Coastguard Worker             for (size_t l = 0; l < output_dims[3]; l++) {
1001*4bdc9457SAndroid Build Coastguard Worker               for (size_t m = 0; m < output_dims[4]; m++) {
1002*4bdc9457SAndroid Build Coastguard Worker                 for (size_t n = 0; n < output_dims[5]; n++) {
1003*4bdc9457SAndroid Build Coastguard Worker                   const size_t index =
1004*4bdc9457SAndroid Build Coastguard Worker                     i * output_strides[0] + j * output_strides[1] + k * output_strides[2] + l * output_strides[3] + m * output_strides[4] + n * output_strides[5];
1005*4bdc9457SAndroid Build Coastguard Worker                   ASSERT_NEAR(output[index], output_ref[index], 1.0e-6f * std::abs(output_ref[index]))
1006*4bdc9457SAndroid Build Coastguard Worker                     << "(i, j, k, l, m, n) = (" << i << ", " << j << ", " << k << ", " << l << ", " << m << ", " << n << ")";
1007*4bdc9457SAndroid Build Coastguard Worker                 }
1008*4bdc9457SAndroid Build Coastguard Worker               }
1009*4bdc9457SAndroid Build Coastguard Worker             }
1010*4bdc9457SAndroid Build Coastguard Worker           }
1011*4bdc9457SAndroid Build Coastguard Worker         }
1012*4bdc9457SAndroid Build Coastguard Worker       }
1013*4bdc9457SAndroid Build Coastguard Worker     }
1014*4bdc9457SAndroid Build Coastguard Worker   }
1015*4bdc9457SAndroid Build Coastguard Worker 
1016*4bdc9457SAndroid Build Coastguard Worker  private:
1017*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> input1_shape_;
1018*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> input2_shape_;
1019*4bdc9457SAndroid Build Coastguard Worker   int16_t input1_zero_point_{0};
1020*4bdc9457SAndroid Build Coastguard Worker   float input1_scale_{1.0f};
1021*4bdc9457SAndroid Build Coastguard Worker   int16_t input2_zero_point_{0};
1022*4bdc9457SAndroid Build Coastguard Worker   float input2_scale_{1.0f};
1023*4bdc9457SAndroid Build Coastguard Worker   int16_t output_zero_point_{0};
1024*4bdc9457SAndroid Build Coastguard Worker   float output_scale_{1.0f};
1025*4bdc9457SAndroid Build Coastguard Worker   uint8_t qmin_{0};
1026*4bdc9457SAndroid Build Coastguard Worker   uint8_t qmax_{255};
1027*4bdc9457SAndroid Build Coastguard Worker   OperationType operation_type_{OperationType::Unknown};
1028*4bdc9457SAndroid Build Coastguard Worker   size_t iterations_{3};
1029*4bdc9457SAndroid Build Coastguard Worker };
1030