1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package zstd
6
7import (
8	"io"
9)
10
11// debug can be set in the source to print debug info using println.
12const debug = false
13
14// compressedBlock decompresses a compressed block, storing the decompressed
15// data in r.buffer. The blockSize argument is the compressed size.
16// RFC 3.1.1.3.
17func (r *Reader) compressedBlock(blockSize int) error {
18	if len(r.compressedBuf) >= blockSize {
19		r.compressedBuf = r.compressedBuf[:blockSize]
20	} else {
21		// We know that blockSize <= 128K,
22		// so this won't allocate an enormous amount.
23		need := blockSize - len(r.compressedBuf)
24		r.compressedBuf = append(r.compressedBuf, make([]byte, need)...)
25	}
26
27	if _, err := io.ReadFull(r.r, r.compressedBuf); err != nil {
28		return r.wrapNonEOFError(0, err)
29	}
30
31	data := block(r.compressedBuf)
32	off := 0
33	r.buffer = r.buffer[:0]
34
35	litoff, litbuf, err := r.readLiterals(data, off, r.literals[:0])
36	if err != nil {
37		return err
38	}
39	r.literals = litbuf
40
41	off = litoff
42
43	seqCount, off, err := r.initSeqs(data, off)
44	if err != nil {
45		return err
46	}
47
48	if seqCount == 0 {
49		// No sequences, just literals.
50		if off < len(data) {
51			return r.makeError(off, "extraneous data after no sequences")
52		}
53
54		r.buffer = append(r.buffer, litbuf...)
55
56		return nil
57	}
58
59	return r.execSeqs(data, off, litbuf, seqCount)
60}
61
62// seqCode is the kind of sequence codes we have to handle.
63type seqCode int
64
65const (
66	seqLiteral seqCode = iota
67	seqOffset
68	seqMatch
69)
70
71// seqCodeInfoData is the information needed to set up seqTables and
72// seqTableBits for a particular kind of sequence code.
73type seqCodeInfoData struct {
74	predefTable     []fseBaselineEntry // predefined FSE
75	predefTableBits int                // number of bits in predefTable
76	maxSym          int                // max symbol value in FSE
77	maxBits         int                // max bits for FSE
78
79	// toBaseline converts from an FSE table to an FSE baseline table.
80	toBaseline func(*Reader, int, []fseEntry, []fseBaselineEntry) error
81}
82
83// seqCodeInfo is the seqCodeInfoData for each kind of sequence code.
84var seqCodeInfo = [3]seqCodeInfoData{
85	seqLiteral: {
86		predefTable:     predefinedLiteralTable[:],
87		predefTableBits: 6,
88		maxSym:          35,
89		maxBits:         9,
90		toBaseline:      (*Reader).makeLiteralBaselineFSE,
91	},
92	seqOffset: {
93		predefTable:     predefinedOffsetTable[:],
94		predefTableBits: 5,
95		maxSym:          31,
96		maxBits:         8,
97		toBaseline:      (*Reader).makeOffsetBaselineFSE,
98	},
99	seqMatch: {
100		predefTable:     predefinedMatchTable[:],
101		predefTableBits: 6,
102		maxSym:          52,
103		maxBits:         9,
104		toBaseline:      (*Reader).makeMatchBaselineFSE,
105	},
106}
107
108// initSeqs reads the Sequences_Section_Header and sets up the FSE
109// tables used to read the sequence codes. It returns the number of
110// sequences and the new offset. RFC 3.1.1.3.2.1.
111func (r *Reader) initSeqs(data block, off int) (int, int, error) {
112	if off >= len(data) {
113		return 0, 0, r.makeEOFError(off)
114	}
115
116	seqHdr := data[off]
117	off++
118	if seqHdr == 0 {
119		return 0, off, nil
120	}
121
122	var seqCount int
123	if seqHdr < 128 {
124		seqCount = int(seqHdr)
125	} else if seqHdr < 255 {
126		if off >= len(data) {
127			return 0, 0, r.makeEOFError(off)
128		}
129		seqCount = ((int(seqHdr) - 128) << 8) + int(data[off])
130		off++
131	} else {
132		if off+1 >= len(data) {
133			return 0, 0, r.makeEOFError(off)
134		}
135		seqCount = int(data[off]) + (int(data[off+1]) << 8) + 0x7f00
136		off += 2
137	}
138
139	// Read the Symbol_Compression_Modes byte.
140
141	if off >= len(data) {
142		return 0, 0, r.makeEOFError(off)
143	}
144	symMode := data[off]
145	if symMode&3 != 0 {
146		return 0, 0, r.makeError(off, "invalid symbol compression mode")
147	}
148	off++
149
150	// Set up the FSE tables used to decode the sequence codes.
151
152	var err error
153	off, err = r.setSeqTable(data, off, seqLiteral, (symMode>>6)&3)
154	if err != nil {
155		return 0, 0, err
156	}
157
158	off, err = r.setSeqTable(data, off, seqOffset, (symMode>>4)&3)
159	if err != nil {
160		return 0, 0, err
161	}
162
163	off, err = r.setSeqTable(data, off, seqMatch, (symMode>>2)&3)
164	if err != nil {
165		return 0, 0, err
166	}
167
168	return seqCount, off, nil
169}
170
171// setSeqTable uses the Compression_Mode in mode to set up r.seqTables and
172// r.seqTableBits for kind. We store these in the Reader because one of
173// the modes simply reuses the value from the last block in the frame.
174func (r *Reader) setSeqTable(data block, off int, kind seqCode, mode byte) (int, error) {
175	info := &seqCodeInfo[kind]
176	switch mode {
177	case 0:
178		// Predefined_Mode
179		r.seqTables[kind] = info.predefTable
180		r.seqTableBits[kind] = uint8(info.predefTableBits)
181		return off, nil
182
183	case 1:
184		// RLE_Mode
185		if off >= len(data) {
186			return 0, r.makeEOFError(off)
187		}
188		rle := data[off]
189		off++
190
191		// Build a simple baseline table that always returns rle.
192
193		entry := []fseEntry{
194			{
195				sym:  rle,
196				bits: 0,
197				base: 0,
198			},
199		}
200		if cap(r.seqTableBuffers[kind]) == 0 {
201			r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1<<info.maxBits)
202		}
203		r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1]
204		if err := info.toBaseline(r, off, entry, r.seqTableBuffers[kind]); err != nil {
205			return 0, err
206		}
207
208		r.seqTables[kind] = r.seqTableBuffers[kind]
209		r.seqTableBits[kind] = 0
210		return off, nil
211
212	case 2:
213		// FSE_Compressed_Mode
214		if cap(r.fseScratch) < 1<<info.maxBits {
215			r.fseScratch = make([]fseEntry, 1<<info.maxBits)
216		}
217		r.fseScratch = r.fseScratch[:1<<info.maxBits]
218
219		tableBits, roff, err := r.readFSE(data, off, info.maxSym, info.maxBits, r.fseScratch)
220		if err != nil {
221			return 0, err
222		}
223		r.fseScratch = r.fseScratch[:1<<tableBits]
224
225		if cap(r.seqTableBuffers[kind]) == 0 {
226			r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1<<info.maxBits)
227		}
228		r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1<<tableBits]
229
230		if err := info.toBaseline(r, roff, r.fseScratch, r.seqTableBuffers[kind]); err != nil {
231			return 0, err
232		}
233
234		r.seqTables[kind] = r.seqTableBuffers[kind]
235		r.seqTableBits[kind] = uint8(tableBits)
236		return roff, nil
237
238	case 3:
239		// Repeat_Mode
240		if len(r.seqTables[kind]) == 0 {
241			return 0, r.makeError(off, "missing repeat sequence FSE table")
242		}
243		return off, nil
244	}
245	panic("unreachable")
246}
247
248// execSeqs reads and executes the sequences. RFC 3.1.1.3.2.1.2.
249func (r *Reader) execSeqs(data block, off int, litbuf []byte, seqCount int) error {
250	// Set up the initial states for the sequence code readers.
251
252	rbr, err := r.makeReverseBitReader(data, len(data)-1, off)
253	if err != nil {
254		return err
255	}
256
257	literalState, err := rbr.val(r.seqTableBits[seqLiteral])
258	if err != nil {
259		return err
260	}
261
262	offsetState, err := rbr.val(r.seqTableBits[seqOffset])
263	if err != nil {
264		return err
265	}
266
267	matchState, err := rbr.val(r.seqTableBits[seqMatch])
268	if err != nil {
269		return err
270	}
271
272	// Read and perform all the sequences. RFC 3.1.1.4.
273
274	seq := 0
275	for seq < seqCount {
276		if len(r.buffer)+len(litbuf) > 128<<10 {
277			return rbr.makeError("uncompressed size too big")
278		}
279
280		ptoffset := &r.seqTables[seqOffset][offsetState]
281		ptmatch := &r.seqTables[seqMatch][matchState]
282		ptliteral := &r.seqTables[seqLiteral][literalState]
283
284		add, err := rbr.val(ptoffset.basebits)
285		if err != nil {
286			return err
287		}
288		offset := ptoffset.baseline + add
289
290		add, err = rbr.val(ptmatch.basebits)
291		if err != nil {
292			return err
293		}
294		match := ptmatch.baseline + add
295
296		add, err = rbr.val(ptliteral.basebits)
297		if err != nil {
298			return err
299		}
300		literal := ptliteral.baseline + add
301
302		// Handle repeat offsets. RFC 3.1.1.5.
303		// See the comment in makeOffsetBaselineFSE.
304		if ptoffset.basebits > 1 {
305			r.repeatedOffset3 = r.repeatedOffset2
306			r.repeatedOffset2 = r.repeatedOffset1
307			r.repeatedOffset1 = offset
308		} else {
309			if literal == 0 {
310				offset++
311			}
312			switch offset {
313			case 1:
314				offset = r.repeatedOffset1
315			case 2:
316				offset = r.repeatedOffset2
317				r.repeatedOffset2 = r.repeatedOffset1
318				r.repeatedOffset1 = offset
319			case 3:
320				offset = r.repeatedOffset3
321				r.repeatedOffset3 = r.repeatedOffset2
322				r.repeatedOffset2 = r.repeatedOffset1
323				r.repeatedOffset1 = offset
324			case 4:
325				offset = r.repeatedOffset1 - 1
326				r.repeatedOffset3 = r.repeatedOffset2
327				r.repeatedOffset2 = r.repeatedOffset1
328				r.repeatedOffset1 = offset
329			}
330		}
331
332		seq++
333		if seq < seqCount {
334			// Update the states.
335			add, err = rbr.val(ptliteral.bits)
336			if err != nil {
337				return err
338			}
339			literalState = uint32(ptliteral.base) + add
340
341			add, err = rbr.val(ptmatch.bits)
342			if err != nil {
343				return err
344			}
345			matchState = uint32(ptmatch.base) + add
346
347			add, err = rbr.val(ptoffset.bits)
348			if err != nil {
349				return err
350			}
351			offsetState = uint32(ptoffset.base) + add
352		}
353
354		// The next sequence is now in literal, offset, match.
355
356		if debug {
357			println("literal", literal, "offset", offset, "match", match)
358		}
359
360		// Copy literal bytes from litbuf.
361		if literal > uint32(len(litbuf)) {
362			return rbr.makeError("literal byte overflow")
363		}
364		if literal > 0 {
365			r.buffer = append(r.buffer, litbuf[:literal]...)
366			litbuf = litbuf[literal:]
367		}
368
369		if match > 0 {
370			if err := r.copyFromWindow(&rbr, offset, match); err != nil {
371				return err
372			}
373		}
374	}
375
376	r.buffer = append(r.buffer, litbuf...)
377
378	if rbr.cnt != 0 {
379		return r.makeError(off, "extraneous data after sequences")
380	}
381
382	return nil
383}
384
385// Copy match bytes from the decoded output, or the window, at offset.
386func (r *Reader) copyFromWindow(rbr *reverseBitReader, offset, match uint32) error {
387	if offset == 0 {
388		return rbr.makeError("invalid zero offset")
389	}
390
391	// Offset may point into the buffer or the window and
392	// match may extend past the end of the initial buffer.
393	// |--r.window--|--r.buffer--|
394	//        |<-----offset------|
395	//        |------match----------->|
396	bufferOffset := uint32(0)
397	lenBlock := uint32(len(r.buffer))
398	if lenBlock < offset {
399		lenWindow := r.window.len()
400		copy := offset - lenBlock
401		if copy > lenWindow {
402			return rbr.makeError("offset past window")
403		}
404		windowOffset := lenWindow - copy
405		if copy > match {
406			copy = match
407		}
408		r.buffer = r.window.appendTo(r.buffer, windowOffset, windowOffset+copy)
409		match -= copy
410	} else {
411		bufferOffset = lenBlock - offset
412	}
413
414	// We are being asked to copy data that we are adding to the
415	// buffer in the same copy.
416	for match > 0 {
417		copy := uint32(len(r.buffer)) - bufferOffset
418		if copy > match {
419			copy = match
420		}
421		r.buffer = append(r.buffer, r.buffer[bufferOffset:bufferOffset+copy]...)
422		match -= copy
423	}
424	return nil
425}
426