1use super::spec;
7use crate::driver::save_restore::AerHandlerSavedState;
8use crate::driver::save_restore::Error;
9use crate::driver::save_restore::PendingCommandSavedState;
10use crate::driver::save_restore::PendingCommandsSavedState;
11use crate::driver::save_restore::QueueHandlerSavedState;
12use crate::driver::save_restore::QueuePairSavedState;
13use crate::queues::CompletionQueue;
14use crate::queues::SubmissionQueue;
15use crate::registers::DeviceRegisters;
16use anyhow::Context;
17use futures::StreamExt;
18use guestmem::GuestMemory;
19use guestmem::GuestMemoryError;
20use guestmem::ranges::PagedRange;
21use inspect::Inspect;
22use inspect_counters::Counter;
23use mesh::rpc::Rpc;
24use mesh::rpc::RpcError;
25use mesh::rpc::RpcSend;
26use nvme_spec::AsynchronousEventRequestDw0;
27use pal_async::driver::SpawnDriver;
28use safeatomic::AtomicSliceOps;
29use slab::Slab;
30use std::future::poll_fn;
31use std::num::Wrapping;
32use std::sync::Arc;
33use std::task::Poll;
34use task_control::AsyncRun;
35use task_control::TaskControl;
36use thiserror::Error;
37use user_driver::DeviceBacking;
38use user_driver::interrupt::DeviceInterrupt;
39use user_driver::memory::MemoryBlock;
40use user_driver::memory::PAGE_SIZE;
41use user_driver::memory::PAGE_SIZE64;
42use user_driver::page_allocator::PageAllocator;
43use user_driver::page_allocator::ScopedPages;
44use zerocopy::FromZeros;
45
46const INVALID_PAGE_ADDR: u64 = !(PAGE_SIZE as u64 - 1);
48
49const SQ_ENTRY_SIZE: usize = size_of::<spec::Command>();
50const CQ_ENTRY_SIZE: usize = size_of::<spec::Completion>();
51const SQ_SIZE: usize = PAGE_SIZE * 4;
53const CQ_SIZE: usize = PAGE_SIZE;
55pub const MAX_SQ_ENTRIES: u16 = (SQ_SIZE / SQ_ENTRY_SIZE) as u16;
57pub const MAX_CQ_ENTRIES: u16 = (CQ_SIZE / CQ_ENTRY_SIZE) as u16;
59const PER_QUEUE_PAGES_BOUNCE_BUFFER: usize = 128;
61const PER_QUEUE_PAGES_NO_BOUNCE_BUFFER: usize = 64;
63const SQ_ENTRIES_PER_PAGE: usize = PAGE_SIZE / SQ_ENTRY_SIZE;
65
66#[derive(Inspect)]
67pub(crate) struct QueuePair<A: AerHandler, D: DeviceBacking> {
68 #[inspect(skip)]
69 task: TaskControl<QueueHandlerLoop<A, D>, ()>,
70 #[inspect(flatten, with = "|x| inspect::send(&x.send_req, Req::Inspect)")]
71 issuer: Arc<Issuer>,
72 #[inspect(skip)]
73 mem: MemoryBlock,
74 #[inspect(skip)]
75 device_id: String,
76 #[inspect(skip)]
77 qid: u16,
78 #[inspect(skip)]
79 sq_entries: u16,
80 #[inspect(skip)]
81 cq_entries: u16,
82 sq_addr: u64,
83 cq_addr: u64,
84}
85
86impl PendingCommands {
87 const CID_KEY_BITS: u32 = 10;
88 const CID_KEY_MASK: u16 = (1 << Self::CID_KEY_BITS) - 1;
89 const MAX_CIDS: usize = 1 << Self::CID_KEY_BITS;
90 const CID_SEQ_OFFSET: Wrapping<u16> = Wrapping(1 << Self::CID_KEY_BITS);
91
92 fn new(qid: u16) -> Self {
93 Self {
94 commands: Slab::new(),
95 next_cid_high_bits: Wrapping(0),
96 qid,
97 }
98 }
99
100 fn is_full(&self) -> bool {
101 self.commands.len() >= Self::MAX_CIDS
102 }
103
104 fn is_empty(&self) -> bool {
105 self.commands.is_empty()
106 }
107
108 fn len(&self) -> usize {
109 self.commands.len()
110 }
111
112 fn insert(&mut self, command: &mut spec::Command, respond: Rpc<(), spec::Completion>) {
114 let entry = self.commands.vacant_entry();
115 assert!(entry.key() < Self::MAX_CIDS);
116 assert_eq!(self.next_cid_high_bits % Self::CID_SEQ_OFFSET, Wrapping(0));
117 let cid = entry.key() as u16 | self.next_cid_high_bits.0;
118 self.next_cid_high_bits += Self::CID_SEQ_OFFSET;
119 command.cdw0.set_cid(cid);
120 entry.insert(PendingCommand {
121 command: *command,
122 respond,
123 });
124 }
125
126 fn remove(&mut self, cid: u16) -> Rpc<(), spec::Completion> {
127 let command = self
128 .commands
129 .try_remove((cid & Self::CID_KEY_MASK) as usize)
130 .unwrap_or_else(|| panic!("completion for unknown cid: qid={}, cid={}", self.qid, cid));
131 assert_eq!(
132 command.command.cdw0.cid(),
133 cid,
134 "cid sequence number mismatch: qid={}, command_opcode={:#x}",
135 self.qid,
136 command.command.cdw0.opcode(),
137 );
138 command.respond
139 }
140
141 pub fn save(&self) -> PendingCommandsSavedState {
143 let commands: Vec<PendingCommandSavedState> = self
144 .commands
145 .iter()
146 .map(|(_index, cmd)| PendingCommandSavedState {
147 command: cmd.command,
148 })
149 .collect();
150 PendingCommandsSavedState {
151 commands,
152 next_cid_high_bits: self.next_cid_high_bits.0,
153 cid_key_bits: Self::CID_KEY_BITS,
155 }
156 }
157
158 pub fn restore(saved_state: &PendingCommandsSavedState, qid: u16) -> anyhow::Result<Self> {
160 let PendingCommandsSavedState {
161 commands,
162 next_cid_high_bits,
163 cid_key_bits: _, } = saved_state;
165
166 Ok(Self {
167 commands: commands
169 .iter()
170 .map(|state| {
171 (
174 (state.command.cdw0.cid() & Self::CID_KEY_MASK) as usize,
176 PendingCommand {
177 command: state.command,
178 respond: Rpc::detached(()),
179 },
180 )
181 })
182 .collect::<Slab<PendingCommand>>(),
183 next_cid_high_bits: Wrapping(*next_cid_high_bits),
184 qid,
185 })
186 }
187}
188
189struct QueueHandlerLoop<A: AerHandler, D: DeviceBacking> {
190 queue_handler: QueueHandler<A>,
191 registers: Arc<DeviceRegisters<D>>,
192 recv_req: Option<mesh::Receiver<Req>>,
193 recv_cmd: Option<mesh::Receiver<Cmd>>,
194 interrupt: DeviceInterrupt,
195}
196
197impl<A: AerHandler, D: DeviceBacking> AsyncRun<()> for QueueHandlerLoop<A, D> {
198 async fn run(
199 &mut self,
200 stop: &mut task_control::StopTask<'_>,
201 _: &mut (),
202 ) -> Result<(), task_control::Cancelled> {
203 stop.until_stopped(async {
204 self.queue_handler
205 .run(
206 &self.registers,
207 self.recv_req.take().unwrap(),
208 self.recv_cmd.take().unwrap(),
209 &mut self.interrupt,
210 )
211 .await;
212 })
213 .await
214 }
215}
216
217impl<A: AerHandler, D: DeviceBacking> QueuePair<A, D> {
218 pub fn new(
228 spawner: impl SpawnDriver,
229 device: &D,
230 qid: u16,
231 sq_entries: u16,
232 cq_entries: u16,
233 interrupt: DeviceInterrupt,
234 registers: Arc<DeviceRegisters<D>>,
235 bounce_buffer: bool,
236 aer_handler: A,
237 ) -> anyhow::Result<Self> {
238 let total_size = SQ_SIZE
244 + CQ_SIZE
245 + if bounce_buffer {
246 PER_QUEUE_PAGES_BOUNCE_BUFFER * PAGE_SIZE
247 } else {
248 PER_QUEUE_PAGES_NO_BOUNCE_BUFFER * PAGE_SIZE
249 };
250 let dma_client = device.dma_client();
251
252 let mem = dma_client
253 .allocate_dma_buffer(total_size)
254 .context("failed to allocate memory for queues")?;
255
256 assert!(sq_entries <= MAX_SQ_ENTRIES);
257 assert!(cq_entries <= MAX_CQ_ENTRIES);
258
259 QueuePair::new_or_restore(
260 spawner,
261 device.id(),
262 qid,
263 sq_entries,
264 cq_entries,
265 interrupt,
266 registers,
267 mem,
268 None,
269 bounce_buffer,
270 aer_handler,
271 )
272 }
273
274 fn new_or_restore(
275 spawner: impl SpawnDriver,
276 device_id: &str,
277 qid: u16,
278 sq_entries: u16,
279 cq_entries: u16,
280 interrupt: DeviceInterrupt,
281 registers: Arc<DeviceRegisters<D>>,
282 mem: MemoryBlock,
283 saved_state: Option<&QueueHandlerSavedState>,
284 bounce_buffer: bool,
285 aer_handler: A,
286 ) -> anyhow::Result<Self> {
287 let sq_mem_block = mem.subblock(0, SQ_SIZE);
289 let cq_mem_block = mem.subblock(SQ_SIZE, CQ_SIZE);
290 let data_offset = SQ_SIZE + CQ_SIZE;
291
292 let (sq_is_contiguous, cq_is_contiguous) = (
311 sq_mem_block.contiguous_pfns(),
312 cq_mem_block.contiguous_pfns(),
313 );
314
315 let (sq_entries, cq_entries) = if !sq_is_contiguous || !cq_is_contiguous {
316 tracing::warn!(
317 qid,
318 sq_is_contiguous,
319 sq_mem_block.pfns = ?sq_mem_block.pfns(),
320 cq_is_contiguous,
321 cq_mem_block.pfns = ?cq_mem_block.pfns(),
322 "non-contiguous queue memory detected, falling back to single page queues"
323 );
324 (SQ_ENTRIES_PER_PAGE as u16, SQ_ENTRIES_PER_PAGE as u16)
328 } else {
329 (sq_entries, cq_entries)
330 };
331
332 let sq_addr = sq_mem_block.pfns()[0] * PAGE_SIZE64;
333 let cq_addr = cq_mem_block.pfns()[0] * PAGE_SIZE64;
334
335 let queue_handler = match saved_state {
336 Some(s) => {
337 QueueHandler::restore(sq_mem_block, cq_mem_block, s, aer_handler, device_id, qid)?
338 }
339 None => {
340 QueueHandler {
342 sq: SubmissionQueue::new(qid, sq_entries, sq_mem_block),
343 cq: CompletionQueue::new(qid, cq_entries, cq_mem_block),
344 commands: PendingCommands::new(qid),
345 stats: Default::default(),
346 drain_after_restore: false,
347 aer_handler,
348 device_id: device_id.into(),
349 qid,
350 }
351 }
352 };
353
354 let (send_req, recv_req) = mesh::channel();
355 let (send_cmd, recv_cmd) = mesh::channel();
356 let mut task = TaskControl::new(QueueHandlerLoop {
357 queue_handler,
358 registers,
359 recv_req: Some(recv_req),
360 recv_cmd: Some(recv_cmd),
361 interrupt,
362 });
363 task.insert(spawner, "nvme-queue", ());
364 task.start();
365
366 const fn pages_to_size_bytes(pages: usize) -> usize {
369 let size = pages * PAGE_SIZE;
370 assert!(
371 size >= 128 * 1024 + PAGE_SIZE,
372 "not enough room for an ATAPI IO plus a PRP list"
373 );
374 size
375 }
376
377 let alloc_len = if bounce_buffer {
384 const { pages_to_size_bytes(PER_QUEUE_PAGES_BOUNCE_BUFFER) }
385 } else {
386 const { pages_to_size_bytes(PER_QUEUE_PAGES_NO_BOUNCE_BUFFER) }
387 };
388
389 let alloc = PageAllocator::new(mem.subblock(data_offset, alloc_len));
390
391 Ok(Self {
392 task,
393 issuer: Arc::new(Issuer {
394 send_req,
395 send_cmd,
396 alloc,
397 }),
398 mem,
399 device_id: device_id.into(),
400 qid,
401 sq_entries,
402 cq_entries,
403 sq_addr,
404 cq_addr,
405 })
406 }
407
408 pub fn sq_entries(&self) -> u16 {
410 self.sq_entries
411 }
412
413 pub fn cq_entries(&self) -> u16 {
415 self.cq_entries
416 }
417
418 pub fn sq_addr(&self) -> u64 {
419 self.sq_addr
420 }
421
422 pub fn cq_addr(&self) -> u64 {
423 self.cq_addr
424 }
425
426 pub fn issuer(&self) -> &Arc<Issuer> {
427 &self.issuer
428 }
429
430 pub async fn shutdown(mut self) -> impl Send {
431 self.task.stop().await;
432 self.task.into_inner().0.queue_handler
433 }
434
435 pub async fn save(&self) -> anyhow::Result<QueuePairSavedState> {
437 tracing::info!(qid = self.qid, pci_id = ?self.device_id, "saving queue pair state");
438 if self.mem.pfns().is_empty() {
440 return Err(Error::InvalidState.into());
441 }
442 let handler_data = self.issuer.send_req.call(Req::Save, ()).await??;
445
446 Ok(QueuePairSavedState {
447 mem_len: self.mem.len(),
448 base_pfn: self.mem.pfns()[0],
449 qid: self.qid,
450 sq_entries: self.sq_entries,
451 cq_entries: self.cq_entries,
452 handler_data,
453 })
454 }
455
456 pub fn restore(
458 spawner: impl SpawnDriver,
459 interrupt: DeviceInterrupt,
460 registers: Arc<DeviceRegisters<D>>,
461 mem: MemoryBlock,
462 device_id: &str,
463 saved_state: &QueuePairSavedState,
464 bounce_buffer: bool,
465 aer_handler: A,
466 ) -> anyhow::Result<Self> {
467 let QueuePairSavedState {
468 mem_len: _, base_pfn: _, qid,
471 sq_entries,
472 cq_entries,
473 handler_data,
474 } = saved_state;
475
476 QueuePair::new_or_restore(
477 spawner,
478 device_id,
479 *qid,
480 *sq_entries,
481 *cq_entries,
482 interrupt,
483 registers,
484 mem,
485 Some(handler_data),
486 bounce_buffer,
487 aer_handler,
488 )
489 }
490}
491
492#[derive(Debug, Error)]
494#[expect(missing_docs)]
495pub enum RequestError {
496 #[error("queue pair is gone")]
497 Gone(#[source] RpcError),
498 #[error("nvme error")]
499 Nvme(#[source] NvmeError),
500 #[error("memory error")]
501 Memory(#[source] GuestMemoryError),
502 #[error("i/o too large for double buffering")]
503 TooLarge,
504}
505
506#[derive(Debug, Copy, Clone, PartialEq, Eq)]
507pub struct NvmeError(spec::Status);
508
509impl NvmeError {
510 pub fn status(&self) -> spec::Status {
511 self.0
512 }
513}
514
515impl From<spec::Status> for NvmeError {
516 fn from(value: spec::Status) -> Self {
517 Self(value)
518 }
519}
520
521impl std::error::Error for NvmeError {}
522
523impl std::fmt::Display for NvmeError {
524 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
525 match self.0.status_code_type() {
526 spec::StatusCodeType::GENERIC => write!(
527 f,
528 "NVMe SCT general error, SC: {:#x?}",
529 self.0.status_code()
530 ),
531 spec::StatusCodeType::COMMAND_SPECIFIC => {
532 write!(
533 f,
534 "NVMe SCT command-specific error, SC: {:#x?}",
535 self.0.status_code()
536 )
537 }
538 spec::StatusCodeType::MEDIA_ERROR => {
539 write!(f, "NVMe SCT media error, SC: {:#x?}", self.0.status_code())
540 }
541 spec::StatusCodeType::PATH_RELATED => {
542 write!(
543 f,
544 "NVMe SCT path-related error, SC: {:#x?}",
545 self.0.status_code()
546 )
547 }
548 spec::StatusCodeType::VENDOR_SPECIFIC => {
549 write!(
550 f,
551 "NVMe SCT vendor-specific error, SC: {:#x?}",
552 self.0.status_code()
553 )
554 }
555 _ => write!(
556 f,
557 "NVMe SCT unknown ({:#x?}), SC: {:#x?} (raw: {:#x?})",
558 self.0.status_code_type(),
559 self.0.status_code(),
560 self.0
561 ),
562 }
563 }
564}
565
566#[derive(Debug, Inspect)]
567pub struct Issuer {
568 #[inspect(skip)]
569 send_cmd: mesh::Sender<Cmd>,
570 #[inspect(skip)]
571 send_req: mesh::Sender<Req>,
572 alloc: PageAllocator,
573}
574
575impl Issuer {
576 pub async fn issue_raw(
577 &self,
578 command: spec::Command,
579 ) -> Result<spec::Completion, RequestError> {
580 match self.send_cmd.call(Cmd::Command, command).await {
581 Ok(completion) if completion.status.status() == 0 => Ok(completion),
582 Ok(completion) => Err(RequestError::Nvme(NvmeError(spec::Status(
583 completion.status.status(),
584 )))),
585 Err(err) => Err(RequestError::Gone(err)),
586 }
587 }
588
589 pub async fn issue_get_aen(&self) -> Result<AsynchronousEventRequestDw0, RequestError> {
590 match self.send_req.call_failable(Req::NextAen, ()).await {
591 Ok(aen_completion) => Ok(aen_completion),
592 Err(RpcError::Call(e)) => Err(e),
593 Err(RpcError::Channel(e)) => Err(RequestError::Gone(RpcError::Channel(e))),
594 }
595 }
596
597 pub async fn issue_external(
598 &self,
599 mut command: spec::Command,
600 guest_memory: &GuestMemory,
601 mem: PagedRange<'_>,
602 ) -> Result<spec::Completion, RequestError> {
603 let mut double_buffer_pages = None;
604 let opcode = spec::Opcode(command.cdw0.opcode());
605 assert!(
606 opcode.transfer_controller_to_host()
607 || opcode.transfer_host_to_controller()
608 || mem.is_empty()
609 );
610
611 guest_memory
613 .probe_gpns(mem.gpns())
614 .map_err(RequestError::Memory)?;
615
616 let prp = if mem
617 .gpns()
618 .iter()
619 .all(|&gpn| guest_memory.iova(gpn * PAGE_SIZE64).is_some())
620 {
621 self.make_prp(
623 mem.offset() as u64,
624 mem.gpns()
625 .iter()
626 .map(|&gpn| guest_memory.iova(gpn * PAGE_SIZE64).unwrap()),
627 )
628 .await
629 } else {
630 tracing::debug!(opcode = opcode.0, size = mem.len(), "double buffering");
631
632 let double_buffer_pages = double_buffer_pages.insert(
635 self.alloc
636 .alloc_bytes(mem.len())
637 .await
638 .ok_or(RequestError::TooLarge)?,
639 );
640
641 if opcode.transfer_host_to_controller() {
642 double_buffer_pages
643 .copy_from_guest_memory(guest_memory, mem)
644 .map_err(RequestError::Memory)?;
645 }
646
647 self.make_prp(
648 0,
649 (0..double_buffer_pages.page_count())
650 .map(|i| double_buffer_pages.physical_address(i)),
651 )
652 .await
653 };
654
655 command.dptr = prp.dptr;
656 let r = self.issue_raw(command).await;
657 if let Some(double_buffer_pages) = double_buffer_pages {
658 if r.is_ok() && opcode.transfer_controller_to_host() {
659 double_buffer_pages
660 .copy_to_guest_memory(guest_memory, mem)
661 .map_err(RequestError::Memory)?;
662 }
663 }
664 r
665 }
666
667 async fn make_prp(
668 &self,
669 offset: u64,
670 mut iovas: impl ExactSizeIterator<Item = u64>,
671 ) -> Prp<'_> {
672 let mut prp_pages = None;
673 let dptr = match iovas.len() {
674 0 => [INVALID_PAGE_ADDR; 2],
675 1 => [iovas.next().unwrap() + offset, INVALID_PAGE_ADDR],
676 2 => [iovas.next().unwrap() + offset, iovas.next().unwrap()],
677 _ => {
678 let a = iovas.next().unwrap();
679 assert!(iovas.len() <= 4096);
680 let prp = self
681 .alloc
682 .alloc_pages(1)
683 .await
684 .expect("pool capacity is >= 1 page");
685
686 let prp_addr = prp.physical_address(0);
687 let page = prp.page_as_slice(0);
688 for (iova, dest) in iovas.zip(page.chunks_exact(8)) {
689 dest.atomic_write_obj(&iova.to_le_bytes());
690 }
691 prp_pages = Some(prp);
692 [a + offset, prp_addr]
693 }
694 };
695 Prp {
696 dptr,
697 _pages: prp_pages,
698 }
699 }
700
701 pub async fn issue_neither(
702 &self,
703 mut command: spec::Command,
704 ) -> Result<spec::Completion, RequestError> {
705 command.dptr = [INVALID_PAGE_ADDR; 2];
706 self.issue_raw(command).await
707 }
708
709 pub async fn issue_in(
710 &self,
711 mut command: spec::Command,
712 data: &[u8],
713 ) -> Result<spec::Completion, RequestError> {
714 let mem = self
715 .alloc
716 .alloc_bytes(data.len())
717 .await
718 .expect("pool cap is >= 1 page");
719
720 mem.write(data);
721 assert_eq!(
722 mem.page_count(),
723 1,
724 "larger requests not currently supported"
725 );
726 let prp = Prp {
727 dptr: [mem.physical_address(0), INVALID_PAGE_ADDR],
728 _pages: None,
729 };
730 command.dptr = prp.dptr;
731 self.issue_raw(command).await
732 }
733
734 pub async fn issue_out(
735 &self,
736 mut command: spec::Command,
737 data: &mut [u8],
738 ) -> Result<spec::Completion, RequestError> {
739 let mem = self
740 .alloc
741 .alloc_bytes(data.len())
742 .await
743 .expect("pool capacity is sufficient");
744
745 let prp = self
746 .make_prp(0, (0..mem.page_count()).map(|i| mem.physical_address(i)))
747 .await;
748
749 command.dptr = prp.dptr;
750 let completion = self.issue_raw(command).await;
751 mem.read(data);
752 completion
753 }
754}
755
756struct Prp<'a> {
757 dptr: [u64; 2],
758 _pages: Option<ScopedPages<'a>>,
759}
760
761#[derive(Inspect)]
762struct PendingCommands {
763 #[inspect(iter_by_key)]
765 commands: Slab<PendingCommand>,
766 #[inspect(hex)]
767 next_cid_high_bits: Wrapping<u16>,
768 qid: u16,
769}
770
771#[derive(Inspect)]
772struct PendingCommand {
773 command: spec::Command,
775 #[inspect(skip)]
776 respond: Rpc<(), spec::Completion>,
777}
778
779enum Req {
783 Save(Rpc<(), Result<QueueHandlerSavedState, anyhow::Error>>),
784 Inspect(inspect::Deferred),
785 NextAen(Rpc<(), Result<AsynchronousEventRequestDw0, RequestError>>),
786}
787
788enum Cmd {
791 Command(Rpc<spec::Command, spec::Completion>),
792 SendAer(),
793}
794
795pub trait AerHandler: Send + Sync + 'static {
799 #[inline]
802 fn handle_completion(&mut self, _completion: &nvme_spec::Completion) {}
803 fn handle_aen_request(
806 &mut self,
807 _rpc: Rpc<(), Result<AsynchronousEventRequestDw0, RequestError>>,
808 ) {
809 }
810 fn update_awaiting_cid(&mut self, _cid: u16) {}
812 #[inline]
815 fn poll_send_aer(&self) -> bool {
816 false
817 }
818 fn save(&self) -> Option<AerHandlerSavedState> {
819 None
820 }
821 fn restore(&mut self, _state: &Option<AerHandlerSavedState>) {}
822}
823
824pub struct AdminAerHandler {
827 last_aen: Option<AsynchronousEventRequestDw0>,
828 await_aen_cid: Option<u16>,
829 send_aen: Option<Rpc<(), Result<AsynchronousEventRequestDw0, RequestError>>>, failed_status: Option<spec::Status>, }
832
833impl AdminAerHandler {
834 pub fn new() -> Self {
835 Self {
836 last_aen: None,
837 await_aen_cid: None,
838 send_aen: None,
839 failed_status: None,
840 }
841 }
842}
843
844impl AerHandler for AdminAerHandler {
845 fn handle_completion(&mut self, completion: &nvme_spec::Completion) {
846 if let Some(await_aen_cid) = self.await_aen_cid
847 && completion.cid == await_aen_cid
848 && self.failed_status.is_none()
849 {
850 self.await_aen_cid = None;
851
852 if completion.status.status() != 0 {
854 self.last_aen = None;
855 let failed_status = spec::Status(completion.status.status());
856 self.failed_status = Some(failed_status);
857 if let Some(send_aen) = self.send_aen.take() {
858 send_aen.complete(Err(RequestError::Nvme(NvmeError(failed_status))));
859 }
860 return;
861 }
862 let aen = AsynchronousEventRequestDw0::from_bits(completion.dw0);
864 if let Some(send_aen) = self.send_aen.take() {
865 send_aen.complete(Ok(aen));
866 } else {
867 self.last_aen = Some(aen);
868 }
869 }
870 }
871
872 fn handle_aen_request(
873 &mut self,
874 rpc: Rpc<(), Result<AsynchronousEventRequestDw0, RequestError>>,
875 ) {
876 if let Some(aen) = self.last_aen.take() {
877 rpc.complete(Ok(aen));
878 } else if let Some(failed_status) = self.failed_status {
879 rpc.complete(Err(RequestError::Nvme(NvmeError(failed_status))));
880 } else {
881 self.send_aen = Some(rpc); }
883 }
884
885 fn poll_send_aer(&self) -> bool {
886 self.await_aen_cid.is_none() && self.failed_status.is_none()
887 }
888
889 fn update_awaiting_cid(&mut self, cid: u16) {
890 if let Some(await_aen_cid) = self.await_aen_cid {
891 panic!("already awaiting on AEN with cid {}", await_aen_cid);
892 }
893 self.await_aen_cid = Some(cid);
894 }
895
896 fn save(&self) -> Option<AerHandlerSavedState> {
897 Some(AerHandlerSavedState {
898 last_aen: self.last_aen.map(AsynchronousEventRequestDw0::into_bits), await_aen_cid: self.await_aen_cid,
900 })
901 }
902
903 fn restore(&mut self, state: &Option<AerHandlerSavedState>) {
904 if let Some(state) = state {
905 let AerHandlerSavedState {
906 last_aen,
907 await_aen_cid,
908 } = state;
909 self.last_aen = last_aen.map(AsynchronousEventRequestDw0::from_bits); self.await_aen_cid = *await_aen_cid;
911 }
912 }
913}
914
915pub struct NoOpAerHandler;
917impl AerHandler for NoOpAerHandler {
918 fn handle_aen_request(
919 &mut self,
920 _rpc: Rpc<(), Result<AsynchronousEventRequestDw0, RequestError>>,
921 ) {
922 panic!(
923 "no-op aer handler should never receive an aen request. This is likely a bug in the driver."
924 );
925 }
926
927 fn update_awaiting_cid(&mut self, _cid: u16) {
928 panic!(
929 "no-op aer handler should never be passed a cid to await. This is likely a bug in the driver."
930 );
931 }
932}
933
934#[derive(Inspect)]
935struct QueueHandler<A: AerHandler> {
936 sq: SubmissionQueue,
937 cq: CompletionQueue,
938 commands: PendingCommands,
939 stats: QueueStats,
940 drain_after_restore: bool,
941 #[inspect(skip)]
942 aer_handler: A,
943 device_id: String,
944 qid: u16,
945}
946
947#[derive(Inspect, Default)]
948struct QueueStats {
949 issued: Counter,
950 completed: Counter,
951 interrupts: Counter,
952}
953
954impl<A: AerHandler> QueueHandler<A> {
955 async fn run(
956 &mut self,
957 registers: &DeviceRegisters<impl DeviceBacking>,
958 mut recv_req: mesh::Receiver<Req>,
959 mut recv_cmd: mesh::Receiver<Cmd>,
960 interrupt: &mut DeviceInterrupt,
961 ) {
962 if self.drain_after_restore {
963 tracing::info!(pci_id = ?self.device_id, qid = self.qid, "Have {} outstanding IOs from before save, draining them before allowing new IO...", self.commands.len());
964 }
965
966 loop {
967 enum Event {
968 Request(Req),
969 Command(Cmd),
970 Completion(spec::Completion),
971 }
972
973 let event = if !self.drain_after_restore {
974 poll_fn(|cx| {
976 if !self.sq.is_full() && !self.commands.is_full() {
978 if self.aer_handler.poll_send_aer() {
980 return Event::Command(Cmd::SendAer()).into();
981 }
982 if let Poll::Ready(Some(cmd)) = recv_cmd.poll_next_unpin(cx) {
983 return Event::Command(cmd).into();
984 }
985 }
986 if let Poll::Ready(Some(req)) = recv_req.poll_next_unpin(cx) {
988 return Event::Request(req).into();
989 }
990 while !self.commands.is_empty() {
992 if let Some(completion) = self.cq.read() {
993 return Event::Completion(completion).into();
994 }
995 if interrupt.poll(cx).is_pending() {
996 break;
997 }
998 self.stats.interrupts.increment();
999 }
1000 self.sq.commit(registers);
1001 self.cq.commit(registers);
1002 Poll::Pending
1003 })
1004 .await
1005 } else {
1006 poll_fn(|cx| {
1008 if let Poll::Ready(Some(req)) = recv_req.poll_next_unpin(cx) {
1010 return Event::Request(req).into();
1011 }
1012
1013 while !self.commands.is_empty() {
1014 if let Some(completion) = self.cq.read() {
1015 return Event::Completion(completion).into();
1016 }
1017 if interrupt.poll(cx).is_pending() {
1018 break;
1019 }
1020 self.stats.interrupts.increment();
1021 }
1022 self.cq.commit(registers);
1023 Poll::Pending
1024 })
1025 .await
1026 };
1027
1028 match event {
1029 Event::Request(req) => match req {
1030 Req::Save(queue_state) => {
1031 tracing::info!(pci_id = ?self.device_id, qid = ?self.qid, "received save request, shutting down ...");
1032 queue_state.complete(self.save().await);
1033 break;
1035 }
1036 Req::Inspect(deferred) => deferred.inspect(&self),
1037 Req::NextAen(rpc) => {
1038 self.aer_handler.handle_aen_request(rpc);
1039 }
1040 },
1041 Event::Command(cmd) => match cmd {
1042 Cmd::Command(rpc) => {
1043 let (mut command, respond) = rpc.split();
1044 self.commands.insert(&mut command, respond);
1045 self.sq.write(command).unwrap();
1046 self.stats.issued.increment();
1047 }
1048 Cmd::SendAer() => {
1049 let mut command = admin_cmd(spec::AdminOpcode::ASYNCHRONOUS_EVENT_REQUEST);
1050 self.commands.insert(&mut command, Rpc::detached(()));
1051 self.aer_handler.update_awaiting_cid(command.cdw0.cid());
1052 self.sq.write(command).unwrap();
1053 self.stats.issued.increment();
1054 }
1055 },
1056 Event::Completion(completion) => {
1057 assert_eq!(completion.sqid, self.sq.id());
1058 let respond = self.commands.remove(completion.cid);
1059 if self.drain_after_restore && self.commands.is_empty() {
1060 tracing::info!(pci_id = ?self.device_id, qid = ?self.qid, "done with drain-after-restore");
1062 self.drain_after_restore = false;
1063 }
1064 self.sq.update_head(completion.sqhd);
1065 self.aer_handler.handle_completion(&completion);
1066 respond.complete(completion);
1067 self.stats.completed.increment();
1068 }
1069 }
1070 }
1071 }
1072
1073 pub async fn save(&self) -> anyhow::Result<QueueHandlerSavedState> {
1075 Ok(QueueHandlerSavedState {
1077 sq_state: self.sq.save(),
1078 cq_state: self.cq.save(),
1079 pending_cmds: self.commands.save(),
1080 aer_handler: self.aer_handler.save(),
1081 })
1082 }
1083
1084 pub fn restore(
1086 sq_mem_block: MemoryBlock,
1087 cq_mem_block: MemoryBlock,
1088 saved_state: &QueueHandlerSavedState,
1089 mut aer_handler: A,
1090 device_id: &str,
1091 qid: u16,
1092 ) -> anyhow::Result<Self> {
1093 let QueueHandlerSavedState {
1094 sq_state,
1095 cq_state,
1096 pending_cmds,
1097 aer_handler: aer_handler_saved_state,
1098 } = saved_state;
1099
1100 aer_handler.restore(aer_handler_saved_state);
1101
1102 Ok(Self {
1103 sq: SubmissionQueue::restore(sq_mem_block, sq_state)?,
1104 cq: CompletionQueue::restore(cq_mem_block, cq_state)?,
1105 commands: PendingCommands::restore(pending_cmds, sq_state.sqid)?,
1106 stats: Default::default(),
1107 drain_after_restore: sq_state.sqid != 0 && !pending_cmds.commands.is_empty(),
1110 aer_handler,
1111 device_id: device_id.into(),
1112 qid,
1113 })
1114 }
1115}
1116
1117pub(crate) fn admin_cmd(opcode: spec::AdminOpcode) -> spec::Command {
1118 spec::Command {
1119 cdw0: spec::Cdw0::new().with_opcode(opcode.0),
1120 ..FromZeros::new_zeroed()
1121 }
1122}