1#![cfg(target_os = "linux")]
7#![cfg(feature = "vfio")]
8
9use crate::DeviceBacking;
10use crate::DeviceRegisterIo;
11use crate::DmaClient;
12use crate::interrupt::DeviceInterrupt;
13use crate::interrupt::DeviceInterruptSource;
14use anyhow::Context;
15use futures::FutureExt;
16use futures_concurrency::future::Race;
17use inspect::Inspect;
18use inspect_counters::SharedCounter;
19use pal_async::task::Spawn;
20use pal_async::task::Task;
21use pal_async::wait::PolledWait;
22use pal_event::Event;
23use std::os::fd::AsFd;
24use std::os::unix::fs::FileExt;
25use std::path::Path;
26use std::sync::Arc;
27use std::sync::atomic::AtomicU32;
28use std::sync::atomic::Ordering::Relaxed;
29use std::time::Duration;
30use uevent::UeventListener;
31use vfio_bindings::bindings::vfio::VFIO_PCI_CONFIG_REGION_INDEX;
32use vfio_sys::IommuType;
33use vfio_sys::IrqInfo;
34use vmcore::vm_task::VmTaskDriver;
35use vmcore::vm_task::VmTaskDriverSource;
36use zerocopy::FromBytes;
37use zerocopy::Immutable;
38use zerocopy::IntoBytes;
39use zerocopy::KnownLayout;
40
41#[derive(Clone)]
42pub enum VfioDmaClients {
43 PersistentOnly(Arc<dyn DmaClient>),
44 EphemeralOnly(Arc<dyn DmaClient>),
45 Split {
46 persistent: Arc<dyn DmaClient>,
47 ephemeral: Arc<dyn DmaClient>,
48 },
49}
50
51#[derive(Inspect)]
53pub struct VfioDevice {
54 pci_id: Arc<str>,
55 #[inspect(skip)]
56 _container: vfio_sys::Container,
57 #[inspect(skip)]
58 _group: vfio_sys::Group,
59 #[inspect(skip)]
60 device: Arc<vfio_sys::Device>,
61 #[inspect(skip)]
62 msix_info: IrqInfo,
63 #[inspect(skip)]
64 driver_source: VmTaskDriverSource,
65 #[inspect(iter_by_index)]
66 interrupts: Vec<Option<InterruptState>>,
67 #[inspect(skip)]
68 config_space: vfio_sys::RegionInfo,
69 #[inspect(skip)]
70 dma_clients: VfioDmaClients,
71}
72
73#[derive(Inspect)]
74struct InterruptState {
75 #[inspect(skip)]
76 interrupt: DeviceInterrupt,
77 target_cpu: Arc<AtomicU32>,
78 #[inspect(skip)]
79 _task: Task<()>,
80}
81
82impl Drop for VfioDevice {
83 fn drop(&mut self) {
84 tracing::trace!(pci_id = ?self.pci_id, "dropping vfio device");
86 }
87}
88
89impl VfioDevice {
90 pub async fn new(
92 driver_source: &VmTaskDriverSource,
93 pci_id: impl AsRef<str>,
94 dma_clients: VfioDmaClients,
95 ) -> anyhow::Result<Self> {
96 Self::restore(driver_source, pci_id, false, dma_clients).await
97 }
98
99 pub async fn restore(
102 driver_source: &VmTaskDriverSource,
103 pci_id: impl AsRef<str>,
104 keepalive: bool,
105 dma_clients: VfioDmaClients,
106 ) -> anyhow::Result<Self> {
107 let pci_id = pci_id.as_ref();
108 let path = Path::new("/sys/bus/pci/devices").join(pci_id);
109
110 let vmbus_device =
113 std::fs::read_link(&path).context("failed to read link for pci device")?;
114 let instance_path = Path::new("/sys").join(vmbus_device.strip_prefix("../../..")?);
115 let vfio_arrived_path = instance_path.join("vfio-dev");
116 let uevent_listener = UeventListener::new(&driver_source.simple())?;
117 let wait_for_vfio_device =
118 uevent_listener.wait_for_matching_child(&vfio_arrived_path, async |_, _| Some(()));
119 let mut ctx = mesh::CancelContext::new().with_timeout(Duration::from_secs(1));
120 let _ = ctx.until_cancelled(wait_for_vfio_device).await;
122
123 tracing::info!(pci_id, keepalive, "device arrived");
124 vfio_sys::print_relevant_params();
125
126 let container = vfio_sys::Container::new()?;
127 let group_id = vfio_sys::Group::find_group_for_device(&path)?;
128 let group = vfio_sys::Group::open_noiommu(group_id)?;
129 group.set_container(&container)?;
130 if !group.status()?.viable() {
131 anyhow::bail!("group is not viable");
132 }
133
134 let driver = driver_source.simple();
135 container.set_iommu(IommuType::NoIommu)?;
136 if keepalive {
137 group.set_keep_alive(pci_id, &driver).await?;
139 }
140 tracing::debug!(pci_id, "about to open device");
141 let device = group.open_device(pci_id, &driver).await?;
142 let msix_info = device.irq_info(vfio_bindings::bindings::vfio::VFIO_PCI_MSIX_IRQ_INDEX)?;
143 if msix_info.flags.noresize() {
144 anyhow::bail!("unsupported: kernel does not support dynamic msix allocation");
145 }
146
147 let config_space = device.region_info(VFIO_PCI_CONFIG_REGION_INDEX)?;
148 let this = Self {
149 pci_id: pci_id.into(),
150 _container: container,
151 _group: group,
152 device: Arc::new(device),
153 msix_info,
154 config_space,
155 driver_source: driver_source.clone(),
156 interrupts: Vec::new(),
157 dma_clients,
158 };
159
160 tracing::debug!(pci_id, "enabling device...");
161 this.enable_device()
164 .with_context(|| format!("failed to enable device {pci_id}"))?;
165 Ok(this)
166 }
167
168 fn enable_device(&self) -> anyhow::Result<()> {
169 let offset = pci_core::spec::cfg_space::HeaderType00::STATUS_COMMAND.0;
170 let status_command = self.read_config(offset)?;
171 let command = pci_core::spec::cfg_space::Command::from(status_command as u16);
172
173 let command = command
174 .with_bus_master(true)
175 .with_intx_disable(true)
176 .with_mmio_enabled(true);
177
178 let status_command = (status_command & 0xffff0000) | u16::from(command) as u32;
179 self.write_config(offset, status_command)?;
180 Ok(())
181 }
182
183 pub fn read_config(&self, offset: u16) -> anyhow::Result<u32> {
184 if offset as u64 > self.config_space.size - 4 {
185 anyhow::bail!("invalid config offset");
186 }
187
188 let mut buf = [0u8; 4];
189 self.device
190 .as_ref()
191 .as_ref()
192 .read_at(&mut buf, self.config_space.offset + offset as u64)
193 .context("failed to read config")?;
194
195 Ok(u32::from_ne_bytes(buf))
196 }
197
198 pub fn write_config(&self, offset: u16, data: u32) -> anyhow::Result<()> {
199 if offset as u64 > self.config_space.size - 4 {
200 anyhow::bail!("invalid config offset");
201 }
202
203 tracing::trace!(pci_id = ?self.pci_id, offset, data, "writing config");
204 let buf = data.to_ne_bytes();
205 self.device
206 .as_ref()
207 .as_ref()
208 .write_at(&buf, self.config_space.offset + offset as u64)
209 .context("failed to write config")?;
210
211 Ok(())
212 }
213
214 fn map_bar(&self, n: u8) -> anyhow::Result<MappedRegionWithFallback> {
216 if n >= 6 {
217 anyhow::bail!("invalid bar");
218 }
219 let info = self.device.region_info(n.into())?;
220 let mapping = self.device.map(info.offset, info.size as usize, true)?;
221 trycopy::initialize_try_copy();
222 Ok(MappedRegionWithFallback {
223 device: self.device.clone(),
224 mapping,
225 len: info.size as usize,
226 offset: info.offset,
227 read_fallback: SharedCounter::new(),
228 write_fallback: SharedCounter::new(),
229 })
230 }
231}
232
233#[derive(Inspect)]
239pub struct MappedRegionWithFallback {
240 #[inspect(skip)]
241 device: Arc<vfio_sys::Device>,
242 #[inspect(skip)]
243 mapping: vfio_sys::MappedRegion,
244 offset: u64,
245 len: usize,
246 read_fallback: SharedCounter,
247 write_fallback: SharedCounter,
248}
249
250impl DeviceBacking for VfioDevice {
251 type Registers = MappedRegionWithFallback;
252
253 fn id(&self) -> &str {
254 &self.pci_id
255 }
256
257 fn map_bar(&mut self, n: u8) -> anyhow::Result<Self::Registers> {
258 (*self).map_bar(n)
259 }
260
261 fn dma_client(&self) -> Arc<dyn DmaClient> {
262 match &self.dma_clients {
265 VfioDmaClients::EphemeralOnly(client) => client.clone(),
266 VfioDmaClients::PersistentOnly(client) => client.clone(),
267 VfioDmaClients::Split {
268 persistent,
269 ephemeral: _,
270 } => persistent.clone(),
271 }
272 }
273
274 fn dma_client_for(&self, pool: crate::DmaPool) -> anyhow::Result<Arc<dyn DmaClient>> {
275 match &self.dma_clients {
276 VfioDmaClients::PersistentOnly(client) => match pool {
277 crate::DmaPool::Persistent => Ok(client.clone()),
278 crate::DmaPool::Ephemeral => {
279 anyhow::bail!(
280 "ephemeral dma pool requested but only persistent client available"
281 )
282 }
283 },
284 VfioDmaClients::EphemeralOnly(client) => match pool {
285 crate::DmaPool::Ephemeral => Ok(client.clone()),
286 crate::DmaPool::Persistent => {
287 anyhow::bail!(
288 "persistent dma pool requested but only ephemeral client available"
289 )
290 }
291 },
292 VfioDmaClients::Split {
293 persistent,
294 ephemeral,
295 } => match pool {
296 crate::DmaPool::Persistent => Ok(persistent.clone()),
297 crate::DmaPool::Ephemeral => Ok(ephemeral.clone()),
298 },
299 }
300 }
301
302 fn max_interrupt_count(&self) -> u32 {
303 self.msix_info.count
304 }
305
306 fn map_interrupt(&mut self, msix: u32, cpu: u32) -> anyhow::Result<DeviceInterrupt> {
307 if msix >= self.msix_info.count {
308 anyhow::bail!("invalid msix index");
309 }
310 if self.interrupts.len() <= msix as usize {
311 self.interrupts.resize_with(msix as usize + 1, || None);
312 }
313
314 let interrupt = &mut self.interrupts[msix as usize];
315 if let Some(interrupt) = interrupt {
316 if interrupt.target_cpu.load(Relaxed) != cpu {
319 interrupt.target_cpu.store(cpu, Relaxed);
320 }
321 return Ok(interrupt.interrupt.clone());
322 }
323
324 let new_interrupt = {
325 let name = format!("vfio-interrupt-{pci_id}-{msix}", pci_id = self.pci_id);
326 let driver = self
327 .driver_source
328 .builder()
329 .run_on_target(true)
330 .target_vp(cpu)
331 .build(&name);
332
333 let event =
334 PolledWait::new(&driver, Event::new()).context("failed to allocate polled wait")?;
335
336 let source = DeviceInterruptSource::new();
337 self.device
338 .map_msix(msix, [event.get().as_fd()])
339 .context("failed to map msix")?;
340
341 let irq = vfio_sys::find_msix_irq(&self.pci_id, msix)
345 .context("failed to find irq for msix")?;
346
347 let target_cpu = Arc::new(AtomicU32::new(cpu));
348
349 let interrupt = source.new_target();
350
351 let task = driver.spawn(
352 name,
353 InterruptTask {
354 driver: driver.clone(),
355 target_cpu: target_cpu.clone(),
356 pci_id: self.pci_id.clone(),
357 msix,
358 irq,
359 event,
360 source,
361 }
362 .run(),
363 );
364
365 InterruptState {
366 interrupt,
367 target_cpu,
368 _task: task,
369 }
370 };
371
372 Ok(interrupt.insert(new_interrupt).interrupt.clone())
373 }
374
375 fn unmap_all_interrupts(&mut self) -> anyhow::Result<()> {
376 if self.interrupts.is_empty() {
377 return Ok(());
378 }
379
380 let count = self.interrupts.len() as u32;
381 self.device
382 .unmap_msix(0, count)
383 .context("failed to unmap all msix vectors")?;
384
385 self.interrupts.clear();
387
388 Ok(())
389 }
390}
391
392struct InterruptTask {
393 driver: VmTaskDriver,
394 target_cpu: Arc<AtomicU32>,
395 pci_id: Arc<str>,
396 msix: u32,
397 irq: u32,
398 event: PolledWait<Event>,
399 source: DeviceInterruptSource,
400}
401
402impl InterruptTask {
403 async fn run(mut self) {
404 let mut current_cpu = !0;
405 loop {
406 let next_cpu = self.target_cpu.load(Relaxed);
407 let r = if next_cpu == current_cpu {
408 self.event.wait().await
409 } else {
410 self.driver.retarget_vp(next_cpu);
411 enum Event {
414 TargetVpReady(()),
415 Interrupt(std::io::Result<()>),
416 }
417 match (
418 self.driver.wait_target_vp_ready().map(Event::TargetVpReady),
419 self.event.wait().map(Event::Interrupt),
420 )
421 .race()
422 .await
423 {
424 Event::TargetVpReady(()) => {
425 if let Err(err) = set_irq_affinity(self.irq, next_cpu) {
426 tracing::error!(
430 pci_id = self.pci_id.as_ref(),
431 msix = self.msix,
432 irq = self.irq,
433 error = &err as &dyn std::error::Error,
434 "failed to set irq affinity"
435 );
436 }
437 current_cpu = next_cpu;
438 continue;
439 }
440 Event::Interrupt(r) => {
441 r
444 }
445 }
446 };
447
448 r.expect("wait cannot fail on eventfd");
449 self.source.signal();
450 }
451 }
452}
453
454fn set_irq_affinity(irq: u32, cpu: u32) -> std::io::Result<()> {
455 fs_err::write(
456 format!("/proc/irq/{}/smp_affinity_list", irq),
457 cpu.to_string(),
458 )
459}
460
461impl DeviceRegisterIo for vfio_sys::MappedRegion {
462 fn len(&self) -> usize {
463 self.len()
464 }
465
466 fn read_u32(&self, offset: usize) -> u32 {
467 self.read_u32(offset)
468 }
469
470 fn read_u64(&self, offset: usize) -> u64 {
471 self.read_u64(offset)
472 }
473
474 fn write_u32(&self, offset: usize, data: u32) {
475 self.write_u32(offset, data)
476 }
477
478 fn write_u64(&self, offset: usize, data: u64) {
479 self.write_u64(offset, data)
480 }
481}
482
483impl MappedRegionWithFallback {
484 fn mapping<T>(&self, offset: usize) -> *mut T {
485 assert!(
486 offset <= self.mapping.len() - size_of::<T>() && offset.is_multiple_of(align_of::<T>())
487 );
488 if cfg!(feature = "mmio_simulate_fallback") {
489 return std::ptr::NonNull::dangling().as_ptr();
490 }
491 unsafe { self.mapping.as_ptr().byte_add(offset).cast() }
493 }
494
495 fn read_from_mapping<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
496 &self,
497 offset: usize,
498 ) -> Result<T, trycopy::MemoryError> {
499 unsafe { trycopy::try_read_volatile(self.mapping::<T>(offset)) }
501 }
502
503 fn write_to_mapping<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
504 &self,
505 offset: usize,
506 data: T,
507 ) -> Result<(), trycopy::MemoryError> {
508 unsafe { trycopy::try_write_volatile(self.mapping::<T>(offset), &data) }
510 }
511
512 fn read_from_file(&self, offset: usize, buf: &mut [u8]) {
513 tracing::trace!(offset, n = buf.len(), "read");
514 self.read_fallback.increment();
515 let n = self
516 .device
517 .as_ref()
518 .as_ref()
519 .read_at(buf, self.offset + offset as u64)
520 .expect("valid mapping");
521 assert_eq!(n, buf.len());
522 }
523
524 fn write_to_file(&self, offset: usize, buf: &[u8]) {
525 tracing::trace!(offset, n = buf.len(), "write");
526 self.write_fallback.increment();
527 let n = self
528 .device
529 .as_ref()
530 .as_ref()
531 .write_at(buf, self.offset + offset as u64)
532 .expect("valid mapping");
533 assert_eq!(n, buf.len());
534 }
535}
536
537impl DeviceRegisterIo for MappedRegionWithFallback {
538 fn len(&self) -> usize {
539 self.len
540 }
541
542 fn read_u32(&self, offset: usize) -> u32 {
543 self.read_from_mapping(offset).unwrap_or_else(|_| {
544 let mut buf = [0u8; 4];
545 self.read_from_file(offset, &mut buf);
546 u32::from_ne_bytes(buf)
547 })
548 }
549
550 fn read_u64(&self, offset: usize) -> u64 {
551 self.read_from_mapping(offset).unwrap_or_else(|_| {
552 let mut buf = [0u8; 8];
553 self.read_from_file(offset, &mut buf);
554 u64::from_ne_bytes(buf)
555 })
556 }
557
558 fn write_u32(&self, offset: usize, data: u32) {
559 self.write_to_mapping(offset, data).unwrap_or_else(|_| {
560 self.write_to_file(offset, &data.to_ne_bytes());
561 })
562 }
563
564 fn write_u64(&self, offset: usize, data: u64) {
565 self.write_to_mapping(offset, data).unwrap_or_else(|_| {
566 self.write_to_file(offset, &data.to_ne_bytes());
567 })
568 }
569}
570
571#[derive(Clone, Copy, Debug)]
572pub enum PciDeviceResetMethod {
573 NoReset,
574 Acpi,
575 Flr,
576 AfFlr,
577 Pm,
578 Bus,
579}
580
581pub fn vfio_set_device_reset_method(
582 pci_id: impl AsRef<str>,
583 method: PciDeviceResetMethod,
584) -> std::io::Result<()> {
585 let reset_method = match method {
586 PciDeviceResetMethod::NoReset => "\0".as_bytes(),
587 PciDeviceResetMethod::Acpi => "acpi\0".as_bytes(),
588 PciDeviceResetMethod::Flr => "flr\0".as_bytes(),
589 PciDeviceResetMethod::AfFlr => "af_flr\0".as_bytes(),
590 PciDeviceResetMethod::Pm => "pm\0".as_bytes(),
591 PciDeviceResetMethod::Bus => "bus\0".as_bytes(),
592 };
593
594 let path: std::path::PathBuf = ["/sys/bus/pci/devices", pci_id.as_ref(), "reset_method"]
595 .iter()
596 .collect();
597 fs_err::write(path, reset_method)?;
598 Ok(())
599}