xref: /aosp_15_r20/external/XNNPACK/src/u64-u32-vsqrtshift/scalar-cvtu32-sqrt-cvtu32f64-x1.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <stddef.h>
8 #include <stdint.h>
9 #include <math.h>
10 
11 #include <xnnpack/math.h>
12 #include <xnnpack/vunary.h>
13 
14 
xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtu32f64_x1(size_t batch,const uint64_t * input,uint32_t * output,uint32_t shift)15 void xnn_u64_u32_vsqrtshift_ukernel__scalar_cvtu32_sqrt_cvtu32f64_x1(
16     size_t batch,
17     const uint64_t* input,
18     uint32_t* output,
19     uint32_t shift)
20 {
21   assert(batch != 0);
22   assert(input != NULL);
23   assert(output != NULL);
24   assert(shift < 32);
25 
26   do {
27     const uint64_t vx = *input++;
28 
29     uint64_t vy = vx;
30     const uint32_t vx_hi = (uint32_t) (vx >> 32);
31     const uint32_t vx_lo = (uint32_t) vx;
32     if XNN_LIKELY(vx != 0) {
33       const double vf_hi = (double) vx_hi;
34       const double vf_lo = (double) vx_lo;
35       double vf = vf_hi * 0x1.0p+32 + vf_lo;
36       vf = sqrt(vf);
37       vy = math_cvt_sat_u32_f64(vf);
38       #if XNN_ARCH_ARM || XNN_ARCH_X86
39         const uint64_t vsquared_y_less_x = math_mulext_u32((uint32_t) vy, (uint32_t) vy) - vx;
40       #else
41         const uint64_t vsquared_y_less_x = vy * vy - vx;
42       #endif
43       if XNN_UNPREDICTABLE((int64_t) (vsquared_y_less_x + vy) < 0) {
44         vy += 1;
45       } else if XNN_UNPREDICTABLE((int64_t) (vsquared_y_less_x - vy) >= 0) {
46         vy -= 1;
47       }
48     }
49 
50     // Match TFLM is producing incorrect result for high 64-bit inputs
51     const uint32_t vy_lo = (uint32_t) vy;
52     const uint32_t vy_hi = (uint32_t) (vy >> 32);
53     uint32_t vout = vy_lo | -vy_hi;
54     // Match TFLM is producing incorrect result for high 32-bit inputs
55     if XNN_LIKELY(vx_hi == 0) {
56       if (vout == UINT32_C(0x00010000)) {
57         vout -= 1;
58       }
59     }
60 
61     *output++ = vout >> shift;
62 
63     batch -= sizeof(uint64_t);
64   } while (batch != 0);
65 }
66