1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package zstd
6
7import (
8	"math/bits"
9)
10
11// fseEntry is one entry in an FSE table.
12type fseEntry struct {
13	sym  uint8  // value that this entry records
14	bits uint8  // number of bits to read to determine next state
15	base uint16 // add those bits to this state to get the next state
16}
17
18// readFSE reads an FSE table from data starting at off.
19// maxSym is the maximum symbol value.
20// maxBits is the maximum number of bits permitted for symbols in the table.
21// The FSE is written into table, which must be at least 1<<maxBits in size.
22// This returns the number of bits in the FSE table and the new offset.
23// RFC 4.1.1.
24func (r *Reader) readFSE(data block, off, maxSym, maxBits int, table []fseEntry) (tableBits, roff int, err error) {
25	br := r.makeBitReader(data, off)
26	if err := br.moreBits(); err != nil {
27		return 0, 0, err
28	}
29
30	accuracyLog := int(br.val(4)) + 5
31	if accuracyLog > maxBits {
32		return 0, 0, br.makeError("FSE accuracy log too large")
33	}
34
35	// The number of remaining probabilities, plus 1.
36	// This determines the number of bits to be read for the next value.
37	remaining := (1 << accuracyLog) + 1
38
39	// The current difference between small and large values,
40	// which depends on the number of remaining values.
41	// Small values use 1 less bit.
42	threshold := 1 << accuracyLog
43
44	// The number of bits needed to compute threshold.
45	bitsNeeded := accuracyLog + 1
46
47	// The next character value.
48	sym := 0
49
50	// Whether the last count was 0.
51	prev0 := false
52
53	var norm [256]int16
54
55	for remaining > 1 && sym <= maxSym {
56		if err := br.moreBits(); err != nil {
57			return 0, 0, err
58		}
59
60		if prev0 {
61			// Previous count was 0, so there is a 2-bit
62			// repeat flag. If the 2-bit flag is 0b11,
63			// it adds 3 and then there is another repeat flag.
64			zsym := sym
65			for (br.bits & 0xfff) == 0xfff {
66				zsym += 3 * 6
67				br.bits >>= 12
68				br.cnt -= 12
69				if err := br.moreBits(); err != nil {
70					return 0, 0, err
71				}
72			}
73			for (br.bits & 3) == 3 {
74				zsym += 3
75				br.bits >>= 2
76				br.cnt -= 2
77				if err := br.moreBits(); err != nil {
78					return 0, 0, err
79				}
80			}
81
82			// We have at least 14 bits here,
83			// no need to call moreBits
84
85			zsym += int(br.val(2))
86
87			if zsym > maxSym {
88				return 0, 0, br.makeError("FSE symbol index overflow")
89			}
90
91			for ; sym < zsym; sym++ {
92				norm[uint8(sym)] = 0
93			}
94
95			prev0 = false
96			continue
97		}
98
99		max := (2*threshold - 1) - remaining
100		var count int
101		if int(br.bits&uint32(threshold-1)) < max {
102			// A small value.
103			count = int(br.bits & uint32((threshold - 1)))
104			br.bits >>= bitsNeeded - 1
105			br.cnt -= uint32(bitsNeeded - 1)
106		} else {
107			// A large value.
108			count = int(br.bits & uint32((2*threshold - 1)))
109			if count >= threshold {
110				count -= max
111			}
112			br.bits >>= bitsNeeded
113			br.cnt -= uint32(bitsNeeded)
114		}
115
116		count--
117		if count >= 0 {
118			remaining -= count
119		} else {
120			remaining--
121		}
122		if sym >= 256 {
123			return 0, 0, br.makeError("FSE sym overflow")
124		}
125		norm[uint8(sym)] = int16(count)
126		sym++
127
128		prev0 = count == 0
129
130		for remaining < threshold {
131			bitsNeeded--
132			threshold >>= 1
133		}
134	}
135
136	if remaining != 1 {
137		return 0, 0, br.makeError("too many symbols in FSE table")
138	}
139
140	for ; sym <= maxSym; sym++ {
141		norm[uint8(sym)] = 0
142	}
143
144	br.backup()
145
146	if err := r.buildFSE(off, norm[:maxSym+1], table, accuracyLog); err != nil {
147		return 0, 0, err
148	}
149
150	return accuracyLog, int(br.off), nil
151}
152
153// buildFSE builds an FSE decoding table from a list of probabilities.
154// The probabilities are in norm. next is scratch space. The number of bits
155// in the table is tableBits.
156func (r *Reader) buildFSE(off int, norm []int16, table []fseEntry, tableBits int) error {
157	tableSize := 1 << tableBits
158	highThreshold := tableSize - 1
159
160	var next [256]uint16
161
162	for i, n := range norm {
163		if n >= 0 {
164			next[uint8(i)] = uint16(n)
165		} else {
166			table[highThreshold].sym = uint8(i)
167			highThreshold--
168			next[uint8(i)] = 1
169		}
170	}
171
172	pos := 0
173	step := (tableSize >> 1) + (tableSize >> 3) + 3
174	mask := tableSize - 1
175	for i, n := range norm {
176		for j := 0; j < int(n); j++ {
177			table[pos].sym = uint8(i)
178			pos = (pos + step) & mask
179			for pos > highThreshold {
180				pos = (pos + step) & mask
181			}
182		}
183	}
184	if pos != 0 {
185		return r.makeError(off, "FSE count error")
186	}
187
188	for i := 0; i < tableSize; i++ {
189		sym := table[i].sym
190		nextState := next[sym]
191		next[sym]++
192
193		if nextState == 0 {
194			return r.makeError(off, "FSE state error")
195		}
196
197		highBit := 15 - bits.LeadingZeros16(nextState)
198
199		bits := tableBits - highBit
200		table[i].bits = uint8(bits)
201		table[i].base = (nextState << bits) - uint16(tableSize)
202	}
203
204	return nil
205}
206
207// fseBaselineEntry is an entry in an FSE baseline table.
208// We use these for literal/match/length values.
209// Those require mapping the symbol to a baseline value,
210// and then reading zero or more bits and adding the value to the baseline.
211// Rather than looking these up in separate tables,
212// we convert the FSE table to an FSE baseline table.
213type fseBaselineEntry struct {
214	baseline uint32 // baseline for value that this entry represents
215	basebits uint8  // number of bits to read to add to baseline
216	bits     uint8  // number of bits to read to determine next state
217	base     uint16 // add the bits to this base to get the next state
218}
219
220// Given a literal length code, we need to read a number of bits and
221// add that to a baseline. For states 0 to 15 the baseline is the
222// state and the number of bits is zero. RFC 3.1.1.3.2.1.1.
223
224const literalLengthOffset = 16
225
226var literalLengthBase = []uint32{
227	16 | (1 << 24),
228	18 | (1 << 24),
229	20 | (1 << 24),
230	22 | (1 << 24),
231	24 | (2 << 24),
232	28 | (2 << 24),
233	32 | (3 << 24),
234	40 | (3 << 24),
235	48 | (4 << 24),
236	64 | (6 << 24),
237	128 | (7 << 24),
238	256 | (8 << 24),
239	512 | (9 << 24),
240	1024 | (10 << 24),
241	2048 | (11 << 24),
242	4096 | (12 << 24),
243	8192 | (13 << 24),
244	16384 | (14 << 24),
245	32768 | (15 << 24),
246	65536 | (16 << 24),
247}
248
249// makeLiteralBaselineFSE converts the literal length fseTable to baselineTable.
250func (r *Reader) makeLiteralBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error {
251	for i, e := range fseTable {
252		be := fseBaselineEntry{
253			bits: e.bits,
254			base: e.base,
255		}
256		if e.sym < literalLengthOffset {
257			be.baseline = uint32(e.sym)
258			be.basebits = 0
259		} else {
260			if e.sym > 35 {
261				return r.makeError(off, "FSE baseline symbol overflow")
262			}
263			idx := e.sym - literalLengthOffset
264			basebits := literalLengthBase[idx]
265			be.baseline = basebits & 0xffffff
266			be.basebits = uint8(basebits >> 24)
267		}
268		baselineTable[i] = be
269	}
270	return nil
271}
272
273// makeOffsetBaselineFSE converts the offset length fseTable to baselineTable.
274func (r *Reader) makeOffsetBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error {
275	for i, e := range fseTable {
276		be := fseBaselineEntry{
277			bits: e.bits,
278			base: e.base,
279		}
280		if e.sym > 31 {
281			return r.makeError(off, "FSE offset symbol overflow")
282		}
283
284		// The simple way to write this is
285		//     be.baseline = 1 << e.sym
286		//     be.basebits = e.sym
287		// That would give us an offset value that corresponds to
288		// the one described in the RFC. However, for offsets > 3
289		// we have to subtract 3. And for offset values 1, 2, 3
290		// we use a repeated offset.
291		//
292		// The baseline is always a power of 2, and is never 0,
293		// so for those low values we will see one entry that is
294		// baseline 1, basebits 0, and one entry that is baseline 2,
295		// basebits 1. All other entries will have baseline >= 4
296		// basebits >= 2.
297		//
298		// So we can check for RFC offset <= 3 by checking for
299		// basebits <= 1. That means that we can subtract 3 here
300		// and not worry about doing it in the hot loop.
301
302		be.baseline = 1 << e.sym
303		if e.sym >= 2 {
304			be.baseline -= 3
305		}
306		be.basebits = e.sym
307		baselineTable[i] = be
308	}
309	return nil
310}
311
312// Given a match length code, we need to read a number of bits and add
313// that to a baseline. For states 0 to 31 the baseline is state+3 and
314// the number of bits is zero. RFC 3.1.1.3.2.1.1.
315
316const matchLengthOffset = 32
317
318var matchLengthBase = []uint32{
319	35 | (1 << 24),
320	37 | (1 << 24),
321	39 | (1 << 24),
322	41 | (1 << 24),
323	43 | (2 << 24),
324	47 | (2 << 24),
325	51 | (3 << 24),
326	59 | (3 << 24),
327	67 | (4 << 24),
328	83 | (4 << 24),
329	99 | (5 << 24),
330	131 | (7 << 24),
331	259 | (8 << 24),
332	515 | (9 << 24),
333	1027 | (10 << 24),
334	2051 | (11 << 24),
335	4099 | (12 << 24),
336	8195 | (13 << 24),
337	16387 | (14 << 24),
338	32771 | (15 << 24),
339	65539 | (16 << 24),
340}
341
342// makeMatchBaselineFSE converts the match length fseTable to baselineTable.
343func (r *Reader) makeMatchBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error {
344	for i, e := range fseTable {
345		be := fseBaselineEntry{
346			bits: e.bits,
347			base: e.base,
348		}
349		if e.sym < matchLengthOffset {
350			be.baseline = uint32(e.sym) + 3
351			be.basebits = 0
352		} else {
353			if e.sym > 52 {
354				return r.makeError(off, "FSE baseline symbol overflow")
355			}
356			idx := e.sym - matchLengthOffset
357			basebits := matchLengthBase[idx]
358			be.baseline = basebits & 0xffffff
359			be.basebits = uint8(basebits >> 24)
360		}
361		baselineTable[i] = be
362	}
363	return nil
364}
365
366// predefinedLiteralTable is the predefined table to use for literal lengths.
367// Generated from table in RFC 3.1.1.3.2.2.1.
368// Checked by TestPredefinedTables.
369var predefinedLiteralTable = [...]fseBaselineEntry{
370	{0, 0, 4, 0}, {0, 0, 4, 16}, {1, 0, 5, 32},
371	{3, 0, 5, 0}, {4, 0, 5, 0}, {6, 0, 5, 0},
372	{7, 0, 5, 0}, {9, 0, 5, 0}, {10, 0, 5, 0},
373	{12, 0, 5, 0}, {14, 0, 6, 0}, {16, 1, 5, 0},
374	{20, 1, 5, 0}, {22, 1, 5, 0}, {28, 2, 5, 0},
375	{32, 3, 5, 0}, {48, 4, 5, 0}, {64, 6, 5, 32},
376	{128, 7, 5, 0}, {256, 8, 6, 0}, {1024, 10, 6, 0},
377	{4096, 12, 6, 0}, {0, 0, 4, 32}, {1, 0, 4, 0},
378	{2, 0, 5, 0}, {4, 0, 5, 32}, {5, 0, 5, 0},
379	{7, 0, 5, 32}, {8, 0, 5, 0}, {10, 0, 5, 32},
380	{11, 0, 5, 0}, {13, 0, 6, 0}, {16, 1, 5, 32},
381	{18, 1, 5, 0}, {22, 1, 5, 32}, {24, 2, 5, 0},
382	{32, 3, 5, 32}, {40, 3, 5, 0}, {64, 6, 4, 0},
383	{64, 6, 4, 16}, {128, 7, 5, 32}, {512, 9, 6, 0},
384	{2048, 11, 6, 0}, {0, 0, 4, 48}, {1, 0, 4, 16},
385	{2, 0, 5, 32}, {3, 0, 5, 32}, {5, 0, 5, 32},
386	{6, 0, 5, 32}, {8, 0, 5, 32}, {9, 0, 5, 32},
387	{11, 0, 5, 32}, {12, 0, 5, 32}, {15, 0, 6, 0},
388	{18, 1, 5, 32}, {20, 1, 5, 32}, {24, 2, 5, 32},
389	{28, 2, 5, 32}, {40, 3, 5, 32}, {48, 4, 5, 32},
390	{65536, 16, 6, 0}, {32768, 15, 6, 0}, {16384, 14, 6, 0},
391	{8192, 13, 6, 0},
392}
393
394// predefinedOffsetTable is the predefined table to use for offsets.
395// Generated from table in RFC 3.1.1.3.2.2.3.
396// Checked by TestPredefinedTables.
397var predefinedOffsetTable = [...]fseBaselineEntry{
398	{1, 0, 5, 0}, {61, 6, 4, 0}, {509, 9, 5, 0},
399	{32765, 15, 5, 0}, {2097149, 21, 5, 0}, {5, 3, 5, 0},
400	{125, 7, 4, 0}, {4093, 12, 5, 0}, {262141, 18, 5, 0},
401	{8388605, 23, 5, 0}, {29, 5, 5, 0}, {253, 8, 4, 0},
402	{16381, 14, 5, 0}, {1048573, 20, 5, 0}, {1, 2, 5, 0},
403	{125, 7, 4, 16}, {2045, 11, 5, 0}, {131069, 17, 5, 0},
404	{4194301, 22, 5, 0}, {13, 4, 5, 0}, {253, 8, 4, 16},
405	{8189, 13, 5, 0}, {524285, 19, 5, 0}, {2, 1, 5, 0},
406	{61, 6, 4, 16}, {1021, 10, 5, 0}, {65533, 16, 5, 0},
407	{268435453, 28, 5, 0}, {134217725, 27, 5, 0}, {67108861, 26, 5, 0},
408	{33554429, 25, 5, 0}, {16777213, 24, 5, 0},
409}
410
411// predefinedMatchTable is the predefined table to use for match lengths.
412// Generated from table in RFC 3.1.1.3.2.2.2.
413// Checked by TestPredefinedTables.
414var predefinedMatchTable = [...]fseBaselineEntry{
415	{3, 0, 6, 0}, {4, 0, 4, 0}, {5, 0, 5, 32},
416	{6, 0, 5, 0}, {8, 0, 5, 0}, {9, 0, 5, 0},
417	{11, 0, 5, 0}, {13, 0, 6, 0}, {16, 0, 6, 0},
418	{19, 0, 6, 0}, {22, 0, 6, 0}, {25, 0, 6, 0},
419	{28, 0, 6, 0}, {31, 0, 6, 0}, {34, 0, 6, 0},
420	{37, 1, 6, 0}, {41, 1, 6, 0}, {47, 2, 6, 0},
421	{59, 3, 6, 0}, {83, 4, 6, 0}, {131, 7, 6, 0},
422	{515, 9, 6, 0}, {4, 0, 4, 16}, {5, 0, 4, 0},
423	{6, 0, 5, 32}, {7, 0, 5, 0}, {9, 0, 5, 32},
424	{10, 0, 5, 0}, {12, 0, 6, 0}, {15, 0, 6, 0},
425	{18, 0, 6, 0}, {21, 0, 6, 0}, {24, 0, 6, 0},
426	{27, 0, 6, 0}, {30, 0, 6, 0}, {33, 0, 6, 0},
427	{35, 1, 6, 0}, {39, 1, 6, 0}, {43, 2, 6, 0},
428	{51, 3, 6, 0}, {67, 4, 6, 0}, {99, 5, 6, 0},
429	{259, 8, 6, 0}, {4, 0, 4, 32}, {4, 0, 4, 48},
430	{5, 0, 4, 16}, {7, 0, 5, 32}, {8, 0, 5, 32},
431	{10, 0, 5, 32}, {11, 0, 5, 32}, {14, 0, 6, 0},
432	{17, 0, 6, 0}, {20, 0, 6, 0}, {23, 0, 6, 0},
433	{26, 0, 6, 0}, {29, 0, 6, 0}, {32, 0, 6, 0},
434	{65539, 16, 6, 0}, {32771, 15, 6, 0}, {16387, 14, 6, 0},
435	{8195, 13, 6, 0}, {4099, 12, 6, 0}, {2051, 11, 6, 0},
436	{1027, 10, 6, 0},
437}
438