1// Copyright 2024 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package http_test 6 7import ( 8 "context" 9 "io" 10 "net" 11 "net/http" 12 "net/http/httptrace" 13 "testing" 14) 15 16func TestTransportPoolConnReusePriorConnection(t *testing.T) { 17 dt := newTransportDialTester(t, http1Mode) 18 19 // First request creates a new connection. 20 rt1 := dt.roundTrip() 21 c1 := dt.wantDial() 22 c1.finish(nil) 23 rt1.wantDone(c1) 24 rt1.finish() 25 26 // Second request reuses the first connection. 27 rt2 := dt.roundTrip() 28 rt2.wantDone(c1) 29 rt2.finish() 30} 31 32func TestTransportPoolConnCannotReuseConnectionInUse(t *testing.T) { 33 dt := newTransportDialTester(t, http1Mode) 34 35 // First request creates a new connection. 36 rt1 := dt.roundTrip() 37 c1 := dt.wantDial() 38 c1.finish(nil) 39 rt1.wantDone(c1) 40 41 // Second request is made while the first request is still using its connection, 42 // so it goes on a new connection. 43 rt2 := dt.roundTrip() 44 c2 := dt.wantDial() 45 c2.finish(nil) 46 rt2.wantDone(c2) 47} 48 49func TestTransportPoolConnConnectionBecomesAvailableDuringDial(t *testing.T) { 50 dt := newTransportDialTester(t, http1Mode) 51 52 // First request creates a new connection. 53 rt1 := dt.roundTrip() 54 c1 := dt.wantDial() 55 c1.finish(nil) 56 rt1.wantDone(c1) 57 58 // Second request is made while the first request is still using its connection. 59 // The first connection completes while the second Dial is in progress, so the 60 // second request uses the first connection. 61 rt2 := dt.roundTrip() 62 c2 := dt.wantDial() 63 rt1.finish() 64 rt2.wantDone(c1) 65 66 // This section is a bit overfitted to the current Transport implementation: 67 // A third request starts. We have an in-progress dial that was started by rt2, 68 // but this new request (rt3) is going to ignore it and make a dial of its own. 69 // rt3 will use the first of these dials that completes. 70 rt3 := dt.roundTrip() 71 c3 := dt.wantDial() 72 c2.finish(nil) 73 rt3.wantDone(c2) 74 75 c3.finish(nil) 76} 77 78// A transportDialTester manages a test of a connection's Dials. 79type transportDialTester struct { 80 t *testing.T 81 cst *clientServerTest 82 83 dials chan *transportDialTesterConn // each new conn is sent to this channel 84 85 roundTripCount int 86 dialCount int 87} 88 89// A transportDialTesterRoundTrip is a RoundTrip made as part of a dial test. 90type transportDialTesterRoundTrip struct { 91 t *testing.T 92 93 roundTripID int // distinguishes RoundTrips in logs 94 cancel context.CancelFunc // cancels the Request context 95 reqBody io.WriteCloser // write half of the Request.Body 96 finished bool 97 98 done chan struct{} // closed when RoundTrip returns:w 99 res *http.Response 100 err error 101 conn *transportDialTesterConn 102} 103 104// A transportDialTesterConn is a client connection created by the Transport as 105// part of a dial test. 106type transportDialTesterConn struct { 107 t *testing.T 108 109 connID int // distinguished Dials in logs 110 ready chan error // sent on to complete the Dial 111 112 net.Conn 113} 114 115func newTransportDialTester(t *testing.T, mode testMode) *transportDialTester { 116 t.Helper() 117 dt := &transportDialTester{ 118 t: t, 119 dials: make(chan *transportDialTesterConn), 120 } 121 dt.cst = newClientServerTest(t, mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 122 // Write response headers when we receive a request. 123 http.NewResponseController(w).EnableFullDuplex() 124 w.WriteHeader(200) 125 http.NewResponseController(w).Flush() 126 // Wait for the client to send the request body, 127 // to synchronize with the rest of the test. 128 io.ReadAll(r.Body) 129 }), func(tr *http.Transport) { 130 tr.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) { 131 c := &transportDialTesterConn{ 132 t: t, 133 ready: make(chan error), 134 } 135 // Notify the test that a Dial has started, 136 // and wait for the test to notify us that it should complete. 137 dt.dials <- c 138 if err := <-c.ready; err != nil { 139 return nil, err 140 } 141 nc, err := net.Dial(network, address) 142 if err != nil { 143 return nil, err 144 } 145 // Use the *transportDialTesterConn as the net.Conn, 146 // to let tests associate requests with connections. 147 c.Conn = nc 148 return c, err 149 } 150 }) 151 return dt 152} 153 154// roundTrip starts a RoundTrip. 155// It returns immediately, without waiting for the RoundTrip call to complete. 156func (dt *transportDialTester) roundTrip() *transportDialTesterRoundTrip { 157 dt.t.Helper() 158 ctx, cancel := context.WithCancel(context.Background()) 159 pr, pw := io.Pipe() 160 rt := &transportDialTesterRoundTrip{ 161 t: dt.t, 162 roundTripID: dt.roundTripCount, 163 done: make(chan struct{}), 164 reqBody: pw, 165 cancel: cancel, 166 } 167 dt.roundTripCount++ 168 dt.t.Logf("RoundTrip %v: started", rt.roundTripID) 169 dt.t.Cleanup(func() { 170 rt.cancel() 171 rt.finish() 172 }) 173 go func() { 174 ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ 175 GotConn: func(info httptrace.GotConnInfo) { 176 rt.conn = info.Conn.(*transportDialTesterConn) 177 }, 178 }) 179 req, _ := http.NewRequestWithContext(ctx, "POST", dt.cst.ts.URL, pr) 180 req.Header.Set("Content-Type", "text/plain") 181 rt.res, rt.err = dt.cst.tr.RoundTrip(req) 182 dt.t.Logf("RoundTrip %v: done (err:%v)", rt.roundTripID, rt.err) 183 close(rt.done) 184 }() 185 return rt 186} 187 188// wantDone indicates that a RoundTrip should have returned. 189func (rt *transportDialTesterRoundTrip) wantDone(c *transportDialTesterConn) { 190 rt.t.Helper() 191 <-rt.done 192 if rt.err != nil { 193 rt.t.Fatalf("RoundTrip %v: want success, got err %v", rt.roundTripID, rt.err) 194 } 195 if rt.conn != c { 196 rt.t.Fatalf("RoundTrip %v: want on conn %v, got conn %v", rt.roundTripID, c.connID, rt.conn.connID) 197 } 198} 199 200// finish completes a RoundTrip by sending the request body, consuming the response body, 201// and closing the response body. 202func (rt *transportDialTesterRoundTrip) finish() { 203 rt.t.Helper() 204 205 if rt.finished { 206 return 207 } 208 rt.finished = true 209 210 <-rt.done 211 212 if rt.err != nil { 213 return 214 } 215 rt.reqBody.Close() 216 io.ReadAll(rt.res.Body) 217 rt.res.Body.Close() 218 rt.t.Logf("RoundTrip %v: closed request body", rt.roundTripID) 219} 220 221// wantDial waits for the Transport to start a Dial. 222func (dt *transportDialTester) wantDial() *transportDialTesterConn { 223 c := <-dt.dials 224 c.connID = dt.dialCount 225 dt.dialCount++ 226 dt.t.Logf("Dial %v: started", c.connID) 227 return c 228} 229 230// finish completes a Dial. 231func (c *transportDialTesterConn) finish(err error) { 232 c.t.Logf("Dial %v: finished (err:%v)", c.connID, err) 233 c.ready <- err 234 close(c.ready) 235} 236