user_driver/
vfio.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Support for accessing a MANA device via VFIO on Linux.
5
6#![cfg(target_os = "linux")]
7#![cfg(feature = "vfio")]
8
9use crate::DeviceBacking;
10use crate::DeviceRegisterIo;
11use crate::DmaClient;
12use crate::interrupt::DeviceInterrupt;
13use crate::interrupt::DeviceInterruptSource;
14use anyhow::Context;
15use futures::FutureExt;
16use futures_concurrency::future::Race;
17use inspect::Inspect;
18use inspect_counters::SharedCounter;
19use pal_async::task::Spawn;
20use pal_async::task::Task;
21use pal_async::wait::PolledWait;
22use pal_event::Event;
23use std::os::fd::AsFd;
24use std::os::unix::fs::FileExt;
25use std::path::Path;
26use std::sync::Arc;
27use std::sync::atomic::AtomicU32;
28use std::sync::atomic::Ordering::Relaxed;
29use std::time::Duration;
30use uevent::UeventListener;
31use vfio_bindings::bindings::vfio::VFIO_PCI_CONFIG_REGION_INDEX;
32use vfio_sys::IommuType;
33use vfio_sys::IrqInfo;
34use vmcore::vm_task::VmTaskDriver;
35use vmcore::vm_task::VmTaskDriverSource;
36use zerocopy::FromBytes;
37use zerocopy::Immutable;
38use zerocopy::IntoBytes;
39use zerocopy::KnownLayout;
40
41/// A device backend accessed via VFIO.
42#[derive(Inspect)]
43pub struct VfioDevice {
44    pci_id: Arc<str>,
45    #[inspect(skip)]
46    _container: vfio_sys::Container,
47    #[inspect(skip)]
48    _group: vfio_sys::Group,
49    #[inspect(skip)]
50    device: Arc<vfio_sys::Device>,
51    #[inspect(skip)]
52    msix_info: IrqInfo,
53    #[inspect(skip)]
54    driver_source: VmTaskDriverSource,
55    #[inspect(iter_by_index)]
56    interrupts: Vec<Option<InterruptState>>,
57    #[inspect(skip)]
58    config_space: vfio_sys::RegionInfo,
59    dma_client: Arc<dyn DmaClient>,
60}
61
62#[derive(Inspect)]
63struct InterruptState {
64    #[inspect(skip)]
65    interrupt: DeviceInterrupt,
66    target_cpu: Arc<AtomicU32>,
67    #[inspect(skip)]
68    _task: Task<()>,
69}
70
71impl VfioDevice {
72    /// Creates a new VFIO-backed device for the PCI device with `pci_id`.
73    pub async fn new(
74        driver_source: &VmTaskDriverSource,
75        pci_id: &str,
76        dma_client: Arc<dyn DmaClient>,
77    ) -> anyhow::Result<Self> {
78        Self::restore(driver_source, pci_id, false, dma_client).await
79    }
80
81    /// Creates a new VFIO-backed device for the PCI device with `pci_id`.
82    /// or creates a device from the saved state if provided.
83    pub async fn restore(
84        driver_source: &VmTaskDriverSource,
85        pci_id: &str,
86        keepalive: bool,
87        dma_client: Arc<dyn DmaClient>,
88    ) -> anyhow::Result<Self> {
89        let path = Path::new("/sys/bus/pci/devices").join(pci_id);
90
91        // The vfio device attaches asynchronously after the PCI device is added,
92        // so make sure that it has completed by checking for the vfio-dev subpath.
93        let vmbus_device =
94            std::fs::read_link(&path).context("failed to read link for pci device")?;
95        let instance_path = Path::new("/sys").join(vmbus_device.strip_prefix("../../..")?);
96        let vfio_arrived_path = instance_path.join("vfio-dev");
97        let uevent_listener = UeventListener::new(&driver_source.simple())?;
98        let wait_for_vfio_device =
99            uevent_listener.wait_for_matching_child(&vfio_arrived_path, async |_, _| Some(()));
100        let mut ctx = mesh::CancelContext::new().with_timeout(Duration::from_secs(1));
101        // Ignore any errors and always attempt to open.
102        let _ = ctx.until_cancelled(wait_for_vfio_device).await;
103
104        let container = vfio_sys::Container::new()?;
105        let group_id = vfio_sys::Group::find_group_for_device(&path)?;
106        let group = vfio_sys::Group::open_noiommu(group_id)?;
107        group.set_container(&container)?;
108        if !group.status()?.viable() {
109            anyhow::bail!("group is not viable");
110        }
111
112        container.set_iommu(IommuType::NoIommu)?;
113        if keepalive {
114            // Prevent physical hardware interaction when restoring.
115            group.set_keep_alive(pci_id)?;
116        }
117        let device = group.open_device(pci_id)?;
118        let msix_info = device.irq_info(vfio_bindings::bindings::vfio::VFIO_PCI_MSIX_IRQ_INDEX)?;
119        if msix_info.flags.noresize() {
120            anyhow::bail!("unsupported: kernel does not support dynamic msix allocation");
121        }
122
123        let config_space = device.region_info(VFIO_PCI_CONFIG_REGION_INDEX)?;
124        let this = Self {
125            pci_id: pci_id.into(),
126            _container: container,
127            _group: group,
128            device: Arc::new(device),
129            msix_info,
130            config_space,
131            driver_source: driver_source.clone(),
132            interrupts: Vec::new(),
133            dma_client,
134        };
135
136        // Ensure bus master enable and memory space enable are set, and that
137        // INTx is disabled.
138        this.enable_device().context("failed to enable device")?;
139        Ok(this)
140    }
141
142    fn enable_device(&self) -> anyhow::Result<()> {
143        let offset = pci_core::spec::cfg_space::HeaderType00::STATUS_COMMAND.0;
144        let status_command = self.read_config(offset)?;
145        let command = pci_core::spec::cfg_space::Command::from(status_command as u16);
146
147        let command = command
148            .with_bus_master(true)
149            .with_intx_disable(true)
150            .with_mmio_enabled(true);
151
152        let status_command = (status_command & 0xffff0000) | u16::from(command) as u32;
153        self.write_config(offset, status_command)?;
154        Ok(())
155    }
156
157    pub fn read_config(&self, offset: u16) -> anyhow::Result<u32> {
158        if offset as u64 > self.config_space.size - 4 {
159            anyhow::bail!("invalid config offset");
160        }
161
162        let mut buf = [0u8; 4];
163        self.device
164            .as_ref()
165            .as_ref()
166            .read_at(&mut buf, self.config_space.offset + offset as u64)
167            .context("failed to read config")?;
168
169        Ok(u32::from_ne_bytes(buf))
170    }
171
172    pub fn write_config(&self, offset: u16, data: u32) -> anyhow::Result<()> {
173        if offset as u64 > self.config_space.size - 4 {
174            anyhow::bail!("invalid config offset");
175        }
176
177        let buf = data.to_ne_bytes();
178        self.device
179            .as_ref()
180            .as_ref()
181            .write_at(&buf, self.config_space.offset + offset as u64)
182            .context("failed to write config")?;
183
184        Ok(())
185    }
186
187    /// Maps PCI BAR[n] to VA space.
188    fn map_bar(&self, n: u8) -> anyhow::Result<MappedRegionWithFallback> {
189        if n >= 6 {
190            anyhow::bail!("invalid bar");
191        }
192        let info = self.device.region_info(n.into())?;
193        let mapping = self.device.map(info.offset, info.size as usize, true)?;
194        sparse_mmap::initialize_try_copy();
195        Ok(MappedRegionWithFallback {
196            device: self.device.clone(),
197            mapping,
198            len: info.size as usize,
199            offset: info.offset,
200            read_fallback: SharedCounter::new(),
201            write_fallback: SharedCounter::new(),
202        })
203    }
204}
205
206/// A mapped region that falls back to read/write if the memory mapped access
207/// fails.
208///
209/// This should only happen for CVM, and only when the MMIO is emulated by the
210/// host.
211#[derive(Inspect)]
212pub struct MappedRegionWithFallback {
213    #[inspect(skip)]
214    device: Arc<vfio_sys::Device>,
215    #[inspect(skip)]
216    mapping: vfio_sys::MappedRegion,
217    offset: u64,
218    len: usize,
219    read_fallback: SharedCounter,
220    write_fallback: SharedCounter,
221}
222
223impl DeviceBacking for VfioDevice {
224    type Registers = MappedRegionWithFallback;
225
226    fn id(&self) -> &str {
227        &self.pci_id
228    }
229
230    fn map_bar(&mut self, n: u8) -> anyhow::Result<Self::Registers> {
231        (*self).map_bar(n)
232    }
233
234    fn dma_client(&self) -> Arc<dyn DmaClient> {
235        self.dma_client.clone()
236    }
237
238    fn max_interrupt_count(&self) -> u32 {
239        self.msix_info.count
240    }
241
242    fn map_interrupt(&mut self, msix: u32, cpu: u32) -> anyhow::Result<DeviceInterrupt> {
243        if msix >= self.msix_info.count {
244            anyhow::bail!("invalid msix index");
245        }
246        if self.interrupts.len() <= msix as usize {
247            self.interrupts.resize_with(msix as usize + 1, || None);
248        }
249
250        let interrupt = &mut self.interrupts[msix as usize];
251        if let Some(interrupt) = interrupt {
252            // The interrupt has been mapped before. Just retarget it to the new
253            // CPU on the next interrupt, if needed.
254            if interrupt.target_cpu.load(Relaxed) != cpu {
255                interrupt.target_cpu.store(cpu, Relaxed);
256            }
257            return Ok(interrupt.interrupt.clone());
258        }
259
260        let new_interrupt = {
261            let name = format!("vfio-interrupt-{pci_id}-{msix}", pci_id = self.pci_id);
262            let driver = self
263                .driver_source
264                .builder()
265                .run_on_target(true)
266                .target_vp(cpu)
267                .build(&name);
268
269            let event =
270                PolledWait::new(&driver, Event::new()).context("failed to allocate polled wait")?;
271
272            let source = DeviceInterruptSource::new();
273            self.device
274                .map_msix(msix, [event.get().as_fd()])
275                .context("failed to map msix")?;
276
277            // The interrupt's CPU affinity will be set by the task when it
278            // starts. This can block the thread briefly, so it's better to do
279            // it on the target CPU.
280            let irq = vfio_sys::find_msix_irq(&self.pci_id, msix)
281                .context("failed to find irq for msix")?;
282
283            let target_cpu = Arc::new(AtomicU32::new(cpu));
284
285            let interrupt = source.new_target();
286
287            let task = driver.spawn(
288                name,
289                InterruptTask {
290                    driver: driver.clone(),
291                    target_cpu: target_cpu.clone(),
292                    pci_id: self.pci_id.clone(),
293                    msix,
294                    irq,
295                    event,
296                    source,
297                }
298                .run(),
299            );
300
301            InterruptState {
302                interrupt,
303                target_cpu,
304                _task: task,
305            }
306        };
307
308        Ok(interrupt.insert(new_interrupt).interrupt.clone())
309    }
310}
311
312struct InterruptTask {
313    driver: VmTaskDriver,
314    target_cpu: Arc<AtomicU32>,
315    pci_id: Arc<str>,
316    msix: u32,
317    irq: u32,
318    event: PolledWait<Event>,
319    source: DeviceInterruptSource,
320}
321
322impl InterruptTask {
323    async fn run(mut self) {
324        let mut current_cpu = !0;
325        loop {
326            let next_cpu = self.target_cpu.load(Relaxed);
327            let r = if next_cpu == current_cpu {
328                self.event.wait().await
329            } else {
330                self.driver.retarget_vp(next_cpu);
331                // Wait until the target CPU is ready before updating affinity,
332                // since otherwise the CPU may not be online.
333                enum Event {
334                    TargetVpReady(()),
335                    Interrupt(std::io::Result<()>),
336                }
337                match (
338                    self.driver.wait_target_vp_ready().map(Event::TargetVpReady),
339                    self.event.wait().map(Event::Interrupt),
340                )
341                    .race()
342                    .await
343                {
344                    Event::TargetVpReady(()) => {
345                        if let Err(err) = set_irq_affinity(self.irq, next_cpu) {
346                            // This should only occur due to extreme low resources.
347                            // However, it is not a fatal error--it will just result in
348                            // worse performance--so do not panic.
349                            tracing::error!(
350                                pci_id = self.pci_id.as_ref(),
351                                msix = self.msix,
352                                irq = self.irq,
353                                error = &err as &dyn std::error::Error,
354                                "failed to set irq affinity"
355                            );
356                        }
357                        current_cpu = next_cpu;
358                        continue;
359                    }
360                    Event::Interrupt(r) => {
361                        // An interrupt arrived while waiting for the VP to be
362                        // ready. Signal and loop around to try again.
363                        r
364                    }
365                }
366            };
367
368            r.expect("wait cannot fail on eventfd");
369            self.source.signal();
370        }
371    }
372}
373
374fn set_irq_affinity(irq: u32, cpu: u32) -> std::io::Result<()> {
375    fs_err::write(
376        format!("/proc/irq/{}/smp_affinity_list", irq),
377        cpu.to_string(),
378    )
379}
380
381impl DeviceRegisterIo for vfio_sys::MappedRegion {
382    fn len(&self) -> usize {
383        self.len()
384    }
385
386    fn read_u32(&self, offset: usize) -> u32 {
387        self.read_u32(offset)
388    }
389
390    fn read_u64(&self, offset: usize) -> u64 {
391        self.read_u64(offset)
392    }
393
394    fn write_u32(&self, offset: usize, data: u32) {
395        self.write_u32(offset, data)
396    }
397
398    fn write_u64(&self, offset: usize, data: u64) {
399        self.write_u64(offset, data)
400    }
401}
402
403impl MappedRegionWithFallback {
404    fn mapping<T>(&self, offset: usize) -> *mut T {
405        assert!(offset <= self.mapping.len() - size_of::<T>() && offset % align_of::<T>() == 0);
406        if cfg!(feature = "mmio_simulate_fallback") {
407            return std::ptr::NonNull::dangling().as_ptr();
408        }
409        // SAFETY: the offset is validated to be in bounds.
410        unsafe { self.mapping.as_ptr().byte_add(offset).cast() }
411    }
412
413    fn read_from_mapping<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
414        &self,
415        offset: usize,
416    ) -> Result<T, sparse_mmap::MemoryError> {
417        // SAFETY: the offset is validated to be in bounds and aligned.
418        unsafe { sparse_mmap::try_read_volatile(self.mapping::<T>(offset)) }
419    }
420
421    fn write_to_mapping<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
422        &self,
423        offset: usize,
424        data: T,
425    ) -> Result<(), sparse_mmap::MemoryError> {
426        // SAFETY: the offset is validated to be in bounds and aligned.
427        unsafe { sparse_mmap::try_write_volatile(self.mapping::<T>(offset), &data) }
428    }
429
430    fn read_from_file(&self, offset: usize, buf: &mut [u8]) {
431        tracing::trace!(offset, n = buf.len(), "read");
432        self.read_fallback.increment();
433        let n = self
434            .device
435            .as_ref()
436            .as_ref()
437            .read_at(buf, self.offset + offset as u64)
438            .expect("valid mapping");
439        assert_eq!(n, buf.len());
440    }
441
442    fn write_to_file(&self, offset: usize, buf: &[u8]) {
443        tracing::trace!(offset, n = buf.len(), "write");
444        self.write_fallback.increment();
445        let n = self
446            .device
447            .as_ref()
448            .as_ref()
449            .write_at(buf, self.offset + offset as u64)
450            .expect("valid mapping");
451        assert_eq!(n, buf.len());
452    }
453}
454
455impl DeviceRegisterIo for MappedRegionWithFallback {
456    fn len(&self) -> usize {
457        self.len
458    }
459
460    fn read_u32(&self, offset: usize) -> u32 {
461        self.read_from_mapping(offset).unwrap_or_else(|_| {
462            let mut buf = [0u8; 4];
463            self.read_from_file(offset, &mut buf);
464            u32::from_ne_bytes(buf)
465        })
466    }
467
468    fn read_u64(&self, offset: usize) -> u64 {
469        self.read_from_mapping(offset).unwrap_or_else(|_| {
470            let mut buf = [0u8; 8];
471            self.read_from_file(offset, &mut buf);
472            u64::from_ne_bytes(buf)
473        })
474    }
475
476    fn write_u32(&self, offset: usize, data: u32) {
477        self.write_to_mapping(offset, data).unwrap_or_else(|_| {
478            self.write_to_file(offset, &data.to_ne_bytes());
479        })
480    }
481
482    fn write_u64(&self, offset: usize, data: u64) {
483        self.write_to_mapping(offset, data).unwrap_or_else(|_| {
484            self.write_to_file(offset, &data.to_ne_bytes());
485        })
486    }
487}
488
489#[derive(Clone, Copy, Debug)]
490pub enum PciDeviceResetMethod {
491    NoReset,
492    Acpi,
493    Flr,
494    AfFlr,
495    Pm,
496    Bus,
497}
498
499pub fn vfio_set_device_reset_method(
500    pci_id: impl AsRef<str>,
501    method: PciDeviceResetMethod,
502) -> std::io::Result<()> {
503    let reset_method = match method {
504        PciDeviceResetMethod::NoReset => "\0".as_bytes(),
505        PciDeviceResetMethod::Acpi => "acpi\0".as_bytes(),
506        PciDeviceResetMethod::Flr => "flr\0".as_bytes(),
507        PciDeviceResetMethod::AfFlr => "af_flr\0".as_bytes(),
508        PciDeviceResetMethod::Pm => "pm\0".as_bytes(),
509        PciDeviceResetMethod::Bus => "bus\0".as_bytes(),
510    };
511
512    let path: std::path::PathBuf = ["/sys/bus/pci/devices", pci_id.as_ref(), "reset_method"]
513        .iter()
514        .collect();
515    fs_err::write(path, reset_method)?;
516    Ok(())
517}