1use crate::bus_range::AssignedBusRange;
7use pal_event::Event;
8use parking_lot::RwLock;
9use std::sync::Arc;
10use vmcore::irqfd::IrqFd;
11use vmcore::irqfd::IrqFdRoute;
12
13pub trait SignalMsi: Send + Sync {
15 fn signal_msi(&self, devid: Option<u32>, address: u64, data: u32);
22}
23
24pub struct MsiRoute {
31 inner: Box<dyn IrqFdRoute>,
32 default_bdf: DefaultBdf,
33}
34
35impl MsiRoute {
36 pub fn event(&self) -> &Event {
40 self.inner.event()
41 }
42
43 pub fn enable(&self, address: u64, data: u32) {
46 let resolved = resolve_default_bdf(&self.default_bdf);
47 self.inner.enable(address, data, Some(resolved))
48 }
49
50 pub fn enable_with_rid(&self, rid: u16, address: u64, data: u32) {
62 let bus = (rid >> 8) as u8;
63 if !self.default_bdf.bus_range.contains_bus(bus) {
64 let (secondary, subordinate) = self.default_bdf.bus_range.bus_range();
65 tracelimit::warn_ratelimited!(
66 rid,
67 secondary,
68 subordinate,
69 "refusing to enable MSI route: rid bus outside assigned bus range"
70 );
71 self.inner.disable();
72 return;
73 }
74 self.inner.enable(address, data, Some(rid.into()))
75 }
76
77 pub fn disable(&self) {
82 self.inner.disable()
83 }
84
85 pub fn consume_pending(&self) -> bool {
88 self.event().try_wait()
89 }
90}
91
92struct DisconnectedMsiTarget;
93
94impl SignalMsi for DisconnectedMsiTarget {
95 fn signal_msi(&self, _devid: Option<u32>, _address: u64, _data: u32) {
96 tracelimit::warn_ratelimited!("dropped MSI interrupt to disconnected target");
97 }
98}
99
100#[derive(Clone, Debug)]
105struct DefaultBdf {
106 bus_range: AssignedBusRange,
107 devfn: u8,
108}
109
110fn resolve_default_bdf(default: &DefaultBdf) -> u32 {
112 let (secondary, _) = default.bus_range.bus_range();
113 (secondary as u32) << 8 | default.devfn as u32
114}
115
116#[derive(Debug)]
118pub struct MsiConnection {
119 target: MsiTarget,
120}
121
122#[derive(Clone)]
124pub struct MsiTarget {
125 inner: Arc<RwLock<MsiTargetInner>>,
126 default_bdf: DefaultBdf,
127}
128
129impl MsiTarget {
130 pub fn disconnected() -> Self {
134 Self {
135 inner: Arc::new(RwLock::new(MsiTargetInner {
136 signal_msi: Arc::new(DisconnectedMsiTarget),
137 irqfd: None,
138 })),
139 default_bdf: DefaultBdf {
140 bus_range: AssignedBusRange::new(),
141 devfn: 0,
142 },
143 }
144 }
145}
146
147impl std::fmt::Debug for MsiTarget {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 f.debug_struct("MsiTarget")
150 .field("default_bdf", &self.default_bdf)
151 .finish()
152 }
153}
154
155struct MsiTargetInner {
156 signal_msi: Arc<dyn SignalMsi>,
157 irqfd: Option<Arc<dyn IrqFd>>,
158}
159
160impl std::fmt::Debug for MsiTargetInner {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 let Self {
163 signal_msi: _,
164 irqfd,
165 } = self;
166 f.debug_struct("MsiTargetInner")
167 .field("has_irqfd", &irqfd.is_some())
168 .finish()
169 }
170}
171
172impl MsiConnection {
173 pub fn new(bus_range: AssignedBusRange, devfn: u8) -> Self {
180 Self {
181 target: MsiTarget {
182 inner: Arc::new(RwLock::new(MsiTargetInner {
183 signal_msi: Arc::new(DisconnectedMsiTarget),
184 irqfd: None,
185 })),
186 default_bdf: DefaultBdf { bus_range, devfn },
187 },
188 }
189 }
190
191 pub fn connect(&self, signal_msi: Arc<dyn SignalMsi>) {
193 let mut inner = self.target.inner.write();
194 inner.signal_msi = signal_msi;
195 }
196
197 pub fn connect_irqfd(&self, irqfd: Arc<dyn IrqFd>) {
202 let mut inner = self.target.inner.write();
203 inner.irqfd = Some(irqfd);
204 }
205
206 pub fn target(&self) -> &MsiTarget {
208 &self.target
209 }
210}
211
212impl MsiTarget {
213 pub fn with_devfn(&self, devfn: u8) -> MsiTarget {
220 MsiTarget {
221 inner: self.inner.clone(),
222 default_bdf: DefaultBdf {
223 bus_range: self.default_bdf.bus_range.clone(),
224 devfn,
225 },
226 }
227 }
228
229 pub fn with_bus_range(&self, bus_range: AssignedBusRange, devfn: u8) -> MsiTarget {
235 MsiTarget {
236 inner: self.inner.clone(),
237 default_bdf: DefaultBdf { bus_range, devfn },
238 }
239 }
240
241 pub fn signal_msi(&self, address: u64, data: u32) {
244 let resolved = resolve_default_bdf(&self.default_bdf);
245 let inner = self.inner.read();
246 inner.signal_msi.signal_msi(Some(resolved), address, data);
247 }
248
249 pub fn signal_msi_with_rid(&self, rid: u16, address: u64, data: u32) {
261 let bus = (rid >> 8) as u8;
262 if !self.default_bdf.bus_range.contains_bus(bus) {
263 let (secondary, subordinate) = self.default_bdf.bus_range.bus_range();
264 tracelimit::warn_ratelimited!(
265 rid,
266 secondary,
267 subordinate,
268 "dropping MSI: rid bus outside assigned bus range"
269 );
270 return;
271 }
272 let inner = self.inner.read();
273 inner.signal_msi.signal_msi(Some(rid.into()), address, data);
274 }
275
276 pub fn new_route(&self) -> Option<anyhow::Result<MsiRoute>> {
285 let inner = self.inner.read();
286 inner.irqfd.as_ref().map(|fd| {
287 Ok(MsiRoute {
288 inner: fd.new_irqfd_route()?,
289 default_bdf: self.default_bdf.clone(),
290 })
291 })
292 }
293
294 pub fn default_bdf(&self) -> u32 {
297 resolve_default_bdf(&self.default_bdf)
298 }
299
300 pub fn supports_direct_msi(&self) -> bool {
302 let inner = self.inner.read();
303 inner.irqfd.is_some()
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use crate::bus_range::AssignedBusRange;
311 use pal_event::Event;
312 use parking_lot::Mutex;
313 use std::collections::VecDeque;
314
315 struct RecordingSignalMsi {
317 calls: Mutex<VecDeque<(Option<u32>, u64, u32)>>,
318 }
319
320 impl RecordingSignalMsi {
321 fn new() -> Arc<Self> {
322 Arc::new(Self {
323 calls: Mutex::new(VecDeque::new()),
324 })
325 }
326
327 fn pop(&self) -> Option<(Option<u32>, u64, u32)> {
328 self.calls.lock().pop_front()
329 }
330 }
331
332 impl SignalMsi for RecordingSignalMsi {
333 fn signal_msi(&self, devid: Option<u32>, address: u64, data: u32) {
334 self.calls.lock().push_back((devid, address, data));
335 }
336 }
337
338 #[derive(Debug, Clone, PartialEq)]
339 enum RouteCall {
340 Enable {
341 address: u64,
342 data: u32,
343 devid: Option<u32>,
344 },
345 Disable,
346 }
347
348 struct MockIrqFdRoute {
349 event: Event,
350 calls: Arc<Mutex<Vec<RouteCall>>>,
351 }
352
353 impl IrqFdRoute for MockIrqFdRoute {
354 fn event(&self) -> &Event {
355 &self.event
356 }
357
358 fn enable(&self, address: u64, data: u32, devid: Option<u32>) {
359 self.calls.lock().push(RouteCall::Enable {
360 address,
361 data,
362 devid,
363 });
364 }
365
366 fn disable(&self) {
367 self.calls.lock().push(RouteCall::Disable);
368 }
369 }
370
371 fn mock_irqfd(count: usize) -> (Arc<dyn IrqFd>, Vec<Arc<Mutex<Vec<RouteCall>>>>) {
372 let mut call_logs = Vec::new();
373 let route_params = Arc::new(Mutex::new(Vec::new()));
374 for _ in 0..count {
375 let calls = Arc::new(Mutex::new(Vec::new()));
376 call_logs.push(calls.clone());
377 route_params.lock().push(calls);
378 }
379
380 struct MockIrqFd {
381 routes: Mutex<Vec<Arc<Mutex<Vec<RouteCall>>>>>,
382 }
383 impl IrqFd for MockIrqFd {
384 fn new_irqfd_route(&self) -> anyhow::Result<Box<dyn IrqFdRoute>> {
385 let calls = self.routes.lock().remove(0);
386 Ok(Box::new(MockIrqFdRoute {
387 event: Event::new(),
388 calls,
389 }))
390 }
391 }
392
393 (
394 Arc::new(MockIrqFd {
395 routes: Mutex::new(call_logs.clone()),
396 }),
397 call_logs,
398 )
399 }
400
401 #[test]
402 fn signal_msi_resolves_default_bdf() {
403 let bus_range = AssignedBusRange::new();
404 bus_range.set_bus_range(5, 10);
405 let msi_conn = MsiConnection::new(bus_range, 0x18); let recorder = RecordingSignalMsi::new();
407 msi_conn.connect(recorder.clone());
408
409 msi_conn.target().signal_msi(0xFEE0_0000, 42);
410
411 let (devid, addr, data) = recorder.pop().unwrap();
412 assert_eq!(devid, Some((5 << 8) | 0x18));
413 assert_eq!(addr, 0xFEE0_0000);
414 assert_eq!(data, 42);
415 }
416
417 #[test]
418 fn signal_msi_with_rid_accepts_bus_in_range() {
419 let bus_range = AssignedBusRange::new();
420 bus_range.set_bus_range(5, 10);
421 let msi_conn = MsiConnection::new(bus_range, 0);
422 let recorder = RecordingSignalMsi::new();
423 msi_conn.connect(recorder.clone());
424
425 let rid: u16 = (7 << 8) | 0x0A;
427 msi_conn.target().signal_msi_with_rid(rid, 0xABCD, 99);
428
429 let (devid, addr, data) = recorder.pop().unwrap();
430 assert_eq!(devid, Some(rid as u32));
431 assert_eq!(addr, 0xABCD);
432 assert_eq!(data, 99);
433 }
434
435 #[test]
436 fn signal_msi_with_rid_drops_bus_outside_range() {
437 let bus_range = AssignedBusRange::new();
438 bus_range.set_bus_range(5, 10);
439 let msi_conn = MsiConnection::new(bus_range, 0);
440 let recorder = RecordingSignalMsi::new();
441 msi_conn.connect(recorder.clone());
442
443 let rid_above: u16 = 11 << 8;
445 msi_conn.target().signal_msi_with_rid(rid_above, 0xABCD, 1);
446 assert!(recorder.pop().is_none());
447
448 let rid_below: u16 = 4 << 8;
450 msi_conn.target().signal_msi_with_rid(rid_below, 0xABCD, 2);
451 assert!(recorder.pop().is_none());
452 }
453
454 #[test]
455 fn signal_msi_with_rid_accepts_boundary_buses() {
456 let bus_range = AssignedBusRange::new();
457 bus_range.set_bus_range(5, 10);
458 let msi_conn = MsiConnection::new(bus_range, 0);
459 let recorder = RecordingSignalMsi::new();
460 msi_conn.connect(recorder.clone());
461
462 msi_conn.target().signal_msi_with_rid(5 << 8, 0x1000, 10);
464 assert!(recorder.pop().is_some());
465
466 msi_conn.target().signal_msi_with_rid(10 << 8, 0x2000, 20);
468 assert!(recorder.pop().is_some());
469 }
470
471 #[test]
472 fn route_enable_resolves_default_bdf() {
473 let bus_range = AssignedBusRange::new();
474 bus_range.set_bus_range(3, 8);
475 let (irqfd, calls) = mock_irqfd(1);
476 let msi_conn = MsiConnection::new(bus_range, 0x10); msi_conn.connect_irqfd(irqfd);
478
479 let route = msi_conn.target().new_route().unwrap().unwrap();
480 route.enable(0xFEE0_0000, 55);
481
482 let log = calls[0].lock();
483 assert_eq!(log.len(), 1);
484 assert_eq!(
485 log[0],
486 RouteCall::Enable {
487 address: 0xFEE0_0000,
488 data: 55,
489 devid: Some((3 << 8) | 0x10),
490 }
491 );
492 }
493
494 #[test]
495 fn route_enable_with_rid_accepts_bus_in_range() {
496 let bus_range = AssignedBusRange::new();
497 bus_range.set_bus_range(5, 10);
498 let (irqfd, calls) = mock_irqfd(1);
499 let msi_conn = MsiConnection::new(bus_range, 0);
500 msi_conn.connect_irqfd(irqfd);
501
502 let route = msi_conn.target().new_route().unwrap().unwrap();
503 let rid: u16 = (7 << 8) | 0x0A;
504 route.enable_with_rid(rid, 0xBEEF, 77);
505
506 let log = calls[0].lock();
507 assert_eq!(log.len(), 1);
508 assert_eq!(
509 log[0],
510 RouteCall::Enable {
511 address: 0xBEEF,
512 data: 77,
513 devid: Some(rid as u32),
514 }
515 );
516 }
517
518 #[test]
519 fn route_enable_with_rid_disables_when_bus_outside_range() {
520 let bus_range = AssignedBusRange::new();
521 bus_range.set_bus_range(5, 10);
522 let (irqfd, calls) = mock_irqfd(1);
523 let msi_conn = MsiConnection::new(bus_range, 0);
524 msi_conn.connect_irqfd(irqfd);
525
526 let route = msi_conn.target().new_route().unwrap().unwrap();
527 let rid: u16 = 11 << 8;
529 route.enable_with_rid(rid, 0xBEEF, 77);
530
531 let log = calls[0].lock();
532 assert_eq!(log.len(), 1);
533 assert_eq!(log[0], RouteCall::Disable);
534 }
535
536 #[test]
537 fn with_devfn_derives_target_with_new_devfn() {
538 let bus_range = AssignedBusRange::new();
539 bus_range.set_bus_range(2, 5);
540 let msi_conn = MsiConnection::new(bus_range, 0);
541 let recorder = RecordingSignalMsi::new();
542 msi_conn.connect(recorder.clone());
543
544 let derived = msi_conn.target().with_devfn(0x18); derived.signal_msi(0x1000, 1);
546
547 let (devid, _, _) = recorder.pop().unwrap();
548 assert_eq!(devid, Some((2 << 8) | 0x18));
549 }
550
551 #[test]
552 fn with_bus_range_derives_target_with_new_range() {
553 let parent_range = AssignedBusRange::new();
554 parent_range.set_bus_range(1, 20);
555 let msi_conn = MsiConnection::new(parent_range, 0);
556 let recorder = RecordingSignalMsi::new();
557 msi_conn.connect(recorder.clone());
558
559 let child_range = AssignedBusRange::new();
560 child_range.set_bus_range(10, 15);
561 let derived = msi_conn.target().with_bus_range(child_range, 0x08);
562 derived.signal_msi(0x2000, 2);
563
564 let (devid, _, _) = recorder.pop().unwrap();
565 assert_eq!(devid, Some((10 << 8) | 0x08));
567
568 derived.signal_msi_with_rid(16 << 8, 0x3000, 3);
570 assert!(recorder.pop().is_none()); }
572}