1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package flate
6
7import (
8	"bytes"
9	"errors"
10	"fmt"
11	"internal/testenv"
12	"io"
13	"math/rand"
14	"os"
15	"reflect"
16	"runtime/debug"
17	"sync"
18	"testing"
19)
20
21type deflateTest struct {
22	in    []byte
23	level int
24	out   []byte
25}
26
27type deflateInflateTest struct {
28	in []byte
29}
30
31type reverseBitsTest struct {
32	in       uint16
33	bitCount uint8
34	out      uint16
35}
36
37var deflateTests = []*deflateTest{
38	{[]byte{}, 0, []byte{1, 0, 0, 255, 255}},
39	{[]byte{0x11}, -1, []byte{18, 4, 4, 0, 0, 255, 255}},
40	{[]byte{0x11}, DefaultCompression, []byte{18, 4, 4, 0, 0, 255, 255}},
41	{[]byte{0x11}, 4, []byte{18, 4, 4, 0, 0, 255, 255}},
42
43	{[]byte{0x11}, 0, []byte{0, 1, 0, 254, 255, 17, 1, 0, 0, 255, 255}},
44	{[]byte{0x11, 0x12}, 0, []byte{0, 2, 0, 253, 255, 17, 18, 1, 0, 0, 255, 255}},
45	{[]byte{0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11}, 0,
46		[]byte{0, 8, 0, 247, 255, 17, 17, 17, 17, 17, 17, 17, 17, 1, 0, 0, 255, 255},
47	},
48	{[]byte{}, 2, []byte{1, 0, 0, 255, 255}},
49	{[]byte{0x11}, 2, []byte{18, 4, 4, 0, 0, 255, 255}},
50	{[]byte{0x11, 0x12}, 2, []byte{18, 20, 2, 4, 0, 0, 255, 255}},
51	{[]byte{0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11}, 2, []byte{18, 132, 2, 64, 0, 0, 0, 255, 255}},
52	{[]byte{}, 9, []byte{1, 0, 0, 255, 255}},
53	{[]byte{0x11}, 9, []byte{18, 4, 4, 0, 0, 255, 255}},
54	{[]byte{0x11, 0x12}, 9, []byte{18, 20, 2, 4, 0, 0, 255, 255}},
55	{[]byte{0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11}, 9, []byte{18, 132, 2, 64, 0, 0, 0, 255, 255}},
56}
57
58var deflateInflateTests = []*deflateInflateTest{
59	{[]byte{}},
60	{[]byte{0x11}},
61	{[]byte{0x11, 0x12}},
62	{[]byte{0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11}},
63	{[]byte{0x11, 0x10, 0x13, 0x41, 0x21, 0x21, 0x41, 0x13, 0x87, 0x78, 0x13}},
64	{largeDataChunk()},
65}
66
67var reverseBitsTests = []*reverseBitsTest{
68	{1, 1, 1},
69	{1, 2, 2},
70	{1, 3, 4},
71	{1, 4, 8},
72	{1, 5, 16},
73	{17, 5, 17},
74	{257, 9, 257},
75	{29, 5, 23},
76}
77
78func largeDataChunk() []byte {
79	result := make([]byte, 100000)
80	for i := range result {
81		result[i] = byte(i * i & 0xFF)
82	}
83	return result
84}
85
86func TestBulkHash4(t *testing.T) {
87	for _, x := range deflateTests {
88		y := x.out
89		if len(y) < minMatchLength {
90			continue
91		}
92		y = append(y, y...)
93		for j := 4; j < len(y); j++ {
94			y := y[:j]
95			dst := make([]uint32, len(y)-minMatchLength+1)
96			for i := range dst {
97				dst[i] = uint32(i + 100)
98			}
99			bulkHash4(y, dst)
100			for i, got := range dst {
101				want := hash4(y[i:])
102				if got != want && got == uint32(i)+100 {
103					t.Errorf("Len:%d Index:%d, want 0x%08x but not modified", len(y), i, want)
104				} else if got != want {
105					t.Errorf("Len:%d Index:%d, got 0x%08x want:0x%08x", len(y), i, got, want)
106				}
107			}
108		}
109	}
110}
111
112func TestDeflate(t *testing.T) {
113	for _, h := range deflateTests {
114		var buf bytes.Buffer
115		w, err := NewWriter(&buf, h.level)
116		if err != nil {
117			t.Errorf("NewWriter: %v", err)
118			continue
119		}
120		w.Write(h.in)
121		w.Close()
122		if !bytes.Equal(buf.Bytes(), h.out) {
123			t.Errorf("Deflate(%d, %x) = \n%#v, want \n%#v", h.level, h.in, buf.Bytes(), h.out)
124		}
125	}
126}
127
128func TestWriterClose(t *testing.T) {
129	b := new(bytes.Buffer)
130	zw, err := NewWriter(b, 6)
131	if err != nil {
132		t.Fatalf("NewWriter: %v", err)
133	}
134
135	if c, err := zw.Write([]byte("Test")); err != nil || c != 4 {
136		t.Fatalf("Write to not closed writer: %s, %d", err, c)
137	}
138
139	if err := zw.Close(); err != nil {
140		t.Fatalf("Close: %v", err)
141	}
142
143	afterClose := b.Len()
144
145	if c, err := zw.Write([]byte("Test")); err == nil || c != 0 {
146		t.Fatalf("Write to closed writer: %v, %d", err, c)
147	}
148
149	if err := zw.Flush(); err == nil {
150		t.Fatalf("Flush to closed writer: %s", err)
151	}
152
153	if err := zw.Close(); err != nil {
154		t.Fatalf("Close: %v", err)
155	}
156
157	if afterClose != b.Len() {
158		t.Fatalf("Writer wrote data after close. After close: %d. After writes on closed stream: %d", afterClose, b.Len())
159	}
160}
161
162// A sparseReader returns a stream consisting of 0s followed by 1<<16 1s.
163// This tests missing hash references in a very large input.
164type sparseReader struct {
165	l   int64
166	cur int64
167}
168
169func (r *sparseReader) Read(b []byte) (n int, err error) {
170	if r.cur >= r.l {
171		return 0, io.EOF
172	}
173	n = len(b)
174	cur := r.cur + int64(n)
175	if cur > r.l {
176		n -= int(cur - r.l)
177		cur = r.l
178	}
179	for i := range b[0:n] {
180		if r.cur+int64(i) >= r.l-1<<16 {
181			b[i] = 1
182		} else {
183			b[i] = 0
184		}
185	}
186	r.cur = cur
187	return
188}
189
190func TestVeryLongSparseChunk(t *testing.T) {
191	if testing.Short() {
192		t.Skip("skipping sparse chunk during short test")
193	}
194	w, err := NewWriter(io.Discard, 1)
195	if err != nil {
196		t.Errorf("NewWriter: %v", err)
197		return
198	}
199	if _, err = io.Copy(w, &sparseReader{l: 23e8}); err != nil {
200		t.Errorf("Compress failed: %v", err)
201		return
202	}
203}
204
205type syncBuffer struct {
206	buf    bytes.Buffer
207	mu     sync.RWMutex
208	closed bool
209	ready  chan bool
210}
211
212func newSyncBuffer() *syncBuffer {
213	return &syncBuffer{ready: make(chan bool, 1)}
214}
215
216func (b *syncBuffer) Read(p []byte) (n int, err error) {
217	for {
218		b.mu.RLock()
219		n, err = b.buf.Read(p)
220		b.mu.RUnlock()
221		if n > 0 || b.closed {
222			return
223		}
224		<-b.ready
225	}
226}
227
228func (b *syncBuffer) signal() {
229	select {
230	case b.ready <- true:
231	default:
232	}
233}
234
235func (b *syncBuffer) Write(p []byte) (n int, err error) {
236	n, err = b.buf.Write(p)
237	b.signal()
238	return
239}
240
241func (b *syncBuffer) WriteMode() {
242	b.mu.Lock()
243}
244
245func (b *syncBuffer) ReadMode() {
246	b.mu.Unlock()
247	b.signal()
248}
249
250func (b *syncBuffer) Close() error {
251	b.closed = true
252	b.signal()
253	return nil
254}
255
256func testSync(t *testing.T, level int, input []byte, name string) {
257	if len(input) == 0 {
258		return
259	}
260
261	t.Logf("--testSync %d, %d, %s", level, len(input), name)
262	buf := newSyncBuffer()
263	buf1 := new(bytes.Buffer)
264	buf.WriteMode()
265	w, err := NewWriter(io.MultiWriter(buf, buf1), level)
266	if err != nil {
267		t.Errorf("NewWriter: %v", err)
268		return
269	}
270	r := NewReader(buf)
271
272	// Write half the input and read back.
273	for i := 0; i < 2; i++ {
274		var lo, hi int
275		if i == 0 {
276			lo, hi = 0, (len(input)+1)/2
277		} else {
278			lo, hi = (len(input)+1)/2, len(input)
279		}
280		t.Logf("#%d: write %d-%d", i, lo, hi)
281		if _, err := w.Write(input[lo:hi]); err != nil {
282			t.Errorf("testSync: write: %v", err)
283			return
284		}
285		if i == 0 {
286			if err := w.Flush(); err != nil {
287				t.Errorf("testSync: flush: %v", err)
288				return
289			}
290		} else {
291			if err := w.Close(); err != nil {
292				t.Errorf("testSync: close: %v", err)
293			}
294		}
295		buf.ReadMode()
296		out := make([]byte, hi-lo+1)
297		m, err := io.ReadAtLeast(r, out, hi-lo)
298		t.Logf("#%d: read %d", i, m)
299		if m != hi-lo || err != nil {
300			t.Errorf("testSync/%d (%d, %d, %s): read %d: %d, %v (%d left)", i, level, len(input), name, hi-lo, m, err, buf.buf.Len())
301			return
302		}
303		if !bytes.Equal(input[lo:hi], out[:hi-lo]) {
304			t.Errorf("testSync/%d: read wrong bytes: %x vs %x", i, input[lo:hi], out[:hi-lo])
305			return
306		}
307		// This test originally checked that after reading
308		// the first half of the input, there was nothing left
309		// in the read buffer (buf.buf.Len() != 0) but that is
310		// not necessarily the case: the write Flush may emit
311		// some extra framing bits that are not necessary
312		// to process to obtain the first half of the uncompressed
313		// data. The test ran correctly most of the time, because
314		// the background goroutine had usually read even
315		// those extra bits by now, but it's not a useful thing to
316		// check.
317		buf.WriteMode()
318	}
319	buf.ReadMode()
320	out := make([]byte, 10)
321	if n, err := r.Read(out); n > 0 || err != io.EOF {
322		t.Errorf("testSync (%d, %d, %s): final Read: %d, %v (hex: %x)", level, len(input), name, n, err, out[0:n])
323	}
324	if buf.buf.Len() != 0 {
325		t.Errorf("testSync (%d, %d, %s): extra data at end", level, len(input), name)
326	}
327	r.Close()
328
329	// stream should work for ordinary reader too
330	r = NewReader(buf1)
331	out, err = io.ReadAll(r)
332	if err != nil {
333		t.Errorf("testSync: read: %s", err)
334		return
335	}
336	r.Close()
337	if !bytes.Equal(input, out) {
338		t.Errorf("testSync: decompress(compress(data)) != data: level=%d input=%s", level, name)
339	}
340}
341
342func testToFromWithLevelAndLimit(t *testing.T, level int, input []byte, name string, limit int) {
343	var buffer bytes.Buffer
344	w, err := NewWriter(&buffer, level)
345	if err != nil {
346		t.Errorf("NewWriter: %v", err)
347		return
348	}
349	w.Write(input)
350	w.Close()
351	if limit > 0 && buffer.Len() > limit {
352		t.Errorf("level: %d, len(compress(data)) = %d > limit = %d", level, buffer.Len(), limit)
353		return
354	}
355	if limit > 0 {
356		t.Logf("level: %d, size:%.2f%%, %d b\n", level, float64(buffer.Len()*100)/float64(limit), buffer.Len())
357	}
358	r := NewReader(&buffer)
359	out, err := io.ReadAll(r)
360	if err != nil {
361		t.Errorf("read: %s", err)
362		return
363	}
364	r.Close()
365	if !bytes.Equal(input, out) {
366		t.Errorf("decompress(compress(data)) != data: level=%d input=%s", level, name)
367		return
368	}
369	testSync(t, level, input, name)
370}
371
372func testToFromWithLimit(t *testing.T, input []byte, name string, limit [11]int) {
373	for i := 0; i < 10; i++ {
374		testToFromWithLevelAndLimit(t, i, input, name, limit[i])
375	}
376	// Test HuffmanCompression
377	testToFromWithLevelAndLimit(t, -2, input, name, limit[10])
378}
379
380func TestDeflateInflate(t *testing.T) {
381	t.Parallel()
382	for i, h := range deflateInflateTests {
383		if testing.Short() && len(h.in) > 10000 {
384			continue
385		}
386		testToFromWithLimit(t, h.in, fmt.Sprintf("#%d", i), [11]int{})
387	}
388}
389
390func TestReverseBits(t *testing.T) {
391	for _, h := range reverseBitsTests {
392		if v := reverseBits(h.in, h.bitCount); v != h.out {
393			t.Errorf("reverseBits(%v,%v) = %v, want %v",
394				h.in, h.bitCount, v, h.out)
395		}
396	}
397}
398
399type deflateInflateStringTest struct {
400	filename string
401	label    string
402	limit    [11]int
403}
404
405var deflateInflateStringTests = []deflateInflateStringTest{
406	{
407		"../testdata/e.txt",
408		"2.718281828...",
409		[...]int{100018, 50650, 50960, 51150, 50930, 50790, 50790, 50790, 50790, 50790, 43683},
410	},
411	{
412		"../../testdata/Isaac.Newton-Opticks.txt",
413		"Isaac.Newton-Opticks",
414		[...]int{567248, 218338, 198211, 193152, 181100, 175427, 175427, 173597, 173422, 173422, 325240},
415	},
416}
417
418func TestDeflateInflateString(t *testing.T) {
419	t.Parallel()
420	if testing.Short() && testenv.Builder() == "" {
421		t.Skip("skipping in short mode")
422	}
423	for _, test := range deflateInflateStringTests {
424		gold, err := os.ReadFile(test.filename)
425		if err != nil {
426			t.Error(err)
427		}
428		testToFromWithLimit(t, gold, test.label, test.limit)
429		if testing.Short() {
430			break
431		}
432	}
433}
434
435func TestReaderDict(t *testing.T) {
436	const (
437		dict = "hello world"
438		text = "hello again world"
439	)
440	var b bytes.Buffer
441	w, err := NewWriter(&b, 5)
442	if err != nil {
443		t.Fatalf("NewWriter: %v", err)
444	}
445	w.Write([]byte(dict))
446	w.Flush()
447	b.Reset()
448	w.Write([]byte(text))
449	w.Close()
450
451	r := NewReaderDict(&b, []byte(dict))
452	data, err := io.ReadAll(r)
453	if err != nil {
454		t.Fatal(err)
455	}
456	if string(data) != "hello again world" {
457		t.Fatalf("read returned %q want %q", string(data), text)
458	}
459}
460
461func TestWriterDict(t *testing.T) {
462	const (
463		dict = "hello world"
464		text = "hello again world"
465	)
466	var b bytes.Buffer
467	w, err := NewWriter(&b, 5)
468	if err != nil {
469		t.Fatalf("NewWriter: %v", err)
470	}
471	w.Write([]byte(dict))
472	w.Flush()
473	b.Reset()
474	w.Write([]byte(text))
475	w.Close()
476
477	var b1 bytes.Buffer
478	w, _ = NewWriterDict(&b1, 5, []byte(dict))
479	w.Write([]byte(text))
480	w.Close()
481
482	if !bytes.Equal(b1.Bytes(), b.Bytes()) {
483		t.Fatalf("writer wrote %q want %q", b1.Bytes(), b.Bytes())
484	}
485}
486
487// See https://golang.org/issue/2508
488func TestRegression2508(t *testing.T) {
489	if testing.Short() {
490		t.Logf("test disabled with -short")
491		return
492	}
493	w, err := NewWriter(io.Discard, 1)
494	if err != nil {
495		t.Fatalf("NewWriter: %v", err)
496	}
497	buf := make([]byte, 1024)
498	for i := 0; i < 131072; i++ {
499		if _, err := w.Write(buf); err != nil {
500			t.Fatalf("writer failed: %v", err)
501		}
502	}
503	w.Close()
504}
505
506func TestWriterReset(t *testing.T) {
507	t.Parallel()
508	for level := 0; level <= 9; level++ {
509		if testing.Short() && level > 1 {
510			break
511		}
512		w, err := NewWriter(io.Discard, level)
513		if err != nil {
514			t.Fatalf("NewWriter: %v", err)
515		}
516		buf := []byte("hello world")
517		n := 1024
518		if testing.Short() {
519			n = 10
520		}
521		for i := 0; i < n; i++ {
522			w.Write(buf)
523		}
524		w.Reset(io.Discard)
525
526		wref, err := NewWriter(io.Discard, level)
527		if err != nil {
528			t.Fatalf("NewWriter: %v", err)
529		}
530
531		// DeepEqual doesn't compare functions.
532		w.d.fill, wref.d.fill = nil, nil
533		w.d.step, wref.d.step = nil, nil
534		w.d.bulkHasher, wref.d.bulkHasher = nil, nil
535		w.d.bestSpeed, wref.d.bestSpeed = nil, nil
536		// hashMatch is always overwritten when used.
537		copy(w.d.hashMatch[:], wref.d.hashMatch[:])
538		if len(w.d.tokens) != 0 {
539			t.Errorf("level %d Writer not reset after Reset. %d tokens were present", level, len(w.d.tokens))
540		}
541		// As long as the length is 0, we don't care about the content.
542		w.d.tokens = wref.d.tokens
543
544		// We don't care if there are values in the window, as long as it is at d.index is 0
545		w.d.window = wref.d.window
546		if !reflect.DeepEqual(w, wref) {
547			t.Errorf("level %d Writer not reset after Reset", level)
548		}
549	}
550
551	levels := []int{0, 1, 2, 5, 9}
552	for _, level := range levels {
553		t.Run(fmt.Sprint(level), func(t *testing.T) {
554			testResetOutput(t, level, nil)
555		})
556	}
557
558	t.Run("dict", func(t *testing.T) {
559		for _, level := range levels {
560			t.Run(fmt.Sprint(level), func(t *testing.T) {
561				testResetOutput(t, level, nil)
562			})
563		}
564	})
565}
566
567func testResetOutput(t *testing.T, level int, dict []byte) {
568	writeData := func(w *Writer) {
569		msg := []byte("now is the time for all good gophers")
570		w.Write(msg)
571		w.Flush()
572
573		hello := []byte("hello world")
574		for i := 0; i < 1024; i++ {
575			w.Write(hello)
576		}
577
578		fill := bytes.Repeat([]byte("x"), 65000)
579		w.Write(fill)
580	}
581
582	buf := new(bytes.Buffer)
583	var w *Writer
584	var err error
585	if dict == nil {
586		w, err = NewWriter(buf, level)
587	} else {
588		w, err = NewWriterDict(buf, level, dict)
589	}
590	if err != nil {
591		t.Fatalf("NewWriter: %v", err)
592	}
593
594	writeData(w)
595	w.Close()
596	out1 := buf.Bytes()
597
598	buf2 := new(bytes.Buffer)
599	w.Reset(buf2)
600	writeData(w)
601	w.Close()
602	out2 := buf2.Bytes()
603
604	if len(out1) != len(out2) {
605		t.Errorf("got %d, expected %d bytes", len(out2), len(out1))
606		return
607	}
608	if !bytes.Equal(out1, out2) {
609		mm := 0
610		for i, b := range out1[:len(out2)] {
611			if b != out2[i] {
612				t.Errorf("mismatch index %d: %#02x, expected %#02x", i, out2[i], b)
613			}
614			mm++
615			if mm == 10 {
616				t.Fatal("Stopping")
617			}
618		}
619	}
620	t.Logf("got %d bytes", len(out1))
621}
622
623// TestBestSpeed tests that round-tripping through deflate and then inflate
624// recovers the original input. The Write sizes are near the thresholds in the
625// compressor.encSpeed method (0, 16, 128), as well as near maxStoreBlockSize
626// (65535).
627func TestBestSpeed(t *testing.T) {
628	t.Parallel()
629	abc := make([]byte, 128)
630	for i := range abc {
631		abc[i] = byte(i)
632	}
633	abcabc := bytes.Repeat(abc, 131072/len(abc))
634	var want []byte
635
636	testCases := [][]int{
637		{65536, 0},
638		{65536, 1},
639		{65536, 1, 256},
640		{65536, 1, 65536},
641		{65536, 14},
642		{65536, 15},
643		{65536, 16},
644		{65536, 16, 256},
645		{65536, 16, 65536},
646		{65536, 127},
647		{65536, 128},
648		{65536, 128, 256},
649		{65536, 128, 65536},
650		{65536, 129},
651		{65536, 65536, 256},
652		{65536, 65536, 65536},
653	}
654
655	for i, tc := range testCases {
656		if i >= 3 && testing.Short() {
657			break
658		}
659		for _, firstN := range []int{1, 65534, 65535, 65536, 65537, 131072} {
660			tc[0] = firstN
661		outer:
662			for _, flush := range []bool{false, true} {
663				buf := new(bytes.Buffer)
664				want = want[:0]
665
666				w, err := NewWriter(buf, BestSpeed)
667				if err != nil {
668					t.Errorf("i=%d, firstN=%d, flush=%t: NewWriter: %v", i, firstN, flush, err)
669					continue
670				}
671				for _, n := range tc {
672					want = append(want, abcabc[:n]...)
673					if _, err := w.Write(abcabc[:n]); err != nil {
674						t.Errorf("i=%d, firstN=%d, flush=%t: Write: %v", i, firstN, flush, err)
675						continue outer
676					}
677					if !flush {
678						continue
679					}
680					if err := w.Flush(); err != nil {
681						t.Errorf("i=%d, firstN=%d, flush=%t: Flush: %v", i, firstN, flush, err)
682						continue outer
683					}
684				}
685				if err := w.Close(); err != nil {
686					t.Errorf("i=%d, firstN=%d, flush=%t: Close: %v", i, firstN, flush, err)
687					continue
688				}
689
690				r := NewReader(buf)
691				got, err := io.ReadAll(r)
692				if err != nil {
693					t.Errorf("i=%d, firstN=%d, flush=%t: ReadAll: %v", i, firstN, flush, err)
694					continue
695				}
696				r.Close()
697
698				if !bytes.Equal(got, want) {
699					t.Errorf("i=%d, firstN=%d, flush=%t: corruption during deflate-then-inflate", i, firstN, flush)
700					continue
701				}
702			}
703		}
704	}
705}
706
707var errIO = errors.New("IO error")
708
709// failWriter fails with errIO exactly at the nth call to Write.
710type failWriter struct{ n int }
711
712func (w *failWriter) Write(b []byte) (int, error) {
713	w.n--
714	if w.n == -1 {
715		return 0, errIO
716	}
717	return len(b), nil
718}
719
720func TestWriterPersistentWriteError(t *testing.T) {
721	t.Parallel()
722	d, err := os.ReadFile("../../testdata/Isaac.Newton-Opticks.txt")
723	if err != nil {
724		t.Fatalf("ReadFile: %v", err)
725	}
726	d = d[:10000] // Keep this test short
727
728	zw, err := NewWriter(nil, DefaultCompression)
729	if err != nil {
730		t.Fatalf("NewWriter: %v", err)
731	}
732
733	// Sweep over the threshold at which an error is returned.
734	// The variable i makes it such that the ith call to failWriter.Write will
735	// return errIO. Since failWriter errors are not persistent, we must ensure
736	// that flate.Writer errors are persistent.
737	for i := 0; i < 1000; i++ {
738		fw := &failWriter{i}
739		zw.Reset(fw)
740
741		_, werr := zw.Write(d)
742		cerr := zw.Close()
743		ferr := zw.Flush()
744		if werr != errIO && werr != nil {
745			t.Errorf("test %d, mismatching Write error: got %v, want %v", i, werr, errIO)
746		}
747		if cerr != errIO && fw.n < 0 {
748			t.Errorf("test %d, mismatching Close error: got %v, want %v", i, cerr, errIO)
749		}
750		if ferr != errIO && fw.n < 0 {
751			t.Errorf("test %d, mismatching Flush error: got %v, want %v", i, ferr, errIO)
752		}
753		if fw.n >= 0 {
754			// At this point, the failure threshold was sufficiently high enough
755			// that we wrote the whole stream without any errors.
756			return
757		}
758	}
759}
760func TestWriterPersistentFlushError(t *testing.T) {
761	zw, err := NewWriter(&failWriter{0}, DefaultCompression)
762	if err != nil {
763		t.Fatalf("NewWriter: %v", err)
764	}
765	flushErr := zw.Flush()
766	closeErr := zw.Close()
767	_, writeErr := zw.Write([]byte("Test"))
768	checkErrors([]error{closeErr, flushErr, writeErr}, errIO, t)
769}
770
771func TestWriterPersistentCloseError(t *testing.T) {
772	// If underlying writer return error on closing stream we should persistent this error across all writer calls.
773	zw, err := NewWriter(&failWriter{0}, DefaultCompression)
774	if err != nil {
775		t.Fatalf("NewWriter: %v", err)
776	}
777	closeErr := zw.Close()
778	flushErr := zw.Flush()
779	_, writeErr := zw.Write([]byte("Test"))
780	checkErrors([]error{closeErr, flushErr, writeErr}, errIO, t)
781
782	// After closing writer we should persistent "write after close" error across Flush and Write calls, but return nil
783	// on next Close calls.
784	var b bytes.Buffer
785	zw.Reset(&b)
786	err = zw.Close()
787	if err != nil {
788		t.Fatalf("First call to close returned error: %s", err)
789	}
790	err = zw.Close()
791	if err != nil {
792		t.Fatalf("Second call to close returned error: %s", err)
793	}
794
795	flushErr = zw.Flush()
796	_, writeErr = zw.Write([]byte("Test"))
797	checkErrors([]error{flushErr, writeErr}, errWriterClosed, t)
798}
799
800func checkErrors(got []error, want error, t *testing.T) {
801	t.Helper()
802	for _, err := range got {
803		if err != want {
804			t.Errorf("Error doesn't match\nWant: %s\nGot: %s", want, got)
805		}
806	}
807}
808
809func TestBestSpeedMatch(t *testing.T) {
810	t.Parallel()
811	cases := []struct {
812		previous, current []byte
813		t, s, want        int32
814	}{{
815		previous: []byte{0, 0, 0, 1, 2},
816		current:  []byte{3, 4, 5, 0, 1, 2, 3, 4, 5},
817		t:        -3,
818		s:        3,
819		want:     6,
820	}, {
821		previous: []byte{0, 0, 0, 1, 2},
822		current:  []byte{2, 4, 5, 0, 1, 2, 3, 4, 5},
823		t:        -3,
824		s:        3,
825		want:     3,
826	}, {
827		previous: []byte{0, 0, 0, 1, 1},
828		current:  []byte{3, 4, 5, 0, 1, 2, 3, 4, 5},
829		t:        -3,
830		s:        3,
831		want:     2,
832	}, {
833		previous: []byte{0, 0, 0, 1, 2},
834		current:  []byte{2, 2, 2, 2, 1, 2, 3, 4, 5},
835		t:        -1,
836		s:        0,
837		want:     4,
838	}, {
839		previous: []byte{0, 0, 0, 1, 2, 3, 4, 5, 2, 2},
840		current:  []byte{2, 2, 2, 2, 1, 2, 3, 4, 5},
841		t:        -7,
842		s:        4,
843		want:     5,
844	}, {
845		previous: []byte{9, 9, 9, 9, 9},
846		current:  []byte{2, 2, 2, 2, 1, 2, 3, 4, 5},
847		t:        -1,
848		s:        0,
849		want:     0,
850	}, {
851		previous: []byte{9, 9, 9, 9, 9},
852		current:  []byte{9, 2, 2, 2, 1, 2, 3, 4, 5},
853		t:        0,
854		s:        1,
855		want:     0,
856	}, {
857		previous: []byte{},
858		current:  []byte{9, 2, 2, 2, 1, 2, 3, 4, 5},
859		t:        -5,
860		s:        1,
861		want:     0,
862	}, {
863		previous: []byte{},
864		current:  []byte{9, 2, 2, 2, 1, 2, 3, 4, 5},
865		t:        -1,
866		s:        1,
867		want:     0,
868	}, {
869		previous: []byte{},
870		current:  []byte{2, 2, 2, 2, 1, 2, 3, 4, 5},
871		t:        0,
872		s:        1,
873		want:     3,
874	}, {
875		previous: []byte{3, 4, 5},
876		current:  []byte{3, 4, 5},
877		t:        -3,
878		s:        0,
879		want:     3,
880	}, {
881		previous: make([]byte, 1000),
882		current:  make([]byte, 1000),
883		t:        -1000,
884		s:        0,
885		want:     maxMatchLength - 4,
886	}, {
887		previous: make([]byte, 200),
888		current:  make([]byte, 500),
889		t:        -200,
890		s:        0,
891		want:     maxMatchLength - 4,
892	}, {
893		previous: make([]byte, 200),
894		current:  make([]byte, 500),
895		t:        0,
896		s:        1,
897		want:     maxMatchLength - 4,
898	}, {
899		previous: make([]byte, maxMatchLength-4),
900		current:  make([]byte, 500),
901		t:        -(maxMatchLength - 4),
902		s:        0,
903		want:     maxMatchLength - 4,
904	}, {
905		previous: make([]byte, 200),
906		current:  make([]byte, 500),
907		t:        -200,
908		s:        400,
909		want:     100,
910	}, {
911		previous: make([]byte, 10),
912		current:  make([]byte, 500),
913		t:        200,
914		s:        400,
915		want:     100,
916	}}
917	for i, c := range cases {
918		e := deflateFast{prev: c.previous}
919		got := e.matchLen(c.s, c.t, c.current)
920		if got != c.want {
921			t.Errorf("Test %d: match length, want %d, got %d", i, c.want, got)
922		}
923	}
924}
925
926func TestBestSpeedMaxMatchOffset(t *testing.T) {
927	t.Parallel()
928	const abc, xyz = "abcdefgh", "stuvwxyz"
929	for _, matchBefore := range []bool{false, true} {
930		for _, extra := range []int{0, inputMargin - 1, inputMargin, inputMargin + 1, 2 * inputMargin} {
931			for offsetAdj := -5; offsetAdj <= +5; offsetAdj++ {
932				report := func(desc string, err error) {
933					t.Errorf("matchBefore=%t, extra=%d, offsetAdj=%d: %s%v",
934						matchBefore, extra, offsetAdj, desc, err)
935				}
936
937				offset := maxMatchOffset + offsetAdj
938
939				// Make src to be a []byte of the form
940				//	"%s%s%s%s%s" % (abc, zeros0, xyzMaybe, abc, zeros1)
941				// where:
942				//	zeros0 is approximately maxMatchOffset zeros.
943				//	xyzMaybe is either xyz or the empty string.
944				//	zeros1 is between 0 and 30 zeros.
945				// The difference between the two abc's will be offset, which
946				// is maxMatchOffset plus or minus a small adjustment.
947				src := make([]byte, offset+len(abc)+extra)
948				copy(src, abc)
949				if !matchBefore {
950					copy(src[offset-len(xyz):], xyz)
951				}
952				copy(src[offset:], abc)
953
954				buf := new(bytes.Buffer)
955				w, err := NewWriter(buf, BestSpeed)
956				if err != nil {
957					report("NewWriter: ", err)
958					continue
959				}
960				if _, err := w.Write(src); err != nil {
961					report("Write: ", err)
962					continue
963				}
964				if err := w.Close(); err != nil {
965					report("Writer.Close: ", err)
966					continue
967				}
968
969				r := NewReader(buf)
970				dst, err := io.ReadAll(r)
971				r.Close()
972				if err != nil {
973					report("ReadAll: ", err)
974					continue
975				}
976
977				if !bytes.Equal(dst, src) {
978					report("", fmt.Errorf("bytes differ after round-tripping"))
979					continue
980				}
981			}
982		}
983	}
984}
985
986func TestBestSpeedShiftOffsets(t *testing.T) {
987	// Test if shiftoffsets properly preserves matches and resets out-of-range matches
988	// seen in https://github.com/golang/go/issues/4142
989	enc := newDeflateFast()
990
991	// testData may not generate internal matches.
992	testData := make([]byte, 32)
993	rng := rand.New(rand.NewSource(0))
994	for i := range testData {
995		testData[i] = byte(rng.Uint32())
996	}
997
998	// Encode the testdata with clean state.
999	// Second part should pick up matches from the first block.
1000	wantFirstTokens := len(enc.encode(nil, testData))
1001	wantSecondTokens := len(enc.encode(nil, testData))
1002
1003	if wantFirstTokens <= wantSecondTokens {
1004		t.Fatalf("test needs matches between inputs to be generated")
1005	}
1006	// Forward the current indicator to before wraparound.
1007	enc.cur = bufferReset - int32(len(testData))
1008
1009	// Part 1 before wrap, should match clean state.
1010	got := len(enc.encode(nil, testData))
1011	if wantFirstTokens != got {
1012		t.Errorf("got %d, want %d tokens", got, wantFirstTokens)
1013	}
1014
1015	// Verify we are about to wrap.
1016	if enc.cur != bufferReset {
1017		t.Errorf("got %d, want e.cur to be at bufferReset (%d)", enc.cur, bufferReset)
1018	}
1019
1020	// Part 2 should match clean state as well even if wrapped.
1021	got = len(enc.encode(nil, testData))
1022	if wantSecondTokens != got {
1023		t.Errorf("got %d, want %d token", got, wantSecondTokens)
1024	}
1025
1026	// Verify that we wrapped.
1027	if enc.cur >= bufferReset {
1028		t.Errorf("want e.cur to be < bufferReset (%d), got %d", bufferReset, enc.cur)
1029	}
1030
1031	// Forward the current buffer, leaving the matches at the bottom.
1032	enc.cur = bufferReset
1033	enc.shiftOffsets()
1034
1035	// Ensure that no matches were picked up.
1036	got = len(enc.encode(nil, testData))
1037	if wantFirstTokens != got {
1038		t.Errorf("got %d, want %d tokens", got, wantFirstTokens)
1039	}
1040}
1041
1042func TestMaxStackSize(t *testing.T) {
1043	// This test must not run in parallel with other tests as debug.SetMaxStack
1044	// affects all goroutines.
1045	n := debug.SetMaxStack(1 << 16)
1046	defer debug.SetMaxStack(n)
1047
1048	var wg sync.WaitGroup
1049	defer wg.Wait()
1050
1051	b := make([]byte, 1<<20)
1052	for level := HuffmanOnly; level <= BestCompression; level++ {
1053		// Run in separate goroutine to increase probability of stack regrowth.
1054		wg.Add(1)
1055		go func(level int) {
1056			defer wg.Done()
1057			zw, err := NewWriter(io.Discard, level)
1058			if err != nil {
1059				t.Errorf("level %d, NewWriter() = %v, want nil", level, err)
1060			}
1061			if n, err := zw.Write(b); n != len(b) || err != nil {
1062				t.Errorf("level %d, Write() = (%d, %v), want (%d, nil)", level, n, err, len(b))
1063			}
1064			if err := zw.Close(); err != nil {
1065				t.Errorf("level %d, Close() = %v, want nil", level, err)
1066			}
1067			zw.Reset(io.Discard)
1068		}(level)
1069	}
1070}
1071