1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package io_test
6
7import (
8	"bytes"
9	"fmt"
10	. "io"
11	"slices"
12	"strings"
13	"testing"
14	"time"
15)
16
17func checkWrite(t *testing.T, w Writer, data []byte, c chan int) {
18	n, err := w.Write(data)
19	if err != nil {
20		t.Errorf("write: %v", err)
21	}
22	if n != len(data) {
23		t.Errorf("short write: %d != %d", n, len(data))
24	}
25	c <- 0
26}
27
28// Test a single read/write pair.
29func TestPipe1(t *testing.T) {
30	c := make(chan int)
31	r, w := Pipe()
32	var buf = make([]byte, 64)
33	go checkWrite(t, w, []byte("hello, world"), c)
34	n, err := r.Read(buf)
35	if err != nil {
36		t.Errorf("read: %v", err)
37	} else if n != 12 || string(buf[0:12]) != "hello, world" {
38		t.Errorf("bad read: got %q", buf[0:n])
39	}
40	<-c
41	r.Close()
42	w.Close()
43}
44
45func reader(t *testing.T, r Reader, c chan int) {
46	var buf = make([]byte, 64)
47	for {
48		n, err := r.Read(buf)
49		if err == EOF {
50			c <- 0
51			break
52		}
53		if err != nil {
54			t.Errorf("read: %v", err)
55		}
56		c <- n
57	}
58}
59
60// Test a sequence of read/write pairs.
61func TestPipe2(t *testing.T) {
62	c := make(chan int)
63	r, w := Pipe()
64	go reader(t, r, c)
65	var buf = make([]byte, 64)
66	for i := 0; i < 5; i++ {
67		p := buf[0 : 5+i*10]
68		n, err := w.Write(p)
69		if n != len(p) {
70			t.Errorf("wrote %d, got %d", len(p), n)
71		}
72		if err != nil {
73			t.Errorf("write: %v", err)
74		}
75		nn := <-c
76		if nn != n {
77			t.Errorf("wrote %d, read got %d", n, nn)
78		}
79	}
80	w.Close()
81	nn := <-c
82	if nn != 0 {
83		t.Errorf("final read got %d", nn)
84	}
85}
86
87type pipeReturn struct {
88	n   int
89	err error
90}
91
92// Test a large write that requires multiple reads to satisfy.
93func writer(w WriteCloser, buf []byte, c chan pipeReturn) {
94	n, err := w.Write(buf)
95	w.Close()
96	c <- pipeReturn{n, err}
97}
98
99func TestPipe3(t *testing.T) {
100	c := make(chan pipeReturn)
101	r, w := Pipe()
102	var wdat = make([]byte, 128)
103	for i := 0; i < len(wdat); i++ {
104		wdat[i] = byte(i)
105	}
106	go writer(w, wdat, c)
107	var rdat = make([]byte, 1024)
108	tot := 0
109	for n := 1; n <= 256; n *= 2 {
110		nn, err := r.Read(rdat[tot : tot+n])
111		if err != nil && err != EOF {
112			t.Fatalf("read: %v", err)
113		}
114
115		// only final two reads should be short - 1 byte, then 0
116		expect := n
117		if n == 128 {
118			expect = 1
119		} else if n == 256 {
120			expect = 0
121			if err != EOF {
122				t.Fatalf("read at end: %v", err)
123			}
124		}
125		if nn != expect {
126			t.Fatalf("read %d, expected %d, got %d", n, expect, nn)
127		}
128		tot += nn
129	}
130	pr := <-c
131	if pr.n != 128 || pr.err != nil {
132		t.Fatalf("write 128: %d, %v", pr.n, pr.err)
133	}
134	if tot != 128 {
135		t.Fatalf("total read %d != 128", tot)
136	}
137	for i := 0; i < 128; i++ {
138		if rdat[i] != byte(i) {
139			t.Fatalf("rdat[%d] = %d", i, rdat[i])
140		}
141	}
142}
143
144// Test read after/before writer close.
145
146type closer interface {
147	CloseWithError(error) error
148	Close() error
149}
150
151type pipeTest struct {
152	async          bool
153	err            error
154	closeWithError bool
155}
156
157func (p pipeTest) String() string {
158	return fmt.Sprintf("async=%v err=%v closeWithError=%v", p.async, p.err, p.closeWithError)
159}
160
161var pipeTests = []pipeTest{
162	{true, nil, false},
163	{true, nil, true},
164	{true, ErrShortWrite, true},
165	{false, nil, false},
166	{false, nil, true},
167	{false, ErrShortWrite, true},
168}
169
170func delayClose(t *testing.T, cl closer, ch chan int, tt pipeTest) {
171	time.Sleep(1 * time.Millisecond)
172	var err error
173	if tt.closeWithError {
174		err = cl.CloseWithError(tt.err)
175	} else {
176		err = cl.Close()
177	}
178	if err != nil {
179		t.Errorf("delayClose: %v", err)
180	}
181	ch <- 0
182}
183
184func TestPipeReadClose(t *testing.T) {
185	for _, tt := range pipeTests {
186		c := make(chan int, 1)
187		r, w := Pipe()
188		if tt.async {
189			go delayClose(t, w, c, tt)
190		} else {
191			delayClose(t, w, c, tt)
192		}
193		var buf = make([]byte, 64)
194		n, err := r.Read(buf)
195		<-c
196		want := tt.err
197		if want == nil {
198			want = EOF
199		}
200		if err != want {
201			t.Errorf("read from closed pipe: %v want %v", err, want)
202		}
203		if n != 0 {
204			t.Errorf("read on closed pipe returned %d", n)
205		}
206		if err = r.Close(); err != nil {
207			t.Errorf("r.Close: %v", err)
208		}
209	}
210}
211
212// Test close on Read side during Read.
213func TestPipeReadClose2(t *testing.T) {
214	c := make(chan int, 1)
215	r, _ := Pipe()
216	go delayClose(t, r, c, pipeTest{})
217	n, err := r.Read(make([]byte, 64))
218	<-c
219	if n != 0 || err != ErrClosedPipe {
220		t.Errorf("read from closed pipe: %v, %v want %v, %v", n, err, 0, ErrClosedPipe)
221	}
222}
223
224// Test write after/before reader close.
225
226func TestPipeWriteClose(t *testing.T) {
227	for _, tt := range pipeTests {
228		c := make(chan int, 1)
229		r, w := Pipe()
230		if tt.async {
231			go delayClose(t, r, c, tt)
232		} else {
233			delayClose(t, r, c, tt)
234		}
235		n, err := WriteString(w, "hello, world")
236		<-c
237		expect := tt.err
238		if expect == nil {
239			expect = ErrClosedPipe
240		}
241		if err != expect {
242			t.Errorf("write on closed pipe: %v want %v", err, expect)
243		}
244		if n != 0 {
245			t.Errorf("write on closed pipe returned %d", n)
246		}
247		if err = w.Close(); err != nil {
248			t.Errorf("w.Close: %v", err)
249		}
250	}
251}
252
253// Test close on Write side during Write.
254func TestPipeWriteClose2(t *testing.T) {
255	c := make(chan int, 1)
256	_, w := Pipe()
257	go delayClose(t, w, c, pipeTest{})
258	n, err := w.Write(make([]byte, 64))
259	<-c
260	if n != 0 || err != ErrClosedPipe {
261		t.Errorf("write to closed pipe: %v, %v want %v, %v", n, err, 0, ErrClosedPipe)
262	}
263}
264
265func TestWriteEmpty(t *testing.T) {
266	r, w := Pipe()
267	go func() {
268		w.Write([]byte{})
269		w.Close()
270	}()
271	var b [2]byte
272	ReadFull(r, b[0:2])
273	r.Close()
274}
275
276func TestWriteNil(t *testing.T) {
277	r, w := Pipe()
278	go func() {
279		w.Write(nil)
280		w.Close()
281	}()
282	var b [2]byte
283	ReadFull(r, b[0:2])
284	r.Close()
285}
286
287func TestWriteAfterWriterClose(t *testing.T) {
288	r, w := Pipe()
289	defer r.Close()
290	done := make(chan bool)
291	var writeErr error
292	go func() {
293		_, err := w.Write([]byte("hello"))
294		if err != nil {
295			t.Errorf("got error: %q; expected none", err)
296		}
297		w.Close()
298		_, writeErr = w.Write([]byte("world"))
299		done <- true
300	}()
301
302	buf := make([]byte, 100)
303	var result string
304	n, err := ReadFull(r, buf)
305	if err != nil && err != ErrUnexpectedEOF {
306		t.Fatalf("got: %q; want: %q", err, ErrUnexpectedEOF)
307	}
308	result = string(buf[0:n])
309	<-done
310
311	if result != "hello" {
312		t.Errorf("got: %q; want: %q", result, "hello")
313	}
314	if writeErr != ErrClosedPipe {
315		t.Errorf("got: %q; want: %q", writeErr, ErrClosedPipe)
316	}
317}
318
319func TestPipeCloseError(t *testing.T) {
320	type testError1 struct{ error }
321	type testError2 struct{ error }
322
323	r, w := Pipe()
324	r.CloseWithError(testError1{})
325	if _, err := w.Write(nil); err != (testError1{}) {
326		t.Errorf("Write error: got %T, want testError1", err)
327	}
328	r.CloseWithError(testError2{})
329	if _, err := w.Write(nil); err != (testError1{}) {
330		t.Errorf("Write error: got %T, want testError1", err)
331	}
332
333	r, w = Pipe()
334	w.CloseWithError(testError1{})
335	if _, err := r.Read(nil); err != (testError1{}) {
336		t.Errorf("Read error: got %T, want testError1", err)
337	}
338	w.CloseWithError(testError2{})
339	if _, err := r.Read(nil); err != (testError1{}) {
340		t.Errorf("Read error: got %T, want testError1", err)
341	}
342}
343
344func TestPipeConcurrent(t *testing.T) {
345	const (
346		input    = "0123456789abcdef"
347		count    = 8
348		readSize = 2
349	)
350
351	t.Run("Write", func(t *testing.T) {
352		r, w := Pipe()
353
354		for i := 0; i < count; i++ {
355			go func() {
356				time.Sleep(time.Millisecond) // Increase probability of race
357				if n, err := w.Write([]byte(input)); n != len(input) || err != nil {
358					t.Errorf("Write() = (%d, %v); want (%d, nil)", n, err, len(input))
359				}
360			}()
361		}
362
363		buf := make([]byte, count*len(input))
364		for i := 0; i < len(buf); i += readSize {
365			if n, err := r.Read(buf[i : i+readSize]); n != readSize || err != nil {
366				t.Errorf("Read() = (%d, %v); want (%d, nil)", n, err, readSize)
367			}
368		}
369
370		// Since each Write is fully gated, if multiple Read calls were needed,
371		// the contents of Write should still appear together in the output.
372		got := string(buf)
373		want := strings.Repeat(input, count)
374		if got != want {
375			t.Errorf("got: %q; want: %q", got, want)
376		}
377	})
378
379	t.Run("Read", func(t *testing.T) {
380		r, w := Pipe()
381
382		c := make(chan []byte, count*len(input)/readSize)
383		for i := 0; i < cap(c); i++ {
384			go func() {
385				time.Sleep(time.Millisecond) // Increase probability of race
386				buf := make([]byte, readSize)
387				if n, err := r.Read(buf); n != readSize || err != nil {
388					t.Errorf("Read() = (%d, %v); want (%d, nil)", n, err, readSize)
389				}
390				c <- buf
391			}()
392		}
393
394		for i := 0; i < count; i++ {
395			if n, err := w.Write([]byte(input)); n != len(input) || err != nil {
396				t.Errorf("Write() = (%d, %v); want (%d, nil)", n, err, len(input))
397			}
398		}
399
400		// Since each read is independent, the only guarantee about the output
401		// is that it is a permutation of the input in readSized groups.
402		got := make([]byte, 0, count*len(input))
403		for i := 0; i < cap(c); i++ {
404			got = append(got, (<-c)...)
405		}
406		got = sortBytesInGroups(got, readSize)
407		want := bytes.Repeat([]byte(input), count)
408		want = sortBytesInGroups(want, readSize)
409		if string(got) != string(want) {
410			t.Errorf("got: %q; want: %q", got, want)
411		}
412	})
413}
414
415func sortBytesInGroups(b []byte, n int) []byte {
416	var groups [][]byte
417	for len(b) > 0 {
418		groups = append(groups, b[:n])
419		b = b[n:]
420	}
421	slices.SortFunc(groups, bytes.Compare)
422	return bytes.Join(groups, nil)
423}
424
425var (
426	rSink *PipeReader
427	wSink *PipeWriter
428)
429
430func TestPipeAllocations(t *testing.T) {
431	numAllocs := testing.AllocsPerRun(10, func() {
432		rSink, wSink = Pipe()
433	})
434
435	// go.dev/cl/473535 claimed Pipe() should only do 2 allocations,
436	// plus the 2 escaping to heap for simulating real world usages.
437	expectedAllocs := 4
438	if int(numAllocs) > expectedAllocs {
439		t.Fatalf("too many allocations for io.Pipe() call: %f", numAllocs)
440	}
441}
442