1// Copyright 2023 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package zstd 6 7import ( 8 "io" 9) 10 11// debug can be set in the source to print debug info using println. 12const debug = false 13 14// compressedBlock decompresses a compressed block, storing the decompressed 15// data in r.buffer. The blockSize argument is the compressed size. 16// RFC 3.1.1.3. 17func (r *Reader) compressedBlock(blockSize int) error { 18 if len(r.compressedBuf) >= blockSize { 19 r.compressedBuf = r.compressedBuf[:blockSize] 20 } else { 21 // We know that blockSize <= 128K, 22 // so this won't allocate an enormous amount. 23 need := blockSize - len(r.compressedBuf) 24 r.compressedBuf = append(r.compressedBuf, make([]byte, need)...) 25 } 26 27 if _, err := io.ReadFull(r.r, r.compressedBuf); err != nil { 28 return r.wrapNonEOFError(0, err) 29 } 30 31 data := block(r.compressedBuf) 32 off := 0 33 r.buffer = r.buffer[:0] 34 35 litoff, litbuf, err := r.readLiterals(data, off, r.literals[:0]) 36 if err != nil { 37 return err 38 } 39 r.literals = litbuf 40 41 off = litoff 42 43 seqCount, off, err := r.initSeqs(data, off) 44 if err != nil { 45 return err 46 } 47 48 if seqCount == 0 { 49 // No sequences, just literals. 50 if off < len(data) { 51 return r.makeError(off, "extraneous data after no sequences") 52 } 53 54 r.buffer = append(r.buffer, litbuf...) 55 56 return nil 57 } 58 59 return r.execSeqs(data, off, litbuf, seqCount) 60} 61 62// seqCode is the kind of sequence codes we have to handle. 63type seqCode int 64 65const ( 66 seqLiteral seqCode = iota 67 seqOffset 68 seqMatch 69) 70 71// seqCodeInfoData is the information needed to set up seqTables and 72// seqTableBits for a particular kind of sequence code. 73type seqCodeInfoData struct { 74 predefTable []fseBaselineEntry // predefined FSE 75 predefTableBits int // number of bits in predefTable 76 maxSym int // max symbol value in FSE 77 maxBits int // max bits for FSE 78 79 // toBaseline converts from an FSE table to an FSE baseline table. 80 toBaseline func(*Reader, int, []fseEntry, []fseBaselineEntry) error 81} 82 83// seqCodeInfo is the seqCodeInfoData for each kind of sequence code. 84var seqCodeInfo = [3]seqCodeInfoData{ 85 seqLiteral: { 86 predefTable: predefinedLiteralTable[:], 87 predefTableBits: 6, 88 maxSym: 35, 89 maxBits: 9, 90 toBaseline: (*Reader).makeLiteralBaselineFSE, 91 }, 92 seqOffset: { 93 predefTable: predefinedOffsetTable[:], 94 predefTableBits: 5, 95 maxSym: 31, 96 maxBits: 8, 97 toBaseline: (*Reader).makeOffsetBaselineFSE, 98 }, 99 seqMatch: { 100 predefTable: predefinedMatchTable[:], 101 predefTableBits: 6, 102 maxSym: 52, 103 maxBits: 9, 104 toBaseline: (*Reader).makeMatchBaselineFSE, 105 }, 106} 107 108// initSeqs reads the Sequences_Section_Header and sets up the FSE 109// tables used to read the sequence codes. It returns the number of 110// sequences and the new offset. RFC 3.1.1.3.2.1. 111func (r *Reader) initSeqs(data block, off int) (int, int, error) { 112 if off >= len(data) { 113 return 0, 0, r.makeEOFError(off) 114 } 115 116 seqHdr := data[off] 117 off++ 118 if seqHdr == 0 { 119 return 0, off, nil 120 } 121 122 var seqCount int 123 if seqHdr < 128 { 124 seqCount = int(seqHdr) 125 } else if seqHdr < 255 { 126 if off >= len(data) { 127 return 0, 0, r.makeEOFError(off) 128 } 129 seqCount = ((int(seqHdr) - 128) << 8) + int(data[off]) 130 off++ 131 } else { 132 if off+1 >= len(data) { 133 return 0, 0, r.makeEOFError(off) 134 } 135 seqCount = int(data[off]) + (int(data[off+1]) << 8) + 0x7f00 136 off += 2 137 } 138 139 // Read the Symbol_Compression_Modes byte. 140 141 if off >= len(data) { 142 return 0, 0, r.makeEOFError(off) 143 } 144 symMode := data[off] 145 if symMode&3 != 0 { 146 return 0, 0, r.makeError(off, "invalid symbol compression mode") 147 } 148 off++ 149 150 // Set up the FSE tables used to decode the sequence codes. 151 152 var err error 153 off, err = r.setSeqTable(data, off, seqLiteral, (symMode>>6)&3) 154 if err != nil { 155 return 0, 0, err 156 } 157 158 off, err = r.setSeqTable(data, off, seqOffset, (symMode>>4)&3) 159 if err != nil { 160 return 0, 0, err 161 } 162 163 off, err = r.setSeqTable(data, off, seqMatch, (symMode>>2)&3) 164 if err != nil { 165 return 0, 0, err 166 } 167 168 return seqCount, off, nil 169} 170 171// setSeqTable uses the Compression_Mode in mode to set up r.seqTables and 172// r.seqTableBits for kind. We store these in the Reader because one of 173// the modes simply reuses the value from the last block in the frame. 174func (r *Reader) setSeqTable(data block, off int, kind seqCode, mode byte) (int, error) { 175 info := &seqCodeInfo[kind] 176 switch mode { 177 case 0: 178 // Predefined_Mode 179 r.seqTables[kind] = info.predefTable 180 r.seqTableBits[kind] = uint8(info.predefTableBits) 181 return off, nil 182 183 case 1: 184 // RLE_Mode 185 if off >= len(data) { 186 return 0, r.makeEOFError(off) 187 } 188 rle := data[off] 189 off++ 190 191 // Build a simple baseline table that always returns rle. 192 193 entry := []fseEntry{ 194 { 195 sym: rle, 196 bits: 0, 197 base: 0, 198 }, 199 } 200 if cap(r.seqTableBuffers[kind]) == 0 { 201 r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1<<info.maxBits) 202 } 203 r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1] 204 if err := info.toBaseline(r, off, entry, r.seqTableBuffers[kind]); err != nil { 205 return 0, err 206 } 207 208 r.seqTables[kind] = r.seqTableBuffers[kind] 209 r.seqTableBits[kind] = 0 210 return off, nil 211 212 case 2: 213 // FSE_Compressed_Mode 214 if cap(r.fseScratch) < 1<<info.maxBits { 215 r.fseScratch = make([]fseEntry, 1<<info.maxBits) 216 } 217 r.fseScratch = r.fseScratch[:1<<info.maxBits] 218 219 tableBits, roff, err := r.readFSE(data, off, info.maxSym, info.maxBits, r.fseScratch) 220 if err != nil { 221 return 0, err 222 } 223 r.fseScratch = r.fseScratch[:1<<tableBits] 224 225 if cap(r.seqTableBuffers[kind]) == 0 { 226 r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1<<info.maxBits) 227 } 228 r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1<<tableBits] 229 230 if err := info.toBaseline(r, roff, r.fseScratch, r.seqTableBuffers[kind]); err != nil { 231 return 0, err 232 } 233 234 r.seqTables[kind] = r.seqTableBuffers[kind] 235 r.seqTableBits[kind] = uint8(tableBits) 236 return roff, nil 237 238 case 3: 239 // Repeat_Mode 240 if len(r.seqTables[kind]) == 0 { 241 return 0, r.makeError(off, "missing repeat sequence FSE table") 242 } 243 return off, nil 244 } 245 panic("unreachable") 246} 247 248// execSeqs reads and executes the sequences. RFC 3.1.1.3.2.1.2. 249func (r *Reader) execSeqs(data block, off int, litbuf []byte, seqCount int) error { 250 // Set up the initial states for the sequence code readers. 251 252 rbr, err := r.makeReverseBitReader(data, len(data)-1, off) 253 if err != nil { 254 return err 255 } 256 257 literalState, err := rbr.val(r.seqTableBits[seqLiteral]) 258 if err != nil { 259 return err 260 } 261 262 offsetState, err := rbr.val(r.seqTableBits[seqOffset]) 263 if err != nil { 264 return err 265 } 266 267 matchState, err := rbr.val(r.seqTableBits[seqMatch]) 268 if err != nil { 269 return err 270 } 271 272 // Read and perform all the sequences. RFC 3.1.1.4. 273 274 seq := 0 275 for seq < seqCount { 276 if len(r.buffer)+len(litbuf) > 128<<10 { 277 return rbr.makeError("uncompressed size too big") 278 } 279 280 ptoffset := &r.seqTables[seqOffset][offsetState] 281 ptmatch := &r.seqTables[seqMatch][matchState] 282 ptliteral := &r.seqTables[seqLiteral][literalState] 283 284 add, err := rbr.val(ptoffset.basebits) 285 if err != nil { 286 return err 287 } 288 offset := ptoffset.baseline + add 289 290 add, err = rbr.val(ptmatch.basebits) 291 if err != nil { 292 return err 293 } 294 match := ptmatch.baseline + add 295 296 add, err = rbr.val(ptliteral.basebits) 297 if err != nil { 298 return err 299 } 300 literal := ptliteral.baseline + add 301 302 // Handle repeat offsets. RFC 3.1.1.5. 303 // See the comment in makeOffsetBaselineFSE. 304 if ptoffset.basebits > 1 { 305 r.repeatedOffset3 = r.repeatedOffset2 306 r.repeatedOffset2 = r.repeatedOffset1 307 r.repeatedOffset1 = offset 308 } else { 309 if literal == 0 { 310 offset++ 311 } 312 switch offset { 313 case 1: 314 offset = r.repeatedOffset1 315 case 2: 316 offset = r.repeatedOffset2 317 r.repeatedOffset2 = r.repeatedOffset1 318 r.repeatedOffset1 = offset 319 case 3: 320 offset = r.repeatedOffset3 321 r.repeatedOffset3 = r.repeatedOffset2 322 r.repeatedOffset2 = r.repeatedOffset1 323 r.repeatedOffset1 = offset 324 case 4: 325 offset = r.repeatedOffset1 - 1 326 r.repeatedOffset3 = r.repeatedOffset2 327 r.repeatedOffset2 = r.repeatedOffset1 328 r.repeatedOffset1 = offset 329 } 330 } 331 332 seq++ 333 if seq < seqCount { 334 // Update the states. 335 add, err = rbr.val(ptliteral.bits) 336 if err != nil { 337 return err 338 } 339 literalState = uint32(ptliteral.base) + add 340 341 add, err = rbr.val(ptmatch.bits) 342 if err != nil { 343 return err 344 } 345 matchState = uint32(ptmatch.base) + add 346 347 add, err = rbr.val(ptoffset.bits) 348 if err != nil { 349 return err 350 } 351 offsetState = uint32(ptoffset.base) + add 352 } 353 354 // The next sequence is now in literal, offset, match. 355 356 if debug { 357 println("literal", literal, "offset", offset, "match", match) 358 } 359 360 // Copy literal bytes from litbuf. 361 if literal > uint32(len(litbuf)) { 362 return rbr.makeError("literal byte overflow") 363 } 364 if literal > 0 { 365 r.buffer = append(r.buffer, litbuf[:literal]...) 366 litbuf = litbuf[literal:] 367 } 368 369 if match > 0 { 370 if err := r.copyFromWindow(&rbr, offset, match); err != nil { 371 return err 372 } 373 } 374 } 375 376 r.buffer = append(r.buffer, litbuf...) 377 378 if rbr.cnt != 0 { 379 return r.makeError(off, "extraneous data after sequences") 380 } 381 382 return nil 383} 384 385// Copy match bytes from the decoded output, or the window, at offset. 386func (r *Reader) copyFromWindow(rbr *reverseBitReader, offset, match uint32) error { 387 if offset == 0 { 388 return rbr.makeError("invalid zero offset") 389 } 390 391 // Offset may point into the buffer or the window and 392 // match may extend past the end of the initial buffer. 393 // |--r.window--|--r.buffer--| 394 // |<-----offset------| 395 // |------match----------->| 396 bufferOffset := uint32(0) 397 lenBlock := uint32(len(r.buffer)) 398 if lenBlock < offset { 399 lenWindow := r.window.len() 400 copy := offset - lenBlock 401 if copy > lenWindow { 402 return rbr.makeError("offset past window") 403 } 404 windowOffset := lenWindow - copy 405 if copy > match { 406 copy = match 407 } 408 r.buffer = r.window.appendTo(r.buffer, windowOffset, windowOffset+copy) 409 match -= copy 410 } else { 411 bufferOffset = lenBlock - offset 412 } 413 414 // We are being asked to copy data that we are adding to the 415 // buffer in the same copy. 416 for match > 0 { 417 copy := uint32(len(r.buffer)) - bufferOffset 418 if copy > match { 419 copy = match 420 } 421 r.buffer = append(r.buffer, r.buffer[bufferOffset:bufferOffset+copy]...) 422 match -= copy 423 } 424 return nil 425} 426