1use super::spec;
7use crate::driver::save_restore::Error;
8use crate::driver::save_restore::PendingCommandSavedState;
9use crate::driver::save_restore::PendingCommandsSavedState;
10use crate::driver::save_restore::QueueHandlerSavedState;
11use crate::driver::save_restore::QueuePairSavedState;
12use crate::queues::CompletionQueue;
13use crate::queues::SubmissionQueue;
14use crate::registers::DeviceRegisters;
15use anyhow::Context;
16use futures::StreamExt;
17use guestmem::GuestMemory;
18use guestmem::GuestMemoryError;
19use guestmem::ranges::PagedRange;
20use inspect::Inspect;
21use inspect_counters::Counter;
22use mesh::Cancel;
23use mesh::CancelContext;
24use mesh::rpc::Rpc;
25use mesh::rpc::RpcError;
26use mesh::rpc::RpcSend;
27use pal_async::driver::SpawnDriver;
28use pal_async::task::Task;
29use safeatomic::AtomicSliceOps;
30use slab::Slab;
31use std::future::poll_fn;
32use std::num::Wrapping;
33use std::sync::Arc;
34use std::task::Poll;
35use thiserror::Error;
36use user_driver::DeviceBacking;
37use user_driver::interrupt::DeviceInterrupt;
38use user_driver::memory::MemoryBlock;
39use user_driver::memory::PAGE_SIZE;
40use user_driver::memory::PAGE_SIZE64;
41use user_driver::page_allocator::PageAllocator;
42use user_driver::page_allocator::ScopedPages;
43use zerocopy::FromZeros;
44
45const INVALID_PAGE_ADDR: u64 = !(PAGE_SIZE as u64 - 1);
47
48pub(crate) struct QueuePair {
49 task: Task<QueueHandler>,
50 cancel: Cancel,
51 issuer: Arc<Issuer>,
52 mem: MemoryBlock,
53 qid: u16,
54 sq_entries: u16,
55 cq_entries: u16,
56}
57
58impl Inspect for QueuePair {
59 fn inspect(&self, req: inspect::Request<'_>) {
60 let Self {
61 task: _,
62 cancel: _,
63 issuer,
64 mem: _,
65 qid: _,
66 sq_entries: _,
67 cq_entries: _,
68 } = self;
69 issuer.send.send(Req::Inspect(req.defer()));
70 }
71}
72
73impl PendingCommands {
74 const CID_KEY_BITS: u32 = 10;
75 const CID_KEY_MASK: u16 = (1 << Self::CID_KEY_BITS) - 1;
76 const MAX_CIDS: usize = 1 << Self::CID_KEY_BITS;
77 const CID_SEQ_OFFSET: Wrapping<u16> = Wrapping(1 << Self::CID_KEY_BITS);
78
79 fn new() -> Self {
80 Self {
81 commands: Slab::new(),
82 next_cid_high_bits: Wrapping(0),
83 }
84 }
85
86 fn is_full(&self) -> bool {
87 self.commands.len() >= Self::MAX_CIDS
88 }
89
90 fn is_empty(&self) -> bool {
91 self.commands.is_empty()
92 }
93
94 fn insert(&mut self, command: &mut spec::Command, respond: Rpc<(), spec::Completion>) {
96 let entry = self.commands.vacant_entry();
97 assert!(entry.key() < Self::MAX_CIDS);
98 assert_eq!(self.next_cid_high_bits % Self::CID_SEQ_OFFSET, Wrapping(0));
99 let cid = entry.key() as u16 | self.next_cid_high_bits.0;
100 self.next_cid_high_bits += Self::CID_SEQ_OFFSET;
101 command.cdw0.set_cid(cid);
102 entry.insert(PendingCommand {
103 command: *command,
104 respond,
105 });
106 }
107
108 fn remove(&mut self, cid: u16) -> Rpc<(), spec::Completion> {
109 let command = self
110 .commands
111 .try_remove((cid & Self::CID_KEY_MASK) as usize)
112 .expect("completion for unknown cid");
113 assert_eq!(
114 command.command.cdw0.cid(),
115 cid,
116 "cid sequence number mismatch"
117 );
118 command.respond
119 }
120
121 pub fn save(&self) -> PendingCommandsSavedState {
123 let commands: Vec<PendingCommandSavedState> = self
124 .commands
125 .iter()
126 .map(|(_index, cmd)| PendingCommandSavedState {
127 command: cmd.command,
128 })
129 .collect();
130 PendingCommandsSavedState {
131 commands,
132 next_cid_high_bits: self.next_cid_high_bits.0,
133 cid_key_bits: Self::CID_KEY_BITS,
135 }
136 }
137
138 pub fn restore(saved_state: &PendingCommandsSavedState) -> anyhow::Result<Self> {
140 let PendingCommandsSavedState {
141 commands,
142 next_cid_high_bits,
143 cid_key_bits: _, } = saved_state;
145
146 Ok(Self {
147 commands: commands
149 .iter()
150 .map(|state| {
151 (
154 (state.command.cdw0.cid() & Self::CID_KEY_MASK) as usize,
156 PendingCommand {
157 command: state.command,
158 respond: Rpc::detached(()),
159 },
160 )
161 })
162 .collect::<Slab<PendingCommand>>(),
163 next_cid_high_bits: Wrapping(*next_cid_high_bits),
164 })
165 }
166}
167
168impl QueuePair {
169 pub const MAX_SQ_ENTRIES: u16 = (PAGE_SIZE / 64) as u16;
171 pub const MAX_CQ_ENTRIES: u16 = (PAGE_SIZE / 16) as u16;
173 const SQ_SIZE: usize = PAGE_SIZE;
175 const CQ_SIZE: usize = PAGE_SIZE;
177 const PER_QUEUE_PAGES_BOUNCE_BUFFER: usize = 128;
179 const PER_QUEUE_PAGES_NO_BOUNCE_BUFFER: usize = 64;
181
182 pub fn new(
183 spawner: impl SpawnDriver,
184 device: &impl DeviceBacking,
185 qid: u16,
186 sq_entries: u16, cq_entries: u16, interrupt: DeviceInterrupt,
189 registers: Arc<DeviceRegisters<impl DeviceBacking>>,
190 bounce_buffer: bool,
191 ) -> anyhow::Result<Self> {
192 let total_size = QueuePair::SQ_SIZE
193 + QueuePair::CQ_SIZE
194 + if bounce_buffer {
195 QueuePair::PER_QUEUE_PAGES_BOUNCE_BUFFER * PAGE_SIZE
196 } else {
197 QueuePair::PER_QUEUE_PAGES_NO_BOUNCE_BUFFER * PAGE_SIZE
198 };
199 let dma_client = device.dma_client();
200 let mem = dma_client
201 .allocate_dma_buffer(total_size)
202 .context("failed to allocate memory for queues")?;
203
204 assert!(sq_entries <= Self::MAX_SQ_ENTRIES);
205 assert!(cq_entries <= Self::MAX_CQ_ENTRIES);
206
207 QueuePair::new_or_restore(
208 spawner,
209 qid,
210 sq_entries,
211 cq_entries,
212 interrupt,
213 registers,
214 mem,
215 None,
216 bounce_buffer,
217 )
218 }
219
220 fn new_or_restore(
222 spawner: impl SpawnDriver,
223 qid: u16,
224 sq_entries: u16, cq_entries: u16, mut interrupt: DeviceInterrupt,
227 registers: Arc<DeviceRegisters<impl DeviceBacking>>,
228 mem: MemoryBlock,
229 saved_state: Option<&QueueHandlerSavedState>,
230 bounce_buffer: bool,
231 ) -> anyhow::Result<Self> {
232 let sq_mem_block = mem.subblock(0, QueuePair::SQ_SIZE);
234 let cq_mem_block = mem.subblock(QueuePair::SQ_SIZE, QueuePair::CQ_SIZE);
235 let data_offset = QueuePair::SQ_SIZE + QueuePair::CQ_SIZE;
236
237 let mut queue_handler = match saved_state {
238 Some(s) => QueueHandler::restore(sq_mem_block, cq_mem_block, s)?,
239 None => {
240 QueueHandler {
242 sq: SubmissionQueue::new(qid, sq_entries, sq_mem_block),
243 cq: CompletionQueue::new(qid, cq_entries, cq_mem_block),
244 commands: PendingCommands::new(),
245 stats: Default::default(),
246 drain_after_restore: false,
247 }
248 }
249 };
250
251 let (send, recv) = mesh::channel();
252 let (mut ctx, cancel) = CancelContext::new().with_cancel();
253 let task = spawner.spawn("nvme-queue", {
254 async move {
255 ctx.until_cancelled(async {
256 queue_handler.run(®isters, recv, &mut interrupt).await;
257 })
258 .await
259 .ok();
260 queue_handler
261 }
262 });
263
264 const fn pages_to_size_bytes(pages: usize) -> usize {
267 let size = pages * PAGE_SIZE;
268 assert!(
269 size >= 128 * 1024 + PAGE_SIZE,
270 "not enough room for an ATAPI IO plus a PRP list"
271 );
272 size
273 }
274
275 let alloc_len = if bounce_buffer {
282 const { pages_to_size_bytes(QueuePair::PER_QUEUE_PAGES_BOUNCE_BUFFER) }
283 } else {
284 const { pages_to_size_bytes(QueuePair::PER_QUEUE_PAGES_NO_BOUNCE_BUFFER) }
285 };
286
287 let alloc = PageAllocator::new(mem.subblock(data_offset, alloc_len));
288
289 Ok(Self {
290 task,
291 cancel,
292 issuer: Arc::new(Issuer { send, alloc }),
293 mem,
294 qid,
295 sq_entries,
296 cq_entries,
297 })
298 }
299
300 pub fn sq_addr(&self) -> u64 {
301 self.mem.pfns()[0] * PAGE_SIZE64
302 }
303
304 pub fn cq_addr(&self) -> u64 {
305 self.mem.pfns()[1] * PAGE_SIZE64
306 }
307
308 pub fn issuer(&self) -> &Arc<Issuer> {
309 &self.issuer
310 }
311
312 pub async fn shutdown(mut self) -> impl Send {
313 self.cancel.cancel();
314 self.task.await
315 }
316
317 pub async fn save(&self) -> anyhow::Result<QueuePairSavedState> {
319 if self.mem.pfns().is_empty() {
321 return Err(Error::InvalidState.into());
322 }
323 let handler_data = self.issuer.send.call(Req::Save, ()).await??;
326
327 Ok(QueuePairSavedState {
328 mem_len: self.mem.len(),
329 base_pfn: self.mem.pfns()[0],
330 qid: self.qid,
331 sq_entries: self.sq_entries,
332 cq_entries: self.cq_entries,
333 handler_data,
334 })
335 }
336
337 pub fn restore(
339 spawner: impl SpawnDriver,
340 interrupt: DeviceInterrupt,
341 registers: Arc<DeviceRegisters<impl DeviceBacking>>,
342 mem: MemoryBlock,
343 saved_state: &QueuePairSavedState,
344 bounce_buffer: bool,
345 ) -> anyhow::Result<Self> {
346 let QueuePairSavedState {
347 mem_len: _, base_pfn: _, qid,
350 sq_entries,
351 cq_entries,
352 handler_data,
353 } = saved_state;
354
355 QueuePair::new_or_restore(
356 spawner,
357 *qid,
358 *sq_entries,
359 *cq_entries,
360 interrupt,
361 registers,
362 mem,
363 Some(handler_data),
364 bounce_buffer,
365 )
366 }
367}
368
369#[derive(Debug, Error)]
371#[expect(missing_docs)]
372pub enum RequestError {
373 #[error("queue pair is gone")]
374 Gone(#[source] RpcError),
375 #[error("nvme error")]
376 Nvme(#[source] NvmeError),
377 #[error("memory error")]
378 Memory(#[source] GuestMemoryError),
379 #[error("i/o too large for double buffering")]
380 TooLarge,
381}
382
383#[derive(Debug, Copy, Clone, PartialEq, Eq)]
384pub struct NvmeError(spec::Status);
385
386impl NvmeError {
387 pub fn status(&self) -> spec::Status {
388 self.0
389 }
390}
391
392impl From<spec::Status> for NvmeError {
393 fn from(value: spec::Status) -> Self {
394 Self(value)
395 }
396}
397
398impl std::error::Error for NvmeError {}
399
400impl std::fmt::Display for NvmeError {
401 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402 match self.0.status_code_type() {
403 spec::StatusCodeType::GENERIC => write!(f, "general error {:#x?}", self.0),
404 spec::StatusCodeType::COMMAND_SPECIFIC => {
405 write!(f, "command-specific error {:#x?}", self.0)
406 }
407 spec::StatusCodeType::MEDIA_ERROR => {
408 write!(f, "media error {:#x?}", self.0)
409 }
410 _ => write!(f, "{:#x?}", self.0),
411 }
412 }
413}
414
415#[derive(Debug, Inspect)]
416pub struct Issuer {
417 #[inspect(skip)]
418 send: mesh::Sender<Req>,
419 alloc: PageAllocator,
420}
421
422impl Issuer {
423 pub async fn issue_raw(
424 &self,
425 command: spec::Command,
426 ) -> Result<spec::Completion, RequestError> {
427 match self.send.call(Req::Command, command).await {
428 Ok(completion) if completion.status.status() == 0 => Ok(completion),
429 Ok(completion) => Err(RequestError::Nvme(NvmeError(spec::Status(
430 completion.status.status(),
431 )))),
432 Err(err) => Err(RequestError::Gone(err)),
433 }
434 }
435
436 pub async fn issue_external(
437 &self,
438 mut command: spec::Command,
439 guest_memory: &GuestMemory,
440 mem: PagedRange<'_>,
441 ) -> Result<spec::Completion, RequestError> {
442 let mut double_buffer_pages = None;
443 let opcode = spec::Opcode(command.cdw0.opcode());
444 assert!(
445 opcode.transfer_controller_to_host()
446 || opcode.transfer_host_to_controller()
447 || mem.is_empty()
448 );
449
450 guest_memory
452 .probe_gpns(mem.gpns())
453 .map_err(RequestError::Memory)?;
454
455 let prp = if mem
456 .gpns()
457 .iter()
458 .all(|&gpn| guest_memory.iova(gpn * PAGE_SIZE64).is_some())
459 {
460 self.make_prp(
462 mem.offset() as u64,
463 mem.gpns()
464 .iter()
465 .map(|&gpn| guest_memory.iova(gpn * PAGE_SIZE64).unwrap()),
466 )
467 .await
468 } else {
469 tracing::debug!(opcode = opcode.0, size = mem.len(), "double buffering");
470
471 let double_buffer_pages = double_buffer_pages.insert(
474 self.alloc
475 .alloc_bytes(mem.len())
476 .await
477 .ok_or(RequestError::TooLarge)?,
478 );
479
480 if opcode.transfer_host_to_controller() {
481 double_buffer_pages
482 .copy_from_guest_memory(guest_memory, mem)
483 .map_err(RequestError::Memory)?;
484 }
485
486 self.make_prp(
487 0,
488 (0..double_buffer_pages.page_count())
489 .map(|i| double_buffer_pages.physical_address(i)),
490 )
491 .await
492 };
493
494 command.dptr = prp.dptr;
495 let r = self.issue_raw(command).await;
496 if let Some(double_buffer_pages) = double_buffer_pages {
497 if r.is_ok() && opcode.transfer_controller_to_host() {
498 double_buffer_pages
499 .copy_to_guest_memory(guest_memory, mem)
500 .map_err(RequestError::Memory)?;
501 }
502 }
503 r
504 }
505
506 async fn make_prp(
507 &self,
508 offset: u64,
509 mut iovas: impl ExactSizeIterator<Item = u64>,
510 ) -> Prp<'_> {
511 let mut prp_pages = None;
512 let dptr = match iovas.len() {
513 0 => [INVALID_PAGE_ADDR; 2],
514 1 => [iovas.next().unwrap() + offset, INVALID_PAGE_ADDR],
515 2 => [iovas.next().unwrap() + offset, iovas.next().unwrap()],
516 _ => {
517 let a = iovas.next().unwrap();
518 assert!(iovas.len() <= 4096);
519 let prp = self
520 .alloc
521 .alloc_pages(1)
522 .await
523 .expect("pool cap is >= 1 page");
524
525 let prp_addr = prp.physical_address(0);
526 let page = prp.page_as_slice(0);
527 for (iova, dest) in iovas.zip(page.chunks_exact(8)) {
528 dest.atomic_write_obj(&iova.to_le_bytes());
529 }
530 prp_pages = Some(prp);
531 [a + offset, prp_addr]
532 }
533 };
534 Prp {
535 dptr,
536 _pages: prp_pages,
537 }
538 }
539
540 pub async fn issue_neither(
541 &self,
542 mut command: spec::Command,
543 ) -> Result<spec::Completion, RequestError> {
544 command.dptr = [INVALID_PAGE_ADDR; 2];
545 self.issue_raw(command).await
546 }
547
548 pub async fn issue_in(
549 &self,
550 mut command: spec::Command,
551 data: &[u8],
552 ) -> Result<spec::Completion, RequestError> {
553 let mem = self
554 .alloc
555 .alloc_bytes(data.len())
556 .await
557 .expect("pool cap is >= 1 page");
558
559 mem.write(data);
560 assert_eq!(
561 mem.page_count(),
562 1,
563 "larger requests not currently supported"
564 );
565 let prp = Prp {
566 dptr: [mem.physical_address(0), INVALID_PAGE_ADDR],
567 _pages: None,
568 };
569 command.dptr = prp.dptr;
570 self.issue_raw(command).await
571 }
572
573 pub async fn issue_out(
574 &self,
575 mut command: spec::Command,
576 data: &mut [u8],
577 ) -> Result<spec::Completion, RequestError> {
578 let mem = self
579 .alloc
580 .alloc_bytes(data.len())
581 .await
582 .expect("pool cap is sufficient");
583
584 assert_eq!(
585 mem.page_count(),
586 1,
587 "larger requests not currently supported"
588 );
589 let prp = Prp {
590 dptr: [mem.physical_address(0), INVALID_PAGE_ADDR],
591 _pages: None,
592 };
593 command.dptr = prp.dptr;
594 let completion = self.issue_raw(command).await;
595 mem.read(data);
596 completion
597 }
598}
599
600struct Prp<'a> {
601 dptr: [u64; 2],
602 _pages: Option<ScopedPages<'a>>,
603}
604
605#[derive(Inspect)]
606struct PendingCommands {
607 #[inspect(iter_by_key)]
609 commands: Slab<PendingCommand>,
610 #[inspect(hex)]
611 next_cid_high_bits: Wrapping<u16>,
612}
613
614#[derive(Inspect)]
615struct PendingCommand {
616 command: spec::Command,
618 #[inspect(skip)]
619 respond: Rpc<(), spec::Completion>,
620}
621
622enum Req {
623 Command(Rpc<spec::Command, spec::Completion>),
624 Inspect(inspect::Deferred),
625 Save(Rpc<(), Result<QueueHandlerSavedState, anyhow::Error>>),
626}
627
628#[derive(Inspect)]
629struct QueueHandler {
630 sq: SubmissionQueue,
631 cq: CompletionQueue,
632 commands: PendingCommands,
633 stats: QueueStats,
634 drain_after_restore: bool,
635}
636
637#[derive(Inspect, Default)]
638struct QueueStats {
639 issued: Counter,
640 completed: Counter,
641 interrupts: Counter,
642}
643
644impl QueueHandler {
645 async fn run(
646 &mut self,
647 registers: &DeviceRegisters<impl DeviceBacking>,
648 mut recv: mesh::Receiver<Req>,
649 interrupt: &mut DeviceInterrupt,
650 ) {
651 loop {
652 enum Event {
653 Request(Req),
654 Completion(spec::Completion),
655 }
656
657 let event = if !self.drain_after_restore {
658 poll_fn(|cx| {
660 if !self.sq.is_full() && !self.commands.is_full() {
661 if let Poll::Ready(Some(req)) = recv.poll_next_unpin(cx) {
662 return Event::Request(req).into();
663 }
664 }
665 while !self.commands.is_empty() {
666 if let Some(completion) = self.cq.read() {
667 return Event::Completion(completion).into();
668 }
669 if interrupt.poll(cx).is_pending() {
670 break;
671 }
672 self.stats.interrupts.increment();
673 }
674 self.sq.commit(registers);
675 self.cq.commit(registers);
676 Poll::Pending
677 })
678 .await
679 } else {
680 poll_fn(|cx| {
682 while !self.commands.is_empty() {
683 if let Some(completion) = self.cq.read() {
684 return Event::Completion(completion).into();
685 }
686 if interrupt.poll(cx).is_pending() {
687 break;
688 }
689 self.stats.interrupts.increment();
690 }
691 self.cq.commit(registers);
692 Poll::Pending
693 })
694 .await
695 };
696
697 match event {
698 Event::Request(req) => match req {
699 Req::Command(rpc) => {
700 let (mut command, respond) = rpc.split();
701 self.commands.insert(&mut command, respond);
702 self.sq.write(command).unwrap();
703 self.stats.issued.increment();
704 }
705 Req::Inspect(deferred) => deferred.inspect(&self),
706 Req::Save(queue_state) => {
707 queue_state.complete(self.save().await);
708 break;
710 }
711 },
712 Event::Completion(completion) => {
713 assert_eq!(completion.sqid, self.sq.id());
714 let respond = self.commands.remove(completion.cid);
715 if self.drain_after_restore && self.commands.is_empty() {
716 self.drain_after_restore = false;
718 }
719 self.sq.update_head(completion.sqhd);
720 respond.complete(completion);
721 self.stats.completed.increment();
722 }
723 }
724 }
725 }
726
727 pub async fn save(&self) -> anyhow::Result<QueueHandlerSavedState> {
729 Ok(QueueHandlerSavedState {
731 sq_state: self.sq.save(),
732 cq_state: self.cq.save(),
733 pending_cmds: self.commands.save(),
734 })
735 }
736
737 pub fn restore(
739 sq_mem_block: MemoryBlock,
740 cq_mem_block: MemoryBlock,
741 saved_state: &QueueHandlerSavedState,
742 ) -> anyhow::Result<Self> {
743 let QueueHandlerSavedState {
744 sq_state,
745 cq_state,
746 pending_cmds,
747 } = saved_state;
748
749 Ok(Self {
750 sq: SubmissionQueue::restore(sq_mem_block, sq_state)?,
751 cq: CompletionQueue::restore(cq_mem_block, cq_state)?,
752 commands: PendingCommands::restore(pending_cmds)?,
753 stats: Default::default(),
754 drain_after_restore: sq_state.sqid != 0 && !pending_cmds.commands.is_empty(),
757 })
758 }
759}
760
761pub(crate) fn admin_cmd(opcode: spec::AdminOpcode) -> spec::Command {
762 spec::Command {
763 cdw0: spec::Cdw0::new().with_opcode(opcode.0),
764 ..FromZeros::new_zeroed()
765 }
766}