1use super::spec;
7use crate::driver::save_restore::AerHandlerSavedState;
8use crate::driver::save_restore::Error;
9use crate::driver::save_restore::PendingCommandSavedState;
10use crate::driver::save_restore::PendingCommandsSavedState;
11use crate::driver::save_restore::QueueHandlerSavedState;
12use crate::driver::save_restore::QueuePairSavedState;
13use crate::queues::CompletionQueue;
14use crate::queues::SubmissionQueue;
15use crate::registers::DeviceRegisters;
16use anyhow::Context;
17use futures::StreamExt;
18use guestmem::GuestMemory;
19use guestmem::GuestMemoryError;
20use guestmem::ranges::PagedRange;
21use inspect::Inspect;
22use inspect_counters::Counter;
23use mesh::rpc::Rpc;
24use mesh::rpc::RpcError;
25use mesh::rpc::RpcSend;
26use nvme_spec::AsynchronousEventRequestDw0;
27use pal_async::driver::SpawnDriver;
28use safeatomic::AtomicSliceOps;
29use slab::Slab;
30use std::future::poll_fn;
31use std::num::Wrapping;
32use std::sync::Arc;
33use std::task::Poll;
34use task_control::AsyncRun;
35use task_control::TaskControl;
36use thiserror::Error;
37use user_driver::DeviceBacking;
38use user_driver::interrupt::DeviceInterrupt;
39use user_driver::memory::MemoryBlock;
40use user_driver::memory::PAGE_SIZE;
41use user_driver::memory::PAGE_SIZE64;
42use user_driver::page_allocator::PageAllocator;
43use user_driver::page_allocator::ScopedPages;
44use zerocopy::FromZeros;
45
46const INVALID_PAGE_ADDR: u64 = !(PAGE_SIZE as u64 - 1);
48
49const SQ_ENTRY_SIZE: usize = size_of::<spec::Command>();
50const CQ_ENTRY_SIZE: usize = size_of::<spec::Completion>();
51const SQ_SIZE: usize = PAGE_SIZE * 4;
53const CQ_SIZE: usize = PAGE_SIZE;
55pub const MAX_SQ_ENTRIES: u16 = (SQ_SIZE / SQ_ENTRY_SIZE) as u16;
57pub const MAX_CQ_ENTRIES: u16 = (CQ_SIZE / CQ_ENTRY_SIZE) as u16;
59const PER_QUEUE_PAGES_BOUNCE_BUFFER: usize = 128;
61const PER_QUEUE_PAGES_NO_BOUNCE_BUFFER: usize = 64;
63const SQ_ENTRIES_PER_PAGE: usize = PAGE_SIZE / SQ_ENTRY_SIZE;
65
66#[derive(Inspect)]
67pub(crate) struct QueuePair<A: AerHandler, D: DeviceBacking> {
68 #[inspect(skip)]
69 task: TaskControl<QueueHandlerLoop<A, D>, ()>,
70 #[inspect(flatten, with = "|x| inspect::send(&x.send, Req::Inspect)")]
71 issuer: Arc<Issuer>,
72 #[inspect(skip)]
73 mem: MemoryBlock,
74 #[inspect(skip)]
75 qid: u16,
76 #[inspect(skip)]
77 sq_entries: u16,
78 #[inspect(skip)]
79 cq_entries: u16,
80 sq_addr: u64,
81 cq_addr: u64,
82}
83
84impl PendingCommands {
85 const CID_KEY_BITS: u32 = 10;
86 const CID_KEY_MASK: u16 = (1 << Self::CID_KEY_BITS) - 1;
87 const MAX_CIDS: usize = 1 << Self::CID_KEY_BITS;
88 const CID_SEQ_OFFSET: Wrapping<u16> = Wrapping(1 << Self::CID_KEY_BITS);
89
90 fn new(qid: u16) -> Self {
91 Self {
92 commands: Slab::new(),
93 next_cid_high_bits: Wrapping(0),
94 qid,
95 }
96 }
97
98 fn is_full(&self) -> bool {
99 self.commands.len() >= Self::MAX_CIDS
100 }
101
102 fn is_empty(&self) -> bool {
103 self.commands.is_empty()
104 }
105
106 fn insert(&mut self, command: &mut spec::Command, respond: Rpc<(), spec::Completion>) {
108 let entry = self.commands.vacant_entry();
109 assert!(entry.key() < Self::MAX_CIDS);
110 assert_eq!(self.next_cid_high_bits % Self::CID_SEQ_OFFSET, Wrapping(0));
111 let cid = entry.key() as u16 | self.next_cid_high_bits.0;
112 self.next_cid_high_bits += Self::CID_SEQ_OFFSET;
113 command.cdw0.set_cid(cid);
114 entry.insert(PendingCommand {
115 command: *command,
116 respond,
117 });
118 }
119
120 fn remove(&mut self, cid: u16) -> Rpc<(), spec::Completion> {
121 let command = self
122 .commands
123 .try_remove((cid & Self::CID_KEY_MASK) as usize)
124 .unwrap_or_else(|| panic!("completion for unknown cid: qid={}, cid={}", self.qid, cid));
125 assert_eq!(
126 command.command.cdw0.cid(),
127 cid,
128 "cid sequence number mismatch: qid={}, command_opcode={:#x}",
129 self.qid,
130 command.command.cdw0.opcode(),
131 );
132 command.respond
133 }
134
135 pub fn save(&self) -> PendingCommandsSavedState {
137 let commands: Vec<PendingCommandSavedState> = self
138 .commands
139 .iter()
140 .map(|(_index, cmd)| PendingCommandSavedState {
141 command: cmd.command,
142 })
143 .collect();
144 PendingCommandsSavedState {
145 commands,
146 next_cid_high_bits: self.next_cid_high_bits.0,
147 cid_key_bits: Self::CID_KEY_BITS,
149 }
150 }
151
152 pub fn restore(saved_state: &PendingCommandsSavedState, qid: u16) -> anyhow::Result<Self> {
154 let PendingCommandsSavedState {
155 commands,
156 next_cid_high_bits,
157 cid_key_bits: _, } = saved_state;
159
160 Ok(Self {
161 commands: commands
163 .iter()
164 .map(|state| {
165 (
168 (state.command.cdw0.cid() & Self::CID_KEY_MASK) as usize,
170 PendingCommand {
171 command: state.command,
172 respond: Rpc::detached(()),
173 },
174 )
175 })
176 .collect::<Slab<PendingCommand>>(),
177 next_cid_high_bits: Wrapping(*next_cid_high_bits),
178 qid,
179 })
180 }
181}
182
183struct QueueHandlerLoop<A: AerHandler, D: DeviceBacking> {
184 queue_handler: QueueHandler<A>,
185 registers: Arc<DeviceRegisters<D>>,
186 recv: Option<mesh::Receiver<Req>>,
187 interrupt: DeviceInterrupt,
188}
189
190impl<A: AerHandler, D: DeviceBacking> AsyncRun<()> for QueueHandlerLoop<A, D> {
191 async fn run(
192 &mut self,
193 stop: &mut task_control::StopTask<'_>,
194 _: &mut (),
195 ) -> Result<(), task_control::Cancelled> {
196 stop.until_stopped(async {
197 self.queue_handler
198 .run(
199 &self.registers,
200 self.recv.take().unwrap(),
201 &mut self.interrupt,
202 )
203 .await;
204 })
205 .await
206 }
207}
208
209impl<A: AerHandler, D: DeviceBacking> QueuePair<A, D> {
210 pub fn new(
220 spawner: impl SpawnDriver,
221 device: &D,
222 qid: u16,
223 sq_entries: u16,
224 cq_entries: u16,
225 interrupt: DeviceInterrupt,
226 registers: Arc<DeviceRegisters<D>>,
227 bounce_buffer: bool,
228 aer_handler: A,
229 ) -> anyhow::Result<Self> {
230 let total_size = SQ_SIZE
236 + CQ_SIZE
237 + if bounce_buffer {
238 PER_QUEUE_PAGES_BOUNCE_BUFFER * PAGE_SIZE
239 } else {
240 PER_QUEUE_PAGES_NO_BOUNCE_BUFFER * PAGE_SIZE
241 };
242 let dma_client = device.dma_client();
243
244 let mem = dma_client
245 .allocate_dma_buffer(total_size)
246 .context("failed to allocate memory for queues")?;
247
248 assert!(sq_entries <= MAX_SQ_ENTRIES);
249 assert!(cq_entries <= MAX_CQ_ENTRIES);
250
251 QueuePair::new_or_restore(
252 spawner,
253 qid,
254 sq_entries,
255 cq_entries,
256 interrupt,
257 registers,
258 mem,
259 None,
260 bounce_buffer,
261 aer_handler,
262 )
263 }
264
265 fn new_or_restore(
266 spawner: impl SpawnDriver,
267 qid: u16,
268 sq_entries: u16,
269 cq_entries: u16,
270 interrupt: DeviceInterrupt,
271 registers: Arc<DeviceRegisters<D>>,
272 mem: MemoryBlock,
273 saved_state: Option<&QueueHandlerSavedState>,
274 bounce_buffer: bool,
275 aer_handler: A,
276 ) -> anyhow::Result<Self> {
277 let sq_mem_block = mem.subblock(0, SQ_SIZE);
279 let cq_mem_block = mem.subblock(SQ_SIZE, CQ_SIZE);
280 let data_offset = SQ_SIZE + CQ_SIZE;
281
282 let (sq_is_contiguous, cq_is_contiguous) = (
301 sq_mem_block.contiguous_pfns(),
302 cq_mem_block.contiguous_pfns(),
303 );
304
305 let (sq_entries, cq_entries) = if !sq_is_contiguous || !cq_is_contiguous {
306 tracing::warn!(
307 qid,
308 sq_is_contiguous,
309 sq_mem_block.pfns = ?sq_mem_block.pfns(),
310 cq_is_contiguous,
311 cq_mem_block.pfns = ?cq_mem_block.pfns(),
312 "non-contiguous queue memory detected, falling back to single page queues"
313 );
314 (SQ_ENTRIES_PER_PAGE as u16, SQ_ENTRIES_PER_PAGE as u16)
318 } else {
319 (sq_entries, cq_entries)
320 };
321
322 let sq_addr = sq_mem_block.pfns()[0] * PAGE_SIZE64;
323 let cq_addr = cq_mem_block.pfns()[0] * PAGE_SIZE64;
324
325 let queue_handler = match saved_state {
326 Some(s) => QueueHandler::restore(sq_mem_block, cq_mem_block, s, aer_handler)?,
327 None => {
328 QueueHandler {
330 sq: SubmissionQueue::new(qid, sq_entries, sq_mem_block),
331 cq: CompletionQueue::new(qid, cq_entries, cq_mem_block),
332 commands: PendingCommands::new(qid),
333 stats: Default::default(),
334 drain_after_restore: false,
335 aer_handler,
336 }
337 }
338 };
339
340 let (send, recv) = mesh::channel();
341 let mut task = TaskControl::new(QueueHandlerLoop {
342 queue_handler,
343 registers,
344 recv: Some(recv),
345 interrupt,
346 });
347 task.insert(spawner, "nvme-queue", ());
348 task.start();
349
350 const fn pages_to_size_bytes(pages: usize) -> usize {
353 let size = pages * PAGE_SIZE;
354 assert!(
355 size >= 128 * 1024 + PAGE_SIZE,
356 "not enough room for an ATAPI IO plus a PRP list"
357 );
358 size
359 }
360
361 let alloc_len = if bounce_buffer {
368 const { pages_to_size_bytes(PER_QUEUE_PAGES_BOUNCE_BUFFER) }
369 } else {
370 const { pages_to_size_bytes(PER_QUEUE_PAGES_NO_BOUNCE_BUFFER) }
371 };
372
373 let alloc = PageAllocator::new(mem.subblock(data_offset, alloc_len));
374
375 Ok(Self {
376 task,
377 issuer: Arc::new(Issuer { send, alloc }),
378 mem,
379 qid,
380 sq_entries,
381 cq_entries,
382 sq_addr,
383 cq_addr,
384 })
385 }
386
387 pub fn sq_entries(&self) -> u16 {
389 self.sq_entries
390 }
391
392 pub fn cq_entries(&self) -> u16 {
394 self.cq_entries
395 }
396
397 pub fn sq_addr(&self) -> u64 {
398 self.sq_addr
399 }
400
401 pub fn cq_addr(&self) -> u64 {
402 self.cq_addr
403 }
404
405 pub fn issuer(&self) -> &Arc<Issuer> {
406 &self.issuer
407 }
408
409 pub async fn shutdown(mut self) -> impl Send {
410 self.task.stop().await;
411 self.task.into_inner().0.queue_handler
412 }
413
414 pub async fn save(&self) -> anyhow::Result<QueuePairSavedState> {
416 if self.mem.pfns().is_empty() {
418 return Err(Error::InvalidState.into());
419 }
420 let handler_data = self.issuer.send.call(Req::Save, ()).await??;
423
424 Ok(QueuePairSavedState {
425 mem_len: self.mem.len(),
426 base_pfn: self.mem.pfns()[0],
427 qid: self.qid,
428 sq_entries: self.sq_entries,
429 cq_entries: self.cq_entries,
430 handler_data,
431 })
432 }
433
434 pub fn restore(
436 spawner: impl SpawnDriver,
437 interrupt: DeviceInterrupt,
438 registers: Arc<DeviceRegisters<D>>,
439 mem: MemoryBlock,
440 saved_state: &QueuePairSavedState,
441 bounce_buffer: bool,
442 aer_handler: A,
443 ) -> anyhow::Result<Self> {
444 let QueuePairSavedState {
445 mem_len: _, base_pfn: _, qid,
448 sq_entries,
449 cq_entries,
450 handler_data,
451 } = saved_state;
452
453 QueuePair::new_or_restore(
454 spawner,
455 *qid,
456 *sq_entries,
457 *cq_entries,
458 interrupt,
459 registers,
460 mem,
461 Some(handler_data),
462 bounce_buffer,
463 aer_handler,
464 )
465 }
466}
467
468#[derive(Debug, Error)]
470#[expect(missing_docs)]
471pub enum RequestError {
472 #[error("queue pair is gone")]
473 Gone(#[source] RpcError),
474 #[error("nvme error")]
475 Nvme(#[source] NvmeError),
476 #[error("memory error")]
477 Memory(#[source] GuestMemoryError),
478 #[error("i/o too large for double buffering")]
479 TooLarge,
480}
481
482#[derive(Debug, Copy, Clone, PartialEq, Eq)]
483pub struct NvmeError(spec::Status);
484
485impl NvmeError {
486 pub fn status(&self) -> spec::Status {
487 self.0
488 }
489}
490
491impl From<spec::Status> for NvmeError {
492 fn from(value: spec::Status) -> Self {
493 Self(value)
494 }
495}
496
497impl std::error::Error for NvmeError {}
498
499impl std::fmt::Display for NvmeError {
500 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
501 match self.0.status_code_type() {
502 spec::StatusCodeType::GENERIC => write!(
503 f,
504 "NVMe SCT general error, SC: {:#x?}",
505 self.0.status_code()
506 ),
507 spec::StatusCodeType::COMMAND_SPECIFIC => {
508 write!(
509 f,
510 "NVMe SCT command-specific error, SC: {:#x?}",
511 self.0.status_code()
512 )
513 }
514 spec::StatusCodeType::MEDIA_ERROR => {
515 write!(f, "NVMe SCT media error, SC: {:#x?}", self.0.status_code())
516 }
517 spec::StatusCodeType::PATH_RELATED => {
518 write!(
519 f,
520 "NVMe SCT path-related error, SC: {:#x?}",
521 self.0.status_code()
522 )
523 }
524 spec::StatusCodeType::VENDOR_SPECIFIC => {
525 write!(
526 f,
527 "NVMe SCT vendor-specific error, SC: {:#x?}",
528 self.0.status_code()
529 )
530 }
531 _ => write!(
532 f,
533 "NVMe SCT unknown ({:#x?}), SC: {:#x?} (raw: {:#x?})",
534 self.0.status_code_type(),
535 self.0.status_code(),
536 self.0
537 ),
538 }
539 }
540}
541
542#[derive(Debug, Inspect)]
543pub struct Issuer {
544 #[inspect(skip)]
545 send: mesh::Sender<Req>,
546 alloc: PageAllocator,
547}
548
549impl Issuer {
550 pub async fn issue_raw(
551 &self,
552 command: spec::Command,
553 ) -> Result<spec::Completion, RequestError> {
554 match self.send.call(Req::Command, command).await {
555 Ok(completion) if completion.status.status() == 0 => Ok(completion),
556 Ok(completion) => Err(RequestError::Nvme(NvmeError(spec::Status(
557 completion.status.status(),
558 )))),
559 Err(err) => Err(RequestError::Gone(err)),
560 }
561 }
562
563 pub async fn issue_get_aen(&self) -> Result<AsynchronousEventRequestDw0, RequestError> {
564 match self.send.call(Req::NextAen, ()).await {
565 Ok(aen_completion) => Ok(aen_completion),
566 Err(err) => Err(RequestError::Gone(err)),
567 }
568 }
569
570 pub async fn issue_external(
571 &self,
572 mut command: spec::Command,
573 guest_memory: &GuestMemory,
574 mem: PagedRange<'_>,
575 ) -> Result<spec::Completion, RequestError> {
576 let mut double_buffer_pages = None;
577 let opcode = spec::Opcode(command.cdw0.opcode());
578 assert!(
579 opcode.transfer_controller_to_host()
580 || opcode.transfer_host_to_controller()
581 || mem.is_empty()
582 );
583
584 guest_memory
586 .probe_gpns(mem.gpns())
587 .map_err(RequestError::Memory)?;
588
589 let prp = if mem
590 .gpns()
591 .iter()
592 .all(|&gpn| guest_memory.iova(gpn * PAGE_SIZE64).is_some())
593 {
594 self.make_prp(
596 mem.offset() as u64,
597 mem.gpns()
598 .iter()
599 .map(|&gpn| guest_memory.iova(gpn * PAGE_SIZE64).unwrap()),
600 )
601 .await
602 } else {
603 tracing::debug!(opcode = opcode.0, size = mem.len(), "double buffering");
604
605 let double_buffer_pages = double_buffer_pages.insert(
608 self.alloc
609 .alloc_bytes(mem.len())
610 .await
611 .ok_or(RequestError::TooLarge)?,
612 );
613
614 if opcode.transfer_host_to_controller() {
615 double_buffer_pages
616 .copy_from_guest_memory(guest_memory, mem)
617 .map_err(RequestError::Memory)?;
618 }
619
620 self.make_prp(
621 0,
622 (0..double_buffer_pages.page_count())
623 .map(|i| double_buffer_pages.physical_address(i)),
624 )
625 .await
626 };
627
628 command.dptr = prp.dptr;
629 let r = self.issue_raw(command).await;
630 if let Some(double_buffer_pages) = double_buffer_pages {
631 if r.is_ok() && opcode.transfer_controller_to_host() {
632 double_buffer_pages
633 .copy_to_guest_memory(guest_memory, mem)
634 .map_err(RequestError::Memory)?;
635 }
636 }
637 r
638 }
639
640 async fn make_prp(
641 &self,
642 offset: u64,
643 mut iovas: impl ExactSizeIterator<Item = u64>,
644 ) -> Prp<'_> {
645 let mut prp_pages = None;
646 let dptr = match iovas.len() {
647 0 => [INVALID_PAGE_ADDR; 2],
648 1 => [iovas.next().unwrap() + offset, INVALID_PAGE_ADDR],
649 2 => [iovas.next().unwrap() + offset, iovas.next().unwrap()],
650 _ => {
651 let a = iovas.next().unwrap();
652 assert!(iovas.len() <= 4096);
653 let prp = self
654 .alloc
655 .alloc_pages(1)
656 .await
657 .expect("pool cap is >= 1 page");
658
659 let prp_addr = prp.physical_address(0);
660 let page = prp.page_as_slice(0);
661 for (iova, dest) in iovas.zip(page.chunks_exact(8)) {
662 dest.atomic_write_obj(&iova.to_le_bytes());
663 }
664 prp_pages = Some(prp);
665 [a + offset, prp_addr]
666 }
667 };
668 Prp {
669 dptr,
670 _pages: prp_pages,
671 }
672 }
673
674 pub async fn issue_neither(
675 &self,
676 mut command: spec::Command,
677 ) -> Result<spec::Completion, RequestError> {
678 command.dptr = [INVALID_PAGE_ADDR; 2];
679 self.issue_raw(command).await
680 }
681
682 pub async fn issue_in(
683 &self,
684 mut command: spec::Command,
685 data: &[u8],
686 ) -> Result<spec::Completion, RequestError> {
687 let mem = self
688 .alloc
689 .alloc_bytes(data.len())
690 .await
691 .expect("pool cap is >= 1 page");
692
693 mem.write(data);
694 assert_eq!(
695 mem.page_count(),
696 1,
697 "larger requests not currently supported"
698 );
699 let prp = Prp {
700 dptr: [mem.physical_address(0), INVALID_PAGE_ADDR],
701 _pages: None,
702 };
703 command.dptr = prp.dptr;
704 self.issue_raw(command).await
705 }
706
707 pub async fn issue_out(
708 &self,
709 mut command: spec::Command,
710 data: &mut [u8],
711 ) -> Result<spec::Completion, RequestError> {
712 let mem = self
713 .alloc
714 .alloc_bytes(data.len())
715 .await
716 .expect("pool cap is sufficient");
717
718 let prp = self
719 .make_prp(0, (0..mem.page_count()).map(|i| mem.physical_address(i)))
720 .await;
721
722 command.dptr = prp.dptr;
723 let completion = self.issue_raw(command).await;
724 mem.read(data);
725 completion
726 }
727}
728
729struct Prp<'a> {
730 dptr: [u64; 2],
731 _pages: Option<ScopedPages<'a>>,
732}
733
734#[derive(Inspect)]
735struct PendingCommands {
736 #[inspect(iter_by_key)]
738 commands: Slab<PendingCommand>,
739 #[inspect(hex)]
740 next_cid_high_bits: Wrapping<u16>,
741 qid: u16,
742}
743
744#[derive(Inspect)]
745struct PendingCommand {
746 command: spec::Command,
748 #[inspect(skip)]
749 respond: Rpc<(), spec::Completion>,
750}
751
752enum Req {
753 Command(Rpc<spec::Command, spec::Completion>),
754 Inspect(inspect::Deferred),
755 Save(Rpc<(), Result<QueueHandlerSavedState, anyhow::Error>>),
756 SendAer(),
757 NextAen(Rpc<(), AsynchronousEventRequestDw0>),
758}
759
760pub trait AerHandler: Send + Sync + 'static {
764 #[inline]
767 fn handle_completion(&mut self, _completion: &nvme_spec::Completion) {}
768 fn handle_aen_request(&mut self, _rpc: Rpc<(), AsynchronousEventRequestDw0>) {}
771 fn update_awaiting_cid(&mut self, _cid: u16) {}
773 #[inline]
776 fn poll_send_aer(&self) -> bool {
777 false
778 }
779 fn save(&self) -> Option<AerHandlerSavedState> {
780 None
781 }
782 fn restore(&mut self, _state: &Option<AerHandlerSavedState>) {}
783}
784
785pub struct AdminAerHandler {
788 last_aen: Option<AsynchronousEventRequestDw0>,
789 await_aen_cid: Option<u16>,
790 send_aen: Option<Rpc<(), AsynchronousEventRequestDw0>>, failed: bool, }
793
794impl AdminAerHandler {
795 pub fn new() -> Self {
796 Self {
797 last_aen: None,
798 await_aen_cid: None,
799 send_aen: None,
800 failed: false,
801 }
802 }
803}
804
805impl AerHandler for AdminAerHandler {
806 fn handle_completion(&mut self, completion: &nvme_spec::Completion) {
807 if let Some(await_aen_cid) = self.await_aen_cid
808 && completion.cid == await_aen_cid
809 && !self.failed
810 {
811 self.await_aen_cid = None;
812
813 if completion.status.status() != 0 {
815 self.failed = true;
816 self.last_aen = None;
817 return;
818 }
819 let aen = AsynchronousEventRequestDw0::from_bits(completion.dw0);
821 if let Some(send_aen) = self.send_aen.take() {
822 send_aen.complete(aen);
823 } else {
824 self.last_aen = Some(aen);
825 }
826 }
827 }
828
829 fn handle_aen_request(&mut self, rpc: Rpc<(), AsynchronousEventRequestDw0>) {
830 if let Some(aen) = self.last_aen.take() {
831 rpc.complete(aen);
832 } else {
833 self.send_aen = Some(rpc); }
835 }
836
837 fn poll_send_aer(&self) -> bool {
838 self.await_aen_cid.is_none() && !self.failed
839 }
840
841 fn update_awaiting_cid(&mut self, cid: u16) {
842 if self.await_aen_cid.is_some() {
843 panic!(
844 "already awaiting on AEN with cid {}",
845 self.await_aen_cid.unwrap()
846 );
847 }
848 self.await_aen_cid = Some(cid);
849 }
850
851 fn save(&self) -> Option<AerHandlerSavedState> {
852 Some(AerHandlerSavedState {
853 last_aen: self.last_aen.map(AsynchronousEventRequestDw0::into_bits), await_aen_cid: self.await_aen_cid,
855 })
856 }
857
858 fn restore(&mut self, state: &Option<AerHandlerSavedState>) {
859 if let Some(state) = state {
860 let AerHandlerSavedState {
861 last_aen,
862 await_aen_cid,
863 } = state;
864 self.last_aen = last_aen.map(AsynchronousEventRequestDw0::from_bits); self.await_aen_cid = *await_aen_cid;
866 }
867 }
868}
869
870pub struct NoOpAerHandler;
872impl AerHandler for NoOpAerHandler {
873 fn handle_aen_request(&mut self, _rpc: Rpc<(), AsynchronousEventRequestDw0>) {
874 panic!(
875 "no-op aer handler should never receive an aen request. This is likely a bug in the driver."
876 );
877 }
878
879 fn update_awaiting_cid(&mut self, _cid: u16) {
880 panic!(
881 "no-op aer handler should never be passed a cid to await. This is likely a bug in the driver."
882 );
883 }
884}
885
886#[derive(Inspect)]
887struct QueueHandler<A: AerHandler> {
888 sq: SubmissionQueue,
889 cq: CompletionQueue,
890 commands: PendingCommands,
891 stats: QueueStats,
892 drain_after_restore: bool,
893 #[inspect(skip)]
894 aer_handler: A,
895}
896
897#[derive(Inspect, Default)]
898struct QueueStats {
899 issued: Counter,
900 completed: Counter,
901 interrupts: Counter,
902}
903
904impl<A: AerHandler> QueueHandler<A> {
905 async fn run(
906 &mut self,
907 registers: &DeviceRegisters<impl DeviceBacking>,
908 mut recv: mesh::Receiver<Req>,
909 interrupt: &mut DeviceInterrupt,
910 ) {
911 loop {
912 enum Event {
913 Request(Req),
914 Completion(spec::Completion),
915 }
916
917 let event = if !self.drain_after_restore {
918 poll_fn(|cx| {
920 if !self.sq.is_full() && !self.commands.is_full() {
921 if self.aer_handler.poll_send_aer() {
923 return Poll::Ready(Event::Request(Req::SendAer()));
924 }
925 if let Poll::Ready(Some(req)) = recv.poll_next_unpin(cx) {
926 return Event::Request(req).into();
927 }
928 }
929 while !self.commands.is_empty() {
930 if let Some(completion) = self.cq.read() {
931 return Event::Completion(completion).into();
932 }
933 if interrupt.poll(cx).is_pending() {
934 break;
935 }
936 self.stats.interrupts.increment();
937 }
938 self.sq.commit(registers);
939 self.cq.commit(registers);
940 Poll::Pending
941 })
942 .await
943 } else {
944 poll_fn(|cx| {
946 while !self.commands.is_empty() {
947 if let Some(completion) = self.cq.read() {
948 return Event::Completion(completion).into();
949 }
950 if interrupt.poll(cx).is_pending() {
951 break;
952 }
953 self.stats.interrupts.increment();
954 }
955 self.cq.commit(registers);
956 Poll::Pending
957 })
958 .await
959 };
960
961 match event {
962 Event::Request(req) => match req {
963 Req::Command(rpc) => {
964 let (mut command, respond) = rpc.split();
965 self.commands.insert(&mut command, respond);
966 self.sq.write(command).unwrap();
967 self.stats.issued.increment();
968 }
969 Req::Inspect(deferred) => deferred.inspect(&self),
970 Req::Save(queue_state) => {
971 queue_state.complete(self.save().await);
972 break;
974 }
975 Req::NextAen(rpc) => {
976 self.aer_handler.handle_aen_request(rpc);
977 }
978 Req::SendAer() => {
979 let mut command = admin_cmd(spec::AdminOpcode::ASYNCHRONOUS_EVENT_REQUEST);
980 self.commands.insert(&mut command, Rpc::detached(()));
981 self.aer_handler.update_awaiting_cid(command.cdw0.cid());
982 self.sq.write(command).unwrap();
983 self.stats.issued.increment();
984 }
985 },
986 Event::Completion(completion) => {
987 assert_eq!(completion.sqid, self.sq.id());
988 let respond = self.commands.remove(completion.cid);
989 if self.drain_after_restore && self.commands.is_empty() {
990 self.drain_after_restore = false;
992 }
993 self.sq.update_head(completion.sqhd);
994 self.aer_handler.handle_completion(&completion);
995 respond.complete(completion);
996 self.stats.completed.increment();
997 }
998 }
999 }
1000 }
1001
1002 pub async fn save(&self) -> anyhow::Result<QueueHandlerSavedState> {
1004 Ok(QueueHandlerSavedState {
1006 sq_state: self.sq.save(),
1007 cq_state: self.cq.save(),
1008 pending_cmds: self.commands.save(),
1009 aer_handler: self.aer_handler.save(),
1010 })
1011 }
1012
1013 pub fn restore(
1015 sq_mem_block: MemoryBlock,
1016 cq_mem_block: MemoryBlock,
1017 saved_state: &QueueHandlerSavedState,
1018 mut aer_handler: A,
1019 ) -> anyhow::Result<Self> {
1020 let QueueHandlerSavedState {
1021 sq_state,
1022 cq_state,
1023 pending_cmds,
1024 aer_handler: aer_handler_saved_state,
1025 } = saved_state;
1026
1027 aer_handler.restore(aer_handler_saved_state);
1028
1029 Ok(Self {
1030 sq: SubmissionQueue::restore(sq_mem_block, sq_state)?,
1031 cq: CompletionQueue::restore(cq_mem_block, cq_state)?,
1032 commands: PendingCommands::restore(pending_cmds, sq_state.sqid)?,
1033 stats: Default::default(),
1034 drain_after_restore: sq_state.sqid != 0 && !pending_cmds.commands.is_empty(),
1037 aer_handler,
1038 })
1039 }
1040}
1041
1042pub(crate) fn admin_cmd(opcode: spec::AdminOpcode) -> spec::Command {
1043 spec::Command {
1044 cdw0: spec::Cdw0::new().with_opcode(opcode.0),
1045 ..FromZeros::new_zeroed()
1046 }
1047}