1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package base32
6
7import (
8	"bytes"
9	"errors"
10	"io"
11	"math"
12	"strconv"
13	"strings"
14	"testing"
15)
16
17type testpair struct {
18	decoded, encoded string
19}
20
21var pairs = []testpair{
22	// RFC 4648 examples
23	{"", ""},
24	{"f", "MY======"},
25	{"fo", "MZXQ===="},
26	{"foo", "MZXW6==="},
27	{"foob", "MZXW6YQ="},
28	{"fooba", "MZXW6YTB"},
29	{"foobar", "MZXW6YTBOI======"},
30
31	// Wikipedia examples, converted to base32
32	{"sure.", "ON2XEZJO"},
33	{"sure", "ON2XEZI="},
34	{"sur", "ON2XE==="},
35	{"su", "ON2Q===="},
36	{"leasure.", "NRSWC43VOJSS4==="},
37	{"easure.", "MVQXG5LSMUXA===="},
38	{"asure.", "MFZXK4TFFY======"},
39	{"sure.", "ON2XEZJO"},
40}
41
42var bigtest = testpair{
43	"Twas brillig, and the slithy toves",
44	"KR3WC4ZAMJZGS3DMNFTSYIDBNZSCA5DIMUQHG3DJORUHSIDUN53GK4Y=",
45}
46
47func testEqual(t *testing.T, msg string, args ...any) bool {
48	t.Helper()
49	if args[len(args)-2] != args[len(args)-1] {
50		t.Errorf(msg, args...)
51		return false
52	}
53	return true
54}
55
56func TestEncode(t *testing.T) {
57	for _, p := range pairs {
58		got := StdEncoding.EncodeToString([]byte(p.decoded))
59		testEqual(t, "Encode(%q) = %q, want %q", p.decoded, got, p.encoded)
60		dst := StdEncoding.AppendEncode([]byte("lead"), []byte(p.decoded))
61		testEqual(t, `AppendEncode("lead", %q) = %q, want %q`, p.decoded, string(dst), "lead"+p.encoded)
62	}
63}
64
65func TestEncoder(t *testing.T) {
66	for _, p := range pairs {
67		bb := &strings.Builder{}
68		encoder := NewEncoder(StdEncoding, bb)
69		encoder.Write([]byte(p.decoded))
70		encoder.Close()
71		testEqual(t, "Encode(%q) = %q, want %q", p.decoded, bb.String(), p.encoded)
72	}
73}
74
75func TestEncoderBuffering(t *testing.T) {
76	input := []byte(bigtest.decoded)
77	for bs := 1; bs <= 12; bs++ {
78		bb := &strings.Builder{}
79		encoder := NewEncoder(StdEncoding, bb)
80		for pos := 0; pos < len(input); pos += bs {
81			end := pos + bs
82			if end > len(input) {
83				end = len(input)
84			}
85			n, err := encoder.Write(input[pos:end])
86			testEqual(t, "Write(%q) gave error %v, want %v", input[pos:end], err, error(nil))
87			testEqual(t, "Write(%q) gave length %v, want %v", input[pos:end], n, end-pos)
88		}
89		err := encoder.Close()
90		testEqual(t, "Close gave error %v, want %v", err, error(nil))
91		testEqual(t, "Encoding/%d of %q = %q, want %q", bs, bigtest.decoded, bb.String(), bigtest.encoded)
92	}
93}
94
95func TestDecoderBufferingWithPadding(t *testing.T) {
96	for bs := 0; bs <= 12; bs++ {
97		for _, s := range pairs {
98			decoder := NewDecoder(StdEncoding, strings.NewReader(s.encoded))
99			buf := make([]byte, len(s.decoded)+bs)
100
101			var n int
102			var err error
103			n, err = decoder.Read(buf)
104
105			if err != nil && err != io.EOF {
106				t.Errorf("Read from %q at pos %d = %d, unexpected error %v", s.encoded, len(s.decoded), n, err)
107			}
108			testEqual(t, "Decoding/%d of %q = %q, want %q\n", bs, s.encoded, string(buf[:n]), s.decoded)
109		}
110	}
111}
112
113func TestDecoderBufferingWithoutPadding(t *testing.T) {
114	for bs := 0; bs <= 12; bs++ {
115		for _, s := range pairs {
116			encoded := strings.TrimRight(s.encoded, "=")
117			decoder := NewDecoder(StdEncoding.WithPadding(NoPadding), strings.NewReader(encoded))
118			buf := make([]byte, len(s.decoded)+bs)
119
120			var n int
121			var err error
122			n, err = decoder.Read(buf)
123
124			if err != nil && err != io.EOF {
125				t.Errorf("Read from %q at pos %d = %d, unexpected error %v", encoded, len(s.decoded), n, err)
126			}
127			testEqual(t, "Decoding/%d of %q = %q, want %q\n", bs, encoded, string(buf[:n]), s.decoded)
128		}
129	}
130}
131
132func TestDecode(t *testing.T) {
133	for _, p := range pairs {
134		dbuf := make([]byte, StdEncoding.DecodedLen(len(p.encoded)))
135		count, end, err := StdEncoding.decode(dbuf, []byte(p.encoded))
136		testEqual(t, "Decode(%q) = error %v, want %v", p.encoded, err, error(nil))
137		testEqual(t, "Decode(%q) = length %v, want %v", p.encoded, count, len(p.decoded))
138		if len(p.encoded) > 0 {
139			testEqual(t, "Decode(%q) = end %v, want %v", p.encoded, end, (p.encoded[len(p.encoded)-1] == '='))
140		}
141		testEqual(t, "Decode(%q) = %q, want %q", p.encoded, string(dbuf[0:count]), p.decoded)
142
143		dbuf, err = StdEncoding.DecodeString(p.encoded)
144		testEqual(t, "DecodeString(%q) = error %v, want %v", p.encoded, err, error(nil))
145		testEqual(t, "DecodeString(%q) = %q, want %q", p.encoded, string(dbuf), p.decoded)
146
147		dst, err := StdEncoding.AppendDecode([]byte("lead"), []byte(p.encoded))
148		testEqual(t, "AppendDecode(%q) = error %v, want %v", p.encoded, err, error(nil))
149		testEqual(t, `AppendDecode("lead", %q) = %q, want %q`, p.encoded, string(dst), "lead"+p.decoded)
150
151		dst2, err := StdEncoding.AppendDecode(dst[:0:len(p.decoded)], []byte(p.encoded))
152		testEqual(t, "AppendDecode(%q) = error %v, want %v", p.encoded, err, error(nil))
153		testEqual(t, `AppendDecode("", %q) = %q, want %q`, p.encoded, string(dst2), p.decoded)
154		if len(dst) > 0 && len(dst2) > 0 && &dst[0] != &dst2[0] {
155			t.Errorf("unexpected capacity growth: got %d, want %d", cap(dst2), cap(dst))
156		}
157	}
158}
159
160func TestDecoder(t *testing.T) {
161	for _, p := range pairs {
162		decoder := NewDecoder(StdEncoding, strings.NewReader(p.encoded))
163		dbuf := make([]byte, StdEncoding.DecodedLen(len(p.encoded)))
164		count, err := decoder.Read(dbuf)
165		if err != nil && err != io.EOF {
166			t.Fatal("Read failed", err)
167		}
168		testEqual(t, "Read from %q = length %v, want %v", p.encoded, count, len(p.decoded))
169		testEqual(t, "Decoding of %q = %q, want %q", p.encoded, string(dbuf[0:count]), p.decoded)
170		if err != io.EOF {
171			_, err = decoder.Read(dbuf)
172		}
173		testEqual(t, "Read from %q = %v, want %v", p.encoded, err, io.EOF)
174	}
175}
176
177type badReader struct {
178	data   []byte
179	errs   []error
180	called int
181	limit  int
182}
183
184// Populates p with data, returns a count of the bytes written and an
185// error.  The error returned is taken from badReader.errs, with each
186// invocation of Read returning the next error in this slice, or io.EOF,
187// if all errors from the slice have already been returned.  The
188// number of bytes returned is determined by the size of the input buffer
189// the test passes to decoder.Read and will be a multiple of 8, unless
190// badReader.limit is non zero.
191func (b *badReader) Read(p []byte) (int, error) {
192	lim := len(p)
193	if b.limit != 0 && b.limit < lim {
194		lim = b.limit
195	}
196	if len(b.data) < lim {
197		lim = len(b.data)
198	}
199	for i := range p[:lim] {
200		p[i] = b.data[i]
201	}
202	b.data = b.data[lim:]
203	err := io.EOF
204	if b.called < len(b.errs) {
205		err = b.errs[b.called]
206	}
207	b.called++
208	return lim, err
209}
210
211// TestIssue20044 tests that decoder.Read behaves correctly when the caller
212// supplied reader returns an error.
213func TestIssue20044(t *testing.T) {
214	badErr := errors.New("bad reader error")
215	testCases := []struct {
216		r       badReader
217		res     string
218		err     error
219		dbuflen int
220	}{
221		// Check valid input data accompanied by an error is processed and the error is propagated.
222		{r: badReader{data: []byte("MY======"), errs: []error{badErr}},
223			res: "f", err: badErr},
224		// Check a read error accompanied by input data consisting of newlines only is propagated.
225		{r: badReader{data: []byte("\n\n\n\n\n\n\n\n"), errs: []error{badErr, nil}},
226			res: "", err: badErr},
227		// Reader will be called twice.  The first time it will return 8 newline characters.  The
228		// second time valid base32 encoded data and an error.  The data should be decoded
229		// correctly and the error should be propagated.
230		{r: badReader{data: []byte("\n\n\n\n\n\n\n\nMY======"), errs: []error{nil, badErr}},
231			res: "f", err: badErr, dbuflen: 8},
232		// Reader returns invalid input data (too short) and an error.  Verify the reader
233		// error is returned.
234		{r: badReader{data: []byte("MY====="), errs: []error{badErr}},
235			res: "", err: badErr},
236		// Reader returns invalid input data (too short) but no error.  Verify io.ErrUnexpectedEOF
237		// is returned.
238		{r: badReader{data: []byte("MY====="), errs: []error{nil}},
239			res: "", err: io.ErrUnexpectedEOF},
240		// Reader returns invalid input data and an error.  Verify the reader and not the
241		// decoder error is returned.
242		{r: badReader{data: []byte("Ma======"), errs: []error{badErr}},
243			res: "", err: badErr},
244		// Reader returns valid data and io.EOF.  Check data is decoded and io.EOF is propagated.
245		{r: badReader{data: []byte("MZXW6YTB"), errs: []error{io.EOF}},
246			res: "fooba", err: io.EOF},
247		// Check errors are properly reported when decoder.Read is called multiple times.
248		// decoder.Read will be called 8 times, badReader.Read will be called twice, returning
249		// valid data both times but an error on the second call.
250		{r: badReader{data: []byte("NRSWC43VOJSS4==="), errs: []error{nil, badErr}},
251			res: "leasure.", err: badErr, dbuflen: 1},
252		// Check io.EOF is properly reported when decoder.Read is called multiple times.
253		// decoder.Read will be called 8 times, badReader.Read will be called twice, returning
254		// valid data both times but io.EOF on the second call.
255		{r: badReader{data: []byte("NRSWC43VOJSS4==="), errs: []error{nil, io.EOF}},
256			res: "leasure.", err: io.EOF, dbuflen: 1},
257		// The following two test cases check that errors are propagated correctly when more than
258		// 8 bytes are read at a time.
259		{r: badReader{data: []byte("NRSWC43VOJSS4==="), errs: []error{io.EOF}},
260			res: "leasure.", err: io.EOF, dbuflen: 11},
261		{r: badReader{data: []byte("NRSWC43VOJSS4==="), errs: []error{badErr}},
262			res: "leasure.", err: badErr, dbuflen: 11},
263		// Check that errors are correctly propagated when the reader returns valid bytes in
264		// groups that are not divisible by 8.  The first read will return 11 bytes and no
265		// error.  The second will return 7 and an error.  The data should be decoded correctly
266		// and the error should be propagated.
267		{r: badReader{data: []byte("NRSWC43VOJSS4==="), errs: []error{nil, badErr}, limit: 11},
268			res: "leasure.", err: badErr},
269	}
270
271	for _, tc := range testCases {
272		input := tc.r.data
273		decoder := NewDecoder(StdEncoding, &tc.r)
274		var dbuflen int
275		if tc.dbuflen > 0 {
276			dbuflen = tc.dbuflen
277		} else {
278			dbuflen = StdEncoding.DecodedLen(len(input))
279		}
280		dbuf := make([]byte, dbuflen)
281		var err error
282		var res []byte
283		for err == nil {
284			var n int
285			n, err = decoder.Read(dbuf)
286			if n > 0 {
287				res = append(res, dbuf[:n]...)
288			}
289		}
290
291		testEqual(t, "Decoding of %q = %q, want %q", string(input), string(res), tc.res)
292		testEqual(t, "Decoding of %q err = %v, expected %v", string(input), err, tc.err)
293	}
294}
295
296// TestDecoderError verifies decode errors are propagated when there are no read
297// errors.
298func TestDecoderError(t *testing.T) {
299	for _, readErr := range []error{io.EOF, nil} {
300		input := "MZXW6YTb"
301		dbuf := make([]byte, StdEncoding.DecodedLen(len(input)))
302		br := badReader{data: []byte(input), errs: []error{readErr}}
303		decoder := NewDecoder(StdEncoding, &br)
304		n, err := decoder.Read(dbuf)
305		testEqual(t, "Read after EOF, n = %d, expected %d", n, 0)
306		if _, ok := err.(CorruptInputError); !ok {
307			t.Errorf("Corrupt input error expected.  Found %T", err)
308		}
309	}
310}
311
312// TestReaderEOF ensures decoder.Read behaves correctly when input data is
313// exhausted.
314func TestReaderEOF(t *testing.T) {
315	for _, readErr := range []error{io.EOF, nil} {
316		input := "MZXW6YTB"
317		br := badReader{data: []byte(input), errs: []error{nil, readErr}}
318		decoder := NewDecoder(StdEncoding, &br)
319		dbuf := make([]byte, StdEncoding.DecodedLen(len(input)))
320		n, err := decoder.Read(dbuf)
321		testEqual(t, "Decoding of %q err = %v, expected %v", input, err, error(nil))
322		n, err = decoder.Read(dbuf)
323		testEqual(t, "Read after EOF, n = %d, expected %d", n, 0)
324		testEqual(t, "Read after EOF, err = %v, expected %v", err, io.EOF)
325		n, err = decoder.Read(dbuf)
326		testEqual(t, "Read after EOF, n = %d, expected %d", n, 0)
327		testEqual(t, "Read after EOF, err = %v, expected %v", err, io.EOF)
328	}
329}
330
331func TestDecoderBuffering(t *testing.T) {
332	for bs := 1; bs <= 12; bs++ {
333		decoder := NewDecoder(StdEncoding, strings.NewReader(bigtest.encoded))
334		buf := make([]byte, len(bigtest.decoded)+12)
335		var total int
336		var n int
337		var err error
338		for total = 0; total < len(bigtest.decoded) && err == nil; {
339			n, err = decoder.Read(buf[total : total+bs])
340			total += n
341		}
342		if err != nil && err != io.EOF {
343			t.Errorf("Read from %q at pos %d = %d, unexpected error %v", bigtest.encoded, total, n, err)
344		}
345		testEqual(t, "Decoding/%d of %q = %q, want %q", bs, bigtest.encoded, string(buf[0:total]), bigtest.decoded)
346	}
347}
348
349func TestDecodeCorrupt(t *testing.T) {
350	testCases := []struct {
351		input  string
352		offset int // -1 means no corruption.
353	}{
354		{"", -1},
355		{"!!!!", 0},
356		{"x===", 0},
357		{"AA=A====", 2},
358		{"AAA=AAAA", 3},
359		{"MMMMMMMMM", 8},
360		{"MMMMMM", 0},
361		{"A=", 1},
362		{"AA=", 3},
363		{"AA==", 4},
364		{"AA===", 5},
365		{"AAAA=", 5},
366		{"AAAA==", 6},
367		{"AAAAA=", 6},
368		{"AAAAA==", 7},
369		{"A=======", 1},
370		{"AA======", -1},
371		{"AAA=====", 3},
372		{"AAAA====", -1},
373		{"AAAAA===", -1},
374		{"AAAAAA==", 6},
375		{"AAAAAAA=", -1},
376		{"AAAAAAAA", -1},
377	}
378	for _, tc := range testCases {
379		dbuf := make([]byte, StdEncoding.DecodedLen(len(tc.input)))
380		_, err := StdEncoding.Decode(dbuf, []byte(tc.input))
381		if tc.offset == -1 {
382			if err != nil {
383				t.Error("Decoder wrongly detected corruption in", tc.input)
384			}
385			continue
386		}
387		switch err := err.(type) {
388		case CorruptInputError:
389			testEqual(t, "Corruption in %q at offset %v, want %v", tc.input, int(err), tc.offset)
390		default:
391			t.Error("Decoder failed to detect corruption in", tc)
392		}
393	}
394}
395
396func TestBig(t *testing.T) {
397	n := 3*1000 + 1
398	raw := make([]byte, n)
399	const alpha = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
400	for i := 0; i < n; i++ {
401		raw[i] = alpha[i%len(alpha)]
402	}
403	encoded := new(bytes.Buffer)
404	w := NewEncoder(StdEncoding, encoded)
405	nn, err := w.Write(raw)
406	if nn != n || err != nil {
407		t.Fatalf("Encoder.Write(raw) = %d, %v want %d, nil", nn, err, n)
408	}
409	err = w.Close()
410	if err != nil {
411		t.Fatalf("Encoder.Close() = %v want nil", err)
412	}
413	decoded, err := io.ReadAll(NewDecoder(StdEncoding, encoded))
414	if err != nil {
415		t.Fatalf("io.ReadAll(NewDecoder(...)): %v", err)
416	}
417
418	if !bytes.Equal(raw, decoded) {
419		var i int
420		for i = 0; i < len(decoded) && i < len(raw); i++ {
421			if decoded[i] != raw[i] {
422				break
423			}
424		}
425		t.Errorf("Decode(Encode(%d-byte string)) failed at offset %d", n, i)
426	}
427}
428
429func testStringEncoding(t *testing.T, expected string, examples []string) {
430	for _, e := range examples {
431		buf, err := StdEncoding.DecodeString(e)
432		if err != nil {
433			t.Errorf("Decode(%q) failed: %v", e, err)
434			continue
435		}
436		if s := string(buf); s != expected {
437			t.Errorf("Decode(%q) = %q, want %q", e, s, expected)
438		}
439	}
440}
441
442func TestNewLineCharacters(t *testing.T) {
443	// Each of these should decode to the string "sure", without errors.
444	examples := []string{
445		"ON2XEZI=",
446		"ON2XEZI=\r",
447		"ON2XEZI=\n",
448		"ON2XEZI=\r\n",
449		"ON2XEZ\r\nI=",
450		"ON2X\rEZ\nI=",
451		"ON2X\nEZ\rI=",
452		"ON2XEZ\nI=",
453		"ON2XEZI\n=",
454	}
455	testStringEncoding(t, "sure", examples)
456
457	// Each of these should decode to the string "foobar", without errors.
458	examples = []string{
459		"MZXW6YTBOI======",
460		"MZXW6YTBOI=\r\n=====",
461	}
462	testStringEncoding(t, "foobar", examples)
463}
464
465func TestDecoderIssue4779(t *testing.T) {
466	encoded := `JRXXEZLNEBUXA43VNUQGI33MN5ZCA43JOQQGC3LFOQWCAY3PNZZWKY3UMV2HK4
467RAMFSGS4DJONUWG2LOM4QGK3DJOQWCA43FMQQGI3YKMVUXK43NN5SCA5DFNVYG64RANFXGG2LENFSH
468K3TUEB2XIIDMMFRG64TFEBSXIIDEN5WG64TFEBWWCZ3OMEQGC3DJOF2WCLRAKV2CAZLONFWQUYLEEB
469WWS3TJNUQHMZLONFQW2LBAOF2WS4ZANZXXG5DSOVSCAZLYMVZGG2LUMF2GS33OEB2WY3DBNVRW6IDM
470MFRG64TJOMQG42LTNEQHK5AKMFWGS4LVNFYCAZLYEBSWCIDDN5WW233EN4QGG33OONSXC5LBOQXCAR
471DVNFZSAYLVORSSA2LSOVZGKIDEN5WG64RANFXAU4TFOBZGK2DFNZSGK4TJOQQGS3RAOZXWY5LQORQX
472IZJAOZSWY2LUEBSXG43FEBRWS3DMOVWSAZDPNRXXEZJAMV2SAZTVM5UWC5BANZ2WY3DBBJYGC4TJMF
4732HK4ROEBCXQY3FOB2GK5LSEBZWS3TUEBXWGY3BMVRWC5BAMN2XA2LEMF2GC5BANZXW4IDQOJXWSZDF
474NZ2CYIDTOVXHIIDJNYFGG5LMOBQSA4LVNEQG6ZTGNFRWSYJAMRSXGZLSOVXHIIDNN5WGY2LUEBQW42
475LNEBUWIIDFON2CA3DBMJXXE5LNFY==
476====`
477	encodedShort := strings.ReplaceAll(encoded, "\n", "")
478
479	dec := NewDecoder(StdEncoding, strings.NewReader(encoded))
480	res1, err := io.ReadAll(dec)
481	if err != nil {
482		t.Errorf("ReadAll failed: %v", err)
483	}
484
485	dec = NewDecoder(StdEncoding, strings.NewReader(encodedShort))
486	var res2 []byte
487	res2, err = io.ReadAll(dec)
488	if err != nil {
489		t.Errorf("ReadAll failed: %v", err)
490	}
491
492	if !bytes.Equal(res1, res2) {
493		t.Error("Decoded results not equal")
494	}
495}
496
497func BenchmarkEncode(b *testing.B) {
498	data := make([]byte, 8192)
499	buf := make([]byte, StdEncoding.EncodedLen(len(data)))
500	b.SetBytes(int64(len(data)))
501	for i := 0; i < b.N; i++ {
502		StdEncoding.Encode(buf, data)
503	}
504}
505
506func BenchmarkEncodeToString(b *testing.B) {
507	data := make([]byte, 8192)
508	b.SetBytes(int64(len(data)))
509	for i := 0; i < b.N; i++ {
510		StdEncoding.EncodeToString(data)
511	}
512}
513
514func BenchmarkDecode(b *testing.B) {
515	data := make([]byte, StdEncoding.EncodedLen(8192))
516	StdEncoding.Encode(data, make([]byte, 8192))
517	buf := make([]byte, 8192)
518	b.SetBytes(int64(len(data)))
519	for i := 0; i < b.N; i++ {
520		StdEncoding.Decode(buf, data)
521	}
522}
523func BenchmarkDecodeString(b *testing.B) {
524	data := StdEncoding.EncodeToString(make([]byte, 8192))
525	b.SetBytes(int64(len(data)))
526	for i := 0; i < b.N; i++ {
527		StdEncoding.DecodeString(data)
528	}
529}
530
531func TestWithCustomPadding(t *testing.T) {
532	for _, testcase := range pairs {
533		defaultPadding := StdEncoding.EncodeToString([]byte(testcase.decoded))
534		customPadding := StdEncoding.WithPadding('@').EncodeToString([]byte(testcase.decoded))
535		expected := strings.ReplaceAll(defaultPadding, "=", "@")
536
537		if expected != customPadding {
538			t.Errorf("Expected custom %s, got %s", expected, customPadding)
539		}
540		if testcase.encoded != defaultPadding {
541			t.Errorf("Expected %s, got %s", testcase.encoded, defaultPadding)
542		}
543	}
544}
545
546func TestWithoutPadding(t *testing.T) {
547	for _, testcase := range pairs {
548		defaultPadding := StdEncoding.EncodeToString([]byte(testcase.decoded))
549		customPadding := StdEncoding.WithPadding(NoPadding).EncodeToString([]byte(testcase.decoded))
550		expected := strings.TrimRight(defaultPadding, "=")
551
552		if expected != customPadding {
553			t.Errorf("Expected custom %s, got %s", expected, customPadding)
554		}
555		if testcase.encoded != defaultPadding {
556			t.Errorf("Expected %s, got %s", testcase.encoded, defaultPadding)
557		}
558	}
559}
560
561func TestDecodeWithPadding(t *testing.T) {
562	encodings := []*Encoding{
563		StdEncoding,
564		StdEncoding.WithPadding('-'),
565		StdEncoding.WithPadding(NoPadding),
566	}
567
568	for i, enc := range encodings {
569		for _, pair := range pairs {
570
571			input := pair.decoded
572			encoded := enc.EncodeToString([]byte(input))
573
574			decoded, err := enc.DecodeString(encoded)
575			if err != nil {
576				t.Errorf("DecodeString Error for encoding %d (%q): %v", i, input, err)
577			}
578
579			if input != string(decoded) {
580				t.Errorf("Unexpected result for encoding %d: got %q; want %q", i, decoded, input)
581			}
582		}
583	}
584}
585
586func TestDecodeWithWrongPadding(t *testing.T) {
587	encoded := StdEncoding.EncodeToString([]byte("foobar"))
588
589	_, err := StdEncoding.WithPadding('-').DecodeString(encoded)
590	if err == nil {
591		t.Error("expected error")
592	}
593
594	_, err = StdEncoding.WithPadding(NoPadding).DecodeString(encoded)
595	if err == nil {
596		t.Error("expected error")
597	}
598}
599
600func TestBufferedDecodingSameError(t *testing.T) {
601	testcases := []struct {
602		prefix            string
603		chunkCombinations [][]string
604		expected          error
605	}{
606		// NBSWY3DPO5XXE3DE == helloworld
607		// Test with "ZZ" as extra input
608		{"helloworld", [][]string{
609			{"NBSW", "Y3DP", "O5XX", "E3DE", "ZZ"},
610			{"NBSWY3DPO5XXE3DE", "ZZ"},
611			{"NBSWY3DPO5XXE3DEZZ"},
612			{"NBS", "WY3", "DPO", "5XX", "E3D", "EZZ"},
613			{"NBSWY3DPO5XXE3", "DEZZ"},
614		}, io.ErrUnexpectedEOF},
615
616		// Test with "ZZY" as extra input
617		{"helloworld", [][]string{
618			{"NBSW", "Y3DP", "O5XX", "E3DE", "ZZY"},
619			{"NBSWY3DPO5XXE3DE", "ZZY"},
620			{"NBSWY3DPO5XXE3DEZZY"},
621			{"NBS", "WY3", "DPO", "5XX", "E3D", "EZZY"},
622			{"NBSWY3DPO5XXE3", "DEZZY"},
623		}, io.ErrUnexpectedEOF},
624
625		// Normal case, this is valid input
626		{"helloworld", [][]string{
627			{"NBSW", "Y3DP", "O5XX", "E3DE"},
628			{"NBSWY3DPO5XXE3DE"},
629			{"NBS", "WY3", "DPO", "5XX", "E3D", "E"},
630			{"NBSWY3DPO5XXE3", "DE"},
631		}, nil},
632
633		// MZXW6YTB = fooba
634		{"fooba", [][]string{
635			{"MZXW6YTBZZ"},
636			{"MZXW6YTBZ", "Z"},
637			{"MZXW6YTB", "ZZ"},
638			{"MZXW6YT", "BZZ"},
639			{"MZXW6Y", "TBZZ"},
640			{"MZXW6Y", "TB", "ZZ"},
641			{"MZXW6", "YTBZZ"},
642			{"MZXW6", "YTB", "ZZ"},
643			{"MZXW6", "YT", "BZZ"},
644		}, io.ErrUnexpectedEOF},
645
646		// Normal case, this is valid input
647		{"fooba", [][]string{
648			{"MZXW6YTB"},
649			{"MZXW6YT", "B"},
650			{"MZXW6Y", "TB"},
651			{"MZXW6", "YTB"},
652			{"MZXW6", "YT", "B"},
653			{"MZXW", "6YTB"},
654			{"MZXW", "6Y", "TB"},
655		}, nil},
656	}
657
658	for _, testcase := range testcases {
659		for _, chunks := range testcase.chunkCombinations {
660			pr, pw := io.Pipe()
661
662			// Write the encoded chunks into the pipe
663			go func() {
664				for _, chunk := range chunks {
665					pw.Write([]byte(chunk))
666				}
667				pw.Close()
668			}()
669
670			decoder := NewDecoder(StdEncoding, pr)
671			_, err := io.ReadAll(decoder)
672
673			if err != testcase.expected {
674				t.Errorf("Expected %v, got %v; case %s %+v", testcase.expected, err, testcase.prefix, chunks)
675			}
676		}
677	}
678}
679
680func TestBufferedDecodingPadding(t *testing.T) {
681	testcases := []struct {
682		chunks        []string
683		expectedError string
684	}{
685		{[]string{
686			"I4======",
687			"==",
688		}, "unexpected EOF"},
689
690		{[]string{
691			"I4======N4======",
692		}, "illegal base32 data at input byte 2"},
693
694		{[]string{
695			"I4======",
696			"N4======",
697		}, "illegal base32 data at input byte 0"},
698
699		{[]string{
700			"I4======",
701			"========",
702		}, "illegal base32 data at input byte 0"},
703
704		{[]string{
705			"I4I4I4I4",
706			"I4======",
707			"I4======",
708		}, "illegal base32 data at input byte 0"},
709	}
710
711	for _, testcase := range testcases {
712		testcase := testcase
713		pr, pw := io.Pipe()
714		go func() {
715			for _, chunk := range testcase.chunks {
716				_, _ = pw.Write([]byte(chunk))
717			}
718			_ = pw.Close()
719		}()
720
721		decoder := NewDecoder(StdEncoding, pr)
722		_, err := io.ReadAll(decoder)
723
724		if err == nil && len(testcase.expectedError) != 0 {
725			t.Errorf("case %q: got nil error, want %v", testcase.chunks, testcase.expectedError)
726		} else if err.Error() != testcase.expectedError {
727			t.Errorf("case %q: got %v, want %v", testcase.chunks, err, testcase.expectedError)
728		}
729	}
730}
731
732func TestEncodedLen(t *testing.T) {
733	var rawStdEncoding = StdEncoding.WithPadding(NoPadding)
734	type test struct {
735		enc  *Encoding
736		n    int
737		want int64
738	}
739	tests := []test{
740		{StdEncoding, 0, 0},
741		{StdEncoding, 1, 8},
742		{StdEncoding, 2, 8},
743		{StdEncoding, 3, 8},
744		{StdEncoding, 4, 8},
745		{StdEncoding, 5, 8},
746		{StdEncoding, 6, 16},
747		{StdEncoding, 10, 16},
748		{StdEncoding, 11, 24},
749		{rawStdEncoding, 0, 0},
750		{rawStdEncoding, 1, 2},
751		{rawStdEncoding, 2, 4},
752		{rawStdEncoding, 3, 5},
753		{rawStdEncoding, 4, 7},
754		{rawStdEncoding, 5, 8},
755		{rawStdEncoding, 6, 10},
756		{rawStdEncoding, 7, 12},
757		{rawStdEncoding, 10, 16},
758		{rawStdEncoding, 11, 18},
759	}
760	// check overflow
761	switch strconv.IntSize {
762	case 32:
763		tests = append(tests, test{rawStdEncoding, (math.MaxInt-4)/8 + 1, 429496730})
764		tests = append(tests, test{rawStdEncoding, math.MaxInt/8*5 + 4, math.MaxInt})
765	case 64:
766		tests = append(tests, test{rawStdEncoding, (math.MaxInt-4)/8 + 1, 1844674407370955162})
767		tests = append(tests, test{rawStdEncoding, math.MaxInt/8*5 + 4, math.MaxInt})
768	}
769	for _, tt := range tests {
770		if got := tt.enc.EncodedLen(tt.n); int64(got) != tt.want {
771			t.Errorf("EncodedLen(%d): got %d, want %d", tt.n, got, tt.want)
772		}
773	}
774}
775
776func TestDecodedLen(t *testing.T) {
777	var rawStdEncoding = StdEncoding.WithPadding(NoPadding)
778	type test struct {
779		enc  *Encoding
780		n    int
781		want int64
782	}
783	tests := []test{
784		{StdEncoding, 0, 0},
785		{StdEncoding, 8, 5},
786		{StdEncoding, 16, 10},
787		{StdEncoding, 24, 15},
788		{rawStdEncoding, 0, 0},
789		{rawStdEncoding, 2, 1},
790		{rawStdEncoding, 4, 2},
791		{rawStdEncoding, 5, 3},
792		{rawStdEncoding, 7, 4},
793		{rawStdEncoding, 8, 5},
794		{rawStdEncoding, 10, 6},
795		{rawStdEncoding, 12, 7},
796		{rawStdEncoding, 16, 10},
797		{rawStdEncoding, 18, 11},
798	}
799	// check overflow
800	switch strconv.IntSize {
801	case 32:
802		tests = append(tests, test{rawStdEncoding, math.MaxInt/5 + 1, 268435456})
803		tests = append(tests, test{rawStdEncoding, math.MaxInt, 1342177279})
804	case 64:
805		tests = append(tests, test{rawStdEncoding, math.MaxInt/5 + 1, 1152921504606846976})
806		tests = append(tests, test{rawStdEncoding, math.MaxInt, 5764607523034234879})
807	}
808	for _, tt := range tests {
809		if got := tt.enc.DecodedLen(tt.n); int64(got) != tt.want {
810			t.Errorf("DecodedLen(%d): got %d, want %d", tt.n, got, tt.want)
811		}
812	}
813}
814
815func TestWithoutPaddingClose(t *testing.T) {
816	encodings := []*Encoding{
817		StdEncoding,
818		StdEncoding.WithPadding(NoPadding),
819	}
820
821	for _, encoding := range encodings {
822		for _, testpair := range pairs {
823
824			var buf strings.Builder
825			encoder := NewEncoder(encoding, &buf)
826			encoder.Write([]byte(testpair.decoded))
827			encoder.Close()
828
829			expected := testpair.encoded
830			if encoding.padChar == NoPadding {
831				expected = strings.ReplaceAll(expected, "=", "")
832			}
833
834			res := buf.String()
835
836			if res != expected {
837				t.Errorf("Expected %s got %s; padChar=%d", expected, res, encoding.padChar)
838			}
839		}
840	}
841}
842
843func TestDecodeReadAll(t *testing.T) {
844	encodings := []*Encoding{
845		StdEncoding,
846		StdEncoding.WithPadding(NoPadding),
847	}
848
849	for _, pair := range pairs {
850		for encIndex, encoding := range encodings {
851			encoded := pair.encoded
852			if encoding.padChar == NoPadding {
853				encoded = strings.ReplaceAll(encoded, "=", "")
854			}
855
856			decReader, err := io.ReadAll(NewDecoder(encoding, strings.NewReader(encoded)))
857			if err != nil {
858				t.Errorf("NewDecoder error: %v", err)
859			}
860
861			if pair.decoded != string(decReader) {
862				t.Errorf("Expected %s got %s; Encoding %d", pair.decoded, decReader, encIndex)
863			}
864		}
865	}
866}
867
868func TestDecodeSmallBuffer(t *testing.T) {
869	encodings := []*Encoding{
870		StdEncoding,
871		StdEncoding.WithPadding(NoPadding),
872	}
873
874	for bufferSize := 1; bufferSize < 200; bufferSize++ {
875		for _, pair := range pairs {
876			for encIndex, encoding := range encodings {
877				encoded := pair.encoded
878				if encoding.padChar == NoPadding {
879					encoded = strings.ReplaceAll(encoded, "=", "")
880				}
881
882				decoder := NewDecoder(encoding, strings.NewReader(encoded))
883
884				var allRead []byte
885
886				for {
887					buf := make([]byte, bufferSize)
888					n, err := decoder.Read(buf)
889					allRead = append(allRead, buf[0:n]...)
890					if err == io.EOF {
891						break
892					}
893					if err != nil {
894						t.Error(err)
895					}
896				}
897
898				if pair.decoded != string(allRead) {
899					t.Errorf("Expected %s got %s; Encoding %d; bufferSize %d", pair.decoded, allRead, encIndex, bufferSize)
900				}
901			}
902		}
903	}
904}
905