1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package lzw
6
7import (
8	"bufio"
9	"errors"
10	"fmt"
11	"io"
12)
13
14// A writer is a buffered, flushable writer.
15type writer interface {
16	io.ByteWriter
17	Flush() error
18}
19
20const (
21	// A code is a 12 bit value, stored as a uint32 when encoding to avoid
22	// type conversions when shifting bits.
23	maxCode     = 1<<12 - 1
24	invalidCode = 1<<32 - 1
25	// There are 1<<12 possible codes, which is an upper bound on the number of
26	// valid hash table entries at any given point in time. tableSize is 4x that.
27	tableSize = 4 * 1 << 12
28	tableMask = tableSize - 1
29	// A hash table entry is a uint32. Zero is an invalid entry since the
30	// lower 12 bits of a valid entry must be a non-literal code.
31	invalidEntry = 0
32)
33
34// Writer is an LZW compressor. It writes the compressed form of the data
35// to an underlying writer (see [NewWriter]).
36type Writer struct {
37	// w is the writer that compressed bytes are written to.
38	w writer
39	// litWidth is the width in bits of literal codes.
40	litWidth uint
41	// order, write, bits, nBits and width are the state for
42	// converting a code stream into a byte stream.
43	order Order
44	write func(*Writer, uint32) error
45	nBits uint
46	width uint
47	bits  uint32
48	// hi is the code implied by the next code emission.
49	// overflow is the code at which hi overflows the code width.
50	hi, overflow uint32
51	// savedCode is the accumulated code at the end of the most recent Write
52	// call. It is equal to invalidCode if there was no such call.
53	savedCode uint32
54	// err is the first error encountered during writing. Closing the writer
55	// will make any future Write calls return errClosed
56	err error
57	// table is the hash table from 20-bit keys to 12-bit values. Each table
58	// entry contains key<<12|val and collisions resolve by linear probing.
59	// The keys consist of a 12-bit code prefix and an 8-bit byte suffix.
60	// The values are a 12-bit code.
61	table [tableSize]uint32
62}
63
64// writeLSB writes the code c for "Least Significant Bits first" data.
65func (w *Writer) writeLSB(c uint32) error {
66	w.bits |= c << w.nBits
67	w.nBits += w.width
68	for w.nBits >= 8 {
69		if err := w.w.WriteByte(uint8(w.bits)); err != nil {
70			return err
71		}
72		w.bits >>= 8
73		w.nBits -= 8
74	}
75	return nil
76}
77
78// writeMSB writes the code c for "Most Significant Bits first" data.
79func (w *Writer) writeMSB(c uint32) error {
80	w.bits |= c << (32 - w.width - w.nBits)
81	w.nBits += w.width
82	for w.nBits >= 8 {
83		if err := w.w.WriteByte(uint8(w.bits >> 24)); err != nil {
84			return err
85		}
86		w.bits <<= 8
87		w.nBits -= 8
88	}
89	return nil
90}
91
92// errOutOfCodes is an internal error that means that the writer has run out
93// of unused codes and a clear code needs to be sent next.
94var errOutOfCodes = errors.New("lzw: out of codes")
95
96// incHi increments e.hi and checks for both overflow and running out of
97// unused codes. In the latter case, incHi sends a clear code, resets the
98// writer state and returns errOutOfCodes.
99func (w *Writer) incHi() error {
100	w.hi++
101	if w.hi == w.overflow {
102		w.width++
103		w.overflow <<= 1
104	}
105	if w.hi == maxCode {
106		clear := uint32(1) << w.litWidth
107		if err := w.write(w, clear); err != nil {
108			return err
109		}
110		w.width = w.litWidth + 1
111		w.hi = clear + 1
112		w.overflow = clear << 1
113		for i := range w.table {
114			w.table[i] = invalidEntry
115		}
116		return errOutOfCodes
117	}
118	return nil
119}
120
121// Write writes a compressed representation of p to w's underlying writer.
122func (w *Writer) Write(p []byte) (n int, err error) {
123	if w.err != nil {
124		return 0, w.err
125	}
126	if len(p) == 0 {
127		return 0, nil
128	}
129	if maxLit := uint8(1<<w.litWidth - 1); maxLit != 0xff {
130		for _, x := range p {
131			if x > maxLit {
132				w.err = errors.New("lzw: input byte too large for the litWidth")
133				return 0, w.err
134			}
135		}
136	}
137	n = len(p)
138	code := w.savedCode
139	if code == invalidCode {
140		// This is the first write; send a clear code.
141		// https://www.w3.org/Graphics/GIF/spec-gif89a.txt Appendix F
142		// "Variable-Length-Code LZW Compression" says that "Encoders should
143		// output a Clear code as the first code of each image data stream".
144		//
145		// LZW compression isn't only used by GIF, but it's cheap to follow
146		// that directive unconditionally.
147		clear := uint32(1) << w.litWidth
148		if err := w.write(w, clear); err != nil {
149			return 0, err
150		}
151		// After the starting clear code, the next code sent (for non-empty
152		// input) is always a literal code.
153		code, p = uint32(p[0]), p[1:]
154	}
155loop:
156	for _, x := range p {
157		literal := uint32(x)
158		key := code<<8 | literal
159		// If there is a hash table hit for this key then we continue the loop
160		// and do not emit a code yet.
161		hash := (key>>12 ^ key) & tableMask
162		for h, t := hash, w.table[hash]; t != invalidEntry; {
163			if key == t>>12 {
164				code = t & maxCode
165				continue loop
166			}
167			h = (h + 1) & tableMask
168			t = w.table[h]
169		}
170		// Otherwise, write the current code, and literal becomes the start of
171		// the next emitted code.
172		if w.err = w.write(w, code); w.err != nil {
173			return 0, w.err
174		}
175		code = literal
176		// Increment e.hi, the next implied code. If we run out of codes, reset
177		// the writer state (including clearing the hash table) and continue.
178		if err1 := w.incHi(); err1 != nil {
179			if err1 == errOutOfCodes {
180				continue
181			}
182			w.err = err1
183			return 0, w.err
184		}
185		// Otherwise, insert key -> e.hi into the map that e.table represents.
186		for {
187			if w.table[hash] == invalidEntry {
188				w.table[hash] = (key << 12) | w.hi
189				break
190			}
191			hash = (hash + 1) & tableMask
192		}
193	}
194	w.savedCode = code
195	return n, nil
196}
197
198// Close closes the [Writer], flushing any pending output. It does not close
199// w's underlying writer.
200func (w *Writer) Close() error {
201	if w.err != nil {
202		if w.err == errClosed {
203			return nil
204		}
205		return w.err
206	}
207	// Make any future calls to Write return errClosed.
208	w.err = errClosed
209	// Write the savedCode if valid.
210	if w.savedCode != invalidCode {
211		if err := w.write(w, w.savedCode); err != nil {
212			return err
213		}
214		if err := w.incHi(); err != nil && err != errOutOfCodes {
215			return err
216		}
217	} else {
218		// Write the starting clear code, as w.Write did not.
219		clear := uint32(1) << w.litWidth
220		if err := w.write(w, clear); err != nil {
221			return err
222		}
223	}
224	// Write the eof code.
225	eof := uint32(1)<<w.litWidth + 1
226	if err := w.write(w, eof); err != nil {
227		return err
228	}
229	// Write the final bits.
230	if w.nBits > 0 {
231		if w.order == MSB {
232			w.bits >>= 24
233		}
234		if err := w.w.WriteByte(uint8(w.bits)); err != nil {
235			return err
236		}
237	}
238	return w.w.Flush()
239}
240
241// Reset clears the [Writer]'s state and allows it to be reused again
242// as a new [Writer].
243func (w *Writer) Reset(dst io.Writer, order Order, litWidth int) {
244	*w = Writer{}
245	w.init(dst, order, litWidth)
246}
247
248// NewWriter creates a new [io.WriteCloser].
249// Writes to the returned [io.WriteCloser] are compressed and written to w.
250// It is the caller's responsibility to call Close on the WriteCloser when
251// finished writing.
252// The number of bits to use for literal codes, litWidth, must be in the
253// range [2,8] and is typically 8. Input bytes must be less than 1<<litWidth.
254//
255// It is guaranteed that the underlying type of the returned [io.WriteCloser]
256// is a *[Writer].
257func NewWriter(w io.Writer, order Order, litWidth int) io.WriteCloser {
258	return newWriter(w, order, litWidth)
259}
260
261func newWriter(dst io.Writer, order Order, litWidth int) *Writer {
262	w := new(Writer)
263	w.init(dst, order, litWidth)
264	return w
265}
266
267func (w *Writer) init(dst io.Writer, order Order, litWidth int) {
268	switch order {
269	case LSB:
270		w.write = (*Writer).writeLSB
271	case MSB:
272		w.write = (*Writer).writeMSB
273	default:
274		w.err = errors.New("lzw: unknown order")
275		return
276	}
277	if litWidth < 2 || 8 < litWidth {
278		w.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth)
279		return
280	}
281	bw, ok := dst.(writer)
282	if !ok && dst != nil {
283		bw = bufio.NewWriter(dst)
284	}
285	w.w = bw
286	lw := uint(litWidth)
287	w.order = order
288	w.width = 1 + lw
289	w.litWidth = lw
290	w.hi = 1<<lw + 1
291	w.overflow = 1 << (lw + 1)
292	w.savedCode = invalidCode
293}
294