1 /*!
2 This module provides APIs for dealing with the alphabets of finite state
3 machines.
4 
5 There are two principal types in this module, [`ByteClasses`] and [`Unit`].
6 The former defines the alphabet of a finite state machine while the latter
7 represents an element of that alphabet.
8 
9 To a first approximation, the alphabet of all automata in this crate is just
10 a `u8`. Namely, every distinct byte value. All 256 of them. In practice, this
11 can be quite wasteful when building a transition table for a DFA, since it
12 requires storing a state identifier for each element in the alphabet. Instead,
13 we collapse the alphabet of an automaton down into equivalence classes, where
14 every byte in the same equivalence class never discriminates between a match or
15 a non-match from any other byte in the same class. For example, in the regex
16 `[a-z]+`, then you could consider it having an alphabet consisting of two
17 equivalence classes: `a-z` and everything else. In terms of the transitions on
18 an automaton, it doesn't actually require representing every distinct byte.
19 Just the equivalence classes.
20 
21 The downside of equivalence classes is that, of course, searching a haystack
22 deals with individual byte values. Those byte values need to be mapped to
23 their corresponding equivalence class. This is what `ByteClasses` does. In
24 practice, doing this for every state transition has negligible impact on modern
25 CPUs. Moreover, it helps make more efficient use of the CPU cache by (possibly
26 considerably) shrinking the size of the transition table.
27 
28 One last hiccup concerns `Unit`. Namely, because of look-around and how the
29 DFAs in this crate work, we need to add a sentinel value to our alphabet
30 of equivalence classes that represents the "end" of a search. We call that
31 sentinel [`Unit::eoi`] or "end of input." Thus, a `Unit` is either an
32 equivalence class corresponding to a set of bytes, or it is a special "end of
33 input" sentinel.
34 
35 In general, you should not expect to need either of these types unless you're
36 doing lower level shenanigans with DFAs, or even building your own DFAs.
37 (Although, you don't have to use these types to build your own DFAs of course.)
38 For example, if you're walking a DFA's state graph, it's probably useful to
39 make use of [`ByteClasses`] to visit each element in the DFA's alphabet instead
40 of just visiting every distinct `u8` value. The latter isn't necessarily wrong,
41 but it could be potentially very wasteful.
42 */
43 use crate::util::{
44     escape::DebugByte,
45     wire::{self, DeserializeError, SerializeError},
46 };
47 
48 /// Unit represents a single unit of haystack for DFA based regex engines.
49 ///
50 /// It is not expected for consumers of this crate to need to use this type
51 /// unless they are implementing their own DFA. And even then, it's not
52 /// required: implementors may use other techniques to handle haystack units.
53 ///
54 /// Typically, a single unit of haystack for a DFA would be a single byte.
55 /// However, for the DFAs in this crate, matches are delayed by a single byte
56 /// in order to handle look-ahead assertions (`\b`, `$` and `\z`). Thus, once
57 /// we have consumed the haystack, we must run the DFA through one additional
58 /// transition using a unit that indicates the haystack has ended.
59 ///
60 /// There is no way to represent a sentinel with a `u8` since all possible
61 /// values *may* be valid haystack units to a DFA, therefore this type
62 /// explicitly adds room for a sentinel value.
63 ///
64 /// The sentinel EOI value is always its own equivalence class and is
65 /// ultimately represented by adding 1 to the maximum equivalence class value.
66 /// So for example, the regex `^[a-z]+$` might be split into the following
67 /// equivalence classes:
68 ///
69 /// ```text
70 /// 0 => [\x00-`]
71 /// 1 => [a-z]
72 /// 2 => [{-\xFF]
73 /// 3 => [EOI]
74 /// ```
75 ///
76 /// Where EOI is the special sentinel value that is always in its own
77 /// singleton equivalence class.
78 #[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord)]
79 pub struct Unit(UnitKind);
80 
81 #[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord)]
82 enum UnitKind {
83     /// Represents a byte value, or more typically, an equivalence class
84     /// represented as a byte value.
85     U8(u8),
86     /// Represents the "end of input" sentinel. We regretably use a `u16`
87     /// here since the maximum sentinel value is `256`. Thankfully, we don't
88     /// actually store a `Unit` anywhere, so this extra space shouldn't be too
89     /// bad.
90     EOI(u16),
91 }
92 
93 impl Unit {
94     /// Create a new haystack unit from a byte value.
95     ///
96     /// All possible byte values are legal. However, when creating a haystack
97     /// unit for a specific DFA, one should be careful to only construct units
98     /// that are in that DFA's alphabet. Namely, one way to compact a DFA's
99     /// in-memory representation is to collapse its transitions to a set of
100     /// equivalence classes into a set of all possible byte values. If a DFA
101     /// uses equivalence classes instead of byte values, then the byte given
102     /// here should be the equivalence class.
u8(byte: u8) -> Unit103     pub fn u8(byte: u8) -> Unit {
104         Unit(UnitKind::U8(byte))
105     }
106 
107     /// Create a new "end of input" haystack unit.
108     ///
109     /// The value given is the sentinel value used by this unit to represent
110     /// the "end of input." The value should be the total number of equivalence
111     /// classes in the corresponding alphabet. Its maximum value is `256`,
112     /// which occurs when every byte is its own equivalence class.
113     ///
114     /// # Panics
115     ///
116     /// This panics when `num_byte_equiv_classes` is greater than `256`.
eoi(num_byte_equiv_classes: usize) -> Unit117     pub fn eoi(num_byte_equiv_classes: usize) -> Unit {
118         assert!(
119             num_byte_equiv_classes <= 256,
120             "max number of byte-based equivalent classes is 256, but got {}",
121             num_byte_equiv_classes,
122         );
123         Unit(UnitKind::EOI(u16::try_from(num_byte_equiv_classes).unwrap()))
124     }
125 
126     /// If this unit is not an "end of input" sentinel, then returns its
127     /// underlying byte value. Otherwise return `None`.
as_u8(self) -> Option<u8>128     pub fn as_u8(self) -> Option<u8> {
129         match self.0 {
130             UnitKind::U8(b) => Some(b),
131             UnitKind::EOI(_) => None,
132         }
133     }
134 
135     /// If this unit is an "end of input" sentinel, then return the underlying
136     /// sentinel value that was given to [`Unit::eoi`]. Otherwise return
137     /// `None`.
as_eoi(self) -> Option<u16>138     pub fn as_eoi(self) -> Option<u16> {
139         match self.0 {
140             UnitKind::U8(_) => None,
141             UnitKind::EOI(sentinel) => Some(sentinel),
142         }
143     }
144 
145     /// Return this unit as a `usize`, regardless of whether it is a byte value
146     /// or an "end of input" sentinel. In the latter case, the underlying
147     /// sentinel value given to [`Unit::eoi`] is returned.
as_usize(self) -> usize148     pub fn as_usize(self) -> usize {
149         match self.0 {
150             UnitKind::U8(b) => usize::from(b),
151             UnitKind::EOI(eoi) => usize::from(eoi),
152         }
153     }
154 
155     /// Returns true if and only of this unit is a byte value equivalent to the
156     /// byte given. This always returns false when this is an "end of input"
157     /// sentinel.
is_byte(self, byte: u8) -> bool158     pub fn is_byte(self, byte: u8) -> bool {
159         self.as_u8().map_or(false, |b| b == byte)
160     }
161 
162     /// Returns true when this unit represents an "end of input" sentinel.
is_eoi(self) -> bool163     pub fn is_eoi(self) -> bool {
164         self.as_eoi().is_some()
165     }
166 
167     /// Returns true when this unit corresponds to an ASCII word byte.
168     ///
169     /// This always returns false when this unit represents an "end of input"
170     /// sentinel.
is_word_byte(self) -> bool171     pub fn is_word_byte(self) -> bool {
172         self.as_u8().map_or(false, crate::util::utf8::is_word_byte)
173     }
174 }
175 
176 impl core::fmt::Debug for Unit {
fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result177     fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
178         match self.0 {
179             UnitKind::U8(b) => write!(f, "{:?}", DebugByte(b)),
180             UnitKind::EOI(_) => write!(f, "EOI"),
181         }
182     }
183 }
184 
185 /// A representation of byte oriented equivalence classes.
186 ///
187 /// This is used in a DFA to reduce the size of the transition table. This can
188 /// have a particularly large impact not only on the total size of a dense DFA,
189 /// but also on compile times.
190 ///
191 /// The essential idea here is that the alphabet of a DFA is shrunk from the
192 /// usual 256 distinct byte values down to a set of equivalence classes. The
193 /// guarantee you get is that any byte belonging to the same equivalence class
194 /// can be treated as if it were any other byte in the same class, and the
195 /// result of a search wouldn't change.
196 ///
197 /// # Example
198 ///
199 /// This example shows how to get byte classes from an
200 /// [`NFA`](crate::nfa::thompson::NFA) and ask for the class of various bytes.
201 ///
202 /// ```
203 /// use regex_automata::nfa::thompson::NFA;
204 ///
205 /// let nfa = NFA::new("[a-z]+")?;
206 /// let classes = nfa.byte_classes();
207 /// // 'a' and 'z' are in the same class for this regex.
208 /// assert_eq!(classes.get(b'a'), classes.get(b'z'));
209 /// // But 'a' and 'A' are not.
210 /// assert_ne!(classes.get(b'a'), classes.get(b'A'));
211 ///
212 /// # Ok::<(), Box<dyn std::error::Error>>(())
213 /// ```
214 #[derive(Clone, Copy)]
215 pub struct ByteClasses([u8; 256]);
216 
217 impl ByteClasses {
218     /// Creates a new set of equivalence classes where all bytes are mapped to
219     /// the same class.
220     #[inline]
empty() -> ByteClasses221     pub fn empty() -> ByteClasses {
222         ByteClasses([0; 256])
223     }
224 
225     /// Creates a new set of equivalence classes where each byte belongs to
226     /// its own equivalence class.
227     #[inline]
singletons() -> ByteClasses228     pub fn singletons() -> ByteClasses {
229         let mut classes = ByteClasses::empty();
230         for b in 0..=255 {
231             classes.set(b, b);
232         }
233         classes
234     }
235 
236     /// Deserializes a byte class map from the given slice. If the slice is of
237     /// insufficient length or otherwise contains an impossible mapping, then
238     /// an error is returned. Upon success, the number of bytes read along with
239     /// the map are returned. The number of bytes read is always a multiple of
240     /// 8.
from_bytes( slice: &[u8], ) -> Result<(ByteClasses, usize), DeserializeError>241     pub(crate) fn from_bytes(
242         slice: &[u8],
243     ) -> Result<(ByteClasses, usize), DeserializeError> {
244         wire::check_slice_len(slice, 256, "byte class map")?;
245         let mut classes = ByteClasses::empty();
246         for (b, &class) in slice[..256].iter().enumerate() {
247             classes.set(u8::try_from(b).unwrap(), class);
248         }
249         // We specifically don't use 'classes.iter()' here because that
250         // iterator depends on 'classes.alphabet_len()' being correct. But that
251         // is precisely the thing we're trying to verify below!
252         for &b in classes.0.iter() {
253             if usize::from(b) >= classes.alphabet_len() {
254                 return Err(DeserializeError::generic(
255                     "found equivalence class greater than alphabet len",
256                 ));
257             }
258         }
259         Ok((classes, 256))
260     }
261 
262     /// Writes this byte class map to the given byte buffer. if the given
263     /// buffer is too small, then an error is returned. Upon success, the total
264     /// number of bytes written is returned. The number of bytes written is
265     /// guaranteed to be a multiple of 8.
write_to( &self, mut dst: &mut [u8], ) -> Result<usize, SerializeError>266     pub(crate) fn write_to(
267         &self,
268         mut dst: &mut [u8],
269     ) -> Result<usize, SerializeError> {
270         let nwrite = self.write_to_len();
271         if dst.len() < nwrite {
272             return Err(SerializeError::buffer_too_small("byte class map"));
273         }
274         for b in 0..=255 {
275             dst[0] = self.get(b);
276             dst = &mut dst[1..];
277         }
278         Ok(nwrite)
279     }
280 
281     /// Returns the total number of bytes written by `write_to`.
write_to_len(&self) -> usize282     pub(crate) fn write_to_len(&self) -> usize {
283         256
284     }
285 
286     /// Set the equivalence class for the given byte.
287     #[inline]
set(&mut self, byte: u8, class: u8)288     pub fn set(&mut self, byte: u8, class: u8) {
289         self.0[usize::from(byte)] = class;
290     }
291 
292     /// Get the equivalence class for the given byte.
293     #[inline]
get(&self, byte: u8) -> u8294     pub fn get(&self, byte: u8) -> u8 {
295         self.0[usize::from(byte)]
296     }
297 
298     /// Get the equivalence class for the given haystack unit and return the
299     /// class as a `usize`.
300     #[inline]
get_by_unit(&self, unit: Unit) -> usize301     pub fn get_by_unit(&self, unit: Unit) -> usize {
302         match unit.0 {
303             UnitKind::U8(b) => usize::from(self.get(b)),
304             UnitKind::EOI(b) => usize::from(b),
305         }
306     }
307 
308     /// Create a unit that represents the "end of input" sentinel based on the
309     /// number of equivalence classes.
310     #[inline]
eoi(&self) -> Unit311     pub fn eoi(&self) -> Unit {
312         // The alphabet length already includes the EOI sentinel, hence why
313         // we subtract 1.
314         Unit::eoi(self.alphabet_len().checked_sub(1).unwrap())
315     }
316 
317     /// Return the total number of elements in the alphabet represented by
318     /// these equivalence classes. Equivalently, this returns the total number
319     /// of equivalence classes.
320     #[inline]
alphabet_len(&self) -> usize321     pub fn alphabet_len(&self) -> usize {
322         // Add one since the number of equivalence classes is one bigger than
323         // the last one. But add another to account for the final EOI class
324         // that isn't explicitly represented.
325         usize::from(self.0[255]) + 1 + 1
326     }
327 
328     /// Returns the stride, as a base-2 exponent, required for these
329     /// equivalence classes.
330     ///
331     /// The stride is always the smallest power of 2 that is greater than or
332     /// equal to the alphabet length, and the `stride2` returned here is the
333     /// exponent applied to `2` to get the smallest power. This is done so that
334     /// converting between premultiplied state IDs and indices can be done with
335     /// shifts alone, which is much faster than integer division.
336     #[inline]
stride2(&self) -> usize337     pub fn stride2(&self) -> usize {
338         let zeros = self.alphabet_len().next_power_of_two().trailing_zeros();
339         usize::try_from(zeros).unwrap()
340     }
341 
342     /// Returns true if and only if every byte in this class maps to its own
343     /// equivalence class. Equivalently, there are 257 equivalence classes
344     /// and each class contains either exactly one byte or corresponds to the
345     /// singleton class containing the "end of input" sentinel.
346     #[inline]
is_singleton(&self) -> bool347     pub fn is_singleton(&self) -> bool {
348         self.alphabet_len() == 257
349     }
350 
351     /// Returns an iterator over all equivalence classes in this set.
352     #[inline]
iter(&self) -> ByteClassIter<'_>353     pub fn iter(&self) -> ByteClassIter<'_> {
354         ByteClassIter { classes: self, i: 0 }
355     }
356 
357     /// Returns an iterator over a sequence of representative bytes from each
358     /// equivalence class within the range of bytes given.
359     ///
360     /// When the given range is unbounded on both sides, the iterator yields
361     /// exactly N items, where N is equivalent to the number of equivalence
362     /// classes. Each item is an arbitrary byte drawn from each equivalence
363     /// class.
364     ///
365     /// This is useful when one is determinizing an NFA and the NFA's alphabet
366     /// hasn't been converted to equivalence classes. Picking an arbitrary byte
367     /// from each equivalence class then permits a full exploration of the NFA
368     /// instead of using every possible byte value and thus potentially saves
369     /// quite a lot of redundant work.
370     ///
371     /// # Example
372     ///
373     /// This shows an example of what a complete sequence of representatives
374     /// might look like from a real example.
375     ///
376     /// ```
377     /// use regex_automata::{nfa::thompson::NFA, util::alphabet::Unit};
378     ///
379     /// let nfa = NFA::new("[a-z]+")?;
380     /// let classes = nfa.byte_classes();
381     /// let reps: Vec<Unit> = classes.representatives(..).collect();
382     /// // Note that the specific byte values yielded are not guaranteed!
383     /// let expected = vec![
384     ///     Unit::u8(b'\x00'),
385     ///     Unit::u8(b'a'),
386     ///     Unit::u8(b'{'),
387     ///     Unit::eoi(3),
388     /// ];
389     /// assert_eq!(expected, reps);
390     ///
391     /// # Ok::<(), Box<dyn std::error::Error>>(())
392     /// ```
393     ///
394     /// Note though, that you can ask for an arbitrary range of bytes, and only
395     /// representatives for that range will be returned:
396     ///
397     /// ```
398     /// use regex_automata::{nfa::thompson::NFA, util::alphabet::Unit};
399     ///
400     /// let nfa = NFA::new("[a-z]+")?;
401     /// let classes = nfa.byte_classes();
402     /// let reps: Vec<Unit> = classes.representatives(b'A'..=b'z').collect();
403     /// // Note that the specific byte values yielded are not guaranteed!
404     /// let expected = vec![
405     ///     Unit::u8(b'A'),
406     ///     Unit::u8(b'a'),
407     /// ];
408     /// assert_eq!(expected, reps);
409     ///
410     /// # Ok::<(), Box<dyn std::error::Error>>(())
411     /// ```
representatives<R: core::ops::RangeBounds<u8>>( &self, range: R, ) -> ByteClassRepresentatives<'_>412     pub fn representatives<R: core::ops::RangeBounds<u8>>(
413         &self,
414         range: R,
415     ) -> ByteClassRepresentatives<'_> {
416         use core::ops::Bound;
417 
418         let cur_byte = match range.start_bound() {
419             Bound::Included(&i) => usize::from(i),
420             Bound::Excluded(&i) => usize::from(i).checked_add(1).unwrap(),
421             Bound::Unbounded => 0,
422         };
423         let end_byte = match range.end_bound() {
424             Bound::Included(&i) => {
425                 Some(usize::from(i).checked_add(1).unwrap())
426             }
427             Bound::Excluded(&i) => Some(usize::from(i)),
428             Bound::Unbounded => None,
429         };
430         assert_ne!(
431             cur_byte,
432             usize::MAX,
433             "start range must be less than usize::MAX",
434         );
435         ByteClassRepresentatives {
436             classes: self,
437             cur_byte,
438             end_byte,
439             last_class: None,
440         }
441     }
442 
443     /// Returns an iterator of the bytes in the given equivalence class.
444     ///
445     /// This is useful when one needs to know the actual bytes that belong to
446     /// an equivalence class. For example, conceptually speaking, accelerating
447     /// a DFA state occurs when a state only has a few outgoing transitions.
448     /// But in reality, what is required is that there are only a small
449     /// number of distinct bytes that can lead to an outgoing transition. The
450     /// difference is that any one transition can correspond to an equivalence
451     /// class which may contains many bytes. Therefore, DFA state acceleration
452     /// considers the actual elements in each equivalence class of each
453     /// outgoing transition.
454     ///
455     /// # Example
456     ///
457     /// This shows an example of how to get all of the elements in an
458     /// equivalence class.
459     ///
460     /// ```
461     /// use regex_automata::{nfa::thompson::NFA, util::alphabet::Unit};
462     ///
463     /// let nfa = NFA::new("[a-z]+")?;
464     /// let classes = nfa.byte_classes();
465     /// let elements: Vec<Unit> = classes.elements(Unit::u8(1)).collect();
466     /// let expected: Vec<Unit> = (b'a'..=b'z').map(Unit::u8).collect();
467     /// assert_eq!(expected, elements);
468     ///
469     /// # Ok::<(), Box<dyn std::error::Error>>(())
470     /// ```
471     #[inline]
elements(&self, class: Unit) -> ByteClassElements472     pub fn elements(&self, class: Unit) -> ByteClassElements {
473         ByteClassElements { classes: self, class, byte: 0 }
474     }
475 
476     /// Returns an iterator of byte ranges in the given equivalence class.
477     ///
478     /// That is, a sequence of contiguous ranges are returned. Typically, every
479     /// class maps to a single contiguous range.
element_ranges(&self, class: Unit) -> ByteClassElementRanges480     fn element_ranges(&self, class: Unit) -> ByteClassElementRanges {
481         ByteClassElementRanges { elements: self.elements(class), range: None }
482     }
483 }
484 
485 impl Default for ByteClasses {
default() -> ByteClasses486     fn default() -> ByteClasses {
487         ByteClasses::singletons()
488     }
489 }
490 
491 impl core::fmt::Debug for ByteClasses {
fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result492     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
493         if self.is_singleton() {
494             write!(f, "ByteClasses({{singletons}})")
495         } else {
496             write!(f, "ByteClasses(")?;
497             for (i, class) in self.iter().enumerate() {
498                 if i > 0 {
499                     write!(f, ", ")?;
500                 }
501                 write!(f, "{:?} => [", class.as_usize())?;
502                 for (start, end) in self.element_ranges(class) {
503                     if start == end {
504                         write!(f, "{:?}", start)?;
505                     } else {
506                         write!(f, "{:?}-{:?}", start, end)?;
507                     }
508                 }
509                 write!(f, "]")?;
510             }
511             write!(f, ")")
512         }
513     }
514 }
515 
516 /// An iterator over each equivalence class.
517 ///
518 /// The last element in this iterator always corresponds to [`Unit::eoi`].
519 ///
520 /// This is created by the [`ByteClasses::iter`] method.
521 ///
522 /// The lifetime `'a` refers to the lifetime of the byte classes that this
523 /// iterator was created from.
524 #[derive(Debug)]
525 pub struct ByteClassIter<'a> {
526     classes: &'a ByteClasses,
527     i: usize,
528 }
529 
530 impl<'a> Iterator for ByteClassIter<'a> {
531     type Item = Unit;
532 
next(&mut self) -> Option<Unit>533     fn next(&mut self) -> Option<Unit> {
534         if self.i + 1 == self.classes.alphabet_len() {
535             self.i += 1;
536             Some(self.classes.eoi())
537         } else if self.i < self.classes.alphabet_len() {
538             let class = u8::try_from(self.i).unwrap();
539             self.i += 1;
540             Some(Unit::u8(class))
541         } else {
542             None
543         }
544     }
545 }
546 
547 /// An iterator over representative bytes from each equivalence class.
548 ///
549 /// This is created by the [`ByteClasses::representatives`] method.
550 ///
551 /// The lifetime `'a` refers to the lifetime of the byte classes that this
552 /// iterator was created from.
553 #[derive(Debug)]
554 pub struct ByteClassRepresentatives<'a> {
555     classes: &'a ByteClasses,
556     cur_byte: usize,
557     end_byte: Option<usize>,
558     last_class: Option<u8>,
559 }
560 
561 impl<'a> Iterator for ByteClassRepresentatives<'a> {
562     type Item = Unit;
563 
next(&mut self) -> Option<Unit>564     fn next(&mut self) -> Option<Unit> {
565         while self.cur_byte < self.end_byte.unwrap_or(256) {
566             let byte = u8::try_from(self.cur_byte).unwrap();
567             let class = self.classes.get(byte);
568             self.cur_byte += 1;
569 
570             if self.last_class != Some(class) {
571                 self.last_class = Some(class);
572                 return Some(Unit::u8(byte));
573             }
574         }
575         if self.cur_byte != usize::MAX && self.end_byte.is_none() {
576             // Using usize::MAX as a sentinel is OK because we ban usize::MAX
577             // from appearing as a start bound in iterator construction. But
578             // why do it this way? Well, we want to return the EOI class
579             // whenever the end of the given range is unbounded because EOI
580             // isn't really a "byte" per se, so the only way it should be
581             // excluded is if there is a bounded end to the range. Therefore,
582             // when the end is unbounded, we just need to know whether we've
583             // reported EOI or not. When we do, we set cur_byte to a value it
584             // can never otherwise be.
585             self.cur_byte = usize::MAX;
586             return Some(self.classes.eoi());
587         }
588         None
589     }
590 }
591 
592 /// An iterator over all elements in an equivalence class.
593 ///
594 /// This is created by the [`ByteClasses::elements`] method.
595 ///
596 /// The lifetime `'a` refers to the lifetime of the byte classes that this
597 /// iterator was created from.
598 #[derive(Debug)]
599 pub struct ByteClassElements<'a> {
600     classes: &'a ByteClasses,
601     class: Unit,
602     byte: usize,
603 }
604 
605 impl<'a> Iterator for ByteClassElements<'a> {
606     type Item = Unit;
607 
next(&mut self) -> Option<Unit>608     fn next(&mut self) -> Option<Unit> {
609         while self.byte < 256 {
610             let byte = u8::try_from(self.byte).unwrap();
611             self.byte += 1;
612             if self.class.is_byte(self.classes.get(byte)) {
613                 return Some(Unit::u8(byte));
614             }
615         }
616         if self.byte < 257 {
617             self.byte += 1;
618             if self.class.is_eoi() {
619                 return Some(Unit::eoi(256));
620             }
621         }
622         None
623     }
624 }
625 
626 /// An iterator over all elements in an equivalence class expressed as a
627 /// sequence of contiguous ranges.
628 #[derive(Debug)]
629 struct ByteClassElementRanges<'a> {
630     elements: ByteClassElements<'a>,
631     range: Option<(Unit, Unit)>,
632 }
633 
634 impl<'a> Iterator for ByteClassElementRanges<'a> {
635     type Item = (Unit, Unit);
636 
next(&mut self) -> Option<(Unit, Unit)>637     fn next(&mut self) -> Option<(Unit, Unit)> {
638         loop {
639             let element = match self.elements.next() {
640                 None => return self.range.take(),
641                 Some(element) => element,
642             };
643             match self.range.take() {
644                 None => {
645                     self.range = Some((element, element));
646                 }
647                 Some((start, end)) => {
648                     if end.as_usize() + 1 != element.as_usize()
649                         || element.is_eoi()
650                     {
651                         self.range = Some((element, element));
652                         return Some((start, end));
653                     }
654                     self.range = Some((start, element));
655                 }
656             }
657         }
658     }
659 }
660 
661 /// A partitioning of bytes into equivalence classes.
662 ///
663 /// A byte class set keeps track of an *approximation* of equivalence classes
664 /// of bytes during NFA construction. That is, every byte in an equivalence
665 /// class cannot discriminate between a match and a non-match.
666 ///
667 /// For example, in the regex `[ab]+`, the bytes `a` and `b` would be in the
668 /// same equivalence class because it never matters whether an `a` or a `b` is
669 /// seen, and no combination of `a`s and `b`s in the text can discriminate a
670 /// match.
671 ///
672 /// Note though that this does not compute the minimal set of equivalence
673 /// classes. For example, in the regex `[ac]+`, both `a` and `c` are in the
674 /// same equivalence class for the same reason that `a` and `b` are in the
675 /// same equivalence class in the aforementioned regex. However, in this
676 /// implementation, `a` and `c` are put into distinct equivalence classes. The
677 /// reason for this is implementation complexity. In the future, we should
678 /// endeavor to compute the minimal equivalence classes since they can have a
679 /// rather large impact on the size of the DFA. (Doing this will likely require
680 /// rethinking how equivalence classes are computed, including changing the
681 /// representation here, which is only able to group contiguous bytes into the
682 /// same equivalence class.)
683 #[cfg(feature = "alloc")]
684 #[derive(Clone, Debug)]
685 pub(crate) struct ByteClassSet(ByteSet);
686 
687 #[cfg(feature = "alloc")]
688 impl Default for ByteClassSet {
default() -> ByteClassSet689     fn default() -> ByteClassSet {
690         ByteClassSet::empty()
691     }
692 }
693 
694 #[cfg(feature = "alloc")]
695 impl ByteClassSet {
696     /// Create a new set of byte classes where all bytes are part of the same
697     /// equivalence class.
empty() -> Self698     pub(crate) fn empty() -> Self {
699         ByteClassSet(ByteSet::empty())
700     }
701 
702     /// Indicate the the range of byte given (inclusive) can discriminate a
703     /// match between it and all other bytes outside of the range.
set_range(&mut self, start: u8, end: u8)704     pub(crate) fn set_range(&mut self, start: u8, end: u8) {
705         debug_assert!(start <= end);
706         if start > 0 {
707             self.0.add(start - 1);
708         }
709         self.0.add(end);
710     }
711 
712     /// Add the contiguous ranges in the set given to this byte class set.
add_set(&mut self, set: &ByteSet)713     pub(crate) fn add_set(&mut self, set: &ByteSet) {
714         for (start, end) in set.iter_ranges() {
715             self.set_range(start, end);
716         }
717     }
718 
719     /// Convert this boolean set to a map that maps all byte values to their
720     /// corresponding equivalence class. The last mapping indicates the largest
721     /// equivalence class identifier (which is never bigger than 255).
byte_classes(&self) -> ByteClasses722     pub(crate) fn byte_classes(&self) -> ByteClasses {
723         let mut classes = ByteClasses::empty();
724         let mut class = 0u8;
725         let mut b = 0u8;
726         loop {
727             classes.set(b, class);
728             if b == 255 {
729                 break;
730             }
731             if self.0.contains(b) {
732                 class = class.checked_add(1).unwrap();
733             }
734             b = b.checked_add(1).unwrap();
735         }
736         classes
737     }
738 }
739 
740 /// A simple set of bytes that is reasonably cheap to copy and allocation free.
741 #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
742 pub(crate) struct ByteSet {
743     bits: BitSet,
744 }
745 
746 /// The representation of a byte set. Split out so that we can define a
747 /// convenient Debug impl for it while keeping "ByteSet" in the output.
748 #[derive(Clone, Copy, Default, Eq, PartialEq)]
749 struct BitSet([u128; 2]);
750 
751 impl ByteSet {
752     /// Create an empty set of bytes.
empty() -> ByteSet753     pub(crate) fn empty() -> ByteSet {
754         ByteSet { bits: BitSet([0; 2]) }
755     }
756 
757     /// Add a byte to this set.
758     ///
759     /// If the given byte already belongs to this set, then this is a no-op.
add(&mut self, byte: u8)760     pub(crate) fn add(&mut self, byte: u8) {
761         let bucket = byte / 128;
762         let bit = byte % 128;
763         self.bits.0[usize::from(bucket)] |= 1 << bit;
764     }
765 
766     /// Remove a byte from this set.
767     ///
768     /// If the given byte is not in this set, then this is a no-op.
remove(&mut self, byte: u8)769     pub(crate) fn remove(&mut self, byte: u8) {
770         let bucket = byte / 128;
771         let bit = byte % 128;
772         self.bits.0[usize::from(bucket)] &= !(1 << bit);
773     }
774 
775     /// Return true if and only if the given byte is in this set.
contains(&self, byte: u8) -> bool776     pub(crate) fn contains(&self, byte: u8) -> bool {
777         let bucket = byte / 128;
778         let bit = byte % 128;
779         self.bits.0[usize::from(bucket)] & (1 << bit) > 0
780     }
781 
782     /// Return true if and only if the given inclusive range of bytes is in
783     /// this set.
contains_range(&self, start: u8, end: u8) -> bool784     pub(crate) fn contains_range(&self, start: u8, end: u8) -> bool {
785         (start..=end).all(|b| self.contains(b))
786     }
787 
788     /// Returns an iterator over all bytes in this set.
iter(&self) -> ByteSetIter789     pub(crate) fn iter(&self) -> ByteSetIter {
790         ByteSetIter { set: self, b: 0 }
791     }
792 
793     /// Returns an iterator over all contiguous ranges of bytes in this set.
iter_ranges(&self) -> ByteSetRangeIter794     pub(crate) fn iter_ranges(&self) -> ByteSetRangeIter {
795         ByteSetRangeIter { set: self, b: 0 }
796     }
797 
798     /// Return true if and only if this set is empty.
799     #[cfg_attr(feature = "perf-inline", inline(always))]
is_empty(&self) -> bool800     pub(crate) fn is_empty(&self) -> bool {
801         self.bits.0 == [0, 0]
802     }
803 
804     /// Deserializes a byte set from the given slice. If the slice is of
805     /// incorrect length or is otherwise malformed, then an error is returned.
806     /// Upon success, the number of bytes read along with the set are returned.
807     /// The number of bytes read is always a multiple of 8.
from_bytes( slice: &[u8], ) -> Result<(ByteSet, usize), DeserializeError>808     pub(crate) fn from_bytes(
809         slice: &[u8],
810     ) -> Result<(ByteSet, usize), DeserializeError> {
811         use core::mem::size_of;
812 
813         wire::check_slice_len(slice, 2 * size_of::<u128>(), "byte set")?;
814         let mut nread = 0;
815         let (low, nr) = wire::try_read_u128(slice, "byte set low bucket")?;
816         nread += nr;
817         let (high, nr) = wire::try_read_u128(slice, "byte set high bucket")?;
818         nread += nr;
819         Ok((ByteSet { bits: BitSet([low, high]) }, nread))
820     }
821 
822     /// Writes this byte set to the given byte buffer. If the given buffer is
823     /// too small, then an error is returned. Upon success, the total number of
824     /// bytes written is returned. The number of bytes written is guaranteed to
825     /// be a multiple of 8.
write_to<E: crate::util::wire::Endian>( &self, dst: &mut [u8], ) -> Result<usize, SerializeError>826     pub(crate) fn write_to<E: crate::util::wire::Endian>(
827         &self,
828         dst: &mut [u8],
829     ) -> Result<usize, SerializeError> {
830         use core::mem::size_of;
831 
832         let nwrite = self.write_to_len();
833         if dst.len() < nwrite {
834             return Err(SerializeError::buffer_too_small("byte set"));
835         }
836         let mut nw = 0;
837         E::write_u128(self.bits.0[0], &mut dst[nw..]);
838         nw += size_of::<u128>();
839         E::write_u128(self.bits.0[1], &mut dst[nw..]);
840         nw += size_of::<u128>();
841         assert_eq!(nwrite, nw, "expected to write certain number of bytes",);
842         assert_eq!(
843             nw % 8,
844             0,
845             "expected to write multiple of 8 bytes for byte set",
846         );
847         Ok(nw)
848     }
849 
850     /// Returns the total number of bytes written by `write_to`.
write_to_len(&self) -> usize851     pub(crate) fn write_to_len(&self) -> usize {
852         2 * core::mem::size_of::<u128>()
853     }
854 }
855 
856 impl core::fmt::Debug for BitSet {
fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result857     fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
858         let mut fmtd = f.debug_set();
859         for b in 0u8..=255 {
860             if (ByteSet { bits: *self }).contains(b) {
861                 fmtd.entry(&b);
862             }
863         }
864         fmtd.finish()
865     }
866 }
867 
868 #[derive(Debug)]
869 pub(crate) struct ByteSetIter<'a> {
870     set: &'a ByteSet,
871     b: usize,
872 }
873 
874 impl<'a> Iterator for ByteSetIter<'a> {
875     type Item = u8;
876 
next(&mut self) -> Option<u8>877     fn next(&mut self) -> Option<u8> {
878         while self.b <= 255 {
879             let b = u8::try_from(self.b).unwrap();
880             self.b += 1;
881             if self.set.contains(b) {
882                 return Some(b);
883             }
884         }
885         None
886     }
887 }
888 
889 #[derive(Debug)]
890 pub(crate) struct ByteSetRangeIter<'a> {
891     set: &'a ByteSet,
892     b: usize,
893 }
894 
895 impl<'a> Iterator for ByteSetRangeIter<'a> {
896     type Item = (u8, u8);
897 
next(&mut self) -> Option<(u8, u8)>898     fn next(&mut self) -> Option<(u8, u8)> {
899         let asu8 = |n: usize| u8::try_from(n).unwrap();
900         while self.b <= 255 {
901             let start = asu8(self.b);
902             self.b += 1;
903             if !self.set.contains(start) {
904                 continue;
905             }
906 
907             let mut end = start;
908             while self.b <= 255 && self.set.contains(asu8(self.b)) {
909                 end = asu8(self.b);
910                 self.b += 1;
911             }
912             return Some((start, end));
913         }
914         None
915     }
916 }
917 
918 #[cfg(all(test, feature = "alloc"))]
919 mod tests {
920     use alloc::{vec, vec::Vec};
921 
922     use super::*;
923 
924     #[test]
byte_classes()925     fn byte_classes() {
926         let mut set = ByteClassSet::empty();
927         set.set_range(b'a', b'z');
928 
929         let classes = set.byte_classes();
930         assert_eq!(classes.get(0), 0);
931         assert_eq!(classes.get(1), 0);
932         assert_eq!(classes.get(2), 0);
933         assert_eq!(classes.get(b'a' - 1), 0);
934         assert_eq!(classes.get(b'a'), 1);
935         assert_eq!(classes.get(b'm'), 1);
936         assert_eq!(classes.get(b'z'), 1);
937         assert_eq!(classes.get(b'z' + 1), 2);
938         assert_eq!(classes.get(254), 2);
939         assert_eq!(classes.get(255), 2);
940 
941         let mut set = ByteClassSet::empty();
942         set.set_range(0, 2);
943         set.set_range(4, 6);
944         let classes = set.byte_classes();
945         assert_eq!(classes.get(0), 0);
946         assert_eq!(classes.get(1), 0);
947         assert_eq!(classes.get(2), 0);
948         assert_eq!(classes.get(3), 1);
949         assert_eq!(classes.get(4), 2);
950         assert_eq!(classes.get(5), 2);
951         assert_eq!(classes.get(6), 2);
952         assert_eq!(classes.get(7), 3);
953         assert_eq!(classes.get(255), 3);
954     }
955 
956     #[test]
full_byte_classes()957     fn full_byte_classes() {
958         let mut set = ByteClassSet::empty();
959         for b in 0u8..=255 {
960             set.set_range(b, b);
961         }
962         assert_eq!(set.byte_classes().alphabet_len(), 257);
963     }
964 
965     #[test]
elements_typical()966     fn elements_typical() {
967         let mut set = ByteClassSet::empty();
968         set.set_range(b'b', b'd');
969         set.set_range(b'g', b'm');
970         set.set_range(b'z', b'z');
971         let classes = set.byte_classes();
972         // class 0: \x00-a
973         // class 1: b-d
974         // class 2: e-f
975         // class 3: g-m
976         // class 4: n-y
977         // class 5: z-z
978         // class 6: \x7B-\xFF
979         // class 7: EOI
980         assert_eq!(classes.alphabet_len(), 8);
981 
982         let elements = classes.elements(Unit::u8(0)).collect::<Vec<_>>();
983         assert_eq!(elements.len(), 98);
984         assert_eq!(elements[0], Unit::u8(b'\x00'));
985         assert_eq!(elements[97], Unit::u8(b'a'));
986 
987         let elements = classes.elements(Unit::u8(1)).collect::<Vec<_>>();
988         assert_eq!(
989             elements,
990             vec![Unit::u8(b'b'), Unit::u8(b'c'), Unit::u8(b'd')],
991         );
992 
993         let elements = classes.elements(Unit::u8(2)).collect::<Vec<_>>();
994         assert_eq!(elements, vec![Unit::u8(b'e'), Unit::u8(b'f')],);
995 
996         let elements = classes.elements(Unit::u8(3)).collect::<Vec<_>>();
997         assert_eq!(
998             elements,
999             vec![
1000                 Unit::u8(b'g'),
1001                 Unit::u8(b'h'),
1002                 Unit::u8(b'i'),
1003                 Unit::u8(b'j'),
1004                 Unit::u8(b'k'),
1005                 Unit::u8(b'l'),
1006                 Unit::u8(b'm'),
1007             ],
1008         );
1009 
1010         let elements = classes.elements(Unit::u8(4)).collect::<Vec<_>>();
1011         assert_eq!(elements.len(), 12);
1012         assert_eq!(elements[0], Unit::u8(b'n'));
1013         assert_eq!(elements[11], Unit::u8(b'y'));
1014 
1015         let elements = classes.elements(Unit::u8(5)).collect::<Vec<_>>();
1016         assert_eq!(elements, vec![Unit::u8(b'z')]);
1017 
1018         let elements = classes.elements(Unit::u8(6)).collect::<Vec<_>>();
1019         assert_eq!(elements.len(), 133);
1020         assert_eq!(elements[0], Unit::u8(b'\x7B'));
1021         assert_eq!(elements[132], Unit::u8(b'\xFF'));
1022 
1023         let elements = classes.elements(Unit::eoi(7)).collect::<Vec<_>>();
1024         assert_eq!(elements, vec![Unit::eoi(256)]);
1025     }
1026 
1027     #[test]
elements_singletons()1028     fn elements_singletons() {
1029         let classes = ByteClasses::singletons();
1030         assert_eq!(classes.alphabet_len(), 257);
1031 
1032         let elements = classes.elements(Unit::u8(b'a')).collect::<Vec<_>>();
1033         assert_eq!(elements, vec![Unit::u8(b'a')]);
1034 
1035         let elements = classes.elements(Unit::eoi(5)).collect::<Vec<_>>();
1036         assert_eq!(elements, vec![Unit::eoi(256)]);
1037     }
1038 
1039     #[test]
elements_empty()1040     fn elements_empty() {
1041         let classes = ByteClasses::empty();
1042         assert_eq!(classes.alphabet_len(), 2);
1043 
1044         let elements = classes.elements(Unit::u8(0)).collect::<Vec<_>>();
1045         assert_eq!(elements.len(), 256);
1046         assert_eq!(elements[0], Unit::u8(b'\x00'));
1047         assert_eq!(elements[255], Unit::u8(b'\xFF'));
1048 
1049         let elements = classes.elements(Unit::eoi(1)).collect::<Vec<_>>();
1050         assert_eq!(elements, vec![Unit::eoi(256)]);
1051     }
1052 
1053     #[test]
representatives()1054     fn representatives() {
1055         let mut set = ByteClassSet::empty();
1056         set.set_range(b'b', b'd');
1057         set.set_range(b'g', b'm');
1058         set.set_range(b'z', b'z');
1059         let classes = set.byte_classes();
1060 
1061         let got: Vec<Unit> = classes.representatives(..).collect();
1062         let expected = vec![
1063             Unit::u8(b'\x00'),
1064             Unit::u8(b'b'),
1065             Unit::u8(b'e'),
1066             Unit::u8(b'g'),
1067             Unit::u8(b'n'),
1068             Unit::u8(b'z'),
1069             Unit::u8(b'\x7B'),
1070             Unit::eoi(7),
1071         ];
1072         assert_eq!(expected, got);
1073 
1074         let got: Vec<Unit> = classes.representatives(..0).collect();
1075         assert!(got.is_empty());
1076         let got: Vec<Unit> = classes.representatives(1..1).collect();
1077         assert!(got.is_empty());
1078         let got: Vec<Unit> = classes.representatives(255..255).collect();
1079         assert!(got.is_empty());
1080 
1081         // A weird case that is the only guaranteed to way to get an iterator
1082         // of just the EOI class by excluding all possible byte values.
1083         let got: Vec<Unit> = classes
1084             .representatives((
1085                 core::ops::Bound::Excluded(255),
1086                 core::ops::Bound::Unbounded,
1087             ))
1088             .collect();
1089         let expected = vec![Unit::eoi(7)];
1090         assert_eq!(expected, got);
1091 
1092         let got: Vec<Unit> = classes.representatives(..=255).collect();
1093         let expected = vec![
1094             Unit::u8(b'\x00'),
1095             Unit::u8(b'b'),
1096             Unit::u8(b'e'),
1097             Unit::u8(b'g'),
1098             Unit::u8(b'n'),
1099             Unit::u8(b'z'),
1100             Unit::u8(b'\x7B'),
1101         ];
1102         assert_eq!(expected, got);
1103 
1104         let got: Vec<Unit> = classes.representatives(b'b'..=b'd').collect();
1105         let expected = vec![Unit::u8(b'b')];
1106         assert_eq!(expected, got);
1107 
1108         let got: Vec<Unit> = classes.representatives(b'a'..=b'd').collect();
1109         let expected = vec![Unit::u8(b'a'), Unit::u8(b'b')];
1110         assert_eq!(expected, got);
1111 
1112         let got: Vec<Unit> = classes.representatives(b'b'..=b'e').collect();
1113         let expected = vec![Unit::u8(b'b'), Unit::u8(b'e')];
1114         assert_eq!(expected, got);
1115 
1116         let got: Vec<Unit> = classes.representatives(b'A'..=b'Z').collect();
1117         let expected = vec![Unit::u8(b'A')];
1118         assert_eq!(expected, got);
1119 
1120         let got: Vec<Unit> = classes.representatives(b'A'..=b'z').collect();
1121         let expected = vec![
1122             Unit::u8(b'A'),
1123             Unit::u8(b'b'),
1124             Unit::u8(b'e'),
1125             Unit::u8(b'g'),
1126             Unit::u8(b'n'),
1127             Unit::u8(b'z'),
1128         ];
1129         assert_eq!(expected, got);
1130 
1131         let got: Vec<Unit> = classes.representatives(b'z'..).collect();
1132         let expected = vec![Unit::u8(b'z'), Unit::u8(b'\x7B'), Unit::eoi(7)];
1133         assert_eq!(expected, got);
1134 
1135         let got: Vec<Unit> = classes.representatives(b'z'..=0xFF).collect();
1136         let expected = vec![Unit::u8(b'z'), Unit::u8(b'\x7B')];
1137         assert_eq!(expected, got);
1138     }
1139 }
1140