xref: /aosp_15_r20/external/grpc-grpc/third_party/utf8_range/range-avx2.c (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 #ifdef __AVX2__
2 
3 #include <stdio.h>
4 #include <stdint.h>
5 #include <x86intrin.h>
6 
7 int utf8_naive(const unsigned char *data, int len);
8 
9 #if 0
10 static void print256(const char *s, const __m256i v256)
11 {
12   const unsigned char *v8 = (const unsigned char *)&v256;
13   if (s)
14     printf("%s:\t", s);
15   for (int i = 0; i < 32; i++)
16     printf("%02x ", v8[i]);
17   printf("\n");
18 }
19 #endif
20 
21 /*
22  * Map high nibble of "First Byte" to legal character length minus 1
23  * 0x00 ~ 0xBF --> 0
24  * 0xC0 ~ 0xDF --> 1
25  * 0xE0 ~ 0xEF --> 2
26  * 0xF0 ~ 0xFF --> 3
27  */
28 static const int8_t _first_len_tbl[] = {
29     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 3,
30     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 3,
31 };
32 
33 /* Map "First Byte" to 8-th item of range table (0xC2 ~ 0xF4) */
34 static const int8_t _first_range_tbl[] = {
35     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8,
36     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8,
37 };
38 
39 /*
40  * Range table, map range index to min and max values
41  * Index 0    : 00 ~ 7F (First Byte, ascii)
42  * Index 1,2,3: 80 ~ BF (Second, Third, Fourth Byte)
43  * Index 4    : A0 ~ BF (Second Byte after E0)
44  * Index 5    : 80 ~ 9F (Second Byte after ED)
45  * Index 6    : 90 ~ BF (Second Byte after F0)
46  * Index 7    : 80 ~ 8F (Second Byte after F4)
47  * Index 8    : C2 ~ F4 (First Byte, non ascii)
48  * Index 9~15 : illegal: i >= 127 && i <= -128
49  */
50 static const int8_t _range_min_tbl[] = {
51     0x00, 0x80, 0x80, 0x80, 0xA0, 0x80, 0x90, 0x80,
52     0xC2, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F,
53     0x00, 0x80, 0x80, 0x80, 0xA0, 0x80, 0x90, 0x80,
54     0xC2, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F, 0x7F,
55 };
56 static const int8_t _range_max_tbl[] = {
57     0x7F, 0xBF, 0xBF, 0xBF, 0xBF, 0x9F, 0xBF, 0x8F,
58     0xF4, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
59     0x7F, 0xBF, 0xBF, 0xBF, 0xBF, 0x9F, 0xBF, 0x8F,
60     0xF4, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,
61 };
62 
63 /*
64  * Tables for fast handling of four special First Bytes(E0,ED,F0,F4), after
65  * which the Second Byte are not 80~BF. It contains "range index adjustment".
66  * +------------+---------------+------------------+----------------+
67  * | First Byte | original range| range adjustment | adjusted range |
68  * +------------+---------------+------------------+----------------+
69  * | E0         | 2             | 2                | 4              |
70  * +------------+---------------+------------------+----------------+
71  * | ED         | 2             | 3                | 5              |
72  * +------------+---------------+------------------+----------------+
73  * | F0         | 3             | 3                | 6              |
74  * +------------+---------------+------------------+----------------+
75  * | F4         | 4             | 4                | 8              |
76  * +------------+---------------+------------------+----------------+
77  */
78 /* index1 -> E0, index14 -> ED */
79 static const int8_t _df_ee_tbl[] = {
80     0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0,
81     0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0,
82 };
83 /* index1 -> F0, index5 -> F4 */
84 static const int8_t _ef_fe_tbl[] = {
85     0, 3, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
86     0, 3, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
87 };
88 
89 #define RET_ERR_IDX 0   /* Define 1 to return index of first error char */
90 
push_last_byte_of_a_to_b(__m256i a,__m256i b)91 static inline __m256i push_last_byte_of_a_to_b(__m256i a, __m256i b) {
92   return _mm256_alignr_epi8(b, _mm256_permute2x128_si256(a, b, 0x21), 15);
93 }
94 
push_last_2bytes_of_a_to_b(__m256i a,__m256i b)95 static inline __m256i push_last_2bytes_of_a_to_b(__m256i a, __m256i b) {
96   return _mm256_alignr_epi8(b, _mm256_permute2x128_si256(a, b, 0x21), 14);
97 }
98 
push_last_3bytes_of_a_to_b(__m256i a,__m256i b)99 static inline __m256i push_last_3bytes_of_a_to_b(__m256i a, __m256i b) {
100   return _mm256_alignr_epi8(b, _mm256_permute2x128_si256(a, b, 0x21), 13);
101 }
102 
103 /* 5x faster than naive method */
104 /* Return 0 - success, -1 - error, >0 - first error char(if RET_ERR_IDX = 1) */
utf8_range_avx2(const unsigned char * data,int len)105 int utf8_range_avx2(const unsigned char *data, int len)
106 {
107 #if  RET_ERR_IDX
108     int err_pos = 1;
109 #endif
110 
111     if (len >= 32) {
112         __m256i prev_input = _mm256_set1_epi8(0);
113         __m256i prev_first_len = _mm256_set1_epi8(0);
114 
115         /* Cached tables */
116         const __m256i first_len_tbl =
117             _mm256_loadu_si256((const __m256i *)_first_len_tbl);
118         const __m256i first_range_tbl =
119             _mm256_loadu_si256((const __m256i *)_first_range_tbl);
120         const __m256i range_min_tbl =
121             _mm256_loadu_si256((const __m256i *)_range_min_tbl);
122         const __m256i range_max_tbl =
123             _mm256_loadu_si256((const __m256i *)_range_max_tbl);
124         const __m256i df_ee_tbl =
125             _mm256_loadu_si256((const __m256i *)_df_ee_tbl);
126         const __m256i ef_fe_tbl =
127             _mm256_loadu_si256((const __m256i *)_ef_fe_tbl);
128 
129 #if !RET_ERR_IDX
130         __m256i error1 = _mm256_set1_epi8(0);
131         __m256i error2 = _mm256_set1_epi8(0);
132 #endif
133 
134         while (len >= 32) {
135             const __m256i input = _mm256_loadu_si256((const __m256i *)data);
136 
137             /* high_nibbles = input >> 4 */
138             const __m256i high_nibbles =
139                 _mm256_and_si256(_mm256_srli_epi16(input, 4), _mm256_set1_epi8(0x0F));
140 
141             /* first_len = legal character length minus 1 */
142             /* 0 for 00~7F, 1 for C0~DF, 2 for E0~EF, 3 for F0~FF */
143             /* first_len = first_len_tbl[high_nibbles] */
144             __m256i first_len = _mm256_shuffle_epi8(first_len_tbl, high_nibbles);
145 
146             /* First Byte: set range index to 8 for bytes within 0xC0 ~ 0xFF */
147             /* range = first_range_tbl[high_nibbles] */
148             __m256i range = _mm256_shuffle_epi8(first_range_tbl, high_nibbles);
149 
150             /* Second Byte: set range index to first_len */
151             /* 0 for 00~7F, 1 for C0~DF, 2 for E0~EF, 3 for F0~FF */
152             /* range |= (first_len, prev_first_len) << 1 byte */
153             range = _mm256_or_si256(
154                     range, push_last_byte_of_a_to_b(prev_first_len, first_len));
155 
156             /* Third Byte: set range index to saturate_sub(first_len, 1) */
157             /* 0 for 00~7F, 0 for C0~DF, 1 for E0~EF, 2 for F0~FF */
158             __m256i tmp1, tmp2;
159 
160             /* tmp1 = (first_len, prev_first_len) << 2 bytes */
161             tmp1 = push_last_2bytes_of_a_to_b(prev_first_len, first_len);
162             /* tmp2 = saturate_sub(tmp1, 1) */
163             tmp2 = _mm256_subs_epu8(tmp1, _mm256_set1_epi8(1));
164 
165             /* range |= tmp2 */
166             range = _mm256_or_si256(range, tmp2);
167 
168             /* Fourth Byte: set range index to saturate_sub(first_len, 2) */
169             /* 0 for 00~7F, 0 for C0~DF, 0 for E0~EF, 1 for F0~FF */
170             /* tmp1 = (first_len, prev_first_len) << 3 bytes */
171             tmp1 = push_last_3bytes_of_a_to_b(prev_first_len, first_len);
172             /* tmp2 = saturate_sub(tmp1, 2) */
173             tmp2 = _mm256_subs_epu8(tmp1, _mm256_set1_epi8(2));
174             /* range |= tmp2 */
175             range = _mm256_or_si256(range, tmp2);
176 
177             /*
178              * Now we have below range indices caluclated
179              * Correct cases:
180              * - 8 for C0~FF
181              * - 3 for 1st byte after F0~FF
182              * - 2 for 1st byte after E0~EF or 2nd byte after F0~FF
183              * - 1 for 1st byte after C0~DF or 2nd byte after E0~EF or
184              *         3rd byte after F0~FF
185              * - 0 for others
186              * Error cases:
187              *   9,10,11 if non ascii First Byte overlaps
188              *   E.g., F1 80 C2 90 --> 8 3 10 2, where 10 indicates error
189              */
190 
191             /* Adjust Second Byte range for special First Bytes(E0,ED,F0,F4) */
192             /* Overlaps lead to index 9~15, which are illegal in range table */
193             __m256i shift1, pos, range2;
194             /* shift1 = (input, prev_input) << 1 byte */
195             shift1 = push_last_byte_of_a_to_b(prev_input, input);
196             pos = _mm256_sub_epi8(shift1, _mm256_set1_epi8(0xEF));
197             /*
198              * shift1:  | EF  F0 ... FE | FF  00  ... ...  DE | DF  E0 ... EE |
199              * pos:     | 0   1      15 | 16  17           239| 240 241    255|
200              * pos-240: | 0   0      0  | 0   0            0  | 0   1      15 |
201              * pos+112: | 112 113    127|       >= 128        |     >= 128    |
202              */
203             tmp1 = _mm256_subs_epu8(pos, _mm256_set1_epi8(240));
204             range2 = _mm256_shuffle_epi8(df_ee_tbl, tmp1);
205             tmp2 = _mm256_adds_epu8(pos, _mm256_set1_epi8(112));
206             range2 = _mm256_add_epi8(range2, _mm256_shuffle_epi8(ef_fe_tbl, tmp2));
207 
208             range = _mm256_add_epi8(range, range2);
209 
210             /* Load min and max values per calculated range index */
211             __m256i minv = _mm256_shuffle_epi8(range_min_tbl, range);
212             __m256i maxv = _mm256_shuffle_epi8(range_max_tbl, range);
213 
214             /* Check value range */
215 #if RET_ERR_IDX
216             __m256i error = _mm256_cmpgt_epi8(minv, input);
217             error = _mm256_or_si256(error, _mm256_cmpgt_epi8(input, maxv));
218             /* 5% performance drop from this conditional branch */
219             if (!_mm256_testz_si256(error, error))
220                 break;
221 #else
222             error1 = _mm256_or_si256(error1, _mm256_cmpgt_epi8(minv, input));
223             error2 = _mm256_or_si256(error2, _mm256_cmpgt_epi8(input, maxv));
224 #endif
225 
226             prev_input = input;
227             prev_first_len = first_len;
228 
229             data += 32;
230             len -= 32;
231 #if RET_ERR_IDX
232             err_pos += 32;
233 #endif
234         }
235 
236 #if RET_ERR_IDX
237         /* Error in first 16 bytes */
238         if (err_pos == 1)
239             goto do_naive;
240 #else
241         __m256i error = _mm256_or_si256(error1, error2);
242         if (!_mm256_testz_si256(error, error))
243             return -1;
244 #endif
245 
246         /* Find previous token (not 80~BF) */
247         int32_t token4 = _mm256_extract_epi32(prev_input, 7);
248         const int8_t *token = (const int8_t *)&token4;
249         int lookahead = 0;
250         if (token[3] > (int8_t)0xBF)
251             lookahead = 1;
252         else if (token[2] > (int8_t)0xBF)
253             lookahead = 2;
254         else if (token[1] > (int8_t)0xBF)
255             lookahead = 3;
256 
257         data -= lookahead;
258         len += lookahead;
259 #if RET_ERR_IDX
260         err_pos -= lookahead;
261 #endif
262     }
263 
264     /* Check remaining bytes with naive method */
265 #if RET_ERR_IDX
266     int err_pos2;
267 do_naive:
268     err_pos2 = utf8_naive(data, len);
269     if (err_pos2)
270         return err_pos + err_pos2 - 1;
271     return 0;
272 #else
273     return utf8_naive(data, len);
274 #endif
275 }
276 
277 #endif
278