1#![cfg(target_os = "linux")]
7#![cfg(feature = "vfio")]
8
9use crate::DeviceBacking;
10use crate::DeviceRegisterIo;
11use crate::DmaClient;
12use crate::interrupt::DeviceInterrupt;
13use crate::interrupt::DeviceInterruptSource;
14use anyhow::Context;
15use futures::FutureExt;
16use futures_concurrency::future::Race;
17use inspect::Inspect;
18use inspect_counters::SharedCounter;
19use pal_async::task::Spawn;
20use pal_async::task::Task;
21use pal_async::wait::PolledWait;
22use pal_event::Event;
23use std::os::fd::AsFd;
24use std::os::unix::fs::FileExt;
25use std::path::Path;
26use std::sync::Arc;
27use std::sync::atomic::AtomicU32;
28use std::sync::atomic::Ordering::Relaxed;
29use std::time::Duration;
30use uevent::UeventListener;
31use vfio_bindings::bindings::vfio::VFIO_PCI_CONFIG_REGION_INDEX;
32use vfio_sys::IommuType;
33use vfio_sys::IrqInfo;
34use vmcore::vm_task::VmTaskDriver;
35use vmcore::vm_task::VmTaskDriverSource;
36use zerocopy::FromBytes;
37use zerocopy::Immutable;
38use zerocopy::IntoBytes;
39use zerocopy::KnownLayout;
40
41#[derive(Inspect)]
43pub struct VfioDevice {
44 pci_id: Arc<str>,
45 #[inspect(skip)]
46 _container: vfio_sys::Container,
47 #[inspect(skip)]
48 _group: vfio_sys::Group,
49 #[inspect(skip)]
50 device: Arc<vfio_sys::Device>,
51 #[inspect(skip)]
52 msix_info: IrqInfo,
53 #[inspect(skip)]
54 driver_source: VmTaskDriverSource,
55 #[inspect(iter_by_index)]
56 interrupts: Vec<Option<InterruptState>>,
57 #[inspect(skip)]
58 config_space: vfio_sys::RegionInfo,
59 dma_client: Arc<dyn DmaClient>,
60}
61
62#[derive(Inspect)]
63struct InterruptState {
64 #[inspect(skip)]
65 interrupt: DeviceInterrupt,
66 target_cpu: Arc<AtomicU32>,
67 #[inspect(skip)]
68 _task: Task<()>,
69}
70
71impl VfioDevice {
72 pub async fn new(
74 driver_source: &VmTaskDriverSource,
75 pci_id: &str,
76 dma_client: Arc<dyn DmaClient>,
77 ) -> anyhow::Result<Self> {
78 Self::restore(driver_source, pci_id, false, dma_client).await
79 }
80
81 pub async fn restore(
84 driver_source: &VmTaskDriverSource,
85 pci_id: &str,
86 keepalive: bool,
87 dma_client: Arc<dyn DmaClient>,
88 ) -> anyhow::Result<Self> {
89 let path = Path::new("/sys/bus/pci/devices").join(pci_id);
90
91 let vmbus_device =
94 std::fs::read_link(&path).context("failed to read link for pci device")?;
95 let instance_path = Path::new("/sys").join(vmbus_device.strip_prefix("../../..")?);
96 let vfio_arrived_path = instance_path.join("vfio-dev");
97 let uevent_listener = UeventListener::new(&driver_source.simple())?;
98 let wait_for_vfio_device =
99 uevent_listener.wait_for_matching_child(&vfio_arrived_path, async |_, _| Some(()));
100 let mut ctx = mesh::CancelContext::new().with_timeout(Duration::from_secs(1));
101 let _ = ctx.until_cancelled(wait_for_vfio_device).await;
103
104 let container = vfio_sys::Container::new()?;
105 let group_id = vfio_sys::Group::find_group_for_device(&path)?;
106 let group = vfio_sys::Group::open_noiommu(group_id)?;
107 group.set_container(&container)?;
108 if !group.status()?.viable() {
109 anyhow::bail!("group is not viable");
110 }
111
112 container.set_iommu(IommuType::NoIommu)?;
113 if keepalive {
114 group.set_keep_alive(pci_id)?;
116 }
117 let device = group.open_device(pci_id)?;
118 let msix_info = device.irq_info(vfio_bindings::bindings::vfio::VFIO_PCI_MSIX_IRQ_INDEX)?;
119 if msix_info.flags.noresize() {
120 anyhow::bail!("unsupported: kernel does not support dynamic msix allocation");
121 }
122
123 let config_space = device.region_info(VFIO_PCI_CONFIG_REGION_INDEX)?;
124 let this = Self {
125 pci_id: pci_id.into(),
126 _container: container,
127 _group: group,
128 device: Arc::new(device),
129 msix_info,
130 config_space,
131 driver_source: driver_source.clone(),
132 interrupts: Vec::new(),
133 dma_client,
134 };
135
136 this.enable_device().context("failed to enable device")?;
139 Ok(this)
140 }
141
142 fn enable_device(&self) -> anyhow::Result<()> {
143 let offset = pci_core::spec::cfg_space::HeaderType00::STATUS_COMMAND.0;
144 let status_command = self.read_config(offset)?;
145 let command = pci_core::spec::cfg_space::Command::from(status_command as u16);
146
147 let command = command
148 .with_bus_master(true)
149 .with_intx_disable(true)
150 .with_mmio_enabled(true);
151
152 let status_command = (status_command & 0xffff0000) | u16::from(command) as u32;
153 self.write_config(offset, status_command)?;
154 Ok(())
155 }
156
157 pub fn read_config(&self, offset: u16) -> anyhow::Result<u32> {
158 if offset as u64 > self.config_space.size - 4 {
159 anyhow::bail!("invalid config offset");
160 }
161
162 let mut buf = [0u8; 4];
163 self.device
164 .as_ref()
165 .as_ref()
166 .read_at(&mut buf, self.config_space.offset + offset as u64)
167 .context("failed to read config")?;
168
169 Ok(u32::from_ne_bytes(buf))
170 }
171
172 pub fn write_config(&self, offset: u16, data: u32) -> anyhow::Result<()> {
173 if offset as u64 > self.config_space.size - 4 {
174 anyhow::bail!("invalid config offset");
175 }
176
177 let buf = data.to_ne_bytes();
178 self.device
179 .as_ref()
180 .as_ref()
181 .write_at(&buf, self.config_space.offset + offset as u64)
182 .context("failed to write config")?;
183
184 Ok(())
185 }
186
187 fn map_bar(&self, n: u8) -> anyhow::Result<MappedRegionWithFallback> {
189 if n >= 6 {
190 anyhow::bail!("invalid bar");
191 }
192 let info = self.device.region_info(n.into())?;
193 let mapping = self.device.map(info.offset, info.size as usize, true)?;
194 sparse_mmap::initialize_try_copy();
195 Ok(MappedRegionWithFallback {
196 device: self.device.clone(),
197 mapping,
198 len: info.size as usize,
199 offset: info.offset,
200 read_fallback: SharedCounter::new(),
201 write_fallback: SharedCounter::new(),
202 })
203 }
204}
205
206#[derive(Inspect)]
212pub struct MappedRegionWithFallback {
213 #[inspect(skip)]
214 device: Arc<vfio_sys::Device>,
215 #[inspect(skip)]
216 mapping: vfio_sys::MappedRegion,
217 offset: u64,
218 len: usize,
219 read_fallback: SharedCounter,
220 write_fallback: SharedCounter,
221}
222
223impl DeviceBacking for VfioDevice {
224 type Registers = MappedRegionWithFallback;
225
226 fn id(&self) -> &str {
227 &self.pci_id
228 }
229
230 fn map_bar(&mut self, n: u8) -> anyhow::Result<Self::Registers> {
231 (*self).map_bar(n)
232 }
233
234 fn dma_client(&self) -> Arc<dyn DmaClient> {
235 self.dma_client.clone()
236 }
237
238 fn max_interrupt_count(&self) -> u32 {
239 self.msix_info.count
240 }
241
242 fn map_interrupt(&mut self, msix: u32, cpu: u32) -> anyhow::Result<DeviceInterrupt> {
243 if msix >= self.msix_info.count {
244 anyhow::bail!("invalid msix index");
245 }
246 if self.interrupts.len() <= msix as usize {
247 self.interrupts.resize_with(msix as usize + 1, || None);
248 }
249
250 let interrupt = &mut self.interrupts[msix as usize];
251 if let Some(interrupt) = interrupt {
252 if interrupt.target_cpu.load(Relaxed) != cpu {
255 interrupt.target_cpu.store(cpu, Relaxed);
256 }
257 return Ok(interrupt.interrupt.clone());
258 }
259
260 let new_interrupt = {
261 let name = format!("vfio-interrupt-{pci_id}-{msix}", pci_id = self.pci_id);
262 let driver = self
263 .driver_source
264 .builder()
265 .run_on_target(true)
266 .target_vp(cpu)
267 .build(&name);
268
269 let event =
270 PolledWait::new(&driver, Event::new()).context("failed to allocate polled wait")?;
271
272 let source = DeviceInterruptSource::new();
273 self.device
274 .map_msix(msix, [event.get().as_fd()])
275 .context("failed to map msix")?;
276
277 let irq = vfio_sys::find_msix_irq(&self.pci_id, msix)
281 .context("failed to find irq for msix")?;
282
283 let target_cpu = Arc::new(AtomicU32::new(cpu));
284
285 let interrupt = source.new_target();
286
287 let task = driver.spawn(
288 name,
289 InterruptTask {
290 driver: driver.clone(),
291 target_cpu: target_cpu.clone(),
292 pci_id: self.pci_id.clone(),
293 msix,
294 irq,
295 event,
296 source,
297 }
298 .run(),
299 );
300
301 InterruptState {
302 interrupt,
303 target_cpu,
304 _task: task,
305 }
306 };
307
308 Ok(interrupt.insert(new_interrupt).interrupt.clone())
309 }
310}
311
312struct InterruptTask {
313 driver: VmTaskDriver,
314 target_cpu: Arc<AtomicU32>,
315 pci_id: Arc<str>,
316 msix: u32,
317 irq: u32,
318 event: PolledWait<Event>,
319 source: DeviceInterruptSource,
320}
321
322impl InterruptTask {
323 async fn run(mut self) {
324 let mut current_cpu = !0;
325 loop {
326 let next_cpu = self.target_cpu.load(Relaxed);
327 let r = if next_cpu == current_cpu {
328 self.event.wait().await
329 } else {
330 self.driver.retarget_vp(next_cpu);
331 enum Event {
334 TargetVpReady(()),
335 Interrupt(std::io::Result<()>),
336 }
337 match (
338 self.driver.wait_target_vp_ready().map(Event::TargetVpReady),
339 self.event.wait().map(Event::Interrupt),
340 )
341 .race()
342 .await
343 {
344 Event::TargetVpReady(()) => {
345 if let Err(err) = set_irq_affinity(self.irq, next_cpu) {
346 tracing::error!(
350 pci_id = self.pci_id.as_ref(),
351 msix = self.msix,
352 irq = self.irq,
353 error = &err as &dyn std::error::Error,
354 "failed to set irq affinity"
355 );
356 }
357 current_cpu = next_cpu;
358 continue;
359 }
360 Event::Interrupt(r) => {
361 r
364 }
365 }
366 };
367
368 r.expect("wait cannot fail on eventfd");
369 self.source.signal();
370 }
371 }
372}
373
374fn set_irq_affinity(irq: u32, cpu: u32) -> std::io::Result<()> {
375 fs_err::write(
376 format!("/proc/irq/{}/smp_affinity_list", irq),
377 cpu.to_string(),
378 )
379}
380
381impl DeviceRegisterIo for vfio_sys::MappedRegion {
382 fn len(&self) -> usize {
383 self.len()
384 }
385
386 fn read_u32(&self, offset: usize) -> u32 {
387 self.read_u32(offset)
388 }
389
390 fn read_u64(&self, offset: usize) -> u64 {
391 self.read_u64(offset)
392 }
393
394 fn write_u32(&self, offset: usize, data: u32) {
395 self.write_u32(offset, data)
396 }
397
398 fn write_u64(&self, offset: usize, data: u64) {
399 self.write_u64(offset, data)
400 }
401}
402
403impl MappedRegionWithFallback {
404 fn mapping<T>(&self, offset: usize) -> *mut T {
405 assert!(offset <= self.mapping.len() - size_of::<T>() && offset % align_of::<T>() == 0);
406 if cfg!(feature = "mmio_simulate_fallback") {
407 return std::ptr::NonNull::dangling().as_ptr();
408 }
409 unsafe { self.mapping.as_ptr().byte_add(offset).cast() }
411 }
412
413 fn read_from_mapping<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
414 &self,
415 offset: usize,
416 ) -> Result<T, sparse_mmap::MemoryError> {
417 unsafe { sparse_mmap::try_read_volatile(self.mapping::<T>(offset)) }
419 }
420
421 fn write_to_mapping<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
422 &self,
423 offset: usize,
424 data: T,
425 ) -> Result<(), sparse_mmap::MemoryError> {
426 unsafe { sparse_mmap::try_write_volatile(self.mapping::<T>(offset), &data) }
428 }
429
430 fn read_from_file(&self, offset: usize, buf: &mut [u8]) {
431 tracing::trace!(offset, n = buf.len(), "read");
432 self.read_fallback.increment();
433 let n = self
434 .device
435 .as_ref()
436 .as_ref()
437 .read_at(buf, self.offset + offset as u64)
438 .expect("valid mapping");
439 assert_eq!(n, buf.len());
440 }
441
442 fn write_to_file(&self, offset: usize, buf: &[u8]) {
443 tracing::trace!(offset, n = buf.len(), "write");
444 self.write_fallback.increment();
445 let n = self
446 .device
447 .as_ref()
448 .as_ref()
449 .write_at(buf, self.offset + offset as u64)
450 .expect("valid mapping");
451 assert_eq!(n, buf.len());
452 }
453}
454
455impl DeviceRegisterIo for MappedRegionWithFallback {
456 fn len(&self) -> usize {
457 self.len
458 }
459
460 fn read_u32(&self, offset: usize) -> u32 {
461 self.read_from_mapping(offset).unwrap_or_else(|_| {
462 let mut buf = [0u8; 4];
463 self.read_from_file(offset, &mut buf);
464 u32::from_ne_bytes(buf)
465 })
466 }
467
468 fn read_u64(&self, offset: usize) -> u64 {
469 self.read_from_mapping(offset).unwrap_or_else(|_| {
470 let mut buf = [0u8; 8];
471 self.read_from_file(offset, &mut buf);
472 u64::from_ne_bytes(buf)
473 })
474 }
475
476 fn write_u32(&self, offset: usize, data: u32) {
477 self.write_to_mapping(offset, data).unwrap_or_else(|_| {
478 self.write_to_file(offset, &data.to_ne_bytes());
479 })
480 }
481
482 fn write_u64(&self, offset: usize, data: u64) {
483 self.write_to_mapping(offset, data).unwrap_or_else(|_| {
484 self.write_to_file(offset, &data.to_ne_bytes());
485 })
486 }
487}
488
489#[derive(Clone, Copy, Debug)]
490pub enum PciDeviceResetMethod {
491 NoReset,
492 Acpi,
493 Flr,
494 AfFlr,
495 Pm,
496 Bus,
497}
498
499pub fn vfio_set_device_reset_method(
500 pci_id: impl AsRef<str>,
501 method: PciDeviceResetMethod,
502) -> std::io::Result<()> {
503 let reset_method = match method {
504 PciDeviceResetMethod::NoReset => "\0".as_bytes(),
505 PciDeviceResetMethod::Acpi => "acpi\0".as_bytes(),
506 PciDeviceResetMethod::Flr => "flr\0".as_bytes(),
507 PciDeviceResetMethod::AfFlr => "af_flr\0".as_bytes(),
508 PciDeviceResetMethod::Pm => "pm\0".as_bytes(),
509 PciDeviceResetMethod::Bus => "bus\0".as_bytes(),
510 };
511
512 let path: std::path::PathBuf = ["/sys/bus/pci/devices", pci_id.as_ref(), "reset_method"]
513 .iter()
514 .collect();
515 fs_err::write(path, reset_method)?;
516 Ok(())
517}