1 // This is adapted from `fallback.rs` from rust-memchr. It's modified to return
2 // the 'inverse' query of memchr, e.g. finding the first byte not in the
3 // provided set. This is simple for the 1-byte case.
4 
5 use core::{cmp, usize};
6 
7 #[cfg(target_pointer_width = "32")]
8 const USIZE_BYTES: usize = 4;
9 
10 #[cfg(target_pointer_width = "64")]
11 const USIZE_BYTES: usize = 8;
12 
13 // The number of bytes to loop at in one iteration of memchr/memrchr.
14 const LOOP_SIZE: usize = 2 * USIZE_BYTES;
15 
16 /// Repeat the given byte into a word size number. That is, every 8 bits
17 /// is equivalent to the given byte. For example, if `b` is `\x4E` or
18 /// `01001110` in binary, then the returned value on a 32-bit system would be:
19 /// `01001110_01001110_01001110_01001110`.
20 #[inline(always)]
repeat_byte(b: u8) -> usize21 fn repeat_byte(b: u8) -> usize {
22     (b as usize) * (usize::MAX / 255)
23 }
24 
inv_memchr(n1: u8, haystack: &[u8]) -> Option<usize>25 pub fn inv_memchr(n1: u8, haystack: &[u8]) -> Option<usize> {
26     let vn1 = repeat_byte(n1);
27     let confirm = |byte| byte != n1;
28     let loop_size = cmp::min(LOOP_SIZE, haystack.len());
29     let align = USIZE_BYTES - 1;
30     let start_ptr = haystack.as_ptr();
31 
32     unsafe {
33         let end_ptr = haystack.as_ptr().add(haystack.len());
34         let mut ptr = start_ptr;
35 
36         if haystack.len() < USIZE_BYTES {
37             return forward_search(start_ptr, end_ptr, ptr, confirm);
38         }
39 
40         let chunk = read_unaligned_usize(ptr);
41         if (chunk ^ vn1) != 0 {
42             return forward_search(start_ptr, end_ptr, ptr, confirm);
43         }
44 
45         ptr = ptr.add(USIZE_BYTES - (start_ptr as usize & align));
46         debug_assert!(ptr > start_ptr);
47         debug_assert!(end_ptr.sub(USIZE_BYTES) >= start_ptr);
48         while loop_size == LOOP_SIZE && ptr <= end_ptr.sub(loop_size) {
49             debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
50 
51             let a = *(ptr as *const usize);
52             let b = *(ptr.add(USIZE_BYTES) as *const usize);
53             let eqa = (a ^ vn1) != 0;
54             let eqb = (b ^ vn1) != 0;
55             if eqa || eqb {
56                 break;
57             }
58             ptr = ptr.add(LOOP_SIZE);
59         }
60         forward_search(start_ptr, end_ptr, ptr, confirm)
61     }
62 }
63 
64 /// Return the last index not matching the byte `x` in `text`.
inv_memrchr(n1: u8, haystack: &[u8]) -> Option<usize>65 pub fn inv_memrchr(n1: u8, haystack: &[u8]) -> Option<usize> {
66     let vn1 = repeat_byte(n1);
67     let confirm = |byte| byte != n1;
68     let loop_size = cmp::min(LOOP_SIZE, haystack.len());
69     let align = USIZE_BYTES - 1;
70     let start_ptr = haystack.as_ptr();
71 
72     unsafe {
73         let end_ptr = haystack.as_ptr().add(haystack.len());
74         let mut ptr = end_ptr;
75 
76         if haystack.len() < USIZE_BYTES {
77             return reverse_search(start_ptr, end_ptr, ptr, confirm);
78         }
79 
80         let chunk = read_unaligned_usize(ptr.sub(USIZE_BYTES));
81         if (chunk ^ vn1) != 0 {
82             return reverse_search(start_ptr, end_ptr, ptr, confirm);
83         }
84 
85         ptr = ptr.sub(end_ptr as usize & align);
86         debug_assert!(start_ptr <= ptr && ptr <= end_ptr);
87         while loop_size == LOOP_SIZE && ptr >= start_ptr.add(loop_size) {
88             debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES);
89 
90             let a = *(ptr.sub(2 * USIZE_BYTES) as *const usize);
91             let b = *(ptr.sub(1 * USIZE_BYTES) as *const usize);
92             let eqa = (a ^ vn1) != 0;
93             let eqb = (b ^ vn1) != 0;
94             if eqa || eqb {
95                 break;
96             }
97             ptr = ptr.sub(loop_size);
98         }
99         reverse_search(start_ptr, end_ptr, ptr, confirm)
100     }
101 }
102 
103 #[inline(always)]
forward_search<F: Fn(u8) -> bool>( start_ptr: *const u8, end_ptr: *const u8, mut ptr: *const u8, confirm: F, ) -> Option<usize>104 unsafe fn forward_search<F: Fn(u8) -> bool>(
105     start_ptr: *const u8,
106     end_ptr: *const u8,
107     mut ptr: *const u8,
108     confirm: F,
109 ) -> Option<usize> {
110     debug_assert!(start_ptr <= ptr);
111     debug_assert!(ptr <= end_ptr);
112 
113     while ptr < end_ptr {
114         if confirm(*ptr) {
115             return Some(sub(ptr, start_ptr));
116         }
117         ptr = ptr.offset(1);
118     }
119     None
120 }
121 
122 #[inline(always)]
reverse_search<F: Fn(u8) -> bool>( start_ptr: *const u8, end_ptr: *const u8, mut ptr: *const u8, confirm: F, ) -> Option<usize>123 unsafe fn reverse_search<F: Fn(u8) -> bool>(
124     start_ptr: *const u8,
125     end_ptr: *const u8,
126     mut ptr: *const u8,
127     confirm: F,
128 ) -> Option<usize> {
129     debug_assert!(start_ptr <= ptr);
130     debug_assert!(ptr <= end_ptr);
131 
132     while ptr > start_ptr {
133         ptr = ptr.offset(-1);
134         if confirm(*ptr) {
135             return Some(sub(ptr, start_ptr));
136         }
137     }
138     None
139 }
140 
read_unaligned_usize(ptr: *const u8) -> usize141 unsafe fn read_unaligned_usize(ptr: *const u8) -> usize {
142     (ptr as *const usize).read_unaligned()
143 }
144 
145 /// Subtract `b` from `a` and return the difference. `a` should be greater than
146 /// or equal to `b`.
sub(a: *const u8, b: *const u8) -> usize147 fn sub(a: *const u8, b: *const u8) -> usize {
148     debug_assert!(a >= b);
149     (a as usize) - (b as usize)
150 }
151 
152 /// Safe wrapper around `forward_search`
153 #[inline]
forward_search_bytes<F: Fn(u8) -> bool>( s: &[u8], confirm: F, ) -> Option<usize>154 pub(crate) fn forward_search_bytes<F: Fn(u8) -> bool>(
155     s: &[u8],
156     confirm: F,
157 ) -> Option<usize> {
158     unsafe {
159         let start = s.as_ptr();
160         let end = start.add(s.len());
161         forward_search(start, end, start, confirm)
162     }
163 }
164 
165 /// Safe wrapper around `reverse_search`
166 #[inline]
reverse_search_bytes<F: Fn(u8) -> bool>( s: &[u8], confirm: F, ) -> Option<usize>167 pub(crate) fn reverse_search_bytes<F: Fn(u8) -> bool>(
168     s: &[u8],
169     confirm: F,
170 ) -> Option<usize> {
171     unsafe {
172         let start = s.as_ptr();
173         let end = start.add(s.len());
174         reverse_search(start, end, end, confirm)
175     }
176 }
177 
178 #[cfg(all(test, feature = "std"))]
179 mod tests {
180     use super::{inv_memchr, inv_memrchr};
181 
182     // search string, search byte, inv_memchr result, inv_memrchr result.
183     // these are expanded into a much larger set of tests in build_tests
184     const TESTS: &[(&[u8], u8, usize, usize)] = &[
185         (b"z", b'a', 0, 0),
186         (b"zz", b'a', 0, 1),
187         (b"aza", b'a', 1, 1),
188         (b"zaz", b'a', 0, 2),
189         (b"zza", b'a', 0, 1),
190         (b"zaa", b'a', 0, 0),
191         (b"zzz", b'a', 0, 2),
192     ];
193 
194     type TestCase = (Vec<u8>, u8, Option<(usize, usize)>);
195 
build_tests() -> Vec<TestCase>196     fn build_tests() -> Vec<TestCase> {
197         #[cfg(not(miri))]
198         const MAX_PER: usize = 515;
199         #[cfg(miri)]
200         const MAX_PER: usize = 10;
201 
202         let mut result = vec![];
203         for &(search, byte, fwd_pos, rev_pos) in TESTS {
204             result.push((search.to_vec(), byte, Some((fwd_pos, rev_pos))));
205             for i in 1..MAX_PER {
206                 // add a bunch of copies of the search byte to the end.
207                 let mut suffixed: Vec<u8> = search.into();
208                 suffixed.extend(std::iter::repeat(byte).take(i));
209                 result.push((suffixed, byte, Some((fwd_pos, rev_pos))));
210 
211                 // add a bunch of copies of the search byte to the start.
212                 let mut prefixed: Vec<u8> =
213                     std::iter::repeat(byte).take(i).collect();
214                 prefixed.extend(search);
215                 result.push((
216                     prefixed,
217                     byte,
218                     Some((fwd_pos + i, rev_pos + i)),
219                 ));
220 
221                 // add a bunch of copies of the search byte to both ends.
222                 let mut surrounded: Vec<u8> =
223                     std::iter::repeat(byte).take(i).collect();
224                 surrounded.extend(search);
225                 surrounded.extend(std::iter::repeat(byte).take(i));
226                 result.push((
227                     surrounded,
228                     byte,
229                     Some((fwd_pos + i, rev_pos + i)),
230                 ));
231             }
232         }
233 
234         // build non-matching tests for several sizes
235         for i in 0..MAX_PER {
236             result.push((
237                 std::iter::repeat(b'\0').take(i).collect(),
238                 b'\0',
239                 None,
240             ));
241         }
242 
243         result
244     }
245 
246     #[test]
test_inv_memchr()247     fn test_inv_memchr() {
248         use crate::{ByteSlice, B};
249 
250         #[cfg(not(miri))]
251         const MAX_OFFSET: usize = 130;
252         #[cfg(miri)]
253         const MAX_OFFSET: usize = 13;
254 
255         for (search, byte, matching) in build_tests() {
256             assert_eq!(
257                 inv_memchr(byte, &search),
258                 matching.map(|m| m.0),
259                 "inv_memchr when searching for {:?} in {:?}",
260                 byte as char,
261                 // better printing
262                 B(&search).as_bstr(),
263             );
264             assert_eq!(
265                 inv_memrchr(byte, &search),
266                 matching.map(|m| m.1),
267                 "inv_memrchr when searching for {:?} in {:?}",
268                 byte as char,
269                 // better printing
270                 B(&search).as_bstr(),
271             );
272             // Test a rather large number off offsets for potential alignment
273             // issues.
274             for offset in 1..MAX_OFFSET {
275                 if offset >= search.len() {
276                     break;
277                 }
278                 // If this would cause us to shift the results off the end,
279                 // skip it so that we don't have to recompute them.
280                 if let Some((f, r)) = matching {
281                     if offset > f || offset > r {
282                         break;
283                     }
284                 }
285                 let realigned = &search[offset..];
286 
287                 let forward_pos = matching.map(|m| m.0 - offset);
288                 let reverse_pos = matching.map(|m| m.1 - offset);
289 
290                 assert_eq!(
291                     inv_memchr(byte, &realigned),
292                     forward_pos,
293                     "inv_memchr when searching (realigned by {}) for {:?} in {:?}",
294                     offset,
295                     byte as char,
296                     realigned.as_bstr(),
297                 );
298                 assert_eq!(
299                     inv_memrchr(byte, &realigned),
300                     reverse_pos,
301                     "inv_memrchr when searching (realigned by {}) for {:?} in {:?}",
302                     offset,
303                     byte as char,
304                     realigned.as_bstr(),
305                 );
306             }
307         }
308     }
309 }
310