1 //! The compression algorithm.
2 //!
3 //! We make use of hash tables to find duplicates. This gives a reasonable compression ratio with a
4 //! high performance. It has fixed memory usage, which contrary to other approachs, makes it less
5 //! memory hungry.
6 
7 use crate::block::hashtable::HashTable;
8 use crate::block::END_OFFSET;
9 use crate::block::LZ4_MIN_LENGTH;
10 use crate::block::MAX_DISTANCE;
11 use crate::block::MFLIMIT;
12 use crate::block::MINMATCH;
13 #[cfg(not(feature = "safe-encode"))]
14 use crate::sink::PtrSink;
15 use crate::sink::Sink;
16 use crate::sink::SliceSink;
17 #[allow(unused_imports)]
18 use alloc::vec;
19 use alloc::vec::Vec;
20 
21 #[cfg(feature = "safe-encode")]
22 use core::convert::TryInto;
23 
24 use super::hashtable::HashTable4K;
25 use super::hashtable::HashTable4KU16;
26 use super::{CompressError, WINDOW_SIZE};
27 
28 /// Increase step size after 1<<INCREASE_STEPSIZE_BITSHIFT non matches
29 const INCREASE_STEPSIZE_BITSHIFT: usize = 5;
30 
31 /// Read a 4-byte "batch" from some position.
32 ///
33 /// This will read a native-endian 4-byte integer from some position.
34 #[inline]
35 #[cfg(not(feature = "safe-encode"))]
get_batch(input: &[u8], n: usize) -> u3236 pub(super) fn get_batch(input: &[u8], n: usize) -> u32 {
37     unsafe { read_u32_ptr(input.as_ptr().add(n)) }
38 }
39 
40 #[inline]
41 #[cfg(feature = "safe-encode")]
get_batch(input: &[u8], n: usize) -> u3242 pub(super) fn get_batch(input: &[u8], n: usize) -> u32 {
43     u32::from_ne_bytes(input[n..n + 4].try_into().unwrap())
44 }
45 
46 /// Read an usize sized "batch" from some position.
47 ///
48 /// This will read a native-endian usize from some position.
49 #[inline]
50 #[allow(dead_code)]
51 #[cfg(not(feature = "safe-encode"))]
get_batch_arch(input: &[u8], n: usize) -> usize52 pub(super) fn get_batch_arch(input: &[u8], n: usize) -> usize {
53     unsafe { read_usize_ptr(input.as_ptr().add(n)) }
54 }
55 
56 #[inline]
57 #[allow(dead_code)]
58 #[cfg(feature = "safe-encode")]
get_batch_arch(input: &[u8], n: usize) -> usize59 pub(super) fn get_batch_arch(input: &[u8], n: usize) -> usize {
60     const USIZE_SIZE: usize = core::mem::size_of::<usize>();
61     let arr: &[u8; USIZE_SIZE] = input[n..n + USIZE_SIZE].try_into().unwrap();
62     usize::from_ne_bytes(*arr)
63 }
64 
65 #[inline]
token_from_literal(lit_len: usize) -> u866 fn token_from_literal(lit_len: usize) -> u8 {
67     if lit_len < 0xF {
68         // Since we can fit the literals length into it, there is no need for saturation.
69         (lit_len as u8) << 4
70     } else {
71         // We were unable to fit the literals into it, so we saturate to 0xF. We will later
72         // write the extensional value.
73         0xF0
74     }
75 }
76 
77 #[inline]
token_from_literal_and_match_length(lit_len: usize, duplicate_length: usize) -> u878 fn token_from_literal_and_match_length(lit_len: usize, duplicate_length: usize) -> u8 {
79     let mut token = if lit_len < 0xF {
80         // Since we can fit the literals length into it, there is no need for saturation.
81         (lit_len as u8) << 4
82     } else {
83         // We were unable to fit the literals into it, so we saturate to 0xF. We will later
84         // write the extensional value.
85         0xF0
86     };
87 
88     token |= if duplicate_length < 0xF {
89         // We could fit it in.
90         duplicate_length as u8
91     } else {
92         // We were unable to fit it in, so we default to 0xF, which will later be extended.
93         0xF
94     };
95 
96     token
97 }
98 
99 /// Counts the number of same bytes in two byte streams.
100 /// `input` is the complete input
101 /// `cur` is the current position in the input. it will be incremented by the number of matched
102 /// bytes `source` either the same as input or an external slice
103 /// `candidate` is the candidate position in `source`
104 ///
105 /// The function ignores the last END_OFFSET bytes in input as those should be literals.
106 #[inline]
107 #[cfg(feature = "safe-encode")]
count_same_bytes(input: &[u8], cur: &mut usize, source: &[u8], candidate: usize) -> usize108 fn count_same_bytes(input: &[u8], cur: &mut usize, source: &[u8], candidate: usize) -> usize {
109     const USIZE_SIZE: usize = core::mem::size_of::<usize>();
110     let cur_slice = &input[*cur..input.len() - END_OFFSET];
111     let cand_slice = &source[candidate..];
112 
113     let mut num = 0;
114     for (block1, block2) in cur_slice
115         .chunks_exact(USIZE_SIZE)
116         .zip(cand_slice.chunks_exact(USIZE_SIZE))
117     {
118         let input_block = usize::from_ne_bytes(block1.try_into().unwrap());
119         let match_block = usize::from_ne_bytes(block2.try_into().unwrap());
120 
121         if input_block == match_block {
122             num += USIZE_SIZE;
123         } else {
124             let diff = input_block ^ match_block;
125             num += (diff.to_le().trailing_zeros() / 8) as usize;
126             *cur += num;
127             return num;
128         }
129     }
130 
131     // If we're here we may have 1 to 7 bytes left to check close to the end of input
132     // or source slices. Since this is rare occurrence we mark it cold to get better
133     // ~5% better performance.
134     #[cold]
135     fn count_same_bytes_tail(a: &[u8], b: &[u8], offset: usize) -> usize {
136         a.iter()
137             .zip(b)
138             .skip(offset)
139             .take_while(|(a, b)| a == b)
140             .count()
141     }
142     num += count_same_bytes_tail(cur_slice, cand_slice, num);
143 
144     *cur += num;
145     num
146 }
147 
148 /// Counts the number of same bytes in two byte streams.
149 /// `input` is the complete input
150 /// `cur` is the current position in the input. it will be incremented by the number of matched
151 /// bytes `source` either the same as input OR an external slice
152 /// `candidate` is the candidate position in `source`
153 ///
154 /// The function ignores the last END_OFFSET bytes in input as those should be literals.
155 #[inline]
156 #[cfg(not(feature = "safe-encode"))]
count_same_bytes(input: &[u8], cur: &mut usize, source: &[u8], candidate: usize) -> usize157 fn count_same_bytes(input: &[u8], cur: &mut usize, source: &[u8], candidate: usize) -> usize {
158     let max_input_match = input.len().saturating_sub(*cur + END_OFFSET);
159     let max_candidate_match = source.len() - candidate;
160     // Considering both limits calc how far we may match in input.
161     let input_end = *cur + max_input_match.min(max_candidate_match);
162 
163     let start = *cur;
164     let mut source_ptr = unsafe { source.as_ptr().add(candidate) };
165 
166     // compare 4/8 bytes blocks depending on the arch
167     const STEP_SIZE: usize = core::mem::size_of::<usize>();
168     while *cur + STEP_SIZE <= input_end {
169         let diff = read_usize_ptr(unsafe { input.as_ptr().add(*cur) }) ^ read_usize_ptr(source_ptr);
170 
171         if diff == 0 {
172             *cur += STEP_SIZE;
173             unsafe {
174                 source_ptr = source_ptr.add(STEP_SIZE);
175             }
176         } else {
177             *cur += (diff.to_le().trailing_zeros() / 8) as usize;
178             return *cur - start;
179         }
180     }
181 
182     // compare 4 bytes block
183     #[cfg(target_pointer_width = "64")]
184     {
185         if input_end - *cur >= 4 {
186             let diff = read_u32_ptr(unsafe { input.as_ptr().add(*cur) }) ^ read_u32_ptr(source_ptr);
187 
188             if diff == 0 {
189                 *cur += 4;
190                 unsafe {
191                     source_ptr = source_ptr.add(4);
192                 }
193             } else {
194                 *cur += (diff.to_le().trailing_zeros() / 8) as usize;
195                 return *cur - start;
196             }
197         }
198     }
199 
200     // compare 2 bytes block
201     if input_end - *cur >= 2
202         && unsafe { read_u16_ptr(input.as_ptr().add(*cur)) == read_u16_ptr(source_ptr) }
203     {
204         *cur += 2;
205         unsafe {
206             source_ptr = source_ptr.add(2);
207         }
208     }
209 
210     if *cur < input_end
211         && unsafe { input.as_ptr().add(*cur).read() } == unsafe { source_ptr.read() }
212     {
213         *cur += 1;
214     }
215 
216     *cur - start
217 }
218 
219 /// Write an integer to the output.
220 ///
221 /// Each additional byte then represent a value from 0 to 255, which is added to the previous value
222 /// to produce a total length. When the byte value is 255, another byte must read and added, and so
223 /// on. There can be any number of bytes of value "255" following token
224 #[inline]
225 #[cfg(feature = "safe-encode")]
write_integer(output: &mut impl Sink, mut n: usize)226 fn write_integer(output: &mut impl Sink, mut n: usize) {
227     // Note: Since `n` is usually < 0xFF and writing multiple bytes to the output
228     // requires 2 branches of bound check (due to the possibility of add overflows)
229     // the simple byte at a time implementation below is faster in most cases.
230     while n >= 0xFF {
231         n -= 0xFF;
232         push_byte(output, 0xFF);
233     }
234     push_byte(output, n as u8);
235 }
236 
237 /// Write an integer to the output.
238 ///
239 /// Each additional byte then represent a value from 0 to 255, which is added to the previous value
240 /// to produce a total length. When the byte value is 255, another byte must read and added, and so
241 /// on. There can be any number of bytes of value "255" following token
242 #[inline]
243 #[cfg(not(feature = "safe-encode"))]
write_integer(output: &mut impl Sink, mut n: usize)244 fn write_integer(output: &mut impl Sink, mut n: usize) {
245     // Write the 0xFF bytes as long as the integer is higher than said value.
246     if n >= 4 * 0xFF {
247         // In this unlikelly branch we use a fill instead of a loop,
248         // otherwise rustc may output a large unrolled/vectorized loop.
249         let bulk = n / (4 * 0xFF);
250         n %= 4 * 0xFF;
251         unsafe {
252             core::ptr::write_bytes(output.pos_mut_ptr(), 0xFF, 4 * bulk);
253             output.set_pos(output.pos() + 4 * bulk);
254         }
255     }
256 
257     // Handle last 1 to 4 bytes
258     push_u32(output, 0xFFFFFFFF);
259     // Updating output len for the remainder
260     unsafe {
261         output.set_pos(output.pos() - 4 + 1 + n / 255);
262         // Write the remaining byte.
263         *output.pos_mut_ptr().sub(1) = (n % 255) as u8;
264     }
265 }
266 
267 /// Handle the last bytes from the input as literals
268 #[cold]
handle_last_literals(output: &mut impl Sink, input: &[u8], start: usize)269 fn handle_last_literals(output: &mut impl Sink, input: &[u8], start: usize) {
270     let lit_len = input.len() - start;
271 
272     let token = token_from_literal(lit_len);
273     push_byte(output, token);
274     if lit_len >= 0xF {
275         write_integer(output, lit_len - 0xF);
276     }
277     // Now, write the actual literals.
278     output.extend_from_slice(&input[start..]);
279 }
280 
281 /// Moves the cursors back as long as the bytes match, to find additional bytes in a duplicate
282 #[inline]
283 #[cfg(feature = "safe-encode")]
backtrack_match( input: &[u8], cur: &mut usize, literal_start: usize, source: &[u8], candidate: &mut usize, )284 fn backtrack_match(
285     input: &[u8],
286     cur: &mut usize,
287     literal_start: usize,
288     source: &[u8],
289     candidate: &mut usize,
290 ) {
291     // Note: Even if iterator version of this loop has less branches inside the loop it has more
292     // branches before the loop. That in practice seems to make it slower than the while version
293     // bellow. TODO: It should be possible remove all bounds checks, since we are walking
294     // backwards
295     while *candidate > 0 && *cur > literal_start && input[*cur - 1] == source[*candidate - 1] {
296         *cur -= 1;
297         *candidate -= 1;
298     }
299 }
300 
301 /// Moves the cursors back as long as the bytes match, to find additional bytes in a duplicate
302 #[inline]
303 #[cfg(not(feature = "safe-encode"))]
backtrack_match( input: &[u8], cur: &mut usize, literal_start: usize, source: &[u8], candidate: &mut usize, )304 fn backtrack_match(
305     input: &[u8],
306     cur: &mut usize,
307     literal_start: usize,
308     source: &[u8],
309     candidate: &mut usize,
310 ) {
311     while unsafe {
312         *candidate > 0
313             && *cur > literal_start
314             && input.get_unchecked(*cur - 1) == source.get_unchecked(*candidate - 1)
315     } {
316         *cur -= 1;
317         *candidate -= 1;
318     }
319 }
320 
321 /// Compress all bytes of `input[input_pos..]` into `output`.
322 ///
323 /// Bytes in `input[..input_pos]` are treated as a preamble and can be used for lookback.
324 /// This part is known as the compressor "prefix".
325 /// Bytes in `ext_dict` logically precede the bytes in `input` and can also be used for lookback.
326 ///
327 /// `input_stream_offset` is the logical position of the first byte of `input`. This allows same
328 /// `dict` to be used for many calls to `compress_internal` as we can "readdress" the first byte of
329 /// `input` to be something other than 0.
330 ///
331 /// `dict` is the dictionary of previously encoded sequences.
332 ///
333 /// This is used to find duplicates in the stream so they are not written multiple times.
334 ///
335 /// Every four bytes are hashed, and in the resulting slot their position in the input buffer
336 /// is placed in the dict. This way we can easily look up a candidate to back references.
337 ///
338 /// Returns the number of bytes written (compressed) into `output`.
339 ///
340 /// # Const parameters
341 /// `USE_DICT`: Disables usage of ext_dict (it'll panic if a non-empty slice is used).
342 /// In other words, this generates more optimized code when an external dictionary isn't used.
343 ///
344 /// A similar const argument could be used to disable the Prefix mode (eg. USE_PREFIX),
345 /// which would impose `input_pos == 0 && input_stream_offset == 0`. Experiments didn't
346 /// show significant improvement though.
347 // Intentionally avoid inlining.
348 // Empirical tests revealed it to be rarely better but often significantly detrimental.
349 #[inline(never)]
compress_internal<T: HashTable, const USE_DICT: bool, S: Sink>( input: &[u8], input_pos: usize, output: &mut S, dict: &mut T, ext_dict: &[u8], input_stream_offset: usize, ) -> Result<usize, CompressError>350 pub(crate) fn compress_internal<T: HashTable, const USE_DICT: bool, S: Sink>(
351     input: &[u8],
352     input_pos: usize,
353     output: &mut S,
354     dict: &mut T,
355     ext_dict: &[u8],
356     input_stream_offset: usize,
357 ) -> Result<usize, CompressError> {
358     assert!(input_pos <= input.len());
359     if USE_DICT {
360         assert!(ext_dict.len() <= super::WINDOW_SIZE);
361         assert!(ext_dict.len() <= input_stream_offset);
362         // Check for overflow hazard when using ext_dict
363         assert!(input_stream_offset
364             .checked_add(input.len())
365             .and_then(|i| i.checked_add(ext_dict.len()))
366             .map_or(false, |i| i <= isize::MAX as usize));
367     } else {
368         assert!(ext_dict.is_empty());
369     }
370     if output.capacity() - output.pos() < get_maximum_output_size(input.len() - input_pos) {
371         return Err(CompressError::OutputTooSmall);
372     }
373 
374     let output_start_pos = output.pos();
375     if input.len() - input_pos < LZ4_MIN_LENGTH {
376         handle_last_literals(output, input, input_pos);
377         return Ok(output.pos() - output_start_pos);
378     }
379 
380     let ext_dict_stream_offset = input_stream_offset - ext_dict.len();
381     let end_pos_check = input.len() - MFLIMIT;
382     let mut literal_start = input_pos;
383     let mut cur = input_pos;
384 
385     if cur == 0 && input_stream_offset == 0 {
386         // According to the spec we can't start with a match,
387         // except when referencing another block.
388         let hash = T::get_hash_at(input, 0);
389         dict.put_at(hash, 0);
390         cur = 1;
391     }
392 
393     loop {
394         // Read the next block into two sections, the literals and the duplicates.
395         let mut step_size;
396         let mut candidate;
397         let mut candidate_source;
398         let mut offset;
399         let mut non_match_count = 1 << INCREASE_STEPSIZE_BITSHIFT;
400         // The number of bytes before our cursor, where the duplicate starts.
401         let mut next_cur = cur;
402 
403         // In this loop we search for duplicates via the hashtable. 4bytes or 8bytes are hashed and
404         // compared.
405         loop {
406             step_size = non_match_count >> INCREASE_STEPSIZE_BITSHIFT;
407             non_match_count += 1;
408 
409             cur = next_cur;
410             next_cur += step_size;
411 
412             // Same as cur + MFLIMIT > input.len()
413             if cur > end_pos_check {
414                 handle_last_literals(output, input, literal_start);
415                 return Ok(output.pos() - output_start_pos);
416             }
417             // Find a candidate in the dictionary with the hash of the current four bytes.
418             // Unchecked is safe as long as the values from the hash function don't exceed the size
419             // of the table. This is ensured by right shifting the hash values
420             // (`dict_bitshift`) to fit them in the table
421 
422             // [Bounds Check]: Can be elided due to `end_pos_check` above
423             let hash = T::get_hash_at(input, cur);
424             candidate = dict.get_at(hash);
425             dict.put_at(hash, cur + input_stream_offset);
426 
427             // Sanity check: Matches can't be ahead of `cur`.
428             debug_assert!(candidate <= input_stream_offset + cur);
429 
430             // Two requirements to the candidate exists:
431             // - We should not return a position which is merely a hash collision, so that the
432             //   candidate actually matches what we search for.
433             // - We can address up to 16-bit offset, hence we are only able to address the candidate
434             //   if its offset is less than or equals to 0xFFFF.
435             if input_stream_offset + cur - candidate > MAX_DISTANCE {
436                 continue;
437             }
438 
439             if candidate >= input_stream_offset {
440                 // match within input
441                 offset = (input_stream_offset + cur - candidate) as u16;
442                 candidate -= input_stream_offset;
443                 candidate_source = input;
444             } else if USE_DICT {
445                 // Sanity check, which may fail if we lost history beyond MAX_DISTANCE
446                 debug_assert!(
447                     candidate >= ext_dict_stream_offset,
448                     "Lost history in ext dict mode"
449                 );
450                 // match within ext dict
451                 offset = (input_stream_offset + cur - candidate) as u16;
452                 candidate -= ext_dict_stream_offset;
453                 candidate_source = ext_dict;
454             } else {
455                 // Match is not reachable anymore
456                 // eg. compressing an independent block frame w/o clearing
457                 // the matches tables, only increasing input_stream_offset.
458                 // Sanity check
459                 debug_assert!(input_pos == 0, "Lost history in prefix mode");
460                 continue;
461             }
462             // [Bounds Check]: Candidate is coming from the Hashmap. It can't be out of bounds, but
463             // impossible to prove for the compiler and remove the bounds checks.
464             let cand_bytes: u32 = get_batch(candidate_source, candidate);
465             // [Bounds Check]: Should be able to be elided due to `end_pos_check`.
466             let curr_bytes: u32 = get_batch(input, cur);
467 
468             if cand_bytes == curr_bytes {
469                 break;
470             }
471         }
472 
473         // Extend the match backwards if we can
474         backtrack_match(
475             input,
476             &mut cur,
477             literal_start,
478             candidate_source,
479             &mut candidate,
480         );
481 
482         // The length (in bytes) of the literals section.
483         let lit_len = cur - literal_start;
484 
485         // Generate the higher half of the token.
486         cur += MINMATCH;
487         candidate += MINMATCH;
488         let duplicate_length = count_same_bytes(input, &mut cur, candidate_source, candidate);
489 
490         // Note: The `- 2` offset was copied from the reference implementation, it could be
491         // arbitrary.
492         let hash = T::get_hash_at(input, cur - 2);
493         dict.put_at(hash, cur - 2 + input_stream_offset);
494 
495         let token = token_from_literal_and_match_length(lit_len, duplicate_length);
496 
497         // Push the token to the output stream.
498         push_byte(output, token);
499         // If we were unable to fit the literals length into the token, write the extensional
500         // part.
501         if lit_len >= 0xF {
502             write_integer(output, lit_len - 0xF);
503         }
504 
505         // Now, write the actual literals.
506         //
507         // The unsafe version copies blocks of 8bytes, and therefore may copy up to 7bytes more than
508         // needed. This is safe, because the last 12 bytes (MF_LIMIT) are handled in
509         // handle_last_literals.
510         copy_literals_wild(output, input, literal_start, lit_len);
511         // write the offset in little endian.
512         push_u16(output, offset);
513 
514         // If we were unable to fit the duplicates length into the token, write the
515         // extensional part.
516         if duplicate_length >= 0xF {
517             write_integer(output, duplicate_length - 0xF);
518         }
519         literal_start = cur;
520     }
521 }
522 
523 #[inline]
524 #[cfg(feature = "safe-encode")]
push_byte(output: &mut impl Sink, el: u8)525 fn push_byte(output: &mut impl Sink, el: u8) {
526     output.push(el);
527 }
528 
529 #[inline]
530 #[cfg(not(feature = "safe-encode"))]
push_byte(output: &mut impl Sink, el: u8)531 fn push_byte(output: &mut impl Sink, el: u8) {
532     unsafe {
533         core::ptr::write(output.pos_mut_ptr(), el);
534         output.set_pos(output.pos() + 1);
535     }
536 }
537 
538 #[inline]
539 #[cfg(feature = "safe-encode")]
push_u16(output: &mut impl Sink, el: u16)540 fn push_u16(output: &mut impl Sink, el: u16) {
541     output.extend_from_slice(&el.to_le_bytes());
542 }
543 
544 #[inline]
545 #[cfg(not(feature = "safe-encode"))]
push_u16(output: &mut impl Sink, el: u16)546 fn push_u16(output: &mut impl Sink, el: u16) {
547     unsafe {
548         core::ptr::copy_nonoverlapping(el.to_le_bytes().as_ptr(), output.pos_mut_ptr(), 2);
549         output.set_pos(output.pos() + 2);
550     }
551 }
552 
553 #[inline]
554 #[cfg(not(feature = "safe-encode"))]
push_u32(output: &mut impl Sink, el: u32)555 fn push_u32(output: &mut impl Sink, el: u32) {
556     unsafe {
557         core::ptr::copy_nonoverlapping(el.to_le_bytes().as_ptr(), output.pos_mut_ptr(), 4);
558         output.set_pos(output.pos() + 4);
559     }
560 }
561 
562 #[inline(always)] // (always) necessary otherwise compiler fails to inline it
563 #[cfg(feature = "safe-encode")]
copy_literals_wild(output: &mut impl Sink, input: &[u8], input_start: usize, len: usize)564 fn copy_literals_wild(output: &mut impl Sink, input: &[u8], input_start: usize, len: usize) {
565     output.extend_from_slice_wild(&input[input_start..input_start + len], len)
566 }
567 
568 #[inline]
569 #[cfg(not(feature = "safe-encode"))]
copy_literals_wild(output: &mut impl Sink, input: &[u8], input_start: usize, len: usize)570 fn copy_literals_wild(output: &mut impl Sink, input: &[u8], input_start: usize, len: usize) {
571     debug_assert!(input_start + len / 8 * 8 + ((len % 8) != 0) as usize * 8 <= input.len());
572     debug_assert!(output.pos() + len / 8 * 8 + ((len % 8) != 0) as usize * 8 <= output.capacity());
573     unsafe {
574         // Note: This used to be a wild copy loop of 8 bytes, but the compiler consistently
575         // transformed it into a call to memcopy, which hurts performance significantly for
576         // small copies, which are common.
577         let start_ptr = input.as_ptr().add(input_start);
578         match len {
579             0..=8 => core::ptr::copy_nonoverlapping(start_ptr, output.pos_mut_ptr(), 8),
580             9..=16 => core::ptr::copy_nonoverlapping(start_ptr, output.pos_mut_ptr(), 16),
581             17..=24 => core::ptr::copy_nonoverlapping(start_ptr, output.pos_mut_ptr(), 24),
582             _ => core::ptr::copy_nonoverlapping(start_ptr, output.pos_mut_ptr(), len),
583         }
584         output.set_pos(output.pos() + len);
585     }
586 }
587 
588 /// Compress all bytes of `input` into `output`.
589 /// The method chooses an appropriate hashtable to lookup duplicates.
590 /// output should be preallocated with a size of
591 /// `get_maximum_output_size`.
592 ///
593 /// Returns the number of bytes written (compressed) into `output`.
594 
595 #[inline]
compress_into_sink_with_dict<const USE_DICT: bool>( input: &[u8], output: &mut impl Sink, mut dict_data: &[u8], ) -> Result<usize, CompressError>596 pub(crate) fn compress_into_sink_with_dict<const USE_DICT: bool>(
597     input: &[u8],
598     output: &mut impl Sink,
599     mut dict_data: &[u8],
600 ) -> Result<usize, CompressError> {
601     if dict_data.len() + input.len() < u16::MAX as usize {
602         let mut dict = HashTable4KU16::new();
603         init_dict(&mut dict, &mut dict_data);
604         compress_internal::<_, USE_DICT, _>(input, 0, output, &mut dict, dict_data, dict_data.len())
605     } else {
606         let mut dict = HashTable4K::new();
607         init_dict(&mut dict, &mut dict_data);
608         compress_internal::<_, USE_DICT, _>(input, 0, output, &mut dict, dict_data, dict_data.len())
609     }
610 }
611 
612 #[inline]
init_dict<T: HashTable>(dict: &mut T, dict_data: &mut &[u8])613 fn init_dict<T: HashTable>(dict: &mut T, dict_data: &mut &[u8]) {
614     if dict_data.len() > WINDOW_SIZE {
615         *dict_data = &dict_data[dict_data.len() - WINDOW_SIZE..];
616     }
617     let mut i = 0usize;
618     while i + core::mem::size_of::<usize>() <= dict_data.len() {
619         let hash = T::get_hash_at(dict_data, i);
620         dict.put_at(hash, i);
621         // Note: The 3 byte step was copied from the reference implementation, it could be
622         // arbitrary.
623         i += 3;
624     }
625 }
626 
627 /// Returns the maximum output size of the compressed data.
628 /// Can be used to preallocate capacity on the output vector
629 #[inline]
get_maximum_output_size(input_len: usize) -> usize630 pub fn get_maximum_output_size(input_len: usize) -> usize {
631     16 + 4 + (input_len as f64 * 1.1) as usize
632 }
633 
634 /// Compress all bytes of `input` into `output`.
635 /// The method chooses an appropriate hashtable to lookup duplicates.
636 /// output should be preallocated with a size of
637 /// `get_maximum_output_size`.
638 ///
639 /// Returns the number of bytes written (compressed) into `output`.
640 #[inline]
compress_into(input: &[u8], output: &mut [u8]) -> Result<usize, CompressError>641 pub fn compress_into(input: &[u8], output: &mut [u8]) -> Result<usize, CompressError> {
642     compress_into_sink_with_dict::<false>(input, &mut SliceSink::new(output, 0), b"")
643 }
644 
645 /// Compress all bytes of `input` into `output`.
646 /// The method chooses an appropriate hashtable to lookup duplicates.
647 /// output should be preallocated with a size of
648 /// `get_maximum_output_size`.
649 ///
650 /// Returns the number of bytes written (compressed) into `output`.
651 #[inline]
compress_into_with_dict( input: &[u8], output: &mut [u8], dict_data: &[u8], ) -> Result<usize, CompressError>652 pub fn compress_into_with_dict(
653     input: &[u8],
654     output: &mut [u8],
655     dict_data: &[u8],
656 ) -> Result<usize, CompressError> {
657     compress_into_sink_with_dict::<true>(input, &mut SliceSink::new(output, 0), dict_data)
658 }
659 
660 #[inline]
compress_into_vec_with_dict<const USE_DICT: bool>( input: &[u8], prepend_size: bool, mut dict_data: &[u8], ) -> Vec<u8>661 fn compress_into_vec_with_dict<const USE_DICT: bool>(
662     input: &[u8],
663     prepend_size: bool,
664     mut dict_data: &[u8],
665 ) -> Vec<u8> {
666     let prepend_size_num_bytes = if prepend_size { 4 } else { 0 };
667     let max_compressed_size = get_maximum_output_size(input.len()) + prepend_size_num_bytes;
668     if dict_data.len() <= 3 {
669         dict_data = b"";
670     }
671     #[cfg(feature = "safe-encode")]
672     let mut compressed = {
673         let mut compressed: Vec<u8> = vec![0u8; max_compressed_size];
674         let out = if prepend_size {
675             compressed[..4].copy_from_slice(&(input.len() as u32).to_le_bytes());
676             &mut compressed[4..]
677         } else {
678             &mut compressed
679         };
680         let compressed_len =
681             compress_into_sink_with_dict::<USE_DICT>(input, &mut SliceSink::new(out, 0), dict_data)
682                 .unwrap();
683 
684         compressed.truncate(prepend_size_num_bytes + compressed_len);
685         compressed
686     };
687     #[cfg(not(feature = "safe-encode"))]
688     let mut compressed = {
689         let mut vec = Vec::with_capacity(max_compressed_size);
690         let start_pos = if prepend_size {
691             vec.extend_from_slice(&(input.len() as u32).to_le_bytes());
692             4
693         } else {
694             0
695         };
696         let compressed_len = compress_into_sink_with_dict::<USE_DICT>(
697             input,
698             &mut PtrSink::from_vec(&mut vec, start_pos),
699             dict_data,
700         )
701         .unwrap();
702         unsafe {
703             vec.set_len(prepend_size_num_bytes + compressed_len);
704         }
705         vec
706     };
707 
708     compressed.shrink_to_fit();
709     compressed
710 }
711 
712 /// Compress all bytes of `input` into `output`. The uncompressed size will be prepended as a little
713 /// endian u32. Can be used in conjunction with `decompress_size_prepended`
714 #[inline]
compress_prepend_size(input: &[u8]) -> Vec<u8>715 pub fn compress_prepend_size(input: &[u8]) -> Vec<u8> {
716     compress_into_vec_with_dict::<false>(input, true, b"")
717 }
718 
719 /// Compress all bytes of `input`.
720 #[inline]
compress(input: &[u8]) -> Vec<u8>721 pub fn compress(input: &[u8]) -> Vec<u8> {
722     compress_into_vec_with_dict::<false>(input, false, b"")
723 }
724 
725 /// Compress all bytes of `input` with an external dictionary.
726 #[inline]
compress_with_dict(input: &[u8], ext_dict: &[u8]) -> Vec<u8>727 pub fn compress_with_dict(input: &[u8], ext_dict: &[u8]) -> Vec<u8> {
728     compress_into_vec_with_dict::<true>(input, false, ext_dict)
729 }
730 
731 /// Compress all bytes of `input` into `output`. The uncompressed size will be prepended as a little
732 /// endian u32. Can be used in conjunction with `decompress_size_prepended_with_dict`
733 #[inline]
compress_prepend_size_with_dict(input: &[u8], ext_dict: &[u8]) -> Vec<u8>734 pub fn compress_prepend_size_with_dict(input: &[u8], ext_dict: &[u8]) -> Vec<u8> {
735     compress_into_vec_with_dict::<true>(input, true, ext_dict)
736 }
737 
738 #[inline]
739 #[cfg(not(feature = "safe-encode"))]
read_u16_ptr(input: *const u8) -> u16740 fn read_u16_ptr(input: *const u8) -> u16 {
741     let mut num: u16 = 0;
742     unsafe {
743         core::ptr::copy_nonoverlapping(input, &mut num as *mut u16 as *mut u8, 2);
744     }
745     num
746 }
747 
748 #[inline]
749 #[cfg(not(feature = "safe-encode"))]
read_u32_ptr(input: *const u8) -> u32750 fn read_u32_ptr(input: *const u8) -> u32 {
751     let mut num: u32 = 0;
752     unsafe {
753         core::ptr::copy_nonoverlapping(input, &mut num as *mut u32 as *mut u8, 4);
754     }
755     num
756 }
757 
758 #[inline]
759 #[cfg(not(feature = "safe-encode"))]
read_usize_ptr(input: *const u8) -> usize760 fn read_usize_ptr(input: *const u8) -> usize {
761     let mut num: usize = 0;
762     unsafe {
763         core::ptr::copy_nonoverlapping(
764             input,
765             &mut num as *mut usize as *mut u8,
766             core::mem::size_of::<usize>(),
767         );
768     }
769     num
770 }
771 
772 #[cfg(test)]
773 mod tests {
774     use super::*;
775 
776     #[test]
test_count_same_bytes()777     fn test_count_same_bytes() {
778         // 8byte aligned block, zeros and ones are added because the end/offset
779         let first: &[u8] = &[
780             1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
781         ];
782         let second: &[u8] = &[
783             1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
784         ];
785         assert_eq!(count_same_bytes(first, &mut 0, second, 0), 16);
786 
787         // 4byte aligned block
788         let first: &[u8] = &[
789             1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0,
790             0, 0, 0,
791         ];
792         let second: &[u8] = &[
793             1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1,
794             1, 1, 1,
795         ];
796         assert_eq!(count_same_bytes(first, &mut 0, second, 0), 20);
797 
798         // 2byte aligned block
799         let first: &[u8] = &[
800             1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 4, 0, 0, 0, 0, 0, 0, 0,
801             0, 0, 0, 0, 0,
802         ];
803         let second: &[u8] = &[
804             1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 4, 1, 1, 1, 1, 1, 1, 1,
805             1, 1, 1, 1, 1,
806         ];
807         assert_eq!(count_same_bytes(first, &mut 0, second, 0), 22);
808 
809         // 1byte aligned block
810         let first: &[u8] = &[
811             1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 4, 5, 0, 0, 0, 0, 0, 0,
812             0, 0, 0, 0, 0, 0,
813         ];
814         let second: &[u8] = &[
815             1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 4, 5, 1, 1, 1, 1, 1, 1,
816             1, 1, 1, 1, 1, 1,
817         ];
818         assert_eq!(count_same_bytes(first, &mut 0, second, 0), 23);
819 
820         // 1byte aligned block - last byte different
821         let first: &[u8] = &[
822             1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 4, 5, 0, 0, 0, 0, 0, 0,
823             0, 0, 0, 0, 0, 0,
824         ];
825         let second: &[u8] = &[
826             1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 4, 6, 1, 1, 1, 1, 1, 1,
827             1, 1, 1, 1, 1, 1,
828         ];
829         assert_eq!(count_same_bytes(first, &mut 0, second, 0), 22);
830 
831         // 1byte aligned block
832         let first: &[u8] = &[
833             1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 9, 5, 0, 0, 0, 0, 0, 0,
834             0, 0, 0, 0, 0, 0,
835         ];
836         let second: &[u8] = &[
837             1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 4, 6, 1, 1, 1, 1, 1, 1,
838             1, 1, 1, 1, 1, 1,
839         ];
840         assert_eq!(count_same_bytes(first, &mut 0, second, 0), 21);
841 
842         for diff_idx in 8..100 {
843             let first: Vec<u8> = (0u8..255).cycle().take(100 + 12).collect();
844             let mut second = first.clone();
845             second[diff_idx] = 255;
846             for start in 0..=diff_idx {
847                 let same_bytes = count_same_bytes(&first, &mut start.clone(), &second, start);
848                 assert_eq!(same_bytes, diff_idx - start);
849             }
850         }
851     }
852 
853     #[test]
test_bug()854     fn test_bug() {
855         let input: &[u8] = &[
856             10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18,
857         ];
858         let _out = compress(input);
859     }
860 
861     #[test]
test_dict()862     fn test_dict() {
863         let input: &[u8] = &[
864             10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18,
865         ];
866         let dict = input;
867         let compressed = compress_with_dict(input, dict);
868         assert_lt!(compressed.len(), compress(input).len());
869 
870         assert!(compressed.len() < compress(input).len());
871         let mut uncompressed = vec![0u8; input.len()];
872         let uncomp_size = crate::block::decompress::decompress_into_with_dict(
873             &compressed,
874             &mut uncompressed,
875             dict,
876         )
877         .unwrap();
878         uncompressed.truncate(uncomp_size);
879         assert_eq!(input, uncompressed);
880     }
881 
882     #[test]
test_dict_no_panic()883     fn test_dict_no_panic() {
884         let input: &[u8] = &[
885             10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18,
886         ];
887         let dict = &[10, 12, 14];
888         let _compressed = compress_with_dict(input, dict);
889     }
890 
891     #[test]
test_dict_match_crossing()892     fn test_dict_match_crossing() {
893         let input: &[u8] = &[
894             10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18,
895         ];
896         let dict = input;
897         let compressed = compress_with_dict(input, dict);
898         assert_lt!(compressed.len(), compress(input).len());
899 
900         let mut uncompressed = vec![0u8; input.len() * 2];
901         // copy first half of the input into output
902         let dict_cutoff = dict.len() / 2;
903         let output_start = dict.len() - dict_cutoff;
904         uncompressed[..output_start].copy_from_slice(&dict[dict_cutoff..]);
905         let uncomp_len = {
906             let mut sink = SliceSink::new(&mut uncompressed[..], output_start);
907             crate::block::decompress::decompress_internal::<true, _>(
908                 &compressed,
909                 &mut sink,
910                 &dict[..dict_cutoff],
911             )
912             .unwrap()
913         };
914         assert_eq!(input.len(), uncomp_len);
915         assert_eq!(
916             input,
917             &uncompressed[output_start..output_start + uncomp_len]
918         );
919     }
920 
921     #[test]
test_conformant_last_block()922     fn test_conformant_last_block() {
923         // From the spec:
924         // The last match must start at least 12 bytes before the end of block.
925         // The last match is part of the penultimate sequence. It is followed by the last sequence,
926         // which contains only literals. Note that, as a consequence, an independent block <
927         // 13 bytes cannot be compressed, because the match must copy "something",
928         // so it needs at least one prior byte.
929         // When a block can reference data from another block, it can start immediately with a match
930         // and no literal, so a block of 12 bytes can be compressed.
931         let aaas: &[u8] = b"aaaaaaaaaaaaaaa";
932 
933         // uncompressible
934         let out = compress(&aaas[..12]);
935         assert_gt!(out.len(), 12);
936         // compressible
937         let out = compress(&aaas[..13]);
938         assert_le!(out.len(), 13);
939         let out = compress(&aaas[..14]);
940         assert_le!(out.len(), 14);
941         let out = compress(&aaas[..15]);
942         assert_le!(out.len(), 15);
943 
944         // dict uncompressible
945         let out = compress_with_dict(&aaas[..11], aaas);
946         assert_gt!(out.len(), 11);
947         // compressible
948         let out = compress_with_dict(&aaas[..12], aaas);
949         // According to the spec this _could_ compres, but it doesn't in this lib
950         // as it aborts compression for any input len < LZ4_MIN_LENGTH
951         assert_gt!(out.len(), 12);
952         let out = compress_with_dict(&aaas[..13], aaas);
953         assert_le!(out.len(), 13);
954         let out = compress_with_dict(&aaas[..14], aaas);
955         assert_le!(out.len(), 14);
956         let out = compress_with_dict(&aaas[..15], aaas);
957         assert_le!(out.len(), 15);
958     }
959 
960     #[test]
test_dict_size()961     fn test_dict_size() {
962         let dict = vec![b'a'; 1024 * 1024];
963         let input = &b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaa"[..];
964         let compressed = compress_prepend_size_with_dict(input, &dict);
965         let decompressed =
966             crate::block::decompress_size_prepended_with_dict(&compressed, &dict).unwrap();
967         assert_eq!(decompressed, input);
968     }
969 }
970