1 // Copyright 2021 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7 #include <stdint.h>
8 #include <stddef.h>
9
10 #include <arm_neon.h>
11
12 #include <xnnpack/math.h>
13 #include <xnnpack/requantization-stubs.h>
14
15
xnn_qs8_requantize_rndnu__neon_qdmulh(size_t n,const int32_t * input,float scale,int8_t zero_point,int8_t qmin,int8_t qmax,int8_t * output)16 void xnn_qs8_requantize_rndnu__neon_qdmulh(
17 size_t n,
18 const int32_t* input,
19 float scale,
20 int8_t zero_point,
21 int8_t qmin,
22 int8_t qmax,
23 int8_t* output)
24 {
25 assert(n % 16 == 0);
26 assert(scale < 1.0f);
27 assert(scale >= 0x1.0p-32f);
28
29 const uint32_t scale_bits = float_as_uint32(scale);
30
31 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
32 const int32_t multiplier = (int32_t) (((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
33 assert(multiplier >= INT32_C(0x40000000));
34 assert(multiplier <= INT32_C(0x7FFFFF80));
35
36 // Shift is in [0, 31] range.
37 const int32_t shift = 127 + 31 - 32 - (float_as_uint32(scale) >> 23);
38 assert(shift >= 0);
39 assert(shift < 32);
40
41 /* Split shift into pre_shift + post_shift, post_shift in [1, 31] range */
42 const int32_t post_shift = math_max_s32(shift, 1);
43 const int32_t pre_shift = shift - post_shift;
44
45 const int32x4_t vmultiplier = vdupq_n_s32(multiplier);
46 const int16x8_t vzero_point = vdupq_n_s16((int16_t) zero_point);
47 const int32x4_t vpre_shift = vdupq_n_s32(-pre_shift);
48 const int32x4_t vpost_shift = vdupq_n_s32(-post_shift);
49 const int8x16_t vqmin = vdupq_n_s8(qmin);
50 const int8x16_t vqmax = vdupq_n_s8(qmax);
51 for (; n != 0; n -= 16) {
52 const int32x4_t x = vld1q_s32(input);
53 const int32x4_t y = vld1q_s32(input + 4);
54 const int32x4_t z = vld1q_s32(input + 8);
55 const int32x4_t w = vld1q_s32(input + 12);
56 input += 16;
57
58 const int32x4_t x_preshifted = vshlq_s32(x, vpre_shift);
59 const int32x4_t y_preshifted = vshlq_s32(y, vpre_shift);
60 const int32x4_t z_preshifted = vshlq_s32(z, vpre_shift);
61 const int32x4_t w_preshifted = vshlq_s32(w, vpre_shift);
62
63 const int32x4_t x_product = vqdmulhq_s32(x_preshifted, vmultiplier);
64 const int32x4_t y_product = vqdmulhq_s32(y_preshifted, vmultiplier);
65 const int32x4_t z_product = vqdmulhq_s32(z_preshifted, vmultiplier);
66 const int32x4_t w_product = vqdmulhq_s32(w_preshifted, vmultiplier);
67
68 const int32x4_t x_scaled = vrshlq_s32(x_product, vpost_shift);
69 const int32x4_t y_scaled = vrshlq_s32(y_product, vpost_shift);
70 const int32x4_t z_scaled = vrshlq_s32(z_product, vpost_shift);
71 const int32x4_t w_scaled = vrshlq_s32(w_product, vpost_shift);
72
73 #ifdef __aarch64__
74 const int16x8_t xy_packed = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(x_scaled), y_scaled), vzero_point);
75 const int16x8_t zw_packed = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(z_scaled), w_scaled), vzero_point);
76 const int8x16_t xyzw_packed = vqmovn_high_s16(vqmovn_s16(xy_packed), zw_packed);
77 #else
78 const int16x8_t xy_packed = vqaddq_s16(vcombine_s16(vqmovn_s32(x_scaled), vqmovn_s32(y_scaled)), vzero_point);
79 const int16x8_t zw_packed = vqaddq_s16(vcombine_s16(vqmovn_s32(z_scaled), vqmovn_s32(w_scaled)), vzero_point);
80 const int8x16_t xyzw_packed = vcombine_s8(vqmovn_s16(xy_packed), vqmovn_s16(zw_packed));
81 #endif
82
83 const int8x16_t xyzw_clamped = vmaxq_s8(vminq_s8(xyzw_packed, vqmax), vqmin);
84
85 // AArch32 version:
86 // 4x VSHL.S32 Qd, Qm, Qn
87 // 4x VQDMULH.S32 Qd, Qm, Qn
88 // 4x VRSHL.S32 Qd, Qm, Qn
89 // 4x VQMOVN.S32 Dd, Qm
90 // 2x VQADD.S16 Qd, Qm, Qn
91 // 2x VQMOVUN.S16 Dd, Qm
92 // 1x VMAX.U8 Qd, Qm, Qn
93 // 1x VMIN.U8 Qd, Qm, Qn
94 // ---------------------
95 // 22 instructions total
96 //
97 // AArch64 version:
98 // 4x SSHL Vd.4S, Vn.4S, Vm.4S
99 // 4x SQDMULH Vd.4S, Vn.4S, Vm.4S
100 // 4x SRSHL 4d.4S, Vn.4S, Vm.4S
101 // 2x SQXTN Vd.4H, Vn.4S
102 // 2x SQXTN2 Vd.8H, Vn.4S
103 // 2x SQADD Vd.8H, Vn.8H, Vm.8H
104 // 1x SQXTN Vd.8B, Vn.8H
105 // 1x SQXTN2 Vd.16B, Vn.8H
106 // 1x SMIN Vd.16B, Vn.16B, Vm.16B
107 // 1x SMAX Vd.16B, Vn.16B, Vm.16B
108 // ---------------------
109 // 22 instructions total
110
111 vst1q_s8(output, xyzw_clamped);
112 output += 16;
113 }
114 }
115