1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Bridge package to expose http internals to tests in the http_test
6// package.
7
8package http
9
10import (
11	"context"
12	"fmt"
13	"net"
14	"net/url"
15	"slices"
16	"sync"
17	"testing"
18	"time"
19)
20
21var (
22	DefaultUserAgent                  = defaultUserAgent
23	NewLoggingConn                    = newLoggingConn
24	ExportAppendTime                  = appendTime
25	ExportRefererForURL               = refererForURL
26	ExportServerNewConn               = (*Server).newConn
27	ExportCloseWriteAndWait           = (*conn).closeWriteAndWait
28	ExportErrRequestCanceled          = errRequestCanceled
29	ExportErrRequestCanceledConn      = errRequestCanceledConn
30	ExportErrServerClosedIdle         = errServerClosedIdle
31	ExportServeFile                   = serveFile
32	ExportScanETag                    = scanETag
33	ExportHttp2ConfigureServer        = http2ConfigureServer
34	Export_shouldCopyHeaderOnRedirect = shouldCopyHeaderOnRedirect
35	Export_writeStatusLine            = writeStatusLine
36	Export_is408Message               = is408Message
37)
38
39var MaxWriteWaitBeforeConnReuse = &maxWriteWaitBeforeConnReuse
40
41func init() {
42	// We only want to pay for this cost during testing.
43	// When not under test, these values are always nil
44	// and never assigned to.
45	testHookMu = new(sync.Mutex)
46
47	testHookClientDoResult = func(res *Response, err error) {
48		if err != nil {
49			if _, ok := err.(*url.Error); !ok {
50				panic(fmt.Sprintf("unexpected Client.Do error of type %T; want *url.Error", err))
51			}
52		} else {
53			if res == nil {
54				panic("Client.Do returned nil, nil")
55			}
56			if res.Body == nil {
57				panic("Client.Do returned nil res.Body and no error")
58			}
59		}
60	}
61}
62
63func CondSkipHTTP2(t testing.TB) {
64	if omitBundledHTTP2 {
65		t.Skip("skipping HTTP/2 test when nethttpomithttp2 build tag in use")
66	}
67}
68
69var (
70	SetEnterRoundTripHook = hookSetter(&testHookEnterRoundTrip)
71	SetRoundTripRetried   = hookSetter(&testHookRoundTripRetried)
72)
73
74func SetReadLoopBeforeNextReadHook(f func()) {
75	unnilTestHook(&f)
76	testHookReadLoopBeforeNextRead = f
77}
78
79// SetPendingDialHooks sets the hooks that run before and after handling
80// pending dials.
81func SetPendingDialHooks(before, after func()) {
82	unnilTestHook(&before)
83	unnilTestHook(&after)
84	testHookPrePendingDial, testHookPostPendingDial = before, after
85}
86
87func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn }
88
89func SetTestHookProxyConnectTimeout(t *testing.T, f func(context.Context, time.Duration) (context.Context, context.CancelFunc)) {
90	orig := testHookProxyConnectTimeout
91	t.Cleanup(func() {
92		testHookProxyConnectTimeout = orig
93	})
94	testHookProxyConnectTimeout = f
95}
96
97func NewTestTimeoutHandler(handler Handler, ctx context.Context) Handler {
98	return &timeoutHandler{
99		handler:     handler,
100		testContext: ctx,
101		// (no body)
102	}
103}
104
105func ResetCachedEnvironment() {
106	resetProxyConfig()
107}
108
109func (t *Transport) NumPendingRequestsForTesting() int {
110	t.reqMu.Lock()
111	defer t.reqMu.Unlock()
112	return len(t.reqCanceler)
113}
114
115func (t *Transport) IdleConnKeysForTesting() (keys []string) {
116	keys = make([]string, 0)
117	t.idleMu.Lock()
118	defer t.idleMu.Unlock()
119	for key := range t.idleConn {
120		keys = append(keys, key.String())
121	}
122	slices.Sort(keys)
123	return
124}
125
126func (t *Transport) IdleConnKeyCountForTesting() int {
127	t.idleMu.Lock()
128	defer t.idleMu.Unlock()
129	return len(t.idleConn)
130}
131
132func (t *Transport) IdleConnStrsForTesting() []string {
133	var ret []string
134	t.idleMu.Lock()
135	defer t.idleMu.Unlock()
136	for _, conns := range t.idleConn {
137		for _, pc := range conns {
138			ret = append(ret, pc.conn.LocalAddr().String()+"/"+pc.conn.RemoteAddr().String())
139		}
140	}
141	slices.Sort(ret)
142	return ret
143}
144
145func (t *Transport) IdleConnStrsForTesting_h2() []string {
146	var ret []string
147	noDialPool := t.h2transport.(*http2Transport).ConnPool.(http2noDialClientConnPool)
148	pool := noDialPool.http2clientConnPool
149
150	pool.mu.Lock()
151	defer pool.mu.Unlock()
152
153	for k, ccs := range pool.conns {
154		for _, cc := range ccs {
155			if cc.idleState().canTakeNewRequest {
156				ret = append(ret, k)
157			}
158		}
159	}
160
161	slices.Sort(ret)
162	return ret
163}
164
165func (t *Transport) IdleConnCountForTesting(scheme, addr string) int {
166	t.idleMu.Lock()
167	defer t.idleMu.Unlock()
168	key := connectMethodKey{"", scheme, addr, false}
169	cacheKey := key.String()
170	for k, conns := range t.idleConn {
171		if k.String() == cacheKey {
172			return len(conns)
173		}
174	}
175	return 0
176}
177
178func (t *Transport) IdleConnWaitMapSizeForTesting() int {
179	t.idleMu.Lock()
180	defer t.idleMu.Unlock()
181	return len(t.idleConnWait)
182}
183
184func (t *Transport) IsIdleForTesting() bool {
185	t.idleMu.Lock()
186	defer t.idleMu.Unlock()
187	return t.closeIdle
188}
189
190func (t *Transport) QueueForIdleConnForTesting() {
191	t.queueForIdleConn(nil)
192}
193
194// PutIdleTestConn reports whether it was able to insert a fresh
195// persistConn for scheme, addr into the idle connection pool.
196func (t *Transport) PutIdleTestConn(scheme, addr string) bool {
197	c, _ := net.Pipe()
198	key := connectMethodKey{"", scheme, addr, false}
199
200	if t.MaxConnsPerHost > 0 {
201		// Transport is tracking conns-per-host.
202		// Increment connection count to account
203		// for new persistConn created below.
204		t.connsPerHostMu.Lock()
205		if t.connsPerHost == nil {
206			t.connsPerHost = make(map[connectMethodKey]int)
207		}
208		t.connsPerHost[key]++
209		t.connsPerHostMu.Unlock()
210	}
211
212	return t.tryPutIdleConn(&persistConn{
213		t:        t,
214		conn:     c,                   // dummy
215		closech:  make(chan struct{}), // so it can be closed
216		cacheKey: key,
217	}) == nil
218}
219
220// PutIdleTestConnH2 reports whether it was able to insert a fresh
221// HTTP/2 persistConn for scheme, addr into the idle connection pool.
222func (t *Transport) PutIdleTestConnH2(scheme, addr string, alt RoundTripper) bool {
223	key := connectMethodKey{"", scheme, addr, false}
224
225	if t.MaxConnsPerHost > 0 {
226		// Transport is tracking conns-per-host.
227		// Increment connection count to account
228		// for new persistConn created below.
229		t.connsPerHostMu.Lock()
230		if t.connsPerHost == nil {
231			t.connsPerHost = make(map[connectMethodKey]int)
232		}
233		t.connsPerHost[key]++
234		t.connsPerHostMu.Unlock()
235	}
236
237	return t.tryPutIdleConn(&persistConn{
238		t:        t,
239		alt:      alt,
240		cacheKey: key,
241	}) == nil
242}
243
244// All test hooks must be non-nil so they can be called directly,
245// but the tests use nil to mean hook disabled.
246func unnilTestHook(f *func()) {
247	if *f == nil {
248		*f = nop
249	}
250}
251
252func hookSetter(dst *func()) func(func()) {
253	return func(fn func()) {
254		unnilTestHook(&fn)
255		*dst = fn
256	}
257}
258
259func ExportHttp2ConfigureTransport(t *Transport) error {
260	t2, err := http2configureTransports(t)
261	if err != nil {
262		return err
263	}
264	t.h2transport = t2
265	return nil
266}
267
268func (s *Server) ExportAllConnsIdle() bool {
269	s.mu.Lock()
270	defer s.mu.Unlock()
271	for c := range s.activeConn {
272		st, unixSec := c.getState()
273		if unixSec == 0 || st != StateIdle {
274			return false
275		}
276	}
277	return true
278}
279
280func (s *Server) ExportAllConnsByState() map[ConnState]int {
281	states := map[ConnState]int{}
282	s.mu.Lock()
283	defer s.mu.Unlock()
284	for c := range s.activeConn {
285		st, _ := c.getState()
286		states[st] += 1
287	}
288	return states
289}
290
291func (r *Request) WithT(t *testing.T) *Request {
292	return r.WithContext(context.WithValue(r.Context(), tLogKey{}, t.Logf))
293}
294
295func ExportSetH2GoawayTimeout(d time.Duration) (restore func()) {
296	old := http2goAwayTimeout
297	http2goAwayTimeout = d
298	return func() { http2goAwayTimeout = old }
299}
300
301func (r *Request) ExportIsReplayable() bool { return r.isReplayable() }
302
303// ExportCloseTransportConnsAbruptly closes all idle connections from
304// tr in an abrupt way, just reaching into the underlying Conns and
305// closing them, without telling the Transport or its persistConns
306// that it's doing so. This is to simulate the server closing connections
307// on the Transport.
308func ExportCloseTransportConnsAbruptly(tr *Transport) {
309	tr.idleMu.Lock()
310	for _, pcs := range tr.idleConn {
311		for _, pc := range pcs {
312			pc.conn.Close()
313		}
314	}
315	tr.idleMu.Unlock()
316}
317
318// ResponseWriterConnForTesting returns w's underlying connection, if w
319// is a regular *response ResponseWriter.
320func ResponseWriterConnForTesting(w ResponseWriter) (c net.Conn, ok bool) {
321	if r, ok := w.(*response); ok {
322		return r.conn.rwc, true
323	}
324	return nil, false
325}
326
327func init() {
328	// Set the default rstAvoidanceDelay to the minimum possible value to shake
329	// out tests that unexpectedly depend on it. Such tests should use
330	// runTimeSensitiveTest and SetRSTAvoidanceDelay to explicitly raise the delay
331	// if needed.
332	rstAvoidanceDelay = 1 * time.Nanosecond
333}
334
335// SetRSTAvoidanceDelay sets how long we are willing to wait between calling
336// CloseWrite on a connection and fully closing the connection.
337func SetRSTAvoidanceDelay(t *testing.T, d time.Duration) {
338	prevDelay := rstAvoidanceDelay
339	t.Cleanup(func() {
340		rstAvoidanceDelay = prevDelay
341	})
342	rstAvoidanceDelay = d
343}
344