1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package net
6
7import (
8	"bytes"
9	"context"
10	"errors"
11	"syscall"
12	"testing"
13)
14
15func newLocalListenerMPTCP(t *testing.T, envVar bool) Listener {
16	lc := &ListenConfig{}
17
18	if envVar {
19		if !lc.MultipathTCP() {
20			t.Fatal("MultipathTCP Listen is not on despite GODEBUG=multipathtcp=1")
21		}
22	} else {
23		if lc.MultipathTCP() {
24			t.Error("MultipathTCP should be off by default")
25		}
26
27		lc.SetMultipathTCP(true)
28		if !lc.MultipathTCP() {
29			t.Fatal("MultipathTCP is not on after having been forced to on")
30		}
31	}
32
33	ln, err := lc.Listen(context.Background(), "tcp", "127.0.0.1:0")
34	if err != nil {
35		t.Fatal(err)
36	}
37	return ln
38}
39
40func postAcceptMPTCP(ls *localServer, ch chan<- error) {
41	defer close(ch)
42
43	if len(ls.cl) == 0 {
44		ch <- errors.New("no accepted stream")
45		return
46	}
47
48	c := ls.cl[0]
49
50	tcp, ok := c.(*TCPConn)
51	if !ok {
52		ch <- errors.New("struct is not a TCPConn")
53		return
54	}
55
56	mptcp, err := tcp.MultipathTCP()
57	if err != nil {
58		ch <- err
59		return
60	}
61
62	if !mptcp {
63		ch <- errors.New("incoming connection is not with MPTCP")
64		return
65	}
66
67	// Also check the method for the older kernels if not tested before
68	if hasSOLMPTCP && !isUsingMPTCPProto(tcp.fd) {
69		ch <- errors.New("incoming connection is not an MPTCP proto")
70		return
71	}
72}
73
74func dialerMPTCP(t *testing.T, addr string, envVar bool) {
75	d := &Dialer{}
76
77	if envVar {
78		if !d.MultipathTCP() {
79			t.Fatal("MultipathTCP Dialer is not on despite GODEBUG=multipathtcp=1")
80		}
81	} else {
82		if d.MultipathTCP() {
83			t.Error("MultipathTCP should be off by default")
84		}
85
86		d.SetMultipathTCP(true)
87		if !d.MultipathTCP() {
88			t.Fatal("MultipathTCP is not on after having been forced to on")
89		}
90	}
91
92	c, err := d.Dial("tcp", addr)
93	if err != nil {
94		t.Fatal(err)
95	}
96	defer c.Close()
97
98	tcp, ok := c.(*TCPConn)
99	if !ok {
100		t.Fatal("struct is not a TCPConn")
101	}
102
103	// Transfer a bit of data to make sure everything is still OK
104	snt := []byte("MPTCP TEST")
105	if _, err := c.Write(snt); err != nil {
106		t.Fatal(err)
107	}
108	b := make([]byte, len(snt))
109	if _, err := c.Read(b); err != nil {
110		t.Fatal(err)
111	}
112	if !bytes.Equal(snt, b) {
113		t.Errorf("sent bytes (%s) are different from received ones (%s)", snt, b)
114	}
115
116	mptcp, err := tcp.MultipathTCP()
117	if err != nil {
118		t.Fatal(err)
119	}
120
121	t.Logf("outgoing connection from %s with mptcp: %t", addr, mptcp)
122
123	if !mptcp {
124		t.Error("outgoing connection is not with MPTCP")
125	}
126
127	// Also check the method for the older kernels if not tested before
128	if hasSOLMPTCP && !isUsingMPTCPProto(tcp.fd) {
129		t.Error("outgoing connection is not an MPTCP proto")
130	}
131}
132
133func canCreateMPTCPSocket() bool {
134	// We want to know if we can create an MPTCP socket, not just if it is
135	// available (mptcpAvailable()): it could be blocked by the admin
136	fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, _IPPROTO_MPTCP)
137	if err != nil {
138		return false
139	}
140
141	syscall.Close(fd)
142	return true
143}
144
145func testMultiPathTCP(t *testing.T, envVar bool) {
146	if envVar {
147		t.Log("Test with GODEBUG=multipathtcp=1")
148		t.Setenv("GODEBUG", "multipathtcp=1")
149	} else {
150		t.Log("Test with GODEBUG=multipathtcp=0")
151		t.Setenv("GODEBUG", "multipathtcp=0")
152	}
153
154	ln := newLocalListenerMPTCP(t, envVar)
155
156	// similar to tcpsock_test:TestIPv6LinkLocalUnicastTCP
157	ls := (&streamListener{Listener: ln}).newLocalServer()
158	defer ls.teardown()
159
160	if g, w := ls.Listener.Addr().Network(), "tcp"; g != w {
161		t.Fatalf("Network type mismatch: got %q, want %q", g, w)
162	}
163
164	genericCh := make(chan error)
165	mptcpCh := make(chan error)
166	handler := func(ls *localServer, ln Listener) {
167		ls.transponder(ln, genericCh)
168		postAcceptMPTCP(ls, mptcpCh)
169	}
170	if err := ls.buildup(handler); err != nil {
171		t.Fatal(err)
172	}
173
174	dialerMPTCP(t, ln.Addr().String(), envVar)
175
176	if err := <-genericCh; err != nil {
177		t.Error(err)
178	}
179	if err := <-mptcpCh; err != nil {
180		t.Error(err)
181	}
182}
183
184func TestMultiPathTCP(t *testing.T) {
185	if !canCreateMPTCPSocket() {
186		t.Skip("Cannot create MPTCP sockets")
187	}
188
189	for _, envVar := range []bool{false, true} {
190		testMultiPathTCP(t, envVar)
191	}
192}
193