1*4bdc9457SAndroid Build Coastguard Worker #include "gemm-microkernel-tester.h"
2*4bdc9457SAndroid Build Coastguard Worker
3*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
4*4bdc9457SAndroid Build Coastguard Worker
5*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
6*4bdc9457SAndroid Build Coastguard Worker #include <cassert>
7*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
8*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
9*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
10*4bdc9457SAndroid Build Coastguard Worker #include <limits>
11*4bdc9457SAndroid Build Coastguard Worker #include <numeric>
12*4bdc9457SAndroid Build Coastguard Worker #include <random>
13*4bdc9457SAndroid Build Coastguard Worker #include <vector>
14*4bdc9457SAndroid Build Coastguard Worker
15*4bdc9457SAndroid Build Coastguard Worker #include <fp16.h>
16*4bdc9457SAndroid Build Coastguard Worker
17*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
18*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/allocator.h>
19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/aligned-allocator.h>
20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/pack.h>
21*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h>
22*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microparams-init.h>
23*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/requantization.h>
24*4bdc9457SAndroid Build Coastguard Worker
25*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_qu8_gemm_minmax_ukernel_function gemm,xnn_init_qu8_conv_minmax_params_fn init_params,xnn_qu8_requantize_fn requantize) const26*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(
27*4bdc9457SAndroid Build Coastguard Worker xnn_qu8_gemm_minmax_ukernel_function gemm,
28*4bdc9457SAndroid Build Coastguard Worker xnn_init_qu8_conv_minmax_params_fn init_params,
29*4bdc9457SAndroid Build Coastguard Worker xnn_qu8_requantize_fn requantize) const
30*4bdc9457SAndroid Build Coastguard Worker {
31*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
32*4bdc9457SAndroid Build Coastguard Worker
33*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
34*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
35*4bdc9457SAndroid Build Coastguard Worker auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
36*4bdc9457SAndroid Build Coastguard Worker auto u8rng = std::bind(
37*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), std::ref(rng));
38*4bdc9457SAndroid Build Coastguard Worker
39*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
40*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> b(n() * k());
41*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> bias(n());
42*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t, AlignedAllocator<uint8_t, 64>> packed_w(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(uint8_t));
43*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
44*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> acc(m() * n());
45*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> c_ref(m() * n());
46*4bdc9457SAndroid Build Coastguard Worker
47*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
48*4bdc9457SAndroid Build Coastguard Worker do {
49*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), std::ref(u8rng));
50*4bdc9457SAndroid Build Coastguard Worker } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
51*4bdc9457SAndroid Build Coastguard Worker do {
52*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), std::ref(u8rng));
53*4bdc9457SAndroid Build Coastguard Worker } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
54*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), std::ref(i32rng));
55*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), 0xA5);
56*4bdc9457SAndroid Build Coastguard Worker
57*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), b_zero_point());
58*4bdc9457SAndroid Build Coastguard Worker const xnn_qu8_packing_params packing_params = { a_zero_point(), b_zero_point() };
59*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qu8_gemm_goi_w(1, n(), k(), nr(), kr(), sr(),
60*4bdc9457SAndroid Build Coastguard Worker b.data(), bias.data(), packed_w.data(), 0, &packing_params);
61*4bdc9457SAndroid Build Coastguard Worker
62*4bdc9457SAndroid Build Coastguard Worker // Compute 32-bit results and output quantization arguments.
63*4bdc9457SAndroid Build Coastguard Worker std::fill(acc.begin(), acc.end(), 0);
64*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
65*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
66*4bdc9457SAndroid Build Coastguard Worker for (size_t k_index = 0; k_index < k(); k_index++) {
67*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] +=
68*4bdc9457SAndroid Build Coastguard Worker (int32_t(a[m_index * a_stride() + k_index]) - int32_t(a_zero_point())) *
69*4bdc9457SAndroid Build Coastguard Worker (int32_t(b[n_index * k() + k_index]) - int32_t(b_zero_point()));
70*4bdc9457SAndroid Build Coastguard Worker }
71*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] += bias[n_index];
72*4bdc9457SAndroid Build Coastguard Worker }
73*4bdc9457SAndroid Build Coastguard Worker }
74*4bdc9457SAndroid Build Coastguard Worker
75*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
76*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
77*4bdc9457SAndroid Build Coastguard Worker const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
78*4bdc9457SAndroid Build Coastguard Worker const uint8_t c_zero_point = uint8_t(std::max(std::min(
79*4bdc9457SAndroid Build Coastguard Worker lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
80*4bdc9457SAndroid Build Coastguard Worker long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
81*4bdc9457SAndroid Build Coastguard Worker
82*4bdc9457SAndroid Build Coastguard Worker const float requantization_scale = 1.0f / float(c_scale);
83*4bdc9457SAndroid Build Coastguard Worker union xnn_qu8_conv_minmax_params quantization_params;
84*4bdc9457SAndroid Build Coastguard Worker init_params(&quantization_params,
85*4bdc9457SAndroid Build Coastguard Worker b_zero_point(), requantization_scale, c_zero_point, qmin(), qmax());
86*4bdc9457SAndroid Build Coastguard Worker
87*4bdc9457SAndroid Build Coastguard Worker gemm(
88*4bdc9457SAndroid Build Coastguard Worker m(), n(), k(),
89*4bdc9457SAndroid Build Coastguard Worker a.data(), a_stride() * sizeof(uint8_t),
90*4bdc9457SAndroid Build Coastguard Worker packed_w.data(),
91*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(uint8_t), cn_stride() * sizeof(uint8_t),
92*4bdc9457SAndroid Build Coastguard Worker &quantization_params);
93*4bdc9457SAndroid Build Coastguard Worker
94*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
95*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
96*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] = requantize(
97*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index], requantization_scale, c_zero_point, qmin(), qmax());
98*4bdc9457SAndroid Build Coastguard Worker }
99*4bdc9457SAndroid Build Coastguard Worker }
100*4bdc9457SAndroid Build Coastguard Worker
101*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
102*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
103*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmax()));
104*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmin()));
105*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(c_ref[i * n() + j]))
106*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << (uint32_t) c_ref[i * n() + j]
107*4bdc9457SAndroid Build Coastguard Worker << " (accumulator = " << acc[i * n() + j]
108*4bdc9457SAndroid Build Coastguard Worker << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
109*4bdc9457SAndroid Build Coastguard Worker << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
110*4bdc9457SAndroid Build Coastguard Worker << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
111*4bdc9457SAndroid Build Coastguard Worker }
112*4bdc9457SAndroid Build Coastguard Worker }
113*4bdc9457SAndroid Build Coastguard Worker }
114*4bdc9457SAndroid Build Coastguard Worker }
115*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_qu8_igemm_minmax_ukernel_function igemm,xnn_init_qu8_conv_minmax_params_fn init_params,xnn_qu8_requantize_fn requantize)116*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(
117*4bdc9457SAndroid Build Coastguard Worker xnn_qu8_igemm_minmax_ukernel_function igemm,
118*4bdc9457SAndroid Build Coastguard Worker xnn_init_qu8_conv_minmax_params_fn init_params,
119*4bdc9457SAndroid Build Coastguard Worker xnn_qu8_requantize_fn requantize)
120*4bdc9457SAndroid Build Coastguard Worker {
121*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
122*4bdc9457SAndroid Build Coastguard Worker
123*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
124*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
125*4bdc9457SAndroid Build Coastguard Worker auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
126*4bdc9457SAndroid Build Coastguard Worker auto u8rng = std::bind(
127*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), std::ref(rng));
128*4bdc9457SAndroid Build Coastguard Worker
129*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
130*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> b(n() * ks() * k());
131*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t, AlignedAllocator<uint8_t, 64>> packed_w(ks() * packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(uint8_t));
132*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> bias(n());
133*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
134*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> acc(m() * n());
135*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> c_ref(m() * n());
136*4bdc9457SAndroid Build Coastguard Worker std::vector<uint8_t> junk(k() + 8);
137*4bdc9457SAndroid Build Coastguard Worker std::vector<const uint8_t*> im2col(mr() * ks());
138*4bdc9457SAndroid Build Coastguard Worker
139*4bdc9457SAndroid Build Coastguard Worker std::fill(junk.begin(), junk.end(), 0xA5);
140*4bdc9457SAndroid Build Coastguard Worker
141*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
142*4bdc9457SAndroid Build Coastguard Worker do {
143*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), std::ref(u8rng));
144*4bdc9457SAndroid Build Coastguard Worker } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
145*4bdc9457SAndroid Build Coastguard Worker do {
146*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), std::ref(u8rng));
147*4bdc9457SAndroid Build Coastguard Worker } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
148*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), std::ref(i32rng));
149*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), 0xA5);
150*4bdc9457SAndroid Build Coastguard Worker
151*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), b_zero_point());
152*4bdc9457SAndroid Build Coastguard Worker const xnn_qu8_packing_params packing_params = { a_zero_point(), b_zero_point() };
153*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qu8_conv_goki_w(
154*4bdc9457SAndroid Build Coastguard Worker 1, n(), ks(), k(), nr(), kr(), sr(),
155*4bdc9457SAndroid Build Coastguard Worker b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, &packing_params);
156*4bdc9457SAndroid Build Coastguard Worker
157*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
158*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < mr(); m_index++) {
159*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
160*4bdc9457SAndroid Build Coastguard Worker }
161*4bdc9457SAndroid Build Coastguard Worker }
162*4bdc9457SAndroid Build Coastguard Worker std::shuffle(im2col.begin(), im2col.end(), rng);
163*4bdc9457SAndroid Build Coastguard Worker if (zero_index() != SIZE_MAX) {
164*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
165*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + zero_index()] = a.data();
166*4bdc9457SAndroid Build Coastguard Worker }
167*4bdc9457SAndroid Build Coastguard Worker }
168*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
169*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = m(); m_index < mr(); m_index++) {
170*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + m_index] = junk.data();
171*4bdc9457SAndroid Build Coastguard Worker }
172*4bdc9457SAndroid Build Coastguard Worker }
173*4bdc9457SAndroid Build Coastguard Worker
174*4bdc9457SAndroid Build Coastguard Worker // Compute 32-bit results and output quantization arguments.
175*4bdc9457SAndroid Build Coastguard Worker std::fill(acc.begin(), acc.end(), 0);
176*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
177*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
178*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
179*4bdc9457SAndroid Build Coastguard Worker for (size_t k_index = 0; k_index < k(); k_index++) {
180*4bdc9457SAndroid Build Coastguard Worker if (im2col[ks_index * mr() + m_index] == a.data()) {
181*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] +=
182*4bdc9457SAndroid Build Coastguard Worker (int32_t(im2col[ks_index * mr() + m_index][k_index]) - int32_t(a_zero_point())) *
183*4bdc9457SAndroid Build Coastguard Worker (int32_t(b[(n_index * ks() + ks_index) * k() + k_index]) - int32_t(b_zero_point()));
184*4bdc9457SAndroid Build Coastguard Worker } else {
185*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] +=
186*4bdc9457SAndroid Build Coastguard Worker (int32_t(im2col[ks_index * mr() + m_index][k_index + a_offset()]) - int32_t(a_zero_point())) *
187*4bdc9457SAndroid Build Coastguard Worker (int32_t(b[(n_index * ks() + ks_index) * k() + k_index]) - int32_t(b_zero_point()));
188*4bdc9457SAndroid Build Coastguard Worker }
189*4bdc9457SAndroid Build Coastguard Worker }
190*4bdc9457SAndroid Build Coastguard Worker }
191*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] += bias[n_index];
192*4bdc9457SAndroid Build Coastguard Worker }
193*4bdc9457SAndroid Build Coastguard Worker }
194*4bdc9457SAndroid Build Coastguard Worker
195*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
196*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
197*4bdc9457SAndroid Build Coastguard Worker const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
198*4bdc9457SAndroid Build Coastguard Worker const uint8_t c_zero_point = uint8_t(std::max(std::min(
199*4bdc9457SAndroid Build Coastguard Worker lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
200*4bdc9457SAndroid Build Coastguard Worker long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
201*4bdc9457SAndroid Build Coastguard Worker
202*4bdc9457SAndroid Build Coastguard Worker const float requantization_scale = 1.0f / float(c_scale);
203*4bdc9457SAndroid Build Coastguard Worker union xnn_qu8_conv_minmax_params quantization_params;
204*4bdc9457SAndroid Build Coastguard Worker init_params(&quantization_params,
205*4bdc9457SAndroid Build Coastguard Worker b_zero_point(), requantization_scale, c_zero_point, qmin(), qmax());
206*4bdc9457SAndroid Build Coastguard Worker
207*4bdc9457SAndroid Build Coastguard Worker const uint8_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
208*4bdc9457SAndroid Build Coastguard Worker
209*4bdc9457SAndroid Build Coastguard Worker igemm(
210*4bdc9457SAndroid Build Coastguard Worker m(), n(), k(), ks() * mr() * sizeof(void*),
211*4bdc9457SAndroid Build Coastguard Worker im2col.data(), packed_w.data(),
212*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(uint8_t), cn_stride() * sizeof(uint8_t),
213*4bdc9457SAndroid Build Coastguard Worker a_offset() * sizeof(uint8_t), zero_pointer,
214*4bdc9457SAndroid Build Coastguard Worker &quantization_params);
215*4bdc9457SAndroid Build Coastguard Worker
216*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
217*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
218*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] = requantize(
219*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index], requantization_scale, c_zero_point, qmin(), qmax());
220*4bdc9457SAndroid Build Coastguard Worker }
221*4bdc9457SAndroid Build Coastguard Worker }
222*4bdc9457SAndroid Build Coastguard Worker
223*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
224*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
225*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmax()));
226*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmin()));
227*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(c_ref[i * n() + j]))
228*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j])
229*4bdc9457SAndroid Build Coastguard Worker << " (accumulator = " << acc[i * n() + j]
230*4bdc9457SAndroid Build Coastguard Worker << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
231*4bdc9457SAndroid Build Coastguard Worker << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
232*4bdc9457SAndroid Build Coastguard Worker << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
233*4bdc9457SAndroid Build Coastguard Worker }
234*4bdc9457SAndroid Build Coastguard Worker }
235*4bdc9457SAndroid Build Coastguard Worker }
236*4bdc9457SAndroid Build Coastguard Worker }
237*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_qc8_gemm_minmax_ukernel_function gemm,xnn_init_qc8_conv_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const238*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(
239*4bdc9457SAndroid Build Coastguard Worker xnn_qc8_gemm_minmax_ukernel_function gemm,
240*4bdc9457SAndroid Build Coastguard Worker xnn_init_qc8_conv_minmax_params_fn init_params,
241*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_requantize_fn requantize) const
242*4bdc9457SAndroid Build Coastguard Worker {
243*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
244*4bdc9457SAndroid Build Coastguard Worker
245*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
246*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
247*4bdc9457SAndroid Build Coastguard Worker auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
248*4bdc9457SAndroid Build Coastguard Worker auto i8rng = std::bind(
249*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
250*4bdc9457SAndroid Build Coastguard Worker std::ref(rng));
251*4bdc9457SAndroid Build Coastguard Worker auto w8rng = std::bind(
252*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
253*4bdc9457SAndroid Build Coastguard Worker std::ref(rng));
254*4bdc9457SAndroid Build Coastguard Worker
255*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(int8_t));
256*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> b(n() * k());
257*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> bias(n());
258*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int8_t));
259*4bdc9457SAndroid Build Coastguard Worker std::vector<int16_t, AlignedAllocator<int16_t, 64>> packed_xw(packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int16_t));
260*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
261*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> acc(m() * n());
262*4bdc9457SAndroid Build Coastguard Worker std::vector<float> scale(n());
263*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> c_ref(m() * n());
264*4bdc9457SAndroid Build Coastguard Worker
265*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
266*4bdc9457SAndroid Build Coastguard Worker do {
267*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), std::ref(i8rng));
268*4bdc9457SAndroid Build Coastguard Worker } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
269*4bdc9457SAndroid Build Coastguard Worker do {
270*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), std::ref(w8rng));
271*4bdc9457SAndroid Build Coastguard Worker } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
272*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), std::ref(i32rng));
273*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), 0xA5);
274*4bdc9457SAndroid Build Coastguard Worker
275*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), 0);
276*4bdc9457SAndroid Build Coastguard Worker const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
277*4bdc9457SAndroid Build Coastguard Worker if (extended_weights()) {
278*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qs8_gemm_xw_goi_w(1, n(), k(), nr(), kr(), sr(),
279*4bdc9457SAndroid Build Coastguard Worker b.data(), bias.data(), packed_xw.data(), nr() * sizeof(float), &packing_params);
280*4bdc9457SAndroid Build Coastguard Worker } else {
281*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qs8_gemm_goi_w(1, n(), k(), nr(), kr(), sr(),
282*4bdc9457SAndroid Build Coastguard Worker b.data(), bias.data(), packed_w.data(), nr() * sizeof(float), &packing_params);
283*4bdc9457SAndroid Build Coastguard Worker }
284*4bdc9457SAndroid Build Coastguard Worker
285*4bdc9457SAndroid Build Coastguard Worker // Compute 32-bit results and output quantization arguments.
286*4bdc9457SAndroid Build Coastguard Worker std::fill(acc.begin(), acc.end(), 0);
287*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
288*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
289*4bdc9457SAndroid Build Coastguard Worker for (size_t k_index = 0; k_index < k(); k_index++) {
290*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] +=
291*4bdc9457SAndroid Build Coastguard Worker (int32_t(a[m_index * a_stride() + k_index]) - int32_t(a_zero_point() - 0x80)) *
292*4bdc9457SAndroid Build Coastguard Worker int32_t(b[n_index * k() + k_index]);
293*4bdc9457SAndroid Build Coastguard Worker }
294*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] += bias[n_index];
295*4bdc9457SAndroid Build Coastguard Worker }
296*4bdc9457SAndroid Build Coastguard Worker }
297*4bdc9457SAndroid Build Coastguard Worker
298*4bdc9457SAndroid Build Coastguard Worker const int8_t c_zero_point = -1;
299*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
300*4bdc9457SAndroid Build Coastguard Worker int32_t accumulated_min = acc[n_index];
301*4bdc9457SAndroid Build Coastguard Worker int32_t accumulated_max = acc[n_index];
302*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
303*4bdc9457SAndroid Build Coastguard Worker accumulated_min = std::min(accumulated_min, acc[m_index * n() + n_index]);
304*4bdc9457SAndroid Build Coastguard Worker accumulated_max = std::max(accumulated_max, acc[m_index * n() + n_index]);
305*4bdc9457SAndroid Build Coastguard Worker }
306*4bdc9457SAndroid Build Coastguard Worker const uint32_t accumulated_range = uint32_t(accumulated_max - accumulated_min);
307*4bdc9457SAndroid Build Coastguard Worker const float c_scale = accumulated_range >= 256 ? double(accumulated_range) / 255.0 : 1.00001;
308*4bdc9457SAndroid Build Coastguard Worker scale[n_index] = 1.0f / c_scale;
309*4bdc9457SAndroid Build Coastguard Worker }
310*4bdc9457SAndroid Build Coastguard Worker
311*4bdc9457SAndroid Build Coastguard Worker if (extended_weights()) {
312*4bdc9457SAndroid Build Coastguard Worker xnn_init_qc8_scale_fp32_params(
313*4bdc9457SAndroid Build Coastguard Worker n(), nr(),
314*4bdc9457SAndroid Build Coastguard Worker nr() * (packed_k() * sizeof(int16_t) + (sizeof(int32_t) + sizeof(float))), scale.data(),
315*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) packed_xw.data() + nr() * (packed_k() * sizeof(int16_t) + sizeof(int32_t))));
316*4bdc9457SAndroid Build Coastguard Worker } else {
317*4bdc9457SAndroid Build Coastguard Worker xnn_init_qc8_scale_fp32_params(
318*4bdc9457SAndroid Build Coastguard Worker n(), nr(),
319*4bdc9457SAndroid Build Coastguard Worker nr() * (packed_k() * sizeof(int8_t) + (sizeof(int32_t) + sizeof(float))), scale.data(),
320*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) packed_w.data() + nr() * (packed_k() * sizeof(int8_t) + sizeof(int32_t))));
321*4bdc9457SAndroid Build Coastguard Worker }
322*4bdc9457SAndroid Build Coastguard Worker
323*4bdc9457SAndroid Build Coastguard Worker union xnn_qc8_conv_minmax_params minmax_params;
324*4bdc9457SAndroid Build Coastguard Worker init_params(&minmax_params,
325*4bdc9457SAndroid Build Coastguard Worker c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
326*4bdc9457SAndroid Build Coastguard Worker
327*4bdc9457SAndroid Build Coastguard Worker gemm(
328*4bdc9457SAndroid Build Coastguard Worker m(), n(), k(),
329*4bdc9457SAndroid Build Coastguard Worker a.data(), a_stride() * sizeof(int8_t),
330*4bdc9457SAndroid Build Coastguard Worker extended_weights() ? static_cast<const void*>(packed_xw.data()) : static_cast<const void*>(packed_w.data()),
331*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
332*4bdc9457SAndroid Build Coastguard Worker &minmax_params);
333*4bdc9457SAndroid Build Coastguard Worker
334*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
335*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
336*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] = requantize(
337*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index], scale[n_index], c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
338*4bdc9457SAndroid Build Coastguard Worker }
339*4bdc9457SAndroid Build Coastguard Worker }
340*4bdc9457SAndroid Build Coastguard Worker
341*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
342*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
343*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
344*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
345*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
346*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << int32_t(c_ref[i * n() + j])
347*4bdc9457SAndroid Build Coastguard Worker << " (accumulator = " << acc[i * n() + j]
348*4bdc9457SAndroid Build Coastguard Worker << "), optimized = " << int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x "
349*4bdc9457SAndroid Build Coastguard Worker << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
350*4bdc9457SAndroid Build Coastguard Worker << ", requantization scale = " << scale[j] << ", output zero point = " << int32_t(c_zero_point);
351*4bdc9457SAndroid Build Coastguard Worker }
352*4bdc9457SAndroid Build Coastguard Worker }
353*4bdc9457SAndroid Build Coastguard Worker }
354*4bdc9457SAndroid Build Coastguard Worker }
355*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_qc8_igemm_minmax_ukernel_function igemm,xnn_init_qc8_conv_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const356*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(
357*4bdc9457SAndroid Build Coastguard Worker xnn_qc8_igemm_minmax_ukernel_function igemm,
358*4bdc9457SAndroid Build Coastguard Worker xnn_init_qc8_conv_minmax_params_fn init_params,
359*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_requantize_fn requantize) const
360*4bdc9457SAndroid Build Coastguard Worker {
361*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
362*4bdc9457SAndroid Build Coastguard Worker
363*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
364*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
365*4bdc9457SAndroid Build Coastguard Worker auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
366*4bdc9457SAndroid Build Coastguard Worker auto i8rng = std::bind(
367*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
368*4bdc9457SAndroid Build Coastguard Worker std::ref(rng));
369*4bdc9457SAndroid Build Coastguard Worker auto w8rng = std::bind(
370*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
371*4bdc9457SAndroid Build Coastguard Worker std::ref(rng));
372*4bdc9457SAndroid Build Coastguard Worker
373*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
374*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> b(n() * ks() * k());
375*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(ks() * packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int8_t));
376*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> bias(n());
377*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
378*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> acc(m() * n());
379*4bdc9457SAndroid Build Coastguard Worker std::vector<float> scale(n());
380*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> c_ref(m() * n());
381*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> junk(k() + 8);
382*4bdc9457SAndroid Build Coastguard Worker std::vector<const int8_t*> im2col(mr() * ks());
383*4bdc9457SAndroid Build Coastguard Worker
384*4bdc9457SAndroid Build Coastguard Worker std::fill(junk.begin(), junk.end(), 0xA5);
385*4bdc9457SAndroid Build Coastguard Worker
386*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
387*4bdc9457SAndroid Build Coastguard Worker do {
388*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), std::ref(i8rng));
389*4bdc9457SAndroid Build Coastguard Worker } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
390*4bdc9457SAndroid Build Coastguard Worker do {
391*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), std::ref(w8rng));
392*4bdc9457SAndroid Build Coastguard Worker } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
393*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), std::ref(i32rng));
394*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), 0xA5);
395*4bdc9457SAndroid Build Coastguard Worker
396*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), 0);
397*4bdc9457SAndroid Build Coastguard Worker const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
398*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qs8_conv_goki_w(
399*4bdc9457SAndroid Build Coastguard Worker 1, n(), ks(), k(), nr(), kr(), sr(),
400*4bdc9457SAndroid Build Coastguard Worker b.data(), bias.data(), packed_w.data(), nr() * sizeof(float), &packing_params);
401*4bdc9457SAndroid Build Coastguard Worker
402*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
403*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < mr(); m_index++) {
404*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
405*4bdc9457SAndroid Build Coastguard Worker }
406*4bdc9457SAndroid Build Coastguard Worker }
407*4bdc9457SAndroid Build Coastguard Worker std::shuffle(im2col.begin(), im2col.end(), rng);
408*4bdc9457SAndroid Build Coastguard Worker if (zero_index() != SIZE_MAX) {
409*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
410*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + zero_index()] = a.data();
411*4bdc9457SAndroid Build Coastguard Worker }
412*4bdc9457SAndroid Build Coastguard Worker }
413*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
414*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = m(); m_index < mr(); m_index++) {
415*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + m_index] = junk.data();
416*4bdc9457SAndroid Build Coastguard Worker }
417*4bdc9457SAndroid Build Coastguard Worker }
418*4bdc9457SAndroid Build Coastguard Worker
419*4bdc9457SAndroid Build Coastguard Worker // Compute 32-bit results and output quantization arguments.
420*4bdc9457SAndroid Build Coastguard Worker std::fill(acc.begin(), acc.end(), 0);
421*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
422*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
423*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
424*4bdc9457SAndroid Build Coastguard Worker for (size_t k_index = 0; k_index < k(); k_index++) {
425*4bdc9457SAndroid Build Coastguard Worker if (im2col[ks_index * mr() + m_index] == a.data()) {
426*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] +=
427*4bdc9457SAndroid Build Coastguard Worker (int32_t(im2col[ks_index * mr() + m_index][k_index]) - int32_t(a_zero_point() - 0x80)) *
428*4bdc9457SAndroid Build Coastguard Worker int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
429*4bdc9457SAndroid Build Coastguard Worker } else {
430*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] +=
431*4bdc9457SAndroid Build Coastguard Worker (int32_t(im2col[ks_index * mr() + m_index][k_index + a_offset()]) - int32_t(a_zero_point() - 0x80)) *
432*4bdc9457SAndroid Build Coastguard Worker int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
433*4bdc9457SAndroid Build Coastguard Worker }
434*4bdc9457SAndroid Build Coastguard Worker }
435*4bdc9457SAndroid Build Coastguard Worker }
436*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] += bias[n_index];
437*4bdc9457SAndroid Build Coastguard Worker }
438*4bdc9457SAndroid Build Coastguard Worker }
439*4bdc9457SAndroid Build Coastguard Worker
440*4bdc9457SAndroid Build Coastguard Worker const int8_t c_zero_point = -1;
441*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
442*4bdc9457SAndroid Build Coastguard Worker int32_t accumulated_min = acc[n_index];
443*4bdc9457SAndroid Build Coastguard Worker int32_t accumulated_max = acc[n_index];
444*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
445*4bdc9457SAndroid Build Coastguard Worker accumulated_min = std::min(accumulated_min, acc[m_index * n() + n_index]);
446*4bdc9457SAndroid Build Coastguard Worker accumulated_max = std::max(accumulated_max, acc[m_index * n() + n_index]);
447*4bdc9457SAndroid Build Coastguard Worker }
448*4bdc9457SAndroid Build Coastguard Worker const uint32_t accumulated_range = uint32_t(accumulated_max - accumulated_min);
449*4bdc9457SAndroid Build Coastguard Worker const float c_scale = accumulated_range >= 256 ? double(accumulated_range) / 255.0 : 1.00001;
450*4bdc9457SAndroid Build Coastguard Worker scale[n_index] = 1.0f / c_scale;
451*4bdc9457SAndroid Build Coastguard Worker }
452*4bdc9457SAndroid Build Coastguard Worker
453*4bdc9457SAndroid Build Coastguard Worker xnn_init_qc8_scale_fp32_params(
454*4bdc9457SAndroid Build Coastguard Worker n(), nr(),
455*4bdc9457SAndroid Build Coastguard Worker nr() * (ks() * packed_k() * sizeof(int8_t) + (sizeof(int32_t) + sizeof(float))), scale.data(),
456*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) packed_w.data() + nr() * (ks() * packed_k() * sizeof(int8_t) + sizeof(int32_t))));
457*4bdc9457SAndroid Build Coastguard Worker
458*4bdc9457SAndroid Build Coastguard Worker union xnn_qc8_conv_minmax_params minmax_params;
459*4bdc9457SAndroid Build Coastguard Worker init_params(&minmax_params,
460*4bdc9457SAndroid Build Coastguard Worker c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
461*4bdc9457SAndroid Build Coastguard Worker
462*4bdc9457SAndroid Build Coastguard Worker const int8_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
463*4bdc9457SAndroid Build Coastguard Worker
464*4bdc9457SAndroid Build Coastguard Worker igemm(
465*4bdc9457SAndroid Build Coastguard Worker m(), n(), k(), ks() * mr() * sizeof(void*),
466*4bdc9457SAndroid Build Coastguard Worker im2col.data(), packed_w.data(),
467*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
468*4bdc9457SAndroid Build Coastguard Worker a_offset() * sizeof(uint8_t), zero_pointer,
469*4bdc9457SAndroid Build Coastguard Worker &minmax_params);
470*4bdc9457SAndroid Build Coastguard Worker
471*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
472*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
473*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] = requantize(
474*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index], scale[n_index], c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
475*4bdc9457SAndroid Build Coastguard Worker }
476*4bdc9457SAndroid Build Coastguard Worker }
477*4bdc9457SAndroid Build Coastguard Worker
478*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
479*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
480*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
481*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
482*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
483*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j])
484*4bdc9457SAndroid Build Coastguard Worker << " (accumulator = " << acc[i * n() + j]
485*4bdc9457SAndroid Build Coastguard Worker << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
486*4bdc9457SAndroid Build Coastguard Worker << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
487*4bdc9457SAndroid Build Coastguard Worker << ", requantization scale = " << scale[j] << ", output zero point = " << int32_t(c_zero_point);
488*4bdc9457SAndroid Build Coastguard Worker }
489*4bdc9457SAndroid Build Coastguard Worker }
490*4bdc9457SAndroid Build Coastguard Worker }
491*4bdc9457SAndroid Build Coastguard Worker }
492*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_qs8_gemm_minmax_ukernel_function gemm,xnn_init_qs8_conv_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const493*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(
494*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_gemm_minmax_ukernel_function gemm,
495*4bdc9457SAndroid Build Coastguard Worker xnn_init_qs8_conv_minmax_params_fn init_params,
496*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_requantize_fn requantize) const
497*4bdc9457SAndroid Build Coastguard Worker {
498*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
499*4bdc9457SAndroid Build Coastguard Worker
500*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
501*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
502*4bdc9457SAndroid Build Coastguard Worker auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
503*4bdc9457SAndroid Build Coastguard Worker auto i8rng = std::bind(
504*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
505*4bdc9457SAndroid Build Coastguard Worker std::ref(rng));
506*4bdc9457SAndroid Build Coastguard Worker auto w8rng = std::bind(
507*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
508*4bdc9457SAndroid Build Coastguard Worker std::ref(rng));
509*4bdc9457SAndroid Build Coastguard Worker
510*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(int8_t));
511*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> b(n() * k());
512*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> bias(n());
513*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int8_t));
514*4bdc9457SAndroid Build Coastguard Worker std::vector<int16_t, AlignedAllocator<int16_t, 64>> packed_xw(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int16_t));
515*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
516*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> acc(m() * n());
517*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> c_ref(m() * n());
518*4bdc9457SAndroid Build Coastguard Worker
519*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
520*4bdc9457SAndroid Build Coastguard Worker do {
521*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), std::ref(i8rng));
522*4bdc9457SAndroid Build Coastguard Worker } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
523*4bdc9457SAndroid Build Coastguard Worker do {
524*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), std::ref(w8rng));
525*4bdc9457SAndroid Build Coastguard Worker } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
526*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), std::ref(i32rng));
527*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), 0xA5);
528*4bdc9457SAndroid Build Coastguard Worker
529*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), 0);
530*4bdc9457SAndroid Build Coastguard Worker const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
531*4bdc9457SAndroid Build Coastguard Worker if (extended_weights()) {
532*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qs8_gemm_xw_goi_w(1, n(), k(), nr(), kr(), sr(),
533*4bdc9457SAndroid Build Coastguard Worker b.data(), bias.data(), packed_xw.data(), 0, &packing_params);
534*4bdc9457SAndroid Build Coastguard Worker } else {
535*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qs8_gemm_goi_w(1, n(), k(), nr(), kr(), sr(),
536*4bdc9457SAndroid Build Coastguard Worker b.data(), bias.data(), packed_w.data(), 0, &packing_params);
537*4bdc9457SAndroid Build Coastguard Worker }
538*4bdc9457SAndroid Build Coastguard Worker
539*4bdc9457SAndroid Build Coastguard Worker // Compute 32-bit results and output quantization arguments.
540*4bdc9457SAndroid Build Coastguard Worker std::fill(acc.begin(), acc.end(), 0);
541*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
542*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
543*4bdc9457SAndroid Build Coastguard Worker for (size_t k_index = 0; k_index < k(); k_index++) {
544*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] +=
545*4bdc9457SAndroid Build Coastguard Worker (int32_t(a[m_index * a_stride() + k_index]) - int32_t(a_zero_point() - 0x80)) *
546*4bdc9457SAndroid Build Coastguard Worker int32_t(b[n_index * k() + k_index]);
547*4bdc9457SAndroid Build Coastguard Worker }
548*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] += bias[n_index];
549*4bdc9457SAndroid Build Coastguard Worker }
550*4bdc9457SAndroid Build Coastguard Worker }
551*4bdc9457SAndroid Build Coastguard Worker
552*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
553*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
554*4bdc9457SAndroid Build Coastguard Worker const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
555*4bdc9457SAndroid Build Coastguard Worker const int8_t c_zero_point = int8_t(std::max(std::min(
556*4bdc9457SAndroid Build Coastguard Worker lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
557*4bdc9457SAndroid Build Coastguard Worker long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
558*4bdc9457SAndroid Build Coastguard Worker
559*4bdc9457SAndroid Build Coastguard Worker const float requantization_scale = 1.0f / float(c_scale);
560*4bdc9457SAndroid Build Coastguard Worker union xnn_qs8_conv_minmax_params quantization_params;
561*4bdc9457SAndroid Build Coastguard Worker init_params(&quantization_params,
562*4bdc9457SAndroid Build Coastguard Worker requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
563*4bdc9457SAndroid Build Coastguard Worker
564*4bdc9457SAndroid Build Coastguard Worker gemm(
565*4bdc9457SAndroid Build Coastguard Worker m(), n(), k(),
566*4bdc9457SAndroid Build Coastguard Worker a.data(), a_stride() * sizeof(int8_t),
567*4bdc9457SAndroid Build Coastguard Worker extended_weights() ? static_cast<const void*>(packed_xw.data()) : static_cast<const void*>(packed_w.data()),
568*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
569*4bdc9457SAndroid Build Coastguard Worker &quantization_params);
570*4bdc9457SAndroid Build Coastguard Worker
571*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
572*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
573*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] = requantize(
574*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index], requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
575*4bdc9457SAndroid Build Coastguard Worker }
576*4bdc9457SAndroid Build Coastguard Worker }
577*4bdc9457SAndroid Build Coastguard Worker
578*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
579*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
580*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
581*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
582*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
583*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << int32_t(c_ref[i * n() + j])
584*4bdc9457SAndroid Build Coastguard Worker << " (accumulator = " << acc[i * n() + j]
585*4bdc9457SAndroid Build Coastguard Worker << "), optimized = " << int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x "
586*4bdc9457SAndroid Build Coastguard Worker << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
587*4bdc9457SAndroid Build Coastguard Worker << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
588*4bdc9457SAndroid Build Coastguard Worker }
589*4bdc9457SAndroid Build Coastguard Worker }
590*4bdc9457SAndroid Build Coastguard Worker }
591*4bdc9457SAndroid Build Coastguard Worker }
592*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_qs8_igemm_minmax_ukernel_function igemm,xnn_init_qs8_conv_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const593*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(
594*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_igemm_minmax_ukernel_function igemm,
595*4bdc9457SAndroid Build Coastguard Worker xnn_init_qs8_conv_minmax_params_fn init_params,
596*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_requantize_fn requantize) const
597*4bdc9457SAndroid Build Coastguard Worker {
598*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
599*4bdc9457SAndroid Build Coastguard Worker
600*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
601*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
602*4bdc9457SAndroid Build Coastguard Worker auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
603*4bdc9457SAndroid Build Coastguard Worker auto i8rng = std::bind(
604*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
605*4bdc9457SAndroid Build Coastguard Worker std::ref(rng));
606*4bdc9457SAndroid Build Coastguard Worker auto w8rng = std::bind(
607*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
608*4bdc9457SAndroid Build Coastguard Worker std::ref(rng));
609*4bdc9457SAndroid Build Coastguard Worker
610*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
611*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> b(n() * ks() * k());
612*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(ks() * packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int8_t));
613*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> bias(n());
614*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
615*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> acc(m() * n());
616*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> c_ref(m() * n());
617*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> junk(k() + 8);
618*4bdc9457SAndroid Build Coastguard Worker std::vector<const int8_t*> im2col(mr() * ks());
619*4bdc9457SAndroid Build Coastguard Worker
620*4bdc9457SAndroid Build Coastguard Worker std::fill(junk.begin(), junk.end(), 0xA5);
621*4bdc9457SAndroid Build Coastguard Worker
622*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
623*4bdc9457SAndroid Build Coastguard Worker do {
624*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), std::ref(i8rng));
625*4bdc9457SAndroid Build Coastguard Worker } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
626*4bdc9457SAndroid Build Coastguard Worker do {
627*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), std::ref(w8rng));
628*4bdc9457SAndroid Build Coastguard Worker } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
629*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), std::ref(i32rng));
630*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), 0xA5);
631*4bdc9457SAndroid Build Coastguard Worker
632*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), 0);
633*4bdc9457SAndroid Build Coastguard Worker const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
634*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qs8_conv_goki_w(
635*4bdc9457SAndroid Build Coastguard Worker 1, n(), ks(), k(), nr(), kr(), sr(),
636*4bdc9457SAndroid Build Coastguard Worker b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, &packing_params);
637*4bdc9457SAndroid Build Coastguard Worker
638*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
639*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < mr(); m_index++) {
640*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
641*4bdc9457SAndroid Build Coastguard Worker }
642*4bdc9457SAndroid Build Coastguard Worker }
643*4bdc9457SAndroid Build Coastguard Worker std::shuffle(im2col.begin(), im2col.end(), rng);
644*4bdc9457SAndroid Build Coastguard Worker if (zero_index() != SIZE_MAX) {
645*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
646*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + zero_index()] = a.data();
647*4bdc9457SAndroid Build Coastguard Worker }
648*4bdc9457SAndroid Build Coastguard Worker }
649*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
650*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = m(); m_index < mr(); m_index++) {
651*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + m_index] = junk.data();
652*4bdc9457SAndroid Build Coastguard Worker }
653*4bdc9457SAndroid Build Coastguard Worker }
654*4bdc9457SAndroid Build Coastguard Worker
655*4bdc9457SAndroid Build Coastguard Worker // Compute 32-bit results and output quantization arguments.
656*4bdc9457SAndroid Build Coastguard Worker std::fill(acc.begin(), acc.end(), 0);
657*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
658*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
659*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
660*4bdc9457SAndroid Build Coastguard Worker for (size_t k_index = 0; k_index < k(); k_index++) {
661*4bdc9457SAndroid Build Coastguard Worker if (im2col[ks_index * mr() + m_index] == a.data()) {
662*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] +=
663*4bdc9457SAndroid Build Coastguard Worker (int32_t(im2col[ks_index * mr() + m_index][k_index]) - int32_t(a_zero_point() - 0x80)) *
664*4bdc9457SAndroid Build Coastguard Worker int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
665*4bdc9457SAndroid Build Coastguard Worker } else {
666*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] +=
667*4bdc9457SAndroid Build Coastguard Worker (int32_t(im2col[ks_index * mr() + m_index][k_index + a_offset()]) - int32_t(a_zero_point() - 0x80)) *
668*4bdc9457SAndroid Build Coastguard Worker int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
669*4bdc9457SAndroid Build Coastguard Worker }
670*4bdc9457SAndroid Build Coastguard Worker }
671*4bdc9457SAndroid Build Coastguard Worker }
672*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] += bias[n_index];
673*4bdc9457SAndroid Build Coastguard Worker }
674*4bdc9457SAndroid Build Coastguard Worker }
675*4bdc9457SAndroid Build Coastguard Worker
676*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
677*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
678*4bdc9457SAndroid Build Coastguard Worker const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
679*4bdc9457SAndroid Build Coastguard Worker const uint8_t c_zero_point = uint8_t(std::max(std::min(
680*4bdc9457SAndroid Build Coastguard Worker lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
681*4bdc9457SAndroid Build Coastguard Worker long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
682*4bdc9457SAndroid Build Coastguard Worker
683*4bdc9457SAndroid Build Coastguard Worker const float requantization_scale = 1.0f / float(c_scale);
684*4bdc9457SAndroid Build Coastguard Worker union xnn_qs8_conv_minmax_params quantization_params;
685*4bdc9457SAndroid Build Coastguard Worker init_params(&quantization_params,
686*4bdc9457SAndroid Build Coastguard Worker requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
687*4bdc9457SAndroid Build Coastguard Worker
688*4bdc9457SAndroid Build Coastguard Worker const int8_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
689*4bdc9457SAndroid Build Coastguard Worker
690*4bdc9457SAndroid Build Coastguard Worker igemm(
691*4bdc9457SAndroid Build Coastguard Worker m(), n(), k(), ks() * mr() * sizeof(void*),
692*4bdc9457SAndroid Build Coastguard Worker im2col.data(), packed_w.data(),
693*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
694*4bdc9457SAndroid Build Coastguard Worker a_offset() * sizeof(uint8_t), zero_pointer,
695*4bdc9457SAndroid Build Coastguard Worker &quantization_params);
696*4bdc9457SAndroid Build Coastguard Worker
697*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
698*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
699*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] = requantize(
700*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index], requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
701*4bdc9457SAndroid Build Coastguard Worker }
702*4bdc9457SAndroid Build Coastguard Worker }
703*4bdc9457SAndroid Build Coastguard Worker
704*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
705*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
706*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
707*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
708*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
709*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j])
710*4bdc9457SAndroid Build Coastguard Worker << " (accumulator = " << acc[i * n() + j]
711*4bdc9457SAndroid Build Coastguard Worker << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
712*4bdc9457SAndroid Build Coastguard Worker << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
713*4bdc9457SAndroid Build Coastguard Worker << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
714*4bdc9457SAndroid Build Coastguard Worker }
715*4bdc9457SAndroid Build Coastguard Worker }
716*4bdc9457SAndroid Build Coastguard Worker }
717*4bdc9457SAndroid Build Coastguard Worker }
718*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_bf16_gemm_minmax_ukernel_function gemm_minmax,xnn_init_bf16_minmax_params_fn init_params) const719*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(xnn_bf16_gemm_minmax_ukernel_function gemm_minmax, xnn_init_bf16_minmax_params_fn init_params) const
720*4bdc9457SAndroid Build Coastguard Worker {
721*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
722*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(a_stride(), k());
723*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(cm_stride(), n());
724*4bdc9457SAndroid Build Coastguard Worker
725*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
726*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
727*4bdc9457SAndroid Build Coastguard Worker auto f32rng = std::bind(std::uniform_real_distribution<float>(0.5f, 1.0f), std::ref(rng));
728*4bdc9457SAndroid Build Coastguard Worker
729*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint16_t));
730*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> b(n() * k());
731*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> packed_w(packed_n() * packed_k() + packed_n());
732*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> bias(n());
733*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
734*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c_ref(m() * n());
735*4bdc9457SAndroid Build Coastguard Worker
736*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
737*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), [&] { return fp32_to_bits(f32rng(rng)) >> 16; });
738*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), [&] { return fp32_to_bits(f32rng(rng)) >> 16; });
739*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&] { return fp32_to_bits(f32rng(rng)) >> 16; });
740*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), UINT32_C(0x7FC0) /* NaN */);
741*4bdc9457SAndroid Build Coastguard Worker std::fill(c_ref.begin(), c_ref.end(), 0.0f);
742*4bdc9457SAndroid Build Coastguard Worker
743*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), 0);
744*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f16_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), 0, nullptr);
745*4bdc9457SAndroid Build Coastguard Worker
746*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
747*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
748*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] = fp32_from_bits(uint32_t(bias[n_index]) << 16);
749*4bdc9457SAndroid Build Coastguard Worker for (size_t k_index = 0; k_index < k(); k_index++) {
750*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(n(), packed_n());
751*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(m_index * n() + n_index, c_ref.size());
752*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(m_index * k() + k_index, a.size());
753*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] +=
754*4bdc9457SAndroid Build Coastguard Worker fp32_from_bits(uint32_t(a[m_index * a_stride() + k_index]) << 16) *
755*4bdc9457SAndroid Build Coastguard Worker fp32_from_bits(uint32_t(b[n_index * k() + k_index]) << 16);
756*4bdc9457SAndroid Build Coastguard Worker }
757*4bdc9457SAndroid Build Coastguard Worker }
758*4bdc9457SAndroid Build Coastguard Worker }
759*4bdc9457SAndroid Build Coastguard Worker
760*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
761*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
762*4bdc9457SAndroid Build Coastguard Worker const float c_min = fp32_from_bits(fp32_to_bits(accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin())) & UINT32_C(0xFFFF0000));
763*4bdc9457SAndroid Build Coastguard Worker const float c_max = fp32_from_bits(fp32_to_bits(accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax())) & UINT32_C(0xFFFF0000));
764*4bdc9457SAndroid Build Coastguard Worker
765*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters.
766*4bdc9457SAndroid Build Coastguard Worker xnn_bf16_minmax_params params;
767*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms,
768*4bdc9457SAndroid Build Coastguard Worker fp32_to_bits(c_min) >> 16,
769*4bdc9457SAndroid Build Coastguard Worker fp32_to_bits(c_max) >> 16);
770*4bdc9457SAndroid Build Coastguard Worker
771*4bdc9457SAndroid Build Coastguard Worker for (float& c_value : c_ref) {
772*4bdc9457SAndroid Build Coastguard Worker c_value = std::max(std::min(c_value, c_max), c_min);
773*4bdc9457SAndroid Build Coastguard Worker }
774*4bdc9457SAndroid Build Coastguard Worker
775*4bdc9457SAndroid Build Coastguard Worker gemm_minmax(m(), n(), k() * sizeof(uint16_t),
776*4bdc9457SAndroid Build Coastguard Worker a.data(), a_stride() * sizeof(uint16_t),
777*4bdc9457SAndroid Build Coastguard Worker packed_w.data(),
778*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(uint16_t), cn_stride() * sizeof(uint16_t),
779*4bdc9457SAndroid Build Coastguard Worker ¶ms);
780*4bdc9457SAndroid Build Coastguard Worker
781*4bdc9457SAndroid Build Coastguard Worker // Validate micro-kernel outputs.
782*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
783*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
784*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
785*4bdc9457SAndroid Build Coastguard Worker fp32_from_bits(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << 16),
786*4bdc9457SAndroid Build Coastguard Worker c_ref[i * n() + j],
787*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-4f, std::abs(c_ref[i * n() + j]) * 3.0e-2f))
788*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": Mr x Nr x Kr = " << mr() << " x " << nr()
789*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
790*4bdc9457SAndroid Build Coastguard Worker }
791*4bdc9457SAndroid Build Coastguard Worker }
792*4bdc9457SAndroid Build Coastguard Worker }
793*4bdc9457SAndroid Build Coastguard Worker }
794*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_f16_gemm_minmax_ukernel_function gemm_minmax,xnn_init_f16_minmax_params_fn init_params) const795*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(xnn_f16_gemm_minmax_ukernel_function gemm_minmax, xnn_init_f16_minmax_params_fn init_params) const
796*4bdc9457SAndroid Build Coastguard Worker {
797*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
798*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(a_stride(), k());
799*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(cm_stride(), n());
800*4bdc9457SAndroid Build Coastguard Worker
801*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
802*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
803*4bdc9457SAndroid Build Coastguard Worker auto f32rng = std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
804*4bdc9457SAndroid Build Coastguard Worker auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
805*4bdc9457SAndroid Build Coastguard Worker
806*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint16_t));
807*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> b(n() * k());
808*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> packed_w(packed_n() * packed_k() + packed_n());
809*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> bias(n());
810*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
811*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c_ref(m() * n());
812*4bdc9457SAndroid Build Coastguard Worker
813*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
814*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), std::ref(f16rng));
815*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), std::ref(f16rng));
816*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), std::ref(f16rng));
817*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), UINT16_C(0x7E00) /* NaN */);
818*4bdc9457SAndroid Build Coastguard Worker std::fill(c_ref.begin(), c_ref.end(), 0.0f);
819*4bdc9457SAndroid Build Coastguard Worker
820*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), 0);
821*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f16_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), 0, nullptr);
822*4bdc9457SAndroid Build Coastguard Worker
823*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
824*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
825*4bdc9457SAndroid Build Coastguard Worker for (size_t k_index = 0; k_index < k(); k_index++) {
826*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(n(), packed_n());
827*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(m_index * n() + n_index, c_ref.size());
828*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(m_index * k() + k_index, a.size());
829*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] +=
830*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(a[m_index * a_stride() + k_index]) *
831*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(b[n_index * k() + k_index]);
832*4bdc9457SAndroid Build Coastguard Worker }
833*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] += fp16_ieee_to_fp32_value(bias[n_index]);
834*4bdc9457SAndroid Build Coastguard Worker }
835*4bdc9457SAndroid Build Coastguard Worker }
836*4bdc9457SAndroid Build Coastguard Worker
837*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
838*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
839*4bdc9457SAndroid Build Coastguard Worker const float c_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin())));
840*4bdc9457SAndroid Build Coastguard Worker const float c_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax())));
841*4bdc9457SAndroid Build Coastguard Worker
842*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters.
843*4bdc9457SAndroid Build Coastguard Worker xnn_f16_minmax_params params;
844*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms,
845*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_from_fp32_value(c_min),
846*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_from_fp32_value(c_max));
847*4bdc9457SAndroid Build Coastguard Worker
848*4bdc9457SAndroid Build Coastguard Worker for (float& c_value : c_ref) {
849*4bdc9457SAndroid Build Coastguard Worker c_value = std::max(std::min(c_value, c_max), c_min);
850*4bdc9457SAndroid Build Coastguard Worker }
851*4bdc9457SAndroid Build Coastguard Worker
852*4bdc9457SAndroid Build Coastguard Worker gemm_minmax(m(), n(), k() * sizeof(uint16_t),
853*4bdc9457SAndroid Build Coastguard Worker a.data(), a_stride() * sizeof(uint16_t),
854*4bdc9457SAndroid Build Coastguard Worker packed_w.data(),
855*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(uint16_t), cn_stride() * sizeof(uint16_t),
856*4bdc9457SAndroid Build Coastguard Worker ¶ms);
857*4bdc9457SAndroid Build Coastguard Worker
858*4bdc9457SAndroid Build Coastguard Worker // Validate micro-kernel outputs.
859*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
860*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
861*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), c_ref[i * n() + j], std::max(1.0e-4f, std::abs(c_ref[i * n() + j]) * 1.0e-2f))
862*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
863*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " << nr()
864*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
865*4bdc9457SAndroid Build Coastguard Worker }
866*4bdc9457SAndroid Build Coastguard Worker }
867*4bdc9457SAndroid Build Coastguard Worker }
868*4bdc9457SAndroid Build Coastguard Worker }
869*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_f16_igemm_minmax_ukernel_function igemm_minmax,xnn_init_f16_minmax_params_fn init_params) const870*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(xnn_f16_igemm_minmax_ukernel_function igemm_minmax, xnn_init_f16_minmax_params_fn init_params) const {
871*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
872*4bdc9457SAndroid Build Coastguard Worker
873*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
874*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
875*4bdc9457SAndroid Build Coastguard Worker auto f32rng = std::bind(std::uniform_real_distribution<float>(), std::ref(rng));
876*4bdc9457SAndroid Build Coastguard Worker auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
877*4bdc9457SAndroid Build Coastguard Worker
878*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint16_t));
879*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> b(n() * ks() * k());
880*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> packed_w(ks() * packed_k() * packed_n() + packed_n());
881*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> bias(n());
882*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
883*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c_ref(m() * n());
884*4bdc9457SAndroid Build Coastguard Worker std::vector<uint16_t> junk(k() + XNN_EXTRA_BYTES / sizeof(uint16_t));
885*4bdc9457SAndroid Build Coastguard Worker std::vector<const uint16_t*> im2col(mr() * ks());
886*4bdc9457SAndroid Build Coastguard Worker std::fill(junk.begin(), junk.end(), UINT16_C(0x7E00) /* NaN */);
887*4bdc9457SAndroid Build Coastguard Worker
888*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
889*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), std::ref(f16rng));
890*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), std::ref(f16rng));
891*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), std::ref(f16rng));
892*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), UINT16_C(0x7E00) /* NaN */);
893*4bdc9457SAndroid Build Coastguard Worker std::fill(c_ref.begin(), c_ref.end(), 0);
894*4bdc9457SAndroid Build Coastguard Worker
895*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), 0);
896*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f16_conv_goki_w(
897*4bdc9457SAndroid Build Coastguard Worker 1, n(), ks(), k(), nr(), kr(), sr(),
898*4bdc9457SAndroid Build Coastguard Worker b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, nullptr);
899*4bdc9457SAndroid Build Coastguard Worker
900*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
901*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < mr(); m_index++) {
902*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
903*4bdc9457SAndroid Build Coastguard Worker }
904*4bdc9457SAndroid Build Coastguard Worker }
905*4bdc9457SAndroid Build Coastguard Worker std::shuffle(im2col.begin(), im2col.end(), rng);
906*4bdc9457SAndroid Build Coastguard Worker if (zero_index() != SIZE_MAX) {
907*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
908*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + zero_index()] = a.data();
909*4bdc9457SAndroid Build Coastguard Worker }
910*4bdc9457SAndroid Build Coastguard Worker }
911*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
912*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = m(); m_index < mr(); m_index++) {
913*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + m_index] = junk.data();
914*4bdc9457SAndroid Build Coastguard Worker }
915*4bdc9457SAndroid Build Coastguard Worker }
916*4bdc9457SAndroid Build Coastguard Worker
917*4bdc9457SAndroid Build Coastguard Worker std::fill(c_ref.begin(), c_ref.end(), 0.0);
918*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
919*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
920*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
921*4bdc9457SAndroid Build Coastguard Worker for (size_t k_index = 0; k_index < k(); k_index++) {
922*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(ks_index * mr() + m_index, im2col.size());
923*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(k_index, k());
924*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(k_index, a_stride());
925*4bdc9457SAndroid Build Coastguard Worker if (im2col[ks_index * mr() + m_index] == a.data()) {
926*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] +=
927*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(im2col[ks_index * mr() + m_index][k_index]) *
928*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(b[(n_index * ks() + ks_index) * k() + k_index]);
929*4bdc9457SAndroid Build Coastguard Worker } else {
930*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] +=
931*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(im2col[ks_index * mr() + m_index][k_index + a_offset()]) *
932*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_to_fp32_value(b[(n_index * ks() + ks_index) * k() + k_index]);
933*4bdc9457SAndroid Build Coastguard Worker }
934*4bdc9457SAndroid Build Coastguard Worker }
935*4bdc9457SAndroid Build Coastguard Worker }
936*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] += fp16_ieee_to_fp32_value(bias[n_index]);
937*4bdc9457SAndroid Build Coastguard Worker }
938*4bdc9457SAndroid Build Coastguard Worker }
939*4bdc9457SAndroid Build Coastguard Worker
940*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
941*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
942*4bdc9457SAndroid Build Coastguard Worker const float c_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + (accumulated_max - accumulated_min) / 255.0f * uint16_t(qmin())));
943*4bdc9457SAndroid Build Coastguard Worker const float c_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - (accumulated_max - accumulated_min) / 255.0f * uint16_t(255 - qmax())));
944*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
945*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
946*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] = std::min(c_ref[m_index * n() + n_index], c_max);
947*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] = std::max(c_ref[m_index * n() + n_index], c_min);
948*4bdc9457SAndroid Build Coastguard Worker }
949*4bdc9457SAndroid Build Coastguard Worker }
950*4bdc9457SAndroid Build Coastguard Worker
951*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters.
952*4bdc9457SAndroid Build Coastguard Worker xnn_f16_minmax_params params;
953*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms,
954*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_from_fp32_value(c_min),
955*4bdc9457SAndroid Build Coastguard Worker fp16_ieee_from_fp32_value(c_max));
956*4bdc9457SAndroid Build Coastguard Worker
957*4bdc9457SAndroid Build Coastguard Worker for (float& c_value : c_ref) {
958*4bdc9457SAndroid Build Coastguard Worker c_value = std::max(std::min(c_value, c_max), c_min);
959*4bdc9457SAndroid Build Coastguard Worker }
960*4bdc9457SAndroid Build Coastguard Worker
961*4bdc9457SAndroid Build Coastguard Worker const uint16_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
962*4bdc9457SAndroid Build Coastguard Worker
963*4bdc9457SAndroid Build Coastguard Worker igemm_minmax(
964*4bdc9457SAndroid Build Coastguard Worker m(), n(), k() * sizeof(uint16_t), ks() * mr() * sizeof(void*),
965*4bdc9457SAndroid Build Coastguard Worker reinterpret_cast<const void**>(im2col.data()), packed_w.data(),
966*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(uint16_t), cn_stride() * sizeof(uint16_t),
967*4bdc9457SAndroid Build Coastguard Worker a_offset() * sizeof(uint16_t), zero_pointer,
968*4bdc9457SAndroid Build Coastguard Worker ¶ms);
969*4bdc9457SAndroid Build Coastguard Worker
970*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
971*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
972*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), c_max)
973*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
974*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " << nr()
975*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
976*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), c_min)
977*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
978*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " << nr()
979*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
980*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), c_ref[i * n() + j], std::max(1.0e-4f, std::abs(c_ref[i * n() + j]) * 1.0e-2f))
981*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
982*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " << nr()
983*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
984*4bdc9457SAndroid Build Coastguard Worker }
985*4bdc9457SAndroid Build Coastguard Worker }
986*4bdc9457SAndroid Build Coastguard Worker }
987*4bdc9457SAndroid Build Coastguard Worker }
988*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_f32_ppmm_minmax_ukernel_function ppmm_minmax,xnn_init_f32_minmax_params_fn init_params) const989*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(xnn_f32_ppmm_minmax_ukernel_function ppmm_minmax, xnn_init_f32_minmax_params_fn init_params) const {
990*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
991*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(cm_stride(), n());
992*4bdc9457SAndroid Build Coastguard Worker
993*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
994*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
995*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist;
996*4bdc9457SAndroid Build Coastguard Worker
997*4bdc9457SAndroid Build Coastguard Worker std::vector<float> a(packed_k() * mr());
998*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(n() * k());
999*4bdc9457SAndroid Build Coastguard Worker std::vector<float> bias(n());
1000*4bdc9457SAndroid Build Coastguard Worker std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + packed_n());
1001*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1002*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c_ref(m() * n());
1003*4bdc9457SAndroid Build Coastguard Worker
1004*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
1005*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); });
1006*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); });
1007*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
1008*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), nanf(""));
1009*4bdc9457SAndroid Build Coastguard Worker std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1010*4bdc9457SAndroid Build Coastguard Worker
1011*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1012*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), 0, nullptr);
1013*4bdc9457SAndroid Build Coastguard Worker
1014*4bdc9457SAndroid Build Coastguard Worker for (size_t i = m(); i < mr(); i++) {
1015*4bdc9457SAndroid Build Coastguard Worker for (size_t l = 0; l < k(); l++) {
1016*4bdc9457SAndroid Build Coastguard Worker a[l * mr() + i] = a[l * mr() + m() - 1];
1017*4bdc9457SAndroid Build Coastguard Worker }
1018*4bdc9457SAndroid Build Coastguard Worker }
1019*4bdc9457SAndroid Build Coastguard Worker
1020*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
1021*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
1022*4bdc9457SAndroid Build Coastguard Worker for (size_t l = 0; l < k(); l++) {
1023*4bdc9457SAndroid Build Coastguard Worker c_ref[i * n() + j] +=
1024*4bdc9457SAndroid Build Coastguard Worker a[l * mr() + i] *
1025*4bdc9457SAndroid Build Coastguard Worker b[j * k() + l];
1026*4bdc9457SAndroid Build Coastguard Worker }
1027*4bdc9457SAndroid Build Coastguard Worker c_ref[i * n() + j] += bias[j];
1028*4bdc9457SAndroid Build Coastguard Worker }
1029*4bdc9457SAndroid Build Coastguard Worker }
1030*4bdc9457SAndroid Build Coastguard Worker
1031*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
1032*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
1033*4bdc9457SAndroid Build Coastguard Worker const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1034*4bdc9457SAndroid Build Coastguard Worker const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1035*4bdc9457SAndroid Build Coastguard Worker
1036*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters.
1037*4bdc9457SAndroid Build Coastguard Worker xnn_f32_minmax_params params;
1038*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, c_min, c_max);
1039*4bdc9457SAndroid Build Coastguard Worker
1040*4bdc9457SAndroid Build Coastguard Worker for (float& c_value : c_ref) {
1041*4bdc9457SAndroid Build Coastguard Worker c_value = std::max(std::min(c_value, c_max), c_min);
1042*4bdc9457SAndroid Build Coastguard Worker }
1043*4bdc9457SAndroid Build Coastguard Worker
1044*4bdc9457SAndroid Build Coastguard Worker ppmm_minmax(m(), n(), k() * sizeof(float),
1045*4bdc9457SAndroid Build Coastguard Worker a.data(), packed_w.data(),
1046*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1047*4bdc9457SAndroid Build Coastguard Worker ¶ms);
1048*4bdc9457SAndroid Build Coastguard Worker
1049*4bdc9457SAndroid Build Coastguard Worker // Validate micro-kernel outputs.
1050*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
1051*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
1052*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
1053*4bdc9457SAndroid Build Coastguard Worker c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1054*4bdc9457SAndroid Build Coastguard Worker c_ref[i * n() + j],
1055*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f))
1056*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1057*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1058*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1059*4bdc9457SAndroid Build Coastguard Worker }
1060*4bdc9457SAndroid Build Coastguard Worker }
1061*4bdc9457SAndroid Build Coastguard Worker }
1062*4bdc9457SAndroid Build Coastguard Worker }
1063*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_f32_gemm_ukernel_function gemm) const1064*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(xnn_f32_gemm_ukernel_function gemm) const {
1065*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
1066*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(a_stride(), k());
1067*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(cm_stride(), n());
1068*4bdc9457SAndroid Build Coastguard Worker
1069*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
1070*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
1071*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist;
1072*4bdc9457SAndroid Build Coastguard Worker
1073*4bdc9457SAndroid Build Coastguard Worker std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1074*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(n() * k());
1075*4bdc9457SAndroid Build Coastguard Worker std::vector<float> bias(n());
1076*4bdc9457SAndroid Build Coastguard Worker std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + packed_n());
1077*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1078*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c_ref(m() * n());
1079*4bdc9457SAndroid Build Coastguard Worker
1080*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
1081*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); });
1082*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); });
1083*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
1084*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), nanf(""));
1085*4bdc9457SAndroid Build Coastguard Worker std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1086*4bdc9457SAndroid Build Coastguard Worker
1087*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1088*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), 0, nullptr);
1089*4bdc9457SAndroid Build Coastguard Worker
1090*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
1091*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
1092*4bdc9457SAndroid Build Coastguard Worker for (size_t k_index = 0; k_index < k(); k_index++) {
1093*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(n(), packed_n());
1094*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(m_index * n() + n_index, c_ref.size());
1095*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] +=
1096*4bdc9457SAndroid Build Coastguard Worker a[m_index * a_stride() + k_index] *
1097*4bdc9457SAndroid Build Coastguard Worker b[n_index * k() + k_index];
1098*4bdc9457SAndroid Build Coastguard Worker }
1099*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] += bias[n_index];
1100*4bdc9457SAndroid Build Coastguard Worker }
1101*4bdc9457SAndroid Build Coastguard Worker }
1102*4bdc9457SAndroid Build Coastguard Worker
1103*4bdc9457SAndroid Build Coastguard Worker gemm(m(), n(), k() * sizeof(float),
1104*4bdc9457SAndroid Build Coastguard Worker a.data(), a_stride() * sizeof(float),
1105*4bdc9457SAndroid Build Coastguard Worker packed_w.data(),
1106*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1107*4bdc9457SAndroid Build Coastguard Worker nullptr);
1108*4bdc9457SAndroid Build Coastguard Worker
1109*4bdc9457SAndroid Build Coastguard Worker // Validate micro-kernel outputs.
1110*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
1111*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
1112*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
1113*4bdc9457SAndroid Build Coastguard Worker c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1114*4bdc9457SAndroid Build Coastguard Worker c_ref[i * n() + j],
1115*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f))
1116*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1117*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1118*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1119*4bdc9457SAndroid Build Coastguard Worker }
1120*4bdc9457SAndroid Build Coastguard Worker }
1121*4bdc9457SAndroid Build Coastguard Worker }
1122*4bdc9457SAndroid Build Coastguard Worker }
1123*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_f32_gemm_relu_ukernel_function gemm_relu) const1124*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(xnn_f32_gemm_relu_ukernel_function gemm_relu) const {
1125*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
1126*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(a_stride(), k());
1127*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(cm_stride(), n());
1128*4bdc9457SAndroid Build Coastguard Worker
1129*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
1130*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
1131*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist;
1132*4bdc9457SAndroid Build Coastguard Worker
1133*4bdc9457SAndroid Build Coastguard Worker std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1134*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(n() * k());
1135*4bdc9457SAndroid Build Coastguard Worker std::vector<float> bias(n());
1136*4bdc9457SAndroid Build Coastguard Worker std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + packed_n());
1137*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1138*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c_ref(m() * n());
1139*4bdc9457SAndroid Build Coastguard Worker
1140*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
1141*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); });
1142*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); });
1143*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
1144*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), nanf(""));
1145*4bdc9457SAndroid Build Coastguard Worker std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1146*4bdc9457SAndroid Build Coastguard Worker
1147*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1148*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), 0, nullptr);
1149*4bdc9457SAndroid Build Coastguard Worker
1150*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
1151*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
1152*4bdc9457SAndroid Build Coastguard Worker for (size_t k_index = 0; k_index < k(); k_index++) {
1153*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(n(), packed_n());
1154*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(m_index * n() + n_index, c_ref.size());
1155*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] +=
1156*4bdc9457SAndroid Build Coastguard Worker a[m_index * a_stride() + k_index] *
1157*4bdc9457SAndroid Build Coastguard Worker b[n_index * k() + k_index];
1158*4bdc9457SAndroid Build Coastguard Worker }
1159*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] = std::max(0.0f, c_ref[m_index * n() + n_index] + bias[n_index]);
1160*4bdc9457SAndroid Build Coastguard Worker }
1161*4bdc9457SAndroid Build Coastguard Worker }
1162*4bdc9457SAndroid Build Coastguard Worker
1163*4bdc9457SAndroid Build Coastguard Worker gemm_relu(m(), n(), k() * sizeof(float),
1164*4bdc9457SAndroid Build Coastguard Worker a.data(), a_stride() * sizeof(float),
1165*4bdc9457SAndroid Build Coastguard Worker packed_w.data(),
1166*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1167*4bdc9457SAndroid Build Coastguard Worker nullptr);
1168*4bdc9457SAndroid Build Coastguard Worker
1169*4bdc9457SAndroid Build Coastguard Worker // Validate micro-kernel outputs.
1170*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
1171*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
1172*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], 0.0f)
1173*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1174*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1175*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1176*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
1177*4bdc9457SAndroid Build Coastguard Worker c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1178*4bdc9457SAndroid Build Coastguard Worker c_ref[i * n() + j],
1179*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f))
1180*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1181*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1182*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1183*4bdc9457SAndroid Build Coastguard Worker }
1184*4bdc9457SAndroid Build Coastguard Worker }
1185*4bdc9457SAndroid Build Coastguard Worker }
1186*4bdc9457SAndroid Build Coastguard Worker }
1187*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_f32_gemm_minmax_ukernel_function gemm_minmax,xnn_init_f32_minmax_params_fn init_params) const1188*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(xnn_f32_gemm_minmax_ukernel_function gemm_minmax, xnn_init_f32_minmax_params_fn init_params) const {
1189*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
1190*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(a_stride(), k());
1191*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(cm_stride(), n());
1192*4bdc9457SAndroid Build Coastguard Worker
1193*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
1194*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
1195*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist;
1196*4bdc9457SAndroid Build Coastguard Worker
1197*4bdc9457SAndroid Build Coastguard Worker std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1198*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(n() * k());
1199*4bdc9457SAndroid Build Coastguard Worker std::vector<float> bias(n());
1200*4bdc9457SAndroid Build Coastguard Worker std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + packed_n());
1201*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1202*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c_ref(m() * n());
1203*4bdc9457SAndroid Build Coastguard Worker
1204*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
1205*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); });
1206*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); });
1207*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
1208*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), nanf(""));
1209*4bdc9457SAndroid Build Coastguard Worker std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1210*4bdc9457SAndroid Build Coastguard Worker
1211*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1212*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), 0, nullptr);
1213*4bdc9457SAndroid Build Coastguard Worker
1214*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
1215*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
1216*4bdc9457SAndroid Build Coastguard Worker for (size_t k_index = 0; k_index < k(); k_index++) {
1217*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(n(), packed_n());
1218*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(m_index * n() + n_index, c_ref.size());
1219*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] +=
1220*4bdc9457SAndroid Build Coastguard Worker a[m_index * a_stride() + k_index] *
1221*4bdc9457SAndroid Build Coastguard Worker b[n_index * k() + k_index];
1222*4bdc9457SAndroid Build Coastguard Worker }
1223*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] += bias[n_index];
1224*4bdc9457SAndroid Build Coastguard Worker }
1225*4bdc9457SAndroid Build Coastguard Worker }
1226*4bdc9457SAndroid Build Coastguard Worker
1227*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
1228*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
1229*4bdc9457SAndroid Build Coastguard Worker const float c_min =
1230*4bdc9457SAndroid Build Coastguard Worker qmin() == std::numeric_limits<uint8_t>::min() ? -std::numeric_limits<float>::infinity()
1231*4bdc9457SAndroid Build Coastguard Worker : accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1232*4bdc9457SAndroid Build Coastguard Worker const float c_max =
1233*4bdc9457SAndroid Build Coastguard Worker qmax() == std::numeric_limits<uint8_t>::max() ? +std::numeric_limits<float>::infinity()
1234*4bdc9457SAndroid Build Coastguard Worker : accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1235*4bdc9457SAndroid Build Coastguard Worker
1236*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters.
1237*4bdc9457SAndroid Build Coastguard Worker xnn_f32_minmax_params params;
1238*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, c_min, c_max);
1239*4bdc9457SAndroid Build Coastguard Worker
1240*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
1241*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
1242*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] = std::max(std::min(c_ref[m_index * n() + n_index], c_max), c_min);
1243*4bdc9457SAndroid Build Coastguard Worker }
1244*4bdc9457SAndroid Build Coastguard Worker }
1245*4bdc9457SAndroid Build Coastguard Worker
1246*4bdc9457SAndroid Build Coastguard Worker gemm_minmax(m(), n(), k() * sizeof(float),
1247*4bdc9457SAndroid Build Coastguard Worker a.data(), a_stride() * sizeof(float),
1248*4bdc9457SAndroid Build Coastguard Worker packed_w.data(),
1249*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1250*4bdc9457SAndroid Build Coastguard Worker ¶ms);
1251*4bdc9457SAndroid Build Coastguard Worker
1252*4bdc9457SAndroid Build Coastguard Worker // Validate micro-kernel outputs.
1253*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
1254*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
1255*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
1256*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1257*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1258*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1259*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
1260*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1261*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1262*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1263*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
1264*4bdc9457SAndroid Build Coastguard Worker c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1265*4bdc9457SAndroid Build Coastguard Worker c_ref[i * n() + j],
1266*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f))
1267*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1268*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1269*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1270*4bdc9457SAndroid Build Coastguard Worker }
1271*4bdc9457SAndroid Build Coastguard Worker }
1272*4bdc9457SAndroid Build Coastguard Worker }
1273*4bdc9457SAndroid Build Coastguard Worker }
1274*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_f32_gemminc_minmax_ukernel_function gemminc,xnn_init_f32_minmax_params_fn init_params) const1275*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(xnn_f32_gemminc_minmax_ukernel_function gemminc, xnn_init_f32_minmax_params_fn init_params) const {
1276*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
1277*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(a_stride(), k());
1278*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(cm_stride(), n());
1279*4bdc9457SAndroid Build Coastguard Worker
1280*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
1281*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
1282*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist;
1283*4bdc9457SAndroid Build Coastguard Worker
1284*4bdc9457SAndroid Build Coastguard Worker std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1285*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(n() * k());
1286*4bdc9457SAndroid Build Coastguard Worker std::vector<float> bias(n());
1287*4bdc9457SAndroid Build Coastguard Worker std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k()); // no packed_n()
1288*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1289*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c_ref(m() * n());
1290*4bdc9457SAndroid Build Coastguard Worker std::vector<float, AlignedAllocator<float, 64>> acc(mr() * packed_n());
1291*4bdc9457SAndroid Build Coastguard Worker
1292*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
1293*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); });
1294*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); });
1295*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), nanf(""));
1296*4bdc9457SAndroid Build Coastguard Worker std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1297*4bdc9457SAndroid Build Coastguard Worker std::generate(acc.begin(), acc.end(), [&]() { return f32dist(rng); });
1298*4bdc9457SAndroid Build Coastguard Worker
1299*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1300*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_gemminc_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), packed_w.data(), nullptr);
1301*4bdc9457SAndroid Build Coastguard Worker
1302*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
1303*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
1304*4bdc9457SAndroid Build Coastguard Worker for (size_t k_index = 0; k_index < k(); k_index++) {
1305*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(n(), packed_n());
1306*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(m_index * n() + n_index, c_ref.size());
1307*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] +=
1308*4bdc9457SAndroid Build Coastguard Worker a[m_index * a_stride() + k_index] *
1309*4bdc9457SAndroid Build Coastguard Worker b[n_index * k() + k_index];
1310*4bdc9457SAndroid Build Coastguard Worker }
1311*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] += acc[n_index / nr() * nr() * mr() + m_index % mr() * nr() + n_index % nr()];
1312*4bdc9457SAndroid Build Coastguard Worker }
1313*4bdc9457SAndroid Build Coastguard Worker }
1314*4bdc9457SAndroid Build Coastguard Worker
1315*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
1316*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
1317*4bdc9457SAndroid Build Coastguard Worker const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1318*4bdc9457SAndroid Build Coastguard Worker const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1319*4bdc9457SAndroid Build Coastguard Worker
1320*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters.
1321*4bdc9457SAndroid Build Coastguard Worker xnn_f32_minmax_params params;
1322*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, c_min, c_max);
1323*4bdc9457SAndroid Build Coastguard Worker
1324*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
1325*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
1326*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] = std::max(std::min(c_ref[m_index * n() + n_index], c_max), c_min);
1327*4bdc9457SAndroid Build Coastguard Worker }
1328*4bdc9457SAndroid Build Coastguard Worker }
1329*4bdc9457SAndroid Build Coastguard Worker
1330*4bdc9457SAndroid Build Coastguard Worker gemminc(m(), n(), k() * sizeof(float),
1331*4bdc9457SAndroid Build Coastguard Worker a.data(), a_stride() * sizeof(float),
1332*4bdc9457SAndroid Build Coastguard Worker packed_w.data(),
1333*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1334*4bdc9457SAndroid Build Coastguard Worker acc.data(),
1335*4bdc9457SAndroid Build Coastguard Worker ¶ms);
1336*4bdc9457SAndroid Build Coastguard Worker
1337*4bdc9457SAndroid Build Coastguard Worker // Validate micro-kernel outputs.
1338*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
1339*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
1340*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
1341*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1342*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1343*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1344*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
1345*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1346*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1347*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1348*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
1349*4bdc9457SAndroid Build Coastguard Worker c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1350*4bdc9457SAndroid Build Coastguard Worker c_ref[i * n() + j],
1351*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f))
1352*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1353*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1354*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1355*4bdc9457SAndroid Build Coastguard Worker }
1356*4bdc9457SAndroid Build Coastguard Worker }
1357*4bdc9457SAndroid Build Coastguard Worker }
1358*4bdc9457SAndroid Build Coastguard Worker }
1359*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_f32_igemm_ukernel_function igemm) const1360*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(xnn_f32_igemm_ukernel_function igemm) const {
1361*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
1362*4bdc9457SAndroid Build Coastguard Worker
1363*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
1364*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
1365*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist;
1366*4bdc9457SAndroid Build Coastguard Worker
1367*4bdc9457SAndroid Build Coastguard Worker std::vector<float> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1368*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(n() * ks() * k());
1369*4bdc9457SAndroid Build Coastguard Worker std::vector<float, AlignedAllocator<float, 64>> packed_w(ks() * packed_k() * packed_n() + packed_n());
1370*4bdc9457SAndroid Build Coastguard Worker std::vector<float> bias(n());
1371*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1372*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c_ref(m() * n());
1373*4bdc9457SAndroid Build Coastguard Worker std::vector<float> junk(k() + XNN_EXTRA_BYTES / sizeof(float));
1374*4bdc9457SAndroid Build Coastguard Worker std::vector<const float*> im2col(mr() * ks());
1375*4bdc9457SAndroid Build Coastguard Worker std::fill(junk.begin(), junk.end(), nanf(""));
1376*4bdc9457SAndroid Build Coastguard Worker
1377*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
1378*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); });
1379*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); });
1380*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
1381*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), nanf(""));
1382*4bdc9457SAndroid Build Coastguard Worker std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1383*4bdc9457SAndroid Build Coastguard Worker
1384*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1385*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_conv_goki_w(
1386*4bdc9457SAndroid Build Coastguard Worker 1, n(), ks(), k(), nr(), kr(), sr(),
1387*4bdc9457SAndroid Build Coastguard Worker b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, nullptr);
1388*4bdc9457SAndroid Build Coastguard Worker
1389*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1390*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < mr(); m_index++) {
1391*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
1392*4bdc9457SAndroid Build Coastguard Worker }
1393*4bdc9457SAndroid Build Coastguard Worker }
1394*4bdc9457SAndroid Build Coastguard Worker std::shuffle(im2col.begin(), im2col.end(), rng);
1395*4bdc9457SAndroid Build Coastguard Worker if (zero_index() != SIZE_MAX) {
1396*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1397*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + zero_index()] = a.data();
1398*4bdc9457SAndroid Build Coastguard Worker }
1399*4bdc9457SAndroid Build Coastguard Worker }
1400*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1401*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = m(); m_index < mr(); m_index++) {
1402*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + m_index] = junk.data();
1403*4bdc9457SAndroid Build Coastguard Worker }
1404*4bdc9457SAndroid Build Coastguard Worker }
1405*4bdc9457SAndroid Build Coastguard Worker
1406*4bdc9457SAndroid Build Coastguard Worker std::fill(c_ref.begin(), c_ref.end(), 0.0);
1407*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
1408*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
1409*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1410*4bdc9457SAndroid Build Coastguard Worker for (size_t k_index = 0; k_index < k(); k_index++) {
1411*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(ks_index * mr() + m_index, im2col.size());
1412*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(k_index, k());
1413*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(k_index, a_stride());
1414*4bdc9457SAndroid Build Coastguard Worker if (im2col[ks_index * mr() + m_index] == a.data()) {
1415*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] +=
1416*4bdc9457SAndroid Build Coastguard Worker (im2col[ks_index * mr() + m_index][k_index]) *
1417*4bdc9457SAndroid Build Coastguard Worker (b[(n_index * ks() + ks_index) * k() + k_index]);
1418*4bdc9457SAndroid Build Coastguard Worker } else {
1419*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] +=
1420*4bdc9457SAndroid Build Coastguard Worker (im2col[ks_index * mr() + m_index][k_index + a_offset()]) *
1421*4bdc9457SAndroid Build Coastguard Worker (b[(n_index * ks() + ks_index) * k() + k_index]);
1422*4bdc9457SAndroid Build Coastguard Worker }
1423*4bdc9457SAndroid Build Coastguard Worker }
1424*4bdc9457SAndroid Build Coastguard Worker }
1425*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] += bias[n_index];
1426*4bdc9457SAndroid Build Coastguard Worker }
1427*4bdc9457SAndroid Build Coastguard Worker }
1428*4bdc9457SAndroid Build Coastguard Worker
1429*4bdc9457SAndroid Build Coastguard Worker const float* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
1430*4bdc9457SAndroid Build Coastguard Worker
1431*4bdc9457SAndroid Build Coastguard Worker igemm(
1432*4bdc9457SAndroid Build Coastguard Worker m(), n(), k() * sizeof(float), ks() * mr() * sizeof(void*),
1433*4bdc9457SAndroid Build Coastguard Worker im2col.data(), packed_w.data(),
1434*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1435*4bdc9457SAndroid Build Coastguard Worker a_offset() * sizeof(float), zero_pointer,
1436*4bdc9457SAndroid Build Coastguard Worker nullptr);
1437*4bdc9457SAndroid Build Coastguard Worker
1438*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
1439*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
1440*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
1441*4bdc9457SAndroid Build Coastguard Worker c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1442*4bdc9457SAndroid Build Coastguard Worker c_ref[i * n() + j],
1443*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f))
1444*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1445*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1446*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1447*4bdc9457SAndroid Build Coastguard Worker }
1448*4bdc9457SAndroid Build Coastguard Worker }
1449*4bdc9457SAndroid Build Coastguard Worker }
1450*4bdc9457SAndroid Build Coastguard Worker }
1451*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_f32_igemm_relu_ukernel_function igemm_relu) const1452*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(xnn_f32_igemm_relu_ukernel_function igemm_relu) const {
1453*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
1454*4bdc9457SAndroid Build Coastguard Worker
1455*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
1456*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
1457*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist;
1458*4bdc9457SAndroid Build Coastguard Worker
1459*4bdc9457SAndroid Build Coastguard Worker std::vector<float> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1460*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(n() * ks() * k());
1461*4bdc9457SAndroid Build Coastguard Worker std::vector<float, AlignedAllocator<float, 64>> packed_w(ks() * packed_k() * packed_n() + packed_n());
1462*4bdc9457SAndroid Build Coastguard Worker std::vector<float> bias(n());
1463*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1464*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c_ref(m() * n());
1465*4bdc9457SAndroid Build Coastguard Worker std::vector<float> junk(k() + XNN_EXTRA_BYTES / sizeof(float));
1466*4bdc9457SAndroid Build Coastguard Worker std::vector<const float*> im2col(mr() * ks());
1467*4bdc9457SAndroid Build Coastguard Worker std::fill(junk.begin(), junk.end(), nanf(""));
1468*4bdc9457SAndroid Build Coastguard Worker
1469*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
1470*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); });
1471*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); });
1472*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
1473*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), nanf(""));
1474*4bdc9457SAndroid Build Coastguard Worker std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1475*4bdc9457SAndroid Build Coastguard Worker
1476*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1477*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_conv_goki_w(
1478*4bdc9457SAndroid Build Coastguard Worker 1, n(), ks(), k(), nr(), kr(), sr(),
1479*4bdc9457SAndroid Build Coastguard Worker b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, nullptr);
1480*4bdc9457SAndroid Build Coastguard Worker
1481*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1482*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < mr(); m_index++) {
1483*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
1484*4bdc9457SAndroid Build Coastguard Worker }
1485*4bdc9457SAndroid Build Coastguard Worker }
1486*4bdc9457SAndroid Build Coastguard Worker std::shuffle(im2col.begin(), im2col.end(), rng);
1487*4bdc9457SAndroid Build Coastguard Worker if (zero_index() != SIZE_MAX) {
1488*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1489*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + zero_index()] = a.data();
1490*4bdc9457SAndroid Build Coastguard Worker }
1491*4bdc9457SAndroid Build Coastguard Worker }
1492*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1493*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = m(); m_index < mr(); m_index++) {
1494*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + m_index] = junk.data();
1495*4bdc9457SAndroid Build Coastguard Worker }
1496*4bdc9457SAndroid Build Coastguard Worker }
1497*4bdc9457SAndroid Build Coastguard Worker
1498*4bdc9457SAndroid Build Coastguard Worker std::fill(c_ref.begin(), c_ref.end(), 0.0);
1499*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
1500*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
1501*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1502*4bdc9457SAndroid Build Coastguard Worker for (size_t k_index = 0; k_index < k(); k_index++) {
1503*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(ks_index * mr() + m_index, im2col.size());
1504*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(k_index, k());
1505*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(k_index, a_stride());
1506*4bdc9457SAndroid Build Coastguard Worker if (im2col[ks_index * mr() + m_index] == a.data()) {
1507*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] +=
1508*4bdc9457SAndroid Build Coastguard Worker (im2col[ks_index * mr() + m_index][k_index]) *
1509*4bdc9457SAndroid Build Coastguard Worker (b[(n_index * ks() + ks_index) * k() + k_index]);
1510*4bdc9457SAndroid Build Coastguard Worker } else {
1511*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] +=
1512*4bdc9457SAndroid Build Coastguard Worker (im2col[ks_index * mr() + m_index][k_index + a_offset()]) *
1513*4bdc9457SAndroid Build Coastguard Worker (b[(n_index * ks() + ks_index) * k() + k_index]);
1514*4bdc9457SAndroid Build Coastguard Worker }
1515*4bdc9457SAndroid Build Coastguard Worker }
1516*4bdc9457SAndroid Build Coastguard Worker }
1517*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] = std::max(0.0f, bias[n_index] + c_ref[m_index * n() + n_index]);
1518*4bdc9457SAndroid Build Coastguard Worker }
1519*4bdc9457SAndroid Build Coastguard Worker }
1520*4bdc9457SAndroid Build Coastguard Worker
1521*4bdc9457SAndroid Build Coastguard Worker const float* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
1522*4bdc9457SAndroid Build Coastguard Worker
1523*4bdc9457SAndroid Build Coastguard Worker igemm_relu(
1524*4bdc9457SAndroid Build Coastguard Worker m(), n(), k() * sizeof(float), ks() * mr() * sizeof(void*),
1525*4bdc9457SAndroid Build Coastguard Worker im2col.data(), packed_w.data(),
1526*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1527*4bdc9457SAndroid Build Coastguard Worker a_offset() * sizeof(float), zero_pointer,
1528*4bdc9457SAndroid Build Coastguard Worker nullptr);
1529*4bdc9457SAndroid Build Coastguard Worker
1530*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
1531*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
1532*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], 0.0f)
1533*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1534*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1535*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1536*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
1537*4bdc9457SAndroid Build Coastguard Worker c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1538*4bdc9457SAndroid Build Coastguard Worker c_ref[i * n() + j],
1539*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f))
1540*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1541*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1542*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1543*4bdc9457SAndroid Build Coastguard Worker }
1544*4bdc9457SAndroid Build Coastguard Worker }
1545*4bdc9457SAndroid Build Coastguard Worker }
1546*4bdc9457SAndroid Build Coastguard Worker }
1547*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_f32_igemm_minmax_ukernel_function igemm_minmax,xnn_init_f32_minmax_params_fn init_params) const1548*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(xnn_f32_igemm_minmax_ukernel_function igemm_minmax, xnn_init_f32_minmax_params_fn init_params) const {
1549*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
1550*4bdc9457SAndroid Build Coastguard Worker
1551*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
1552*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
1553*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist;
1554*4bdc9457SAndroid Build Coastguard Worker
1555*4bdc9457SAndroid Build Coastguard Worker std::vector<float> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1556*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(n() * ks() * k());
1557*4bdc9457SAndroid Build Coastguard Worker std::vector<float, AlignedAllocator<float, 64>> packed_w(ks() * packed_k() * packed_n() + packed_n());
1558*4bdc9457SAndroid Build Coastguard Worker std::vector<float> bias(n());
1559*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1560*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c_ref(m() * n());
1561*4bdc9457SAndroid Build Coastguard Worker std::vector<float> junk(k() + XNN_EXTRA_BYTES / sizeof(float));
1562*4bdc9457SAndroid Build Coastguard Worker std::vector<const float*> im2col(mr() * ks());
1563*4bdc9457SAndroid Build Coastguard Worker std::fill(junk.begin(), junk.end(), nanf(""));
1564*4bdc9457SAndroid Build Coastguard Worker
1565*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
1566*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); });
1567*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); });
1568*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
1569*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), nanf(""));
1570*4bdc9457SAndroid Build Coastguard Worker std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1571*4bdc9457SAndroid Build Coastguard Worker
1572*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1573*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_conv_goki_w(
1574*4bdc9457SAndroid Build Coastguard Worker 1, n(), ks(), k(), nr(), kr(), sr(),
1575*4bdc9457SAndroid Build Coastguard Worker b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, nullptr);
1576*4bdc9457SAndroid Build Coastguard Worker
1577*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1578*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < mr(); m_index++) {
1579*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
1580*4bdc9457SAndroid Build Coastguard Worker }
1581*4bdc9457SAndroid Build Coastguard Worker }
1582*4bdc9457SAndroid Build Coastguard Worker std::shuffle(im2col.begin(), im2col.end(), rng);
1583*4bdc9457SAndroid Build Coastguard Worker if (zero_index() != SIZE_MAX) {
1584*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1585*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + zero_index()] = a.data();
1586*4bdc9457SAndroid Build Coastguard Worker }
1587*4bdc9457SAndroid Build Coastguard Worker }
1588*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1589*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = m(); m_index < mr(); m_index++) {
1590*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + m_index] = junk.data();
1591*4bdc9457SAndroid Build Coastguard Worker }
1592*4bdc9457SAndroid Build Coastguard Worker }
1593*4bdc9457SAndroid Build Coastguard Worker
1594*4bdc9457SAndroid Build Coastguard Worker std::fill(c_ref.begin(), c_ref.end(), 0.0);
1595*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
1596*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
1597*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1598*4bdc9457SAndroid Build Coastguard Worker for (size_t k_index = 0; k_index < k(); k_index++) {
1599*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(ks_index * mr() + m_index, im2col.size());
1600*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(k_index, k());
1601*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(k_index, a_stride());
1602*4bdc9457SAndroid Build Coastguard Worker if (im2col[ks_index * mr() + m_index] == a.data()) {
1603*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] +=
1604*4bdc9457SAndroid Build Coastguard Worker (im2col[ks_index * mr() + m_index][k_index]) *
1605*4bdc9457SAndroid Build Coastguard Worker (b[(n_index * ks() + ks_index) * k() + k_index]);
1606*4bdc9457SAndroid Build Coastguard Worker } else {
1607*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] +=
1608*4bdc9457SAndroid Build Coastguard Worker (im2col[ks_index * mr() + m_index][k_index + a_offset()]) *
1609*4bdc9457SAndroid Build Coastguard Worker (b[(n_index * ks() + ks_index) * k() + k_index]);
1610*4bdc9457SAndroid Build Coastguard Worker }
1611*4bdc9457SAndroid Build Coastguard Worker }
1612*4bdc9457SAndroid Build Coastguard Worker }
1613*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] += bias[n_index];
1614*4bdc9457SAndroid Build Coastguard Worker }
1615*4bdc9457SAndroid Build Coastguard Worker }
1616*4bdc9457SAndroid Build Coastguard Worker
1617*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
1618*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
1619*4bdc9457SAndroid Build Coastguard Worker const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1620*4bdc9457SAndroid Build Coastguard Worker const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1621*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
1622*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
1623*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] = std::min(c_ref[m_index * n() + n_index], c_max);
1624*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] = std::max(c_ref[m_index * n() + n_index], c_min);
1625*4bdc9457SAndroid Build Coastguard Worker }
1626*4bdc9457SAndroid Build Coastguard Worker }
1627*4bdc9457SAndroid Build Coastguard Worker
1628*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters.
1629*4bdc9457SAndroid Build Coastguard Worker xnn_f32_minmax_params params;
1630*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, c_min, c_max);
1631*4bdc9457SAndroid Build Coastguard Worker
1632*4bdc9457SAndroid Build Coastguard Worker const float* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
1633*4bdc9457SAndroid Build Coastguard Worker
1634*4bdc9457SAndroid Build Coastguard Worker igemm_minmax(
1635*4bdc9457SAndroid Build Coastguard Worker m(), n(), k() * sizeof(float), ks() * mr() * sizeof(void*),
1636*4bdc9457SAndroid Build Coastguard Worker im2col.data(), packed_w.data(),
1637*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1638*4bdc9457SAndroid Build Coastguard Worker a_offset() * sizeof(float), zero_pointer,
1639*4bdc9457SAndroid Build Coastguard Worker ¶ms);
1640*4bdc9457SAndroid Build Coastguard Worker
1641*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
1642*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
1643*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
1644*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1645*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1646*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1647*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
1648*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1649*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1650*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1651*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
1652*4bdc9457SAndroid Build Coastguard Worker c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1653*4bdc9457SAndroid Build Coastguard Worker c_ref[i * n() + j],
1654*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f))
1655*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1656*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1657*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1658*4bdc9457SAndroid Build Coastguard Worker }
1659*4bdc9457SAndroid Build Coastguard Worker }
1660*4bdc9457SAndroid Build Coastguard Worker }
1661*4bdc9457SAndroid Build Coastguard Worker }
1662*4bdc9457SAndroid Build Coastguard Worker
1663*4bdc9457SAndroid Build Coastguard Worker #if XNN_PLATFORM_JIT
Test(xnn_jit_gemm_code_generator_function gemm_generator,xnn_init_f32_minmax_params_fn init_params) const1664*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(
1665*4bdc9457SAndroid Build Coastguard Worker xnn_jit_gemm_code_generator_function gemm_generator,
1666*4bdc9457SAndroid Build Coastguard Worker xnn_init_f32_minmax_params_fn init_params) const
1667*4bdc9457SAndroid Build Coastguard Worker {
1668*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
1669*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(a_stride(), k());
1670*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(cm_stride(), n());
1671*4bdc9457SAndroid Build Coastguard Worker
1672*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
1673*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
1674*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist;
1675*4bdc9457SAndroid Build Coastguard Worker
1676*4bdc9457SAndroid Build Coastguard Worker std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1677*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(n() * k());
1678*4bdc9457SAndroid Build Coastguard Worker std::vector<float> bias(n());
1679*4bdc9457SAndroid Build Coastguard Worker std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + packed_n());
1680*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1681*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c_ref(m() * n());
1682*4bdc9457SAndroid Build Coastguard Worker
1683*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
1684*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); });
1685*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); });
1686*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
1687*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), nanf(""));
1688*4bdc9457SAndroid Build Coastguard Worker std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1689*4bdc9457SAndroid Build Coastguard Worker
1690*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1691*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data(), 0, nullptr);
1692*4bdc9457SAndroid Build Coastguard Worker
1693*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
1694*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
1695*4bdc9457SAndroid Build Coastguard Worker for (size_t k_index = 0; k_index < k(); k_index++) {
1696*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(n(), packed_n());
1697*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(m_index * n() + n_index, c_ref.size());
1698*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] +=
1699*4bdc9457SAndroid Build Coastguard Worker a[m_index * a_stride() + k_index] *
1700*4bdc9457SAndroid Build Coastguard Worker b[n_index * k() + k_index];
1701*4bdc9457SAndroid Build Coastguard Worker }
1702*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] += bias[n_index];
1703*4bdc9457SAndroid Build Coastguard Worker }
1704*4bdc9457SAndroid Build Coastguard Worker }
1705*4bdc9457SAndroid Build Coastguard Worker
1706*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
1707*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
1708*4bdc9457SAndroid Build Coastguard Worker const float c_min =
1709*4bdc9457SAndroid Build Coastguard Worker qmin() == std::numeric_limits<uint8_t>::min() ? -std::numeric_limits<float>::infinity()
1710*4bdc9457SAndroid Build Coastguard Worker : accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1711*4bdc9457SAndroid Build Coastguard Worker const float c_max =
1712*4bdc9457SAndroid Build Coastguard Worker qmax() == std::numeric_limits<uint8_t>::max() ? +std::numeric_limits<float>::infinity()
1713*4bdc9457SAndroid Build Coastguard Worker : accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1714*4bdc9457SAndroid Build Coastguard Worker
1715*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters.
1716*4bdc9457SAndroid Build Coastguard Worker xnn_f32_minmax_params params;
1717*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, c_min, c_max);
1718*4bdc9457SAndroid Build Coastguard Worker
1719*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
1720*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
1721*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] = std::max(std::min(c_ref[m_index * n() + n_index], c_max), c_min);
1722*4bdc9457SAndroid Build Coastguard Worker }
1723*4bdc9457SAndroid Build Coastguard Worker }
1724*4bdc9457SAndroid Build Coastguard Worker
1725*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
1726*4bdc9457SAndroid Build Coastguard Worker struct xnn_code_buffer code_buffer;
1727*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_allocate_code_memory(&code_buffer, XNN_DEFAULT_CODE_BUFFER_SIZE));
1728*4bdc9457SAndroid Build Coastguard Worker jit_gemm_params p = (jit_gemm_params) {
1729*4bdc9457SAndroid Build Coastguard Worker .f32_minmax = {
1730*4bdc9457SAndroid Build Coastguard Worker .min = c_min,
1731*4bdc9457SAndroid Build Coastguard Worker .max = c_max
1732*4bdc9457SAndroid Build Coastguard Worker }
1733*4bdc9457SAndroid Build Coastguard Worker };
1734*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, gemm_generator(&code_buffer, mr(), n() % nr(), k() * sizeof(float), &p));
1735*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_finalize_code_memory(&code_buffer));
1736*4bdc9457SAndroid Build Coastguard Worker xnn_f32_gemm_minmax_ukernel_function gemm_minmax =
1737*4bdc9457SAndroid Build Coastguard Worker reinterpret_cast<xnn_f32_gemm_minmax_ukernel_function>(code_buffer.start);
1738*4bdc9457SAndroid Build Coastguard Worker
1739*4bdc9457SAndroid Build Coastguard Worker gemm_minmax(m(), n(), k() * sizeof(float),
1740*4bdc9457SAndroid Build Coastguard Worker a.data(), a_stride() * sizeof(float),
1741*4bdc9457SAndroid Build Coastguard Worker packed_w.data(),
1742*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1743*4bdc9457SAndroid Build Coastguard Worker ¶ms);
1744*4bdc9457SAndroid Build Coastguard Worker
1745*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_release_code_memory(&code_buffer));
1746*4bdc9457SAndroid Build Coastguard Worker
1747*4bdc9457SAndroid Build Coastguard Worker // Validate micro-kernel outputs.
1748*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
1749*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
1750*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
1751*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1752*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1753*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1754*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
1755*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1756*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1757*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1758*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
1759*4bdc9457SAndroid Build Coastguard Worker c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1760*4bdc9457SAndroid Build Coastguard Worker c_ref[i * n() + j],
1761*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f))
1762*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
1763*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1764*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
1765*4bdc9457SAndroid Build Coastguard Worker }
1766*4bdc9457SAndroid Build Coastguard Worker }
1767*4bdc9457SAndroid Build Coastguard Worker }
1768*4bdc9457SAndroid Build Coastguard Worker }
1769*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_jit_igemm_code_generator_function igemm_generator,xnn_init_f32_minmax_params_fn init_params) const1770*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(
1771*4bdc9457SAndroid Build Coastguard Worker xnn_jit_igemm_code_generator_function igemm_generator,
1772*4bdc9457SAndroid Build Coastguard Worker xnn_init_f32_minmax_params_fn init_params) const
1773*4bdc9457SAndroid Build Coastguard Worker {
1774*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
1775*4bdc9457SAndroid Build Coastguard Worker
1776*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
1777*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
1778*4bdc9457SAndroid Build Coastguard Worker std::uniform_real_distribution<float> f32dist;
1779*4bdc9457SAndroid Build Coastguard Worker
1780*4bdc9457SAndroid Build Coastguard Worker std::vector<float> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
1781*4bdc9457SAndroid Build Coastguard Worker std::vector<float> b(n() * ks() * k());
1782*4bdc9457SAndroid Build Coastguard Worker std::vector<float, AlignedAllocator<float, 64>> packed_w(ks() * packed_k() * packed_n() + packed_n());
1783*4bdc9457SAndroid Build Coastguard Worker std::vector<float> bias(n());
1784*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1785*4bdc9457SAndroid Build Coastguard Worker std::vector<float> c_ref(m() * n());
1786*4bdc9457SAndroid Build Coastguard Worker std::vector<float> junk(k() + XNN_EXTRA_BYTES / sizeof(float));
1787*4bdc9457SAndroid Build Coastguard Worker std::vector<const float*> im2col(mr() * ks());
1788*4bdc9457SAndroid Build Coastguard Worker std::fill(junk.begin(), junk.end(), nanf(""));
1789*4bdc9457SAndroid Build Coastguard Worker
1790*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
1791*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), [&]() { return f32dist(rng); });
1792*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), [&]() { return f32dist(rng); });
1793*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); });
1794*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), nanf(""));
1795*4bdc9457SAndroid Build Coastguard Worker std::fill(c_ref.begin(), c_ref.end(), 0.0f);
1796*4bdc9457SAndroid Build Coastguard Worker
1797*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), 0.0f);
1798*4bdc9457SAndroid Build Coastguard Worker xnn_pack_f32_conv_goki_w(
1799*4bdc9457SAndroid Build Coastguard Worker 1, n(), ks(), k(), nr(), kr(), sr(),
1800*4bdc9457SAndroid Build Coastguard Worker b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, nullptr);
1801*4bdc9457SAndroid Build Coastguard Worker
1802*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1803*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < mr(); m_index++) {
1804*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
1805*4bdc9457SAndroid Build Coastguard Worker }
1806*4bdc9457SAndroid Build Coastguard Worker }
1807*4bdc9457SAndroid Build Coastguard Worker std::shuffle(im2col.begin(), im2col.end(), rng);
1808*4bdc9457SAndroid Build Coastguard Worker if (zero_index() != SIZE_MAX) {
1809*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1810*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + zero_index()] = a.data();
1811*4bdc9457SAndroid Build Coastguard Worker }
1812*4bdc9457SAndroid Build Coastguard Worker }
1813*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1814*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = m(); m_index < mr(); m_index++) {
1815*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + m_index] = junk.data();
1816*4bdc9457SAndroid Build Coastguard Worker }
1817*4bdc9457SAndroid Build Coastguard Worker }
1818*4bdc9457SAndroid Build Coastguard Worker
1819*4bdc9457SAndroid Build Coastguard Worker std::fill(c_ref.begin(), c_ref.end(), 0.0);
1820*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
1821*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
1822*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
1823*4bdc9457SAndroid Build Coastguard Worker for (size_t k_index = 0; k_index < k(); k_index++) {
1824*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(ks_index * mr() + m_index, im2col.size());
1825*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(k_index, k());
1826*4bdc9457SAndroid Build Coastguard Worker ASSERT_LT(k_index, a_stride());
1827*4bdc9457SAndroid Build Coastguard Worker if (im2col[ks_index * mr() + m_index] == a.data()) {
1828*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] +=
1829*4bdc9457SAndroid Build Coastguard Worker (im2col[ks_index * mr() + m_index][k_index]) *
1830*4bdc9457SAndroid Build Coastguard Worker (b[(n_index * ks() + ks_index) * k() + k_index]);
1831*4bdc9457SAndroid Build Coastguard Worker } else {
1832*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] +=
1833*4bdc9457SAndroid Build Coastguard Worker (im2col[ks_index * mr() + m_index][k_index + a_offset()]) *
1834*4bdc9457SAndroid Build Coastguard Worker (b[(n_index * ks() + ks_index) * k() + k_index]);
1835*4bdc9457SAndroid Build Coastguard Worker }
1836*4bdc9457SAndroid Build Coastguard Worker }
1837*4bdc9457SAndroid Build Coastguard Worker }
1838*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] += bias[n_index];
1839*4bdc9457SAndroid Build Coastguard Worker }
1840*4bdc9457SAndroid Build Coastguard Worker }
1841*4bdc9457SAndroid Build Coastguard Worker
1842*4bdc9457SAndroid Build Coastguard Worker const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
1843*4bdc9457SAndroid Build Coastguard Worker const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
1844*4bdc9457SAndroid Build Coastguard Worker const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
1845*4bdc9457SAndroid Build Coastguard Worker const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
1846*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
1847*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
1848*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] = std::min(c_ref[m_index * n() + n_index], c_max);
1849*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] = std::max(c_ref[m_index * n() + n_index], c_min);
1850*4bdc9457SAndroid Build Coastguard Worker }
1851*4bdc9457SAndroid Build Coastguard Worker }
1852*4bdc9457SAndroid Build Coastguard Worker
1853*4bdc9457SAndroid Build Coastguard Worker // Prepare parameters.
1854*4bdc9457SAndroid Build Coastguard Worker xnn_f32_minmax_params params;
1855*4bdc9457SAndroid Build Coastguard Worker init_params(¶ms, c_min, c_max);
1856*4bdc9457SAndroid Build Coastguard Worker
1857*4bdc9457SAndroid Build Coastguard Worker const float* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
1858*4bdc9457SAndroid Build Coastguard Worker
1859*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
1860*4bdc9457SAndroid Build Coastguard Worker struct xnn_code_buffer code_buffer;
1861*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_allocate_code_memory(&code_buffer, XNN_DEFAULT_CODE_BUFFER_SIZE));
1862*4bdc9457SAndroid Build Coastguard Worker jit_gemm_params p = (jit_gemm_params) {
1863*4bdc9457SAndroid Build Coastguard Worker .f32_minmax = {
1864*4bdc9457SAndroid Build Coastguard Worker .min = c_min,
1865*4bdc9457SAndroid Build Coastguard Worker .max = c_max
1866*4bdc9457SAndroid Build Coastguard Worker }
1867*4bdc9457SAndroid Build Coastguard Worker };
1868*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success,
1869*4bdc9457SAndroid Build Coastguard Worker igemm_generator(&code_buffer, mr(), n() % nr(), k() * sizeof(float), ks() * mr() * sizeof(void *), &p));
1870*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_finalize_code_memory(&code_buffer));
1871*4bdc9457SAndroid Build Coastguard Worker xnn_f32_igemm_minmax_ukernel_function igemm_minmax =
1872*4bdc9457SAndroid Build Coastguard Worker reinterpret_cast<xnn_f32_igemm_minmax_ukernel_function>(code_buffer.start);
1873*4bdc9457SAndroid Build Coastguard Worker
1874*4bdc9457SAndroid Build Coastguard Worker igemm_minmax(
1875*4bdc9457SAndroid Build Coastguard Worker m(), n(), k() * sizeof(float), ks() * mr() * sizeof(void*),
1876*4bdc9457SAndroid Build Coastguard Worker im2col.data(), packed_w.data(),
1877*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
1878*4bdc9457SAndroid Build Coastguard Worker a_offset() * sizeof(float), zero_pointer,
1879*4bdc9457SAndroid Build Coastguard Worker ¶ms);
1880*4bdc9457SAndroid Build Coastguard Worker
1881*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_release_code_memory(&code_buffer));
1882*4bdc9457SAndroid Build Coastguard Worker
1883*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
1884*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
1885*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
1886*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1887*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1888*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1889*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
1890*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1891*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1892*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1893*4bdc9457SAndroid Build Coastguard Worker ASSERT_NEAR(
1894*4bdc9457SAndroid Build Coastguard Worker c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
1895*4bdc9457SAndroid Build Coastguard Worker c_ref[i * n() + j],
1896*4bdc9457SAndroid Build Coastguard Worker std::max(1.0e-5f, std::abs(c_ref[i * n() + j]) * 1.0e-6f))
1897*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
1898*4bdc9457SAndroid Build Coastguard Worker << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
1899*4bdc9457SAndroid Build Coastguard Worker << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
1900*4bdc9457SAndroid Build Coastguard Worker }
1901*4bdc9457SAndroid Build Coastguard Worker }
1902*4bdc9457SAndroid Build Coastguard Worker }
1903*4bdc9457SAndroid Build Coastguard Worker }
1904*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_jit_gemm_code_generator_function gemm_generator,xnn_init_qc8_conv_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const1905*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(
1906*4bdc9457SAndroid Build Coastguard Worker xnn_jit_gemm_code_generator_function gemm_generator,
1907*4bdc9457SAndroid Build Coastguard Worker xnn_init_qc8_conv_minmax_params_fn init_params,
1908*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_requantize_fn requantize) const
1909*4bdc9457SAndroid Build Coastguard Worker {
1910*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
1911*4bdc9457SAndroid Build Coastguard Worker
1912*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
1913*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
1914*4bdc9457SAndroid Build Coastguard Worker auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
1915*4bdc9457SAndroid Build Coastguard Worker auto i8rng = std::bind(
1916*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
1917*4bdc9457SAndroid Build Coastguard Worker std::ref(rng));
1918*4bdc9457SAndroid Build Coastguard Worker auto w8rng = std::bind(
1919*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
1920*4bdc9457SAndroid Build Coastguard Worker std::ref(rng));
1921*4bdc9457SAndroid Build Coastguard Worker
1922*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(int8_t));
1923*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> b(n() * k());
1924*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> bias(n());
1925*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int8_t));
1926*4bdc9457SAndroid Build Coastguard Worker std::vector<int16_t, AlignedAllocator<int16_t, 64>> packed_xw(packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int16_t));
1927*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
1928*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> acc(m() * n());
1929*4bdc9457SAndroid Build Coastguard Worker std::vector<float> scale(n());
1930*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> c_ref(m() * n());
1931*4bdc9457SAndroid Build Coastguard Worker
1932*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
1933*4bdc9457SAndroid Build Coastguard Worker do {
1934*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), std::ref(i8rng));
1935*4bdc9457SAndroid Build Coastguard Worker } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
1936*4bdc9457SAndroid Build Coastguard Worker do {
1937*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), std::ref(w8rng));
1938*4bdc9457SAndroid Build Coastguard Worker } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
1939*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), std::ref(i32rng));
1940*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), 0xA5);
1941*4bdc9457SAndroid Build Coastguard Worker
1942*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), 0);
1943*4bdc9457SAndroid Build Coastguard Worker const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
1944*4bdc9457SAndroid Build Coastguard Worker if (extended_weights()) {
1945*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qs8_gemm_xw_goi_w(1, n(), k(), nr(), kr(), sr(),
1946*4bdc9457SAndroid Build Coastguard Worker b.data(), bias.data(), packed_xw.data(), nr() * sizeof(float), &packing_params);
1947*4bdc9457SAndroid Build Coastguard Worker } else {
1948*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qs8_gemm_goi_w(1, n(), k(), nr(), kr(), sr(),
1949*4bdc9457SAndroid Build Coastguard Worker b.data(), bias.data(), packed_w.data(), nr() * sizeof(float), &packing_params);
1950*4bdc9457SAndroid Build Coastguard Worker }
1951*4bdc9457SAndroid Build Coastguard Worker
1952*4bdc9457SAndroid Build Coastguard Worker // Compute 32-bit results and output quantization arguments.
1953*4bdc9457SAndroid Build Coastguard Worker std::fill(acc.begin(), acc.end(), 0);
1954*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
1955*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
1956*4bdc9457SAndroid Build Coastguard Worker for (size_t k_index = 0; k_index < k(); k_index++) {
1957*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] +=
1958*4bdc9457SAndroid Build Coastguard Worker (int32_t(a[m_index * a_stride() + k_index]) - int32_t(a_zero_point() - 0x80)) *
1959*4bdc9457SAndroid Build Coastguard Worker int32_t(b[n_index * k() + k_index]);
1960*4bdc9457SAndroid Build Coastguard Worker }
1961*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] += bias[n_index];
1962*4bdc9457SAndroid Build Coastguard Worker }
1963*4bdc9457SAndroid Build Coastguard Worker }
1964*4bdc9457SAndroid Build Coastguard Worker
1965*4bdc9457SAndroid Build Coastguard Worker const int8_t c_zero_point = -1;
1966*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
1967*4bdc9457SAndroid Build Coastguard Worker int32_t accumulated_min = acc[n_index];
1968*4bdc9457SAndroid Build Coastguard Worker int32_t accumulated_max = acc[n_index];
1969*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
1970*4bdc9457SAndroid Build Coastguard Worker accumulated_min = std::min(accumulated_min, acc[m_index * n() + n_index]);
1971*4bdc9457SAndroid Build Coastguard Worker accumulated_max = std::max(accumulated_max, acc[m_index * n() + n_index]);
1972*4bdc9457SAndroid Build Coastguard Worker }
1973*4bdc9457SAndroid Build Coastguard Worker const uint32_t accumulated_range = uint32_t(accumulated_max - accumulated_min);
1974*4bdc9457SAndroid Build Coastguard Worker const float c_scale = accumulated_range >= 256 ? double(accumulated_range) / 255.0 : 1.00001;
1975*4bdc9457SAndroid Build Coastguard Worker scale[n_index] = 1.0f / c_scale;
1976*4bdc9457SAndroid Build Coastguard Worker }
1977*4bdc9457SAndroid Build Coastguard Worker
1978*4bdc9457SAndroid Build Coastguard Worker if (extended_weights()) {
1979*4bdc9457SAndroid Build Coastguard Worker xnn_init_qc8_scale_fp32_params(
1980*4bdc9457SAndroid Build Coastguard Worker n(), nr(),
1981*4bdc9457SAndroid Build Coastguard Worker nr() * (packed_k() * sizeof(int16_t) + (sizeof(int32_t) + sizeof(float))), scale.data(),
1982*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) packed_xw.data() + nr() * (packed_k() * sizeof(int16_t) + sizeof(int32_t))));
1983*4bdc9457SAndroid Build Coastguard Worker } else {
1984*4bdc9457SAndroid Build Coastguard Worker xnn_init_qc8_scale_fp32_params(
1985*4bdc9457SAndroid Build Coastguard Worker n(), nr(),
1986*4bdc9457SAndroid Build Coastguard Worker nr() * (packed_k() * sizeof(int8_t) + (sizeof(int32_t) + sizeof(float))), scale.data(),
1987*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) packed_w.data() + nr() * (packed_k() * sizeof(int8_t) + sizeof(int32_t))));
1988*4bdc9457SAndroid Build Coastguard Worker }
1989*4bdc9457SAndroid Build Coastguard Worker
1990*4bdc9457SAndroid Build Coastguard Worker union xnn_qc8_conv_minmax_params minmax_params;
1991*4bdc9457SAndroid Build Coastguard Worker init_params(&minmax_params,
1992*4bdc9457SAndroid Build Coastguard Worker c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
1993*4bdc9457SAndroid Build Coastguard Worker
1994*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
1995*4bdc9457SAndroid Build Coastguard Worker struct xnn_code_buffer code_buffer;
1996*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_allocate_code_memory(&code_buffer, XNN_DEFAULT_CODE_BUFFER_SIZE));
1997*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, gemm_generator(&code_buffer, mr(), n() % nr(), k(), nullptr));
1998*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_finalize_code_memory(&code_buffer));
1999*4bdc9457SAndroid Build Coastguard Worker xnn_qc8_gemm_minmax_ukernel_function gemm = reinterpret_cast<xnn_qc8_gemm_minmax_ukernel_function>(code_buffer.start);
2000*4bdc9457SAndroid Build Coastguard Worker
2001*4bdc9457SAndroid Build Coastguard Worker gemm(
2002*4bdc9457SAndroid Build Coastguard Worker m(), n(), k(),
2003*4bdc9457SAndroid Build Coastguard Worker a.data(), a_stride() * sizeof(int8_t),
2004*4bdc9457SAndroid Build Coastguard Worker extended_weights() ? static_cast<const void*>(packed_xw.data()) : static_cast<const void*>(packed_w.data()),
2005*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
2006*4bdc9457SAndroid Build Coastguard Worker &minmax_params);
2007*4bdc9457SAndroid Build Coastguard Worker
2008*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_release_code_memory(&code_buffer));
2009*4bdc9457SAndroid Build Coastguard Worker
2010*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
2011*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
2012*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] = requantize(
2013*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index], scale[n_index], c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
2014*4bdc9457SAndroid Build Coastguard Worker }
2015*4bdc9457SAndroid Build Coastguard Worker }
2016*4bdc9457SAndroid Build Coastguard Worker
2017*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
2018*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
2019*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
2020*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
2021*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
2022*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << int32_t(c_ref[i * n() + j])
2023*4bdc9457SAndroid Build Coastguard Worker << " (accumulator = " << acc[i * n() + j]
2024*4bdc9457SAndroid Build Coastguard Worker << "), optimized = " << int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x "
2025*4bdc9457SAndroid Build Coastguard Worker << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
2026*4bdc9457SAndroid Build Coastguard Worker << ", requantization scale = " << scale[j] << ", output zero point = " << int32_t(c_zero_point);
2027*4bdc9457SAndroid Build Coastguard Worker }
2028*4bdc9457SAndroid Build Coastguard Worker }
2029*4bdc9457SAndroid Build Coastguard Worker }
2030*4bdc9457SAndroid Build Coastguard Worker }
2031*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_jit_igemm_code_generator_function igemm_generator,xnn_init_qc8_conv_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const2032*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(
2033*4bdc9457SAndroid Build Coastguard Worker xnn_jit_igemm_code_generator_function igemm_generator,
2034*4bdc9457SAndroid Build Coastguard Worker xnn_init_qc8_conv_minmax_params_fn init_params,
2035*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_requantize_fn requantize) const
2036*4bdc9457SAndroid Build Coastguard Worker {
2037*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
2038*4bdc9457SAndroid Build Coastguard Worker
2039*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
2040*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
2041*4bdc9457SAndroid Build Coastguard Worker auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
2042*4bdc9457SAndroid Build Coastguard Worker auto i8rng = std::bind(
2043*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
2044*4bdc9457SAndroid Build Coastguard Worker std::ref(rng));
2045*4bdc9457SAndroid Build Coastguard Worker auto w8rng = std::bind(
2046*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
2047*4bdc9457SAndroid Build Coastguard Worker std::ref(rng));
2048*4bdc9457SAndroid Build Coastguard Worker
2049*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
2050*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> b(n() * ks() * k());
2051*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(ks() * packed_n() * packed_k() + packed_n() * (sizeof(int32_t) + sizeof(float)) / sizeof(int8_t));
2052*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> bias(n());
2053*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
2054*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> acc(m() * n());
2055*4bdc9457SAndroid Build Coastguard Worker std::vector<float> scale(n());
2056*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> c_ref(m() * n());
2057*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> junk(k() + 8);
2058*4bdc9457SAndroid Build Coastguard Worker std::vector<const int8_t*> im2col(mr() * ks());
2059*4bdc9457SAndroid Build Coastguard Worker
2060*4bdc9457SAndroid Build Coastguard Worker std::fill(junk.begin(), junk.end(), 0xA5);
2061*4bdc9457SAndroid Build Coastguard Worker
2062*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
2063*4bdc9457SAndroid Build Coastguard Worker do {
2064*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), std::ref(i8rng));
2065*4bdc9457SAndroid Build Coastguard Worker } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
2066*4bdc9457SAndroid Build Coastguard Worker do {
2067*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), std::ref(w8rng));
2068*4bdc9457SAndroid Build Coastguard Worker } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
2069*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), std::ref(i32rng));
2070*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), 0xA5);
2071*4bdc9457SAndroid Build Coastguard Worker
2072*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), 0);
2073*4bdc9457SAndroid Build Coastguard Worker const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
2074*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qs8_conv_goki_w(
2075*4bdc9457SAndroid Build Coastguard Worker 1, n(), ks(), k(), nr(), kr(), sr(),
2076*4bdc9457SAndroid Build Coastguard Worker b.data(), bias.data(), packed_w.data(), nr() * sizeof(float), &packing_params);
2077*4bdc9457SAndroid Build Coastguard Worker
2078*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
2079*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < mr(); m_index++) {
2080*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
2081*4bdc9457SAndroid Build Coastguard Worker }
2082*4bdc9457SAndroid Build Coastguard Worker }
2083*4bdc9457SAndroid Build Coastguard Worker std::shuffle(im2col.begin(), im2col.end(), rng);
2084*4bdc9457SAndroid Build Coastguard Worker if (zero_index() != SIZE_MAX) {
2085*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
2086*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + zero_index()] = a.data();
2087*4bdc9457SAndroid Build Coastguard Worker }
2088*4bdc9457SAndroid Build Coastguard Worker }
2089*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
2090*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = m(); m_index < mr(); m_index++) {
2091*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + m_index] = junk.data();
2092*4bdc9457SAndroid Build Coastguard Worker }
2093*4bdc9457SAndroid Build Coastguard Worker }
2094*4bdc9457SAndroid Build Coastguard Worker
2095*4bdc9457SAndroid Build Coastguard Worker // Compute 32-bit results and output quantization arguments.
2096*4bdc9457SAndroid Build Coastguard Worker std::fill(acc.begin(), acc.end(), 0);
2097*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
2098*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
2099*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
2100*4bdc9457SAndroid Build Coastguard Worker for (size_t k_index = 0; k_index < k(); k_index++) {
2101*4bdc9457SAndroid Build Coastguard Worker if (im2col[ks_index * mr() + m_index] == a.data()) {
2102*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] +=
2103*4bdc9457SAndroid Build Coastguard Worker (int32_t(im2col[ks_index * mr() + m_index][k_index]) - int32_t(a_zero_point() - 0x80)) *
2104*4bdc9457SAndroid Build Coastguard Worker int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
2105*4bdc9457SAndroid Build Coastguard Worker } else {
2106*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] +=
2107*4bdc9457SAndroid Build Coastguard Worker (int32_t(im2col[ks_index * mr() + m_index][k_index + a_offset()]) - int32_t(a_zero_point() - 0x80)) *
2108*4bdc9457SAndroid Build Coastguard Worker int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
2109*4bdc9457SAndroid Build Coastguard Worker }
2110*4bdc9457SAndroid Build Coastguard Worker }
2111*4bdc9457SAndroid Build Coastguard Worker }
2112*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] += bias[n_index];
2113*4bdc9457SAndroid Build Coastguard Worker }
2114*4bdc9457SAndroid Build Coastguard Worker }
2115*4bdc9457SAndroid Build Coastguard Worker
2116*4bdc9457SAndroid Build Coastguard Worker const int8_t c_zero_point = -1;
2117*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
2118*4bdc9457SAndroid Build Coastguard Worker int32_t accumulated_min = acc[n_index];
2119*4bdc9457SAndroid Build Coastguard Worker int32_t accumulated_max = acc[n_index];
2120*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
2121*4bdc9457SAndroid Build Coastguard Worker accumulated_min = std::min(accumulated_min, acc[m_index * n() + n_index]);
2122*4bdc9457SAndroid Build Coastguard Worker accumulated_max = std::max(accumulated_max, acc[m_index * n() + n_index]);
2123*4bdc9457SAndroid Build Coastguard Worker }
2124*4bdc9457SAndroid Build Coastguard Worker const uint32_t accumulated_range = uint32_t(accumulated_max - accumulated_min);
2125*4bdc9457SAndroid Build Coastguard Worker const float c_scale = accumulated_range >= 256 ? double(accumulated_range) / 255.0 : 1.00001;
2126*4bdc9457SAndroid Build Coastguard Worker scale[n_index] = 1.0f / c_scale;
2127*4bdc9457SAndroid Build Coastguard Worker }
2128*4bdc9457SAndroid Build Coastguard Worker
2129*4bdc9457SAndroid Build Coastguard Worker xnn_init_qc8_scale_fp32_params(
2130*4bdc9457SAndroid Build Coastguard Worker n(), nr(),
2131*4bdc9457SAndroid Build Coastguard Worker nr() * (ks() * packed_k() * sizeof(int8_t) + (sizeof(int32_t) + sizeof(float))), scale.data(),
2132*4bdc9457SAndroid Build Coastguard Worker (void*) ((uintptr_t) packed_w.data() + nr() * (ks() * packed_k() * sizeof(int8_t) + sizeof(int32_t))));
2133*4bdc9457SAndroid Build Coastguard Worker
2134*4bdc9457SAndroid Build Coastguard Worker union xnn_qc8_conv_minmax_params minmax_params;
2135*4bdc9457SAndroid Build Coastguard Worker init_params(&minmax_params,
2136*4bdc9457SAndroid Build Coastguard Worker c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
2137*4bdc9457SAndroid Build Coastguard Worker
2138*4bdc9457SAndroid Build Coastguard Worker const int8_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
2139*4bdc9457SAndroid Build Coastguard Worker
2140*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
2141*4bdc9457SAndroid Build Coastguard Worker struct xnn_code_buffer code_buffer;
2142*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_allocate_code_memory(&code_buffer, XNN_DEFAULT_CODE_BUFFER_SIZE));
2143*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, igemm_generator(&code_buffer, mr(), n() % nr(), k(), ks() * mr() * sizeof(void *), nullptr));
2144*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_finalize_code_memory(&code_buffer));
2145*4bdc9457SAndroid Build Coastguard Worker xnn_qc8_igemm_minmax_ukernel_function igemm = reinterpret_cast<xnn_qc8_igemm_minmax_ukernel_function>(code_buffer.start);
2146*4bdc9457SAndroid Build Coastguard Worker
2147*4bdc9457SAndroid Build Coastguard Worker igemm(
2148*4bdc9457SAndroid Build Coastguard Worker m(), n(), k(), ks() * mr() * sizeof(void*),
2149*4bdc9457SAndroid Build Coastguard Worker im2col.data(), packed_w.data(),
2150*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
2151*4bdc9457SAndroid Build Coastguard Worker a_offset() * sizeof(uint8_t), zero_pointer,
2152*4bdc9457SAndroid Build Coastguard Worker &minmax_params);
2153*4bdc9457SAndroid Build Coastguard Worker
2154*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_release_code_memory(&code_buffer));
2155*4bdc9457SAndroid Build Coastguard Worker
2156*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
2157*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
2158*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] = requantize(
2159*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index], scale[n_index], c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
2160*4bdc9457SAndroid Build Coastguard Worker }
2161*4bdc9457SAndroid Build Coastguard Worker }
2162*4bdc9457SAndroid Build Coastguard Worker
2163*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
2164*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
2165*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
2166*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
2167*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
2168*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j])
2169*4bdc9457SAndroid Build Coastguard Worker << " (accumulator = " << acc[i * n() + j]
2170*4bdc9457SAndroid Build Coastguard Worker << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
2171*4bdc9457SAndroid Build Coastguard Worker << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
2172*4bdc9457SAndroid Build Coastguard Worker << ", requantization scale = " << scale[j] << ", output zero point = " << int32_t(c_zero_point);
2173*4bdc9457SAndroid Build Coastguard Worker }
2174*4bdc9457SAndroid Build Coastguard Worker }
2175*4bdc9457SAndroid Build Coastguard Worker }
2176*4bdc9457SAndroid Build Coastguard Worker }
2177*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_jit_gemm_code_generator_function gemm_generator,xnn_init_qs8_conv_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const2178*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(
2179*4bdc9457SAndroid Build Coastguard Worker xnn_jit_gemm_code_generator_function gemm_generator,
2180*4bdc9457SAndroid Build Coastguard Worker xnn_init_qs8_conv_minmax_params_fn init_params,
2181*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_requantize_fn requantize) const
2182*4bdc9457SAndroid Build Coastguard Worker {
2183*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
2184*4bdc9457SAndroid Build Coastguard Worker
2185*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
2186*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
2187*4bdc9457SAndroid Build Coastguard Worker auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
2188*4bdc9457SAndroid Build Coastguard Worker auto i8rng = std::bind(
2189*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
2190*4bdc9457SAndroid Build Coastguard Worker std::ref(rng));
2191*4bdc9457SAndroid Build Coastguard Worker auto w8rng = std::bind(
2192*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
2193*4bdc9457SAndroid Build Coastguard Worker std::ref(rng));
2194*4bdc9457SAndroid Build Coastguard Worker
2195*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(int8_t));
2196*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> b(n() * k());
2197*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> bias(n());
2198*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int8_t));
2199*4bdc9457SAndroid Build Coastguard Worker std::vector<int16_t, AlignedAllocator<int16_t, 64>> packed_xw(packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int16_t));
2200*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
2201*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> acc(m() * n());
2202*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> c_ref(m() * n());
2203*4bdc9457SAndroid Build Coastguard Worker
2204*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
2205*4bdc9457SAndroid Build Coastguard Worker do {
2206*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), std::ref(i8rng));
2207*4bdc9457SAndroid Build Coastguard Worker } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
2208*4bdc9457SAndroid Build Coastguard Worker do {
2209*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), std::ref(w8rng));
2210*4bdc9457SAndroid Build Coastguard Worker } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
2211*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), std::ref(i32rng));
2212*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), 0xA5);
2213*4bdc9457SAndroid Build Coastguard Worker
2214*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), 0);
2215*4bdc9457SAndroid Build Coastguard Worker const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
2216*4bdc9457SAndroid Build Coastguard Worker if (extended_weights()) {
2217*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qs8_gemm_xw_goi_w(1, n(), k(), nr(), kr(), sr(),
2218*4bdc9457SAndroid Build Coastguard Worker b.data(), bias.data(), packed_xw.data(), 0, &packing_params);
2219*4bdc9457SAndroid Build Coastguard Worker } else {
2220*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qs8_gemm_goi_w(1, n(), k(), nr(), kr(), sr(),
2221*4bdc9457SAndroid Build Coastguard Worker b.data(), bias.data(), packed_w.data(), 0, &packing_params);
2222*4bdc9457SAndroid Build Coastguard Worker }
2223*4bdc9457SAndroid Build Coastguard Worker
2224*4bdc9457SAndroid Build Coastguard Worker // Compute 32-bit results and output quantization arguments.
2225*4bdc9457SAndroid Build Coastguard Worker std::fill(acc.begin(), acc.end(), 0);
2226*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
2227*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
2228*4bdc9457SAndroid Build Coastguard Worker for (size_t k_index = 0; k_index < k(); k_index++) {
2229*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] +=
2230*4bdc9457SAndroid Build Coastguard Worker (int32_t(a[m_index * a_stride() + k_index]) - int32_t(a_zero_point() - 0x80)) *
2231*4bdc9457SAndroid Build Coastguard Worker int32_t(b[n_index * k() + k_index]);
2232*4bdc9457SAndroid Build Coastguard Worker }
2233*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] += bias[n_index];
2234*4bdc9457SAndroid Build Coastguard Worker }
2235*4bdc9457SAndroid Build Coastguard Worker }
2236*4bdc9457SAndroid Build Coastguard Worker
2237*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
2238*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
2239*4bdc9457SAndroid Build Coastguard Worker const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
2240*4bdc9457SAndroid Build Coastguard Worker const int8_t c_zero_point = int8_t(std::max(std::min(
2241*4bdc9457SAndroid Build Coastguard Worker lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
2242*4bdc9457SAndroid Build Coastguard Worker long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
2243*4bdc9457SAndroid Build Coastguard Worker
2244*4bdc9457SAndroid Build Coastguard Worker const float requantization_scale = 1.0f / float(c_scale);
2245*4bdc9457SAndroid Build Coastguard Worker union xnn_qs8_conv_minmax_params quantization_params;
2246*4bdc9457SAndroid Build Coastguard Worker init_params(&quantization_params,
2247*4bdc9457SAndroid Build Coastguard Worker requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
2248*4bdc9457SAndroid Build Coastguard Worker
2249*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
2250*4bdc9457SAndroid Build Coastguard Worker struct xnn_code_buffer code_buffer;
2251*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_allocate_code_memory(&code_buffer, XNN_DEFAULT_CODE_BUFFER_SIZE));
2252*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, gemm_generator(&code_buffer, mr(), n() % nr(), k(), nullptr));
2253*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_finalize_code_memory(&code_buffer));
2254*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_gemm_minmax_ukernel_function gemm = reinterpret_cast<xnn_qs8_gemm_minmax_ukernel_function >(code_buffer.start);
2255*4bdc9457SAndroid Build Coastguard Worker
2256*4bdc9457SAndroid Build Coastguard Worker gemm(
2257*4bdc9457SAndroid Build Coastguard Worker m(), n(), k(),
2258*4bdc9457SAndroid Build Coastguard Worker a.data(), a_stride() * sizeof(int8_t),
2259*4bdc9457SAndroid Build Coastguard Worker extended_weights() ? static_cast<const void*>(packed_xw.data()) : static_cast<const void*>(packed_w.data()),
2260*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
2261*4bdc9457SAndroid Build Coastguard Worker &quantization_params);
2262*4bdc9457SAndroid Build Coastguard Worker
2263*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_release_code_memory(&code_buffer));
2264*4bdc9457SAndroid Build Coastguard Worker
2265*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
2266*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
2267*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] = requantize(
2268*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index], requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
2269*4bdc9457SAndroid Build Coastguard Worker }
2270*4bdc9457SAndroid Build Coastguard Worker }
2271*4bdc9457SAndroid Build Coastguard Worker
2272*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
2273*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
2274*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
2275*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
2276*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
2277*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << int32_t(c_ref[i * n() + j])
2278*4bdc9457SAndroid Build Coastguard Worker << " (accumulator = " << acc[i * n() + j]
2279*4bdc9457SAndroid Build Coastguard Worker << "), optimized = " << int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x "
2280*4bdc9457SAndroid Build Coastguard Worker << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
2281*4bdc9457SAndroid Build Coastguard Worker << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
2282*4bdc9457SAndroid Build Coastguard Worker }
2283*4bdc9457SAndroid Build Coastguard Worker }
2284*4bdc9457SAndroid Build Coastguard Worker }
2285*4bdc9457SAndroid Build Coastguard Worker }
2286*4bdc9457SAndroid Build Coastguard Worker
Test(xnn_jit_igemm_code_generator_function igemm_generator,xnn_init_qs8_conv_minmax_params_fn init_params,xnn_qs8_requantize_fn requantize) const2287*4bdc9457SAndroid Build Coastguard Worker void GemmMicrokernelTester::Test(
2288*4bdc9457SAndroid Build Coastguard Worker xnn_jit_igemm_code_generator_function igemm_generator,
2289*4bdc9457SAndroid Build Coastguard Worker xnn_init_qs8_conv_minmax_params_fn init_params,
2290*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_requantize_fn requantize) const
2291*4bdc9457SAndroid Build Coastguard Worker {
2292*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(m(), mr());
2293*4bdc9457SAndroid Build Coastguard Worker
2294*4bdc9457SAndroid Build Coastguard Worker std::random_device random_device;
2295*4bdc9457SAndroid Build Coastguard Worker auto rng = std::mt19937(random_device());
2296*4bdc9457SAndroid Build Coastguard Worker auto i32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), std::ref(rng));
2297*4bdc9457SAndroid Build Coastguard Worker auto i8rng = std::bind(
2298*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(std::numeric_limits<int8_t>::min(), std::numeric_limits<int8_t>::max()),
2299*4bdc9457SAndroid Build Coastguard Worker std::ref(rng));
2300*4bdc9457SAndroid Build Coastguard Worker auto w8rng = std::bind(
2301*4bdc9457SAndroid Build Coastguard Worker std::uniform_int_distribution<int32_t>(-std::numeric_limits<int8_t>::max(), std::numeric_limits<int8_t>::max()),
2302*4bdc9457SAndroid Build Coastguard Worker std::ref(rng));
2303*4bdc9457SAndroid Build Coastguard Worker
2304*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
2305*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> b(n() * ks() * k());
2306*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t, AlignedAllocator<int8_t, 64>> packed_w(ks() * packed_n() * packed_k() + packed_n() * sizeof(int32_t) / sizeof(int8_t));
2307*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> bias(n());
2308*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
2309*4bdc9457SAndroid Build Coastguard Worker std::vector<int32_t> acc(m() * n());
2310*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> c_ref(m() * n());
2311*4bdc9457SAndroid Build Coastguard Worker std::vector<int8_t> junk(k() + 8);
2312*4bdc9457SAndroid Build Coastguard Worker std::vector<const int8_t*> im2col(mr() * ks());
2313*4bdc9457SAndroid Build Coastguard Worker
2314*4bdc9457SAndroid Build Coastguard Worker std::fill(junk.begin(), junk.end(), 0xA5);
2315*4bdc9457SAndroid Build Coastguard Worker
2316*4bdc9457SAndroid Build Coastguard Worker for (size_t iteration = 0; iteration < iterations(); iteration++) {
2317*4bdc9457SAndroid Build Coastguard Worker do {
2318*4bdc9457SAndroid Build Coastguard Worker std::generate(a.begin(), a.end(), std::ref(i8rng));
2319*4bdc9457SAndroid Build Coastguard Worker } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
2320*4bdc9457SAndroid Build Coastguard Worker do {
2321*4bdc9457SAndroid Build Coastguard Worker std::generate(b.begin(), b.end(), std::ref(w8rng));
2322*4bdc9457SAndroid Build Coastguard Worker } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
2323*4bdc9457SAndroid Build Coastguard Worker std::generate(bias.begin(), bias.end(), std::ref(i32rng));
2324*4bdc9457SAndroid Build Coastguard Worker std::fill(c.begin(), c.end(), 0xA5);
2325*4bdc9457SAndroid Build Coastguard Worker
2326*4bdc9457SAndroid Build Coastguard Worker std::fill(packed_w.begin(), packed_w.end(), 0);
2327*4bdc9457SAndroid Build Coastguard Worker const xnn_qs8_packing_params packing_params = { int8_t(a_zero_point() - 0x80) };
2328*4bdc9457SAndroid Build Coastguard Worker xnn_pack_qs8_conv_goki_w(
2329*4bdc9457SAndroid Build Coastguard Worker 1, n(), ks(), k(), nr(), kr(), sr(),
2330*4bdc9457SAndroid Build Coastguard Worker b.data(), bias.data(), packed_w.data(), 0 /* extra bytes */, &packing_params);
2331*4bdc9457SAndroid Build Coastguard Worker
2332*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
2333*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < mr(); m_index++) {
2334*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
2335*4bdc9457SAndroid Build Coastguard Worker }
2336*4bdc9457SAndroid Build Coastguard Worker }
2337*4bdc9457SAndroid Build Coastguard Worker std::shuffle(im2col.begin(), im2col.end(), rng);
2338*4bdc9457SAndroid Build Coastguard Worker if (zero_index() != SIZE_MAX) {
2339*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
2340*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + zero_index()] = a.data();
2341*4bdc9457SAndroid Build Coastguard Worker }
2342*4bdc9457SAndroid Build Coastguard Worker }
2343*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
2344*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = m(); m_index < mr(); m_index++) {
2345*4bdc9457SAndroid Build Coastguard Worker im2col[ks_index * mr() + m_index] = junk.data();
2346*4bdc9457SAndroid Build Coastguard Worker }
2347*4bdc9457SAndroid Build Coastguard Worker }
2348*4bdc9457SAndroid Build Coastguard Worker
2349*4bdc9457SAndroid Build Coastguard Worker // Compute 32-bit results and output quantization arguments.
2350*4bdc9457SAndroid Build Coastguard Worker std::fill(acc.begin(), acc.end(), 0);
2351*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
2352*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
2353*4bdc9457SAndroid Build Coastguard Worker for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
2354*4bdc9457SAndroid Build Coastguard Worker for (size_t k_index = 0; k_index < k(); k_index++) {
2355*4bdc9457SAndroid Build Coastguard Worker if (im2col[ks_index * mr() + m_index] == a.data()) {
2356*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] +=
2357*4bdc9457SAndroid Build Coastguard Worker (int32_t(im2col[ks_index * mr() + m_index][k_index]) - int32_t(a_zero_point() - 0x80)) *
2358*4bdc9457SAndroid Build Coastguard Worker int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
2359*4bdc9457SAndroid Build Coastguard Worker } else {
2360*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] +=
2361*4bdc9457SAndroid Build Coastguard Worker (int32_t(im2col[ks_index * mr() + m_index][k_index + a_offset()]) - int32_t(a_zero_point() - 0x80)) *
2362*4bdc9457SAndroid Build Coastguard Worker int32_t(b[(n_index * ks() + ks_index) * k() + k_index]);
2363*4bdc9457SAndroid Build Coastguard Worker }
2364*4bdc9457SAndroid Build Coastguard Worker }
2365*4bdc9457SAndroid Build Coastguard Worker }
2366*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index] += bias[n_index];
2367*4bdc9457SAndroid Build Coastguard Worker }
2368*4bdc9457SAndroid Build Coastguard Worker }
2369*4bdc9457SAndroid Build Coastguard Worker
2370*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
2371*4bdc9457SAndroid Build Coastguard Worker const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
2372*4bdc9457SAndroid Build Coastguard Worker const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
2373*4bdc9457SAndroid Build Coastguard Worker const uint8_t c_zero_point = uint8_t(std::max(std::min(
2374*4bdc9457SAndroid Build Coastguard Worker lrint(-0.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
2375*4bdc9457SAndroid Build Coastguard Worker long(std::numeric_limits<int8_t>::max())), long(std::numeric_limits<int8_t>::min())));
2376*4bdc9457SAndroid Build Coastguard Worker
2377*4bdc9457SAndroid Build Coastguard Worker const float requantization_scale = 1.0f / float(c_scale);
2378*4bdc9457SAndroid Build Coastguard Worker union xnn_qs8_conv_minmax_params quantization_params;
2379*4bdc9457SAndroid Build Coastguard Worker init_params(&quantization_params,
2380*4bdc9457SAndroid Build Coastguard Worker requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
2381*4bdc9457SAndroid Build Coastguard Worker
2382*4bdc9457SAndroid Build Coastguard Worker const int8_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
2383*4bdc9457SAndroid Build Coastguard Worker
2384*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));
2385*4bdc9457SAndroid Build Coastguard Worker struct xnn_code_buffer code_buffer;
2386*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_allocate_code_memory(&code_buffer, XNN_DEFAULT_CODE_BUFFER_SIZE));
2387*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, igemm_generator(&code_buffer, mr(), n() % nr(), k(), ks() * mr() * sizeof(void *), nullptr));
2388*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_finalize_code_memory(&code_buffer));
2389*4bdc9457SAndroid Build Coastguard Worker xnn_qs8_igemm_minmax_ukernel_function igemm = reinterpret_cast<xnn_qs8_igemm_minmax_ukernel_function>(code_buffer.start);
2390*4bdc9457SAndroid Build Coastguard Worker
2391*4bdc9457SAndroid Build Coastguard Worker igemm(
2392*4bdc9457SAndroid Build Coastguard Worker m(), n(), k(), ks() * mr() * sizeof(void*),
2393*4bdc9457SAndroid Build Coastguard Worker im2col.data(), packed_w.data(),
2394*4bdc9457SAndroid Build Coastguard Worker c.data(), cm_stride() * sizeof(int8_t), cn_stride() * sizeof(int8_t),
2395*4bdc9457SAndroid Build Coastguard Worker a_offset() * sizeof(uint8_t), zero_pointer,
2396*4bdc9457SAndroid Build Coastguard Worker &quantization_params);
2397*4bdc9457SAndroid Build Coastguard Worker
2398*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(xnn_status_success, xnn_release_code_memory(&code_buffer));
2399*4bdc9457SAndroid Build Coastguard Worker
2400*4bdc9457SAndroid Build Coastguard Worker for (size_t m_index = 0; m_index < m(); m_index++) {
2401*4bdc9457SAndroid Build Coastguard Worker for (size_t n_index = 0; n_index < n(); n_index++) {
2402*4bdc9457SAndroid Build Coastguard Worker c_ref[m_index * n() + n_index] = requantize(
2403*4bdc9457SAndroid Build Coastguard Worker acc[m_index * n() + n_index], requantization_scale, c_zero_point, int8_t(qmin() - 0x80), int8_t(qmax() - 0x80));
2404*4bdc9457SAndroid Build Coastguard Worker }
2405*4bdc9457SAndroid Build Coastguard Worker }
2406*4bdc9457SAndroid Build Coastguard Worker
2407*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < m(); i++) {
2408*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < n(); j++) {
2409*4bdc9457SAndroid Build Coastguard Worker ASSERT_LE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmax()) - 0x80);
2410*4bdc9457SAndroid Build Coastguard Worker ASSERT_GE(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(qmin()) - 0x80);
2411*4bdc9457SAndroid Build Coastguard Worker ASSERT_EQ(int32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), int32_t(c_ref[i * n() + j]))
2412*4bdc9457SAndroid Build Coastguard Worker << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j])
2413*4bdc9457SAndroid Build Coastguard Worker << " (accumulator = " << acc[i * n() + j]
2414*4bdc9457SAndroid Build Coastguard Worker << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
2415*4bdc9457SAndroid Build Coastguard Worker << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
2416*4bdc9457SAndroid Build Coastguard Worker << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
2417*4bdc9457SAndroid Build Coastguard Worker }
2418*4bdc9457SAndroid Build Coastguard Worker }
2419*4bdc9457SAndroid Build Coastguard Worker }
2420*4bdc9457SAndroid Build Coastguard Worker }
2421*4bdc9457SAndroid Build Coastguard Worker
2422*4bdc9457SAndroid Build Coastguard Worker #endif // XNN_PLATFORM_JIT
2423