#![cfg(target_os = "linux")]
#![cfg(feature = "vfio")]
use crate::DeviceBacking;
use crate::DeviceRegisterIo;
use crate::DmaClient;
use crate::interrupt::DeviceInterrupt;
use crate::interrupt::DeviceInterruptSource;
use anyhow::Context;
use futures::FutureExt;
use futures_concurrency::future::Race;
use inspect::Inspect;
use inspect_counters::SharedCounter;
use pal_async::task::Spawn;
use pal_async::task::Task;
use pal_async::wait::PolledWait;
use pal_event::Event;
use std::os::fd::AsFd;
use std::os::unix::fs::FileExt;
use std::path::Path;
use std::sync::Arc;
use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering::Relaxed;
use std::time::Duration;
use uevent::UeventListener;
use vfio_bindings::bindings::vfio::VFIO_PCI_CONFIG_REGION_INDEX;
use vfio_sys::IommuType;
use vfio_sys::IrqInfo;
use vmcore::vm_task::VmTaskDriver;
use vmcore::vm_task::VmTaskDriverSource;
use zerocopy::FromBytes;
use zerocopy::Immutable;
use zerocopy::IntoBytes;
use zerocopy::KnownLayout;
#[derive(Inspect)]
pub struct VfioDevice {
pci_id: Arc<str>,
#[inspect(skip)]
_container: vfio_sys::Container,
#[inspect(skip)]
_group: vfio_sys::Group,
#[inspect(skip)]
device: Arc<vfio_sys::Device>,
#[inspect(skip)]
msix_info: IrqInfo,
#[inspect(skip)]
driver_source: VmTaskDriverSource,
#[inspect(iter_by_index)]
interrupts: Vec<Option<InterruptState>>,
#[inspect(skip)]
config_space: vfio_sys::RegionInfo,
dma_client: Arc<dyn DmaClient>,
}
#[derive(Inspect)]
struct InterruptState {
#[inspect(skip)]
interrupt: DeviceInterrupt,
target_cpu: Arc<AtomicU32>,
#[inspect(skip)]
_task: Task<()>,
}
impl VfioDevice {
pub async fn new(
driver_source: &VmTaskDriverSource,
pci_id: &str,
dma_client: Arc<dyn DmaClient>,
) -> anyhow::Result<Self> {
Self::restore(driver_source, pci_id, false, dma_client).await
}
pub async fn restore(
driver_source: &VmTaskDriverSource,
pci_id: &str,
keepalive: bool,
dma_client: Arc<dyn DmaClient>,
) -> anyhow::Result<Self> {
let path = Path::new("/sys/bus/pci/devices").join(pci_id);
let vmbus_device =
std::fs::read_link(&path).context("failed to read link for pci device")?;
let instance_path = Path::new("/sys").join(vmbus_device.strip_prefix("../../..")?);
let vfio_arrived_path = instance_path.join("vfio-dev");
let uevent_listener = UeventListener::new(&driver_source.simple())?;
let wait_for_vfio_device =
uevent_listener.wait_for_matching_child(&vfio_arrived_path, async |_, _| Some(()));
let mut ctx = mesh::CancelContext::new().with_timeout(Duration::from_secs(1));
let _ = ctx.until_cancelled(wait_for_vfio_device).await;
let container = vfio_sys::Container::new()?;
let group_id = vfio_sys::Group::find_group_for_device(&path)?;
let group = vfio_sys::Group::open_noiommu(group_id)?;
group.set_container(&container)?;
if !group.status()?.viable() {
anyhow::bail!("group is not viable");
}
container.set_iommu(IommuType::NoIommu)?;
if keepalive {
group.set_keep_alive(pci_id)?;
}
let device = group.open_device(pci_id)?;
let msix_info = device.irq_info(vfio_bindings::bindings::vfio::VFIO_PCI_MSIX_IRQ_INDEX)?;
if msix_info.flags.noresize() {
anyhow::bail!("unsupported: kernel does not support dynamic msix allocation");
}
let config_space = device.region_info(VFIO_PCI_CONFIG_REGION_INDEX)?;
let this = Self {
pci_id: pci_id.into(),
_container: container,
_group: group,
device: Arc::new(device),
msix_info,
config_space,
driver_source: driver_source.clone(),
interrupts: Vec::new(),
dma_client,
};
this.enable_device().context("failed to enable device")?;
Ok(this)
}
fn enable_device(&self) -> anyhow::Result<()> {
let offset = pci_core::spec::cfg_space::HeaderType00::STATUS_COMMAND.0;
let status_command = self.read_config(offset)?;
let command = pci_core::spec::cfg_space::Command::from(status_command as u16);
let command = command
.with_bus_master(true)
.with_intx_disable(true)
.with_mmio_enabled(true);
let status_command = (status_command & 0xffff0000) | u16::from(command) as u32;
self.write_config(offset, status_command)?;
Ok(())
}
pub fn read_config(&self, offset: u16) -> anyhow::Result<u32> {
if offset as u64 > self.config_space.size - 4 {
anyhow::bail!("invalid config offset");
}
let mut buf = [0u8; 4];
self.device
.as_ref()
.as_ref()
.read_at(&mut buf, self.config_space.offset + offset as u64)
.context("failed to read config")?;
Ok(u32::from_ne_bytes(buf))
}
pub fn write_config(&self, offset: u16, data: u32) -> anyhow::Result<()> {
if offset as u64 > self.config_space.size - 4 {
anyhow::bail!("invalid config offset");
}
let buf = data.to_ne_bytes();
self.device
.as_ref()
.as_ref()
.write_at(&buf, self.config_space.offset + offset as u64)
.context("failed to write config")?;
Ok(())
}
fn map_bar(&self, n: u8) -> anyhow::Result<MappedRegionWithFallback> {
if n >= 6 {
anyhow::bail!("invalid bar");
}
let info = self.device.region_info(n.into())?;
let mapping = self.device.map(info.offset, info.size as usize, true)?;
sparse_mmap::initialize_try_copy();
Ok(MappedRegionWithFallback {
device: self.device.clone(),
mapping,
len: info.size as usize,
offset: info.offset,
read_fallback: SharedCounter::new(),
write_fallback: SharedCounter::new(),
})
}
}
#[derive(Inspect)]
pub struct MappedRegionWithFallback {
#[inspect(skip)]
device: Arc<vfio_sys::Device>,
#[inspect(skip)]
mapping: vfio_sys::MappedRegion,
offset: u64,
len: usize,
read_fallback: SharedCounter,
write_fallback: SharedCounter,
}
impl DeviceBacking for VfioDevice {
type Registers = MappedRegionWithFallback;
fn id(&self) -> &str {
&self.pci_id
}
fn map_bar(&mut self, n: u8) -> anyhow::Result<Self::Registers> {
(*self).map_bar(n)
}
fn dma_client(&self) -> Arc<dyn DmaClient> {
self.dma_client.clone()
}
fn max_interrupt_count(&self) -> u32 {
self.msix_info.count
}
fn map_interrupt(&mut self, msix: u32, cpu: u32) -> anyhow::Result<DeviceInterrupt> {
if msix >= self.msix_info.count {
anyhow::bail!("invalid msix index");
}
if self.interrupts.len() <= msix as usize {
self.interrupts.resize_with(msix as usize + 1, || None);
}
let interrupt = &mut self.interrupts[msix as usize];
if let Some(interrupt) = interrupt {
if interrupt.target_cpu.load(Relaxed) != cpu {
interrupt.target_cpu.store(cpu, Relaxed);
}
return Ok(interrupt.interrupt.clone());
}
let new_interrupt = {
let name = format!("vfio-interrupt-{pci_id}-{msix}", pci_id = self.pci_id);
let driver = self
.driver_source
.builder()
.run_on_target(true)
.target_vp(cpu)
.build(&name);
let event =
PolledWait::new(&driver, Event::new()).context("failed to allocate polled wait")?;
let source = DeviceInterruptSource::new();
self.device
.map_msix(msix, [event.get().as_fd()])
.context("failed to map msix")?;
let irq = vfio_sys::find_msix_irq(&self.pci_id, msix)
.context("failed to find irq for msix")?;
let target_cpu = Arc::new(AtomicU32::new(cpu));
let interrupt = source.new_target();
let task = driver.spawn(
name,
InterruptTask {
driver: driver.clone(),
target_cpu: target_cpu.clone(),
pci_id: self.pci_id.clone(),
msix,
irq,
event,
source,
}
.run(),
);
InterruptState {
interrupt,
target_cpu,
_task: task,
}
};
Ok(interrupt.insert(new_interrupt).interrupt.clone())
}
}
struct InterruptTask {
driver: VmTaskDriver,
target_cpu: Arc<AtomicU32>,
pci_id: Arc<str>,
msix: u32,
irq: u32,
event: PolledWait<Event>,
source: DeviceInterruptSource,
}
impl InterruptTask {
async fn run(mut self) {
let mut current_cpu = !0;
loop {
let next_cpu = self.target_cpu.load(Relaxed);
let r = if next_cpu == current_cpu {
self.event.wait().await
} else {
self.driver.retarget_vp(next_cpu);
enum Event {
TargetVpReady(()),
Interrupt(std::io::Result<()>),
}
match (
self.driver.wait_target_vp_ready().map(Event::TargetVpReady),
self.event.wait().map(Event::Interrupt),
)
.race()
.await
{
Event::TargetVpReady(()) => {
if let Err(err) = set_irq_affinity(self.irq, next_cpu) {
tracing::error!(
pci_id = self.pci_id.as_ref(),
msix = self.msix,
irq = self.irq,
error = &err as &dyn std::error::Error,
"failed to set irq affinity"
);
}
current_cpu = next_cpu;
continue;
}
Event::Interrupt(r) => {
r
}
}
};
r.expect("wait cannot fail on eventfd");
self.source.signal();
}
}
}
fn set_irq_affinity(irq: u32, cpu: u32) -> std::io::Result<()> {
fs_err::write(
format!("/proc/irq/{}/smp_affinity_list", irq),
cpu.to_string(),
)
}
impl DeviceRegisterIo for vfio_sys::MappedRegion {
fn len(&self) -> usize {
self.len()
}
fn read_u32(&self, offset: usize) -> u32 {
self.read_u32(offset)
}
fn read_u64(&self, offset: usize) -> u64 {
self.read_u64(offset)
}
fn write_u32(&self, offset: usize, data: u32) {
self.write_u32(offset, data)
}
fn write_u64(&self, offset: usize, data: u64) {
self.write_u64(offset, data)
}
}
impl MappedRegionWithFallback {
fn mapping<T>(&self, offset: usize) -> *mut T {
assert!(offset <= self.mapping.len() - size_of::<T>() && offset % align_of::<T>() == 0);
if cfg!(feature = "mmio_simulate_fallback") {
return std::ptr::NonNull::dangling().as_ptr();
}
unsafe { self.mapping.as_ptr().byte_add(offset).cast() }
}
fn read_from_mapping<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
&self,
offset: usize,
) -> Result<T, sparse_mmap::MemoryError> {
unsafe { sparse_mmap::try_read_volatile(self.mapping::<T>(offset)) }
}
fn write_to_mapping<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
&self,
offset: usize,
data: T,
) -> Result<(), sparse_mmap::MemoryError> {
unsafe { sparse_mmap::try_write_volatile(self.mapping::<T>(offset), &data) }
}
fn read_from_file(&self, offset: usize, buf: &mut [u8]) {
tracing::trace!(offset, n = buf.len(), "read");
self.read_fallback.increment();
let n = self
.device
.as_ref()
.as_ref()
.read_at(buf, self.offset + offset as u64)
.expect("valid mapping");
assert_eq!(n, buf.len());
}
fn write_to_file(&self, offset: usize, buf: &[u8]) {
tracing::trace!(offset, n = buf.len(), "write");
self.write_fallback.increment();
let n = self
.device
.as_ref()
.as_ref()
.write_at(buf, self.offset + offset as u64)
.expect("valid mapping");
assert_eq!(n, buf.len());
}
}
impl DeviceRegisterIo for MappedRegionWithFallback {
fn len(&self) -> usize {
self.len
}
fn read_u32(&self, offset: usize) -> u32 {
self.read_from_mapping(offset).unwrap_or_else(|_| {
let mut buf = [0u8; 4];
self.read_from_file(offset, &mut buf);
u32::from_ne_bytes(buf)
})
}
fn read_u64(&self, offset: usize) -> u64 {
self.read_from_mapping(offset).unwrap_or_else(|_| {
let mut buf = [0u8; 8];
self.read_from_file(offset, &mut buf);
u64::from_ne_bytes(buf)
})
}
fn write_u32(&self, offset: usize, data: u32) {
self.write_to_mapping(offset, data).unwrap_or_else(|_| {
self.write_to_file(offset, &data.to_ne_bytes());
})
}
fn write_u64(&self, offset: usize, data: u64) {
self.write_to_mapping(offset, data).unwrap_or_else(|_| {
self.write_to_file(offset, &data.to_ne_bytes());
})
}
}
#[derive(Clone, Copy, Debug)]
pub enum PciDeviceResetMethod {
NoReset,
Acpi,
Flr,
AfFlr,
Pm,
Bus,
}
pub fn vfio_set_device_reset_method(
pci_id: impl AsRef<str>,
method: PciDeviceResetMethod,
) -> std::io::Result<()> {
let reset_method = match method {
PciDeviceResetMethod::NoReset => "\0".as_bytes(),
PciDeviceResetMethod::Acpi => "acpi\0".as_bytes(),
PciDeviceResetMethod::Flr => "flr\0".as_bytes(),
PciDeviceResetMethod::AfFlr => "af_flr\0".as_bytes(),
PciDeviceResetMethod::Pm => "pm\0".as_bytes(),
PciDeviceResetMethod::Bus => "bus\0".as_bytes(),
};
let path: std::path::PathBuf = ["/sys/bus/pci/devices", pci_id.as_ref(), "reset_method"]
.iter()
.collect();
fs_err::write(path, reset_method)?;
Ok(())
}