1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package zstd provides a decompressor for zstd streams,
6// described in RFC 8878. It does not support dictionaries.
7package zstd
8
9import (
10	"encoding/binary"
11	"errors"
12	"fmt"
13	"io"
14)
15
16// fuzzing is a fuzzer hook set to true when fuzzing.
17// This is used to reject cases where we don't match zstd.
18var fuzzing = false
19
20// Reader implements [io.Reader] to read a zstd compressed stream.
21type Reader struct {
22	// The underlying Reader.
23	r io.Reader
24
25	// Whether we have read the frame header.
26	// This is of interest when buffer is empty.
27	// If true we expect to see a new block.
28	sawFrameHeader bool
29
30	// Whether the current frame expects a checksum.
31	hasChecksum bool
32
33	// Whether we have read at least one frame.
34	readOneFrame bool
35
36	// True if the frame size is not known.
37	frameSizeUnknown bool
38
39	// The number of uncompressed bytes remaining in the current frame.
40	// If frameSizeUnknown is true, this is not valid.
41	remainingFrameSize uint64
42
43	// The number of bytes read from r up to the start of the current
44	// block, for error reporting.
45	blockOffset int64
46
47	// Buffered decompressed data.
48	buffer []byte
49	// Current read offset in buffer.
50	off int
51
52	// The current repeated offsets.
53	repeatedOffset1 uint32
54	repeatedOffset2 uint32
55	repeatedOffset3 uint32
56
57	// The current Huffman tree used for compressing literals.
58	huffmanTable     []uint16
59	huffmanTableBits int
60
61	// The window for back references.
62	window window
63
64	// A buffer available to hold a compressed block.
65	compressedBuf []byte
66
67	// A buffer for literals.
68	literals []byte
69
70	// Sequence decode FSE tables.
71	seqTables    [3][]fseBaselineEntry
72	seqTableBits [3]uint8
73
74	// Buffers for sequence decode FSE tables.
75	seqTableBuffers [3][]fseBaselineEntry
76
77	// Scratch space used for small reads, to avoid allocation.
78	scratch [16]byte
79
80	// A scratch table for reading an FSE. Only temporarily valid.
81	fseScratch []fseEntry
82
83	// For checksum computation.
84	checksum xxhash64
85}
86
87// NewReader creates a new Reader that decompresses data from the given reader.
88func NewReader(input io.Reader) *Reader {
89	r := new(Reader)
90	r.Reset(input)
91	return r
92}
93
94// Reset discards the current state and starts reading a new stream from r.
95// This permits reusing a Reader rather than allocating a new one.
96func (r *Reader) Reset(input io.Reader) {
97	r.r = input
98
99	// Several fields are preserved to avoid allocation.
100	// Others are always set before they are used.
101	r.sawFrameHeader = false
102	r.hasChecksum = false
103	r.readOneFrame = false
104	r.frameSizeUnknown = false
105	r.remainingFrameSize = 0
106	r.blockOffset = 0
107	r.buffer = r.buffer[:0]
108	r.off = 0
109	// repeatedOffset1
110	// repeatedOffset2
111	// repeatedOffset3
112	// huffmanTable
113	// huffmanTableBits
114	// window
115	// compressedBuf
116	// literals
117	// seqTables
118	// seqTableBits
119	// seqTableBuffers
120	// scratch
121	// fseScratch
122}
123
124// Read implements [io.Reader].
125func (r *Reader) Read(p []byte) (int, error) {
126	if err := r.refillIfNeeded(); err != nil {
127		return 0, err
128	}
129	n := copy(p, r.buffer[r.off:])
130	r.off += n
131	return n, nil
132}
133
134// ReadByte implements [io.ByteReader].
135func (r *Reader) ReadByte() (byte, error) {
136	if err := r.refillIfNeeded(); err != nil {
137		return 0, err
138	}
139	ret := r.buffer[r.off]
140	r.off++
141	return ret, nil
142}
143
144// refillIfNeeded reads the next block if necessary.
145func (r *Reader) refillIfNeeded() error {
146	for r.off >= len(r.buffer) {
147		if err := r.refill(); err != nil {
148			return err
149		}
150		r.off = 0
151	}
152	return nil
153}
154
155// refill reads and decompresses the next block.
156func (r *Reader) refill() error {
157	if !r.sawFrameHeader {
158		if err := r.readFrameHeader(); err != nil {
159			return err
160		}
161	}
162	return r.readBlock()
163}
164
165// readFrameHeader reads the frame header and prepares to read a block.
166func (r *Reader) readFrameHeader() error {
167retry:
168	relativeOffset := 0
169
170	// Read magic number. RFC 3.1.1.
171	if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
172		// We require that the stream contains at least one frame.
173		if err == io.EOF && !r.readOneFrame {
174			err = io.ErrUnexpectedEOF
175		}
176		return r.wrapError(relativeOffset, err)
177	}
178
179	if magic := binary.LittleEndian.Uint32(r.scratch[:4]); magic != 0xfd2fb528 {
180		if magic >= 0x184d2a50 && magic <= 0x184d2a5f {
181			// This is a skippable frame.
182			r.blockOffset += int64(relativeOffset) + 4
183			if err := r.skipFrame(); err != nil {
184				return err
185			}
186			r.readOneFrame = true
187			goto retry
188		}
189
190		return r.makeError(relativeOffset, "invalid magic number")
191	}
192
193	relativeOffset += 4
194
195	// Read Frame_Header_Descriptor. RFC 3.1.1.1.1.
196	if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
197		return r.wrapNonEOFError(relativeOffset, err)
198	}
199	descriptor := r.scratch[0]
200
201	singleSegment := descriptor&(1<<5) != 0
202
203	fcsFieldSize := 1 << (descriptor >> 6)
204	if fcsFieldSize == 1 && !singleSegment {
205		fcsFieldSize = 0
206	}
207
208	var windowDescriptorSize int
209	if singleSegment {
210		windowDescriptorSize = 0
211	} else {
212		windowDescriptorSize = 1
213	}
214
215	if descriptor&(1<<3) != 0 {
216		return r.makeError(relativeOffset, "reserved bit set in frame header descriptor")
217	}
218
219	r.hasChecksum = descriptor&(1<<2) != 0
220	if r.hasChecksum {
221		r.checksum.reset()
222	}
223
224	// Dictionary_ID_Flag. RFC 3.1.1.1.1.6.
225	dictionaryIdSize := 0
226	if dictIdFlag := descriptor & 3; dictIdFlag != 0 {
227		dictionaryIdSize = 1 << (dictIdFlag - 1)
228	}
229
230	relativeOffset++
231
232	headerSize := windowDescriptorSize + dictionaryIdSize + fcsFieldSize
233
234	if _, err := io.ReadFull(r.r, r.scratch[:headerSize]); err != nil {
235		return r.wrapNonEOFError(relativeOffset, err)
236	}
237
238	// Figure out the maximum amount of data we need to retain
239	// for backreferences.
240	var windowSize uint64
241	if !singleSegment {
242		// Window descriptor. RFC 3.1.1.1.2.
243		windowDescriptor := r.scratch[0]
244		exponent := uint64(windowDescriptor >> 3)
245		mantissa := uint64(windowDescriptor & 7)
246		windowLog := exponent + 10
247		windowBase := uint64(1) << windowLog
248		windowAdd := (windowBase / 8) * mantissa
249		windowSize = windowBase + windowAdd
250
251		// Default zstd sets limits on the window size.
252		if fuzzing && (windowLog > 31 || windowSize > 1<<27) {
253			return r.makeError(relativeOffset, "windowSize too large")
254		}
255	}
256
257	// Dictionary_ID. RFC 3.1.1.1.3.
258	if dictionaryIdSize != 0 {
259		dictionaryId := r.scratch[windowDescriptorSize : windowDescriptorSize+dictionaryIdSize]
260		// Allow only zero Dictionary ID.
261		for _, b := range dictionaryId {
262			if b != 0 {
263				return r.makeError(relativeOffset, "dictionaries are not supported")
264			}
265		}
266	}
267
268	// Frame_Content_Size. RFC 3.1.1.1.4.
269	r.frameSizeUnknown = false
270	r.remainingFrameSize = 0
271	fb := r.scratch[windowDescriptorSize+dictionaryIdSize:]
272	switch fcsFieldSize {
273	case 0:
274		r.frameSizeUnknown = true
275	case 1:
276		r.remainingFrameSize = uint64(fb[0])
277	case 2:
278		r.remainingFrameSize = 256 + uint64(binary.LittleEndian.Uint16(fb))
279	case 4:
280		r.remainingFrameSize = uint64(binary.LittleEndian.Uint32(fb))
281	case 8:
282		r.remainingFrameSize = binary.LittleEndian.Uint64(fb)
283	default:
284		panic("unreachable")
285	}
286
287	// RFC 3.1.1.1.2.
288	// When Single_Segment_Flag is set, Window_Descriptor is not present.
289	// In this case, Window_Size is Frame_Content_Size.
290	if singleSegment {
291		windowSize = r.remainingFrameSize
292	}
293
294	// RFC 8878 3.1.1.1.1.2. permits us to set an 8M max on window size.
295	const maxWindowSize = 8 << 20
296	if windowSize > maxWindowSize {
297		windowSize = maxWindowSize
298	}
299
300	relativeOffset += headerSize
301
302	r.sawFrameHeader = true
303	r.readOneFrame = true
304	r.blockOffset += int64(relativeOffset)
305
306	// Prepare to read blocks from the frame.
307	r.repeatedOffset1 = 1
308	r.repeatedOffset2 = 4
309	r.repeatedOffset3 = 8
310	r.huffmanTableBits = 0
311	r.window.reset(int(windowSize))
312	r.seqTables[0] = nil
313	r.seqTables[1] = nil
314	r.seqTables[2] = nil
315
316	return nil
317}
318
319// skipFrame skips a skippable frame. RFC 3.1.2.
320func (r *Reader) skipFrame() error {
321	relativeOffset := 0
322
323	if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
324		return r.wrapNonEOFError(relativeOffset, err)
325	}
326
327	relativeOffset += 4
328
329	size := binary.LittleEndian.Uint32(r.scratch[:4])
330	if size == 0 {
331		r.blockOffset += int64(relativeOffset)
332		return nil
333	}
334
335	if seeker, ok := r.r.(io.Seeker); ok {
336		r.blockOffset += int64(relativeOffset)
337		// Implementations of Seeker do not always detect invalid offsets,
338		// so check that the new offset is valid by comparing to the end.
339		prev, err := seeker.Seek(0, io.SeekCurrent)
340		if err != nil {
341			return r.wrapError(0, err)
342		}
343		end, err := seeker.Seek(0, io.SeekEnd)
344		if err != nil {
345			return r.wrapError(0, err)
346		}
347		if prev > end-int64(size) {
348			r.blockOffset += end - prev
349			return r.makeEOFError(0)
350		}
351
352		// The new offset is valid, so seek to it.
353		_, err = seeker.Seek(prev+int64(size), io.SeekStart)
354		if err != nil {
355			return r.wrapError(0, err)
356		}
357		r.blockOffset += int64(size)
358		return nil
359	}
360
361	var skip []byte
362	const chunk = 1 << 20 // 1M
363	for size >= chunk {
364		if len(skip) == 0 {
365			skip = make([]byte, chunk)
366		}
367		if _, err := io.ReadFull(r.r, skip); err != nil {
368			return r.wrapNonEOFError(relativeOffset, err)
369		}
370		relativeOffset += chunk
371		size -= chunk
372	}
373	if size > 0 {
374		if len(skip) == 0 {
375			skip = make([]byte, size)
376		}
377		if _, err := io.ReadFull(r.r, skip); err != nil {
378			return r.wrapNonEOFError(relativeOffset, err)
379		}
380		relativeOffset += int(size)
381	}
382
383	r.blockOffset += int64(relativeOffset)
384
385	return nil
386}
387
388// readBlock reads the next block from a frame.
389func (r *Reader) readBlock() error {
390	relativeOffset := 0
391
392	// Read Block_Header. RFC 3.1.1.2.
393	if _, err := io.ReadFull(r.r, r.scratch[:3]); err != nil {
394		return r.wrapNonEOFError(relativeOffset, err)
395	}
396
397	relativeOffset += 3
398
399	header := uint32(r.scratch[0]) | (uint32(r.scratch[1]) << 8) | (uint32(r.scratch[2]) << 16)
400
401	lastBlock := header&1 != 0
402	blockType := (header >> 1) & 3
403	blockSize := int(header >> 3)
404
405	// Maximum block size is smaller of window size and 128K.
406	// We don't record the window size for a single segment frame,
407	// so just use 128K. RFC 3.1.1.2.3, 3.1.1.2.4.
408	if blockSize > 128<<10 || (r.window.size > 0 && blockSize > r.window.size) {
409		return r.makeError(relativeOffset, "block size too large")
410	}
411
412	// Handle different block types. RFC 3.1.1.2.2.
413	switch blockType {
414	case 0:
415		r.setBufferSize(blockSize)
416		if _, err := io.ReadFull(r.r, r.buffer); err != nil {
417			return r.wrapNonEOFError(relativeOffset, err)
418		}
419		relativeOffset += blockSize
420		r.blockOffset += int64(relativeOffset)
421	case 1:
422		r.setBufferSize(blockSize)
423		if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
424			return r.wrapNonEOFError(relativeOffset, err)
425		}
426		relativeOffset++
427		v := r.scratch[0]
428		for i := range r.buffer {
429			r.buffer[i] = v
430		}
431		r.blockOffset += int64(relativeOffset)
432	case 2:
433		r.blockOffset += int64(relativeOffset)
434		if err := r.compressedBlock(blockSize); err != nil {
435			return err
436		}
437		r.blockOffset += int64(blockSize)
438	case 3:
439		return r.makeError(relativeOffset, "invalid block type")
440	}
441
442	if !r.frameSizeUnknown {
443		if uint64(len(r.buffer)) > r.remainingFrameSize {
444			return r.makeError(relativeOffset, "too many uncompressed bytes in frame")
445		}
446		r.remainingFrameSize -= uint64(len(r.buffer))
447	}
448
449	if r.hasChecksum {
450		r.checksum.update(r.buffer)
451	}
452
453	if !lastBlock {
454		r.window.save(r.buffer)
455	} else {
456		if !r.frameSizeUnknown && r.remainingFrameSize != 0 {
457			return r.makeError(relativeOffset, "not enough uncompressed bytes for frame")
458		}
459		// Check for checksum at end of frame. RFC 3.1.1.
460		if r.hasChecksum {
461			if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
462				return r.wrapNonEOFError(0, err)
463			}
464
465			inputChecksum := binary.LittleEndian.Uint32(r.scratch[:4])
466			dataChecksum := uint32(r.checksum.digest())
467			if inputChecksum != dataChecksum {
468				return r.wrapError(0, fmt.Errorf("invalid checksum: got %#x want %#x", dataChecksum, inputChecksum))
469			}
470
471			r.blockOffset += 4
472		}
473		r.sawFrameHeader = false
474	}
475
476	return nil
477}
478
479// setBufferSize sets the decompressed buffer size.
480// When this is called the buffer is empty.
481func (r *Reader) setBufferSize(size int) {
482	if cap(r.buffer) < size {
483		need := size - cap(r.buffer)
484		r.buffer = append(r.buffer[:cap(r.buffer)], make([]byte, need)...)
485	}
486	r.buffer = r.buffer[:size]
487}
488
489// zstdError is an error while decompressing.
490type zstdError struct {
491	offset int64
492	err    error
493}
494
495func (ze *zstdError) Error() string {
496	return fmt.Sprintf("zstd decompression error at %d: %v", ze.offset, ze.err)
497}
498
499func (ze *zstdError) Unwrap() error {
500	return ze.err
501}
502
503func (r *Reader) makeEOFError(off int) error {
504	return r.wrapError(off, io.ErrUnexpectedEOF)
505}
506
507func (r *Reader) wrapNonEOFError(off int, err error) error {
508	if err == io.EOF {
509		err = io.ErrUnexpectedEOF
510	}
511	return r.wrapError(off, err)
512}
513
514func (r *Reader) makeError(off int, msg string) error {
515	return r.wrapError(off, errors.New(msg))
516}
517
518func (r *Reader) wrapError(off int, err error) error {
519	if err == io.EOF {
520		return err
521	}
522	return &zstdError{r.blockOffset + int64(off), err}
523}
524