xref: /aosp_15_r20/external/XNNPACK/test/fully-connected-operator-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates.
2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved.
3*4bdc9457SAndroid Build Coastguard Worker //
4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
5*4bdc9457SAndroid Build Coastguard Worker //
6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
8*4bdc9457SAndroid Build Coastguard Worker 
9*4bdc9457SAndroid Build Coastguard Worker #pragma once
10*4bdc9457SAndroid Build Coastguard Worker 
11*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
12*4bdc9457SAndroid Build Coastguard Worker 
13*4bdc9457SAndroid Build Coastguard Worker #include <cassert>
14*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
15*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
16*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
17*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
18*4bdc9457SAndroid Build Coastguard Worker #include <limits>
19*4bdc9457SAndroid Build Coastguard Worker #include <random>
20*4bdc9457SAndroid Build Coastguard Worker #include <vector>
21*4bdc9457SAndroid Build Coastguard Worker 
22*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
23*4bdc9457SAndroid Build Coastguard Worker 
24*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
25*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/cache.h>
26*4bdc9457SAndroid Build Coastguard Worker 
27*4bdc9457SAndroid Build Coastguard Worker 
28*4bdc9457SAndroid Build Coastguard Worker class FullyConnectedOperatorTester {
29*4bdc9457SAndroid Build Coastguard Worker  public:
30*4bdc9457SAndroid Build Coastguard Worker   enum class WeightsType {
31*4bdc9457SAndroid Build Coastguard Worker     Default,
32*4bdc9457SAndroid Build Coastguard Worker     FP32,
33*4bdc9457SAndroid Build Coastguard Worker   };
34*4bdc9457SAndroid Build Coastguard Worker 
input_channels(size_t input_channels)35*4bdc9457SAndroid Build Coastguard Worker   inline FullyConnectedOperatorTester& input_channels(size_t input_channels) {
36*4bdc9457SAndroid Build Coastguard Worker     assert(input_channels >= 1);
37*4bdc9457SAndroid Build Coastguard Worker     this->input_channels_ = input_channels;
38*4bdc9457SAndroid Build Coastguard Worker     return *this;
39*4bdc9457SAndroid Build Coastguard Worker   }
40*4bdc9457SAndroid Build Coastguard Worker 
input_channels()41*4bdc9457SAndroid Build Coastguard Worker   inline size_t input_channels() const {
42*4bdc9457SAndroid Build Coastguard Worker     return this->input_channels_;
43*4bdc9457SAndroid Build Coastguard Worker   }
44*4bdc9457SAndroid Build Coastguard Worker 
output_channels(size_t output_channels)45*4bdc9457SAndroid Build Coastguard Worker   inline FullyConnectedOperatorTester& output_channels(size_t output_channels) {
46*4bdc9457SAndroid Build Coastguard Worker     assert(output_channels >= 1);
47*4bdc9457SAndroid Build Coastguard Worker     this->output_channels_ = output_channels;
48*4bdc9457SAndroid Build Coastguard Worker     return *this;
49*4bdc9457SAndroid Build Coastguard Worker   }
50*4bdc9457SAndroid Build Coastguard Worker 
output_channels()51*4bdc9457SAndroid Build Coastguard Worker   inline size_t output_channels() const {
52*4bdc9457SAndroid Build Coastguard Worker     return this->output_channels_;
53*4bdc9457SAndroid Build Coastguard Worker   }
54*4bdc9457SAndroid Build Coastguard Worker 
batch_size(size_t batch_size)55*4bdc9457SAndroid Build Coastguard Worker   inline FullyConnectedOperatorTester& batch_size(size_t batch_size) {
56*4bdc9457SAndroid Build Coastguard Worker     assert(batch_size >= 1);
57*4bdc9457SAndroid Build Coastguard Worker     this->batch_size_ = batch_size;
58*4bdc9457SAndroid Build Coastguard Worker     return *this;
59*4bdc9457SAndroid Build Coastguard Worker   }
60*4bdc9457SAndroid Build Coastguard Worker 
batch_size()61*4bdc9457SAndroid Build Coastguard Worker   inline size_t batch_size() const {
62*4bdc9457SAndroid Build Coastguard Worker     return this->batch_size_;
63*4bdc9457SAndroid Build Coastguard Worker   }
64*4bdc9457SAndroid Build Coastguard Worker 
input_stride(size_t input_stride)65*4bdc9457SAndroid Build Coastguard Worker   inline FullyConnectedOperatorTester& input_stride(size_t input_stride) {
66*4bdc9457SAndroid Build Coastguard Worker     assert(input_stride >= 1);
67*4bdc9457SAndroid Build Coastguard Worker     this->input_stride_ = input_stride;
68*4bdc9457SAndroid Build Coastguard Worker     return *this;
69*4bdc9457SAndroid Build Coastguard Worker   }
70*4bdc9457SAndroid Build Coastguard Worker 
input_stride()71*4bdc9457SAndroid Build Coastguard Worker   inline size_t input_stride() const {
72*4bdc9457SAndroid Build Coastguard Worker     if (this->input_stride_ == 0) {
73*4bdc9457SAndroid Build Coastguard Worker       return input_channels();
74*4bdc9457SAndroid Build Coastguard Worker     } else {
75*4bdc9457SAndroid Build Coastguard Worker       assert(this->input_stride_ >= input_channels());
76*4bdc9457SAndroid Build Coastguard Worker       return this->input_stride_;
77*4bdc9457SAndroid Build Coastguard Worker     }
78*4bdc9457SAndroid Build Coastguard Worker   }
79*4bdc9457SAndroid Build Coastguard Worker 
output_stride(size_t output_stride)80*4bdc9457SAndroid Build Coastguard Worker   inline FullyConnectedOperatorTester& output_stride(size_t output_stride) {
81*4bdc9457SAndroid Build Coastguard Worker     assert(output_stride >= 1);
82*4bdc9457SAndroid Build Coastguard Worker     this->output_stride_ = output_stride;
83*4bdc9457SAndroid Build Coastguard Worker     return *this;
84*4bdc9457SAndroid Build Coastguard Worker   }
85*4bdc9457SAndroid Build Coastguard Worker 
output_stride()86*4bdc9457SAndroid Build Coastguard Worker   inline size_t output_stride() const {
87*4bdc9457SAndroid Build Coastguard Worker     if (this->output_stride_ == 0) {
88*4bdc9457SAndroid Build Coastguard Worker       return output_channels();
89*4bdc9457SAndroid Build Coastguard Worker     } else {
90*4bdc9457SAndroid Build Coastguard Worker       assert(this->output_stride_ >= output_channels());
91*4bdc9457SAndroid Build Coastguard Worker       return this->output_stride_;
92*4bdc9457SAndroid Build Coastguard Worker     }
93*4bdc9457SAndroid Build Coastguard Worker   }
94*4bdc9457SAndroid Build Coastguard Worker 
qmin(uint8_t qmin)95*4bdc9457SAndroid Build Coastguard Worker   inline FullyConnectedOperatorTester& qmin(uint8_t qmin) {
96*4bdc9457SAndroid Build Coastguard Worker     this->qmin_ = qmin;
97*4bdc9457SAndroid Build Coastguard Worker     return *this;
98*4bdc9457SAndroid Build Coastguard Worker   }
99*4bdc9457SAndroid Build Coastguard Worker 
qmin()100*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t qmin() const {
101*4bdc9457SAndroid Build Coastguard Worker     return this->qmin_;
102*4bdc9457SAndroid Build Coastguard Worker   }
103*4bdc9457SAndroid Build Coastguard Worker 
qmax(uint8_t qmax)104*4bdc9457SAndroid Build Coastguard Worker   inline FullyConnectedOperatorTester& qmax(uint8_t qmax) {
105*4bdc9457SAndroid Build Coastguard Worker     this->qmax_ = qmax;
106*4bdc9457SAndroid Build Coastguard Worker     return *this;
107*4bdc9457SAndroid Build Coastguard Worker   }
108*4bdc9457SAndroid Build Coastguard Worker 
qmax()109*4bdc9457SAndroid Build Coastguard Worker   inline uint8_t qmax() const {
110*4bdc9457SAndroid Build Coastguard Worker     return this->qmax_;
111*4bdc9457SAndroid Build Coastguard Worker   }
112*4bdc9457SAndroid Build Coastguard Worker 
transpose_weights(bool transpose_weights)113*4bdc9457SAndroid Build Coastguard Worker   inline FullyConnectedOperatorTester& transpose_weights(bool transpose_weights) {
114*4bdc9457SAndroid Build Coastguard Worker     this->transpose_weights_ = transpose_weights;
115*4bdc9457SAndroid Build Coastguard Worker     return *this;
116*4bdc9457SAndroid Build Coastguard Worker   }
117*4bdc9457SAndroid Build Coastguard Worker 
transpose_weights()118*4bdc9457SAndroid Build Coastguard Worker   inline bool transpose_weights() const {
119*4bdc9457SAndroid Build Coastguard Worker     return this->transpose_weights_;
120*4bdc9457SAndroid Build Coastguard Worker   }
121*4bdc9457SAndroid Build Coastguard Worker 
has_bias(bool has_bias)122*4bdc9457SAndroid Build Coastguard Worker   inline FullyConnectedOperatorTester& has_bias(bool has_bias) {
123*4bdc9457SAndroid Build Coastguard Worker     this->has_bias_ = has_bias;
124*4bdc9457SAndroid Build Coastguard Worker     return *this;
125*4bdc9457SAndroid Build Coastguard Worker   }
126*4bdc9457SAndroid Build Coastguard Worker 
has_bias()127*4bdc9457SAndroid Build Coastguard Worker   inline bool has_bias() const {
128*4bdc9457SAndroid Build Coastguard Worker     return this->has_bias_;
129*4bdc9457SAndroid Build Coastguard Worker   }
130*4bdc9457SAndroid Build Coastguard Worker 
weights_type(WeightsType weights_type)131*4bdc9457SAndroid Build Coastguard Worker   inline FullyConnectedOperatorTester& weights_type(WeightsType weights_type) {
132*4bdc9457SAndroid Build Coastguard Worker     this->weights_type_ = weights_type;
133*4bdc9457SAndroid Build Coastguard Worker     return *this;
134*4bdc9457SAndroid Build Coastguard Worker   }
135*4bdc9457SAndroid Build Coastguard Worker 
weights_type()136*4bdc9457SAndroid Build Coastguard Worker   inline WeightsType weights_type() const {
137*4bdc9457SAndroid Build Coastguard Worker     return this->weights_type_;
138*4bdc9457SAndroid Build Coastguard Worker   }
139*4bdc9457SAndroid Build Coastguard Worker 
use_weights_cache(bool use_weights_cache)140*4bdc9457SAndroid Build Coastguard Worker   inline FullyConnectedOperatorTester& use_weights_cache(bool use_weights_cache) {
141*4bdc9457SAndroid Build Coastguard Worker     this->use_weights_cache_ = use_weights_cache;
142*4bdc9457SAndroid Build Coastguard Worker     return *this;
143*4bdc9457SAndroid Build Coastguard Worker   }
144*4bdc9457SAndroid Build Coastguard Worker 
use_weights_cache()145*4bdc9457SAndroid Build Coastguard Worker   inline bool use_weights_cache() const {
146*4bdc9457SAndroid Build Coastguard Worker     return this->use_weights_cache_;
147*4bdc9457SAndroid Build Coastguard Worker   }
148*4bdc9457SAndroid Build Coastguard Worker 
iterations(size_t iterations)149*4bdc9457SAndroid Build Coastguard Worker   inline FullyConnectedOperatorTester& iterations(size_t iterations) {
150*4bdc9457SAndroid Build Coastguard Worker     this->iterations_ = iterations;
151*4bdc9457SAndroid Build Coastguard Worker     return *this;
152*4bdc9457SAndroid Build Coastguard Worker   }
153*4bdc9457SAndroid Build Coastguard Worker 
iterations()154*4bdc9457SAndroid Build Coastguard Worker   inline size_t iterations() const {
155*4bdc9457SAndroid Build Coastguard Worker     return this->iterations_;
156*4bdc9457SAndroid Build Coastguard Worker   }
157*4bdc9457SAndroid Build Coastguard Worker 
TestQS8()158*4bdc9457SAndroid Build Coastguard Worker   void TestQS8() const {
159*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(weights_type(), WeightsType::Default);
160*4bdc9457SAndroid Build Coastguard Worker 
161*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
162*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
163*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i32dist(-10000, 10000);
164*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i8dist(
165*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max());
166*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> w8dist(
167*4bdc9457SAndroid Build Coastguard Worker       -std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max());
168*4bdc9457SAndroid Build Coastguard Worker 
169*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> input(XNN_EXTRA_BYTES / sizeof(int8_t) +
170*4bdc9457SAndroid Build Coastguard Worker       (batch_size() - 1) * input_stride() + input_channels());
171*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> kernel(output_channels() * input_channels());
172*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> bias(output_channels());
173*4bdc9457SAndroid Build Coastguard Worker     std::vector<int8_t> output((batch_size() - 1) * output_stride() + output_channels());
174*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> accumulators(batch_size() * output_channels());
175*4bdc9457SAndroid Build Coastguard Worker     std::vector<double> output_ref(batch_size() * output_channels());
176*4bdc9457SAndroid Build Coastguard Worker 
177*4bdc9457SAndroid Build Coastguard Worker     const int8_t input_zero_point = 127;
178*4bdc9457SAndroid Build Coastguard Worker 
179*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
180*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return i8dist(rng); });
181*4bdc9457SAndroid Build Coastguard Worker       std::generate(kernel.begin(), kernel.end(), [&]() { return w8dist(rng); });
182*4bdc9457SAndroid Build Coastguard Worker       std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); });
183*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), INT8_C(0xA5));
184*4bdc9457SAndroid Build Coastguard Worker 
185*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without renormalization.
186*4bdc9457SAndroid Build Coastguard Worker       if (has_bias()) {
187*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
188*4bdc9457SAndroid Build Coastguard Worker           for (size_t oc = 0; oc < output_channels(); oc++) {
189*4bdc9457SAndroid Build Coastguard Worker             accumulators[i * output_channels() + oc] = bias[oc];
190*4bdc9457SAndroid Build Coastguard Worker           }
191*4bdc9457SAndroid Build Coastguard Worker         }
192*4bdc9457SAndroid Build Coastguard Worker       } else {
193*4bdc9457SAndroid Build Coastguard Worker         std::fill(accumulators.begin(), accumulators.end(), 0);
194*4bdc9457SAndroid Build Coastguard Worker       }
195*4bdc9457SAndroid Build Coastguard Worker       if (transpose_weights()) {
196*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
197*4bdc9457SAndroid Build Coastguard Worker           for (size_t oc = 0; oc < output_channels(); oc++) {
198*4bdc9457SAndroid Build Coastguard Worker             for (size_t ic = 0; ic < input_channels(); ic++) {
199*4bdc9457SAndroid Build Coastguard Worker               accumulators[i * output_channels() + oc] +=
200*4bdc9457SAndroid Build Coastguard Worker                 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
201*4bdc9457SAndroid Build Coastguard Worker                 int32_t(kernel[ic * output_channels() + oc]);
202*4bdc9457SAndroid Build Coastguard Worker             }
203*4bdc9457SAndroid Build Coastguard Worker           }
204*4bdc9457SAndroid Build Coastguard Worker         }
205*4bdc9457SAndroid Build Coastguard Worker       } else {
206*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
207*4bdc9457SAndroid Build Coastguard Worker           for (size_t oc = 0; oc < output_channels(); oc++) {
208*4bdc9457SAndroid Build Coastguard Worker             for (size_t ic = 0; ic < input_channels(); ic++) {
209*4bdc9457SAndroid Build Coastguard Worker               accumulators[i * output_channels() + oc] +=
210*4bdc9457SAndroid Build Coastguard Worker                 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
211*4bdc9457SAndroid Build Coastguard Worker                 int32_t(kernel[oc * input_channels() + ic]);
212*4bdc9457SAndroid Build Coastguard Worker             }
213*4bdc9457SAndroid Build Coastguard Worker           }
214*4bdc9457SAndroid Build Coastguard Worker         }
215*4bdc9457SAndroid Build Coastguard Worker       }
216*4bdc9457SAndroid Build Coastguard Worker 
217*4bdc9457SAndroid Build Coastguard Worker       // Compute renormalization parameters.
218*4bdc9457SAndroid Build Coastguard Worker       const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
219*4bdc9457SAndroid Build Coastguard Worker       const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
220*4bdc9457SAndroid Build Coastguard Worker 
221*4bdc9457SAndroid Build Coastguard Worker       const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0;
222*4bdc9457SAndroid Build Coastguard Worker       const int8_t output_zero_point = int8_t(std::max(std::min(
223*4bdc9457SAndroid Build Coastguard Worker         lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
224*4bdc9457SAndroid Build Coastguard Worker         long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
225*4bdc9457SAndroid Build Coastguard Worker 
226*4bdc9457SAndroid Build Coastguard Worker       // Renormalize reference results.
227*4bdc9457SAndroid Build Coastguard Worker       std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(),
228*4bdc9457SAndroid Build Coastguard Worker         [this, output_scale, output_zero_point](int32_t x) -> double {
229*4bdc9457SAndroid Build Coastguard Worker           return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax() - 0x80) - output_zero_point), double(qmin() - 0x80) - output_zero_point);
230*4bdc9457SAndroid Build Coastguard Worker         });
231*4bdc9457SAndroid Build Coastguard Worker 
232*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Fully Connected operator.
233*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
234*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t fully_connected_op = nullptr;
235*4bdc9457SAndroid Build Coastguard Worker 
236*4bdc9457SAndroid Build Coastguard Worker       xnn_caches caches = {
237*4bdc9457SAndroid Build Coastguard Worker         .code_cache = NULL,
238*4bdc9457SAndroid Build Coastguard Worker         .weights_cache = NULL,
239*4bdc9457SAndroid Build Coastguard Worker       };
240*4bdc9457SAndroid Build Coastguard Worker       xnn_weights_cache weights_cache;
241*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
242*4bdc9457SAndroid Build Coastguard Worker         xnn_init_weights_cache(&weights_cache);
243*4bdc9457SAndroid Build Coastguard Worker         caches.weights_cache = &weights_cache;
244*4bdc9457SAndroid Build Coastguard Worker       }
245*4bdc9457SAndroid Build Coastguard Worker 
246*4bdc9457SAndroid Build Coastguard Worker       const xnn_status status = xnn_create_fully_connected_nc_qs8(
247*4bdc9457SAndroid Build Coastguard Worker           input_channels(), output_channels(),
248*4bdc9457SAndroid Build Coastguard Worker           input_stride(), output_stride(),
249*4bdc9457SAndroid Build Coastguard Worker           input_zero_point, 1.0f /* input scale */,
250*4bdc9457SAndroid Build Coastguard Worker           1.0f /* kernel scale */,
251*4bdc9457SAndroid Build Coastguard Worker           kernel.data(), has_bias() ? bias.data() : nullptr,
252*4bdc9457SAndroid Build Coastguard Worker           output_zero_point, output_scale, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80),
253*4bdc9457SAndroid Build Coastguard Worker           transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
254*4bdc9457SAndroid Build Coastguard Worker           &caches,
255*4bdc9457SAndroid Build Coastguard Worker           &fully_connected_op);
256*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
257*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
258*4bdc9457SAndroid Build Coastguard Worker       }
259*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
260*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, fully_connected_op);
261*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
262*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
263*4bdc9457SAndroid Build Coastguard Worker                   xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
264*4bdc9457SAndroid Build Coastguard Worker       }
265*4bdc9457SAndroid Build Coastguard Worker 
266*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete fully_connected_op.
267*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
268*4bdc9457SAndroid Build Coastguard Worker 
269*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
270*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_fully_connected_nc_qs8(
271*4bdc9457SAndroid Build Coastguard Worker           fully_connected_op,
272*4bdc9457SAndroid Build Coastguard Worker           batch_size(),
273*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
274*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
275*4bdc9457SAndroid Build Coastguard Worker 
276*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
277*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(fully_connected_op, nullptr /* thread pool */));
278*4bdc9457SAndroid Build Coastguard Worker 
279*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
280*4bdc9457SAndroid Build Coastguard Worker       VerifyQS8(output, output_ref, double(output_zero_point));
281*4bdc9457SAndroid Build Coastguard Worker 
282*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
283*4bdc9457SAndroid Build Coastguard Worker         // Create another operator with the same weights cache.
284*4bdc9457SAndroid Build Coastguard Worker         xnn_operator_t fully_connected_op2 = nullptr;
285*4bdc9457SAndroid Build Coastguard Worker         size_t old_weights_cache_size = weights_cache.cache.weights.size;
286*4bdc9457SAndroid Build Coastguard Worker 
287*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
288*4bdc9457SAndroid Build Coastguard Worker                   xnn_create_fully_connected_nc_qs8(
289*4bdc9457SAndroid Build Coastguard Worker                       input_channels(), output_channels(), input_stride(),
290*4bdc9457SAndroid Build Coastguard Worker                       output_stride(), input_zero_point, 1.0f /* input scale */,
291*4bdc9457SAndroid Build Coastguard Worker                       1.0f /* kernel scale */, kernel.data(),
292*4bdc9457SAndroid Build Coastguard Worker                       has_bias() ? bias.data() : nullptr, output_zero_point,
293*4bdc9457SAndroid Build Coastguard Worker                       output_scale, int8_t(qmin() - 0x80),
294*4bdc9457SAndroid Build Coastguard Worker                       int8_t(qmax() - 0x80),
295*4bdc9457SAndroid Build Coastguard Worker                       transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
296*4bdc9457SAndroid Build Coastguard Worker                       &caches, &fully_connected_op2));
297*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NE(nullptr, fully_connected_op2);
298*4bdc9457SAndroid Build Coastguard Worker 
299*4bdc9457SAndroid Build Coastguard Worker         // Smart pointer to automatically delete fully_connected_op.
300*4bdc9457SAndroid Build Coastguard Worker         std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)>
301*4bdc9457SAndroid Build Coastguard Worker             auto_fully_connected_op(fully_connected_op2, xnn_delete_operator);
302*4bdc9457SAndroid Build Coastguard Worker         std::vector<int8_t> output2(output.size(), INT8_C(0xA5));
303*4bdc9457SAndroid Build Coastguard Worker 
304*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
305*4bdc9457SAndroid Build Coastguard Worker                   xnn_setup_fully_connected_nc_qs8(
306*4bdc9457SAndroid Build Coastguard Worker                       fully_connected_op2,
307*4bdc9457SAndroid Build Coastguard Worker                       batch_size(),
308*4bdc9457SAndroid Build Coastguard Worker                       input.data(), output2.data(),
309*4bdc9457SAndroid Build Coastguard Worker                       nullptr /* thread pool */));
310*4bdc9457SAndroid Build Coastguard Worker 
311*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(
312*4bdc9457SAndroid Build Coastguard Worker             xnn_status_success,
313*4bdc9457SAndroid Build Coastguard Worker             xnn_run_operator(fully_connected_op2, nullptr /* thread pool */));
314*4bdc9457SAndroid Build Coastguard Worker 
315*4bdc9457SAndroid Build Coastguard Worker         VerifyWeightsCache(weights_cache, old_weights_cache_size);
316*4bdc9457SAndroid Build Coastguard Worker         xnn_release_weights_cache(&weights_cache);
317*4bdc9457SAndroid Build Coastguard Worker 
318*4bdc9457SAndroid Build Coastguard Worker         VerifyQS8(output, output_ref, double(output_zero_point));
319*4bdc9457SAndroid Build Coastguard Worker       }
320*4bdc9457SAndroid Build Coastguard Worker     }
321*4bdc9457SAndroid Build Coastguard Worker   }
322*4bdc9457SAndroid Build Coastguard Worker 
VerifyQS8(const std::vector<int8_t> & output,const std::vector<double> & output_ref,double output_zero_point)323*4bdc9457SAndroid Build Coastguard Worker   void VerifyQS8(const std::vector<int8_t>& output,
324*4bdc9457SAndroid Build Coastguard Worker                  const std::vector<double>& output_ref,
325*4bdc9457SAndroid Build Coastguard Worker                  double output_zero_point) const {
326*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < batch_size(); i++) {
327*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < output_channels(); c++) {
328*4bdc9457SAndroid Build Coastguard Worker         ASSERT_LE(int32_t(output[i * output_stride() + c]), int32_t(qmax() - 0x80))
329*4bdc9457SAndroid Build Coastguard Worker             << "batch index = " << i << ", channel = " << c;
330*4bdc9457SAndroid Build Coastguard Worker         ASSERT_GE(int32_t(output[i * output_stride() + c]), int32_t(qmin() - 0x80))
331*4bdc9457SAndroid Build Coastguard Worker             << "batch index = " << i << ", channel = " << c;
332*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(output_ref[i * output_channels() + c],
333*4bdc9457SAndroid Build Coastguard Worker                     double(output[i * output_stride() + c]) - output_zero_point,
334*4bdc9457SAndroid Build Coastguard Worker                     0.9)
335*4bdc9457SAndroid Build Coastguard Worker             << "batch index = " << i << ", channel = " << c;
336*4bdc9457SAndroid Build Coastguard Worker       }
337*4bdc9457SAndroid Build Coastguard Worker     }
338*4bdc9457SAndroid Build Coastguard Worker   }
339*4bdc9457SAndroid Build Coastguard Worker 
TestQU8()340*4bdc9457SAndroid Build Coastguard Worker   void TestQU8() const {
341*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(weights_type(), WeightsType::Default);
342*4bdc9457SAndroid Build Coastguard Worker 
343*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
344*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
345*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> i32dist(-10000, 10000);
346*4bdc9457SAndroid Build Coastguard Worker     std::uniform_int_distribution<int32_t> u8dist(
347*4bdc9457SAndroid Build Coastguard Worker       std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
348*4bdc9457SAndroid Build Coastguard Worker 
349*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) +
350*4bdc9457SAndroid Build Coastguard Worker       (batch_size() - 1) * input_stride() + input_channels());
351*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> kernel(output_channels() * input_channels());
352*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> bias(output_channels());
353*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint8_t> output((batch_size() - 1) * output_stride() + output_channels());
354*4bdc9457SAndroid Build Coastguard Worker     std::vector<int32_t> accumulators(batch_size() * output_channels());
355*4bdc9457SAndroid Build Coastguard Worker     std::vector<double> output_ref(batch_size() * output_channels());
356*4bdc9457SAndroid Build Coastguard Worker 
357*4bdc9457SAndroid Build Coastguard Worker     const uint8_t input_zero_point = 127;
358*4bdc9457SAndroid Build Coastguard Worker     const uint8_t kernel_zero_point = 127;
359*4bdc9457SAndroid Build Coastguard Worker 
360*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
361*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return u8dist(rng); });
362*4bdc9457SAndroid Build Coastguard Worker       std::generate(kernel.begin(), kernel.end(), [&]() { return u8dist(rng); });
363*4bdc9457SAndroid Build Coastguard Worker       std::generate(bias.begin(), bias.end(), [&]() { return i32dist(rng); });
364*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT8_C(0xA5));
365*4bdc9457SAndroid Build Coastguard Worker 
366*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without renormalization.
367*4bdc9457SAndroid Build Coastguard Worker       if (has_bias()) {
368*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
369*4bdc9457SAndroid Build Coastguard Worker           for (size_t oc = 0; oc < output_channels(); oc++) {
370*4bdc9457SAndroid Build Coastguard Worker             accumulators[i * output_channels() + oc] = bias[oc];
371*4bdc9457SAndroid Build Coastguard Worker           }
372*4bdc9457SAndroid Build Coastguard Worker         }
373*4bdc9457SAndroid Build Coastguard Worker       } else {
374*4bdc9457SAndroid Build Coastguard Worker         std::fill(accumulators.begin(), accumulators.end(), 0);
375*4bdc9457SAndroid Build Coastguard Worker       }
376*4bdc9457SAndroid Build Coastguard Worker       if (transpose_weights()) {
377*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
378*4bdc9457SAndroid Build Coastguard Worker           for (size_t oc = 0; oc < output_channels(); oc++) {
379*4bdc9457SAndroid Build Coastguard Worker             for (size_t ic = 0; ic < input_channels(); ic++) {
380*4bdc9457SAndroid Build Coastguard Worker               accumulators[i * output_channels() + oc] +=
381*4bdc9457SAndroid Build Coastguard Worker                 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
382*4bdc9457SAndroid Build Coastguard Worker                 (int32_t(kernel[ic * output_channels() + oc]) - int32_t(kernel_zero_point));
383*4bdc9457SAndroid Build Coastguard Worker             }
384*4bdc9457SAndroid Build Coastguard Worker           }
385*4bdc9457SAndroid Build Coastguard Worker         }
386*4bdc9457SAndroid Build Coastguard Worker       } else {
387*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
388*4bdc9457SAndroid Build Coastguard Worker           for (size_t oc = 0; oc < output_channels(); oc++) {
389*4bdc9457SAndroid Build Coastguard Worker             for (size_t ic = 0; ic < input_channels(); ic++) {
390*4bdc9457SAndroid Build Coastguard Worker               accumulators[i * output_channels() + oc] +=
391*4bdc9457SAndroid Build Coastguard Worker                 (int32_t(input[i * input_stride() + ic]) - int32_t(input_zero_point)) *
392*4bdc9457SAndroid Build Coastguard Worker                 (int32_t(kernel[oc * input_channels() + ic]) - int32_t(kernel_zero_point));
393*4bdc9457SAndroid Build Coastguard Worker             }
394*4bdc9457SAndroid Build Coastguard Worker           }
395*4bdc9457SAndroid Build Coastguard Worker         }
396*4bdc9457SAndroid Build Coastguard Worker       }
397*4bdc9457SAndroid Build Coastguard Worker 
398*4bdc9457SAndroid Build Coastguard Worker       // Compute renormalization parameters.
399*4bdc9457SAndroid Build Coastguard Worker       const int32_t accumulated_min = *std::min_element(accumulators.cbegin(), accumulators.cend());
400*4bdc9457SAndroid Build Coastguard Worker       const int32_t accumulated_max = *std::max_element(accumulators.cbegin(), accumulators.cend());
401*4bdc9457SAndroid Build Coastguard Worker 
402*4bdc9457SAndroid Build Coastguard Worker       const double output_scale = double(uint32_t(accumulated_max - accumulated_min)) / 255.0;
403*4bdc9457SAndroid Build Coastguard Worker       const uint8_t output_zero_point = uint8_t(std::max(std::min(
404*4bdc9457SAndroid Build Coastguard Worker         lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / output_scale),
405*4bdc9457SAndroid Build Coastguard Worker         long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
406*4bdc9457SAndroid Build Coastguard Worker 
407*4bdc9457SAndroid Build Coastguard Worker       // Renormalize reference results.
408*4bdc9457SAndroid Build Coastguard Worker       std::transform(accumulators.cbegin(), accumulators.cend(), output_ref.begin(),
409*4bdc9457SAndroid Build Coastguard Worker         [this, output_scale, output_zero_point](int32_t x) -> double {
410*4bdc9457SAndroid Build Coastguard Worker           return std::max<double>(std::min<double>(double(x) / output_scale, double(qmax()) - output_zero_point), double(qmin()) - output_zero_point);
411*4bdc9457SAndroid Build Coastguard Worker         });
412*4bdc9457SAndroid Build Coastguard Worker 
413*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Fully Connected operator.
414*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
415*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t fully_connected_op = nullptr;
416*4bdc9457SAndroid Build Coastguard Worker 
417*4bdc9457SAndroid Build Coastguard Worker       xnn_caches caches = {
418*4bdc9457SAndroid Build Coastguard Worker         .code_cache = NULL,
419*4bdc9457SAndroid Build Coastguard Worker         .weights_cache = NULL,
420*4bdc9457SAndroid Build Coastguard Worker       };
421*4bdc9457SAndroid Build Coastguard Worker       xnn_weights_cache weights_cache;
422*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
423*4bdc9457SAndroid Build Coastguard Worker         xnn_init_weights_cache(&weights_cache);
424*4bdc9457SAndroid Build Coastguard Worker         caches.weights_cache = &weights_cache;
425*4bdc9457SAndroid Build Coastguard Worker       }
426*4bdc9457SAndroid Build Coastguard Worker 
427*4bdc9457SAndroid Build Coastguard Worker       const xnn_status status = xnn_create_fully_connected_nc_qu8(
428*4bdc9457SAndroid Build Coastguard Worker           input_channels(), output_channels(),
429*4bdc9457SAndroid Build Coastguard Worker           input_stride(), output_stride(),
430*4bdc9457SAndroid Build Coastguard Worker           input_zero_point, 1.0f /* input scale */,
431*4bdc9457SAndroid Build Coastguard Worker           kernel_zero_point, 1.0f /* kernel scale */,
432*4bdc9457SAndroid Build Coastguard Worker           kernel.data(), has_bias() ? bias.data() : nullptr,
433*4bdc9457SAndroid Build Coastguard Worker           output_zero_point, output_scale, qmin(), qmax(),
434*4bdc9457SAndroid Build Coastguard Worker           transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
435*4bdc9457SAndroid Build Coastguard Worker           &caches,
436*4bdc9457SAndroid Build Coastguard Worker           &fully_connected_op);
437*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
438*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
439*4bdc9457SAndroid Build Coastguard Worker       }
440*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
441*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, fully_connected_op);
442*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
443*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
444*4bdc9457SAndroid Build Coastguard Worker                   xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
445*4bdc9457SAndroid Build Coastguard Worker       }
446*4bdc9457SAndroid Build Coastguard Worker 
447*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete fully_connected_op.
448*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
449*4bdc9457SAndroid Build Coastguard Worker 
450*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
451*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_fully_connected_nc_qu8(
452*4bdc9457SAndroid Build Coastguard Worker           fully_connected_op,
453*4bdc9457SAndroid Build Coastguard Worker           batch_size(),
454*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
455*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
456*4bdc9457SAndroid Build Coastguard Worker 
457*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
458*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(fully_connected_op, nullptr /* thread pool */));
459*4bdc9457SAndroid Build Coastguard Worker 
460*4bdc9457SAndroid Build Coastguard Worker       VerifyQU8(output, output_ref, double(output_zero_point));
461*4bdc9457SAndroid Build Coastguard Worker 
462*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
463*4bdc9457SAndroid Build Coastguard Worker         // Create another operator with the same weights cache.
464*4bdc9457SAndroid Build Coastguard Worker         xnn_operator_t fully_connected_op2 = nullptr;
465*4bdc9457SAndroid Build Coastguard Worker         size_t old_weights_cache_size = weights_cache.cache.weights.size;
466*4bdc9457SAndroid Build Coastguard Worker 
467*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
468*4bdc9457SAndroid Build Coastguard Worker                   xnn_create_fully_connected_nc_qu8(
469*4bdc9457SAndroid Build Coastguard Worker                       input_channels(), output_channels(), input_stride(),
470*4bdc9457SAndroid Build Coastguard Worker                       output_stride(), input_zero_point, 1.0f /* input scale */,
471*4bdc9457SAndroid Build Coastguard Worker                       kernel_zero_point, 1.0f /* kernel scale */, kernel.data(),
472*4bdc9457SAndroid Build Coastguard Worker                       has_bias() ? bias.data() : nullptr, output_zero_point,
473*4bdc9457SAndroid Build Coastguard Worker                       output_scale, qmin(), qmax(),
474*4bdc9457SAndroid Build Coastguard Worker                       transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
475*4bdc9457SAndroid Build Coastguard Worker                       &caches, &fully_connected_op2));
476*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NE(nullptr, fully_connected_op2);
477*4bdc9457SAndroid Build Coastguard Worker 
478*4bdc9457SAndroid Build Coastguard Worker         // Smart pointer to automatically delete fully_connected_op.
479*4bdc9457SAndroid Build Coastguard Worker         std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)>
480*4bdc9457SAndroid Build Coastguard Worker             auto_fully_connected_op(fully_connected_op2, xnn_delete_operator);
481*4bdc9457SAndroid Build Coastguard Worker         std::vector<uint8_t> output2(output.size(), UINT8_C(0xA5));
482*4bdc9457SAndroid Build Coastguard Worker 
483*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
484*4bdc9457SAndroid Build Coastguard Worker                   xnn_setup_fully_connected_nc_qu8(
485*4bdc9457SAndroid Build Coastguard Worker                       fully_connected_op2, batch_size(), input.data(),
486*4bdc9457SAndroid Build Coastguard Worker                       output2.data(), nullptr /* thread pool */));
487*4bdc9457SAndroid Build Coastguard Worker 
488*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(
489*4bdc9457SAndroid Build Coastguard Worker             xnn_status_success,
490*4bdc9457SAndroid Build Coastguard Worker             xnn_run_operator(fully_connected_op2, nullptr /* thread pool */));
491*4bdc9457SAndroid Build Coastguard Worker 
492*4bdc9457SAndroid Build Coastguard Worker         VerifyWeightsCache(weights_cache, old_weights_cache_size);
493*4bdc9457SAndroid Build Coastguard Worker         xnn_release_weights_cache(&weights_cache);
494*4bdc9457SAndroid Build Coastguard Worker 
495*4bdc9457SAndroid Build Coastguard Worker         VerifyQU8(output2, output_ref, double(output_zero_point));
496*4bdc9457SAndroid Build Coastguard Worker       }
497*4bdc9457SAndroid Build Coastguard Worker 
498*4bdc9457SAndroid Build Coastguard Worker     }
499*4bdc9457SAndroid Build Coastguard Worker   }
500*4bdc9457SAndroid Build Coastguard Worker 
VerifyQU8(const std::vector<uint8_t> & output,const std::vector<double> & output_ref,double output_zero_point)501*4bdc9457SAndroid Build Coastguard Worker   void VerifyQU8(const std::vector<uint8_t>& output,
502*4bdc9457SAndroid Build Coastguard Worker                  const std::vector<double>& output_ref,
503*4bdc9457SAndroid Build Coastguard Worker                  double output_zero_point) const {
504*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < batch_size(); i++) {
505*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < output_channels(); c++) {
506*4bdc9457SAndroid Build Coastguard Worker         ASSERT_LE(int32_t(output[i * output_stride() + c]), int32_t(qmax()))
507*4bdc9457SAndroid Build Coastguard Worker             << "batch index = " << i << ", channel = " << c;
508*4bdc9457SAndroid Build Coastguard Worker         ASSERT_GE(int32_t(output[i * output_stride() + c]), int32_t(qmin()))
509*4bdc9457SAndroid Build Coastguard Worker             << "batch index = " << i << ", channel = " << c;
510*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(output_ref[i * output_channels() + c],
511*4bdc9457SAndroid Build Coastguard Worker                     double(output[i * output_stride() + c]) - output_zero_point,
512*4bdc9457SAndroid Build Coastguard Worker                     0.9)
513*4bdc9457SAndroid Build Coastguard Worker             << "batch index = " << i << ", channel = " << c;
514*4bdc9457SAndroid Build Coastguard Worker       }
515*4bdc9457SAndroid Build Coastguard Worker     }
516*4bdc9457SAndroid Build Coastguard Worker   }
517*4bdc9457SAndroid Build Coastguard Worker 
TestF32()518*4bdc9457SAndroid Build Coastguard Worker   void TestF32() const {
519*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(weights_type(), WeightsType::Default);
520*4bdc9457SAndroid Build Coastguard Worker 
521*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
522*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
523*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(0.1f, 1.0f);
524*4bdc9457SAndroid Build Coastguard Worker 
525*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
526*4bdc9457SAndroid Build Coastguard Worker       (batch_size() - 1) * input_stride() + input_channels());
527*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> kernel(output_channels() * input_channels());
528*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> bias(output_channels());
529*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output((batch_size() - 1) * output_stride() + output_channels());
530*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(batch_size() * output_channels());
531*4bdc9457SAndroid Build Coastguard Worker 
532*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
533*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); });
534*4bdc9457SAndroid Build Coastguard Worker       std::generate(kernel.begin(), kernel.end(), [&]() { return f32dist(rng); });
535*4bdc9457SAndroid Build Coastguard Worker       std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
536*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), nanf(""));
537*4bdc9457SAndroid Build Coastguard Worker 
538*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without renormalization.
539*4bdc9457SAndroid Build Coastguard Worker       if (has_bias()) {
540*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
541*4bdc9457SAndroid Build Coastguard Worker           for (size_t oc = 0; oc < output_channels(); oc++) {
542*4bdc9457SAndroid Build Coastguard Worker             output_ref[i * output_channels() + oc] = bias[oc];
543*4bdc9457SAndroid Build Coastguard Worker           }
544*4bdc9457SAndroid Build Coastguard Worker         }
545*4bdc9457SAndroid Build Coastguard Worker       } else {
546*4bdc9457SAndroid Build Coastguard Worker         std::fill(output_ref.begin(), output_ref.end(), 0.0f);
547*4bdc9457SAndroid Build Coastguard Worker       }
548*4bdc9457SAndroid Build Coastguard Worker       if (transpose_weights()) {
549*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
550*4bdc9457SAndroid Build Coastguard Worker           for (size_t oc = 0; oc < output_channels(); oc++) {
551*4bdc9457SAndroid Build Coastguard Worker             for (size_t ic = 0; ic < input_channels(); ic++) {
552*4bdc9457SAndroid Build Coastguard Worker               output_ref[i * output_channels() + oc] +=
553*4bdc9457SAndroid Build Coastguard Worker                 input[i * input_stride() + ic] * kernel[ic * output_channels() + oc];
554*4bdc9457SAndroid Build Coastguard Worker             }
555*4bdc9457SAndroid Build Coastguard Worker           }
556*4bdc9457SAndroid Build Coastguard Worker         }
557*4bdc9457SAndroid Build Coastguard Worker       } else {
558*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
559*4bdc9457SAndroid Build Coastguard Worker           for (size_t oc = 0; oc < output_channels(); oc++) {
560*4bdc9457SAndroid Build Coastguard Worker             for (size_t ic = 0; ic < input_channels(); ic++) {
561*4bdc9457SAndroid Build Coastguard Worker               output_ref[i * output_channels() + oc] +=
562*4bdc9457SAndroid Build Coastguard Worker                 input[i * input_stride() + ic] * kernel[oc * input_channels() + ic];
563*4bdc9457SAndroid Build Coastguard Worker             }
564*4bdc9457SAndroid Build Coastguard Worker           }
565*4bdc9457SAndroid Build Coastguard Worker         }
566*4bdc9457SAndroid Build Coastguard Worker       }
567*4bdc9457SAndroid Build Coastguard Worker 
568*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
569*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
570*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
571*4bdc9457SAndroid Build Coastguard Worker 
572*4bdc9457SAndroid Build Coastguard Worker       const float output_min = qmin() == 0 ? -std::numeric_limits<float>::infinity() :
573*4bdc9457SAndroid Build Coastguard Worker         accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
574*4bdc9457SAndroid Build Coastguard Worker       const float output_max = qmax() == 255 ? std::numeric_limits<float>::infinity() :
575*4bdc9457SAndroid Build Coastguard Worker         accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
576*4bdc9457SAndroid Build Coastguard Worker 
577*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
578*4bdc9457SAndroid Build Coastguard Worker       for (float& value : output_ref) {
579*4bdc9457SAndroid Build Coastguard Worker         value = std::max(std::min(value, output_max), output_min);
580*4bdc9457SAndroid Build Coastguard Worker       }
581*4bdc9457SAndroid Build Coastguard Worker 
582*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Fully Connected operator.
583*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
584*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t fully_connected_op = nullptr;
585*4bdc9457SAndroid Build Coastguard Worker 
586*4bdc9457SAndroid Build Coastguard Worker       xnn_caches caches = {
587*4bdc9457SAndroid Build Coastguard Worker         .code_cache = NULL,
588*4bdc9457SAndroid Build Coastguard Worker         .weights_cache = NULL,
589*4bdc9457SAndroid Build Coastguard Worker       };
590*4bdc9457SAndroid Build Coastguard Worker       xnn_weights_cache weights_cache;
591*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
592*4bdc9457SAndroid Build Coastguard Worker         xnn_init_weights_cache(&weights_cache);
593*4bdc9457SAndroid Build Coastguard Worker         caches.weights_cache = &weights_cache;
594*4bdc9457SAndroid Build Coastguard Worker       }
595*4bdc9457SAndroid Build Coastguard Worker 
596*4bdc9457SAndroid Build Coastguard Worker       const xnn_status status = xnn_create_fully_connected_nc_f32(
597*4bdc9457SAndroid Build Coastguard Worker           input_channels(), output_channels(),
598*4bdc9457SAndroid Build Coastguard Worker           input_stride(), output_stride(),
599*4bdc9457SAndroid Build Coastguard Worker           kernel.data(), has_bias() ? bias.data() : nullptr,
600*4bdc9457SAndroid Build Coastguard Worker           output_min, output_max,
601*4bdc9457SAndroid Build Coastguard Worker           transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
602*4bdc9457SAndroid Build Coastguard Worker           &caches,
603*4bdc9457SAndroid Build Coastguard Worker           &fully_connected_op);
604*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
605*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
606*4bdc9457SAndroid Build Coastguard Worker       }
607*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
608*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, fully_connected_op);
609*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
610*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
611*4bdc9457SAndroid Build Coastguard Worker                   xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
612*4bdc9457SAndroid Build Coastguard Worker       }
613*4bdc9457SAndroid Build Coastguard Worker 
614*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete fully_connected_op.
615*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
616*4bdc9457SAndroid Build Coastguard Worker 
617*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
618*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_fully_connected_nc_f32(
619*4bdc9457SAndroid Build Coastguard Worker           fully_connected_op,
620*4bdc9457SAndroid Build Coastguard Worker           batch_size(),
621*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
622*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
623*4bdc9457SAndroid Build Coastguard Worker 
624*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
625*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(fully_connected_op, nullptr /* thread pool */));
626*4bdc9457SAndroid Build Coastguard Worker 
627*4bdc9457SAndroid Build Coastguard Worker       VerifyF32(output, output_ref, output_max, output_min);
628*4bdc9457SAndroid Build Coastguard Worker 
629*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
630*4bdc9457SAndroid Build Coastguard Worker         // Create another operator with the same weights cache.
631*4bdc9457SAndroid Build Coastguard Worker         xnn_operator_t fully_connected_op2 = nullptr;
632*4bdc9457SAndroid Build Coastguard Worker         size_t old_weights_cache_size = weights_cache.cache.weights.size;
633*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
634*4bdc9457SAndroid Build Coastguard Worker                   xnn_create_fully_connected_nc_f32(
635*4bdc9457SAndroid Build Coastguard Worker                       input_channels(), output_channels(), input_stride(),
636*4bdc9457SAndroid Build Coastguard Worker                       output_stride(), kernel.data(),
637*4bdc9457SAndroid Build Coastguard Worker                       has_bias() ? bias.data() : nullptr, output_min,
638*4bdc9457SAndroid Build Coastguard Worker                       output_max,
639*4bdc9457SAndroid Build Coastguard Worker                       transpose_weights() ? XNN_FLAG_TRANSPOSE_WEIGHTS : 0,
640*4bdc9457SAndroid Build Coastguard Worker                       &caches, &fully_connected_op2));
641*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NE(nullptr, fully_connected_op2);
642*4bdc9457SAndroid Build Coastguard Worker 
643*4bdc9457SAndroid Build Coastguard Worker         std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op2, xnn_delete_operator);
644*4bdc9457SAndroid Build Coastguard Worker 
645*4bdc9457SAndroid Build Coastguard Worker         std::vector<float> output2(output.size(), nanf(""));
646*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
647*4bdc9457SAndroid Build Coastguard Worker                   xnn_setup_fully_connected_nc_f32(
648*4bdc9457SAndroid Build Coastguard Worker                       fully_connected_op2,
649*4bdc9457SAndroid Build Coastguard Worker                       batch_size(),
650*4bdc9457SAndroid Build Coastguard Worker                       input.data(), output2.data(),
651*4bdc9457SAndroid Build Coastguard Worker                       nullptr /* thread pool */));
652*4bdc9457SAndroid Build Coastguard Worker 
653*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
654*4bdc9457SAndroid Build Coastguard Worker                   xnn_run_operator(fully_connected_op2, nullptr /* thread pool */));
655*4bdc9457SAndroid Build Coastguard Worker         VerifyWeightsCache(weights_cache, old_weights_cache_size);
656*4bdc9457SAndroid Build Coastguard Worker         xnn_release_weights_cache(&weights_cache);
657*4bdc9457SAndroid Build Coastguard Worker 
658*4bdc9457SAndroid Build Coastguard Worker         VerifyF32(output, output_ref, output_max, output_min);
659*4bdc9457SAndroid Build Coastguard Worker       }
660*4bdc9457SAndroid Build Coastguard Worker     }
661*4bdc9457SAndroid Build Coastguard Worker   }
662*4bdc9457SAndroid Build Coastguard Worker 
VerifyF32(const std::vector<float> & output,const std::vector<float> & output_ref,float output_max,float output_min)663*4bdc9457SAndroid Build Coastguard Worker   void VerifyF32(const std::vector<float>& output,
664*4bdc9457SAndroid Build Coastguard Worker                  const std::vector<float>& output_ref,
665*4bdc9457SAndroid Build Coastguard Worker                  float output_max,
666*4bdc9457SAndroid Build Coastguard Worker                  float output_min) const {
667*4bdc9457SAndroid Build Coastguard Worker     // Verify results.
668*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < batch_size(); i++) {
669*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < output_channels(); c++) {
670*4bdc9457SAndroid Build Coastguard Worker         ASSERT_LE(output[i * output_stride() + c], output_max)
671*4bdc9457SAndroid Build Coastguard Worker             << "batch index = " << i << ", channel = " << c;
672*4bdc9457SAndroid Build Coastguard Worker         ASSERT_GE(output[i * output_stride() + c], output_min)
673*4bdc9457SAndroid Build Coastguard Worker             << "batch index = " << i << ", channel = " << c;
674*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(output_ref[i * output_channels() + c],
675*4bdc9457SAndroid Build Coastguard Worker                     output[i * output_stride() + c],
676*4bdc9457SAndroid Build Coastguard Worker                     1.0e-4 * std::abs(output_ref[i * output_channels() + c]))
677*4bdc9457SAndroid Build Coastguard Worker             << "batch index = " << i << ", channel = " << c;
678*4bdc9457SAndroid Build Coastguard Worker       }
679*4bdc9457SAndroid Build Coastguard Worker     }
680*4bdc9457SAndroid Build Coastguard Worker   }
681*4bdc9457SAndroid Build Coastguard Worker 
TestF16()682*4bdc9457SAndroid Build Coastguard Worker   void TestF16() const {
683*4bdc9457SAndroid Build Coastguard Worker     switch (weights_type()) {
684*4bdc9457SAndroid Build Coastguard Worker       case WeightsType::Default:
685*4bdc9457SAndroid Build Coastguard Worker         break;
686*4bdc9457SAndroid Build Coastguard Worker       case WeightsType::FP32:
687*4bdc9457SAndroid Build Coastguard Worker         break;
688*4bdc9457SAndroid Build Coastguard Worker       default:
689*4bdc9457SAndroid Build Coastguard Worker         GTEST_FAIL() << "unexpected weights type";
690*4bdc9457SAndroid Build Coastguard Worker     }
691*4bdc9457SAndroid Build Coastguard Worker 
692*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
693*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
694*4bdc9457SAndroid Build Coastguard Worker     std::uniform_real_distribution<float> f32dist(0.1f, 1.0f);
695*4bdc9457SAndroid Build Coastguard Worker 
696*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) +
697*4bdc9457SAndroid Build Coastguard Worker       (batch_size() - 1) * input_stride() + input_channels());
698*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> kernel(output_channels() * input_channels());
699*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> kernel_as_float(kernel.size());
700*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> bias(output_channels());
701*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> bias_as_float(bias.size());
702*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> output((batch_size() - 1) * output_stride() + output_channels());
703*4bdc9457SAndroid Build Coastguard Worker     std::vector<float> output_ref(batch_size() * output_channels());
704*4bdc9457SAndroid Build Coastguard Worker 
705*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
706*4bdc9457SAndroid Build Coastguard Worker       std::generate(input.begin(), input.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
707*4bdc9457SAndroid Build Coastguard Worker       std::generate(kernel.begin(), kernel.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
708*4bdc9457SAndroid Build Coastguard Worker       std::transform(kernel.cbegin(), kernel.cend(), kernel_as_float.begin(), fp16_ieee_to_fp32_value);
709*4bdc9457SAndroid Build Coastguard Worker       std::generate(bias.begin(), bias.end(), [&]() { return fp16_ieee_from_fp32_value(f32dist(rng)); });
710*4bdc9457SAndroid Build Coastguard Worker       std::transform(bias.cbegin(), bias.cend(), bias_as_float.begin(), fp16_ieee_to_fp32_value);
711*4bdc9457SAndroid Build Coastguard Worker       std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
712*4bdc9457SAndroid Build Coastguard Worker 
713*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results, without renormalization.
714*4bdc9457SAndroid Build Coastguard Worker       if (has_bias()) {
715*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
716*4bdc9457SAndroid Build Coastguard Worker           for (size_t oc = 0; oc < output_channels(); oc++) {
717*4bdc9457SAndroid Build Coastguard Worker             output_ref[i * output_channels() + oc] = fp16_ieee_to_fp32_value(bias[oc]);
718*4bdc9457SAndroid Build Coastguard Worker           }
719*4bdc9457SAndroid Build Coastguard Worker         }
720*4bdc9457SAndroid Build Coastguard Worker       } else {
721*4bdc9457SAndroid Build Coastguard Worker         std::fill(output_ref.begin(), output_ref.end(), 0.0f);
722*4bdc9457SAndroid Build Coastguard Worker       }
723*4bdc9457SAndroid Build Coastguard Worker       if (transpose_weights()) {
724*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
725*4bdc9457SAndroid Build Coastguard Worker           for (size_t oc = 0; oc < output_channels(); oc++) {
726*4bdc9457SAndroid Build Coastguard Worker             for (size_t ic = 0; ic < input_channels(); ic++) {
727*4bdc9457SAndroid Build Coastguard Worker               output_ref[i * output_channels() + oc] +=
728*4bdc9457SAndroid Build Coastguard Worker                 fp16_ieee_to_fp32_value(input[i * input_stride() + ic]) * fp16_ieee_to_fp32_value(kernel[ic * output_channels() + oc]);
729*4bdc9457SAndroid Build Coastguard Worker             }
730*4bdc9457SAndroid Build Coastguard Worker           }
731*4bdc9457SAndroid Build Coastguard Worker         }
732*4bdc9457SAndroid Build Coastguard Worker       } else {
733*4bdc9457SAndroid Build Coastguard Worker         for (size_t i = 0; i < batch_size(); i++) {
734*4bdc9457SAndroid Build Coastguard Worker           for (size_t oc = 0; oc < output_channels(); oc++) {
735*4bdc9457SAndroid Build Coastguard Worker             for (size_t ic = 0; ic < input_channels(); ic++) {
736*4bdc9457SAndroid Build Coastguard Worker               output_ref[i * output_channels() + oc] +=
737*4bdc9457SAndroid Build Coastguard Worker                 fp16_ieee_to_fp32_value(input[i * input_stride() + ic]) * fp16_ieee_to_fp32_value(kernel[oc * input_channels() + ic]);
738*4bdc9457SAndroid Build Coastguard Worker             }
739*4bdc9457SAndroid Build Coastguard Worker           }
740*4bdc9457SAndroid Build Coastguard Worker         }
741*4bdc9457SAndroid Build Coastguard Worker       }
742*4bdc9457SAndroid Build Coastguard Worker 
743*4bdc9457SAndroid Build Coastguard Worker       // Compute clamping parameters.
744*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_min = *std::min_element(output_ref.cbegin(), output_ref.cend());
745*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_max = *std::max_element(output_ref.cbegin(), output_ref.cend());
746*4bdc9457SAndroid Build Coastguard Worker       const float accumulated_range = accumulated_max - accumulated_min;
747*4bdc9457SAndroid Build Coastguard Worker       const float scaled_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + accumulated_range / 255.0f * float(qmin())));
748*4bdc9457SAndroid Build Coastguard Worker       const float scaled_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - accumulated_range / 255.0f * float(255 - qmax())));
749*4bdc9457SAndroid Build Coastguard Worker       const float output_min = scaled_min == scaled_max ? -std::numeric_limits<float>::infinity() : scaled_min;
750*4bdc9457SAndroid Build Coastguard Worker       const float output_max = scaled_min == scaled_max ? +std::numeric_limits<float>::infinity() : scaled_max;
751*4bdc9457SAndroid Build Coastguard Worker 
752*4bdc9457SAndroid Build Coastguard Worker       // Clamp reference results.
753*4bdc9457SAndroid Build Coastguard Worker       for (float& value : output_ref) {
754*4bdc9457SAndroid Build Coastguard Worker         value = std::max(std::min(value, output_max), output_min);
755*4bdc9457SAndroid Build Coastguard Worker       }
756*4bdc9457SAndroid Build Coastguard Worker 
757*4bdc9457SAndroid Build Coastguard Worker       // Create, setup, run, and destroy Fully Connected operator.
758*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
759*4bdc9457SAndroid Build Coastguard Worker       xnn_operator_t fully_connected_op = nullptr;
760*4bdc9457SAndroid Build Coastguard Worker 
761*4bdc9457SAndroid Build Coastguard Worker       xnn_caches caches = {
762*4bdc9457SAndroid Build Coastguard Worker         .code_cache = NULL,
763*4bdc9457SAndroid Build Coastguard Worker         .weights_cache = NULL,
764*4bdc9457SAndroid Build Coastguard Worker       };
765*4bdc9457SAndroid Build Coastguard Worker       xnn_weights_cache weights_cache;
766*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
767*4bdc9457SAndroid Build Coastguard Worker         xnn_init_weights_cache(&weights_cache);
768*4bdc9457SAndroid Build Coastguard Worker         caches.weights_cache = &weights_cache;
769*4bdc9457SAndroid Build Coastguard Worker       }
770*4bdc9457SAndroid Build Coastguard Worker 
771*4bdc9457SAndroid Build Coastguard Worker       const void* kernel_data = kernel.data();
772*4bdc9457SAndroid Build Coastguard Worker       const void* bias_data = bias.data();
773*4bdc9457SAndroid Build Coastguard Worker       if (weights_type() == WeightsType::FP32) {
774*4bdc9457SAndroid Build Coastguard Worker         kernel_data = kernel_as_float.data();
775*4bdc9457SAndroid Build Coastguard Worker         bias_data = bias_as_float.data();
776*4bdc9457SAndroid Build Coastguard Worker       }
777*4bdc9457SAndroid Build Coastguard Worker       uint32_t flags = 0;
778*4bdc9457SAndroid Build Coastguard Worker       if (transpose_weights()) {
779*4bdc9457SAndroid Build Coastguard Worker         flags |= XNN_FLAG_TRANSPOSE_WEIGHTS;
780*4bdc9457SAndroid Build Coastguard Worker       }
781*4bdc9457SAndroid Build Coastguard Worker       if (weights_type() == WeightsType::FP32) {
782*4bdc9457SAndroid Build Coastguard Worker         flags |= XNN_FLAG_FP32_STATIC_WEIGHTS;
783*4bdc9457SAndroid Build Coastguard Worker       }
784*4bdc9457SAndroid Build Coastguard Worker       const xnn_status status = xnn_create_fully_connected_nc_f16(
785*4bdc9457SAndroid Build Coastguard Worker           input_channels(), output_channels(),
786*4bdc9457SAndroid Build Coastguard Worker           input_stride(), output_stride(),
787*4bdc9457SAndroid Build Coastguard Worker           kernel_data, has_bias() ? bias_data : nullptr,
788*4bdc9457SAndroid Build Coastguard Worker           output_min, output_max,
789*4bdc9457SAndroid Build Coastguard Worker           flags,
790*4bdc9457SAndroid Build Coastguard Worker           &caches,
791*4bdc9457SAndroid Build Coastguard Worker           &fully_connected_op);
792*4bdc9457SAndroid Build Coastguard Worker       if (status == xnn_status_unsupported_hardware) {
793*4bdc9457SAndroid Build Coastguard Worker         GTEST_SKIP();
794*4bdc9457SAndroid Build Coastguard Worker       }
795*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success, status);
796*4bdc9457SAndroid Build Coastguard Worker       ASSERT_NE(nullptr, fully_connected_op);
797*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
798*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
799*4bdc9457SAndroid Build Coastguard Worker                   xnn_finalize_weights_cache(&weights_cache, xnn_weights_cache_finalization_kind_soft));
800*4bdc9457SAndroid Build Coastguard Worker       }
801*4bdc9457SAndroid Build Coastguard Worker 
802*4bdc9457SAndroid Build Coastguard Worker       // Smart pointer to automatically delete fully_connected_op.
803*4bdc9457SAndroid Build Coastguard Worker       std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op, xnn_delete_operator);
804*4bdc9457SAndroid Build Coastguard Worker 
805*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
806*4bdc9457SAndroid Build Coastguard Worker         xnn_setup_fully_connected_nc_f16(
807*4bdc9457SAndroid Build Coastguard Worker           fully_connected_op,
808*4bdc9457SAndroid Build Coastguard Worker           batch_size(),
809*4bdc9457SAndroid Build Coastguard Worker           input.data(), output.data(),
810*4bdc9457SAndroid Build Coastguard Worker           nullptr /* thread pool */));
811*4bdc9457SAndroid Build Coastguard Worker 
812*4bdc9457SAndroid Build Coastguard Worker       ASSERT_EQ(xnn_status_success,
813*4bdc9457SAndroid Build Coastguard Worker         xnn_run_operator(fully_connected_op, nullptr /* thread pool */));
814*4bdc9457SAndroid Build Coastguard Worker 
815*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
816*4bdc9457SAndroid Build Coastguard Worker       VerifyF16(output, output_ref, output_max, output_min);
817*4bdc9457SAndroid Build Coastguard Worker 
818*4bdc9457SAndroid Build Coastguard Worker       if (use_weights_cache()) {
819*4bdc9457SAndroid Build Coastguard Worker         xnn_operator_t fully_connected_op2 = nullptr;
820*4bdc9457SAndroid Build Coastguard Worker         size_t old_weights_cache_size = weights_cache.cache.weights.size;
821*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
822*4bdc9457SAndroid Build Coastguard Worker                   xnn_create_fully_connected_nc_f16(
823*4bdc9457SAndroid Build Coastguard Worker                       input_channels(), output_channels(), input_stride(),
824*4bdc9457SAndroid Build Coastguard Worker                       output_stride(), kernel_data,
825*4bdc9457SAndroid Build Coastguard Worker                       has_bias() ? bias_data : nullptr, output_min, output_max,
826*4bdc9457SAndroid Build Coastguard Worker                       flags, &caches, &fully_connected_op2));
827*4bdc9457SAndroid Build Coastguard Worker         if (status == xnn_status_unsupported_hardware) {
828*4bdc9457SAndroid Build Coastguard Worker           GTEST_SKIP();
829*4bdc9457SAndroid Build Coastguard Worker         }
830*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success, status);
831*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NE(nullptr, fully_connected_op2);
832*4bdc9457SAndroid Build Coastguard Worker 
833*4bdc9457SAndroid Build Coastguard Worker         // Smart pointer to automatically delete fully_connected_op2.
834*4bdc9457SAndroid Build Coastguard Worker         std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_fully_connected_op(fully_connected_op2, xnn_delete_operator);
835*4bdc9457SAndroid Build Coastguard Worker         std::vector<uint16_t> output2(output.size(), UINT16_C(0x7E00) /* NaN */);
836*4bdc9457SAndroid Build Coastguard Worker 
837*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
838*4bdc9457SAndroid Build Coastguard Worker                   xnn_setup_fully_connected_nc_f16(
839*4bdc9457SAndroid Build Coastguard Worker                       fully_connected_op2,
840*4bdc9457SAndroid Build Coastguard Worker                       batch_size(),
841*4bdc9457SAndroid Build Coastguard Worker                       input.data(), output2.data(),
842*4bdc9457SAndroid Build Coastguard Worker                       nullptr /* thread pool */));
843*4bdc9457SAndroid Build Coastguard Worker 
844*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(xnn_status_success,
845*4bdc9457SAndroid Build Coastguard Worker                   xnn_run_operator(fully_connected_op2, nullptr /* thread pool */));
846*4bdc9457SAndroid Build Coastguard Worker 
847*4bdc9457SAndroid Build Coastguard Worker         // Verify results.
848*4bdc9457SAndroid Build Coastguard Worker         VerifyF16(output2, output_ref, output_max, output_min);
849*4bdc9457SAndroid Build Coastguard Worker         VerifyWeightsCache(weights_cache, old_weights_cache_size);
850*4bdc9457SAndroid Build Coastguard Worker         xnn_release_weights_cache(&weights_cache);
851*4bdc9457SAndroid Build Coastguard Worker       }
852*4bdc9457SAndroid Build Coastguard Worker     }
853*4bdc9457SAndroid Build Coastguard Worker   }
854*4bdc9457SAndroid Build Coastguard Worker 
VerifyF16(const std::vector<uint16_t> & output,const std::vector<float> & output_ref,const float output_max,const float output_min)855*4bdc9457SAndroid Build Coastguard Worker   void VerifyF16(const std::vector<uint16_t>& output,
856*4bdc9457SAndroid Build Coastguard Worker                  const std::vector<float>& output_ref,
857*4bdc9457SAndroid Build Coastguard Worker                  const float output_max,
858*4bdc9457SAndroid Build Coastguard Worker                  const float output_min) const {
859*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < batch_size(); i++) {
860*4bdc9457SAndroid Build Coastguard Worker       for (size_t c = 0; c < output_channels(); c++) {
861*4bdc9457SAndroid Build Coastguard Worker         ASSERT_LE(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_max)
862*4bdc9457SAndroid Build Coastguard Worker           << "batch index = " << i << ", channel = " << c;
863*4bdc9457SAndroid Build Coastguard Worker         ASSERT_GE(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_min)
864*4bdc9457SAndroid Build Coastguard Worker           << "batch index = " << i << ", channel = " << c;
865*4bdc9457SAndroid Build Coastguard Worker         ASSERT_NEAR(
866*4bdc9457SAndroid Build Coastguard Worker             output_ref[i * output_channels() + c],
867*4bdc9457SAndroid Build Coastguard Worker             fp16_ieee_to_fp32_value(output[i * output_stride() + c]),
868*4bdc9457SAndroid Build Coastguard Worker             1.0e-2f * std::abs(output_ref[i * output_channels() + c]))
869*4bdc9457SAndroid Build Coastguard Worker           << "batch index = " << i << ", channel = " << c;
870*4bdc9457SAndroid Build Coastguard Worker       }
871*4bdc9457SAndroid Build Coastguard Worker     }
872*4bdc9457SAndroid Build Coastguard Worker   }
873*4bdc9457SAndroid Build Coastguard Worker 
VerifyWeightsCache(const xnn_weights_cache & weights_cache,size_t old_size)874*4bdc9457SAndroid Build Coastguard Worker   void VerifyWeightsCache(const xnn_weights_cache& weights_cache, size_t old_size) const {
875*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(weights_cache.cache.hits, 1);
876*4bdc9457SAndroid Build Coastguard Worker     // Ensure that we did not write more weights to the cache because it was a cache hit.
877*4bdc9457SAndroid Build Coastguard Worker     ASSERT_EQ(old_size, weights_cache.cache.weights.size);
878*4bdc9457SAndroid Build Coastguard Worker   };
879*4bdc9457SAndroid Build Coastguard Worker 
880*4bdc9457SAndroid Build Coastguard Worker  private:
881*4bdc9457SAndroid Build Coastguard Worker   size_t input_channels_{1};
882*4bdc9457SAndroid Build Coastguard Worker   size_t input_stride_{0};
883*4bdc9457SAndroid Build Coastguard Worker   size_t output_channels_{1};
884*4bdc9457SAndroid Build Coastguard Worker   size_t output_stride_{0};
885*4bdc9457SAndroid Build Coastguard Worker   size_t batch_size_{1};
886*4bdc9457SAndroid Build Coastguard Worker   uint8_t qmin_{0};
887*4bdc9457SAndroid Build Coastguard Worker   uint8_t qmax_{255};
888*4bdc9457SAndroid Build Coastguard Worker   bool transpose_weights_{false};
889*4bdc9457SAndroid Build Coastguard Worker   bool has_bias_{true};
890*4bdc9457SAndroid Build Coastguard Worker   WeightsType weights_type_{WeightsType::Default};
891*4bdc9457SAndroid Build Coastguard Worker   bool use_weights_cache_{false};
892*4bdc9457SAndroid Build Coastguard Worker   size_t iterations_{1};
893*4bdc9457SAndroid Build Coastguard Worker };
894