1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5//go:build unix || js || wasip1 || windows
6
7package net
8
9import (
10	"context"
11	"net/netip"
12	"syscall"
13)
14
15func sockaddrToUDP(sa syscall.Sockaddr) Addr {
16	switch sa := sa.(type) {
17	case *syscall.SockaddrInet4:
18		return &UDPAddr{IP: sa.Addr[0:], Port: sa.Port}
19	case *syscall.SockaddrInet6:
20		return &UDPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneCache.name(int(sa.ZoneId))}
21	}
22	return nil
23}
24
25func (a *UDPAddr) family() int {
26	if a == nil || len(a.IP) <= IPv4len {
27		return syscall.AF_INET
28	}
29	if a.IP.To4() != nil {
30		return syscall.AF_INET
31	}
32	return syscall.AF_INET6
33}
34
35func (a *UDPAddr) sockaddr(family int) (syscall.Sockaddr, error) {
36	if a == nil {
37		return nil, nil
38	}
39	return ipToSockaddr(family, a.IP, a.Port, a.Zone)
40}
41
42func (a *UDPAddr) toLocal(net string) sockaddr {
43	return &UDPAddr{loopbackIP(net), a.Port, a.Zone}
44}
45
46func (c *UDPConn) readFrom(b []byte, addr *UDPAddr) (int, *UDPAddr, error) {
47	var n int
48	var err error
49	switch c.fd.family {
50	case syscall.AF_INET:
51		var from syscall.SockaddrInet4
52		n, err = c.fd.readFromInet4(b, &from)
53		if err == nil {
54			ip := from.Addr // copy from.Addr; ip escapes, so this line allocates 4 bytes
55			*addr = UDPAddr{IP: ip[:], Port: from.Port}
56		}
57	case syscall.AF_INET6:
58		var from syscall.SockaddrInet6
59		n, err = c.fd.readFromInet6(b, &from)
60		if err == nil {
61			ip := from.Addr // copy from.Addr; ip escapes, so this line allocates 16 bytes
62			*addr = UDPAddr{IP: ip[:], Port: from.Port, Zone: zoneCache.name(int(from.ZoneId))}
63		}
64	}
65	if err != nil {
66		// No sockaddr, so don't return UDPAddr.
67		addr = nil
68	}
69	return n, addr, err
70}
71
72func (c *UDPConn) readFromAddrPort(b []byte) (n int, addr netip.AddrPort, err error) {
73	var ip netip.Addr
74	var port int
75	switch c.fd.family {
76	case syscall.AF_INET:
77		var from syscall.SockaddrInet4
78		n, err = c.fd.readFromInet4(b, &from)
79		if err == nil {
80			ip = netip.AddrFrom4(from.Addr)
81			port = from.Port
82		}
83	case syscall.AF_INET6:
84		var from syscall.SockaddrInet6
85		n, err = c.fd.readFromInet6(b, &from)
86		if err == nil {
87			ip = netip.AddrFrom16(from.Addr).WithZone(zoneCache.name(int(from.ZoneId)))
88			port = from.Port
89		}
90	}
91	if err == nil {
92		addr = netip.AddrPortFrom(ip, uint16(port))
93	}
94	return n, addr, err
95}
96
97func (c *UDPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr netip.AddrPort, err error) {
98	switch c.fd.family {
99	case syscall.AF_INET:
100		var sa syscall.SockaddrInet4
101		n, oobn, flags, err = c.fd.readMsgInet4(b, oob, 0, &sa)
102		ip := netip.AddrFrom4(sa.Addr)
103		addr = netip.AddrPortFrom(ip, uint16(sa.Port))
104	case syscall.AF_INET6:
105		var sa syscall.SockaddrInet6
106		n, oobn, flags, err = c.fd.readMsgInet6(b, oob, 0, &sa)
107		ip := netip.AddrFrom16(sa.Addr).WithZone(zoneCache.name(int(sa.ZoneId)))
108		addr = netip.AddrPortFrom(ip, uint16(sa.Port))
109	}
110	return
111}
112
113func (c *UDPConn) writeTo(b []byte, addr *UDPAddr) (int, error) {
114	if c.fd.isConnected {
115		return 0, ErrWriteToConnected
116	}
117	if addr == nil {
118		return 0, errMissingAddress
119	}
120
121	switch c.fd.family {
122	case syscall.AF_INET:
123		sa, err := ipToSockaddrInet4(addr.IP, addr.Port)
124		if err != nil {
125			return 0, err
126		}
127		return c.fd.writeToInet4(b, &sa)
128	case syscall.AF_INET6:
129		sa, err := ipToSockaddrInet6(addr.IP, addr.Port, addr.Zone)
130		if err != nil {
131			return 0, err
132		}
133		return c.fd.writeToInet6(b, &sa)
134	default:
135		return 0, &AddrError{Err: "invalid address family", Addr: addr.IP.String()}
136	}
137}
138
139func (c *UDPConn) writeToAddrPort(b []byte, addr netip.AddrPort) (int, error) {
140	if c.fd.isConnected {
141		return 0, ErrWriteToConnected
142	}
143	if !addr.IsValid() {
144		return 0, errMissingAddress
145	}
146
147	switch c.fd.family {
148	case syscall.AF_INET:
149		sa, err := addrPortToSockaddrInet4(addr)
150		if err != nil {
151			return 0, err
152		}
153		return c.fd.writeToInet4(b, &sa)
154	case syscall.AF_INET6:
155		sa, err := addrPortToSockaddrInet6(addr)
156		if err != nil {
157			return 0, err
158		}
159		return c.fd.writeToInet6(b, &sa)
160	default:
161		return 0, &AddrError{Err: "invalid address family", Addr: addr.Addr().String()}
162	}
163}
164
165func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error) {
166	if c.fd.isConnected && addr != nil {
167		return 0, 0, ErrWriteToConnected
168	}
169	if !c.fd.isConnected && addr == nil {
170		return 0, 0, errMissingAddress
171	}
172	sa, err := addr.sockaddr(c.fd.family)
173	if err != nil {
174		return 0, 0, err
175	}
176	return c.fd.writeMsg(b, oob, sa)
177}
178
179func (c *UDPConn) writeMsgAddrPort(b, oob []byte, addr netip.AddrPort) (n, oobn int, err error) {
180	if c.fd.isConnected && addr.IsValid() {
181		return 0, 0, ErrWriteToConnected
182	}
183	if !c.fd.isConnected && !addr.IsValid() {
184		return 0, 0, errMissingAddress
185	}
186
187	switch c.fd.family {
188	case syscall.AF_INET:
189		sa, err := addrPortToSockaddrInet4(addr)
190		if err != nil {
191			return 0, 0, err
192		}
193		return c.fd.writeMsgInet4(b, oob, &sa)
194	case syscall.AF_INET6:
195		sa, err := addrPortToSockaddrInet6(addr)
196		if err != nil {
197			return 0, 0, err
198		}
199		return c.fd.writeMsgInet6(b, oob, &sa)
200	default:
201		return 0, 0, &AddrError{Err: "invalid address family", Addr: addr.Addr().String()}
202	}
203}
204
205func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPConn, error) {
206	ctrlCtxFn := sd.Dialer.ControlContext
207	if ctrlCtxFn == nil && sd.Dialer.Control != nil {
208		ctrlCtxFn = func(ctx context.Context, network, address string, c syscall.RawConn) error {
209			return sd.Dialer.Control(network, address, c)
210		}
211	}
212	fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial", ctrlCtxFn)
213	if err != nil {
214		return nil, err
215	}
216	return newUDPConn(fd), nil
217}
218
219func (sl *sysListener) listenUDP(ctx context.Context, laddr *UDPAddr) (*UDPConn, error) {
220	var ctrlCtxFn func(ctx context.Context, network, address string, c syscall.RawConn) error
221	if sl.ListenConfig.Control != nil {
222		ctrlCtxFn = func(ctx context.Context, network, address string, c syscall.RawConn) error {
223			return sl.ListenConfig.Control(network, address, c)
224		}
225	}
226	fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen", ctrlCtxFn)
227	if err != nil {
228		return nil, err
229	}
230	return newUDPConn(fd), nil
231}
232
233func (sl *sysListener) listenMulticastUDP(ctx context.Context, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
234	var ctrlCtxFn func(ctx context.Context, network, address string, c syscall.RawConn) error
235	if sl.ListenConfig.Control != nil {
236		ctrlCtxFn = func(ctx context.Context, network, address string, c syscall.RawConn) error {
237			return sl.ListenConfig.Control(network, address, c)
238		}
239	}
240	fd, err := internetSocket(ctx, sl.network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen", ctrlCtxFn)
241	if err != nil {
242		return nil, err
243	}
244	c := newUDPConn(fd)
245	if ip4 := gaddr.IP.To4(); ip4 != nil {
246		if err := listenIPv4MulticastUDP(c, ifi, ip4); err != nil {
247			c.Close()
248			return nil, err
249		}
250	} else {
251		if err := listenIPv6MulticastUDP(c, ifi, gaddr.IP); err != nil {
252			c.Close()
253			return nil, err
254		}
255	}
256	return c, nil
257}
258
259func listenIPv4MulticastUDP(c *UDPConn, ifi *Interface, ip IP) error {
260	if ifi != nil {
261		if err := setIPv4MulticastInterface(c.fd, ifi); err != nil {
262			return err
263		}
264	}
265	if err := setIPv4MulticastLoopback(c.fd, false); err != nil {
266		return err
267	}
268	if err := joinIPv4Group(c.fd, ifi, ip); err != nil {
269		return err
270	}
271	return nil
272}
273
274func listenIPv6MulticastUDP(c *UDPConn, ifi *Interface, ip IP) error {
275	if ifi != nil {
276		if err := setIPv6MulticastInterface(c.fd, ifi); err != nil {
277			return err
278		}
279	}
280	if err := setIPv6MulticastLoopback(c.fd, false); err != nil {
281		return err
282	}
283	if err := joinIPv6Group(c.fd, ifi, ip); err != nil {
284		return err
285	}
286	return nil
287}
288