1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5//go:build linux
6
7package net
8
9import (
10	"internal/poll"
11	"io"
12	"os"
13	"strconv"
14	"sync"
15	"syscall"
16	"testing"
17)
18
19func TestSplice(t *testing.T) {
20	t.Run("tcp-to-tcp", func(t *testing.T) { testSplice(t, "tcp", "tcp") })
21	if !testableNetwork("unixgram") {
22		t.Skip("skipping unix-to-tcp tests")
23	}
24	t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") })
25	t.Run("tcp-to-unix", func(t *testing.T) { testSplice(t, "tcp", "unix") })
26	t.Run("tcp-to-file", func(t *testing.T) { testSpliceToFile(t, "tcp", "file") })
27	t.Run("unix-to-file", func(t *testing.T) { testSpliceToFile(t, "unix", "file") })
28	t.Run("no-unixpacket", testSpliceNoUnixpacket)
29	t.Run("no-unixgram", testSpliceNoUnixgram)
30}
31
32func testSpliceToFile(t *testing.T, upNet, downNet string) {
33	t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.testFile)
34	t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.testFile)
35	t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.testFile)
36	t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.testFile)
37	t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.testFile)
38	t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.testFile)
39}
40
41func testSplice(t *testing.T, upNet, downNet string) {
42	t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.test)
43	t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.test)
44	t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.test)
45	t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.test)
46	t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.test)
47	t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.test)
48	t.Run("readerAtEOF", func(t *testing.T) { testSpliceReaderAtEOF(t, upNet, downNet) })
49	t.Run("issue25985", func(t *testing.T) { testSpliceIssue25985(t, upNet, downNet) })
50}
51
52type spliceTestCase struct {
53	upNet, downNet string
54
55	chunkSize, totalSize int
56	limitReadSize        int
57}
58
59func (tc spliceTestCase) test(t *testing.T) {
60	hook := hookSplice(t)
61
62	// We need to use the actual size for startTestSocketPeer when testing with LimitedReader,
63	// otherwise the child process created in startTestSocketPeer will hang infinitely because of
64	// the mismatch of data size to transfer.
65	size := tc.totalSize
66	if tc.limitReadSize > 0 {
67		if tc.limitReadSize < size {
68			size = tc.limitReadSize
69		}
70	}
71
72	clientUp, serverUp := spawnTestSocketPair(t, tc.upNet)
73	defer serverUp.Close()
74	cleanup, err := startTestSocketPeer(t, clientUp, "w", tc.chunkSize, size)
75	if err != nil {
76		t.Fatal(err)
77	}
78	defer cleanup(t)
79	clientDown, serverDown := spawnTestSocketPair(t, tc.downNet)
80	defer serverDown.Close()
81	cleanup, err = startTestSocketPeer(t, clientDown, "r", tc.chunkSize, size)
82	if err != nil {
83		t.Fatal(err)
84	}
85	defer cleanup(t)
86
87	var r io.Reader = serverUp
88	if tc.limitReadSize > 0 {
89		r = &io.LimitedReader{
90			N: int64(tc.limitReadSize),
91			R: serverUp,
92		}
93		defer serverUp.Close()
94	}
95	n, err := io.Copy(serverDown, r)
96	if err != nil {
97		t.Fatal(err)
98	}
99
100	if want := int64(size); want != n {
101		t.Errorf("want %d bytes spliced, got %d", want, n)
102	}
103
104	if tc.limitReadSize > 0 {
105		wantN := 0
106		if tc.limitReadSize > size {
107			wantN = tc.limitReadSize - size
108		}
109
110		if n := r.(*io.LimitedReader).N; n != int64(wantN) {
111			t.Errorf("r.N = %d, want %d", n, wantN)
112		}
113	}
114
115	// poll.Splice is expected to be called when the source is not
116	// a wrapper or the destination is TCPConn.
117	if tc.limitReadSize == 0 || tc.downNet == "tcp" {
118		// We should have called poll.Splice with the right file descriptor arguments.
119		if n > 0 && !hook.called {
120			t.Fatal("expected poll.Splice to be called")
121		}
122
123		verifySpliceFds(t, serverDown, hook, "dst")
124		verifySpliceFds(t, serverUp, hook, "src")
125
126		// poll.Splice is expected to handle the data transmission successfully.
127		if !hook.handled || hook.written != int64(size) || hook.err != nil {
128			t.Errorf("expected handled = true, written = %d, err = nil, but got handled = %t, written = %d, err = %v",
129				size, hook.handled, hook.written, hook.err)
130		}
131	} else if hook.called {
132		// poll.Splice will certainly not be called when the source
133		// is a wrapper and the destination is not TCPConn.
134		t.Errorf("expected poll.Splice not be called")
135	}
136}
137
138func verifySpliceFds(t *testing.T, c Conn, hook *spliceHook, fdType string) {
139	t.Helper()
140
141	sc, ok := c.(syscall.Conn)
142	if !ok {
143		t.Fatalf("expected syscall.Conn")
144	}
145	rc, err := sc.SyscallConn()
146	if err != nil {
147		t.Fatalf("syscall.Conn.SyscallConn error: %v", err)
148	}
149	var hookFd int
150	switch fdType {
151	case "src":
152		hookFd = hook.srcfd
153	case "dst":
154		hookFd = hook.dstfd
155	default:
156		t.Fatalf("unknown fdType %q", fdType)
157	}
158	if err := rc.Control(func(fd uintptr) {
159		if hook.called && hookFd != int(fd) {
160			t.Fatalf("wrong %s file descriptor: got %d, want %d", fdType, hook.dstfd, int(fd))
161		}
162	}); err != nil {
163		t.Fatalf("syscall.RawConn.Control error: %v", err)
164	}
165}
166
167func (tc spliceTestCase) testFile(t *testing.T) {
168	hook := hookSplice(t)
169
170	// We need to use the actual size for startTestSocketPeer when testing with LimitedReader,
171	// otherwise the child process created in startTestSocketPeer will hang infinitely because of
172	// the mismatch of data size to transfer.
173	actualSize := tc.totalSize
174	if tc.limitReadSize > 0 {
175		if tc.limitReadSize < actualSize {
176			actualSize = tc.limitReadSize
177		}
178	}
179
180	f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
181	if err != nil {
182		t.Fatal(err)
183	}
184	defer f.Close()
185
186	client, server := spawnTestSocketPair(t, tc.upNet)
187	defer server.Close()
188
189	cleanup, err := startTestSocketPeer(t, client, "w", tc.chunkSize, actualSize)
190	if err != nil {
191		client.Close()
192		t.Fatal("failed to start splice client:", err)
193	}
194	defer cleanup(t)
195
196	var r io.Reader = server
197	if tc.limitReadSize > 0 {
198		r = &io.LimitedReader{
199			N: int64(tc.limitReadSize),
200			R: r,
201		}
202	}
203
204	got, err := io.Copy(f, r)
205	if err != nil {
206		t.Fatalf("failed to ReadFrom with error: %v", err)
207	}
208
209	// We shouldn't have called poll.Splice in TCPConn.WriteTo,
210	// it's supposed to be called from File.ReadFrom.
211	if got > 0 && hook.called {
212		t.Error("expected not poll.Splice to be called")
213	}
214
215	if want := int64(actualSize); got != want {
216		t.Errorf("got %d bytes, want %d", got, want)
217	}
218	if tc.limitReadSize > 0 {
219		wantN := 0
220		if tc.limitReadSize > actualSize {
221			wantN = tc.limitReadSize - actualSize
222		}
223
224		if gotN := r.(*io.LimitedReader).N; gotN != int64(wantN) {
225			t.Errorf("r.N = %d, want %d", gotN, wantN)
226		}
227	}
228}
229
230func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
231	// UnixConn doesn't implement io.ReaderFrom, which will fail
232	// the following test in asserting a UnixConn to be an io.ReaderFrom,
233	// so skip this test.
234	if downNet == "unix" {
235		t.Skip("skipping test on unix socket")
236	}
237
238	hook := hookSplice(t)
239
240	clientUp, serverUp := spawnTestSocketPair(t, upNet)
241	defer clientUp.Close()
242	clientDown, serverDown := spawnTestSocketPair(t, downNet)
243	defer clientDown.Close()
244	defer serverDown.Close()
245
246	serverUp.Close()
247
248	// We'd like to call net.spliceFrom here and check the handled return
249	// value, but we disable splice on old Linux kernels.
250	//
251	// In that case, poll.Splice and net.spliceFrom return a non-nil error
252	// and handled == false. We'd ideally like to see handled == true
253	// because the source reader is at EOF, but if we're running on an old
254	// kernel, and splice is disabled, we won't see EOF from net.spliceFrom,
255	// because we won't touch the reader at all.
256	//
257	// Trying to untangle the errors from net.spliceFrom and match them
258	// against the errors created by the poll package would be brittle,
259	// so this is a higher level test.
260	//
261	// The following ReadFrom should return immediately, regardless of
262	// whether splice is disabled or not. The other side should then
263	// get a goodbye signal. Test for the goodbye signal.
264	msg := "bye"
265	go func() {
266		serverDown.(io.ReaderFrom).ReadFrom(serverUp)
267		io.WriteString(serverDown, msg)
268	}()
269
270	buf := make([]byte, 3)
271	n, err := io.ReadFull(clientDown, buf)
272	if err != nil {
273		t.Errorf("clientDown: %v", err)
274	}
275	if string(buf) != msg {
276		t.Errorf("clientDown got %q, want %q", buf, msg)
277	}
278
279	// We should have called poll.Splice with the right file descriptor arguments.
280	if n > 0 && !hook.called {
281		t.Fatal("expected poll.Splice to be called")
282	}
283
284	verifySpliceFds(t, serverDown, hook, "dst")
285
286	// poll.Splice is expected to handle the data transmission but fail
287	// when working with a closed endpoint, return an error.
288	if !hook.handled || hook.written > 0 || hook.err == nil {
289		t.Errorf("expected handled = true, written = 0, err != nil, but got handled = %t, written = %d, err = %v",
290			hook.handled, hook.written, hook.err)
291	}
292}
293
294func testSpliceIssue25985(t *testing.T, upNet, downNet string) {
295	front := newLocalListener(t, upNet)
296	defer front.Close()
297	back := newLocalListener(t, downNet)
298	defer back.Close()
299
300	var wg sync.WaitGroup
301	wg.Add(2)
302
303	proxy := func() {
304		src, err := front.Accept()
305		if err != nil {
306			return
307		}
308		dst, err := Dial(downNet, back.Addr().String())
309		if err != nil {
310			return
311		}
312		defer dst.Close()
313		defer src.Close()
314		go func() {
315			io.Copy(src, dst)
316			wg.Done()
317		}()
318		go func() {
319			io.Copy(dst, src)
320			wg.Done()
321		}()
322	}
323
324	go proxy()
325
326	toFront, err := Dial(upNet, front.Addr().String())
327	if err != nil {
328		t.Fatal(err)
329	}
330
331	io.WriteString(toFront, "foo")
332	toFront.Close()
333
334	fromProxy, err := back.Accept()
335	if err != nil {
336		t.Fatal(err)
337	}
338	defer fromProxy.Close()
339
340	_, err = io.ReadAll(fromProxy)
341	if err != nil {
342		t.Fatal(err)
343	}
344
345	wg.Wait()
346}
347
348func testSpliceNoUnixpacket(t *testing.T) {
349	clientUp, serverUp := spawnTestSocketPair(t, "unixpacket")
350	defer clientUp.Close()
351	defer serverUp.Close()
352	clientDown, serverDown := spawnTestSocketPair(t, "tcp")
353	defer clientDown.Close()
354	defer serverDown.Close()
355	// If splice called poll.Splice here, we'd get err == syscall.EINVAL
356	// and handled == false.  If poll.Splice gets an EINVAL on the first
357	// try, it assumes the kernel it's running on doesn't support splice
358	// for unix sockets and returns handled == false. This works for our
359	// purposes by somewhat of an accident, but is not entirely correct.
360	//
361	// What we want is err == nil and handled == false, i.e. we never
362	// called poll.Splice, because we know the unix socket's network.
363	_, err, handled := spliceFrom(serverDown.(*TCPConn).fd, serverUp)
364	if err != nil || handled != false {
365		t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
366	}
367}
368
369func testSpliceNoUnixgram(t *testing.T) {
370	addr, err := ResolveUnixAddr("unixgram", testUnixAddr(t))
371	if err != nil {
372		t.Fatal(err)
373	}
374	defer os.Remove(addr.Name)
375	up, err := ListenUnixgram("unixgram", addr)
376	if err != nil {
377		t.Fatal(err)
378	}
379	defer up.Close()
380	clientDown, serverDown := spawnTestSocketPair(t, "tcp")
381	defer clientDown.Close()
382	defer serverDown.Close()
383	// Analogous to testSpliceNoUnixpacket.
384	_, err, handled := spliceFrom(serverDown.(*TCPConn).fd, up)
385	if err != nil || handled != false {
386		t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
387	}
388}
389
390func BenchmarkSplice(b *testing.B) {
391	testHookUninstaller.Do(uninstallTestHooks)
392
393	b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") })
394	b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") })
395	b.Run("tcp-to-unix", func(b *testing.B) { benchSplice(b, "tcp", "unix") })
396}
397
398func benchSplice(b *testing.B, upNet, downNet string) {
399	for i := 0; i <= 10; i++ {
400		chunkSize := 1 << uint(i+10)
401		tc := spliceTestCase{
402			upNet:     upNet,
403			downNet:   downNet,
404			chunkSize: chunkSize,
405		}
406
407		b.Run(strconv.Itoa(chunkSize), tc.bench)
408	}
409}
410
411func (tc spliceTestCase) bench(b *testing.B) {
412	// To benchmark the genericReadFrom code path, set this to false.
413	useSplice := true
414
415	clientUp, serverUp := spawnTestSocketPair(b, tc.upNet)
416	defer serverUp.Close()
417
418	cleanup, err := startTestSocketPeer(b, clientUp, "w", tc.chunkSize, tc.chunkSize*b.N)
419	if err != nil {
420		b.Fatal(err)
421	}
422	defer cleanup(b)
423
424	clientDown, serverDown := spawnTestSocketPair(b, tc.downNet)
425	defer serverDown.Close()
426
427	cleanup, err = startTestSocketPeer(b, clientDown, "r", tc.chunkSize, tc.chunkSize*b.N)
428	if err != nil {
429		b.Fatal(err)
430	}
431	defer cleanup(b)
432
433	b.SetBytes(int64(tc.chunkSize))
434	b.ResetTimer()
435
436	if useSplice {
437		_, err := io.Copy(serverDown, serverUp)
438		if err != nil {
439			b.Fatal(err)
440		}
441	} else {
442		type onlyReader struct {
443			io.Reader
444		}
445		_, err := io.Copy(serverDown, onlyReader{serverUp})
446		if err != nil {
447			b.Fatal(err)
448		}
449	}
450}
451
452func BenchmarkSpliceFile(b *testing.B) {
453	b.Run("tcp-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "tcp") })
454	b.Run("unix-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "unix") })
455}
456
457func benchmarkSpliceFile(b *testing.B, proto string) {
458	for i := 0; i <= 10; i++ {
459		size := 1 << (i + 10)
460		bench := spliceFileBench{
461			proto:     proto,
462			chunkSize: size,
463		}
464		b.Run(strconv.Itoa(size), bench.benchSpliceFile)
465	}
466}
467
468type spliceFileBench struct {
469	proto     string
470	chunkSize int
471}
472
473func (bench spliceFileBench) benchSpliceFile(b *testing.B) {
474	f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
475	if err != nil {
476		b.Fatal(err)
477	}
478	defer f.Close()
479
480	totalSize := b.N * bench.chunkSize
481
482	client, server := spawnTestSocketPair(b, bench.proto)
483	defer server.Close()
484
485	cleanup, err := startTestSocketPeer(b, client, "w", bench.chunkSize, totalSize)
486	if err != nil {
487		client.Close()
488		b.Fatalf("failed to start splice client: %v", err)
489	}
490	defer cleanup(b)
491
492	b.ReportAllocs()
493	b.SetBytes(int64(bench.chunkSize))
494	b.ResetTimer()
495
496	got, err := io.Copy(f, server)
497	if err != nil {
498		b.Fatalf("failed to ReadFrom with error: %v", err)
499	}
500	if want := int64(totalSize); got != want {
501		b.Errorf("bytes sent mismatch, got: %d, want: %d", got, want)
502	}
503}
504
505func hookSplice(t *testing.T) *spliceHook {
506	t.Helper()
507
508	h := new(spliceHook)
509	h.install()
510	t.Cleanup(h.uninstall)
511	return h
512}
513
514type spliceHook struct {
515	called bool
516	dstfd  int
517	srcfd  int
518	remain int64
519
520	written int64
521	handled bool
522	err     error
523
524	original func(dst, src *poll.FD, remain int64) (int64, bool, error)
525}
526
527func (h *spliceHook) install() {
528	h.original = pollSplice
529	pollSplice = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
530		h.called = true
531		h.dstfd = dst.Sysfd
532		h.srcfd = src.Sysfd
533		h.remain = remain
534		h.written, h.handled, h.err = h.original(dst, src, remain)
535		return h.written, h.handled, h.err
536	}
537}
538
539func (h *spliceHook) uninstall() {
540	pollSplice = h.original
541}
542