Skip to main content

user_driver/
vfio.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Support for accessing a MANA device via VFIO on Linux.
5
6#![cfg(target_os = "linux")]
7#![cfg(feature = "vfio")]
8
9use crate::DeviceBacking;
10use crate::DeviceRegisterIo;
11use crate::DmaClient;
12use crate::interrupt::DeviceInterrupt;
13use crate::interrupt::DeviceInterruptSource;
14use anyhow::Context;
15use futures::FutureExt;
16use futures_concurrency::future::Race;
17use inspect::Inspect;
18use inspect_counters::SharedCounter;
19use nix::errno::Errno;
20use pal_async::task::Spawn;
21use pal_async::task::Task;
22use pal_async::wait::PolledWait;
23use pal_event::Event;
24use std::os::fd::AsFd;
25use std::os::unix::fs::FileExt;
26use std::path::Path;
27use std::sync::Arc;
28use std::sync::atomic::AtomicU32;
29use std::sync::atomic::Ordering::Relaxed;
30use std::time::Duration;
31use uevent::UeventListener;
32use vfio_bindings::bindings::vfio::VFIO_PCI_CONFIG_REGION_INDEX;
33use vfio_sys::IommuType;
34use vfio_sys::IrqInfo;
35use vmcore::vm_task::VmTaskDriver;
36use vmcore::vm_task::VmTaskDriverSource;
37use zerocopy::FromBytes;
38use zerocopy::Immutable;
39use zerocopy::IntoBytes;
40use zerocopy::KnownLayout;
41
42#[derive(Clone)]
43pub enum VfioDmaClients {
44    PersistentOnly(Arc<dyn DmaClient>),
45    EphemeralOnly(Arc<dyn DmaClient>),
46    Split {
47        persistent: Arc<dyn DmaClient>,
48        ephemeral: Arc<dyn DmaClient>,
49    },
50}
51
52/// A device backend accessed via VFIO.
53#[derive(Inspect)]
54pub struct VfioDevice {
55    pci_id: Arc<str>,
56    #[inspect(skip)]
57    _container: vfio_sys::Container,
58    #[inspect(skip)]
59    _group: vfio_sys::Group,
60    #[inspect(skip)]
61    device: Arc<vfio_sys::Device>,
62    #[inspect(skip)]
63    msix_info: IrqInfo,
64    #[inspect(skip)]
65    driver_source: VmTaskDriverSource,
66    #[inspect(iter_by_index)]
67    interrupts: Vec<Option<InterruptState>>,
68    #[inspect(skip)]
69    config_space: vfio_sys::RegionInfo,
70    #[inspect(skip)]
71    dma_clients: VfioDmaClients,
72}
73
74#[derive(Inspect)]
75struct InterruptState {
76    #[inspect(skip)]
77    interrupt: DeviceInterrupt,
78    target_cpu: Arc<AtomicU32>,
79    #[inspect(skip)]
80    _task: Task<()>,
81}
82
83impl Drop for VfioDevice {
84    fn drop(&mut self) {
85        // Just for tracing ...
86        tracing::trace!(pci_id = ?self.pci_id, "dropping vfio device");
87    }
88}
89
90impl VfioDevice {
91    /// Creates a new VFIO-backed device for the PCI device with `pci_id`.
92    pub async fn new(
93        driver_source: &VmTaskDriverSource,
94        pci_id: impl AsRef<str>,
95        dma_clients: VfioDmaClients,
96    ) -> anyhow::Result<Self> {
97        Self::restore(driver_source, pci_id, false, dma_clients).await
98    }
99
100    /// Creates a new VFIO-backed device for the PCI device with `pci_id`.
101    /// or creates a device from the saved state if provided.
102    pub async fn restore(
103        driver_source: &VmTaskDriverSource,
104        pci_id: impl AsRef<str>,
105        keepalive: bool,
106        dma_clients: VfioDmaClients,
107    ) -> anyhow::Result<Self> {
108        let pci_id = pci_id.as_ref();
109        let path = Path::new("/sys/bus/pci/devices").join(pci_id);
110
111        // The vfio device attaches asynchronously after the PCI device is added,
112        // so make sure that it has completed by checking for the vfio-dev subpath.
113        let vmbus_device =
114            std::fs::read_link(&path).context("failed to read link for pci device")?;
115        let instance_path = Path::new("/sys").join(vmbus_device.strip_prefix("../../..")?);
116        let vfio_arrived_path = instance_path.join("vfio-dev");
117        let uevent_listener = UeventListener::new(&driver_source.simple())?;
118        let wait_for_vfio_device =
119            uevent_listener.wait_for_matching_child(&vfio_arrived_path, async |_, _| Some(()));
120        let mut ctx = mesh::CancelContext::new().with_timeout(Duration::from_secs(1));
121        // Ignore any errors and always attempt to open.
122        let _ = ctx.until_cancelled(wait_for_vfio_device).await;
123
124        tracing::info!(pci_id, keepalive, "device arrived");
125        vfio_sys::print_relevant_params();
126
127        let driver = driver_source.simple();
128        let retry = vfio_sys::VfioRetry::new(&driver, pci_id);
129
130        let is_not_found = |e: &anyhow::Error| {
131            e.chain().any(|cause| {
132                cause
133                    .downcast_ref::<std::io::Error>()
134                    .is_some_and(|io_err| io_err.kind() == std::io::ErrorKind::NotFound)
135            })
136        };
137        let is_enodev = |e: &anyhow::Error| {
138            e.chain().any(|cause| {
139                cause
140                    .downcast_ref::<Errno>()
141                    .is_some_and(|e| *e == Errno::ENODEV)
142            })
143        };
144
145        let container = vfio_sys::Container::new()?;
146        let group_id = retry
147            .retry(
148                || vfio_sys::Group::find_group_for_device(&path),
149                &is_not_found,
150                "find_group_for_device",
151            )
152            .await?;
153        let group = retry
154            .retry(
155                || vfio_sys::Group::open_noiommu(group_id),
156                &is_not_found,
157                "open_noiommu",
158            )
159            .await?;
160        group.set_container(&container)?;
161        if !group.status()?.viable() {
162            anyhow::bail!("group is not viable");
163        }
164
165        container.set_iommu(IommuType::NoIommu)?;
166        if keepalive {
167            retry
168                .retry(
169                    || group.set_keep_alive(pci_id),
170                    &is_enodev,
171                    "set_keep_alive",
172                )
173                .await?;
174        }
175        tracing::debug!(pci_id, "about to open device");
176        let device = retry
177            .retry(|| group.open_device(pci_id), &is_enodev, "open_device")
178            .await?;
179        let msix_info = device.irq_info(vfio_bindings::bindings::vfio::VFIO_PCI_MSIX_IRQ_INDEX)?;
180        if msix_info.flags.noresize() {
181            anyhow::bail!("unsupported: kernel does not support dynamic msix allocation");
182        }
183
184        let config_space = device.region_info(VFIO_PCI_CONFIG_REGION_INDEX)?;
185        let this = Self {
186            pci_id: pci_id.into(),
187            _container: container,
188            _group: group,
189            device: Arc::new(device),
190            msix_info,
191            config_space,
192            driver_source: driver_source.clone(),
193            interrupts: Vec::new(),
194            dma_clients,
195        };
196
197        tracing::debug!(pci_id, "enabling device...");
198        // Ensure bus master enable and memory space enable are set, and that
199        // INTx is disabled.
200        this.enable_device()
201            .with_context(|| format!("failed to enable device {pci_id}"))?;
202        Ok(this)
203    }
204
205    fn enable_device(&self) -> anyhow::Result<()> {
206        let offset = pci_core::spec::cfg_space::HeaderType00::STATUS_COMMAND.0;
207        let status_command = self.read_config(offset)?;
208        let command = pci_core::spec::cfg_space::Command::from(status_command as u16);
209
210        let command = command
211            .with_bus_master(true)
212            .with_intx_disable(true)
213            .with_mmio_enabled(true);
214
215        let status_command = (status_command & 0xffff0000) | u16::from(command) as u32;
216        self.write_config(offset, status_command)?;
217        Ok(())
218    }
219
220    pub fn read_config(&self, offset: u16) -> anyhow::Result<u32> {
221        if offset as u64 > self.config_space.size - 4 {
222            anyhow::bail!("invalid config offset");
223        }
224
225        let mut buf = [0u8; 4];
226        self.device
227            .as_ref()
228            .as_ref()
229            .read_at(&mut buf, self.config_space.offset + offset as u64)
230            .context("failed to read config")?;
231
232        Ok(u32::from_ne_bytes(buf))
233    }
234
235    pub fn write_config(&self, offset: u16, data: u32) -> anyhow::Result<()> {
236        if offset as u64 > self.config_space.size - 4 {
237            anyhow::bail!("invalid config offset");
238        }
239
240        tracing::trace!(pci_id = ?self.pci_id, offset, data, "writing config");
241        let buf = data.to_ne_bytes();
242        self.device
243            .as_ref()
244            .as_ref()
245            .write_at(&buf, self.config_space.offset + offset as u64)
246            .context("failed to write config")?;
247
248        Ok(())
249    }
250
251    /// Maps PCI BAR[n] to VA space.
252    fn map_bar(&self, n: u8) -> anyhow::Result<MappedRegionWithFallback> {
253        if n >= 6 {
254            anyhow::bail!("invalid bar");
255        }
256        let info = self.device.region_info(n.into())?;
257        let mapping = self.device.map(info.offset, info.size as usize, true)?;
258        trycopy::initialize_try_copy();
259        Ok(MappedRegionWithFallback {
260            device: self.device.clone(),
261            mapping,
262            len: info.size as usize,
263            offset: info.offset,
264            read_fallback: SharedCounter::new(),
265            write_fallback: SharedCounter::new(),
266        })
267    }
268}
269
270/// A mapped region that falls back to read/write if the memory mapped access
271/// fails.
272///
273/// This should only happen for CVM, and only when the MMIO is emulated by the
274/// host.
275#[derive(Inspect)]
276pub struct MappedRegionWithFallback {
277    #[inspect(skip)]
278    device: Arc<vfio_sys::Device>,
279    #[inspect(skip)]
280    mapping: vfio_sys::MappedRegion,
281    offset: u64,
282    len: usize,
283    read_fallback: SharedCounter,
284    write_fallback: SharedCounter,
285}
286
287impl DeviceBacking for VfioDevice {
288    type Registers = MappedRegionWithFallback;
289
290    fn id(&self) -> &str {
291        &self.pci_id
292    }
293
294    fn map_bar(&mut self, n: u8) -> anyhow::Result<Self::Registers> {
295        (*self).map_bar(n)
296    }
297
298    fn dma_client(&self) -> Arc<dyn DmaClient> {
299        // Default to the only present client, or if both are available default to the
300        // persistent client.
301        match &self.dma_clients {
302            VfioDmaClients::EphemeralOnly(client) => client.clone(),
303            VfioDmaClients::PersistentOnly(client) => client.clone(),
304            VfioDmaClients::Split {
305                persistent,
306                ephemeral: _,
307            } => persistent.clone(),
308        }
309    }
310
311    fn dma_client_for(&self, pool: crate::DmaPool) -> anyhow::Result<Arc<dyn DmaClient>> {
312        match &self.dma_clients {
313            VfioDmaClients::PersistentOnly(client) => match pool {
314                crate::DmaPool::Persistent => Ok(client.clone()),
315                crate::DmaPool::Ephemeral => {
316                    anyhow::bail!(
317                        "ephemeral dma pool requested but only persistent client available"
318                    )
319                }
320            },
321            VfioDmaClients::EphemeralOnly(client) => match pool {
322                crate::DmaPool::Ephemeral => Ok(client.clone()),
323                crate::DmaPool::Persistent => {
324                    anyhow::bail!(
325                        "persistent dma pool requested but only ephemeral client available"
326                    )
327                }
328            },
329            VfioDmaClients::Split {
330                persistent,
331                ephemeral,
332            } => match pool {
333                crate::DmaPool::Persistent => Ok(persistent.clone()),
334                crate::DmaPool::Ephemeral => Ok(ephemeral.clone()),
335            },
336        }
337    }
338
339    fn max_interrupt_count(&self) -> u32 {
340        self.msix_info.count
341    }
342
343    fn map_interrupt(&mut self, msix: u32, cpu: u32) -> anyhow::Result<DeviceInterrupt> {
344        if msix >= self.msix_info.count {
345            anyhow::bail!("invalid msix index");
346        }
347        if self.interrupts.len() <= msix as usize {
348            self.interrupts.resize_with(msix as usize + 1, || None);
349        }
350
351        let interrupt = &mut self.interrupts[msix as usize];
352        if let Some(interrupt) = interrupt {
353            // The interrupt has been mapped before. Just retarget it to the new
354            // CPU on the next interrupt, if needed.
355            if interrupt.target_cpu.load(Relaxed) != cpu {
356                interrupt.target_cpu.store(cpu, Relaxed);
357            }
358            return Ok(interrupt.interrupt.clone());
359        }
360
361        let new_interrupt = {
362            let name = format!("vfio-interrupt-{pci_id}-{msix}", pci_id = self.pci_id);
363            let driver = self
364                .driver_source
365                .builder()
366                .run_on_target(true)
367                .target_vp(cpu)
368                .build(&name);
369
370            let event =
371                PolledWait::new(&driver, Event::new()).context("failed to allocate polled wait")?;
372
373            let source = DeviceInterruptSource::new();
374            self.device
375                .map_msix(msix, [event.get().as_fd()])
376                .context("failed to map msix")?;
377
378            // The interrupt's CPU affinity will be set by the task when it
379            // starts. This can block the thread briefly, so it's better to do
380            // it on the target CPU.
381            let irq = vfio_sys::find_msix_irq(&self.pci_id, msix)
382                .context("failed to find irq for msix")?;
383
384            let target_cpu = Arc::new(AtomicU32::new(cpu));
385
386            let interrupt = source.new_target();
387
388            let task = driver.spawn(
389                name,
390                InterruptTask {
391                    driver: driver.clone(),
392                    target_cpu: target_cpu.clone(),
393                    pci_id: self.pci_id.clone(),
394                    msix,
395                    irq,
396                    event,
397                    source,
398                }
399                .run(),
400            );
401
402            InterruptState {
403                interrupt,
404                target_cpu,
405                _task: task,
406            }
407        };
408
409        Ok(interrupt.insert(new_interrupt).interrupt.clone())
410    }
411
412    fn unmap_all_interrupts(&mut self) -> anyhow::Result<()> {
413        if self.interrupts.is_empty() {
414            return Ok(());
415        }
416
417        let count = self.interrupts.len() as u32;
418        self.device
419            .unmap_msix(0, count)
420            .context("failed to unmap all msix vectors")?;
421
422        // Clear local bookkeeping so re-mapping works correctly later.
423        self.interrupts.clear();
424
425        Ok(())
426    }
427}
428
429struct InterruptTask {
430    driver: VmTaskDriver,
431    target_cpu: Arc<AtomicU32>,
432    pci_id: Arc<str>,
433    msix: u32,
434    irq: u32,
435    event: PolledWait<Event>,
436    source: DeviceInterruptSource,
437}
438
439impl InterruptTask {
440    async fn run(mut self) {
441        let mut current_cpu = !0;
442        loop {
443            let next_cpu = self.target_cpu.load(Relaxed);
444            let r = if next_cpu == current_cpu {
445                self.event.wait().await
446            } else {
447                self.driver.retarget_vp(next_cpu);
448                // Wait until the target CPU is ready before updating affinity,
449                // since otherwise the CPU may not be online.
450                enum Event {
451                    TargetVpReady(()),
452                    Interrupt(std::io::Result<()>),
453                }
454                match (
455                    self.driver.wait_target_vp_ready().map(Event::TargetVpReady),
456                    self.event.wait().map(Event::Interrupt),
457                )
458                    .race()
459                    .await
460                {
461                    Event::TargetVpReady(()) => {
462                        if let Err(err) = set_irq_affinity(self.irq, next_cpu) {
463                            // This should only occur due to extreme low resources.
464                            // However, it is not a fatal error--it will just result in
465                            // worse performance--so do not panic.
466                            tracing::error!(
467                                pci_id = self.pci_id.as_ref(),
468                                msix = self.msix,
469                                irq = self.irq,
470                                error = &err as &dyn std::error::Error,
471                                "failed to set irq affinity"
472                            );
473                        }
474                        current_cpu = next_cpu;
475                        continue;
476                    }
477                    Event::Interrupt(r) => {
478                        // An interrupt arrived while waiting for the VP to be
479                        // ready. Signal and loop around to try again.
480                        r
481                    }
482                }
483            };
484
485            r.expect("wait cannot fail on eventfd");
486            self.source.signal();
487        }
488    }
489}
490
491fn set_irq_affinity(irq: u32, cpu: u32) -> std::io::Result<()> {
492    fs_err::write(
493        format!("/proc/irq/{}/smp_affinity_list", irq),
494        cpu.to_string(),
495    )
496}
497
498impl DeviceRegisterIo for vfio_sys::MappedRegion {
499    fn len(&self) -> usize {
500        self.len()
501    }
502
503    fn read_u32(&self, offset: usize) -> u32 {
504        self.read_u32(offset)
505    }
506
507    fn read_u64(&self, offset: usize) -> u64 {
508        self.read_u64(offset)
509    }
510
511    fn write_u32(&self, offset: usize, data: u32) {
512        self.write_u32(offset, data)
513    }
514
515    fn write_u64(&self, offset: usize, data: u64) {
516        self.write_u64(offset, data)
517    }
518}
519
520impl MappedRegionWithFallback {
521    fn mapping<T>(&self, offset: usize) -> *mut T {
522        assert!(
523            offset <= self.mapping.len() - size_of::<T>() && offset.is_multiple_of(align_of::<T>())
524        );
525        if cfg!(feature = "mmio_simulate_fallback") {
526            return std::ptr::NonNull::dangling().as_ptr();
527        }
528        // SAFETY: the offset is validated to be in bounds.
529        unsafe { self.mapping.as_ptr().byte_add(offset).cast() }
530    }
531
532    fn read_from_mapping<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
533        &self,
534        offset: usize,
535    ) -> Result<T, trycopy::MemoryError> {
536        // SAFETY: the offset is validated to be in bounds and aligned.
537        unsafe { trycopy::try_read_volatile(self.mapping::<T>(offset)) }
538    }
539
540    fn write_to_mapping<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
541        &self,
542        offset: usize,
543        data: T,
544    ) -> Result<(), trycopy::MemoryError> {
545        // SAFETY: the offset is validated to be in bounds and aligned.
546        unsafe { trycopy::try_write_volatile(self.mapping::<T>(offset), &data) }
547    }
548
549    fn read_from_file(&self, offset: usize, buf: &mut [u8]) {
550        tracing::trace!(offset, n = buf.len(), "read");
551        self.read_fallback.increment();
552        let n = self
553            .device
554            .as_ref()
555            .as_ref()
556            .read_at(buf, self.offset + offset as u64)
557            .expect("valid mapping");
558        assert_eq!(n, buf.len());
559    }
560
561    fn write_to_file(&self, offset: usize, buf: &[u8]) {
562        tracing::trace!(offset, n = buf.len(), "write");
563        self.write_fallback.increment();
564        let n = self
565            .device
566            .as_ref()
567            .as_ref()
568            .write_at(buf, self.offset + offset as u64)
569            .expect("valid mapping");
570        assert_eq!(n, buf.len());
571    }
572}
573
574impl DeviceRegisterIo for MappedRegionWithFallback {
575    fn len(&self) -> usize {
576        self.len
577    }
578
579    fn read_u32(&self, offset: usize) -> u32 {
580        self.read_from_mapping(offset).unwrap_or_else(|_| {
581            let mut buf = [0u8; 4];
582            self.read_from_file(offset, &mut buf);
583            u32::from_ne_bytes(buf)
584        })
585    }
586
587    fn read_u64(&self, offset: usize) -> u64 {
588        self.read_from_mapping(offset).unwrap_or_else(|_| {
589            let mut buf = [0u8; 8];
590            self.read_from_file(offset, &mut buf);
591            u64::from_ne_bytes(buf)
592        })
593    }
594
595    fn write_u32(&self, offset: usize, data: u32) {
596        self.write_to_mapping(offset, data).unwrap_or_else(|_| {
597            self.write_to_file(offset, &data.to_ne_bytes());
598        })
599    }
600
601    fn write_u64(&self, offset: usize, data: u64) {
602        self.write_to_mapping(offset, data).unwrap_or_else(|_| {
603            self.write_to_file(offset, &data.to_ne_bytes());
604        })
605    }
606}
607
608#[derive(Clone, Copy, Debug)]
609pub enum PciDeviceResetMethod {
610    NoReset,
611    Acpi,
612    Flr,
613    AfFlr,
614    Pm,
615    Bus,
616}
617
618pub fn vfio_set_device_reset_method(
619    pci_id: impl AsRef<str>,
620    method: PciDeviceResetMethod,
621) -> std::io::Result<()> {
622    let reset_method = match method {
623        PciDeviceResetMethod::NoReset => "\0".as_bytes(),
624        PciDeviceResetMethod::Acpi => "acpi\0".as_bytes(),
625        PciDeviceResetMethod::Flr => "flr\0".as_bytes(),
626        PciDeviceResetMethod::AfFlr => "af_flr\0".as_bytes(),
627        PciDeviceResetMethod::Pm => "pm\0".as_bytes(),
628        PciDeviceResetMethod::Bus => "bus\0".as_bytes(),
629    };
630
631    let path: std::path::PathBuf = ["/sys/bus/pci/devices", pci_id.as_ref(), "reset_method"]
632        .iter()
633        .collect();
634    fs_err::write(path, reset_method)?;
635    Ok(())
636}