1// Copyright 2015 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package socktest
6
7import (
8	"internal/syscall/windows"
9	"syscall"
10)
11
12// WSASocket wraps [syscall.WSASocket].
13func (sw *Switch) WSASocket(family, sotype, proto int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (s syscall.Handle, err error) {
14	sw.once.Do(sw.init)
15
16	so := &Status{Cookie: cookie(int(family), int(sotype), int(proto))}
17	sw.fmu.RLock()
18	f, _ := sw.fltab[FilterSocket]
19	sw.fmu.RUnlock()
20
21	af, err := f.apply(so)
22	if err != nil {
23		return syscall.InvalidHandle, err
24	}
25	s, so.Err = windows.WSASocket(family, sotype, proto, protinfo, group, flags)
26	if err = af.apply(so); err != nil {
27		if so.Err == nil {
28			syscall.Closesocket(s)
29		}
30		return syscall.InvalidHandle, err
31	}
32
33	sw.smu.Lock()
34	defer sw.smu.Unlock()
35	if so.Err != nil {
36		sw.stats.getLocked(so.Cookie).OpenFailed++
37		return syscall.InvalidHandle, so.Err
38	}
39	nso := sw.addLocked(s, int(family), int(sotype), int(proto))
40	sw.stats.getLocked(nso.Cookie).Opened++
41	return s, nil
42}
43
44// Closesocket wraps [syscall.Closesocket].
45func (sw *Switch) Closesocket(s syscall.Handle) (err error) {
46	so := sw.sockso(s)
47	if so == nil {
48		return syscall.Closesocket(s)
49	}
50	sw.fmu.RLock()
51	f, _ := sw.fltab[FilterClose]
52	sw.fmu.RUnlock()
53
54	af, err := f.apply(so)
55	if err != nil {
56		return err
57	}
58	so.Err = syscall.Closesocket(s)
59	if err = af.apply(so); err != nil {
60		return err
61	}
62
63	sw.smu.Lock()
64	defer sw.smu.Unlock()
65	if so.Err != nil {
66		sw.stats.getLocked(so.Cookie).CloseFailed++
67		return so.Err
68	}
69	delete(sw.sotab, s)
70	sw.stats.getLocked(so.Cookie).Closed++
71	return nil
72}
73
74// Connect wraps [syscall.Connect].
75func (sw *Switch) Connect(s syscall.Handle, sa syscall.Sockaddr) (err error) {
76	so := sw.sockso(s)
77	if so == nil {
78		return syscall.Connect(s, sa)
79	}
80	sw.fmu.RLock()
81	f, _ := sw.fltab[FilterConnect]
82	sw.fmu.RUnlock()
83
84	af, err := f.apply(so)
85	if err != nil {
86		return err
87	}
88	so.Err = syscall.Connect(s, sa)
89	if err = af.apply(so); err != nil {
90		return err
91	}
92
93	sw.smu.Lock()
94	defer sw.smu.Unlock()
95	if so.Err != nil {
96		sw.stats.getLocked(so.Cookie).ConnectFailed++
97		return so.Err
98	}
99	sw.stats.getLocked(so.Cookie).Connected++
100	return nil
101}
102
103// ConnectEx wraps [syscall.ConnectEx].
104func (sw *Switch) ConnectEx(s syscall.Handle, sa syscall.Sockaddr, b *byte, n uint32, nwr *uint32, o *syscall.Overlapped) (err error) {
105	so := sw.sockso(s)
106	if so == nil {
107		return syscall.ConnectEx(s, sa, b, n, nwr, o)
108	}
109	sw.fmu.RLock()
110	f, _ := sw.fltab[FilterConnect]
111	sw.fmu.RUnlock()
112
113	af, err := f.apply(so)
114	if err != nil {
115		return err
116	}
117	so.Err = syscall.ConnectEx(s, sa, b, n, nwr, o)
118	if err = af.apply(so); err != nil {
119		return err
120	}
121
122	sw.smu.Lock()
123	defer sw.smu.Unlock()
124	if so.Err != nil {
125		sw.stats.getLocked(so.Cookie).ConnectFailed++
126		return so.Err
127	}
128	sw.stats.getLocked(so.Cookie).Connected++
129	return nil
130}
131
132// Listen wraps [syscall.Listen].
133func (sw *Switch) Listen(s syscall.Handle, backlog int) (err error) {
134	so := sw.sockso(s)
135	if so == nil {
136		return syscall.Listen(s, backlog)
137	}
138	sw.fmu.RLock()
139	f, _ := sw.fltab[FilterListen]
140	sw.fmu.RUnlock()
141
142	af, err := f.apply(so)
143	if err != nil {
144		return err
145	}
146	so.Err = syscall.Listen(s, backlog)
147	if err = af.apply(so); err != nil {
148		return err
149	}
150
151	sw.smu.Lock()
152	defer sw.smu.Unlock()
153	if so.Err != nil {
154		sw.stats.getLocked(so.Cookie).ListenFailed++
155		return so.Err
156	}
157	sw.stats.getLocked(so.Cookie).Listened++
158	return nil
159}
160
161// AcceptEx wraps [syscall.AcceptEx].
162func (sw *Switch) AcceptEx(ls syscall.Handle, as syscall.Handle, b *byte, rxdatalen uint32, laddrlen uint32, raddrlen uint32, rcvd *uint32, overlapped *syscall.Overlapped) error {
163	so := sw.sockso(ls)
164	if so == nil {
165		return syscall.AcceptEx(ls, as, b, rxdatalen, laddrlen, raddrlen, rcvd, overlapped)
166	}
167	sw.fmu.RLock()
168	f, _ := sw.fltab[FilterAccept]
169	sw.fmu.RUnlock()
170
171	af, err := f.apply(so)
172	if err != nil {
173		return err
174	}
175	so.Err = syscall.AcceptEx(ls, as, b, rxdatalen, laddrlen, raddrlen, rcvd, overlapped)
176	if err = af.apply(so); err != nil {
177		return err
178	}
179
180	sw.smu.Lock()
181	defer sw.smu.Unlock()
182	if so.Err != nil {
183		sw.stats.getLocked(so.Cookie).AcceptFailed++
184		return so.Err
185	}
186	nso := sw.addLocked(as, so.Cookie.Family(), so.Cookie.Type(), so.Cookie.Protocol())
187	sw.stats.getLocked(nso.Cookie).Accepted++
188	return nil
189}
190