1 /// SWAR: SIMD Within A Register
2 /// SIMD validator backend that validates register-sized chunks of data at a time.
3 use crate::{is_header_name_token, is_header_value_token, is_uri_token, Bytes};
4 
5 // Adapt block-size to match native register size, i.e: 32bit => 4, 64bit => 8
6 const BLOCK_SIZE: usize = core::mem::size_of::<usize>();
7 type ByteBlock = [u8; BLOCK_SIZE];
8 
9 #[inline]
match_uri_vectored(bytes: &mut Bytes)10 pub fn match_uri_vectored(bytes: &mut Bytes) {
11     loop {
12         if let Some(bytes8) = bytes.peek_n::<ByteBlock>(BLOCK_SIZE) {
13             let n = match_uri_char_8_swar(bytes8);
14             // SAFETY: using peek_n to retrieve the bytes ensures that there are at least n more bytes
15             // in `bytes`, so calling `advance(n)` is safe.
16             unsafe {
17                 bytes.advance(n);
18             }
19             if n == BLOCK_SIZE {
20                 continue;
21             }
22         }
23         if let Some(b) = bytes.peek() {
24             if is_uri_token(b) {
25                 // SAFETY: using peek to retrieve the byte ensures that there is at least 1 more byte
26                 // in bytes, so calling advance is safe.
27                 unsafe {
28                     bytes.advance(1);
29                 }
30                 continue;
31             }
32         }
33         break;
34     }
35 }
36 
37 #[inline]
match_header_value_vectored(bytes: &mut Bytes)38 pub fn match_header_value_vectored(bytes: &mut Bytes) {
39     loop {
40         if let Some(bytes8) = bytes.peek_n::<ByteBlock>(BLOCK_SIZE) {
41             let n = match_header_value_char_8_swar(bytes8);
42             // SAFETY: using peek_n to retrieve the bytes ensures that there are at least n more bytes
43             // in `bytes`, so calling `advance(n)` is safe.
44             unsafe {
45                 bytes.advance(n);
46             }
47             if n == BLOCK_SIZE {
48                 continue;
49             }
50         }
51         if let Some(b) = bytes.peek() {
52             if is_header_value_token(b) {
53                 // SAFETY: using peek to retrieve the byte ensures that there is at least 1 more byte
54                 // in bytes, so calling advance is safe.
55                 unsafe {
56                     bytes.advance(1);
57                 }
58                 continue;
59             }
60         }
61         break;
62     }
63 }
64 
65 #[inline]
match_header_name_vectored(bytes: &mut Bytes)66 pub fn match_header_name_vectored(bytes: &mut Bytes) {
67     while let Some(block) = bytes.peek_n::<ByteBlock>(BLOCK_SIZE) {
68         let n = match_block(is_header_name_token, block);
69         // SAFETY: using peek_n to retrieve the bytes ensures that there are at least n more bytes
70         // in `bytes`, so calling `advance(n)` is safe.
71         unsafe {
72             bytes.advance(n);
73         }
74         if n != BLOCK_SIZE {
75             return;
76         }
77     }
78     // SAFETY: match_tail processes at most the remaining data in `bytes`. advances `bytes` to the
79     // end, but no further.
80     unsafe { bytes.advance(match_tail(is_header_name_token, bytes.as_ref())) };
81 }
82 
83 // Matches "tail", i.e: when we have <BLOCK_SIZE bytes in the buffer, should be uncommon
84 #[cold]
85 #[inline]
match_tail(f: impl Fn(u8) -> bool, bytes: &[u8]) -> usize86 fn match_tail(f: impl Fn(u8) -> bool, bytes: &[u8]) -> usize {
87     for (i, &b) in bytes.iter().enumerate() {
88         if !f(b) {
89             return i;
90         }
91     }
92     bytes.len()
93 }
94 
95 // Naive fallback block matcher
96 #[inline(always)]
match_block(f: impl Fn(u8) -> bool, block: ByteBlock) -> usize97 fn match_block(f: impl Fn(u8) -> bool, block: ByteBlock) -> usize {
98     for (i, &b) in block.iter().enumerate() {
99         if !f(b) {
100             return i;
101         }
102     }
103     BLOCK_SIZE
104 }
105 
106 // A const alternative to u64::from_ne_bytes to avoid bumping MSRV (1.36 => 1.44)
107 // creates a u64 whose bytes are each equal to b
uniform_block(b: u8) -> usize108 const fn uniform_block(b: u8) -> usize {
109     (b as u64 *  0x01_01_01_01_01_01_01_01 /* [1_u8; 8] */) as usize
110 }
111 
112 // A byte-wise range-check on an enire word/block,
113 // ensuring all bytes in the word satisfy
114 // `33 <= x <= 126 && x != '>' && x != '<'`
115 // IMPORTANT: it false negatives if the block contains '?'
116 #[inline]
match_uri_char_8_swar(block: ByteBlock) -> usize117 fn match_uri_char_8_swar(block: ByteBlock) -> usize {
118     // 33 <= x <= 126
119     const M: u8 = 0x21;
120     const N: u8 = 0x7E;
121     const BM: usize = uniform_block(M);
122     const BN: usize = uniform_block(127 - N);
123     const M128: usize = uniform_block(128);
124 
125     let x = usize::from_ne_bytes(block); // Really just a transmute
126     let lt = x.wrapping_sub(BM) & !x; // <= m
127     let gt = x.wrapping_add(BN) | x; // >= n
128 
129     // XOR checks to catch '<' & '>' for correctness
130     //
131     // XOR can be thought of as a "distance function"
132     // (somewhat extrapolating from the `xor(x, x) = 0` identity and ∀ x != y: xor(x, y) != 0`
133     // (each u8 "xor key" providing a unique total ordering of u8)
134     // '<' and '>' have a "xor distance" of 2 (`xor('<', '>') = 2`)
135     // xor(x, '>') <= 2 => {'>', '?', '<'}
136     // xor(x, '<') <= 2 => {'<', '=', '>'}
137     //
138     // We assume P('=') > P('?'),
139     // given well/commonly-formatted URLs with querystrings contain
140     // a single '?' but possibly many '='
141     //
142     // Thus it's preferable/near-optimal to "xor distance" on '>',
143     // since we'll slowpath at most one block per URL
144     //
145     // Some rust code to sanity check this yourself:
146     // ```rs
147     // fn xordist(x: u8, n: u8) -> Vec<(char, u8)> {
148     //     (0..=255).into_iter().map(|c| (c as char, c ^ x)).filter(|(_c, y)| *y <= n).collect()
149     // }
150     // (xordist(b'<', 2), xordist(b'>', 2))
151     // ```
152     const B3: usize = uniform_block(3); // (dist <= 2) + 1 to wrap
153     const BGT: usize = uniform_block(b'>');
154 
155     let xgt = x ^ BGT;
156     let ltgtq = xgt.wrapping_sub(B3) & !xgt;
157 
158     offsetnz((ltgtq | lt | gt) & M128)
159 }
160 
161 // A byte-wise range-check on an entire word/block,
162 // ensuring all bytes in the word satisfy `32 <= x <= 126`
163 // IMPORTANT: false negatives if obs-text is present (0x80..=0xFF)
164 #[inline]
match_header_value_char_8_swar(block: ByteBlock) -> usize165 fn match_header_value_char_8_swar(block: ByteBlock) -> usize {
166     // 32 <= x <= 126
167     const M: u8 = 0x20;
168     const N: u8 = 0x7E;
169     const BM: usize = uniform_block(M);
170     const BN: usize = uniform_block(127 - N);
171     const M128: usize = uniform_block(128);
172 
173     let x = usize::from_ne_bytes(block); // Really just a transmute
174     let lt = x.wrapping_sub(BM) & !x; // <= m
175     let gt = x.wrapping_add(BN) | x; // >= n
176     offsetnz((lt | gt) & M128)
177 }
178 
179 /// Check block to find offset of first non-zero byte
180 // NOTE: Curiously `block.trailing_zeros() >> 3` appears to be slower, maybe revisit
181 #[inline]
offsetnz(block: usize) -> usize182 fn offsetnz(block: usize) -> usize {
183     // fast path optimistic case (common for long valid sequences)
184     if block == 0 {
185         return BLOCK_SIZE;
186     }
187 
188     // perf: rust will unroll this loop
189     for (i, b) in block.to_ne_bytes().iter().copied().enumerate() {
190         if b != 0 {
191             return i;
192         }
193     }
194     unreachable!()
195 }
196 
197 #[test]
test_is_header_value_block()198 fn test_is_header_value_block() {
199     let is_header_value_block = |b| match_header_value_char_8_swar(b) == BLOCK_SIZE;
200 
201     // 0..32 => false
202     for b in 0..32_u8 {
203         assert!(!is_header_value_block([b; BLOCK_SIZE]), "b={}", b);
204     }
205     // 32..127 => true
206     for b in 32..127_u8 {
207         assert!(is_header_value_block([b; BLOCK_SIZE]), "b={}", b);
208     }
209     // 127..=255 => false
210     for b in 127..=255_u8 {
211         assert!(!is_header_value_block([b; BLOCK_SIZE]), "b={}", b);
212     }
213 
214 
215     #[cfg(target_pointer_width = "64")]
216     {
217         // A few sanity checks on non-uniform bytes for safe-measure
218         assert!(!is_header_value_block(*b"foo.com\n"));
219         assert!(!is_header_value_block(*b"o.com\r\nU"));
220     }
221 }
222 
223 #[test]
test_is_uri_block()224 fn test_is_uri_block() {
225     let is_uri_block = |b| match_uri_char_8_swar(b) == BLOCK_SIZE;
226 
227     // 0..33 => false
228     for b in 0..33_u8 {
229         assert!(!is_uri_block([b; BLOCK_SIZE]), "b={}", b);
230     }
231     // 33..127 => true if b not in { '<', '?', '>' }
232     let falsy = |b| b"<?>".contains(&b);
233     for b in 33..127_u8 {
234         assert_eq!(is_uri_block([b; BLOCK_SIZE]), !falsy(b), "b={}", b);
235     }
236     // 127..=255 => false
237     for b in 127..=255_u8 {
238         assert!(!is_uri_block([b; BLOCK_SIZE]), "b={}", b);
239     }
240 }
241 
242 #[test]
test_offsetnz()243 fn test_offsetnz() {
244     let seq = [0_u8; BLOCK_SIZE];
245     for i in 0..BLOCK_SIZE {
246         let mut seq = seq;
247         seq[i] = 1;
248         let x = usize::from_ne_bytes(seq);
249         assert_eq!(offsetnz(x), i);
250     }
251 }
252