nvme_driver/
queue_pair.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Implementation of an admin or IO queue pair.
5
6use super::spec;
7use crate::driver::save_restore::AerHandlerSavedState;
8use crate::driver::save_restore::Error;
9use crate::driver::save_restore::PendingCommandSavedState;
10use crate::driver::save_restore::PendingCommandsSavedState;
11use crate::driver::save_restore::QueueHandlerSavedState;
12use crate::driver::save_restore::QueuePairSavedState;
13use crate::queues::CompletionQueue;
14use crate::queues::SubmissionQueue;
15use crate::registers::DeviceRegisters;
16use anyhow::Context;
17use futures::StreamExt;
18use guestmem::GuestMemory;
19use guestmem::GuestMemoryError;
20use guestmem::ranges::PagedRange;
21use inspect::Inspect;
22use inspect_counters::Counter;
23use mesh::rpc::Rpc;
24use mesh::rpc::RpcError;
25use mesh::rpc::RpcSend;
26use nvme_spec::AsynchronousEventRequestDw0;
27use pal_async::driver::SpawnDriver;
28use safeatomic::AtomicSliceOps;
29use slab::Slab;
30use std::future::poll_fn;
31use std::num::Wrapping;
32use std::sync::Arc;
33use std::task::Poll;
34use task_control::AsyncRun;
35use task_control::TaskControl;
36use thiserror::Error;
37use user_driver::DeviceBacking;
38use user_driver::interrupt::DeviceInterrupt;
39use user_driver::memory::MemoryBlock;
40use user_driver::memory::PAGE_SIZE;
41use user_driver::memory::PAGE_SIZE64;
42use user_driver::page_allocator::PageAllocator;
43use user_driver::page_allocator::ScopedPages;
44use zerocopy::FromZeros;
45
46/// Value for unused PRP entries, to catch/mitigate buffer size mismatches.
47const INVALID_PAGE_ADDR: u64 = !(PAGE_SIZE as u64 - 1);
48
49const SQ_ENTRY_SIZE: usize = size_of::<spec::Command>();
50const CQ_ENTRY_SIZE: usize = size_of::<spec::Completion>();
51/// Submission Queue size in bytes.
52const SQ_SIZE: usize = PAGE_SIZE * 4;
53/// Completion Queue size in bytes.
54const CQ_SIZE: usize = PAGE_SIZE;
55/// Maximum SQ size in entries.
56pub const MAX_SQ_ENTRIES: u16 = (SQ_SIZE / SQ_ENTRY_SIZE) as u16;
57/// Maximum CQ size in entries.
58pub const MAX_CQ_ENTRIES: u16 = (CQ_SIZE / CQ_ENTRY_SIZE) as u16;
59/// Number of pages per queue if bounce buffering.
60const PER_QUEUE_PAGES_BOUNCE_BUFFER: usize = 128;
61/// Number of pages per queue if not bounce buffering.
62const PER_QUEUE_PAGES_NO_BOUNCE_BUFFER: usize = 64;
63/// Number of SQ entries per page (64).
64const SQ_ENTRIES_PER_PAGE: usize = PAGE_SIZE / SQ_ENTRY_SIZE;
65
66#[derive(Inspect)]
67pub(crate) struct QueuePair<A: AerHandler, D: DeviceBacking> {
68    #[inspect(skip)]
69    task: TaskControl<QueueHandlerLoop<A, D>, ()>,
70    #[inspect(flatten, with = "|x| inspect::send(&x.send, Req::Inspect)")]
71    issuer: Arc<Issuer>,
72    #[inspect(skip)]
73    mem: MemoryBlock,
74    #[inspect(skip)]
75    qid: u16,
76    #[inspect(skip)]
77    sq_entries: u16,
78    #[inspect(skip)]
79    cq_entries: u16,
80    sq_addr: u64,
81    cq_addr: u64,
82}
83
84impl PendingCommands {
85    const CID_KEY_BITS: u32 = 10;
86    const CID_KEY_MASK: u16 = (1 << Self::CID_KEY_BITS) - 1;
87    const MAX_CIDS: usize = 1 << Self::CID_KEY_BITS;
88    const CID_SEQ_OFFSET: Wrapping<u16> = Wrapping(1 << Self::CID_KEY_BITS);
89
90    fn new(qid: u16) -> Self {
91        Self {
92            commands: Slab::new(),
93            next_cid_high_bits: Wrapping(0),
94            qid,
95        }
96    }
97
98    fn is_full(&self) -> bool {
99        self.commands.len() >= Self::MAX_CIDS
100    }
101
102    fn is_empty(&self) -> bool {
103        self.commands.is_empty()
104    }
105
106    /// Inserts a command into the pending list, updating it with a new CID.
107    fn insert(&mut self, command: &mut spec::Command, respond: Rpc<(), spec::Completion>) {
108        let entry = self.commands.vacant_entry();
109        assert!(entry.key() < Self::MAX_CIDS);
110        assert_eq!(self.next_cid_high_bits % Self::CID_SEQ_OFFSET, Wrapping(0));
111        let cid = entry.key() as u16 | self.next_cid_high_bits.0;
112        self.next_cid_high_bits += Self::CID_SEQ_OFFSET;
113        command.cdw0.set_cid(cid);
114        entry.insert(PendingCommand {
115            command: *command,
116            respond,
117        });
118    }
119
120    fn remove(&mut self, cid: u16) -> Rpc<(), spec::Completion> {
121        let command = self
122            .commands
123            .try_remove((cid & Self::CID_KEY_MASK) as usize)
124            .unwrap_or_else(|| panic!("completion for unknown cid: qid={}, cid={}", self.qid, cid));
125        assert_eq!(
126            command.command.cdw0.cid(),
127            cid,
128            "cid sequence number mismatch: qid={}, command_opcode={:#x}",
129            self.qid,
130            command.command.cdw0.opcode(),
131        );
132        command.respond
133    }
134
135    /// Save pending commands into a buffer.
136    pub fn save(&self) -> PendingCommandsSavedState {
137        let commands: Vec<PendingCommandSavedState> = self
138            .commands
139            .iter()
140            .map(|(_index, cmd)| PendingCommandSavedState {
141                command: cmd.command,
142            })
143            .collect();
144        PendingCommandsSavedState {
145            commands,
146            next_cid_high_bits: self.next_cid_high_bits.0,
147            // TODO: Not used today, added for future compatibility.
148            cid_key_bits: Self::CID_KEY_BITS,
149        }
150    }
151
152    /// Restore pending commands from the saved state.
153    pub fn restore(saved_state: &PendingCommandsSavedState, qid: u16) -> anyhow::Result<Self> {
154        let PendingCommandsSavedState {
155            commands,
156            next_cid_high_bits,
157            cid_key_bits: _, // TODO: For future use.
158        } = saved_state;
159
160        Ok(Self {
161            // Re-create identical Slab where CIDs are correctly mapped.
162            commands: commands
163                .iter()
164                .map(|state| {
165                    // To correctly restore Slab we need both the command index,
166                    // inherited from command's CID, and the command itself.
167                    (
168                        // Remove high CID bits to be used as a key.
169                        (state.command.cdw0.cid() & Self::CID_KEY_MASK) as usize,
170                        PendingCommand {
171                            command: state.command,
172                            respond: Rpc::detached(()),
173                        },
174                    )
175                })
176                .collect::<Slab<PendingCommand>>(),
177            next_cid_high_bits: Wrapping(*next_cid_high_bits),
178            qid,
179        })
180    }
181}
182
183struct QueueHandlerLoop<A: AerHandler, D: DeviceBacking> {
184    queue_handler: QueueHandler<A>,
185    registers: Arc<DeviceRegisters<D>>,
186    recv: Option<mesh::Receiver<Req>>,
187    interrupt: DeviceInterrupt,
188}
189
190impl<A: AerHandler, D: DeviceBacking> AsyncRun<()> for QueueHandlerLoop<A, D> {
191    async fn run(
192        &mut self,
193        stop: &mut task_control::StopTask<'_>,
194        _: &mut (),
195    ) -> Result<(), task_control::Cancelled> {
196        stop.until_stopped(async {
197            self.queue_handler
198                .run(
199                    &self.registers,
200                    self.recv.take().unwrap(),
201                    &mut self.interrupt,
202                )
203                .await;
204        })
205        .await
206    }
207}
208
209impl<A: AerHandler, D: DeviceBacking> QueuePair<A, D> {
210    /// Create a new queue pair.
211    ///
212    /// `sq_entries` and `cq_entries` are the requested sizes in entries.
213    /// Calling code should request the largest size it thinks the device
214    /// will support (see `CAP.MQES`). These may be clamped down to what will
215    /// fit in one page should this routine fail to allocate physically
216    /// contiguous memory to back the queues.
217    /// IMPORTANT: Calling code should check the actual sizes via corresponding
218    /// calls to [`QueuePair::sq_entries`] and [`QueuePair::cq_entries`] AFTER calling this routine.
219    pub fn new(
220        spawner: impl SpawnDriver,
221        device: &D,
222        qid: u16,
223        sq_entries: u16,
224        cq_entries: u16,
225        interrupt: DeviceInterrupt,
226        registers: Arc<DeviceRegisters<D>>,
227        bounce_buffer: bool,
228        aer_handler: A,
229    ) -> anyhow::Result<Self> {
230        // FUTURE: Consider splitting this into several allocations, rather than
231        // allocating the sum total together. This can increase the likelihood
232        // of getting contiguous memory when falling back to the LockedMem
233        // allocator, but this is not the expected path. Be careful that any
234        // changes you make here work with already established save state.
235        let total_size = SQ_SIZE
236            + CQ_SIZE
237            + if bounce_buffer {
238                PER_QUEUE_PAGES_BOUNCE_BUFFER * PAGE_SIZE
239            } else {
240                PER_QUEUE_PAGES_NO_BOUNCE_BUFFER * PAGE_SIZE
241            };
242        let dma_client = device.dma_client();
243
244        let mem = dma_client
245            .allocate_dma_buffer(total_size)
246            .context("failed to allocate memory for queues")?;
247
248        assert!(sq_entries <= MAX_SQ_ENTRIES);
249        assert!(cq_entries <= MAX_CQ_ENTRIES);
250
251        QueuePair::new_or_restore(
252            spawner,
253            qid,
254            sq_entries,
255            cq_entries,
256            interrupt,
257            registers,
258            mem,
259            None,
260            bounce_buffer,
261            aer_handler,
262        )
263    }
264
265    fn new_or_restore(
266        spawner: impl SpawnDriver,
267        qid: u16,
268        sq_entries: u16,
269        cq_entries: u16,
270        interrupt: DeviceInterrupt,
271        registers: Arc<DeviceRegisters<D>>,
272        mem: MemoryBlock,
273        saved_state: Option<&QueueHandlerSavedState>,
274        bounce_buffer: bool,
275        aer_handler: A,
276    ) -> anyhow::Result<Self> {
277        // MemoryBlock is either allocated or restored prior calling here.
278        let sq_mem_block = mem.subblock(0, SQ_SIZE);
279        let cq_mem_block = mem.subblock(SQ_SIZE, CQ_SIZE);
280        let data_offset = SQ_SIZE + CQ_SIZE;
281
282        // Make sure that the queue memory is physically contiguous. While the
283        // NVMe spec allows for some provisions of queue memory to be
284        // non-contiguous, this depends on device support. At least one device
285        // that we must support requires that the memory is contiguous (via the
286        // CAP.CQR bit). Because of that, just simplify the code paths to use
287        // contiguous memory.
288        //
289        // We could also seek through the memory block to find contiguous pages
290        // (for example, if the first 4 pages are not contiguous, but pages 5-8
291        // are, use those), but other parts of this driver already assume the
292        // math to get the correct offsets.
293        //
294        // N.B. It is expected that allocations from the private pool will
295        // always be contiguous, and that is the normal path. That can fail in
296        // some cases (e.g. if we got some guesses about memory size wrong), and
297        // we prefer to operate in a perf degraded state rather than fail
298        // completely.
299
300        let (sq_is_contiguous, cq_is_contiguous) = (
301            sq_mem_block.contiguous_pfns(),
302            cq_mem_block.contiguous_pfns(),
303        );
304
305        let (sq_entries, cq_entries) = if !sq_is_contiguous || !cq_is_contiguous {
306            tracing::warn!(
307                qid,
308                sq_is_contiguous,
309                sq_mem_block.pfns = ?sq_mem_block.pfns(),
310                cq_is_contiguous,
311                cq_mem_block.pfns = ?cq_mem_block.pfns(),
312                "non-contiguous queue memory detected, falling back to single page queues"
313            );
314            // Clamp both queues to the number of entries that will fit in a
315            // single SQ page (since this will be the smaller between the SQ and
316            // CQ capacity).
317            (SQ_ENTRIES_PER_PAGE as u16, SQ_ENTRIES_PER_PAGE as u16)
318        } else {
319            (sq_entries, cq_entries)
320        };
321
322        let sq_addr = sq_mem_block.pfns()[0] * PAGE_SIZE64;
323        let cq_addr = cq_mem_block.pfns()[0] * PAGE_SIZE64;
324
325        let queue_handler = match saved_state {
326            Some(s) => QueueHandler::restore(sq_mem_block, cq_mem_block, s, aer_handler)?,
327            None => {
328                // Create a new one.
329                QueueHandler {
330                    sq: SubmissionQueue::new(qid, sq_entries, sq_mem_block),
331                    cq: CompletionQueue::new(qid, cq_entries, cq_mem_block),
332                    commands: PendingCommands::new(qid),
333                    stats: Default::default(),
334                    drain_after_restore: false,
335                    aer_handler,
336                }
337            }
338        };
339
340        let (send, recv) = mesh::channel();
341        let mut task = TaskControl::new(QueueHandlerLoop {
342            queue_handler,
343            registers,
344            recv: Some(recv),
345            interrupt,
346        });
347        task.insert(spawner, "nvme-queue", ());
348        task.start();
349
350        // Convert the queue pages to bytes, and assert that queue size is large
351        // enough.
352        const fn pages_to_size_bytes(pages: usize) -> usize {
353            let size = pages * PAGE_SIZE;
354            assert!(
355                size >= 128 * 1024 + PAGE_SIZE,
356                "not enough room for an ATAPI IO plus a PRP list"
357            );
358            size
359        }
360
361        // Page allocator uses remaining part of the buffer for dynamic
362        // allocation. The length of the page allocator depends on if bounce
363        // buffering / double buffering is needed.
364        //
365        // NOTE: Do not remove the `const` blocks below. This is to force
366        // compile time evaluation of the assertion described above.
367        let alloc_len = if bounce_buffer {
368            const { pages_to_size_bytes(PER_QUEUE_PAGES_BOUNCE_BUFFER) }
369        } else {
370            const { pages_to_size_bytes(PER_QUEUE_PAGES_NO_BOUNCE_BUFFER) }
371        };
372
373        let alloc = PageAllocator::new(mem.subblock(data_offset, alloc_len));
374
375        Ok(Self {
376            task,
377            issuer: Arc::new(Issuer { send, alloc }),
378            mem,
379            qid,
380            sq_entries,
381            cq_entries,
382            sq_addr,
383            cq_addr,
384        })
385    }
386
387    /// Returns the actual number of SQ entries supported by this queue pair.
388    pub fn sq_entries(&self) -> u16 {
389        self.sq_entries
390    }
391
392    /// Returns the actual number of CQ entries supported by this queue pair.
393    pub fn cq_entries(&self) -> u16 {
394        self.cq_entries
395    }
396
397    pub fn sq_addr(&self) -> u64 {
398        self.sq_addr
399    }
400
401    pub fn cq_addr(&self) -> u64 {
402        self.cq_addr
403    }
404
405    pub fn issuer(&self) -> &Arc<Issuer> {
406        &self.issuer
407    }
408
409    pub async fn shutdown(mut self) -> impl Send {
410        self.task.stop().await;
411        self.task.into_inner().0.queue_handler
412    }
413
414    /// Save queue pair state for servicing.
415    pub async fn save(&self) -> anyhow::Result<QueuePairSavedState> {
416        // Return error if the queue does not have any memory allocated.
417        if self.mem.pfns().is_empty() {
418            return Err(Error::InvalidState.into());
419        }
420        // Send an RPC request to QueueHandler thread to save its data.
421        // QueueHandler stops any other processing after completing Save request.
422        let handler_data = self.issuer.send.call(Req::Save, ()).await??;
423
424        Ok(QueuePairSavedState {
425            mem_len: self.mem.len(),
426            base_pfn: self.mem.pfns()[0],
427            qid: self.qid,
428            sq_entries: self.sq_entries,
429            cq_entries: self.cq_entries,
430            handler_data,
431        })
432    }
433
434    /// Restore queue pair state after servicing.
435    pub fn restore(
436        spawner: impl SpawnDriver,
437        interrupt: DeviceInterrupt,
438        registers: Arc<DeviceRegisters<D>>,
439        mem: MemoryBlock,
440        saved_state: &QueuePairSavedState,
441        bounce_buffer: bool,
442        aer_handler: A,
443    ) -> anyhow::Result<Self> {
444        let QueuePairSavedState {
445            mem_len: _,  // Used to restore DMA buffer before calling this.
446            base_pfn: _, // Used to restore DMA buffer before calling this.
447            qid,
448            sq_entries,
449            cq_entries,
450            handler_data,
451        } = saved_state;
452
453        QueuePair::new_or_restore(
454            spawner,
455            *qid,
456            *sq_entries,
457            *cq_entries,
458            interrupt,
459            registers,
460            mem,
461            Some(handler_data),
462            bounce_buffer,
463            aer_handler,
464        )
465    }
466}
467
468/// An error issuing an NVMe request.
469#[derive(Debug, Error)]
470#[expect(missing_docs)]
471pub enum RequestError {
472    #[error("queue pair is gone")]
473    Gone(#[source] RpcError),
474    #[error("nvme error")]
475    Nvme(#[source] NvmeError),
476    #[error("memory error")]
477    Memory(#[source] GuestMemoryError),
478    #[error("i/o too large for double buffering")]
479    TooLarge,
480}
481
482#[derive(Debug, Copy, Clone, PartialEq, Eq)]
483pub struct NvmeError(spec::Status);
484
485impl NvmeError {
486    pub fn status(&self) -> spec::Status {
487        self.0
488    }
489}
490
491impl From<spec::Status> for NvmeError {
492    fn from(value: spec::Status) -> Self {
493        Self(value)
494    }
495}
496
497impl std::error::Error for NvmeError {}
498
499impl std::fmt::Display for NvmeError {
500    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
501        match self.0.status_code_type() {
502            spec::StatusCodeType::GENERIC => write!(
503                f,
504                "NVMe SCT general error, SC: {:#x?}",
505                self.0.status_code()
506            ),
507            spec::StatusCodeType::COMMAND_SPECIFIC => {
508                write!(
509                    f,
510                    "NVMe SCT command-specific error, SC: {:#x?}",
511                    self.0.status_code()
512                )
513            }
514            spec::StatusCodeType::MEDIA_ERROR => {
515                write!(f, "NVMe SCT media error, SC: {:#x?}", self.0.status_code())
516            }
517            spec::StatusCodeType::PATH_RELATED => {
518                write!(
519                    f,
520                    "NVMe SCT path-related error, SC: {:#x?}",
521                    self.0.status_code()
522                )
523            }
524            spec::StatusCodeType::VENDOR_SPECIFIC => {
525                write!(
526                    f,
527                    "NVMe SCT vendor-specific error, SC: {:#x?}",
528                    self.0.status_code()
529                )
530            }
531            _ => write!(
532                f,
533                "NVMe SCT unknown ({:#x?}), SC: {:#x?} (raw: {:#x?})",
534                self.0.status_code_type(),
535                self.0.status_code(),
536                self.0
537            ),
538        }
539    }
540}
541
542#[derive(Debug, Inspect)]
543pub struct Issuer {
544    #[inspect(skip)]
545    send: mesh::Sender<Req>,
546    alloc: PageAllocator,
547}
548
549impl Issuer {
550    pub async fn issue_raw(
551        &self,
552        command: spec::Command,
553    ) -> Result<spec::Completion, RequestError> {
554        match self.send.call(Req::Command, command).await {
555            Ok(completion) if completion.status.status() == 0 => Ok(completion),
556            Ok(completion) => Err(RequestError::Nvme(NvmeError(spec::Status(
557                completion.status.status(),
558            )))),
559            Err(err) => Err(RequestError::Gone(err)),
560        }
561    }
562
563    pub async fn issue_get_aen(&self) -> Result<AsynchronousEventRequestDw0, RequestError> {
564        match self.send.call(Req::NextAen, ()).await {
565            Ok(aen_completion) => Ok(aen_completion),
566            Err(err) => Err(RequestError::Gone(err)),
567        }
568    }
569
570    pub async fn issue_external(
571        &self,
572        mut command: spec::Command,
573        guest_memory: &GuestMemory,
574        mem: PagedRange<'_>,
575    ) -> Result<spec::Completion, RequestError> {
576        let mut double_buffer_pages = None;
577        let opcode = spec::Opcode(command.cdw0.opcode());
578        assert!(
579            opcode.transfer_controller_to_host()
580                || opcode.transfer_host_to_controller()
581                || mem.is_empty()
582        );
583
584        // Ensure the memory is currently mapped.
585        guest_memory
586            .probe_gpns(mem.gpns())
587            .map_err(RequestError::Memory)?;
588
589        let prp = if mem
590            .gpns()
591            .iter()
592            .all(|&gpn| guest_memory.iova(gpn * PAGE_SIZE64).is_some())
593        {
594            // Guest memory is available to the device, so issue the IO directly.
595            self.make_prp(
596                mem.offset() as u64,
597                mem.gpns()
598                    .iter()
599                    .map(|&gpn| guest_memory.iova(gpn * PAGE_SIZE64).unwrap()),
600            )
601            .await
602        } else {
603            tracing::debug!(opcode = opcode.0, size = mem.len(), "double buffering");
604
605            // Guest memory is not accessible by the device. Double buffer
606            // through an allocation.
607            let double_buffer_pages = double_buffer_pages.insert(
608                self.alloc
609                    .alloc_bytes(mem.len())
610                    .await
611                    .ok_or(RequestError::TooLarge)?,
612            );
613
614            if opcode.transfer_host_to_controller() {
615                double_buffer_pages
616                    .copy_from_guest_memory(guest_memory, mem)
617                    .map_err(RequestError::Memory)?;
618            }
619
620            self.make_prp(
621                0,
622                (0..double_buffer_pages.page_count())
623                    .map(|i| double_buffer_pages.physical_address(i)),
624            )
625            .await
626        };
627
628        command.dptr = prp.dptr;
629        let r = self.issue_raw(command).await;
630        if let Some(double_buffer_pages) = double_buffer_pages {
631            if r.is_ok() && opcode.transfer_controller_to_host() {
632                double_buffer_pages
633                    .copy_to_guest_memory(guest_memory, mem)
634                    .map_err(RequestError::Memory)?;
635            }
636        }
637        r
638    }
639
640    async fn make_prp(
641        &self,
642        offset: u64,
643        mut iovas: impl ExactSizeIterator<Item = u64>,
644    ) -> Prp<'_> {
645        let mut prp_pages = None;
646        let dptr = match iovas.len() {
647            0 => [INVALID_PAGE_ADDR; 2],
648            1 => [iovas.next().unwrap() + offset, INVALID_PAGE_ADDR],
649            2 => [iovas.next().unwrap() + offset, iovas.next().unwrap()],
650            _ => {
651                let a = iovas.next().unwrap();
652                assert!(iovas.len() <= 4096);
653                let prp = self
654                    .alloc
655                    .alloc_pages(1)
656                    .await
657                    .expect("pool cap is >= 1 page");
658
659                let prp_addr = prp.physical_address(0);
660                let page = prp.page_as_slice(0);
661                for (iova, dest) in iovas.zip(page.chunks_exact(8)) {
662                    dest.atomic_write_obj(&iova.to_le_bytes());
663                }
664                prp_pages = Some(prp);
665                [a + offset, prp_addr]
666            }
667        };
668        Prp {
669            dptr,
670            _pages: prp_pages,
671        }
672    }
673
674    pub async fn issue_neither(
675        &self,
676        mut command: spec::Command,
677    ) -> Result<spec::Completion, RequestError> {
678        command.dptr = [INVALID_PAGE_ADDR; 2];
679        self.issue_raw(command).await
680    }
681
682    pub async fn issue_in(
683        &self,
684        mut command: spec::Command,
685        data: &[u8],
686    ) -> Result<spec::Completion, RequestError> {
687        let mem = self
688            .alloc
689            .alloc_bytes(data.len())
690            .await
691            .expect("pool cap is >= 1 page");
692
693        mem.write(data);
694        assert_eq!(
695            mem.page_count(),
696            1,
697            "larger requests not currently supported"
698        );
699        let prp = Prp {
700            dptr: [mem.physical_address(0), INVALID_PAGE_ADDR],
701            _pages: None,
702        };
703        command.dptr = prp.dptr;
704        self.issue_raw(command).await
705    }
706
707    pub async fn issue_out(
708        &self,
709        mut command: spec::Command,
710        data: &mut [u8],
711    ) -> Result<spec::Completion, RequestError> {
712        let mem = self
713            .alloc
714            .alloc_bytes(data.len())
715            .await
716            .expect("pool cap is sufficient");
717
718        let prp = self
719            .make_prp(0, (0..mem.page_count()).map(|i| mem.physical_address(i)))
720            .await;
721
722        command.dptr = prp.dptr;
723        let completion = self.issue_raw(command).await;
724        mem.read(data);
725        completion
726    }
727}
728
729struct Prp<'a> {
730    dptr: [u64; 2],
731    _pages: Option<ScopedPages<'a>>,
732}
733
734#[derive(Inspect)]
735struct PendingCommands {
736    /// Mapping from the low bits of cid to pending command.
737    #[inspect(iter_by_key)]
738    commands: Slab<PendingCommand>,
739    #[inspect(hex)]
740    next_cid_high_bits: Wrapping<u16>,
741    qid: u16,
742}
743
744#[derive(Inspect)]
745struct PendingCommand {
746    // Keep the command around for diagnostics.
747    command: spec::Command,
748    #[inspect(skip)]
749    respond: Rpc<(), spec::Completion>,
750}
751
752enum Req {
753    Command(Rpc<spec::Command, spec::Completion>),
754    Inspect(inspect::Deferred),
755    Save(Rpc<(), Result<QueueHandlerSavedState, anyhow::Error>>),
756    SendAer(),
757    NextAen(Rpc<(), AsynchronousEventRequestDw0>),
758}
759
760/// Functionality for an AER handler. The default implementation
761/// represents a NoOp handler with functions on the critical path compiled out
762/// for efficiency and should be used for IO Queues.
763pub trait AerHandler: Send + Sync + 'static {
764    /// Given a completion command, if the command pertains to a pending AEN,
765    /// process it.
766    #[inline]
767    fn handle_completion(&mut self, _completion: &nvme_spec::Completion) {}
768    /// Handle a request from the driver to get the most-recent undelivered AEN
769    /// or wait for the next one.
770    fn handle_aen_request(&mut self, _rpc: Rpc<(), AsynchronousEventRequestDw0>) {}
771    /// Update the CID that the handler is awaiting an AEN on.
772    fn update_awaiting_cid(&mut self, _cid: u16) {}
773    /// Returns whether an AER needs to sent to the controller or not. Since
774    /// this is the only function on the critical path, attempt to inline it.
775    #[inline]
776    fn poll_send_aer(&self) -> bool {
777        false
778    }
779    fn save(&self) -> Option<AerHandlerSavedState> {
780        None
781    }
782    fn restore(&mut self, _state: &Option<AerHandlerSavedState>) {}
783}
784
785/// Admin queue AER handler. Ensures a single outstanding AER and persists state
786/// across save/restore to process AENs received during servicing.
787pub struct AdminAerHandler {
788    last_aen: Option<AsynchronousEventRequestDw0>,
789    await_aen_cid: Option<u16>,
790    send_aen: Option<Rpc<(), AsynchronousEventRequestDw0>>, // Channel to return AENs on.
791    failed: bool, // If the failed state is reached, it will stop looping until save/restore.
792}
793
794impl AdminAerHandler {
795    pub fn new() -> Self {
796        Self {
797            last_aen: None,
798            await_aen_cid: None,
799            send_aen: None,
800            failed: false,
801        }
802    }
803}
804
805impl AerHandler for AdminAerHandler {
806    fn handle_completion(&mut self, completion: &nvme_spec::Completion) {
807        if let Some(await_aen_cid) = self.await_aen_cid
808            && completion.cid == await_aen_cid
809            && !self.failed
810        {
811            self.await_aen_cid = None;
812
813            // If error, cleanup and stop processing AENs.
814            if completion.status.status() != 0 {
815                self.failed = true;
816                self.last_aen = None;
817                return;
818            }
819            // Complete the AEN or pend it.
820            let aen = AsynchronousEventRequestDw0::from_bits(completion.dw0);
821            if let Some(send_aen) = self.send_aen.take() {
822                send_aen.complete(aen);
823            } else {
824                self.last_aen = Some(aen);
825            }
826        }
827    }
828
829    fn handle_aen_request(&mut self, rpc: Rpc<(), AsynchronousEventRequestDw0>) {
830        if let Some(aen) = self.last_aen.take() {
831            rpc.complete(aen);
832        } else {
833            self.send_aen = Some(rpc); // Save driver request to be completed later.
834        }
835    }
836
837    fn poll_send_aer(&self) -> bool {
838        self.await_aen_cid.is_none() && !self.failed
839    }
840
841    fn update_awaiting_cid(&mut self, cid: u16) {
842        if self.await_aen_cid.is_some() {
843            panic!(
844                "already awaiting on AEN with cid {}",
845                self.await_aen_cid.unwrap()
846            );
847        }
848        self.await_aen_cid = Some(cid);
849    }
850
851    fn save(&self) -> Option<AerHandlerSavedState> {
852        Some(AerHandlerSavedState {
853            last_aen: self.last_aen.map(AsynchronousEventRequestDw0::into_bits), // Save as u32
854            await_aen_cid: self.await_aen_cid,
855        })
856    }
857
858    fn restore(&mut self, state: &Option<AerHandlerSavedState>) {
859        if let Some(state) = state {
860            let AerHandlerSavedState {
861                last_aen,
862                await_aen_cid,
863            } = state;
864            self.last_aen = last_aen.map(AsynchronousEventRequestDw0::from_bits); // Restore from u32
865            self.await_aen_cid = *await_aen_cid;
866        }
867    }
868}
869
870/// No-op AER handler. Should be only used for IO queues.
871pub struct NoOpAerHandler;
872impl AerHandler for NoOpAerHandler {
873    fn handle_aen_request(&mut self, _rpc: Rpc<(), AsynchronousEventRequestDw0>) {
874        panic!(
875            "no-op aer handler should never receive an aen request. This is likely a bug in the driver."
876        );
877    }
878
879    fn update_awaiting_cid(&mut self, _cid: u16) {
880        panic!(
881            "no-op aer handler should never be passed a cid to await. This is likely a bug in the driver."
882        );
883    }
884}
885
886#[derive(Inspect)]
887struct QueueHandler<A: AerHandler> {
888    sq: SubmissionQueue,
889    cq: CompletionQueue,
890    commands: PendingCommands,
891    stats: QueueStats,
892    drain_after_restore: bool,
893    #[inspect(skip)]
894    aer_handler: A,
895}
896
897#[derive(Inspect, Default)]
898struct QueueStats {
899    issued: Counter,
900    completed: Counter,
901    interrupts: Counter,
902}
903
904impl<A: AerHandler> QueueHandler<A> {
905    async fn run(
906        &mut self,
907        registers: &DeviceRegisters<impl DeviceBacking>,
908        mut recv: mesh::Receiver<Req>,
909        interrupt: &mut DeviceInterrupt,
910    ) {
911        loop {
912            enum Event {
913                Request(Req),
914                Completion(spec::Completion),
915            }
916
917            let event = if !self.drain_after_restore {
918                // Normal processing of the requests and completions.
919                poll_fn(|cx| {
920                    if !self.sq.is_full() && !self.commands.is_full() {
921                        // Prioritize sending AERs to keep the cycle going
922                        if self.aer_handler.poll_send_aer() {
923                            return Poll::Ready(Event::Request(Req::SendAer()));
924                        }
925                        if let Poll::Ready(Some(req)) = recv.poll_next_unpin(cx) {
926                            return Event::Request(req).into();
927                        }
928                    }
929                    while !self.commands.is_empty() {
930                        if let Some(completion) = self.cq.read() {
931                            return Event::Completion(completion).into();
932                        }
933                        if interrupt.poll(cx).is_pending() {
934                            break;
935                        }
936                        self.stats.interrupts.increment();
937                    }
938                    self.sq.commit(registers);
939                    self.cq.commit(registers);
940                    Poll::Pending
941                })
942                .await
943            } else {
944                // Only process in-flight completions.
945                poll_fn(|cx| {
946                    while !self.commands.is_empty() {
947                        if let Some(completion) = self.cq.read() {
948                            return Event::Completion(completion).into();
949                        }
950                        if interrupt.poll(cx).is_pending() {
951                            break;
952                        }
953                        self.stats.interrupts.increment();
954                    }
955                    self.cq.commit(registers);
956                    Poll::Pending
957                })
958                .await
959            };
960
961            match event {
962                Event::Request(req) => match req {
963                    Req::Command(rpc) => {
964                        let (mut command, respond) = rpc.split();
965                        self.commands.insert(&mut command, respond);
966                        self.sq.write(command).unwrap();
967                        self.stats.issued.increment();
968                    }
969                    Req::Inspect(deferred) => deferred.inspect(&self),
970                    Req::Save(queue_state) => {
971                        queue_state.complete(self.save().await);
972                        // Do not allow any more processing after save completed.
973                        break;
974                    }
975                    Req::NextAen(rpc) => {
976                        self.aer_handler.handle_aen_request(rpc);
977                    }
978                    Req::SendAer() => {
979                        let mut command = admin_cmd(spec::AdminOpcode::ASYNCHRONOUS_EVENT_REQUEST);
980                        self.commands.insert(&mut command, Rpc::detached(()));
981                        self.aer_handler.update_awaiting_cid(command.cdw0.cid());
982                        self.sq.write(command).unwrap();
983                        self.stats.issued.increment();
984                    }
985                },
986                Event::Completion(completion) => {
987                    assert_eq!(completion.sqid, self.sq.id());
988                    let respond = self.commands.remove(completion.cid);
989                    if self.drain_after_restore && self.commands.is_empty() {
990                        // Switch to normal processing mode once all in-flight commands completed.
991                        self.drain_after_restore = false;
992                    }
993                    self.sq.update_head(completion.sqhd);
994                    self.aer_handler.handle_completion(&completion);
995                    respond.complete(completion);
996                    self.stats.completed.increment();
997                }
998            }
999        }
1000    }
1001
1002    /// Save queue data for servicing.
1003    pub async fn save(&self) -> anyhow::Result<QueueHandlerSavedState> {
1004        // The data is collected from both QueuePair and QueueHandler.
1005        Ok(QueueHandlerSavedState {
1006            sq_state: self.sq.save(),
1007            cq_state: self.cq.save(),
1008            pending_cmds: self.commands.save(),
1009            aer_handler: self.aer_handler.save(),
1010        })
1011    }
1012
1013    /// Restore queue data after servicing.
1014    pub fn restore(
1015        sq_mem_block: MemoryBlock,
1016        cq_mem_block: MemoryBlock,
1017        saved_state: &QueueHandlerSavedState,
1018        mut aer_handler: A,
1019    ) -> anyhow::Result<Self> {
1020        let QueueHandlerSavedState {
1021            sq_state,
1022            cq_state,
1023            pending_cmds,
1024            aer_handler: aer_handler_saved_state,
1025        } = saved_state;
1026
1027        aer_handler.restore(aer_handler_saved_state);
1028
1029        Ok(Self {
1030            sq: SubmissionQueue::restore(sq_mem_block, sq_state)?,
1031            cq: CompletionQueue::restore(cq_mem_block, cq_state)?,
1032            commands: PendingCommands::restore(pending_cmds, sq_state.sqid)?,
1033            stats: Default::default(),
1034            // Only drain pending commands for I/O queues.
1035            // Admin queue is expected to have pending Async Event requests.
1036            drain_after_restore: sq_state.sqid != 0 && !pending_cmds.commands.is_empty(),
1037            aer_handler,
1038        })
1039    }
1040}
1041
1042pub(crate) fn admin_cmd(opcode: spec::AdminOpcode) -> spec::Command {
1043    spec::Command {
1044        cdw0: spec::Cdw0::new().with_opcode(opcode.0),
1045        ..FromZeros::new_zeroed()
1046    }
1047}