1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::alloc::Layout;
6 use std::mem::MaybeUninit;
7 use std::os::unix::io::AsRawFd;
8 use std::str;
9
10 use libc::EINVAL;
11 use log::error;
12 use zerocopy::AsBytes;
13 use zerocopy::FromBytes;
14 use zerocopy::FromZeroes;
15 use zerocopy::Ref;
16
17 use super::errno_result;
18 use super::getpid;
19 use super::Error;
20 use super::RawDescriptor;
21 use super::Result;
22 use crate::alloc::LayoutAllocation;
23 use crate::descriptor::AsRawDescriptor;
24 use crate::descriptor::FromRawDescriptor;
25 use crate::descriptor::SafeDescriptor;
26
27 macro_rules! debug_pr {
28 // By default debugs are suppressed, to enabled them replace macro body with:
29 // $($args:tt)+) => (println!($($args)*))
30 ($($args:tt)+) => {};
31 }
32
33 const NLMSGHDR_SIZE: usize = std::mem::size_of::<NlMsgHdr>();
34 const GENL_HDRLEN: usize = std::mem::size_of::<GenlMsgHdr>();
35 const NLA_HDRLEN: usize = std::mem::size_of::<NlAttr>();
36 const NLATTR_ALIGN_TO: usize = 4;
37
38 #[repr(C)]
39 #[derive(Copy, Clone, FromZeroes, FromBytes, AsBytes)]
40 struct NlMsgHdr {
41 pub nlmsg_len: u32,
42 pub nlmsg_type: u16,
43 pub nlmsg_flags: u16,
44 pub nlmsg_seq: u32,
45 pub nlmsg_pid: u32,
46 }
47
48 /// Netlink attribute struct, can be used by netlink consumer
49 #[repr(C)]
50 #[derive(Copy, Clone, FromZeroes, FromBytes, AsBytes)]
51 pub struct NlAttr {
52 pub len: u16,
53 pub _type: u16,
54 }
55
56 /// Generic netlink header struct, can be used by netlink consumer
57 #[repr(C)]
58 #[derive(Copy, Clone, FromZeroes, FromBytes, AsBytes)]
59 pub struct GenlMsgHdr {
60 pub cmd: u8,
61 pub version: u8,
62 pub reserved: u16,
63 }
64 /// A single netlink message, including its header and data.
65 pub struct NetlinkMessage<'a> {
66 pub _type: u16,
67 pub flags: u16,
68 pub seq: u32,
69 pub pid: u32,
70 pub data: &'a [u8],
71 }
72
73 pub struct NlAttrWithData<'a> {
74 pub len: u16,
75 pub _type: u16,
76 pub data: &'a [u8],
77 }
78
nlattr_align(offset: usize) -> usize79 fn nlattr_align(offset: usize) -> usize {
80 (offset + NLATTR_ALIGN_TO - 1) & !(NLATTR_ALIGN_TO - 1)
81 }
82
83 /// Iterator over `struct NlAttr` as received from a netlink socket.
84 pub struct NetlinkGenericDataIter<'a> {
85 // `data` must be properly aligned for NlAttr.
86 data: &'a [u8],
87 }
88
89 impl<'a> Iterator for NetlinkGenericDataIter<'a> {
90 type Item = NlAttrWithData<'a>;
91
next(&mut self) -> Option<Self::Item>92 fn next(&mut self) -> Option<Self::Item> {
93 if self.data.len() < NLA_HDRLEN {
94 return None;
95 }
96 let nl_hdr = NlAttr::read_from(&self.data[..NLA_HDRLEN])?;
97
98 // Make sure NlAtrr fits
99 let nl_data_len = nl_hdr.len as usize;
100 if nl_data_len < NLA_HDRLEN || nl_data_len > self.data.len() {
101 return None;
102 }
103
104 // Get data related to processed NlAttr
105 let data_start = NLA_HDRLEN;
106 let data = &self.data[data_start..nl_data_len];
107
108 // Get next NlAttr
109 let next_hdr = nlattr_align(nl_data_len);
110 if next_hdr >= self.data.len() {
111 self.data = &[];
112 } else {
113 self.data = &self.data[next_hdr..];
114 }
115
116 Some(NlAttrWithData {
117 _type: nl_hdr._type,
118 len: nl_hdr.len,
119 data,
120 })
121 }
122 }
123
124 /// Iterator over `struct nlmsghdr` as received from a netlink socket.
125 pub struct NetlinkMessageIter<'a> {
126 // `data` must be properly aligned for nlmsghdr.
127 data: &'a [u8],
128 }
129
130 impl<'a> Iterator for NetlinkMessageIter<'a> {
131 type Item = NetlinkMessage<'a>;
132
next(&mut self) -> Option<Self::Item>133 fn next(&mut self) -> Option<Self::Item> {
134 if self.data.len() < NLMSGHDR_SIZE {
135 return None;
136 }
137 let hdr = NlMsgHdr::read_from(&self.data[..NLMSGHDR_SIZE])?;
138
139 // NLMSG_OK
140 let msg_len = hdr.nlmsg_len as usize;
141 if msg_len < NLMSGHDR_SIZE || msg_len > self.data.len() {
142 return None;
143 }
144
145 // NLMSG_DATA
146 let data_start = NLMSGHDR_SIZE;
147 let data = &self.data[data_start..msg_len];
148
149 // NLMSG_NEXT
150 let align_to = std::mem::align_of::<NlMsgHdr>();
151 let next_hdr = (msg_len + align_to - 1) & !(align_to - 1);
152 if next_hdr >= self.data.len() {
153 self.data = &[];
154 } else {
155 self.data = &self.data[next_hdr..];
156 }
157
158 Some(NetlinkMessage {
159 _type: hdr.nlmsg_type,
160 flags: hdr.nlmsg_flags,
161 seq: hdr.nlmsg_seq,
162 pid: hdr.nlmsg_pid,
163 data,
164 })
165 }
166 }
167
168 /// Safe wrapper for `NETLINK_GENERIC` netlink sockets.
169 pub struct NetlinkGenericSocket {
170 sock: SafeDescriptor,
171 }
172
173 impl AsRawDescriptor for NetlinkGenericSocket {
as_raw_descriptor(&self) -> RawDescriptor174 fn as_raw_descriptor(&self) -> RawDescriptor {
175 self.sock.as_raw_descriptor()
176 }
177 }
178
179 impl NetlinkGenericSocket {
180 /// Create and bind a new `NETLINK_GENERIC` socket.
new(nl_groups: u32) -> Result<Self>181 pub fn new(nl_groups: u32) -> Result<Self> {
182 // SAFETY:
183 // Safe because we check the return value and convert the raw fd into a SafeDescriptor.
184 let sock = unsafe {
185 let fd = libc::socket(
186 libc::AF_NETLINK,
187 libc::SOCK_RAW | libc::SOCK_CLOEXEC,
188 libc::NETLINK_GENERIC,
189 );
190 if fd < 0 {
191 return errno_result();
192 }
193
194 SafeDescriptor::from_raw_descriptor(fd)
195 };
196
197 // SAFETY:
198 // This MaybeUninit dance is needed because sockaddr_nl has a private padding field and
199 // doesn't implement Default. Safe because all 0s is valid data for sockaddr_nl.
200 let mut sa = unsafe { MaybeUninit::<libc::sockaddr_nl>::zeroed().assume_init() };
201 sa.nl_family = libc::AF_NETLINK as libc::sa_family_t;
202 sa.nl_groups = nl_groups;
203
204 // SAFETY:
205 // Safe because we pass a descriptor that we own and valid pointer/size for sockaddr.
206 unsafe {
207 let res = libc::bind(
208 sock.as_raw_fd(),
209 &sa as *const libc::sockaddr_nl as *const libc::sockaddr,
210 std::mem::size_of_val(&sa) as libc::socklen_t,
211 );
212 if res < 0 {
213 return errno_result();
214 }
215 }
216
217 Ok(NetlinkGenericSocket { sock })
218 }
219
220 /// Receive messages from the netlink socket.
recv(&self) -> Result<NetlinkGenericRead>221 pub fn recv(&self) -> Result<NetlinkGenericRead> {
222 let buf_size = 8192; // TODO(dverkamp): make this configurable?
223
224 // Create a buffer with sufficient alignment for nlmsghdr.
225 let layout = Layout::from_size_align(buf_size, std::mem::align_of::<NlMsgHdr>())
226 .map_err(|_| Error::new(EINVAL))?;
227 let allocation = LayoutAllocation::uninitialized(layout);
228
229 // SAFETY:
230 // Safe because we pass a valid, owned socket fd and a valid pointer/size for the buffer.
231 let bytes_read = unsafe {
232 let res = libc::recv(self.sock.as_raw_fd(), allocation.as_ptr(), buf_size, 0);
233 if res < 0 {
234 return errno_result();
235 }
236 res as usize
237 };
238
239 Ok(NetlinkGenericRead {
240 allocation,
241 len: bytes_read,
242 })
243 }
244
family_name_query(&self, family_name: String) -> Result<NetlinkGenericRead>245 pub fn family_name_query(&self, family_name: String) -> Result<NetlinkGenericRead> {
246 let buf_size = 1024;
247 debug_pr!(
248 "preparing query for family name {}, len {}",
249 family_name,
250 family_name.len()
251 );
252
253 // Create a buffer with sufficient alignment for nlmsghdr.
254 let layout = Layout::from_size_align(buf_size, std::mem::align_of::<NlMsgHdr>())
255 .map_err(|_| Error::new(EINVAL))
256 .unwrap();
257 let mut allocation = LayoutAllocation::zeroed(layout);
258
259 // SAFETY:
260 // Safe because the data in allocation was initialized up to `buf_size` and is
261 // sufficiently aligned.
262 let data = unsafe { allocation.as_mut_slice(buf_size) };
263
264 // Prepare the netlink message header
265 let hdr = Ref::<_, NlMsgHdr>::new(&mut data[..NLMSGHDR_SIZE])
266 .expect("failed to unwrap")
267 .into_mut();
268 hdr.nlmsg_len = NLMSGHDR_SIZE as u32 + GENL_HDRLEN as u32;
269 hdr.nlmsg_len += NLA_HDRLEN as u32 + family_name.len() as u32 + 1;
270 hdr.nlmsg_flags = libc::NLM_F_REQUEST as u16;
271 hdr.nlmsg_type = libc::GENL_ID_CTRL as u16;
272 hdr.nlmsg_pid = getpid() as u32;
273
274 // Prepare generic netlink message header
275 let genl_hdr_end = NLMSGHDR_SIZE + GENL_HDRLEN;
276 let genl_hdr = Ref::<_, GenlMsgHdr>::new(&mut data[NLMSGHDR_SIZE..genl_hdr_end])
277 .expect("unable to get GenlMsgHdr from slice")
278 .into_mut();
279 genl_hdr.cmd = libc::CTRL_CMD_GETFAMILY as u8;
280 genl_hdr.version = 0x1;
281
282 // Netlink attributes
283 let nlattr_start = genl_hdr_end;
284 let nlattr_end = nlattr_start + NLA_HDRLEN;
285 let nl_attr = Ref::<_, NlAttr>::new(&mut data[nlattr_start..nlattr_end])
286 .expect("unable to get NlAttr from slice")
287 .into_mut();
288 nl_attr._type = libc::CTRL_ATTR_FAMILY_NAME as u16;
289 nl_attr.len = family_name.len() as u16 + 1 + NLA_HDRLEN as u16;
290
291 // Fill the message payload with the family name
292 let payload_start = nlattr_end;
293 let payload_end = payload_start + family_name.len();
294 data[payload_start..payload_end].copy_from_slice(family_name.as_bytes());
295
296 // SAFETY:
297 // Safe because we pass a valid, owned socket fd and a valid pointer/size for the buffer.
298 unsafe {
299 let res = libc::send(
300 self.sock.as_raw_fd(),
301 allocation.as_ptr(),
302 payload_end + 1,
303 0,
304 );
305 if res < 0 {
306 error!("failed to send get_family_cmd");
307 return errno_result();
308 }
309 };
310
311 // Return the answer
312 match self.recv() {
313 Ok(msg) => Ok(msg),
314 Err(e) => {
315 error!("recv get_family returned with error {}", e);
316 Err(e)
317 }
318 }
319 }
320 }
321
parse_ctrl_group_name_and_id( nested_nl_attr_data: NetlinkGenericDataIter, group_name: &str, ) -> Option<u32>322 fn parse_ctrl_group_name_and_id(
323 nested_nl_attr_data: NetlinkGenericDataIter,
324 group_name: &str,
325 ) -> Option<u32> {
326 let mut mcast_group_id: Option<u32> = None;
327
328 for nested_nl_attr in nested_nl_attr_data {
329 debug_pr!(
330 "\t\tmcast_grp: nlattr type {}, len {}",
331 nested_nl_attr._type,
332 nested_nl_attr.len
333 );
334
335 if nested_nl_attr._type == libc::CTRL_ATTR_MCAST_GRP_ID as u16 {
336 mcast_group_id = Some(u32::from_ne_bytes(nested_nl_attr.data.try_into().unwrap()));
337 debug_pr!("\t\t mcast group_id {}", mcast_group_id?);
338 }
339
340 if nested_nl_attr._type == libc::CTRL_ATTR_MCAST_GRP_NAME as u16 {
341 debug_pr!(
342 "\t\t mcast group name {}",
343 strip_padding(&nested_nl_attr.data)
344 );
345
346 // If the group name match and the group_id was set in previous iteration, return,
347 // valid for group_name, group_id
348 if group_name.eq(strip_padding(nested_nl_attr.data)) && mcast_group_id.is_some() {
349 debug_pr!(
350 "\t\t Got what we were looking for group_id = {} for {}",
351 mcast_group_id?,
352 group_name
353 );
354
355 return mcast_group_id;
356 }
357 }
358 }
359
360 None
361 }
362
363 /// Parse CTRL_ATTR_MCAST_GROUPS data in order to get multicast group id
364 ///
365 /// On success, returns group_id for a given `group_name`
366 ///
367 /// # Arguments
368 ///
369 /// * `nl_attr_area`
370 ///
371 /// Nested attributes area (CTRL_ATTR_MCAST_GROUPS data), where nl_attr's corresponding to
372 /// specific groups are embed
373 ///
374 /// * `group_name`
375 ///
376 /// String with group_name for which we are looking group_id
377 ///
378 /// the CTRL_ATTR_MCAST_GROUPS data has nested attributes. Each of nested attribute is per
379 /// multicast group attributes, which have another nested attributes: CTRL_ATTR_MCAST_GRP_NAME and
380 /// CTRL_ATTR_MCAST_GRP_ID. Need to parse all of them to get mcast group id for a given group_name..
381 ///
382 /// Illustrated layout:
383 /// CTRL_ATTR_MCAST_GROUPS:
384 /// GR1 (nl_attr._type = 1):
385 /// CTRL_ATTR_MCAST_GRP_ID,
386 /// CTRL_ATTR_MCAST_GRP_NAME,
387 /// GR2 (nl_attr._type = 2):
388 /// CTRL_ATTR_MCAST_GRP_ID,
389 /// CTRL_ATTR_MCAST_GRP_NAME,
390 /// ..
391 ///
392 /// Unfortunately kernel implementation uses `nla_nest_start_noflag` for that
393 /// purpose, which means that it never marked their nest attributes with NLA_F_NESTED flag.
394 /// Therefore all this nesting stages need to be deduced based on specific nl_attr type.
parse_ctrl_mcast_group_id( nl_attr_area: NetlinkGenericDataIter, group_name: &str, ) -> Option<u32>395 fn parse_ctrl_mcast_group_id(
396 nl_attr_area: NetlinkGenericDataIter,
397 group_name: &str,
398 ) -> Option<u32> {
399 // There may be multiple nested multicast groups, go through all of them.
400 // Each of nested group, has other nested nlattr:
401 // CTRL_ATTR_MCAST_GRP_ID
402 // CTRL_ATTR_MCAST_GRP_NAME
403 //
404 // which are further proceed by parse_ctrl_group_name_and_id
405 for nested_gr_nl_attr in nl_attr_area {
406 debug_pr!(
407 "\tmcast_groups: nlattr type(gr_nr) {}, len {}",
408 nested_gr_nl_attr._type,
409 nested_gr_nl_attr.len
410 );
411
412 let netlink_nested_attr = NetlinkGenericDataIter {
413 data: nested_gr_nl_attr.data,
414 };
415
416 if let Some(mcast_group_id) = parse_ctrl_group_name_and_id(netlink_nested_attr, group_name)
417 {
418 return Some(mcast_group_id);
419 }
420 }
421
422 None
423 }
424
425 // Like `CStr::from_bytes_with_nul` but strips any bytes starting from first '\0'-byte and
426 // returns &str. Panics if `b` doesn't contain any '\0' bytes.
strip_padding(b: &[u8]) -> &str427 fn strip_padding(b: &[u8]) -> &str {
428 // It would be nice if we could use memchr here but that's locked behind an unstable gate.
429 let pos = b
430 .iter()
431 .position(|&c| c == 0)
432 .expect("`b` doesn't contain any nul bytes");
433
434 str::from_utf8(&b[..pos]).unwrap()
435 }
436
437 pub struct NetlinkGenericRead {
438 allocation: LayoutAllocation,
439 len: usize,
440 }
441
442 impl NetlinkGenericRead {
iter(&self) -> NetlinkMessageIter443 pub fn iter(&self) -> NetlinkMessageIter {
444 // SAFETY:
445 // Safe because the data in allocation was initialized up to `self.len` by `recv()` and is
446 // sufficiently aligned.
447 let data = unsafe { &self.allocation.as_slice(self.len) };
448 NetlinkMessageIter { data }
449 }
450
451 /// Parse NetlinkGeneric response in order to get multicast group id
452 ///
453 /// On success, returns group_id for a given `group_name`
454 ///
455 /// # Arguments
456 ///
457 /// * `group_name` - String with group_name for which we are looking group_id
458 ///
459 /// Response from family_name_query (CTRL_CMD_GETFAMILY) is a netlink message with multiple
460 /// attributes encapsulated (some of them are nested). An example response layout is
461 /// illustrated below:
462 ///
463 /// {
464 /// CTRL_ATTR_FAMILY_NAME
465 /// CTRL_ATTR_FAMILY_ID
466 /// CTRL_ATTR_VERSION
467 /// ...
468 /// CTRL_ATTR_MCAST_GROUPS {
469 /// GR1 (nl_attr._type = 1) {
470 /// CTRL_ATTR_MCAST_GRP_ID *we need parse this attr to obtain group id used for
471 /// the group mask
472 /// CTRL_ATTR_MCAST_GRP_NAME *group_name that we need to match with
473 /// }
474 /// GR2 (nl_attr._type = 2) {
475 /// CTRL_ATTR_MCAST_GRP_ID
476 /// CTRL_ATTR_MCAST_GRP_NAME
477 /// }
478 /// ...
479 /// }
480 /// }
get_multicast_group_id(&self, group_name: String) -> Option<u32>481 pub fn get_multicast_group_id(&self, group_name: String) -> Option<u32> {
482 for netlink_msg in self.iter() {
483 debug_pr!(
484 "received type: {}, flags {}, pid {}, data {:?}",
485 netlink_msg._type,
486 netlink_msg.flags,
487 netlink_msg.pid,
488 netlink_msg.data
489 );
490
491 if netlink_msg._type != libc::GENL_ID_CTRL as u16 {
492 error!("Received not a generic netlink controller msg");
493 return None;
494 }
495
496 let netlink_data = NetlinkGenericDataIter {
497 data: &netlink_msg.data[GENL_HDRLEN..],
498 };
499 for nl_attr in netlink_data {
500 debug_pr!("nl_attr type {}, len {}", nl_attr._type, nl_attr.len);
501
502 if nl_attr._type == libc::CTRL_ATTR_MCAST_GROUPS as u16 {
503 let netlink_nested_attr = NetlinkGenericDataIter { data: nl_attr.data };
504
505 if let Some(mcast_group_id) =
506 parse_ctrl_mcast_group_id(netlink_nested_attr, &group_name)
507 {
508 return Some(mcast_group_id);
509 }
510 }
511 }
512 }
513 None
514 }
515 }
516