1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package lzw
6
7import (
8	"bytes"
9	"fmt"
10	"io"
11	"math"
12	"os"
13	"runtime"
14	"strconv"
15	"strings"
16	"testing"
17)
18
19type lzwTest struct {
20	desc       string
21	raw        string
22	compressed string
23	err        error
24}
25
26var lzwTests = []lzwTest{
27	{
28		"empty;LSB;8",
29		"",
30		"\x01\x01",
31		nil,
32	},
33	{
34		"empty;MSB;8",
35		"",
36		"\x80\x80",
37		nil,
38	},
39	{
40		"tobe;LSB;7",
41		"TOBEORNOTTOBEORTOBEORNOT",
42		"\x54\x4f\x42\x45\x4f\x52\x4e\x4f\x54\x82\x84\x86\x8b\x85\x87\x89\x81",
43		nil,
44	},
45	{
46		"tobe;LSB;8",
47		"TOBEORNOTTOBEORTOBEORNOT",
48		"\x54\x9e\x08\x29\xf2\x44\x8a\x93\x27\x54\x04\x12\x34\xb8\xb0\xe0\xc1\x84\x01\x01",
49		nil,
50	},
51	{
52		"tobe;MSB;7",
53		"TOBEORNOTTOBEORTOBEORNOT",
54		"\x54\x4f\x42\x45\x4f\x52\x4e\x4f\x54\x82\x84\x86\x8b\x85\x87\x89\x81",
55		nil,
56	},
57	{
58		"tobe;MSB;8",
59		"TOBEORNOTTOBEORTOBEORNOT",
60		"\x2a\x13\xc8\x44\x52\x79\x48\x9c\x4f\x2a\x40\xa0\x90\x68\x5c\x16\x0f\x09\x80\x80",
61		nil,
62	},
63	{
64		"tobe-truncated;LSB;8",
65		"TOBEORNOTTOBEORTOBEORNOT",
66		"\x54\x9e\x08\x29\xf2\x44\x8a\x93\x27\x54\x04",
67		io.ErrUnexpectedEOF,
68	},
69	// This example comes from https://en.wikipedia.org/wiki/Graphics_Interchange_Format.
70	{
71		"gif;LSB;8",
72		"\x28\xff\xff\xff\x28\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff",
73		"\x00\x51\xfc\x1b\x28\x70\xa0\xc1\x83\x01\x01",
74		nil,
75	},
76	// This example comes from http://compgroups.net/comp.lang.ruby/Decompressing-LZW-compression-from-PDF-file
77	{
78		"pdf;MSB;8",
79		"-----A---B",
80		"\x80\x0b\x60\x50\x22\x0c\x0c\x85\x01",
81		nil,
82	},
83}
84
85func TestReader(t *testing.T) {
86	var b bytes.Buffer
87	for _, tt := range lzwTests {
88		d := strings.Split(tt.desc, ";")
89		var order Order
90		switch d[1] {
91		case "LSB":
92			order = LSB
93		case "MSB":
94			order = MSB
95		default:
96			t.Errorf("%s: bad order %q", tt.desc, d[1])
97		}
98		litWidth, _ := strconv.Atoi(d[2])
99		rc := NewReader(strings.NewReader(tt.compressed), order, litWidth)
100		defer rc.Close()
101		b.Reset()
102		n, err := io.Copy(&b, rc)
103		s := b.String()
104		if err != nil {
105			if err != tt.err {
106				t.Errorf("%s: io.Copy: %v want %v", tt.desc, err, tt.err)
107			}
108			if err == io.ErrUnexpectedEOF {
109				// Even if the input is truncated, we should still return the
110				// partial decoded result.
111				if n == 0 || !strings.HasPrefix(tt.raw, s) {
112					t.Errorf("got %d bytes (%q), want a non-empty prefix of %q", n, s, tt.raw)
113				}
114			}
115			continue
116		}
117		if s != tt.raw {
118			t.Errorf("%s: got %d-byte %q want %d-byte %q", tt.desc, n, s, len(tt.raw), tt.raw)
119		}
120	}
121}
122
123func TestReaderReset(t *testing.T) {
124	var b bytes.Buffer
125	for _, tt := range lzwTests {
126		d := strings.Split(tt.desc, ";")
127		var order Order
128		switch d[1] {
129		case "LSB":
130			order = LSB
131		case "MSB":
132			order = MSB
133		default:
134			t.Errorf("%s: bad order %q", tt.desc, d[1])
135		}
136		litWidth, _ := strconv.Atoi(d[2])
137		rc := NewReader(strings.NewReader(tt.compressed), order, litWidth)
138		defer rc.Close()
139		b.Reset()
140		n, err := io.Copy(&b, rc)
141		b1 := b.Bytes()
142		if err != nil {
143			if err != tt.err {
144				t.Errorf("%s: io.Copy: %v want %v", tt.desc, err, tt.err)
145			}
146			if err == io.ErrUnexpectedEOF {
147				// Even if the input is truncated, we should still return the
148				// partial decoded result.
149				if n == 0 || !strings.HasPrefix(tt.raw, b.String()) {
150					t.Errorf("got %d bytes (%q), want a non-empty prefix of %q", n, b.String(), tt.raw)
151				}
152			}
153			continue
154		}
155
156		b.Reset()
157		rc.(*Reader).Reset(strings.NewReader(tt.compressed), order, litWidth)
158		n, err = io.Copy(&b, rc)
159		b2 := b.Bytes()
160		if err != nil {
161			t.Errorf("%s: io.Copy: %v want %v", tt.desc, err, nil)
162			continue
163		}
164		if !bytes.Equal(b1, b2) {
165			t.Errorf("bytes read were not the same")
166		}
167	}
168}
169
170type devZero struct{}
171
172func (devZero) Read(p []byte) (int, error) {
173	clear(p)
174	return len(p), nil
175}
176
177func TestHiCodeDoesNotOverflow(t *testing.T) {
178	r := NewReader(devZero{}, LSB, 8)
179	d := r.(*Reader)
180	buf := make([]byte, 1024)
181	oldHi := uint16(0)
182	for i := 0; i < 100; i++ {
183		if _, err := io.ReadFull(r, buf); err != nil {
184			t.Fatalf("i=%d: %v", i, err)
185		}
186		// The hi code should never decrease.
187		if d.hi < oldHi {
188			t.Fatalf("i=%d: hi=%d decreased from previous value %d", i, d.hi, oldHi)
189		}
190		oldHi = d.hi
191	}
192}
193
194// TestNoLongerSavingPriorExpansions tests the decoder state when codes other
195// than clear codes continue to be seen after decoder.hi and decoder.width
196// reach their maximum values (4095 and 12), i.e. after we no longer save prior
197// expansions. In particular, it tests seeing the highest possible code, 4095.
198func TestNoLongerSavingPriorExpansions(t *testing.T) {
199	// Iterations is used to calculate how many input bits are needed to get
200	// the decoder.hi and decoder.width values up to their maximum.
201	iterations := []struct {
202		width, n int
203	}{
204		// The final term is 257, not 256, as NewReader initializes d.hi to
205		// d.clear+1 and the clear code is 256.
206		{9, 512 - 257},
207		{10, 1024 - 512},
208		{11, 2048 - 1024},
209		{12, 4096 - 2048},
210	}
211	nCodes, nBits := 0, 0
212	for _, e := range iterations {
213		nCodes += e.n
214		nBits += e.n * e.width
215	}
216	if nCodes != 3839 {
217		t.Fatalf("nCodes: got %v, want %v", nCodes, 3839)
218	}
219	if nBits != 43255 {
220		t.Fatalf("nBits: got %v, want %v", nBits, 43255)
221	}
222
223	// Construct our input of 43255 zero bits (which gets d.hi and d.width up
224	// to 4095 and 12), followed by 0xfff (4095) as 12 bits, followed by 0x101
225	// (EOF) as 12 bits.
226	//
227	// 43255 = 5406*8 + 7, and codes are read in LSB order. The final bytes are
228	// therefore:
229	//
230	// xwwwwwww xxxxxxxx yyyyyxxx zyyyyyyy
231	// 10000000 11111111 00001111 00001000
232	//
233	// or split out:
234	//
235	// .0000000 ........ ........ ........   w = 0x000
236	// 1....... 11111111 .....111 ........   x = 0xfff
237	// ........ ........ 00001... .0001000   y = 0x101
238	//
239	// The 12 'w' bits (not all are shown) form the 3839'th code, with value
240	// 0x000. Just after decoder.read returns that code, d.hi == 4095 and
241	// d.last == 0.
242	//
243	// The 12 'x' bits form the 3840'th code, with value 0xfff or 4095. Just
244	// after decoder.read returns that code, d.hi == 4095 and d.last ==
245	// decoderInvalidCode.
246	//
247	// The 12 'y' bits form the 3841'st code, with value 0x101, the EOF code.
248	//
249	// The 'z' bit is unused.
250	in := make([]byte, 5406)
251	in = append(in, 0x80, 0xff, 0x0f, 0x08)
252
253	r := NewReader(bytes.NewReader(in), LSB, 8)
254	nDecoded, err := io.Copy(io.Discard, r)
255	if err != nil {
256		t.Fatalf("Copy: %v", err)
257	}
258	// nDecoded should be 3841: 3839 literal codes and then 2 decoded bytes
259	// from 1 non-literal code. The EOF code contributes 0 decoded bytes.
260	if nDecoded != int64(nCodes+2) {
261		t.Fatalf("nDecoded: got %v, want %v", nDecoded, nCodes+2)
262	}
263}
264
265func BenchmarkDecoder(b *testing.B) {
266	buf, err := os.ReadFile("../testdata/e.txt")
267	if err != nil {
268		b.Fatal(err)
269	}
270	if len(buf) == 0 {
271		b.Fatalf("test file has no data")
272	}
273
274	getInputBuf := func(buf []byte, n int) []byte {
275		compressed := new(bytes.Buffer)
276		w := NewWriter(compressed, LSB, 8)
277		for i := 0; i < n; i += len(buf) {
278			if len(buf) > n-i {
279				buf = buf[:n-i]
280			}
281			w.Write(buf)
282		}
283		w.Close()
284		return compressed.Bytes()
285	}
286
287	for e := 4; e <= 6; e++ {
288		n := int(math.Pow10(e))
289		b.Run(fmt.Sprint("1e", e), func(b *testing.B) {
290			b.StopTimer()
291			b.SetBytes(int64(n))
292			buf1 := getInputBuf(buf, n)
293			runtime.GC()
294			b.StartTimer()
295			for i := 0; i < b.N; i++ {
296				io.Copy(io.Discard, NewReader(bytes.NewReader(buf1), LSB, 8))
297			}
298		})
299		b.Run(fmt.Sprint("1e-Reuse", e), func(b *testing.B) {
300			b.StopTimer()
301			b.SetBytes(int64(n))
302			buf1 := getInputBuf(buf, n)
303			runtime.GC()
304			b.StartTimer()
305			r := NewReader(bytes.NewReader(buf1), LSB, 8)
306			for i := 0; i < b.N; i++ {
307				io.Copy(io.Discard, r)
308				r.Close()
309				r.(*Reader).Reset(bytes.NewReader(buf1), LSB, 8)
310			}
311		})
312	}
313}
314