1// Copyright 2015 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5// Tests that use both the client & server, in both HTTP/1 and HTTP/2 mode. 6 7package http_test 8 9import ( 10 "bytes" 11 "compress/gzip" 12 "context" 13 "crypto/rand" 14 "crypto/sha1" 15 "crypto/tls" 16 "fmt" 17 "hash" 18 "io" 19 "log" 20 "net" 21 . "net/http" 22 "net/http/httptest" 23 "net/http/httptrace" 24 "net/http/httputil" 25 "net/textproto" 26 "net/url" 27 "os" 28 "reflect" 29 "runtime" 30 "slices" 31 "strings" 32 "sync" 33 "sync/atomic" 34 "testing" 35 "time" 36) 37 38type testMode string 39 40const ( 41 http1Mode = testMode("h1") // HTTP/1.1 42 https1Mode = testMode("https1") // HTTPS/1.1 43 http2Mode = testMode("h2") // HTTP/2 44) 45 46type testNotParallelOpt struct{} 47 48var ( 49 testNotParallel = testNotParallelOpt{} 50) 51 52type TBRun[T any] interface { 53 testing.TB 54 Run(string, func(T)) bool 55} 56 57// run runs a client/server test in a variety of test configurations. 58// 59// Tests execute in HTTP/1.1 and HTTP/2 modes by default. 60// To run in a different set of configurations, pass a []testMode option. 61// 62// Tests call t.Parallel() by default. 63// To disable parallel execution, pass the testNotParallel option. 64func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) { 65 t.Helper() 66 modes := []testMode{http1Mode, http2Mode} 67 parallel := true 68 for _, opt := range opts { 69 switch opt := opt.(type) { 70 case []testMode: 71 modes = opt 72 case testNotParallelOpt: 73 parallel = false 74 default: 75 t.Fatalf("unknown option type %T", opt) 76 } 77 } 78 if t, ok := any(t).(*testing.T); ok && parallel { 79 setParallel(t) 80 } 81 for _, mode := range modes { 82 t.Run(string(mode), func(t T) { 83 t.Helper() 84 if t, ok := any(t).(*testing.T); ok && parallel { 85 setParallel(t) 86 } 87 t.Cleanup(func() { 88 afterTest(t) 89 }) 90 f(t, mode) 91 }) 92 } 93} 94 95type clientServerTest struct { 96 t testing.TB 97 h2 bool 98 h Handler 99 ts *httptest.Server 100 tr *Transport 101 c *Client 102} 103 104func (t *clientServerTest) close() { 105 t.tr.CloseIdleConnections() 106 t.ts.Close() 107} 108 109func (t *clientServerTest) getURL(u string) string { 110 res, err := t.c.Get(u) 111 if err != nil { 112 t.t.Fatal(err) 113 } 114 defer res.Body.Close() 115 slurp, err := io.ReadAll(res.Body) 116 if err != nil { 117 t.t.Fatal(err) 118 } 119 return string(slurp) 120} 121 122func (t *clientServerTest) scheme() string { 123 if t.h2 { 124 return "https" 125 } 126 return "http" 127} 128 129var optQuietLog = func(ts *httptest.Server) { 130 ts.Config.ErrorLog = quietLog 131} 132 133func optWithServerLog(lg *log.Logger) func(*httptest.Server) { 134 return func(ts *httptest.Server) { 135 ts.Config.ErrorLog = lg 136 } 137} 138 139// newClientServerTest creates and starts an httptest.Server. 140// 141// The mode parameter selects the implementation to test: 142// HTTP/1, HTTP/2, etc. Tests using newClientServerTest should use 143// the 'run' function, which will start a subtests for each tested mode. 144// 145// The vararg opts parameter can include functions to configure the 146// test server or transport. 147// 148// func(*httptest.Server) // run before starting the server 149// func(*http.Transport) 150func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest { 151 if mode == http2Mode { 152 CondSkipHTTP2(t) 153 } 154 cst := &clientServerTest{ 155 t: t, 156 h2: mode == http2Mode, 157 h: h, 158 } 159 cst.ts = httptest.NewUnstartedServer(h) 160 161 var transportFuncs []func(*Transport) 162 for _, opt := range opts { 163 switch opt := opt.(type) { 164 case func(*Transport): 165 transportFuncs = append(transportFuncs, opt) 166 case func(*httptest.Server): 167 opt(cst.ts) 168 default: 169 t.Fatalf("unhandled option type %T", opt) 170 } 171 } 172 173 if cst.ts.Config.ErrorLog == nil { 174 cst.ts.Config.ErrorLog = log.New(testLogWriter{t}, "", 0) 175 } 176 177 switch mode { 178 case http1Mode: 179 cst.ts.Start() 180 case https1Mode: 181 cst.ts.StartTLS() 182 case http2Mode: 183 ExportHttp2ConfigureServer(cst.ts.Config, nil) 184 cst.ts.TLS = cst.ts.Config.TLSConfig 185 cst.ts.StartTLS() 186 default: 187 t.Fatalf("unknown test mode %v", mode) 188 } 189 cst.c = cst.ts.Client() 190 cst.tr = cst.c.Transport.(*Transport) 191 if mode == http2Mode { 192 if err := ExportHttp2ConfigureTransport(cst.tr); err != nil { 193 t.Fatal(err) 194 } 195 } 196 for _, f := range transportFuncs { 197 f(cst.tr) 198 } 199 t.Cleanup(func() { 200 cst.close() 201 }) 202 return cst 203} 204 205type testLogWriter struct { 206 t testing.TB 207} 208 209func (w testLogWriter) Write(b []byte) (int, error) { 210 w.t.Logf("server log: %v", strings.TrimSpace(string(b))) 211 return len(b), nil 212} 213 214// Testing the newClientServerTest helper itself. 215func TestNewClientServerTest(t *testing.T) { 216 run(t, testNewClientServerTest, []testMode{http1Mode, https1Mode, http2Mode}) 217} 218func testNewClientServerTest(t *testing.T, mode testMode) { 219 var got struct { 220 sync.Mutex 221 proto string 222 hasTLS bool 223 } 224 h := HandlerFunc(func(w ResponseWriter, r *Request) { 225 got.Lock() 226 defer got.Unlock() 227 got.proto = r.Proto 228 got.hasTLS = r.TLS != nil 229 }) 230 cst := newClientServerTest(t, mode, h) 231 if _, err := cst.c.Head(cst.ts.URL); err != nil { 232 t.Fatal(err) 233 } 234 var wantProto string 235 var wantTLS bool 236 switch mode { 237 case http1Mode: 238 wantProto = "HTTP/1.1" 239 wantTLS = false 240 case https1Mode: 241 wantProto = "HTTP/1.1" 242 wantTLS = true 243 case http2Mode: 244 wantProto = "HTTP/2.0" 245 wantTLS = true 246 } 247 if got.proto != wantProto { 248 t.Errorf("req.Proto = %q, want %q", got.proto, wantProto) 249 } 250 if got.hasTLS != wantTLS { 251 t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS) 252 } 253} 254 255func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) } 256func testChunkedResponseHeaders(t *testing.T, mode testMode) { 257 log.SetOutput(io.Discard) // is noisy otherwise 258 defer log.SetOutput(os.Stderr) 259 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 260 w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted 261 w.(Flusher).Flush() 262 fmt.Fprintf(w, "I am a chunked response.") 263 })) 264 265 res, err := cst.c.Get(cst.ts.URL) 266 if err != nil { 267 t.Fatalf("Get error: %v", err) 268 } 269 defer res.Body.Close() 270 if g, e := res.ContentLength, int64(-1); g != e { 271 t.Errorf("expected ContentLength of %d; got %d", e, g) 272 } 273 wantTE := []string{"chunked"} 274 if mode == http2Mode { 275 wantTE = nil 276 } 277 if !reflect.DeepEqual(res.TransferEncoding, wantTE) { 278 t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE) 279 } 280 if got, haveCL := res.Header["Content-Length"]; haveCL { 281 t.Errorf("Unexpected Content-Length: %q", got) 282 } 283} 284 285type reqFunc func(c *Client, url string) (*Response, error) 286 287// h12Compare is a test that compares HTTP/1 and HTTP/2 behavior 288// against each other. 289type h12Compare struct { 290 Handler func(ResponseWriter, *Request) // required 291 ReqFunc reqFunc // optional 292 CheckResponse func(proto string, res *Response) // optional 293 EarlyCheckResponse func(proto string, res *Response) // optional; pre-normalize 294 Opts []any 295} 296 297func (tt h12Compare) reqFunc() reqFunc { 298 if tt.ReqFunc == nil { 299 return (*Client).Get 300 } 301 return tt.ReqFunc 302} 303 304func (tt h12Compare) run(t *testing.T) { 305 setParallel(t) 306 cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...) 307 defer cst1.close() 308 cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...) 309 defer cst2.close() 310 311 res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL) 312 if err != nil { 313 t.Errorf("HTTP/1 request: %v", err) 314 return 315 } 316 res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL) 317 if err != nil { 318 t.Errorf("HTTP/2 request: %v", err) 319 return 320 } 321 322 if fn := tt.EarlyCheckResponse; fn != nil { 323 fn("HTTP/1.1", res1) 324 fn("HTTP/2.0", res2) 325 } 326 327 tt.normalizeRes(t, res1, "HTTP/1.1") 328 tt.normalizeRes(t, res2, "HTTP/2.0") 329 res1body, res2body := res1.Body, res2.Body 330 331 eres1 := mostlyCopy(res1) 332 eres2 := mostlyCopy(res2) 333 if !reflect.DeepEqual(eres1, eres2) { 334 t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v", 335 cst1.ts.URL, eres1, cst2.ts.URL, eres2) 336 } 337 if !reflect.DeepEqual(res1body, res2body) { 338 t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body) 339 } 340 if fn := tt.CheckResponse; fn != nil { 341 res1.Body, res2.Body = res1body, res2body 342 fn("HTTP/1.1", res1) 343 fn("HTTP/2.0", res2) 344 } 345} 346 347func mostlyCopy(r *Response) *Response { 348 c := *r 349 c.Body = nil 350 c.TransferEncoding = nil 351 c.TLS = nil 352 c.Request = nil 353 return &c 354} 355 356type slurpResult struct { 357 io.ReadCloser 358 body []byte 359 err error 360} 361 362func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) } 363 364func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) { 365 if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" { 366 res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0 367 } else { 368 t.Errorf("got %q response; want %q", res.Proto, wantProto) 369 } 370 slurp, err := io.ReadAll(res.Body) 371 372 res.Body.Close() 373 res.Body = slurpResult{ 374 ReadCloser: io.NopCloser(bytes.NewReader(slurp)), 375 body: slurp, 376 err: err, 377 } 378 for i, v := range res.Header["Date"] { 379 res.Header["Date"][i] = strings.Repeat("x", len(v)) 380 } 381 if res.Request == nil { 382 t.Errorf("for %s, no request", wantProto) 383 } 384 if (res.TLS != nil) != (wantProto == "HTTP/2.0") { 385 t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil) 386 } 387} 388 389// Issue 13532 390func TestH12_HeadContentLengthNoBody(t *testing.T) { 391 h12Compare{ 392 ReqFunc: (*Client).Head, 393 Handler: func(w ResponseWriter, r *Request) { 394 }, 395 }.run(t) 396} 397 398func TestH12_HeadContentLengthSmallBody(t *testing.T) { 399 h12Compare{ 400 ReqFunc: (*Client).Head, 401 Handler: func(w ResponseWriter, r *Request) { 402 io.WriteString(w, "small") 403 }, 404 }.run(t) 405} 406 407func TestH12_HeadContentLengthLargeBody(t *testing.T) { 408 h12Compare{ 409 ReqFunc: (*Client).Head, 410 Handler: func(w ResponseWriter, r *Request) { 411 chunk := strings.Repeat("x", 512<<10) 412 for i := 0; i < 10; i++ { 413 io.WriteString(w, chunk) 414 } 415 }, 416 }.run(t) 417} 418 419func TestH12_200NoBody(t *testing.T) { 420 h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t) 421} 422 423func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) } 424func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) } 425func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) } 426 427func testH12_noBody(t *testing.T, status int) { 428 h12Compare{Handler: func(w ResponseWriter, r *Request) { 429 w.WriteHeader(status) 430 }}.run(t) 431} 432 433func TestH12_SmallBody(t *testing.T) { 434 h12Compare{Handler: func(w ResponseWriter, r *Request) { 435 io.WriteString(w, "small body") 436 }}.run(t) 437} 438 439func TestH12_ExplicitContentLength(t *testing.T) { 440 h12Compare{Handler: func(w ResponseWriter, r *Request) { 441 w.Header().Set("Content-Length", "3") 442 io.WriteString(w, "foo") 443 }}.run(t) 444} 445 446func TestH12_FlushBeforeBody(t *testing.T) { 447 h12Compare{Handler: func(w ResponseWriter, r *Request) { 448 w.(Flusher).Flush() 449 io.WriteString(w, "foo") 450 }}.run(t) 451} 452 453func TestH12_FlushMidBody(t *testing.T) { 454 h12Compare{Handler: func(w ResponseWriter, r *Request) { 455 io.WriteString(w, "foo") 456 w.(Flusher).Flush() 457 io.WriteString(w, "bar") 458 }}.run(t) 459} 460 461func TestH12_Head_ExplicitLen(t *testing.T) { 462 h12Compare{ 463 ReqFunc: (*Client).Head, 464 Handler: func(w ResponseWriter, r *Request) { 465 if r.Method != "HEAD" { 466 t.Errorf("unexpected method %q", r.Method) 467 } 468 w.Header().Set("Content-Length", "1235") 469 }, 470 }.run(t) 471} 472 473func TestH12_Head_ImplicitLen(t *testing.T) { 474 h12Compare{ 475 ReqFunc: (*Client).Head, 476 Handler: func(w ResponseWriter, r *Request) { 477 if r.Method != "HEAD" { 478 t.Errorf("unexpected method %q", r.Method) 479 } 480 io.WriteString(w, "foo") 481 }, 482 }.run(t) 483} 484 485func TestH12_HandlerWritesTooLittle(t *testing.T) { 486 h12Compare{ 487 Handler: func(w ResponseWriter, r *Request) { 488 w.Header().Set("Content-Length", "3") 489 io.WriteString(w, "12") // one byte short 490 }, 491 CheckResponse: func(proto string, res *Response) { 492 sr, ok := res.Body.(slurpResult) 493 if !ok { 494 t.Errorf("%s body is %T; want slurpResult", proto, res.Body) 495 return 496 } 497 if sr.err != io.ErrUnexpectedEOF { 498 t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err) 499 } 500 if string(sr.body) != "12" { 501 t.Errorf("%s body = %q; want %q", proto, sr.body, "12") 502 } 503 }, 504 }.run(t) 505} 506 507// Tests that the HTTP/1 and HTTP/2 servers prevent handlers from 508// writing more than they declared. This test does not test whether 509// the transport deals with too much data, though, since the server 510// doesn't make it possible to send bogus data. For those tests, see 511// transport_test.go (for HTTP/1) or x/net/http2/transport_test.go 512// (for HTTP/2). 513func TestHandlerWritesTooMuch(t *testing.T) { run(t, testHandlerWritesTooMuch) } 514func testHandlerWritesTooMuch(t *testing.T, mode testMode) { 515 wantBody := []byte("123") 516 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 517 rc := NewResponseController(w) 518 w.Header().Set("Content-Length", fmt.Sprintf("%v", len(wantBody))) 519 rc.Flush() 520 w.Write(wantBody) 521 rc.Flush() 522 n, err := io.WriteString(w, "x") // too many 523 if err == nil { 524 err = rc.Flush() 525 } 526 // TODO: Check that this is ErrContentLength, not just any error. 527 if err == nil { 528 t.Errorf("for proto %q, final write = %v, %v; want _, some error", r.Proto, n, err) 529 } 530 })) 531 532 res, err := cst.c.Get(cst.ts.URL) 533 if err != nil { 534 t.Fatal(err) 535 } 536 defer res.Body.Close() 537 538 gotBody, _ := io.ReadAll(res.Body) 539 if !bytes.Equal(gotBody, wantBody) { 540 t.Fatalf("got response body: %q; want %q", gotBody, wantBody) 541 } 542} 543 544// Verify that both our HTTP/1 and HTTP/2 request and auto-decompress gzip. 545// Some hosts send gzip even if you don't ask for it; see golang.org/issue/13298 546func TestH12_AutoGzip(t *testing.T) { 547 h12Compare{ 548 Handler: func(w ResponseWriter, r *Request) { 549 if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" { 550 t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae) 551 } 552 w.Header().Set("Content-Encoding", "gzip") 553 gz := gzip.NewWriter(w) 554 io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.") 555 gz.Close() 556 }, 557 }.run(t) 558} 559 560func TestH12_AutoGzip_Disabled(t *testing.T) { 561 h12Compare{ 562 Opts: []any{ 563 func(tr *Transport) { tr.DisableCompression = true }, 564 }, 565 Handler: func(w ResponseWriter, r *Request) { 566 fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"]) 567 if ae := r.Header.Get("Accept-Encoding"); ae != "" { 568 t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae) 569 } 570 }, 571 }.run(t) 572} 573 574// Test304Responses verifies that 304s don't declare that they're 575// chunking in their response headers and aren't allowed to produce 576// output. 577func Test304Responses(t *testing.T) { run(t, test304Responses) } 578func test304Responses(t *testing.T, mode testMode) { 579 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 580 w.WriteHeader(StatusNotModified) 581 _, err := w.Write([]byte("illegal body")) 582 if err != ErrBodyNotAllowed { 583 t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err) 584 } 585 })) 586 defer cst.close() 587 res, err := cst.c.Get(cst.ts.URL) 588 if err != nil { 589 t.Fatal(err) 590 } 591 if len(res.TransferEncoding) > 0 { 592 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding) 593 } 594 body, err := io.ReadAll(res.Body) 595 if err != nil { 596 t.Error(err) 597 } 598 if len(body) > 0 { 599 t.Errorf("got unexpected body %q", string(body)) 600 } 601} 602 603func TestH12_ServerEmptyContentLength(t *testing.T) { 604 h12Compare{ 605 Handler: func(w ResponseWriter, r *Request) { 606 w.Header()["Content-Type"] = []string{""} 607 io.WriteString(w, "<html><body>hi</body></html>") 608 }, 609 }.run(t) 610} 611 612func TestH12_RequestContentLength_Known_NonZero(t *testing.T) { 613 h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4) 614} 615 616func TestH12_RequestContentLength_Known_Zero(t *testing.T) { 617 h12requestContentLength(t, func() io.Reader { return nil }, 0) 618} 619 620func TestH12_RequestContentLength_Unknown(t *testing.T) { 621 h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1) 622} 623 624func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) { 625 h12Compare{ 626 Handler: func(w ResponseWriter, r *Request) { 627 w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength)) 628 fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength) 629 }, 630 ReqFunc: func(c *Client, url string) (*Response, error) { 631 return c.Post(url, "text/plain", bodyfn()) 632 }, 633 CheckResponse: func(proto string, res *Response) { 634 if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want { 635 t.Errorf("Proto %q got length %q; want %q", proto, got, want) 636 } 637 }, 638 }.run(t) 639} 640 641// Tests that closing the Request.Cancel channel also while still 642// reading the response body. Issue 13159. 643func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) } 644func testCancelRequestMidBody(t *testing.T, mode testMode) { 645 unblock := make(chan bool) 646 didFlush := make(chan bool, 1) 647 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 648 io.WriteString(w, "Hello") 649 w.(Flusher).Flush() 650 didFlush <- true 651 <-unblock 652 io.WriteString(w, ", world.") 653 })) 654 defer close(unblock) 655 656 req, _ := NewRequest("GET", cst.ts.URL, nil) 657 cancel := make(chan struct{}) 658 req.Cancel = cancel 659 660 res, err := cst.c.Do(req) 661 if err != nil { 662 t.Fatal(err) 663 } 664 defer res.Body.Close() 665 <-didFlush 666 667 // Read a bit before we cancel. (Issue 13626) 668 // We should have "Hello" at least sitting there. 669 firstRead := make([]byte, 10) 670 n, err := res.Body.Read(firstRead) 671 if err != nil { 672 t.Fatal(err) 673 } 674 firstRead = firstRead[:n] 675 676 close(cancel) 677 678 rest, err := io.ReadAll(res.Body) 679 all := string(firstRead) + string(rest) 680 if all != "Hello" { 681 t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest) 682 } 683 if err != ExportErrRequestCanceled { 684 t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled) 685 } 686} 687 688// Tests that clients can send trailers to a server and that the server can read them. 689func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) } 690func testTrailersClientToServer(t *testing.T, mode testMode) { 691 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 692 var decl []string 693 for k := range r.Trailer { 694 decl = append(decl, k) 695 } 696 slices.Sort(decl) 697 698 slurp, err := io.ReadAll(r.Body) 699 if err != nil { 700 t.Errorf("Server reading request body: %v", err) 701 } 702 if string(slurp) != "foo" { 703 t.Errorf("Server read request body %q; want foo", slurp) 704 } 705 if r.Trailer == nil { 706 io.WriteString(w, "nil Trailer") 707 } else { 708 fmt.Fprintf(w, "decl: %v, vals: %s, %s", 709 decl, 710 r.Trailer.Get("Client-Trailer-A"), 711 r.Trailer.Get("Client-Trailer-B")) 712 } 713 })) 714 715 var req *Request 716 req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader( 717 eofReaderFunc(func() { 718 req.Trailer["Client-Trailer-A"] = []string{"valuea"} 719 }), 720 strings.NewReader("foo"), 721 eofReaderFunc(func() { 722 req.Trailer["Client-Trailer-B"] = []string{"valueb"} 723 }), 724 )) 725 req.Trailer = Header{ 726 "Client-Trailer-A": nil, // to be set later 727 "Client-Trailer-B": nil, // to be set later 728 } 729 req.ContentLength = -1 730 res, err := cst.c.Do(req) 731 if err != nil { 732 t.Fatal(err) 733 } 734 if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil { 735 t.Error(err) 736 } 737} 738 739// Tests that servers send trailers to a client and that the client can read them. 740func TestTrailersServerToClient(t *testing.T) { 741 run(t, func(t *testing.T, mode testMode) { 742 testTrailersServerToClient(t, mode, false) 743 }) 744} 745func TestTrailersServerToClientFlush(t *testing.T) { 746 run(t, func(t *testing.T, mode testMode) { 747 testTrailersServerToClient(t, mode, true) 748 }) 749} 750 751func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) { 752 const body = "Some body" 753 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 754 w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B") 755 w.Header().Add("Trailer", "Server-Trailer-C") 756 757 io.WriteString(w, body) 758 if flush { 759 w.(Flusher).Flush() 760 } 761 762 // How handlers set Trailers: declare it ahead of time 763 // with the Trailer header, and then mutate the 764 // Header() of those values later, after the response 765 // has been written (we wrote to w above). 766 w.Header().Set("Server-Trailer-A", "valuea") 767 w.Header().Set("Server-Trailer-C", "valuec") // skipping B 768 w.Header().Set("Server-Trailer-NotDeclared", "should be omitted") 769 })) 770 771 res, err := cst.c.Get(cst.ts.URL) 772 if err != nil { 773 t.Fatal(err) 774 } 775 776 wantHeader := Header{ 777 "Content-Type": {"text/plain; charset=utf-8"}, 778 } 779 wantLen := -1 780 if mode == http2Mode && !flush { 781 // In HTTP/1.1, any use of trailers forces HTTP/1.1 782 // chunking and a flush at the first write. That's 783 // unnecessary with HTTP/2's framing, so the server 784 // is able to calculate the length while still sending 785 // trailers afterwards. 786 wantLen = len(body) 787 wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)} 788 } 789 if res.ContentLength != int64(wantLen) { 790 t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen) 791 } 792 793 delete(res.Header, "Date") // irrelevant for test 794 if !reflect.DeepEqual(res.Header, wantHeader) { 795 t.Errorf("Header = %v; want %v", res.Header, wantHeader) 796 } 797 798 if got, want := res.Trailer, (Header{ 799 "Server-Trailer-A": nil, 800 "Server-Trailer-B": nil, 801 "Server-Trailer-C": nil, 802 }); !reflect.DeepEqual(got, want) { 803 t.Errorf("Trailer before body read = %v; want %v", got, want) 804 } 805 806 if err := wantBody(res, nil, body); err != nil { 807 t.Fatal(err) 808 } 809 810 if got, want := res.Trailer, (Header{ 811 "Server-Trailer-A": {"valuea"}, 812 "Server-Trailer-B": nil, 813 "Server-Trailer-C": {"valuec"}, 814 }); !reflect.DeepEqual(got, want) { 815 t.Errorf("Trailer after body read = %v; want %v", got, want) 816 } 817} 818 819// Don't allow a Body.Read after Body.Close. Issue 13648. 820func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) } 821func testResponseBodyReadAfterClose(t *testing.T, mode testMode) { 822 const body = "Some body" 823 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 824 io.WriteString(w, body) 825 })) 826 res, err := cst.c.Get(cst.ts.URL) 827 if err != nil { 828 t.Fatal(err) 829 } 830 res.Body.Close() 831 data, err := io.ReadAll(res.Body) 832 if len(data) != 0 || err == nil { 833 t.Fatalf("ReadAll returned %q, %v; want error", data, err) 834 } 835} 836 837func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) } 838func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) { 839 const reqBody = "some request body" 840 const resBody = "some response body" 841 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 842 var wg sync.WaitGroup 843 wg.Add(2) 844 didRead := make(chan bool, 1) 845 // Read in one goroutine. 846 go func() { 847 defer wg.Done() 848 data, err := io.ReadAll(r.Body) 849 if string(data) != reqBody { 850 t.Errorf("Handler read %q; want %q", data, reqBody) 851 } 852 if err != nil { 853 t.Errorf("Handler Read: %v", err) 854 } 855 didRead <- true 856 }() 857 // Write in another goroutine. 858 go func() { 859 defer wg.Done() 860 if mode != http2Mode { 861 // our HTTP/1 implementation intentionally 862 // doesn't permit writes during read (mostly 863 // due to it being undefined); if that is ever 864 // relaxed, change this. 865 <-didRead 866 } 867 io.WriteString(w, resBody) 868 }() 869 wg.Wait() 870 })) 871 req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody)) 872 req.Header.Add("Expect", "100-continue") // just to complicate things 873 res, err := cst.c.Do(req) 874 if err != nil { 875 t.Fatal(err) 876 } 877 data, err := io.ReadAll(res.Body) 878 defer res.Body.Close() 879 if err != nil { 880 t.Fatal(err) 881 } 882 if string(data) != resBody { 883 t.Errorf("read %q; want %q", data, resBody) 884 } 885} 886 887func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) } 888func testConnectRequest(t *testing.T, mode testMode) { 889 gotc := make(chan *Request, 1) 890 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 891 gotc <- r 892 })) 893 894 u, err := url.Parse(cst.ts.URL) 895 if err != nil { 896 t.Fatal(err) 897 } 898 899 tests := []struct { 900 req *Request 901 want string 902 }{ 903 { 904 req: &Request{ 905 Method: "CONNECT", 906 Header: Header{}, 907 URL: u, 908 }, 909 want: u.Host, 910 }, 911 { 912 req: &Request{ 913 Method: "CONNECT", 914 Header: Header{}, 915 URL: u, 916 Host: "example.com:123", 917 }, 918 want: "example.com:123", 919 }, 920 } 921 922 for i, tt := range tests { 923 res, err := cst.c.Do(tt.req) 924 if err != nil { 925 t.Errorf("%d. RoundTrip = %v", i, err) 926 continue 927 } 928 res.Body.Close() 929 req := <-gotc 930 if req.Method != "CONNECT" { 931 t.Errorf("method = %q; want CONNECT", req.Method) 932 } 933 if req.Host != tt.want { 934 t.Errorf("Host = %q; want %q", req.Host, tt.want) 935 } 936 if req.URL.Host != tt.want { 937 t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want) 938 } 939 } 940} 941 942func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) } 943func testTransportUserAgent(t *testing.T, mode testMode) { 944 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 945 fmt.Fprintf(w, "%q", r.Header["User-Agent"]) 946 })) 947 948 either := func(a, b string) string { 949 if mode == http2Mode { 950 return b 951 } 952 return a 953 } 954 955 tests := []struct { 956 setup func(*Request) 957 want string 958 }{ 959 { 960 func(r *Request) {}, 961 either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`), 962 }, 963 { 964 func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") }, 965 `["foo/1.2.3"]`, 966 }, 967 { 968 func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} }, 969 `["single"]`, 970 }, 971 { 972 func(r *Request) { r.Header.Set("User-Agent", "") }, 973 `[]`, 974 }, 975 { 976 func(r *Request) { r.Header["User-Agent"] = nil }, 977 `[]`, 978 }, 979 } 980 for i, tt := range tests { 981 req, _ := NewRequest("GET", cst.ts.URL, nil) 982 tt.setup(req) 983 res, err := cst.c.Do(req) 984 if err != nil { 985 t.Errorf("%d. RoundTrip = %v", i, err) 986 continue 987 } 988 slurp, err := io.ReadAll(res.Body) 989 res.Body.Close() 990 if err != nil { 991 t.Errorf("%d. read body = %v", i, err) 992 continue 993 } 994 if string(slurp) != tt.want { 995 t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want) 996 } 997 } 998} 999 1000func TestStarRequestMethod(t *testing.T) { 1001 for _, method := range []string{"FOO", "OPTIONS"} { 1002 t.Run(method, func(t *testing.T) { 1003 run(t, func(t *testing.T, mode testMode) { 1004 testStarRequest(t, method, mode) 1005 }) 1006 }) 1007 } 1008} 1009func testStarRequest(t *testing.T, method string, mode testMode) { 1010 gotc := make(chan *Request, 1) 1011 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1012 w.Header().Set("foo", "bar") 1013 gotc <- r 1014 w.(Flusher).Flush() 1015 })) 1016 1017 u, err := url.Parse(cst.ts.URL) 1018 if err != nil { 1019 t.Fatal(err) 1020 } 1021 u.Path = "*" 1022 1023 req := &Request{ 1024 Method: method, 1025 Header: Header{}, 1026 URL: u, 1027 } 1028 1029 res, err := cst.c.Do(req) 1030 if err != nil { 1031 t.Fatalf("RoundTrip = %v", err) 1032 } 1033 res.Body.Close() 1034 1035 wantFoo := "bar" 1036 wantLen := int64(-1) 1037 if method == "OPTIONS" { 1038 wantFoo = "" 1039 wantLen = 0 1040 } 1041 if res.StatusCode != 200 { 1042 t.Errorf("status code = %v; want %d", res.Status, 200) 1043 } 1044 if res.ContentLength != wantLen { 1045 t.Errorf("content length = %v; want %d", res.ContentLength, wantLen) 1046 } 1047 if got := res.Header.Get("foo"); got != wantFoo { 1048 t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo) 1049 } 1050 select { 1051 case req = <-gotc: 1052 default: 1053 req = nil 1054 } 1055 if req == nil { 1056 if method != "OPTIONS" { 1057 t.Fatalf("handler never got request") 1058 } 1059 return 1060 } 1061 if req.Method != method { 1062 t.Errorf("method = %q; want %q", req.Method, method) 1063 } 1064 if req.URL.Path != "*" { 1065 t.Errorf("URL.Path = %q; want *", req.URL.Path) 1066 } 1067 if req.RequestURI != "*" { 1068 t.Errorf("RequestURI = %q; want *", req.RequestURI) 1069 } 1070} 1071 1072// Issue 13957 1073func TestTransportDiscardsUnneededConns(t *testing.T) { 1074 run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode}) 1075} 1076func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) { 1077 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1078 fmt.Fprintf(w, "Hello, %v", r.RemoteAddr) 1079 })) 1080 defer cst.close() 1081 1082 var numOpen, numClose int32 // atomic 1083 1084 tlsConfig := &tls.Config{InsecureSkipVerify: true} 1085 tr := &Transport{ 1086 TLSClientConfig: tlsConfig, 1087 DialTLS: func(_, addr string) (net.Conn, error) { 1088 time.Sleep(10 * time.Millisecond) 1089 rc, err := net.Dial("tcp", addr) 1090 if err != nil { 1091 return nil, err 1092 } 1093 atomic.AddInt32(&numOpen, 1) 1094 c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }} 1095 return tls.Client(c, tlsConfig), nil 1096 }, 1097 } 1098 if err := ExportHttp2ConfigureTransport(tr); err != nil { 1099 t.Fatal(err) 1100 } 1101 defer tr.CloseIdleConnections() 1102 1103 c := &Client{Transport: tr} 1104 1105 const N = 10 1106 gotBody := make(chan string, N) 1107 var wg sync.WaitGroup 1108 for i := 0; i < N; i++ { 1109 wg.Add(1) 1110 go func() { 1111 defer wg.Done() 1112 resp, err := c.Get(cst.ts.URL) 1113 if err != nil { 1114 // Try to work around spurious connection reset on loaded system. 1115 // See golang.org/issue/33585 and golang.org/issue/36797. 1116 time.Sleep(10 * time.Millisecond) 1117 resp, err = c.Get(cst.ts.URL) 1118 if err != nil { 1119 t.Errorf("Get: %v", err) 1120 return 1121 } 1122 } 1123 defer resp.Body.Close() 1124 slurp, err := io.ReadAll(resp.Body) 1125 if err != nil { 1126 t.Error(err) 1127 } 1128 gotBody <- string(slurp) 1129 }() 1130 } 1131 wg.Wait() 1132 close(gotBody) 1133 1134 var last string 1135 for got := range gotBody { 1136 if last == "" { 1137 last = got 1138 continue 1139 } 1140 if got != last { 1141 t.Errorf("Response body changed: %q -> %q", last, got) 1142 } 1143 } 1144 1145 var open, close int32 1146 for i := 0; i < 150; i++ { 1147 open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose) 1148 if open < 1 { 1149 t.Fatalf("open = %d; want at least", open) 1150 } 1151 if close == open-1 { 1152 // Success 1153 return 1154 } 1155 time.Sleep(10 * time.Millisecond) 1156 } 1157 t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1) 1158} 1159 1160// tests that Transport doesn't retain a pointer to the provided request. 1161func TestTransportGCRequest(t *testing.T) { 1162 run(t, func(t *testing.T, mode testMode) { 1163 t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) }) 1164 t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) }) 1165 }) 1166} 1167func testTransportGCRequest(t *testing.T, mode testMode, body bool) { 1168 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1169 io.ReadAll(r.Body) 1170 if body { 1171 io.WriteString(w, "Hello.") 1172 } 1173 })) 1174 1175 didGC := make(chan struct{}) 1176 (func() { 1177 body := strings.NewReader("some body") 1178 req, _ := NewRequest("POST", cst.ts.URL, body) 1179 runtime.SetFinalizer(req, func(*Request) { close(didGC) }) 1180 res, err := cst.c.Do(req) 1181 if err != nil { 1182 t.Fatal(err) 1183 } 1184 if _, err := io.ReadAll(res.Body); err != nil { 1185 t.Fatal(err) 1186 } 1187 if err := res.Body.Close(); err != nil { 1188 t.Fatal(err) 1189 } 1190 })() 1191 for { 1192 select { 1193 case <-didGC: 1194 return 1195 case <-time.After(1 * time.Millisecond): 1196 runtime.GC() 1197 } 1198 } 1199} 1200 1201func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) } 1202func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) { 1203 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1204 fmt.Fprintf(w, "Handler saw headers: %q", r.Header) 1205 }), optQuietLog) 1206 cst.tr.DisableKeepAlives = true 1207 1208 tests := []struct { 1209 key, val string 1210 ok bool 1211 }{ 1212 {"Foo", "capital-key", true}, // verify h2 allows capital keys 1213 {"Foo", "foo\x00bar", false}, // \x00 byte in value not allowed 1214 {"Foo", "two\nlines", false}, // \n byte in value not allowed 1215 {"bogus\nkey", "v", false}, // \n byte also not allowed in key 1216 {"A space", "v", false}, // spaces in keys not allowed 1217 {"имя", "v", false}, // key must be ascii 1218 {"name", "валю", true}, // value may be non-ascii 1219 {"", "v", false}, // key must be non-empty 1220 {"k", "", true}, // value may be empty 1221 } 1222 for _, tt := range tests { 1223 dialedc := make(chan bool, 1) 1224 cst.tr.Dial = func(netw, addr string) (net.Conn, error) { 1225 dialedc <- true 1226 return net.Dial(netw, addr) 1227 } 1228 req, _ := NewRequest("GET", cst.ts.URL, nil) 1229 req.Header[tt.key] = []string{tt.val} 1230 res, err := cst.c.Do(req) 1231 var body []byte 1232 if err == nil { 1233 body, _ = io.ReadAll(res.Body) 1234 res.Body.Close() 1235 } 1236 var dialed bool 1237 select { 1238 case <-dialedc: 1239 dialed = true 1240 default: 1241 } 1242 1243 if !tt.ok && dialed { 1244 t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body) 1245 } else if (err == nil) != tt.ok { 1246 t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok) 1247 } 1248 } 1249} 1250 1251func TestInterruptWithPanic(t *testing.T) { 1252 run(t, func(t *testing.T, mode testMode) { 1253 t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") }) 1254 t.Run("nil", func(t *testing.T) { t.Setenv("GODEBUG", "panicnil=1"); testInterruptWithPanic(t, mode, nil) }) 1255 t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) }) 1256 }, testNotParallel) 1257} 1258func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) { 1259 const msg = "hello" 1260 1261 testDone := make(chan struct{}) 1262 defer close(testDone) 1263 1264 var errorLog lockedBytesBuffer 1265 gotHeaders := make(chan bool, 1) 1266 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1267 io.WriteString(w, msg) 1268 w.(Flusher).Flush() 1269 1270 select { 1271 case <-gotHeaders: 1272 case <-testDone: 1273 } 1274 panic(panicValue) 1275 }), func(ts *httptest.Server) { 1276 ts.Config.ErrorLog = log.New(&errorLog, "", 0) 1277 }) 1278 res, err := cst.c.Get(cst.ts.URL) 1279 if err != nil { 1280 t.Fatal(err) 1281 } 1282 gotHeaders <- true 1283 defer res.Body.Close() 1284 slurp, err := io.ReadAll(res.Body) 1285 if string(slurp) != msg { 1286 t.Errorf("client read %q; want %q", slurp, msg) 1287 } 1288 if err == nil { 1289 t.Errorf("client read all successfully; want some error") 1290 } 1291 logOutput := func() string { 1292 errorLog.Lock() 1293 defer errorLog.Unlock() 1294 return errorLog.String() 1295 } 1296 wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler 1297 1298 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { 1299 gotLog := logOutput() 1300 if !wantStackLogged { 1301 if gotLog == "" { 1302 return true 1303 } 1304 t.Fatalf("want no log output; got: %s", gotLog) 1305 } 1306 if gotLog == "" { 1307 if d > 0 { 1308 t.Logf("wanted a stack trace logged; got nothing after %v", d) 1309 } 1310 return false 1311 } 1312 if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 { 1313 if d > 0 { 1314 t.Logf("output doesn't look like a panic stack trace after %v. Got: %s", d, gotLog) 1315 } 1316 return false 1317 } 1318 return true 1319 }) 1320} 1321 1322type lockedBytesBuffer struct { 1323 sync.Mutex 1324 bytes.Buffer 1325} 1326 1327func (b *lockedBytesBuffer) Write(p []byte) (int, error) { 1328 b.Lock() 1329 defer b.Unlock() 1330 return b.Buffer.Write(p) 1331} 1332 1333// Issue 15366 1334func TestH12_AutoGzipWithDumpResponse(t *testing.T) { 1335 h12Compare{ 1336 Handler: func(w ResponseWriter, r *Request) { 1337 h := w.Header() 1338 h.Set("Content-Encoding", "gzip") 1339 h.Set("Content-Length", "23") 1340 io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00") 1341 }, 1342 EarlyCheckResponse: func(proto string, res *Response) { 1343 if !res.Uncompressed { 1344 t.Errorf("%s: expected Uncompressed to be set", proto) 1345 } 1346 dump, err := httputil.DumpResponse(res, true) 1347 if err != nil { 1348 t.Errorf("%s: DumpResponse: %v", proto, err) 1349 return 1350 } 1351 if strings.Contains(string(dump), "Connection: close") { 1352 t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump) 1353 } 1354 if !strings.Contains(string(dump), "FOO") { 1355 t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump) 1356 } 1357 }, 1358 }.run(t) 1359} 1360 1361// Issue 14607 1362func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) } 1363func testCloseIdleConnections(t *testing.T, mode testMode) { 1364 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1365 w.Header().Set("X-Addr", r.RemoteAddr) 1366 })) 1367 get := func() string { 1368 res, err := cst.c.Get(cst.ts.URL) 1369 if err != nil { 1370 t.Fatal(err) 1371 } 1372 res.Body.Close() 1373 v := res.Header.Get("X-Addr") 1374 if v == "" { 1375 t.Fatal("didn't get X-Addr") 1376 } 1377 return v 1378 } 1379 a1 := get() 1380 cst.tr.CloseIdleConnections() 1381 a2 := get() 1382 if a1 == a2 { 1383 t.Errorf("didn't close connection") 1384 } 1385} 1386 1387type noteCloseConn struct { 1388 net.Conn 1389 closeFunc func() 1390} 1391 1392func (x noteCloseConn) Close() error { 1393 x.closeFunc() 1394 return x.Conn.Close() 1395} 1396 1397type testErrorReader struct{ t *testing.T } 1398 1399func (r testErrorReader) Read(p []byte) (n int, err error) { 1400 r.t.Error("unexpected Read call") 1401 return 0, io.EOF 1402} 1403 1404func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) } 1405func testNoSniffExpectRequestBody(t *testing.T, mode testMode) { 1406 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1407 w.WriteHeader(StatusUnauthorized) 1408 })) 1409 1410 // Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it. 1411 cst.tr.ExpectContinueTimeout = 10 * time.Second 1412 1413 req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t}) 1414 if err != nil { 1415 t.Fatal(err) 1416 } 1417 req.ContentLength = 0 // so transport is tempted to sniff it 1418 req.Header.Set("Expect", "100-continue") 1419 res, err := cst.tr.RoundTrip(req) 1420 if err != nil { 1421 t.Fatal(err) 1422 } 1423 defer res.Body.Close() 1424 if res.StatusCode != StatusUnauthorized { 1425 t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized) 1426 } 1427} 1428 1429func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) } 1430func testServerUndeclaredTrailers(t *testing.T, mode testMode) { 1431 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1432 w.Header().Set("Foo", "Bar") 1433 w.Header().Set("Trailer:Foo", "Baz") 1434 w.(Flusher).Flush() 1435 w.Header().Add("Trailer:Foo", "Baz2") 1436 w.Header().Set("Trailer:Bar", "Quux") 1437 })) 1438 res, err := cst.c.Get(cst.ts.URL) 1439 if err != nil { 1440 t.Fatal(err) 1441 } 1442 if _, err := io.Copy(io.Discard, res.Body); err != nil { 1443 t.Fatal(err) 1444 } 1445 res.Body.Close() 1446 delete(res.Header, "Date") 1447 delete(res.Header, "Content-Type") 1448 1449 if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) { 1450 t.Errorf("Header = %#v; want %#v", res.Header, want) 1451 } 1452 if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) { 1453 t.Errorf("Trailer = %#v; want %#v", res.Trailer, want) 1454 } 1455} 1456 1457func TestBadResponseAfterReadingBody(t *testing.T) { 1458 run(t, testBadResponseAfterReadingBody, []testMode{http1Mode}) 1459} 1460func testBadResponseAfterReadingBody(t *testing.T, mode testMode) { 1461 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1462 _, err := io.Copy(io.Discard, r.Body) 1463 if err != nil { 1464 t.Fatal(err) 1465 } 1466 c, _, err := w.(Hijacker).Hijack() 1467 if err != nil { 1468 t.Fatal(err) 1469 } 1470 defer c.Close() 1471 fmt.Fprintln(c, "some bogus crap") 1472 })) 1473 1474 closes := 0 1475 res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) 1476 if err == nil { 1477 res.Body.Close() 1478 t.Fatal("expected an error to be returned from Post") 1479 } 1480 if closes != 1 { 1481 t.Errorf("closes = %d; want 1", closes) 1482 } 1483} 1484 1485func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) } 1486func testWriteHeader0(t *testing.T, mode testMode) { 1487 gotpanic := make(chan bool, 1) 1488 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1489 defer close(gotpanic) 1490 defer func() { 1491 if e := recover(); e != nil { 1492 got := fmt.Sprintf("%T, %v", e, e) 1493 want := "string, invalid WriteHeader code 0" 1494 if got != want { 1495 t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want) 1496 } 1497 gotpanic <- true 1498 1499 // Set an explicit 503. This also tests that the WriteHeader call panics 1500 // before it recorded that an explicit value was set and that bogus 1501 // value wasn't stuck. 1502 w.WriteHeader(503) 1503 } 1504 }() 1505 w.WriteHeader(0) 1506 })) 1507 res, err := cst.c.Get(cst.ts.URL) 1508 if err != nil { 1509 t.Fatal(err) 1510 } 1511 if res.StatusCode != 503 { 1512 t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status) 1513 } 1514 if !<-gotpanic { 1515 t.Error("expected panic in handler") 1516 } 1517} 1518 1519// Issue 23010: don't be super strict checking WriteHeader's code if 1520// it's not even valid to call WriteHeader then anyway. 1521func TestWriteHeaderNoCodeCheck(t *testing.T) { 1522 run(t, func(t *testing.T, mode testMode) { 1523 testWriteHeaderAfterWrite(t, mode, false) 1524 }) 1525} 1526func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) { 1527 testWriteHeaderAfterWrite(t, http1Mode, true) 1528} 1529func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) { 1530 var errorLog lockedBytesBuffer 1531 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1532 if hijack { 1533 conn, _, _ := w.(Hijacker).Hijack() 1534 defer conn.Close() 1535 conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo")) 1536 w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010 1537 conn.Write([]byte("bar")) 1538 return 1539 } 1540 io.WriteString(w, "foo") 1541 w.(Flusher).Flush() 1542 w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010 1543 io.WriteString(w, "bar") 1544 }), func(ts *httptest.Server) { 1545 ts.Config.ErrorLog = log.New(&errorLog, "", 0) 1546 }) 1547 res, err := cst.c.Get(cst.ts.URL) 1548 if err != nil { 1549 t.Fatal(err) 1550 } 1551 defer res.Body.Close() 1552 body, err := io.ReadAll(res.Body) 1553 if err != nil { 1554 t.Fatal(err) 1555 } 1556 if got, want := string(body), "foobar"; got != want { 1557 t.Errorf("got = %q; want %q", got, want) 1558 } 1559 1560 // Also check the stderr output: 1561 if mode == http2Mode { 1562 // TODO: also emit this log message for HTTP/2? 1563 // We historically haven't, so don't check. 1564 return 1565 } 1566 gotLog := strings.TrimSpace(errorLog.String()) 1567 wantLog := "http: superfluous response.WriteHeader call from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:" 1568 if hijack { 1569 wantLog = "http: response.WriteHeader on hijacked connection from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:" 1570 } 1571 if !strings.HasPrefix(gotLog, wantLog) { 1572 t.Errorf("stderr output = %q; want %q", gotLog, wantLog) 1573 } 1574} 1575 1576func TestBidiStreamReverseProxy(t *testing.T) { 1577 run(t, testBidiStreamReverseProxy, []testMode{http2Mode}) 1578} 1579func testBidiStreamReverseProxy(t *testing.T, mode testMode) { 1580 backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1581 if _, err := io.Copy(w, r.Body); err != nil { 1582 log.Printf("bidi backend copy: %v", err) 1583 } 1584 })) 1585 1586 backURL, err := url.Parse(backend.ts.URL) 1587 if err != nil { 1588 t.Fatal(err) 1589 } 1590 rp := httputil.NewSingleHostReverseProxy(backURL) 1591 rp.Transport = backend.tr 1592 proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1593 rp.ServeHTTP(w, r) 1594 })) 1595 1596 bodyRes := make(chan any, 1) // error or hash.Hash 1597 pr, pw := io.Pipe() 1598 req, _ := NewRequest("PUT", proxy.ts.URL, pr) 1599 const size = 4 << 20 1600 go func() { 1601 h := sha1.New() 1602 _, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size) 1603 go pw.Close() 1604 if err != nil { 1605 bodyRes <- err 1606 } else { 1607 bodyRes <- h 1608 } 1609 }() 1610 res, err := backend.c.Do(req) 1611 if err != nil { 1612 t.Fatal(err) 1613 } 1614 defer res.Body.Close() 1615 hgot := sha1.New() 1616 n, err := io.Copy(hgot, res.Body) 1617 if err != nil { 1618 t.Fatal(err) 1619 } 1620 if n != size { 1621 t.Fatalf("got %d bytes; want %d", n, size) 1622 } 1623 select { 1624 case v := <-bodyRes: 1625 switch v := v.(type) { 1626 default: 1627 t.Fatalf("body copy: %v", err) 1628 case hash.Hash: 1629 if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) { 1630 t.Errorf("written bytes didn't match received bytes") 1631 } 1632 } 1633 case <-time.After(10 * time.Second): 1634 t.Fatal("timeout") 1635 } 1636 1637} 1638 1639// Always use HTTP/1.1 for WebSocket upgrades. 1640func TestH12_WebSocketUpgrade(t *testing.T) { 1641 h12Compare{ 1642 Handler: func(w ResponseWriter, r *Request) { 1643 h := w.Header() 1644 h.Set("Foo", "bar") 1645 }, 1646 ReqFunc: func(c *Client, url string) (*Response, error) { 1647 req, _ := NewRequest("GET", url, nil) 1648 req.Header.Set("Connection", "Upgrade") 1649 req.Header.Set("Upgrade", "WebSocket") 1650 return c.Do(req) 1651 }, 1652 EarlyCheckResponse: func(proto string, res *Response) { 1653 if res.Proto != "HTTP/1.1" { 1654 t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto) 1655 } 1656 res.Proto = "HTTP/IGNORE" // skip later checks that Proto must be 1.1 vs 2.0 1657 }, 1658 }.run(t) 1659} 1660 1661func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) } 1662func testIdentityTransferEncoding(t *testing.T, mode testMode) { 1663 const body = "body" 1664 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1665 gotBody, _ := io.ReadAll(r.Body) 1666 if got, want := string(gotBody), body; got != want { 1667 t.Errorf("got request body = %q; want %q", got, want) 1668 } 1669 w.Header().Set("Transfer-Encoding", "identity") 1670 w.WriteHeader(StatusOK) 1671 w.(Flusher).Flush() 1672 io.WriteString(w, body) 1673 })) 1674 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body)) 1675 res, err := cst.c.Do(req) 1676 if err != nil { 1677 t.Fatal(err) 1678 } 1679 defer res.Body.Close() 1680 gotBody, err := io.ReadAll(res.Body) 1681 if err != nil { 1682 t.Fatal(err) 1683 } 1684 if got, want := string(gotBody), body; got != want { 1685 t.Errorf("got response body = %q; want %q", got, want) 1686 } 1687} 1688 1689func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) } 1690func testEarlyHintsRequest(t *testing.T, mode testMode) { 1691 var wg sync.WaitGroup 1692 wg.Add(1) 1693 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { 1694 h := w.Header() 1695 1696 h.Add("Content-Length", "123") // must be ignored 1697 h.Add("Link", "</style.css>; rel=preload; as=style") 1698 h.Add("Link", "</script.js>; rel=preload; as=script") 1699 w.WriteHeader(StatusEarlyHints) 1700 1701 wg.Wait() 1702 1703 h.Add("Link", "</foo.js>; rel=preload; as=script") 1704 w.WriteHeader(StatusEarlyHints) 1705 1706 w.Write([]byte("Hello")) 1707 })) 1708 1709 checkLinkHeaders := func(t *testing.T, expected, got []string) { 1710 t.Helper() 1711 1712 if len(expected) != len(got) { 1713 t.Errorf("got %d expected %d", len(got), len(expected)) 1714 } 1715 1716 for i := range expected { 1717 if expected[i] != got[i] { 1718 t.Errorf("got %q expected %q", got[i], expected[i]) 1719 } 1720 } 1721 } 1722 1723 checkExcludedHeaders := func(t *testing.T, header textproto.MIMEHeader) { 1724 t.Helper() 1725 1726 for _, h := range []string{"Content-Length", "Transfer-Encoding"} { 1727 if v, ok := header[h]; ok { 1728 t.Errorf("%s is %q; must not be sent", h, v) 1729 } 1730 } 1731 } 1732 1733 var respCounter uint8 1734 trace := &httptrace.ClientTrace{ 1735 Got1xxResponse: func(code int, header textproto.MIMEHeader) error { 1736 switch respCounter { 1737 case 0: 1738 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"]) 1739 checkExcludedHeaders(t, header) 1740 1741 wg.Done() 1742 case 1: 1743 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"]) 1744 checkExcludedHeaders(t, header) 1745 1746 default: 1747 t.Error("Unexpected 1xx response") 1748 } 1749 1750 respCounter++ 1751 1752 return nil 1753 }, 1754 } 1755 req, _ := NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", cst.ts.URL, nil) 1756 1757 res, err := cst.c.Do(req) 1758 if err != nil { 1759 t.Fatal(err) 1760 } 1761 defer res.Body.Close() 1762 1763 checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"]) 1764 if cl := res.Header.Get("Content-Length"); cl != "123" { 1765 t.Errorf("Content-Length is %q; want 123", cl) 1766 } 1767 1768 body, _ := io.ReadAll(res.Body) 1769 if string(body) != "Hello" { 1770 t.Errorf("Read body %q; want Hello", body) 1771 } 1772} 1773