1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package base64 implements base64 encoding as specified by RFC 4648.
6package base64
7
8import (
9	"encoding/binary"
10	"io"
11	"slices"
12	"strconv"
13)
14
15/*
16 * Encodings
17 */
18
19// An Encoding is a radix 64 encoding/decoding scheme, defined by a
20// 64-character alphabet. The most common encoding is the "base64"
21// encoding defined in RFC 4648 and used in MIME (RFC 2045) and PEM
22// (RFC 1421).  RFC 4648 also defines an alternate encoding, which is
23// the standard encoding with - and _ substituted for + and /.
24type Encoding struct {
25	encode    [64]byte   // mapping of symbol index to symbol byte value
26	decodeMap [256]uint8 // mapping of symbol byte value to symbol index
27	padChar   rune
28	strict    bool
29}
30
31const (
32	StdPadding rune = '=' // Standard padding character
33	NoPadding  rune = -1  // No padding
34)
35
36const (
37	decodeMapInitialize = "" +
38		"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
39		"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
40		"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
41		"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
42		"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
43		"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
44		"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
45		"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
46		"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
47		"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
48		"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
49		"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
50		"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
51		"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
52		"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
53		"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
54	invalidIndex = '\xff'
55)
56
57// NewEncoding returns a new padded Encoding defined by the given alphabet,
58// which must be a 64-byte string that contains unique byte values and
59// does not contain the padding character or CR / LF ('\r', '\n').
60// The alphabet is treated as a sequence of byte values
61// without any special treatment for multi-byte UTF-8.
62// The resulting Encoding uses the default padding character ('='),
63// which may be changed or disabled via [Encoding.WithPadding].
64func NewEncoding(encoder string) *Encoding {
65	if len(encoder) != 64 {
66		panic("encoding alphabet is not 64-bytes long")
67	}
68
69	e := new(Encoding)
70	e.padChar = StdPadding
71	copy(e.encode[:], encoder)
72	copy(e.decodeMap[:], decodeMapInitialize)
73
74	for i := 0; i < len(encoder); i++ {
75		// Note: While we document that the alphabet cannot contain
76		// the padding character, we do not enforce it since we do not know
77		// if the caller intends to switch the padding from StdPadding later.
78		switch {
79		case encoder[i] == '\n' || encoder[i] == '\r':
80			panic("encoding alphabet contains newline character")
81		case e.decodeMap[encoder[i]] != invalidIndex:
82			panic("encoding alphabet includes duplicate symbols")
83		}
84		e.decodeMap[encoder[i]] = uint8(i)
85	}
86	return e
87}
88
89// WithPadding creates a new encoding identical to enc except
90// with a specified padding character, or [NoPadding] to disable padding.
91// The padding character must not be '\r' or '\n',
92// must not be contained in the encoding's alphabet,
93// must not be negative, and must be a rune equal or below '\xff'.
94// Padding characters above '\x7f' are encoded as their exact byte value
95// rather than using the UTF-8 representation of the codepoint.
96func (enc Encoding) WithPadding(padding rune) *Encoding {
97	switch {
98	case padding < NoPadding || padding == '\r' || padding == '\n' || padding > 0xff:
99		panic("invalid padding")
100	case padding != NoPadding && enc.decodeMap[byte(padding)] != invalidIndex:
101		panic("padding contained in alphabet")
102	}
103	enc.padChar = padding
104	return &enc
105}
106
107// Strict creates a new encoding identical to enc except with
108// strict decoding enabled. In this mode, the decoder requires that
109// trailing padding bits are zero, as described in RFC 4648 section 3.5.
110//
111// Note that the input is still malleable, as new line characters
112// (CR and LF) are still ignored.
113func (enc Encoding) Strict() *Encoding {
114	enc.strict = true
115	return &enc
116}
117
118// StdEncoding is the standard base64 encoding, as defined in RFC 4648.
119var StdEncoding = NewEncoding("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/")
120
121// URLEncoding is the alternate base64 encoding defined in RFC 4648.
122// It is typically used in URLs and file names.
123var URLEncoding = NewEncoding("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_")
124
125// RawStdEncoding is the standard raw, unpadded base64 encoding,
126// as defined in RFC 4648 section 3.2.
127// This is the same as [StdEncoding] but omits padding characters.
128var RawStdEncoding = StdEncoding.WithPadding(NoPadding)
129
130// RawURLEncoding is the unpadded alternate base64 encoding defined in RFC 4648.
131// It is typically used in URLs and file names.
132// This is the same as [URLEncoding] but omits padding characters.
133var RawURLEncoding = URLEncoding.WithPadding(NoPadding)
134
135/*
136 * Encoder
137 */
138
139// Encode encodes src using the encoding enc,
140// writing [Encoding.EncodedLen](len(src)) bytes to dst.
141//
142// The encoding pads the output to a multiple of 4 bytes,
143// so Encode is not appropriate for use on individual blocks
144// of a large data stream. Use [NewEncoder] instead.
145func (enc *Encoding) Encode(dst, src []byte) {
146	if len(src) == 0 {
147		return
148	}
149	// enc is a pointer receiver, so the use of enc.encode within the hot
150	// loop below means a nil check at every operation. Lift that nil check
151	// outside of the loop to speed up the encoder.
152	_ = enc.encode
153
154	di, si := 0, 0
155	n := (len(src) / 3) * 3
156	for si < n {
157		// Convert 3x 8bit source bytes into 4 bytes
158		val := uint(src[si+0])<<16 | uint(src[si+1])<<8 | uint(src[si+2])
159
160		dst[di+0] = enc.encode[val>>18&0x3F]
161		dst[di+1] = enc.encode[val>>12&0x3F]
162		dst[di+2] = enc.encode[val>>6&0x3F]
163		dst[di+3] = enc.encode[val&0x3F]
164
165		si += 3
166		di += 4
167	}
168
169	remain := len(src) - si
170	if remain == 0 {
171		return
172	}
173	// Add the remaining small block
174	val := uint(src[si+0]) << 16
175	if remain == 2 {
176		val |= uint(src[si+1]) << 8
177	}
178
179	dst[di+0] = enc.encode[val>>18&0x3F]
180	dst[di+1] = enc.encode[val>>12&0x3F]
181
182	switch remain {
183	case 2:
184		dst[di+2] = enc.encode[val>>6&0x3F]
185		if enc.padChar != NoPadding {
186			dst[di+3] = byte(enc.padChar)
187		}
188	case 1:
189		if enc.padChar != NoPadding {
190			dst[di+2] = byte(enc.padChar)
191			dst[di+3] = byte(enc.padChar)
192		}
193	}
194}
195
196// AppendEncode appends the base64 encoded src to dst
197// and returns the extended buffer.
198func (enc *Encoding) AppendEncode(dst, src []byte) []byte {
199	n := enc.EncodedLen(len(src))
200	dst = slices.Grow(dst, n)
201	enc.Encode(dst[len(dst):][:n], src)
202	return dst[:len(dst)+n]
203}
204
205// EncodeToString returns the base64 encoding of src.
206func (enc *Encoding) EncodeToString(src []byte) string {
207	buf := make([]byte, enc.EncodedLen(len(src)))
208	enc.Encode(buf, src)
209	return string(buf)
210}
211
212type encoder struct {
213	err  error
214	enc  *Encoding
215	w    io.Writer
216	buf  [3]byte    // buffered data waiting to be encoded
217	nbuf int        // number of bytes in buf
218	out  [1024]byte // output buffer
219}
220
221func (e *encoder) Write(p []byte) (n int, err error) {
222	if e.err != nil {
223		return 0, e.err
224	}
225
226	// Leading fringe.
227	if e.nbuf > 0 {
228		var i int
229		for i = 0; i < len(p) && e.nbuf < 3; i++ {
230			e.buf[e.nbuf] = p[i]
231			e.nbuf++
232		}
233		n += i
234		p = p[i:]
235		if e.nbuf < 3 {
236			return
237		}
238		e.enc.Encode(e.out[:], e.buf[:])
239		if _, e.err = e.w.Write(e.out[:4]); e.err != nil {
240			return n, e.err
241		}
242		e.nbuf = 0
243	}
244
245	// Large interior chunks.
246	for len(p) >= 3 {
247		nn := len(e.out) / 4 * 3
248		if nn > len(p) {
249			nn = len(p)
250			nn -= nn % 3
251		}
252		e.enc.Encode(e.out[:], p[:nn])
253		if _, e.err = e.w.Write(e.out[0 : nn/3*4]); e.err != nil {
254			return n, e.err
255		}
256		n += nn
257		p = p[nn:]
258	}
259
260	// Trailing fringe.
261	copy(e.buf[:], p)
262	e.nbuf = len(p)
263	n += len(p)
264	return
265}
266
267// Close flushes any pending output from the encoder.
268// It is an error to call Write after calling Close.
269func (e *encoder) Close() error {
270	// If there's anything left in the buffer, flush it out
271	if e.err == nil && e.nbuf > 0 {
272		e.enc.Encode(e.out[:], e.buf[:e.nbuf])
273		_, e.err = e.w.Write(e.out[:e.enc.EncodedLen(e.nbuf)])
274		e.nbuf = 0
275	}
276	return e.err
277}
278
279// NewEncoder returns a new base64 stream encoder. Data written to
280// the returned writer will be encoded using enc and then written to w.
281// Base64 encodings operate in 4-byte blocks; when finished
282// writing, the caller must Close the returned encoder to flush any
283// partially written blocks.
284func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
285	return &encoder{enc: enc, w: w}
286}
287
288// EncodedLen returns the length in bytes of the base64 encoding
289// of an input buffer of length n.
290func (enc *Encoding) EncodedLen(n int) int {
291	if enc.padChar == NoPadding {
292		return n/3*4 + (n%3*8+5)/6 // minimum # chars at 6 bits per char
293	}
294	return (n + 2) / 3 * 4 // minimum # 4-char quanta, 3 bytes each
295}
296
297/*
298 * Decoder
299 */
300
301type CorruptInputError int64
302
303func (e CorruptInputError) Error() string {
304	return "illegal base64 data at input byte " + strconv.FormatInt(int64(e), 10)
305}
306
307// decodeQuantum decodes up to 4 base64 bytes. The received parameters are
308// the destination buffer dst, the source buffer src and an index in the
309// source buffer si.
310// It returns the number of bytes read from src, the number of bytes written
311// to dst, and an error, if any.
312func (enc *Encoding) decodeQuantum(dst, src []byte, si int) (nsi, n int, err error) {
313	// Decode quantum using the base64 alphabet
314	var dbuf [4]byte
315	dlen := 4
316
317	// Lift the nil check outside of the loop.
318	_ = enc.decodeMap
319
320	for j := 0; j < len(dbuf); j++ {
321		if len(src) == si {
322			switch {
323			case j == 0:
324				return si, 0, nil
325			case j == 1, enc.padChar != NoPadding:
326				return si, 0, CorruptInputError(si - j)
327			}
328			dlen = j
329			break
330		}
331		in := src[si]
332		si++
333
334		out := enc.decodeMap[in]
335		if out != 0xff {
336			dbuf[j] = out
337			continue
338		}
339
340		if in == '\n' || in == '\r' {
341			j--
342			continue
343		}
344
345		if rune(in) != enc.padChar {
346			return si, 0, CorruptInputError(si - 1)
347		}
348
349		// We've reached the end and there's padding
350		switch j {
351		case 0, 1:
352			// incorrect padding
353			return si, 0, CorruptInputError(si - 1)
354		case 2:
355			// "==" is expected, the first "=" is already consumed.
356			// skip over newlines
357			for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
358				si++
359			}
360			if si == len(src) {
361				// not enough padding
362				return si, 0, CorruptInputError(len(src))
363			}
364			if rune(src[si]) != enc.padChar {
365				// incorrect padding
366				return si, 0, CorruptInputError(si - 1)
367			}
368
369			si++
370		}
371
372		// skip over newlines
373		for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
374			si++
375		}
376		if si < len(src) {
377			// trailing garbage
378			err = CorruptInputError(si)
379		}
380		dlen = j
381		break
382	}
383
384	// Convert 4x 6bit source bytes into 3 bytes
385	val := uint(dbuf[0])<<18 | uint(dbuf[1])<<12 | uint(dbuf[2])<<6 | uint(dbuf[3])
386	dbuf[2], dbuf[1], dbuf[0] = byte(val>>0), byte(val>>8), byte(val>>16)
387	switch dlen {
388	case 4:
389		dst[2] = dbuf[2]
390		dbuf[2] = 0
391		fallthrough
392	case 3:
393		dst[1] = dbuf[1]
394		if enc.strict && dbuf[2] != 0 {
395			return si, 0, CorruptInputError(si - 1)
396		}
397		dbuf[1] = 0
398		fallthrough
399	case 2:
400		dst[0] = dbuf[0]
401		if enc.strict && (dbuf[1] != 0 || dbuf[2] != 0) {
402			return si, 0, CorruptInputError(si - 2)
403		}
404	}
405
406	return si, dlen - 1, err
407}
408
409// AppendDecode appends the base64 decoded src to dst
410// and returns the extended buffer.
411// If the input is malformed, it returns the partially decoded src and an error.
412func (enc *Encoding) AppendDecode(dst, src []byte) ([]byte, error) {
413	// Compute the output size without padding to avoid over allocating.
414	n := len(src)
415	for n > 0 && rune(src[n-1]) == enc.padChar {
416		n--
417	}
418	n = decodedLen(n, NoPadding)
419
420	dst = slices.Grow(dst, n)
421	n, err := enc.Decode(dst[len(dst):][:n], src)
422	return dst[:len(dst)+n], err
423}
424
425// DecodeString returns the bytes represented by the base64 string s.
426func (enc *Encoding) DecodeString(s string) ([]byte, error) {
427	dbuf := make([]byte, enc.DecodedLen(len(s)))
428	n, err := enc.Decode(dbuf, []byte(s))
429	return dbuf[:n], err
430}
431
432type decoder struct {
433	err     error
434	readErr error // error from r.Read
435	enc     *Encoding
436	r       io.Reader
437	buf     [1024]byte // leftover input
438	nbuf    int
439	out     []byte // leftover decoded output
440	outbuf  [1024 / 4 * 3]byte
441}
442
443func (d *decoder) Read(p []byte) (n int, err error) {
444	// Use leftover decoded output from last read.
445	if len(d.out) > 0 {
446		n = copy(p, d.out)
447		d.out = d.out[n:]
448		return n, nil
449	}
450
451	if d.err != nil {
452		return 0, d.err
453	}
454
455	// This code assumes that d.r strips supported whitespace ('\r' and '\n').
456
457	// Refill buffer.
458	for d.nbuf < 4 && d.readErr == nil {
459		nn := len(p) / 3 * 4
460		if nn < 4 {
461			nn = 4
462		}
463		if nn > len(d.buf) {
464			nn = len(d.buf)
465		}
466		nn, d.readErr = d.r.Read(d.buf[d.nbuf:nn])
467		d.nbuf += nn
468	}
469
470	if d.nbuf < 4 {
471		if d.enc.padChar == NoPadding && d.nbuf > 0 {
472			// Decode final fragment, without padding.
473			var nw int
474			nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:d.nbuf])
475			d.nbuf = 0
476			d.out = d.outbuf[:nw]
477			n = copy(p, d.out)
478			d.out = d.out[n:]
479			if n > 0 || len(p) == 0 && len(d.out) > 0 {
480				return n, nil
481			}
482			if d.err != nil {
483				return 0, d.err
484			}
485		}
486		d.err = d.readErr
487		if d.err == io.EOF && d.nbuf > 0 {
488			d.err = io.ErrUnexpectedEOF
489		}
490		return 0, d.err
491	}
492
493	// Decode chunk into p, or d.out and then p if p is too small.
494	nr := d.nbuf / 4 * 4
495	nw := d.nbuf / 4 * 3
496	if nw > len(p) {
497		nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:nr])
498		d.out = d.outbuf[:nw]
499		n = copy(p, d.out)
500		d.out = d.out[n:]
501	} else {
502		n, d.err = d.enc.Decode(p, d.buf[:nr])
503	}
504	d.nbuf -= nr
505	copy(d.buf[:d.nbuf], d.buf[nr:])
506	return n, d.err
507}
508
509// Decode decodes src using the encoding enc. It writes at most
510// [Encoding.DecodedLen](len(src)) bytes to dst and returns the number of bytes
511// written. If src contains invalid base64 data, it will return the
512// number of bytes successfully written and [CorruptInputError].
513// New line characters (\r and \n) are ignored.
514func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
515	if len(src) == 0 {
516		return 0, nil
517	}
518
519	// Lift the nil check outside of the loop. enc.decodeMap is directly
520	// used later in this function, to let the compiler know that the
521	// receiver can't be nil.
522	_ = enc.decodeMap
523
524	si := 0
525	for strconv.IntSize >= 64 && len(src)-si >= 8 && len(dst)-n >= 8 {
526		src2 := src[si : si+8]
527		if dn, ok := assemble64(
528			enc.decodeMap[src2[0]],
529			enc.decodeMap[src2[1]],
530			enc.decodeMap[src2[2]],
531			enc.decodeMap[src2[3]],
532			enc.decodeMap[src2[4]],
533			enc.decodeMap[src2[5]],
534			enc.decodeMap[src2[6]],
535			enc.decodeMap[src2[7]],
536		); ok {
537			binary.BigEndian.PutUint64(dst[n:], dn)
538			n += 6
539			si += 8
540		} else {
541			var ninc int
542			si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
543			n += ninc
544			if err != nil {
545				return n, err
546			}
547		}
548	}
549
550	for len(src)-si >= 4 && len(dst)-n >= 4 {
551		src2 := src[si : si+4]
552		if dn, ok := assemble32(
553			enc.decodeMap[src2[0]],
554			enc.decodeMap[src2[1]],
555			enc.decodeMap[src2[2]],
556			enc.decodeMap[src2[3]],
557		); ok {
558			binary.BigEndian.PutUint32(dst[n:], dn)
559			n += 3
560			si += 4
561		} else {
562			var ninc int
563			si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
564			n += ninc
565			if err != nil {
566				return n, err
567			}
568		}
569	}
570
571	for si < len(src) {
572		var ninc int
573		si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
574		n += ninc
575		if err != nil {
576			return n, err
577		}
578	}
579	return n, err
580}
581
582// assemble32 assembles 4 base64 digits into 3 bytes.
583// Each digit comes from the decode map, and will be 0xff
584// if it came from an invalid character.
585func assemble32(n1, n2, n3, n4 byte) (dn uint32, ok bool) {
586	// Check that all the digits are valid. If any of them was 0xff, their
587	// bitwise OR will be 0xff.
588	if n1|n2|n3|n4 == 0xff {
589		return 0, false
590	}
591	return uint32(n1)<<26 |
592			uint32(n2)<<20 |
593			uint32(n3)<<14 |
594			uint32(n4)<<8,
595		true
596}
597
598// assemble64 assembles 8 base64 digits into 6 bytes.
599// Each digit comes from the decode map, and will be 0xff
600// if it came from an invalid character.
601func assemble64(n1, n2, n3, n4, n5, n6, n7, n8 byte) (dn uint64, ok bool) {
602	// Check that all the digits are valid. If any of them was 0xff, their
603	// bitwise OR will be 0xff.
604	if n1|n2|n3|n4|n5|n6|n7|n8 == 0xff {
605		return 0, false
606	}
607	return uint64(n1)<<58 |
608			uint64(n2)<<52 |
609			uint64(n3)<<46 |
610			uint64(n4)<<40 |
611			uint64(n5)<<34 |
612			uint64(n6)<<28 |
613			uint64(n7)<<22 |
614			uint64(n8)<<16,
615		true
616}
617
618type newlineFilteringReader struct {
619	wrapped io.Reader
620}
621
622func (r *newlineFilteringReader) Read(p []byte) (int, error) {
623	n, err := r.wrapped.Read(p)
624	for n > 0 {
625		offset := 0
626		for i, b := range p[:n] {
627			if b != '\r' && b != '\n' {
628				if i != offset {
629					p[offset] = b
630				}
631				offset++
632			}
633		}
634		if offset > 0 {
635			return offset, err
636		}
637		// Previous buffer entirely whitespace, read again
638		n, err = r.wrapped.Read(p)
639	}
640	return n, err
641}
642
643// NewDecoder constructs a new base64 stream decoder.
644func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
645	return &decoder{enc: enc, r: &newlineFilteringReader{r}}
646}
647
648// DecodedLen returns the maximum length in bytes of the decoded data
649// corresponding to n bytes of base64-encoded data.
650func (enc *Encoding) DecodedLen(n int) int {
651	return decodedLen(n, enc.padChar)
652}
653
654func decodedLen(n int, padChar rune) int {
655	if padChar == NoPadding {
656		// Unpadded data may end with partial block of 2-3 characters.
657		return n/4*3 + n%4*6/8
658	}
659	// Padded base64 should always be a multiple of 4 characters in length.
660	return n / 4 * 3
661}
662