user_driver/
vfio.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Support for accessing a MANA device via VFIO on Linux.
5
6#![cfg(target_os = "linux")]
7#![cfg(feature = "vfio")]
8
9use crate::DeviceBacking;
10use crate::DeviceRegisterIo;
11use crate::DmaClient;
12use crate::interrupt::DeviceInterrupt;
13use crate::interrupt::DeviceInterruptSource;
14use anyhow::Context;
15use futures::FutureExt;
16use futures_concurrency::future::Race;
17use inspect::Inspect;
18use inspect_counters::SharedCounter;
19use pal_async::task::Spawn;
20use pal_async::task::Task;
21use pal_async::wait::PolledWait;
22use pal_event::Event;
23use std::os::fd::AsFd;
24use std::os::unix::fs::FileExt;
25use std::path::Path;
26use std::sync::Arc;
27use std::sync::atomic::AtomicU32;
28use std::sync::atomic::Ordering::Relaxed;
29use std::time::Duration;
30use uevent::UeventListener;
31use vfio_bindings::bindings::vfio::VFIO_PCI_CONFIG_REGION_INDEX;
32use vfio_sys::IommuType;
33use vfio_sys::IrqInfo;
34use vmcore::vm_task::VmTaskDriver;
35use vmcore::vm_task::VmTaskDriverSource;
36use zerocopy::FromBytes;
37use zerocopy::Immutable;
38use zerocopy::IntoBytes;
39use zerocopy::KnownLayout;
40
41#[derive(Clone)]
42pub enum VfioDmaClients {
43    PersistentOnly(Arc<dyn DmaClient>),
44    EphemeralOnly(Arc<dyn DmaClient>),
45    Split {
46        persistent: Arc<dyn DmaClient>,
47        ephemeral: Arc<dyn DmaClient>,
48    },
49}
50
51/// A device backend accessed via VFIO.
52#[derive(Inspect)]
53pub struct VfioDevice {
54    pci_id: Arc<str>,
55    #[inspect(skip)]
56    _container: vfio_sys::Container,
57    #[inspect(skip)]
58    _group: vfio_sys::Group,
59    #[inspect(skip)]
60    device: Arc<vfio_sys::Device>,
61    #[inspect(skip)]
62    msix_info: IrqInfo,
63    #[inspect(skip)]
64    driver_source: VmTaskDriverSource,
65    #[inspect(iter_by_index)]
66    interrupts: Vec<Option<InterruptState>>,
67    #[inspect(skip)]
68    config_space: vfio_sys::RegionInfo,
69    #[inspect(skip)]
70    dma_clients: VfioDmaClients,
71}
72
73#[derive(Inspect)]
74struct InterruptState {
75    #[inspect(skip)]
76    interrupt: DeviceInterrupt,
77    target_cpu: Arc<AtomicU32>,
78    #[inspect(skip)]
79    _task: Task<()>,
80}
81
82impl Drop for VfioDevice {
83    fn drop(&mut self) {
84        // Just for tracing ...
85        tracing::trace!(pci_id = ?self.pci_id, "dropping vfio device");
86    }
87}
88
89impl VfioDevice {
90    /// Creates a new VFIO-backed device for the PCI device with `pci_id`.
91    pub async fn new(
92        driver_source: &VmTaskDriverSource,
93        pci_id: impl AsRef<str>,
94        dma_clients: VfioDmaClients,
95    ) -> anyhow::Result<Self> {
96        Self::restore(driver_source, pci_id, false, dma_clients).await
97    }
98
99    /// Creates a new VFIO-backed device for the PCI device with `pci_id`.
100    /// or creates a device from the saved state if provided.
101    pub async fn restore(
102        driver_source: &VmTaskDriverSource,
103        pci_id: impl AsRef<str>,
104        keepalive: bool,
105        dma_clients: VfioDmaClients,
106    ) -> anyhow::Result<Self> {
107        let pci_id = pci_id.as_ref();
108        let path = Path::new("/sys/bus/pci/devices").join(pci_id);
109
110        // The vfio device attaches asynchronously after the PCI device is added,
111        // so make sure that it has completed by checking for the vfio-dev subpath.
112        let vmbus_device =
113            std::fs::read_link(&path).context("failed to read link for pci device")?;
114        let instance_path = Path::new("/sys").join(vmbus_device.strip_prefix("../../..")?);
115        let vfio_arrived_path = instance_path.join("vfio-dev");
116        let uevent_listener = UeventListener::new(&driver_source.simple())?;
117        let wait_for_vfio_device =
118            uevent_listener.wait_for_matching_child(&vfio_arrived_path, async |_, _| Some(()));
119        let mut ctx = mesh::CancelContext::new().with_timeout(Duration::from_secs(1));
120        // Ignore any errors and always attempt to open.
121        let _ = ctx.until_cancelled(wait_for_vfio_device).await;
122
123        tracing::info!(pci_id, keepalive, "device arrived");
124        vfio_sys::print_relevant_params();
125
126        let container = vfio_sys::Container::new()?;
127        let group_id = vfio_sys::Group::find_group_for_device(&path)?;
128        let group = vfio_sys::Group::open_noiommu(group_id)?;
129        group.set_container(&container)?;
130        if !group.status()?.viable() {
131            anyhow::bail!("group is not viable");
132        }
133
134        let driver = driver_source.simple();
135        container.set_iommu(IommuType::NoIommu)?;
136        if keepalive {
137            // Prevent physical hardware interaction when restoring.
138            group.set_keep_alive(pci_id, &driver).await?;
139        }
140        tracing::debug!(pci_id, "about to open device");
141        let device = group.open_device(pci_id, &driver).await?;
142        let msix_info = device.irq_info(vfio_bindings::bindings::vfio::VFIO_PCI_MSIX_IRQ_INDEX)?;
143        if msix_info.flags.noresize() {
144            anyhow::bail!("unsupported: kernel does not support dynamic msix allocation");
145        }
146
147        let config_space = device.region_info(VFIO_PCI_CONFIG_REGION_INDEX)?;
148        let this = Self {
149            pci_id: pci_id.into(),
150            _container: container,
151            _group: group,
152            device: Arc::new(device),
153            msix_info,
154            config_space,
155            driver_source: driver_source.clone(),
156            interrupts: Vec::new(),
157            dma_clients,
158        };
159
160        tracing::debug!(pci_id, "enabling device...");
161        // Ensure bus master enable and memory space enable are set, and that
162        // INTx is disabled.
163        this.enable_device()
164            .with_context(|| format!("failed to enable device {pci_id}"))?;
165        Ok(this)
166    }
167
168    fn enable_device(&self) -> anyhow::Result<()> {
169        let offset = pci_core::spec::cfg_space::HeaderType00::STATUS_COMMAND.0;
170        let status_command = self.read_config(offset)?;
171        let command = pci_core::spec::cfg_space::Command::from(status_command as u16);
172
173        let command = command
174            .with_bus_master(true)
175            .with_intx_disable(true)
176            .with_mmio_enabled(true);
177
178        let status_command = (status_command & 0xffff0000) | u16::from(command) as u32;
179        self.write_config(offset, status_command)?;
180        Ok(())
181    }
182
183    pub fn read_config(&self, offset: u16) -> anyhow::Result<u32> {
184        if offset as u64 > self.config_space.size - 4 {
185            anyhow::bail!("invalid config offset");
186        }
187
188        let mut buf = [0u8; 4];
189        self.device
190            .as_ref()
191            .as_ref()
192            .read_at(&mut buf, self.config_space.offset + offset as u64)
193            .context("failed to read config")?;
194
195        Ok(u32::from_ne_bytes(buf))
196    }
197
198    pub fn write_config(&self, offset: u16, data: u32) -> anyhow::Result<()> {
199        if offset as u64 > self.config_space.size - 4 {
200            anyhow::bail!("invalid config offset");
201        }
202
203        tracing::trace!(pci_id = ?self.pci_id, offset, data, "writing config");
204        let buf = data.to_ne_bytes();
205        self.device
206            .as_ref()
207            .as_ref()
208            .write_at(&buf, self.config_space.offset + offset as u64)
209            .context("failed to write config")?;
210
211        Ok(())
212    }
213
214    /// Maps PCI BAR[n] to VA space.
215    fn map_bar(&self, n: u8) -> anyhow::Result<MappedRegionWithFallback> {
216        if n >= 6 {
217            anyhow::bail!("invalid bar");
218        }
219        let info = self.device.region_info(n.into())?;
220        let mapping = self.device.map(info.offset, info.size as usize, true)?;
221        trycopy::initialize_try_copy();
222        Ok(MappedRegionWithFallback {
223            device: self.device.clone(),
224            mapping,
225            len: info.size as usize,
226            offset: info.offset,
227            read_fallback: SharedCounter::new(),
228            write_fallback: SharedCounter::new(),
229        })
230    }
231}
232
233/// A mapped region that falls back to read/write if the memory mapped access
234/// fails.
235///
236/// This should only happen for CVM, and only when the MMIO is emulated by the
237/// host.
238#[derive(Inspect)]
239pub struct MappedRegionWithFallback {
240    #[inspect(skip)]
241    device: Arc<vfio_sys::Device>,
242    #[inspect(skip)]
243    mapping: vfio_sys::MappedRegion,
244    offset: u64,
245    len: usize,
246    read_fallback: SharedCounter,
247    write_fallback: SharedCounter,
248}
249
250impl DeviceBacking for VfioDevice {
251    type Registers = MappedRegionWithFallback;
252
253    fn id(&self) -> &str {
254        &self.pci_id
255    }
256
257    fn map_bar(&mut self, n: u8) -> anyhow::Result<Self::Registers> {
258        (*self).map_bar(n)
259    }
260
261    fn dma_client(&self) -> Arc<dyn DmaClient> {
262        // Default to the only present client, or if both are available default to the
263        // persistent client.
264        match &self.dma_clients {
265            VfioDmaClients::EphemeralOnly(client) => client.clone(),
266            VfioDmaClients::PersistentOnly(client) => client.clone(),
267            VfioDmaClients::Split {
268                persistent,
269                ephemeral: _,
270            } => persistent.clone(),
271        }
272    }
273
274    fn dma_client_for(&self, pool: crate::DmaPool) -> anyhow::Result<Arc<dyn DmaClient>> {
275        match &self.dma_clients {
276            VfioDmaClients::PersistentOnly(client) => match pool {
277                crate::DmaPool::Persistent => Ok(client.clone()),
278                crate::DmaPool::Ephemeral => {
279                    anyhow::bail!(
280                        "ephemeral dma pool requested but only persistent client available"
281                    )
282                }
283            },
284            VfioDmaClients::EphemeralOnly(client) => match pool {
285                crate::DmaPool::Ephemeral => Ok(client.clone()),
286                crate::DmaPool::Persistent => {
287                    anyhow::bail!(
288                        "persistent dma pool requested but only ephemeral client available"
289                    )
290                }
291            },
292            VfioDmaClients::Split {
293                persistent,
294                ephemeral,
295            } => match pool {
296                crate::DmaPool::Persistent => Ok(persistent.clone()),
297                crate::DmaPool::Ephemeral => Ok(ephemeral.clone()),
298            },
299        }
300    }
301
302    fn max_interrupt_count(&self) -> u32 {
303        self.msix_info.count
304    }
305
306    fn map_interrupt(&mut self, msix: u32, cpu: u32) -> anyhow::Result<DeviceInterrupt> {
307        if msix >= self.msix_info.count {
308            anyhow::bail!("invalid msix index");
309        }
310        if self.interrupts.len() <= msix as usize {
311            self.interrupts.resize_with(msix as usize + 1, || None);
312        }
313
314        let interrupt = &mut self.interrupts[msix as usize];
315        if let Some(interrupt) = interrupt {
316            // The interrupt has been mapped before. Just retarget it to the new
317            // CPU on the next interrupt, if needed.
318            if interrupt.target_cpu.load(Relaxed) != cpu {
319                interrupt.target_cpu.store(cpu, Relaxed);
320            }
321            return Ok(interrupt.interrupt.clone());
322        }
323
324        let new_interrupt = {
325            let name = format!("vfio-interrupt-{pci_id}-{msix}", pci_id = self.pci_id);
326            let driver = self
327                .driver_source
328                .builder()
329                .run_on_target(true)
330                .target_vp(cpu)
331                .build(&name);
332
333            let event =
334                PolledWait::new(&driver, Event::new()).context("failed to allocate polled wait")?;
335
336            let source = DeviceInterruptSource::new();
337            self.device
338                .map_msix(msix, [event.get().as_fd()])
339                .context("failed to map msix")?;
340
341            // The interrupt's CPU affinity will be set by the task when it
342            // starts. This can block the thread briefly, so it's better to do
343            // it on the target CPU.
344            let irq = vfio_sys::find_msix_irq(&self.pci_id, msix)
345                .context("failed to find irq for msix")?;
346
347            let target_cpu = Arc::new(AtomicU32::new(cpu));
348
349            let interrupt = source.new_target();
350
351            let task = driver.spawn(
352                name,
353                InterruptTask {
354                    driver: driver.clone(),
355                    target_cpu: target_cpu.clone(),
356                    pci_id: self.pci_id.clone(),
357                    msix,
358                    irq,
359                    event,
360                    source,
361                }
362                .run(),
363            );
364
365            InterruptState {
366                interrupt,
367                target_cpu,
368                _task: task,
369            }
370        };
371
372        Ok(interrupt.insert(new_interrupt).interrupt.clone())
373    }
374
375    fn unmap_all_interrupts(&mut self) -> anyhow::Result<()> {
376        if self.interrupts.is_empty() {
377            return Ok(());
378        }
379
380        let count = self.interrupts.len() as u32;
381        self.device
382            .unmap_msix(0, count)
383            .context("failed to unmap all msix vectors")?;
384
385        // Clear local bookkeeping so re-mapping works correctly later.
386        self.interrupts.clear();
387
388        Ok(())
389    }
390}
391
392struct InterruptTask {
393    driver: VmTaskDriver,
394    target_cpu: Arc<AtomicU32>,
395    pci_id: Arc<str>,
396    msix: u32,
397    irq: u32,
398    event: PolledWait<Event>,
399    source: DeviceInterruptSource,
400}
401
402impl InterruptTask {
403    async fn run(mut self) {
404        let mut current_cpu = !0;
405        loop {
406            let next_cpu = self.target_cpu.load(Relaxed);
407            let r = if next_cpu == current_cpu {
408                self.event.wait().await
409            } else {
410                self.driver.retarget_vp(next_cpu);
411                // Wait until the target CPU is ready before updating affinity,
412                // since otherwise the CPU may not be online.
413                enum Event {
414                    TargetVpReady(()),
415                    Interrupt(std::io::Result<()>),
416                }
417                match (
418                    self.driver.wait_target_vp_ready().map(Event::TargetVpReady),
419                    self.event.wait().map(Event::Interrupt),
420                )
421                    .race()
422                    .await
423                {
424                    Event::TargetVpReady(()) => {
425                        if let Err(err) = set_irq_affinity(self.irq, next_cpu) {
426                            // This should only occur due to extreme low resources.
427                            // However, it is not a fatal error--it will just result in
428                            // worse performance--so do not panic.
429                            tracing::error!(
430                                pci_id = self.pci_id.as_ref(),
431                                msix = self.msix,
432                                irq = self.irq,
433                                error = &err as &dyn std::error::Error,
434                                "failed to set irq affinity"
435                            );
436                        }
437                        current_cpu = next_cpu;
438                        continue;
439                    }
440                    Event::Interrupt(r) => {
441                        // An interrupt arrived while waiting for the VP to be
442                        // ready. Signal and loop around to try again.
443                        r
444                    }
445                }
446            };
447
448            r.expect("wait cannot fail on eventfd");
449            self.source.signal();
450        }
451    }
452}
453
454fn set_irq_affinity(irq: u32, cpu: u32) -> std::io::Result<()> {
455    fs_err::write(
456        format!("/proc/irq/{}/smp_affinity_list", irq),
457        cpu.to_string(),
458    )
459}
460
461impl DeviceRegisterIo for vfio_sys::MappedRegion {
462    fn len(&self) -> usize {
463        self.len()
464    }
465
466    fn read_u32(&self, offset: usize) -> u32 {
467        self.read_u32(offset)
468    }
469
470    fn read_u64(&self, offset: usize) -> u64 {
471        self.read_u64(offset)
472    }
473
474    fn write_u32(&self, offset: usize, data: u32) {
475        self.write_u32(offset, data)
476    }
477
478    fn write_u64(&self, offset: usize, data: u64) {
479        self.write_u64(offset, data)
480    }
481}
482
483impl MappedRegionWithFallback {
484    fn mapping<T>(&self, offset: usize) -> *mut T {
485        assert!(
486            offset <= self.mapping.len() - size_of::<T>() && offset.is_multiple_of(align_of::<T>())
487        );
488        if cfg!(feature = "mmio_simulate_fallback") {
489            return std::ptr::NonNull::dangling().as_ptr();
490        }
491        // SAFETY: the offset is validated to be in bounds.
492        unsafe { self.mapping.as_ptr().byte_add(offset).cast() }
493    }
494
495    fn read_from_mapping<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
496        &self,
497        offset: usize,
498    ) -> Result<T, trycopy::MemoryError> {
499        // SAFETY: the offset is validated to be in bounds and aligned.
500        unsafe { trycopy::try_read_volatile(self.mapping::<T>(offset)) }
501    }
502
503    fn write_to_mapping<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
504        &self,
505        offset: usize,
506        data: T,
507    ) -> Result<(), trycopy::MemoryError> {
508        // SAFETY: the offset is validated to be in bounds and aligned.
509        unsafe { trycopy::try_write_volatile(self.mapping::<T>(offset), &data) }
510    }
511
512    fn read_from_file(&self, offset: usize, buf: &mut [u8]) {
513        tracing::trace!(offset, n = buf.len(), "read");
514        self.read_fallback.increment();
515        let n = self
516            .device
517            .as_ref()
518            .as_ref()
519            .read_at(buf, self.offset + offset as u64)
520            .expect("valid mapping");
521        assert_eq!(n, buf.len());
522    }
523
524    fn write_to_file(&self, offset: usize, buf: &[u8]) {
525        tracing::trace!(offset, n = buf.len(), "write");
526        self.write_fallback.increment();
527        let n = self
528            .device
529            .as_ref()
530            .as_ref()
531            .write_at(buf, self.offset + offset as u64)
532            .expect("valid mapping");
533        assert_eq!(n, buf.len());
534    }
535}
536
537impl DeviceRegisterIo for MappedRegionWithFallback {
538    fn len(&self) -> usize {
539        self.len
540    }
541
542    fn read_u32(&self, offset: usize) -> u32 {
543        self.read_from_mapping(offset).unwrap_or_else(|_| {
544            let mut buf = [0u8; 4];
545            self.read_from_file(offset, &mut buf);
546            u32::from_ne_bytes(buf)
547        })
548    }
549
550    fn read_u64(&self, offset: usize) -> u64 {
551        self.read_from_mapping(offset).unwrap_or_else(|_| {
552            let mut buf = [0u8; 8];
553            self.read_from_file(offset, &mut buf);
554            u64::from_ne_bytes(buf)
555        })
556    }
557
558    fn write_u32(&self, offset: usize, data: u32) {
559        self.write_to_mapping(offset, data).unwrap_or_else(|_| {
560            self.write_to_file(offset, &data.to_ne_bytes());
561        })
562    }
563
564    fn write_u64(&self, offset: usize, data: u64) {
565        self.write_to_mapping(offset, data).unwrap_or_else(|_| {
566            self.write_to_file(offset, &data.to_ne_bytes());
567        })
568    }
569}
570
571#[derive(Clone, Copy, Debug)]
572pub enum PciDeviceResetMethod {
573    NoReset,
574    Acpi,
575    Flr,
576    AfFlr,
577    Pm,
578    Bus,
579}
580
581pub fn vfio_set_device_reset_method(
582    pci_id: impl AsRef<str>,
583    method: PciDeviceResetMethod,
584) -> std::io::Result<()> {
585    let reset_method = match method {
586        PciDeviceResetMethod::NoReset => "\0".as_bytes(),
587        PciDeviceResetMethod::Acpi => "acpi\0".as_bytes(),
588        PciDeviceResetMethod::Flr => "flr\0".as_bytes(),
589        PciDeviceResetMethod::AfFlr => "af_flr\0".as_bytes(),
590        PciDeviceResetMethod::Pm => "pm\0".as_bytes(),
591        PciDeviceResetMethod::Bus => "bus\0".as_bytes(),
592    };
593
594    let path: std::path::PathBuf = ["/sys/bus/pci/devices", pci_id.as_ref(), "reset_method"]
595        .iter()
596        .collect();
597    fs_err::write(path, reset_method)?;
598    Ok(())
599}