1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package zstd
6
7import (
8	"bytes"
9	"crypto/sha256"
10	"fmt"
11	"internal/race"
12	"internal/testenv"
13	"io"
14	"os"
15	"os/exec"
16	"path/filepath"
17	"strings"
18	"sync"
19	"testing"
20)
21
22// tests holds some simple test cases, including some found by fuzzing.
23var tests = []struct {
24	name, uncompressed, compressed string
25}{
26	{
27		"hello",
28		"hello, world\n",
29		"\x28\xb5\x2f\xfd\x24\x0d\x69\x00\x00\x68\x65\x6c\x6c\x6f\x2c\x20\x77\x6f\x72\x6c\x64\x0a\x4c\x1f\xf9\xf1",
30	},
31	{
32		// a small compressed .debug_ranges section.
33		"ranges",
34		"\xcc\x11\x00\x00\x00\x00\x00\x00\xd5\x13\x00\x00\x00\x00\x00\x00" +
35			"\x1c\x14\x00\x00\x00\x00\x00\x00\x72\x14\x00\x00\x00\x00\x00\x00" +
36			"\x9d\x14\x00\x00\x00\x00\x00\x00\xd5\x14\x00\x00\x00\x00\x00\x00" +
37			"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
38			"\xfb\x12\x00\x00\x00\x00\x00\x00\x09\x13\x00\x00\x00\x00\x00\x00" +
39			"\x0c\x13\x00\x00\x00\x00\x00\x00\xcb\x13\x00\x00\x00\x00\x00\x00" +
40			"\x29\x14\x00\x00\x00\x00\x00\x00\x4e\x14\x00\x00\x00\x00\x00\x00" +
41			"\x9d\x14\x00\x00\x00\x00\x00\x00\xd5\x14\x00\x00\x00\x00\x00\x00" +
42			"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
43			"\xfb\x12\x00\x00\x00\x00\x00\x00\x09\x13\x00\x00\x00\x00\x00\x00" +
44			"\x67\x13\x00\x00\x00\x00\x00\x00\xcb\x13\x00\x00\x00\x00\x00\x00" +
45			"\x9d\x14\x00\x00\x00\x00\x00\x00\xd5\x14\x00\x00\x00\x00\x00\x00" +
46			"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
47			"\x5f\x0b\x00\x00\x00\x00\x00\x00\x6c\x0b\x00\x00\x00\x00\x00\x00" +
48			"\x7d\x0b\x00\x00\x00\x00\x00\x00\x7e\x0c\x00\x00\x00\x00\x00\x00" +
49			"\x38\x0f\x00\x00\x00\x00\x00\x00\x5c\x0f\x00\x00\x00\x00\x00\x00" +
50			"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
51			"\x83\x0c\x00\x00\x00\x00\x00\x00\xfa\x0c\x00\x00\x00\x00\x00\x00" +
52			"\xfd\x0d\x00\x00\x00\x00\x00\x00\xef\x0e\x00\x00\x00\x00\x00\x00" +
53			"\x14\x0f\x00\x00\x00\x00\x00\x00\x38\x0f\x00\x00\x00\x00\x00\x00" +
54			"\x9f\x0f\x00\x00\x00\x00\x00\x00\xac\x0f\x00\x00\x00\x00\x00\x00" +
55			"\xdb\x0f\x00\x00\x00\x00\x00\x00\xff\x0f\x00\x00\x00\x00\x00\x00" +
56			"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
57			"\xfd\x0d\x00\x00\x00\x00\x00\x00\xd8\x0e\x00\x00\x00\x00\x00\x00" +
58			"\x9f\x0f\x00\x00\x00\x00\x00\x00\xac\x0f\x00\x00\x00\x00\x00\x00" +
59			"\xdb\x0f\x00\x00\x00\x00\x00\x00\xff\x0f\x00\x00\x00\x00\x00\x00" +
60			"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
61			"\xfa\x0c\x00\x00\x00\x00\x00\x00\xea\x0d\x00\x00\x00\x00\x00\x00" +
62			"\xef\x0e\x00\x00\x00\x00\x00\x00\x14\x0f\x00\x00\x00\x00\x00\x00" +
63			"\x5c\x0f\x00\x00\x00\x00\x00\x00\x9f\x0f\x00\x00\x00\x00\x00\x00" +
64			"\xac\x0f\x00\x00\x00\x00\x00\x00\xdb\x0f\x00\x00\x00\x00\x00\x00" +
65			"\xff\x0f\x00\x00\x00\x00\x00\x00\x2c\x10\x00\x00\x00\x00\x00\x00" +
66			"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
67			"\x60\x11\x00\x00\x00\x00\x00\x00\xd1\x16\x00\x00\x00\x00\x00\x00" +
68			"\x40\x0b\x00\x00\x00\x00\x00\x00\x2c\x10\x00\x00\x00\x00\x00\x00" +
69			"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
70			"\x7a\x00\x00\x00\x00\x00\x00\x00\xb6\x00\x00\x00\x00\x00\x00\x00" +
71			"\x9f\x01\x00\x00\x00\x00\x00\x00\xa7\x01\x00\x00\x00\x00\x00\x00" +
72			"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
73			"\x7a\x00\x00\x00\x00\x00\x00\x00\xa9\x00\x00\x00\x00\x00\x00\x00" +
74			"\x9f\x01\x00\x00\x00\x00\x00\x00\xa7\x01\x00\x00\x00\x00\x00\x00" +
75			"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
76
77		"\x28\xb5\x2f\xfd\x64\xa0\x01\x2d\x05\x00\xc4\x04\xcc\x11\x00\xd5" +
78			"\x13\x00\x1c\x14\x00\x72\x9d\xd5\xfb\x12\x00\x09\x0c\x13\xcb\x13" +
79			"\x29\x4e\x67\x5f\x0b\x6c\x0b\x7d\x0b\x7e\x0c\x38\x0f\x5c\x0f\x83" +
80			"\x0c\xfa\x0c\xfd\x0d\xef\x0e\x14\x38\x9f\x0f\xac\x0f\xdb\x0f\xff" +
81			"\x0f\xd8\x9f\xac\xdb\xff\xea\x5c\x2c\x10\x60\xd1\x16\x40\x0b\x7a" +
82			"\x00\xb6\x00\x9f\x01\xa7\x01\xa9\x36\x20\xa0\x83\x14\x34\x63\x4a" +
83			"\x21\x70\x8c\x07\x46\x03\x4e\x10\x62\x3c\x06\x4e\xc8\x8c\xb0\x32" +
84			"\x2a\x59\xad\xb2\xf1\x02\x82\x7c\x33\xcb\x92\x6f\x32\x4f\x9b\xb0" +
85			"\xa2\x30\xf0\xc0\x06\x1e\x98\x99\x2c\x06\x1e\xd8\xc0\x03\x56\xd8" +
86			"\xc0\x03\x0f\x6c\xe0\x01\xf1\xf0\xee\x9a\xc6\xc8\x97\x99\xd1\x6c" +
87			"\xb4\x21\x45\x3b\x10\xe4\x7b\x99\x4d\x8a\x36\x64\x5c\x77\x08\x02" +
88			"\xcb\xe0\xce",
89	},
90	{
91		"fuzz1",
92		"0\x00\x00\x00\x00\x000\x00\x00\x00\x00\x001\x00\x00\x00\x00\x000000",
93		"(\xb5/\xfd\x04X\x8d\x00\x00P0\x000\x001\x000000\x03T\x02\x00\x01\x01m\xf9\xb7G",
94	},
95	{
96		"empty block",
97		"",
98		"\x28\xb5\x2f\xfd\x00\x00\x15\x00\x00\x00\x00",
99	},
100	{
101		"single skippable frame",
102		"",
103		"\x50\x2a\x4d\x18\x00\x00\x00\x00",
104	},
105	{
106		"two skippable frames",
107		"",
108		"\x50\x2a\x4d\x18\x00\x00\x00\x00" +
109			"\x50\x2a\x4d\x18\x00\x00\x00\x00",
110	},
111}
112
113func TestSamples(t *testing.T) {
114	for _, test := range tests {
115		test := test
116		t.Run(test.name, func(t *testing.T) {
117			r := NewReader(strings.NewReader(test.compressed))
118			got, err := io.ReadAll(r)
119			if err != nil {
120				t.Fatal(err)
121			}
122			gotstr := string(got)
123			if gotstr != test.uncompressed {
124				t.Errorf("got %q want %q", gotstr, test.uncompressed)
125			}
126		})
127	}
128}
129
130func TestReset(t *testing.T) {
131	input := strings.NewReader("")
132	r := NewReader(input)
133	for _, test := range tests {
134		test := test
135		t.Run(test.name, func(t *testing.T) {
136			input.Reset(test.compressed)
137			r.Reset(input)
138			got, err := io.ReadAll(r)
139			if err != nil {
140				t.Fatal(err)
141			}
142			gotstr := string(got)
143			if gotstr != test.uncompressed {
144				t.Errorf("got %q want %q", gotstr, test.uncompressed)
145			}
146		})
147	}
148}
149
150var (
151	bigDataOnce  sync.Once
152	bigDataBytes []byte
153	bigDataErr   error
154)
155
156// bigData returns the contents of our large test file repeated multiple times.
157func bigData(t testing.TB) []byte {
158	bigDataOnce.Do(func() {
159		bigDataBytes, bigDataErr = os.ReadFile("../../testdata/Isaac.Newton-Opticks.txt")
160		if bigDataErr == nil {
161			bigDataBytes = bytes.Repeat(bigDataBytes, 20)
162		}
163	})
164	if bigDataErr != nil {
165		t.Fatal(bigDataErr)
166	}
167	return bigDataBytes
168}
169
170func findZstd(t testing.TB) string {
171	zstd, err := exec.LookPath("zstd")
172	if err != nil {
173		t.Skip("skipping because zstd not found")
174	}
175	return zstd
176}
177
178var (
179	zstdBigOnce  sync.Once
180	zstdBigBytes []byte
181	zstdBigErr   error
182)
183
184// zstdBigData returns the compressed contents of our large test file.
185// This will only run on Unix systems with zstd installed.
186// That's OK as the package is GOOS-independent.
187func zstdBigData(t testing.TB) []byte {
188	input := bigData(t)
189
190	zstd := findZstd(t)
191
192	zstdBigOnce.Do(func() {
193		cmd := exec.Command(zstd, "-z")
194		cmd.Stdin = bytes.NewReader(input)
195		var compressed bytes.Buffer
196		cmd.Stdout = &compressed
197		cmd.Stderr = os.Stderr
198		if err := cmd.Run(); err != nil {
199			zstdBigErr = fmt.Errorf("running zstd failed: %v", err)
200			return
201		}
202
203		zstdBigBytes = compressed.Bytes()
204	})
205	if zstdBigErr != nil {
206		t.Fatal(zstdBigErr)
207	}
208	return zstdBigBytes
209}
210
211// Test decompressing a large file. We don't have a compressor,
212// so this test only runs on systems with zstd installed.
213func TestLarge(t *testing.T) {
214	if testing.Short() {
215		t.Skip("skipping expensive test in short mode")
216	}
217
218	data := bigData(t)
219	compressed := zstdBigData(t)
220
221	t.Logf("zstd compressed %d bytes to %d", len(data), len(compressed))
222
223	r := NewReader(bytes.NewReader(compressed))
224	got, err := io.ReadAll(r)
225	if err != nil {
226		t.Fatal(err)
227	}
228
229	if !bytes.Equal(got, data) {
230		showDiffs(t, got, data)
231	}
232}
233
234// showDiffs reports the first few differences in two []byte.
235func showDiffs(t *testing.T, got, want []byte) {
236	t.Error("data mismatch")
237	if len(got) != len(want) {
238		t.Errorf("got data length %d, want %d", len(got), len(want))
239	}
240	diffs := 0
241	for i, b := range got {
242		if i >= len(want) {
243			break
244		}
245		if b != want[i] {
246			diffs++
247			if diffs > 20 {
248				break
249			}
250			t.Logf("%d: %#x != %#x", i, b, want[i])
251		}
252	}
253}
254
255func TestAlloc(t *testing.T) {
256	testenv.SkipIfOptimizationOff(t)
257	if race.Enabled {
258		t.Skip("skipping allocation test under race detector")
259	}
260
261	compressed := zstdBigData(t)
262	input := bytes.NewReader(compressed)
263	r := NewReader(input)
264	c := testing.AllocsPerRun(10, func() {
265		input.Reset(compressed)
266		r.Reset(input)
267		io.Copy(io.Discard, r)
268	})
269	if c != 0 {
270		t.Errorf("got %v allocs, want 0", c)
271	}
272}
273
274func TestFileSamples(t *testing.T) {
275	samples, err := os.ReadDir("testdata")
276	if err != nil {
277		t.Fatal(err)
278	}
279
280	for _, sample := range samples {
281		name := sample.Name()
282		if !strings.HasSuffix(name, ".zst") {
283			continue
284		}
285
286		t.Run(name, func(t *testing.T) {
287			f, err := os.Open(filepath.Join("testdata", name))
288			if err != nil {
289				t.Fatal(err)
290			}
291
292			r := NewReader(f)
293			h := sha256.New()
294			if _, err := io.Copy(h, r); err != nil {
295				t.Fatal(err)
296			}
297			got := fmt.Sprintf("%x", h.Sum(nil))[:8]
298
299			want, _, _ := strings.Cut(name, ".")
300			if got != want {
301				t.Errorf("Wrong uncompressed content hash: got %s, want %s", got, want)
302			}
303		})
304	}
305}
306
307func TestReaderBad(t *testing.T) {
308	for i, s := range badStrings {
309		t.Run(fmt.Sprintf("badStrings#%d", i), func(t *testing.T) {
310			_, err := io.Copy(io.Discard, NewReader(strings.NewReader(s)))
311			if err == nil {
312				t.Error("expected error")
313			}
314		})
315	}
316}
317
318func BenchmarkLarge(b *testing.B) {
319	b.StopTimer()
320	b.ReportAllocs()
321
322	compressed := zstdBigData(b)
323
324	b.SetBytes(int64(len(compressed)))
325
326	input := bytes.NewReader(compressed)
327	r := NewReader(input)
328
329	b.StartTimer()
330	for i := 0; i < b.N; i++ {
331		input.Reset(compressed)
332		r.Reset(input)
333		io.Copy(io.Discard, r)
334	}
335}
336