nvme_driver/
queue_pair.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Implementation of an admin or IO queue pair.
5
6use super::spec;
7use crate::driver::save_restore::AerHandlerSavedState;
8use crate::driver::save_restore::Error;
9use crate::driver::save_restore::PendingCommandSavedState;
10use crate::driver::save_restore::PendingCommandsSavedState;
11use crate::driver::save_restore::QueueHandlerSavedState;
12use crate::driver::save_restore::QueuePairSavedState;
13use crate::queues::CompletionQueue;
14use crate::queues::SubmissionQueue;
15use crate::registers::DeviceRegisters;
16use anyhow::Context;
17use futures::StreamExt;
18use guestmem::GuestMemory;
19use guestmem::GuestMemoryError;
20use guestmem::ranges::PagedRange;
21use inspect::Inspect;
22use inspect_counters::Counter;
23use mesh::rpc::Rpc;
24use mesh::rpc::RpcError;
25use mesh::rpc::RpcSend;
26use nvme_spec::AsynchronousEventRequestDw0;
27use pal_async::driver::SpawnDriver;
28use safeatomic::AtomicSliceOps;
29use slab::Slab;
30use std::future::poll_fn;
31use std::num::Wrapping;
32use std::sync::Arc;
33use std::task::Poll;
34use task_control::AsyncRun;
35use task_control::TaskControl;
36use thiserror::Error;
37use user_driver::DeviceBacking;
38use user_driver::interrupt::DeviceInterrupt;
39use user_driver::memory::MemoryBlock;
40use user_driver::memory::PAGE_SIZE;
41use user_driver::memory::PAGE_SIZE64;
42use user_driver::page_allocator::PageAllocator;
43use user_driver::page_allocator::ScopedPages;
44use zerocopy::FromZeros;
45
46/// Value for unused PRP entries, to catch/mitigate buffer size mismatches.
47const INVALID_PAGE_ADDR: u64 = !(PAGE_SIZE as u64 - 1);
48
49const SQ_ENTRY_SIZE: usize = size_of::<spec::Command>();
50const CQ_ENTRY_SIZE: usize = size_of::<spec::Completion>();
51/// Submission Queue size in bytes.
52const SQ_SIZE: usize = PAGE_SIZE * 4;
53/// Completion Queue size in bytes.
54const CQ_SIZE: usize = PAGE_SIZE;
55/// Maximum SQ size in entries.
56pub const MAX_SQ_ENTRIES: u16 = (SQ_SIZE / SQ_ENTRY_SIZE) as u16;
57/// Maximum CQ size in entries.
58pub const MAX_CQ_ENTRIES: u16 = (CQ_SIZE / CQ_ENTRY_SIZE) as u16;
59/// Number of pages per queue if bounce buffering.
60const PER_QUEUE_PAGES_BOUNCE_BUFFER: usize = 128;
61/// Number of pages per queue if not bounce buffering.
62const PER_QUEUE_PAGES_NO_BOUNCE_BUFFER: usize = 64;
63/// Number of SQ entries per page (64).
64const SQ_ENTRIES_PER_PAGE: usize = PAGE_SIZE / SQ_ENTRY_SIZE;
65
66#[derive(Inspect)]
67pub(crate) struct QueuePair<A: AerHandler, D: DeviceBacking> {
68    #[inspect(skip)]
69    task: TaskControl<QueueHandlerLoop<A, D>, ()>,
70    #[inspect(flatten, with = "|x| inspect::send(&x.send_req, Req::Inspect)")]
71    issuer: Arc<Issuer>,
72    #[inspect(skip)]
73    mem: MemoryBlock,
74    #[inspect(skip)]
75    device_id: String,
76    #[inspect(skip)]
77    qid: u16,
78    #[inspect(skip)]
79    sq_entries: u16,
80    #[inspect(skip)]
81    cq_entries: u16,
82    sq_addr: u64,
83    cq_addr: u64,
84}
85
86impl PendingCommands {
87    const CID_KEY_BITS: u32 = 10;
88    const CID_KEY_MASK: u16 = (1 << Self::CID_KEY_BITS) - 1;
89    const MAX_CIDS: usize = 1 << Self::CID_KEY_BITS;
90    const CID_SEQ_OFFSET: Wrapping<u16> = Wrapping(1 << Self::CID_KEY_BITS);
91
92    fn new(qid: u16) -> Self {
93        Self {
94            commands: Slab::new(),
95            next_cid_high_bits: Wrapping(0),
96            qid,
97        }
98    }
99
100    fn is_full(&self) -> bool {
101        self.commands.len() >= Self::MAX_CIDS
102    }
103
104    fn is_empty(&self) -> bool {
105        self.commands.is_empty()
106    }
107
108    fn len(&self) -> usize {
109        self.commands.len()
110    }
111
112    /// Inserts a command into the pending list, updating it with a new CID.
113    fn insert(&mut self, command: &mut spec::Command, respond: Rpc<(), spec::Completion>) {
114        let entry = self.commands.vacant_entry();
115        assert!(entry.key() < Self::MAX_CIDS);
116        assert_eq!(self.next_cid_high_bits % Self::CID_SEQ_OFFSET, Wrapping(0));
117        let cid = entry.key() as u16 | self.next_cid_high_bits.0;
118        self.next_cid_high_bits += Self::CID_SEQ_OFFSET;
119        command.cdw0.set_cid(cid);
120        entry.insert(PendingCommand {
121            command: *command,
122            respond,
123        });
124    }
125
126    fn remove(&mut self, cid: u16) -> Rpc<(), spec::Completion> {
127        let command = self
128            .commands
129            .try_remove((cid & Self::CID_KEY_MASK) as usize)
130            .unwrap_or_else(|| panic!("completion for unknown cid: qid={}, cid={}", self.qid, cid));
131        assert_eq!(
132            command.command.cdw0.cid(),
133            cid,
134            "cid sequence number mismatch: qid={}, command_opcode={:#x}",
135            self.qid,
136            command.command.cdw0.opcode(),
137        );
138        command.respond
139    }
140
141    /// Save pending commands into a buffer.
142    pub fn save(&self) -> PendingCommandsSavedState {
143        let commands: Vec<PendingCommandSavedState> = self
144            .commands
145            .iter()
146            .map(|(_index, cmd)| PendingCommandSavedState {
147                command: cmd.command,
148            })
149            .collect();
150        PendingCommandsSavedState {
151            commands,
152            next_cid_high_bits: self.next_cid_high_bits.0,
153            // TODO: Not used today, added for future compatibility.
154            cid_key_bits: Self::CID_KEY_BITS,
155        }
156    }
157
158    /// Restore pending commands from the saved state.
159    pub fn restore(saved_state: &PendingCommandsSavedState, qid: u16) -> anyhow::Result<Self> {
160        let PendingCommandsSavedState {
161            commands,
162            next_cid_high_bits,
163            cid_key_bits: _, // TODO: For future use.
164        } = saved_state;
165
166        Ok(Self {
167            // Re-create identical Slab where CIDs are correctly mapped.
168            commands: commands
169                .iter()
170                .map(|state| {
171                    // To correctly restore Slab we need both the command index,
172                    // inherited from command's CID, and the command itself.
173                    (
174                        // Remove high CID bits to be used as a key.
175                        (state.command.cdw0.cid() & Self::CID_KEY_MASK) as usize,
176                        PendingCommand {
177                            command: state.command,
178                            respond: Rpc::detached(()),
179                        },
180                    )
181                })
182                .collect::<Slab<PendingCommand>>(),
183            next_cid_high_bits: Wrapping(*next_cid_high_bits),
184            qid,
185        })
186    }
187}
188
189struct QueueHandlerLoop<A: AerHandler, D: DeviceBacking> {
190    queue_handler: QueueHandler<A>,
191    registers: Arc<DeviceRegisters<D>>,
192    recv_req: Option<mesh::Receiver<Req>>,
193    recv_cmd: Option<mesh::Receiver<Cmd>>,
194    interrupt: DeviceInterrupt,
195}
196
197impl<A: AerHandler, D: DeviceBacking> AsyncRun<()> for QueueHandlerLoop<A, D> {
198    async fn run(
199        &mut self,
200        stop: &mut task_control::StopTask<'_>,
201        _: &mut (),
202    ) -> Result<(), task_control::Cancelled> {
203        stop.until_stopped(async {
204            self.queue_handler
205                .run(
206                    &self.registers,
207                    self.recv_req.take().unwrap(),
208                    self.recv_cmd.take().unwrap(),
209                    &mut self.interrupt,
210                )
211                .await;
212        })
213        .await
214    }
215}
216
217impl<A: AerHandler, D: DeviceBacking> QueuePair<A, D> {
218    /// Create a new queue pair.
219    ///
220    /// `sq_entries` and `cq_entries` are the requested sizes in entries.
221    /// Calling code should request the largest size it thinks the device
222    /// will support (see `CAP.MQES`). These may be clamped down to what will
223    /// fit in one page should this routine fail to allocate physically
224    /// contiguous memory to back the queues.
225    /// IMPORTANT: Calling code should check the actual sizes via corresponding
226    /// calls to [`QueuePair::sq_entries`] and [`QueuePair::cq_entries`] AFTER calling this routine.
227    pub fn new(
228        spawner: impl SpawnDriver,
229        device: &D,
230        qid: u16,
231        sq_entries: u16,
232        cq_entries: u16,
233        interrupt: DeviceInterrupt,
234        registers: Arc<DeviceRegisters<D>>,
235        bounce_buffer: bool,
236        aer_handler: A,
237    ) -> anyhow::Result<Self> {
238        // FUTURE: Consider splitting this into several allocations, rather than
239        // allocating the sum total together. This can increase the likelihood
240        // of getting contiguous memory when falling back to the LockedMem
241        // allocator, but this is not the expected path. Be careful that any
242        // changes you make here work with already established save state.
243        let total_size = SQ_SIZE
244            + CQ_SIZE
245            + if bounce_buffer {
246                PER_QUEUE_PAGES_BOUNCE_BUFFER * PAGE_SIZE
247            } else {
248                PER_QUEUE_PAGES_NO_BOUNCE_BUFFER * PAGE_SIZE
249            };
250        let dma_client = device.dma_client();
251
252        let mem = dma_client
253            .allocate_dma_buffer(total_size)
254            .context("failed to allocate memory for queues")?;
255
256        assert!(sq_entries <= MAX_SQ_ENTRIES);
257        assert!(cq_entries <= MAX_CQ_ENTRIES);
258
259        QueuePair::new_or_restore(
260            spawner,
261            device.id(),
262            qid,
263            sq_entries,
264            cq_entries,
265            interrupt,
266            registers,
267            mem,
268            None,
269            bounce_buffer,
270            aer_handler,
271        )
272    }
273
274    fn new_or_restore(
275        spawner: impl SpawnDriver,
276        device_id: &str,
277        qid: u16,
278        sq_entries: u16,
279        cq_entries: u16,
280        interrupt: DeviceInterrupt,
281        registers: Arc<DeviceRegisters<D>>,
282        mem: MemoryBlock,
283        saved_state: Option<&QueueHandlerSavedState>,
284        bounce_buffer: bool,
285        aer_handler: A,
286    ) -> anyhow::Result<Self> {
287        // MemoryBlock is either allocated or restored prior calling here.
288        let sq_mem_block = mem.subblock(0, SQ_SIZE);
289        let cq_mem_block = mem.subblock(SQ_SIZE, CQ_SIZE);
290        let data_offset = SQ_SIZE + CQ_SIZE;
291
292        // Make sure that the queue memory is physically contiguous. While the
293        // NVMe spec allows for some provisions of queue memory to be
294        // non-contiguous, this depends on device support. At least one device
295        // that we must support requires that the memory is contiguous (via the
296        // CAP.CQR bit). Because of that, just simplify the code paths to use
297        // contiguous memory.
298        //
299        // We could also seek through the memory block to find contiguous pages
300        // (for example, if the first 4 pages are not contiguous, but pages 5-8
301        // are, use those), but other parts of this driver already assume the
302        // math to get the correct offsets.
303        //
304        // N.B. It is expected that allocations from the private pool will
305        // always be contiguous, and that is the normal path. That can fail in
306        // some cases (e.g. if we got some guesses about memory size wrong), and
307        // we prefer to operate in a perf degraded state rather than fail
308        // completely.
309
310        let (sq_is_contiguous, cq_is_contiguous) = (
311            sq_mem_block.contiguous_pfns(),
312            cq_mem_block.contiguous_pfns(),
313        );
314
315        let (sq_entries, cq_entries) = if !sq_is_contiguous || !cq_is_contiguous {
316            tracing::warn!(
317                qid,
318                sq_is_contiguous,
319                sq_mem_block.pfns = ?sq_mem_block.pfns(),
320                cq_is_contiguous,
321                cq_mem_block.pfns = ?cq_mem_block.pfns(),
322                "non-contiguous queue memory detected, falling back to single page queues"
323            );
324            // Clamp both queues to the number of entries that will fit in a
325            // single SQ page (since this will be the smaller between the SQ and
326            // CQ capacity).
327            (SQ_ENTRIES_PER_PAGE as u16, SQ_ENTRIES_PER_PAGE as u16)
328        } else {
329            (sq_entries, cq_entries)
330        };
331
332        let sq_addr = sq_mem_block.pfns()[0] * PAGE_SIZE64;
333        let cq_addr = cq_mem_block.pfns()[0] * PAGE_SIZE64;
334
335        let queue_handler = match saved_state {
336            Some(s) => {
337                QueueHandler::restore(sq_mem_block, cq_mem_block, s, aer_handler, device_id, qid)?
338            }
339            None => {
340                // Create a new one.
341                QueueHandler {
342                    sq: SubmissionQueue::new(qid, sq_entries, sq_mem_block),
343                    cq: CompletionQueue::new(qid, cq_entries, cq_mem_block),
344                    commands: PendingCommands::new(qid),
345                    stats: Default::default(),
346                    drain_after_restore: false,
347                    aer_handler,
348                    device_id: device_id.into(),
349                    qid,
350                }
351            }
352        };
353
354        let (send_req, recv_req) = mesh::channel();
355        let (send_cmd, recv_cmd) = mesh::channel();
356        let mut task = TaskControl::new(QueueHandlerLoop {
357            queue_handler,
358            registers,
359            recv_req: Some(recv_req),
360            recv_cmd: Some(recv_cmd),
361            interrupt,
362        });
363        task.insert(spawner, "nvme-queue", ());
364        task.start();
365
366        // Convert the queue pages to bytes, and assert that queue size is large
367        // enough.
368        const fn pages_to_size_bytes(pages: usize) -> usize {
369            let size = pages * PAGE_SIZE;
370            assert!(
371                size >= 128 * 1024 + PAGE_SIZE,
372                "not enough room for an ATAPI IO plus a PRP list"
373            );
374            size
375        }
376
377        // Page allocator uses remaining part of the buffer for dynamic
378        // allocation. The length of the page allocator depends on if bounce
379        // buffering / double buffering is needed.
380        //
381        // NOTE: Do not remove the `const` blocks below. This is to force
382        // compile time evaluation of the assertion described above.
383        let alloc_len = if bounce_buffer {
384            const { pages_to_size_bytes(PER_QUEUE_PAGES_BOUNCE_BUFFER) }
385        } else {
386            const { pages_to_size_bytes(PER_QUEUE_PAGES_NO_BOUNCE_BUFFER) }
387        };
388
389        let alloc = PageAllocator::new(mem.subblock(data_offset, alloc_len));
390
391        Ok(Self {
392            task,
393            issuer: Arc::new(Issuer {
394                send_req,
395                send_cmd,
396                alloc,
397            }),
398            mem,
399            device_id: device_id.into(),
400            qid,
401            sq_entries,
402            cq_entries,
403            sq_addr,
404            cq_addr,
405        })
406    }
407
408    /// Returns the actual number of SQ entries supported by this queue pair.
409    pub fn sq_entries(&self) -> u16 {
410        self.sq_entries
411    }
412
413    /// Returns the actual number of CQ entries supported by this queue pair.
414    pub fn cq_entries(&self) -> u16 {
415        self.cq_entries
416    }
417
418    pub fn sq_addr(&self) -> u64 {
419        self.sq_addr
420    }
421
422    pub fn cq_addr(&self) -> u64 {
423        self.cq_addr
424    }
425
426    pub fn issuer(&self) -> &Arc<Issuer> {
427        &self.issuer
428    }
429
430    pub async fn shutdown(mut self) -> impl Send {
431        self.task.stop().await;
432        self.task.into_inner().0.queue_handler
433    }
434
435    /// Save queue pair state for servicing.
436    pub async fn save(&self) -> anyhow::Result<QueuePairSavedState> {
437        tracing::info!(qid = self.qid, pci_id = ?self.device_id, "saving queue pair state");
438        // Return error if the queue does not have any memory allocated.
439        if self.mem.pfns().is_empty() {
440            return Err(Error::InvalidState.into());
441        }
442        // Send an RPC request to QueueHandler thread to save its data.
443        // QueueHandler stops any other processing after completing Save request.
444        let handler_data = self.issuer.send_req.call(Req::Save, ()).await??;
445
446        Ok(QueuePairSavedState {
447            mem_len: self.mem.len(),
448            base_pfn: self.mem.pfns()[0],
449            qid: self.qid,
450            sq_entries: self.sq_entries,
451            cq_entries: self.cq_entries,
452            handler_data,
453        })
454    }
455
456    /// Restore queue pair state after servicing.
457    pub fn restore(
458        spawner: impl SpawnDriver,
459        interrupt: DeviceInterrupt,
460        registers: Arc<DeviceRegisters<D>>,
461        mem: MemoryBlock,
462        device_id: &str,
463        saved_state: &QueuePairSavedState,
464        bounce_buffer: bool,
465        aer_handler: A,
466    ) -> anyhow::Result<Self> {
467        let QueuePairSavedState {
468            mem_len: _,  // Used to restore DMA buffer before calling this.
469            base_pfn: _, // Used to restore DMA buffer before calling this.
470            qid,
471            sq_entries,
472            cq_entries,
473            handler_data,
474        } = saved_state;
475
476        QueuePair::new_or_restore(
477            spawner,
478            device_id,
479            *qid,
480            *sq_entries,
481            *cq_entries,
482            interrupt,
483            registers,
484            mem,
485            Some(handler_data),
486            bounce_buffer,
487            aer_handler,
488        )
489    }
490}
491
492/// An error issuing an NVMe request.
493#[derive(Debug, Error)]
494#[expect(missing_docs)]
495pub enum RequestError {
496    #[error("queue pair is gone")]
497    Gone(#[source] RpcError),
498    #[error("nvme error")]
499    Nvme(#[source] NvmeError),
500    #[error("memory error")]
501    Memory(#[source] GuestMemoryError),
502    #[error("i/o too large for double buffering")]
503    TooLarge,
504}
505
506#[derive(Debug, Copy, Clone, PartialEq, Eq)]
507pub struct NvmeError(spec::Status);
508
509impl NvmeError {
510    pub fn status(&self) -> spec::Status {
511        self.0
512    }
513}
514
515impl From<spec::Status> for NvmeError {
516    fn from(value: spec::Status) -> Self {
517        Self(value)
518    }
519}
520
521impl std::error::Error for NvmeError {}
522
523impl std::fmt::Display for NvmeError {
524    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
525        match self.0.status_code_type() {
526            spec::StatusCodeType::GENERIC => write!(
527                f,
528                "NVMe SCT general error, SC: {:#x?}",
529                self.0.status_code()
530            ),
531            spec::StatusCodeType::COMMAND_SPECIFIC => {
532                write!(
533                    f,
534                    "NVMe SCT command-specific error, SC: {:#x?}",
535                    self.0.status_code()
536                )
537            }
538            spec::StatusCodeType::MEDIA_ERROR => {
539                write!(f, "NVMe SCT media error, SC: {:#x?}", self.0.status_code())
540            }
541            spec::StatusCodeType::PATH_RELATED => {
542                write!(
543                    f,
544                    "NVMe SCT path-related error, SC: {:#x?}",
545                    self.0.status_code()
546                )
547            }
548            spec::StatusCodeType::VENDOR_SPECIFIC => {
549                write!(
550                    f,
551                    "NVMe SCT vendor-specific error, SC: {:#x?}",
552                    self.0.status_code()
553                )
554            }
555            _ => write!(
556                f,
557                "NVMe SCT unknown ({:#x?}), SC: {:#x?} (raw: {:#x?})",
558                self.0.status_code_type(),
559                self.0.status_code(),
560                self.0
561            ),
562        }
563    }
564}
565
566#[derive(Debug, Inspect)]
567pub struct Issuer {
568    #[inspect(skip)]
569    send_cmd: mesh::Sender<Cmd>,
570    #[inspect(skip)]
571    send_req: mesh::Sender<Req>,
572    alloc: PageAllocator,
573}
574
575impl Issuer {
576    pub async fn issue_raw(
577        &self,
578        command: spec::Command,
579    ) -> Result<spec::Completion, RequestError> {
580        match self.send_cmd.call(Cmd::Command, command).await {
581            Ok(completion) if completion.status.status() == 0 => Ok(completion),
582            Ok(completion) => Err(RequestError::Nvme(NvmeError(spec::Status(
583                completion.status.status(),
584            )))),
585            Err(err) => Err(RequestError::Gone(err)),
586        }
587    }
588
589    pub async fn issue_get_aen(&self) -> Result<AsynchronousEventRequestDw0, RequestError> {
590        match self.send_req.call_failable(Req::NextAen, ()).await {
591            Ok(aen_completion) => Ok(aen_completion),
592            Err(RpcError::Call(e)) => Err(e),
593            Err(RpcError::Channel(e)) => Err(RequestError::Gone(RpcError::Channel(e))),
594        }
595    }
596
597    pub async fn issue_external(
598        &self,
599        mut command: spec::Command,
600        guest_memory: &GuestMemory,
601        mem: PagedRange<'_>,
602    ) -> Result<spec::Completion, RequestError> {
603        let mut double_buffer_pages = None;
604        let opcode = spec::Opcode(command.cdw0.opcode());
605        assert!(
606            opcode.transfer_controller_to_host()
607                || opcode.transfer_host_to_controller()
608                || mem.is_empty()
609        );
610
611        // Ensure the memory is currently mapped.
612        guest_memory
613            .probe_gpns(mem.gpns())
614            .map_err(RequestError::Memory)?;
615
616        let prp = if mem
617            .gpns()
618            .iter()
619            .all(|&gpn| guest_memory.iova(gpn * PAGE_SIZE64).is_some())
620        {
621            // Guest memory is available to the device, so issue the IO directly.
622            self.make_prp(
623                mem.offset() as u64,
624                mem.gpns()
625                    .iter()
626                    .map(|&gpn| guest_memory.iova(gpn * PAGE_SIZE64).unwrap()),
627            )
628            .await
629        } else {
630            tracing::debug!(opcode = opcode.0, size = mem.len(), "double buffering");
631
632            // Guest memory is not accessible by the device. Double buffer
633            // through an allocation.
634            let double_buffer_pages = double_buffer_pages.insert(
635                self.alloc
636                    .alloc_bytes(mem.len())
637                    .await
638                    .ok_or(RequestError::TooLarge)?,
639            );
640
641            if opcode.transfer_host_to_controller() {
642                double_buffer_pages
643                    .copy_from_guest_memory(guest_memory, mem)
644                    .map_err(RequestError::Memory)?;
645            }
646
647            self.make_prp(
648                0,
649                (0..double_buffer_pages.page_count())
650                    .map(|i| double_buffer_pages.physical_address(i)),
651            )
652            .await
653        };
654
655        command.dptr = prp.dptr;
656        let r = self.issue_raw(command).await;
657        if let Some(double_buffer_pages) = double_buffer_pages {
658            if r.is_ok() && opcode.transfer_controller_to_host() {
659                double_buffer_pages
660                    .copy_to_guest_memory(guest_memory, mem)
661                    .map_err(RequestError::Memory)?;
662            }
663        }
664        r
665    }
666
667    async fn make_prp(
668        &self,
669        offset: u64,
670        mut iovas: impl ExactSizeIterator<Item = u64>,
671    ) -> Prp<'_> {
672        let mut prp_pages = None;
673        let dptr = match iovas.len() {
674            0 => [INVALID_PAGE_ADDR; 2],
675            1 => [iovas.next().unwrap() + offset, INVALID_PAGE_ADDR],
676            2 => [iovas.next().unwrap() + offset, iovas.next().unwrap()],
677            _ => {
678                let a = iovas.next().unwrap();
679                assert!(iovas.len() <= 4096);
680                let prp = self
681                    .alloc
682                    .alloc_pages(1)
683                    .await
684                    .expect("pool capacity is >= 1 page");
685
686                let prp_addr = prp.physical_address(0);
687                let page = prp.page_as_slice(0);
688                for (iova, dest) in iovas.zip(page.chunks_exact(8)) {
689                    dest.atomic_write_obj(&iova.to_le_bytes());
690                }
691                prp_pages = Some(prp);
692                [a + offset, prp_addr]
693            }
694        };
695        Prp {
696            dptr,
697            _pages: prp_pages,
698        }
699    }
700
701    pub async fn issue_neither(
702        &self,
703        mut command: spec::Command,
704    ) -> Result<spec::Completion, RequestError> {
705        command.dptr = [INVALID_PAGE_ADDR; 2];
706        self.issue_raw(command).await
707    }
708
709    pub async fn issue_in(
710        &self,
711        mut command: spec::Command,
712        data: &[u8],
713    ) -> Result<spec::Completion, RequestError> {
714        let mem = self
715            .alloc
716            .alloc_bytes(data.len())
717            .await
718            .expect("pool cap is >= 1 page");
719
720        mem.write(data);
721        assert_eq!(
722            mem.page_count(),
723            1,
724            "larger requests not currently supported"
725        );
726        let prp = Prp {
727            dptr: [mem.physical_address(0), INVALID_PAGE_ADDR],
728            _pages: None,
729        };
730        command.dptr = prp.dptr;
731        self.issue_raw(command).await
732    }
733
734    pub async fn issue_out(
735        &self,
736        mut command: spec::Command,
737        data: &mut [u8],
738    ) -> Result<spec::Completion, RequestError> {
739        let mem = self
740            .alloc
741            .alloc_bytes(data.len())
742            .await
743            .expect("pool capacity is sufficient");
744
745        let prp = self
746            .make_prp(0, (0..mem.page_count()).map(|i| mem.physical_address(i)))
747            .await;
748
749        command.dptr = prp.dptr;
750        let completion = self.issue_raw(command).await;
751        mem.read(data);
752        completion
753    }
754}
755
756struct Prp<'a> {
757    dptr: [u64; 2],
758    _pages: Option<ScopedPages<'a>>,
759}
760
761#[derive(Inspect)]
762struct PendingCommands {
763    /// Mapping from the low bits of cid to pending command.
764    #[inspect(iter_by_key)]
765    commands: Slab<PendingCommand>,
766    #[inspect(hex)]
767    next_cid_high_bits: Wrapping<u16>,
768    qid: u16,
769}
770
771#[derive(Inspect)]
772struct PendingCommand {
773    // Keep the command around for diagnostics.
774    command: spec::Command,
775    #[inspect(skip)]
776    respond: Rpc<(), spec::Completion>,
777}
778
779// "ControlPlane" requests sent to the QueueHandler. These can be processed at
780// any time; regardless of whether submission queue is full or not and will be
781// prioritized over IO completions to keep the save path responsive.
782enum Req {
783    Save(Rpc<(), Result<QueueHandlerSavedState, anyhow::Error>>),
784    Inspect(inspect::Deferred),
785    NextAen(Rpc<(), Result<AsynchronousEventRequestDw0, RequestError>>),
786}
787
788// "DataPlane" commands sent to the QueueHandler. Actual NVMe commands that
789// require space in the submission queue.
790enum Cmd {
791    Command(Rpc<spec::Command, spec::Completion>),
792    SendAer(),
793}
794
795/// Functionality for an AER handler. The default implementation
796/// represents a NoOp handler with functions on the critical path compiled out
797/// for efficiency and should be used for IO Queues.
798pub trait AerHandler: Send + Sync + 'static {
799    /// Given a completion command, if the command pertains to a pending AEN,
800    /// process it.
801    #[inline]
802    fn handle_completion(&mut self, _completion: &nvme_spec::Completion) {}
803    /// Handle a request from the driver to get the most-recent undelivered AEN
804    /// or wait for the next one.
805    fn handle_aen_request(
806        &mut self,
807        _rpc: Rpc<(), Result<AsynchronousEventRequestDw0, RequestError>>,
808    ) {
809    }
810    /// Update the CID that the handler is awaiting an AEN on.
811    fn update_awaiting_cid(&mut self, _cid: u16) {}
812    /// Returns whether an AER needs to sent to the controller or not. Since
813    /// this is the only function on the critical path, attempt to inline it.
814    #[inline]
815    fn poll_send_aer(&self) -> bool {
816        false
817    }
818    fn save(&self) -> Option<AerHandlerSavedState> {
819        None
820    }
821    fn restore(&mut self, _state: &Option<AerHandlerSavedState>) {}
822}
823
824/// Admin queue AER handler. Ensures a single outstanding AER and persists state
825/// across save/restore to process AENs received during servicing.
826pub struct AdminAerHandler {
827    last_aen: Option<AsynchronousEventRequestDw0>,
828    await_aen_cid: Option<u16>,
829    send_aen: Option<Rpc<(), Result<AsynchronousEventRequestDw0, RequestError>>>, // Channel to return AENs on.
830    failed_status: Option<spec::Status>, // If the failed state is reached, it will stop looping until save/restore.
831}
832
833impl AdminAerHandler {
834    pub fn new() -> Self {
835        Self {
836            last_aen: None,
837            await_aen_cid: None,
838            send_aen: None,
839            failed_status: None,
840        }
841    }
842}
843
844impl AerHandler for AdminAerHandler {
845    fn handle_completion(&mut self, completion: &nvme_spec::Completion) {
846        if let Some(await_aen_cid) = self.await_aen_cid
847            && completion.cid == await_aen_cid
848            && self.failed_status.is_none()
849        {
850            self.await_aen_cid = None;
851
852            // If error, cleanup and stop processing AENs.
853            if completion.status.status() != 0 {
854                self.last_aen = None;
855                let failed_status = spec::Status(completion.status.status());
856                self.failed_status = Some(failed_status);
857                if let Some(send_aen) = self.send_aen.take() {
858                    send_aen.complete(Err(RequestError::Nvme(NvmeError(failed_status))));
859                }
860                return;
861            }
862            // Complete the AEN or pend it.
863            let aen = AsynchronousEventRequestDw0::from_bits(completion.dw0);
864            if let Some(send_aen) = self.send_aen.take() {
865                send_aen.complete(Ok(aen));
866            } else {
867                self.last_aen = Some(aen);
868            }
869        }
870    }
871
872    fn handle_aen_request(
873        &mut self,
874        rpc: Rpc<(), Result<AsynchronousEventRequestDw0, RequestError>>,
875    ) {
876        if let Some(aen) = self.last_aen.take() {
877            rpc.complete(Ok(aen));
878        } else if let Some(failed_status) = self.failed_status {
879            rpc.complete(Err(RequestError::Nvme(NvmeError(failed_status))));
880        } else {
881            self.send_aen = Some(rpc); // Save driver request to be completed later.
882        }
883    }
884
885    fn poll_send_aer(&self) -> bool {
886        self.await_aen_cid.is_none() && self.failed_status.is_none()
887    }
888
889    fn update_awaiting_cid(&mut self, cid: u16) {
890        if let Some(await_aen_cid) = self.await_aen_cid {
891            panic!("already awaiting on AEN with cid {}", await_aen_cid);
892        }
893        self.await_aen_cid = Some(cid);
894    }
895
896    fn save(&self) -> Option<AerHandlerSavedState> {
897        Some(AerHandlerSavedState {
898            last_aen: self.last_aen.map(AsynchronousEventRequestDw0::into_bits), // Save as u32
899            await_aen_cid: self.await_aen_cid,
900        })
901    }
902
903    fn restore(&mut self, state: &Option<AerHandlerSavedState>) {
904        if let Some(state) = state {
905            let AerHandlerSavedState {
906                last_aen,
907                await_aen_cid,
908            } = state;
909            self.last_aen = last_aen.map(AsynchronousEventRequestDw0::from_bits); // Restore from u32
910            self.await_aen_cid = *await_aen_cid;
911        }
912    }
913}
914
915/// No-op AER handler. Should be only used for IO queues.
916pub struct NoOpAerHandler;
917impl AerHandler for NoOpAerHandler {
918    fn handle_aen_request(
919        &mut self,
920        _rpc: Rpc<(), Result<AsynchronousEventRequestDw0, RequestError>>,
921    ) {
922        panic!(
923            "no-op aer handler should never receive an aen request. This is likely a bug in the driver."
924        );
925    }
926
927    fn update_awaiting_cid(&mut self, _cid: u16) {
928        panic!(
929            "no-op aer handler should never be passed a cid to await. This is likely a bug in the driver."
930        );
931    }
932}
933
934#[derive(Inspect)]
935struct QueueHandler<A: AerHandler> {
936    sq: SubmissionQueue,
937    cq: CompletionQueue,
938    commands: PendingCommands,
939    stats: QueueStats,
940    drain_after_restore: bool,
941    #[inspect(skip)]
942    aer_handler: A,
943    device_id: String,
944    qid: u16,
945}
946
947#[derive(Inspect, Default)]
948struct QueueStats {
949    issued: Counter,
950    completed: Counter,
951    interrupts: Counter,
952}
953
954impl<A: AerHandler> QueueHandler<A> {
955    async fn run(
956        &mut self,
957        registers: &DeviceRegisters<impl DeviceBacking>,
958        mut recv_req: mesh::Receiver<Req>,
959        mut recv_cmd: mesh::Receiver<Cmd>,
960        interrupt: &mut DeviceInterrupt,
961    ) {
962        if self.drain_after_restore {
963            tracing::info!(pci_id = ?self.device_id, qid = self.qid, "Have {} outstanding IOs from before save, draining them before allowing new IO...", self.commands.len());
964        }
965
966        loop {
967            enum Event {
968                Request(Req),
969                Command(Cmd),
970                Completion(spec::Completion),
971            }
972
973            let event = if !self.drain_after_restore {
974                // Normal processing of the requests and completions.
975                poll_fn(|cx| {
976                    // Look for NVME commands
977                    if !self.sq.is_full() && !self.commands.is_full() {
978                        // Prioritize sending AERs to keep the cycle going
979                        if self.aer_handler.poll_send_aer() {
980                            return Event::Command(Cmd::SendAer()).into();
981                        }
982                        if let Poll::Ready(Some(cmd)) = recv_cmd.poll_next_unpin(cx) {
983                            return Event::Command(cmd).into();
984                        }
985                    }
986                    // Look for control plane requests like Save/Inspect
987                    if let Poll::Ready(Some(req)) = recv_req.poll_next_unpin(cx) {
988                        return Event::Request(req).into();
989                    }
990                    // Look for completions
991                    while !self.commands.is_empty() {
992                        if let Some(completion) = self.cq.read() {
993                            return Event::Completion(completion).into();
994                        }
995                        if interrupt.poll(cx).is_pending() {
996                            break;
997                        }
998                        self.stats.interrupts.increment();
999                    }
1000                    self.sq.commit(registers);
1001                    self.cq.commit(registers);
1002                    Poll::Pending
1003                })
1004                .await
1005            } else {
1006                // Only process in-flight completions.
1007                poll_fn(|cx| {
1008                    // Look for control plane requests like Save/Inspect
1009                    if let Poll::Ready(Some(req)) = recv_req.poll_next_unpin(cx) {
1010                        return Event::Request(req).into();
1011                    }
1012
1013                    while !self.commands.is_empty() {
1014                        if let Some(completion) = self.cq.read() {
1015                            return Event::Completion(completion).into();
1016                        }
1017                        if interrupt.poll(cx).is_pending() {
1018                            break;
1019                        }
1020                        self.stats.interrupts.increment();
1021                    }
1022                    self.cq.commit(registers);
1023                    Poll::Pending
1024                })
1025                .await
1026            };
1027
1028            match event {
1029                Event::Request(req) => match req {
1030                    Req::Save(queue_state) => {
1031                        tracing::info!(pci_id = ?self.device_id, qid = ?self.qid, "received save request, shutting down ...");
1032                        queue_state.complete(self.save().await);
1033                        // Do not allow any more processing after save completed.
1034                        break;
1035                    }
1036                    Req::Inspect(deferred) => deferred.inspect(&self),
1037                    Req::NextAen(rpc) => {
1038                        self.aer_handler.handle_aen_request(rpc);
1039                    }
1040                },
1041                Event::Command(cmd) => match cmd {
1042                    Cmd::Command(rpc) => {
1043                        let (mut command, respond) = rpc.split();
1044                        self.commands.insert(&mut command, respond);
1045                        self.sq.write(command).unwrap();
1046                        self.stats.issued.increment();
1047                    }
1048                    Cmd::SendAer() => {
1049                        let mut command = admin_cmd(spec::AdminOpcode::ASYNCHRONOUS_EVENT_REQUEST);
1050                        self.commands.insert(&mut command, Rpc::detached(()));
1051                        self.aer_handler.update_awaiting_cid(command.cdw0.cid());
1052                        self.sq.write(command).unwrap();
1053                        self.stats.issued.increment();
1054                    }
1055                },
1056                Event::Completion(completion) => {
1057                    assert_eq!(completion.sqid, self.sq.id());
1058                    let respond = self.commands.remove(completion.cid);
1059                    if self.drain_after_restore && self.commands.is_empty() {
1060                        // Switch to normal processing mode once all in-flight commands completed.
1061                        tracing::info!(pci_id = ?self.device_id, qid = ?self.qid, "done with drain-after-restore");
1062                        self.drain_after_restore = false;
1063                    }
1064                    self.sq.update_head(completion.sqhd);
1065                    self.aer_handler.handle_completion(&completion);
1066                    respond.complete(completion);
1067                    self.stats.completed.increment();
1068                }
1069            }
1070        }
1071    }
1072
1073    /// Save queue data for servicing.
1074    pub async fn save(&self) -> anyhow::Result<QueueHandlerSavedState> {
1075        // The data is collected from both QueuePair and QueueHandler.
1076        Ok(QueueHandlerSavedState {
1077            sq_state: self.sq.save(),
1078            cq_state: self.cq.save(),
1079            pending_cmds: self.commands.save(),
1080            aer_handler: self.aer_handler.save(),
1081        })
1082    }
1083
1084    /// Restore queue data after servicing.
1085    pub fn restore(
1086        sq_mem_block: MemoryBlock,
1087        cq_mem_block: MemoryBlock,
1088        saved_state: &QueueHandlerSavedState,
1089        mut aer_handler: A,
1090        device_id: &str,
1091        qid: u16,
1092    ) -> anyhow::Result<Self> {
1093        let QueueHandlerSavedState {
1094            sq_state,
1095            cq_state,
1096            pending_cmds,
1097            aer_handler: aer_handler_saved_state,
1098        } = saved_state;
1099
1100        aer_handler.restore(aer_handler_saved_state);
1101
1102        Ok(Self {
1103            sq: SubmissionQueue::restore(sq_mem_block, sq_state)?,
1104            cq: CompletionQueue::restore(cq_mem_block, cq_state)?,
1105            commands: PendingCommands::restore(pending_cmds, sq_state.sqid)?,
1106            stats: Default::default(),
1107            // Only drain pending commands for I/O queues.
1108            // Admin queue is expected to have pending Async Event requests.
1109            drain_after_restore: sq_state.sqid != 0 && !pending_cmds.commands.is_empty(),
1110            aer_handler,
1111            device_id: device_id.into(),
1112            qid,
1113        })
1114    }
1115}
1116
1117pub(crate) fn admin_cmd(opcode: spec::AdminOpcode) -> spec::Command {
1118    spec::Command {
1119        cdw0: spec::Cdw0::new().with_opcode(opcode.0),
1120        ..FromZeros::new_zeroed()
1121    }
1122}