xref: /aosp_15_r20/external/crosvm/devices/src/virtio/iommu.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 pub mod ipc_memory_mapper;
6 pub mod memory_mapper;
7 pub mod protocol;
8 pub(crate) mod sys;
9 
10 use std::cell::RefCell;
11 use std::collections::btree_map::Entry;
12 use std::collections::BTreeMap;
13 use std::io;
14 use std::io::Write;
15 use std::mem::size_of;
16 use std::ops::RangeInclusive;
17 use std::rc::Rc;
18 use std::result;
19 use std::sync::Arc;
20 
21 #[cfg(target_arch = "x86_64")]
22 use acpi_tables::sdt::SDT;
23 use anyhow::anyhow;
24 use anyhow::Context;
25 use base::debug;
26 use base::error;
27 use base::pagesize;
28 #[cfg(target_arch = "x86_64")]
29 use base::warn;
30 use base::AsRawDescriptor;
31 use base::Error as SysError;
32 use base::Event;
33 use base::MappedRegion;
34 use base::MemoryMapping;
35 use base::Protection;
36 use base::RawDescriptor;
37 use base::Result as SysResult;
38 use base::Tube;
39 use base::TubeError;
40 use base::WorkerThread;
41 use cros_async::AsyncError;
42 use cros_async::AsyncTube;
43 use cros_async::EventAsync;
44 use cros_async::Executor;
45 use data_model::Le64;
46 use futures::select;
47 use futures::FutureExt;
48 use remain::sorted;
49 use sync::Mutex;
50 use thiserror::Error;
51 use vm_control::VmMemoryRegionId;
52 use vm_memory::GuestAddress;
53 use vm_memory::GuestMemory;
54 use vm_memory::GuestMemoryError;
55 use zerocopy::AsBytes;
56 #[cfg(target_arch = "x86_64")]
57 use zerocopy::FromBytes;
58 #[cfg(target_arch = "x86_64")]
59 use zerocopy::FromZeroes;
60 
61 #[cfg(target_arch = "x86_64")]
62 use crate::pci::PciAddress;
63 use crate::virtio::async_utils;
64 use crate::virtio::copy_config;
65 use crate::virtio::iommu::memory_mapper::*;
66 use crate::virtio::iommu::protocol::*;
67 use crate::virtio::DescriptorChain;
68 use crate::virtio::DeviceType;
69 use crate::virtio::Interrupt;
70 use crate::virtio::Queue;
71 use crate::virtio::Reader;
72 use crate::virtio::VirtioDevice;
73 #[cfg(target_arch = "x86_64")]
74 use crate::virtio::Writer;
75 
76 const QUEUE_SIZE: u16 = 256;
77 const NUM_QUEUES: usize = 2;
78 const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE; NUM_QUEUES];
79 
80 // Size of struct virtio_iommu_probe_property
81 #[cfg(target_arch = "x86_64")]
82 const IOMMU_PROBE_SIZE: usize = size_of::<virtio_iommu_probe_resv_mem>();
83 
84 #[cfg(target_arch = "x86_64")]
85 const VIRTIO_IOMMU_VIOT_NODE_PCI_RANGE: u8 = 1;
86 #[cfg(target_arch = "x86_64")]
87 const VIRTIO_IOMMU_VIOT_NODE_VIRTIO_IOMMU_PCI: u8 = 3;
88 
89 #[derive(Copy, Clone, Debug, Default, FromZeroes, FromBytes, AsBytes)]
90 #[repr(C, packed)]
91 #[cfg(target_arch = "x86_64")]
92 struct VirtioIommuViotHeader {
93     node_count: u16,
94     node_offset: u16,
95     reserved: [u8; 8],
96 }
97 
98 #[derive(Copy, Clone, Debug, Default, FromZeroes, FromBytes, AsBytes)]
99 #[repr(C, packed)]
100 #[cfg(target_arch = "x86_64")]
101 struct VirtioIommuViotVirtioPciNode {
102     type_: u8,
103     reserved: [u8; 1],
104     length: u16,
105     segment: u16,
106     bdf: u16,
107     reserved2: [u8; 8],
108 }
109 
110 #[derive(Copy, Clone, Debug, Default, FromZeroes, FromBytes, AsBytes)]
111 #[repr(C, packed)]
112 #[cfg(target_arch = "x86_64")]
113 struct VirtioIommuViotPciRangeNode {
114     type_: u8,
115     reserved: [u8; 1],
116     length: u16,
117     endpoint_start: u32,
118     segment_start: u16,
119     segment_end: u16,
120     bdf_start: u16,
121     bdf_end: u16,
122     output_node: u16,
123     reserved2: [u8; 2],
124     reserved3: [u8; 4],
125 }
126 
127 type Result<T> = result::Result<T, IommuError>;
128 
129 #[sorted]
130 #[derive(Error, Debug)]
131 pub enum IommuError {
132     #[error("async executor error: {0}")]
133     AsyncExec(AsyncError),
134     #[error("failed to create wait context: {0}")]
135     CreateWaitContext(SysError),
136     #[error("failed getting host address: {0}")]
137     GetHostAddress(GuestMemoryError),
138     #[error("failed to read from guest address: {0}")]
139     GuestMemoryRead(io::Error),
140     #[error("failed to write to guest address: {0}")]
141     GuestMemoryWrite(io::Error),
142     #[error("memory mapper failed: {0}")]
143     MemoryMapper(anyhow::Error),
144     #[error("Failed to read descriptor asynchronously: {0}")]
145     ReadAsyncDesc(AsyncError),
146     #[error("failed to read from virtio queue Event: {0}")]
147     ReadQueueEvent(SysError),
148     #[error("tube error: {0}")]
149     Tube(TubeError),
150     #[error("unexpected descriptor error")]
151     UnexpectedDescriptor,
152     #[error("failed to receive virtio-iommu control request: {0}")]
153     VirtioIOMMUReqError(TubeError),
154     #[error("failed to send virtio-iommu control response: {0}")]
155     VirtioIOMMUResponseError(TubeError),
156     #[error("failed to wait for events: {0}")]
157     WaitError(SysError),
158     #[error("write buffer length too small")]
159     WriteBufferTooSmall,
160 }
161 
162 // key: domain ID
163 // value: reference counter and MemoryMapperTrait
164 type DomainMap = BTreeMap<u32, (u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>)>;
165 
166 struct DmabufRegionEntry {
167     mmap: MemoryMapping,
168     region_id: VmMemoryRegionId,
169     size: u64,
170 }
171 
172 // Shared state for the virtio-iommu device.
173 struct State {
174     mem: GuestMemory,
175     page_mask: u64,
176     // Hot-pluggable PCI endpoints ranges
177     // RangeInclusive: (start endpoint PCI address .. =end endpoint PCI address)
178     #[cfg_attr(windows, allow(dead_code))]
179     hp_endpoints_ranges: Vec<RangeInclusive<u32>>,
180     // All PCI endpoints that attach to certain IOMMU domain
181     // key: endpoint PCI address
182     // value: attached domain ID
183     endpoint_map: BTreeMap<u32, u32>,
184     // All attached domains
185     domain_map: DomainMap,
186     // Contains all pass-through endpoints that attach to this IOMMU device
187     // key: endpoint PCI address
188     // value: reference counter and MemoryMapperTrait
189     endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>,
190     // Contains dmabuf regions
191     // key: guest physical address
192     dmabuf_mem: BTreeMap<u64, DmabufRegionEntry>,
193 }
194 
195 impl State {
196     // Detach the given endpoint if possible, and return whether or not the endpoint
197     // was actually detached. If a successfully detached endpoint has exported
198     // memory, returns an event that will be signaled once all exported memory is released.
199     //
200     // The device MUST ensure that after being detached from a domain, the endpoint
201     // cannot access any mapping from that domain.
202     //
203     // Currently, we only support detaching an endpoint if it is the only endpoint attached
204     // to its domain.
detach_endpoint( endpoint_map: &mut BTreeMap<u32, u32>, domain_map: &mut DomainMap, endpoint: u32, ) -> (bool, Option<EventAsync>)205     fn detach_endpoint(
206         endpoint_map: &mut BTreeMap<u32, u32>,
207         domain_map: &mut DomainMap,
208         endpoint: u32,
209     ) -> (bool, Option<EventAsync>) {
210         let mut evt = None;
211         // The endpoint has attached to an IOMMU domain
212         if let Some(attached_domain) = endpoint_map.get(&endpoint) {
213             // Remove the entry or update the domain reference count
214             if let Entry::Occupied(o) = domain_map.entry(*attached_domain) {
215                 let (refs, mapper) = o.get();
216                 if !mapper.lock().supports_detach() {
217                     return (false, None);
218                 }
219 
220                 match refs {
221                     0 => unreachable!(),
222                     1 => {
223                         evt = mapper.lock().reset_domain();
224                         o.remove();
225                     }
226                     _ => return (false, None),
227                 }
228             }
229         }
230 
231         endpoint_map.remove(&endpoint);
232         (true, evt)
233     }
234 
235     // Processes an attach request. This may require detaching the endpoint from
236     // its current endpoint before attaching it to a new endpoint. If that happens
237     // while the endpoint has exported memory, this function returns an event that
238     // will be signaled once all exported memory is released.
239     //
240     // Notes: if a VFIO group contains multiple devices, it could violate the follow
241     // requirement from the virtio IOMMU spec: If the VIRTIO_IOMMU_F_BYPASS feature
242     // is negotiated, all accesses from unattached endpoints are allowed and translated
243     // by the IOMMU using the identity function. If the feature is not negotiated, any
244     // memory access from an unattached endpoint fails.
245     //
246     // This happens after the virtio-iommu device receives a VIRTIO_IOMMU_T_ATTACH
247     // request for the first endpoint in a VFIO group, any not yet attached endpoints
248     // in the VFIO group will be able to access the domain.
249     //
250     // This violation is benign for current virtualization use cases. Since device
251     // topology in the guest matches topology in the host, the guest doesn't expect
252     // the device in the same VFIO group are isolated from each other in the first place.
process_attach_request( &mut self, reader: &mut Reader, tail: &mut virtio_iommu_req_tail, ) -> Result<(usize, Option<EventAsync>)>253     fn process_attach_request(
254         &mut self,
255         reader: &mut Reader,
256         tail: &mut virtio_iommu_req_tail,
257     ) -> Result<(usize, Option<EventAsync>)> {
258         let req: virtio_iommu_req_attach =
259             reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
260         let mut fault_resolved_event = None;
261 
262         // If the reserved field of an ATTACH request is not zero,
263         // the device MUST reject the request and set status to
264         // VIRTIO_IOMMU_S_INVAL.
265         if req.reserved.iter().any(|&x| x != 0) {
266             tail.status = VIRTIO_IOMMU_S_INVAL;
267             return Ok((0, None));
268         }
269 
270         let domain: u32 = req.domain.into();
271         let endpoint: u32 = req.endpoint.into();
272 
273         if let Some(mapper) = self.endpoints.get(&endpoint) {
274             // The same mapper can't be used for two domains at the same time,
275             // since that would result in conflicts/permission leaks between
276             // the two domains.
277             let mapper_id = {
278                 let m = mapper.lock();
279                 ((**m).type_id(), m.id())
280             };
281             for (other_endpoint, other_mapper) in self.endpoints.iter() {
282                 if *other_endpoint == endpoint {
283                     continue;
284                 }
285                 let other_id = {
286                     let m = other_mapper.lock();
287                     ((**m).type_id(), m.id())
288                 };
289                 if mapper_id == other_id {
290                     if !self
291                         .endpoint_map
292                         .get(other_endpoint)
293                         .map_or(true, |d| d == &domain)
294                     {
295                         tail.status = VIRTIO_IOMMU_S_UNSUPP;
296                         return Ok((0, None));
297                     }
298                 }
299             }
300 
301             // If the endpoint identified by `endpoint` is already attached
302             // to another domain, then the device SHOULD first detach it
303             // from that domain and attach it to the one identified by domain.
304             if self.endpoint_map.contains_key(&endpoint) {
305                 // In that case the device SHOULD behave as if the driver issued
306                 // a DETACH request with this endpoint, followed by the ATTACH
307                 // request. If the device cannot do so, it MUST reject the request
308                 // and set status to VIRTIO_IOMMU_S_UNSUPP.
309                 let (detached, evt) =
310                     Self::detach_endpoint(&mut self.endpoint_map, &mut self.domain_map, endpoint);
311                 if !detached {
312                     tail.status = VIRTIO_IOMMU_S_UNSUPP;
313                     return Ok((0, None));
314                 }
315                 fault_resolved_event = evt;
316             }
317 
318             let new_ref = match self.domain_map.get(&domain) {
319                 None => 1,
320                 Some(val) => val.0 + 1,
321             };
322 
323             self.endpoint_map.insert(endpoint, domain);
324             self.domain_map.insert(domain, (new_ref, mapper.clone()));
325         } else {
326             // If the endpoint identified by endpoint doesn’t exist,
327             // the device MUST reject the request and set status to
328             // VIRTIO_IOMMU_S_NOENT.
329             tail.status = VIRTIO_IOMMU_S_NOENT;
330         }
331 
332         Ok((0, fault_resolved_event))
333     }
334 
process_detach_request( &mut self, reader: &mut Reader, tail: &mut virtio_iommu_req_tail, ) -> Result<(usize, Option<EventAsync>)>335     fn process_detach_request(
336         &mut self,
337         reader: &mut Reader,
338         tail: &mut virtio_iommu_req_tail,
339     ) -> Result<(usize, Option<EventAsync>)> {
340         let req: virtio_iommu_req_detach =
341             reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
342 
343         // If the endpoint identified by |req.endpoint| doesn’t exist,
344         // the device MUST reject the request and set status to
345         // VIRTIO_IOMMU_S_NOENT.
346         let endpoint: u32 = req.endpoint.into();
347         if !self.endpoints.contains_key(&endpoint) {
348             tail.status = VIRTIO_IOMMU_S_NOENT;
349             return Ok((0, None));
350         }
351 
352         let (detached, evt) =
353             Self::detach_endpoint(&mut self.endpoint_map, &mut self.domain_map, endpoint);
354         if !detached {
355             tail.status = VIRTIO_IOMMU_S_UNSUPP;
356         }
357         Ok((0, evt))
358     }
359 
process_dma_map_request( &mut self, reader: &mut Reader, tail: &mut virtio_iommu_req_tail, ) -> Result<usize>360     fn process_dma_map_request(
361         &mut self,
362         reader: &mut Reader,
363         tail: &mut virtio_iommu_req_tail,
364     ) -> Result<usize> {
365         let req: virtio_iommu_req_map = reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
366 
367         let phys_start = u64::from(req.phys_start);
368         let virt_start = u64::from(req.virt_start);
369         let virt_end = u64::from(req.virt_end);
370 
371         // enforce driver requirement: virt_end MUST be strictly greater than virt_start.
372         if virt_start >= virt_end {
373             tail.status = VIRTIO_IOMMU_S_INVAL;
374             return Ok(0);
375         }
376 
377         // If virt_start, phys_start or (virt_end + 1) is not aligned
378         // on the page granularity, the device SHOULD reject the
379         // request and set status to VIRTIO_IOMMU_S_RANGE
380         if self.page_mask & phys_start != 0
381             || self.page_mask & virt_start != 0
382             || self.page_mask & (virt_end + 1) != 0
383         {
384             tail.status = VIRTIO_IOMMU_S_RANGE;
385             return Ok(0);
386         }
387 
388         // If the device doesn’t recognize a flags bit, it MUST reject
389         // the request and set status to VIRTIO_IOMMU_S_INVAL.
390         if u32::from(req.flags) & !VIRTIO_IOMMU_MAP_F_MASK != 0 {
391             tail.status = VIRTIO_IOMMU_S_INVAL;
392             return Ok(0);
393         }
394 
395         let domain: u32 = req.domain.into();
396         if !self.domain_map.contains_key(&domain) {
397             // If domain does not exist, the device SHOULD reject
398             // the request and set status to VIRTIO_IOMMU_S_NOENT.
399             tail.status = VIRTIO_IOMMU_S_NOENT;
400             return Ok(0);
401         }
402 
403         // The device MUST NOT allow writes to a range mapped
404         // without the VIRTIO_IOMMU_MAP_F_WRITE flag.
405         let write_en = u32::from(req.flags) & VIRTIO_IOMMU_MAP_F_WRITE != 0;
406 
407         if let Some(mapper) = self.domain_map.get(&domain) {
408             let gpa = phys_start;
409             let iova = virt_start;
410             let Some(size) = u64::checked_add(virt_end - virt_start, 1) else {
411                 // implementation doesn't support unlikely request for size == U64::MAX+1
412                 tail.status = VIRTIO_IOMMU_S_DEVERR;
413                 return Ok(0);
414             };
415 
416             let dmabuf_map =
417                 self.dmabuf_mem
418                     .range(..=gpa)
419                     .next_back()
420                     .and_then(|(base_gpa, region)| {
421                         if gpa + size <= base_gpa + region.size {
422                             let offset = gpa - base_gpa;
423                             Some(region.mmap.as_ptr() as u64 + offset)
424                         } else {
425                             None
426                         }
427                     });
428 
429             let prot = match write_en {
430                 true => Protection::read_write(),
431                 false => Protection::read(),
432             };
433 
434             let vfio_map_result = match dmabuf_map {
435                 // SAFETY:
436                 // Safe because [dmabuf_map, dmabuf_map + size) refers to an external mmap'ed
437                 // region.
438                 Some(dmabuf_map) => unsafe {
439                     mapper.1.lock().vfio_dma_map(iova, dmabuf_map, size, prot)
440                 },
441                 None => mapper.1.lock().add_map(MappingInfo {
442                     iova,
443                     gpa: GuestAddress(gpa),
444                     size,
445                     prot,
446                 }),
447             };
448 
449             match vfio_map_result {
450                 Ok(AddMapResult::Ok) => (),
451                 Ok(AddMapResult::OverlapFailure) => {
452                     // If a mapping already exists in the requested range,
453                     // the device SHOULD reject the request and set status
454                     // to VIRTIO_IOMMU_S_INVAL.
455                     tail.status = VIRTIO_IOMMU_S_INVAL;
456                 }
457                 Err(e) => return Err(IommuError::MemoryMapper(e)),
458             }
459         }
460 
461         Ok(0)
462     }
463 
process_dma_unmap_request( &mut self, reader: &mut Reader, tail: &mut virtio_iommu_req_tail, ) -> Result<(usize, Option<EventAsync>)>464     fn process_dma_unmap_request(
465         &mut self,
466         reader: &mut Reader,
467         tail: &mut virtio_iommu_req_tail,
468     ) -> Result<(usize, Option<EventAsync>)> {
469         let req: virtio_iommu_req_unmap = reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
470 
471         let domain: u32 = req.domain.into();
472         let fault_resolved_event = if let Some(mapper) = self.domain_map.get(&domain) {
473             let size = u64::from(req.virt_end) - u64::from(req.virt_start) + 1;
474             let res = mapper
475                 .1
476                 .lock()
477                 .remove_map(u64::from(req.virt_start), size)
478                 .map_err(IommuError::MemoryMapper)?;
479             match res {
480                 RemoveMapResult::Success(evt) => evt,
481                 RemoveMapResult::OverlapFailure => {
482                     // If a mapping affected by the range is not covered in its entirety by the
483                     // range (the UNMAP request would split the mapping), then the device SHOULD
484                     // set the request `status` to VIRTIO_IOMMU_S_RANGE, and SHOULD NOT remove
485                     // any mapping.
486                     tail.status = VIRTIO_IOMMU_S_RANGE;
487                     None
488                 }
489             }
490         } else {
491             // If domain does not exist, the device SHOULD set the
492             // request status to VIRTIO_IOMMU_S_NOENT
493             tail.status = VIRTIO_IOMMU_S_NOENT;
494             None
495         };
496 
497         Ok((0, fault_resolved_event))
498     }
499 
500     #[cfg(target_arch = "x86_64")]
process_probe_request( &mut self, reader: &mut Reader, writer: &mut Writer, tail: &mut virtio_iommu_req_tail, ) -> Result<usize>501     fn process_probe_request(
502         &mut self,
503         reader: &mut Reader,
504         writer: &mut Writer,
505         tail: &mut virtio_iommu_req_tail,
506     ) -> Result<usize> {
507         let req: virtio_iommu_req_probe = reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
508         let endpoint: u32 = req.endpoint.into();
509 
510         // If the endpoint identified by endpoint doesn’t exist,
511         // then the device SHOULD reject the request and set status
512         // to VIRTIO_IOMMU_S_NOENT.
513         if !self.endpoints.contains_key(&endpoint) {
514             tail.status = VIRTIO_IOMMU_S_NOENT;
515         }
516 
517         let properties_size = writer.available_bytes() - size_of::<virtio_iommu_req_tail>();
518 
519         // It's OK if properties_size is larger than probe_size
520         // We are good even if properties_size is 0
521         if properties_size < IOMMU_PROBE_SIZE {
522             // If the properties list is smaller than probe_size, the device
523             // SHOULD NOT write any property. It SHOULD reject the request
524             // and set status to VIRTIO_IOMMU_S_INVAL.
525             tail.status = VIRTIO_IOMMU_S_INVAL;
526         } else if tail.status == VIRTIO_IOMMU_S_OK {
527             const VIRTIO_IOMMU_PROBE_T_RESV_MEM: u16 = 1;
528             const VIRTIO_IOMMU_RESV_MEM_T_MSI: u8 = 1;
529             const PROBE_PROPERTY_SIZE: u16 = 4;
530             const X86_MSI_IOVA_START: u64 = 0xfee0_0000;
531             const X86_MSI_IOVA_END: u64 = 0xfeef_ffff;
532 
533             let properties = virtio_iommu_probe_resv_mem {
534                 head: virtio_iommu_probe_property {
535                     type_: VIRTIO_IOMMU_PROBE_T_RESV_MEM.into(),
536                     length: (IOMMU_PROBE_SIZE as u16 - PROBE_PROPERTY_SIZE).into(),
537                 },
538                 subtype: VIRTIO_IOMMU_RESV_MEM_T_MSI,
539                 start: X86_MSI_IOVA_START.into(),
540                 end: X86_MSI_IOVA_END.into(),
541                 ..Default::default()
542             };
543             writer
544                 .write_all(properties.as_bytes())
545                 .map_err(IommuError::GuestMemoryWrite)?;
546         }
547 
548         // If the device doesn’t fill all probe_size bytes with properties,
549         // it SHOULD fill the remaining bytes of properties with zeroes.
550         let remaining_bytes = writer.available_bytes() - size_of::<virtio_iommu_req_tail>();
551 
552         if remaining_bytes > 0 {
553             let buffer: Vec<u8> = vec![0; remaining_bytes];
554             writer
555                 .write_all(buffer.as_slice())
556                 .map_err(IommuError::GuestMemoryWrite)?;
557         }
558 
559         Ok(properties_size)
560     }
561 
execute_request( &mut self, avail_desc: &mut DescriptorChain, ) -> Result<(usize, Option<EventAsync>)>562     fn execute_request(
563         &mut self,
564         avail_desc: &mut DescriptorChain,
565     ) -> Result<(usize, Option<EventAsync>)> {
566         let reader = &mut avail_desc.reader;
567         let writer = &mut avail_desc.writer;
568 
569         // at least we need space to write VirtioIommuReqTail
570         if writer.available_bytes() < size_of::<virtio_iommu_req_tail>() {
571             return Err(IommuError::WriteBufferTooSmall);
572         }
573 
574         let req_head: virtio_iommu_req_head =
575             reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
576 
577         let mut tail = virtio_iommu_req_tail {
578             status: VIRTIO_IOMMU_S_OK,
579             ..Default::default()
580         };
581 
582         let (reply_len, fault_resolved_event) = match req_head.type_ {
583             VIRTIO_IOMMU_T_ATTACH => self.process_attach_request(reader, &mut tail)?,
584             VIRTIO_IOMMU_T_DETACH => self.process_detach_request(reader, &mut tail)?,
585             VIRTIO_IOMMU_T_MAP => (self.process_dma_map_request(reader, &mut tail)?, None),
586             VIRTIO_IOMMU_T_UNMAP => self.process_dma_unmap_request(reader, &mut tail)?,
587             #[cfg(target_arch = "x86_64")]
588             VIRTIO_IOMMU_T_PROBE => (self.process_probe_request(reader, writer, &mut tail)?, None),
589             _ => return Err(IommuError::UnexpectedDescriptor),
590         };
591 
592         writer
593             .write_all(tail.as_bytes())
594             .map_err(IommuError::GuestMemoryWrite)?;
595         Ok((
596             reply_len + size_of::<virtio_iommu_req_tail>(),
597             fault_resolved_event,
598         ))
599     }
600 }
601 
request_queue( state: &Rc<RefCell<State>>, mut queue: Queue, mut queue_event: EventAsync, ) -> Result<()>602 async fn request_queue(
603     state: &Rc<RefCell<State>>,
604     mut queue: Queue,
605     mut queue_event: EventAsync,
606 ) -> Result<()> {
607     loop {
608         let mut avail_desc = queue
609             .next_async(&mut queue_event)
610             .await
611             .map_err(IommuError::ReadAsyncDesc)?;
612 
613         let (len, fault_resolved_event) = match state.borrow_mut().execute_request(&mut avail_desc)
614         {
615             Ok(res) => res,
616             Err(e) => {
617                 error!("execute_request failed: {}", e);
618 
619                 // If a request type is not recognized, the device SHOULD NOT write
620                 // the buffer and SHOULD set the used length to zero
621                 (0, None)
622             }
623         };
624 
625         if let Some(fault_resolved_event) = fault_resolved_event {
626             debug!("waiting for iommu fault resolution");
627             fault_resolved_event
628                 .next_val()
629                 .await
630                 .expect("failed waiting for fault");
631             debug!("iommu fault resolved");
632         }
633 
634         queue.add_used(avail_desc, len as u32);
635         queue.trigger_interrupt();
636     }
637 }
638 
run( state: State, iommu_device_tube: Tube, mut queues: BTreeMap<usize, Queue>, kill_evt: Event, interrupt: Interrupt, translate_response_senders: Option<BTreeMap<u32, Tube>>, translate_request_rx: Option<Tube>, ) -> Result<()>639 fn run(
640     state: State,
641     iommu_device_tube: Tube,
642     mut queues: BTreeMap<usize, Queue>,
643     kill_evt: Event,
644     interrupt: Interrupt,
645     translate_response_senders: Option<BTreeMap<u32, Tube>>,
646     translate_request_rx: Option<Tube>,
647 ) -> Result<()> {
648     let state = Rc::new(RefCell::new(state));
649     let ex = Executor::new().expect("Failed to create an executor");
650 
651     let req_queue = queues.remove(&0).unwrap();
652     let req_evt = req_queue
653         .event()
654         .try_clone()
655         .expect("Failed to clone queue event");
656     let req_evt = EventAsync::new(req_evt, &ex).expect("Failed to create async event for queue");
657 
658     let f_resample = async_utils::handle_irq_resample(&ex, interrupt);
659     let f_kill = async_utils::await_and_exit(&ex, kill_evt);
660 
661     let request_tube = translate_request_rx
662         .map(|t| AsyncTube::new(&ex, t).expect("Failed to create async tube for rx"));
663     let response_tubes = translate_response_senders.map(|m| {
664         m.into_iter()
665             .map(|x| {
666                 (
667                     x.0,
668                     AsyncTube::new(&ex, x.1).expect("Failed to create async tube"),
669                 )
670             })
671             .collect()
672     });
673 
674     let f_handle_translate_request =
675         sys::handle_translate_request(&ex, &state, request_tube, response_tubes);
676     let f_request = request_queue(&state, req_queue, req_evt);
677 
678     let command_tube = AsyncTube::new(&ex, iommu_device_tube).unwrap();
679     // Future to handle command messages from host, such as passing vfio containers.
680     let f_cmd = sys::handle_command_tube(&state, command_tube);
681 
682     let done = async {
683         select! {
684             res = f_request.fuse() => res.context("error in handling request queue"),
685             res = f_resample.fuse() => res.context("error in handle_irq_resample"),
686             res = f_kill.fuse() => res.context("error in await_and_exit"),
687             res = f_handle_translate_request.fuse() => {
688                 res.context("error in handle_translate_request")
689             }
690             res = f_cmd.fuse() => res.context("error in handling host request"),
691         }
692     };
693     match ex.run_until(done) {
694         Ok(Ok(())) => {}
695         Ok(Err(e)) => error!("Error in worker: {:#}", e),
696         Err(e) => return Err(IommuError::AsyncExec(e)),
697     }
698 
699     Ok(())
700 }
701 
702 /// Virtio device for IOMMU memory management.
703 pub struct Iommu {
704     worker_thread: Option<WorkerThread<()>>,
705     config: virtio_iommu_config,
706     avail_features: u64,
707     // Attached endpoints
708     // key: endpoint PCI address
709     // value: reference counter and MemoryMapperTrait
710     endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>,
711     // Hot-pluggable PCI endpoints ranges
712     // RangeInclusive: (start endpoint PCI address .. =end endpoint PCI address)
713     hp_endpoints_ranges: Vec<RangeInclusive<u32>>,
714     translate_response_senders: Option<BTreeMap<u32, Tube>>,
715     translate_request_rx: Option<Tube>,
716     iommu_device_tube: Option<Tube>,
717 }
718 
719 impl Iommu {
720     /// Create a new virtio IOMMU device.
new( base_features: u64, endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>, iova_max_addr: u64, hp_endpoints_ranges: Vec<RangeInclusive<u32>>, translate_response_senders: Option<BTreeMap<u32, Tube>>, translate_request_rx: Option<Tube>, iommu_device_tube: Option<Tube>, ) -> SysResult<Iommu>721     pub fn new(
722         base_features: u64,
723         endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>,
724         iova_max_addr: u64,
725         hp_endpoints_ranges: Vec<RangeInclusive<u32>>,
726         translate_response_senders: Option<BTreeMap<u32, Tube>>,
727         translate_request_rx: Option<Tube>,
728         iommu_device_tube: Option<Tube>,
729     ) -> SysResult<Iommu> {
730         let mut page_size_mask = !((pagesize() as u64) - 1);
731         for (_, container) in endpoints.iter() {
732             page_size_mask &= container
733                 .lock()
734                 .get_mask()
735                 .map_err(|_e| SysError::new(libc::EIO))?;
736         }
737 
738         if page_size_mask == 0 {
739             return Err(SysError::new(libc::EIO));
740         }
741 
742         let input_range = virtio_iommu_range_64 {
743             start: Le64::from(0),
744             end: iova_max_addr.into(),
745         };
746 
747         let config = virtio_iommu_config {
748             page_size_mask: page_size_mask.into(),
749             input_range,
750             #[cfg(target_arch = "x86_64")]
751             probe_size: (IOMMU_PROBE_SIZE as u32).into(),
752             ..Default::default()
753         };
754 
755         let mut avail_features: u64 = base_features;
756         avail_features |= 1 << VIRTIO_IOMMU_F_MAP_UNMAP
757             | 1 << VIRTIO_IOMMU_F_INPUT_RANGE
758             | 1 << VIRTIO_IOMMU_F_MMIO;
759 
760         if cfg!(target_arch = "x86_64") {
761             avail_features |= 1 << VIRTIO_IOMMU_F_PROBE;
762         }
763 
764         Ok(Iommu {
765             worker_thread: None,
766             config,
767             avail_features,
768             endpoints,
769             hp_endpoints_ranges,
770             translate_response_senders,
771             translate_request_rx,
772             iommu_device_tube,
773         })
774     }
775 }
776 
777 impl VirtioDevice for Iommu {
keep_rds(&self) -> Vec<RawDescriptor>778     fn keep_rds(&self) -> Vec<RawDescriptor> {
779         let mut rds = Vec::new();
780 
781         for (_, mapper) in self.endpoints.iter() {
782             rds.append(&mut mapper.lock().as_raw_descriptors());
783         }
784         if let Some(senders) = &self.translate_response_senders {
785             for (_, tube) in senders.iter() {
786                 rds.push(tube.as_raw_descriptor());
787             }
788         }
789         if let Some(rx) = &self.translate_request_rx {
790             rds.push(rx.as_raw_descriptor());
791         }
792 
793         if let Some(iommu_device_tube) = &self.iommu_device_tube {
794             rds.push(iommu_device_tube.as_raw_descriptor());
795         }
796 
797         rds
798     }
799 
device_type(&self) -> DeviceType800     fn device_type(&self) -> DeviceType {
801         DeviceType::Iommu
802     }
803 
queue_max_sizes(&self) -> &[u16]804     fn queue_max_sizes(&self) -> &[u16] {
805         QUEUE_SIZES
806     }
807 
features(&self) -> u64808     fn features(&self) -> u64 {
809         self.avail_features
810     }
811 
read_config(&self, offset: u64, data: &mut [u8])812     fn read_config(&self, offset: u64, data: &mut [u8]) {
813         let mut config: Vec<u8> = Vec::new();
814         config.extend_from_slice(self.config.as_bytes());
815         copy_config(data, 0, config.as_slice(), offset);
816     }
817 
activate( &mut self, mem: GuestMemory, interrupt: Interrupt, queues: BTreeMap<usize, Queue>, ) -> anyhow::Result<()>818     fn activate(
819         &mut self,
820         mem: GuestMemory,
821         interrupt: Interrupt,
822         queues: BTreeMap<usize, Queue>,
823     ) -> anyhow::Result<()> {
824         if queues.len() != QUEUE_SIZES.len() {
825             return Err(anyhow!(
826                 "expected {} queues, got {}",
827                 QUEUE_SIZES.len(),
828                 queues.len()
829             ));
830         }
831 
832         // The least significant bit of page_size_masks defines the page
833         // granularity of IOMMU mappings
834         let page_mask = (1u64 << u64::from(self.config.page_size_mask).trailing_zeros()) - 1;
835         let eps = self.endpoints.clone();
836         let hp_endpoints_ranges = self.hp_endpoints_ranges.to_owned();
837 
838         let translate_response_senders = self.translate_response_senders.take();
839         let translate_request_rx = self.translate_request_rx.take();
840 
841         let iommu_device_tube = self
842             .iommu_device_tube
843             .take()
844             .context("failed to start virtio-iommu worker: No control tube")?;
845 
846         self.worker_thread = Some(WorkerThread::start("v_iommu", move |kill_evt| {
847             let state = State {
848                 mem,
849                 page_mask,
850                 hp_endpoints_ranges,
851                 endpoint_map: BTreeMap::new(),
852                 domain_map: BTreeMap::new(),
853                 endpoints: eps,
854                 dmabuf_mem: BTreeMap::new(),
855             };
856             let result = run(
857                 state,
858                 iommu_device_tube,
859                 queues,
860                 kill_evt,
861                 interrupt,
862                 translate_response_senders,
863                 translate_request_rx,
864             );
865             if let Err(e) = result {
866                 error!("virtio-iommu worker thread exited with error: {}", e);
867             }
868         }));
869         Ok(())
870     }
871 
872     #[cfg(target_arch = "x86_64")]
generate_acpi( &mut self, pci_address: &Option<PciAddress>, mut sdts: Vec<SDT>, ) -> Option<Vec<SDT>>873     fn generate_acpi(
874         &mut self,
875         pci_address: &Option<PciAddress>,
876         mut sdts: Vec<SDT>,
877     ) -> Option<Vec<SDT>> {
878         const OEM_REVISION: u32 = 1;
879         const VIOT_REVISION: u8 = 0;
880 
881         for sdt in sdts.iter() {
882             // there should only be one VIOT table
883             if sdt.is_signature(b"VIOT") {
884                 warn!("vIOMMU: duplicate VIOT table detected");
885                 return None;
886             }
887         }
888 
889         let mut viot = SDT::new(
890             *b"VIOT",
891             acpi_tables::HEADER_LEN,
892             VIOT_REVISION,
893             *b"CROSVM",
894             *b"CROSVMDT",
895             OEM_REVISION,
896         );
897         viot.append(VirtioIommuViotHeader {
898             // # of PCI range nodes + 1 virtio-pci node
899             node_count: (self.endpoints.len() + self.hp_endpoints_ranges.len() + 1) as u16,
900             node_offset: (viot.len() + std::mem::size_of::<VirtioIommuViotHeader>()) as u16,
901             ..Default::default()
902         });
903 
904         let bdf = pci_address
905             .or_else(|| {
906                 error!("vIOMMU device has no PCI address");
907                 None
908             })?
909             .to_u32() as u16;
910         let iommu_offset = viot.len();
911 
912         viot.append(VirtioIommuViotVirtioPciNode {
913             type_: VIRTIO_IOMMU_VIOT_NODE_VIRTIO_IOMMU_PCI,
914             length: size_of::<VirtioIommuViotVirtioPciNode>() as u16,
915             bdf,
916             ..Default::default()
917         });
918 
919         for (endpoint, _) in self.endpoints.iter() {
920             viot.append(VirtioIommuViotPciRangeNode {
921                 type_: VIRTIO_IOMMU_VIOT_NODE_PCI_RANGE,
922                 length: size_of::<VirtioIommuViotPciRangeNode>() as u16,
923                 endpoint_start: *endpoint,
924                 bdf_start: *endpoint as u16,
925                 bdf_end: *endpoint as u16,
926                 output_node: iommu_offset as u16,
927                 ..Default::default()
928             });
929         }
930 
931         for endpoints_range in self.hp_endpoints_ranges.iter() {
932             let (endpoint_start, endpoint_end) = endpoints_range.clone().into_inner();
933             viot.append(VirtioIommuViotPciRangeNode {
934                 type_: VIRTIO_IOMMU_VIOT_NODE_PCI_RANGE,
935                 length: size_of::<VirtioIommuViotPciRangeNode>() as u16,
936                 endpoint_start,
937                 bdf_start: endpoint_start as u16,
938                 bdf_end: endpoint_end as u16,
939                 output_node: iommu_offset as u16,
940                 ..Default::default()
941             });
942         }
943 
944         sdts.push(viot);
945         Some(sdts)
946     }
947 }
948