1#![cfg(target_os = "linux")]
7#![cfg(feature = "vfio")]
8
9use crate::DeviceBacking;
10use crate::DeviceRegisterIo;
11use crate::DmaClient;
12use crate::interrupt::DeviceInterrupt;
13use crate::interrupt::DeviceInterruptSource;
14use anyhow::Context;
15use futures::FutureExt;
16use futures_concurrency::future::Race;
17use inspect::Inspect;
18use inspect_counters::SharedCounter;
19use nix::errno::Errno;
20use pal_async::task::Spawn;
21use pal_async::task::Task;
22use pal_async::wait::PolledWait;
23use pal_event::Event;
24use std::os::fd::AsFd;
25use std::os::unix::fs::FileExt;
26use std::path::Path;
27use std::sync::Arc;
28use std::sync::atomic::AtomicU32;
29use std::sync::atomic::Ordering::Relaxed;
30use std::time::Duration;
31use uevent::UeventListener;
32use vfio_bindings::bindings::vfio::VFIO_PCI_CONFIG_REGION_INDEX;
33use vfio_sys::IommuType;
34use vfio_sys::IrqInfo;
35use vmcore::vm_task::VmTaskDriver;
36use vmcore::vm_task::VmTaskDriverSource;
37use zerocopy::FromBytes;
38use zerocopy::Immutable;
39use zerocopy::IntoBytes;
40use zerocopy::KnownLayout;
41
42#[derive(Clone)]
43pub enum VfioDmaClients {
44 PersistentOnly(Arc<dyn DmaClient>),
45 EphemeralOnly(Arc<dyn DmaClient>),
46 Split {
47 persistent: Arc<dyn DmaClient>,
48 ephemeral: Arc<dyn DmaClient>,
49 },
50}
51
52#[derive(Inspect)]
54pub struct VfioDevice {
55 pci_id: Arc<str>,
56 #[inspect(skip)]
57 _container: vfio_sys::Container,
58 #[inspect(skip)]
59 _group: vfio_sys::Group,
60 #[inspect(skip)]
61 device: Arc<vfio_sys::Device>,
62 #[inspect(skip)]
63 msix_info: IrqInfo,
64 #[inspect(skip)]
65 driver_source: VmTaskDriverSource,
66 #[inspect(iter_by_index)]
67 interrupts: Vec<Option<InterruptState>>,
68 #[inspect(skip)]
69 config_space: vfio_sys::RegionInfo,
70 #[inspect(skip)]
71 dma_clients: VfioDmaClients,
72}
73
74#[derive(Inspect)]
75struct InterruptState {
76 #[inspect(skip)]
77 interrupt: DeviceInterrupt,
78 target_cpu: Arc<AtomicU32>,
79 #[inspect(skip)]
80 _task: Task<()>,
81}
82
83impl Drop for VfioDevice {
84 fn drop(&mut self) {
85 tracing::trace!(pci_id = ?self.pci_id, "dropping vfio device");
87 }
88}
89
90impl VfioDevice {
91 pub async fn new(
93 driver_source: &VmTaskDriverSource,
94 pci_id: impl AsRef<str>,
95 dma_clients: VfioDmaClients,
96 ) -> anyhow::Result<Self> {
97 Self::restore(driver_source, pci_id, false, dma_clients).await
98 }
99
100 pub async fn restore(
103 driver_source: &VmTaskDriverSource,
104 pci_id: impl AsRef<str>,
105 keepalive: bool,
106 dma_clients: VfioDmaClients,
107 ) -> anyhow::Result<Self> {
108 let pci_id = pci_id.as_ref();
109 let path = Path::new("/sys/bus/pci/devices").join(pci_id);
110
111 let vmbus_device =
114 std::fs::read_link(&path).context("failed to read link for pci device")?;
115 let instance_path = Path::new("/sys").join(vmbus_device.strip_prefix("../../..")?);
116 let vfio_arrived_path = instance_path.join("vfio-dev");
117 let uevent_listener = UeventListener::new(&driver_source.simple())?;
118 let wait_for_vfio_device =
119 uevent_listener.wait_for_matching_child(&vfio_arrived_path, async |_, _| Some(()));
120 let mut ctx = mesh::CancelContext::new().with_timeout(Duration::from_secs(1));
121 let _ = ctx.until_cancelled(wait_for_vfio_device).await;
123
124 tracing::info!(pci_id, keepalive, "device arrived");
125 vfio_sys::print_relevant_params();
126
127 let driver = driver_source.simple();
128 let retry = vfio_sys::VfioRetry::new(&driver, pci_id);
129
130 let is_not_found = |e: &anyhow::Error| {
131 e.chain().any(|cause| {
132 cause
133 .downcast_ref::<std::io::Error>()
134 .is_some_and(|io_err| io_err.kind() == std::io::ErrorKind::NotFound)
135 })
136 };
137 let is_enodev = |e: &anyhow::Error| {
138 e.chain().any(|cause| {
139 cause
140 .downcast_ref::<Errno>()
141 .is_some_and(|e| *e == Errno::ENODEV)
142 })
143 };
144
145 let container = vfio_sys::Container::new()?;
146 let group_id = retry
147 .retry(
148 || vfio_sys::Group::find_group_for_device(&path),
149 &is_not_found,
150 "find_group_for_device",
151 )
152 .await?;
153 let group = retry
154 .retry(
155 || vfio_sys::Group::open_noiommu(group_id),
156 &is_not_found,
157 "open_noiommu",
158 )
159 .await?;
160 group.set_container(&container)?;
161 if !group.status()?.viable() {
162 anyhow::bail!("group is not viable");
163 }
164
165 container.set_iommu(IommuType::NoIommu)?;
166 if keepalive {
167 retry
168 .retry(
169 || group.set_keep_alive(pci_id),
170 &is_enodev,
171 "set_keep_alive",
172 )
173 .await?;
174 }
175 tracing::debug!(pci_id, "about to open device");
176 let device = retry
177 .retry(|| group.open_device(pci_id), &is_enodev, "open_device")
178 .await?;
179 let msix_info = device.irq_info(vfio_bindings::bindings::vfio::VFIO_PCI_MSIX_IRQ_INDEX)?;
180 if msix_info.flags.noresize() {
181 anyhow::bail!("unsupported: kernel does not support dynamic msix allocation");
182 }
183
184 let config_space = device.region_info(VFIO_PCI_CONFIG_REGION_INDEX)?;
185 let this = Self {
186 pci_id: pci_id.into(),
187 _container: container,
188 _group: group,
189 device: Arc::new(device),
190 msix_info,
191 config_space,
192 driver_source: driver_source.clone(),
193 interrupts: Vec::new(),
194 dma_clients,
195 };
196
197 tracing::debug!(pci_id, "enabling device...");
198 this.enable_device()
201 .with_context(|| format!("failed to enable device {pci_id}"))?;
202 Ok(this)
203 }
204
205 fn enable_device(&self) -> anyhow::Result<()> {
206 let offset = pci_core::spec::cfg_space::HeaderType00::STATUS_COMMAND.0;
207 let status_command = self.read_config(offset)?;
208 let command = pci_core::spec::cfg_space::Command::from(status_command as u16);
209
210 let command = command
211 .with_bus_master(true)
212 .with_intx_disable(true)
213 .with_mmio_enabled(true);
214
215 let status_command = (status_command & 0xffff0000) | u16::from(command) as u32;
216 self.write_config(offset, status_command)?;
217 Ok(())
218 }
219
220 pub fn read_config(&self, offset: u16) -> anyhow::Result<u32> {
221 if offset as u64 > self.config_space.size - 4 {
222 anyhow::bail!("invalid config offset");
223 }
224
225 let mut buf = [0u8; 4];
226 self.device
227 .as_ref()
228 .as_ref()
229 .read_at(&mut buf, self.config_space.offset + offset as u64)
230 .context("failed to read config")?;
231
232 Ok(u32::from_ne_bytes(buf))
233 }
234
235 pub fn write_config(&self, offset: u16, data: u32) -> anyhow::Result<()> {
236 if offset as u64 > self.config_space.size - 4 {
237 anyhow::bail!("invalid config offset");
238 }
239
240 tracing::trace!(pci_id = ?self.pci_id, offset, data, "writing config");
241 let buf = data.to_ne_bytes();
242 self.device
243 .as_ref()
244 .as_ref()
245 .write_at(&buf, self.config_space.offset + offset as u64)
246 .context("failed to write config")?;
247
248 Ok(())
249 }
250
251 fn map_bar(&self, n: u8) -> anyhow::Result<MappedRegionWithFallback> {
253 if n >= 6 {
254 anyhow::bail!("invalid bar");
255 }
256 let info = self.device.region_info(n.into())?;
257 let mapping = self.device.map(info.offset, info.size as usize, true)?;
258 trycopy::initialize_try_copy();
259 Ok(MappedRegionWithFallback {
260 device: self.device.clone(),
261 mapping,
262 len: info.size as usize,
263 offset: info.offset,
264 read_fallback: SharedCounter::new(),
265 write_fallback: SharedCounter::new(),
266 })
267 }
268}
269
270#[derive(Inspect)]
276pub struct MappedRegionWithFallback {
277 #[inspect(skip)]
278 device: Arc<vfio_sys::Device>,
279 #[inspect(skip)]
280 mapping: vfio_sys::MappedRegion,
281 offset: u64,
282 len: usize,
283 read_fallback: SharedCounter,
284 write_fallback: SharedCounter,
285}
286
287impl DeviceBacking for VfioDevice {
288 type Registers = MappedRegionWithFallback;
289
290 fn id(&self) -> &str {
291 &self.pci_id
292 }
293
294 fn map_bar(&mut self, n: u8) -> anyhow::Result<Self::Registers> {
295 (*self).map_bar(n)
296 }
297
298 fn dma_client(&self) -> Arc<dyn DmaClient> {
299 match &self.dma_clients {
302 VfioDmaClients::EphemeralOnly(client) => client.clone(),
303 VfioDmaClients::PersistentOnly(client) => client.clone(),
304 VfioDmaClients::Split {
305 persistent,
306 ephemeral: _,
307 } => persistent.clone(),
308 }
309 }
310
311 fn dma_client_for(&self, pool: crate::DmaPool) -> anyhow::Result<Arc<dyn DmaClient>> {
312 match &self.dma_clients {
313 VfioDmaClients::PersistentOnly(client) => match pool {
314 crate::DmaPool::Persistent => Ok(client.clone()),
315 crate::DmaPool::Ephemeral => {
316 anyhow::bail!(
317 "ephemeral dma pool requested but only persistent client available"
318 )
319 }
320 },
321 VfioDmaClients::EphemeralOnly(client) => match pool {
322 crate::DmaPool::Ephemeral => Ok(client.clone()),
323 crate::DmaPool::Persistent => {
324 anyhow::bail!(
325 "persistent dma pool requested but only ephemeral client available"
326 )
327 }
328 },
329 VfioDmaClients::Split {
330 persistent,
331 ephemeral,
332 } => match pool {
333 crate::DmaPool::Persistent => Ok(persistent.clone()),
334 crate::DmaPool::Ephemeral => Ok(ephemeral.clone()),
335 },
336 }
337 }
338
339 fn max_interrupt_count(&self) -> u32 {
340 self.msix_info.count
341 }
342
343 fn map_interrupt(&mut self, msix: u32, cpu: u32) -> anyhow::Result<DeviceInterrupt> {
344 if msix >= self.msix_info.count {
345 anyhow::bail!("invalid msix index");
346 }
347 if self.interrupts.len() <= msix as usize {
348 self.interrupts.resize_with(msix as usize + 1, || None);
349 }
350
351 let interrupt = &mut self.interrupts[msix as usize];
352 if let Some(interrupt) = interrupt {
353 if interrupt.target_cpu.load(Relaxed) != cpu {
356 interrupt.target_cpu.store(cpu, Relaxed);
357 }
358 return Ok(interrupt.interrupt.clone());
359 }
360
361 let new_interrupt = {
362 let name = format!("vfio-interrupt-{pci_id}-{msix}", pci_id = self.pci_id);
363 let driver = self
364 .driver_source
365 .builder()
366 .run_on_target(true)
367 .target_vp(cpu)
368 .build(&name);
369
370 let event =
371 PolledWait::new(&driver, Event::new()).context("failed to allocate polled wait")?;
372
373 let source = DeviceInterruptSource::new();
374 self.device
375 .map_msix(msix, [event.get().as_fd()])
376 .context("failed to map msix")?;
377
378 let irq = vfio_sys::find_msix_irq(&self.pci_id, msix)
382 .context("failed to find irq for msix")?;
383
384 let target_cpu = Arc::new(AtomicU32::new(cpu));
385
386 let interrupt = source.new_target();
387
388 let task = driver.spawn(
389 name,
390 InterruptTask {
391 driver: driver.clone(),
392 target_cpu: target_cpu.clone(),
393 pci_id: self.pci_id.clone(),
394 msix,
395 irq,
396 event,
397 source,
398 }
399 .run(),
400 );
401
402 InterruptState {
403 interrupt,
404 target_cpu,
405 _task: task,
406 }
407 };
408
409 Ok(interrupt.insert(new_interrupt).interrupt.clone())
410 }
411
412 fn unmap_all_interrupts(&mut self) -> anyhow::Result<()> {
413 if self.interrupts.is_empty() {
414 return Ok(());
415 }
416
417 let count = self.interrupts.len() as u32;
418 self.device
419 .unmap_msix(0, count)
420 .context("failed to unmap all msix vectors")?;
421
422 self.interrupts.clear();
424
425 Ok(())
426 }
427}
428
429struct InterruptTask {
430 driver: VmTaskDriver,
431 target_cpu: Arc<AtomicU32>,
432 pci_id: Arc<str>,
433 msix: u32,
434 irq: u32,
435 event: PolledWait<Event>,
436 source: DeviceInterruptSource,
437}
438
439impl InterruptTask {
440 async fn run(mut self) {
441 let mut current_cpu = !0;
442 loop {
443 let next_cpu = self.target_cpu.load(Relaxed);
444 let r = if next_cpu == current_cpu {
445 self.event.wait().await
446 } else {
447 self.driver.retarget_vp(next_cpu);
448 enum Event {
451 TargetVpReady(()),
452 Interrupt(std::io::Result<()>),
453 }
454 match (
455 self.driver.wait_target_vp_ready().map(Event::TargetVpReady),
456 self.event.wait().map(Event::Interrupt),
457 )
458 .race()
459 .await
460 {
461 Event::TargetVpReady(()) => {
462 if let Err(err) = set_irq_affinity(self.irq, next_cpu) {
463 tracing::error!(
467 pci_id = self.pci_id.as_ref(),
468 msix = self.msix,
469 irq = self.irq,
470 error = &err as &dyn std::error::Error,
471 "failed to set irq affinity"
472 );
473 }
474 current_cpu = next_cpu;
475 continue;
476 }
477 Event::Interrupt(r) => {
478 r
481 }
482 }
483 };
484
485 r.expect("wait cannot fail on eventfd");
486 self.source.signal();
487 }
488 }
489}
490
491fn set_irq_affinity(irq: u32, cpu: u32) -> std::io::Result<()> {
492 fs_err::write(
493 format!("/proc/irq/{}/smp_affinity_list", irq),
494 cpu.to_string(),
495 )
496}
497
498impl DeviceRegisterIo for vfio_sys::MappedRegion {
499 fn len(&self) -> usize {
500 self.len()
501 }
502
503 fn read_u32(&self, offset: usize) -> u32 {
504 self.read_u32(offset)
505 }
506
507 fn read_u64(&self, offset: usize) -> u64 {
508 self.read_u64(offset)
509 }
510
511 fn write_u32(&self, offset: usize, data: u32) {
512 self.write_u32(offset, data)
513 }
514
515 fn write_u64(&self, offset: usize, data: u64) {
516 self.write_u64(offset, data)
517 }
518}
519
520impl MappedRegionWithFallback {
521 fn mapping<T>(&self, offset: usize) -> *mut T {
522 assert!(
523 offset <= self.mapping.len() - size_of::<T>() && offset.is_multiple_of(align_of::<T>())
524 );
525 if cfg!(feature = "mmio_simulate_fallback") {
526 return std::ptr::NonNull::dangling().as_ptr();
527 }
528 unsafe { self.mapping.as_ptr().byte_add(offset).cast() }
530 }
531
532 fn read_from_mapping<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
533 &self,
534 offset: usize,
535 ) -> Result<T, trycopy::MemoryError> {
536 unsafe { trycopy::try_read_volatile(self.mapping::<T>(offset)) }
538 }
539
540 fn write_to_mapping<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
541 &self,
542 offset: usize,
543 data: T,
544 ) -> Result<(), trycopy::MemoryError> {
545 unsafe { trycopy::try_write_volatile(self.mapping::<T>(offset), &data) }
547 }
548
549 fn read_from_file(&self, offset: usize, buf: &mut [u8]) {
550 tracing::trace!(offset, n = buf.len(), "read");
551 self.read_fallback.increment();
552 let n = self
553 .device
554 .as_ref()
555 .as_ref()
556 .read_at(buf, self.offset + offset as u64)
557 .expect("valid mapping");
558 assert_eq!(n, buf.len());
559 }
560
561 fn write_to_file(&self, offset: usize, buf: &[u8]) {
562 tracing::trace!(offset, n = buf.len(), "write");
563 self.write_fallback.increment();
564 let n = self
565 .device
566 .as_ref()
567 .as_ref()
568 .write_at(buf, self.offset + offset as u64)
569 .expect("valid mapping");
570 assert_eq!(n, buf.len());
571 }
572}
573
574impl DeviceRegisterIo for MappedRegionWithFallback {
575 fn len(&self) -> usize {
576 self.len
577 }
578
579 fn read_u32(&self, offset: usize) -> u32 {
580 self.read_from_mapping(offset).unwrap_or_else(|_| {
581 let mut buf = [0u8; 4];
582 self.read_from_file(offset, &mut buf);
583 u32::from_ne_bytes(buf)
584 })
585 }
586
587 fn read_u64(&self, offset: usize) -> u64 {
588 self.read_from_mapping(offset).unwrap_or_else(|_| {
589 let mut buf = [0u8; 8];
590 self.read_from_file(offset, &mut buf);
591 u64::from_ne_bytes(buf)
592 })
593 }
594
595 fn write_u32(&self, offset: usize, data: u32) {
596 self.write_to_mapping(offset, data).unwrap_or_else(|_| {
597 self.write_to_file(offset, &data.to_ne_bytes());
598 })
599 }
600
601 fn write_u64(&self, offset: usize, data: u64) {
602 self.write_to_mapping(offset, data).unwrap_or_else(|_| {
603 self.write_to_file(offset, &data.to_ne_bytes());
604 })
605 }
606}
607
608#[derive(Clone, Copy, Debug)]
609pub enum PciDeviceResetMethod {
610 NoReset,
611 Acpi,
612 Flr,
613 AfFlr,
614 Pm,
615 Bus,
616}
617
618pub fn vfio_set_device_reset_method(
619 pci_id: impl AsRef<str>,
620 method: PciDeviceResetMethod,
621) -> std::io::Result<()> {
622 let reset_method = match method {
623 PciDeviceResetMethod::NoReset => "\0".as_bytes(),
624 PciDeviceResetMethod::Acpi => "acpi\0".as_bytes(),
625 PciDeviceResetMethod::Flr => "flr\0".as_bytes(),
626 PciDeviceResetMethod::AfFlr => "af_flr\0".as_bytes(),
627 PciDeviceResetMethod::Pm => "pm\0".as_bytes(),
628 PciDeviceResetMethod::Bus => "bus\0".as_bytes(),
629 };
630
631 let path: std::path::PathBuf = ["/sys/bus/pci/devices", pci_id.as_ref(), "reset_method"]
632 .iter()
633 .collect();
634 fs_err::write(path, reset_method)?;
635 Ok(())
636}