1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package httptest
6
7import (
8	"bufio"
9	"io"
10	"net"
11	"net/http"
12	"sync"
13	"testing"
14)
15
16type newServerFunc func(http.Handler) *Server
17
18var newServers = map[string]newServerFunc{
19	"NewServer":    NewServer,
20	"NewTLSServer": NewTLSServer,
21
22	// The manual variants of newServer create a Server manually by only filling
23	// in the exported fields of Server.
24	"NewServerManual": func(h http.Handler) *Server {
25		ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}}
26		ts.Start()
27		return ts
28	},
29	"NewTLSServerManual": func(h http.Handler) *Server {
30		ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}}
31		ts.StartTLS()
32		return ts
33	},
34}
35
36func TestServer(t *testing.T) {
37	for _, name := range []string{"NewServer", "NewServerManual"} {
38		t.Run(name, func(t *testing.T) {
39			newServer := newServers[name]
40			t.Run("Server", func(t *testing.T) { testServer(t, newServer) })
41			t.Run("GetAfterClose", func(t *testing.T) { testGetAfterClose(t, newServer) })
42			t.Run("ServerCloseBlocking", func(t *testing.T) { testServerCloseBlocking(t, newServer) })
43			t.Run("ServerCloseClientConnections", func(t *testing.T) { testServerCloseClientConnections(t, newServer) })
44			t.Run("ServerClientTransportType", func(t *testing.T) { testServerClientTransportType(t, newServer) })
45		})
46	}
47	for _, name := range []string{"NewTLSServer", "NewTLSServerManual"} {
48		t.Run(name, func(t *testing.T) {
49			newServer := newServers[name]
50			t.Run("ServerClient", func(t *testing.T) { testServerClient(t, newServer) })
51			t.Run("TLSServerClientTransportType", func(t *testing.T) { testTLSServerClientTransportType(t, newServer) })
52		})
53	}
54}
55
56func testServer(t *testing.T, newServer newServerFunc) {
57	ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
58		w.Write([]byte("hello"))
59	}))
60	defer ts.Close()
61	res, err := http.Get(ts.URL)
62	if err != nil {
63		t.Fatal(err)
64	}
65	got, err := io.ReadAll(res.Body)
66	res.Body.Close()
67	if err != nil {
68		t.Fatal(err)
69	}
70	if string(got) != "hello" {
71		t.Errorf("got %q, want hello", string(got))
72	}
73}
74
75// Issue 12781
76func testGetAfterClose(t *testing.T, newServer newServerFunc) {
77	ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
78		w.Write([]byte("hello"))
79	}))
80
81	res, err := http.Get(ts.URL)
82	if err != nil {
83		t.Fatal(err)
84	}
85	got, err := io.ReadAll(res.Body)
86	res.Body.Close()
87	if err != nil {
88		t.Fatal(err)
89	}
90	if string(got) != "hello" {
91		t.Fatalf("got %q, want hello", string(got))
92	}
93
94	ts.Close()
95
96	res, err = http.Get(ts.URL)
97	if err == nil {
98		body, _ := io.ReadAll(res.Body)
99		t.Fatalf("Unexpected response after close: %v, %v, %s", res.Status, res.Header, body)
100	}
101}
102
103func testServerCloseBlocking(t *testing.T, newServer newServerFunc) {
104	ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
105		w.Write([]byte("hello"))
106	}))
107	dial := func() net.Conn {
108		c, err := net.Dial("tcp", ts.Listener.Addr().String())
109		if err != nil {
110			t.Fatal(err)
111		}
112		return c
113	}
114
115	// Keep one connection in StateNew (connected, but not sending anything)
116	cnew := dial()
117	defer cnew.Close()
118
119	// Keep one connection in StateIdle (idle after a request)
120	cidle := dial()
121	defer cidle.Close()
122	cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n"))
123	_, err := http.ReadResponse(bufio.NewReader(cidle), nil)
124	if err != nil {
125		t.Fatal(err)
126	}
127
128	ts.Close() // test we don't hang here forever.
129}
130
131// Issue 14290
132func testServerCloseClientConnections(t *testing.T, newServer newServerFunc) {
133	var s *Server
134	s = newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
135		s.CloseClientConnections()
136	}))
137	defer s.Close()
138	res, err := http.Get(s.URL)
139	if err == nil {
140		res.Body.Close()
141		t.Fatalf("Unexpected response: %#v", res)
142	}
143}
144
145// Tests that the Server.Client method works and returns an http.Client that can hit
146// NewTLSServer without cert warnings.
147func testServerClient(t *testing.T, newTLSServer newServerFunc) {
148	ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
149		w.Write([]byte("hello"))
150	}))
151	defer ts.Close()
152	client := ts.Client()
153	res, err := client.Get(ts.URL)
154	if err != nil {
155		t.Fatal(err)
156	}
157	got, err := io.ReadAll(res.Body)
158	res.Body.Close()
159	if err != nil {
160		t.Fatal(err)
161	}
162	if string(got) != "hello" {
163		t.Errorf("got %q, want hello", string(got))
164	}
165}
166
167// Tests that the Server.Client.Transport interface is implemented
168// by a *http.Transport.
169func testServerClientTransportType(t *testing.T, newServer newServerFunc) {
170	ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
171	}))
172	defer ts.Close()
173	client := ts.Client()
174	if _, ok := client.Transport.(*http.Transport); !ok {
175		t.Errorf("got %T, want *http.Transport", client.Transport)
176	}
177}
178
179// Tests that the TLS Server.Client.Transport interface is implemented
180// by a *http.Transport.
181func testTLSServerClientTransportType(t *testing.T, newTLSServer newServerFunc) {
182	ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
183	}))
184	defer ts.Close()
185	client := ts.Client()
186	if _, ok := client.Transport.(*http.Transport); !ok {
187		t.Errorf("got %T, want *http.Transport", client.Transport)
188	}
189}
190
191type onlyCloseListener struct {
192	net.Listener
193}
194
195func (onlyCloseListener) Close() error { return nil }
196
197// Issue 19729: panic in Server.Close for values created directly
198// without a constructor (so the unexported client field is nil).
199func TestServerZeroValueClose(t *testing.T) {
200	ts := &Server{
201		Listener: onlyCloseListener{},
202		Config:   &http.Server{},
203	}
204
205	ts.Close() // tests that it doesn't panic
206}
207
208// Issue 51799: test hijacking a connection and then closing it
209// concurrently with closing the server.
210func TestCloseHijackedConnection(t *testing.T) {
211	hijacked := make(chan net.Conn)
212	ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
213		defer close(hijacked)
214		hj, ok := w.(http.Hijacker)
215		if !ok {
216			t.Fatal("failed to hijack")
217		}
218		c, _, err := hj.Hijack()
219		if err != nil {
220			t.Fatal(err)
221		}
222		hijacked <- c
223	}))
224
225	var wg sync.WaitGroup
226	wg.Add(1)
227	go func() {
228		defer wg.Done()
229		req, err := http.NewRequest("GET", ts.URL, nil)
230		if err != nil {
231			t.Log(err)
232		}
233		// Use a client not associated with the Server.
234		var c http.Client
235		resp, err := c.Do(req)
236		if err != nil {
237			t.Log(err)
238			return
239		}
240		resp.Body.Close()
241	}()
242
243	wg.Add(1)
244	conn := <-hijacked
245	go func(conn net.Conn) {
246		defer wg.Done()
247		// Close the connection and then inform the Server that
248		// we closed it.
249		conn.Close()
250		ts.Config.ConnState(conn, http.StateClosed)
251	}(conn)
252
253	wg.Add(1)
254	go func() {
255		defer wg.Done()
256		ts.Close()
257	}()
258	wg.Wait()
259}
260
261func TestTLSServerWithHTTP2(t *testing.T) {
262	modes := []struct {
263		name      string
264		wantProto string
265	}{
266		{"http1", "HTTP/1.1"},
267		{"http2", "HTTP/2.0"},
268	}
269
270	for _, tt := range modes {
271		t.Run(tt.name, func(t *testing.T) {
272			cst := NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
273				w.Header().Set("X-Proto", r.Proto)
274			}))
275
276			switch tt.name {
277			case "http2":
278				cst.EnableHTTP2 = true
279				cst.StartTLS()
280			default:
281				cst.Start()
282			}
283
284			defer cst.Close()
285
286			res, err := cst.Client().Get(cst.URL)
287			if err != nil {
288				t.Fatalf("Failed to make request: %v", err)
289			}
290			if g, w := res.Header.Get("X-Proto"), tt.wantProto; g != w {
291				t.Fatalf("X-Proto header mismatch:\n\tgot:  %q\n\twant: %q", g, w)
292			}
293		})
294	}
295}
296