1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package multipart
6
7import (
8	"bytes"
9	"crypto/rand"
10	"errors"
11	"fmt"
12	"io"
13	"net/textproto"
14	"slices"
15	"strings"
16)
17
18// A Writer generates multipart messages.
19type Writer struct {
20	w        io.Writer
21	boundary string
22	lastpart *part
23}
24
25// NewWriter returns a new multipart [Writer] with a random boundary,
26// writing to w.
27func NewWriter(w io.Writer) *Writer {
28	return &Writer{
29		w:        w,
30		boundary: randomBoundary(),
31	}
32}
33
34// Boundary returns the [Writer]'s boundary.
35func (w *Writer) Boundary() string {
36	return w.boundary
37}
38
39// SetBoundary overrides the [Writer]'s default randomly-generated
40// boundary separator with an explicit value.
41//
42// SetBoundary must be called before any parts are created, may only
43// contain certain ASCII characters, and must be non-empty and
44// at most 70 bytes long.
45func (w *Writer) SetBoundary(boundary string) error {
46	if w.lastpart != nil {
47		return errors.New("mime: SetBoundary called after write")
48	}
49	// rfc2046#section-5.1.1
50	if len(boundary) < 1 || len(boundary) > 70 {
51		return errors.New("mime: invalid boundary length")
52	}
53	end := len(boundary) - 1
54	for i, b := range boundary {
55		if 'A' <= b && b <= 'Z' || 'a' <= b && b <= 'z' || '0' <= b && b <= '9' {
56			continue
57		}
58		switch b {
59		case '\'', '(', ')', '+', '_', ',', '-', '.', '/', ':', '=', '?':
60			continue
61		case ' ':
62			if i != end {
63				continue
64			}
65		}
66		return errors.New("mime: invalid boundary character")
67	}
68	w.boundary = boundary
69	return nil
70}
71
72// FormDataContentType returns the Content-Type for an HTTP
73// multipart/form-data with this [Writer]'s Boundary.
74func (w *Writer) FormDataContentType() string {
75	b := w.boundary
76	// We must quote the boundary if it contains any of the
77	// tspecials characters defined by RFC 2045, or space.
78	if strings.ContainsAny(b, `()<>@,;:\"/[]?= `) {
79		b = `"` + b + `"`
80	}
81	return "multipart/form-data; boundary=" + b
82}
83
84func randomBoundary() string {
85	var buf [30]byte
86	_, err := io.ReadFull(rand.Reader, buf[:])
87	if err != nil {
88		panic(err)
89	}
90	return fmt.Sprintf("%x", buf[:])
91}
92
93// CreatePart creates a new multipart section with the provided
94// header. The body of the part should be written to the returned
95// [Writer]. After calling CreatePart, any previous part may no longer
96// be written to.
97func (w *Writer) CreatePart(header textproto.MIMEHeader) (io.Writer, error) {
98	if w.lastpart != nil {
99		if err := w.lastpart.close(); err != nil {
100			return nil, err
101		}
102	}
103	var b bytes.Buffer
104	if w.lastpart != nil {
105		fmt.Fprintf(&b, "\r\n--%s\r\n", w.boundary)
106	} else {
107		fmt.Fprintf(&b, "--%s\r\n", w.boundary)
108	}
109
110	keys := make([]string, 0, len(header))
111	for k := range header {
112		keys = append(keys, k)
113	}
114	slices.Sort(keys)
115	for _, k := range keys {
116		for _, v := range header[k] {
117			fmt.Fprintf(&b, "%s: %s\r\n", k, v)
118		}
119	}
120	fmt.Fprintf(&b, "\r\n")
121	_, err := io.Copy(w.w, &b)
122	if err != nil {
123		return nil, err
124	}
125	p := &part{
126		mw: w,
127	}
128	w.lastpart = p
129	return p, nil
130}
131
132var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
133
134func escapeQuotes(s string) string {
135	return quoteEscaper.Replace(s)
136}
137
138// CreateFormFile is a convenience wrapper around [Writer.CreatePart]. It creates
139// a new form-data header with the provided field name and file name.
140func (w *Writer) CreateFormFile(fieldname, filename string) (io.Writer, error) {
141	h := make(textproto.MIMEHeader)
142	h.Set("Content-Disposition",
143		fmt.Sprintf(`form-data; name="%s"; filename="%s"`,
144			escapeQuotes(fieldname), escapeQuotes(filename)))
145	h.Set("Content-Type", "application/octet-stream")
146	return w.CreatePart(h)
147}
148
149// CreateFormField calls [Writer.CreatePart] with a header using the
150// given field name.
151func (w *Writer) CreateFormField(fieldname string) (io.Writer, error) {
152	h := make(textproto.MIMEHeader)
153	h.Set("Content-Disposition",
154		fmt.Sprintf(`form-data; name="%s"`, escapeQuotes(fieldname)))
155	return w.CreatePart(h)
156}
157
158// WriteField calls [Writer.CreateFormField] and then writes the given value.
159func (w *Writer) WriteField(fieldname, value string) error {
160	p, err := w.CreateFormField(fieldname)
161	if err != nil {
162		return err
163	}
164	_, err = p.Write([]byte(value))
165	return err
166}
167
168// Close finishes the multipart message and writes the trailing
169// boundary end line to the output.
170func (w *Writer) Close() error {
171	if w.lastpart != nil {
172		if err := w.lastpart.close(); err != nil {
173			return err
174		}
175		w.lastpart = nil
176	}
177	_, err := fmt.Fprintf(w.w, "\r\n--%s--\r\n", w.boundary)
178	return err
179}
180
181type part struct {
182	mw     *Writer
183	closed bool
184	we     error // last error that occurred writing
185}
186
187func (p *part) close() error {
188	p.closed = true
189	return p.we
190}
191
192func (p *part) Write(d []byte) (n int, err error) {
193	if p.closed {
194		return 0, errors.New("multipart: can't write to finished part")
195	}
196	n, err = p.mw.w.Write(d)
197	if err != nil {
198		p.we = err
199	}
200	return
201}
202