1// Copyright 2018 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5//go:build linux 6 7package net 8 9import ( 10 "internal/poll" 11 "io" 12 "os" 13 "strconv" 14 "sync" 15 "syscall" 16 "testing" 17) 18 19func TestSplice(t *testing.T) { 20 t.Run("tcp-to-tcp", func(t *testing.T) { testSplice(t, "tcp", "tcp") }) 21 if !testableNetwork("unixgram") { 22 t.Skip("skipping unix-to-tcp tests") 23 } 24 t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") }) 25 t.Run("tcp-to-unix", func(t *testing.T) { testSplice(t, "tcp", "unix") }) 26 t.Run("tcp-to-file", func(t *testing.T) { testSpliceToFile(t, "tcp", "file") }) 27 t.Run("unix-to-file", func(t *testing.T) { testSpliceToFile(t, "unix", "file") }) 28 t.Run("no-unixpacket", testSpliceNoUnixpacket) 29 t.Run("no-unixgram", testSpliceNoUnixgram) 30} 31 32func testSpliceToFile(t *testing.T, upNet, downNet string) { 33 t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.testFile) 34 t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.testFile) 35 t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.testFile) 36 t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.testFile) 37 t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.testFile) 38 t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.testFile) 39} 40 41func testSplice(t *testing.T, upNet, downNet string) { 42 t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.test) 43 t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.test) 44 t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.test) 45 t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.test) 46 t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.test) 47 t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.test) 48 t.Run("readerAtEOF", func(t *testing.T) { testSpliceReaderAtEOF(t, upNet, downNet) }) 49 t.Run("issue25985", func(t *testing.T) { testSpliceIssue25985(t, upNet, downNet) }) 50} 51 52type spliceTestCase struct { 53 upNet, downNet string 54 55 chunkSize, totalSize int 56 limitReadSize int 57} 58 59func (tc spliceTestCase) test(t *testing.T) { 60 hook := hookSplice(t) 61 62 // We need to use the actual size for startTestSocketPeer when testing with LimitedReader, 63 // otherwise the child process created in startTestSocketPeer will hang infinitely because of 64 // the mismatch of data size to transfer. 65 size := tc.totalSize 66 if tc.limitReadSize > 0 { 67 if tc.limitReadSize < size { 68 size = tc.limitReadSize 69 } 70 } 71 72 clientUp, serverUp := spawnTestSocketPair(t, tc.upNet) 73 defer serverUp.Close() 74 cleanup, err := startTestSocketPeer(t, clientUp, "w", tc.chunkSize, size) 75 if err != nil { 76 t.Fatal(err) 77 } 78 defer cleanup(t) 79 clientDown, serverDown := spawnTestSocketPair(t, tc.downNet) 80 defer serverDown.Close() 81 cleanup, err = startTestSocketPeer(t, clientDown, "r", tc.chunkSize, size) 82 if err != nil { 83 t.Fatal(err) 84 } 85 defer cleanup(t) 86 87 var r io.Reader = serverUp 88 if tc.limitReadSize > 0 { 89 r = &io.LimitedReader{ 90 N: int64(tc.limitReadSize), 91 R: serverUp, 92 } 93 defer serverUp.Close() 94 } 95 n, err := io.Copy(serverDown, r) 96 if err != nil { 97 t.Fatal(err) 98 } 99 100 if want := int64(size); want != n { 101 t.Errorf("want %d bytes spliced, got %d", want, n) 102 } 103 104 if tc.limitReadSize > 0 { 105 wantN := 0 106 if tc.limitReadSize > size { 107 wantN = tc.limitReadSize - size 108 } 109 110 if n := r.(*io.LimitedReader).N; n != int64(wantN) { 111 t.Errorf("r.N = %d, want %d", n, wantN) 112 } 113 } 114 115 // poll.Splice is expected to be called when the source is not 116 // a wrapper or the destination is TCPConn. 117 if tc.limitReadSize == 0 || tc.downNet == "tcp" { 118 // We should have called poll.Splice with the right file descriptor arguments. 119 if n > 0 && !hook.called { 120 t.Fatal("expected poll.Splice to be called") 121 } 122 123 verifySpliceFds(t, serverDown, hook, "dst") 124 verifySpliceFds(t, serverUp, hook, "src") 125 126 // poll.Splice is expected to handle the data transmission successfully. 127 if !hook.handled || hook.written != int64(size) || hook.err != nil { 128 t.Errorf("expected handled = true, written = %d, err = nil, but got handled = %t, written = %d, err = %v", 129 size, hook.handled, hook.written, hook.err) 130 } 131 } else if hook.called { 132 // poll.Splice will certainly not be called when the source 133 // is a wrapper and the destination is not TCPConn. 134 t.Errorf("expected poll.Splice not be called") 135 } 136} 137 138func verifySpliceFds(t *testing.T, c Conn, hook *spliceHook, fdType string) { 139 t.Helper() 140 141 sc, ok := c.(syscall.Conn) 142 if !ok { 143 t.Fatalf("expected syscall.Conn") 144 } 145 rc, err := sc.SyscallConn() 146 if err != nil { 147 t.Fatalf("syscall.Conn.SyscallConn error: %v", err) 148 } 149 var hookFd int 150 switch fdType { 151 case "src": 152 hookFd = hook.srcfd 153 case "dst": 154 hookFd = hook.dstfd 155 default: 156 t.Fatalf("unknown fdType %q", fdType) 157 } 158 if err := rc.Control(func(fd uintptr) { 159 if hook.called && hookFd != int(fd) { 160 t.Fatalf("wrong %s file descriptor: got %d, want %d", fdType, hook.dstfd, int(fd)) 161 } 162 }); err != nil { 163 t.Fatalf("syscall.RawConn.Control error: %v", err) 164 } 165} 166 167func (tc spliceTestCase) testFile(t *testing.T) { 168 hook := hookSplice(t) 169 170 // We need to use the actual size for startTestSocketPeer when testing with LimitedReader, 171 // otherwise the child process created in startTestSocketPeer will hang infinitely because of 172 // the mismatch of data size to transfer. 173 actualSize := tc.totalSize 174 if tc.limitReadSize > 0 { 175 if tc.limitReadSize < actualSize { 176 actualSize = tc.limitReadSize 177 } 178 } 179 180 f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0) 181 if err != nil { 182 t.Fatal(err) 183 } 184 defer f.Close() 185 186 client, server := spawnTestSocketPair(t, tc.upNet) 187 defer server.Close() 188 189 cleanup, err := startTestSocketPeer(t, client, "w", tc.chunkSize, actualSize) 190 if err != nil { 191 client.Close() 192 t.Fatal("failed to start splice client:", err) 193 } 194 defer cleanup(t) 195 196 var r io.Reader = server 197 if tc.limitReadSize > 0 { 198 r = &io.LimitedReader{ 199 N: int64(tc.limitReadSize), 200 R: r, 201 } 202 } 203 204 got, err := io.Copy(f, r) 205 if err != nil { 206 t.Fatalf("failed to ReadFrom with error: %v", err) 207 } 208 209 // We shouldn't have called poll.Splice in TCPConn.WriteTo, 210 // it's supposed to be called from File.ReadFrom. 211 if got > 0 && hook.called { 212 t.Error("expected not poll.Splice to be called") 213 } 214 215 if want := int64(actualSize); got != want { 216 t.Errorf("got %d bytes, want %d", got, want) 217 } 218 if tc.limitReadSize > 0 { 219 wantN := 0 220 if tc.limitReadSize > actualSize { 221 wantN = tc.limitReadSize - actualSize 222 } 223 224 if gotN := r.(*io.LimitedReader).N; gotN != int64(wantN) { 225 t.Errorf("r.N = %d, want %d", gotN, wantN) 226 } 227 } 228} 229 230func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) { 231 // UnixConn doesn't implement io.ReaderFrom, which will fail 232 // the following test in asserting a UnixConn to be an io.ReaderFrom, 233 // so skip this test. 234 if downNet == "unix" { 235 t.Skip("skipping test on unix socket") 236 } 237 238 hook := hookSplice(t) 239 240 clientUp, serverUp := spawnTestSocketPair(t, upNet) 241 defer clientUp.Close() 242 clientDown, serverDown := spawnTestSocketPair(t, downNet) 243 defer clientDown.Close() 244 defer serverDown.Close() 245 246 serverUp.Close() 247 248 // We'd like to call net.spliceFrom here and check the handled return 249 // value, but we disable splice on old Linux kernels. 250 // 251 // In that case, poll.Splice and net.spliceFrom return a non-nil error 252 // and handled == false. We'd ideally like to see handled == true 253 // because the source reader is at EOF, but if we're running on an old 254 // kernel, and splice is disabled, we won't see EOF from net.spliceFrom, 255 // because we won't touch the reader at all. 256 // 257 // Trying to untangle the errors from net.spliceFrom and match them 258 // against the errors created by the poll package would be brittle, 259 // so this is a higher level test. 260 // 261 // The following ReadFrom should return immediately, regardless of 262 // whether splice is disabled or not. The other side should then 263 // get a goodbye signal. Test for the goodbye signal. 264 msg := "bye" 265 go func() { 266 serverDown.(io.ReaderFrom).ReadFrom(serverUp) 267 io.WriteString(serverDown, msg) 268 }() 269 270 buf := make([]byte, 3) 271 n, err := io.ReadFull(clientDown, buf) 272 if err != nil { 273 t.Errorf("clientDown: %v", err) 274 } 275 if string(buf) != msg { 276 t.Errorf("clientDown got %q, want %q", buf, msg) 277 } 278 279 // We should have called poll.Splice with the right file descriptor arguments. 280 if n > 0 && !hook.called { 281 t.Fatal("expected poll.Splice to be called") 282 } 283 284 verifySpliceFds(t, serverDown, hook, "dst") 285 286 // poll.Splice is expected to handle the data transmission but fail 287 // when working with a closed endpoint, return an error. 288 if !hook.handled || hook.written > 0 || hook.err == nil { 289 t.Errorf("expected handled = true, written = 0, err != nil, but got handled = %t, written = %d, err = %v", 290 hook.handled, hook.written, hook.err) 291 } 292} 293 294func testSpliceIssue25985(t *testing.T, upNet, downNet string) { 295 front := newLocalListener(t, upNet) 296 defer front.Close() 297 back := newLocalListener(t, downNet) 298 defer back.Close() 299 300 var wg sync.WaitGroup 301 wg.Add(2) 302 303 proxy := func() { 304 src, err := front.Accept() 305 if err != nil { 306 return 307 } 308 dst, err := Dial(downNet, back.Addr().String()) 309 if err != nil { 310 return 311 } 312 defer dst.Close() 313 defer src.Close() 314 go func() { 315 io.Copy(src, dst) 316 wg.Done() 317 }() 318 go func() { 319 io.Copy(dst, src) 320 wg.Done() 321 }() 322 } 323 324 go proxy() 325 326 toFront, err := Dial(upNet, front.Addr().String()) 327 if err != nil { 328 t.Fatal(err) 329 } 330 331 io.WriteString(toFront, "foo") 332 toFront.Close() 333 334 fromProxy, err := back.Accept() 335 if err != nil { 336 t.Fatal(err) 337 } 338 defer fromProxy.Close() 339 340 _, err = io.ReadAll(fromProxy) 341 if err != nil { 342 t.Fatal(err) 343 } 344 345 wg.Wait() 346} 347 348func testSpliceNoUnixpacket(t *testing.T) { 349 clientUp, serverUp := spawnTestSocketPair(t, "unixpacket") 350 defer clientUp.Close() 351 defer serverUp.Close() 352 clientDown, serverDown := spawnTestSocketPair(t, "tcp") 353 defer clientDown.Close() 354 defer serverDown.Close() 355 // If splice called poll.Splice here, we'd get err == syscall.EINVAL 356 // and handled == false. If poll.Splice gets an EINVAL on the first 357 // try, it assumes the kernel it's running on doesn't support splice 358 // for unix sockets and returns handled == false. This works for our 359 // purposes by somewhat of an accident, but is not entirely correct. 360 // 361 // What we want is err == nil and handled == false, i.e. we never 362 // called poll.Splice, because we know the unix socket's network. 363 _, err, handled := spliceFrom(serverDown.(*TCPConn).fd, serverUp) 364 if err != nil || handled != false { 365 t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled) 366 } 367} 368 369func testSpliceNoUnixgram(t *testing.T) { 370 addr, err := ResolveUnixAddr("unixgram", testUnixAddr(t)) 371 if err != nil { 372 t.Fatal(err) 373 } 374 defer os.Remove(addr.Name) 375 up, err := ListenUnixgram("unixgram", addr) 376 if err != nil { 377 t.Fatal(err) 378 } 379 defer up.Close() 380 clientDown, serverDown := spawnTestSocketPair(t, "tcp") 381 defer clientDown.Close() 382 defer serverDown.Close() 383 // Analogous to testSpliceNoUnixpacket. 384 _, err, handled := spliceFrom(serverDown.(*TCPConn).fd, up) 385 if err != nil || handled != false { 386 t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled) 387 } 388} 389 390func BenchmarkSplice(b *testing.B) { 391 testHookUninstaller.Do(uninstallTestHooks) 392 393 b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") }) 394 b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") }) 395 b.Run("tcp-to-unix", func(b *testing.B) { benchSplice(b, "tcp", "unix") }) 396} 397 398func benchSplice(b *testing.B, upNet, downNet string) { 399 for i := 0; i <= 10; i++ { 400 chunkSize := 1 << uint(i+10) 401 tc := spliceTestCase{ 402 upNet: upNet, 403 downNet: downNet, 404 chunkSize: chunkSize, 405 } 406 407 b.Run(strconv.Itoa(chunkSize), tc.bench) 408 } 409} 410 411func (tc spliceTestCase) bench(b *testing.B) { 412 // To benchmark the genericReadFrom code path, set this to false. 413 useSplice := true 414 415 clientUp, serverUp := spawnTestSocketPair(b, tc.upNet) 416 defer serverUp.Close() 417 418 cleanup, err := startTestSocketPeer(b, clientUp, "w", tc.chunkSize, tc.chunkSize*b.N) 419 if err != nil { 420 b.Fatal(err) 421 } 422 defer cleanup(b) 423 424 clientDown, serverDown := spawnTestSocketPair(b, tc.downNet) 425 defer serverDown.Close() 426 427 cleanup, err = startTestSocketPeer(b, clientDown, "r", tc.chunkSize, tc.chunkSize*b.N) 428 if err != nil { 429 b.Fatal(err) 430 } 431 defer cleanup(b) 432 433 b.SetBytes(int64(tc.chunkSize)) 434 b.ResetTimer() 435 436 if useSplice { 437 _, err := io.Copy(serverDown, serverUp) 438 if err != nil { 439 b.Fatal(err) 440 } 441 } else { 442 type onlyReader struct { 443 io.Reader 444 } 445 _, err := io.Copy(serverDown, onlyReader{serverUp}) 446 if err != nil { 447 b.Fatal(err) 448 } 449 } 450} 451 452func BenchmarkSpliceFile(b *testing.B) { 453 b.Run("tcp-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "tcp") }) 454 b.Run("unix-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "unix") }) 455} 456 457func benchmarkSpliceFile(b *testing.B, proto string) { 458 for i := 0; i <= 10; i++ { 459 size := 1 << (i + 10) 460 bench := spliceFileBench{ 461 proto: proto, 462 chunkSize: size, 463 } 464 b.Run(strconv.Itoa(size), bench.benchSpliceFile) 465 } 466} 467 468type spliceFileBench struct { 469 proto string 470 chunkSize int 471} 472 473func (bench spliceFileBench) benchSpliceFile(b *testing.B) { 474 f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0) 475 if err != nil { 476 b.Fatal(err) 477 } 478 defer f.Close() 479 480 totalSize := b.N * bench.chunkSize 481 482 client, server := spawnTestSocketPair(b, bench.proto) 483 defer server.Close() 484 485 cleanup, err := startTestSocketPeer(b, client, "w", bench.chunkSize, totalSize) 486 if err != nil { 487 client.Close() 488 b.Fatalf("failed to start splice client: %v", err) 489 } 490 defer cleanup(b) 491 492 b.ReportAllocs() 493 b.SetBytes(int64(bench.chunkSize)) 494 b.ResetTimer() 495 496 got, err := io.Copy(f, server) 497 if err != nil { 498 b.Fatalf("failed to ReadFrom with error: %v", err) 499 } 500 if want := int64(totalSize); got != want { 501 b.Errorf("bytes sent mismatch, got: %d, want: %d", got, want) 502 } 503} 504 505func hookSplice(t *testing.T) *spliceHook { 506 t.Helper() 507 508 h := new(spliceHook) 509 h.install() 510 t.Cleanup(h.uninstall) 511 return h 512} 513 514type spliceHook struct { 515 called bool 516 dstfd int 517 srcfd int 518 remain int64 519 520 written int64 521 handled bool 522 err error 523 524 original func(dst, src *poll.FD, remain int64) (int64, bool, error) 525} 526 527func (h *spliceHook) install() { 528 h.original = pollSplice 529 pollSplice = func(dst, src *poll.FD, remain int64) (int64, bool, error) { 530 h.called = true 531 h.dstfd = dst.Sysfd 532 h.srcfd = src.Sysfd 533 h.remain = remain 534 h.written, h.handled, h.err = h.original(dst, src, remain) 535 return h.written, h.handled, h.err 536 } 537} 538 539func (h *spliceHook) uninstall() { 540 pollSplice = h.original 541} 542