xref: /aosp_15_r20/external/XNNPACK/test/vlog-microkernel-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #pragma once
7*4bdc9457SAndroid Build Coastguard Worker 
8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
9*4bdc9457SAndroid Build Coastguard Worker 
10*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
11*4bdc9457SAndroid Build Coastguard Worker #include <cassert>
12*4bdc9457SAndroid Build Coastguard Worker #include <cmath>
13*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
14*4bdc9457SAndroid Build Coastguard Worker #include <cstdlib>
15*4bdc9457SAndroid Build Coastguard Worker #include <random>
16*4bdc9457SAndroid Build Coastguard Worker #include <vector>
17*4bdc9457SAndroid Build Coastguard Worker 
18*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
19*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/aligned-allocator.h>
20*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/math.h>
21*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/microfnptr.h>
22*4bdc9457SAndroid Build Coastguard Worker 
23*4bdc9457SAndroid Build Coastguard Worker 
24*4bdc9457SAndroid Build Coastguard Worker extern XNN_INTERNAL const uint16_t xnn_table_vlog[129];
25*4bdc9457SAndroid Build Coastguard Worker 
26*4bdc9457SAndroid Build Coastguard Worker class VLogMicrokernelTester {
27*4bdc9457SAndroid Build Coastguard Worker  public:
batch(size_t batch)28*4bdc9457SAndroid Build Coastguard Worker   inline VLogMicrokernelTester& batch(size_t batch) {
29*4bdc9457SAndroid Build Coastguard Worker     assert(batch != 0);
30*4bdc9457SAndroid Build Coastguard Worker     this->batch_ = batch;
31*4bdc9457SAndroid Build Coastguard Worker     return *this;
32*4bdc9457SAndroid Build Coastguard Worker   }
33*4bdc9457SAndroid Build Coastguard Worker 
batch()34*4bdc9457SAndroid Build Coastguard Worker   inline size_t batch() const {
35*4bdc9457SAndroid Build Coastguard Worker     return this->batch_;
36*4bdc9457SAndroid Build Coastguard Worker   }
37*4bdc9457SAndroid Build Coastguard Worker 
input_lshift(uint32_t input_lshift)38*4bdc9457SAndroid Build Coastguard Worker   inline VLogMicrokernelTester& input_lshift(uint32_t input_lshift) {
39*4bdc9457SAndroid Build Coastguard Worker     assert(input_lshift < 32);
40*4bdc9457SAndroid Build Coastguard Worker     this->input_lshift_ = input_lshift;
41*4bdc9457SAndroid Build Coastguard Worker     return *this;
42*4bdc9457SAndroid Build Coastguard Worker   }
43*4bdc9457SAndroid Build Coastguard Worker 
input_lshift()44*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t input_lshift() const {
45*4bdc9457SAndroid Build Coastguard Worker     return this->input_lshift_;
46*4bdc9457SAndroid Build Coastguard Worker   }
47*4bdc9457SAndroid Build Coastguard Worker 
output_scale(uint32_t output_scale)48*4bdc9457SAndroid Build Coastguard Worker   inline VLogMicrokernelTester& output_scale(uint32_t output_scale) {
49*4bdc9457SAndroid Build Coastguard Worker     this->output_scale_ = output_scale;
50*4bdc9457SAndroid Build Coastguard Worker     return *this;
51*4bdc9457SAndroid Build Coastguard Worker   }
52*4bdc9457SAndroid Build Coastguard Worker 
output_scale()53*4bdc9457SAndroid Build Coastguard Worker   inline uint32_t output_scale() const {
54*4bdc9457SAndroid Build Coastguard Worker     return this->output_scale_;
55*4bdc9457SAndroid Build Coastguard Worker   }
56*4bdc9457SAndroid Build Coastguard Worker 
inplace(bool inplace)57*4bdc9457SAndroid Build Coastguard Worker   inline VLogMicrokernelTester& inplace(bool inplace) {
58*4bdc9457SAndroid Build Coastguard Worker     this->inplace_ = inplace;
59*4bdc9457SAndroid Build Coastguard Worker     return *this;
60*4bdc9457SAndroid Build Coastguard Worker   }
61*4bdc9457SAndroid Build Coastguard Worker 
inplace()62*4bdc9457SAndroid Build Coastguard Worker   inline bool inplace() const {
63*4bdc9457SAndroid Build Coastguard Worker     return this->inplace_;
64*4bdc9457SAndroid Build Coastguard Worker   }
65*4bdc9457SAndroid Build Coastguard Worker 
iterations(size_t iterations)66*4bdc9457SAndroid Build Coastguard Worker   inline VLogMicrokernelTester& iterations(size_t iterations) {
67*4bdc9457SAndroid Build Coastguard Worker     this->iterations_ = iterations;
68*4bdc9457SAndroid Build Coastguard Worker     return *this;
69*4bdc9457SAndroid Build Coastguard Worker   }
70*4bdc9457SAndroid Build Coastguard Worker 
iterations()71*4bdc9457SAndroid Build Coastguard Worker   inline size_t iterations() const {
72*4bdc9457SAndroid Build Coastguard Worker     return this->iterations_;
73*4bdc9457SAndroid Build Coastguard Worker   }
74*4bdc9457SAndroid Build Coastguard Worker 
Test(xnn_u32_vlog_ukernel_function vlog)75*4bdc9457SAndroid Build Coastguard Worker   void Test(xnn_u32_vlog_ukernel_function vlog) const {
76*4bdc9457SAndroid Build Coastguard Worker     std::random_device random_device;
77*4bdc9457SAndroid Build Coastguard Worker     auto rng = std::mt19937(random_device());
78*4bdc9457SAndroid Build Coastguard Worker     auto i16rng = std::bind(std::uniform_int_distribution<uint16_t>(), std::ref(rng));
79*4bdc9457SAndroid Build Coastguard Worker     auto i32rng = std::bind(std::uniform_int_distribution<uint32_t>(), std::ref(rng));
80*4bdc9457SAndroid Build Coastguard Worker 
81*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint32_t> x(batch() + XNN_EXTRA_BYTES / sizeof(uint32_t));
82*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> y(batch() * (inplace() ? sizeof(uint32_t) / sizeof(uint16_t) : 1) + XNN_EXTRA_BYTES / sizeof(uint32_t));
83*4bdc9457SAndroid Build Coastguard Worker     std::vector<uint16_t> y_ref(batch());
84*4bdc9457SAndroid Build Coastguard Worker     const uint32_t* x_data = inplace() ? reinterpret_cast<const uint32_t*>(y.data()) : x.data();
85*4bdc9457SAndroid Build Coastguard Worker 
86*4bdc9457SAndroid Build Coastguard Worker     for (size_t iteration = 0; iteration < iterations(); iteration++) {
87*4bdc9457SAndroid Build Coastguard Worker       std::generate(x.begin(), x.end(), std::ref(i32rng));
88*4bdc9457SAndroid Build Coastguard Worker       std::generate(y.begin(), y.end(), std::ref(i16rng));
89*4bdc9457SAndroid Build Coastguard Worker       std::generate(y_ref.begin(), y_ref.end(), std::ref(i16rng));
90*4bdc9457SAndroid Build Coastguard Worker 
91*4bdc9457SAndroid Build Coastguard Worker       // Compute reference results.
92*4bdc9457SAndroid Build Coastguard Worker       for (size_t n = 0; n < batch(); n++) {
93*4bdc9457SAndroid Build Coastguard Worker         const uint32_t x_value = x_data[n];
94*4bdc9457SAndroid Build Coastguard Worker         const uint32_t scaled = x_value << input_lshift();
95*4bdc9457SAndroid Build Coastguard Worker         uint32_t log_value = 0;
96*4bdc9457SAndroid Build Coastguard Worker         if (scaled != 0) {
97*4bdc9457SAndroid Build Coastguard Worker           const uint32_t out_scale = output_scale();
98*4bdc9457SAndroid Build Coastguard Worker 
99*4bdc9457SAndroid Build Coastguard Worker           const int log_scale = 65536;
100*4bdc9457SAndroid Build Coastguard Worker           const int log_scale_log2 = 16;
101*4bdc9457SAndroid Build Coastguard Worker           const int log_coeff = 45426;
102*4bdc9457SAndroid Build Coastguard Worker           const uint32_t log2x = math_clz_nonzero_u32(scaled) ^ 31;  // log2 of scaled
103*4bdc9457SAndroid Build Coastguard Worker           assert(log2x < 32);
104*4bdc9457SAndroid Build Coastguard Worker 
105*4bdc9457SAndroid Build Coastguard Worker           // Number of segments in the log lookup table. The table will be log_segments+1
106*4bdc9457SAndroid Build Coastguard Worker           // in length (with some padding).
107*4bdc9457SAndroid Build Coastguard Worker           const int log_segments_log2 = 7;
108*4bdc9457SAndroid Build Coastguard Worker 
109*4bdc9457SAndroid Build Coastguard Worker           // Part 1
110*4bdc9457SAndroid Build Coastguard Worker           uint32_t frac = scaled - (UINT32_C(1) << log2x);
111*4bdc9457SAndroid Build Coastguard Worker 
112*4bdc9457SAndroid Build Coastguard Worker           // Shift the fractional part into msb of 16 bits
113*4bdc9457SAndroid Build Coastguard Worker           frac =  XNN_UNPREDICTABLE(log2x < log_scale_log2) ?
114*4bdc9457SAndroid Build Coastguard Worker               (frac << (log_scale_log2 - log2x)) :
115*4bdc9457SAndroid Build Coastguard Worker               (frac >> (log2x - log_scale_log2));
116*4bdc9457SAndroid Build Coastguard Worker 
117*4bdc9457SAndroid Build Coastguard Worker           // Part 2
118*4bdc9457SAndroid Build Coastguard Worker           const uint32_t base_seg = frac >> (log_scale_log2 - log_segments_log2);
119*4bdc9457SAndroid Build Coastguard Worker           const uint32_t seg_unit = (UINT32_C(1) << log_scale_log2) >> log_segments_log2;
120*4bdc9457SAndroid Build Coastguard Worker 
121*4bdc9457SAndroid Build Coastguard Worker           assert(128 == (1 << log_segments_log2));
122*4bdc9457SAndroid Build Coastguard Worker           assert(base_seg < (1 << log_segments_log2));
123*4bdc9457SAndroid Build Coastguard Worker 
124*4bdc9457SAndroid Build Coastguard Worker           const uint32_t c0 = xnn_table_vlog[base_seg];
125*4bdc9457SAndroid Build Coastguard Worker           const uint32_t c1 = xnn_table_vlog[base_seg + 1];
126*4bdc9457SAndroid Build Coastguard Worker           const uint32_t seg_base = seg_unit * base_seg;
127*4bdc9457SAndroid Build Coastguard Worker           const uint32_t rel_pos = ((c1 - c0) * (frac - seg_base)) >> log_scale_log2;
128*4bdc9457SAndroid Build Coastguard Worker           const uint32_t fraction =  frac + c0 + rel_pos;
129*4bdc9457SAndroid Build Coastguard Worker 
130*4bdc9457SAndroid Build Coastguard Worker           const uint32_t log2 = (log2x << log_scale_log2) + fraction;
131*4bdc9457SAndroid Build Coastguard Worker           const uint32_t round = log_scale / 2;
132*4bdc9457SAndroid Build Coastguard Worker           const uint32_t loge = (((uint64_t) log_coeff) * log2 + round) >> log_scale_log2;
133*4bdc9457SAndroid Build Coastguard Worker 
134*4bdc9457SAndroid Build Coastguard Worker           // Finally scale to our output scale
135*4bdc9457SAndroid Build Coastguard Worker           log_value = (out_scale * loge + round) >> log_scale_log2;
136*4bdc9457SAndroid Build Coastguard Worker         }
137*4bdc9457SAndroid Build Coastguard Worker 
138*4bdc9457SAndroid Build Coastguard Worker         const uint32_t vout = math_min_u32(log_value, (uint32_t) INT16_MAX);
139*4bdc9457SAndroid Build Coastguard Worker         y_ref[n] = vout;
140*4bdc9457SAndroid Build Coastguard Worker       }
141*4bdc9457SAndroid Build Coastguard Worker 
142*4bdc9457SAndroid Build Coastguard Worker       // Call optimized micro-kernel.
143*4bdc9457SAndroid Build Coastguard Worker       vlog(batch(), x_data, input_lshift(), output_scale(), y.data());
144*4bdc9457SAndroid Build Coastguard Worker 
145*4bdc9457SAndroid Build Coastguard Worker       // Verify results.
146*4bdc9457SAndroid Build Coastguard Worker       for (size_t n = 0; n < batch(); n++) {
147*4bdc9457SAndroid Build Coastguard Worker         ASSERT_EQ(y[n], y_ref[n])
148*4bdc9457SAndroid Build Coastguard Worker           << ", input_lshift " << input_lshift()
149*4bdc9457SAndroid Build Coastguard Worker           << ", output_scale " << output_scale()
150*4bdc9457SAndroid Build Coastguard Worker           << ", batch " << n << " / " << batch();
151*4bdc9457SAndroid Build Coastguard Worker       }
152*4bdc9457SAndroid Build Coastguard Worker     }
153*4bdc9457SAndroid Build Coastguard Worker   }
154*4bdc9457SAndroid Build Coastguard Worker 
155*4bdc9457SAndroid Build Coastguard Worker  private:
156*4bdc9457SAndroid Build Coastguard Worker   size_t batch_{1};
157*4bdc9457SAndroid Build Coastguard Worker   uint32_t input_lshift_{4};
158*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_scale_{16};
159*4bdc9457SAndroid Build Coastguard Worker   bool inplace_{false};
160*4bdc9457SAndroid Build Coastguard Worker   size_t iterations_{15};
161*4bdc9457SAndroid Build Coastguard Worker };
162