xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/gemmlowp-sse.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <limits.h>
12 
13 #include <immintrin.h>
14 
15 /*
16  * The code below is adapted from Google's gemmlowp library.
17  * It is only used in QNNPACK unit tests and comparative benchmarks,
18  * but not the library itself.
19  */
20 
21 // Copyright 2015 Google Inc. All Rights Reserved.
22 //
23 // Licensed under the Apache License, Version 2.0 (the "License");
24 // you may not use this file except in compliance with the License.
25 // You may obtain a copy of the License at
26 //
27 //     http://www.apache.org/licenses/LICENSE-2.0
28 //
29 // Unless required by applicable law or agreed to in writing, software
30 // distributed under the License is distributed on an "AS IS" BASIS,
31 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32 // See the License for the specific language governing permissions and
33 // limitations under the License.
34 
gemmlowp_sse_rdivbypo2_s32(__m128i x,int exponent)35 static inline __m128i gemmlowp_sse_rdivbypo2_s32(__m128i x, int exponent) {
36   const __m128i mask =
37       _mm_set1_epi32((int32_t)((UINT64_C(1) << exponent) - UINT64_C(1)));
38   const __m128i remainder = _mm_and_si128(x, mask);
39   const __m128i threshold = _mm_sub_epi32(
40       _mm_srli_epi32(mask, 1), _mm_cmplt_epi32(x, _mm_setzero_si128()));
41   return _mm_sub_epi32(
42       _mm_sra_epi32(x, _mm_cvtsi32_si128(exponent)),
43       _mm_cmpgt_epi32(remainder, threshold));
44 }
45 
gemmlowp_sse_mul_s32(__m128i a,__m128i b)46 static inline __m128i gemmlowp_sse_mul_s32(__m128i a, __m128i b) {
47 #ifdef __SSE4_1__
48   return _mm_mul_epi32(a, b);
49 #else
50   __m128i sign, zero, mul_us, a_neg, b_neg, mul_us_neg;
51   sign = _mm_xor_si128(a, b);
52   sign = _mm_srai_epi32(sign, 31); // promote sign bit to all fields, all fff if
53                                    // negative and all 0 if positive
54   sign = _mm_shuffle_epi32(
55       sign,
56       _MM_SHUFFLE(2, 2, 0, 0)); // promote sign bit to 3 and 1st data lanes
57   zero = _mm_setzero_si128();
58 #ifdef __SSSE3__
59   a_neg = _mm_abs_epi32(a); // negate a and b
60   b_neg = _mm_abs_epi32(b); // negate a and b
61 #else /* pre-SSSE3 */
62   const __m128i a_neg_mask = _mm_cmplt_epi32(a, zero);
63   a_neg = _mm_sub_epi32(_mm_xor_si128(a, a_neg_mask), a_neg_mask);
64   const __m128i b_neg_mask = _mm_cmplt_epi32(b, zero);
65   b_neg = _mm_sub_epi32(_mm_xor_si128(b, b_neg_mask), b_neg_mask);
66 #endif /* pre-SSSE3 */
67   mul_us = _mm_mul_epu32(a_neg, b_neg); // uses 0 and 2nd data lanes, (abs), the
68                                         // multiplication gives 64 bit result
69   mul_us_neg = _mm_sub_epi64(zero, mul_us);
70   mul_us_neg = _mm_and_si128(sign, mul_us_neg);
71   mul_us = _mm_andnot_si128(sign, mul_us);
72   return _mm_or_si128(mul_us, mul_us_neg);
73 #endif
74 }
75 
gemmlowp_sse_vqrdmulh_s32(__m128i a,__m128i b)76 static inline __m128i gemmlowp_sse_vqrdmulh_s32(__m128i a, __m128i b) {
77   // saturation only happen if a == b == INT32_MIN
78   const __m128i min = _mm_set1_epi32(INT32_MIN);
79   const __m128i saturation_mask =
80       _mm_and_si128(_mm_cmpeq_epi32(a, b), _mm_cmpeq_epi32(a, min));
81 
82   // a = a0 | a1 | a2 | a3
83   // b = b0 | b1 | b2 | b3
84   const __m128i a0_a2 = a;
85   const __m128i a1_a3 = _mm_srli_si128(a, 4);
86   const __m128i b0_b2 = b;
87   const __m128i b1_b3 = _mm_srli_si128(b, 4);
88 
89   const __m128i a0b0_a2b2 = gemmlowp_sse_mul_s32(a0_a2, b0_b2);
90   const __m128i a1b1_a3b3 = gemmlowp_sse_mul_s32(a1_a3, b1_b3);
91 
92   // do the rounding and take into account that it will be doubled
93   const __m128i nudge = _mm_set1_epi64x(1 << 30);
94   const __m128i a0b0_a2b2_rounded = _mm_add_epi64(a0b0_a2b2, nudge);
95   const __m128i a1b1_a3b3_rounded = _mm_add_epi64(a1b1_a3b3, nudge);
96 
97   // do the doubling
98   const __m128i a0b0_a2b2_rounded_2x = _mm_slli_epi64(a0b0_a2b2_rounded, 1);
99   const __m128i a1b1_a3b3_rounded_2x = _mm_slli_epi64(a1b1_a3b3_rounded, 1);
100 
101 // get the high part of the products
102 #ifdef __SSE4_1__
103   const __m128i result = _mm_blend_epi16(
104       _mm_srli_epi64(a0b0_a2b2_rounded_2x, 32), a1b1_a3b3_rounded_2x, 0xCC);
105 #else
106   const __m128i result0213 = _mm_castps_si128(_mm_shuffle_ps(
107       _mm_castsi128_ps(a0b0_a2b2_rounded_2x),
108       _mm_castsi128_ps(a1b1_a3b3_rounded_2x),
109       _MM_SHUFFLE(3, 1, 3, 1)));
110   const __m128i result = _mm_shuffle_epi32(result0213, _MM_SHUFFLE(3, 1, 2, 0));
111 #endif
112 
113 // saturate those which overflowed
114 #ifdef __SSE4_1__
115   const __m128i saturated_result =
116       _mm_blendv_epi8(result, min, saturation_mask);
117 #else
118   const __m128i saturated_result = _mm_or_si128(
119       _mm_and_si128(saturation_mask, min),
120       _mm_andnot_si128(saturation_mask, result));
121 #endif
122   return saturated_result;
123 }
124