1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package bzip2
6
7import (
8	"cmp"
9	"slices"
10)
11
12// A huffmanTree is a binary tree which is navigated, bit-by-bit to reach a
13// symbol.
14type huffmanTree struct {
15	// nodes contains all the non-leaf nodes in the tree. nodes[0] is the
16	// root of the tree and nextNode contains the index of the next element
17	// of nodes to use when the tree is being constructed.
18	nodes    []huffmanNode
19	nextNode int
20}
21
22// A huffmanNode is a node in the tree. left and right contain indexes into the
23// nodes slice of the tree. If left or right is invalidNodeValue then the child
24// is a left node and its value is in leftValue/rightValue.
25//
26// The symbols are uint16s because bzip2 encodes not only MTF indexes in the
27// tree, but also two magic values for run-length encoding and an EOF symbol.
28// Thus there are more than 256 possible symbols.
29type huffmanNode struct {
30	left, right           uint16
31	leftValue, rightValue uint16
32}
33
34// invalidNodeValue is an invalid index which marks a leaf node in the tree.
35const invalidNodeValue = 0xffff
36
37// Decode reads bits from the given bitReader and navigates the tree until a
38// symbol is found.
39func (t *huffmanTree) Decode(br *bitReader) (v uint16) {
40	nodeIndex := uint16(0) // node 0 is the root of the tree.
41
42	for {
43		node := &t.nodes[nodeIndex]
44
45		var bit uint16
46		if br.bits > 0 {
47			// Get next bit - fast path.
48			br.bits--
49			bit = uint16(br.n>>(br.bits&63)) & 1
50		} else {
51			// Get next bit - slow path.
52			// Use ReadBits to retrieve a single bit
53			// from the underling io.ByteReader.
54			bit = uint16(br.ReadBits(1))
55		}
56
57		// Trick a compiler into generating conditional move instead of branch,
58		// by making both loads unconditional.
59		l, r := node.left, node.right
60
61		if bit == 1 {
62			nodeIndex = l
63		} else {
64			nodeIndex = r
65		}
66
67		if nodeIndex == invalidNodeValue {
68			// We found a leaf. Use the value of bit to decide
69			// whether is a left or a right value.
70			l, r := node.leftValue, node.rightValue
71			if bit == 1 {
72				v = l
73			} else {
74				v = r
75			}
76			return
77		}
78	}
79}
80
81// newHuffmanTree builds a Huffman tree from a slice containing the code
82// lengths of each symbol. The maximum code length is 32 bits.
83func newHuffmanTree(lengths []uint8) (huffmanTree, error) {
84	// There are many possible trees that assign the same code length to
85	// each symbol (consider reflecting a tree down the middle, for
86	// example). Since the code length assignments determine the
87	// efficiency of the tree, each of these trees is equally good. In
88	// order to minimize the amount of information needed to build a tree
89	// bzip2 uses a canonical tree so that it can be reconstructed given
90	// only the code length assignments.
91
92	if len(lengths) < 2 {
93		panic("newHuffmanTree: too few symbols")
94	}
95
96	var t huffmanTree
97
98	// First we sort the code length assignments by ascending code length,
99	// using the symbol value to break ties.
100	pairs := make([]huffmanSymbolLengthPair, len(lengths))
101	for i, length := range lengths {
102		pairs[i].value = uint16(i)
103		pairs[i].length = length
104	}
105
106	slices.SortFunc(pairs, func(a, b huffmanSymbolLengthPair) int {
107		if c := cmp.Compare(a.length, b.length); c != 0 {
108			return c
109		}
110		return cmp.Compare(a.value, b.value)
111	})
112
113	// Now we assign codes to the symbols, starting with the longest code.
114	// We keep the codes packed into a uint32, at the most-significant end.
115	// So branches are taken from the MSB downwards. This makes it easy to
116	// sort them later.
117	code := uint32(0)
118	length := uint8(32)
119
120	codes := make([]huffmanCode, len(lengths))
121	for i := len(pairs) - 1; i >= 0; i-- {
122		if length > pairs[i].length {
123			length = pairs[i].length
124		}
125		codes[i].code = code
126		codes[i].codeLen = length
127		codes[i].value = pairs[i].value
128		// We need to 'increment' the code, which means treating |code|
129		// like a |length| bit number.
130		code += 1 << (32 - length)
131	}
132
133	// Now we can sort by the code so that the left half of each branch are
134	// grouped together, recursively.
135	slices.SortFunc(codes, func(a, b huffmanCode) int {
136		return cmp.Compare(a.code, b.code)
137	})
138
139	t.nodes = make([]huffmanNode, len(codes))
140	_, err := buildHuffmanNode(&t, codes, 0)
141	return t, err
142}
143
144// huffmanSymbolLengthPair contains a symbol and its code length.
145type huffmanSymbolLengthPair struct {
146	value  uint16
147	length uint8
148}
149
150// huffmanCode contains a symbol, its code and code length.
151type huffmanCode struct {
152	code    uint32
153	codeLen uint8
154	value   uint16
155}
156
157// buildHuffmanNode takes a slice of sorted huffmanCodes and builds a node in
158// the Huffman tree at the given level. It returns the index of the newly
159// constructed node.
160func buildHuffmanNode(t *huffmanTree, codes []huffmanCode, level uint32) (nodeIndex uint16, err error) {
161	test := uint32(1) << (31 - level)
162
163	// We have to search the list of codes to find the divide between the left and right sides.
164	firstRightIndex := len(codes)
165	for i, code := range codes {
166		if code.code&test != 0 {
167			firstRightIndex = i
168			break
169		}
170	}
171
172	left := codes[:firstRightIndex]
173	right := codes[firstRightIndex:]
174
175	if len(left) == 0 || len(right) == 0 {
176		// There is a superfluous level in the Huffman tree indicating
177		// a bug in the encoder. However, this bug has been observed in
178		// the wild so we handle it.
179
180		// If this function was called recursively then we know that
181		// len(codes) >= 2 because, otherwise, we would have hit the
182		// "leaf node" case, below, and not recurred.
183		//
184		// However, for the initial call it's possible that len(codes)
185		// is zero or one. Both cases are invalid because a zero length
186		// tree cannot encode anything and a length-1 tree can only
187		// encode EOF and so is superfluous. We reject both.
188		if len(codes) < 2 {
189			return 0, StructuralError("empty Huffman tree")
190		}
191
192		// In this case the recursion doesn't always reduce the length
193		// of codes so we need to ensure termination via another
194		// mechanism.
195		if level == 31 {
196			// Since len(codes) >= 2 the only way that the values
197			// can match at all 32 bits is if they are equal, which
198			// is invalid. This ensures that we never enter
199			// infinite recursion.
200			return 0, StructuralError("equal symbols in Huffman tree")
201		}
202
203		if len(left) == 0 {
204			return buildHuffmanNode(t, right, level+1)
205		}
206		return buildHuffmanNode(t, left, level+1)
207	}
208
209	nodeIndex = uint16(t.nextNode)
210	node := &t.nodes[t.nextNode]
211	t.nextNode++
212
213	if len(left) == 1 {
214		// leaf node
215		node.left = invalidNodeValue
216		node.leftValue = left[0].value
217	} else {
218		node.left, err = buildHuffmanNode(t, left, level+1)
219	}
220
221	if err != nil {
222		return
223	}
224
225	if len(right) == 1 {
226		// leaf node
227		node.right = invalidNodeValue
228		node.rightValue = right[0].value
229	} else {
230		node.right, err = buildHuffmanNode(t, right, level+1)
231	}
232
233	return
234}
235