Skip to main content

pci_core/capabilities/
msix.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! MSI-X Capability.
5
6use super::PciCapability;
7use crate::msi::MsiTarget;
8use crate::spec::caps::CapabilityId;
9use crate::spec::caps::msix::MsixCapabilityHeader;
10use crate::spec::caps::msix::MsixTableEntryIdx;
11use inspect::Inspect;
12use inspect::InspectMut;
13use parking_lot::Mutex;
14use std::fmt::Debug;
15use std::sync::Arc;
16use vmcore::interrupt::Interrupt;
17use vmcore::irqfd::IrqFd;
18use vmcore::irqfd::IrqFdRoute;
19
20#[derive(Debug, Inspect)]
21struct MsiTableLocation {
22    #[inspect(hex)]
23    // MSI-X table offsets are, per spec, no larger than 32 bits.
24    offset: u32,
25    bar: u8,
26}
27
28impl MsiTableLocation {
29    fn new(bar: u8, offset: u32) -> Self {
30        assert!(bar < 6);
31        assert!(offset & 7 == 0);
32        Self { offset, bar }
33    }
34
35    fn read_u32(&self) -> u32 {
36        self.offset | self.bar as u32
37    }
38}
39
40#[derive(Inspect)]
41struct MsixCapability {
42    count: u16,
43    #[inspect(with = "|x| inspect::adhoc(|req| x.lock().inspect_mut(req))")]
44    state: Arc<Mutex<MsixState>>,
45    config_table_location: MsiTableLocation,
46    pending_bits_location: MsiTableLocation,
47}
48
49impl PciCapability for MsixCapability {
50    fn label(&self) -> &str {
51        "msi-x"
52    }
53
54    fn capability_id(&self) -> CapabilityId {
55        CapabilityId::MSIX
56    }
57
58    fn len(&self) -> usize {
59        12
60    }
61
62    fn read_u32(&self, offset: u16) -> u32 {
63        match MsixCapabilityHeader(offset) {
64            MsixCapabilityHeader::CONTROL_CAPS => {
65                CapabilityId::MSIX.0 as u32
66                    | ((self.count as u32 - 1) | if self.state.lock().enabled { 0x8000 } else { 0 })
67                        << 16
68            }
69            MsixCapabilityHeader::OFFSET_TABLE => self.config_table_location.read_u32(),
70            MsixCapabilityHeader::OFFSET_PBA => self.pending_bits_location.read_u32(),
71            _ => panic!("Unreachable read offset {}", offset),
72        }
73    }
74
75    fn write_u32(&mut self, offset: u16, val: u32) {
76        match MsixCapabilityHeader(offset) {
77            MsixCapabilityHeader::CONTROL_CAPS => {
78                let enabled = val & 0x80000000 != 0;
79                let mut state = self.state.lock();
80                let was_enabled = state.enabled;
81                state.enabled = enabled;
82                if was_enabled && !enabled {
83                    for entry in &mut state.vectors {
84                        if entry.is_enabled(true) {
85                            entry.msi.disable();
86                        }
87                    }
88                } else if enabled && !was_enabled {
89                    for entry in &mut state.vectors {
90                        if entry.is_enabled(true) {
91                            entry.msi.enable(
92                                entry.state.address,
93                                entry.state.data,
94                                entry.state.is_pending,
95                            );
96                            entry.state.is_pending = false;
97                        }
98                    }
99                }
100            }
101            MsixCapabilityHeader::OFFSET_TABLE | MsixCapabilityHeader::OFFSET_PBA => {
102                tracelimit::warn_ratelimited!(
103                    "Unexpected write offset {:?}",
104                    MsixCapabilityHeader(offset)
105                )
106            }
107            _ => panic!("Unreachable write offset {}", offset),
108        }
109    }
110
111    fn reset(&mut self) {
112        let mut state = self.state.lock();
113        state.enabled = false;
114        for vector in &mut state.vectors {
115            vector.msi.disable();
116            vector.state = EntryState::new();
117        }
118    }
119}
120
121#[derive(Clone, Inspect, Debug)]
122pub(crate) struct MsiInterrupt(#[inspect(flatten)] Arc<Mutex<MsiInterruptInner>>);
123
124#[derive(Inspect)]
125struct MsiInterruptInner {
126    target: MsiTarget,
127    /// Optional kernel-mediated route for fast interrupt delivery.
128    /// When present, `enable()`/`disable()` automatically program the
129    /// kernel's MSI routing alongside the userspace `MsiTarget` path.
130    #[inspect(skip)]
131    route: Option<Box<dyn IrqFdRoute>>,
132    pending: bool,
133    enabled: bool,
134    address: u64,
135    data: u32,
136}
137
138impl Debug for MsiInterruptInner {
139    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140        f.debug_struct("MsiInterruptInner")
141            .field("pending", &self.pending)
142            .field("enabled", &self.enabled)
143            .field("address", &self.address)
144            .field("data", &self.data)
145            .field("has_route", &self.route.is_some())
146            .finish()
147    }
148}
149
150impl MsiInterrupt {
151    pub fn new(target: MsiTarget) -> Self {
152        Self(Arc::new(Mutex::new(MsiInterruptInner {
153            target,
154            route: None,
155            pending: false,
156            enabled: false,
157            address: 0,
158            data: 0,
159        })))
160    }
161
162    pub fn enable(&self, address: u64, data: u32, set_pending: bool) {
163        let mut state = self.0.lock();
164        state.pending |= set_pending;
165        state.address = address;
166        state.data = data;
167        state.enabled = true;
168
169        // Program the kernel route if present.
170        if state.route.is_some() {
171            let route = state.route.as_ref().unwrap();
172            if route.consume_pending() {
173                state.pending = true;
174            }
175        }
176        if let Some(route) = &state.route {
177            if address != 0 || data != 0 {
178                if let Err(e) = route.set_msi(address, data) {
179                    tracelimit::warn_ratelimited!(
180                        error = ?e,
181                        "failed to program MSI-X route on enable"
182                    );
183                }
184            } else if let Err(e) = route.clear_msi() {
185                tracelimit::warn_ratelimited!(
186                    error = ?e,
187                    "failed to clear MSI-X route on enable"
188                );
189            }
190        }
191
192        if state.pending {
193            state.target.signal_msi(0, address, data);
194            state.pending = false;
195        }
196    }
197
198    pub fn disable(&self) {
199        let mut state = self.0.lock();
200        state.enabled = false;
201        if let Some(route) = &state.route {
202            if let Err(e) = route.mask() {
203                tracelimit::warn_ratelimited!(
204                    error = ?e,
205                    "failed to mask MSI-X route on disable"
206                );
207            }
208        }
209    }
210
211    pub fn drain_pending(&self) -> bool {
212        let mut state = self.0.lock();
213        if let Some(route) = &state.route {
214            state.pending |= route.consume_pending();
215        }
216        let was_pending = state.pending;
217        state.pending = false;
218        was_pending
219    }
220
221    /// Install a kernel-mediated route for fast interrupt delivery.
222    pub(crate) fn set_route(&self, route: Box<dyn IrqFdRoute>) {
223        self.0.lock().route = Some(route);
224    }
225
226    /// Remove the kernel-mediated route.
227    pub(crate) fn clear_route(&self) {
228        self.0.lock().route = None;
229    }
230
231    pub fn interrupt(&self) -> Interrupt {
232        let state = self.0.clone();
233        Interrupt::from_fn(move || {
234            let mut state = state.lock();
235            if state.enabled {
236                state.target.signal_msi(0, state.address, state.data);
237            } else {
238                state.pending = true;
239            }
240        })
241    }
242}
243
244struct MsixMessageTableEntry {
245    msi: MsiInterrupt,
246    state: EntryState,
247}
248
249impl InspectMut for MsixMessageTableEntry {
250    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
251        req.respond()
252            .hex("address", self.state.address)
253            .hex("data", self.state.data)
254            .hex("control", self.state.control)
255            .field("enabled", self.state.control & 1 == 0)
256            .field("is_pending", self.check_is_pending(true));
257    }
258}
259
260#[derive(Debug)]
261struct EntryState {
262    address: u64,
263    data: u32,
264    control: u32,
265    is_pending: bool,
266}
267
268impl EntryState {
269    fn new() -> Self {
270        Self {
271            address: 0,
272            data: 0,
273            control: 1,
274            is_pending: false,
275        }
276    }
277}
278
279impl MsixMessageTableEntry {
280    fn new(msi: MsiInterrupt) -> Self {
281        Self {
282            msi,
283            state: EntryState::new(),
284        }
285    }
286
287    fn read_u32(&self, offset: u64) -> u32 {
288        match MsixTableEntryIdx(offset) {
289            MsixTableEntryIdx::MSG_ADDR_LO => self.state.address as u32,
290            MsixTableEntryIdx::MSG_ADDR_HI => (self.state.address >> 32) as u32,
291            MsixTableEntryIdx::MSG_DATA => self.state.data,
292            MsixTableEntryIdx::VECTOR_CTL => self.state.control,
293            _ => panic!("Unexpected read offset {}", offset),
294        }
295    }
296
297    fn write_u32(&mut self, offset: u64, val: u32) {
298        match MsixTableEntryIdx(offset) {
299            MsixTableEntryIdx::MSG_ADDR_LO => {
300                self.state.address = (self.state.address & 0xffffffff00000000) | val as u64
301            }
302            MsixTableEntryIdx::MSG_ADDR_HI => {
303                self.state.address = (val as u64) << 32 | self.state.address & 0xffffffff
304            }
305            MsixTableEntryIdx::MSG_DATA => self.state.data = val,
306            MsixTableEntryIdx::VECTOR_CTL => self.state.control = val,
307            _ => panic!("Unexpected write offset {}", offset),
308        }
309    }
310
311    fn is_enabled(&self, global_enabled: bool) -> bool {
312        global_enabled && self.state.control & 1 == 0
313    }
314
315    fn check_is_pending(&mut self, global_enabled: bool) -> bool {
316        if !self.state.is_pending && !self.is_enabled(global_enabled) {
317            self.state.is_pending = self.msi.drain_pending();
318        }
319        self.state.is_pending
320    }
321}
322
323#[derive(InspectMut)]
324struct MsixState {
325    enabled: bool,
326    #[inspect(mut, with = "inspect_entries")]
327    vectors: Vec<MsixMessageTableEntry>,
328}
329
330fn inspect_entries(entries: &mut [MsixMessageTableEntry]) -> impl '_ + InspectMut {
331    inspect::adhoc_mut(|req| {
332        let mut resp = req.respond();
333        for (i, entry) in entries.iter_mut().enumerate() {
334            resp.field_mut(&i.to_string(), entry);
335        }
336    })
337}
338
339/// Emulator for the hardware-level interface required to configure and trigger
340/// MSI-X interrupts on a PCI device.
341#[derive(Clone)]
342pub struct MsixEmulator {
343    state: Arc<Mutex<MsixState>>,
344    // PBA offsets, per spec, are no larger than 32 bits.
345    pending_bits_offset: u32,
346    pending_bits_dword_count: u16,
347}
348
349impl MsixEmulator {
350    /// Create a new [`MsixEmulator`] instance, along with with its associated
351    /// [`PciCapability`] structure.
352    ///
353    /// This implementation of MSI-X expects a dedicated BAR to store the vector
354    /// and pending tables.
355    ///
356    /// * * *
357    ///
358    /// DEVNOTE: This current implementation of MSI-X isn't particularly
359    /// "flexible" with respect to the various ways the PCI spec allows MSI-X to
360    /// be implemented. e.g: it uses a shared BAR for the table and BPA, with
361    /// fixed offsets into the BAR for both of those tables. It would be nice to
362    /// re-visit this code and make it more flexible.
363    pub fn new(bar: u8, count: u16, msi_target: &MsiTarget) -> (Self, impl PciCapability + use<>) {
364        let state = MsixState {
365            enabled: false,
366            vectors: (0..count)
367                .map(|_| MsixMessageTableEntry::new(MsiInterrupt::new(msi_target.clone())))
368                .collect(),
369        };
370        let state = Arc::new(Mutex::new(state));
371        let pending_bits_offset = count as u32 * 16;
372        (
373            Self {
374                state: state.clone(),
375                pending_bits_offset,
376                pending_bits_dword_count: count.div_ceil(32),
377            },
378            MsixCapability {
379                count,
380                state,
381                config_table_location: MsiTableLocation::new(bar, 0),
382                pending_bits_location: MsiTableLocation::new(bar, pending_bits_offset),
383            },
384        )
385    }
386
387    /// Return the total length of the MSI-X BAR
388    /// (Actually, the notion that there is an "MSI-X BAR" is an issue to fix sometime.
389    /// MSI-X tables are often in the same bar as other things.)
390    pub fn bar_len(&self) -> u64 {
391        self.pending_bits_offset as u64 + self.pending_bits_dword_count as u64 * 4
392    }
393
394    /// Read a `u32` from the MSI-X BAR at the given offset.
395    pub fn read_u32(&self, offset: u64) -> u32 {
396        let mut state = self.state.lock();
397        let state: &mut MsixState = &mut state;
398        if offset < self.pending_bits_offset as u64 {
399            let index = offset / 16;
400            if let Some(entry) = state.vectors.get(index as usize) {
401                return entry.read_u32(offset & 0xf);
402            }
403        } else {
404            let dword = (offset - self.pending_bits_offset as u64) / 4;
405            let start = dword as usize * 32;
406            if start < state.vectors.len() {
407                let end = (start + 32).min(state.vectors.len());
408                let mut val = 0u32;
409                for (i, entry) in state.vectors[start..end].iter_mut().enumerate() {
410                    if entry.check_is_pending(state.enabled) {
411                        val |= 1 << i;
412                    }
413                }
414                return val;
415            }
416        }
417        tracelimit::warn_ratelimited!(offset, "Unexpected read offset");
418        0
419    }
420
421    /// Write a `u32` to the MSI-X BAR at the given offset.
422    pub fn write_u32(&mut self, offset: u64, val: u32) {
423        let mut state = self.state.lock();
424        if offset < self.pending_bits_offset as u64 {
425            let index = offset / 16;
426            let global = state.enabled;
427            if let Some(entry) = state.vectors.get_mut(index as usize) {
428                let was_enabled = entry.is_enabled(global);
429                entry.write_u32(offset & 0xf, val);
430                let is_enabled = entry.is_enabled(global);
431                if is_enabled && !was_enabled {
432                    // Vector just unmasked.
433                    entry.msi.enable(
434                        entry.state.address,
435                        entry.state.data,
436                        entry.state.is_pending,
437                    );
438                    entry.state.is_pending = false;
439                } else if was_enabled && !is_enabled {
440                    // Vector just masked.
441                    entry.msi.disable();
442                } else if is_enabled {
443                    // Still enabled. addr/data may have changed — enable()
444                    // will update the route if present.
445                    entry
446                        .msi
447                        .enable(entry.state.address, entry.state.data, false);
448                }
449                return;
450            }
451        } else if offset - (self.pending_bits_offset as u64)
452            < self.pending_bits_dword_count as u64 * 4
453        {
454            return;
455        }
456        tracelimit::warn_ratelimited!(offset, "Unexpected write offset");
457    }
458
459    /// Return an [`Interrupt`] associated with the particular MSI-X vector, or
460    /// `None` if the index is out of bounds.
461    pub fn interrupt(&self, index: u16) -> Option<Interrupt> {
462        Some(
463            self.state
464                .lock()
465                .vectors
466                .get_mut(index as usize)?
467                .msi
468                .interrupt(),
469        )
470    }
471
472    #[cfg(test)]
473    fn clear_pending_bit(&self, index: u8) {
474        let mut state = self.state.lock();
475        state.vectors[index as usize].state.is_pending = false;
476    }
477
478    /// Sets the pending bit for the given vector index.
479    ///
480    /// Used by device passthrough (e.g., VFIO with irqfd) to record that an
481    /// interrupt arrived while the vector was masked, so PBA reads return
482    /// the correct pending state.
483    pub fn set_pending_bit(&self, index: u16) {
484        let mut state = self.state.lock();
485        if let Some(entry) = state.vectors.get_mut(index as usize) {
486            entry.state.is_pending = true;
487        } else {
488            tracelimit::warn_ratelimited!(
489                index,
490                count = state.vectors.len(),
491                "set_pending_bit: vector index out of range"
492            );
493        }
494    }
495
496    /// Enable kernel-mediated interrupt delivery for all vectors.
497    ///
498    /// Creates one irqfd route per vector using the provided [`IrqFd`]
499    /// interface. The `register` callback receives references to the events
500    /// for all vectors. The caller uses this to pass events to the
501    /// interrupt source (e.g., VFIO `map_msix` or vhost-user
502    /// `SET_VRING_CALL`). After the callback returns, the routes are
503    /// installed in the emulator. When the guest programs MSI-X table
504    /// entries, the emulator automatically updates the kernel's MSI routing.
505    ///
506    /// Call [`disable_irqfd`](Self::disable_irqfd) to tear down the routes.
507    pub fn enable_irqfd(
508        &self,
509        irqfd: &dyn IrqFd,
510        register: impl FnOnce(&[&pal_event::Event]) -> anyhow::Result<()>,
511    ) -> anyhow::Result<()> {
512        let state = self.state.lock();
513        let mut routes: Vec<Box<dyn IrqFdRoute>> = Vec::with_capacity(state.vectors.len());
514        for _ in state.vectors.iter() {
515            routes.push(irqfd.new_irqfd_route()?);
516        }
517
518        // Collect event references while routes are still in our local Vec.
519        let events: Vec<&pal_event::Event> = routes.iter().map(|r| r.event()).collect();
520        register(&events)?;
521
522        // Move routes into the emulator's MsiInterrupts.
523        for (entry, route) in state.vectors.iter().zip(routes) {
524            entry.msi.set_route(route);
525        }
526        Ok(())
527    }
528
529    /// Tear down kernel-mediated interrupt delivery.
530    ///
531    /// Drops all irqfd routes, unregistering them from the hypervisor and
532    /// freeing GSI allocations.
533    pub fn disable_irqfd(&self) {
534        let state = self.state.lock();
535        for entry in &state.vectors {
536            entry.msi.clear_route();
537        }
538    }
539}
540
541mod save_restore {
542    use super::*;
543    use thiserror::Error;
544    use vmcore::save_restore::RestoreError;
545    use vmcore::save_restore::SaveError;
546    use vmcore::save_restore::SaveRestore;
547
548    mod state {
549        use mesh::payload::Protobuf;
550        use vmcore::save_restore::SavedStateRoot;
551
552        #[derive(Debug, Protobuf)]
553        #[mesh(package = "pci.caps.msix")]
554        pub struct SavedMsixMessageTableEntryState {
555            #[mesh(1)]
556            pub address: u64,
557            #[mesh(2)]
558            pub data: u32,
559            #[mesh(3)]
560            pub control: u32,
561            #[mesh(4)]
562            pub is_pending: bool,
563        }
564
565        #[derive(Debug, Protobuf, SavedStateRoot)]
566        #[mesh(package = "pci.caps.msix")]
567        pub struct SavedState {
568            #[mesh(2)]
569            pub enabled: bool,
570            #[mesh(3)]
571            pub vectors: Vec<SavedMsixMessageTableEntryState>,
572        }
573    }
574
575    #[derive(Debug, Error)]
576    enum MsixRestoreError {
577        #[error("mismatched vector lengths: current:{0}, saved:{1}")]
578        MismatchedTableLengths(usize, usize),
579    }
580
581    impl SaveRestore for MsixCapability {
582        type SavedState = state::SavedState;
583
584        fn save(&mut self) -> Result<Self::SavedState, SaveError> {
585            let state = self.state.lock();
586            let saved_state = state::SavedState {
587                enabled: state.enabled,
588                vectors: {
589                    state
590                        .vectors
591                        .iter()
592                        .map(|vec| {
593                            let EntryState {
594                                address,
595                                data,
596                                control,
597                                is_pending,
598                            } = vec.state;
599
600                            state::SavedMsixMessageTableEntryState {
601                                address,
602                                data,
603                                control,
604                                is_pending,
605                            }
606                        })
607                        .collect()
608                },
609            };
610            Ok(saved_state)
611        }
612
613        fn restore(&mut self, state: Self::SavedState) -> Result<(), RestoreError> {
614            let state::SavedState { enabled, vectors } = state;
615
616            let mut state = self.state.lock();
617            state.enabled = enabled;
618
619            if vectors.len() != state.vectors.len() {
620                return Err(RestoreError::InvalidSavedState(
621                    MsixRestoreError::MismatchedTableLengths(vectors.len(), state.vectors.len())
622                        .into(),
623                ));
624            }
625
626            for (new_vec, vec) in vectors.into_iter().zip(state.vectors.iter_mut()) {
627                vec.state = EntryState {
628                    address: new_vec.address,
629                    data: new_vec.data,
630                    control: new_vec.control,
631                    is_pending: new_vec.is_pending,
632                }
633            }
634
635            Ok(())
636        }
637    }
638}
639
640#[cfg(test)]
641mod tests {
642    use super::*;
643    use crate::{msi::MsiConnection, test_helpers::TestPciInterruptController};
644
645    #[test]
646    fn msix_check() {
647        let msi_conn = MsiConnection::new();
648        let (mut msix, mut cap) = MsixEmulator::new(2, 64, msi_conn.target());
649        let msi_controller = TestPciInterruptController::new();
650        msi_conn.connect(msi_controller.signal_msi());
651        // check capabilities
652        assert_eq!(cap.read_u32(0), 0x3f0011);
653        assert_eq!(cap.read_u32(4), 2);
654        assert_eq!(cap.read_u32(8), 0x402);
655        cap.write_u32(0, 0xffffffff);
656        assert_eq!(cap.read_u32(0), 0x803f0011);
657        // check BAR
658        // Vector[0]
659        assert_eq!(msix.read_u32(0), 0);
660        assert_eq!(msix.read_u32(4), 0);
661        assert_eq!(msix.read_u32(8), 0);
662        assert_eq!(msix.read_u32(12), 1);
663        msix.write_u32(0, 0x12345678);
664        msix.write_u32(4, 0x9abcdef0);
665        msix.write_u32(8, 0x123);
666        msix.write_u32(12, 0x456);
667        assert_eq!(msix.read_u32(0), 0x12345678);
668        assert_eq!(msix.read_u32(4), 0x9abcdef0);
669        assert_eq!(msix.read_u32(8), 0x123);
670        assert_eq!(msix.read_u32(12), 0x456);
671        // Vector[63]
672        assert_eq!(msix.read_u32(0x3f0), 0);
673        assert_eq!(msix.read_u32(0x3f4), 0);
674        assert_eq!(msix.read_u32(0x3f8), 0);
675        assert_eq!(msix.read_u32(0x3fc), 1);
676        msix.write_u32(0x3f0, 0x12345678);
677        msix.write_u32(0x3f4, 0x9abcdef0);
678        msix.write_u32(0x3f8, 0x123);
679        msix.write_u32(0x3fc, 0x456);
680        assert_eq!(msix.read_u32(0x3f0), 0x12345678);
681        assert_eq!(msix.read_u32(0x3f4), 0x9abcdef0);
682        assert_eq!(msix.read_u32(0x3f8), 0x123);
683        assert_eq!(msix.read_u32(0x3fc), 0x456);
684        // Pending Bit Array
685        assert_eq!(msix.read_u32(0x400), 0);
686        assert_eq!(msix.read_u32(0x404), 0);
687        msix.set_pending_bit(1);
688        assert_eq!(msix.read_u32(0x400), 2);
689        assert_eq!(msix.read_u32(0x404), 0);
690        msix.set_pending_bit(33);
691        assert_eq!(msix.read_u32(0x400), 2);
692        assert_eq!(msix.read_u32(0x404), 2);
693        msix.set_pending_bit(63);
694        msix.set_pending_bit(31);
695        assert_eq!(msix.read_u32(0x400), 0x80000002);
696        assert_eq!(msix.read_u32(0x404), 0x80000002);
697        msix.clear_pending_bit(1);
698        assert_eq!(msix.read_u32(0x400), 0x80000000);
699        assert_eq!(msix.read_u32(0x404), 0x80000002);
700    }
701
702    use pal_event::Event;
703    use parking_lot::Mutex;
704
705    /// Record of a call made to a [`MockIrqFdRoute`].
706    #[derive(Debug, Clone, PartialEq)]
707    enum RouteCall {
708        SetMsi { address: u64, data: u32 },
709        ClearMsi,
710        Mask,
711        ConsumePending,
712    }
713
714    /// Mock irqfd route that records calls.
715    struct MockIrqFdRoute {
716        event: Event,
717        calls: Arc<Mutex<Vec<RouteCall>>>,
718        pending: Arc<Mutex<bool>>,
719    }
720
721    impl IrqFdRoute for MockIrqFdRoute {
722        fn event(&self) -> &Event {
723            &self.event
724        }
725
726        fn set_msi(&self, address: u64, data: u32) -> anyhow::Result<()> {
727            self.calls.lock().push(RouteCall::SetMsi { address, data });
728            Ok(())
729        }
730
731        fn clear_msi(&self) -> anyhow::Result<()> {
732            self.calls.lock().push(RouteCall::ClearMsi);
733            Ok(())
734        }
735
736        fn mask(&self) -> anyhow::Result<()> {
737            self.calls.lock().push(RouteCall::Mask);
738            Ok(())
739        }
740
741        fn unmask(&self) -> anyhow::Result<()> {
742            Ok(())
743        }
744
745        fn consume_pending(&self) -> bool {
746            self.calls.lock().push(RouteCall::ConsumePending);
747            let mut p = self.pending.lock();
748            let was = *p;
749            *p = false;
750            was
751        }
752    }
753
754    /// Mock IrqFd that creates MockIrqFdRoutes and returns shared call logs.
755    struct MockIrqFd {
756        routes: Mutex<Vec<(Arc<Mutex<Vec<RouteCall>>>, Arc<Mutex<bool>>)>>,
757    }
758
759    impl MockIrqFd {
760        fn new(count: usize) -> (Self, Vec<Arc<Mutex<Vec<RouteCall>>>>, Vec<Arc<Mutex<bool>>>) {
761            let mut call_logs = Vec::new();
762            let mut pendings = Vec::new();
763            let mut route_params = Vec::new();
764            for _ in 0..count {
765                let calls = Arc::new(Mutex::new(Vec::new()));
766                let pending = Arc::new(Mutex::new(false));
767                call_logs.push(calls.clone());
768                pendings.push(pending.clone());
769                route_params.push((calls, pending));
770            }
771            (
772                Self {
773                    routes: Mutex::new(route_params),
774                },
775                call_logs,
776                pendings,
777            )
778        }
779    }
780
781    impl IrqFd for MockIrqFd {
782        fn new_irqfd_route(&self) -> anyhow::Result<Box<dyn IrqFdRoute>> {
783            let (calls, pending) = self.routes.lock().remove(0);
784            Ok(Box::new(MockIrqFdRoute {
785                event: Event::new(),
786                calls,
787                pending,
788            }))
789        }
790    }
791
792    #[test]
793    fn route_set_msi_on_unmask() {
794        let msi_conn = MsiConnection::new();
795        let (mut msix, mut cap) = MsixEmulator::new(2, 2, msi_conn.target());
796        let msi_controller = TestPciInterruptController::new();
797        msi_conn.connect(msi_controller.signal_msi());
798
799        let (mock_irqfd, calls, _pendings) = MockIrqFd::new(2);
800        msix.enable_irqfd(&mock_irqfd, |_| Ok(())).unwrap();
801
802        // Enable MSI-X globally.
803        cap.write_u32(0, 0x80000000);
804
805        // Program vector 0 addr/data (still masked — control starts at 1).
806        msix.write_u32(0, 0xFEE00000); // addr_lo
807        msix.write_u32(4, 0); // addr_hi
808        msix.write_u32(8, 0x42); // data
809
810        // No set_msi yet because vector is still masked.
811        assert!(
812            !calls[0]
813                .lock()
814                .iter()
815                .any(|c| matches!(c, RouteCall::SetMsi { .. }))
816        );
817
818        // Unmask vector 0 (write control = 0).
819        calls[0].lock().clear();
820        msix.write_u32(12, 0);
821
822        // Should have called consume_pending then set_msi.
823        let log = calls[0].lock().clone();
824        assert!(log.contains(&RouteCall::ConsumePending));
825        assert!(log.contains(&RouteCall::SetMsi {
826            address: 0xFEE00000,
827            data: 0x42
828        }));
829    }
830
831    #[test]
832    fn route_mask_on_vector_mask() {
833        let msi_conn = MsiConnection::new();
834        let (mut msix, mut cap) = MsixEmulator::new(2, 2, msi_conn.target());
835        let msi_controller = TestPciInterruptController::new();
836        msi_conn.connect(msi_controller.signal_msi());
837
838        let (mock_irqfd, calls, _pendings) = MockIrqFd::new(2);
839        msix.enable_irqfd(&mock_irqfd, |_| Ok(())).unwrap();
840
841        // Enable MSI-X, program and unmask vector 0.
842        cap.write_u32(0, 0x80000000);
843        msix.write_u32(0, 0xFEE00000);
844        msix.write_u32(8, 0x42);
845        msix.write_u32(12, 0); // unmask
846
847        calls[0].lock().clear();
848
849        // Mask vector 0 (write control = 1).
850        msix.write_u32(12, 1);
851
852        let log = calls[0].lock().clone();
853        assert!(log.contains(&RouteCall::Mask));
854    }
855
856    #[test]
857    fn route_global_disable_masks_all() {
858        let msi_conn = MsiConnection::new();
859        let (mut msix, mut cap) = MsixEmulator::new(2, 2, msi_conn.target());
860        let msi_controller = TestPciInterruptController::new();
861        msi_conn.connect(msi_controller.signal_msi());
862
863        let (mock_irqfd, calls, _pendings) = MockIrqFd::new(2);
864        msix.enable_irqfd(&mock_irqfd, |_| Ok(())).unwrap();
865
866        // Enable, program, and unmask both vectors.
867        cap.write_u32(0, 0x80000000);
868        for v in 0..2u64 {
869            msix.write_u32(v * 16, 0xFEE00000);
870            msix.write_u32(v * 16 + 8, (v + 1) as u32);
871            msix.write_u32(v * 16 + 12, 0); // unmask
872        }
873        calls[0].lock().clear();
874        calls[1].lock().clear();
875
876        // Disable MSI-X globally.
877        cap.write_u32(0, 0);
878
879        // Both vectors should have been masked.
880        assert!(calls[0].lock().contains(&RouteCall::Mask));
881        assert!(calls[1].lock().contains(&RouteCall::Mask));
882    }
883
884    #[test]
885    fn route_consume_pending_on_pba_read() {
886        let msi_conn = MsiConnection::new();
887        let (msix, mut cap) = MsixEmulator::new(2, 2, msi_conn.target());
888        let msi_controller = TestPciInterruptController::new();
889        msi_conn.connect(msi_controller.signal_msi());
890
891        let (mock_irqfd, calls, pendings) = MockIrqFd::new(2);
892        msix.enable_irqfd(&mock_irqfd, |_| Ok(())).unwrap();
893
894        // Enable MSI-X but leave vectors masked (control = 1 by default).
895        cap.write_u32(0, 0x80000000);
896
897        // Simulate a pending interrupt on vector 0.
898        *pendings[0].lock() = true;
899        calls[0].lock().clear();
900
901        // PBA is at offset = vector_count * 16 = 32.
902        let pba = msix.read_u32(32);
903
904        // Should have called consume_pending and returned bit 0 set.
905        assert!(calls[0].lock().contains(&RouteCall::ConsumePending));
906        assert_eq!(pba & 1, 1);
907    }
908
909    #[test]
910    fn route_set_msi_on_addr_data_change_while_unmasked() {
911        let msi_conn = MsiConnection::new();
912        let (mut msix, mut cap) = MsixEmulator::new(2, 1, msi_conn.target());
913        let msi_controller = TestPciInterruptController::new();
914        msi_conn.connect(msi_controller.signal_msi());
915
916        let (mock_irqfd, calls, _pendings) = MockIrqFd::new(1);
917        msix.enable_irqfd(&mock_irqfd, |_| Ok(())).unwrap();
918
919        // Enable, program, unmask.
920        cap.write_u32(0, 0x80000000);
921        msix.write_u32(0, 0xFEE00000);
922        msix.write_u32(8, 0x42);
923        msix.write_u32(12, 0);
924        calls[0].lock().clear();
925
926        // Change data while still unmasked.
927        msix.write_u32(8, 0x99);
928
929        let log = calls[0].lock().clone();
930        assert!(log.contains(&RouteCall::SetMsi {
931            address: 0xFEE00000,
932            data: 0x99
933        }));
934    }
935}