1// Copyright 2024 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package http_test
6
7import (
8	"context"
9	"io"
10	"net"
11	"net/http"
12	"net/http/httptrace"
13	"testing"
14)
15
16func TestTransportPoolConnReusePriorConnection(t *testing.T) {
17	dt := newTransportDialTester(t, http1Mode)
18
19	// First request creates a new connection.
20	rt1 := dt.roundTrip()
21	c1 := dt.wantDial()
22	c1.finish(nil)
23	rt1.wantDone(c1)
24	rt1.finish()
25
26	// Second request reuses the first connection.
27	rt2 := dt.roundTrip()
28	rt2.wantDone(c1)
29	rt2.finish()
30}
31
32func TestTransportPoolConnCannotReuseConnectionInUse(t *testing.T) {
33	dt := newTransportDialTester(t, http1Mode)
34
35	// First request creates a new connection.
36	rt1 := dt.roundTrip()
37	c1 := dt.wantDial()
38	c1.finish(nil)
39	rt1.wantDone(c1)
40
41	// Second request is made while the first request is still using its connection,
42	// so it goes on a new connection.
43	rt2 := dt.roundTrip()
44	c2 := dt.wantDial()
45	c2.finish(nil)
46	rt2.wantDone(c2)
47}
48
49func TestTransportPoolConnConnectionBecomesAvailableDuringDial(t *testing.T) {
50	dt := newTransportDialTester(t, http1Mode)
51
52	// First request creates a new connection.
53	rt1 := dt.roundTrip()
54	c1 := dt.wantDial()
55	c1.finish(nil)
56	rt1.wantDone(c1)
57
58	// Second request is made while the first request is still using its connection.
59	// The first connection completes while the second Dial is in progress, so the
60	// second request uses the first connection.
61	rt2 := dt.roundTrip()
62	c2 := dt.wantDial()
63	rt1.finish()
64	rt2.wantDone(c1)
65
66	// This section is a bit overfitted to the current Transport implementation:
67	// A third request starts. We have an in-progress dial that was started by rt2,
68	// but this new request (rt3) is going to ignore it and make a dial of its own.
69	// rt3 will use the first of these dials that completes.
70	rt3 := dt.roundTrip()
71	c3 := dt.wantDial()
72	c2.finish(nil)
73	rt3.wantDone(c2)
74
75	c3.finish(nil)
76}
77
78// A transportDialTester manages a test of a connection's Dials.
79type transportDialTester struct {
80	t   *testing.T
81	cst *clientServerTest
82
83	dials chan *transportDialTesterConn // each new conn is sent to this channel
84
85	roundTripCount int
86	dialCount      int
87}
88
89// A transportDialTesterRoundTrip is a RoundTrip made as part of a dial test.
90type transportDialTesterRoundTrip struct {
91	t *testing.T
92
93	roundTripID int                // distinguishes RoundTrips in logs
94	cancel      context.CancelFunc // cancels the Request context
95	reqBody     io.WriteCloser     // write half of the Request.Body
96	finished    bool
97
98	done chan struct{} // closed when RoundTrip returns:w
99	res  *http.Response
100	err  error
101	conn *transportDialTesterConn
102}
103
104// A transportDialTesterConn is a client connection created by the Transport as
105// part of a dial test.
106type transportDialTesterConn struct {
107	t *testing.T
108
109	connID int        // distinguished Dials in logs
110	ready  chan error // sent on to complete the Dial
111
112	net.Conn
113}
114
115func newTransportDialTester(t *testing.T, mode testMode) *transportDialTester {
116	t.Helper()
117	dt := &transportDialTester{
118		t:     t,
119		dials: make(chan *transportDialTesterConn),
120	}
121	dt.cst = newClientServerTest(t, mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
122		// Write response headers when we receive a request.
123		http.NewResponseController(w).EnableFullDuplex()
124		w.WriteHeader(200)
125		http.NewResponseController(w).Flush()
126		// Wait for the client to send the request body,
127		// to synchronize with the rest of the test.
128		io.ReadAll(r.Body)
129	}), func(tr *http.Transport) {
130		tr.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
131			c := &transportDialTesterConn{
132				t:     t,
133				ready: make(chan error),
134			}
135			// Notify the test that a Dial has started,
136			// and wait for the test to notify us that it should complete.
137			dt.dials <- c
138			if err := <-c.ready; err != nil {
139				return nil, err
140			}
141			nc, err := net.Dial(network, address)
142			if err != nil {
143				return nil, err
144			}
145			// Use the *transportDialTesterConn as the net.Conn,
146			// to let tests associate requests with connections.
147			c.Conn = nc
148			return c, err
149		}
150	})
151	return dt
152}
153
154// roundTrip starts a RoundTrip.
155// It returns immediately, without waiting for the RoundTrip call to complete.
156func (dt *transportDialTester) roundTrip() *transportDialTesterRoundTrip {
157	dt.t.Helper()
158	ctx, cancel := context.WithCancel(context.Background())
159	pr, pw := io.Pipe()
160	rt := &transportDialTesterRoundTrip{
161		t:           dt.t,
162		roundTripID: dt.roundTripCount,
163		done:        make(chan struct{}),
164		reqBody:     pw,
165		cancel:      cancel,
166	}
167	dt.roundTripCount++
168	dt.t.Logf("RoundTrip %v: started", rt.roundTripID)
169	dt.t.Cleanup(func() {
170		rt.cancel()
171		rt.finish()
172	})
173	go func() {
174		ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
175			GotConn: func(info httptrace.GotConnInfo) {
176				rt.conn = info.Conn.(*transportDialTesterConn)
177			},
178		})
179		req, _ := http.NewRequestWithContext(ctx, "POST", dt.cst.ts.URL, pr)
180		req.Header.Set("Content-Type", "text/plain")
181		rt.res, rt.err = dt.cst.tr.RoundTrip(req)
182		dt.t.Logf("RoundTrip %v: done (err:%v)", rt.roundTripID, rt.err)
183		close(rt.done)
184	}()
185	return rt
186}
187
188// wantDone indicates that a RoundTrip should have returned.
189func (rt *transportDialTesterRoundTrip) wantDone(c *transportDialTesterConn) {
190	rt.t.Helper()
191	<-rt.done
192	if rt.err != nil {
193		rt.t.Fatalf("RoundTrip %v: want success, got err %v", rt.roundTripID, rt.err)
194	}
195	if rt.conn != c {
196		rt.t.Fatalf("RoundTrip %v: want on conn %v, got conn %v", rt.roundTripID, c.connID, rt.conn.connID)
197	}
198}
199
200// finish completes a RoundTrip by sending the request body, consuming the response body,
201// and closing the response body.
202func (rt *transportDialTesterRoundTrip) finish() {
203	rt.t.Helper()
204
205	if rt.finished {
206		return
207	}
208	rt.finished = true
209
210	<-rt.done
211
212	if rt.err != nil {
213		return
214	}
215	rt.reqBody.Close()
216	io.ReadAll(rt.res.Body)
217	rt.res.Body.Close()
218	rt.t.Logf("RoundTrip %v: closed request body", rt.roundTripID)
219}
220
221// wantDial waits for the Transport to start a Dial.
222func (dt *transportDialTester) wantDial() *transportDialTesterConn {
223	c := <-dt.dials
224	c.connID = dt.dialCount
225	dt.dialCount++
226	dt.t.Logf("Dial %v: started", c.connID)
227	return c
228}
229
230// finish completes a Dial.
231func (c *transportDialTesterConn) finish(err error) {
232	c.t.Logf("Dial %v: finished (err:%v)", c.connID, err)
233	c.ready <- err
234	close(c.ready)
235}
236