1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package net
6
7import (
8	"errors"
9	"fmt"
10	"io"
11	"net/internal/socktest"
12	"os"
13	"runtime"
14	"testing"
15	"time"
16)
17
18func TestCloseRead(t *testing.T) {
19	switch runtime.GOOS {
20	case "plan9":
21		t.Skipf("not supported on %s", runtime.GOOS)
22	}
23	t.Parallel()
24
25	for _, network := range []string{"tcp", "unix", "unixpacket"} {
26		network := network
27		t.Run(network, func(t *testing.T) {
28			if !testableNetwork(network) {
29				t.Skipf("network %s is not testable on the current platform", network)
30			}
31			t.Parallel()
32
33			ln := newLocalListener(t, network)
34			switch network {
35			case "unix", "unixpacket":
36				defer os.Remove(ln.Addr().String())
37			}
38			defer ln.Close()
39
40			c, err := Dial(ln.Addr().Network(), ln.Addr().String())
41			if err != nil {
42				t.Fatal(err)
43			}
44			switch network {
45			case "unix", "unixpacket":
46				defer os.Remove(c.LocalAddr().String())
47			}
48			defer c.Close()
49
50			switch c := c.(type) {
51			case *TCPConn:
52				err = c.CloseRead()
53			case *UnixConn:
54				err = c.CloseRead()
55			}
56			if err != nil {
57				if perr := parseCloseError(err, true); perr != nil {
58					t.Error(perr)
59				}
60				t.Fatal(err)
61			}
62			var b [1]byte
63			n, err := c.Read(b[:])
64			if n != 0 || err == nil {
65				t.Fatalf("got (%d, %v); want (0, error)", n, err)
66			}
67		})
68	}
69}
70
71func TestCloseWrite(t *testing.T) {
72	switch runtime.GOOS {
73	case "plan9":
74		t.Skipf("not supported on %s", runtime.GOOS)
75	}
76
77	t.Parallel()
78	deadline, _ := t.Deadline()
79	if !deadline.IsZero() {
80		// Leave 10% headroom on the deadline to report errors and clean up.
81		deadline = deadline.Add(-time.Until(deadline) / 10)
82	}
83
84	for _, network := range []string{"tcp", "unix", "unixpacket"} {
85		network := network
86		t.Run(network, func(t *testing.T) {
87			if !testableNetwork(network) {
88				t.Skipf("network %s is not testable on the current platform", network)
89			}
90			t.Parallel()
91
92			handler := func(ls *localServer, ln Listener) {
93				c, err := ln.Accept()
94				if err != nil {
95					t.Error(err)
96					return
97				}
98
99				// Workaround for https://go.dev/issue/49352.
100				// On arm64 macOS (current as of macOS 12.4),
101				// reading from a socket at the same time as the client
102				// is closing it occasionally hangs for 60 seconds before
103				// returning ECONNRESET. Sleep for a bit to give the
104				// socket time to close before trying to read from it.
105				if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
106					time.Sleep(10 * time.Millisecond)
107				}
108
109				if !deadline.IsZero() {
110					c.SetDeadline(deadline)
111				}
112				defer c.Close()
113
114				var b [1]byte
115				n, err := c.Read(b[:])
116				if n != 0 || err != io.EOF {
117					t.Errorf("got (%d, %v); want (0, io.EOF)", n, err)
118					return
119				}
120				switch c := c.(type) {
121				case *TCPConn:
122					err = c.CloseWrite()
123				case *UnixConn:
124					err = c.CloseWrite()
125				}
126				if err != nil {
127					if perr := parseCloseError(err, true); perr != nil {
128						t.Error(perr)
129					}
130					t.Error(err)
131					return
132				}
133				n, err = c.Write(b[:])
134				if err == nil {
135					t.Errorf("got (%d, %v); want (any, error)", n, err)
136					return
137				}
138			}
139
140			ls := newLocalServer(t, network)
141			defer ls.teardown()
142			if err := ls.buildup(handler); err != nil {
143				t.Fatal(err)
144			}
145
146			c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
147			if err != nil {
148				t.Fatal(err)
149			}
150			if !deadline.IsZero() {
151				c.SetDeadline(deadline)
152			}
153			switch network {
154			case "unix", "unixpacket":
155				defer os.Remove(c.LocalAddr().String())
156			}
157			defer c.Close()
158
159			switch c := c.(type) {
160			case *TCPConn:
161				err = c.CloseWrite()
162			case *UnixConn:
163				err = c.CloseWrite()
164			}
165			if err != nil {
166				if perr := parseCloseError(err, true); perr != nil {
167					t.Error(perr)
168				}
169				t.Fatal(err)
170			}
171			var b [1]byte
172			n, err := c.Read(b[:])
173			if n != 0 || err != io.EOF {
174				t.Fatalf("got (%d, %v); want (0, io.EOF)", n, err)
175			}
176			n, err = c.Write(b[:])
177			if err == nil {
178				t.Fatalf("got (%d, %v); want (any, error)", n, err)
179			}
180		})
181	}
182}
183
184func TestConnClose(t *testing.T) {
185	t.Parallel()
186	for _, network := range []string{"tcp", "unix", "unixpacket"} {
187		network := network
188		t.Run(network, func(t *testing.T) {
189			if !testableNetwork(network) {
190				t.Skipf("network %s is not testable on the current platform", network)
191			}
192			t.Parallel()
193
194			ln := newLocalListener(t, network)
195			switch network {
196			case "unix", "unixpacket":
197				defer os.Remove(ln.Addr().String())
198			}
199			defer ln.Close()
200
201			c, err := Dial(ln.Addr().Network(), ln.Addr().String())
202			if err != nil {
203				t.Fatal(err)
204			}
205			switch network {
206			case "unix", "unixpacket":
207				defer os.Remove(c.LocalAddr().String())
208			}
209			defer c.Close()
210
211			if err := c.Close(); err != nil {
212				if perr := parseCloseError(err, false); perr != nil {
213					t.Error(perr)
214				}
215				t.Fatal(err)
216			}
217			var b [1]byte
218			n, err := c.Read(b[:])
219			if n != 0 || err == nil {
220				t.Fatalf("got (%d, %v); want (0, error)", n, err)
221			}
222		})
223	}
224}
225
226func TestListenerClose(t *testing.T) {
227	t.Parallel()
228	for _, network := range []string{"tcp", "unix", "unixpacket"} {
229		network := network
230		t.Run(network, func(t *testing.T) {
231			if !testableNetwork(network) {
232				t.Skipf("network %s is not testable on the current platform", network)
233			}
234			t.Parallel()
235
236			ln := newLocalListener(t, network)
237			switch network {
238			case "unix", "unixpacket":
239				defer os.Remove(ln.Addr().String())
240			}
241
242			if err := ln.Close(); err != nil {
243				if perr := parseCloseError(err, false); perr != nil {
244					t.Error(perr)
245				}
246				t.Fatal(err)
247			}
248			c, err := ln.Accept()
249			if err == nil {
250				c.Close()
251				t.Fatal("should fail")
252			}
253
254			// Note: we cannot ensure that a subsequent Dial does not succeed, because
255			// we do not in general have any guarantee that ln.Addr is not immediately
256			// reused. (TCP sockets enter a TIME_WAIT state when closed, but that only
257			// applies to existing connections for the port — it does not prevent the
258			// port itself from being used for entirely new connections in the
259			// meantime.)
260		})
261	}
262}
263
264func TestPacketConnClose(t *testing.T) {
265	t.Parallel()
266	for _, network := range []string{"udp", "unixgram"} {
267		network := network
268		t.Run(network, func(t *testing.T) {
269			if !testableNetwork(network) {
270				t.Skipf("network %s is not testable on the current platform", network)
271			}
272			t.Parallel()
273
274			c := newLocalPacketListener(t, network)
275			switch network {
276			case "unixgram":
277				defer os.Remove(c.LocalAddr().String())
278			}
279			defer c.Close()
280
281			if err := c.Close(); err != nil {
282				if perr := parseCloseError(err, false); perr != nil {
283					t.Error(perr)
284				}
285				t.Fatal(err)
286			}
287			var b [1]byte
288			n, _, err := c.ReadFrom(b[:])
289			if n != 0 || err == nil {
290				t.Fatalf("got (%d, %v); want (0, error)", n, err)
291			}
292		})
293	}
294}
295
296// See golang.org/issue/6163, golang.org/issue/6987.
297func TestAcceptIgnoreAbortedConnRequest(t *testing.T) {
298	switch runtime.GOOS {
299	case "plan9":
300		t.Skipf("%s does not have full support of socktest", runtime.GOOS)
301	}
302
303	syserr := make(chan error)
304	go func() {
305		defer close(syserr)
306		for _, err := range abortedConnRequestErrors {
307			syserr <- err
308		}
309	}()
310	sw.Set(socktest.FilterAccept, func(so *socktest.Status) (socktest.AfterFilter, error) {
311		if err, ok := <-syserr; ok {
312			return nil, err
313		}
314		return nil, nil
315	})
316	defer sw.Set(socktest.FilterAccept, nil)
317
318	operr := make(chan error, 1)
319	handler := func(ls *localServer, ln Listener) {
320		defer close(operr)
321		c, err := ln.Accept()
322		if err != nil {
323			if perr := parseAcceptError(err); perr != nil {
324				operr <- perr
325			}
326			operr <- err
327			return
328		}
329		c.Close()
330	}
331	ls := newLocalServer(t, "tcp")
332	defer ls.teardown()
333	if err := ls.buildup(handler); err != nil {
334		t.Fatal(err)
335	}
336
337	c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
338	if err != nil {
339		t.Fatal(err)
340	}
341	c.Close()
342
343	for err := range operr {
344		t.Error(err)
345	}
346}
347
348func TestZeroByteRead(t *testing.T) {
349	t.Parallel()
350	for _, network := range []string{"tcp", "unix", "unixpacket"} {
351		network := network
352		t.Run(network, func(t *testing.T) {
353			if !testableNetwork(network) {
354				t.Skipf("network %s is not testable on the current platform", network)
355			}
356			t.Parallel()
357
358			ln := newLocalListener(t, network)
359			connc := make(chan Conn, 1)
360			defer func() {
361				ln.Close()
362				for c := range connc {
363					if c != nil {
364						c.Close()
365					}
366				}
367			}()
368			go func() {
369				defer close(connc)
370				c, err := ln.Accept()
371				if err != nil {
372					t.Error(err)
373				}
374				connc <- c // might be nil
375			}()
376			c, err := Dial(network, ln.Addr().String())
377			if err != nil {
378				t.Fatal(err)
379			}
380			defer c.Close()
381			sc := <-connc
382			if sc == nil {
383				return
384			}
385			defer sc.Close()
386
387			if runtime.GOOS == "windows" {
388				// A zero byte read on Windows caused a wait for readability first.
389				// Rather than change that behavior, satisfy it in this test.
390				// See Issue 15735.
391				go io.WriteString(sc, "a")
392			}
393
394			n, err := c.Read(nil)
395			if n != 0 || err != nil {
396				t.Errorf("%s: zero byte client read = %v, %v; want 0, nil", network, n, err)
397			}
398
399			if runtime.GOOS == "windows" {
400				// Same as comment above.
401				go io.WriteString(c, "a")
402			}
403			n, err = sc.Read(nil)
404			if n != 0 || err != nil {
405				t.Errorf("%s: zero byte server read = %v, %v; want 0, nil", network, n, err)
406			}
407		})
408	}
409}
410
411// withTCPConnPair sets up a TCP connection between two peers, then
412// runs peer1 and peer2 concurrently. withTCPConnPair returns when
413// both have completed.
414func withTCPConnPair(t *testing.T, peer1, peer2 func(c *TCPConn) error) {
415	t.Helper()
416	ln := newLocalListener(t, "tcp")
417	defer ln.Close()
418	errc := make(chan error, 2)
419	go func() {
420		c1, err := ln.Accept()
421		if err != nil {
422			errc <- err
423			return
424		}
425		err = peer1(c1.(*TCPConn))
426		c1.Close()
427		errc <- err
428	}()
429	go func() {
430		c2, err := Dial("tcp", ln.Addr().String())
431		if err != nil {
432			errc <- err
433			return
434		}
435		err = peer2(c2.(*TCPConn))
436		c2.Close()
437		errc <- err
438	}()
439	for i := 0; i < 2; i++ {
440		if err := <-errc; err != nil {
441			t.Error(err)
442		}
443	}
444}
445
446// Tests that a blocked Read is interrupted by a concurrent SetReadDeadline
447// modifying that Conn's read deadline to the past.
448// See golang.org/cl/30164 which documented this. The net/http package
449// depends on this.
450func TestReadTimeoutUnblocksRead(t *testing.T) {
451	serverDone := make(chan struct{})
452	server := func(cs *TCPConn) error {
453		defer close(serverDone)
454		errc := make(chan error, 1)
455		go func() {
456			defer close(errc)
457			go func() {
458				// TODO: find a better way to wait
459				// until we're blocked in the cs.Read
460				// call below. Sleep is lame.
461				time.Sleep(100 * time.Millisecond)
462
463				// Interrupt the upcoming Read, unblocking it:
464				cs.SetReadDeadline(time.Unix(123, 0)) // time in the past
465			}()
466			var buf [1]byte
467			n, err := cs.Read(buf[:1])
468			if n != 0 || err == nil {
469				errc <- fmt.Errorf("Read = %v, %v; want 0, non-nil", n, err)
470			}
471		}()
472		select {
473		case err := <-errc:
474			return err
475		case <-time.After(5 * time.Second):
476			buf := make([]byte, 2<<20)
477			buf = buf[:runtime.Stack(buf, true)]
478			println("Stacks at timeout:\n", string(buf))
479			return errors.New("timeout waiting for Read to finish")
480		}
481
482	}
483	// Do nothing in the client. Never write. Just wait for the
484	// server's half to be done.
485	client := func(*TCPConn) error {
486		<-serverDone
487		return nil
488	}
489	withTCPConnPair(t, client, server)
490}
491
492// Issue 17695: verify that a blocked Read is woken up by a Close.
493func TestCloseUnblocksRead(t *testing.T) {
494	t.Parallel()
495	server := func(cs *TCPConn) error {
496		// Give the client time to get stuck in a Read:
497		time.Sleep(20 * time.Millisecond)
498		cs.Close()
499		return nil
500	}
501	client := func(ss *TCPConn) error {
502		n, err := ss.Read([]byte{0})
503		if n != 0 || err != io.EOF {
504			return fmt.Errorf("Read = %v, %v; want 0, EOF", n, err)
505		}
506		return nil
507	}
508	withTCPConnPair(t, client, server)
509}
510
511// Issue 24808: verify that ECONNRESET is not temporary for read.
512func TestNotTemporaryRead(t *testing.T) {
513	t.Parallel()
514
515	ln := newLocalListener(t, "tcp")
516	serverDone := make(chan struct{})
517	dialed := make(chan struct{})
518	go func() {
519		defer close(serverDone)
520
521		cs, err := ln.Accept()
522		if err != nil {
523			return
524		}
525		<-dialed
526		cs.(*TCPConn).SetLinger(0)
527		cs.Close()
528	}()
529	defer func() {
530		ln.Close()
531		<-serverDone
532	}()
533
534	ss, err := Dial("tcp", ln.Addr().String())
535	close(dialed)
536	if err != nil {
537		t.Fatal(err)
538	}
539	defer ss.Close()
540
541	_, err = ss.Read([]byte{0})
542	if err == nil {
543		t.Fatal("Read succeeded unexpectedly")
544	} else if err == io.EOF {
545		// This happens on Plan 9, but for some reason (prior to CL 385314) it was
546		// accepted everywhere else too.
547		if runtime.GOOS == "plan9" {
548			return
549		}
550		t.Fatal("Read unexpectedly returned io.EOF after socket was abruptly closed")
551	}
552	if ne, ok := err.(Error); !ok {
553		t.Errorf("Read error does not implement net.Error: %v", err)
554	} else if ne.Temporary() {
555		t.Errorf("Read error is unexpectedly temporary: %v", err)
556	}
557}
558
559// The various errors should implement the Error interface.
560func TestErrors(t *testing.T) {
561	var (
562		_ Error = &OpError{}
563		_ Error = &ParseError{}
564		_ Error = &AddrError{}
565		_ Error = UnknownNetworkError("")
566		_ Error = InvalidAddrError("")
567		_ Error = &timeoutError{}
568		_ Error = &DNSConfigError{}
569		_ Error = &DNSError{}
570	)
571
572	// ErrClosed was introduced as type error, so we can't check
573	// it using a declaration.
574	if _, ok := ErrClosed.(Error); !ok {
575		t.Fatal("ErrClosed does not implement Error")
576	}
577}
578