user_driver/
vfio.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Support for accessing a MANA device via VFIO on Linux.
5
6#![cfg(target_os = "linux")]
7#![cfg(feature = "vfio")]
8
9use crate::DeviceBacking;
10use crate::DeviceRegisterIo;
11use crate::DmaClient;
12use crate::interrupt::DeviceInterrupt;
13use crate::interrupt::DeviceInterruptSource;
14use anyhow::Context;
15use futures::FutureExt;
16use futures_concurrency::future::Race;
17use inspect::Inspect;
18use inspect_counters::SharedCounter;
19use pal_async::task::Spawn;
20use pal_async::task::Task;
21use pal_async::wait::PolledWait;
22use pal_event::Event;
23use std::os::fd::AsFd;
24use std::os::unix::fs::FileExt;
25use std::path::Path;
26use std::sync::Arc;
27use std::sync::atomic::AtomicU32;
28use std::sync::atomic::Ordering::Relaxed;
29use std::time::Duration;
30use uevent::UeventListener;
31use vfio_bindings::bindings::vfio::VFIO_PCI_CONFIG_REGION_INDEX;
32use vfio_sys::IommuType;
33use vfio_sys::IrqInfo;
34use vmcore::vm_task::VmTaskDriver;
35use vmcore::vm_task::VmTaskDriverSource;
36use zerocopy::FromBytes;
37use zerocopy::Immutable;
38use zerocopy::IntoBytes;
39use zerocopy::KnownLayout;
40
41#[derive(Clone)]
42pub enum VfioDmaClients {
43    PersistentOnly(Arc<dyn DmaClient>),
44    EphemeralOnly(Arc<dyn DmaClient>),
45    Split {
46        persistent: Arc<dyn DmaClient>,
47        ephemeral: Arc<dyn DmaClient>,
48    },
49}
50
51/// A device backend accessed via VFIO.
52#[derive(Inspect)]
53pub struct VfioDevice {
54    pci_id: Arc<str>,
55    #[inspect(skip)]
56    _container: vfio_sys::Container,
57    #[inspect(skip)]
58    _group: vfio_sys::Group,
59    #[inspect(skip)]
60    device: Arc<vfio_sys::Device>,
61    #[inspect(skip)]
62    msix_info: IrqInfo,
63    #[inspect(skip)]
64    driver_source: VmTaskDriverSource,
65    #[inspect(iter_by_index)]
66    interrupts: Vec<Option<InterruptState>>,
67    #[inspect(skip)]
68    config_space: vfio_sys::RegionInfo,
69    #[inspect(skip)]
70    dma_clients: VfioDmaClients,
71}
72
73#[derive(Inspect)]
74struct InterruptState {
75    #[inspect(skip)]
76    interrupt: DeviceInterrupt,
77    target_cpu: Arc<AtomicU32>,
78    #[inspect(skip)]
79    _task: Task<()>,
80}
81
82impl Drop for VfioDevice {
83    fn drop(&mut self) {
84        // Just for tracing ...
85        tracing::trace!(pci_id = ?self.pci_id, "dropping vfio device");
86    }
87}
88
89impl VfioDevice {
90    /// Creates a new VFIO-backed device for the PCI device with `pci_id`.
91    pub async fn new(
92        driver_source: &VmTaskDriverSource,
93        pci_id: impl AsRef<str>,
94        dma_clients: VfioDmaClients,
95    ) -> anyhow::Result<Self> {
96        Self::restore(driver_source, pci_id, false, dma_clients).await
97    }
98
99    /// Creates a new VFIO-backed device for the PCI device with `pci_id`.
100    /// or creates a device from the saved state if provided.
101    pub async fn restore(
102        driver_source: &VmTaskDriverSource,
103        pci_id: impl AsRef<str>,
104        keepalive: bool,
105        dma_clients: VfioDmaClients,
106    ) -> anyhow::Result<Self> {
107        let pci_id = pci_id.as_ref();
108        let path = Path::new("/sys/bus/pci/devices").join(pci_id);
109
110        // The vfio device attaches asynchronously after the PCI device is added,
111        // so make sure that it has completed by checking for the vfio-dev subpath.
112        let vmbus_device =
113            std::fs::read_link(&path).context("failed to read link for pci device")?;
114        let instance_path = Path::new("/sys").join(vmbus_device.strip_prefix("../../..")?);
115        let vfio_arrived_path = instance_path.join("vfio-dev");
116        let uevent_listener = UeventListener::new(&driver_source.simple())?;
117        let wait_for_vfio_device =
118            uevent_listener.wait_for_matching_child(&vfio_arrived_path, async |_, _| Some(()));
119        let mut ctx = mesh::CancelContext::new().with_timeout(Duration::from_secs(1));
120        // Ignore any errors and always attempt to open.
121        let _ = ctx.until_cancelled(wait_for_vfio_device).await;
122
123        tracing::info!(pci_id, keepalive, "device arrived");
124
125        let container = vfio_sys::Container::new()?;
126        let group_id = vfio_sys::Group::find_group_for_device(&path)?;
127        let group = vfio_sys::Group::open_noiommu(group_id)?;
128        group.set_container(&container)?;
129        if !group.status()?.viable() {
130            anyhow::bail!("group is not viable");
131        }
132
133        let driver = driver_source.simple();
134        container.set_iommu(IommuType::NoIommu)?;
135        if keepalive {
136            // Prevent physical hardware interaction when restoring.
137            group.set_keep_alive(pci_id, &driver).await?;
138        }
139        tracing::debug!(pci_id, "about to open device");
140        let device = group.open_device(pci_id, &driver).await?;
141        let msix_info = device.irq_info(vfio_bindings::bindings::vfio::VFIO_PCI_MSIX_IRQ_INDEX)?;
142        if msix_info.flags.noresize() {
143            anyhow::bail!("unsupported: kernel does not support dynamic msix allocation");
144        }
145
146        let config_space = device.region_info(VFIO_PCI_CONFIG_REGION_INDEX)?;
147        let this = Self {
148            pci_id: pci_id.into(),
149            _container: container,
150            _group: group,
151            device: Arc::new(device),
152            msix_info,
153            config_space,
154            driver_source: driver_source.clone(),
155            interrupts: Vec::new(),
156            dma_clients,
157        };
158
159        tracing::debug!(pci_id, "enabling device...");
160        // Ensure bus master enable and memory space enable are set, and that
161        // INTx is disabled.
162        this.enable_device()
163            .with_context(|| format!("failed to enable device {pci_id}"))?;
164        Ok(this)
165    }
166
167    fn enable_device(&self) -> anyhow::Result<()> {
168        let offset = pci_core::spec::cfg_space::HeaderType00::STATUS_COMMAND.0;
169        let status_command = self.read_config(offset)?;
170        let command = pci_core::spec::cfg_space::Command::from(status_command as u16);
171
172        let command = command
173            .with_bus_master(true)
174            .with_intx_disable(true)
175            .with_mmio_enabled(true);
176
177        let status_command = (status_command & 0xffff0000) | u16::from(command) as u32;
178        self.write_config(offset, status_command)?;
179        Ok(())
180    }
181
182    pub fn read_config(&self, offset: u16) -> anyhow::Result<u32> {
183        if offset as u64 > self.config_space.size - 4 {
184            anyhow::bail!("invalid config offset");
185        }
186
187        let mut buf = [0u8; 4];
188        self.device
189            .as_ref()
190            .as_ref()
191            .read_at(&mut buf, self.config_space.offset + offset as u64)
192            .context("failed to read config")?;
193
194        Ok(u32::from_ne_bytes(buf))
195    }
196
197    pub fn write_config(&self, offset: u16, data: u32) -> anyhow::Result<()> {
198        if offset as u64 > self.config_space.size - 4 {
199            anyhow::bail!("invalid config offset");
200        }
201
202        tracing::trace!(pci_id = ?self.pci_id, offset, data, "writing config");
203        let buf = data.to_ne_bytes();
204        self.device
205            .as_ref()
206            .as_ref()
207            .write_at(&buf, self.config_space.offset + offset as u64)
208            .context("failed to write config")?;
209
210        Ok(())
211    }
212
213    /// Maps PCI BAR[n] to VA space.
214    fn map_bar(&self, n: u8) -> anyhow::Result<MappedRegionWithFallback> {
215        if n >= 6 {
216            anyhow::bail!("invalid bar");
217        }
218        let info = self.device.region_info(n.into())?;
219        let mapping = self.device.map(info.offset, info.size as usize, true)?;
220        trycopy::initialize_try_copy();
221        Ok(MappedRegionWithFallback {
222            device: self.device.clone(),
223            mapping,
224            len: info.size as usize,
225            offset: info.offset,
226            read_fallback: SharedCounter::new(),
227            write_fallback: SharedCounter::new(),
228        })
229    }
230}
231
232/// A mapped region that falls back to read/write if the memory mapped access
233/// fails.
234///
235/// This should only happen for CVM, and only when the MMIO is emulated by the
236/// host.
237#[derive(Inspect)]
238pub struct MappedRegionWithFallback {
239    #[inspect(skip)]
240    device: Arc<vfio_sys::Device>,
241    #[inspect(skip)]
242    mapping: vfio_sys::MappedRegion,
243    offset: u64,
244    len: usize,
245    read_fallback: SharedCounter,
246    write_fallback: SharedCounter,
247}
248
249impl DeviceBacking for VfioDevice {
250    type Registers = MappedRegionWithFallback;
251
252    fn id(&self) -> &str {
253        &self.pci_id
254    }
255
256    fn map_bar(&mut self, n: u8) -> anyhow::Result<Self::Registers> {
257        (*self).map_bar(n)
258    }
259
260    fn dma_client(&self) -> Arc<dyn DmaClient> {
261        // Default to the only present client, or if both are available default to the
262        // persistent client.
263        match &self.dma_clients {
264            VfioDmaClients::EphemeralOnly(client) => client.clone(),
265            VfioDmaClients::PersistentOnly(client) => client.clone(),
266            VfioDmaClients::Split {
267                persistent,
268                ephemeral: _,
269            } => persistent.clone(),
270        }
271    }
272
273    fn dma_client_for(&self, pool: crate::DmaPool) -> anyhow::Result<Arc<dyn DmaClient>> {
274        match &self.dma_clients {
275            VfioDmaClients::PersistentOnly(client) => match pool {
276                crate::DmaPool::Persistent => Ok(client.clone()),
277                crate::DmaPool::Ephemeral => {
278                    anyhow::bail!(
279                        "ephemeral dma pool requested but only persistent client available"
280                    )
281                }
282            },
283            VfioDmaClients::EphemeralOnly(client) => match pool {
284                crate::DmaPool::Ephemeral => Ok(client.clone()),
285                crate::DmaPool::Persistent => {
286                    anyhow::bail!(
287                        "persistent dma pool requested but only ephemeral client available"
288                    )
289                }
290            },
291            VfioDmaClients::Split {
292                persistent,
293                ephemeral,
294            } => match pool {
295                crate::DmaPool::Persistent => Ok(persistent.clone()),
296                crate::DmaPool::Ephemeral => Ok(ephemeral.clone()),
297            },
298        }
299    }
300
301    fn max_interrupt_count(&self) -> u32 {
302        self.msix_info.count
303    }
304
305    fn map_interrupt(&mut self, msix: u32, cpu: u32) -> anyhow::Result<DeviceInterrupt> {
306        if msix >= self.msix_info.count {
307            anyhow::bail!("invalid msix index");
308        }
309        if self.interrupts.len() <= msix as usize {
310            self.interrupts.resize_with(msix as usize + 1, || None);
311        }
312
313        let interrupt = &mut self.interrupts[msix as usize];
314        if let Some(interrupt) = interrupt {
315            // The interrupt has been mapped before. Just retarget it to the new
316            // CPU on the next interrupt, if needed.
317            if interrupt.target_cpu.load(Relaxed) != cpu {
318                interrupt.target_cpu.store(cpu, Relaxed);
319            }
320            return Ok(interrupt.interrupt.clone());
321        }
322
323        let new_interrupt = {
324            let name = format!("vfio-interrupt-{pci_id}-{msix}", pci_id = self.pci_id);
325            let driver = self
326                .driver_source
327                .builder()
328                .run_on_target(true)
329                .target_vp(cpu)
330                .build(&name);
331
332            let event =
333                PolledWait::new(&driver, Event::new()).context("failed to allocate polled wait")?;
334
335            let source = DeviceInterruptSource::new();
336            self.device
337                .map_msix(msix, [event.get().as_fd()])
338                .context("failed to map msix")?;
339
340            // The interrupt's CPU affinity will be set by the task when it
341            // starts. This can block the thread briefly, so it's better to do
342            // it on the target CPU.
343            let irq = vfio_sys::find_msix_irq(&self.pci_id, msix)
344                .context("failed to find irq for msix")?;
345
346            let target_cpu = Arc::new(AtomicU32::new(cpu));
347
348            let interrupt = source.new_target();
349
350            let task = driver.spawn(
351                name,
352                InterruptTask {
353                    driver: driver.clone(),
354                    target_cpu: target_cpu.clone(),
355                    pci_id: self.pci_id.clone(),
356                    msix,
357                    irq,
358                    event,
359                    source,
360                }
361                .run(),
362            );
363
364            InterruptState {
365                interrupt,
366                target_cpu,
367                _task: task,
368            }
369        };
370
371        Ok(interrupt.insert(new_interrupt).interrupt.clone())
372    }
373
374    fn unmap_all_interrupts(&mut self) -> anyhow::Result<()> {
375        if self.interrupts.is_empty() {
376            return Ok(());
377        }
378
379        let count = self.interrupts.len() as u32;
380        self.device
381            .unmap_msix(0, count)
382            .context("failed to unmap all msix vectors")?;
383
384        // Clear local bookkeeping so re-mapping works correctly later.
385        self.interrupts.clear();
386
387        Ok(())
388    }
389}
390
391struct InterruptTask {
392    driver: VmTaskDriver,
393    target_cpu: Arc<AtomicU32>,
394    pci_id: Arc<str>,
395    msix: u32,
396    irq: u32,
397    event: PolledWait<Event>,
398    source: DeviceInterruptSource,
399}
400
401impl InterruptTask {
402    async fn run(mut self) {
403        let mut current_cpu = !0;
404        loop {
405            let next_cpu = self.target_cpu.load(Relaxed);
406            let r = if next_cpu == current_cpu {
407                self.event.wait().await
408            } else {
409                self.driver.retarget_vp(next_cpu);
410                // Wait until the target CPU is ready before updating affinity,
411                // since otherwise the CPU may not be online.
412                enum Event {
413                    TargetVpReady(()),
414                    Interrupt(std::io::Result<()>),
415                }
416                match (
417                    self.driver.wait_target_vp_ready().map(Event::TargetVpReady),
418                    self.event.wait().map(Event::Interrupt),
419                )
420                    .race()
421                    .await
422                {
423                    Event::TargetVpReady(()) => {
424                        if let Err(err) = set_irq_affinity(self.irq, next_cpu) {
425                            // This should only occur due to extreme low resources.
426                            // However, it is not a fatal error--it will just result in
427                            // worse performance--so do not panic.
428                            tracing::error!(
429                                pci_id = self.pci_id.as_ref(),
430                                msix = self.msix,
431                                irq = self.irq,
432                                error = &err as &dyn std::error::Error,
433                                "failed to set irq affinity"
434                            );
435                        }
436                        current_cpu = next_cpu;
437                        continue;
438                    }
439                    Event::Interrupt(r) => {
440                        // An interrupt arrived while waiting for the VP to be
441                        // ready. Signal and loop around to try again.
442                        r
443                    }
444                }
445            };
446
447            r.expect("wait cannot fail on eventfd");
448            self.source.signal();
449        }
450    }
451}
452
453fn set_irq_affinity(irq: u32, cpu: u32) -> std::io::Result<()> {
454    fs_err::write(
455        format!("/proc/irq/{}/smp_affinity_list", irq),
456        cpu.to_string(),
457    )
458}
459
460impl DeviceRegisterIo for vfio_sys::MappedRegion {
461    fn len(&self) -> usize {
462        self.len()
463    }
464
465    fn read_u32(&self, offset: usize) -> u32 {
466        self.read_u32(offset)
467    }
468
469    fn read_u64(&self, offset: usize) -> u64 {
470        self.read_u64(offset)
471    }
472
473    fn write_u32(&self, offset: usize, data: u32) {
474        self.write_u32(offset, data)
475    }
476
477    fn write_u64(&self, offset: usize, data: u64) {
478        self.write_u64(offset, data)
479    }
480}
481
482impl MappedRegionWithFallback {
483    fn mapping<T>(&self, offset: usize) -> *mut T {
484        assert!(
485            offset <= self.mapping.len() - size_of::<T>() && offset.is_multiple_of(align_of::<T>())
486        );
487        if cfg!(feature = "mmio_simulate_fallback") {
488            return std::ptr::NonNull::dangling().as_ptr();
489        }
490        // SAFETY: the offset is validated to be in bounds.
491        unsafe { self.mapping.as_ptr().byte_add(offset).cast() }
492    }
493
494    fn read_from_mapping<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
495        &self,
496        offset: usize,
497    ) -> Result<T, trycopy::MemoryError> {
498        // SAFETY: the offset is validated to be in bounds and aligned.
499        unsafe { trycopy::try_read_volatile(self.mapping::<T>(offset)) }
500    }
501
502    fn write_to_mapping<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
503        &self,
504        offset: usize,
505        data: T,
506    ) -> Result<(), trycopy::MemoryError> {
507        // SAFETY: the offset is validated to be in bounds and aligned.
508        unsafe { trycopy::try_write_volatile(self.mapping::<T>(offset), &data) }
509    }
510
511    fn read_from_file(&self, offset: usize, buf: &mut [u8]) {
512        tracing::trace!(offset, n = buf.len(), "read");
513        self.read_fallback.increment();
514        let n = self
515            .device
516            .as_ref()
517            .as_ref()
518            .read_at(buf, self.offset + offset as u64)
519            .expect("valid mapping");
520        assert_eq!(n, buf.len());
521    }
522
523    fn write_to_file(&self, offset: usize, buf: &[u8]) {
524        tracing::trace!(offset, n = buf.len(), "write");
525        self.write_fallback.increment();
526        let n = self
527            .device
528            .as_ref()
529            .as_ref()
530            .write_at(buf, self.offset + offset as u64)
531            .expect("valid mapping");
532        assert_eq!(n, buf.len());
533    }
534}
535
536impl DeviceRegisterIo for MappedRegionWithFallback {
537    fn len(&self) -> usize {
538        self.len
539    }
540
541    fn read_u32(&self, offset: usize) -> u32 {
542        self.read_from_mapping(offset).unwrap_or_else(|_| {
543            let mut buf = [0u8; 4];
544            self.read_from_file(offset, &mut buf);
545            u32::from_ne_bytes(buf)
546        })
547    }
548
549    fn read_u64(&self, offset: usize) -> u64 {
550        self.read_from_mapping(offset).unwrap_or_else(|_| {
551            let mut buf = [0u8; 8];
552            self.read_from_file(offset, &mut buf);
553            u64::from_ne_bytes(buf)
554        })
555    }
556
557    fn write_u32(&self, offset: usize, data: u32) {
558        self.write_to_mapping(offset, data).unwrap_or_else(|_| {
559            self.write_to_file(offset, &data.to_ne_bytes());
560        })
561    }
562
563    fn write_u64(&self, offset: usize, data: u64) {
564        self.write_to_mapping(offset, data).unwrap_or_else(|_| {
565            self.write_to_file(offset, &data.to_ne_bytes());
566        })
567    }
568}
569
570#[derive(Clone, Copy, Debug)]
571pub enum PciDeviceResetMethod {
572    NoReset,
573    Acpi,
574    Flr,
575    AfFlr,
576    Pm,
577    Bus,
578}
579
580pub fn vfio_set_device_reset_method(
581    pci_id: impl AsRef<str>,
582    method: PciDeviceResetMethod,
583) -> std::io::Result<()> {
584    let reset_method = match method {
585        PciDeviceResetMethod::NoReset => "\0".as_bytes(),
586        PciDeviceResetMethod::Acpi => "acpi\0".as_bytes(),
587        PciDeviceResetMethod::Flr => "flr\0".as_bytes(),
588        PciDeviceResetMethod::AfFlr => "af_flr\0".as_bytes(),
589        PciDeviceResetMethod::Pm => "pm\0".as_bytes(),
590        PciDeviceResetMethod::Bus => "bus\0".as_bytes(),
591    };
592
593    let path: std::path::PathBuf = ["/sys/bus/pci/devices", pci_id.as_ref(), "reset_method"]
594        .iter()
595        .collect();
596    fs_err::write(path, reset_method)?;
597    Ok(())
598}