1#![cfg(target_os = "linux")]
7#![cfg(feature = "vfio")]
8
9use crate::DeviceBacking;
10use crate::DeviceRegisterIo;
11use crate::DmaClient;
12use crate::interrupt::DeviceInterrupt;
13use crate::interrupt::DeviceInterruptSource;
14use anyhow::Context;
15use futures::FutureExt;
16use futures_concurrency::future::Race;
17use inspect::Inspect;
18use inspect_counters::SharedCounter;
19use pal_async::task::Spawn;
20use pal_async::task::Task;
21use pal_async::wait::PolledWait;
22use pal_event::Event;
23use std::os::fd::AsFd;
24use std::os::unix::fs::FileExt;
25use std::path::Path;
26use std::sync::Arc;
27use std::sync::atomic::AtomicU32;
28use std::sync::atomic::Ordering::Relaxed;
29use std::time::Duration;
30use uevent::UeventListener;
31use vfio_bindings::bindings::vfio::VFIO_PCI_CONFIG_REGION_INDEX;
32use vfio_sys::IommuType;
33use vfio_sys::IrqInfo;
34use vmcore::vm_task::VmTaskDriver;
35use vmcore::vm_task::VmTaskDriverSource;
36use zerocopy::FromBytes;
37use zerocopy::Immutable;
38use zerocopy::IntoBytes;
39use zerocopy::KnownLayout;
40
41#[derive(Clone)]
42pub enum VfioDmaClients {
43 PersistentOnly(Arc<dyn DmaClient>),
44 EphemeralOnly(Arc<dyn DmaClient>),
45 Split {
46 persistent: Arc<dyn DmaClient>,
47 ephemeral: Arc<dyn DmaClient>,
48 },
49}
50
51#[derive(Inspect)]
53pub struct VfioDevice {
54 pci_id: Arc<str>,
55 #[inspect(skip)]
56 _container: vfio_sys::Container,
57 #[inspect(skip)]
58 _group: vfio_sys::Group,
59 #[inspect(skip)]
60 device: Arc<vfio_sys::Device>,
61 #[inspect(skip)]
62 msix_info: IrqInfo,
63 #[inspect(skip)]
64 driver_source: VmTaskDriverSource,
65 #[inspect(iter_by_index)]
66 interrupts: Vec<Option<InterruptState>>,
67 #[inspect(skip)]
68 config_space: vfio_sys::RegionInfo,
69 #[inspect(skip)]
70 dma_clients: VfioDmaClients,
71}
72
73#[derive(Inspect)]
74struct InterruptState {
75 #[inspect(skip)]
76 interrupt: DeviceInterrupt,
77 target_cpu: Arc<AtomicU32>,
78 #[inspect(skip)]
79 _task: Task<()>,
80}
81
82impl Drop for VfioDevice {
83 fn drop(&mut self) {
84 tracing::trace!(pci_id = ?self.pci_id, "dropping vfio device");
86 }
87}
88
89impl VfioDevice {
90 pub async fn new(
92 driver_source: &VmTaskDriverSource,
93 pci_id: impl AsRef<str>,
94 dma_clients: VfioDmaClients,
95 ) -> anyhow::Result<Self> {
96 Self::restore(driver_source, pci_id, false, dma_clients).await
97 }
98
99 pub async fn restore(
102 driver_source: &VmTaskDriverSource,
103 pci_id: impl AsRef<str>,
104 keepalive: bool,
105 dma_clients: VfioDmaClients,
106 ) -> anyhow::Result<Self> {
107 let pci_id = pci_id.as_ref();
108 let path = Path::new("/sys/bus/pci/devices").join(pci_id);
109
110 let vmbus_device =
113 std::fs::read_link(&path).context("failed to read link for pci device")?;
114 let instance_path = Path::new("/sys").join(vmbus_device.strip_prefix("../../..")?);
115 let vfio_arrived_path = instance_path.join("vfio-dev");
116 let uevent_listener = UeventListener::new(&driver_source.simple())?;
117 let wait_for_vfio_device =
118 uevent_listener.wait_for_matching_child(&vfio_arrived_path, async |_, _| Some(()));
119 let mut ctx = mesh::CancelContext::new().with_timeout(Duration::from_secs(1));
120 let _ = ctx.until_cancelled(wait_for_vfio_device).await;
122
123 tracing::info!(pci_id, keepalive, "device arrived");
124
125 let container = vfio_sys::Container::new()?;
126 let group_id = vfio_sys::Group::find_group_for_device(&path)?;
127 let group = vfio_sys::Group::open_noiommu(group_id)?;
128 group.set_container(&container)?;
129 if !group.status()?.viable() {
130 anyhow::bail!("group is not viable");
131 }
132
133 let driver = driver_source.simple();
134 container.set_iommu(IommuType::NoIommu)?;
135 if keepalive {
136 group.set_keep_alive(pci_id, &driver).await?;
138 }
139 tracing::debug!(pci_id, "about to open device");
140 let device = group.open_device(pci_id, &driver).await?;
141 let msix_info = device.irq_info(vfio_bindings::bindings::vfio::VFIO_PCI_MSIX_IRQ_INDEX)?;
142 if msix_info.flags.noresize() {
143 anyhow::bail!("unsupported: kernel does not support dynamic msix allocation");
144 }
145
146 let config_space = device.region_info(VFIO_PCI_CONFIG_REGION_INDEX)?;
147 let this = Self {
148 pci_id: pci_id.into(),
149 _container: container,
150 _group: group,
151 device: Arc::new(device),
152 msix_info,
153 config_space,
154 driver_source: driver_source.clone(),
155 interrupts: Vec::new(),
156 dma_clients,
157 };
158
159 tracing::debug!(pci_id, "enabling device...");
160 this.enable_device()
163 .with_context(|| format!("failed to enable device {pci_id}"))?;
164 Ok(this)
165 }
166
167 fn enable_device(&self) -> anyhow::Result<()> {
168 let offset = pci_core::spec::cfg_space::HeaderType00::STATUS_COMMAND.0;
169 let status_command = self.read_config(offset)?;
170 let command = pci_core::spec::cfg_space::Command::from(status_command as u16);
171
172 let command = command
173 .with_bus_master(true)
174 .with_intx_disable(true)
175 .with_mmio_enabled(true);
176
177 let status_command = (status_command & 0xffff0000) | u16::from(command) as u32;
178 self.write_config(offset, status_command)?;
179 Ok(())
180 }
181
182 pub fn read_config(&self, offset: u16) -> anyhow::Result<u32> {
183 if offset as u64 > self.config_space.size - 4 {
184 anyhow::bail!("invalid config offset");
185 }
186
187 let mut buf = [0u8; 4];
188 self.device
189 .as_ref()
190 .as_ref()
191 .read_at(&mut buf, self.config_space.offset + offset as u64)
192 .context("failed to read config")?;
193
194 Ok(u32::from_ne_bytes(buf))
195 }
196
197 pub fn write_config(&self, offset: u16, data: u32) -> anyhow::Result<()> {
198 if offset as u64 > self.config_space.size - 4 {
199 anyhow::bail!("invalid config offset");
200 }
201
202 tracing::trace!(pci_id = ?self.pci_id, offset, data, "writing config");
203 let buf = data.to_ne_bytes();
204 self.device
205 .as_ref()
206 .as_ref()
207 .write_at(&buf, self.config_space.offset + offset as u64)
208 .context("failed to write config")?;
209
210 Ok(())
211 }
212
213 fn map_bar(&self, n: u8) -> anyhow::Result<MappedRegionWithFallback> {
215 if n >= 6 {
216 anyhow::bail!("invalid bar");
217 }
218 let info = self.device.region_info(n.into())?;
219 let mapping = self.device.map(info.offset, info.size as usize, true)?;
220 trycopy::initialize_try_copy();
221 Ok(MappedRegionWithFallback {
222 device: self.device.clone(),
223 mapping,
224 len: info.size as usize,
225 offset: info.offset,
226 read_fallback: SharedCounter::new(),
227 write_fallback: SharedCounter::new(),
228 })
229 }
230}
231
232#[derive(Inspect)]
238pub struct MappedRegionWithFallback {
239 #[inspect(skip)]
240 device: Arc<vfio_sys::Device>,
241 #[inspect(skip)]
242 mapping: vfio_sys::MappedRegion,
243 offset: u64,
244 len: usize,
245 read_fallback: SharedCounter,
246 write_fallback: SharedCounter,
247}
248
249impl DeviceBacking for VfioDevice {
250 type Registers = MappedRegionWithFallback;
251
252 fn id(&self) -> &str {
253 &self.pci_id
254 }
255
256 fn map_bar(&mut self, n: u8) -> anyhow::Result<Self::Registers> {
257 (*self).map_bar(n)
258 }
259
260 fn dma_client(&self) -> Arc<dyn DmaClient> {
261 match &self.dma_clients {
264 VfioDmaClients::EphemeralOnly(client) => client.clone(),
265 VfioDmaClients::PersistentOnly(client) => client.clone(),
266 VfioDmaClients::Split {
267 persistent,
268 ephemeral: _,
269 } => persistent.clone(),
270 }
271 }
272
273 fn dma_client_for(&self, pool: crate::DmaPool) -> anyhow::Result<Arc<dyn DmaClient>> {
274 match &self.dma_clients {
275 VfioDmaClients::PersistentOnly(client) => match pool {
276 crate::DmaPool::Persistent => Ok(client.clone()),
277 crate::DmaPool::Ephemeral => {
278 anyhow::bail!(
279 "ephemeral dma pool requested but only persistent client available"
280 )
281 }
282 },
283 VfioDmaClients::EphemeralOnly(client) => match pool {
284 crate::DmaPool::Ephemeral => Ok(client.clone()),
285 crate::DmaPool::Persistent => {
286 anyhow::bail!(
287 "persistent dma pool requested but only ephemeral client available"
288 )
289 }
290 },
291 VfioDmaClients::Split {
292 persistent,
293 ephemeral,
294 } => match pool {
295 crate::DmaPool::Persistent => Ok(persistent.clone()),
296 crate::DmaPool::Ephemeral => Ok(ephemeral.clone()),
297 },
298 }
299 }
300
301 fn max_interrupt_count(&self) -> u32 {
302 self.msix_info.count
303 }
304
305 fn map_interrupt(&mut self, msix: u32, cpu: u32) -> anyhow::Result<DeviceInterrupt> {
306 if msix >= self.msix_info.count {
307 anyhow::bail!("invalid msix index");
308 }
309 if self.interrupts.len() <= msix as usize {
310 self.interrupts.resize_with(msix as usize + 1, || None);
311 }
312
313 let interrupt = &mut self.interrupts[msix as usize];
314 if let Some(interrupt) = interrupt {
315 if interrupt.target_cpu.load(Relaxed) != cpu {
318 interrupt.target_cpu.store(cpu, Relaxed);
319 }
320 return Ok(interrupt.interrupt.clone());
321 }
322
323 let new_interrupt = {
324 let name = format!("vfio-interrupt-{pci_id}-{msix}", pci_id = self.pci_id);
325 let driver = self
326 .driver_source
327 .builder()
328 .run_on_target(true)
329 .target_vp(cpu)
330 .build(&name);
331
332 let event =
333 PolledWait::new(&driver, Event::new()).context("failed to allocate polled wait")?;
334
335 let source = DeviceInterruptSource::new();
336 self.device
337 .map_msix(msix, [event.get().as_fd()])
338 .context("failed to map msix")?;
339
340 let irq = vfio_sys::find_msix_irq(&self.pci_id, msix)
344 .context("failed to find irq for msix")?;
345
346 let target_cpu = Arc::new(AtomicU32::new(cpu));
347
348 let interrupt = source.new_target();
349
350 let task = driver.spawn(
351 name,
352 InterruptTask {
353 driver: driver.clone(),
354 target_cpu: target_cpu.clone(),
355 pci_id: self.pci_id.clone(),
356 msix,
357 irq,
358 event,
359 source,
360 }
361 .run(),
362 );
363
364 InterruptState {
365 interrupt,
366 target_cpu,
367 _task: task,
368 }
369 };
370
371 Ok(interrupt.insert(new_interrupt).interrupt.clone())
372 }
373
374 fn unmap_all_interrupts(&mut self) -> anyhow::Result<()> {
375 if self.interrupts.is_empty() {
376 return Ok(());
377 }
378
379 let count = self.interrupts.len() as u32;
380 self.device
381 .unmap_msix(0, count)
382 .context("failed to unmap all msix vectors")?;
383
384 self.interrupts.clear();
386
387 Ok(())
388 }
389}
390
391struct InterruptTask {
392 driver: VmTaskDriver,
393 target_cpu: Arc<AtomicU32>,
394 pci_id: Arc<str>,
395 msix: u32,
396 irq: u32,
397 event: PolledWait<Event>,
398 source: DeviceInterruptSource,
399}
400
401impl InterruptTask {
402 async fn run(mut self) {
403 let mut current_cpu = !0;
404 loop {
405 let next_cpu = self.target_cpu.load(Relaxed);
406 let r = if next_cpu == current_cpu {
407 self.event.wait().await
408 } else {
409 self.driver.retarget_vp(next_cpu);
410 enum Event {
413 TargetVpReady(()),
414 Interrupt(std::io::Result<()>),
415 }
416 match (
417 self.driver.wait_target_vp_ready().map(Event::TargetVpReady),
418 self.event.wait().map(Event::Interrupt),
419 )
420 .race()
421 .await
422 {
423 Event::TargetVpReady(()) => {
424 if let Err(err) = set_irq_affinity(self.irq, next_cpu) {
425 tracing::error!(
429 pci_id = self.pci_id.as_ref(),
430 msix = self.msix,
431 irq = self.irq,
432 error = &err as &dyn std::error::Error,
433 "failed to set irq affinity"
434 );
435 }
436 current_cpu = next_cpu;
437 continue;
438 }
439 Event::Interrupt(r) => {
440 r
443 }
444 }
445 };
446
447 r.expect("wait cannot fail on eventfd");
448 self.source.signal();
449 }
450 }
451}
452
453fn set_irq_affinity(irq: u32, cpu: u32) -> std::io::Result<()> {
454 fs_err::write(
455 format!("/proc/irq/{}/smp_affinity_list", irq),
456 cpu.to_string(),
457 )
458}
459
460impl DeviceRegisterIo for vfio_sys::MappedRegion {
461 fn len(&self) -> usize {
462 self.len()
463 }
464
465 fn read_u32(&self, offset: usize) -> u32 {
466 self.read_u32(offset)
467 }
468
469 fn read_u64(&self, offset: usize) -> u64 {
470 self.read_u64(offset)
471 }
472
473 fn write_u32(&self, offset: usize, data: u32) {
474 self.write_u32(offset, data)
475 }
476
477 fn write_u64(&self, offset: usize, data: u64) {
478 self.write_u64(offset, data)
479 }
480}
481
482impl MappedRegionWithFallback {
483 fn mapping<T>(&self, offset: usize) -> *mut T {
484 assert!(
485 offset <= self.mapping.len() - size_of::<T>() && offset.is_multiple_of(align_of::<T>())
486 );
487 if cfg!(feature = "mmio_simulate_fallback") {
488 return std::ptr::NonNull::dangling().as_ptr();
489 }
490 unsafe { self.mapping.as_ptr().byte_add(offset).cast() }
492 }
493
494 fn read_from_mapping<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
495 &self,
496 offset: usize,
497 ) -> Result<T, trycopy::MemoryError> {
498 unsafe { trycopy::try_read_volatile(self.mapping::<T>(offset)) }
500 }
501
502 fn write_to_mapping<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
503 &self,
504 offset: usize,
505 data: T,
506 ) -> Result<(), trycopy::MemoryError> {
507 unsafe { trycopy::try_write_volatile(self.mapping::<T>(offset), &data) }
509 }
510
511 fn read_from_file(&self, offset: usize, buf: &mut [u8]) {
512 tracing::trace!(offset, n = buf.len(), "read");
513 self.read_fallback.increment();
514 let n = self
515 .device
516 .as_ref()
517 .as_ref()
518 .read_at(buf, self.offset + offset as u64)
519 .expect("valid mapping");
520 assert_eq!(n, buf.len());
521 }
522
523 fn write_to_file(&self, offset: usize, buf: &[u8]) {
524 tracing::trace!(offset, n = buf.len(), "write");
525 self.write_fallback.increment();
526 let n = self
527 .device
528 .as_ref()
529 .as_ref()
530 .write_at(buf, self.offset + offset as u64)
531 .expect("valid mapping");
532 assert_eq!(n, buf.len());
533 }
534}
535
536impl DeviceRegisterIo for MappedRegionWithFallback {
537 fn len(&self) -> usize {
538 self.len
539 }
540
541 fn read_u32(&self, offset: usize) -> u32 {
542 self.read_from_mapping(offset).unwrap_or_else(|_| {
543 let mut buf = [0u8; 4];
544 self.read_from_file(offset, &mut buf);
545 u32::from_ne_bytes(buf)
546 })
547 }
548
549 fn read_u64(&self, offset: usize) -> u64 {
550 self.read_from_mapping(offset).unwrap_or_else(|_| {
551 let mut buf = [0u8; 8];
552 self.read_from_file(offset, &mut buf);
553 u64::from_ne_bytes(buf)
554 })
555 }
556
557 fn write_u32(&self, offset: usize, data: u32) {
558 self.write_to_mapping(offset, data).unwrap_or_else(|_| {
559 self.write_to_file(offset, &data.to_ne_bytes());
560 })
561 }
562
563 fn write_u64(&self, offset: usize, data: u64) {
564 self.write_to_mapping(offset, data).unwrap_or_else(|_| {
565 self.write_to_file(offset, &data.to_ne_bytes());
566 })
567 }
568}
569
570#[derive(Clone, Copy, Debug)]
571pub enum PciDeviceResetMethod {
572 NoReset,
573 Acpi,
574 Flr,
575 AfFlr,
576 Pm,
577 Bus,
578}
579
580pub fn vfio_set_device_reset_method(
581 pci_id: impl AsRef<str>,
582 method: PciDeviceResetMethod,
583) -> std::io::Result<()> {
584 let reset_method = match method {
585 PciDeviceResetMethod::NoReset => "\0".as_bytes(),
586 PciDeviceResetMethod::Acpi => "acpi\0".as_bytes(),
587 PciDeviceResetMethod::Flr => "flr\0".as_bytes(),
588 PciDeviceResetMethod::AfFlr => "af_flr\0".as_bytes(),
589 PciDeviceResetMethod::Pm => "pm\0".as_bytes(),
590 PciDeviceResetMethod::Bus => "bus\0".as_bytes(),
591 };
592
593 let path: std::path::PathBuf = ["/sys/bus/pci/devices", pci_id.as_ref(), "reset_method"]
594 .iter()
595 .collect();
596 fs_err::write(path, reset_method)?;
597 Ok(())
598}