1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Implementation of Server
6
7package httptest
8
9import (
10	"crypto/tls"
11	"crypto/x509"
12	"flag"
13	"fmt"
14	"log"
15	"net"
16	"net/http"
17	"net/http/internal/testcert"
18	"os"
19	"strings"
20	"sync"
21	"time"
22)
23
24// A Server is an HTTP server listening on a system-chosen port on the
25// local loopback interface, for use in end-to-end HTTP tests.
26type Server struct {
27	URL      string // base URL of form http://ipaddr:port with no trailing slash
28	Listener net.Listener
29
30	// EnableHTTP2 controls whether HTTP/2 is enabled
31	// on the server. It must be set between calling
32	// NewUnstartedServer and calling Server.StartTLS.
33	EnableHTTP2 bool
34
35	// TLS is the optional TLS configuration, populated with a new config
36	// after TLS is started. If set on an unstarted server before StartTLS
37	// is called, existing fields are copied into the new config.
38	TLS *tls.Config
39
40	// Config may be changed after calling NewUnstartedServer and
41	// before Start or StartTLS.
42	Config *http.Server
43
44	// certificate is a parsed version of the TLS config certificate, if present.
45	certificate *x509.Certificate
46
47	// wg counts the number of outstanding HTTP requests on this server.
48	// Close blocks until all requests are finished.
49	wg sync.WaitGroup
50
51	mu     sync.Mutex // guards closed and conns
52	closed bool
53	conns  map[net.Conn]http.ConnState // except terminal states
54
55	// client is configured for use with the server.
56	// Its transport is automatically closed when Close is called.
57	client *http.Client
58}
59
60func newLocalListener() net.Listener {
61	if serveFlag != "" {
62		l, err := net.Listen("tcp", serveFlag)
63		if err != nil {
64			panic(fmt.Sprintf("httptest: failed to listen on %v: %v", serveFlag, err))
65		}
66		return l
67	}
68	l, err := net.Listen("tcp", "127.0.0.1:0")
69	if err != nil {
70		if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
71			panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
72		}
73	}
74	return l
75}
76
77// When debugging a particular http server-based test,
78// this flag lets you run
79//
80//	go test -run='^BrokenTest$' -httptest.serve=127.0.0.1:8000
81//
82// to start the broken server so you can interact with it manually.
83// We only register this flag if it looks like the caller knows about it
84// and is trying to use it as we don't want to pollute flags and this
85// isn't really part of our API. Don't depend on this.
86var serveFlag string
87
88func init() {
89	if strSliceContainsPrefix(os.Args, "-httptest.serve=") || strSliceContainsPrefix(os.Args, "--httptest.serve=") {
90		flag.StringVar(&serveFlag, "httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks.")
91	}
92}
93
94func strSliceContainsPrefix(v []string, pre string) bool {
95	for _, s := range v {
96		if strings.HasPrefix(s, pre) {
97			return true
98		}
99	}
100	return false
101}
102
103// NewServer starts and returns a new [Server].
104// The caller should call Close when finished, to shut it down.
105func NewServer(handler http.Handler) *Server {
106	ts := NewUnstartedServer(handler)
107	ts.Start()
108	return ts
109}
110
111// NewUnstartedServer returns a new [Server] but doesn't start it.
112//
113// After changing its configuration, the caller should call Start or
114// StartTLS.
115//
116// The caller should call Close when finished, to shut it down.
117func NewUnstartedServer(handler http.Handler) *Server {
118	return &Server{
119		Listener: newLocalListener(),
120		Config:   &http.Server{Handler: handler},
121	}
122}
123
124// Start starts a server from NewUnstartedServer.
125func (s *Server) Start() {
126	if s.URL != "" {
127		panic("Server already started")
128	}
129	if s.client == nil {
130		s.client = &http.Client{Transport: &http.Transport{}}
131	}
132	s.URL = "http://" + s.Listener.Addr().String()
133	s.wrap()
134	s.goServe()
135	if serveFlag != "" {
136		fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
137		select {}
138	}
139}
140
141// StartTLS starts TLS on a server from NewUnstartedServer.
142func (s *Server) StartTLS() {
143	if s.URL != "" {
144		panic("Server already started")
145	}
146	if s.client == nil {
147		s.client = &http.Client{}
148	}
149	cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
150	if err != nil {
151		panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
152	}
153
154	existingConfig := s.TLS
155	if existingConfig != nil {
156		s.TLS = existingConfig.Clone()
157	} else {
158		s.TLS = new(tls.Config)
159	}
160	if s.TLS.NextProtos == nil {
161		nextProtos := []string{"http/1.1"}
162		if s.EnableHTTP2 {
163			nextProtos = []string{"h2"}
164		}
165		s.TLS.NextProtos = nextProtos
166	}
167	if len(s.TLS.Certificates) == 0 {
168		s.TLS.Certificates = []tls.Certificate{cert}
169	}
170	s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0])
171	if err != nil {
172		panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
173	}
174	certpool := x509.NewCertPool()
175	certpool.AddCert(s.certificate)
176	s.client.Transport = &http.Transport{
177		TLSClientConfig: &tls.Config{
178			RootCAs: certpool,
179		},
180		ForceAttemptHTTP2: s.EnableHTTP2,
181	}
182	s.Listener = tls.NewListener(s.Listener, s.TLS)
183	s.URL = "https://" + s.Listener.Addr().String()
184	s.wrap()
185	s.goServe()
186}
187
188// NewTLSServer starts and returns a new [Server] using TLS.
189// The caller should call Close when finished, to shut it down.
190func NewTLSServer(handler http.Handler) *Server {
191	ts := NewUnstartedServer(handler)
192	ts.StartTLS()
193	return ts
194}
195
196type closeIdleTransport interface {
197	CloseIdleConnections()
198}
199
200// Close shuts down the server and blocks until all outstanding
201// requests on this server have completed.
202func (s *Server) Close() {
203	s.mu.Lock()
204	if !s.closed {
205		s.closed = true
206		s.Listener.Close()
207		s.Config.SetKeepAlivesEnabled(false)
208		for c, st := range s.conns {
209			// Force-close any idle connections (those between
210			// requests) and new connections (those which connected
211			// but never sent a request). StateNew connections are
212			// super rare and have only been seen (in
213			// previously-flaky tests) in the case of
214			// socket-late-binding races from the http Client
215			// dialing this server and then getting an idle
216			// connection before the dial completed. There is thus
217			// a connected connection in StateNew with no
218			// associated Request. We only close StateIdle and
219			// StateNew because they're not doing anything. It's
220			// possible StateNew is about to do something in a few
221			// milliseconds, but a previous CL to check again in a
222			// few milliseconds wasn't liked (early versions of
223			// https://golang.org/cl/15151) so now we just
224			// forcefully close StateNew. The docs for Server.Close say
225			// we wait for "outstanding requests", so we don't close things
226			// in StateActive.
227			if st == http.StateIdle || st == http.StateNew {
228				s.closeConn(c)
229			}
230		}
231		// If this server doesn't shut down in 5 seconds, tell the user why.
232		t := time.AfterFunc(5*time.Second, s.logCloseHangDebugInfo)
233		defer t.Stop()
234	}
235	s.mu.Unlock()
236
237	// Not part of httptest.Server's correctness, but assume most
238	// users of httptest.Server will be using the standard
239	// transport, so help them out and close any idle connections for them.
240	if t, ok := http.DefaultTransport.(closeIdleTransport); ok {
241		t.CloseIdleConnections()
242	}
243
244	// Also close the client idle connections.
245	if s.client != nil {
246		if t, ok := s.client.Transport.(closeIdleTransport); ok {
247			t.CloseIdleConnections()
248		}
249	}
250
251	s.wg.Wait()
252}
253
254func (s *Server) logCloseHangDebugInfo() {
255	s.mu.Lock()
256	defer s.mu.Unlock()
257	var buf strings.Builder
258	buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n")
259	for c, st := range s.conns {
260		fmt.Fprintf(&buf, "  %T %p %v in state %v\n", c, c, c.RemoteAddr(), st)
261	}
262	log.Print(buf.String())
263}
264
265// CloseClientConnections closes any open HTTP connections to the test Server.
266func (s *Server) CloseClientConnections() {
267	s.mu.Lock()
268	nconn := len(s.conns)
269	ch := make(chan struct{}, nconn)
270	for c := range s.conns {
271		go s.closeConnChan(c, ch)
272	}
273	s.mu.Unlock()
274
275	// Wait for outstanding closes to finish.
276	//
277	// Out of paranoia for making a late change in Go 1.6, we
278	// bound how long this can wait, since golang.org/issue/14291
279	// isn't fully understood yet. At least this should only be used
280	// in tests.
281	timer := time.NewTimer(5 * time.Second)
282	defer timer.Stop()
283	for i := 0; i < nconn; i++ {
284		select {
285		case <-ch:
286		case <-timer.C:
287			// Too slow. Give up.
288			return
289		}
290	}
291}
292
293// Certificate returns the certificate used by the server, or nil if
294// the server doesn't use TLS.
295func (s *Server) Certificate() *x509.Certificate {
296	return s.certificate
297}
298
299// Client returns an HTTP client configured for making requests to the server.
300// It is configured to trust the server's TLS test certificate and will
301// close its idle connections on [Server.Close].
302// Use Server.URL as the base URL to send requests to the server.
303func (s *Server) Client() *http.Client {
304	return s.client
305}
306
307func (s *Server) goServe() {
308	s.wg.Add(1)
309	go func() {
310		defer s.wg.Done()
311		s.Config.Serve(s.Listener)
312	}()
313}
314
315// wrap installs the connection state-tracking hook to know which
316// connections are idle.
317func (s *Server) wrap() {
318	oldHook := s.Config.ConnState
319	s.Config.ConnState = func(c net.Conn, cs http.ConnState) {
320		s.mu.Lock()
321		defer s.mu.Unlock()
322
323		switch cs {
324		case http.StateNew:
325			if _, exists := s.conns[c]; exists {
326				panic("invalid state transition")
327			}
328			if s.conns == nil {
329				s.conns = make(map[net.Conn]http.ConnState)
330			}
331			// Add c to the set of tracked conns and increment it to the
332			// waitgroup.
333			s.wg.Add(1)
334			s.conns[c] = cs
335			if s.closed {
336				// Probably just a socket-late-binding dial from
337				// the default transport that lost the race (and
338				// thus this connection is now idle and will
339				// never be used).
340				s.closeConn(c)
341			}
342		case http.StateActive:
343			if oldState, ok := s.conns[c]; ok {
344				if oldState != http.StateNew && oldState != http.StateIdle {
345					panic("invalid state transition")
346				}
347				s.conns[c] = cs
348			}
349		case http.StateIdle:
350			if oldState, ok := s.conns[c]; ok {
351				if oldState != http.StateActive {
352					panic("invalid state transition")
353				}
354				s.conns[c] = cs
355			}
356			if s.closed {
357				s.closeConn(c)
358			}
359		case http.StateHijacked, http.StateClosed:
360			// Remove c from the set of tracked conns and decrement it from the
361			// waitgroup, unless it was previously removed.
362			if _, ok := s.conns[c]; ok {
363				delete(s.conns, c)
364				// Keep Close from returning until the user's ConnState hook
365				// (if any) finishes.
366				defer s.wg.Done()
367			}
368		}
369		if oldHook != nil {
370			oldHook(c, cs)
371		}
372	}
373}
374
375// closeConn closes c.
376// s.mu must be held.
377func (s *Server) closeConn(c net.Conn) { s.closeConnChan(c, nil) }
378
379// closeConnChan is like closeConn, but takes an optional channel to receive a value
380// when the goroutine closing c is done.
381func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) {
382	c.Close()
383	if done != nil {
384		done <- struct{}{}
385	}
386}
387