virtio/transport/
pci.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! PCI transport for virtio devices
5
6use self::capabilities::*;
7use crate::QUEUE_MAX_SIZE;
8use crate::QueueResources;
9use crate::Resources;
10use crate::VirtioDevice;
11use crate::VirtioDoorbells;
12use crate::queue::QueueParams;
13use crate::spec::pci::VIRTIO_PCI_COMMON_CFG_SIZE;
14use crate::spec::pci::VIRTIO_PCI_DEVICE_ID_BASE;
15use crate::spec::pci::VIRTIO_VENDOR_ID;
16use crate::spec::pci::VirtioPciCapType;
17use crate::spec::pci::VirtioPciCommonCfg;
18use crate::spec::*;
19use chipset_device::ChipsetDevice;
20use chipset_device::io::IoResult;
21use chipset_device::mmio::MmioIntercept;
22use chipset_device::mmio::RegisterMmioIntercept;
23use chipset_device::pci::PciConfigSpace;
24use chipset_device::poll_device::PollDevice;
25use device_emulators::ReadWriteRequestType;
26use device_emulators::read_as_u32_chunks;
27use device_emulators::write_as_u32_chunks;
28use guestmem::DoorbellRegistration;
29use guestmem::MappedMemoryRegion;
30use guestmem::MemoryMapper;
31use inspect::InspectMut;
32use parking_lot::Mutex;
33use pci_core::PciInterruptPin;
34use pci_core::capabilities::PciCapability;
35use pci_core::capabilities::ReadOnlyCapability;
36use pci_core::capabilities::msix::MsixEmulator;
37use pci_core::cfg_space_emu::BarMemoryKind;
38use pci_core::cfg_space_emu::ConfigSpaceType0Emulator;
39use pci_core::cfg_space_emu::DeviceBars;
40use pci_core::cfg_space_emu::IntxInterrupt;
41use pci_core::msi::MsiTarget;
42use pci_core::spec::hwid::ClassCode;
43use pci_core::spec::hwid::HardwareIds;
44use pci_core::spec::hwid::ProgrammingInterface;
45use pci_core::spec::hwid::Subclass;
46use std::io;
47use std::sync::Arc;
48use vmcore::device_state::ChangeDeviceState;
49use vmcore::interrupt::Interrupt;
50use vmcore::line_interrupt::LineInterrupt;
51use vmcore::save_restore::NoSavedState;
52use vmcore::save_restore::RestoreError;
53use vmcore::save_restore::SaveError;
54use vmcore::save_restore::SaveRestore;
55
56/// What kind of PCI interrupts [`VirtioPciDevice`] should use.
57pub enum PciInterruptModel<'a> {
58    Msix(&'a MsiTarget),
59    IntX(PciInterruptPin, LineInterrupt),
60}
61
62enum InterruptKind {
63    Msix(MsixEmulator),
64    IntX(Arc<IntxInterrupt>),
65}
66
67/// BAR0 layout: common cfg is at offset 0, followed by notify, ISR, and
68/// device-specific config regions.
69const BAR0_NOTIFY_OFFSET: u16 = VIRTIO_PCI_COMMON_CFG_SIZE;
70const BAR0_NOTIFY_SIZE: u16 = 4;
71const BAR0_ISR_OFFSET: u16 = BAR0_NOTIFY_OFFSET + BAR0_NOTIFY_SIZE;
72const BAR0_ISR_SIZE: u16 = 4;
73const BAR0_DEVICE_CFG_OFFSET: u16 = BAR0_ISR_OFFSET + BAR0_ISR_SIZE;
74
75/// Run a virtio device over PCI
76#[derive(InspectMut)]
77pub struct VirtioPciDevice {
78    #[inspect(mut)]
79    device: Box<dyn VirtioDevice>,
80    #[inspect(skip)]
81    device_feature: VirtioDeviceFeatures,
82    #[inspect(hex)]
83    device_feature_select: u32,
84    #[inspect(skip)]
85    driver_feature: VirtioDeviceFeatures,
86    #[inspect(hex)]
87    driver_feature_select: u32,
88    msix_config_vector: u16,
89    queue_select: u32,
90    #[inspect(skip)]
91    events: Vec<pal_event::Event>,
92    #[inspect(skip)]
93    queues: Vec<QueueParams>,
94    #[inspect(skip)]
95    msix_vectors: Vec<u16>,
96    #[inspect(skip)]
97    interrupt_status: Arc<Mutex<u32>>,
98    #[inspect(hex)]
99    device_status: VirtioDeviceStatus,
100    disabling: bool,
101    #[inspect(skip)]
102    poll_waker: Option<std::task::Waker>,
103    config_generation: u32,
104    config_space: ConfigSpaceType0Emulator,
105
106    #[inspect(skip)]
107    interrupt_kind: InterruptKind,
108    #[inspect(skip)]
109    doorbells: VirtioDoorbells,
110    #[inspect(skip)]
111    shared_memory_region: Option<Arc<dyn MappedMemoryRegion>>,
112    #[inspect(hex)]
113    shared_memory_size: u64,
114}
115
116impl VirtioPciDevice {
117    pub fn new(
118        device: Box<dyn VirtioDevice>,
119        interrupt_model: PciInterruptModel<'_>,
120        doorbell_registration: Option<Arc<dyn DoorbellRegistration>>,
121        mmio_registration: &mut dyn RegisterMmioIntercept,
122        shared_mem_mapper: Option<&dyn MemoryMapper>,
123    ) -> io::Result<Self> {
124        let traits = device.traits();
125        let queues = (0..traits.max_queues)
126            .map(|_| QueueParams {
127                size: QUEUE_MAX_SIZE,
128                ..Default::default()
129            })
130            .collect();
131        let events = (0..traits.max_queues)
132            .map(|_| pal_event::Event::new())
133            .collect();
134        let msix_vectors = vec![0; traits.max_queues.into()];
135
136        let hardware_ids = HardwareIds {
137            vendor_id: VIRTIO_VENDOR_ID,
138            device_id: VIRTIO_PCI_DEVICE_ID_BASE + traits.device_id,
139            revision_id: 1,
140            prog_if: ProgrammingInterface::NONE,
141            base_class: ClassCode::BASE_SYSTEM_PERIPHERAL,
142            sub_class: Subclass::BASE_SYSTEM_PERIPHERAL_OTHER,
143            type0_sub_vendor_id: VIRTIO_VENDOR_ID,
144            type0_sub_system_id: 0x40,
145        };
146
147        let mut caps: Vec<Box<dyn PciCapability>> = vec![
148            Box::new(ReadOnlyCapability::new(
149                "virtio-common",
150                VirtioCapability::new(
151                    VirtioPciCapType::COMMON_CFG.0,
152                    0,
153                    0,
154                    0,
155                    VIRTIO_PCI_COMMON_CFG_SIZE as u32,
156                ),
157            )),
158            Box::new(ReadOnlyCapability::new(
159                "virtio-notify",
160                VirtioNotifyCapability::new(
161                    0,
162                    0,
163                    BAR0_NOTIFY_OFFSET as u32,
164                    BAR0_NOTIFY_SIZE as u32,
165                ),
166            )),
167            Box::new(ReadOnlyCapability::new(
168                "virtio-pci-isr",
169                VirtioCapability::new(
170                    VirtioPciCapType::ISR_CFG.0,
171                    0,
172                    0,
173                    BAR0_ISR_OFFSET as u32,
174                    BAR0_ISR_SIZE as u32,
175                ),
176            )),
177            Box::new(ReadOnlyCapability::new(
178                "virtio-pci-device",
179                VirtioCapability::new(
180                    VirtioPciCapType::DEVICE_CFG.0,
181                    0,
182                    0,
183                    BAR0_DEVICE_CFG_OFFSET as u32,
184                    traits.device_register_length,
185                ),
186            )),
187        ];
188
189        let mut bars = DeviceBars::new().bar0(
190            BAR0_DEVICE_CFG_OFFSET as u64 + traits.device_register_length as u64,
191            BarMemoryKind::Intercept(mmio_registration.new_io_region(
192                "config",
193                BAR0_DEVICE_CFG_OFFSET as u64 + traits.device_register_length as u64,
194            )),
195        );
196
197        let msix: Option<MsixEmulator> = if let PciInterruptModel::Msix(msi_target) =
198            interrupt_model
199        {
200            let (msix, msix_capability) = MsixEmulator::new(2, 64, msi_target);
201            // setting msix as the first cap so that we don't have to update unit tests
202            // i.e: there's no reason why this can't be a .push() instead of .insert()
203            caps.insert(0, Box::new(msix_capability));
204            bars = bars.bar2(
205                msix.bar_len(),
206                BarMemoryKind::Intercept(mmio_registration.new_io_region("msix", msix.bar_len())),
207            );
208            Some(msix)
209        } else {
210            None
211        };
212
213        let shared_memory_size = traits.shared_memory.size;
214        let mut shared_memory_region = None;
215        if shared_memory_size > 0 {
216            let (control, region) = shared_mem_mapper
217                .expect("must provide mapper for shmem")
218                .new_region(
219                    shared_memory_size.try_into().expect("region too big"),
220                    "virtio-pci-shmem".into(),
221                )?;
222
223            caps.push(Box::new(ReadOnlyCapability::new(
224                "virtio-pci-shm",
225                VirtioCapability64::new(
226                    VirtioPciCapType::SHARED_MEMORY_CFG.0,
227                    4, // BAR 4
228                    traits.shared_memory.id,
229                    0,
230                    shared_memory_size,
231                ),
232            )));
233
234            bars = bars.bar4(shared_memory_size, BarMemoryKind::SharedMem(control));
235            shared_memory_region = Some(region);
236        }
237
238        let mut config_space = ConfigSpaceType0Emulator::new(hardware_ids, caps, bars);
239        let interrupt_kind = match interrupt_model {
240            PciInterruptModel::Msix(_) => InterruptKind::Msix(msix.unwrap()),
241            PciInterruptModel::IntX(pin, line) => {
242                InterruptKind::IntX(config_space.set_interrupt_pin(pin, line))
243            }
244        };
245
246        let mut device_feature = traits.device_features.clone();
247        device_feature.set_bank(
248            0,
249            device_feature
250                .bank0()
251                .with_ring_event_idx(true)
252                .with_ring_indirect_desc(true)
253                .into_bits(),
254        );
255        device_feature.set_bank(1, device_feature.bank1().with_version_1(true).into_bits());
256        Ok(VirtioPciDevice {
257            device,
258            device_feature,
259            device_feature_select: 0,
260            driver_feature: VirtioDeviceFeatures::new(),
261            driver_feature_select: 0,
262            msix_config_vector: 0,
263            queue_select: 0,
264            events,
265            queues,
266            msix_vectors,
267            interrupt_status: Arc::new(Mutex::new(0)),
268            device_status: VirtioDeviceStatus::new(),
269            disabling: false,
270            poll_waker: None,
271            config_generation: 0,
272            interrupt_kind,
273            config_space,
274            doorbells: VirtioDoorbells::new(doorbell_registration),
275            shared_memory_region,
276            shared_memory_size,
277        })
278    }
279
280    fn update_config_generation(&mut self) {
281        self.config_generation = self.config_generation.wrapping_add(1);
282        if self.device_status.driver_ok() {
283            *self.interrupt_status.lock() |= 2;
284            match &self.interrupt_kind {
285                InterruptKind::Msix(msix) => {
286                    if let Some(interrupt) = msix.interrupt(self.msix_config_vector) {
287                        interrupt.deliver();
288                    }
289                }
290                InterruptKind::IntX(line) => line.set_level(true),
291            }
292        }
293    }
294
295    fn read_u32(&mut self, offset: u16) -> u32 {
296        assert!(offset & 3 == 0);
297        let queue_select = self.queue_select as usize;
298        match VirtioPciCommonCfg(offset) {
299            VirtioPciCommonCfg::DEVICE_FEATURE_SELECT => self.device_feature_select,
300            VirtioPciCommonCfg::DEVICE_FEATURE => {
301                let feature_select = self.device_feature_select as usize;
302                self.device_feature.bank(feature_select)
303            }
304            VirtioPciCommonCfg::DRIVER_FEATURE_SELECT => self.driver_feature_select,
305            VirtioPciCommonCfg::DRIVER_FEATURE => {
306                let feature_select = self.driver_feature_select as usize;
307                self.driver_feature.bank(feature_select)
308            }
309            VirtioPciCommonCfg::MSIX_CONFIG => {
310                (self.queues.len() as u32) << 16 | self.msix_config_vector as u32
311            }
312            VirtioPciCommonCfg::DEVICE_STATUS => {
313                self.queue_select << 24 | self.config_generation << 8 | self.device_status.as_u32()
314            }
315            VirtioPciCommonCfg::QUEUE_SIZE => {
316                let size = if queue_select < self.queues.len() {
317                    self.queues[queue_select].size
318                } else {
319                    0
320                };
321                let msix_vector = self.msix_vectors.get(queue_select).copied().unwrap_or(0);
322                (msix_vector as u32) << 16 | size as u32
323            }
324            VirtioPciCommonCfg::QUEUE_ENABLE => {
325                let enable = if queue_select < self.queues.len() {
326                    if self.queues[queue_select].enable {
327                        1
328                    } else {
329                        0
330                    }
331                } else {
332                    0
333                };
334                #[expect(clippy::if_same_then_else)] // fix when TODO is resolved
335                let notify_offset = if queue_select < self.queues.len() {
336                    0 // TODO: when should this be non-zero? ever?
337                } else {
338                    0
339                };
340                (notify_offset as u32) << 16 | enable as u32
341            }
342            VirtioPciCommonCfg::QUEUE_DESC_LO => {
343                if queue_select < self.queues.len() {
344                    self.queues[queue_select].desc_addr as u32
345                } else {
346                    0
347                }
348            }
349            VirtioPciCommonCfg::QUEUE_DESC_HI => {
350                if queue_select < self.queues.len() {
351                    (self.queues[queue_select].desc_addr >> 32) as u32
352                } else {
353                    0
354                }
355            }
356            VirtioPciCommonCfg::QUEUE_AVAIL_LO => {
357                if queue_select < self.queues.len() {
358                    self.queues[queue_select].avail_addr as u32
359                } else {
360                    0
361                }
362            }
363            VirtioPciCommonCfg::QUEUE_AVAIL_HI => {
364                if queue_select < self.queues.len() {
365                    (self.queues[queue_select].avail_addr >> 32) as u32
366                } else {
367                    0
368                }
369            }
370            VirtioPciCommonCfg::QUEUE_USED_LO => {
371                if queue_select < self.queues.len() {
372                    self.queues[queue_select].used_addr as u32
373                } else {
374                    0
375                }
376            }
377            VirtioPciCommonCfg::QUEUE_USED_HI => {
378                if queue_select < self.queues.len() {
379                    (self.queues[queue_select].used_addr >> 32) as u32
380                } else {
381                    0
382                }
383            }
384            VirtioPciCommonCfg(BAR0_NOTIFY_OFFSET) => 0,
385            VirtioPciCommonCfg(BAR0_ISR_OFFSET) => {
386                let mut interrupt_status = self.interrupt_status.lock();
387                let status = *interrupt_status;
388                *interrupt_status = 0;
389                if let InterruptKind::IntX(line) = &self.interrupt_kind {
390                    line.set_level(false)
391                }
392                status
393            }
394            VirtioPciCommonCfg(offset) if offset >= BAR0_DEVICE_CFG_OFFSET => self
395                .device
396                .read_registers_u32(offset - BAR0_DEVICE_CFG_OFFSET),
397            _ => {
398                tracing::warn!(offset, "unknown bar read");
399                0xffffffff
400            }
401        }
402    }
403
404    fn write_u32(&mut self, address: u64, offset: u16, val: u32) {
405        assert!(offset & 3 == 0);
406        let queues_locked = self.device_status.driver_ok();
407        let features_locked = queues_locked || self.device_status.features_ok();
408        let queue_select = self.queue_select as usize;
409        match VirtioPciCommonCfg(offset) {
410            VirtioPciCommonCfg::DEVICE_FEATURE_SELECT => self.device_feature_select = val,
411            VirtioPciCommonCfg::DRIVER_FEATURE_SELECT => self.driver_feature_select = val,
412            VirtioPciCommonCfg::DRIVER_FEATURE => {
413                let bank = self.driver_feature_select as usize;
414                if features_locked || bank >= self.device_feature.len() {
415                    // Update is not persisted.
416                } else {
417                    self.driver_feature
418                        .set_bank(bank, val & self.device_feature.bank(bank));
419                }
420            }
421            VirtioPciCommonCfg::MSIX_CONFIG => self.msix_config_vector = val as u16,
422            VirtioPciCommonCfg::DEVICE_STATUS => {
423                self.queue_select = val >> 16;
424                let val = val & 0xff;
425                if val == 0 {
426                    if self.disabling {
427                        return;
428                    }
429                    let started = self.device_status.driver_ok();
430                    self.config_generation = 0;
431                    if started {
432                        self.doorbells.clear();
433                        // Try the fast path: poll with a noop waker to see if
434                        // the device can disable synchronously.
435                        let waker = std::task::Waker::noop();
436                        let mut cx = std::task::Context::from_waker(waker);
437                        if self.device.poll_disable(&mut cx).is_pending() {
438                            self.disabling = true;
439                            // Wake the real poll waker so that poll_device will
440                            // re-poll with a real waker, replacing the noop one.
441                            if let Some(waker) = self.poll_waker.take() {
442                                waker.wake();
443                            }
444                            return;
445                        }
446                    }
447                    // Fast path: disable completed synchronously.
448                    self.device_status = VirtioDeviceStatus::new();
449                    *self.interrupt_status.lock() = 0;
450                    return;
451                }
452
453                let new_status = VirtioDeviceStatus::from(val as u8);
454                if new_status.acknowledge() {
455                    self.device_status.set_acknowledge(true);
456                }
457                if new_status.driver() {
458                    self.device_status.set_driver(true);
459                }
460                if new_status.failed() {
461                    self.device_status.set_failed(true);
462                }
463
464                if !self.device_status.features_ok() && new_status.features_ok() {
465                    self.device_status.set_features_ok(true);
466                    self.update_config_generation();
467                }
468
469                if !self.device_status.driver_ok() && new_status.driver_ok() {
470                    let notification_address = (address & !0xfff) + BAR0_NOTIFY_OFFSET as u64;
471                    for i in 0..self.events.len() {
472                        self.doorbells.add(
473                            notification_address,
474                            Some(i as u64),
475                            Some(2),
476                            &self.events[i],
477                        );
478                    }
479                    let queues = self
480                        .queues
481                        .iter()
482                        .zip(self.msix_vectors.iter().copied())
483                        .zip(self.events.iter().cloned())
484                        .map(|((queue, vector), event)| {
485                            let notify = match &self.interrupt_kind {
486                                InterruptKind::Msix(msix) => {
487                                    if let Some(interrupt) = msix.interrupt(vector) {
488                                        interrupt
489                                    } else {
490                                        tracing::warn!(vector, "invalid MSIx vector specified");
491                                        Interrupt::null()
492                                    }
493                                }
494                                InterruptKind::IntX(line) => {
495                                    let interrupt_status = self.interrupt_status.clone();
496                                    let line = line.clone();
497                                    Interrupt::from_fn(move || {
498                                        *interrupt_status.lock() |= 1;
499                                        line.set_level(true);
500                                    })
501                                }
502                            };
503
504                            QueueResources {
505                                params: *queue,
506                                notify,
507                                event,
508                            }
509                        })
510                        .collect();
511
512                    match self.device.enable(Resources {
513                        features: self.driver_feature.clone(),
514                        queues,
515                        shared_memory_region: self.shared_memory_region.clone(),
516                        shared_memory_size: self.shared_memory_size,
517                    }) {
518                        Ok(()) => {
519                            self.device_status.set_driver_ok(true);
520                        }
521                        Err(err) => {
522                            self.doorbells.clear();
523                            // FUTURE: consider setting DEVICE_NEEDS_RESET and
524                            // delivering a config change interrupt so the guest
525                            // can detect the failure proactively instead of
526                            // waiting for IO timeouts.
527                            tracelimit::error_ratelimited!(
528                                error = &*err as &dyn std::error::Error,
529                                "virtio device enable failed"
530                            );
531                        }
532                    }
533                    self.update_config_generation();
534                }
535            }
536            VirtioPciCommonCfg::QUEUE_SIZE => {
537                let msix_vector = (val >> 16) as u16;
538                if !queues_locked && queue_select < self.queues.len() {
539                    let val = val as u16;
540                    let queue = &mut self.queues[queue_select];
541                    if val > QUEUE_MAX_SIZE {
542                        queue.size = QUEUE_MAX_SIZE;
543                    } else {
544                        queue.size = val;
545                    }
546                    self.msix_vectors[queue_select] = msix_vector;
547                }
548            }
549            VirtioPciCommonCfg::QUEUE_ENABLE => {
550                let val = val & 0xffff;
551                if !queues_locked && queue_select < self.queues.len() {
552                    let queue = &mut self.queues[queue_select];
553                    queue.enable = val != 0;
554                }
555            }
556            VirtioPciCommonCfg::QUEUE_DESC_LO => {
557                if !queues_locked && queue_select < self.queues.len() {
558                    let queue = &mut self.queues[queue_select];
559                    queue.desc_addr = queue.desc_addr & 0xffffffff00000000 | val as u64;
560                }
561            }
562            VirtioPciCommonCfg::QUEUE_DESC_HI => {
563                if !queues_locked && queue_select < self.queues.len() {
564                    let queue = &mut self.queues[queue_select];
565                    queue.desc_addr = (val as u64) << 32 | queue.desc_addr & 0xffffffff;
566                }
567            }
568            VirtioPciCommonCfg::QUEUE_AVAIL_LO => {
569                if !queues_locked && queue_select < self.queues.len() {
570                    let queue = &mut self.queues[queue_select];
571                    queue.avail_addr = queue.avail_addr & 0xffffffff00000000 | val as u64;
572                }
573            }
574            VirtioPciCommonCfg::QUEUE_AVAIL_HI => {
575                if !queues_locked && queue_select < self.queues.len() {
576                    let queue = &mut self.queues[queue_select];
577                    queue.avail_addr = (val as u64) << 32 | queue.avail_addr & 0xffffffff;
578                }
579            }
580            VirtioPciCommonCfg::QUEUE_USED_LO => {
581                if !queues_locked && (queue_select) < self.queues.len() {
582                    let queue = &mut self.queues[queue_select];
583                    queue.used_addr = queue.used_addr & 0xffffffff00000000 | val as u64;
584                }
585            }
586            VirtioPciCommonCfg::QUEUE_USED_HI => {
587                if !queues_locked && queue_select < self.queues.len() {
588                    let queue = &mut self.queues[queue_select];
589                    queue.used_addr = (val as u64) << 32 | queue.used_addr & 0xffffffff;
590                }
591            }
592            VirtioPciCommonCfg(BAR0_NOTIFY_OFFSET) => {
593                if (val as usize) < self.events.len() {
594                    self.events[val as usize].signal();
595                }
596            }
597            VirtioPciCommonCfg(offset) if offset >= BAR0_DEVICE_CFG_OFFSET => self
598                .device
599                .write_registers_u32(offset - BAR0_DEVICE_CFG_OFFSET, val),
600            _ => {
601                tracing::warn!(offset, "unknown bar write at offset");
602            }
603        }
604    }
605}
606
607impl VirtioPciDevice {
608    fn read_bar_u32(&mut self, bar: u8, offset: u16) -> u32 {
609        match bar {
610            0 => self.read_u32(offset),
611            2 => {
612                if let InterruptKind::Msix(msix) = &self.interrupt_kind {
613                    msix.read_u32(offset)
614                } else {
615                    !0
616                }
617            }
618            _ => !0,
619        }
620    }
621
622    fn write_bar_u32(&mut self, address: u64, bar: u8, offset: u16, value: u32) {
623        match bar {
624            0 => self.write_u32(address, offset, value),
625            2 => {
626                if let InterruptKind::Msix(msix) = &mut self.interrupt_kind {
627                    msix.write_u32(offset, value)
628                }
629            }
630            _ => tracing::warn!(bar, offset, "Unknown write"),
631        }
632    }
633}
634
635impl ChangeDeviceState for VirtioPciDevice {
636    fn start(&mut self) {}
637
638    async fn stop(&mut self) {}
639
640    async fn reset(&mut self) {
641        if self.device_status.driver_ok() || self.disabling {
642            self.doorbells.clear();
643            std::future::poll_fn(|cx| self.device.poll_disable(cx)).await;
644        }
645        self.device_status = VirtioDeviceStatus::new();
646        self.disabling = false;
647        self.config_generation = 0;
648        *self.interrupt_status.lock() = 0;
649    }
650}
651
652impl PollDevice for VirtioPciDevice {
653    fn poll_device(&mut self, cx: &mut std::task::Context<'_>) {
654        self.poll_waker = Some(cx.waker().clone());
655        if self.disabling {
656            if self.device.poll_disable(cx).is_ready() {
657                self.device_status = VirtioDeviceStatus::new();
658                self.disabling = false;
659                *self.interrupt_status.lock() = 0;
660            }
661        }
662    }
663}
664
665impl ChipsetDevice for VirtioPciDevice {
666    fn supports_mmio(&mut self) -> Option<&mut dyn MmioIntercept> {
667        Some(self)
668    }
669
670    fn supports_pci(&mut self) -> Option<&mut dyn PciConfigSpace> {
671        Some(self)
672    }
673
674    fn supports_poll_device(&mut self) -> Option<&mut dyn PollDevice> {
675        Some(self)
676    }
677}
678
679impl SaveRestore for VirtioPciDevice {
680    type SavedState = NoSavedState; // TODO
681
682    fn save(&mut self) -> Result<Self::SavedState, SaveError> {
683        Ok(NoSavedState)
684    }
685
686    fn restore(&mut self, NoSavedState: Self::SavedState) -> Result<(), RestoreError> {
687        Ok(())
688    }
689}
690
691impl MmioIntercept for VirtioPciDevice {
692    fn mmio_read(&mut self, address: u64, data: &mut [u8]) -> IoResult {
693        if let Some((bar, offset)) = self.config_space.find_bar(address) {
694            read_as_u32_chunks(offset, data, |offset| self.read_bar_u32(bar, offset))
695        }
696        IoResult::Ok
697    }
698
699    fn mmio_write(&mut self, address: u64, data: &[u8]) -> IoResult {
700        if let Some((bar, offset)) = self.config_space.find_bar(address) {
701            write_as_u32_chunks(offset, data, |offset, request_type| match request_type {
702                ReadWriteRequestType::Write(value) => {
703                    self.write_bar_u32(address, bar, offset, value);
704                    None
705                }
706                ReadWriteRequestType::Read => Some(self.read_bar_u32(bar, offset)),
707            })
708        }
709        IoResult::Ok
710    }
711}
712
713impl PciConfigSpace for VirtioPciDevice {
714    fn pci_cfg_read(&mut self, offset: u16, value: &mut u32) -> IoResult {
715        self.config_space.read_u32(offset, value)
716    }
717
718    fn pci_cfg_write(&mut self, offset: u16, value: u32) -> IoResult {
719        self.config_space.write_u32(offset, value)
720    }
721}
722
723pub(crate) mod capabilities {
724    use crate::spec::pci::VirtioPciCapType;
725    use pci_core::spec::caps::CapabilityId;
726
727    use zerocopy::Immutable;
728    use zerocopy::IntoBytes;
729    use zerocopy::KnownLayout;
730
731    #[repr(C)]
732    #[derive(Debug, IntoBytes, Immutable, KnownLayout)]
733    pub struct VirtioCapabilityCommon {
734        cap_id: u8,
735        cap_next: u8,
736        len: u8,
737        typ: u8,
738        bar: u8,
739        unique_id: u8,
740        padding: [u8; 2],
741        offset: u32,
742        length: u32,
743    }
744
745    impl VirtioCapabilityCommon {
746        pub fn new(len: u8, typ: u8, bar: u8, unique_id: u8, addr_off: u32, addr_len: u32) -> Self {
747            Self {
748                cap_id: CapabilityId::VENDOR_SPECIFIC.0,
749                cap_next: 0,
750                len,
751                typ,
752                bar,
753                unique_id,
754                padding: [0; 2],
755                offset: addr_off,
756                length: addr_len,
757            }
758        }
759    }
760
761    #[repr(C)]
762    #[derive(Debug, IntoBytes, Immutable, KnownLayout)]
763    pub struct VirtioCapability {
764        common: VirtioCapabilityCommon,
765    }
766
767    impl VirtioCapability {
768        pub fn new(typ: u8, bar: u8, unique_id: u8, addr_off: u32, addr_len: u32) -> Self {
769            Self {
770                common: VirtioCapabilityCommon::new(
771                    size_of::<Self>() as u8,
772                    typ,
773                    bar,
774                    unique_id,
775                    addr_off,
776                    addr_len,
777                ),
778            }
779        }
780    }
781
782    #[repr(C)]
783    #[derive(Debug, IntoBytes, Immutable, KnownLayout)]
784    pub struct VirtioCapability64 {
785        common: VirtioCapabilityCommon,
786        offset_hi: u32,
787        length_hi: u32,
788    }
789
790    impl VirtioCapability64 {
791        pub fn new(typ: u8, bar: u8, unique_id: u8, addr_off: u64, addr_len: u64) -> Self {
792            Self {
793                common: VirtioCapabilityCommon::new(
794                    size_of::<Self>() as u8,
795                    typ,
796                    bar,
797                    unique_id,
798                    addr_off as u32,
799                    addr_len as u32,
800                ),
801                offset_hi: (addr_off >> 32) as u32,
802                length_hi: (addr_len >> 32) as u32,
803            }
804        }
805    }
806
807    #[repr(C)]
808    #[derive(Debug, IntoBytes, Immutable, KnownLayout)]
809    pub struct VirtioNotifyCapability {
810        common: VirtioCapabilityCommon,
811        offset_multiplier: u32,
812    }
813
814    impl VirtioNotifyCapability {
815        pub fn new(offset_multiplier: u32, bar: u8, addr_off: u32, addr_len: u32) -> Self {
816            Self {
817                common: VirtioCapabilityCommon::new(
818                    size_of::<Self>() as u8,
819                    VirtioPciCapType::NOTIFY_CFG.0,
820                    bar,
821                    0,
822                    addr_off,
823                    addr_len,
824                ),
825                offset_multiplier,
826            }
827        }
828    }
829
830    #[cfg(test)]
831    mod tests {
832        use super::*;
833        use pci_core::capabilities::PciCapability;
834        use pci_core::capabilities::ReadOnlyCapability;
835
836        #[test]
837        fn common_check() {
838            let common =
839                ReadOnlyCapability::new("common", VirtioCapability::new(0x13, 2, 0, 0x100, 0x200));
840            assert_eq!(common.read_u32(0), 0x13100009);
841            assert_eq!(common.read_u32(4), 2);
842            assert_eq!(common.read_u32(8), 0x100);
843            assert_eq!(common.read_u32(12), 0x200);
844        }
845
846        #[test]
847        fn notify_check() {
848            let notify = ReadOnlyCapability::new(
849                "notify",
850                VirtioNotifyCapability::new(0x123, 2, 0x100, 0x200),
851            );
852            assert_eq!(notify.read_u32(0), 0x2140009);
853            assert_eq!(notify.read_u32(4), 2);
854            assert_eq!(notify.read_u32(8), 0x100);
855            assert_eq!(notify.read_u32(12), 0x200);
856        }
857    }
858}