1// Copyright 2022 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package http_test
6
7import (
8	"errors"
9	"fmt"
10	"io"
11	. "net/http"
12	"os"
13	"sync"
14	"testing"
15	"time"
16)
17
18func TestResponseControllerFlush(t *testing.T) { run(t, testResponseControllerFlush) }
19func testResponseControllerFlush(t *testing.T, mode testMode) {
20	continuec := make(chan struct{})
21	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
22		ctl := NewResponseController(w)
23		w.Write([]byte("one"))
24		if err := ctl.Flush(); err != nil {
25			t.Errorf("ctl.Flush() = %v, want nil", err)
26			return
27		}
28		<-continuec
29		w.Write([]byte("two"))
30	}))
31
32	res, err := cst.c.Get(cst.ts.URL)
33	if err != nil {
34		t.Fatalf("unexpected connection error: %v", err)
35	}
36	defer res.Body.Close()
37
38	buf := make([]byte, 16)
39	n, err := res.Body.Read(buf)
40	close(continuec)
41	if err != nil || string(buf[:n]) != "one" {
42		t.Fatalf("Body.Read = %q, %v, want %q, nil", string(buf[:n]), err, "one")
43	}
44
45	got, err := io.ReadAll(res.Body)
46	if err != nil || string(got) != "two" {
47		t.Fatalf("Body.Read = %q, %v, want %q, nil", string(got), err, "two")
48	}
49}
50
51func TestResponseControllerHijack(t *testing.T) { run(t, testResponseControllerHijack) }
52func testResponseControllerHijack(t *testing.T, mode testMode) {
53	const header = "X-Header"
54	const value = "set"
55	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
56		ctl := NewResponseController(w)
57		c, _, err := ctl.Hijack()
58		if mode == http2Mode {
59			if err == nil {
60				t.Errorf("ctl.Hijack = nil, want error")
61			}
62			w.Header().Set(header, value)
63			return
64		}
65		if err != nil {
66			t.Errorf("ctl.Hijack = _, _, %v, want _, _, nil", err)
67			return
68		}
69		fmt.Fprintf(c, "HTTP/1.0 200 OK\r\n%v: %v\r\nContent-Length: 0\r\n\r\n", header, value)
70	}))
71	res, err := cst.c.Get(cst.ts.URL)
72	if err != nil {
73		t.Fatal(err)
74	}
75	if got, want := res.Header.Get(header), value; got != want {
76		t.Errorf("response header %q = %q, want %q", header, got, want)
77	}
78}
79
80func TestResponseControllerSetPastWriteDeadline(t *testing.T) {
81	run(t, testResponseControllerSetPastWriteDeadline)
82}
83func testResponseControllerSetPastWriteDeadline(t *testing.T, mode testMode) {
84	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
85		ctl := NewResponseController(w)
86		w.Write([]byte("one"))
87		if err := ctl.Flush(); err != nil {
88			t.Errorf("before setting deadline: ctl.Flush() = %v, want nil", err)
89		}
90		if err := ctl.SetWriteDeadline(time.Now().Add(-10 * time.Second)); err != nil {
91			t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
92		}
93
94		w.Write([]byte("two"))
95		if err := ctl.Flush(); err == nil {
96			t.Errorf("after setting deadline: ctl.Flush() = nil, want non-nil")
97		}
98		// Connection errors are sticky, so resetting the deadline does not permit
99		// making more progress. We might want to change this in the future, but verify
100		// the current behavior for now. If we do change this, we'll want to make sure
101		// to do so only for writing the response body, not headers.
102		if err := ctl.SetWriteDeadline(time.Now().Add(1 * time.Hour)); err != nil {
103			t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
104		}
105		w.Write([]byte("three"))
106		if err := ctl.Flush(); err == nil {
107			t.Errorf("after resetting deadline: ctl.Flush() = nil, want non-nil")
108		}
109	}))
110
111	res, err := cst.c.Get(cst.ts.URL)
112	if err != nil {
113		t.Fatalf("unexpected connection error: %v", err)
114	}
115	defer res.Body.Close()
116	b, _ := io.ReadAll(res.Body)
117	if string(b) != "one" {
118		t.Errorf("unexpected body: %q", string(b))
119	}
120}
121
122func TestResponseControllerSetFutureWriteDeadline(t *testing.T) {
123	run(t, testResponseControllerSetFutureWriteDeadline)
124}
125func testResponseControllerSetFutureWriteDeadline(t *testing.T, mode testMode) {
126	errc := make(chan error, 1)
127	startwritec := make(chan struct{})
128	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
129		ctl := NewResponseController(w)
130		w.WriteHeader(200)
131		if err := ctl.Flush(); err != nil {
132			t.Errorf("ctl.Flush() = %v, want nil", err)
133		}
134		<-startwritec // don't set the deadline until the client reads response headers
135		if err := ctl.SetWriteDeadline(time.Now().Add(1 * time.Millisecond)); err != nil {
136			t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
137		}
138		_, err := io.Copy(w, neverEnding('a'))
139		errc <- err
140	}))
141
142	res, err := cst.c.Get(cst.ts.URL)
143	close(startwritec)
144	if err != nil {
145		t.Fatalf("unexpected connection error: %v", err)
146	}
147	defer res.Body.Close()
148	_, err = io.Copy(io.Discard, res.Body)
149	if err == nil {
150		t.Errorf("client reading from truncated request body: got nil error, want non-nil")
151	}
152	err = <-errc // io.Copy error
153	if !errors.Is(err, os.ErrDeadlineExceeded) {
154		t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err)
155	}
156}
157
158func TestResponseControllerSetPastReadDeadline(t *testing.T) {
159	run(t, testResponseControllerSetPastReadDeadline)
160}
161func testResponseControllerSetPastReadDeadline(t *testing.T, mode testMode) {
162	readc := make(chan struct{})
163	donec := make(chan struct{})
164	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
165		defer close(donec)
166		ctl := NewResponseController(w)
167		b := make([]byte, 3)
168		n, err := io.ReadFull(r.Body, b)
169		b = b[:n]
170		if err != nil || string(b) != "one" {
171			t.Errorf("before setting read deadline: Read = %v, %q, want nil, %q", err, string(b), "one")
172			return
173		}
174		if err := ctl.SetReadDeadline(time.Now()); err != nil {
175			t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
176			return
177		}
178		b, err = io.ReadAll(r.Body)
179		if err == nil || string(b) != "" {
180			t.Errorf("after setting read deadline: Read = %q, nil, want error", string(b))
181		}
182		close(readc)
183		// Connection errors are sticky, so resetting the deadline does not permit
184		// making more progress. We might want to change this in the future, but verify
185		// the current behavior for now.
186		if err := ctl.SetReadDeadline(time.Time{}); err != nil {
187			t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
188			return
189		}
190		b, err = io.ReadAll(r.Body)
191		if err == nil {
192			t.Errorf("after resetting read deadline: Read = %q, nil, want error", string(b))
193		}
194	}))
195
196	pr, pw := io.Pipe()
197	var wg sync.WaitGroup
198	wg.Add(1)
199	go func() {
200		defer wg.Done()
201		defer pw.Close()
202		pw.Write([]byte("one"))
203		select {
204		case <-readc:
205		case <-donec:
206			select {
207			case <-readc:
208			default:
209				t.Errorf("server handler unexpectedly exited without closing readc")
210				return
211			}
212		}
213		pw.Write([]byte("two"))
214	}()
215	defer wg.Wait()
216	res, err := cst.c.Post(cst.ts.URL, "text/foo", pr)
217	if err == nil {
218		defer res.Body.Close()
219	}
220}
221
222func TestResponseControllerSetFutureReadDeadline(t *testing.T) {
223	run(t, testResponseControllerSetFutureReadDeadline)
224}
225func testResponseControllerSetFutureReadDeadline(t *testing.T, mode testMode) {
226	respBody := "response body"
227	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
228		ctl := NewResponseController(w)
229		if err := ctl.SetReadDeadline(time.Now().Add(1 * time.Millisecond)); err != nil {
230			t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
231		}
232		_, err := io.Copy(io.Discard, req.Body)
233		if !errors.Is(err, os.ErrDeadlineExceeded) {
234			t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err)
235		}
236		w.Write([]byte(respBody))
237	}))
238	pr, pw := io.Pipe()
239	res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
240	if err != nil {
241		t.Fatal(err)
242	}
243	defer res.Body.Close()
244	got, err := io.ReadAll(res.Body)
245	if string(got) != respBody || err != nil {
246		t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody)
247	}
248	pw.Close()
249}
250
251type wrapWriter struct {
252	ResponseWriter
253}
254
255func (w wrapWriter) Unwrap() ResponseWriter {
256	return w.ResponseWriter
257}
258
259func TestWrappedResponseController(t *testing.T) { run(t, testWrappedResponseController) }
260func testWrappedResponseController(t *testing.T, mode testMode) {
261	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
262		w = wrapWriter{w}
263		ctl := NewResponseController(w)
264		if err := ctl.Flush(); err != nil {
265			t.Errorf("ctl.Flush() = %v, want nil", err)
266		}
267		if err := ctl.SetReadDeadline(time.Time{}); err != nil {
268			t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
269		}
270		if err := ctl.SetWriteDeadline(time.Time{}); err != nil {
271			t.Errorf("ctl.SetWriteDeadline() = %v, want nil", err)
272		}
273	}))
274	res, err := cst.c.Get(cst.ts.URL)
275	if err != nil {
276		t.Fatalf("unexpected connection error: %v", err)
277	}
278	io.Copy(io.Discard, res.Body)
279	defer res.Body.Close()
280}
281
282func TestResponseControllerEnableFullDuplex(t *testing.T) {
283	run(t, testResponseControllerEnableFullDuplex)
284}
285func testResponseControllerEnableFullDuplex(t *testing.T, mode testMode) {
286	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
287		ctl := NewResponseController(w)
288		if err := ctl.EnableFullDuplex(); err != nil {
289			// TODO: Drop test for HTTP/2 when x/net is updated to support
290			// EnableFullDuplex. Since HTTP/2 supports full duplex by default,
291			// the rest of the test is fine; it's just the EnableFullDuplex call
292			// that fails.
293			if mode != http2Mode {
294				t.Errorf("ctl.EnableFullDuplex() = %v, want nil", err)
295			}
296		}
297		w.WriteHeader(200)
298		ctl.Flush()
299		for {
300			var buf [1]byte
301			n, err := req.Body.Read(buf[:])
302			if n != 1 || err != nil {
303				break
304			}
305			w.Write(buf[:])
306			ctl.Flush()
307		}
308	}))
309	pr, pw := io.Pipe()
310	res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
311	if err != nil {
312		t.Fatal(err)
313	}
314	defer res.Body.Close()
315	for i := byte(0); i < 10; i++ {
316		if _, err := pw.Write([]byte{i}); err != nil {
317			t.Fatalf("Write: %v", err)
318		}
319		var buf [1]byte
320		if n, err := res.Body.Read(buf[:]); n != 1 || err != nil {
321			t.Fatalf("Read: %v, %v", n, err)
322		}
323		if buf[0] != i {
324			t.Fatalf("read byte %v, want %v", buf[0], i)
325		}
326	}
327	pw.Close()
328}
329
330func TestIssue58237(t *testing.T) {
331	cst := newClientServerTest(t, http2Mode, HandlerFunc(func(w ResponseWriter, req *Request) {
332		ctl := NewResponseController(w)
333		if err := ctl.SetReadDeadline(time.Now().Add(1 * time.Millisecond)); err != nil {
334			t.Errorf("ctl.SetReadDeadline() = %v, want nil", err)
335		}
336		time.Sleep(10 * time.Millisecond)
337	}))
338	res, err := cst.c.Get(cst.ts.URL)
339	if err != nil {
340		t.Fatal(err)
341	}
342	res.Body.Close()
343}
344