1// Copyright 2013 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package net
6
7import (
8	"context"
9	"errors"
10	"fmt"
11	"internal/testenv"
12	"log"
13	"os"
14	"path/filepath"
15	"runtime"
16	"strconv"
17	"sync"
18	"testing"
19	"time"
20)
21
22// testUnixAddr uses os.MkdirTemp to get a name that is unique.
23func testUnixAddr(t testing.TB) string {
24	// Pass an empty pattern to get a directory name that is as short as possible.
25	// If we end up with a name longer than the sun_path field in the sockaddr_un
26	// struct, we won't be able to make the syscall to open the socket.
27	d, err := os.MkdirTemp("", "")
28	if err != nil {
29		t.Fatal(err)
30	}
31	t.Cleanup(func() {
32		if err := os.RemoveAll(d); err != nil {
33			t.Error(err)
34		}
35	})
36	return filepath.Join(d, "sock")
37}
38
39func newLocalListener(t testing.TB, network string, lcOpt ...*ListenConfig) Listener {
40	var lc *ListenConfig
41	switch len(lcOpt) {
42	case 0:
43		lc = new(ListenConfig)
44	case 1:
45		lc = lcOpt[0]
46	default:
47		t.Helper()
48		t.Fatal("too many ListenConfigs passed to newLocalListener: want 0 or 1")
49	}
50
51	listen := func(net, addr string) Listener {
52		ln, err := lc.Listen(context.Background(), net, addr)
53		if err != nil {
54			t.Helper()
55			t.Fatal(err)
56		}
57		return ln
58	}
59
60	switch network {
61	case "tcp":
62		if supportsIPv4() {
63			return listen("tcp4", "127.0.0.1:0")
64		}
65		if supportsIPv6() {
66			return listen("tcp6", "[::1]:0")
67		}
68	case "tcp4":
69		if supportsIPv4() {
70			return listen("tcp4", "127.0.0.1:0")
71		}
72	case "tcp6":
73		if supportsIPv6() {
74			return listen("tcp6", "[::1]:0")
75		}
76	case "unix", "unixpacket":
77		return listen(network, testUnixAddr(t))
78	}
79
80	t.Helper()
81	t.Fatalf("%s is not supported", network)
82	return nil
83}
84
85func newDualStackListener() (lns []*TCPListener, err error) {
86	var args = []struct {
87		network string
88		TCPAddr
89	}{
90		{"tcp4", TCPAddr{IP: IPv4(127, 0, 0, 1)}},
91		{"tcp6", TCPAddr{IP: IPv6loopback}},
92	}
93	for i := 0; i < 64; i++ {
94		var port int
95		var lns []*TCPListener
96		for _, arg := range args {
97			arg.TCPAddr.Port = port
98			ln, err := ListenTCP(arg.network, &arg.TCPAddr)
99			if err != nil {
100				continue
101			}
102			port = ln.Addr().(*TCPAddr).Port
103			lns = append(lns, ln)
104		}
105		if len(lns) != len(args) {
106			for _, ln := range lns {
107				ln.Close()
108			}
109			continue
110		}
111		return lns, nil
112	}
113	return nil, errors.New("no dualstack port available")
114}
115
116type localServer struct {
117	lnmu sync.RWMutex
118	Listener
119	done chan bool // signal that indicates server stopped
120	cl   []Conn    // accepted connection list
121}
122
123func (ls *localServer) buildup(handler func(*localServer, Listener)) error {
124	go func() {
125		handler(ls, ls.Listener)
126		close(ls.done)
127	}()
128	return nil
129}
130
131func (ls *localServer) teardown() error {
132	ls.lnmu.Lock()
133	defer ls.lnmu.Unlock()
134	if ls.Listener != nil {
135		network := ls.Listener.Addr().Network()
136		address := ls.Listener.Addr().String()
137		ls.Listener.Close()
138		for _, c := range ls.cl {
139			if err := c.Close(); err != nil {
140				return err
141			}
142		}
143		<-ls.done
144		ls.Listener = nil
145		switch network {
146		case "unix", "unixpacket":
147			os.Remove(address)
148		}
149	}
150	return nil
151}
152
153func newLocalServer(t testing.TB, network string) *localServer {
154	t.Helper()
155	ln := newLocalListener(t, network)
156	return &localServer{Listener: ln, done: make(chan bool)}
157}
158
159type streamListener struct {
160	network, address string
161	Listener
162	done chan bool // signal that indicates server stopped
163}
164
165func (sl *streamListener) newLocalServer() *localServer {
166	return &localServer{Listener: sl.Listener, done: make(chan bool)}
167}
168
169type dualStackServer struct {
170	lnmu sync.RWMutex
171	lns  []streamListener
172	port string
173
174	cmu sync.RWMutex
175	cs  []Conn // established connections at the passive open side
176}
177
178func (dss *dualStackServer) buildup(handler func(*dualStackServer, Listener)) error {
179	for i := range dss.lns {
180		go func(i int) {
181			handler(dss, dss.lns[i].Listener)
182			close(dss.lns[i].done)
183		}(i)
184	}
185	return nil
186}
187
188func (dss *dualStackServer) teardownNetwork(network string) error {
189	dss.lnmu.Lock()
190	for i := range dss.lns {
191		if network == dss.lns[i].network && dss.lns[i].Listener != nil {
192			dss.lns[i].Listener.Close()
193			<-dss.lns[i].done
194			dss.lns[i].Listener = nil
195		}
196	}
197	dss.lnmu.Unlock()
198	return nil
199}
200
201func (dss *dualStackServer) teardown() error {
202	dss.lnmu.Lock()
203	for i := range dss.lns {
204		if dss.lns[i].Listener != nil {
205			dss.lns[i].Listener.Close()
206			<-dss.lns[i].done
207		}
208	}
209	dss.lns = dss.lns[:0]
210	dss.lnmu.Unlock()
211	dss.cmu.Lock()
212	for _, c := range dss.cs {
213		c.Close()
214	}
215	dss.cs = dss.cs[:0]
216	dss.cmu.Unlock()
217	return nil
218}
219
220func newDualStackServer() (*dualStackServer, error) {
221	lns, err := newDualStackListener()
222	if err != nil {
223		return nil, err
224	}
225	_, port, err := SplitHostPort(lns[0].Addr().String())
226	if err != nil {
227		lns[0].Close()
228		lns[1].Close()
229		return nil, err
230	}
231	return &dualStackServer{
232		lns: []streamListener{
233			{network: "tcp4", address: lns[0].Addr().String(), Listener: lns[0], done: make(chan bool)},
234			{network: "tcp6", address: lns[1].Addr().String(), Listener: lns[1], done: make(chan bool)},
235		},
236		port: port,
237	}, nil
238}
239
240func (ls *localServer) transponder(ln Listener, ch chan<- error) {
241	defer close(ch)
242
243	switch ln := ln.(type) {
244	case *TCPListener:
245		ln.SetDeadline(time.Now().Add(someTimeout))
246	case *UnixListener:
247		ln.SetDeadline(time.Now().Add(someTimeout))
248	}
249	c, err := ln.Accept()
250	if err != nil {
251		if perr := parseAcceptError(err); perr != nil {
252			ch <- perr
253		}
254		ch <- err
255		return
256	}
257	ls.cl = append(ls.cl, c)
258
259	network := ln.Addr().Network()
260	if c.LocalAddr().Network() != network || c.RemoteAddr().Network() != network {
261		ch <- fmt.Errorf("got %v->%v; expected %v->%v", c.LocalAddr().Network(), c.RemoteAddr().Network(), network, network)
262		return
263	}
264	c.SetDeadline(time.Now().Add(someTimeout))
265	c.SetReadDeadline(time.Now().Add(someTimeout))
266	c.SetWriteDeadline(time.Now().Add(someTimeout))
267
268	b := make([]byte, 256)
269	n, err := c.Read(b)
270	if err != nil {
271		if perr := parseReadError(err); perr != nil {
272			ch <- perr
273		}
274		ch <- err
275		return
276	}
277	if _, err := c.Write(b[:n]); err != nil {
278		if perr := parseWriteError(err); perr != nil {
279			ch <- perr
280		}
281		ch <- err
282		return
283	}
284}
285
286func transceiver(c Conn, wb []byte, ch chan<- error) {
287	defer close(ch)
288
289	c.SetDeadline(time.Now().Add(someTimeout))
290	c.SetReadDeadline(time.Now().Add(someTimeout))
291	c.SetWriteDeadline(time.Now().Add(someTimeout))
292
293	n, err := c.Write(wb)
294	if err != nil {
295		if perr := parseWriteError(err); perr != nil {
296			ch <- perr
297		}
298		ch <- err
299		return
300	}
301	if n != len(wb) {
302		ch <- fmt.Errorf("wrote %d; want %d", n, len(wb))
303	}
304	rb := make([]byte, len(wb))
305	n, err = c.Read(rb)
306	if err != nil {
307		if perr := parseReadError(err); perr != nil {
308			ch <- perr
309		}
310		ch <- err
311		return
312	}
313	if n != len(wb) {
314		ch <- fmt.Errorf("read %d; want %d", n, len(wb))
315	}
316}
317
318func newLocalPacketListener(t testing.TB, network string, lcOpt ...*ListenConfig) PacketConn {
319	var lc *ListenConfig
320	switch len(lcOpt) {
321	case 0:
322		lc = new(ListenConfig)
323	case 1:
324		lc = lcOpt[0]
325	default:
326		t.Helper()
327		t.Fatal("too many ListenConfigs passed to newLocalListener: want 0 or 1")
328	}
329
330	listenPacket := func(net, addr string) PacketConn {
331		c, err := lc.ListenPacket(context.Background(), net, addr)
332		if err != nil {
333			t.Helper()
334			t.Fatal(err)
335		}
336		return c
337	}
338
339	t.Helper()
340	switch network {
341	case "udp":
342		if supportsIPv4() {
343			return listenPacket("udp4", "127.0.0.1:0")
344		}
345		if supportsIPv6() {
346			return listenPacket("udp6", "[::1]:0")
347		}
348	case "udp4":
349		if supportsIPv4() {
350			return listenPacket("udp4", "127.0.0.1:0")
351		}
352	case "udp6":
353		if supportsIPv6() {
354			return listenPacket("udp6", "[::1]:0")
355		}
356	case "unixgram":
357		return listenPacket(network, testUnixAddr(t))
358	}
359
360	t.Fatalf("%s is not supported", network)
361	return nil
362}
363
364func newDualStackPacketListener() (cs []*UDPConn, err error) {
365	var args = []struct {
366		network string
367		UDPAddr
368	}{
369		{"udp4", UDPAddr{IP: IPv4(127, 0, 0, 1)}},
370		{"udp6", UDPAddr{IP: IPv6loopback}},
371	}
372	for i := 0; i < 64; i++ {
373		var port int
374		var cs []*UDPConn
375		for _, arg := range args {
376			arg.UDPAddr.Port = port
377			c, err := ListenUDP(arg.network, &arg.UDPAddr)
378			if err != nil {
379				continue
380			}
381			port = c.LocalAddr().(*UDPAddr).Port
382			cs = append(cs, c)
383		}
384		if len(cs) != len(args) {
385			for _, c := range cs {
386				c.Close()
387			}
388			continue
389		}
390		return cs, nil
391	}
392	return nil, errors.New("no dualstack port available")
393}
394
395type localPacketServer struct {
396	pcmu sync.RWMutex
397	PacketConn
398	done chan bool // signal that indicates server stopped
399}
400
401func (ls *localPacketServer) buildup(handler func(*localPacketServer, PacketConn)) error {
402	go func() {
403		handler(ls, ls.PacketConn)
404		close(ls.done)
405	}()
406	return nil
407}
408
409func (ls *localPacketServer) teardown() error {
410	ls.pcmu.Lock()
411	if ls.PacketConn != nil {
412		network := ls.PacketConn.LocalAddr().Network()
413		address := ls.PacketConn.LocalAddr().String()
414		ls.PacketConn.Close()
415		<-ls.done
416		ls.PacketConn = nil
417		switch network {
418		case "unixgram":
419			os.Remove(address)
420		}
421	}
422	ls.pcmu.Unlock()
423	return nil
424}
425
426func newLocalPacketServer(t testing.TB, network string) *localPacketServer {
427	t.Helper()
428	c := newLocalPacketListener(t, network)
429	return &localPacketServer{PacketConn: c, done: make(chan bool)}
430}
431
432type packetListener struct {
433	PacketConn
434}
435
436func (pl *packetListener) newLocalServer() *localPacketServer {
437	return &localPacketServer{PacketConn: pl.PacketConn, done: make(chan bool)}
438}
439
440func packetTransponder(c PacketConn, ch chan<- error) {
441	defer close(ch)
442
443	c.SetDeadline(time.Now().Add(someTimeout))
444	c.SetReadDeadline(time.Now().Add(someTimeout))
445	c.SetWriteDeadline(time.Now().Add(someTimeout))
446
447	b := make([]byte, 256)
448	n, peer, err := c.ReadFrom(b)
449	if err != nil {
450		if perr := parseReadError(err); perr != nil {
451			ch <- perr
452		}
453		ch <- err
454		return
455	}
456	if peer == nil { // for connected-mode sockets
457		switch c.LocalAddr().Network() {
458		case "udp":
459			peer, err = ResolveUDPAddr("udp", string(b[:n]))
460		case "unixgram":
461			peer, err = ResolveUnixAddr("unixgram", string(b[:n]))
462		}
463		if err != nil {
464			ch <- err
465			return
466		}
467	}
468	if _, err := c.WriteTo(b[:n], peer); err != nil {
469		if perr := parseWriteError(err); perr != nil {
470			ch <- perr
471		}
472		ch <- err
473		return
474	}
475}
476
477func packetTransceiver(c PacketConn, wb []byte, dst Addr, ch chan<- error) {
478	defer close(ch)
479
480	c.SetDeadline(time.Now().Add(someTimeout))
481	c.SetReadDeadline(time.Now().Add(someTimeout))
482	c.SetWriteDeadline(time.Now().Add(someTimeout))
483
484	n, err := c.WriteTo(wb, dst)
485	if err != nil {
486		if perr := parseWriteError(err); perr != nil {
487			ch <- perr
488		}
489		ch <- err
490		return
491	}
492	if n != len(wb) {
493		ch <- fmt.Errorf("wrote %d; want %d", n, len(wb))
494	}
495	rb := make([]byte, len(wb))
496	n, _, err = c.ReadFrom(rb)
497	if err != nil {
498		if perr := parseReadError(err); perr != nil {
499			ch <- perr
500		}
501		ch <- err
502		return
503	}
504	if n != len(wb) {
505		ch <- fmt.Errorf("read %d; want %d", n, len(wb))
506	}
507}
508
509func spawnTestSocketPair(t testing.TB, net string) (client, server Conn) {
510	t.Helper()
511
512	ln := newLocalListener(t, net)
513	defer ln.Close()
514	var cerr, serr error
515	acceptDone := make(chan struct{})
516	go func() {
517		server, serr = ln.Accept()
518		acceptDone <- struct{}{}
519	}()
520	client, cerr = Dial(ln.Addr().Network(), ln.Addr().String())
521	<-acceptDone
522	if cerr != nil {
523		if server != nil {
524			server.Close()
525		}
526		t.Fatal(cerr)
527	}
528	if serr != nil {
529		if client != nil {
530			client.Close()
531		}
532		t.Fatal(serr)
533	}
534	return client, server
535}
536
537func startTestSocketPeer(t testing.TB, conn Conn, op string, chunkSize, totalSize int) (func(t testing.TB), error) {
538	t.Helper()
539
540	if runtime.GOOS == "windows" {
541		// TODO(panjf2000): Windows has not yet implemented FileConn,
542		//		remove this when it's implemented in https://go.dev/issues/9503.
543		t.Fatalf("startTestSocketPeer is not supported on %s", runtime.GOOS)
544	}
545
546	f, err := conn.(interface{ File() (*os.File, error) }).File()
547	if err != nil {
548		return nil, err
549	}
550
551	cmd := testenv.Command(t, os.Args[0])
552	cmd.Env = []string{
553		"GO_NET_TEST_TRANSFER=1",
554		"GO_NET_TEST_TRANSFER_OP=" + op,
555		"GO_NET_TEST_TRANSFER_CHUNK_SIZE=" + strconv.Itoa(chunkSize),
556		"GO_NET_TEST_TRANSFER_TOTAL_SIZE=" + strconv.Itoa(totalSize),
557		"TMPDIR=" + os.Getenv("TMPDIR"),
558	}
559	cmd.ExtraFiles = append(cmd.ExtraFiles, f)
560	cmd.Stdout = os.Stdout
561	cmd.Stderr = os.Stderr
562
563	if err := cmd.Start(); err != nil {
564		return nil, err
565	}
566
567	cmdCh := make(chan error, 1)
568	go func() {
569		err := cmd.Wait()
570		conn.Close()
571		f.Close()
572		cmdCh <- err
573	}()
574
575	return func(tb testing.TB) {
576		err := <-cmdCh
577		if err != nil {
578			tb.Errorf("process exited with error: %v", err)
579		}
580	}, nil
581}
582
583func init() {
584	if os.Getenv("GO_NET_TEST_TRANSFER") == "" {
585		return
586	}
587	defer os.Exit(0)
588
589	f := os.NewFile(uintptr(3), "splice-test-conn")
590	defer f.Close()
591
592	conn, err := FileConn(f)
593	if err != nil {
594		log.Fatal(err)
595	}
596
597	var chunkSize int
598	if chunkSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_TRANSFER_CHUNK_SIZE")); err != nil {
599		log.Fatal(err)
600	}
601	buf := make([]byte, chunkSize)
602
603	var totalSize int
604	if totalSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_TRANSFER_TOTAL_SIZE")); err != nil {
605		log.Fatal(err)
606	}
607
608	var fn func([]byte) (int, error)
609	switch op := os.Getenv("GO_NET_TEST_TRANSFER_OP"); op {
610	case "r":
611		fn = conn.Read
612	case "w":
613		defer conn.Close()
614
615		fn = conn.Write
616	default:
617		log.Fatalf("unknown op %q", op)
618	}
619
620	var n int
621	for count := 0; count < totalSize; count += n {
622		if count+chunkSize > totalSize {
623			buf = buf[:totalSize-count]
624		}
625
626		var err error
627		if n, err = fn(buf); err != nil {
628			return
629		}
630	}
631}
632