xref: /aosp_15_r20/external/crosvm/base/src/sys/linux/netlink.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::alloc::Layout;
6 use std::mem::MaybeUninit;
7 use std::os::unix::io::AsRawFd;
8 use std::str;
9 
10 use libc::EINVAL;
11 use log::error;
12 use zerocopy::AsBytes;
13 use zerocopy::FromBytes;
14 use zerocopy::FromZeroes;
15 use zerocopy::Ref;
16 
17 use super::errno_result;
18 use super::getpid;
19 use super::Error;
20 use super::RawDescriptor;
21 use super::Result;
22 use crate::alloc::LayoutAllocation;
23 use crate::descriptor::AsRawDescriptor;
24 use crate::descriptor::FromRawDescriptor;
25 use crate::descriptor::SafeDescriptor;
26 
27 macro_rules! debug_pr {
28     // By default debugs are suppressed, to enabled them replace macro body with:
29     // $($args:tt)+) => (println!($($args)*))
30     ($($args:tt)+) => {};
31 }
32 
33 const NLMSGHDR_SIZE: usize = std::mem::size_of::<NlMsgHdr>();
34 const GENL_HDRLEN: usize = std::mem::size_of::<GenlMsgHdr>();
35 const NLA_HDRLEN: usize = std::mem::size_of::<NlAttr>();
36 const NLATTR_ALIGN_TO: usize = 4;
37 
38 #[repr(C)]
39 #[derive(Copy, Clone, FromZeroes, FromBytes, AsBytes)]
40 struct NlMsgHdr {
41     pub nlmsg_len: u32,
42     pub nlmsg_type: u16,
43     pub nlmsg_flags: u16,
44     pub nlmsg_seq: u32,
45     pub nlmsg_pid: u32,
46 }
47 
48 /// Netlink attribute struct, can be used by netlink consumer
49 #[repr(C)]
50 #[derive(Copy, Clone, FromZeroes, FromBytes, AsBytes)]
51 pub struct NlAttr {
52     pub len: u16,
53     pub _type: u16,
54 }
55 
56 /// Generic netlink header struct, can be used by netlink consumer
57 #[repr(C)]
58 #[derive(Copy, Clone, FromZeroes, FromBytes, AsBytes)]
59 pub struct GenlMsgHdr {
60     pub cmd: u8,
61     pub version: u8,
62     pub reserved: u16,
63 }
64 /// A single netlink message, including its header and data.
65 pub struct NetlinkMessage<'a> {
66     pub _type: u16,
67     pub flags: u16,
68     pub seq: u32,
69     pub pid: u32,
70     pub data: &'a [u8],
71 }
72 
73 pub struct NlAttrWithData<'a> {
74     pub len: u16,
75     pub _type: u16,
76     pub data: &'a [u8],
77 }
78 
nlattr_align(offset: usize) -> usize79 fn nlattr_align(offset: usize) -> usize {
80     (offset + NLATTR_ALIGN_TO - 1) & !(NLATTR_ALIGN_TO - 1)
81 }
82 
83 /// Iterator over `struct NlAttr` as received from a netlink socket.
84 pub struct NetlinkGenericDataIter<'a> {
85     // `data` must be properly aligned for NlAttr.
86     data: &'a [u8],
87 }
88 
89 impl<'a> Iterator for NetlinkGenericDataIter<'a> {
90     type Item = NlAttrWithData<'a>;
91 
next(&mut self) -> Option<Self::Item>92     fn next(&mut self) -> Option<Self::Item> {
93         if self.data.len() < NLA_HDRLEN {
94             return None;
95         }
96         let nl_hdr = NlAttr::read_from(&self.data[..NLA_HDRLEN])?;
97 
98         // Make sure NlAtrr fits
99         let nl_data_len = nl_hdr.len as usize;
100         if nl_data_len < NLA_HDRLEN || nl_data_len > self.data.len() {
101             return None;
102         }
103 
104         // Get data related to processed NlAttr
105         let data_start = NLA_HDRLEN;
106         let data = &self.data[data_start..nl_data_len];
107 
108         // Get next NlAttr
109         let next_hdr = nlattr_align(nl_data_len);
110         if next_hdr >= self.data.len() {
111             self.data = &[];
112         } else {
113             self.data = &self.data[next_hdr..];
114         }
115 
116         Some(NlAttrWithData {
117             _type: nl_hdr._type,
118             len: nl_hdr.len,
119             data,
120         })
121     }
122 }
123 
124 /// Iterator over `struct nlmsghdr` as received from a netlink socket.
125 pub struct NetlinkMessageIter<'a> {
126     // `data` must be properly aligned for nlmsghdr.
127     data: &'a [u8],
128 }
129 
130 impl<'a> Iterator for NetlinkMessageIter<'a> {
131     type Item = NetlinkMessage<'a>;
132 
next(&mut self) -> Option<Self::Item>133     fn next(&mut self) -> Option<Self::Item> {
134         if self.data.len() < NLMSGHDR_SIZE {
135             return None;
136         }
137         let hdr = NlMsgHdr::read_from(&self.data[..NLMSGHDR_SIZE])?;
138 
139         // NLMSG_OK
140         let msg_len = hdr.nlmsg_len as usize;
141         if msg_len < NLMSGHDR_SIZE || msg_len > self.data.len() {
142             return None;
143         }
144 
145         // NLMSG_DATA
146         let data_start = NLMSGHDR_SIZE;
147         let data = &self.data[data_start..msg_len];
148 
149         // NLMSG_NEXT
150         let align_to = std::mem::align_of::<NlMsgHdr>();
151         let next_hdr = (msg_len + align_to - 1) & !(align_to - 1);
152         if next_hdr >= self.data.len() {
153             self.data = &[];
154         } else {
155             self.data = &self.data[next_hdr..];
156         }
157 
158         Some(NetlinkMessage {
159             _type: hdr.nlmsg_type,
160             flags: hdr.nlmsg_flags,
161             seq: hdr.nlmsg_seq,
162             pid: hdr.nlmsg_pid,
163             data,
164         })
165     }
166 }
167 
168 /// Safe wrapper for `NETLINK_GENERIC` netlink sockets.
169 pub struct NetlinkGenericSocket {
170     sock: SafeDescriptor,
171 }
172 
173 impl AsRawDescriptor for NetlinkGenericSocket {
as_raw_descriptor(&self) -> RawDescriptor174     fn as_raw_descriptor(&self) -> RawDescriptor {
175         self.sock.as_raw_descriptor()
176     }
177 }
178 
179 impl NetlinkGenericSocket {
180     /// Create and bind a new `NETLINK_GENERIC` socket.
new(nl_groups: u32) -> Result<Self>181     pub fn new(nl_groups: u32) -> Result<Self> {
182         // SAFETY:
183         // Safe because we check the return value and convert the raw fd into a SafeDescriptor.
184         let sock = unsafe {
185             let fd = libc::socket(
186                 libc::AF_NETLINK,
187                 libc::SOCK_RAW | libc::SOCK_CLOEXEC,
188                 libc::NETLINK_GENERIC,
189             );
190             if fd < 0 {
191                 return errno_result();
192             }
193 
194             SafeDescriptor::from_raw_descriptor(fd)
195         };
196 
197         // SAFETY:
198         // This MaybeUninit dance is needed because sockaddr_nl has a private padding field and
199         // doesn't implement Default. Safe because all 0s is valid data for sockaddr_nl.
200         let mut sa = unsafe { MaybeUninit::<libc::sockaddr_nl>::zeroed().assume_init() };
201         sa.nl_family = libc::AF_NETLINK as libc::sa_family_t;
202         sa.nl_groups = nl_groups;
203 
204         // SAFETY:
205         // Safe because we pass a descriptor that we own and valid pointer/size for sockaddr.
206         unsafe {
207             let res = libc::bind(
208                 sock.as_raw_fd(),
209                 &sa as *const libc::sockaddr_nl as *const libc::sockaddr,
210                 std::mem::size_of_val(&sa) as libc::socklen_t,
211             );
212             if res < 0 {
213                 return errno_result();
214             }
215         }
216 
217         Ok(NetlinkGenericSocket { sock })
218     }
219 
220     /// Receive messages from the netlink socket.
recv(&self) -> Result<NetlinkGenericRead>221     pub fn recv(&self) -> Result<NetlinkGenericRead> {
222         let buf_size = 8192; // TODO(dverkamp): make this configurable?
223 
224         // Create a buffer with sufficient alignment for nlmsghdr.
225         let layout = Layout::from_size_align(buf_size, std::mem::align_of::<NlMsgHdr>())
226             .map_err(|_| Error::new(EINVAL))?;
227         let allocation = LayoutAllocation::uninitialized(layout);
228 
229         // SAFETY:
230         // Safe because we pass a valid, owned socket fd and a valid pointer/size for the buffer.
231         let bytes_read = unsafe {
232             let res = libc::recv(self.sock.as_raw_fd(), allocation.as_ptr(), buf_size, 0);
233             if res < 0 {
234                 return errno_result();
235             }
236             res as usize
237         };
238 
239         Ok(NetlinkGenericRead {
240             allocation,
241             len: bytes_read,
242         })
243     }
244 
family_name_query(&self, family_name: String) -> Result<NetlinkGenericRead>245     pub fn family_name_query(&self, family_name: String) -> Result<NetlinkGenericRead> {
246         let buf_size = 1024;
247         debug_pr!(
248             "preparing query for family name {}, len {}",
249             family_name,
250             family_name.len()
251         );
252 
253         // Create a buffer with sufficient alignment for nlmsghdr.
254         let layout = Layout::from_size_align(buf_size, std::mem::align_of::<NlMsgHdr>())
255             .map_err(|_| Error::new(EINVAL))
256             .unwrap();
257         let mut allocation = LayoutAllocation::zeroed(layout);
258 
259         // SAFETY:
260         // Safe because the data in allocation was initialized up to `buf_size` and is
261         // sufficiently aligned.
262         let data = unsafe { allocation.as_mut_slice(buf_size) };
263 
264         // Prepare the netlink message header
265         let hdr = Ref::<_, NlMsgHdr>::new(&mut data[..NLMSGHDR_SIZE])
266             .expect("failed to unwrap")
267             .into_mut();
268         hdr.nlmsg_len = NLMSGHDR_SIZE as u32 + GENL_HDRLEN as u32;
269         hdr.nlmsg_len += NLA_HDRLEN as u32 + family_name.len() as u32 + 1;
270         hdr.nlmsg_flags = libc::NLM_F_REQUEST as u16;
271         hdr.nlmsg_type = libc::GENL_ID_CTRL as u16;
272         hdr.nlmsg_pid = getpid() as u32;
273 
274         // Prepare generic netlink message header
275         let genl_hdr_end = NLMSGHDR_SIZE + GENL_HDRLEN;
276         let genl_hdr = Ref::<_, GenlMsgHdr>::new(&mut data[NLMSGHDR_SIZE..genl_hdr_end])
277             .expect("unable to get GenlMsgHdr from slice")
278             .into_mut();
279         genl_hdr.cmd = libc::CTRL_CMD_GETFAMILY as u8;
280         genl_hdr.version = 0x1;
281 
282         // Netlink attributes
283         let nlattr_start = genl_hdr_end;
284         let nlattr_end = nlattr_start + NLA_HDRLEN;
285         let nl_attr = Ref::<_, NlAttr>::new(&mut data[nlattr_start..nlattr_end])
286             .expect("unable to get NlAttr from slice")
287             .into_mut();
288         nl_attr._type = libc::CTRL_ATTR_FAMILY_NAME as u16;
289         nl_attr.len = family_name.len() as u16 + 1 + NLA_HDRLEN as u16;
290 
291         // Fill the message payload with the family name
292         let payload_start = nlattr_end;
293         let payload_end = payload_start + family_name.len();
294         data[payload_start..payload_end].copy_from_slice(family_name.as_bytes());
295 
296         // SAFETY:
297         // Safe because we pass a valid, owned socket fd and a valid pointer/size for the buffer.
298         unsafe {
299             let res = libc::send(
300                 self.sock.as_raw_fd(),
301                 allocation.as_ptr(),
302                 payload_end + 1,
303                 0,
304             );
305             if res < 0 {
306                 error!("failed to send get_family_cmd");
307                 return errno_result();
308             }
309         };
310 
311         // Return the answer
312         match self.recv() {
313             Ok(msg) => Ok(msg),
314             Err(e) => {
315                 error!("recv get_family returned with error {}", e);
316                 Err(e)
317             }
318         }
319     }
320 }
321 
parse_ctrl_group_name_and_id( nested_nl_attr_data: NetlinkGenericDataIter, group_name: &str, ) -> Option<u32>322 fn parse_ctrl_group_name_and_id(
323     nested_nl_attr_data: NetlinkGenericDataIter,
324     group_name: &str,
325 ) -> Option<u32> {
326     let mut mcast_group_id: Option<u32> = None;
327 
328     for nested_nl_attr in nested_nl_attr_data {
329         debug_pr!(
330             "\t\tmcast_grp: nlattr type {}, len {}",
331             nested_nl_attr._type,
332             nested_nl_attr.len
333         );
334 
335         if nested_nl_attr._type == libc::CTRL_ATTR_MCAST_GRP_ID as u16 {
336             mcast_group_id = Some(u32::from_ne_bytes(nested_nl_attr.data.try_into().unwrap()));
337             debug_pr!("\t\t mcast group_id {}", mcast_group_id?);
338         }
339 
340         if nested_nl_attr._type == libc::CTRL_ATTR_MCAST_GRP_NAME as u16 {
341             debug_pr!(
342                 "\t\t mcast group name {}",
343                 strip_padding(&nested_nl_attr.data)
344             );
345 
346             // If the group name match and the group_id was set in previous iteration, return,
347             // valid for group_name, group_id
348             if group_name.eq(strip_padding(nested_nl_attr.data)) && mcast_group_id.is_some() {
349                 debug_pr!(
350                     "\t\t Got what we were looking for group_id = {} for {}",
351                     mcast_group_id?,
352                     group_name
353                 );
354 
355                 return mcast_group_id;
356             }
357         }
358     }
359 
360     None
361 }
362 
363 /// Parse CTRL_ATTR_MCAST_GROUPS data in order to get multicast group id
364 ///
365 /// On success, returns group_id for a given `group_name`
366 ///
367 /// # Arguments
368 ///
369 /// * `nl_attr_area`
370 ///
371 ///     Nested attributes area (CTRL_ATTR_MCAST_GROUPS data), where nl_attr's corresponding to
372 ///     specific groups are embed
373 ///
374 /// * `group_name`
375 ///
376 ///     String with group_name for which we are looking group_id
377 ///
378 /// the CTRL_ATTR_MCAST_GROUPS data has nested attributes. Each of nested attribute is per
379 /// multicast group attributes, which have another nested attributes: CTRL_ATTR_MCAST_GRP_NAME and
380 /// CTRL_ATTR_MCAST_GRP_ID. Need to parse all of them to get mcast group id for a given group_name..
381 ///
382 /// Illustrated layout:
383 /// CTRL_ATTR_MCAST_GROUPS:
384 ///   GR1 (nl_attr._type = 1):
385 ///       CTRL_ATTR_MCAST_GRP_ID,
386 ///       CTRL_ATTR_MCAST_GRP_NAME,
387 ///   GR2 (nl_attr._type = 2):
388 ///       CTRL_ATTR_MCAST_GRP_ID,
389 ///       CTRL_ATTR_MCAST_GRP_NAME,
390 ///   ..
391 ///
392 /// Unfortunately kernel implementation uses `nla_nest_start_noflag` for that
393 /// purpose, which means that it never marked their nest attributes with NLA_F_NESTED flag.
394 /// Therefore all this nesting stages need to be deduced based on specific nl_attr type.
parse_ctrl_mcast_group_id( nl_attr_area: NetlinkGenericDataIter, group_name: &str, ) -> Option<u32>395 fn parse_ctrl_mcast_group_id(
396     nl_attr_area: NetlinkGenericDataIter,
397     group_name: &str,
398 ) -> Option<u32> {
399     // There may be multiple nested multicast groups, go through all of them.
400     // Each of nested group, has other nested nlattr:
401     //  CTRL_ATTR_MCAST_GRP_ID
402     //  CTRL_ATTR_MCAST_GRP_NAME
403     //
404     //  which are further proceed by parse_ctrl_group_name_and_id
405     for nested_gr_nl_attr in nl_attr_area {
406         debug_pr!(
407             "\tmcast_groups: nlattr type(gr_nr) {}, len {}",
408             nested_gr_nl_attr._type,
409             nested_gr_nl_attr.len
410         );
411 
412         let netlink_nested_attr = NetlinkGenericDataIter {
413             data: nested_gr_nl_attr.data,
414         };
415 
416         if let Some(mcast_group_id) = parse_ctrl_group_name_and_id(netlink_nested_attr, group_name)
417         {
418             return Some(mcast_group_id);
419         }
420     }
421 
422     None
423 }
424 
425 // Like `CStr::from_bytes_with_nul` but strips any bytes starting from first '\0'-byte and
426 // returns &str. Panics if `b` doesn't contain any '\0' bytes.
strip_padding(b: &[u8]) -> &str427 fn strip_padding(b: &[u8]) -> &str {
428     // It would be nice if we could use memchr here but that's locked behind an unstable gate.
429     let pos = b
430         .iter()
431         .position(|&c| c == 0)
432         .expect("`b` doesn't contain any nul bytes");
433 
434     str::from_utf8(&b[..pos]).unwrap()
435 }
436 
437 pub struct NetlinkGenericRead {
438     allocation: LayoutAllocation,
439     len: usize,
440 }
441 
442 impl NetlinkGenericRead {
iter(&self) -> NetlinkMessageIter443     pub fn iter(&self) -> NetlinkMessageIter {
444         // SAFETY:
445         // Safe because the data in allocation was initialized up to `self.len` by `recv()` and is
446         // sufficiently aligned.
447         let data = unsafe { &self.allocation.as_slice(self.len) };
448         NetlinkMessageIter { data }
449     }
450 
451     /// Parse NetlinkGeneric response in order to get multicast group id
452     ///
453     /// On success, returns group_id for a given `group_name`
454     ///
455     /// # Arguments
456     ///
457     /// * `group_name` - String with group_name for which we are looking group_id
458     ///
459     /// Response from family_name_query (CTRL_CMD_GETFAMILY) is a netlink message with multiple
460     /// attributes encapsulated (some of them are nested). An example response layout is
461     /// illustrated below:
462     ///
463     ///  {
464     ///    CTRL_ATTR_FAMILY_NAME
465     ///    CTRL_ATTR_FAMILY_ID
466     ///    CTRL_ATTR_VERSION
467     ///    ...
468     ///    CTRL_ATTR_MCAST_GROUPS {
469     ///      GR1 (nl_attr._type = 1) {
470     ///          CTRL_ATTR_MCAST_GRP_ID    *we need parse this attr to obtain group id used for
471     ///                                     the group mask
472     ///          CTRL_ATTR_MCAST_GRP_NAME  *group_name that we need to match with
473     ///      }
474     ///      GR2 (nl_attr._type = 2) {
475     ///          CTRL_ATTR_MCAST_GRP_ID
476     ///          CTRL_ATTR_MCAST_GRP_NAME
477     ///      }
478     ///      ...
479     ///     }
480     ///   }
get_multicast_group_id(&self, group_name: String) -> Option<u32>481     pub fn get_multicast_group_id(&self, group_name: String) -> Option<u32> {
482         for netlink_msg in self.iter() {
483             debug_pr!(
484                 "received type: {}, flags {}, pid {}, data {:?}",
485                 netlink_msg._type,
486                 netlink_msg.flags,
487                 netlink_msg.pid,
488                 netlink_msg.data
489             );
490 
491             if netlink_msg._type != libc::GENL_ID_CTRL as u16 {
492                 error!("Received not a generic netlink controller msg");
493                 return None;
494             }
495 
496             let netlink_data = NetlinkGenericDataIter {
497                 data: &netlink_msg.data[GENL_HDRLEN..],
498             };
499             for nl_attr in netlink_data {
500                 debug_pr!("nl_attr type {}, len {}", nl_attr._type, nl_attr.len);
501 
502                 if nl_attr._type == libc::CTRL_ATTR_MCAST_GROUPS as u16 {
503                     let netlink_nested_attr = NetlinkGenericDataIter { data: nl_attr.data };
504 
505                     if let Some(mcast_group_id) =
506                         parse_ctrl_mcast_group_id(netlink_nested_attr, &group_name)
507                     {
508                         return Some(mcast_group_id);
509                     }
510                 }
511             }
512         }
513         None
514     }
515 }
516