1// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package os_test
6
7import (
8	"bytes"
9	"errors"
10	"internal/poll"
11	"internal/testpty"
12	"io"
13	"math/rand"
14	"net"
15	. "os"
16	"path/filepath"
17	"runtime"
18	"strconv"
19	"strings"
20	"sync"
21	"syscall"
22	"testing"
23	"time"
24
25	"golang.org/x/net/nettest"
26)
27
28func TestCopyFileRange(t *testing.T) {
29	sizes := []int{
30		1,
31		42,
32		1025,
33		syscall.Getpagesize() + 1,
34		32769,
35	}
36	t.Run("Basic", func(t *testing.T) {
37		for _, size := range sizes {
38			t.Run(strconv.Itoa(size), func(t *testing.T) {
39				testCopyFileRange(t, int64(size), -1)
40			})
41		}
42	})
43	t.Run("Limited", func(t *testing.T) {
44		t.Run("OneLess", func(t *testing.T) {
45			for _, size := range sizes {
46				t.Run(strconv.Itoa(size), func(t *testing.T) {
47					testCopyFileRange(t, int64(size), int64(size)-1)
48				})
49			}
50		})
51		t.Run("Half", func(t *testing.T) {
52			for _, size := range sizes {
53				t.Run(strconv.Itoa(size), func(t *testing.T) {
54					testCopyFileRange(t, int64(size), int64(size)/2)
55				})
56			}
57		})
58		t.Run("More", func(t *testing.T) {
59			for _, size := range sizes {
60				t.Run(strconv.Itoa(size), func(t *testing.T) {
61					testCopyFileRange(t, int64(size), int64(size)+7)
62				})
63			}
64		})
65	})
66	t.Run("DoesntTryInAppendMode", func(t *testing.T) {
67		dst, src, data, hook := newCopyFileRangeTest(t, 42)
68
69		dst2, err := OpenFile(dst.Name(), O_RDWR|O_APPEND, 0755)
70		if err != nil {
71			t.Fatal(err)
72		}
73		defer dst2.Close()
74
75		if _, err := io.Copy(dst2, src); err != nil {
76			t.Fatal(err)
77		}
78		if hook.called {
79			t.Fatal("called poll.CopyFileRange for destination in O_APPEND mode")
80		}
81		mustSeekStart(t, dst2)
82		mustContainData(t, dst2, data) // through traditional means
83	})
84	t.Run("CopyFileItself", func(t *testing.T) {
85		hook := hookCopyFileRange(t)
86
87		f, err := CreateTemp("", "file-readfrom-itself-test")
88		if err != nil {
89			t.Fatalf("failed to create tmp file: %v", err)
90		}
91		t.Cleanup(func() {
92			f.Close()
93			Remove(f.Name())
94		})
95
96		data := []byte("hello world!")
97		if _, err := f.Write(data); err != nil {
98			t.Fatalf("failed to create and feed the file: %v", err)
99		}
100
101		if err := f.Sync(); err != nil {
102			t.Fatalf("failed to save the file: %v", err)
103		}
104
105		// Rewind it.
106		if _, err := f.Seek(0, io.SeekStart); err != nil {
107			t.Fatalf("failed to rewind the file: %v", err)
108		}
109
110		// Read data from the file itself.
111		if _, err := io.Copy(f, f); err != nil {
112			t.Fatalf("failed to read from the file: %v", err)
113		}
114
115		if !hook.called || hook.written != 0 || hook.handled || hook.err != nil {
116			t.Fatalf("poll.CopyFileRange should be called and return the EINVAL error, but got hook.called=%t, hook.err=%v", hook.called, hook.err)
117		}
118
119		// Rewind it.
120		if _, err := f.Seek(0, io.SeekStart); err != nil {
121			t.Fatalf("failed to rewind the file: %v", err)
122		}
123
124		data2, err := io.ReadAll(f)
125		if err != nil {
126			t.Fatalf("failed to read from the file: %v", err)
127		}
128
129		// It should wind up a double of the original data.
130		if strings.Repeat(string(data), 2) != string(data2) {
131			t.Fatalf("data mismatch: %s != %s", string(data), string(data2))
132		}
133	})
134	t.Run("NotRegular", func(t *testing.T) {
135		t.Run("BothPipes", func(t *testing.T) {
136			hook := hookCopyFileRange(t)
137
138			pr1, pw1, err := Pipe()
139			if err != nil {
140				t.Fatal(err)
141			}
142			defer pr1.Close()
143			defer pw1.Close()
144
145			pr2, pw2, err := Pipe()
146			if err != nil {
147				t.Fatal(err)
148			}
149			defer pr2.Close()
150			defer pw2.Close()
151
152			// The pipe is empty, and PIPE_BUF is large enough
153			// for this, by (POSIX) definition, so there is no
154			// need for an additional goroutine.
155			data := []byte("hello")
156			if _, err := pw1.Write(data); err != nil {
157				t.Fatal(err)
158			}
159			pw1.Close()
160
161			n, err := io.Copy(pw2, pr1)
162			if err != nil {
163				t.Fatal(err)
164			}
165			if n != int64(len(data)) {
166				t.Fatalf("transferred %d, want %d", n, len(data))
167			}
168			if !hook.called {
169				t.Fatalf("should have called poll.CopyFileRange")
170			}
171			pw2.Close()
172			mustContainData(t, pr2, data)
173		})
174		t.Run("DstPipe", func(t *testing.T) {
175			dst, src, data, hook := newCopyFileRangeTest(t, 255)
176			dst.Close()
177
178			pr, pw, err := Pipe()
179			if err != nil {
180				t.Fatal(err)
181			}
182			defer pr.Close()
183			defer pw.Close()
184
185			n, err := io.Copy(pw, src)
186			if err != nil {
187				t.Fatal(err)
188			}
189			if n != int64(len(data)) {
190				t.Fatalf("transferred %d, want %d", n, len(data))
191			}
192			if !hook.called {
193				t.Fatalf("should have called poll.CopyFileRange")
194			}
195			pw.Close()
196			mustContainData(t, pr, data)
197		})
198		t.Run("SrcPipe", func(t *testing.T) {
199			dst, src, data, hook := newCopyFileRangeTest(t, 255)
200			src.Close()
201
202			pr, pw, err := Pipe()
203			if err != nil {
204				t.Fatal(err)
205			}
206			defer pr.Close()
207			defer pw.Close()
208
209			// The pipe is empty, and PIPE_BUF is large enough
210			// for this, by (POSIX) definition, so there is no
211			// need for an additional goroutine.
212			if _, err := pw.Write(data); err != nil {
213				t.Fatal(err)
214			}
215			pw.Close()
216
217			n, err := io.Copy(dst, pr)
218			if err != nil {
219				t.Fatal(err)
220			}
221			if n != int64(len(data)) {
222				t.Fatalf("transferred %d, want %d", n, len(data))
223			}
224			if !hook.called {
225				t.Fatalf("should have called poll.CopyFileRange")
226			}
227			mustSeekStart(t, dst)
228			mustContainData(t, dst, data)
229		})
230	})
231	t.Run("Nil", func(t *testing.T) {
232		var nilFile *File
233		anyFile, err := CreateTemp("", "")
234		if err != nil {
235			t.Fatal(err)
236		}
237		defer Remove(anyFile.Name())
238		defer anyFile.Close()
239
240		if _, err := io.Copy(nilFile, nilFile); err != ErrInvalid {
241			t.Errorf("io.Copy(nilFile, nilFile) = %v, want %v", err, ErrInvalid)
242		}
243		if _, err := io.Copy(anyFile, nilFile); err != ErrInvalid {
244			t.Errorf("io.Copy(anyFile, nilFile) = %v, want %v", err, ErrInvalid)
245		}
246		if _, err := io.Copy(nilFile, anyFile); err != ErrInvalid {
247			t.Errorf("io.Copy(nilFile, anyFile) = %v, want %v", err, ErrInvalid)
248		}
249
250		if _, err := nilFile.ReadFrom(nilFile); err != ErrInvalid {
251			t.Errorf("nilFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid)
252		}
253		if _, err := anyFile.ReadFrom(nilFile); err != ErrInvalid {
254			t.Errorf("anyFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid)
255		}
256		if _, err := nilFile.ReadFrom(anyFile); err != ErrInvalid {
257			t.Errorf("nilFile.ReadFrom(anyFile) = %v, want %v", err, ErrInvalid)
258		}
259	})
260}
261
262func TestSpliceFile(t *testing.T) {
263	sizes := []int{
264		1,
265		42,
266		1025,
267		syscall.Getpagesize() + 1,
268		32769,
269	}
270	t.Run("Basic-TCP", func(t *testing.T) {
271		for _, size := range sizes {
272			t.Run(strconv.Itoa(size), func(t *testing.T) {
273				testSpliceFile(t, "tcp", int64(size), -1)
274			})
275		}
276	})
277	t.Run("Basic-Unix", func(t *testing.T) {
278		for _, size := range sizes {
279			t.Run(strconv.Itoa(size), func(t *testing.T) {
280				testSpliceFile(t, "unix", int64(size), -1)
281			})
282		}
283	})
284	t.Run("TCP-To-TTY", func(t *testing.T) {
285		testSpliceToTTY(t, "tcp", 32768)
286	})
287	t.Run("Unix-To-TTY", func(t *testing.T) {
288		testSpliceToTTY(t, "unix", 32768)
289	})
290	t.Run("Limited", func(t *testing.T) {
291		t.Run("OneLess-TCP", func(t *testing.T) {
292			for _, size := range sizes {
293				t.Run(strconv.Itoa(size), func(t *testing.T) {
294					testSpliceFile(t, "tcp", int64(size), int64(size)-1)
295				})
296			}
297		})
298		t.Run("OneLess-Unix", func(t *testing.T) {
299			for _, size := range sizes {
300				t.Run(strconv.Itoa(size), func(t *testing.T) {
301					testSpliceFile(t, "unix", int64(size), int64(size)-1)
302				})
303			}
304		})
305		t.Run("Half-TCP", func(t *testing.T) {
306			for _, size := range sizes {
307				t.Run(strconv.Itoa(size), func(t *testing.T) {
308					testSpliceFile(t, "tcp", int64(size), int64(size)/2)
309				})
310			}
311		})
312		t.Run("Half-Unix", func(t *testing.T) {
313			for _, size := range sizes {
314				t.Run(strconv.Itoa(size), func(t *testing.T) {
315					testSpliceFile(t, "unix", int64(size), int64(size)/2)
316				})
317			}
318		})
319		t.Run("More-TCP", func(t *testing.T) {
320			for _, size := range sizes {
321				t.Run(strconv.Itoa(size), func(t *testing.T) {
322					testSpliceFile(t, "tcp", int64(size), int64(size)+1)
323				})
324			}
325		})
326		t.Run("More-Unix", func(t *testing.T) {
327			for _, size := range sizes {
328				t.Run(strconv.Itoa(size), func(t *testing.T) {
329					testSpliceFile(t, "unix", int64(size), int64(size)+1)
330				})
331			}
332		})
333	})
334}
335
336func testSpliceFile(t *testing.T, proto string, size, limit int64) {
337	dst, src, data, hook, cleanup := newSpliceFileTest(t, proto, size)
338	defer cleanup()
339
340	// If we have a limit, wrap the reader.
341	var (
342		r  io.Reader
343		lr *io.LimitedReader
344	)
345	if limit >= 0 {
346		lr = &io.LimitedReader{N: limit, R: src}
347		r = lr
348		if limit < int64(len(data)) {
349			data = data[:limit]
350		}
351	} else {
352		r = src
353	}
354	// Now call ReadFrom (through io.Copy), which will hopefully call poll.Splice
355	n, err := io.Copy(dst, r)
356	if err != nil {
357		t.Fatal(err)
358	}
359
360	// We should have called poll.Splice with the right file descriptor arguments.
361	if n > 0 && !hook.called {
362		t.Fatal("expected to called poll.Splice")
363	}
364	if hook.called && hook.dstfd != int(dst.Fd()) {
365		t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd())
366	}
367	sc, ok := src.(syscall.Conn)
368	if !ok {
369		t.Fatalf("server Conn is not a syscall.Conn")
370	}
371	rc, err := sc.SyscallConn()
372	if err != nil {
373		t.Fatalf("server Conn SyscallConn error: %v", err)
374	}
375	if err = rc.Control(func(fd uintptr) {
376		if hook.called && hook.srcfd != int(fd) {
377			t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, int(fd))
378		}
379	}); err != nil {
380		t.Fatalf("server Conn Control error: %v", err)
381	}
382
383	// Check that the offsets after the transfer make sense, that the size
384	// of the transfer was reported correctly, and that the destination
385	// file contains exactly the bytes we expect it to contain.
386	dstoff, err := dst.Seek(0, io.SeekCurrent)
387	if err != nil {
388		t.Fatal(err)
389	}
390	if dstoff != int64(len(data)) {
391		t.Errorf("dstoff = %d, want %d", dstoff, len(data))
392	}
393	if n != int64(len(data)) {
394		t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data))
395	}
396	mustSeekStart(t, dst)
397	mustContainData(t, dst, data)
398
399	// If we had a limit, check that it was updated.
400	if lr != nil {
401		if want := limit - n; lr.N != want {
402			t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want)
403		}
404	}
405}
406
407// Issue #59041.
408func testSpliceToTTY(t *testing.T, proto string, size int64) {
409	var wg sync.WaitGroup
410
411	// Call wg.Wait as the final deferred function,
412	// because the goroutines may block until some of
413	// the deferred Close calls.
414	defer wg.Wait()
415
416	pty, ttyName, err := testpty.Open()
417	if err != nil {
418		t.Skipf("skipping test because pty open failed: %v", err)
419	}
420	defer pty.Close()
421
422	// Open the tty directly, rather than via OpenFile.
423	// This bypasses the non-blocking support and is required
424	// to recreate the problem in the issue (#59041).
425	ttyFD, err := syscall.Open(ttyName, syscall.O_RDWR, 0)
426	if err != nil {
427		t.Skipf("skipping test because failed to open tty: %v", err)
428	}
429	defer syscall.Close(ttyFD)
430
431	tty := NewFile(uintptr(ttyFD), "tty")
432	defer tty.Close()
433
434	client, server := createSocketPair(t, proto)
435
436	data := bytes.Repeat([]byte{'a'}, int(size))
437
438	wg.Add(1)
439	go func() {
440		defer wg.Done()
441		// The problem (issue #59041) occurs when writing
442		// a series of blocks of data. It does not occur
443		// when all the data is written at once.
444		for i := 0; i < len(data); i += 1024 {
445			if _, err := client.Write(data[i : i+1024]); err != nil {
446				// If we get here because the client was
447				// closed, skip the error.
448				if !errors.Is(err, net.ErrClosed) {
449					t.Errorf("error writing to socket: %v", err)
450				}
451				return
452			}
453		}
454		client.Close()
455	}()
456
457	wg.Add(1)
458	go func() {
459		defer wg.Done()
460		buf := make([]byte, 32)
461		for {
462			if _, err := pty.Read(buf); err != nil {
463				if err != io.EOF && !errors.Is(err, ErrClosed) {
464					// An error here doesn't matter for
465					// our test.
466					t.Logf("error reading from pty: %v", err)
467				}
468				return
469			}
470		}
471	}()
472
473	// Close Client to wake up the writing goroutine if necessary.
474	defer client.Close()
475
476	_, err = io.Copy(tty, server)
477	if err != nil {
478		t.Fatal(err)
479	}
480}
481
482func testCopyFileRange(t *testing.T, size int64, limit int64) {
483	dst, src, data, hook := newCopyFileRangeTest(t, size)
484
485	// If we have a limit, wrap the reader.
486	var (
487		realsrc io.Reader
488		lr      *io.LimitedReader
489	)
490	if limit >= 0 {
491		lr = &io.LimitedReader{N: limit, R: src}
492		realsrc = lr
493		if limit < int64(len(data)) {
494			data = data[:limit]
495		}
496	} else {
497		realsrc = src
498	}
499
500	// Now call ReadFrom (through io.Copy), which will hopefully call
501	// poll.CopyFileRange.
502	n, err := io.Copy(dst, realsrc)
503	if err != nil {
504		t.Fatal(err)
505	}
506
507	// If we didn't have a limit, we should have called poll.CopyFileRange
508	// with the right file descriptor arguments.
509	if limit > 0 && !hook.called {
510		t.Fatal("never called poll.CopyFileRange")
511	}
512	if hook.called && hook.dstfd != int(dst.Fd()) {
513		t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd())
514	}
515	if hook.called && hook.srcfd != int(src.Fd()) {
516		t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd())
517	}
518
519	// Check that the offsets after the transfer make sense, that the size
520	// of the transfer was reported correctly, and that the destination
521	// file contains exactly the bytes we expect it to contain.
522	dstoff, err := dst.Seek(0, io.SeekCurrent)
523	if err != nil {
524		t.Fatal(err)
525	}
526	srcoff, err := src.Seek(0, io.SeekCurrent)
527	if err != nil {
528		t.Fatal(err)
529	}
530	if dstoff != srcoff {
531		t.Errorf("offsets differ: dstoff = %d, srcoff = %d", dstoff, srcoff)
532	}
533	if dstoff != int64(len(data)) {
534		t.Errorf("dstoff = %d, want %d", dstoff, len(data))
535	}
536	if n != int64(len(data)) {
537		t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data))
538	}
539	mustSeekStart(t, dst)
540	mustContainData(t, dst, data)
541
542	// If we had a limit, check that it was updated.
543	if lr != nil {
544		if want := limit - n; lr.N != want {
545			t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want)
546		}
547	}
548}
549
550// newCopyFileRangeTest initializes a new test for copy_file_range.
551//
552// It creates source and destination files, and populates the source file
553// with random data of the specified size. It also hooks package os' call
554// to poll.CopyFileRange and returns the hook so it can be inspected.
555func newCopyFileRangeTest(t *testing.T, size int64) (dst, src *File, data []byte, hook *copyFileRangeHook) {
556	t.Helper()
557
558	hook = hookCopyFileRange(t)
559	tmp := t.TempDir()
560
561	src, err := Create(filepath.Join(tmp, "src"))
562	if err != nil {
563		t.Fatal(err)
564	}
565	t.Cleanup(func() { src.Close() })
566
567	dst, err = Create(filepath.Join(tmp, "dst"))
568	if err != nil {
569		t.Fatal(err)
570	}
571	t.Cleanup(func() { dst.Close() })
572
573	// Populate the source file with data, then rewind it, so it can be
574	// consumed by copy_file_range(2).
575	prng := rand.New(rand.NewSource(time.Now().Unix()))
576	data = make([]byte, size)
577	prng.Read(data)
578	if _, err := src.Write(data); err != nil {
579		t.Fatal(err)
580	}
581	if _, err := src.Seek(0, io.SeekStart); err != nil {
582		t.Fatal(err)
583	}
584
585	return dst, src, data, hook
586}
587
588// newSpliceFileTest initializes a new test for splice.
589//
590// It creates source sockets and destination file, and populates the source sockets
591// with random data of the specified size. It also hooks package os' call
592// to poll.Splice and returns the hook so it can be inspected.
593func newSpliceFileTest(t *testing.T, proto string, size int64) (*File, net.Conn, []byte, *spliceFileHook, func()) {
594	t.Helper()
595
596	hook := hookSpliceFile(t)
597
598	client, server := createSocketPair(t, proto)
599
600	dst, err := CreateTemp(t.TempDir(), "dst-splice-file-test")
601	if err != nil {
602		t.Fatal(err)
603	}
604	t.Cleanup(func() { dst.Close() })
605
606	randSeed := time.Now().Unix()
607	t.Logf("random data seed: %d\n", randSeed)
608	prng := rand.New(rand.NewSource(randSeed))
609	data := make([]byte, size)
610	prng.Read(data)
611
612	done := make(chan struct{})
613	go func() {
614		client.Write(data)
615		client.Close()
616		close(done)
617	}()
618
619	return dst, server, data, hook, func() { <-done }
620}
621
622// mustContainData ensures that the specified file contains exactly the
623// specified data.
624func mustContainData(t *testing.T, f *File, data []byte) {
625	t.Helper()
626
627	got := make([]byte, len(data))
628	if _, err := io.ReadFull(f, got); err != nil {
629		t.Fatal(err)
630	}
631	if !bytes.Equal(got, data) {
632		t.Fatalf("didn't get the same data back from %s", f.Name())
633	}
634	if _, err := f.Read(make([]byte, 1)); err != io.EOF {
635		t.Fatalf("not at EOF")
636	}
637}
638
639func mustSeekStart(t *testing.T, f *File) {
640	if _, err := f.Seek(0, io.SeekStart); err != nil {
641		t.Fatal(err)
642	}
643}
644
645func hookCopyFileRange(t *testing.T) *copyFileRangeHook {
646	h := new(copyFileRangeHook)
647	h.install()
648	t.Cleanup(h.uninstall)
649	return h
650}
651
652type copyFileRangeHook struct {
653	called bool
654	dstfd  int
655	srcfd  int
656	remain int64
657
658	written int64
659	handled bool
660	err     error
661
662	original func(dst, src *poll.FD, remain int64) (int64, bool, error)
663}
664
665func (h *copyFileRangeHook) install() {
666	h.original = *PollCopyFileRangeP
667	*PollCopyFileRangeP = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
668		h.called = true
669		h.dstfd = dst.Sysfd
670		h.srcfd = src.Sysfd
671		h.remain = remain
672		h.written, h.handled, h.err = h.original(dst, src, remain)
673		return h.written, h.handled, h.err
674	}
675}
676
677func (h *copyFileRangeHook) uninstall() {
678	*PollCopyFileRangeP = h.original
679}
680
681func hookSpliceFile(t *testing.T) *spliceFileHook {
682	h := new(spliceFileHook)
683	h.install()
684	t.Cleanup(h.uninstall)
685	return h
686}
687
688type spliceFileHook struct {
689	called bool
690	dstfd  int
691	srcfd  int
692	remain int64
693
694	written int64
695	handled bool
696	err     error
697
698	original func(dst, src *poll.FD, remain int64) (int64, bool, error)
699}
700
701func (h *spliceFileHook) install() {
702	h.original = *PollSpliceFile
703	*PollSpliceFile = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
704		h.called = true
705		h.dstfd = dst.Sysfd
706		h.srcfd = src.Sysfd
707		h.remain = remain
708		h.written, h.handled, h.err = h.original(dst, src, remain)
709		return h.written, h.handled, h.err
710	}
711}
712
713func (h *spliceFileHook) uninstall() {
714	*PollSpliceFile = h.original
715}
716
717// On some kernels copy_file_range fails on files in /proc.
718func TestProcCopy(t *testing.T) {
719	t.Parallel()
720
721	const cmdlineFile = "/proc/self/cmdline"
722	cmdline, err := ReadFile(cmdlineFile)
723	if err != nil {
724		t.Skipf("can't read /proc file: %v", err)
725	}
726	in, err := Open(cmdlineFile)
727	if err != nil {
728		t.Fatal(err)
729	}
730	defer in.Close()
731	outFile := filepath.Join(t.TempDir(), "cmdline")
732	out, err := Create(outFile)
733	if err != nil {
734		t.Fatal(err)
735	}
736	if _, err := io.Copy(out, in); err != nil {
737		t.Fatal(err)
738	}
739	if err := out.Close(); err != nil {
740		t.Fatal(err)
741	}
742	copy, err := ReadFile(outFile)
743	if err != nil {
744		t.Fatal(err)
745	}
746	if !bytes.Equal(cmdline, copy) {
747		t.Errorf("copy of %q got %q want %q\n", cmdlineFile, copy, cmdline)
748	}
749}
750
751func TestGetPollFDAndNetwork(t *testing.T) {
752	t.Run("tcp4", func(t *testing.T) { testGetPollFDAndNetwork(t, "tcp4") })
753	t.Run("unix", func(t *testing.T) { testGetPollFDAndNetwork(t, "unix") })
754}
755
756func testGetPollFDAndNetwork(t *testing.T, proto string) {
757	_, server := createSocketPair(t, proto)
758	sc, ok := server.(syscall.Conn)
759	if !ok {
760		t.Fatalf("server Conn is not a syscall.Conn")
761	}
762	rc, err := sc.SyscallConn()
763	if err != nil {
764		t.Fatalf("server SyscallConn error: %v", err)
765	}
766	if err = rc.Control(func(fd uintptr) {
767		pfd, network := GetPollFDAndNetwork(server)
768		if pfd == nil {
769			t.Fatalf("GetPollFDAndNetwork didn't return poll.FD")
770		}
771		if string(network) != proto {
772			t.Fatalf("GetPollFDAndNetwork returned wrong network, got: %s, want: %s", network, proto)
773		}
774		if pfd.Sysfd != int(fd) {
775			t.Fatalf("GetPollFDAndNetwork returned wrong poll.FD, got: %d, want: %d", pfd.Sysfd, int(fd))
776		}
777		if !pfd.IsStream {
778			t.Fatalf("expected IsStream to be true")
779		}
780		if err = pfd.Init(proto, true); err == nil {
781			t.Fatalf("Init should have failed with the initialized poll.FD and return EEXIST error")
782		}
783	}); err != nil {
784		t.Fatalf("server Control error: %v", err)
785	}
786}
787
788func createSocketPair(t *testing.T, proto string) (client, server net.Conn) {
789	t.Helper()
790	if !nettest.TestableNetwork(proto) {
791		t.Skipf("%s does not support %q", runtime.GOOS, proto)
792	}
793
794	ln, err := nettest.NewLocalListener(proto)
795	if err != nil {
796		t.Fatalf("NewLocalListener error: %v", err)
797	}
798	t.Cleanup(func() {
799		if ln != nil {
800			ln.Close()
801		}
802		if client != nil {
803			client.Close()
804		}
805		if server != nil {
806			server.Close()
807		}
808	})
809	ch := make(chan struct{})
810	go func() {
811		var err error
812		server, err = ln.Accept()
813		if err != nil {
814			t.Errorf("Accept new connection error: %v", err)
815		}
816		ch <- struct{}{}
817	}()
818	client, err = net.Dial(proto, ln.Addr().String())
819	<-ch
820	if err != nil {
821		t.Fatalf("Dial new connection error: %v", err)
822	}
823	return client, server
824}
825