xref: /aosp_15_r20/external/zstd/doc/educational_decoder/zstd_decompress.c (revision 01826a4963a0d8a59bc3812d29bdf0fb76416722)
1*01826a49SYabin Cui /*
2*01826a49SYabin Cui  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*01826a49SYabin Cui  * All rights reserved.
4*01826a49SYabin Cui  *
5*01826a49SYabin Cui  * This source code is licensed under both the BSD-style license (found in the
6*01826a49SYabin Cui  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7*01826a49SYabin Cui  * in the COPYING file in the root directory of this source tree).
8*01826a49SYabin Cui  * You may select, at your option, one of the above-listed licenses.
9*01826a49SYabin Cui  */
10*01826a49SYabin Cui 
11*01826a49SYabin Cui /// Zstandard educational decoder implementation
12*01826a49SYabin Cui /// See https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
13*01826a49SYabin Cui 
14*01826a49SYabin Cui #include <stdint.h>   // uint8_t, etc.
15*01826a49SYabin Cui #include <stdlib.h>   // malloc, free, exit
16*01826a49SYabin Cui #include <stdio.h>    // fprintf
17*01826a49SYabin Cui #include <string.h>   // memset, memcpy
18*01826a49SYabin Cui #include "zstd_decompress.h"
19*01826a49SYabin Cui 
20*01826a49SYabin Cui 
21*01826a49SYabin Cui /******* IMPORTANT CONSTANTS *********************************************/
22*01826a49SYabin Cui 
23*01826a49SYabin Cui // Zstandard frame
24*01826a49SYabin Cui // "Magic_Number
25*01826a49SYabin Cui // 4 Bytes, little-endian format. Value : 0xFD2FB528"
26*01826a49SYabin Cui #define ZSTD_MAGIC_NUMBER 0xFD2FB528U
27*01826a49SYabin Cui 
28*01826a49SYabin Cui // The size of `Block_Content` is limited by `Block_Maximum_Size`,
29*01826a49SYabin Cui #define ZSTD_BLOCK_SIZE_MAX ((size_t)128 * 1024)
30*01826a49SYabin Cui 
31*01826a49SYabin Cui // literal blocks can't be larger than their block
32*01826a49SYabin Cui #define MAX_LITERALS_SIZE ZSTD_BLOCK_SIZE_MAX
33*01826a49SYabin Cui 
34*01826a49SYabin Cui 
35*01826a49SYabin Cui /******* UTILITY MACROS AND TYPES *********************************************/
36*01826a49SYabin Cui #define MAX(a, b) ((a) > (b) ? (a) : (b))
37*01826a49SYabin Cui #define MIN(a, b) ((a) < (b) ? (a) : (b))
38*01826a49SYabin Cui 
39*01826a49SYabin Cui #if defined(ZDEC_NO_MESSAGE)
40*01826a49SYabin Cui #define MESSAGE(...)
41*01826a49SYabin Cui #else
42*01826a49SYabin Cui #define MESSAGE(...)  fprintf(stderr, "" __VA_ARGS__)
43*01826a49SYabin Cui #endif
44*01826a49SYabin Cui 
45*01826a49SYabin Cui /// This decoder calls exit(1) when it encounters an error, however a production
46*01826a49SYabin Cui /// library should propagate error codes
47*01826a49SYabin Cui #define ERROR(s)                                                               \
48*01826a49SYabin Cui     do {                                                                       \
49*01826a49SYabin Cui         MESSAGE("Error: %s\n", s);                                     \
50*01826a49SYabin Cui         exit(1);                                                               \
51*01826a49SYabin Cui     } while (0)
52*01826a49SYabin Cui #define INP_SIZE()                                                             \
53*01826a49SYabin Cui     ERROR("Input buffer smaller than it should be or input is "                \
54*01826a49SYabin Cui           "corrupted")
55*01826a49SYabin Cui #define OUT_SIZE() ERROR("Output buffer too small for output")
56*01826a49SYabin Cui #define CORRUPTION() ERROR("Corruption detected while decompressing")
57*01826a49SYabin Cui #define BAD_ALLOC() ERROR("Memory allocation error")
58*01826a49SYabin Cui #define IMPOSSIBLE() ERROR("An impossibility has occurred")
59*01826a49SYabin Cui 
60*01826a49SYabin Cui typedef uint8_t  u8;
61*01826a49SYabin Cui typedef uint16_t u16;
62*01826a49SYabin Cui typedef uint32_t u32;
63*01826a49SYabin Cui typedef uint64_t u64;
64*01826a49SYabin Cui 
65*01826a49SYabin Cui typedef int8_t  i8;
66*01826a49SYabin Cui typedef int16_t i16;
67*01826a49SYabin Cui typedef int32_t i32;
68*01826a49SYabin Cui typedef int64_t i64;
69*01826a49SYabin Cui /******* END UTILITY MACROS AND TYPES *****************************************/
70*01826a49SYabin Cui 
71*01826a49SYabin Cui /******* IMPLEMENTATION PRIMITIVE PROTOTYPES **********************************/
72*01826a49SYabin Cui /// The implementations for these functions can be found at the bottom of this
73*01826a49SYabin Cui /// file.  They implement low-level functionality needed for the higher level
74*01826a49SYabin Cui /// decompression functions.
75*01826a49SYabin Cui 
76*01826a49SYabin Cui /*** IO STREAM OPERATIONS *************/
77*01826a49SYabin Cui 
78*01826a49SYabin Cui /// ostream_t/istream_t are used to wrap the pointers/length data passed into
79*01826a49SYabin Cui /// ZSTD_decompress, so that all IO operations are safely bounds checked
80*01826a49SYabin Cui /// They are written/read forward, and reads are treated as little-endian
81*01826a49SYabin Cui /// They should be used opaquely to ensure safety
82*01826a49SYabin Cui typedef struct {
83*01826a49SYabin Cui     u8 *ptr;
84*01826a49SYabin Cui     size_t len;
85*01826a49SYabin Cui } ostream_t;
86*01826a49SYabin Cui 
87*01826a49SYabin Cui typedef struct {
88*01826a49SYabin Cui     const u8 *ptr;
89*01826a49SYabin Cui     size_t len;
90*01826a49SYabin Cui 
91*01826a49SYabin Cui     // Input often reads a few bits at a time, so maintain an internal offset
92*01826a49SYabin Cui     int bit_offset;
93*01826a49SYabin Cui } istream_t;
94*01826a49SYabin Cui 
95*01826a49SYabin Cui /// The following two functions are the only ones that allow the istream to be
96*01826a49SYabin Cui /// non-byte aligned
97*01826a49SYabin Cui 
98*01826a49SYabin Cui /// Reads `num` bits from a bitstream, and updates the internal offset
99*01826a49SYabin Cui static inline u64 IO_read_bits(istream_t *const in, const int num_bits);
100*01826a49SYabin Cui /// Backs-up the stream by `num` bits so they can be read again
101*01826a49SYabin Cui static inline void IO_rewind_bits(istream_t *const in, const int num_bits);
102*01826a49SYabin Cui /// If the remaining bits in a byte will be unused, advance to the end of the
103*01826a49SYabin Cui /// byte
104*01826a49SYabin Cui static inline void IO_align_stream(istream_t *const in);
105*01826a49SYabin Cui 
106*01826a49SYabin Cui /// Write the given byte into the output stream
107*01826a49SYabin Cui static inline void IO_write_byte(ostream_t *const out, u8 symb);
108*01826a49SYabin Cui 
109*01826a49SYabin Cui /// Returns the number of bytes left to be read in this stream.  The stream must
110*01826a49SYabin Cui /// be byte aligned.
111*01826a49SYabin Cui static inline size_t IO_istream_len(const istream_t *const in);
112*01826a49SYabin Cui 
113*01826a49SYabin Cui /// Advances the stream by `len` bytes, and returns a pointer to the chunk that
114*01826a49SYabin Cui /// was skipped.  The stream must be byte aligned.
115*01826a49SYabin Cui static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len);
116*01826a49SYabin Cui /// Advances the stream by `len` bytes, and returns a pointer to the chunk that
117*01826a49SYabin Cui /// was skipped so it can be written to.
118*01826a49SYabin Cui static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len);
119*01826a49SYabin Cui 
120*01826a49SYabin Cui /// Advance the inner state by `len` bytes.  The stream must be byte aligned.
121*01826a49SYabin Cui static inline void IO_advance_input(istream_t *const in, size_t len);
122*01826a49SYabin Cui 
123*01826a49SYabin Cui /// Returns an `ostream_t` constructed from the given pointer and length.
124*01826a49SYabin Cui static inline ostream_t IO_make_ostream(u8 *out, size_t len);
125*01826a49SYabin Cui /// Returns an `istream_t` constructed from the given pointer and length.
126*01826a49SYabin Cui static inline istream_t IO_make_istream(const u8 *in, size_t len);
127*01826a49SYabin Cui 
128*01826a49SYabin Cui /// Returns an `istream_t` with the same base as `in`, and length `len`.
129*01826a49SYabin Cui /// Then, advance `in` to account for the consumed bytes.
130*01826a49SYabin Cui /// `in` must be byte aligned.
131*01826a49SYabin Cui static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len);
132*01826a49SYabin Cui /*** END IO STREAM OPERATIONS *********/
133*01826a49SYabin Cui 
134*01826a49SYabin Cui /*** BITSTREAM OPERATIONS *************/
135*01826a49SYabin Cui /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits,
136*01826a49SYabin Cui /// and return them interpreted as a little-endian unsigned integer.
137*01826a49SYabin Cui static inline u64 read_bits_LE(const u8 *src, const int num_bits,
138*01826a49SYabin Cui                                const size_t offset);
139*01826a49SYabin Cui 
140*01826a49SYabin Cui /// Read bits from the end of a HUF or FSE bitstream.  `offset` is in bits, so
141*01826a49SYabin Cui /// it updates `offset` to `offset - bits`, and then reads `bits` bits from
142*01826a49SYabin Cui /// `src + offset`.  If the offset becomes negative, the extra bits at the
143*01826a49SYabin Cui /// bottom are filled in with `0` bits instead of reading from before `src`.
144*01826a49SYabin Cui static inline u64 STREAM_read_bits(const u8 *src, const int bits,
145*01826a49SYabin Cui                                    i64 *const offset);
146*01826a49SYabin Cui /*** END BITSTREAM OPERATIONS *********/
147*01826a49SYabin Cui 
148*01826a49SYabin Cui /*** BIT COUNTING OPERATIONS **********/
149*01826a49SYabin Cui /// Returns the index of the highest set bit in `num`, or `-1` if `num == 0`
150*01826a49SYabin Cui static inline int highest_set_bit(const u64 num);
151*01826a49SYabin Cui /*** END BIT COUNTING OPERATIONS ******/
152*01826a49SYabin Cui 
153*01826a49SYabin Cui /*** HUFFMAN PRIMITIVES ***************/
154*01826a49SYabin Cui // Table decode method uses exponential memory, so we need to limit depth
155*01826a49SYabin Cui #define HUF_MAX_BITS (16)
156*01826a49SYabin Cui 
157*01826a49SYabin Cui // Limit the maximum number of symbols to 256 so we can store a symbol in a byte
158*01826a49SYabin Cui #define HUF_MAX_SYMBS (256)
159*01826a49SYabin Cui 
160*01826a49SYabin Cui /// Structure containing all tables necessary for efficient Huffman decoding
161*01826a49SYabin Cui typedef struct {
162*01826a49SYabin Cui     u8 *symbols;
163*01826a49SYabin Cui     u8 *num_bits;
164*01826a49SYabin Cui     int max_bits;
165*01826a49SYabin Cui } HUF_dtable;
166*01826a49SYabin Cui 
167*01826a49SYabin Cui /// Decode a single symbol and read in enough bits to refresh the state
168*01826a49SYabin Cui static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable,
169*01826a49SYabin Cui                                    u16 *const state, const u8 *const src,
170*01826a49SYabin Cui                                    i64 *const offset);
171*01826a49SYabin Cui /// Read in a full state's worth of bits to initialize it
172*01826a49SYabin Cui static inline void HUF_init_state(const HUF_dtable *const dtable,
173*01826a49SYabin Cui                                   u16 *const state, const u8 *const src,
174*01826a49SYabin Cui                                   i64 *const offset);
175*01826a49SYabin Cui 
176*01826a49SYabin Cui /// Decompresses a single Huffman stream, returns the number of bytes decoded.
177*01826a49SYabin Cui /// `src_len` must be the exact length of the Huffman-coded block.
178*01826a49SYabin Cui static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
179*01826a49SYabin Cui                                      ostream_t *const out, istream_t *const in);
180*01826a49SYabin Cui /// Same as previous but decodes 4 streams, formatted as in the Zstandard
181*01826a49SYabin Cui /// specification.
182*01826a49SYabin Cui /// `src_len` must be the exact length of the Huffman-coded block.
183*01826a49SYabin Cui static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
184*01826a49SYabin Cui                                      ostream_t *const out, istream_t *const in);
185*01826a49SYabin Cui 
186*01826a49SYabin Cui /// Initialize a Huffman decoding table using the table of bit counts provided
187*01826a49SYabin Cui static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
188*01826a49SYabin Cui                             const int num_symbs);
189*01826a49SYabin Cui /// Initialize a Huffman decoding table using the table of weights provided
190*01826a49SYabin Cui /// Weights follow the definition provided in the Zstandard specification
191*01826a49SYabin Cui static void HUF_init_dtable_usingweights(HUF_dtable *const table,
192*01826a49SYabin Cui                                          const u8 *const weights,
193*01826a49SYabin Cui                                          const int num_symbs);
194*01826a49SYabin Cui 
195*01826a49SYabin Cui /// Free the malloc'ed parts of a decoding table
196*01826a49SYabin Cui static void HUF_free_dtable(HUF_dtable *const dtable);
197*01826a49SYabin Cui /*** END HUFFMAN PRIMITIVES ***********/
198*01826a49SYabin Cui 
199*01826a49SYabin Cui /*** FSE PRIMITIVES *******************/
200*01826a49SYabin Cui /// For more description of FSE see
201*01826a49SYabin Cui /// https://github.com/Cyan4973/FiniteStateEntropy/
202*01826a49SYabin Cui 
203*01826a49SYabin Cui // FSE table decoding uses exponential memory, so limit the maximum accuracy
204*01826a49SYabin Cui #define FSE_MAX_ACCURACY_LOG (15)
205*01826a49SYabin Cui // Limit the maximum number of symbols so they can be stored in a single byte
206*01826a49SYabin Cui #define FSE_MAX_SYMBS (256)
207*01826a49SYabin Cui 
208*01826a49SYabin Cui /// The tables needed to decode FSE encoded streams
209*01826a49SYabin Cui typedef struct {
210*01826a49SYabin Cui     u8 *symbols;
211*01826a49SYabin Cui     u8 *num_bits;
212*01826a49SYabin Cui     u16 *new_state_base;
213*01826a49SYabin Cui     int accuracy_log;
214*01826a49SYabin Cui } FSE_dtable;
215*01826a49SYabin Cui 
216*01826a49SYabin Cui /// Return the symbol for the current state
217*01826a49SYabin Cui static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable,
218*01826a49SYabin Cui                                  const u16 state);
219*01826a49SYabin Cui /// Read the number of bits necessary to update state, update, and shift offset
220*01826a49SYabin Cui /// back to reflect the bits read
221*01826a49SYabin Cui static inline void FSE_update_state(const FSE_dtable *const dtable,
222*01826a49SYabin Cui                                     u16 *const state, const u8 *const src,
223*01826a49SYabin Cui                                     i64 *const offset);
224*01826a49SYabin Cui 
225*01826a49SYabin Cui /// Combine peek and update: decode a symbol and update the state
226*01826a49SYabin Cui static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable,
227*01826a49SYabin Cui                                    u16 *const state, const u8 *const src,
228*01826a49SYabin Cui                                    i64 *const offset);
229*01826a49SYabin Cui 
230*01826a49SYabin Cui /// Read bits from the stream to initialize the state and shift offset back
231*01826a49SYabin Cui static inline void FSE_init_state(const FSE_dtable *const dtable,
232*01826a49SYabin Cui                                   u16 *const state, const u8 *const src,
233*01826a49SYabin Cui                                   i64 *const offset);
234*01826a49SYabin Cui 
235*01826a49SYabin Cui /// Decompress two interleaved bitstreams (e.g. compressed Huffman weights)
236*01826a49SYabin Cui /// using an FSE decoding table.  `src_len` must be the exact length of the
237*01826a49SYabin Cui /// block.
238*01826a49SYabin Cui static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
239*01826a49SYabin Cui                                           ostream_t *const out,
240*01826a49SYabin Cui                                           istream_t *const in);
241*01826a49SYabin Cui 
242*01826a49SYabin Cui /// Initialize a decoding table using normalized frequencies.
243*01826a49SYabin Cui static void FSE_init_dtable(FSE_dtable *const dtable,
244*01826a49SYabin Cui                             const i16 *const norm_freqs, const int num_symbs,
245*01826a49SYabin Cui                             const int accuracy_log);
246*01826a49SYabin Cui 
247*01826a49SYabin Cui /// Decode an FSE header as defined in the Zstandard format specification and
248*01826a49SYabin Cui /// use the decoded frequencies to initialize a decoding table.
249*01826a49SYabin Cui static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
250*01826a49SYabin Cui                                 const int max_accuracy_log);
251*01826a49SYabin Cui 
252*01826a49SYabin Cui /// Initialize an FSE table that will always return the same symbol and consume
253*01826a49SYabin Cui /// 0 bits per symbol, to be used for RLE mode in sequence commands
254*01826a49SYabin Cui static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb);
255*01826a49SYabin Cui 
256*01826a49SYabin Cui /// Free the malloc'ed parts of a decoding table
257*01826a49SYabin Cui static void FSE_free_dtable(FSE_dtable *const dtable);
258*01826a49SYabin Cui /*** END FSE PRIMITIVES ***************/
259*01826a49SYabin Cui 
260*01826a49SYabin Cui /******* END IMPLEMENTATION PRIMITIVE PROTOTYPES ******************************/
261*01826a49SYabin Cui 
262*01826a49SYabin Cui /******* ZSTD HELPER STRUCTS AND PROTOTYPES ***********************************/
263*01826a49SYabin Cui 
264*01826a49SYabin Cui /// A small structure that can be reused in various places that need to access
265*01826a49SYabin Cui /// frame header information
266*01826a49SYabin Cui typedef struct {
267*01826a49SYabin Cui     // The size of window that we need to be able to contiguously store for
268*01826a49SYabin Cui     // references
269*01826a49SYabin Cui     size_t window_size;
270*01826a49SYabin Cui     // The total output size of this compressed frame
271*01826a49SYabin Cui     size_t frame_content_size;
272*01826a49SYabin Cui 
273*01826a49SYabin Cui     // The dictionary id if this frame uses one
274*01826a49SYabin Cui     u32 dictionary_id;
275*01826a49SYabin Cui 
276*01826a49SYabin Cui     // Whether or not the content of this frame has a checksum
277*01826a49SYabin Cui     int content_checksum_flag;
278*01826a49SYabin Cui     // Whether or not the output for this frame is in a single segment
279*01826a49SYabin Cui     int single_segment_flag;
280*01826a49SYabin Cui } frame_header_t;
281*01826a49SYabin Cui 
282*01826a49SYabin Cui /// The context needed to decode blocks in a frame
283*01826a49SYabin Cui typedef struct {
284*01826a49SYabin Cui     frame_header_t header;
285*01826a49SYabin Cui 
286*01826a49SYabin Cui     // The total amount of data available for backreferences, to determine if an
287*01826a49SYabin Cui     // offset too large to be correct
288*01826a49SYabin Cui     size_t current_total_output;
289*01826a49SYabin Cui 
290*01826a49SYabin Cui     const u8 *dict_content;
291*01826a49SYabin Cui     size_t dict_content_len;
292*01826a49SYabin Cui 
293*01826a49SYabin Cui     // Entropy encoding tables so they can be repeated by future blocks instead
294*01826a49SYabin Cui     // of retransmitting
295*01826a49SYabin Cui     HUF_dtable literals_dtable;
296*01826a49SYabin Cui     FSE_dtable ll_dtable;
297*01826a49SYabin Cui     FSE_dtable ml_dtable;
298*01826a49SYabin Cui     FSE_dtable of_dtable;
299*01826a49SYabin Cui 
300*01826a49SYabin Cui     // The last 3 offsets for the special "repeat offsets".
301*01826a49SYabin Cui     u64 previous_offsets[3];
302*01826a49SYabin Cui } frame_context_t;
303*01826a49SYabin Cui 
304*01826a49SYabin Cui /// The decoded contents of a dictionary so that it doesn't have to be repeated
305*01826a49SYabin Cui /// for each frame that uses it
306*01826a49SYabin Cui struct dictionary_s {
307*01826a49SYabin Cui     // Entropy tables
308*01826a49SYabin Cui     HUF_dtable literals_dtable;
309*01826a49SYabin Cui     FSE_dtable ll_dtable;
310*01826a49SYabin Cui     FSE_dtable ml_dtable;
311*01826a49SYabin Cui     FSE_dtable of_dtable;
312*01826a49SYabin Cui 
313*01826a49SYabin Cui     // Raw content for backreferences
314*01826a49SYabin Cui     u8 *content;
315*01826a49SYabin Cui     size_t content_size;
316*01826a49SYabin Cui 
317*01826a49SYabin Cui     // Offset history to prepopulate the frame's history
318*01826a49SYabin Cui     u64 previous_offsets[3];
319*01826a49SYabin Cui 
320*01826a49SYabin Cui     u32 dictionary_id;
321*01826a49SYabin Cui };
322*01826a49SYabin Cui 
323*01826a49SYabin Cui /// A tuple containing the parts necessary to decode and execute a ZSTD sequence
324*01826a49SYabin Cui /// command
325*01826a49SYabin Cui typedef struct {
326*01826a49SYabin Cui     u32 literal_length;
327*01826a49SYabin Cui     u32 match_length;
328*01826a49SYabin Cui     u32 offset;
329*01826a49SYabin Cui } sequence_command_t;
330*01826a49SYabin Cui 
331*01826a49SYabin Cui /// The decoder works top-down, starting at the high level like Zstd frames, and
332*01826a49SYabin Cui /// working down to lower more technical levels such as blocks, literals, and
333*01826a49SYabin Cui /// sequences.  The high-level functions roughly follow the outline of the
334*01826a49SYabin Cui /// format specification:
335*01826a49SYabin Cui /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
336*01826a49SYabin Cui 
337*01826a49SYabin Cui /// Before the implementation of each high-level function declared here, the
338*01826a49SYabin Cui /// prototypes for their helper functions are defined and explained
339*01826a49SYabin Cui 
340*01826a49SYabin Cui /// Decode a single Zstd frame, or error if the input is not a valid frame.
341*01826a49SYabin Cui /// Accepts a dict argument, which may be NULL indicating no dictionary.
342*01826a49SYabin Cui /// See
343*01826a49SYabin Cui /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame-concatenation
344*01826a49SYabin Cui static void decode_frame(ostream_t *const out, istream_t *const in,
345*01826a49SYabin Cui                          const dictionary_t *const dict);
346*01826a49SYabin Cui 
347*01826a49SYabin Cui // Decode data in a compressed block
348*01826a49SYabin Cui static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
349*01826a49SYabin Cui                              istream_t *const in);
350*01826a49SYabin Cui 
351*01826a49SYabin Cui // Decode the literals section of a block
352*01826a49SYabin Cui static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
353*01826a49SYabin Cui                               u8 **const literals);
354*01826a49SYabin Cui 
355*01826a49SYabin Cui // Decode the sequences part of a block
356*01826a49SYabin Cui static size_t decode_sequences(frame_context_t *const ctx, istream_t *const in,
357*01826a49SYabin Cui                                sequence_command_t **const sequences);
358*01826a49SYabin Cui 
359*01826a49SYabin Cui // Execute the decoded sequences on the literals block
360*01826a49SYabin Cui static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
361*01826a49SYabin Cui                               const u8 *const literals,
362*01826a49SYabin Cui                               const size_t literals_len,
363*01826a49SYabin Cui                               const sequence_command_t *const sequences,
364*01826a49SYabin Cui                               const size_t num_sequences);
365*01826a49SYabin Cui 
366*01826a49SYabin Cui // Copies literals and returns the total literal length that was copied
367*01826a49SYabin Cui static u32 copy_literals(const size_t seq, istream_t *litstream,
368*01826a49SYabin Cui                          ostream_t *const out);
369*01826a49SYabin Cui 
370*01826a49SYabin Cui // Given an offset code from a sequence command (either an actual offset value
371*01826a49SYabin Cui // or an index for previous offset), computes the correct offset and updates
372*01826a49SYabin Cui // the offset history
373*01826a49SYabin Cui static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist);
374*01826a49SYabin Cui 
375*01826a49SYabin Cui // Given an offset, match length, and total output, as well as the frame
376*01826a49SYabin Cui // context for the dictionary, determines if the dictionary is used and
377*01826a49SYabin Cui // executes the copy operation
378*01826a49SYabin Cui static void execute_match_copy(frame_context_t *const ctx, size_t offset,
379*01826a49SYabin Cui                               size_t match_length, size_t total_output,
380*01826a49SYabin Cui                               ostream_t *const out);
381*01826a49SYabin Cui 
382*01826a49SYabin Cui /******* END ZSTD HELPER STRUCTS AND PROTOTYPES *******************************/
383*01826a49SYabin Cui 
ZSTD_decompress(void * const dst,const size_t dst_len,const void * const src,const size_t src_len)384*01826a49SYabin Cui size_t ZSTD_decompress(void *const dst, const size_t dst_len,
385*01826a49SYabin Cui                        const void *const src, const size_t src_len) {
386*01826a49SYabin Cui     dictionary_t* const uninit_dict = create_dictionary();
387*01826a49SYabin Cui     size_t const decomp_size = ZSTD_decompress_with_dict(dst, dst_len, src,
388*01826a49SYabin Cui                                                          src_len, uninit_dict);
389*01826a49SYabin Cui     free_dictionary(uninit_dict);
390*01826a49SYabin Cui     return decomp_size;
391*01826a49SYabin Cui }
392*01826a49SYabin Cui 
ZSTD_decompress_with_dict(void * const dst,const size_t dst_len,const void * const src,const size_t src_len,dictionary_t * parsed_dict)393*01826a49SYabin Cui size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len,
394*01826a49SYabin Cui                                  const void *const src, const size_t src_len,
395*01826a49SYabin Cui                                  dictionary_t* parsed_dict) {
396*01826a49SYabin Cui 
397*01826a49SYabin Cui     istream_t in = IO_make_istream(src, src_len);
398*01826a49SYabin Cui     ostream_t out = IO_make_ostream(dst, dst_len);
399*01826a49SYabin Cui 
400*01826a49SYabin Cui     // "A content compressed by Zstandard is transformed into a Zstandard frame.
401*01826a49SYabin Cui     // Multiple frames can be appended into a single file or stream. A frame is
402*01826a49SYabin Cui     // totally independent, has a defined beginning and end, and a set of
403*01826a49SYabin Cui     // parameters which tells the decoder how to decompress it."
404*01826a49SYabin Cui 
405*01826a49SYabin Cui     /* this decoder assumes decompression of a single frame */
406*01826a49SYabin Cui     decode_frame(&out, &in, parsed_dict);
407*01826a49SYabin Cui 
408*01826a49SYabin Cui     return (size_t)(out.ptr - (u8 *)dst);
409*01826a49SYabin Cui }
410*01826a49SYabin Cui 
411*01826a49SYabin Cui /******* FRAME DECODING ******************************************************/
412*01826a49SYabin Cui 
413*01826a49SYabin Cui static void decode_data_frame(ostream_t *const out, istream_t *const in,
414*01826a49SYabin Cui                               const dictionary_t *const dict);
415*01826a49SYabin Cui static void init_frame_context(frame_context_t *const context,
416*01826a49SYabin Cui                                istream_t *const in,
417*01826a49SYabin Cui                                const dictionary_t *const dict);
418*01826a49SYabin Cui static void free_frame_context(frame_context_t *const context);
419*01826a49SYabin Cui static void parse_frame_header(frame_header_t *const header,
420*01826a49SYabin Cui                                istream_t *const in);
421*01826a49SYabin Cui static void frame_context_apply_dict(frame_context_t *const ctx,
422*01826a49SYabin Cui                                      const dictionary_t *const dict);
423*01826a49SYabin Cui 
424*01826a49SYabin Cui static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
425*01826a49SYabin Cui                             istream_t *const in);
426*01826a49SYabin Cui 
decode_frame(ostream_t * const out,istream_t * const in,const dictionary_t * const dict)427*01826a49SYabin Cui static void decode_frame(ostream_t *const out, istream_t *const in,
428*01826a49SYabin Cui                          const dictionary_t *const dict) {
429*01826a49SYabin Cui     const u32 magic_number = (u32)IO_read_bits(in, 32);
430*01826a49SYabin Cui     if (magic_number == ZSTD_MAGIC_NUMBER) {
431*01826a49SYabin Cui         // ZSTD frame
432*01826a49SYabin Cui         decode_data_frame(out, in, dict);
433*01826a49SYabin Cui 
434*01826a49SYabin Cui         return;
435*01826a49SYabin Cui     }
436*01826a49SYabin Cui 
437*01826a49SYabin Cui     // not a real frame or a skippable frame
438*01826a49SYabin Cui     ERROR("Tried to decode non-ZSTD frame");
439*01826a49SYabin Cui }
440*01826a49SYabin Cui 
441*01826a49SYabin Cui /// Decode a frame that contains compressed data.  Not all frames do as there
442*01826a49SYabin Cui /// are skippable frames.
443*01826a49SYabin Cui /// See
444*01826a49SYabin Cui /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#general-structure-of-zstandard-frame-format
decode_data_frame(ostream_t * const out,istream_t * const in,const dictionary_t * const dict)445*01826a49SYabin Cui static void decode_data_frame(ostream_t *const out, istream_t *const in,
446*01826a49SYabin Cui                               const dictionary_t *const dict) {
447*01826a49SYabin Cui     frame_context_t ctx;
448*01826a49SYabin Cui 
449*01826a49SYabin Cui     // Initialize the context that needs to be carried from block to block
450*01826a49SYabin Cui     init_frame_context(&ctx, in, dict);
451*01826a49SYabin Cui 
452*01826a49SYabin Cui     if (ctx.header.frame_content_size != 0 &&
453*01826a49SYabin Cui         ctx.header.frame_content_size > out->len) {
454*01826a49SYabin Cui         OUT_SIZE();
455*01826a49SYabin Cui     }
456*01826a49SYabin Cui 
457*01826a49SYabin Cui     decompress_data(&ctx, out, in);
458*01826a49SYabin Cui 
459*01826a49SYabin Cui     free_frame_context(&ctx);
460*01826a49SYabin Cui }
461*01826a49SYabin Cui 
462*01826a49SYabin Cui /// Takes the information provided in the header and dictionary, and initializes
463*01826a49SYabin Cui /// the context for this frame
init_frame_context(frame_context_t * const context,istream_t * const in,const dictionary_t * const dict)464*01826a49SYabin Cui static void init_frame_context(frame_context_t *const context,
465*01826a49SYabin Cui                                istream_t *const in,
466*01826a49SYabin Cui                                const dictionary_t *const dict) {
467*01826a49SYabin Cui     // Most fields in context are correct when initialized to 0
468*01826a49SYabin Cui     memset(context, 0, sizeof(frame_context_t));
469*01826a49SYabin Cui 
470*01826a49SYabin Cui     // Parse data from the frame header
471*01826a49SYabin Cui     parse_frame_header(&context->header, in);
472*01826a49SYabin Cui 
473*01826a49SYabin Cui     // Set up the offset history for the repeat offset commands
474*01826a49SYabin Cui     context->previous_offsets[0] = 1;
475*01826a49SYabin Cui     context->previous_offsets[1] = 4;
476*01826a49SYabin Cui     context->previous_offsets[2] = 8;
477*01826a49SYabin Cui 
478*01826a49SYabin Cui     // Apply details from the dict if it exists
479*01826a49SYabin Cui     frame_context_apply_dict(context, dict);
480*01826a49SYabin Cui }
481*01826a49SYabin Cui 
free_frame_context(frame_context_t * const context)482*01826a49SYabin Cui static void free_frame_context(frame_context_t *const context) {
483*01826a49SYabin Cui     HUF_free_dtable(&context->literals_dtable);
484*01826a49SYabin Cui 
485*01826a49SYabin Cui     FSE_free_dtable(&context->ll_dtable);
486*01826a49SYabin Cui     FSE_free_dtable(&context->ml_dtable);
487*01826a49SYabin Cui     FSE_free_dtable(&context->of_dtable);
488*01826a49SYabin Cui 
489*01826a49SYabin Cui     memset(context, 0, sizeof(frame_context_t));
490*01826a49SYabin Cui }
491*01826a49SYabin Cui 
parse_frame_header(frame_header_t * const header,istream_t * const in)492*01826a49SYabin Cui static void parse_frame_header(frame_header_t *const header,
493*01826a49SYabin Cui                                istream_t *const in) {
494*01826a49SYabin Cui     // "The first header's byte is called the Frame_Header_Descriptor. It tells
495*01826a49SYabin Cui     // which other fields are present. Decoding this byte is enough to tell the
496*01826a49SYabin Cui     // size of Frame_Header.
497*01826a49SYabin Cui     //
498*01826a49SYabin Cui     // Bit number   Field name
499*01826a49SYabin Cui     // 7-6  Frame_Content_Size_flag
500*01826a49SYabin Cui     // 5    Single_Segment_flag
501*01826a49SYabin Cui     // 4    Unused_bit
502*01826a49SYabin Cui     // 3    Reserved_bit
503*01826a49SYabin Cui     // 2    Content_Checksum_flag
504*01826a49SYabin Cui     // 1-0  Dictionary_ID_flag"
505*01826a49SYabin Cui     const u8 descriptor = (u8)IO_read_bits(in, 8);
506*01826a49SYabin Cui 
507*01826a49SYabin Cui     // decode frame header descriptor into flags
508*01826a49SYabin Cui     const u8 frame_content_size_flag = descriptor >> 6;
509*01826a49SYabin Cui     const u8 single_segment_flag = (descriptor >> 5) & 1;
510*01826a49SYabin Cui     const u8 reserved_bit = (descriptor >> 3) & 1;
511*01826a49SYabin Cui     const u8 content_checksum_flag = (descriptor >> 2) & 1;
512*01826a49SYabin Cui     const u8 dictionary_id_flag = descriptor & 3;
513*01826a49SYabin Cui 
514*01826a49SYabin Cui     if (reserved_bit != 0) {
515*01826a49SYabin Cui         CORRUPTION();
516*01826a49SYabin Cui     }
517*01826a49SYabin Cui 
518*01826a49SYabin Cui     header->single_segment_flag = single_segment_flag;
519*01826a49SYabin Cui     header->content_checksum_flag = content_checksum_flag;
520*01826a49SYabin Cui 
521*01826a49SYabin Cui     // decode window size
522*01826a49SYabin Cui     if (!single_segment_flag) {
523*01826a49SYabin Cui         // "Provides guarantees on maximum back-reference distance that will be
524*01826a49SYabin Cui         // used within compressed data. This information is important for
525*01826a49SYabin Cui         // decoders to allocate enough memory.
526*01826a49SYabin Cui         //
527*01826a49SYabin Cui         // Bit numbers  7-3         2-0
528*01826a49SYabin Cui         // Field name   Exponent    Mantissa"
529*01826a49SYabin Cui         u8 window_descriptor = (u8)IO_read_bits(in, 8);
530*01826a49SYabin Cui         u8 exponent = window_descriptor >> 3;
531*01826a49SYabin Cui         u8 mantissa = window_descriptor & 7;
532*01826a49SYabin Cui 
533*01826a49SYabin Cui         // Use the algorithm from the specification to compute window size
534*01826a49SYabin Cui         // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
535*01826a49SYabin Cui         size_t window_base = (size_t)1 << (10 + exponent);
536*01826a49SYabin Cui         size_t window_add = (window_base / 8) * mantissa;
537*01826a49SYabin Cui         header->window_size = window_base + window_add;
538*01826a49SYabin Cui     }
539*01826a49SYabin Cui 
540*01826a49SYabin Cui     // decode dictionary id if it exists
541*01826a49SYabin Cui     if (dictionary_id_flag) {
542*01826a49SYabin Cui         // "This is a variable size field, which contains the ID of the
543*01826a49SYabin Cui         // dictionary required to properly decode the frame. Note that this
544*01826a49SYabin Cui         // field is optional. When it's not present, it's up to the caller to
545*01826a49SYabin Cui         // make sure it uses the correct dictionary. Format is little-endian."
546*01826a49SYabin Cui         const int bytes_array[] = {0, 1, 2, 4};
547*01826a49SYabin Cui         const int bytes = bytes_array[dictionary_id_flag];
548*01826a49SYabin Cui 
549*01826a49SYabin Cui         header->dictionary_id = (u32)IO_read_bits(in, bytes * 8);
550*01826a49SYabin Cui     } else {
551*01826a49SYabin Cui         header->dictionary_id = 0;
552*01826a49SYabin Cui     }
553*01826a49SYabin Cui 
554*01826a49SYabin Cui     // decode frame content size if it exists
555*01826a49SYabin Cui     if (single_segment_flag || frame_content_size_flag) {
556*01826a49SYabin Cui         // "This is the original (uncompressed) size. This information is
557*01826a49SYabin Cui         // optional. The Field_Size is provided according to value of
558*01826a49SYabin Cui         // Frame_Content_Size_flag. The Field_Size can be equal to 0 (not
559*01826a49SYabin Cui         // present), 1, 2, 4 or 8 bytes. Format is little-endian."
560*01826a49SYabin Cui         //
561*01826a49SYabin Cui         // if frame_content_size_flag == 0 but single_segment_flag is set, we
562*01826a49SYabin Cui         // still have a 1 byte field
563*01826a49SYabin Cui         const int bytes_array[] = {1, 2, 4, 8};
564*01826a49SYabin Cui         const int bytes = bytes_array[frame_content_size_flag];
565*01826a49SYabin Cui 
566*01826a49SYabin Cui         header->frame_content_size = IO_read_bits(in, bytes * 8);
567*01826a49SYabin Cui         if (bytes == 2) {
568*01826a49SYabin Cui             // "When Field_Size is 2, the offset of 256 is added."
569*01826a49SYabin Cui             header->frame_content_size += 256;
570*01826a49SYabin Cui         }
571*01826a49SYabin Cui     } else {
572*01826a49SYabin Cui         header->frame_content_size = 0;
573*01826a49SYabin Cui     }
574*01826a49SYabin Cui 
575*01826a49SYabin Cui     if (single_segment_flag) {
576*01826a49SYabin Cui         // "The Window_Descriptor byte is optional. It is absent when
577*01826a49SYabin Cui         // Single_Segment_flag is set. In this case, the maximum back-reference
578*01826a49SYabin Cui         // distance is the content size itself, which can be any value from 1 to
579*01826a49SYabin Cui         // 2^64-1 bytes (16 EB)."
580*01826a49SYabin Cui         header->window_size = header->frame_content_size;
581*01826a49SYabin Cui     }
582*01826a49SYabin Cui }
583*01826a49SYabin Cui 
584*01826a49SYabin Cui /// Decompress the data from a frame block by block
decompress_data(frame_context_t * const ctx,ostream_t * const out,istream_t * const in)585*01826a49SYabin Cui static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
586*01826a49SYabin Cui                             istream_t *const in) {
587*01826a49SYabin Cui     // "A frame encapsulates one or multiple blocks. Each block can be
588*01826a49SYabin Cui     // compressed or not, and has a guaranteed maximum content size, which
589*01826a49SYabin Cui     // depends on frame parameters. Unlike frames, each block depends on
590*01826a49SYabin Cui     // previous blocks for proper decoding. However, each block can be
591*01826a49SYabin Cui     // decompressed without waiting for its successor, allowing streaming
592*01826a49SYabin Cui     // operations."
593*01826a49SYabin Cui     int last_block = 0;
594*01826a49SYabin Cui     do {
595*01826a49SYabin Cui         // "Last_Block
596*01826a49SYabin Cui         //
597*01826a49SYabin Cui         // The lowest bit signals if this block is the last one. Frame ends
598*01826a49SYabin Cui         // right after this block.
599*01826a49SYabin Cui         //
600*01826a49SYabin Cui         // Block_Type and Block_Size
601*01826a49SYabin Cui         //
602*01826a49SYabin Cui         // The next 2 bits represent the Block_Type, while the remaining 21 bits
603*01826a49SYabin Cui         // represent the Block_Size. Format is little-endian."
604*01826a49SYabin Cui         last_block = (int)IO_read_bits(in, 1);
605*01826a49SYabin Cui         const int block_type = (int)IO_read_bits(in, 2);
606*01826a49SYabin Cui         const size_t block_len = IO_read_bits(in, 21);
607*01826a49SYabin Cui 
608*01826a49SYabin Cui         switch (block_type) {
609*01826a49SYabin Cui         case 0: {
610*01826a49SYabin Cui             // "Raw_Block - this is an uncompressed block. Block_Size is the
611*01826a49SYabin Cui             // number of bytes to read and copy."
612*01826a49SYabin Cui             const u8 *const read_ptr = IO_get_read_ptr(in, block_len);
613*01826a49SYabin Cui             u8 *const write_ptr = IO_get_write_ptr(out, block_len);
614*01826a49SYabin Cui 
615*01826a49SYabin Cui             // Copy the raw data into the output
616*01826a49SYabin Cui             memcpy(write_ptr, read_ptr, block_len);
617*01826a49SYabin Cui 
618*01826a49SYabin Cui             ctx->current_total_output += block_len;
619*01826a49SYabin Cui             break;
620*01826a49SYabin Cui         }
621*01826a49SYabin Cui         case 1: {
622*01826a49SYabin Cui             // "RLE_Block - this is a single byte, repeated N times. In which
623*01826a49SYabin Cui             // case, Block_Size is the size to regenerate, while the
624*01826a49SYabin Cui             // "compressed" block is just 1 byte (the byte to repeat)."
625*01826a49SYabin Cui             const u8 *const read_ptr = IO_get_read_ptr(in, 1);
626*01826a49SYabin Cui             u8 *const write_ptr = IO_get_write_ptr(out, block_len);
627*01826a49SYabin Cui 
628*01826a49SYabin Cui             // Copy `block_len` copies of `read_ptr[0]` to the output
629*01826a49SYabin Cui             memset(write_ptr, read_ptr[0], block_len);
630*01826a49SYabin Cui 
631*01826a49SYabin Cui             ctx->current_total_output += block_len;
632*01826a49SYabin Cui             break;
633*01826a49SYabin Cui         }
634*01826a49SYabin Cui         case 2: {
635*01826a49SYabin Cui             // "Compressed_Block - this is a Zstandard compressed block,
636*01826a49SYabin Cui             // detailed in another section of this specification. Block_Size is
637*01826a49SYabin Cui             // the compressed size.
638*01826a49SYabin Cui 
639*01826a49SYabin Cui             // Create a sub-stream for the block
640*01826a49SYabin Cui             istream_t block_stream = IO_make_sub_istream(in, block_len);
641*01826a49SYabin Cui             decompress_block(ctx, out, &block_stream);
642*01826a49SYabin Cui             break;
643*01826a49SYabin Cui         }
644*01826a49SYabin Cui         case 3:
645*01826a49SYabin Cui             // "Reserved - this is not a block. This value cannot be used with
646*01826a49SYabin Cui             // current version of this specification."
647*01826a49SYabin Cui             CORRUPTION();
648*01826a49SYabin Cui             break;
649*01826a49SYabin Cui         default:
650*01826a49SYabin Cui             IMPOSSIBLE();
651*01826a49SYabin Cui         }
652*01826a49SYabin Cui     } while (!last_block);
653*01826a49SYabin Cui 
654*01826a49SYabin Cui     if (ctx->header.content_checksum_flag) {
655*01826a49SYabin Cui         // This program does not support checking the checksum, so skip over it
656*01826a49SYabin Cui         // if it's present
657*01826a49SYabin Cui         IO_advance_input(in, 4);
658*01826a49SYabin Cui     }
659*01826a49SYabin Cui }
660*01826a49SYabin Cui /******* END FRAME DECODING ***************************************************/
661*01826a49SYabin Cui 
662*01826a49SYabin Cui /******* BLOCK DECOMPRESSION **************************************************/
decompress_block(frame_context_t * const ctx,ostream_t * const out,istream_t * const in)663*01826a49SYabin Cui static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
664*01826a49SYabin Cui                              istream_t *const in) {
665*01826a49SYabin Cui     // "A compressed block consists of 2 sections :
666*01826a49SYabin Cui     //
667*01826a49SYabin Cui     // Literals_Section
668*01826a49SYabin Cui     // Sequences_Section"
669*01826a49SYabin Cui 
670*01826a49SYabin Cui 
671*01826a49SYabin Cui     // Part 1: decode the literals block
672*01826a49SYabin Cui     u8 *literals = NULL;
673*01826a49SYabin Cui     const size_t literals_size = decode_literals(ctx, in, &literals);
674*01826a49SYabin Cui 
675*01826a49SYabin Cui     // Part 2: decode the sequences block
676*01826a49SYabin Cui     sequence_command_t *sequences = NULL;
677*01826a49SYabin Cui     const size_t num_sequences =
678*01826a49SYabin Cui         decode_sequences(ctx, in, &sequences);
679*01826a49SYabin Cui 
680*01826a49SYabin Cui     // Part 3: combine literals and sequence commands to generate output
681*01826a49SYabin Cui     execute_sequences(ctx, out, literals, literals_size, sequences,
682*01826a49SYabin Cui                       num_sequences);
683*01826a49SYabin Cui     free(literals);
684*01826a49SYabin Cui     free(sequences);
685*01826a49SYabin Cui }
686*01826a49SYabin Cui /******* END BLOCK DECOMPRESSION **********************************************/
687*01826a49SYabin Cui 
688*01826a49SYabin Cui /******* LITERALS DECODING ****************************************************/
689*01826a49SYabin Cui static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
690*01826a49SYabin Cui                                      const int block_type,
691*01826a49SYabin Cui                                      const int size_format);
692*01826a49SYabin Cui static size_t decode_literals_compressed(frame_context_t *const ctx,
693*01826a49SYabin Cui                                          istream_t *const in,
694*01826a49SYabin Cui                                          u8 **const literals,
695*01826a49SYabin Cui                                          const int block_type,
696*01826a49SYabin Cui                                          const int size_format);
697*01826a49SYabin Cui static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in);
698*01826a49SYabin Cui static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
699*01826a49SYabin Cui                                     int *const num_symbs);
700*01826a49SYabin Cui 
decode_literals(frame_context_t * const ctx,istream_t * const in,u8 ** const literals)701*01826a49SYabin Cui static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
702*01826a49SYabin Cui                               u8 **const literals) {
703*01826a49SYabin Cui     // "Literals can be stored uncompressed or compressed using Huffman prefix
704*01826a49SYabin Cui     // codes. When compressed, an optional tree description can be present,
705*01826a49SYabin Cui     // followed by 1 or 4 streams."
706*01826a49SYabin Cui     //
707*01826a49SYabin Cui     // "Literals_Section_Header
708*01826a49SYabin Cui     //
709*01826a49SYabin Cui     // Header is in charge of describing how literals are packed. It's a
710*01826a49SYabin Cui     // byte-aligned variable-size bitfield, ranging from 1 to 5 bytes, using
711*01826a49SYabin Cui     // little-endian convention."
712*01826a49SYabin Cui     //
713*01826a49SYabin Cui     // "Literals_Block_Type
714*01826a49SYabin Cui     //
715*01826a49SYabin Cui     // This field uses 2 lowest bits of first byte, describing 4 different block
716*01826a49SYabin Cui     // types"
717*01826a49SYabin Cui     //
718*01826a49SYabin Cui     // size_format takes between 1 and 2 bits
719*01826a49SYabin Cui     int block_type = (int)IO_read_bits(in, 2);
720*01826a49SYabin Cui     int size_format = (int)IO_read_bits(in, 2);
721*01826a49SYabin Cui 
722*01826a49SYabin Cui     if (block_type <= 1) {
723*01826a49SYabin Cui         // Raw or RLE literals block
724*01826a49SYabin Cui         return decode_literals_simple(in, literals, block_type,
725*01826a49SYabin Cui                                       size_format);
726*01826a49SYabin Cui     } else {
727*01826a49SYabin Cui         // Huffman compressed literals
728*01826a49SYabin Cui         return decode_literals_compressed(ctx, in, literals, block_type,
729*01826a49SYabin Cui                                           size_format);
730*01826a49SYabin Cui     }
731*01826a49SYabin Cui }
732*01826a49SYabin Cui 
733*01826a49SYabin Cui /// Decodes literals blocks in raw or RLE form
decode_literals_simple(istream_t * const in,u8 ** const literals,const int block_type,const int size_format)734*01826a49SYabin Cui static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
735*01826a49SYabin Cui                                      const int block_type,
736*01826a49SYabin Cui                                      const int size_format) {
737*01826a49SYabin Cui     size_t size;
738*01826a49SYabin Cui     switch (size_format) {
739*01826a49SYabin Cui     // These cases are in the form ?0
740*01826a49SYabin Cui     // In this case, the ? bit is actually part of the size field
741*01826a49SYabin Cui     case 0:
742*01826a49SYabin Cui     case 2:
743*01826a49SYabin Cui         // "Size_Format uses 1 bit. Regenerated_Size uses 5 bits (0-31)."
744*01826a49SYabin Cui         IO_rewind_bits(in, 1);
745*01826a49SYabin Cui         size = IO_read_bits(in, 5);
746*01826a49SYabin Cui         break;
747*01826a49SYabin Cui     case 1:
748*01826a49SYabin Cui         // "Size_Format uses 2 bits. Regenerated_Size uses 12 bits (0-4095)."
749*01826a49SYabin Cui         size = IO_read_bits(in, 12);
750*01826a49SYabin Cui         break;
751*01826a49SYabin Cui     case 3:
752*01826a49SYabin Cui         // "Size_Format uses 2 bits. Regenerated_Size uses 20 bits (0-1048575)."
753*01826a49SYabin Cui         size = IO_read_bits(in, 20);
754*01826a49SYabin Cui         break;
755*01826a49SYabin Cui     default:
756*01826a49SYabin Cui         // Size format is in range 0-3
757*01826a49SYabin Cui         IMPOSSIBLE();
758*01826a49SYabin Cui     }
759*01826a49SYabin Cui 
760*01826a49SYabin Cui     if (size > MAX_LITERALS_SIZE) {
761*01826a49SYabin Cui         CORRUPTION();
762*01826a49SYabin Cui     }
763*01826a49SYabin Cui 
764*01826a49SYabin Cui     *literals = malloc(size);
765*01826a49SYabin Cui     if (!*literals) {
766*01826a49SYabin Cui         BAD_ALLOC();
767*01826a49SYabin Cui     }
768*01826a49SYabin Cui 
769*01826a49SYabin Cui     switch (block_type) {
770*01826a49SYabin Cui     case 0: {
771*01826a49SYabin Cui         // "Raw_Literals_Block - Literals are stored uncompressed."
772*01826a49SYabin Cui         const u8 *const read_ptr = IO_get_read_ptr(in, size);
773*01826a49SYabin Cui         memcpy(*literals, read_ptr, size);
774*01826a49SYabin Cui         break;
775*01826a49SYabin Cui     }
776*01826a49SYabin Cui     case 1: {
777*01826a49SYabin Cui         // "RLE_Literals_Block - Literals consist of a single byte value repeated N times."
778*01826a49SYabin Cui         const u8 *const read_ptr = IO_get_read_ptr(in, 1);
779*01826a49SYabin Cui         memset(*literals, read_ptr[0], size);
780*01826a49SYabin Cui         break;
781*01826a49SYabin Cui     }
782*01826a49SYabin Cui     default:
783*01826a49SYabin Cui         IMPOSSIBLE();
784*01826a49SYabin Cui     }
785*01826a49SYabin Cui 
786*01826a49SYabin Cui     return size;
787*01826a49SYabin Cui }
788*01826a49SYabin Cui 
789*01826a49SYabin Cui /// Decodes Huffman compressed literals
decode_literals_compressed(frame_context_t * const ctx,istream_t * const in,u8 ** const literals,const int block_type,const int size_format)790*01826a49SYabin Cui static size_t decode_literals_compressed(frame_context_t *const ctx,
791*01826a49SYabin Cui                                          istream_t *const in,
792*01826a49SYabin Cui                                          u8 **const literals,
793*01826a49SYabin Cui                                          const int block_type,
794*01826a49SYabin Cui                                          const int size_format) {
795*01826a49SYabin Cui     size_t regenerated_size, compressed_size;
796*01826a49SYabin Cui     // Only size_format=0 has 1 stream, so default to 4
797*01826a49SYabin Cui     int num_streams = 4;
798*01826a49SYabin Cui     switch (size_format) {
799*01826a49SYabin Cui     case 0:
800*01826a49SYabin Cui         // "A single stream. Both Compressed_Size and Regenerated_Size use 10
801*01826a49SYabin Cui         // bits (0-1023)."
802*01826a49SYabin Cui         num_streams = 1;
803*01826a49SYabin Cui     // Fall through as it has the same size format
804*01826a49SYabin Cui         /* fallthrough */
805*01826a49SYabin Cui     case 1:
806*01826a49SYabin Cui         // "4 streams. Both Compressed_Size and Regenerated_Size use 10 bits
807*01826a49SYabin Cui         // (0-1023)."
808*01826a49SYabin Cui         regenerated_size = IO_read_bits(in, 10);
809*01826a49SYabin Cui         compressed_size = IO_read_bits(in, 10);
810*01826a49SYabin Cui         break;
811*01826a49SYabin Cui     case 2:
812*01826a49SYabin Cui         // "4 streams. Both Compressed_Size and Regenerated_Size use 14 bits
813*01826a49SYabin Cui         // (0-16383)."
814*01826a49SYabin Cui         regenerated_size = IO_read_bits(in, 14);
815*01826a49SYabin Cui         compressed_size = IO_read_bits(in, 14);
816*01826a49SYabin Cui         break;
817*01826a49SYabin Cui     case 3:
818*01826a49SYabin Cui         // "4 streams. Both Compressed_Size and Regenerated_Size use 18 bits
819*01826a49SYabin Cui         // (0-262143)."
820*01826a49SYabin Cui         regenerated_size = IO_read_bits(in, 18);
821*01826a49SYabin Cui         compressed_size = IO_read_bits(in, 18);
822*01826a49SYabin Cui         break;
823*01826a49SYabin Cui     default:
824*01826a49SYabin Cui         // Impossible
825*01826a49SYabin Cui         IMPOSSIBLE();
826*01826a49SYabin Cui     }
827*01826a49SYabin Cui     if (regenerated_size > MAX_LITERALS_SIZE) {
828*01826a49SYabin Cui         CORRUPTION();
829*01826a49SYabin Cui     }
830*01826a49SYabin Cui 
831*01826a49SYabin Cui     *literals = malloc(regenerated_size);
832*01826a49SYabin Cui     if (!*literals) {
833*01826a49SYabin Cui         BAD_ALLOC();
834*01826a49SYabin Cui     }
835*01826a49SYabin Cui 
836*01826a49SYabin Cui     ostream_t lit_stream = IO_make_ostream(*literals, regenerated_size);
837*01826a49SYabin Cui     istream_t huf_stream = IO_make_sub_istream(in, compressed_size);
838*01826a49SYabin Cui 
839*01826a49SYabin Cui     if (block_type == 2) {
840*01826a49SYabin Cui         // Decode the provided Huffman table
841*01826a49SYabin Cui         // "This section is only present when Literals_Block_Type type is
842*01826a49SYabin Cui         // Compressed_Literals_Block (2)."
843*01826a49SYabin Cui 
844*01826a49SYabin Cui         HUF_free_dtable(&ctx->literals_dtable);
845*01826a49SYabin Cui         decode_huf_table(&ctx->literals_dtable, &huf_stream);
846*01826a49SYabin Cui     } else {
847*01826a49SYabin Cui         // If the previous Huffman table is being repeated, ensure it exists
848*01826a49SYabin Cui         if (!ctx->literals_dtable.symbols) {
849*01826a49SYabin Cui             CORRUPTION();
850*01826a49SYabin Cui         }
851*01826a49SYabin Cui     }
852*01826a49SYabin Cui 
853*01826a49SYabin Cui     size_t symbols_decoded;
854*01826a49SYabin Cui     if (num_streams == 1) {
855*01826a49SYabin Cui         symbols_decoded = HUF_decompress_1stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
856*01826a49SYabin Cui     } else {
857*01826a49SYabin Cui         symbols_decoded = HUF_decompress_4stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
858*01826a49SYabin Cui     }
859*01826a49SYabin Cui 
860*01826a49SYabin Cui     if (symbols_decoded != regenerated_size) {
861*01826a49SYabin Cui         CORRUPTION();
862*01826a49SYabin Cui     }
863*01826a49SYabin Cui 
864*01826a49SYabin Cui     return regenerated_size;
865*01826a49SYabin Cui }
866*01826a49SYabin Cui 
867*01826a49SYabin Cui // Decode the Huffman table description
decode_huf_table(HUF_dtable * const dtable,istream_t * const in)868*01826a49SYabin Cui static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in) {
869*01826a49SYabin Cui     // "All literal values from zero (included) to last present one (excluded)
870*01826a49SYabin Cui     // are represented by Weight with values from 0 to Max_Number_of_Bits."
871*01826a49SYabin Cui 
872*01826a49SYabin Cui     // "This is a single byte value (0-255), which describes how to decode the list of weights."
873*01826a49SYabin Cui     const u8 header = IO_read_bits(in, 8);
874*01826a49SYabin Cui 
875*01826a49SYabin Cui     u8 weights[HUF_MAX_SYMBS];
876*01826a49SYabin Cui     memset(weights, 0, sizeof(weights));
877*01826a49SYabin Cui 
878*01826a49SYabin Cui     int num_symbs;
879*01826a49SYabin Cui 
880*01826a49SYabin Cui     if (header >= 128) {
881*01826a49SYabin Cui         // "This is a direct representation, where each Weight is written
882*01826a49SYabin Cui         // directly as a 4 bits field (0-15). The full representation occupies
883*01826a49SYabin Cui         // ((Number_of_Symbols+1)/2) bytes, meaning it uses a last full byte
884*01826a49SYabin Cui         // even if Number_of_Symbols is odd. Number_of_Symbols = headerByte -
885*01826a49SYabin Cui         // 127"
886*01826a49SYabin Cui         num_symbs = header - 127;
887*01826a49SYabin Cui         const size_t bytes = (num_symbs + 1) / 2;
888*01826a49SYabin Cui 
889*01826a49SYabin Cui         const u8 *const weight_src = IO_get_read_ptr(in, bytes);
890*01826a49SYabin Cui 
891*01826a49SYabin Cui         for (int i = 0; i < num_symbs; i++) {
892*01826a49SYabin Cui             // "They are encoded forward, 2
893*01826a49SYabin Cui             // weights to a byte with the first weight taking the top four bits
894*01826a49SYabin Cui             // and the second taking the bottom four (e.g. the following
895*01826a49SYabin Cui             // operations could be used to read the weights: Weight[0] =
896*01826a49SYabin Cui             // (Byte[0] >> 4), Weight[1] = (Byte[0] & 0xf), etc.)."
897*01826a49SYabin Cui             if (i % 2 == 0) {
898*01826a49SYabin Cui                 weights[i] = weight_src[i / 2] >> 4;
899*01826a49SYabin Cui             } else {
900*01826a49SYabin Cui                 weights[i] = weight_src[i / 2] & 0xf;
901*01826a49SYabin Cui             }
902*01826a49SYabin Cui         }
903*01826a49SYabin Cui     } else {
904*01826a49SYabin Cui         // The weights are FSE encoded, decode them before we can construct the
905*01826a49SYabin Cui         // table
906*01826a49SYabin Cui         istream_t fse_stream = IO_make_sub_istream(in, header);
907*01826a49SYabin Cui         ostream_t weight_stream = IO_make_ostream(weights, HUF_MAX_SYMBS);
908*01826a49SYabin Cui         fse_decode_hufweights(&weight_stream, &fse_stream, &num_symbs);
909*01826a49SYabin Cui     }
910*01826a49SYabin Cui 
911*01826a49SYabin Cui     // Construct the table using the decoded weights
912*01826a49SYabin Cui     HUF_init_dtable_usingweights(dtable, weights, num_symbs);
913*01826a49SYabin Cui }
914*01826a49SYabin Cui 
fse_decode_hufweights(ostream_t * weights,istream_t * const in,int * const num_symbs)915*01826a49SYabin Cui static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
916*01826a49SYabin Cui                                     int *const num_symbs) {
917*01826a49SYabin Cui     const int MAX_ACCURACY_LOG = 7;
918*01826a49SYabin Cui 
919*01826a49SYabin Cui     FSE_dtable dtable;
920*01826a49SYabin Cui 
921*01826a49SYabin Cui     // "An FSE bitstream starts by a header, describing probabilities
922*01826a49SYabin Cui     // distribution. It will create a Decoding Table. For a list of Huffman
923*01826a49SYabin Cui     // weights, maximum accuracy is 7 bits."
924*01826a49SYabin Cui     FSE_decode_header(&dtable, in, MAX_ACCURACY_LOG);
925*01826a49SYabin Cui 
926*01826a49SYabin Cui     // Decode the weights
927*01826a49SYabin Cui     *num_symbs = FSE_decompress_interleaved2(&dtable, weights, in);
928*01826a49SYabin Cui 
929*01826a49SYabin Cui     FSE_free_dtable(&dtable);
930*01826a49SYabin Cui }
931*01826a49SYabin Cui /******* END LITERALS DECODING ************************************************/
932*01826a49SYabin Cui 
933*01826a49SYabin Cui /******* SEQUENCE DECODING ****************************************************/
934*01826a49SYabin Cui /// The combination of FSE states needed to decode sequences
935*01826a49SYabin Cui typedef struct {
936*01826a49SYabin Cui     FSE_dtable ll_table;
937*01826a49SYabin Cui     FSE_dtable of_table;
938*01826a49SYabin Cui     FSE_dtable ml_table;
939*01826a49SYabin Cui 
940*01826a49SYabin Cui     u16 ll_state;
941*01826a49SYabin Cui     u16 of_state;
942*01826a49SYabin Cui     u16 ml_state;
943*01826a49SYabin Cui } sequence_states_t;
944*01826a49SYabin Cui 
945*01826a49SYabin Cui /// Different modes to signal to decode_seq_tables what to do
946*01826a49SYabin Cui typedef enum {
947*01826a49SYabin Cui     seq_literal_length = 0,
948*01826a49SYabin Cui     seq_offset = 1,
949*01826a49SYabin Cui     seq_match_length = 2,
950*01826a49SYabin Cui } seq_part_t;
951*01826a49SYabin Cui 
952*01826a49SYabin Cui typedef enum {
953*01826a49SYabin Cui     seq_predefined = 0,
954*01826a49SYabin Cui     seq_rle = 1,
955*01826a49SYabin Cui     seq_fse = 2,
956*01826a49SYabin Cui     seq_repeat = 3,
957*01826a49SYabin Cui } seq_mode_t;
958*01826a49SYabin Cui 
959*01826a49SYabin Cui /// The predefined FSE distribution tables for `seq_predefined` mode
960*01826a49SYabin Cui static const i16 SEQ_LITERAL_LENGTH_DEFAULT_DIST[36] = {
961*01826a49SYabin Cui     4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1,  1,  2,  2,
962*01826a49SYabin Cui     2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1, -1, -1, -1, -1};
963*01826a49SYabin Cui static const i16 SEQ_OFFSET_DEFAULT_DIST[29] = {
964*01826a49SYabin Cui     1, 1, 1, 1, 1, 1, 2, 2, 2, 1,  1,  1,  1,  1, 1,
965*01826a49SYabin Cui     1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1};
966*01826a49SYabin Cui static const i16 SEQ_MATCH_LENGTH_DEFAULT_DIST[53] = {
967*01826a49SYabin Cui     1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1,  1,  1,  1,  1,  1,  1, 1,
968*01826a49SYabin Cui     1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,  1,  1,  1,  1,  1,  1, 1,
969*01826a49SYabin Cui     1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1};
970*01826a49SYabin Cui 
971*01826a49SYabin Cui /// The sequence decoding baseline and number of additional bits to read/add
972*01826a49SYabin Cui /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#the-codes-for-literals-lengths-match-lengths-and-offsets
973*01826a49SYabin Cui static const u32 SEQ_LITERAL_LENGTH_BASELINES[36] = {
974*01826a49SYabin Cui     0,  1,  2,   3,   4,   5,    6,    7,    8,    9,     10,    11,
975*01826a49SYabin Cui     12, 13, 14,  15,  16,  18,   20,   22,   24,   28,    32,    40,
976*01826a49SYabin Cui     48, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536};
977*01826a49SYabin Cui static const u8 SEQ_LITERAL_LENGTH_EXTRA_BITS[36] = {
978*01826a49SYabin Cui     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  0,  1,  1,
979*01826a49SYabin Cui     1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
980*01826a49SYabin Cui 
981*01826a49SYabin Cui static const u32 SEQ_MATCH_LENGTH_BASELINES[53] = {
982*01826a49SYabin Cui     3,  4,   5,   6,   7,    8,    9,    10,   11,    12,    13,   14, 15, 16,
983*01826a49SYabin Cui     17, 18,  19,  20,  21,   22,   23,   24,   25,    26,    27,   28, 29, 30,
984*01826a49SYabin Cui     31, 32,  33,  34,  35,   37,   39,   41,   43,    47,    51,   59, 67, 83,
985*01826a49SYabin Cui     99, 131, 259, 515, 1027, 2051, 4099, 8195, 16387, 32771, 65539};
986*01826a49SYabin Cui static const u8 SEQ_MATCH_LENGTH_EXTRA_BITS[53] = {
987*01826a49SYabin Cui     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  0,  0,  0, 0,
988*01826a49SYabin Cui     0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  1,  1,  1, 1,
989*01826a49SYabin Cui     2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
990*01826a49SYabin Cui 
991*01826a49SYabin Cui /// Offset decoding is simpler so we just need a maximum code value
992*01826a49SYabin Cui static const u8 SEQ_MAX_CODES[3] = {35, (u8)-1, 52};
993*01826a49SYabin Cui 
994*01826a49SYabin Cui static void decompress_sequences(frame_context_t *const ctx,
995*01826a49SYabin Cui                                  istream_t *const in,
996*01826a49SYabin Cui                                  sequence_command_t *const sequences,
997*01826a49SYabin Cui                                  const size_t num_sequences);
998*01826a49SYabin Cui static sequence_command_t decode_sequence(sequence_states_t *const state,
999*01826a49SYabin Cui                                           const u8 *const src,
1000*01826a49SYabin Cui                                           i64 *const offset,
1001*01826a49SYabin Cui                                           int lastSequence);
1002*01826a49SYabin Cui static void decode_seq_table(FSE_dtable *const table, istream_t *const in,
1003*01826a49SYabin Cui                                const seq_part_t type, const seq_mode_t mode);
1004*01826a49SYabin Cui 
decode_sequences(frame_context_t * const ctx,istream_t * in,sequence_command_t ** const sequences)1005*01826a49SYabin Cui static size_t decode_sequences(frame_context_t *const ctx, istream_t *in,
1006*01826a49SYabin Cui                                sequence_command_t **const sequences) {
1007*01826a49SYabin Cui     // "A compressed block is a succession of sequences . A sequence is a
1008*01826a49SYabin Cui     // literal copy command, followed by a match copy command. A literal copy
1009*01826a49SYabin Cui     // command specifies a length. It is the number of bytes to be copied (or
1010*01826a49SYabin Cui     // extracted) from the literal section. A match copy command specifies an
1011*01826a49SYabin Cui     // offset and a length. The offset gives the position to copy from, which
1012*01826a49SYabin Cui     // can be within a previous block."
1013*01826a49SYabin Cui 
1014*01826a49SYabin Cui     size_t num_sequences;
1015*01826a49SYabin Cui 
1016*01826a49SYabin Cui     // "Number_of_Sequences
1017*01826a49SYabin Cui     //
1018*01826a49SYabin Cui     // This is a variable size field using between 1 and 3 bytes. Let's call its
1019*01826a49SYabin Cui     // first byte byte0."
1020*01826a49SYabin Cui     u8 header = IO_read_bits(in, 8);
1021*01826a49SYabin Cui     if (header < 128) {
1022*01826a49SYabin Cui         // "Number_of_Sequences = byte0 . Uses 1 byte."
1023*01826a49SYabin Cui         num_sequences = header;
1024*01826a49SYabin Cui     } else if (header < 255) {
1025*01826a49SYabin Cui         // "Number_of_Sequences = ((byte0-128) << 8) + byte1 . Uses 2 bytes."
1026*01826a49SYabin Cui         num_sequences = ((header - 128) << 8) + IO_read_bits(in, 8);
1027*01826a49SYabin Cui     } else {
1028*01826a49SYabin Cui         // "Number_of_Sequences = byte1 + (byte2<<8) + 0x7F00 . Uses 3 bytes."
1029*01826a49SYabin Cui         num_sequences = IO_read_bits(in, 16) + 0x7F00;
1030*01826a49SYabin Cui     }
1031*01826a49SYabin Cui 
1032*01826a49SYabin Cui     if (num_sequences == 0) {
1033*01826a49SYabin Cui         // "There are no sequences. The sequence section stops there."
1034*01826a49SYabin Cui         *sequences = NULL;
1035*01826a49SYabin Cui         return 0;
1036*01826a49SYabin Cui     }
1037*01826a49SYabin Cui 
1038*01826a49SYabin Cui     *sequences = malloc(num_sequences * sizeof(sequence_command_t));
1039*01826a49SYabin Cui     if (!*sequences) {
1040*01826a49SYabin Cui         BAD_ALLOC();
1041*01826a49SYabin Cui     }
1042*01826a49SYabin Cui 
1043*01826a49SYabin Cui     decompress_sequences(ctx, in, *sequences, num_sequences);
1044*01826a49SYabin Cui     return num_sequences;
1045*01826a49SYabin Cui }
1046*01826a49SYabin Cui 
1047*01826a49SYabin Cui /// Decompress the FSE encoded sequence commands
decompress_sequences(frame_context_t * const ctx,istream_t * in,sequence_command_t * const sequences,const size_t num_sequences)1048*01826a49SYabin Cui static void decompress_sequences(frame_context_t *const ctx, istream_t *in,
1049*01826a49SYabin Cui                                  sequence_command_t *const sequences,
1050*01826a49SYabin Cui                                  const size_t num_sequences) {
1051*01826a49SYabin Cui     // "The Sequences_Section regroup all symbols required to decode commands.
1052*01826a49SYabin Cui     // There are 3 symbol types : literals lengths, offsets and match lengths.
1053*01826a49SYabin Cui     // They are encoded together, interleaved, in a single bitstream."
1054*01826a49SYabin Cui 
1055*01826a49SYabin Cui     // "Symbol compression modes
1056*01826a49SYabin Cui     //
1057*01826a49SYabin Cui     // This is a single byte, defining the compression mode of each symbol
1058*01826a49SYabin Cui     // type."
1059*01826a49SYabin Cui     //
1060*01826a49SYabin Cui     // Bit number : Field name
1061*01826a49SYabin Cui     // 7-6        : Literals_Lengths_Mode
1062*01826a49SYabin Cui     // 5-4        : Offsets_Mode
1063*01826a49SYabin Cui     // 3-2        : Match_Lengths_Mode
1064*01826a49SYabin Cui     // 1-0        : Reserved
1065*01826a49SYabin Cui     u8 compression_modes = IO_read_bits(in, 8);
1066*01826a49SYabin Cui 
1067*01826a49SYabin Cui     if ((compression_modes & 3) != 0) {
1068*01826a49SYabin Cui         // Reserved bits set
1069*01826a49SYabin Cui         CORRUPTION();
1070*01826a49SYabin Cui     }
1071*01826a49SYabin Cui 
1072*01826a49SYabin Cui     // "Following the header, up to 3 distribution tables can be described. When
1073*01826a49SYabin Cui     // present, they are in this order :
1074*01826a49SYabin Cui     //
1075*01826a49SYabin Cui     // Literals lengths
1076*01826a49SYabin Cui     // Offsets
1077*01826a49SYabin Cui     // Match Lengths"
1078*01826a49SYabin Cui     // Update the tables we have stored in the context
1079*01826a49SYabin Cui     decode_seq_table(&ctx->ll_dtable, in, seq_literal_length,
1080*01826a49SYabin Cui                      (compression_modes >> 6) & 3);
1081*01826a49SYabin Cui 
1082*01826a49SYabin Cui     decode_seq_table(&ctx->of_dtable, in, seq_offset,
1083*01826a49SYabin Cui                      (compression_modes >> 4) & 3);
1084*01826a49SYabin Cui 
1085*01826a49SYabin Cui     decode_seq_table(&ctx->ml_dtable, in, seq_match_length,
1086*01826a49SYabin Cui                      (compression_modes >> 2) & 3);
1087*01826a49SYabin Cui 
1088*01826a49SYabin Cui 
1089*01826a49SYabin Cui     sequence_states_t states;
1090*01826a49SYabin Cui 
1091*01826a49SYabin Cui     // Initialize the decoding tables
1092*01826a49SYabin Cui     {
1093*01826a49SYabin Cui         states.ll_table = ctx->ll_dtable;
1094*01826a49SYabin Cui         states.of_table = ctx->of_dtable;
1095*01826a49SYabin Cui         states.ml_table = ctx->ml_dtable;
1096*01826a49SYabin Cui     }
1097*01826a49SYabin Cui 
1098*01826a49SYabin Cui     const size_t len = IO_istream_len(in);
1099*01826a49SYabin Cui     const u8 *const src = IO_get_read_ptr(in, len);
1100*01826a49SYabin Cui 
1101*01826a49SYabin Cui     // "After writing the last bit containing information, the compressor writes
1102*01826a49SYabin Cui     // a single 1-bit and then fills the byte with 0-7 0 bits of padding."
1103*01826a49SYabin Cui     const int padding = 8 - highest_set_bit(src[len - 1]);
1104*01826a49SYabin Cui     // The offset starts at the end because FSE streams are read backwards
1105*01826a49SYabin Cui     i64 bit_offset = (i64)(len * 8 - (size_t)padding);
1106*01826a49SYabin Cui 
1107*01826a49SYabin Cui     // "The bitstream starts with initial state values, each using the required
1108*01826a49SYabin Cui     // number of bits in their respective accuracy, decoded previously from
1109*01826a49SYabin Cui     // their normalized distribution.
1110*01826a49SYabin Cui     //
1111*01826a49SYabin Cui     // It starts by Literals_Length_State, followed by Offset_State, and finally
1112*01826a49SYabin Cui     // Match_Length_State."
1113*01826a49SYabin Cui     FSE_init_state(&states.ll_table, &states.ll_state, src, &bit_offset);
1114*01826a49SYabin Cui     FSE_init_state(&states.of_table, &states.of_state, src, &bit_offset);
1115*01826a49SYabin Cui     FSE_init_state(&states.ml_table, &states.ml_state, src, &bit_offset);
1116*01826a49SYabin Cui 
1117*01826a49SYabin Cui     for (size_t i = 0; i < num_sequences; i++) {
1118*01826a49SYabin Cui         // Decode sequences one by one
1119*01826a49SYabin Cui         sequences[i] = decode_sequence(&states, src, &bit_offset, i==num_sequences-1);
1120*01826a49SYabin Cui     }
1121*01826a49SYabin Cui 
1122*01826a49SYabin Cui     if (bit_offset != 0) {
1123*01826a49SYabin Cui         CORRUPTION();
1124*01826a49SYabin Cui     }
1125*01826a49SYabin Cui }
1126*01826a49SYabin Cui 
1127*01826a49SYabin Cui // Decode a single sequence and update the state
decode_sequence(sequence_states_t * const states,const u8 * const src,i64 * const offset,int lastSequence)1128*01826a49SYabin Cui static sequence_command_t decode_sequence(sequence_states_t *const states,
1129*01826a49SYabin Cui                                           const u8 *const src,
1130*01826a49SYabin Cui                                           i64 *const offset,
1131*01826a49SYabin Cui                                           int lastSequence) {
1132*01826a49SYabin Cui     // "Each symbol is a code in its own context, which specifies Baseline and
1133*01826a49SYabin Cui     // Number_of_Bits to add. Codes are FSE compressed, and interleaved with raw
1134*01826a49SYabin Cui     // additional bits in the same bitstream."
1135*01826a49SYabin Cui 
1136*01826a49SYabin Cui     // Decode symbols, but don't update states
1137*01826a49SYabin Cui     const u8 of_code = FSE_peek_symbol(&states->of_table, states->of_state);
1138*01826a49SYabin Cui     const u8 ll_code = FSE_peek_symbol(&states->ll_table, states->ll_state);
1139*01826a49SYabin Cui     const u8 ml_code = FSE_peek_symbol(&states->ml_table, states->ml_state);
1140*01826a49SYabin Cui 
1141*01826a49SYabin Cui     // Offset doesn't need a max value as it's not decoded using a table
1142*01826a49SYabin Cui     if (ll_code > SEQ_MAX_CODES[seq_literal_length] ||
1143*01826a49SYabin Cui         ml_code > SEQ_MAX_CODES[seq_match_length]) {
1144*01826a49SYabin Cui         CORRUPTION();
1145*01826a49SYabin Cui     }
1146*01826a49SYabin Cui 
1147*01826a49SYabin Cui     // Read the interleaved bits
1148*01826a49SYabin Cui     sequence_command_t seq;
1149*01826a49SYabin Cui     // "Decoding starts by reading the Number_of_Bits required to decode Offset.
1150*01826a49SYabin Cui     // It then does the same for Match_Length, and then for Literals_Length."
1151*01826a49SYabin Cui     seq.offset = ((u32)1 << of_code) + STREAM_read_bits(src, of_code, offset);
1152*01826a49SYabin Cui 
1153*01826a49SYabin Cui     seq.match_length =
1154*01826a49SYabin Cui         SEQ_MATCH_LENGTH_BASELINES[ml_code] +
1155*01826a49SYabin Cui         STREAM_read_bits(src, SEQ_MATCH_LENGTH_EXTRA_BITS[ml_code], offset);
1156*01826a49SYabin Cui 
1157*01826a49SYabin Cui     seq.literal_length =
1158*01826a49SYabin Cui         SEQ_LITERAL_LENGTH_BASELINES[ll_code] +
1159*01826a49SYabin Cui         STREAM_read_bits(src, SEQ_LITERAL_LENGTH_EXTRA_BITS[ll_code], offset);
1160*01826a49SYabin Cui 
1161*01826a49SYabin Cui     // "If it is not the last sequence in the block, the next operation is to
1162*01826a49SYabin Cui     // update states. Using the rules pre-calculated in the decoding tables,
1163*01826a49SYabin Cui     // Literals_Length_State is updated, followed by Match_Length_State, and
1164*01826a49SYabin Cui     // then Offset_State."
1165*01826a49SYabin Cui     // If the stream is complete don't read bits to update state
1166*01826a49SYabin Cui     if (!lastSequence) {
1167*01826a49SYabin Cui         FSE_update_state(&states->ll_table, &states->ll_state, src, offset);
1168*01826a49SYabin Cui         FSE_update_state(&states->ml_table, &states->ml_state, src, offset);
1169*01826a49SYabin Cui         FSE_update_state(&states->of_table, &states->of_state, src, offset);
1170*01826a49SYabin Cui     }
1171*01826a49SYabin Cui 
1172*01826a49SYabin Cui     return seq;
1173*01826a49SYabin Cui }
1174*01826a49SYabin Cui 
1175*01826a49SYabin Cui /// Given a sequence part and table mode, decode the FSE distribution
1176*01826a49SYabin Cui /// Errors if the mode is `seq_repeat` without a pre-existing table in `table`
decode_seq_table(FSE_dtable * const table,istream_t * const in,const seq_part_t type,const seq_mode_t mode)1177*01826a49SYabin Cui static void decode_seq_table(FSE_dtable *const table, istream_t *const in,
1178*01826a49SYabin Cui                              const seq_part_t type, const seq_mode_t mode) {
1179*01826a49SYabin Cui     // Constant arrays indexed by seq_part_t
1180*01826a49SYabin Cui     const i16 *const default_distributions[] = {SEQ_LITERAL_LENGTH_DEFAULT_DIST,
1181*01826a49SYabin Cui                                                 SEQ_OFFSET_DEFAULT_DIST,
1182*01826a49SYabin Cui                                                 SEQ_MATCH_LENGTH_DEFAULT_DIST};
1183*01826a49SYabin Cui     const size_t default_distribution_lengths[] = {36, 29, 53};
1184*01826a49SYabin Cui     const size_t default_distribution_accuracies[] = {6, 5, 6};
1185*01826a49SYabin Cui 
1186*01826a49SYabin Cui     const size_t max_accuracies[] = {9, 8, 9};
1187*01826a49SYabin Cui 
1188*01826a49SYabin Cui     if (mode != seq_repeat) {
1189*01826a49SYabin Cui         // Free old one before overwriting
1190*01826a49SYabin Cui         FSE_free_dtable(table);
1191*01826a49SYabin Cui     }
1192*01826a49SYabin Cui 
1193*01826a49SYabin Cui     switch (mode) {
1194*01826a49SYabin Cui     case seq_predefined: {
1195*01826a49SYabin Cui         // "Predefined_Mode : uses a predefined distribution table."
1196*01826a49SYabin Cui         const i16 *distribution = default_distributions[type];
1197*01826a49SYabin Cui         const size_t symbs = default_distribution_lengths[type];
1198*01826a49SYabin Cui         const size_t accuracy_log = default_distribution_accuracies[type];
1199*01826a49SYabin Cui 
1200*01826a49SYabin Cui         FSE_init_dtable(table, distribution, symbs, accuracy_log);
1201*01826a49SYabin Cui         break;
1202*01826a49SYabin Cui     }
1203*01826a49SYabin Cui     case seq_rle: {
1204*01826a49SYabin Cui         // "RLE_Mode : it's a single code, repeated Number_of_Sequences times."
1205*01826a49SYabin Cui         const u8 symb = IO_get_read_ptr(in, 1)[0];
1206*01826a49SYabin Cui         FSE_init_dtable_rle(table, symb);
1207*01826a49SYabin Cui         break;
1208*01826a49SYabin Cui     }
1209*01826a49SYabin Cui     case seq_fse: {
1210*01826a49SYabin Cui         // "FSE_Compressed_Mode : standard FSE compression. A distribution table
1211*01826a49SYabin Cui         // will be present "
1212*01826a49SYabin Cui         FSE_decode_header(table, in, max_accuracies[type]);
1213*01826a49SYabin Cui         break;
1214*01826a49SYabin Cui     }
1215*01826a49SYabin Cui     case seq_repeat:
1216*01826a49SYabin Cui         // "Repeat_Mode : reuse distribution table from previous compressed
1217*01826a49SYabin Cui         // block."
1218*01826a49SYabin Cui         // Nothing to do here, table will be unchanged
1219*01826a49SYabin Cui         if (!table->symbols) {
1220*01826a49SYabin Cui             // This mode is invalid if we don't already have a table
1221*01826a49SYabin Cui             CORRUPTION();
1222*01826a49SYabin Cui         }
1223*01826a49SYabin Cui         break;
1224*01826a49SYabin Cui     default:
1225*01826a49SYabin Cui         // Impossible, as mode is from 0-3
1226*01826a49SYabin Cui         IMPOSSIBLE();
1227*01826a49SYabin Cui         break;
1228*01826a49SYabin Cui     }
1229*01826a49SYabin Cui 
1230*01826a49SYabin Cui }
1231*01826a49SYabin Cui /******* END SEQUENCE DECODING ************************************************/
1232*01826a49SYabin Cui 
1233*01826a49SYabin Cui /******* SEQUENCE EXECUTION ***************************************************/
execute_sequences(frame_context_t * const ctx,ostream_t * const out,const u8 * const literals,const size_t literals_len,const sequence_command_t * const sequences,const size_t num_sequences)1234*01826a49SYabin Cui static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
1235*01826a49SYabin Cui                               const u8 *const literals,
1236*01826a49SYabin Cui                               const size_t literals_len,
1237*01826a49SYabin Cui                               const sequence_command_t *const sequences,
1238*01826a49SYabin Cui                               const size_t num_sequences) {
1239*01826a49SYabin Cui     istream_t litstream = IO_make_istream(literals, literals_len);
1240*01826a49SYabin Cui 
1241*01826a49SYabin Cui     u64 *const offset_hist = ctx->previous_offsets;
1242*01826a49SYabin Cui     size_t total_output = ctx->current_total_output;
1243*01826a49SYabin Cui 
1244*01826a49SYabin Cui     for (size_t i = 0; i < num_sequences; i++) {
1245*01826a49SYabin Cui         const sequence_command_t seq = sequences[i];
1246*01826a49SYabin Cui         {
1247*01826a49SYabin Cui             const u32 literals_size = copy_literals(seq.literal_length, &litstream, out);
1248*01826a49SYabin Cui             total_output += literals_size;
1249*01826a49SYabin Cui         }
1250*01826a49SYabin Cui 
1251*01826a49SYabin Cui         size_t const offset = compute_offset(seq, offset_hist);
1252*01826a49SYabin Cui 
1253*01826a49SYabin Cui         size_t const match_length = seq.match_length;
1254*01826a49SYabin Cui 
1255*01826a49SYabin Cui         execute_match_copy(ctx, offset, match_length, total_output, out);
1256*01826a49SYabin Cui 
1257*01826a49SYabin Cui         total_output += match_length;
1258*01826a49SYabin Cui     }
1259*01826a49SYabin Cui 
1260*01826a49SYabin Cui     // Copy any leftover literals
1261*01826a49SYabin Cui     {
1262*01826a49SYabin Cui         size_t len = IO_istream_len(&litstream);
1263*01826a49SYabin Cui         copy_literals(len, &litstream, out);
1264*01826a49SYabin Cui         total_output += len;
1265*01826a49SYabin Cui     }
1266*01826a49SYabin Cui 
1267*01826a49SYabin Cui     ctx->current_total_output = total_output;
1268*01826a49SYabin Cui }
1269*01826a49SYabin Cui 
copy_literals(const size_t literal_length,istream_t * litstream,ostream_t * const out)1270*01826a49SYabin Cui static u32 copy_literals(const size_t literal_length, istream_t *litstream,
1271*01826a49SYabin Cui                          ostream_t *const out) {
1272*01826a49SYabin Cui     // If the sequence asks for more literals than are left, the
1273*01826a49SYabin Cui     // sequence must be corrupted
1274*01826a49SYabin Cui     if (literal_length > IO_istream_len(litstream)) {
1275*01826a49SYabin Cui         CORRUPTION();
1276*01826a49SYabin Cui     }
1277*01826a49SYabin Cui 
1278*01826a49SYabin Cui     u8 *const write_ptr = IO_get_write_ptr(out, literal_length);
1279*01826a49SYabin Cui     const u8 *const read_ptr =
1280*01826a49SYabin Cui          IO_get_read_ptr(litstream, literal_length);
1281*01826a49SYabin Cui     // Copy literals to output
1282*01826a49SYabin Cui     memcpy(write_ptr, read_ptr, literal_length);
1283*01826a49SYabin Cui 
1284*01826a49SYabin Cui     return literal_length;
1285*01826a49SYabin Cui }
1286*01826a49SYabin Cui 
compute_offset(sequence_command_t seq,u64 * const offset_hist)1287*01826a49SYabin Cui static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist) {
1288*01826a49SYabin Cui     size_t offset;
1289*01826a49SYabin Cui     // Offsets are special, we need to handle the repeat offsets
1290*01826a49SYabin Cui     if (seq.offset <= 3) {
1291*01826a49SYabin Cui         // "The first 3 values define a repeated offset and we will call
1292*01826a49SYabin Cui         // them Repeated_Offset1, Repeated_Offset2, and Repeated_Offset3.
1293*01826a49SYabin Cui         // They are sorted in recency order, with Repeated_Offset1 meaning
1294*01826a49SYabin Cui         // 'most recent one'".
1295*01826a49SYabin Cui 
1296*01826a49SYabin Cui         // Use 0 indexing for the array
1297*01826a49SYabin Cui         u32 idx = seq.offset - 1;
1298*01826a49SYabin Cui         if (seq.literal_length == 0) {
1299*01826a49SYabin Cui             // "There is an exception though, when current sequence's
1300*01826a49SYabin Cui             // literals length is 0. In this case, repeated offsets are
1301*01826a49SYabin Cui             // shifted by one, so Repeated_Offset1 becomes Repeated_Offset2,
1302*01826a49SYabin Cui             // Repeated_Offset2 becomes Repeated_Offset3, and
1303*01826a49SYabin Cui             // Repeated_Offset3 becomes Repeated_Offset1 - 1_byte."
1304*01826a49SYabin Cui             idx++;
1305*01826a49SYabin Cui         }
1306*01826a49SYabin Cui 
1307*01826a49SYabin Cui         if (idx == 0) {
1308*01826a49SYabin Cui             offset = offset_hist[0];
1309*01826a49SYabin Cui         } else {
1310*01826a49SYabin Cui             // If idx == 3 then literal length was 0 and the offset was 3,
1311*01826a49SYabin Cui             // as per the exception listed above
1312*01826a49SYabin Cui             offset = idx < 3 ? offset_hist[idx] : offset_hist[0] - 1;
1313*01826a49SYabin Cui 
1314*01826a49SYabin Cui             // If idx == 1 we don't need to modify offset_hist[2], since
1315*01826a49SYabin Cui             // we're using the second-most recent code
1316*01826a49SYabin Cui             if (idx > 1) {
1317*01826a49SYabin Cui                 offset_hist[2] = offset_hist[1];
1318*01826a49SYabin Cui             }
1319*01826a49SYabin Cui             offset_hist[1] = offset_hist[0];
1320*01826a49SYabin Cui             offset_hist[0] = offset;
1321*01826a49SYabin Cui         }
1322*01826a49SYabin Cui     } else {
1323*01826a49SYabin Cui         // When it's not a repeat offset:
1324*01826a49SYabin Cui         // "if (Offset_Value > 3) offset = Offset_Value - 3;"
1325*01826a49SYabin Cui         offset = seq.offset - 3;
1326*01826a49SYabin Cui 
1327*01826a49SYabin Cui         // Shift back history
1328*01826a49SYabin Cui         offset_hist[2] = offset_hist[1];
1329*01826a49SYabin Cui         offset_hist[1] = offset_hist[0];
1330*01826a49SYabin Cui         offset_hist[0] = offset;
1331*01826a49SYabin Cui     }
1332*01826a49SYabin Cui     return offset;
1333*01826a49SYabin Cui }
1334*01826a49SYabin Cui 
execute_match_copy(frame_context_t * const ctx,size_t offset,size_t match_length,size_t total_output,ostream_t * const out)1335*01826a49SYabin Cui static void execute_match_copy(frame_context_t *const ctx, size_t offset,
1336*01826a49SYabin Cui                               size_t match_length, size_t total_output,
1337*01826a49SYabin Cui                               ostream_t *const out) {
1338*01826a49SYabin Cui     u8 *write_ptr = IO_get_write_ptr(out, match_length);
1339*01826a49SYabin Cui     if (total_output <= ctx->header.window_size) {
1340*01826a49SYabin Cui         // In this case offset might go back into the dictionary
1341*01826a49SYabin Cui         if (offset > total_output + ctx->dict_content_len) {
1342*01826a49SYabin Cui             // The offset goes beyond even the dictionary
1343*01826a49SYabin Cui             CORRUPTION();
1344*01826a49SYabin Cui         }
1345*01826a49SYabin Cui 
1346*01826a49SYabin Cui         if (offset > total_output) {
1347*01826a49SYabin Cui             // "The rest of the dictionary is its content. The content act
1348*01826a49SYabin Cui             // as a "past" in front of data to compress or decompress, so it
1349*01826a49SYabin Cui             // can be referenced in sequence commands."
1350*01826a49SYabin Cui             const size_t dict_copy =
1351*01826a49SYabin Cui                 MIN(offset - total_output, match_length);
1352*01826a49SYabin Cui             const size_t dict_offset =
1353*01826a49SYabin Cui                 ctx->dict_content_len - (offset - total_output);
1354*01826a49SYabin Cui 
1355*01826a49SYabin Cui             memcpy(write_ptr, ctx->dict_content + dict_offset, dict_copy);
1356*01826a49SYabin Cui             write_ptr += dict_copy;
1357*01826a49SYabin Cui             match_length -= dict_copy;
1358*01826a49SYabin Cui         }
1359*01826a49SYabin Cui     } else if (offset > ctx->header.window_size) {
1360*01826a49SYabin Cui         CORRUPTION();
1361*01826a49SYabin Cui     }
1362*01826a49SYabin Cui 
1363*01826a49SYabin Cui     // We must copy byte by byte because the match length might be larger
1364*01826a49SYabin Cui     // than the offset
1365*01826a49SYabin Cui     // ex: if the output so far was "abc", a command with offset=3 and
1366*01826a49SYabin Cui     // match_length=6 would produce "abcabcabc" as the new output
1367*01826a49SYabin Cui     for (size_t j = 0; j < match_length; j++) {
1368*01826a49SYabin Cui         *write_ptr = *(write_ptr - offset);
1369*01826a49SYabin Cui         write_ptr++;
1370*01826a49SYabin Cui     }
1371*01826a49SYabin Cui }
1372*01826a49SYabin Cui /******* END SEQUENCE EXECUTION ***********************************************/
1373*01826a49SYabin Cui 
1374*01826a49SYabin Cui /******* OUTPUT SIZE COUNTING *************************************************/
1375*01826a49SYabin Cui /// Get the decompressed size of an input stream so memory can be allocated in
1376*01826a49SYabin Cui /// advance.
1377*01826a49SYabin Cui /// This implementation assumes `src` points to a single ZSTD-compressed frame
ZSTD_get_decompressed_size(const void * src,const size_t src_len)1378*01826a49SYabin Cui size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) {
1379*01826a49SYabin Cui     istream_t in = IO_make_istream(src, src_len);
1380*01826a49SYabin Cui 
1381*01826a49SYabin Cui     // get decompressed size from ZSTD frame header
1382*01826a49SYabin Cui     {
1383*01826a49SYabin Cui         const u32 magic_number = (u32)IO_read_bits(&in, 32);
1384*01826a49SYabin Cui 
1385*01826a49SYabin Cui         if (magic_number == ZSTD_MAGIC_NUMBER) {
1386*01826a49SYabin Cui             // ZSTD frame
1387*01826a49SYabin Cui             frame_header_t header;
1388*01826a49SYabin Cui             parse_frame_header(&header, &in);
1389*01826a49SYabin Cui 
1390*01826a49SYabin Cui             if (header.frame_content_size == 0 && !header.single_segment_flag) {
1391*01826a49SYabin Cui                 // Content size not provided, we can't tell
1392*01826a49SYabin Cui                 return (size_t)-1;
1393*01826a49SYabin Cui             }
1394*01826a49SYabin Cui 
1395*01826a49SYabin Cui             return header.frame_content_size;
1396*01826a49SYabin Cui         } else {
1397*01826a49SYabin Cui             // not a real frame or skippable frame
1398*01826a49SYabin Cui             ERROR("ZSTD frame magic number did not match");
1399*01826a49SYabin Cui         }
1400*01826a49SYabin Cui     }
1401*01826a49SYabin Cui }
1402*01826a49SYabin Cui /******* END OUTPUT SIZE COUNTING *********************************************/
1403*01826a49SYabin Cui 
1404*01826a49SYabin Cui /******* DICTIONARY PARSING ***************************************************/
create_dictionary(void)1405*01826a49SYabin Cui dictionary_t* create_dictionary(void) {
1406*01826a49SYabin Cui     dictionary_t* const dict = calloc(1, sizeof(dictionary_t));
1407*01826a49SYabin Cui     if (!dict) {
1408*01826a49SYabin Cui         BAD_ALLOC();
1409*01826a49SYabin Cui     }
1410*01826a49SYabin Cui     return dict;
1411*01826a49SYabin Cui }
1412*01826a49SYabin Cui 
1413*01826a49SYabin Cui /// Free an allocated dictionary
free_dictionary(dictionary_t * const dict)1414*01826a49SYabin Cui void free_dictionary(dictionary_t *const dict) {
1415*01826a49SYabin Cui     HUF_free_dtable(&dict->literals_dtable);
1416*01826a49SYabin Cui     FSE_free_dtable(&dict->ll_dtable);
1417*01826a49SYabin Cui     FSE_free_dtable(&dict->of_dtable);
1418*01826a49SYabin Cui     FSE_free_dtable(&dict->ml_dtable);
1419*01826a49SYabin Cui 
1420*01826a49SYabin Cui     free(dict->content);
1421*01826a49SYabin Cui 
1422*01826a49SYabin Cui     memset(dict, 0, sizeof(dictionary_t));
1423*01826a49SYabin Cui 
1424*01826a49SYabin Cui     free(dict);
1425*01826a49SYabin Cui }
1426*01826a49SYabin Cui 
1427*01826a49SYabin Cui 
1428*01826a49SYabin Cui #if !defined(ZDEC_NO_DICTIONARY)
1429*01826a49SYabin Cui #define DICT_SIZE_ERROR() ERROR("Dictionary size cannot be less than 8 bytes")
1430*01826a49SYabin Cui #define NULL_SRC() ERROR("Tried to create dictionary with pointer to null src");
1431*01826a49SYabin Cui 
1432*01826a49SYabin Cui static void init_dictionary_content(dictionary_t *const dict,
1433*01826a49SYabin Cui                                     istream_t *const in);
1434*01826a49SYabin Cui 
parse_dictionary(dictionary_t * const dict,const void * src,size_t src_len)1435*01826a49SYabin Cui void parse_dictionary(dictionary_t *const dict, const void *src,
1436*01826a49SYabin Cui                              size_t src_len) {
1437*01826a49SYabin Cui     const u8 *byte_src = (const u8 *)src;
1438*01826a49SYabin Cui     memset(dict, 0, sizeof(dictionary_t));
1439*01826a49SYabin Cui     if (src == NULL) { /* cannot initialize dictionary with null src */
1440*01826a49SYabin Cui         NULL_SRC();
1441*01826a49SYabin Cui     }
1442*01826a49SYabin Cui     if (src_len < 8) {
1443*01826a49SYabin Cui         DICT_SIZE_ERROR();
1444*01826a49SYabin Cui     }
1445*01826a49SYabin Cui 
1446*01826a49SYabin Cui     istream_t in = IO_make_istream(byte_src, src_len);
1447*01826a49SYabin Cui 
1448*01826a49SYabin Cui     const u32 magic_number = IO_read_bits(&in, 32);
1449*01826a49SYabin Cui     if (magic_number != 0xEC30A437) {
1450*01826a49SYabin Cui         // raw content dict
1451*01826a49SYabin Cui         IO_rewind_bits(&in, 32);
1452*01826a49SYabin Cui         init_dictionary_content(dict, &in);
1453*01826a49SYabin Cui         return;
1454*01826a49SYabin Cui     }
1455*01826a49SYabin Cui 
1456*01826a49SYabin Cui     dict->dictionary_id = IO_read_bits(&in, 32);
1457*01826a49SYabin Cui 
1458*01826a49SYabin Cui     // "Entropy_Tables : following the same format as the tables in compressed
1459*01826a49SYabin Cui     // blocks. They are stored in following order : Huffman tables for literals,
1460*01826a49SYabin Cui     // FSE table for offsets, FSE table for match lengths, and FSE table for
1461*01826a49SYabin Cui     // literals lengths. It's finally followed by 3 offset values, populating
1462*01826a49SYabin Cui     // recent offsets (instead of using {1,4,8}), stored in order, 4-bytes
1463*01826a49SYabin Cui     // little-endian each, for a total of 12 bytes. Each recent offset must have
1464*01826a49SYabin Cui     // a value < dictionary size."
1465*01826a49SYabin Cui     decode_huf_table(&dict->literals_dtable, &in);
1466*01826a49SYabin Cui     decode_seq_table(&dict->of_dtable, &in, seq_offset, seq_fse);
1467*01826a49SYabin Cui     decode_seq_table(&dict->ml_dtable, &in, seq_match_length, seq_fse);
1468*01826a49SYabin Cui     decode_seq_table(&dict->ll_dtable, &in, seq_literal_length, seq_fse);
1469*01826a49SYabin Cui 
1470*01826a49SYabin Cui     // Read in the previous offset history
1471*01826a49SYabin Cui     dict->previous_offsets[0] = IO_read_bits(&in, 32);
1472*01826a49SYabin Cui     dict->previous_offsets[1] = IO_read_bits(&in, 32);
1473*01826a49SYabin Cui     dict->previous_offsets[2] = IO_read_bits(&in, 32);
1474*01826a49SYabin Cui 
1475*01826a49SYabin Cui     // Ensure the provided offsets aren't too large
1476*01826a49SYabin Cui     // "Each recent offset must have a value < dictionary size."
1477*01826a49SYabin Cui     for (int i = 0; i < 3; i++) {
1478*01826a49SYabin Cui         if (dict->previous_offsets[i] > src_len) {
1479*01826a49SYabin Cui             ERROR("Dictionary corrupted");
1480*01826a49SYabin Cui         }
1481*01826a49SYabin Cui     }
1482*01826a49SYabin Cui 
1483*01826a49SYabin Cui     // "Content : The rest of the dictionary is its content. The content act as
1484*01826a49SYabin Cui     // a "past" in front of data to compress or decompress, so it can be
1485*01826a49SYabin Cui     // referenced in sequence commands."
1486*01826a49SYabin Cui     init_dictionary_content(dict, &in);
1487*01826a49SYabin Cui }
1488*01826a49SYabin Cui 
init_dictionary_content(dictionary_t * const dict,istream_t * const in)1489*01826a49SYabin Cui static void init_dictionary_content(dictionary_t *const dict,
1490*01826a49SYabin Cui                                     istream_t *const in) {
1491*01826a49SYabin Cui     // Copy in the content
1492*01826a49SYabin Cui     dict->content_size = IO_istream_len(in);
1493*01826a49SYabin Cui     dict->content = malloc(dict->content_size);
1494*01826a49SYabin Cui     if (!dict->content) {
1495*01826a49SYabin Cui         BAD_ALLOC();
1496*01826a49SYabin Cui     }
1497*01826a49SYabin Cui 
1498*01826a49SYabin Cui     const u8 *const content = IO_get_read_ptr(in, dict->content_size);
1499*01826a49SYabin Cui 
1500*01826a49SYabin Cui     memcpy(dict->content, content, dict->content_size);
1501*01826a49SYabin Cui }
1502*01826a49SYabin Cui 
HUF_copy_dtable(HUF_dtable * const dst,const HUF_dtable * const src)1503*01826a49SYabin Cui static void HUF_copy_dtable(HUF_dtable *const dst,
1504*01826a49SYabin Cui                             const HUF_dtable *const src) {
1505*01826a49SYabin Cui     if (src->max_bits == 0) {
1506*01826a49SYabin Cui         memset(dst, 0, sizeof(HUF_dtable));
1507*01826a49SYabin Cui         return;
1508*01826a49SYabin Cui     }
1509*01826a49SYabin Cui 
1510*01826a49SYabin Cui     const size_t size = (size_t)1 << src->max_bits;
1511*01826a49SYabin Cui     dst->max_bits = src->max_bits;
1512*01826a49SYabin Cui 
1513*01826a49SYabin Cui     dst->symbols = malloc(size);
1514*01826a49SYabin Cui     dst->num_bits = malloc(size);
1515*01826a49SYabin Cui     if (!dst->symbols || !dst->num_bits) {
1516*01826a49SYabin Cui         BAD_ALLOC();
1517*01826a49SYabin Cui     }
1518*01826a49SYabin Cui 
1519*01826a49SYabin Cui     memcpy(dst->symbols, src->symbols, size);
1520*01826a49SYabin Cui     memcpy(dst->num_bits, src->num_bits, size);
1521*01826a49SYabin Cui }
1522*01826a49SYabin Cui 
FSE_copy_dtable(FSE_dtable * const dst,const FSE_dtable * const src)1523*01826a49SYabin Cui static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src) {
1524*01826a49SYabin Cui     if (src->accuracy_log == 0) {
1525*01826a49SYabin Cui         memset(dst, 0, sizeof(FSE_dtable));
1526*01826a49SYabin Cui         return;
1527*01826a49SYabin Cui     }
1528*01826a49SYabin Cui 
1529*01826a49SYabin Cui     size_t size = (size_t)1 << src->accuracy_log;
1530*01826a49SYabin Cui     dst->accuracy_log = src->accuracy_log;
1531*01826a49SYabin Cui 
1532*01826a49SYabin Cui     dst->symbols = malloc(size);
1533*01826a49SYabin Cui     dst->num_bits = malloc(size);
1534*01826a49SYabin Cui     dst->new_state_base = malloc(size * sizeof(u16));
1535*01826a49SYabin Cui     if (!dst->symbols || !dst->num_bits || !dst->new_state_base) {
1536*01826a49SYabin Cui         BAD_ALLOC();
1537*01826a49SYabin Cui     }
1538*01826a49SYabin Cui 
1539*01826a49SYabin Cui     memcpy(dst->symbols, src->symbols, size);
1540*01826a49SYabin Cui     memcpy(dst->num_bits, src->num_bits, size);
1541*01826a49SYabin Cui     memcpy(dst->new_state_base, src->new_state_base, size * sizeof(u16));
1542*01826a49SYabin Cui }
1543*01826a49SYabin Cui 
1544*01826a49SYabin Cui /// A dictionary acts as initializing values for the frame context before
1545*01826a49SYabin Cui /// decompression, so we implement it by applying it's predetermined
1546*01826a49SYabin Cui /// tables and content to the context before beginning decompression
frame_context_apply_dict(frame_context_t * const ctx,const dictionary_t * const dict)1547*01826a49SYabin Cui static void frame_context_apply_dict(frame_context_t *const ctx,
1548*01826a49SYabin Cui                                      const dictionary_t *const dict) {
1549*01826a49SYabin Cui     // If the content pointer is NULL then it must be an empty dict
1550*01826a49SYabin Cui     if (!dict || !dict->content)
1551*01826a49SYabin Cui         return;
1552*01826a49SYabin Cui 
1553*01826a49SYabin Cui     // If the requested dictionary_id is non-zero, the correct dictionary must
1554*01826a49SYabin Cui     // be present
1555*01826a49SYabin Cui     if (ctx->header.dictionary_id != 0 &&
1556*01826a49SYabin Cui         ctx->header.dictionary_id != dict->dictionary_id) {
1557*01826a49SYabin Cui         ERROR("Wrong dictionary provided");
1558*01826a49SYabin Cui     }
1559*01826a49SYabin Cui 
1560*01826a49SYabin Cui     // Copy the dict content to the context for references during sequence
1561*01826a49SYabin Cui     // execution
1562*01826a49SYabin Cui     ctx->dict_content = dict->content;
1563*01826a49SYabin Cui     ctx->dict_content_len = dict->content_size;
1564*01826a49SYabin Cui 
1565*01826a49SYabin Cui     // If it's a formatted dict copy the precomputed tables in so they can
1566*01826a49SYabin Cui     // be used in the table repeat modes
1567*01826a49SYabin Cui     if (dict->dictionary_id != 0) {
1568*01826a49SYabin Cui         // Deep copy the entropy tables so they can be freed independently of
1569*01826a49SYabin Cui         // the dictionary struct
1570*01826a49SYabin Cui         HUF_copy_dtable(&ctx->literals_dtable, &dict->literals_dtable);
1571*01826a49SYabin Cui         FSE_copy_dtable(&ctx->ll_dtable, &dict->ll_dtable);
1572*01826a49SYabin Cui         FSE_copy_dtable(&ctx->of_dtable, &dict->of_dtable);
1573*01826a49SYabin Cui         FSE_copy_dtable(&ctx->ml_dtable, &dict->ml_dtable);
1574*01826a49SYabin Cui 
1575*01826a49SYabin Cui         // Copy the repeated offsets
1576*01826a49SYabin Cui         memcpy(ctx->previous_offsets, dict->previous_offsets,
1577*01826a49SYabin Cui                sizeof(ctx->previous_offsets));
1578*01826a49SYabin Cui     }
1579*01826a49SYabin Cui }
1580*01826a49SYabin Cui 
1581*01826a49SYabin Cui #else  // ZDEC_NO_DICTIONARY is defined
1582*01826a49SYabin Cui 
frame_context_apply_dict(frame_context_t * const ctx,const dictionary_t * const dict)1583*01826a49SYabin Cui static void frame_context_apply_dict(frame_context_t *const ctx,
1584*01826a49SYabin Cui                                      const dictionary_t *const dict) {
1585*01826a49SYabin Cui     (void)ctx;
1586*01826a49SYabin Cui     if (dict && dict->content) ERROR("dictionary not supported");
1587*01826a49SYabin Cui }
1588*01826a49SYabin Cui 
1589*01826a49SYabin Cui #endif
1590*01826a49SYabin Cui /******* END DICTIONARY PARSING ***********************************************/
1591*01826a49SYabin Cui 
1592*01826a49SYabin Cui /******* IO STREAM OPERATIONS *************************************************/
1593*01826a49SYabin Cui 
1594*01826a49SYabin Cui /// Reads `num` bits from a bitstream, and updates the internal offset
IO_read_bits(istream_t * const in,const int num_bits)1595*01826a49SYabin Cui static inline u64 IO_read_bits(istream_t *const in, const int num_bits) {
1596*01826a49SYabin Cui     if (num_bits > 64 || num_bits <= 0) {
1597*01826a49SYabin Cui         ERROR("Attempt to read an invalid number of bits");
1598*01826a49SYabin Cui     }
1599*01826a49SYabin Cui 
1600*01826a49SYabin Cui     const size_t bytes = (num_bits + in->bit_offset + 7) / 8;
1601*01826a49SYabin Cui     const size_t full_bytes = (num_bits + in->bit_offset) / 8;
1602*01826a49SYabin Cui     if (bytes > in->len) {
1603*01826a49SYabin Cui         INP_SIZE();
1604*01826a49SYabin Cui     }
1605*01826a49SYabin Cui 
1606*01826a49SYabin Cui     const u64 result = read_bits_LE(in->ptr, num_bits, in->bit_offset);
1607*01826a49SYabin Cui 
1608*01826a49SYabin Cui     in->bit_offset = (num_bits + in->bit_offset) % 8;
1609*01826a49SYabin Cui     in->ptr += full_bytes;
1610*01826a49SYabin Cui     in->len -= full_bytes;
1611*01826a49SYabin Cui 
1612*01826a49SYabin Cui     return result;
1613*01826a49SYabin Cui }
1614*01826a49SYabin Cui 
1615*01826a49SYabin Cui /// If a non-zero number of bits have been read from the current byte, advance
1616*01826a49SYabin Cui /// the offset to the next byte
IO_rewind_bits(istream_t * const in,int num_bits)1617*01826a49SYabin Cui static inline void IO_rewind_bits(istream_t *const in, int num_bits) {
1618*01826a49SYabin Cui     if (num_bits < 0) {
1619*01826a49SYabin Cui         ERROR("Attempting to rewind stream by a negative number of bits");
1620*01826a49SYabin Cui     }
1621*01826a49SYabin Cui 
1622*01826a49SYabin Cui     // move the offset back by `num_bits` bits
1623*01826a49SYabin Cui     const int new_offset = in->bit_offset - num_bits;
1624*01826a49SYabin Cui     // determine the number of whole bytes we have to rewind, rounding up to an
1625*01826a49SYabin Cui     // integer number (e.g. if `new_offset == -5`, `bytes == 1`)
1626*01826a49SYabin Cui     const i64 bytes = -(new_offset - 7) / 8;
1627*01826a49SYabin Cui 
1628*01826a49SYabin Cui     in->ptr -= bytes;
1629*01826a49SYabin Cui     in->len += bytes;
1630*01826a49SYabin Cui     // make sure the resulting `bit_offset` is positive, as mod in C does not
1631*01826a49SYabin Cui     // convert numbers from negative to positive (e.g. -22 % 8 == -6)
1632*01826a49SYabin Cui     in->bit_offset = ((new_offset % 8) + 8) % 8;
1633*01826a49SYabin Cui }
1634*01826a49SYabin Cui 
1635*01826a49SYabin Cui /// If the remaining bits in a byte will be unused, advance to the end of the
1636*01826a49SYabin Cui /// byte
IO_align_stream(istream_t * const in)1637*01826a49SYabin Cui static inline void IO_align_stream(istream_t *const in) {
1638*01826a49SYabin Cui     if (in->bit_offset != 0) {
1639*01826a49SYabin Cui         if (in->len == 0) {
1640*01826a49SYabin Cui             INP_SIZE();
1641*01826a49SYabin Cui         }
1642*01826a49SYabin Cui         in->ptr++;
1643*01826a49SYabin Cui         in->len--;
1644*01826a49SYabin Cui         in->bit_offset = 0;
1645*01826a49SYabin Cui     }
1646*01826a49SYabin Cui }
1647*01826a49SYabin Cui 
1648*01826a49SYabin Cui /// Write the given byte into the output stream
IO_write_byte(ostream_t * const out,u8 symb)1649*01826a49SYabin Cui static inline void IO_write_byte(ostream_t *const out, u8 symb) {
1650*01826a49SYabin Cui     if (out->len == 0) {
1651*01826a49SYabin Cui         OUT_SIZE();
1652*01826a49SYabin Cui     }
1653*01826a49SYabin Cui 
1654*01826a49SYabin Cui     out->ptr[0] = symb;
1655*01826a49SYabin Cui     out->ptr++;
1656*01826a49SYabin Cui     out->len--;
1657*01826a49SYabin Cui }
1658*01826a49SYabin Cui 
1659*01826a49SYabin Cui /// Returns the number of bytes left to be read in this stream.  The stream must
1660*01826a49SYabin Cui /// be byte aligned.
IO_istream_len(const istream_t * const in)1661*01826a49SYabin Cui static inline size_t IO_istream_len(const istream_t *const in) {
1662*01826a49SYabin Cui     return in->len;
1663*01826a49SYabin Cui }
1664*01826a49SYabin Cui 
1665*01826a49SYabin Cui /// Returns a pointer where `len` bytes can be read, and advances the internal
1666*01826a49SYabin Cui /// state.  The stream must be byte aligned.
IO_get_read_ptr(istream_t * const in,size_t len)1667*01826a49SYabin Cui static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len) {
1668*01826a49SYabin Cui     if (len > in->len) {
1669*01826a49SYabin Cui         INP_SIZE();
1670*01826a49SYabin Cui     }
1671*01826a49SYabin Cui     if (in->bit_offset != 0) {
1672*01826a49SYabin Cui         ERROR("Attempting to operate on a non-byte aligned stream");
1673*01826a49SYabin Cui     }
1674*01826a49SYabin Cui     const u8 *const ptr = in->ptr;
1675*01826a49SYabin Cui     in->ptr += len;
1676*01826a49SYabin Cui     in->len -= len;
1677*01826a49SYabin Cui 
1678*01826a49SYabin Cui     return ptr;
1679*01826a49SYabin Cui }
1680*01826a49SYabin Cui /// Returns a pointer to write `len` bytes to, and advances the internal state
IO_get_write_ptr(ostream_t * const out,size_t len)1681*01826a49SYabin Cui static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len) {
1682*01826a49SYabin Cui     if (len > out->len) {
1683*01826a49SYabin Cui         OUT_SIZE();
1684*01826a49SYabin Cui     }
1685*01826a49SYabin Cui     u8 *const ptr = out->ptr;
1686*01826a49SYabin Cui     out->ptr += len;
1687*01826a49SYabin Cui     out->len -= len;
1688*01826a49SYabin Cui 
1689*01826a49SYabin Cui     return ptr;
1690*01826a49SYabin Cui }
1691*01826a49SYabin Cui 
1692*01826a49SYabin Cui /// Advance the inner state by `len` bytes
IO_advance_input(istream_t * const in,size_t len)1693*01826a49SYabin Cui static inline void IO_advance_input(istream_t *const in, size_t len) {
1694*01826a49SYabin Cui     if (len > in->len) {
1695*01826a49SYabin Cui          INP_SIZE();
1696*01826a49SYabin Cui     }
1697*01826a49SYabin Cui     if (in->bit_offset != 0) {
1698*01826a49SYabin Cui         ERROR("Attempting to operate on a non-byte aligned stream");
1699*01826a49SYabin Cui     }
1700*01826a49SYabin Cui 
1701*01826a49SYabin Cui     in->ptr += len;
1702*01826a49SYabin Cui     in->len -= len;
1703*01826a49SYabin Cui }
1704*01826a49SYabin Cui 
1705*01826a49SYabin Cui /// Returns an `ostream_t` constructed from the given pointer and length
IO_make_ostream(u8 * out,size_t len)1706*01826a49SYabin Cui static inline ostream_t IO_make_ostream(u8 *out, size_t len) {
1707*01826a49SYabin Cui     return (ostream_t) { out, len };
1708*01826a49SYabin Cui }
1709*01826a49SYabin Cui 
1710*01826a49SYabin Cui /// Returns an `istream_t` constructed from the given pointer and length
IO_make_istream(const u8 * in,size_t len)1711*01826a49SYabin Cui static inline istream_t IO_make_istream(const u8 *in, size_t len) {
1712*01826a49SYabin Cui     return (istream_t) { in, len, 0 };
1713*01826a49SYabin Cui }
1714*01826a49SYabin Cui 
1715*01826a49SYabin Cui /// Returns an `istream_t` with the same base as `in`, and length `len`
1716*01826a49SYabin Cui /// Then, advance `in` to account for the consumed bytes
1717*01826a49SYabin Cui /// `in` must be byte aligned
IO_make_sub_istream(istream_t * const in,size_t len)1718*01826a49SYabin Cui static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len) {
1719*01826a49SYabin Cui     // Consume `len` bytes of the parent stream
1720*01826a49SYabin Cui     const u8 *const ptr = IO_get_read_ptr(in, len);
1721*01826a49SYabin Cui 
1722*01826a49SYabin Cui     // Make a substream using the pointer to those `len` bytes
1723*01826a49SYabin Cui     return IO_make_istream(ptr, len);
1724*01826a49SYabin Cui }
1725*01826a49SYabin Cui /******* END IO STREAM OPERATIONS *********************************************/
1726*01826a49SYabin Cui 
1727*01826a49SYabin Cui /******* BITSTREAM OPERATIONS *************************************************/
1728*01826a49SYabin Cui /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits
read_bits_LE(const u8 * src,const int num_bits,const size_t offset)1729*01826a49SYabin Cui static inline u64 read_bits_LE(const u8 *src, const int num_bits,
1730*01826a49SYabin Cui                                const size_t offset) {
1731*01826a49SYabin Cui     if (num_bits > 64) {
1732*01826a49SYabin Cui         ERROR("Attempt to read an invalid number of bits");
1733*01826a49SYabin Cui     }
1734*01826a49SYabin Cui 
1735*01826a49SYabin Cui     // Skip over bytes that aren't in range
1736*01826a49SYabin Cui     src += offset / 8;
1737*01826a49SYabin Cui     size_t bit_offset = offset % 8;
1738*01826a49SYabin Cui     u64 res = 0;
1739*01826a49SYabin Cui 
1740*01826a49SYabin Cui     int shift = 0;
1741*01826a49SYabin Cui     int left = num_bits;
1742*01826a49SYabin Cui     while (left > 0) {
1743*01826a49SYabin Cui         u64 mask = left >= 8 ? 0xff : (((u64)1 << left) - 1);
1744*01826a49SYabin Cui         // Read the next byte, shift it to account for the offset, and then mask
1745*01826a49SYabin Cui         // out the top part if we don't need all the bits
1746*01826a49SYabin Cui         res += (((u64)*src++ >> bit_offset) & mask) << shift;
1747*01826a49SYabin Cui         shift += 8 - bit_offset;
1748*01826a49SYabin Cui         left -= 8 - bit_offset;
1749*01826a49SYabin Cui         bit_offset = 0;
1750*01826a49SYabin Cui     }
1751*01826a49SYabin Cui 
1752*01826a49SYabin Cui     return res;
1753*01826a49SYabin Cui }
1754*01826a49SYabin Cui 
1755*01826a49SYabin Cui /// Read bits from the end of a HUF or FSE bitstream.  `offset` is in bits, so
1756*01826a49SYabin Cui /// it updates `offset` to `offset - bits`, and then reads `bits` bits from
1757*01826a49SYabin Cui /// `src + offset`.  If the offset becomes negative, the extra bits at the
1758*01826a49SYabin Cui /// bottom are filled in with `0` bits instead of reading from before `src`.
STREAM_read_bits(const u8 * const src,const int bits,i64 * const offset)1759*01826a49SYabin Cui static inline u64 STREAM_read_bits(const u8 *const src, const int bits,
1760*01826a49SYabin Cui                                    i64 *const offset) {
1761*01826a49SYabin Cui     *offset = *offset - bits;
1762*01826a49SYabin Cui     size_t actual_off = *offset;
1763*01826a49SYabin Cui     size_t actual_bits = bits;
1764*01826a49SYabin Cui     // Don't actually read bits from before the start of src, so if `*offset <
1765*01826a49SYabin Cui     // 0` fix actual_off and actual_bits to reflect the quantity to read
1766*01826a49SYabin Cui     if (*offset < 0) {
1767*01826a49SYabin Cui         actual_bits += *offset;
1768*01826a49SYabin Cui         actual_off = 0;
1769*01826a49SYabin Cui     }
1770*01826a49SYabin Cui     u64 res = read_bits_LE(src, actual_bits, actual_off);
1771*01826a49SYabin Cui 
1772*01826a49SYabin Cui     if (*offset < 0) {
1773*01826a49SYabin Cui         // Fill in the bottom "overflowed" bits with 0's
1774*01826a49SYabin Cui         res = -*offset >= 64 ? 0 : (res << -*offset);
1775*01826a49SYabin Cui     }
1776*01826a49SYabin Cui     return res;
1777*01826a49SYabin Cui }
1778*01826a49SYabin Cui /******* END BITSTREAM OPERATIONS *********************************************/
1779*01826a49SYabin Cui 
1780*01826a49SYabin Cui /******* BIT COUNTING OPERATIONS **********************************************/
1781*01826a49SYabin Cui /// Returns `x`, where `2^x` is the largest power of 2 less than or equal to
1782*01826a49SYabin Cui /// `num`, or `-1` if `num == 0`.
highest_set_bit(const u64 num)1783*01826a49SYabin Cui static inline int highest_set_bit(const u64 num) {
1784*01826a49SYabin Cui     for (int i = 63; i >= 0; i--) {
1785*01826a49SYabin Cui         if (((u64)1 << i) <= num) {
1786*01826a49SYabin Cui             return i;
1787*01826a49SYabin Cui         }
1788*01826a49SYabin Cui     }
1789*01826a49SYabin Cui     return -1;
1790*01826a49SYabin Cui }
1791*01826a49SYabin Cui /******* END BIT COUNTING OPERATIONS ******************************************/
1792*01826a49SYabin Cui 
1793*01826a49SYabin Cui /******* HUFFMAN PRIMITIVES ***************************************************/
HUF_decode_symbol(const HUF_dtable * const dtable,u16 * const state,const u8 * const src,i64 * const offset)1794*01826a49SYabin Cui static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable,
1795*01826a49SYabin Cui                                    u16 *const state, const u8 *const src,
1796*01826a49SYabin Cui                                    i64 *const offset) {
1797*01826a49SYabin Cui     // Look up the symbol and number of bits to read
1798*01826a49SYabin Cui     const u8 symb = dtable->symbols[*state];
1799*01826a49SYabin Cui     const u8 bits = dtable->num_bits[*state];
1800*01826a49SYabin Cui     const u16 rest = STREAM_read_bits(src, bits, offset);
1801*01826a49SYabin Cui     // Shift `bits` bits out of the state, keeping the low order bits that
1802*01826a49SYabin Cui     // weren't necessary to determine this symbol.  Then add in the new bits
1803*01826a49SYabin Cui     // read from the stream.
1804*01826a49SYabin Cui     *state = ((*state << bits) + rest) & (((u16)1 << dtable->max_bits) - 1);
1805*01826a49SYabin Cui 
1806*01826a49SYabin Cui     return symb;
1807*01826a49SYabin Cui }
1808*01826a49SYabin Cui 
HUF_init_state(const HUF_dtable * const dtable,u16 * const state,const u8 * const src,i64 * const offset)1809*01826a49SYabin Cui static inline void HUF_init_state(const HUF_dtable *const dtable,
1810*01826a49SYabin Cui                                   u16 *const state, const u8 *const src,
1811*01826a49SYabin Cui                                   i64 *const offset) {
1812*01826a49SYabin Cui     // Read in a full `dtable->max_bits` bits to initialize the state
1813*01826a49SYabin Cui     const u8 bits = dtable->max_bits;
1814*01826a49SYabin Cui     *state = STREAM_read_bits(src, bits, offset);
1815*01826a49SYabin Cui }
1816*01826a49SYabin Cui 
HUF_decompress_1stream(const HUF_dtable * const dtable,ostream_t * const out,istream_t * const in)1817*01826a49SYabin Cui static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
1818*01826a49SYabin Cui                                      ostream_t *const out,
1819*01826a49SYabin Cui                                      istream_t *const in) {
1820*01826a49SYabin Cui     const size_t len = IO_istream_len(in);
1821*01826a49SYabin Cui     if (len == 0) {
1822*01826a49SYabin Cui         INP_SIZE();
1823*01826a49SYabin Cui     }
1824*01826a49SYabin Cui     const u8 *const src = IO_get_read_ptr(in, len);
1825*01826a49SYabin Cui 
1826*01826a49SYabin Cui     // "Each bitstream must be read backward, that is starting from the end down
1827*01826a49SYabin Cui     // to the beginning. Therefore it's necessary to know the size of each
1828*01826a49SYabin Cui     // bitstream.
1829*01826a49SYabin Cui     //
1830*01826a49SYabin Cui     // It's also necessary to know exactly which bit is the latest. This is
1831*01826a49SYabin Cui     // detected by a final bit flag : the highest bit of latest byte is a
1832*01826a49SYabin Cui     // final-bit-flag. Consequently, a last byte of 0 is not possible. And the
1833*01826a49SYabin Cui     // final-bit-flag itself is not part of the useful bitstream. Hence, the
1834*01826a49SYabin Cui     // last byte contains between 0 and 7 useful bits."
1835*01826a49SYabin Cui     const int padding = 8 - highest_set_bit(src[len - 1]);
1836*01826a49SYabin Cui 
1837*01826a49SYabin Cui     // Offset starts at the end because HUF streams are read backwards
1838*01826a49SYabin Cui     i64 bit_offset = len * 8 - padding;
1839*01826a49SYabin Cui     u16 state;
1840*01826a49SYabin Cui 
1841*01826a49SYabin Cui     HUF_init_state(dtable, &state, src, &bit_offset);
1842*01826a49SYabin Cui 
1843*01826a49SYabin Cui     size_t symbols_written = 0;
1844*01826a49SYabin Cui     while (bit_offset > -dtable->max_bits) {
1845*01826a49SYabin Cui         // Iterate over the stream, decoding one symbol at a time
1846*01826a49SYabin Cui         IO_write_byte(out, HUF_decode_symbol(dtable, &state, src, &bit_offset));
1847*01826a49SYabin Cui         symbols_written++;
1848*01826a49SYabin Cui     }
1849*01826a49SYabin Cui     // "The process continues up to reading the required number of symbols per
1850*01826a49SYabin Cui     // stream. If a bitstream is not entirely and exactly consumed, hence
1851*01826a49SYabin Cui     // reaching exactly its beginning position with all bits consumed, the
1852*01826a49SYabin Cui     // decoding process is considered faulty."
1853*01826a49SYabin Cui 
1854*01826a49SYabin Cui     // When all symbols have been decoded, the final state value shouldn't have
1855*01826a49SYabin Cui     // any data from the stream, so it should have "read" dtable->max_bits from
1856*01826a49SYabin Cui     // before the start of `src`
1857*01826a49SYabin Cui     // Therefore `offset`, the edge to start reading new bits at, should be
1858*01826a49SYabin Cui     // dtable->max_bits before the start of the stream
1859*01826a49SYabin Cui     if (bit_offset != -dtable->max_bits) {
1860*01826a49SYabin Cui         CORRUPTION();
1861*01826a49SYabin Cui     }
1862*01826a49SYabin Cui 
1863*01826a49SYabin Cui     return symbols_written;
1864*01826a49SYabin Cui }
1865*01826a49SYabin Cui 
HUF_decompress_4stream(const HUF_dtable * const dtable,ostream_t * const out,istream_t * const in)1866*01826a49SYabin Cui static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
1867*01826a49SYabin Cui                                      ostream_t *const out, istream_t *const in) {
1868*01826a49SYabin Cui     // "Compressed size is provided explicitly : in the 4-streams variant,
1869*01826a49SYabin Cui     // bitstreams are preceded by 3 unsigned little-endian 16-bits values. Each
1870*01826a49SYabin Cui     // value represents the compressed size of one stream, in order. The last
1871*01826a49SYabin Cui     // stream size is deducted from total compressed size and from previously
1872*01826a49SYabin Cui     // decoded stream sizes"
1873*01826a49SYabin Cui     const size_t csize1 = IO_read_bits(in, 16);
1874*01826a49SYabin Cui     const size_t csize2 = IO_read_bits(in, 16);
1875*01826a49SYabin Cui     const size_t csize3 = IO_read_bits(in, 16);
1876*01826a49SYabin Cui 
1877*01826a49SYabin Cui     istream_t in1 = IO_make_sub_istream(in, csize1);
1878*01826a49SYabin Cui     istream_t in2 = IO_make_sub_istream(in, csize2);
1879*01826a49SYabin Cui     istream_t in3 = IO_make_sub_istream(in, csize3);
1880*01826a49SYabin Cui     istream_t in4 = IO_make_sub_istream(in, IO_istream_len(in));
1881*01826a49SYabin Cui 
1882*01826a49SYabin Cui     size_t total_output = 0;
1883*01826a49SYabin Cui     // Decode each stream independently for simplicity
1884*01826a49SYabin Cui     // If we wanted to we could decode all 4 at the same time for speed,
1885*01826a49SYabin Cui     // utilizing more execution units
1886*01826a49SYabin Cui     total_output += HUF_decompress_1stream(dtable, out, &in1);
1887*01826a49SYabin Cui     total_output += HUF_decompress_1stream(dtable, out, &in2);
1888*01826a49SYabin Cui     total_output += HUF_decompress_1stream(dtable, out, &in3);
1889*01826a49SYabin Cui     total_output += HUF_decompress_1stream(dtable, out, &in4);
1890*01826a49SYabin Cui 
1891*01826a49SYabin Cui     return total_output;
1892*01826a49SYabin Cui }
1893*01826a49SYabin Cui 
1894*01826a49SYabin Cui /// Initializes a Huffman table using canonical Huffman codes
1895*01826a49SYabin Cui /// For more explanation on canonical Huffman codes see
1896*01826a49SYabin Cui /// https://www.cs.scranton.edu/~mccloske/courses/cmps340/huff_canonical_dec2015.html
1897*01826a49SYabin Cui /// Codes within a level are allocated in symbol order (i.e. smaller symbols get
1898*01826a49SYabin Cui /// earlier codes)
HUF_init_dtable(HUF_dtable * const table,const u8 * const bits,const int num_symbs)1899*01826a49SYabin Cui static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
1900*01826a49SYabin Cui                             const int num_symbs) {
1901*01826a49SYabin Cui     memset(table, 0, sizeof(HUF_dtable));
1902*01826a49SYabin Cui     if (num_symbs > HUF_MAX_SYMBS) {
1903*01826a49SYabin Cui         ERROR("Too many symbols for Huffman");
1904*01826a49SYabin Cui     }
1905*01826a49SYabin Cui 
1906*01826a49SYabin Cui     u8 max_bits = 0;
1907*01826a49SYabin Cui     u16 rank_count[HUF_MAX_BITS + 1];
1908*01826a49SYabin Cui     memset(rank_count, 0, sizeof(rank_count));
1909*01826a49SYabin Cui 
1910*01826a49SYabin Cui     // Count the number of symbols for each number of bits, and determine the
1911*01826a49SYabin Cui     // depth of the tree
1912*01826a49SYabin Cui     for (int i = 0; i < num_symbs; i++) {
1913*01826a49SYabin Cui         if (bits[i] > HUF_MAX_BITS) {
1914*01826a49SYabin Cui             ERROR("Huffman table depth too large");
1915*01826a49SYabin Cui         }
1916*01826a49SYabin Cui         max_bits = MAX(max_bits, bits[i]);
1917*01826a49SYabin Cui         rank_count[bits[i]]++;
1918*01826a49SYabin Cui     }
1919*01826a49SYabin Cui 
1920*01826a49SYabin Cui     const size_t table_size = 1 << max_bits;
1921*01826a49SYabin Cui     table->max_bits = max_bits;
1922*01826a49SYabin Cui     table->symbols = malloc(table_size);
1923*01826a49SYabin Cui     table->num_bits = malloc(table_size);
1924*01826a49SYabin Cui 
1925*01826a49SYabin Cui     if (!table->symbols || !table->num_bits) {
1926*01826a49SYabin Cui         free(table->symbols);
1927*01826a49SYabin Cui         free(table->num_bits);
1928*01826a49SYabin Cui         BAD_ALLOC();
1929*01826a49SYabin Cui     }
1930*01826a49SYabin Cui 
1931*01826a49SYabin Cui     // "Symbols are sorted by Weight. Within same Weight, symbols keep natural
1932*01826a49SYabin Cui     // order. Symbols with a Weight of zero are removed. Then, starting from
1933*01826a49SYabin Cui     // lowest weight, prefix codes are distributed in order."
1934*01826a49SYabin Cui 
1935*01826a49SYabin Cui     u32 rank_idx[HUF_MAX_BITS + 1];
1936*01826a49SYabin Cui     // Initialize the starting codes for each rank (number of bits)
1937*01826a49SYabin Cui     rank_idx[max_bits] = 0;
1938*01826a49SYabin Cui     for (int i = max_bits; i >= 1; i--) {
1939*01826a49SYabin Cui         rank_idx[i - 1] = rank_idx[i] + rank_count[i] * (1 << (max_bits - i));
1940*01826a49SYabin Cui         // The entire range takes the same number of bits so we can memset it
1941*01826a49SYabin Cui         memset(&table->num_bits[rank_idx[i]], i, rank_idx[i - 1] - rank_idx[i]);
1942*01826a49SYabin Cui     }
1943*01826a49SYabin Cui 
1944*01826a49SYabin Cui     if (rank_idx[0] != table_size) {
1945*01826a49SYabin Cui         CORRUPTION();
1946*01826a49SYabin Cui     }
1947*01826a49SYabin Cui 
1948*01826a49SYabin Cui     // Allocate codes and fill in the table
1949*01826a49SYabin Cui     for (int i = 0; i < num_symbs; i++) {
1950*01826a49SYabin Cui         if (bits[i] != 0) {
1951*01826a49SYabin Cui             // Allocate a code for this symbol and set its range in the table
1952*01826a49SYabin Cui             const u16 code = rank_idx[bits[i]];
1953*01826a49SYabin Cui             // Since the code doesn't care about the bottom `max_bits - bits[i]`
1954*01826a49SYabin Cui             // bits of state, it gets a range that spans all possible values of
1955*01826a49SYabin Cui             // the lower bits
1956*01826a49SYabin Cui             const u16 len = 1 << (max_bits - bits[i]);
1957*01826a49SYabin Cui             memset(&table->symbols[code], i, len);
1958*01826a49SYabin Cui             rank_idx[bits[i]] += len;
1959*01826a49SYabin Cui         }
1960*01826a49SYabin Cui     }
1961*01826a49SYabin Cui }
1962*01826a49SYabin Cui 
HUF_init_dtable_usingweights(HUF_dtable * const table,const u8 * const weights,const int num_symbs)1963*01826a49SYabin Cui static void HUF_init_dtable_usingweights(HUF_dtable *const table,
1964*01826a49SYabin Cui                                          const u8 *const weights,
1965*01826a49SYabin Cui                                          const int num_symbs) {
1966*01826a49SYabin Cui     // +1 because the last weight is not transmitted in the header
1967*01826a49SYabin Cui     if (num_symbs + 1 > HUF_MAX_SYMBS) {
1968*01826a49SYabin Cui         ERROR("Too many symbols for Huffman");
1969*01826a49SYabin Cui     }
1970*01826a49SYabin Cui 
1971*01826a49SYabin Cui     u8 bits[HUF_MAX_SYMBS];
1972*01826a49SYabin Cui 
1973*01826a49SYabin Cui     u64 weight_sum = 0;
1974*01826a49SYabin Cui     for (int i = 0; i < num_symbs; i++) {
1975*01826a49SYabin Cui         // Weights are in the same range as bit count
1976*01826a49SYabin Cui         if (weights[i] > HUF_MAX_BITS) {
1977*01826a49SYabin Cui             CORRUPTION();
1978*01826a49SYabin Cui         }
1979*01826a49SYabin Cui         weight_sum += weights[i] > 0 ? (u64)1 << (weights[i] - 1) : 0;
1980*01826a49SYabin Cui     }
1981*01826a49SYabin Cui 
1982*01826a49SYabin Cui     // Find the first power of 2 larger than the sum
1983*01826a49SYabin Cui     const int max_bits = highest_set_bit(weight_sum) + 1;
1984*01826a49SYabin Cui     const u64 left_over = ((u64)1 << max_bits) - weight_sum;
1985*01826a49SYabin Cui     // If the left over isn't a power of 2, the weights are invalid
1986*01826a49SYabin Cui     if (left_over & (left_over - 1)) {
1987*01826a49SYabin Cui         CORRUPTION();
1988*01826a49SYabin Cui     }
1989*01826a49SYabin Cui 
1990*01826a49SYabin Cui     // left_over is used to find the last weight as it's not transmitted
1991*01826a49SYabin Cui     // by inverting 2^(weight - 1) we can determine the value of last_weight
1992*01826a49SYabin Cui     const int last_weight = highest_set_bit(left_over) + 1;
1993*01826a49SYabin Cui 
1994*01826a49SYabin Cui     for (int i = 0; i < num_symbs; i++) {
1995*01826a49SYabin Cui         // "Number_of_Bits = Number_of_Bits ? Max_Number_of_Bits + 1 - Weight : 0"
1996*01826a49SYabin Cui         bits[i] = weights[i] > 0 ? (max_bits + 1 - weights[i]) : 0;
1997*01826a49SYabin Cui     }
1998*01826a49SYabin Cui     bits[num_symbs] =
1999*01826a49SYabin Cui         max_bits + 1 - last_weight; // Last weight is always non-zero
2000*01826a49SYabin Cui 
2001*01826a49SYabin Cui     HUF_init_dtable(table, bits, num_symbs + 1);
2002*01826a49SYabin Cui }
2003*01826a49SYabin Cui 
HUF_free_dtable(HUF_dtable * const dtable)2004*01826a49SYabin Cui static void HUF_free_dtable(HUF_dtable *const dtable) {
2005*01826a49SYabin Cui     free(dtable->symbols);
2006*01826a49SYabin Cui     free(dtable->num_bits);
2007*01826a49SYabin Cui     memset(dtable, 0, sizeof(HUF_dtable));
2008*01826a49SYabin Cui }
2009*01826a49SYabin Cui /******* END HUFFMAN PRIMITIVES ***********************************************/
2010*01826a49SYabin Cui 
2011*01826a49SYabin Cui /******* FSE PRIMITIVES *******************************************************/
2012*01826a49SYabin Cui /// For more description of FSE see
2013*01826a49SYabin Cui /// https://github.com/Cyan4973/FiniteStateEntropy/
2014*01826a49SYabin Cui 
2015*01826a49SYabin Cui /// Allow a symbol to be decoded without updating state
FSE_peek_symbol(const FSE_dtable * const dtable,const u16 state)2016*01826a49SYabin Cui static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable,
2017*01826a49SYabin Cui                                  const u16 state) {
2018*01826a49SYabin Cui     return dtable->symbols[state];
2019*01826a49SYabin Cui }
2020*01826a49SYabin Cui 
2021*01826a49SYabin Cui /// Consumes bits from the input and uses the current state to determine the
2022*01826a49SYabin Cui /// next state
FSE_update_state(const FSE_dtable * const dtable,u16 * const state,const u8 * const src,i64 * const offset)2023*01826a49SYabin Cui static inline void FSE_update_state(const FSE_dtable *const dtable,
2024*01826a49SYabin Cui                                     u16 *const state, const u8 *const src,
2025*01826a49SYabin Cui                                     i64 *const offset) {
2026*01826a49SYabin Cui     const u8 bits = dtable->num_bits[*state];
2027*01826a49SYabin Cui     const u16 rest = STREAM_read_bits(src, bits, offset);
2028*01826a49SYabin Cui     *state = dtable->new_state_base[*state] + rest;
2029*01826a49SYabin Cui }
2030*01826a49SYabin Cui 
2031*01826a49SYabin Cui /// Decodes a single FSE symbol and updates the offset
FSE_decode_symbol(const FSE_dtable * const dtable,u16 * const state,const u8 * const src,i64 * const offset)2032*01826a49SYabin Cui static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable,
2033*01826a49SYabin Cui                                    u16 *const state, const u8 *const src,
2034*01826a49SYabin Cui                                    i64 *const offset) {
2035*01826a49SYabin Cui     const u8 symb = FSE_peek_symbol(dtable, *state);
2036*01826a49SYabin Cui     FSE_update_state(dtable, state, src, offset);
2037*01826a49SYabin Cui     return symb;
2038*01826a49SYabin Cui }
2039*01826a49SYabin Cui 
FSE_init_state(const FSE_dtable * const dtable,u16 * const state,const u8 * const src,i64 * const offset)2040*01826a49SYabin Cui static inline void FSE_init_state(const FSE_dtable *const dtable,
2041*01826a49SYabin Cui                                   u16 *const state, const u8 *const src,
2042*01826a49SYabin Cui                                   i64 *const offset) {
2043*01826a49SYabin Cui     // Read in a full `accuracy_log` bits to initialize the state
2044*01826a49SYabin Cui     const u8 bits = dtable->accuracy_log;
2045*01826a49SYabin Cui     *state = STREAM_read_bits(src, bits, offset);
2046*01826a49SYabin Cui }
2047*01826a49SYabin Cui 
FSE_decompress_interleaved2(const FSE_dtable * const dtable,ostream_t * const out,istream_t * const in)2048*01826a49SYabin Cui static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
2049*01826a49SYabin Cui                                           ostream_t *const out,
2050*01826a49SYabin Cui                                           istream_t *const in) {
2051*01826a49SYabin Cui     const size_t len = IO_istream_len(in);
2052*01826a49SYabin Cui     if (len == 0) {
2053*01826a49SYabin Cui         INP_SIZE();
2054*01826a49SYabin Cui     }
2055*01826a49SYabin Cui     const u8 *const src = IO_get_read_ptr(in, len);
2056*01826a49SYabin Cui 
2057*01826a49SYabin Cui     // "Each bitstream must be read backward, that is starting from the end down
2058*01826a49SYabin Cui     // to the beginning. Therefore it's necessary to know the size of each
2059*01826a49SYabin Cui     // bitstream.
2060*01826a49SYabin Cui     //
2061*01826a49SYabin Cui     // It's also necessary to know exactly which bit is the latest. This is
2062*01826a49SYabin Cui     // detected by a final bit flag : the highest bit of latest byte is a
2063*01826a49SYabin Cui     // final-bit-flag. Consequently, a last byte of 0 is not possible. And the
2064*01826a49SYabin Cui     // final-bit-flag itself is not part of the useful bitstream. Hence, the
2065*01826a49SYabin Cui     // last byte contains between 0 and 7 useful bits."
2066*01826a49SYabin Cui     const int padding = 8 - highest_set_bit(src[len - 1]);
2067*01826a49SYabin Cui     i64 offset = len * 8 - padding;
2068*01826a49SYabin Cui 
2069*01826a49SYabin Cui     u16 state1, state2;
2070*01826a49SYabin Cui     // "The first state (State1) encodes the even indexed symbols, and the
2071*01826a49SYabin Cui     // second (State2) encodes the odd indexes. State1 is initialized first, and
2072*01826a49SYabin Cui     // then State2, and they take turns decoding a single symbol and updating
2073*01826a49SYabin Cui     // their state."
2074*01826a49SYabin Cui     FSE_init_state(dtable, &state1, src, &offset);
2075*01826a49SYabin Cui     FSE_init_state(dtable, &state2, src, &offset);
2076*01826a49SYabin Cui 
2077*01826a49SYabin Cui     // Decode until we overflow the stream
2078*01826a49SYabin Cui     // Since we decode in reverse order, overflowing the stream is offset going
2079*01826a49SYabin Cui     // negative
2080*01826a49SYabin Cui     size_t symbols_written = 0;
2081*01826a49SYabin Cui     while (1) {
2082*01826a49SYabin Cui         // "The number of symbols to decode is determined by tracking bitStream
2083*01826a49SYabin Cui         // overflow condition: If updating state after decoding a symbol would
2084*01826a49SYabin Cui         // require more bits than remain in the stream, it is assumed the extra
2085*01826a49SYabin Cui         // bits are 0. Then, the symbols for each of the final states are
2086*01826a49SYabin Cui         // decoded and the process is complete."
2087*01826a49SYabin Cui         IO_write_byte(out, FSE_decode_symbol(dtable, &state1, src, &offset));
2088*01826a49SYabin Cui         symbols_written++;
2089*01826a49SYabin Cui         if (offset < 0) {
2090*01826a49SYabin Cui             // There's still a symbol to decode in state2
2091*01826a49SYabin Cui             IO_write_byte(out, FSE_peek_symbol(dtable, state2));
2092*01826a49SYabin Cui             symbols_written++;
2093*01826a49SYabin Cui             break;
2094*01826a49SYabin Cui         }
2095*01826a49SYabin Cui 
2096*01826a49SYabin Cui         IO_write_byte(out, FSE_decode_symbol(dtable, &state2, src, &offset));
2097*01826a49SYabin Cui         symbols_written++;
2098*01826a49SYabin Cui         if (offset < 0) {
2099*01826a49SYabin Cui             // There's still a symbol to decode in state1
2100*01826a49SYabin Cui             IO_write_byte(out, FSE_peek_symbol(dtable, state1));
2101*01826a49SYabin Cui             symbols_written++;
2102*01826a49SYabin Cui             break;
2103*01826a49SYabin Cui         }
2104*01826a49SYabin Cui     }
2105*01826a49SYabin Cui 
2106*01826a49SYabin Cui     return symbols_written;
2107*01826a49SYabin Cui }
2108*01826a49SYabin Cui 
FSE_init_dtable(FSE_dtable * const dtable,const i16 * const norm_freqs,const int num_symbs,const int accuracy_log)2109*01826a49SYabin Cui static void FSE_init_dtable(FSE_dtable *const dtable,
2110*01826a49SYabin Cui                             const i16 *const norm_freqs, const int num_symbs,
2111*01826a49SYabin Cui                             const int accuracy_log) {
2112*01826a49SYabin Cui     if (accuracy_log > FSE_MAX_ACCURACY_LOG) {
2113*01826a49SYabin Cui         ERROR("FSE accuracy too large");
2114*01826a49SYabin Cui     }
2115*01826a49SYabin Cui     if (num_symbs > FSE_MAX_SYMBS) {
2116*01826a49SYabin Cui         ERROR("Too many symbols for FSE");
2117*01826a49SYabin Cui     }
2118*01826a49SYabin Cui 
2119*01826a49SYabin Cui     dtable->accuracy_log = accuracy_log;
2120*01826a49SYabin Cui 
2121*01826a49SYabin Cui     const size_t size = (size_t)1 << accuracy_log;
2122*01826a49SYabin Cui     dtable->symbols = malloc(size * sizeof(u8));
2123*01826a49SYabin Cui     dtable->num_bits = malloc(size * sizeof(u8));
2124*01826a49SYabin Cui     dtable->new_state_base = malloc(size * sizeof(u16));
2125*01826a49SYabin Cui 
2126*01826a49SYabin Cui     if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) {
2127*01826a49SYabin Cui         BAD_ALLOC();
2128*01826a49SYabin Cui     }
2129*01826a49SYabin Cui 
2130*01826a49SYabin Cui     // Used to determine how many bits need to be read for each state,
2131*01826a49SYabin Cui     // and where the destination range should start
2132*01826a49SYabin Cui     // Needs to be u16 because max value is 2 * max number of symbols,
2133*01826a49SYabin Cui     // which can be larger than a byte can store
2134*01826a49SYabin Cui     u16 state_desc[FSE_MAX_SYMBS];
2135*01826a49SYabin Cui 
2136*01826a49SYabin Cui     // "Symbols are scanned in their natural order for "less than 1"
2137*01826a49SYabin Cui     // probabilities. Symbols with this probability are being attributed a
2138*01826a49SYabin Cui     // single cell, starting from the end of the table. These symbols define a
2139*01826a49SYabin Cui     // full state reset, reading Accuracy_Log bits."
2140*01826a49SYabin Cui     int high_threshold = size;
2141*01826a49SYabin Cui     for (int s = 0; s < num_symbs; s++) {
2142*01826a49SYabin Cui         // Scan for low probability symbols to put at the top
2143*01826a49SYabin Cui         if (norm_freqs[s] == -1) {
2144*01826a49SYabin Cui             dtable->symbols[--high_threshold] = s;
2145*01826a49SYabin Cui             state_desc[s] = 1;
2146*01826a49SYabin Cui         }
2147*01826a49SYabin Cui     }
2148*01826a49SYabin Cui 
2149*01826a49SYabin Cui     // "All remaining symbols are sorted in their natural order. Starting from
2150*01826a49SYabin Cui     // symbol 0 and table position 0, each symbol gets attributed as many cells
2151*01826a49SYabin Cui     // as its probability. Cell allocation is spread, not linear."
2152*01826a49SYabin Cui     // Place the rest in the table
2153*01826a49SYabin Cui     const u16 step = (size >> 1) + (size >> 3) + 3;
2154*01826a49SYabin Cui     const u16 mask = size - 1;
2155*01826a49SYabin Cui     u16 pos = 0;
2156*01826a49SYabin Cui     for (int s = 0; s < num_symbs; s++) {
2157*01826a49SYabin Cui         if (norm_freqs[s] <= 0) {
2158*01826a49SYabin Cui             continue;
2159*01826a49SYabin Cui         }
2160*01826a49SYabin Cui 
2161*01826a49SYabin Cui         state_desc[s] = norm_freqs[s];
2162*01826a49SYabin Cui 
2163*01826a49SYabin Cui         for (int i = 0; i < norm_freqs[s]; i++) {
2164*01826a49SYabin Cui             // Give `norm_freqs[s]` states to symbol s
2165*01826a49SYabin Cui             dtable->symbols[pos] = s;
2166*01826a49SYabin Cui             // "A position is skipped if already occupied, typically by a "less
2167*01826a49SYabin Cui             // than 1" probability symbol."
2168*01826a49SYabin Cui             do {
2169*01826a49SYabin Cui                 pos = (pos + step) & mask;
2170*01826a49SYabin Cui             } while (pos >=
2171*01826a49SYabin Cui                      high_threshold);
2172*01826a49SYabin Cui             // Note: no other collision checking is necessary as `step` is
2173*01826a49SYabin Cui             // coprime to `size`, so the cycle will visit each position exactly
2174*01826a49SYabin Cui             // once
2175*01826a49SYabin Cui         }
2176*01826a49SYabin Cui     }
2177*01826a49SYabin Cui     if (pos != 0) {
2178*01826a49SYabin Cui         CORRUPTION();
2179*01826a49SYabin Cui     }
2180*01826a49SYabin Cui 
2181*01826a49SYabin Cui     // Now we can fill baseline and num bits
2182*01826a49SYabin Cui     for (size_t i = 0; i < size; i++) {
2183*01826a49SYabin Cui         u8 symbol = dtable->symbols[i];
2184*01826a49SYabin Cui         u16 next_state_desc = state_desc[symbol]++;
2185*01826a49SYabin Cui         // Fills in the table appropriately, next_state_desc increases by symbol
2186*01826a49SYabin Cui         // over time, decreasing number of bits
2187*01826a49SYabin Cui         dtable->num_bits[i] = (u8)(accuracy_log - highest_set_bit(next_state_desc));
2188*01826a49SYabin Cui         // Baseline increases until the bit threshold is passed, at which point
2189*01826a49SYabin Cui         // it resets to 0
2190*01826a49SYabin Cui         dtable->new_state_base[i] =
2191*01826a49SYabin Cui             ((u16)next_state_desc << dtable->num_bits[i]) - size;
2192*01826a49SYabin Cui     }
2193*01826a49SYabin Cui }
2194*01826a49SYabin Cui 
2195*01826a49SYabin Cui /// Decode an FSE header as defined in the Zstandard format specification and
2196*01826a49SYabin Cui /// use the decoded frequencies to initialize a decoding table.
FSE_decode_header(FSE_dtable * const dtable,istream_t * const in,const int max_accuracy_log)2197*01826a49SYabin Cui static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
2198*01826a49SYabin Cui                                 const int max_accuracy_log) {
2199*01826a49SYabin Cui     // "An FSE distribution table describes the probabilities of all symbols
2200*01826a49SYabin Cui     // from 0 to the last present one (included) on a normalized scale of 1 <<
2201*01826a49SYabin Cui     // Accuracy_Log .
2202*01826a49SYabin Cui     //
2203*01826a49SYabin Cui     // It's a bitstream which is read forward, in little-endian fashion. It's
2204*01826a49SYabin Cui     // not necessary to know its exact size, since it will be discovered and
2205*01826a49SYabin Cui     // reported by the decoding process.
2206*01826a49SYabin Cui     if (max_accuracy_log > FSE_MAX_ACCURACY_LOG) {
2207*01826a49SYabin Cui         ERROR("FSE accuracy too large");
2208*01826a49SYabin Cui     }
2209*01826a49SYabin Cui 
2210*01826a49SYabin Cui     // The bitstream starts by reporting on which scale it operates.
2211*01826a49SYabin Cui     // Accuracy_Log = low4bits + 5. Note that maximum Accuracy_Log for literal
2212*01826a49SYabin Cui     // and match lengths is 9, and for offsets is 8. Higher values are
2213*01826a49SYabin Cui     // considered errors."
2214*01826a49SYabin Cui     const int accuracy_log = 5 + IO_read_bits(in, 4);
2215*01826a49SYabin Cui     if (accuracy_log > max_accuracy_log) {
2216*01826a49SYabin Cui         ERROR("FSE accuracy too large");
2217*01826a49SYabin Cui     }
2218*01826a49SYabin Cui 
2219*01826a49SYabin Cui     // "Then follows each symbol value, from 0 to last present one. The number
2220*01826a49SYabin Cui     // of bits used by each field is variable. It depends on :
2221*01826a49SYabin Cui     //
2222*01826a49SYabin Cui     // Remaining probabilities + 1 : example : Presuming an Accuracy_Log of 8,
2223*01826a49SYabin Cui     // and presuming 100 probabilities points have already been distributed, the
2224*01826a49SYabin Cui     // decoder may read any value from 0 to 255 - 100 + 1 == 156 (inclusive).
2225*01826a49SYabin Cui     // Therefore, it must read log2sup(156) == 8 bits.
2226*01826a49SYabin Cui     //
2227*01826a49SYabin Cui     // Value decoded : small values use 1 less bit : example : Presuming values
2228*01826a49SYabin Cui     // from 0 to 156 (inclusive) are possible, 255-156 = 99 values are remaining
2229*01826a49SYabin Cui     // in an 8-bits field. They are used this way : first 99 values (hence from
2230*01826a49SYabin Cui     // 0 to 98) use only 7 bits, values from 99 to 156 use 8 bits. "
2231*01826a49SYabin Cui 
2232*01826a49SYabin Cui     i32 remaining = 1 << accuracy_log;
2233*01826a49SYabin Cui     i16 frequencies[FSE_MAX_SYMBS];
2234*01826a49SYabin Cui 
2235*01826a49SYabin Cui     int symb = 0;
2236*01826a49SYabin Cui     while (remaining > 0 && symb < FSE_MAX_SYMBS) {
2237*01826a49SYabin Cui         // Log of the number of possible values we could read
2238*01826a49SYabin Cui         int bits = highest_set_bit(remaining + 1) + 1;
2239*01826a49SYabin Cui 
2240*01826a49SYabin Cui         u16 val = IO_read_bits(in, bits);
2241*01826a49SYabin Cui 
2242*01826a49SYabin Cui         // Try to mask out the lower bits to see if it qualifies for the "small
2243*01826a49SYabin Cui         // value" threshold
2244*01826a49SYabin Cui         const u16 lower_mask = ((u16)1 << (bits - 1)) - 1;
2245*01826a49SYabin Cui         const u16 threshold = ((u16)1 << bits) - 1 - (remaining + 1);
2246*01826a49SYabin Cui 
2247*01826a49SYabin Cui         if ((val & lower_mask) < threshold) {
2248*01826a49SYabin Cui             IO_rewind_bits(in, 1);
2249*01826a49SYabin Cui             val = val & lower_mask;
2250*01826a49SYabin Cui         } else if (val > lower_mask) {
2251*01826a49SYabin Cui             val = val - threshold;
2252*01826a49SYabin Cui         }
2253*01826a49SYabin Cui 
2254*01826a49SYabin Cui         // "Probability is obtained from Value decoded by following formula :
2255*01826a49SYabin Cui         // Proba = value - 1"
2256*01826a49SYabin Cui         const i16 proba = (i16)val - 1;
2257*01826a49SYabin Cui 
2258*01826a49SYabin Cui         // "It means value 0 becomes negative probability -1. -1 is a special
2259*01826a49SYabin Cui         // probability, which means "less than 1". Its effect on distribution
2260*01826a49SYabin Cui         // table is described in next paragraph. For the purpose of calculating
2261*01826a49SYabin Cui         // cumulated distribution, it counts as one."
2262*01826a49SYabin Cui         remaining -= proba < 0 ? -proba : proba;
2263*01826a49SYabin Cui 
2264*01826a49SYabin Cui         frequencies[symb] = proba;
2265*01826a49SYabin Cui         symb++;
2266*01826a49SYabin Cui 
2267*01826a49SYabin Cui         // "When a symbol has a probability of zero, it is followed by a 2-bits
2268*01826a49SYabin Cui         // repeat flag. This repeat flag tells how many probabilities of zeroes
2269*01826a49SYabin Cui         // follow the current one. It provides a number ranging from 0 to 3. If
2270*01826a49SYabin Cui         // it is a 3, another 2-bits repeat flag follows, and so on."
2271*01826a49SYabin Cui         if (proba == 0) {
2272*01826a49SYabin Cui             // Read the next two bits to see how many more 0s
2273*01826a49SYabin Cui             int repeat = IO_read_bits(in, 2);
2274*01826a49SYabin Cui 
2275*01826a49SYabin Cui             while (1) {
2276*01826a49SYabin Cui                 for (int i = 0; i < repeat && symb < FSE_MAX_SYMBS; i++) {
2277*01826a49SYabin Cui                     frequencies[symb++] = 0;
2278*01826a49SYabin Cui                 }
2279*01826a49SYabin Cui                 if (repeat == 3) {
2280*01826a49SYabin Cui                     repeat = IO_read_bits(in, 2);
2281*01826a49SYabin Cui                 } else {
2282*01826a49SYabin Cui                     break;
2283*01826a49SYabin Cui                 }
2284*01826a49SYabin Cui             }
2285*01826a49SYabin Cui         }
2286*01826a49SYabin Cui     }
2287*01826a49SYabin Cui     IO_align_stream(in);
2288*01826a49SYabin Cui 
2289*01826a49SYabin Cui     // "When last symbol reaches cumulated total of 1 << Accuracy_Log, decoding
2290*01826a49SYabin Cui     // is complete. If the last symbol makes cumulated total go above 1 <<
2291*01826a49SYabin Cui     // Accuracy_Log, distribution is considered corrupted."
2292*01826a49SYabin Cui     if (remaining != 0 || symb >= FSE_MAX_SYMBS) {
2293*01826a49SYabin Cui         CORRUPTION();
2294*01826a49SYabin Cui     }
2295*01826a49SYabin Cui 
2296*01826a49SYabin Cui     // Initialize the decoding table using the determined weights
2297*01826a49SYabin Cui     FSE_init_dtable(dtable, frequencies, symb, accuracy_log);
2298*01826a49SYabin Cui }
2299*01826a49SYabin Cui 
FSE_init_dtable_rle(FSE_dtable * const dtable,const u8 symb)2300*01826a49SYabin Cui static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb) {
2301*01826a49SYabin Cui     dtable->symbols = malloc(sizeof(u8));
2302*01826a49SYabin Cui     dtable->num_bits = malloc(sizeof(u8));
2303*01826a49SYabin Cui     dtable->new_state_base = malloc(sizeof(u16));
2304*01826a49SYabin Cui 
2305*01826a49SYabin Cui     if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) {
2306*01826a49SYabin Cui         BAD_ALLOC();
2307*01826a49SYabin Cui     }
2308*01826a49SYabin Cui 
2309*01826a49SYabin Cui     // This setup will always have a state of 0, always return symbol `symb`,
2310*01826a49SYabin Cui     // and never consume any bits
2311*01826a49SYabin Cui     dtable->symbols[0] = symb;
2312*01826a49SYabin Cui     dtable->num_bits[0] = 0;
2313*01826a49SYabin Cui     dtable->new_state_base[0] = 0;
2314*01826a49SYabin Cui     dtable->accuracy_log = 0;
2315*01826a49SYabin Cui }
2316*01826a49SYabin Cui 
FSE_free_dtable(FSE_dtable * const dtable)2317*01826a49SYabin Cui static void FSE_free_dtable(FSE_dtable *const dtable) {
2318*01826a49SYabin Cui     free(dtable->symbols);
2319*01826a49SYabin Cui     free(dtable->num_bits);
2320*01826a49SYabin Cui     free(dtable->new_state_base);
2321*01826a49SYabin Cui     memset(dtable, 0, sizeof(FSE_dtable));
2322*01826a49SYabin Cui }
2323*01826a49SYabin Cui /******* END FSE PRIMITIVES ***************************************************/
2324