1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package net
6
7import (
8	"bytes"
9	"fmt"
10	"internal/poll"
11	"io"
12	"reflect"
13	"runtime"
14	"sync"
15	"testing"
16)
17
18func TestBuffers_read(t *testing.T) {
19	const story = "once upon a time in Gopherland ... "
20	buffers := Buffers{
21		[]byte("once "),
22		[]byte("upon "),
23		[]byte("a "),
24		[]byte("time "),
25		[]byte("in "),
26		[]byte("Gopherland ... "),
27	}
28	got, err := io.ReadAll(&buffers)
29	if err != nil {
30		t.Fatal(err)
31	}
32	if string(got) != story {
33		t.Errorf("read %q; want %q", got, story)
34	}
35	if len(buffers) != 0 {
36		t.Errorf("len(buffers) = %d; want 0", len(buffers))
37	}
38}
39
40func TestBuffers_consume(t *testing.T) {
41	tests := []struct {
42		in      Buffers
43		consume int64
44		want    Buffers
45	}{
46		{
47			in:      Buffers{[]byte("foo"), []byte("bar")},
48			consume: 0,
49			want:    Buffers{[]byte("foo"), []byte("bar")},
50		},
51		{
52			in:      Buffers{[]byte("foo"), []byte("bar")},
53			consume: 2,
54			want:    Buffers{[]byte("o"), []byte("bar")},
55		},
56		{
57			in:      Buffers{[]byte("foo"), []byte("bar")},
58			consume: 3,
59			want:    Buffers{[]byte("bar")},
60		},
61		{
62			in:      Buffers{[]byte("foo"), []byte("bar")},
63			consume: 4,
64			want:    Buffers{[]byte("ar")},
65		},
66		{
67			in:      Buffers{nil, nil, nil, []byte("bar")},
68			consume: 1,
69			want:    Buffers{[]byte("ar")},
70		},
71		{
72			in:      Buffers{nil, nil, nil, []byte("foo")},
73			consume: 0,
74			want:    Buffers{[]byte("foo")},
75		},
76		{
77			in:      Buffers{nil, nil, nil},
78			consume: 0,
79			want:    Buffers{},
80		},
81	}
82	for i, tt := range tests {
83		in := tt.in
84		in.consume(tt.consume)
85		if !reflect.DeepEqual(in, tt.want) {
86			t.Errorf("%d. after consume(%d) = %+v, want %+v", i, tt.consume, in, tt.want)
87		}
88	}
89}
90
91func TestBuffers_WriteTo(t *testing.T) {
92	for _, name := range []string{"WriteTo", "Copy"} {
93		for _, size := range []int{0, 10, 1023, 1024, 1025} {
94			t.Run(fmt.Sprintf("%s/%d", name, size), func(t *testing.T) {
95				testBuffer_writeTo(t, size, name == "Copy")
96			})
97		}
98	}
99}
100
101func testBuffer_writeTo(t *testing.T, chunks int, useCopy bool) {
102	oldHook := poll.TestHookDidWritev
103	defer func() { poll.TestHookDidWritev = oldHook }()
104	var writeLog struct {
105		sync.Mutex
106		log []int
107	}
108	poll.TestHookDidWritev = func(size int) {
109		writeLog.Lock()
110		writeLog.log = append(writeLog.log, size)
111		writeLog.Unlock()
112	}
113	var want bytes.Buffer
114	for i := 0; i < chunks; i++ {
115		want.WriteByte(byte(i))
116	}
117
118	withTCPConnPair(t, func(c *TCPConn) error {
119		buffers := make(Buffers, chunks)
120		for i := range buffers {
121			buffers[i] = want.Bytes()[i : i+1]
122		}
123		var n int64
124		var err error
125		if useCopy {
126			n, err = io.Copy(c, &buffers)
127		} else {
128			n, err = buffers.WriteTo(c)
129		}
130		if err != nil {
131			return err
132		}
133		if len(buffers) != 0 {
134			return fmt.Errorf("len(buffers) = %d; want 0", len(buffers))
135		}
136		if n != int64(want.Len()) {
137			return fmt.Errorf("Buffers.WriteTo returned %d; want %d", n, want.Len())
138		}
139		return nil
140	}, func(c *TCPConn) error {
141		all, err := io.ReadAll(c)
142		if !bytes.Equal(all, want.Bytes()) || err != nil {
143			return fmt.Errorf("client read %q, %v; want %q, nil", all, err, want.Bytes())
144		}
145
146		writeLog.Lock() // no need to unlock
147		var gotSum int
148		for _, v := range writeLog.log {
149			gotSum += v
150		}
151
152		var wantSum int
153		switch runtime.GOOS {
154		case "aix", "android", "darwin", "ios", "dragonfly", "freebsd", "illumos", "linux", "netbsd", "openbsd", "solaris":
155			var wantMinCalls int
156			wantSum = want.Len()
157			v := chunks
158			for v > 0 {
159				wantMinCalls++
160				v -= 1024
161			}
162			if len(writeLog.log) < wantMinCalls {
163				t.Errorf("write calls = %v < wanted min %v", len(writeLog.log), wantMinCalls)
164			}
165		case "windows":
166			var wantCalls int
167			wantSum = want.Len()
168			if wantSum > 0 {
169				wantCalls = 1 // windows will always do 1 syscall, unless sending empty buffer
170			}
171			if len(writeLog.log) != wantCalls {
172				t.Errorf("write calls = %v; want %v", len(writeLog.log), wantCalls)
173			}
174		}
175		if gotSum != wantSum {
176			t.Errorf("writev call sum  = %v; want %v", gotSum, wantSum)
177		}
178		return nil
179	})
180}
181
182func TestWritevError(t *testing.T) {
183	if runtime.GOOS == "windows" {
184		t.Skipf("skipping the test: windows does not have problem sending large chunks of data")
185	}
186
187	ln := newLocalListener(t, "tcp")
188
189	ch := make(chan Conn, 1)
190	defer func() {
191		ln.Close()
192		for c := range ch {
193			c.Close()
194		}
195	}()
196
197	go func() {
198		defer close(ch)
199		c, err := ln.Accept()
200		if err != nil {
201			t.Error(err)
202			return
203		}
204		ch <- c
205	}()
206	c1, err := Dial("tcp", ln.Addr().String())
207	if err != nil {
208		t.Fatal(err)
209	}
210	defer c1.Close()
211	c2 := <-ch
212	if c2 == nil {
213		t.Fatal("no server side connection")
214	}
215	c2.Close()
216
217	// 1 GB of data should be enough to notice the connection is gone.
218	// Just a few bytes is not enough.
219	// Arrange to reuse the same 1 MB buffer so that we don't allocate much.
220	buf := make([]byte, 1<<20)
221	buffers := make(Buffers, 1<<10)
222	for i := range buffers {
223		buffers[i] = buf
224	}
225	if _, err := buffers.WriteTo(c1); err == nil {
226		t.Fatal("Buffers.WriteTo(closed conn) succeeded, want error")
227	}
228}
229