1// Copyright 2023 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5// Package zstd provides a decompressor for zstd streams, 6// described in RFC 8878. It does not support dictionaries. 7package zstd 8 9import ( 10 "encoding/binary" 11 "errors" 12 "fmt" 13 "io" 14) 15 16// fuzzing is a fuzzer hook set to true when fuzzing. 17// This is used to reject cases where we don't match zstd. 18var fuzzing = false 19 20// Reader implements [io.Reader] to read a zstd compressed stream. 21type Reader struct { 22 // The underlying Reader. 23 r io.Reader 24 25 // Whether we have read the frame header. 26 // This is of interest when buffer is empty. 27 // If true we expect to see a new block. 28 sawFrameHeader bool 29 30 // Whether the current frame expects a checksum. 31 hasChecksum bool 32 33 // Whether we have read at least one frame. 34 readOneFrame bool 35 36 // True if the frame size is not known. 37 frameSizeUnknown bool 38 39 // The number of uncompressed bytes remaining in the current frame. 40 // If frameSizeUnknown is true, this is not valid. 41 remainingFrameSize uint64 42 43 // The number of bytes read from r up to the start of the current 44 // block, for error reporting. 45 blockOffset int64 46 47 // Buffered decompressed data. 48 buffer []byte 49 // Current read offset in buffer. 50 off int 51 52 // The current repeated offsets. 53 repeatedOffset1 uint32 54 repeatedOffset2 uint32 55 repeatedOffset3 uint32 56 57 // The current Huffman tree used for compressing literals. 58 huffmanTable []uint16 59 huffmanTableBits int 60 61 // The window for back references. 62 window window 63 64 // A buffer available to hold a compressed block. 65 compressedBuf []byte 66 67 // A buffer for literals. 68 literals []byte 69 70 // Sequence decode FSE tables. 71 seqTables [3][]fseBaselineEntry 72 seqTableBits [3]uint8 73 74 // Buffers for sequence decode FSE tables. 75 seqTableBuffers [3][]fseBaselineEntry 76 77 // Scratch space used for small reads, to avoid allocation. 78 scratch [16]byte 79 80 // A scratch table for reading an FSE. Only temporarily valid. 81 fseScratch []fseEntry 82 83 // For checksum computation. 84 checksum xxhash64 85} 86 87// NewReader creates a new Reader that decompresses data from the given reader. 88func NewReader(input io.Reader) *Reader { 89 r := new(Reader) 90 r.Reset(input) 91 return r 92} 93 94// Reset discards the current state and starts reading a new stream from r. 95// This permits reusing a Reader rather than allocating a new one. 96func (r *Reader) Reset(input io.Reader) { 97 r.r = input 98 99 // Several fields are preserved to avoid allocation. 100 // Others are always set before they are used. 101 r.sawFrameHeader = false 102 r.hasChecksum = false 103 r.readOneFrame = false 104 r.frameSizeUnknown = false 105 r.remainingFrameSize = 0 106 r.blockOffset = 0 107 r.buffer = r.buffer[:0] 108 r.off = 0 109 // repeatedOffset1 110 // repeatedOffset2 111 // repeatedOffset3 112 // huffmanTable 113 // huffmanTableBits 114 // window 115 // compressedBuf 116 // literals 117 // seqTables 118 // seqTableBits 119 // seqTableBuffers 120 // scratch 121 // fseScratch 122} 123 124// Read implements [io.Reader]. 125func (r *Reader) Read(p []byte) (int, error) { 126 if err := r.refillIfNeeded(); err != nil { 127 return 0, err 128 } 129 n := copy(p, r.buffer[r.off:]) 130 r.off += n 131 return n, nil 132} 133 134// ReadByte implements [io.ByteReader]. 135func (r *Reader) ReadByte() (byte, error) { 136 if err := r.refillIfNeeded(); err != nil { 137 return 0, err 138 } 139 ret := r.buffer[r.off] 140 r.off++ 141 return ret, nil 142} 143 144// refillIfNeeded reads the next block if necessary. 145func (r *Reader) refillIfNeeded() error { 146 for r.off >= len(r.buffer) { 147 if err := r.refill(); err != nil { 148 return err 149 } 150 r.off = 0 151 } 152 return nil 153} 154 155// refill reads and decompresses the next block. 156func (r *Reader) refill() error { 157 if !r.sawFrameHeader { 158 if err := r.readFrameHeader(); err != nil { 159 return err 160 } 161 } 162 return r.readBlock() 163} 164 165// readFrameHeader reads the frame header and prepares to read a block. 166func (r *Reader) readFrameHeader() error { 167retry: 168 relativeOffset := 0 169 170 // Read magic number. RFC 3.1.1. 171 if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil { 172 // We require that the stream contains at least one frame. 173 if err == io.EOF && !r.readOneFrame { 174 err = io.ErrUnexpectedEOF 175 } 176 return r.wrapError(relativeOffset, err) 177 } 178 179 if magic := binary.LittleEndian.Uint32(r.scratch[:4]); magic != 0xfd2fb528 { 180 if magic >= 0x184d2a50 && magic <= 0x184d2a5f { 181 // This is a skippable frame. 182 r.blockOffset += int64(relativeOffset) + 4 183 if err := r.skipFrame(); err != nil { 184 return err 185 } 186 r.readOneFrame = true 187 goto retry 188 } 189 190 return r.makeError(relativeOffset, "invalid magic number") 191 } 192 193 relativeOffset += 4 194 195 // Read Frame_Header_Descriptor. RFC 3.1.1.1.1. 196 if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil { 197 return r.wrapNonEOFError(relativeOffset, err) 198 } 199 descriptor := r.scratch[0] 200 201 singleSegment := descriptor&(1<<5) != 0 202 203 fcsFieldSize := 1 << (descriptor >> 6) 204 if fcsFieldSize == 1 && !singleSegment { 205 fcsFieldSize = 0 206 } 207 208 var windowDescriptorSize int 209 if singleSegment { 210 windowDescriptorSize = 0 211 } else { 212 windowDescriptorSize = 1 213 } 214 215 if descriptor&(1<<3) != 0 { 216 return r.makeError(relativeOffset, "reserved bit set in frame header descriptor") 217 } 218 219 r.hasChecksum = descriptor&(1<<2) != 0 220 if r.hasChecksum { 221 r.checksum.reset() 222 } 223 224 // Dictionary_ID_Flag. RFC 3.1.1.1.1.6. 225 dictionaryIdSize := 0 226 if dictIdFlag := descriptor & 3; dictIdFlag != 0 { 227 dictionaryIdSize = 1 << (dictIdFlag - 1) 228 } 229 230 relativeOffset++ 231 232 headerSize := windowDescriptorSize + dictionaryIdSize + fcsFieldSize 233 234 if _, err := io.ReadFull(r.r, r.scratch[:headerSize]); err != nil { 235 return r.wrapNonEOFError(relativeOffset, err) 236 } 237 238 // Figure out the maximum amount of data we need to retain 239 // for backreferences. 240 var windowSize uint64 241 if !singleSegment { 242 // Window descriptor. RFC 3.1.1.1.2. 243 windowDescriptor := r.scratch[0] 244 exponent := uint64(windowDescriptor >> 3) 245 mantissa := uint64(windowDescriptor & 7) 246 windowLog := exponent + 10 247 windowBase := uint64(1) << windowLog 248 windowAdd := (windowBase / 8) * mantissa 249 windowSize = windowBase + windowAdd 250 251 // Default zstd sets limits on the window size. 252 if fuzzing && (windowLog > 31 || windowSize > 1<<27) { 253 return r.makeError(relativeOffset, "windowSize too large") 254 } 255 } 256 257 // Dictionary_ID. RFC 3.1.1.1.3. 258 if dictionaryIdSize != 0 { 259 dictionaryId := r.scratch[windowDescriptorSize : windowDescriptorSize+dictionaryIdSize] 260 // Allow only zero Dictionary ID. 261 for _, b := range dictionaryId { 262 if b != 0 { 263 return r.makeError(relativeOffset, "dictionaries are not supported") 264 } 265 } 266 } 267 268 // Frame_Content_Size. RFC 3.1.1.1.4. 269 r.frameSizeUnknown = false 270 r.remainingFrameSize = 0 271 fb := r.scratch[windowDescriptorSize+dictionaryIdSize:] 272 switch fcsFieldSize { 273 case 0: 274 r.frameSizeUnknown = true 275 case 1: 276 r.remainingFrameSize = uint64(fb[0]) 277 case 2: 278 r.remainingFrameSize = 256 + uint64(binary.LittleEndian.Uint16(fb)) 279 case 4: 280 r.remainingFrameSize = uint64(binary.LittleEndian.Uint32(fb)) 281 case 8: 282 r.remainingFrameSize = binary.LittleEndian.Uint64(fb) 283 default: 284 panic("unreachable") 285 } 286 287 // RFC 3.1.1.1.2. 288 // When Single_Segment_Flag is set, Window_Descriptor is not present. 289 // In this case, Window_Size is Frame_Content_Size. 290 if singleSegment { 291 windowSize = r.remainingFrameSize 292 } 293 294 // RFC 8878 3.1.1.1.1.2. permits us to set an 8M max on window size. 295 const maxWindowSize = 8 << 20 296 if windowSize > maxWindowSize { 297 windowSize = maxWindowSize 298 } 299 300 relativeOffset += headerSize 301 302 r.sawFrameHeader = true 303 r.readOneFrame = true 304 r.blockOffset += int64(relativeOffset) 305 306 // Prepare to read blocks from the frame. 307 r.repeatedOffset1 = 1 308 r.repeatedOffset2 = 4 309 r.repeatedOffset3 = 8 310 r.huffmanTableBits = 0 311 r.window.reset(int(windowSize)) 312 r.seqTables[0] = nil 313 r.seqTables[1] = nil 314 r.seqTables[2] = nil 315 316 return nil 317} 318 319// skipFrame skips a skippable frame. RFC 3.1.2. 320func (r *Reader) skipFrame() error { 321 relativeOffset := 0 322 323 if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil { 324 return r.wrapNonEOFError(relativeOffset, err) 325 } 326 327 relativeOffset += 4 328 329 size := binary.LittleEndian.Uint32(r.scratch[:4]) 330 if size == 0 { 331 r.blockOffset += int64(relativeOffset) 332 return nil 333 } 334 335 if seeker, ok := r.r.(io.Seeker); ok { 336 r.blockOffset += int64(relativeOffset) 337 // Implementations of Seeker do not always detect invalid offsets, 338 // so check that the new offset is valid by comparing to the end. 339 prev, err := seeker.Seek(0, io.SeekCurrent) 340 if err != nil { 341 return r.wrapError(0, err) 342 } 343 end, err := seeker.Seek(0, io.SeekEnd) 344 if err != nil { 345 return r.wrapError(0, err) 346 } 347 if prev > end-int64(size) { 348 r.blockOffset += end - prev 349 return r.makeEOFError(0) 350 } 351 352 // The new offset is valid, so seek to it. 353 _, err = seeker.Seek(prev+int64(size), io.SeekStart) 354 if err != nil { 355 return r.wrapError(0, err) 356 } 357 r.blockOffset += int64(size) 358 return nil 359 } 360 361 var skip []byte 362 const chunk = 1 << 20 // 1M 363 for size >= chunk { 364 if len(skip) == 0 { 365 skip = make([]byte, chunk) 366 } 367 if _, err := io.ReadFull(r.r, skip); err != nil { 368 return r.wrapNonEOFError(relativeOffset, err) 369 } 370 relativeOffset += chunk 371 size -= chunk 372 } 373 if size > 0 { 374 if len(skip) == 0 { 375 skip = make([]byte, size) 376 } 377 if _, err := io.ReadFull(r.r, skip); err != nil { 378 return r.wrapNonEOFError(relativeOffset, err) 379 } 380 relativeOffset += int(size) 381 } 382 383 r.blockOffset += int64(relativeOffset) 384 385 return nil 386} 387 388// readBlock reads the next block from a frame. 389func (r *Reader) readBlock() error { 390 relativeOffset := 0 391 392 // Read Block_Header. RFC 3.1.1.2. 393 if _, err := io.ReadFull(r.r, r.scratch[:3]); err != nil { 394 return r.wrapNonEOFError(relativeOffset, err) 395 } 396 397 relativeOffset += 3 398 399 header := uint32(r.scratch[0]) | (uint32(r.scratch[1]) << 8) | (uint32(r.scratch[2]) << 16) 400 401 lastBlock := header&1 != 0 402 blockType := (header >> 1) & 3 403 blockSize := int(header >> 3) 404 405 // Maximum block size is smaller of window size and 128K. 406 // We don't record the window size for a single segment frame, 407 // so just use 128K. RFC 3.1.1.2.3, 3.1.1.2.4. 408 if blockSize > 128<<10 || (r.window.size > 0 && blockSize > r.window.size) { 409 return r.makeError(relativeOffset, "block size too large") 410 } 411 412 // Handle different block types. RFC 3.1.1.2.2. 413 switch blockType { 414 case 0: 415 r.setBufferSize(blockSize) 416 if _, err := io.ReadFull(r.r, r.buffer); err != nil { 417 return r.wrapNonEOFError(relativeOffset, err) 418 } 419 relativeOffset += blockSize 420 r.blockOffset += int64(relativeOffset) 421 case 1: 422 r.setBufferSize(blockSize) 423 if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil { 424 return r.wrapNonEOFError(relativeOffset, err) 425 } 426 relativeOffset++ 427 v := r.scratch[0] 428 for i := range r.buffer { 429 r.buffer[i] = v 430 } 431 r.blockOffset += int64(relativeOffset) 432 case 2: 433 r.blockOffset += int64(relativeOffset) 434 if err := r.compressedBlock(blockSize); err != nil { 435 return err 436 } 437 r.blockOffset += int64(blockSize) 438 case 3: 439 return r.makeError(relativeOffset, "invalid block type") 440 } 441 442 if !r.frameSizeUnknown { 443 if uint64(len(r.buffer)) > r.remainingFrameSize { 444 return r.makeError(relativeOffset, "too many uncompressed bytes in frame") 445 } 446 r.remainingFrameSize -= uint64(len(r.buffer)) 447 } 448 449 if r.hasChecksum { 450 r.checksum.update(r.buffer) 451 } 452 453 if !lastBlock { 454 r.window.save(r.buffer) 455 } else { 456 if !r.frameSizeUnknown && r.remainingFrameSize != 0 { 457 return r.makeError(relativeOffset, "not enough uncompressed bytes for frame") 458 } 459 // Check for checksum at end of frame. RFC 3.1.1. 460 if r.hasChecksum { 461 if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil { 462 return r.wrapNonEOFError(0, err) 463 } 464 465 inputChecksum := binary.LittleEndian.Uint32(r.scratch[:4]) 466 dataChecksum := uint32(r.checksum.digest()) 467 if inputChecksum != dataChecksum { 468 return r.wrapError(0, fmt.Errorf("invalid checksum: got %#x want %#x", dataChecksum, inputChecksum)) 469 } 470 471 r.blockOffset += 4 472 } 473 r.sawFrameHeader = false 474 } 475 476 return nil 477} 478 479// setBufferSize sets the decompressed buffer size. 480// When this is called the buffer is empty. 481func (r *Reader) setBufferSize(size int) { 482 if cap(r.buffer) < size { 483 need := size - cap(r.buffer) 484 r.buffer = append(r.buffer[:cap(r.buffer)], make([]byte, need)...) 485 } 486 r.buffer = r.buffer[:size] 487} 488 489// zstdError is an error while decompressing. 490type zstdError struct { 491 offset int64 492 err error 493} 494 495func (ze *zstdError) Error() string { 496 return fmt.Sprintf("zstd decompression error at %d: %v", ze.offset, ze.err) 497} 498 499func (ze *zstdError) Unwrap() error { 500 return ze.err 501} 502 503func (r *Reader) makeEOFError(off int) error { 504 return r.wrapError(off, io.ErrUnexpectedEOF) 505} 506 507func (r *Reader) wrapNonEOFError(off int, err error) error { 508 if err == io.EOF { 509 err = io.ErrUnexpectedEOF 510 } 511 return r.wrapError(off, err) 512} 513 514func (r *Reader) makeError(off int, msg string) error { 515 return r.wrapError(off, errors.New(msg)) 516} 517 518func (r *Reader) wrapError(off int, err error) error { 519 if err == io.EOF { 520 return err 521 } 522 return &zstdError{r.blockOffset + int64(off), err} 523} 524