1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package httptest
6
7import (
8	"bytes"
9	"fmt"
10	"io"
11	"net/http"
12	"net/textproto"
13	"strconv"
14	"strings"
15
16	"golang.org/x/net/http/httpguts"
17)
18
19// ResponseRecorder is an implementation of [http.ResponseWriter] that
20// records its mutations for later inspection in tests.
21type ResponseRecorder struct {
22	// Code is the HTTP response code set by WriteHeader.
23	//
24	// Note that if a Handler never calls WriteHeader or Write,
25	// this might end up being 0, rather than the implicit
26	// http.StatusOK. To get the implicit value, use the Result
27	// method.
28	Code int
29
30	// HeaderMap contains the headers explicitly set by the Handler.
31	// It is an internal detail.
32	//
33	// Deprecated: HeaderMap exists for historical compatibility
34	// and should not be used. To access the headers returned by a handler,
35	// use the Response.Header map as returned by the Result method.
36	HeaderMap http.Header
37
38	// Body is the buffer to which the Handler's Write calls are sent.
39	// If nil, the Writes are silently discarded.
40	Body *bytes.Buffer
41
42	// Flushed is whether the Handler called Flush.
43	Flushed bool
44
45	result      *http.Response // cache of Result's return value
46	snapHeader  http.Header    // snapshot of HeaderMap at first Write
47	wroteHeader bool
48}
49
50// NewRecorder returns an initialized [ResponseRecorder].
51func NewRecorder() *ResponseRecorder {
52	return &ResponseRecorder{
53		HeaderMap: make(http.Header),
54		Body:      new(bytes.Buffer),
55		Code:      200,
56	}
57}
58
59// DefaultRemoteAddr is the default remote address to return in RemoteAddr if
60// an explicit DefaultRemoteAddr isn't set on [ResponseRecorder].
61const DefaultRemoteAddr = "1.2.3.4"
62
63// Header implements [http.ResponseWriter]. It returns the response
64// headers to mutate within a handler. To test the headers that were
65// written after a handler completes, use the [ResponseRecorder.Result] method and see
66// the returned Response value's Header.
67func (rw *ResponseRecorder) Header() http.Header {
68	m := rw.HeaderMap
69	if m == nil {
70		m = make(http.Header)
71		rw.HeaderMap = m
72	}
73	return m
74}
75
76// writeHeader writes a header if it was not written yet and
77// detects Content-Type if needed.
78//
79// bytes or str are the beginning of the response body.
80// We pass both to avoid unnecessarily generate garbage
81// in rw.WriteString which was created for performance reasons.
82// Non-nil bytes win.
83func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
84	if rw.wroteHeader {
85		return
86	}
87	if len(str) > 512 {
88		str = str[:512]
89	}
90
91	m := rw.Header()
92
93	_, hasType := m["Content-Type"]
94	hasTE := m.Get("Transfer-Encoding") != ""
95	if !hasType && !hasTE {
96		if b == nil {
97			b = []byte(str)
98		}
99		m.Set("Content-Type", http.DetectContentType(b))
100	}
101
102	rw.WriteHeader(200)
103}
104
105// Write implements http.ResponseWriter. The data in buf is written to
106// rw.Body, if not nil.
107func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
108	rw.writeHeader(buf, "")
109	if rw.Body != nil {
110		rw.Body.Write(buf)
111	}
112	return len(buf), nil
113}
114
115// WriteString implements [io.StringWriter]. The data in str is written
116// to rw.Body, if not nil.
117func (rw *ResponseRecorder) WriteString(str string) (int, error) {
118	rw.writeHeader(nil, str)
119	if rw.Body != nil {
120		rw.Body.WriteString(str)
121	}
122	return len(str), nil
123}
124
125func checkWriteHeaderCode(code int) {
126	// Issue 22880: require valid WriteHeader status codes.
127	// For now we only enforce that it's three digits.
128	// In the future we might block things over 599 (600 and above aren't defined
129	// at https://httpwg.org/specs/rfc7231.html#status.codes)
130	// and we might block under 200 (once we have more mature 1xx support).
131	// But for now any three digits.
132	//
133	// We used to send "HTTP/1.1 000 0" on the wire in responses but there's
134	// no equivalent bogus thing we can realistically send in HTTP/2,
135	// so we'll consistently panic instead and help people find their bugs
136	// early. (We can't return an error from WriteHeader even if we wanted to.)
137	if code < 100 || code > 999 {
138		panic(fmt.Sprintf("invalid WriteHeader code %v", code))
139	}
140}
141
142// WriteHeader implements [http.ResponseWriter].
143func (rw *ResponseRecorder) WriteHeader(code int) {
144	if rw.wroteHeader {
145		return
146	}
147
148	checkWriteHeaderCode(code)
149	rw.Code = code
150	rw.wroteHeader = true
151	if rw.HeaderMap == nil {
152		rw.HeaderMap = make(http.Header)
153	}
154	rw.snapHeader = rw.HeaderMap.Clone()
155}
156
157// Flush implements [http.Flusher]. To test whether Flush was
158// called, see rw.Flushed.
159func (rw *ResponseRecorder) Flush() {
160	if !rw.wroteHeader {
161		rw.WriteHeader(200)
162	}
163	rw.Flushed = true
164}
165
166// Result returns the response generated by the handler.
167//
168// The returned Response will have at least its StatusCode,
169// Header, Body, and optionally Trailer populated.
170// More fields may be populated in the future, so callers should
171// not DeepEqual the result in tests.
172//
173// The Response.Header is a snapshot of the headers at the time of the
174// first write call, or at the time of this call, if the handler never
175// did a write.
176//
177// The Response.Body is guaranteed to be non-nil and Body.Read call is
178// guaranteed to not return any error other than [io.EOF].
179//
180// Result must only be called after the handler has finished running.
181func (rw *ResponseRecorder) Result() *http.Response {
182	if rw.result != nil {
183		return rw.result
184	}
185	if rw.snapHeader == nil {
186		rw.snapHeader = rw.HeaderMap.Clone()
187	}
188	res := &http.Response{
189		Proto:      "HTTP/1.1",
190		ProtoMajor: 1,
191		ProtoMinor: 1,
192		StatusCode: rw.Code,
193		Header:     rw.snapHeader,
194	}
195	rw.result = res
196	if res.StatusCode == 0 {
197		res.StatusCode = 200
198	}
199	res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode))
200	if rw.Body != nil {
201		res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes()))
202	} else {
203		res.Body = http.NoBody
204	}
205	res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
206
207	if trailers, ok := rw.snapHeader["Trailer"]; ok {
208		res.Trailer = make(http.Header, len(trailers))
209		for _, k := range trailers {
210			for _, k := range strings.Split(k, ",") {
211				k = http.CanonicalHeaderKey(textproto.TrimString(k))
212				if !httpguts.ValidTrailerHeader(k) {
213					// Ignore since forbidden by RFC 7230, section 4.1.2.
214					continue
215				}
216				vv, ok := rw.HeaderMap[k]
217				if !ok {
218					continue
219				}
220				vv2 := make([]string, len(vv))
221				copy(vv2, vv)
222				res.Trailer[k] = vv2
223			}
224		}
225	}
226	for k, vv := range rw.HeaderMap {
227		if !strings.HasPrefix(k, http.TrailerPrefix) {
228			continue
229		}
230		if res.Trailer == nil {
231			res.Trailer = make(http.Header)
232		}
233		for _, v := range vv {
234			res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
235		}
236	}
237	return res
238}
239
240// parseContentLength trims whitespace from s and returns -1 if no value
241// is set, or the value if it's >= 0.
242//
243// This a modified version of same function found in net/http/transfer.go. This
244// one just ignores an invalid header.
245func parseContentLength(cl string) int64 {
246	cl = textproto.TrimString(cl)
247	if cl == "" {
248		return -1
249	}
250	n, err := strconv.ParseUint(cl, 10, 63)
251	if err != nil {
252		return -1
253	}
254	return int64(n)
255}
256