1// Copyright 2015 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Tests that use both the client & server, in both HTTP/1 and HTTP/2 mode.
6
7package http_test
8
9import (
10	"bytes"
11	"compress/gzip"
12	"context"
13	"crypto/rand"
14	"crypto/sha1"
15	"crypto/tls"
16	"fmt"
17	"hash"
18	"io"
19	"log"
20	"net"
21	. "net/http"
22	"net/http/httptest"
23	"net/http/httptrace"
24	"net/http/httputil"
25	"net/textproto"
26	"net/url"
27	"os"
28	"reflect"
29	"runtime"
30	"slices"
31	"strings"
32	"sync"
33	"sync/atomic"
34	"testing"
35	"time"
36)
37
38type testMode string
39
40const (
41	http1Mode  = testMode("h1")     // HTTP/1.1
42	https1Mode = testMode("https1") // HTTPS/1.1
43	http2Mode  = testMode("h2")     // HTTP/2
44)
45
46type testNotParallelOpt struct{}
47
48var (
49	testNotParallel = testNotParallelOpt{}
50)
51
52type TBRun[T any] interface {
53	testing.TB
54	Run(string, func(T)) bool
55}
56
57// run runs a client/server test in a variety of test configurations.
58//
59// Tests execute in HTTP/1.1 and HTTP/2 modes by default.
60// To run in a different set of configurations, pass a []testMode option.
61//
62// Tests call t.Parallel() by default.
63// To disable parallel execution, pass the testNotParallel option.
64func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) {
65	t.Helper()
66	modes := []testMode{http1Mode, http2Mode}
67	parallel := true
68	for _, opt := range opts {
69		switch opt := opt.(type) {
70		case []testMode:
71			modes = opt
72		case testNotParallelOpt:
73			parallel = false
74		default:
75			t.Fatalf("unknown option type %T", opt)
76		}
77	}
78	if t, ok := any(t).(*testing.T); ok && parallel {
79		setParallel(t)
80	}
81	for _, mode := range modes {
82		t.Run(string(mode), func(t T) {
83			t.Helper()
84			if t, ok := any(t).(*testing.T); ok && parallel {
85				setParallel(t)
86			}
87			t.Cleanup(func() {
88				afterTest(t)
89			})
90			f(t, mode)
91		})
92	}
93}
94
95type clientServerTest struct {
96	t  testing.TB
97	h2 bool
98	h  Handler
99	ts *httptest.Server
100	tr *Transport
101	c  *Client
102}
103
104func (t *clientServerTest) close() {
105	t.tr.CloseIdleConnections()
106	t.ts.Close()
107}
108
109func (t *clientServerTest) getURL(u string) string {
110	res, err := t.c.Get(u)
111	if err != nil {
112		t.t.Fatal(err)
113	}
114	defer res.Body.Close()
115	slurp, err := io.ReadAll(res.Body)
116	if err != nil {
117		t.t.Fatal(err)
118	}
119	return string(slurp)
120}
121
122func (t *clientServerTest) scheme() string {
123	if t.h2 {
124		return "https"
125	}
126	return "http"
127}
128
129var optQuietLog = func(ts *httptest.Server) {
130	ts.Config.ErrorLog = quietLog
131}
132
133func optWithServerLog(lg *log.Logger) func(*httptest.Server) {
134	return func(ts *httptest.Server) {
135		ts.Config.ErrorLog = lg
136	}
137}
138
139// newClientServerTest creates and starts an httptest.Server.
140//
141// The mode parameter selects the implementation to test:
142// HTTP/1, HTTP/2, etc. Tests using newClientServerTest should use
143// the 'run' function, which will start a subtests for each tested mode.
144//
145// The vararg opts parameter can include functions to configure the
146// test server or transport.
147//
148//	func(*httptest.Server) // run before starting the server
149//	func(*http.Transport)
150func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest {
151	if mode == http2Mode {
152		CondSkipHTTP2(t)
153	}
154	cst := &clientServerTest{
155		t:  t,
156		h2: mode == http2Mode,
157		h:  h,
158	}
159	cst.ts = httptest.NewUnstartedServer(h)
160
161	var transportFuncs []func(*Transport)
162	for _, opt := range opts {
163		switch opt := opt.(type) {
164		case func(*Transport):
165			transportFuncs = append(transportFuncs, opt)
166		case func(*httptest.Server):
167			opt(cst.ts)
168		default:
169			t.Fatalf("unhandled option type %T", opt)
170		}
171	}
172
173	if cst.ts.Config.ErrorLog == nil {
174		cst.ts.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
175	}
176
177	switch mode {
178	case http1Mode:
179		cst.ts.Start()
180	case https1Mode:
181		cst.ts.StartTLS()
182	case http2Mode:
183		ExportHttp2ConfigureServer(cst.ts.Config, nil)
184		cst.ts.TLS = cst.ts.Config.TLSConfig
185		cst.ts.StartTLS()
186	default:
187		t.Fatalf("unknown test mode %v", mode)
188	}
189	cst.c = cst.ts.Client()
190	cst.tr = cst.c.Transport.(*Transport)
191	if mode == http2Mode {
192		if err := ExportHttp2ConfigureTransport(cst.tr); err != nil {
193			t.Fatal(err)
194		}
195	}
196	for _, f := range transportFuncs {
197		f(cst.tr)
198	}
199	t.Cleanup(func() {
200		cst.close()
201	})
202	return cst
203}
204
205type testLogWriter struct {
206	t testing.TB
207}
208
209func (w testLogWriter) Write(b []byte) (int, error) {
210	w.t.Logf("server log: %v", strings.TrimSpace(string(b)))
211	return len(b), nil
212}
213
214// Testing the newClientServerTest helper itself.
215func TestNewClientServerTest(t *testing.T) {
216	run(t, testNewClientServerTest, []testMode{http1Mode, https1Mode, http2Mode})
217}
218func testNewClientServerTest(t *testing.T, mode testMode) {
219	var got struct {
220		sync.Mutex
221		proto  string
222		hasTLS bool
223	}
224	h := HandlerFunc(func(w ResponseWriter, r *Request) {
225		got.Lock()
226		defer got.Unlock()
227		got.proto = r.Proto
228		got.hasTLS = r.TLS != nil
229	})
230	cst := newClientServerTest(t, mode, h)
231	if _, err := cst.c.Head(cst.ts.URL); err != nil {
232		t.Fatal(err)
233	}
234	var wantProto string
235	var wantTLS bool
236	switch mode {
237	case http1Mode:
238		wantProto = "HTTP/1.1"
239		wantTLS = false
240	case https1Mode:
241		wantProto = "HTTP/1.1"
242		wantTLS = true
243	case http2Mode:
244		wantProto = "HTTP/2.0"
245		wantTLS = true
246	}
247	if got.proto != wantProto {
248		t.Errorf("req.Proto = %q, want %q", got.proto, wantProto)
249	}
250	if got.hasTLS != wantTLS {
251		t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS)
252	}
253}
254
255func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) }
256func testChunkedResponseHeaders(t *testing.T, mode testMode) {
257	log.SetOutput(io.Discard) // is noisy otherwise
258	defer log.SetOutput(os.Stderr)
259	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
260		w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted
261		w.(Flusher).Flush()
262		fmt.Fprintf(w, "I am a chunked response.")
263	}))
264
265	res, err := cst.c.Get(cst.ts.URL)
266	if err != nil {
267		t.Fatalf("Get error: %v", err)
268	}
269	defer res.Body.Close()
270	if g, e := res.ContentLength, int64(-1); g != e {
271		t.Errorf("expected ContentLength of %d; got %d", e, g)
272	}
273	wantTE := []string{"chunked"}
274	if mode == http2Mode {
275		wantTE = nil
276	}
277	if !reflect.DeepEqual(res.TransferEncoding, wantTE) {
278		t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE)
279	}
280	if got, haveCL := res.Header["Content-Length"]; haveCL {
281		t.Errorf("Unexpected Content-Length: %q", got)
282	}
283}
284
285type reqFunc func(c *Client, url string) (*Response, error)
286
287// h12Compare is a test that compares HTTP/1 and HTTP/2 behavior
288// against each other.
289type h12Compare struct {
290	Handler            func(ResponseWriter, *Request)    // required
291	ReqFunc            reqFunc                           // optional
292	CheckResponse      func(proto string, res *Response) // optional
293	EarlyCheckResponse func(proto string, res *Response) // optional; pre-normalize
294	Opts               []any
295}
296
297func (tt h12Compare) reqFunc() reqFunc {
298	if tt.ReqFunc == nil {
299		return (*Client).Get
300	}
301	return tt.ReqFunc
302}
303
304func (tt h12Compare) run(t *testing.T) {
305	setParallel(t)
306	cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...)
307	defer cst1.close()
308	cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...)
309	defer cst2.close()
310
311	res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
312	if err != nil {
313		t.Errorf("HTTP/1 request: %v", err)
314		return
315	}
316	res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL)
317	if err != nil {
318		t.Errorf("HTTP/2 request: %v", err)
319		return
320	}
321
322	if fn := tt.EarlyCheckResponse; fn != nil {
323		fn("HTTP/1.1", res1)
324		fn("HTTP/2.0", res2)
325	}
326
327	tt.normalizeRes(t, res1, "HTTP/1.1")
328	tt.normalizeRes(t, res2, "HTTP/2.0")
329	res1body, res2body := res1.Body, res2.Body
330
331	eres1 := mostlyCopy(res1)
332	eres2 := mostlyCopy(res2)
333	if !reflect.DeepEqual(eres1, eres2) {
334		t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v",
335			cst1.ts.URL, eres1, cst2.ts.URL, eres2)
336	}
337	if !reflect.DeepEqual(res1body, res2body) {
338		t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body)
339	}
340	if fn := tt.CheckResponse; fn != nil {
341		res1.Body, res2.Body = res1body, res2body
342		fn("HTTP/1.1", res1)
343		fn("HTTP/2.0", res2)
344	}
345}
346
347func mostlyCopy(r *Response) *Response {
348	c := *r
349	c.Body = nil
350	c.TransferEncoding = nil
351	c.TLS = nil
352	c.Request = nil
353	return &c
354}
355
356type slurpResult struct {
357	io.ReadCloser
358	body []byte
359	err  error
360}
361
362func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
363
364func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
365	if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" {
366		res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
367	} else {
368		t.Errorf("got %q response; want %q", res.Proto, wantProto)
369	}
370	slurp, err := io.ReadAll(res.Body)
371
372	res.Body.Close()
373	res.Body = slurpResult{
374		ReadCloser: io.NopCloser(bytes.NewReader(slurp)),
375		body:       slurp,
376		err:        err,
377	}
378	for i, v := range res.Header["Date"] {
379		res.Header["Date"][i] = strings.Repeat("x", len(v))
380	}
381	if res.Request == nil {
382		t.Errorf("for %s, no request", wantProto)
383	}
384	if (res.TLS != nil) != (wantProto == "HTTP/2.0") {
385		t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil)
386	}
387}
388
389// Issue 13532
390func TestH12_HeadContentLengthNoBody(t *testing.T) {
391	h12Compare{
392		ReqFunc: (*Client).Head,
393		Handler: func(w ResponseWriter, r *Request) {
394		},
395	}.run(t)
396}
397
398func TestH12_HeadContentLengthSmallBody(t *testing.T) {
399	h12Compare{
400		ReqFunc: (*Client).Head,
401		Handler: func(w ResponseWriter, r *Request) {
402			io.WriteString(w, "small")
403		},
404	}.run(t)
405}
406
407func TestH12_HeadContentLengthLargeBody(t *testing.T) {
408	h12Compare{
409		ReqFunc: (*Client).Head,
410		Handler: func(w ResponseWriter, r *Request) {
411			chunk := strings.Repeat("x", 512<<10)
412			for i := 0; i < 10; i++ {
413				io.WriteString(w, chunk)
414			}
415		},
416	}.run(t)
417}
418
419func TestH12_200NoBody(t *testing.T) {
420	h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t)
421}
422
423func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) }
424func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) }
425func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) }
426
427func testH12_noBody(t *testing.T, status int) {
428	h12Compare{Handler: func(w ResponseWriter, r *Request) {
429		w.WriteHeader(status)
430	}}.run(t)
431}
432
433func TestH12_SmallBody(t *testing.T) {
434	h12Compare{Handler: func(w ResponseWriter, r *Request) {
435		io.WriteString(w, "small body")
436	}}.run(t)
437}
438
439func TestH12_ExplicitContentLength(t *testing.T) {
440	h12Compare{Handler: func(w ResponseWriter, r *Request) {
441		w.Header().Set("Content-Length", "3")
442		io.WriteString(w, "foo")
443	}}.run(t)
444}
445
446func TestH12_FlushBeforeBody(t *testing.T) {
447	h12Compare{Handler: func(w ResponseWriter, r *Request) {
448		w.(Flusher).Flush()
449		io.WriteString(w, "foo")
450	}}.run(t)
451}
452
453func TestH12_FlushMidBody(t *testing.T) {
454	h12Compare{Handler: func(w ResponseWriter, r *Request) {
455		io.WriteString(w, "foo")
456		w.(Flusher).Flush()
457		io.WriteString(w, "bar")
458	}}.run(t)
459}
460
461func TestH12_Head_ExplicitLen(t *testing.T) {
462	h12Compare{
463		ReqFunc: (*Client).Head,
464		Handler: func(w ResponseWriter, r *Request) {
465			if r.Method != "HEAD" {
466				t.Errorf("unexpected method %q", r.Method)
467			}
468			w.Header().Set("Content-Length", "1235")
469		},
470	}.run(t)
471}
472
473func TestH12_Head_ImplicitLen(t *testing.T) {
474	h12Compare{
475		ReqFunc: (*Client).Head,
476		Handler: func(w ResponseWriter, r *Request) {
477			if r.Method != "HEAD" {
478				t.Errorf("unexpected method %q", r.Method)
479			}
480			io.WriteString(w, "foo")
481		},
482	}.run(t)
483}
484
485func TestH12_HandlerWritesTooLittle(t *testing.T) {
486	h12Compare{
487		Handler: func(w ResponseWriter, r *Request) {
488			w.Header().Set("Content-Length", "3")
489			io.WriteString(w, "12") // one byte short
490		},
491		CheckResponse: func(proto string, res *Response) {
492			sr, ok := res.Body.(slurpResult)
493			if !ok {
494				t.Errorf("%s body is %T; want slurpResult", proto, res.Body)
495				return
496			}
497			if sr.err != io.ErrUnexpectedEOF {
498				t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err)
499			}
500			if string(sr.body) != "12" {
501				t.Errorf("%s body = %q; want %q", proto, sr.body, "12")
502			}
503		},
504	}.run(t)
505}
506
507// Tests that the HTTP/1 and HTTP/2 servers prevent handlers from
508// writing more than they declared. This test does not test whether
509// the transport deals with too much data, though, since the server
510// doesn't make it possible to send bogus data. For those tests, see
511// transport_test.go (for HTTP/1) or x/net/http2/transport_test.go
512// (for HTTP/2).
513func TestHandlerWritesTooMuch(t *testing.T) { run(t, testHandlerWritesTooMuch) }
514func testHandlerWritesTooMuch(t *testing.T, mode testMode) {
515	wantBody := []byte("123")
516	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
517		rc := NewResponseController(w)
518		w.Header().Set("Content-Length", fmt.Sprintf("%v", len(wantBody)))
519		rc.Flush()
520		w.Write(wantBody)
521		rc.Flush()
522		n, err := io.WriteString(w, "x") // too many
523		if err == nil {
524			err = rc.Flush()
525		}
526		// TODO: Check that this is ErrContentLength, not just any error.
527		if err == nil {
528			t.Errorf("for proto %q, final write = %v, %v; want _, some error", r.Proto, n, err)
529		}
530	}))
531
532	res, err := cst.c.Get(cst.ts.URL)
533	if err != nil {
534		t.Fatal(err)
535	}
536	defer res.Body.Close()
537
538	gotBody, _ := io.ReadAll(res.Body)
539	if !bytes.Equal(gotBody, wantBody) {
540		t.Fatalf("got response body: %q; want %q", gotBody, wantBody)
541	}
542}
543
544// Verify that both our HTTP/1 and HTTP/2 request and auto-decompress gzip.
545// Some hosts send gzip even if you don't ask for it; see golang.org/issue/13298
546func TestH12_AutoGzip(t *testing.T) {
547	h12Compare{
548		Handler: func(w ResponseWriter, r *Request) {
549			if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" {
550				t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae)
551			}
552			w.Header().Set("Content-Encoding", "gzip")
553			gz := gzip.NewWriter(w)
554			io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.")
555			gz.Close()
556		},
557	}.run(t)
558}
559
560func TestH12_AutoGzip_Disabled(t *testing.T) {
561	h12Compare{
562		Opts: []any{
563			func(tr *Transport) { tr.DisableCompression = true },
564		},
565		Handler: func(w ResponseWriter, r *Request) {
566			fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
567			if ae := r.Header.Get("Accept-Encoding"); ae != "" {
568				t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
569			}
570		},
571	}.run(t)
572}
573
574// Test304Responses verifies that 304s don't declare that they're
575// chunking in their response headers and aren't allowed to produce
576// output.
577func Test304Responses(t *testing.T) { run(t, test304Responses) }
578func test304Responses(t *testing.T, mode testMode) {
579	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
580		w.WriteHeader(StatusNotModified)
581		_, err := w.Write([]byte("illegal body"))
582		if err != ErrBodyNotAllowed {
583			t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
584		}
585	}))
586	defer cst.close()
587	res, err := cst.c.Get(cst.ts.URL)
588	if err != nil {
589		t.Fatal(err)
590	}
591	if len(res.TransferEncoding) > 0 {
592		t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
593	}
594	body, err := io.ReadAll(res.Body)
595	if err != nil {
596		t.Error(err)
597	}
598	if len(body) > 0 {
599		t.Errorf("got unexpected body %q", string(body))
600	}
601}
602
603func TestH12_ServerEmptyContentLength(t *testing.T) {
604	h12Compare{
605		Handler: func(w ResponseWriter, r *Request) {
606			w.Header()["Content-Type"] = []string{""}
607			io.WriteString(w, "<html><body>hi</body></html>")
608		},
609	}.run(t)
610}
611
612func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
613	h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4)
614}
615
616func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
617	h12requestContentLength(t, func() io.Reader { return nil }, 0)
618}
619
620func TestH12_RequestContentLength_Unknown(t *testing.T) {
621	h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1)
622}
623
624func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) {
625	h12Compare{
626		Handler: func(w ResponseWriter, r *Request) {
627			w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength))
628			fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength)
629		},
630		ReqFunc: func(c *Client, url string) (*Response, error) {
631			return c.Post(url, "text/plain", bodyfn())
632		},
633		CheckResponse: func(proto string, res *Response) {
634			if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want {
635				t.Errorf("Proto %q got length %q; want %q", proto, got, want)
636			}
637		},
638	}.run(t)
639}
640
641// Tests that closing the Request.Cancel channel also while still
642// reading the response body. Issue 13159.
643func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) }
644func testCancelRequestMidBody(t *testing.T, mode testMode) {
645	unblock := make(chan bool)
646	didFlush := make(chan bool, 1)
647	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
648		io.WriteString(w, "Hello")
649		w.(Flusher).Flush()
650		didFlush <- true
651		<-unblock
652		io.WriteString(w, ", world.")
653	}))
654	defer close(unblock)
655
656	req, _ := NewRequest("GET", cst.ts.URL, nil)
657	cancel := make(chan struct{})
658	req.Cancel = cancel
659
660	res, err := cst.c.Do(req)
661	if err != nil {
662		t.Fatal(err)
663	}
664	defer res.Body.Close()
665	<-didFlush
666
667	// Read a bit before we cancel. (Issue 13626)
668	// We should have "Hello" at least sitting there.
669	firstRead := make([]byte, 10)
670	n, err := res.Body.Read(firstRead)
671	if err != nil {
672		t.Fatal(err)
673	}
674	firstRead = firstRead[:n]
675
676	close(cancel)
677
678	rest, err := io.ReadAll(res.Body)
679	all := string(firstRead) + string(rest)
680	if all != "Hello" {
681		t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest)
682	}
683	if err != ExportErrRequestCanceled {
684		t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled)
685	}
686}
687
688// Tests that clients can send trailers to a server and that the server can read them.
689func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) }
690func testTrailersClientToServer(t *testing.T, mode testMode) {
691	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
692		var decl []string
693		for k := range r.Trailer {
694			decl = append(decl, k)
695		}
696		slices.Sort(decl)
697
698		slurp, err := io.ReadAll(r.Body)
699		if err != nil {
700			t.Errorf("Server reading request body: %v", err)
701		}
702		if string(slurp) != "foo" {
703			t.Errorf("Server read request body %q; want foo", slurp)
704		}
705		if r.Trailer == nil {
706			io.WriteString(w, "nil Trailer")
707		} else {
708			fmt.Fprintf(w, "decl: %v, vals: %s, %s",
709				decl,
710				r.Trailer.Get("Client-Trailer-A"),
711				r.Trailer.Get("Client-Trailer-B"))
712		}
713	}))
714
715	var req *Request
716	req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader(
717		eofReaderFunc(func() {
718			req.Trailer["Client-Trailer-A"] = []string{"valuea"}
719		}),
720		strings.NewReader("foo"),
721		eofReaderFunc(func() {
722			req.Trailer["Client-Trailer-B"] = []string{"valueb"}
723		}),
724	))
725	req.Trailer = Header{
726		"Client-Trailer-A": nil, //  to be set later
727		"Client-Trailer-B": nil, //  to be set later
728	}
729	req.ContentLength = -1
730	res, err := cst.c.Do(req)
731	if err != nil {
732		t.Fatal(err)
733	}
734	if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
735		t.Error(err)
736	}
737}
738
739// Tests that servers send trailers to a client and that the client can read them.
740func TestTrailersServerToClient(t *testing.T) {
741	run(t, func(t *testing.T, mode testMode) {
742		testTrailersServerToClient(t, mode, false)
743	})
744}
745func TestTrailersServerToClientFlush(t *testing.T) {
746	run(t, func(t *testing.T, mode testMode) {
747		testTrailersServerToClient(t, mode, true)
748	})
749}
750
751func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) {
752	const body = "Some body"
753	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
754		w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
755		w.Header().Add("Trailer", "Server-Trailer-C")
756
757		io.WriteString(w, body)
758		if flush {
759			w.(Flusher).Flush()
760		}
761
762		// How handlers set Trailers: declare it ahead of time
763		// with the Trailer header, and then mutate the
764		// Header() of those values later, after the response
765		// has been written (we wrote to w above).
766		w.Header().Set("Server-Trailer-A", "valuea")
767		w.Header().Set("Server-Trailer-C", "valuec") // skipping B
768		w.Header().Set("Server-Trailer-NotDeclared", "should be omitted")
769	}))
770
771	res, err := cst.c.Get(cst.ts.URL)
772	if err != nil {
773		t.Fatal(err)
774	}
775
776	wantHeader := Header{
777		"Content-Type": {"text/plain; charset=utf-8"},
778	}
779	wantLen := -1
780	if mode == http2Mode && !flush {
781		// In HTTP/1.1, any use of trailers forces HTTP/1.1
782		// chunking and a flush at the first write. That's
783		// unnecessary with HTTP/2's framing, so the server
784		// is able to calculate the length while still sending
785		// trailers afterwards.
786		wantLen = len(body)
787		wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)}
788	}
789	if res.ContentLength != int64(wantLen) {
790		t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen)
791	}
792
793	delete(res.Header, "Date") // irrelevant for test
794	if !reflect.DeepEqual(res.Header, wantHeader) {
795		t.Errorf("Header = %v; want %v", res.Header, wantHeader)
796	}
797
798	if got, want := res.Trailer, (Header{
799		"Server-Trailer-A": nil,
800		"Server-Trailer-B": nil,
801		"Server-Trailer-C": nil,
802	}); !reflect.DeepEqual(got, want) {
803		t.Errorf("Trailer before body read = %v; want %v", got, want)
804	}
805
806	if err := wantBody(res, nil, body); err != nil {
807		t.Fatal(err)
808	}
809
810	if got, want := res.Trailer, (Header{
811		"Server-Trailer-A": {"valuea"},
812		"Server-Trailer-B": nil,
813		"Server-Trailer-C": {"valuec"},
814	}); !reflect.DeepEqual(got, want) {
815		t.Errorf("Trailer after body read = %v; want %v", got, want)
816	}
817}
818
819// Don't allow a Body.Read after Body.Close. Issue 13648.
820func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) }
821func testResponseBodyReadAfterClose(t *testing.T, mode testMode) {
822	const body = "Some body"
823	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
824		io.WriteString(w, body)
825	}))
826	res, err := cst.c.Get(cst.ts.URL)
827	if err != nil {
828		t.Fatal(err)
829	}
830	res.Body.Close()
831	data, err := io.ReadAll(res.Body)
832	if len(data) != 0 || err == nil {
833		t.Fatalf("ReadAll returned %q, %v; want error", data, err)
834	}
835}
836
837func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) }
838func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) {
839	const reqBody = "some request body"
840	const resBody = "some response body"
841	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
842		var wg sync.WaitGroup
843		wg.Add(2)
844		didRead := make(chan bool, 1)
845		// Read in one goroutine.
846		go func() {
847			defer wg.Done()
848			data, err := io.ReadAll(r.Body)
849			if string(data) != reqBody {
850				t.Errorf("Handler read %q; want %q", data, reqBody)
851			}
852			if err != nil {
853				t.Errorf("Handler Read: %v", err)
854			}
855			didRead <- true
856		}()
857		// Write in another goroutine.
858		go func() {
859			defer wg.Done()
860			if mode != http2Mode {
861				// our HTTP/1 implementation intentionally
862				// doesn't permit writes during read (mostly
863				// due to it being undefined); if that is ever
864				// relaxed, change this.
865				<-didRead
866			}
867			io.WriteString(w, resBody)
868		}()
869		wg.Wait()
870	}))
871	req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody))
872	req.Header.Add("Expect", "100-continue") // just to complicate things
873	res, err := cst.c.Do(req)
874	if err != nil {
875		t.Fatal(err)
876	}
877	data, err := io.ReadAll(res.Body)
878	defer res.Body.Close()
879	if err != nil {
880		t.Fatal(err)
881	}
882	if string(data) != resBody {
883		t.Errorf("read %q; want %q", data, resBody)
884	}
885}
886
887func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) }
888func testConnectRequest(t *testing.T, mode testMode) {
889	gotc := make(chan *Request, 1)
890	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
891		gotc <- r
892	}))
893
894	u, err := url.Parse(cst.ts.URL)
895	if err != nil {
896		t.Fatal(err)
897	}
898
899	tests := []struct {
900		req  *Request
901		want string
902	}{
903		{
904			req: &Request{
905				Method: "CONNECT",
906				Header: Header{},
907				URL:    u,
908			},
909			want: u.Host,
910		},
911		{
912			req: &Request{
913				Method: "CONNECT",
914				Header: Header{},
915				URL:    u,
916				Host:   "example.com:123",
917			},
918			want: "example.com:123",
919		},
920	}
921
922	for i, tt := range tests {
923		res, err := cst.c.Do(tt.req)
924		if err != nil {
925			t.Errorf("%d. RoundTrip = %v", i, err)
926			continue
927		}
928		res.Body.Close()
929		req := <-gotc
930		if req.Method != "CONNECT" {
931			t.Errorf("method = %q; want CONNECT", req.Method)
932		}
933		if req.Host != tt.want {
934			t.Errorf("Host = %q; want %q", req.Host, tt.want)
935		}
936		if req.URL.Host != tt.want {
937			t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
938		}
939	}
940}
941
942func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) }
943func testTransportUserAgent(t *testing.T, mode testMode) {
944	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
945		fmt.Fprintf(w, "%q", r.Header["User-Agent"])
946	}))
947
948	either := func(a, b string) string {
949		if mode == http2Mode {
950			return b
951		}
952		return a
953	}
954
955	tests := []struct {
956		setup func(*Request)
957		want  string
958	}{
959		{
960			func(r *Request) {},
961			either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`),
962		},
963		{
964			func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") },
965			`["foo/1.2.3"]`,
966		},
967		{
968			func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} },
969			`["single"]`,
970		},
971		{
972			func(r *Request) { r.Header.Set("User-Agent", "") },
973			`[]`,
974		},
975		{
976			func(r *Request) { r.Header["User-Agent"] = nil },
977			`[]`,
978		},
979	}
980	for i, tt := range tests {
981		req, _ := NewRequest("GET", cst.ts.URL, nil)
982		tt.setup(req)
983		res, err := cst.c.Do(req)
984		if err != nil {
985			t.Errorf("%d. RoundTrip = %v", i, err)
986			continue
987		}
988		slurp, err := io.ReadAll(res.Body)
989		res.Body.Close()
990		if err != nil {
991			t.Errorf("%d. read body = %v", i, err)
992			continue
993		}
994		if string(slurp) != tt.want {
995			t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want)
996		}
997	}
998}
999
1000func TestStarRequestMethod(t *testing.T) {
1001	for _, method := range []string{"FOO", "OPTIONS"} {
1002		t.Run(method, func(t *testing.T) {
1003			run(t, func(t *testing.T, mode testMode) {
1004				testStarRequest(t, method, mode)
1005			})
1006		})
1007	}
1008}
1009func testStarRequest(t *testing.T, method string, mode testMode) {
1010	gotc := make(chan *Request, 1)
1011	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1012		w.Header().Set("foo", "bar")
1013		gotc <- r
1014		w.(Flusher).Flush()
1015	}))
1016
1017	u, err := url.Parse(cst.ts.URL)
1018	if err != nil {
1019		t.Fatal(err)
1020	}
1021	u.Path = "*"
1022
1023	req := &Request{
1024		Method: method,
1025		Header: Header{},
1026		URL:    u,
1027	}
1028
1029	res, err := cst.c.Do(req)
1030	if err != nil {
1031		t.Fatalf("RoundTrip = %v", err)
1032	}
1033	res.Body.Close()
1034
1035	wantFoo := "bar"
1036	wantLen := int64(-1)
1037	if method == "OPTIONS" {
1038		wantFoo = ""
1039		wantLen = 0
1040	}
1041	if res.StatusCode != 200 {
1042		t.Errorf("status code = %v; want %d", res.Status, 200)
1043	}
1044	if res.ContentLength != wantLen {
1045		t.Errorf("content length = %v; want %d", res.ContentLength, wantLen)
1046	}
1047	if got := res.Header.Get("foo"); got != wantFoo {
1048		t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo)
1049	}
1050	select {
1051	case req = <-gotc:
1052	default:
1053		req = nil
1054	}
1055	if req == nil {
1056		if method != "OPTIONS" {
1057			t.Fatalf("handler never got request")
1058		}
1059		return
1060	}
1061	if req.Method != method {
1062		t.Errorf("method = %q; want %q", req.Method, method)
1063	}
1064	if req.URL.Path != "*" {
1065		t.Errorf("URL.Path = %q; want *", req.URL.Path)
1066	}
1067	if req.RequestURI != "*" {
1068		t.Errorf("RequestURI = %q; want *", req.RequestURI)
1069	}
1070}
1071
1072// Issue 13957
1073func TestTransportDiscardsUnneededConns(t *testing.T) {
1074	run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode})
1075}
1076func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) {
1077	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1078		fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
1079	}))
1080	defer cst.close()
1081
1082	var numOpen, numClose int32 // atomic
1083
1084	tlsConfig := &tls.Config{InsecureSkipVerify: true}
1085	tr := &Transport{
1086		TLSClientConfig: tlsConfig,
1087		DialTLS: func(_, addr string) (net.Conn, error) {
1088			time.Sleep(10 * time.Millisecond)
1089			rc, err := net.Dial("tcp", addr)
1090			if err != nil {
1091				return nil, err
1092			}
1093			atomic.AddInt32(&numOpen, 1)
1094			c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }}
1095			return tls.Client(c, tlsConfig), nil
1096		},
1097	}
1098	if err := ExportHttp2ConfigureTransport(tr); err != nil {
1099		t.Fatal(err)
1100	}
1101	defer tr.CloseIdleConnections()
1102
1103	c := &Client{Transport: tr}
1104
1105	const N = 10
1106	gotBody := make(chan string, N)
1107	var wg sync.WaitGroup
1108	for i := 0; i < N; i++ {
1109		wg.Add(1)
1110		go func() {
1111			defer wg.Done()
1112			resp, err := c.Get(cst.ts.URL)
1113			if err != nil {
1114				// Try to work around spurious connection reset on loaded system.
1115				// See golang.org/issue/33585 and golang.org/issue/36797.
1116				time.Sleep(10 * time.Millisecond)
1117				resp, err = c.Get(cst.ts.URL)
1118				if err != nil {
1119					t.Errorf("Get: %v", err)
1120					return
1121				}
1122			}
1123			defer resp.Body.Close()
1124			slurp, err := io.ReadAll(resp.Body)
1125			if err != nil {
1126				t.Error(err)
1127			}
1128			gotBody <- string(slurp)
1129		}()
1130	}
1131	wg.Wait()
1132	close(gotBody)
1133
1134	var last string
1135	for got := range gotBody {
1136		if last == "" {
1137			last = got
1138			continue
1139		}
1140		if got != last {
1141			t.Errorf("Response body changed: %q -> %q", last, got)
1142		}
1143	}
1144
1145	var open, close int32
1146	for i := 0; i < 150; i++ {
1147		open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose)
1148		if open < 1 {
1149			t.Fatalf("open = %d; want at least", open)
1150		}
1151		if close == open-1 {
1152			// Success
1153			return
1154		}
1155		time.Sleep(10 * time.Millisecond)
1156	}
1157	t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1)
1158}
1159
1160// tests that Transport doesn't retain a pointer to the provided request.
1161func TestTransportGCRequest(t *testing.T) {
1162	run(t, func(t *testing.T, mode testMode) {
1163		t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) })
1164		t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) })
1165	})
1166}
1167func testTransportGCRequest(t *testing.T, mode testMode, body bool) {
1168	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1169		io.ReadAll(r.Body)
1170		if body {
1171			io.WriteString(w, "Hello.")
1172		}
1173	}))
1174
1175	didGC := make(chan struct{})
1176	(func() {
1177		body := strings.NewReader("some body")
1178		req, _ := NewRequest("POST", cst.ts.URL, body)
1179		runtime.SetFinalizer(req, func(*Request) { close(didGC) })
1180		res, err := cst.c.Do(req)
1181		if err != nil {
1182			t.Fatal(err)
1183		}
1184		if _, err := io.ReadAll(res.Body); err != nil {
1185			t.Fatal(err)
1186		}
1187		if err := res.Body.Close(); err != nil {
1188			t.Fatal(err)
1189		}
1190	})()
1191	for {
1192		select {
1193		case <-didGC:
1194			return
1195		case <-time.After(1 * time.Millisecond):
1196			runtime.GC()
1197		}
1198	}
1199}
1200
1201func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) }
1202func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) {
1203	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1204		fmt.Fprintf(w, "Handler saw headers: %q", r.Header)
1205	}), optQuietLog)
1206	cst.tr.DisableKeepAlives = true
1207
1208	tests := []struct {
1209		key, val string
1210		ok       bool
1211	}{
1212		{"Foo", "capital-key", true}, // verify h2 allows capital keys
1213		{"Foo", "foo\x00bar", false}, // \x00 byte in value not allowed
1214		{"Foo", "two\nlines", false}, // \n byte in value not allowed
1215		{"bogus\nkey", "v", false},   // \n byte also not allowed in key
1216		{"A space", "v", false},      // spaces in keys not allowed
1217		{"имя", "v", false},          // key must be ascii
1218		{"name", "валю", true},       // value may be non-ascii
1219		{"", "v", false},             // key must be non-empty
1220		{"k", "", true},              // value may be empty
1221	}
1222	for _, tt := range tests {
1223		dialedc := make(chan bool, 1)
1224		cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
1225			dialedc <- true
1226			return net.Dial(netw, addr)
1227		}
1228		req, _ := NewRequest("GET", cst.ts.URL, nil)
1229		req.Header[tt.key] = []string{tt.val}
1230		res, err := cst.c.Do(req)
1231		var body []byte
1232		if err == nil {
1233			body, _ = io.ReadAll(res.Body)
1234			res.Body.Close()
1235		}
1236		var dialed bool
1237		select {
1238		case <-dialedc:
1239			dialed = true
1240		default:
1241		}
1242
1243		if !tt.ok && dialed {
1244			t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body)
1245		} else if (err == nil) != tt.ok {
1246			t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok)
1247		}
1248	}
1249}
1250
1251func TestInterruptWithPanic(t *testing.T) {
1252	run(t, func(t *testing.T, mode testMode) {
1253		t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") })
1254		t.Run("nil", func(t *testing.T) { t.Setenv("GODEBUG", "panicnil=1"); testInterruptWithPanic(t, mode, nil) })
1255		t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) })
1256	}, testNotParallel)
1257}
1258func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) {
1259	const msg = "hello"
1260
1261	testDone := make(chan struct{})
1262	defer close(testDone)
1263
1264	var errorLog lockedBytesBuffer
1265	gotHeaders := make(chan bool, 1)
1266	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1267		io.WriteString(w, msg)
1268		w.(Flusher).Flush()
1269
1270		select {
1271		case <-gotHeaders:
1272		case <-testDone:
1273		}
1274		panic(panicValue)
1275	}), func(ts *httptest.Server) {
1276		ts.Config.ErrorLog = log.New(&errorLog, "", 0)
1277	})
1278	res, err := cst.c.Get(cst.ts.URL)
1279	if err != nil {
1280		t.Fatal(err)
1281	}
1282	gotHeaders <- true
1283	defer res.Body.Close()
1284	slurp, err := io.ReadAll(res.Body)
1285	if string(slurp) != msg {
1286		t.Errorf("client read %q; want %q", slurp, msg)
1287	}
1288	if err == nil {
1289		t.Errorf("client read all successfully; want some error")
1290	}
1291	logOutput := func() string {
1292		errorLog.Lock()
1293		defer errorLog.Unlock()
1294		return errorLog.String()
1295	}
1296	wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler
1297
1298	waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
1299		gotLog := logOutput()
1300		if !wantStackLogged {
1301			if gotLog == "" {
1302				return true
1303			}
1304			t.Fatalf("want no log output; got: %s", gotLog)
1305		}
1306		if gotLog == "" {
1307			if d > 0 {
1308				t.Logf("wanted a stack trace logged; got nothing after %v", d)
1309			}
1310			return false
1311		}
1312		if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 {
1313			if d > 0 {
1314				t.Logf("output doesn't look like a panic stack trace after %v. Got: %s", d, gotLog)
1315			}
1316			return false
1317		}
1318		return true
1319	})
1320}
1321
1322type lockedBytesBuffer struct {
1323	sync.Mutex
1324	bytes.Buffer
1325}
1326
1327func (b *lockedBytesBuffer) Write(p []byte) (int, error) {
1328	b.Lock()
1329	defer b.Unlock()
1330	return b.Buffer.Write(p)
1331}
1332
1333// Issue 15366
1334func TestH12_AutoGzipWithDumpResponse(t *testing.T) {
1335	h12Compare{
1336		Handler: func(w ResponseWriter, r *Request) {
1337			h := w.Header()
1338			h.Set("Content-Encoding", "gzip")
1339			h.Set("Content-Length", "23")
1340			io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00")
1341		},
1342		EarlyCheckResponse: func(proto string, res *Response) {
1343			if !res.Uncompressed {
1344				t.Errorf("%s: expected Uncompressed to be set", proto)
1345			}
1346			dump, err := httputil.DumpResponse(res, true)
1347			if err != nil {
1348				t.Errorf("%s: DumpResponse: %v", proto, err)
1349				return
1350			}
1351			if strings.Contains(string(dump), "Connection: close") {
1352				t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump)
1353			}
1354			if !strings.Contains(string(dump), "FOO") {
1355				t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump)
1356			}
1357		},
1358	}.run(t)
1359}
1360
1361// Issue 14607
1362func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) }
1363func testCloseIdleConnections(t *testing.T, mode testMode) {
1364	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1365		w.Header().Set("X-Addr", r.RemoteAddr)
1366	}))
1367	get := func() string {
1368		res, err := cst.c.Get(cst.ts.URL)
1369		if err != nil {
1370			t.Fatal(err)
1371		}
1372		res.Body.Close()
1373		v := res.Header.Get("X-Addr")
1374		if v == "" {
1375			t.Fatal("didn't get X-Addr")
1376		}
1377		return v
1378	}
1379	a1 := get()
1380	cst.tr.CloseIdleConnections()
1381	a2 := get()
1382	if a1 == a2 {
1383		t.Errorf("didn't close connection")
1384	}
1385}
1386
1387type noteCloseConn struct {
1388	net.Conn
1389	closeFunc func()
1390}
1391
1392func (x noteCloseConn) Close() error {
1393	x.closeFunc()
1394	return x.Conn.Close()
1395}
1396
1397type testErrorReader struct{ t *testing.T }
1398
1399func (r testErrorReader) Read(p []byte) (n int, err error) {
1400	r.t.Error("unexpected Read call")
1401	return 0, io.EOF
1402}
1403
1404func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) }
1405func testNoSniffExpectRequestBody(t *testing.T, mode testMode) {
1406	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1407		w.WriteHeader(StatusUnauthorized)
1408	}))
1409
1410	// Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it.
1411	cst.tr.ExpectContinueTimeout = 10 * time.Second
1412
1413	req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t})
1414	if err != nil {
1415		t.Fatal(err)
1416	}
1417	req.ContentLength = 0 // so transport is tempted to sniff it
1418	req.Header.Set("Expect", "100-continue")
1419	res, err := cst.tr.RoundTrip(req)
1420	if err != nil {
1421		t.Fatal(err)
1422	}
1423	defer res.Body.Close()
1424	if res.StatusCode != StatusUnauthorized {
1425		t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized)
1426	}
1427}
1428
1429func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) }
1430func testServerUndeclaredTrailers(t *testing.T, mode testMode) {
1431	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1432		w.Header().Set("Foo", "Bar")
1433		w.Header().Set("Trailer:Foo", "Baz")
1434		w.(Flusher).Flush()
1435		w.Header().Add("Trailer:Foo", "Baz2")
1436		w.Header().Set("Trailer:Bar", "Quux")
1437	}))
1438	res, err := cst.c.Get(cst.ts.URL)
1439	if err != nil {
1440		t.Fatal(err)
1441	}
1442	if _, err := io.Copy(io.Discard, res.Body); err != nil {
1443		t.Fatal(err)
1444	}
1445	res.Body.Close()
1446	delete(res.Header, "Date")
1447	delete(res.Header, "Content-Type")
1448
1449	if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) {
1450		t.Errorf("Header = %#v; want %#v", res.Header, want)
1451	}
1452	if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) {
1453		t.Errorf("Trailer = %#v; want %#v", res.Trailer, want)
1454	}
1455}
1456
1457func TestBadResponseAfterReadingBody(t *testing.T) {
1458	run(t, testBadResponseAfterReadingBody, []testMode{http1Mode})
1459}
1460func testBadResponseAfterReadingBody(t *testing.T, mode testMode) {
1461	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1462		_, err := io.Copy(io.Discard, r.Body)
1463		if err != nil {
1464			t.Fatal(err)
1465		}
1466		c, _, err := w.(Hijacker).Hijack()
1467		if err != nil {
1468			t.Fatal(err)
1469		}
1470		defer c.Close()
1471		fmt.Fprintln(c, "some bogus crap")
1472	}))
1473
1474	closes := 0
1475	res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
1476	if err == nil {
1477		res.Body.Close()
1478		t.Fatal("expected an error to be returned from Post")
1479	}
1480	if closes != 1 {
1481		t.Errorf("closes = %d; want 1", closes)
1482	}
1483}
1484
1485func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) }
1486func testWriteHeader0(t *testing.T, mode testMode) {
1487	gotpanic := make(chan bool, 1)
1488	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1489		defer close(gotpanic)
1490		defer func() {
1491			if e := recover(); e != nil {
1492				got := fmt.Sprintf("%T, %v", e, e)
1493				want := "string, invalid WriteHeader code 0"
1494				if got != want {
1495					t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want)
1496				}
1497				gotpanic <- true
1498
1499				// Set an explicit 503. This also tests that the WriteHeader call panics
1500				// before it recorded that an explicit value was set and that bogus
1501				// value wasn't stuck.
1502				w.WriteHeader(503)
1503			}
1504		}()
1505		w.WriteHeader(0)
1506	}))
1507	res, err := cst.c.Get(cst.ts.URL)
1508	if err != nil {
1509		t.Fatal(err)
1510	}
1511	if res.StatusCode != 503 {
1512		t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status)
1513	}
1514	if !<-gotpanic {
1515		t.Error("expected panic in handler")
1516	}
1517}
1518
1519// Issue 23010: don't be super strict checking WriteHeader's code if
1520// it's not even valid to call WriteHeader then anyway.
1521func TestWriteHeaderNoCodeCheck(t *testing.T) {
1522	run(t, func(t *testing.T, mode testMode) {
1523		testWriteHeaderAfterWrite(t, mode, false)
1524	})
1525}
1526func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) {
1527	testWriteHeaderAfterWrite(t, http1Mode, true)
1528}
1529func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) {
1530	var errorLog lockedBytesBuffer
1531	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1532		if hijack {
1533			conn, _, _ := w.(Hijacker).Hijack()
1534			defer conn.Close()
1535			conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo"))
1536			w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010
1537			conn.Write([]byte("bar"))
1538			return
1539		}
1540		io.WriteString(w, "foo")
1541		w.(Flusher).Flush()
1542		w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010
1543		io.WriteString(w, "bar")
1544	}), func(ts *httptest.Server) {
1545		ts.Config.ErrorLog = log.New(&errorLog, "", 0)
1546	})
1547	res, err := cst.c.Get(cst.ts.URL)
1548	if err != nil {
1549		t.Fatal(err)
1550	}
1551	defer res.Body.Close()
1552	body, err := io.ReadAll(res.Body)
1553	if err != nil {
1554		t.Fatal(err)
1555	}
1556	if got, want := string(body), "foobar"; got != want {
1557		t.Errorf("got = %q; want %q", got, want)
1558	}
1559
1560	// Also check the stderr output:
1561	if mode == http2Mode {
1562		// TODO: also emit this log message for HTTP/2?
1563		// We historically haven't, so don't check.
1564		return
1565	}
1566	gotLog := strings.TrimSpace(errorLog.String())
1567	wantLog := "http: superfluous response.WriteHeader call from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
1568	if hijack {
1569		wantLog = "http: response.WriteHeader on hijacked connection from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
1570	}
1571	if !strings.HasPrefix(gotLog, wantLog) {
1572		t.Errorf("stderr output = %q; want %q", gotLog, wantLog)
1573	}
1574}
1575
1576func TestBidiStreamReverseProxy(t *testing.T) {
1577	run(t, testBidiStreamReverseProxy, []testMode{http2Mode})
1578}
1579func testBidiStreamReverseProxy(t *testing.T, mode testMode) {
1580	backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1581		if _, err := io.Copy(w, r.Body); err != nil {
1582			log.Printf("bidi backend copy: %v", err)
1583		}
1584	}))
1585
1586	backURL, err := url.Parse(backend.ts.URL)
1587	if err != nil {
1588		t.Fatal(err)
1589	}
1590	rp := httputil.NewSingleHostReverseProxy(backURL)
1591	rp.Transport = backend.tr
1592	proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1593		rp.ServeHTTP(w, r)
1594	}))
1595
1596	bodyRes := make(chan any, 1) // error or hash.Hash
1597	pr, pw := io.Pipe()
1598	req, _ := NewRequest("PUT", proxy.ts.URL, pr)
1599	const size = 4 << 20
1600	go func() {
1601		h := sha1.New()
1602		_, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size)
1603		go pw.Close()
1604		if err != nil {
1605			bodyRes <- err
1606		} else {
1607			bodyRes <- h
1608		}
1609	}()
1610	res, err := backend.c.Do(req)
1611	if err != nil {
1612		t.Fatal(err)
1613	}
1614	defer res.Body.Close()
1615	hgot := sha1.New()
1616	n, err := io.Copy(hgot, res.Body)
1617	if err != nil {
1618		t.Fatal(err)
1619	}
1620	if n != size {
1621		t.Fatalf("got %d bytes; want %d", n, size)
1622	}
1623	select {
1624	case v := <-bodyRes:
1625		switch v := v.(type) {
1626		default:
1627			t.Fatalf("body copy: %v", err)
1628		case hash.Hash:
1629			if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) {
1630				t.Errorf("written bytes didn't match received bytes")
1631			}
1632		}
1633	case <-time.After(10 * time.Second):
1634		t.Fatal("timeout")
1635	}
1636
1637}
1638
1639// Always use HTTP/1.1 for WebSocket upgrades.
1640func TestH12_WebSocketUpgrade(t *testing.T) {
1641	h12Compare{
1642		Handler: func(w ResponseWriter, r *Request) {
1643			h := w.Header()
1644			h.Set("Foo", "bar")
1645		},
1646		ReqFunc: func(c *Client, url string) (*Response, error) {
1647			req, _ := NewRequest("GET", url, nil)
1648			req.Header.Set("Connection", "Upgrade")
1649			req.Header.Set("Upgrade", "WebSocket")
1650			return c.Do(req)
1651		},
1652		EarlyCheckResponse: func(proto string, res *Response) {
1653			if res.Proto != "HTTP/1.1" {
1654				t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto)
1655			}
1656			res.Proto = "HTTP/IGNORE" // skip later checks that Proto must be 1.1 vs 2.0
1657		},
1658	}.run(t)
1659}
1660
1661func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) }
1662func testIdentityTransferEncoding(t *testing.T, mode testMode) {
1663	const body = "body"
1664	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1665		gotBody, _ := io.ReadAll(r.Body)
1666		if got, want := string(gotBody), body; got != want {
1667			t.Errorf("got request body = %q; want %q", got, want)
1668		}
1669		w.Header().Set("Transfer-Encoding", "identity")
1670		w.WriteHeader(StatusOK)
1671		w.(Flusher).Flush()
1672		io.WriteString(w, body)
1673	}))
1674	req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body))
1675	res, err := cst.c.Do(req)
1676	if err != nil {
1677		t.Fatal(err)
1678	}
1679	defer res.Body.Close()
1680	gotBody, err := io.ReadAll(res.Body)
1681	if err != nil {
1682		t.Fatal(err)
1683	}
1684	if got, want := string(gotBody), body; got != want {
1685		t.Errorf("got response body = %q; want %q", got, want)
1686	}
1687}
1688
1689func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) }
1690func testEarlyHintsRequest(t *testing.T, mode testMode) {
1691	var wg sync.WaitGroup
1692	wg.Add(1)
1693	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1694		h := w.Header()
1695
1696		h.Add("Content-Length", "123") // must be ignored
1697		h.Add("Link", "</style.css>; rel=preload; as=style")
1698		h.Add("Link", "</script.js>; rel=preload; as=script")
1699		w.WriteHeader(StatusEarlyHints)
1700
1701		wg.Wait()
1702
1703		h.Add("Link", "</foo.js>; rel=preload; as=script")
1704		w.WriteHeader(StatusEarlyHints)
1705
1706		w.Write([]byte("Hello"))
1707	}))
1708
1709	checkLinkHeaders := func(t *testing.T, expected, got []string) {
1710		t.Helper()
1711
1712		if len(expected) != len(got) {
1713			t.Errorf("got %d expected %d", len(got), len(expected))
1714		}
1715
1716		for i := range expected {
1717			if expected[i] != got[i] {
1718				t.Errorf("got %q expected %q", got[i], expected[i])
1719			}
1720		}
1721	}
1722
1723	checkExcludedHeaders := func(t *testing.T, header textproto.MIMEHeader) {
1724		t.Helper()
1725
1726		for _, h := range []string{"Content-Length", "Transfer-Encoding"} {
1727			if v, ok := header[h]; ok {
1728				t.Errorf("%s is %q; must not be sent", h, v)
1729			}
1730		}
1731	}
1732
1733	var respCounter uint8
1734	trace := &httptrace.ClientTrace{
1735		Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
1736			switch respCounter {
1737			case 0:
1738				checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
1739				checkExcludedHeaders(t, header)
1740
1741				wg.Done()
1742			case 1:
1743				checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
1744				checkExcludedHeaders(t, header)
1745
1746			default:
1747				t.Error("Unexpected 1xx response")
1748			}
1749
1750			respCounter++
1751
1752			return nil
1753		},
1754	}
1755	req, _ := NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", cst.ts.URL, nil)
1756
1757	res, err := cst.c.Do(req)
1758	if err != nil {
1759		t.Fatal(err)
1760	}
1761	defer res.Body.Close()
1762
1763	checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
1764	if cl := res.Header.Get("Content-Length"); cl != "123" {
1765		t.Errorf("Content-Length is %q; want 123", cl)
1766	}
1767
1768	body, _ := io.ReadAll(res.Body)
1769	if string(body) != "Hello" {
1770		t.Errorf("Read body %q; want Hello", body)
1771	}
1772}
1773