xref: /aosp_15_r20/external/XNNPACK/src/math/sqrt-u32-scalar-hashemian.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <stddef.h>
8 
9 #include <xnnpack/math.h>
10 #include <xnnpack/math-stubs.h>
11 
12 
xnn_math_u32_sqrt__scalar_hashemian(size_t n,const uint32_t * input,uint32_t * output)13 void xnn_math_u32_sqrt__scalar_hashemian(
14     size_t n,
15     const uint32_t* input,
16     uint32_t* output)
17 {
18   assert(n % sizeof(uint32_t) == 0);
19 
20   for (; n != 0; n -= sizeof(uint32_t)) {
21     const uint32_t vx = *input++;
22 
23     uint32_t vy = vx;
24     if (vx != 0) {
25       /*
26        * Based on "Square Rooting Algorithms for Integer and Floating-Point Numbers" by Reza Hashemian
27        * and StackOverflow answer https://stackoverflow.com/a/31149161
28       */
29 
30       const uint32_t vn = math_clz_nonzero_u32(vx);
31       const uint32_t vleft_shift = vn & 1;
32       const uint32_t vm_minus_1 = 15 - (vn >> 1);
33       const uint32_t vm_plus_1 = vm_minus_1 + 2;
34       const uint32_t vexp2_m_minus_1 = UINT32_C(1) << vm_minus_1;
35       const uint32_t vz = vexp2_m_minus_1 - (vx >> (vm_plus_1 - vleft_shift));
36 
37       vy = vz;
38       // Iterate until y[i] == y[i-1]. Alternatively, we can do 7 iterations:
39       //   for (uint32_t i = 0; i < 7; i++) {
40       //     vy = vz + ((vy * vy) >> vm_plus_1);
41       //   }
42       uint32_t vy_prev;
43       do {
44         vy_prev = vy;
45         vy = vz + ((vy * vy) >> vm_plus_1);
46       } while (vy != vy_prev);
47 
48       // Reconstruct Y = 2**m - vy
49       vy = (vexp2_m_minus_1 << 1) - vy;
50       if XNN_UNPREDICTABLE(vleft_shift) {
51         // Multiply by sqrt(0.5) by subtracting vy * (1 - sqrt(0.5)), 1 - sqrt(0.5) is represented
52         // as a .16 fixed-point number to guarantee than the product doesn't overflow 32 bits.
53         // Using 1 - sqrt(0.5) under these constraints is 1 bit more accurate than using sqrt(0.5) directly.
54         vy -= (vy * UINT32_C(19195)) >> 16;
55       }
56 
57       // When X has an even number of bits, Y can overestimate isqrt(X) by 1 due to truncations in fixed-point
58       // arithmetics. When X has an odd number of bits, Y can overestimate isqrt(X) by an extra 1 (2 total) due to
59       // truncation in the multiplication by sqrt(0.5).
60       // We decrement Y once if X < Y * Y and decrement it once again if Y * Y - X > X - (Y - 1) * (Y - 1).
61       uint32_t vsquared_y = vy * vy;
62       if XNN_UNPREDICTABLE(vsquared_y > vx) {
63         vsquared_y -= 2 * vy - 1;
64         vy -= 1;
65       }
66 
67       // Y is within a distance of 1 from properly rounded sqrt(X).
68       // - Increment Y if (Y + 1) * (Y + 1) - X < X - Y * Y.
69       // - Decrement Y if Y * Y - X > X - (Y - 1) * (Y - 1).
70       // The increment + decrement are combined together to re-use the (Y * Y) value.
71       if XNN_UNPREDICTABLE(vsquared_y < vx - vy) {
72         vy += 1;
73       } else if XNN_UNPREDICTABLE(vsquared_y - vy >= vx) {
74         vy -= 1;
75       }
76     }
77 
78     *output++ = vy;
79   }
80 }
81