1 /*!
2 Types and routines that support the wire format of finite automata.
3 
4 Currently, this module just exports a few error types and some small helpers
5 for deserializing [dense DFAs](crate::dfa::dense::DFA) using correct alignment.
6 */
7 
8 /*
9 A collection of helper functions, types and traits for serializing automata.
10 
11 This crate defines its own bespoke serialization mechanism for some structures
12 provided in the public API, namely, DFAs. A bespoke mechanism was developed
13 primarily because structures like automata demand a specific binary format.
14 Attempting to encode their rich structure in an existing serialization
15 format is just not feasible. Moreover, the format for each structure is
16 generally designed such that deserialization is cheap. More specifically, that
17 deserialization can be done in constant time. (The idea being that you can
18 embed it into your binary or mmap it, and then use it immediately.)
19 
20 In order to achieve this, the dense and sparse DFAs in this crate use an
21 in-memory representation that very closely corresponds to its binary serialized
22 form. This pervades and complicates everything, and in some cases, requires
23 dealing with alignment and reasoning about safety.
24 
25 This technique does have major advantages. In particular, it permits doing
26 the potentially costly work of compiling a finite state machine in an offline
27 manner, and then loading it at runtime not only without having to re-compile
28 the regex, but even without the code required to do the compilation. This, for
29 example, permits one to use a pre-compiled DFA not only in environments without
30 Rust's standard library, but also in environments without a heap.
31 
32 In the code below, whenever we insert some kind of padding, it's to enforce a
33 4-byte alignment, unless otherwise noted. Namely, u32 is the only state ID type
34 supported. (In a previous version of this library, DFAs were generic over the
35 state ID representation.)
36 
37 Also, serialization generally requires the caller to specify endianness,
38 where as deserialization always assumes native endianness (otherwise cheap
39 deserialization would be impossible). This implies that serializing a structure
40 generally requires serializing both its big-endian and little-endian variants,
41 and then loading the correct one based on the target's endianness.
42 */
43 
44 use core::{cmp, mem::size_of};
45 
46 #[cfg(feature = "alloc")]
47 use alloc::{vec, vec::Vec};
48 
49 use crate::util::{
50     int::Pointer,
51     primitives::{PatternID, PatternIDError, StateID, StateIDError},
52 };
53 
54 /// A hack to align a smaller type `B` with a bigger type `T`.
55 ///
56 /// The usual use of this is with `B = [u8]` and `T = u32`. That is,
57 /// it permits aligning a sequence of bytes on a 4-byte boundary. This
58 /// is useful in contexts where one wants to embed a serialized [dense
59 /// DFA](crate::dfa::dense::DFA) into a Rust a program while guaranteeing the
60 /// alignment required for the DFA.
61 ///
62 /// See [`dense::DFA::from_bytes`](crate::dfa::dense::DFA::from_bytes) for an
63 /// example of how to use this type.
64 #[repr(C)]
65 #[derive(Debug)]
66 pub struct AlignAs<B: ?Sized, T> {
67     /// A zero-sized field indicating the alignment we want.
68     pub _align: [T; 0],
69     /// A possibly non-sized field containing a sequence of bytes.
70     pub bytes: B,
71 }
72 
73 /// An error that occurs when serializing an object from this crate.
74 ///
75 /// Serialization, as used in this crate, universally refers to the process
76 /// of transforming a structure (like a DFA) into a custom binary format
77 /// represented by `&[u8]`. To this end, serialization is generally infallible.
78 /// However, it can fail when caller provided buffer sizes are too small. When
79 /// that occurs, a serialization error is reported.
80 ///
81 /// A `SerializeError` provides no introspection capabilities. Its only
82 /// supported operation is conversion to a human readable error message.
83 ///
84 /// This error type implements the `std::error::Error` trait only when the
85 /// `std` feature is enabled. Otherwise, this type is defined in all
86 /// configurations.
87 #[derive(Debug)]
88 pub struct SerializeError {
89     /// The name of the thing that a buffer is too small for.
90     ///
91     /// Currently, the only kind of serialization error is one that is
92     /// committed by a caller: providing a destination buffer that is too
93     /// small to fit the serialized object. This makes sense conceptually,
94     /// since every valid inhabitant of a type should be serializable.
95     ///
96     /// This is somewhat exposed in the public API of this crate. For example,
97     /// the `to_bytes_{big,little}_endian` APIs return a `Vec<u8>` and are
98     /// guaranteed to never panic or error. This is only possible because the
99     /// implementation guarantees that it will allocate a `Vec<u8>` that is
100     /// big enough.
101     ///
102     /// In summary, if a new serialization error kind needs to be added, then
103     /// it will need careful consideration.
104     what: &'static str,
105 }
106 
107 impl SerializeError {
buffer_too_small(what: &'static str) -> SerializeError108     pub(crate) fn buffer_too_small(what: &'static str) -> SerializeError {
109         SerializeError { what }
110     }
111 }
112 
113 impl core::fmt::Display for SerializeError {
fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result114     fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
115         write!(f, "destination buffer is too small to write {}", self.what)
116     }
117 }
118 
119 #[cfg(feature = "std")]
120 impl std::error::Error for SerializeError {}
121 
122 /// An error that occurs when deserializing an object defined in this crate.
123 ///
124 /// Serialization, as used in this crate, universally refers to the process
125 /// of transforming a structure (like a DFA) into a custom binary format
126 /// represented by `&[u8]`. Deserialization, then, refers to the process of
127 /// cheaply converting this binary format back to the object's in-memory
128 /// representation as defined in this crate. To the extent possible,
129 /// deserialization will report this error whenever this process fails.
130 ///
131 /// A `DeserializeError` provides no introspection capabilities. Its only
132 /// supported operation is conversion to a human readable error message.
133 ///
134 /// This error type implements the `std::error::Error` trait only when the
135 /// `std` feature is enabled. Otherwise, this type is defined in all
136 /// configurations.
137 #[derive(Debug)]
138 pub struct DeserializeError(DeserializeErrorKind);
139 
140 #[derive(Debug)]
141 enum DeserializeErrorKind {
142     Generic { msg: &'static str },
143     BufferTooSmall { what: &'static str },
144     InvalidUsize { what: &'static str },
145     VersionMismatch { expected: u32, found: u32 },
146     EndianMismatch { expected: u32, found: u32 },
147     AlignmentMismatch { alignment: usize, address: usize },
148     LabelMismatch { expected: &'static str },
149     ArithmeticOverflow { what: &'static str },
150     PatternID { err: PatternIDError, what: &'static str },
151     StateID { err: StateIDError, what: &'static str },
152 }
153 
154 impl DeserializeError {
generic(msg: &'static str) -> DeserializeError155     pub(crate) fn generic(msg: &'static str) -> DeserializeError {
156         DeserializeError(DeserializeErrorKind::Generic { msg })
157     }
158 
buffer_too_small(what: &'static str) -> DeserializeError159     pub(crate) fn buffer_too_small(what: &'static str) -> DeserializeError {
160         DeserializeError(DeserializeErrorKind::BufferTooSmall { what })
161     }
162 
invalid_usize(what: &'static str) -> DeserializeError163     fn invalid_usize(what: &'static str) -> DeserializeError {
164         DeserializeError(DeserializeErrorKind::InvalidUsize { what })
165     }
166 
version_mismatch(expected: u32, found: u32) -> DeserializeError167     fn version_mismatch(expected: u32, found: u32) -> DeserializeError {
168         DeserializeError(DeserializeErrorKind::VersionMismatch {
169             expected,
170             found,
171         })
172     }
173 
endian_mismatch(expected: u32, found: u32) -> DeserializeError174     fn endian_mismatch(expected: u32, found: u32) -> DeserializeError {
175         DeserializeError(DeserializeErrorKind::EndianMismatch {
176             expected,
177             found,
178         })
179     }
180 
alignment_mismatch( alignment: usize, address: usize, ) -> DeserializeError181     fn alignment_mismatch(
182         alignment: usize,
183         address: usize,
184     ) -> DeserializeError {
185         DeserializeError(DeserializeErrorKind::AlignmentMismatch {
186             alignment,
187             address,
188         })
189     }
190 
label_mismatch(expected: &'static str) -> DeserializeError191     fn label_mismatch(expected: &'static str) -> DeserializeError {
192         DeserializeError(DeserializeErrorKind::LabelMismatch { expected })
193     }
194 
arithmetic_overflow(what: &'static str) -> DeserializeError195     fn arithmetic_overflow(what: &'static str) -> DeserializeError {
196         DeserializeError(DeserializeErrorKind::ArithmeticOverflow { what })
197     }
198 
pattern_id_error( err: PatternIDError, what: &'static str, ) -> DeserializeError199     fn pattern_id_error(
200         err: PatternIDError,
201         what: &'static str,
202     ) -> DeserializeError {
203         DeserializeError(DeserializeErrorKind::PatternID { err, what })
204     }
205 
state_id_error( err: StateIDError, what: &'static str, ) -> DeserializeError206     pub(crate) fn state_id_error(
207         err: StateIDError,
208         what: &'static str,
209     ) -> DeserializeError {
210         DeserializeError(DeserializeErrorKind::StateID { err, what })
211     }
212 }
213 
214 #[cfg(feature = "std")]
215 impl std::error::Error for DeserializeError {}
216 
217 impl core::fmt::Display for DeserializeError {
fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result218     fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
219         use self::DeserializeErrorKind::*;
220 
221         match self.0 {
222             Generic { msg } => write!(f, "{}", msg),
223             BufferTooSmall { what } => {
224                 write!(f, "buffer is too small to read {}", what)
225             }
226             InvalidUsize { what } => {
227                 write!(f, "{} is too big to fit in a usize", what)
228             }
229             VersionMismatch { expected, found } => write!(
230                 f,
231                 "unsupported version: \
232                  expected version {} but found version {}",
233                 expected, found,
234             ),
235             EndianMismatch { expected, found } => write!(
236                 f,
237                 "endianness mismatch: expected 0x{:X} but got 0x{:X}. \
238                  (Are you trying to load an object serialized with a \
239                  different endianness?)",
240                 expected, found,
241             ),
242             AlignmentMismatch { alignment, address } => write!(
243                 f,
244                 "alignment mismatch: slice starts at address \
245                  0x{:X}, which is not aligned to a {} byte boundary",
246                 address, alignment,
247             ),
248             LabelMismatch { expected } => write!(
249                 f,
250                 "label mismatch: start of serialized object should \
251                  contain a NUL terminated {:?} label, but a different \
252                  label was found",
253                 expected,
254             ),
255             ArithmeticOverflow { what } => {
256                 write!(f, "arithmetic overflow for {}", what)
257             }
258             PatternID { ref err, what } => {
259                 write!(f, "failed to read pattern ID for {}: {}", what, err)
260             }
261             StateID { ref err, what } => {
262                 write!(f, "failed to read state ID for {}: {}", what, err)
263             }
264         }
265     }
266 }
267 
268 /// Safely converts a `&[u32]` to `&[StateID]` with zero cost.
269 #[cfg_attr(feature = "perf-inline", inline(always))]
u32s_to_state_ids(slice: &[u32]) -> &[StateID]270 pub(crate) fn u32s_to_state_ids(slice: &[u32]) -> &[StateID] {
271     // SAFETY: This is safe because StateID is defined to have the same memory
272     // representation as a u32 (it is repr(transparent)). While not every u32
273     // is a "valid" StateID, callers are not permitted to rely on the validity
274     // of StateIDs for memory safety. It can only lead to logical errors. (This
275     // is why StateID::new_unchecked is safe.)
276     unsafe {
277         core::slice::from_raw_parts(
278             slice.as_ptr().cast::<StateID>(),
279             slice.len(),
280         )
281     }
282 }
283 
284 /// Safely converts a `&mut [u32]` to `&mut [StateID]` with zero cost.
u32s_to_state_ids_mut(slice: &mut [u32]) -> &mut [StateID]285 pub(crate) fn u32s_to_state_ids_mut(slice: &mut [u32]) -> &mut [StateID] {
286     // SAFETY: This is safe because StateID is defined to have the same memory
287     // representation as a u32 (it is repr(transparent)). While not every u32
288     // is a "valid" StateID, callers are not permitted to rely on the validity
289     // of StateIDs for memory safety. It can only lead to logical errors. (This
290     // is why StateID::new_unchecked is safe.)
291     unsafe {
292         core::slice::from_raw_parts_mut(
293             slice.as_mut_ptr().cast::<StateID>(),
294             slice.len(),
295         )
296     }
297 }
298 
299 /// Safely converts a `&[u32]` to `&[PatternID]` with zero cost.
300 #[cfg_attr(feature = "perf-inline", inline(always))]
u32s_to_pattern_ids(slice: &[u32]) -> &[PatternID]301 pub(crate) fn u32s_to_pattern_ids(slice: &[u32]) -> &[PatternID] {
302     // SAFETY: This is safe because PatternID is defined to have the same
303     // memory representation as a u32 (it is repr(transparent)). While not
304     // every u32 is a "valid" PatternID, callers are not permitted to rely
305     // on the validity of PatternIDs for memory safety. It can only lead to
306     // logical errors. (This is why PatternID::new_unchecked is safe.)
307     unsafe {
308         core::slice::from_raw_parts(
309             slice.as_ptr().cast::<PatternID>(),
310             slice.len(),
311         )
312     }
313 }
314 
315 /// Checks that the given slice has an alignment that matches `T`.
316 ///
317 /// This is useful for checking that a slice has an appropriate alignment
318 /// before casting it to a &[T]. Note though that alignment is not itself
319 /// sufficient to perform the cast for any `T`.
check_alignment<T>( slice: &[u8], ) -> Result<(), DeserializeError>320 pub(crate) fn check_alignment<T>(
321     slice: &[u8],
322 ) -> Result<(), DeserializeError> {
323     let alignment = core::mem::align_of::<T>();
324     let address = slice.as_ptr().as_usize();
325     if address % alignment == 0 {
326         return Ok(());
327     }
328     Err(DeserializeError::alignment_mismatch(alignment, address))
329 }
330 
331 /// Reads a possibly empty amount of padding, up to 7 bytes, from the beginning
332 /// of the given slice. All padding bytes must be NUL bytes.
333 ///
334 /// This is useful because it can be theoretically necessary to pad the
335 /// beginning of a serialized object with NUL bytes to ensure that it starts
336 /// at a correctly aligned address. These padding bytes should come immediately
337 /// before the label.
338 ///
339 /// This returns the number of bytes read from the given slice.
skip_initial_padding(slice: &[u8]) -> usize340 pub(crate) fn skip_initial_padding(slice: &[u8]) -> usize {
341     let mut nread = 0;
342     while nread < 7 && nread < slice.len() && slice[nread] == 0 {
343         nread += 1;
344     }
345     nread
346 }
347 
348 /// Allocate a byte buffer of the given size, along with some initial padding
349 /// such that `buf[padding..]` has the same alignment as `T`, where the
350 /// alignment of `T` must be at most `8`. In particular, callers should treat
351 /// the first N bytes (second return value) as padding bytes that must not be
352 /// overwritten. In all cases, the following identity holds:
353 ///
354 /// ```ignore
355 /// let (buf, padding) = alloc_aligned_buffer::<StateID>(SIZE);
356 /// assert_eq!(SIZE, buf[padding..].len());
357 /// ```
358 ///
359 /// In practice, padding is often zero.
360 ///
361 /// The requirement for `8` as a maximum here is somewhat arbitrary. In
362 /// practice, we never need anything bigger in this crate, and so this function
363 /// does some sanity asserts under the assumption of a max alignment of `8`.
364 #[cfg(feature = "alloc")]
alloc_aligned_buffer<T>(size: usize) -> (Vec<u8>, usize)365 pub(crate) fn alloc_aligned_buffer<T>(size: usize) -> (Vec<u8>, usize) {
366     // NOTE: This is a kludge because there's no easy way to allocate a Vec<u8>
367     // with an alignment guaranteed to be greater than 1. We could create a
368     // Vec<u32>, but this cannot be safely transmuted to a Vec<u8> without
369     // concern, since reallocing or dropping the Vec<u8> is UB (different
370     // alignment than the initial allocation). We could define a wrapper type
371     // to manage this for us, but it seems like more machinery than it's worth.
372     let buf = vec![0; size];
373     let align = core::mem::align_of::<T>();
374     let address = buf.as_ptr().as_usize();
375     if address % align == 0 {
376         return (buf, 0);
377     }
378     // Let's try this again. We have to create a totally new alloc with
379     // the maximum amount of bytes we might need. We can't just extend our
380     // pre-existing 'buf' because that might create a new alloc with a
381     // different alignment.
382     let extra = align - 1;
383     let mut buf = vec![0; size + extra];
384     let address = buf.as_ptr().as_usize();
385     // The code below handles the case where 'address' is aligned to T, so if
386     // we got lucky and 'address' is now aligned to T (when it previously
387     // wasn't), then we're done.
388     if address % align == 0 {
389         buf.truncate(size);
390         return (buf, 0);
391     }
392     let padding = ((address & !(align - 1)).checked_add(align).unwrap())
393         .checked_sub(address)
394         .unwrap();
395     assert!(padding <= 7, "padding of {} is bigger than 7", padding);
396     assert!(
397         padding <= extra,
398         "padding of {} is bigger than extra {} bytes",
399         padding,
400         extra
401     );
402     buf.truncate(size + padding);
403     assert_eq!(size + padding, buf.len());
404     assert_eq!(
405         0,
406         buf[padding..].as_ptr().as_usize() % align,
407         "expected end of initial padding to be aligned to {}",
408         align,
409     );
410     (buf, padding)
411 }
412 
413 /// Reads a NUL terminated label starting at the beginning of the given slice.
414 ///
415 /// If a NUL terminated label could not be found, then an error is returned.
416 /// Similarly, if a label is found but doesn't match the expected label, then
417 /// an error is returned.
418 ///
419 /// Upon success, the total number of bytes read (including padding bytes) is
420 /// returned.
read_label( slice: &[u8], expected_label: &'static str, ) -> Result<usize, DeserializeError>421 pub(crate) fn read_label(
422     slice: &[u8],
423     expected_label: &'static str,
424 ) -> Result<usize, DeserializeError> {
425     // Set an upper bound on how many bytes we scan for a NUL. Since no label
426     // in this crate is longer than 256 bytes, if we can't find one within that
427     // range, then we have corrupted data.
428     let first_nul =
429         slice[..cmp::min(slice.len(), 256)].iter().position(|&b| b == 0);
430     let first_nul = match first_nul {
431         Some(first_nul) => first_nul,
432         None => {
433             return Err(DeserializeError::generic(
434                 "could not find NUL terminated label \
435                  at start of serialized object",
436             ));
437         }
438     };
439     let len = first_nul + padding_len(first_nul);
440     if slice.len() < len {
441         return Err(DeserializeError::generic(
442             "could not find properly sized label at start of serialized object"
443         ));
444     }
445     if expected_label.as_bytes() != &slice[..first_nul] {
446         return Err(DeserializeError::label_mismatch(expected_label));
447     }
448     Ok(len)
449 }
450 
451 /// Writes the given label to the buffer as a NUL terminated string. The label
452 /// given must not contain NUL, otherwise this will panic. Similarly, the label
453 /// must not be longer than 255 bytes, otherwise this will panic.
454 ///
455 /// Additional NUL bytes are written as necessary to ensure that the number of
456 /// bytes written is always a multiple of 4.
457 ///
458 /// Upon success, the total number of bytes written (including padding) is
459 /// returned.
write_label( label: &str, dst: &mut [u8], ) -> Result<usize, SerializeError>460 pub(crate) fn write_label(
461     label: &str,
462     dst: &mut [u8],
463 ) -> Result<usize, SerializeError> {
464     let nwrite = write_label_len(label);
465     if dst.len() < nwrite {
466         return Err(SerializeError::buffer_too_small("label"));
467     }
468     dst[..label.len()].copy_from_slice(label.as_bytes());
469     for i in 0..(nwrite - label.len()) {
470         dst[label.len() + i] = 0;
471     }
472     assert_eq!(nwrite % 4, 0);
473     Ok(nwrite)
474 }
475 
476 /// Returns the total number of bytes (including padding) that would be written
477 /// for the given label. This panics if the given label contains a NUL byte or
478 /// is longer than 255 bytes. (The size restriction exists so that searching
479 /// for a label during deserialization can be done in small bounded space.)
write_label_len(label: &str) -> usize480 pub(crate) fn write_label_len(label: &str) -> usize {
481     if label.len() > 255 {
482         panic!("label must not be longer than 255 bytes");
483     }
484     if label.as_bytes().iter().position(|&b| b == 0).is_some() {
485         panic!("label must not contain NUL bytes");
486     }
487     let label_len = label.len() + 1; // +1 for the NUL terminator
488     label_len + padding_len(label_len)
489 }
490 
491 /// Reads the endianness check from the beginning of the given slice and
492 /// confirms that the endianness of the serialized object matches the expected
493 /// endianness. If the slice is too small or if the endianness check fails,
494 /// this returns an error.
495 ///
496 /// Upon success, the total number of bytes read is returned.
read_endianness_check( slice: &[u8], ) -> Result<usize, DeserializeError>497 pub(crate) fn read_endianness_check(
498     slice: &[u8],
499 ) -> Result<usize, DeserializeError> {
500     let (n, nr) = try_read_u32(slice, "endianness check")?;
501     assert_eq!(nr, write_endianness_check_len());
502     if n != 0xFEFF {
503         return Err(DeserializeError::endian_mismatch(0xFEFF, n));
504     }
505     Ok(nr)
506 }
507 
508 /// Writes 0xFEFF as an integer using the given endianness.
509 ///
510 /// This is useful for writing into the header of a serialized object. It can
511 /// be read during deserialization as a sanity check to ensure the proper
512 /// endianness is used.
513 ///
514 /// Upon success, the total number of bytes written is returned.
write_endianness_check<E: Endian>( dst: &mut [u8], ) -> Result<usize, SerializeError>515 pub(crate) fn write_endianness_check<E: Endian>(
516     dst: &mut [u8],
517 ) -> Result<usize, SerializeError> {
518     let nwrite = write_endianness_check_len();
519     if dst.len() < nwrite {
520         return Err(SerializeError::buffer_too_small("endianness check"));
521     }
522     E::write_u32(0xFEFF, dst);
523     Ok(nwrite)
524 }
525 
526 /// Returns the number of bytes written by the endianness check.
write_endianness_check_len() -> usize527 pub(crate) fn write_endianness_check_len() -> usize {
528     size_of::<u32>()
529 }
530 
531 /// Reads a version number from the beginning of the given slice and confirms
532 /// that is matches the expected version number given. If the slice is too
533 /// small or if the version numbers aren't equivalent, this returns an error.
534 ///
535 /// Upon success, the total number of bytes read is returned.
536 ///
537 /// N.B. Currently, we require that the version number is exactly equivalent.
538 /// In the future, if we bump the version number without a semver bump, then
539 /// we'll need to relax this a bit and support older versions.
read_version( slice: &[u8], expected_version: u32, ) -> Result<usize, DeserializeError>540 pub(crate) fn read_version(
541     slice: &[u8],
542     expected_version: u32,
543 ) -> Result<usize, DeserializeError> {
544     let (n, nr) = try_read_u32(slice, "version")?;
545     assert_eq!(nr, write_version_len());
546     if n != expected_version {
547         return Err(DeserializeError::version_mismatch(expected_version, n));
548     }
549     Ok(nr)
550 }
551 
552 /// Writes the given version number to the beginning of the given slice.
553 ///
554 /// This is useful for writing into the header of a serialized object. It can
555 /// be read during deserialization as a sanity check to ensure that the library
556 /// code supports the format of the serialized object.
557 ///
558 /// Upon success, the total number of bytes written is returned.
write_version<E: Endian>( version: u32, dst: &mut [u8], ) -> Result<usize, SerializeError>559 pub(crate) fn write_version<E: Endian>(
560     version: u32,
561     dst: &mut [u8],
562 ) -> Result<usize, SerializeError> {
563     let nwrite = write_version_len();
564     if dst.len() < nwrite {
565         return Err(SerializeError::buffer_too_small("version number"));
566     }
567     E::write_u32(version, dst);
568     Ok(nwrite)
569 }
570 
571 /// Returns the number of bytes written by writing the version number.
write_version_len() -> usize572 pub(crate) fn write_version_len() -> usize {
573     size_of::<u32>()
574 }
575 
576 /// Reads a pattern ID from the given slice. If the slice has insufficient
577 /// length, then this panics. If the deserialized integer exceeds the pattern
578 /// ID limit for the current target, then this returns an error.
579 ///
580 /// Upon success, this also returns the number of bytes read.
read_pattern_id( slice: &[u8], what: &'static str, ) -> Result<(PatternID, usize), DeserializeError>581 pub(crate) fn read_pattern_id(
582     slice: &[u8],
583     what: &'static str,
584 ) -> Result<(PatternID, usize), DeserializeError> {
585     let bytes: [u8; PatternID::SIZE] =
586         slice[..PatternID::SIZE].try_into().unwrap();
587     let pid = PatternID::from_ne_bytes(bytes)
588         .map_err(|err| DeserializeError::pattern_id_error(err, what))?;
589     Ok((pid, PatternID::SIZE))
590 }
591 
592 /// Reads a pattern ID from the given slice. If the slice has insufficient
593 /// length, then this panics. Otherwise, the deserialized integer is assumed
594 /// to be a valid pattern ID.
595 ///
596 /// This also returns the number of bytes read.
read_pattern_id_unchecked(slice: &[u8]) -> (PatternID, usize)597 pub(crate) fn read_pattern_id_unchecked(slice: &[u8]) -> (PatternID, usize) {
598     let pid = PatternID::from_ne_bytes_unchecked(
599         slice[..PatternID::SIZE].try_into().unwrap(),
600     );
601     (pid, PatternID::SIZE)
602 }
603 
604 /// Write the given pattern ID to the beginning of the given slice of bytes
605 /// using the specified endianness. The given slice must have length at least
606 /// `PatternID::SIZE`, or else this panics. Upon success, the total number of
607 /// bytes written is returned.
write_pattern_id<E: Endian>( pid: PatternID, dst: &mut [u8], ) -> usize608 pub(crate) fn write_pattern_id<E: Endian>(
609     pid: PatternID,
610     dst: &mut [u8],
611 ) -> usize {
612     E::write_u32(pid.as_u32(), dst);
613     PatternID::SIZE
614 }
615 
616 /// Attempts to read a state ID from the given slice. If the slice has an
617 /// insufficient number of bytes or if the state ID exceeds the limit for
618 /// the current target, then this returns an error.
619 ///
620 /// Upon success, this also returns the number of bytes read.
try_read_state_id( slice: &[u8], what: &'static str, ) -> Result<(StateID, usize), DeserializeError>621 pub(crate) fn try_read_state_id(
622     slice: &[u8],
623     what: &'static str,
624 ) -> Result<(StateID, usize), DeserializeError> {
625     if slice.len() < StateID::SIZE {
626         return Err(DeserializeError::buffer_too_small(what));
627     }
628     read_state_id(slice, what)
629 }
630 
631 /// Reads a state ID from the given slice. If the slice has insufficient
632 /// length, then this panics. If the deserialized integer exceeds the state ID
633 /// limit for the current target, then this returns an error.
634 ///
635 /// Upon success, this also returns the number of bytes read.
read_state_id( slice: &[u8], what: &'static str, ) -> Result<(StateID, usize), DeserializeError>636 pub(crate) fn read_state_id(
637     slice: &[u8],
638     what: &'static str,
639 ) -> Result<(StateID, usize), DeserializeError> {
640     let bytes: [u8; StateID::SIZE] =
641         slice[..StateID::SIZE].try_into().unwrap();
642     let sid = StateID::from_ne_bytes(bytes)
643         .map_err(|err| DeserializeError::state_id_error(err, what))?;
644     Ok((sid, StateID::SIZE))
645 }
646 
647 /// Reads a state ID from the given slice. If the slice has insufficient
648 /// length, then this panics. Otherwise, the deserialized integer is assumed
649 /// to be a valid state ID.
650 ///
651 /// This also returns the number of bytes read.
read_state_id_unchecked(slice: &[u8]) -> (StateID, usize)652 pub(crate) fn read_state_id_unchecked(slice: &[u8]) -> (StateID, usize) {
653     let sid = StateID::from_ne_bytes_unchecked(
654         slice[..StateID::SIZE].try_into().unwrap(),
655     );
656     (sid, StateID::SIZE)
657 }
658 
659 /// Write the given state ID to the beginning of the given slice of bytes
660 /// using the specified endianness. The given slice must have length at least
661 /// `StateID::SIZE`, or else this panics. Upon success, the total number of
662 /// bytes written is returned.
write_state_id<E: Endian>( sid: StateID, dst: &mut [u8], ) -> usize663 pub(crate) fn write_state_id<E: Endian>(
664     sid: StateID,
665     dst: &mut [u8],
666 ) -> usize {
667     E::write_u32(sid.as_u32(), dst);
668     StateID::SIZE
669 }
670 
671 /// Try to read a u16 as a usize from the beginning of the given slice in
672 /// native endian format. If the slice has fewer than 2 bytes or if the
673 /// deserialized number cannot be represented by usize, then this returns an
674 /// error. The error message will include the `what` description of what is
675 /// being deserialized, for better error messages. `what` should be a noun in
676 /// singular form.
677 ///
678 /// Upon success, this also returns the number of bytes read.
try_read_u16_as_usize( slice: &[u8], what: &'static str, ) -> Result<(usize, usize), DeserializeError>679 pub(crate) fn try_read_u16_as_usize(
680     slice: &[u8],
681     what: &'static str,
682 ) -> Result<(usize, usize), DeserializeError> {
683     try_read_u16(slice, what).and_then(|(n, nr)| {
684         usize::try_from(n)
685             .map(|n| (n, nr))
686             .map_err(|_| DeserializeError::invalid_usize(what))
687     })
688 }
689 
690 /// Try to read a u32 as a usize from the beginning of the given slice in
691 /// native endian format. If the slice has fewer than 4 bytes or if the
692 /// deserialized number cannot be represented by usize, then this returns an
693 /// error. The error message will include the `what` description of what is
694 /// being deserialized, for better error messages. `what` should be a noun in
695 /// singular form.
696 ///
697 /// Upon success, this also returns the number of bytes read.
try_read_u32_as_usize( slice: &[u8], what: &'static str, ) -> Result<(usize, usize), DeserializeError>698 pub(crate) fn try_read_u32_as_usize(
699     slice: &[u8],
700     what: &'static str,
701 ) -> Result<(usize, usize), DeserializeError> {
702     try_read_u32(slice, what).and_then(|(n, nr)| {
703         usize::try_from(n)
704             .map(|n| (n, nr))
705             .map_err(|_| DeserializeError::invalid_usize(what))
706     })
707 }
708 
709 /// Try to read a u16 from the beginning of the given slice in native endian
710 /// format. If the slice has fewer than 2 bytes, then this returns an error.
711 /// The error message will include the `what` description of what is being
712 /// deserialized, for better error messages. `what` should be a noun in
713 /// singular form.
714 ///
715 /// Upon success, this also returns the number of bytes read.
try_read_u16( slice: &[u8], what: &'static str, ) -> Result<(u16, usize), DeserializeError>716 pub(crate) fn try_read_u16(
717     slice: &[u8],
718     what: &'static str,
719 ) -> Result<(u16, usize), DeserializeError> {
720     check_slice_len(slice, size_of::<u16>(), what)?;
721     Ok((read_u16(slice), size_of::<u16>()))
722 }
723 
724 /// Try to read a u32 from the beginning of the given slice in native endian
725 /// format. If the slice has fewer than 4 bytes, then this returns an error.
726 /// The error message will include the `what` description of what is being
727 /// deserialized, for better error messages. `what` should be a noun in
728 /// singular form.
729 ///
730 /// Upon success, this also returns the number of bytes read.
try_read_u32( slice: &[u8], what: &'static str, ) -> Result<(u32, usize), DeserializeError>731 pub(crate) fn try_read_u32(
732     slice: &[u8],
733     what: &'static str,
734 ) -> Result<(u32, usize), DeserializeError> {
735     check_slice_len(slice, size_of::<u32>(), what)?;
736     Ok((read_u32(slice), size_of::<u32>()))
737 }
738 
739 /// Try to read a u128 from the beginning of the given slice in native endian
740 /// format. If the slice has fewer than 16 bytes, then this returns an error.
741 /// The error message will include the `what` description of what is being
742 /// deserialized, for better error messages. `what` should be a noun in
743 /// singular form.
744 ///
745 /// Upon success, this also returns the number of bytes read.
try_read_u128( slice: &[u8], what: &'static str, ) -> Result<(u128, usize), DeserializeError>746 pub(crate) fn try_read_u128(
747     slice: &[u8],
748     what: &'static str,
749 ) -> Result<(u128, usize), DeserializeError> {
750     check_slice_len(slice, size_of::<u128>(), what)?;
751     Ok((read_u128(slice), size_of::<u128>()))
752 }
753 
754 /// Read a u16 from the beginning of the given slice in native endian format.
755 /// If the slice has fewer than 2 bytes, then this panics.
756 ///
757 /// Marked as inline to speed up sparse searching which decodes integers from
758 /// its automaton at search time.
759 #[cfg_attr(feature = "perf-inline", inline(always))]
read_u16(slice: &[u8]) -> u16760 pub(crate) fn read_u16(slice: &[u8]) -> u16 {
761     let bytes: [u8; 2] = slice[..size_of::<u16>()].try_into().unwrap();
762     u16::from_ne_bytes(bytes)
763 }
764 
765 /// Read a u32 from the beginning of the given slice in native endian format.
766 /// If the slice has fewer than 4 bytes, then this panics.
767 ///
768 /// Marked as inline to speed up sparse searching which decodes integers from
769 /// its automaton at search time.
770 #[cfg_attr(feature = "perf-inline", inline(always))]
read_u32(slice: &[u8]) -> u32771 pub(crate) fn read_u32(slice: &[u8]) -> u32 {
772     let bytes: [u8; 4] = slice[..size_of::<u32>()].try_into().unwrap();
773     u32::from_ne_bytes(bytes)
774 }
775 
776 /// Read a u128 from the beginning of the given slice in native endian format.
777 /// If the slice has fewer than 16 bytes, then this panics.
read_u128(slice: &[u8]) -> u128778 pub(crate) fn read_u128(slice: &[u8]) -> u128 {
779     let bytes: [u8; 16] = slice[..size_of::<u128>()].try_into().unwrap();
780     u128::from_ne_bytes(bytes)
781 }
782 
783 /// Checks that the given slice has some minimal length. If it's smaller than
784 /// the bound given, then a "buffer too small" error is returned with `what`
785 /// describing what the buffer represents.
check_slice_len<T>( slice: &[T], at_least_len: usize, what: &'static str, ) -> Result<(), DeserializeError>786 pub(crate) fn check_slice_len<T>(
787     slice: &[T],
788     at_least_len: usize,
789     what: &'static str,
790 ) -> Result<(), DeserializeError> {
791     if slice.len() < at_least_len {
792         return Err(DeserializeError::buffer_too_small(what));
793     }
794     Ok(())
795 }
796 
797 /// Multiply the given numbers, and on overflow, return an error that includes
798 /// 'what' in the error message.
799 ///
800 /// This is useful when doing arithmetic with untrusted data.
mul( a: usize, b: usize, what: &'static str, ) -> Result<usize, DeserializeError>801 pub(crate) fn mul(
802     a: usize,
803     b: usize,
804     what: &'static str,
805 ) -> Result<usize, DeserializeError> {
806     match a.checked_mul(b) {
807         Some(c) => Ok(c),
808         None => Err(DeserializeError::arithmetic_overflow(what)),
809     }
810 }
811 
812 /// Add the given numbers, and on overflow, return an error that includes
813 /// 'what' in the error message.
814 ///
815 /// This is useful when doing arithmetic with untrusted data.
add( a: usize, b: usize, what: &'static str, ) -> Result<usize, DeserializeError>816 pub(crate) fn add(
817     a: usize,
818     b: usize,
819     what: &'static str,
820 ) -> Result<usize, DeserializeError> {
821     match a.checked_add(b) {
822         Some(c) => Ok(c),
823         None => Err(DeserializeError::arithmetic_overflow(what)),
824     }
825 }
826 
827 /// Shift `a` left by `b`, and on overflow, return an error that includes
828 /// 'what' in the error message.
829 ///
830 /// This is useful when doing arithmetic with untrusted data.
shl( a: usize, b: usize, what: &'static str, ) -> Result<usize, DeserializeError>831 pub(crate) fn shl(
832     a: usize,
833     b: usize,
834     what: &'static str,
835 ) -> Result<usize, DeserializeError> {
836     let amount = u32::try_from(b)
837         .map_err(|_| DeserializeError::arithmetic_overflow(what))?;
838     match a.checked_shl(amount) {
839         Some(c) => Ok(c),
840         None => Err(DeserializeError::arithmetic_overflow(what)),
841     }
842 }
843 
844 /// Returns the number of additional bytes required to add to the given length
845 /// in order to make the total length a multiple of 4. The return value is
846 /// always less than 4.
padding_len(non_padding_len: usize) -> usize847 pub(crate) fn padding_len(non_padding_len: usize) -> usize {
848     (4 - (non_padding_len & 0b11)) & 0b11
849 }
850 
851 /// A simple trait for writing code generic over endianness.
852 ///
853 /// This is similar to what byteorder provides, but we only need a very small
854 /// subset.
855 pub(crate) trait Endian {
856     /// Writes a u16 to the given destination buffer in a particular
857     /// endianness. If the destination buffer has a length smaller than 2, then
858     /// this panics.
write_u16(n: u16, dst: &mut [u8])859     fn write_u16(n: u16, dst: &mut [u8]);
860 
861     /// Writes a u32 to the given destination buffer in a particular
862     /// endianness. If the destination buffer has a length smaller than 4, then
863     /// this panics.
write_u32(n: u32, dst: &mut [u8])864     fn write_u32(n: u32, dst: &mut [u8]);
865 
866     /// Writes a u128 to the given destination buffer in a particular
867     /// endianness. If the destination buffer has a length smaller than 16,
868     /// then this panics.
write_u128(n: u128, dst: &mut [u8])869     fn write_u128(n: u128, dst: &mut [u8]);
870 }
871 
872 /// Little endian writing.
873 pub(crate) enum LE {}
874 /// Big endian writing.
875 pub(crate) enum BE {}
876 
877 #[cfg(target_endian = "little")]
878 pub(crate) type NE = LE;
879 #[cfg(target_endian = "big")]
880 pub(crate) type NE = BE;
881 
882 impl Endian for LE {
write_u16(n: u16, dst: &mut [u8])883     fn write_u16(n: u16, dst: &mut [u8]) {
884         dst[..2].copy_from_slice(&n.to_le_bytes());
885     }
886 
write_u32(n: u32, dst: &mut [u8])887     fn write_u32(n: u32, dst: &mut [u8]) {
888         dst[..4].copy_from_slice(&n.to_le_bytes());
889     }
890 
write_u128(n: u128, dst: &mut [u8])891     fn write_u128(n: u128, dst: &mut [u8]) {
892         dst[..16].copy_from_slice(&n.to_le_bytes());
893     }
894 }
895 
896 impl Endian for BE {
write_u16(n: u16, dst: &mut [u8])897     fn write_u16(n: u16, dst: &mut [u8]) {
898         dst[..2].copy_from_slice(&n.to_be_bytes());
899     }
900 
write_u32(n: u32, dst: &mut [u8])901     fn write_u32(n: u32, dst: &mut [u8]) {
902         dst[..4].copy_from_slice(&n.to_be_bytes());
903     }
904 
write_u128(n: u128, dst: &mut [u8])905     fn write_u128(n: u128, dst: &mut [u8]) {
906         dst[..16].copy_from_slice(&n.to_be_bytes());
907     }
908 }
909 
910 #[cfg(all(test, feature = "alloc"))]
911 mod tests {
912     use super::*;
913 
914     #[test]
labels()915     fn labels() {
916         let mut buf = [0; 1024];
917 
918         let nwrite = write_label("fooba", &mut buf).unwrap();
919         assert_eq!(nwrite, 8);
920         assert_eq!(&buf[..nwrite], b"fooba\x00\x00\x00");
921 
922         let nread = read_label(&buf, "fooba").unwrap();
923         assert_eq!(nread, 8);
924     }
925 
926     #[test]
927     #[should_panic]
bad_label_interior_nul()928     fn bad_label_interior_nul() {
929         // interior NULs are not allowed
930         write_label("foo\x00bar", &mut [0; 1024]).unwrap();
931     }
932 
933     #[test]
bad_label_almost_too_long()934     fn bad_label_almost_too_long() {
935         // ok
936         write_label(&"z".repeat(255), &mut [0; 1024]).unwrap();
937     }
938 
939     #[test]
940     #[should_panic]
bad_label_too_long()941     fn bad_label_too_long() {
942         // labels longer than 255 bytes are banned
943         write_label(&"z".repeat(256), &mut [0; 1024]).unwrap();
944     }
945 
946     #[test]
padding()947     fn padding() {
948         assert_eq!(0, padding_len(8));
949         assert_eq!(3, padding_len(9));
950         assert_eq!(2, padding_len(10));
951         assert_eq!(1, padding_len(11));
952         assert_eq!(0, padding_len(12));
953         assert_eq!(3, padding_len(13));
954         assert_eq!(2, padding_len(14));
955         assert_eq!(1, padding_len(15));
956         assert_eq!(0, padding_len(16));
957     }
958 }
959