1 //! The block decompression algorithm.
2 use crate::block::{DecompressError, MINMATCH};
3 use crate::fastcpy_unsafe;
4 use crate::sink::SliceSink;
5 use crate::sink::{PtrSink, Sink};
6 use alloc::vec::Vec;
7 
8 /// Copies data to output_ptr by self-referential copy from start and match_length
9 #[inline]
duplicate( output_ptr: &mut *mut u8, output_end: *mut u8, start: *const u8, match_length: usize, )10 unsafe fn duplicate(
11     output_ptr: &mut *mut u8,
12     output_end: *mut u8,
13     start: *const u8,
14     match_length: usize,
15 ) {
16     // We cannot simply use memcpy or `extend_from_slice`, because these do not allow
17     // self-referential copies: http://ticki.github.io/img/lz4_runs_encoding_diagram.svg
18 
19     // Considering that `wild_copy_match_16` can copy up to `16 - 1` extra bytes.
20     // Defer to `duplicate_overlapping` in case of an overlapping match
21     // OR the if the wild copy would copy beyond the end of the output.
22     if (output_ptr.offset_from(start) as usize) < match_length + 16 - 1
23         || (output_end.offset_from(*output_ptr) as usize) < match_length + 16 - 1
24     {
25         duplicate_overlapping(output_ptr, start, match_length);
26     } else {
27         debug_assert!(
28             output_ptr.add(match_length / 16 * 16 + ((match_length % 16) != 0) as usize * 16)
29                 <= output_end
30         );
31         wild_copy_from_src_16(start, *output_ptr, match_length);
32         *output_ptr = output_ptr.add(match_length);
33     }
34 }
35 
36 #[inline]
wild_copy_from_src_16(mut source: *const u8, mut dst_ptr: *mut u8, num_items: usize)37 fn wild_copy_from_src_16(mut source: *const u8, mut dst_ptr: *mut u8, num_items: usize) {
38     // Note: if the compiler auto-vectorizes this it'll hurt performance!
39     // It's not the case for 16 bytes stepsize, but for 8 bytes.
40     unsafe {
41         let dst_ptr_end = dst_ptr.add(num_items);
42         loop {
43             core::ptr::copy_nonoverlapping(source, dst_ptr, 16);
44             source = source.add(16);
45             dst_ptr = dst_ptr.add(16);
46             if dst_ptr >= dst_ptr_end {
47                 break;
48             }
49         }
50     }
51 }
52 
53 /// Copy function, if the data start + match_length overlaps into output_ptr
54 #[inline]
55 #[cfg_attr(nightly, optimize(size))] // to avoid loop unrolling
duplicate_overlapping( output_ptr: &mut *mut u8, mut start: *const u8, match_length: usize, )56 unsafe fn duplicate_overlapping(
57     output_ptr: &mut *mut u8,
58     mut start: *const u8,
59     match_length: usize,
60 ) {
61     // There is an edge case when output_ptr == start, which causes the decoder to potentially
62     // expose up to match_length bytes of uninitialized data in the decompression buffer.
63     // To prevent that we write a dummy zero to output, which will zero out output in such cases.
64     // This is the same strategy used by the reference C implementation https://github.com/lz4/lz4/pull/772
65     output_ptr.write(0u8);
66     let dst_ptr_end = output_ptr.add(match_length);
67 
68     while output_ptr.add(1) < dst_ptr_end {
69         // Note that this loop unrolling is done, so that the compiler doesn't do it in a awful
70         // way.
71         // Without that the compiler will unroll/auto-vectorize the copy with a lot of branches.
72         // This is not what we want, as large overlapping copies are not that common.
73         core::ptr::copy(start, *output_ptr, 1);
74         start = start.add(1);
75         *output_ptr = output_ptr.add(1);
76 
77         core::ptr::copy(start, *output_ptr, 1);
78         start = start.add(1);
79         *output_ptr = output_ptr.add(1);
80     }
81 
82     if *output_ptr < dst_ptr_end {
83         core::ptr::copy(start, *output_ptr, 1);
84         *output_ptr = output_ptr.add(1);
85     }
86 }
87 
88 #[inline]
copy_from_dict( output_base: *mut u8, output_ptr: &mut *mut u8, ext_dict: &[u8], offset: usize, match_length: usize, ) -> usize89 unsafe fn copy_from_dict(
90     output_base: *mut u8,
91     output_ptr: &mut *mut u8,
92     ext_dict: &[u8],
93     offset: usize,
94     match_length: usize,
95 ) -> usize {
96     // If we're here we know offset > output pos, so we have at least 1 byte to copy from dict
97     debug_assert!(output_ptr.offset_from(output_base) >= 0);
98     debug_assert!(offset > output_ptr.offset_from(output_base) as usize);
99     // If unchecked-decode is not disabled we also know that the offset falls within ext_dict
100     debug_assert!(ext_dict.len() + output_ptr.offset_from(output_base) as usize >= offset);
101 
102     let dict_offset = ext_dict.len() + output_ptr.offset_from(output_base) as usize - offset;
103     // Can't copy past ext_dict len, the match may cross dict and output
104     let dict_match_length = match_length.min(ext_dict.len() - dict_offset);
105     // TODO test fastcpy_unsafe
106     core::ptr::copy_nonoverlapping(
107         ext_dict.as_ptr().add(dict_offset),
108         *output_ptr,
109         dict_match_length,
110     );
111     *output_ptr = output_ptr.add(dict_match_length);
112     dict_match_length
113 }
114 
115 /// Read an integer.
116 ///
117 /// In LZ4, we encode small integers in a way that we can have an arbitrary number of bytes. In
118 /// particular, we add the bytes repeatedly until we hit a non-0xFF byte. When we do, we add
119 /// this byte to our sum and terminate the loop.
120 ///
121 /// # Example
122 ///
123 /// ```notest
124 ///     255, 255, 255, 4, 2, 3, 4, 6, 7
125 /// ```
126 ///
127 /// is encoded to _255 + 255 + 255 + 4 = 769_. The bytes after the first 4 is ignored, because
128 /// 4 is the first non-0xFF byte.
129 #[inline]
read_integer_ptr( input_ptr: &mut *const u8, _input_ptr_end: *const u8, ) -> Result<u32, DecompressError>130 fn read_integer_ptr(
131     input_ptr: &mut *const u8,
132     _input_ptr_end: *const u8,
133 ) -> Result<u32, DecompressError> {
134     // We start at zero and count upwards.
135     let mut n: u32 = 0;
136     // If this byte takes value 255 (the maximum value it can take), another byte is read
137     // and added to the sum. This repeats until a byte lower than 255 is read.
138     loop {
139         // We add the next byte until we get a byte which we add to the counting variable.
140 
141         #[cfg(not(feature = "unchecked-decode"))]
142         {
143             if *input_ptr >= _input_ptr_end {
144                 return Err(DecompressError::ExpectedAnotherByte);
145             }
146         }
147         let extra = unsafe { input_ptr.read() };
148         *input_ptr = unsafe { input_ptr.add(1) };
149         n += extra as u32;
150 
151         // We continue if we got 255, break otherwise.
152         if extra != 0xFF {
153             break;
154         }
155     }
156 
157     // 255, 255, 255, 8
158     // 111, 111, 111, 101
159 
160     Ok(n)
161 }
162 
163 /// Read a little-endian 16-bit integer from the input stream.
164 #[inline]
read_u16_ptr(input_ptr: &mut *const u8) -> u16165 fn read_u16_ptr(input_ptr: &mut *const u8) -> u16 {
166     let mut num: u16 = 0;
167     unsafe {
168         core::ptr::copy_nonoverlapping(*input_ptr, &mut num as *mut u16 as *mut u8, 2);
169         *input_ptr = input_ptr.add(2);
170     }
171 
172     u16::from_le(num)
173 }
174 
175 const FIT_TOKEN_MASK_LITERAL: u8 = 0b00001111;
176 const FIT_TOKEN_MASK_MATCH: u8 = 0b11110000;
177 
178 #[test]
check_token()179 fn check_token() {
180     assert!(!does_token_fit(15));
181     assert!(does_token_fit(14));
182     assert!(does_token_fit(114));
183     assert!(!does_token_fit(0b11110000));
184     assert!(does_token_fit(0b10110000));
185 }
186 
187 /// The token consists of two parts, the literal length (upper 4 bits) and match_length (lower 4
188 /// bits) if the literal length and match_length are both below 15, we don't need to read additional
189 /// data, so the token does fit the metadata in a single u8.
190 #[inline]
does_token_fit(token: u8) -> bool191 fn does_token_fit(token: u8) -> bool {
192     !((token & FIT_TOKEN_MASK_LITERAL) == FIT_TOKEN_MASK_LITERAL
193         || (token & FIT_TOKEN_MASK_MATCH) == FIT_TOKEN_MASK_MATCH)
194 }
195 
196 /// Decompress all bytes of `input` into `output`.
197 ///
198 /// Returns the number of bytes written (decompressed) into `output`.
199 #[inline]
decompress_internal<const USE_DICT: bool, S: Sink>( input: &[u8], output: &mut S, ext_dict: &[u8], ) -> Result<usize, DecompressError>200 pub(crate) fn decompress_internal<const USE_DICT: bool, S: Sink>(
201     input: &[u8],
202     output: &mut S,
203     ext_dict: &[u8],
204 ) -> Result<usize, DecompressError> {
205     // Prevent segfault for empty input
206     if input.is_empty() {
207         return Err(DecompressError::ExpectedAnotherByte);
208     }
209 
210     let ext_dict = if USE_DICT {
211         ext_dict
212     } else {
213         // ensure optimizer knows ext_dict length is 0 if !USE_DICT
214         debug_assert!(ext_dict.is_empty());
215         &[]
216     };
217     let output_base = unsafe { output.base_mut_ptr() };
218     let output_end = unsafe { output_base.add(output.capacity()) };
219     let output_start_pos_ptr = unsafe { output.base_mut_ptr().add(output.pos()) as *mut u8 };
220     let mut output_ptr = output_start_pos_ptr;
221 
222     let mut input_ptr = input.as_ptr();
223     let input_ptr_end = unsafe { input.as_ptr().add(input.len()) };
224     let safe_distance_from_end =  (16 /* literal copy */ +  2 /* u16 match offset */ + 1 /* The next token to read (we can skip the check) */).min(input.len()) ;
225     let input_ptr_safe = unsafe { input_ptr_end.sub(safe_distance_from_end) };
226 
227     let safe_output_ptr = unsafe {
228         let mut output_num_safe_bytes = output
229             .capacity()
230             .saturating_sub(16 /* literal copy */ + 18 /* match copy */);
231         if USE_DICT {
232             // In the dictionary case the output pointer is moved by the match length in the dictionary.
233             // This may be up to 17 bytes without exiting the loop. So we need to ensure that we have
234             // at least additional 17 bytes of space left in the output buffer in the fast loop.
235             output_num_safe_bytes = output_num_safe_bytes.saturating_sub(17);
236         };
237 
238         output_base.add(output_num_safe_bytes)
239     };
240 
241     // Exhaust the decoder by reading and decompressing all blocks until the remaining buffer is
242     // empty.
243     loop {
244         // Read the token. The token is the first byte in a block. It is divided into two 4-bit
245         // subtokens, the higher and the lower.
246         // This token contains to 4-bit "fields", a higher and a lower, representing the literals'
247         // length and the back reference's length, respectively.
248         let token = unsafe { input_ptr.read() };
249         input_ptr = unsafe { input_ptr.add(1) };
250 
251         // Checking for hot-loop.
252         // In most cases the metadata does fit in a single 1byte token (statistically) and we are in
253         // a safe-distance to the end. This enables some optimized handling.
254         //
255         // Ideally we want to check for safe output pos like: output.pos() <= safe_output_pos; But
256         // that doesn't work when the safe_output_ptr is == output_ptr due to insufficient
257         // capacity. So we use `<` instead of `<=`, which covers that case.
258         if does_token_fit(token)
259             && (input_ptr as usize) <= input_ptr_safe as usize
260             && output_ptr < safe_output_ptr
261         {
262             let literal_length = (token >> 4) as usize;
263             let mut match_length = MINMATCH + (token & 0xF) as usize;
264 
265             // output_ptr <= safe_output_ptr should guarantee we have enough space in output
266             debug_assert!(
267                 unsafe { output_ptr.add(literal_length + match_length) } <= output_end,
268                 "{literal_length} + {match_length} {} wont fit ",
269                 literal_length + match_length
270             );
271 
272             // Copy the literal
273             // The literal is at max 16 bytes, and the is_safe_distance check assures
274             // that we are far away enough from the end so we can safely copy 16 bytes
275             unsafe {
276                 core::ptr::copy_nonoverlapping(input_ptr, output_ptr, 16);
277                 input_ptr = input_ptr.add(literal_length);
278                 output_ptr = output_ptr.add(literal_length);
279             }
280 
281             // input_ptr <= input_ptr_safe should guarantee we have enough space in input
282             debug_assert!(input_ptr_end as usize - input_ptr as usize >= 2);
283             let offset = read_u16_ptr(&mut input_ptr) as usize;
284 
285             let output_len = unsafe { output_ptr.offset_from(output_base) as usize };
286             let offset = offset.min(output_len + ext_dict.len());
287 
288             // Check if part of the match is in the external dict
289             if USE_DICT && offset > output_len {
290                 let copied = unsafe {
291                     copy_from_dict(output_base, &mut output_ptr, ext_dict, offset, match_length)
292                 };
293                 if copied == match_length {
294                     continue;
295                 }
296                 // match crosses ext_dict and output
297                 match_length -= copied;
298             }
299 
300             // Calculate the start of this duplicate segment. At this point offset was already
301             // checked to be in bounds and the external dictionary copy, if any, was
302             // already copied and subtracted from match_length.
303             let start_ptr = unsafe { output_ptr.sub(offset) };
304             debug_assert!(start_ptr >= output_base);
305             debug_assert!(start_ptr < output_end);
306             debug_assert!(unsafe { output_end.offset_from(start_ptr) as usize } >= match_length);
307 
308             // In this branch we know that match_length is at most 18 (14 + MINMATCH).
309             // But the blocks can overlap, so make sure they are at least 18 bytes apart
310             // to enable an optimized copy of 18 bytes.
311             if offset >= match_length {
312                 unsafe {
313                     // _copy_, not copy_non_overlaping, as it may overlap.
314                     // Compiles to the same assembly on x68_64.
315                     core::ptr::copy(start_ptr, output_ptr, 18);
316                     output_ptr = output_ptr.add(match_length);
317                 }
318             } else {
319                 unsafe {
320                     duplicate_overlapping(&mut output_ptr, start_ptr, match_length);
321                 }
322             }
323 
324             continue;
325         }
326 
327         // Now, we read the literals section.
328         // Literal Section
329         // If the initial value is 15, it is indicated that another byte will be read and added to
330         // it
331         let mut literal_length = (token >> 4) as usize;
332         if literal_length != 0 {
333             if literal_length == 15 {
334                 // The literal_length length took the maximal value, indicating that there is more
335                 // than 15 literal_length bytes. We read the extra integer.
336                 literal_length += read_integer_ptr(&mut input_ptr, input_ptr_end)? as usize;
337             }
338 
339             #[cfg(not(feature = "unchecked-decode"))]
340             {
341                 // Check if literal is out of bounds for the input, and if there is enough space on
342                 // the output
343                 if literal_length > input_ptr_end as usize - input_ptr as usize {
344                     return Err(DecompressError::LiteralOutOfBounds);
345                 }
346                 if literal_length > unsafe { output_end.offset_from(output_ptr) as usize } {
347                     return Err(DecompressError::OutputTooSmall {
348                         expected: unsafe { output_ptr.offset_from(output_base) as usize }
349                             + literal_length,
350                         actual: output.capacity(),
351                     });
352                 }
353             }
354             unsafe {
355                 fastcpy_unsafe::slice_copy(input_ptr, output_ptr, literal_length);
356                 output_ptr = output_ptr.add(literal_length);
357                 input_ptr = input_ptr.add(literal_length);
358             }
359         }
360 
361         // If the input stream is emptied, we break out of the loop. This is only the case
362         // in the end of the stream, since the block is intact otherwise.
363         if input_ptr >= input_ptr_end {
364             break;
365         }
366 
367         // Read duplicate section
368         #[cfg(not(feature = "unchecked-decode"))]
369         {
370             if (input_ptr_end as usize) - (input_ptr as usize) < 2 {
371                 return Err(DecompressError::ExpectedAnotherByte);
372             }
373         }
374         let offset = read_u16_ptr(&mut input_ptr) as usize;
375         // Obtain the initial match length. The match length is the length of the duplicate segment
376         // which will later be copied from data previously decompressed into the output buffer. The
377         // initial length is derived from the second part of the token (the lower nibble), we read
378         // earlier. Since having a match length of less than 4 would mean negative compression
379         // ratio, we start at 4 (MINMATCH).
380 
381         // The initial match length can maximally be 19 (MINMATCH + 15). As with the literal length,
382         // this indicates that there are more bytes to read.
383         let mut match_length = MINMATCH + (token & 0xF) as usize;
384         if match_length == MINMATCH + 15 {
385             // The match length took the maximal value, indicating that there is more bytes. We
386             // read the extra integer.
387             match_length += read_integer_ptr(&mut input_ptr, input_ptr_end)? as usize;
388         }
389 
390         // We now copy from the already decompressed buffer. This allows us for storing duplicates
391         // by simply referencing the other location.
392         let output_len = unsafe { output_ptr.offset_from(output_base) as usize };
393 
394         // We'll do a bounds check except unchecked-decode is enabled.
395         #[cfg(not(feature = "unchecked-decode"))]
396         {
397             if offset > output_len + ext_dict.len() {
398                 return Err(DecompressError::OffsetOutOfBounds);
399             }
400             if match_length > unsafe { output_end.offset_from(output_ptr) as usize } {
401                 return Err(DecompressError::OutputTooSmall {
402                     expected: output_len + match_length,
403                     actual: output.capacity(),
404                 });
405             }
406         }
407 
408         if USE_DICT && offset > output_len {
409             let copied = unsafe {
410                 copy_from_dict(output_base, &mut output_ptr, ext_dict, offset, match_length)
411             };
412             if copied == match_length {
413                 #[cfg(not(feature = "unchecked-decode"))]
414                 {
415                     if input_ptr >= input_ptr_end {
416                         return Err(DecompressError::ExpectedAnotherByte);
417                     }
418                 }
419 
420                 continue;
421             }
422             // match crosses ext_dict and output
423             match_length -= copied;
424         }
425 
426         // Calculate the start of this duplicate segment. At this point offset was already checked
427         // to be in bounds and the external dictionary copy, if any, was already copied and
428         // subtracted from match_length.
429         let start_ptr = unsafe { output_ptr.sub(offset) };
430         debug_assert!(start_ptr >= output_base);
431         debug_assert!(start_ptr < output_end);
432         debug_assert!(unsafe { output_end.offset_from(start_ptr) as usize } >= match_length);
433         unsafe {
434             duplicate(&mut output_ptr, output_end, start_ptr, match_length);
435         }
436         #[cfg(not(feature = "unchecked-decode"))]
437         {
438             if input_ptr >= input_ptr_end {
439                 return Err(DecompressError::ExpectedAnotherByte);
440             }
441         }
442     }
443     unsafe {
444         output.set_pos(output_ptr.offset_from(output_base) as usize);
445         Ok(output_ptr.offset_from(output_start_pos_ptr) as usize)
446     }
447 }
448 
449 /// Decompress all bytes of `input` into `output`.
450 /// `output` should be preallocated with a size of of the uncompressed data.
451 #[inline]
decompress_into(input: &[u8], output: &mut [u8]) -> Result<usize, DecompressError>452 pub fn decompress_into(input: &[u8], output: &mut [u8]) -> Result<usize, DecompressError> {
453     decompress_internal::<false, _>(input, &mut SliceSink::new(output, 0), b"")
454 }
455 
456 /// Decompress all bytes of `input` into `output`.
457 ///
458 /// Returns the number of bytes written (decompressed) into `output`.
459 #[inline]
decompress_into_with_dict( input: &[u8], output: &mut [u8], ext_dict: &[u8], ) -> Result<usize, DecompressError>460 pub fn decompress_into_with_dict(
461     input: &[u8],
462     output: &mut [u8],
463     ext_dict: &[u8],
464 ) -> Result<usize, DecompressError> {
465     decompress_internal::<true, _>(input, &mut SliceSink::new(output, 0), ext_dict)
466 }
467 
468 /// Decompress all bytes of `input` into a new vec.
469 /// The passed parameter `min_uncompressed_size` needs to be equal or larger than the uncompressed size.
470 ///
471 /// # Panics
472 /// May panic if the parameter `min_uncompressed_size` is smaller than the
473 /// uncompressed data.
474 
475 #[inline]
decompress_with_dict( input: &[u8], min_uncompressed_size: usize, ext_dict: &[u8], ) -> Result<Vec<u8>, DecompressError>476 pub fn decompress_with_dict(
477     input: &[u8],
478     min_uncompressed_size: usize,
479     ext_dict: &[u8],
480 ) -> Result<Vec<u8>, DecompressError> {
481     // Allocate a vector to contain the decompressed stream.
482     let mut vec = Vec::with_capacity(min_uncompressed_size);
483     let decomp_len =
484         decompress_internal::<true, _>(input, &mut PtrSink::from_vec(&mut vec, 0), ext_dict)?;
485     unsafe {
486         vec.set_len(decomp_len);
487     }
488     Ok(vec)
489 }
490 
491 /// Decompress all bytes of `input` into a new vec. The first 4 bytes are the uncompressed size in
492 /// little endian. Can be used in conjunction with `compress_prepend_size`
493 #[inline]
decompress_size_prepended(input: &[u8]) -> Result<Vec<u8>, DecompressError>494 pub fn decompress_size_prepended(input: &[u8]) -> Result<Vec<u8>, DecompressError> {
495     let (uncompressed_size, input) = super::uncompressed_size(input)?;
496     decompress(input, uncompressed_size)
497 }
498 
499 /// Decompress all bytes of `input` into a new vec.
500 /// The passed parameter `min_uncompressed_size` needs to be equal or larger than the uncompressed size.
501 ///
502 /// # Panics
503 /// May panic if the parameter `min_uncompressed_size` is smaller than the
504 /// uncompressed data.
505 #[inline]
decompress(input: &[u8], min_uncompressed_size: usize) -> Result<Vec<u8>, DecompressError>506 pub fn decompress(input: &[u8], min_uncompressed_size: usize) -> Result<Vec<u8>, DecompressError> {
507     // Allocate a vector to contain the decompressed stream.
508     let mut vec = Vec::with_capacity(min_uncompressed_size);
509     let decomp_len =
510         decompress_internal::<true, _>(input, &mut PtrSink::from_vec(&mut vec, 0), b"")?;
511     unsafe {
512         vec.set_len(decomp_len);
513     }
514     Ok(vec)
515 }
516 
517 /// Decompress all bytes of `input` into a new vec. The first 4 bytes are the uncompressed size in
518 /// little endian. Can be used in conjunction with `compress_prepend_size_with_dict`
519 #[inline]
decompress_size_prepended_with_dict( input: &[u8], ext_dict: &[u8], ) -> Result<Vec<u8>, DecompressError>520 pub fn decompress_size_prepended_with_dict(
521     input: &[u8],
522     ext_dict: &[u8],
523 ) -> Result<Vec<u8>, DecompressError> {
524     let (uncompressed_size, input) = super::uncompressed_size(input)?;
525     decompress_with_dict(input, uncompressed_size, ext_dict)
526 }
527 
528 #[cfg(test)]
529 mod test {
530     use super::*;
531 
532     #[test]
all_literal()533     fn all_literal() {
534         assert_eq!(decompress(&[0x30, b'a', b'4', b'9'], 3).unwrap(), b"a49");
535     }
536 
537     // this error test is only valid with checked-decode.
538     #[cfg(not(feature = "unchecked-decode"))]
539     #[test]
offset_oob()540     fn offset_oob() {
541         decompress(&[0x10, b'a', 2, 0], 4).unwrap_err();
542         decompress(&[0x40, b'a', 1, 0], 4).unwrap_err();
543     }
544 }
545