1use super::spec;
7use crate::NVME_PAGE_SHIFT;
8use crate::Namespace;
9use crate::NamespaceError;
10use crate::RequestError;
11use crate::driver::save_restore::IoQueueSavedState;
12use crate::queue_pair::AdminAerHandler;
13use crate::queue_pair::Issuer;
14use crate::queue_pair::MAX_CQ_ENTRIES;
15use crate::queue_pair::MAX_SQ_ENTRIES;
16use crate::queue_pair::NoOpAerHandler;
17use crate::queue_pair::QueuePair;
18use crate::queue_pair::admin_cmd;
19use crate::registers::Bar0;
20use crate::registers::DeviceRegisters;
21use crate::save_restore::NvmeDriverSavedState;
22use anyhow::Context as _;
23use futures::StreamExt;
24use futures::future::join_all;
25use inspect::Inspect;
26use mesh::payload::Protobuf;
27use mesh::rpc::Rpc;
28use mesh::rpc::RpcSend;
29use pal_async::task::Spawn;
30use pal_async::task::Task;
31use parking_lot::RwLock;
32use save_restore::NvmeDriverWorkerSavedState;
33use std::collections::HashMap;
34use std::mem::ManuallyDrop;
35use std::ops::Deref;
36use std::sync::Arc;
37use std::sync::OnceLock;
38use task_control::AsyncRun;
39use task_control::InspectTask;
40use task_control::TaskControl;
41use thiserror::Error;
42use tracing::Instrument;
43use tracing::Span;
44use tracing::info_span;
45use user_driver::DeviceBacking;
46use user_driver::backoff::Backoff;
47use user_driver::interrupt::DeviceInterrupt;
48use user_driver::memory::MemoryBlock;
49use vmcore::vm_task::VmTaskDriver;
50use vmcore::vm_task::VmTaskDriverSource;
51use zerocopy::FromBytes;
52use zerocopy::FromZeros;
53use zerocopy::IntoBytes;
54
55#[derive(Inspect)]
64pub struct NvmeDriver<D: DeviceBacking> {
65 #[inspect(flatten)]
66 task: Option<TaskControl<DriverWorkerTask<D>, WorkerState>>,
67 device_id: String,
68 identify: Option<Arc<spec::IdentifyController>>,
69 #[inspect(skip)]
70 driver: VmTaskDriver,
71 #[inspect(skip)]
72 admin: Option<Arc<Issuer>>,
73 #[inspect(skip)]
74 io_issuers: Arc<IoIssuers>,
75 #[inspect(skip)]
76 rescan_notifiers: Arc<RwLock<HashMap<u32, mesh::Sender<()>>>>,
77 #[inspect(skip)]
79 namespaces: HashMap<u32, NamespaceHandle>,
80 nvme_keepalive: bool,
82 bounce_buffer: bool,
83}
84
85struct NamespaceHandle {
86 namespace: Arc<Namespace>,
87 in_use: bool,
88}
89
90#[derive(Inspect)]
91struct DriverWorkerTask<D: DeviceBacking> {
92 device: ManuallyDrop<D>,
98 #[inspect(skip)]
99 driver: VmTaskDriver,
100 registers: Arc<DeviceRegisters<D>>,
101 admin: Option<QueuePair<AdminAerHandler, D>>,
102 #[inspect(iter_by_index)]
103 io: Vec<IoQueue<D>>,
104 #[inspect(skip)]
111 proto_io: HashMap<u32, ProtoIoQueue>,
112 next_ioq_id: u16,
114 io_issuers: Arc<IoIssuers>,
115 #[inspect(skip)]
116 recv: mesh::Receiver<NvmeWorkerRequest>,
117 bounce_buffer: bool,
118}
119
120#[derive(Inspect)]
121struct WorkerState {
122 max_io_queues: u16,
123 qsize: u16,
124 #[inspect(skip)]
125 async_event_task: Task<()>,
126}
127
128#[derive(Debug, Error)]
130pub enum RestoreError {
131 #[error("invalid data")]
132 InvalidData,
133}
134
135#[derive(Debug, Error)]
136pub enum DeviceError {
137 #[error("no more io queues available, reached maximum {0}")]
138 NoMoreIoQueues(u16),
139 #[error("failed to map interrupt")]
140 InterruptMapFailure(#[source] anyhow::Error),
141 #[error("failed to create io queue pair {1}")]
142 IoQueuePairCreationFailure(#[source] anyhow::Error, u16),
143 #[error("failed to create io completion queue {1}")]
144 IoCompletionQueueFailure(#[source] anyhow::Error, u16),
145 #[error("failed to create io submission queue {1}")]
146 IoSubmissionQueueFailure(#[source] anyhow::Error, u16),
147 #[error(transparent)]
149 Other(anyhow::Error),
150}
151
152#[derive(Debug, Clone)]
153struct ProtoIoQueue {
154 save_state: IoQueueSavedState,
155 mem: MemoryBlock,
156}
157
158#[derive(Inspect)]
159struct IoQueue<D: DeviceBacking> {
160 queue: QueuePair<NoOpAerHandler, D>,
161 iv: u16,
162 cpu: u32,
163}
164
165impl<D: DeviceBacking> IoQueue<D> {
166 pub async fn save(&self) -> anyhow::Result<IoQueueSavedState> {
167 Ok(IoQueueSavedState {
168 cpu: self.cpu,
169 iv: self.iv as u32,
170 queue_data: self.queue.save().await?,
171 })
172 }
173
174 pub fn restore(
175 spawner: VmTaskDriver,
176 interrupt: DeviceInterrupt,
177 registers: Arc<DeviceRegisters<D>>,
178 mem_block: MemoryBlock,
179 saved_state: &IoQueueSavedState,
180 bounce_buffer: bool,
181 ) -> anyhow::Result<Self> {
182 let IoQueueSavedState {
183 cpu,
184 iv,
185 queue_data,
186 } = saved_state;
187 let queue = QueuePair::restore(
188 spawner,
189 interrupt,
190 registers.clone(),
191 mem_block,
192 queue_data,
193 bounce_buffer,
194 NoOpAerHandler,
195 )?;
196
197 Ok(Self {
198 queue,
199 iv: *iv as u16,
200 cpu: *cpu,
201 })
202 }
203}
204
205#[derive(Debug, Inspect)]
206pub(crate) struct IoIssuers {
207 #[inspect(iter_by_index)]
208 per_cpu: Vec<OnceLock<IoIssuer>>,
209 #[inspect(skip)]
210 send: mesh::Sender<NvmeWorkerRequest>,
211}
212
213#[derive(Debug, Clone, Inspect)]
214struct IoIssuer {
215 #[inspect(flatten)]
216 issuer: Arc<Issuer>,
217 cpu: u32,
218}
219
220#[derive(Debug)]
221enum NvmeWorkerRequest {
222 CreateIssuer(Rpc<u32, ()>),
223 Save(Rpc<Span, anyhow::Result<NvmeDriverWorkerSavedState>>),
225}
226
227impl<D: DeviceBacking> NvmeDriver<D> {
228 pub async fn new(
230 driver_source: &VmTaskDriverSource,
231 cpu_count: u32,
232 device: D,
233 bounce_buffer: bool,
234 ) -> anyhow::Result<Self> {
235 let pci_id = device.id().to_owned();
236 let mut this = Self::new_disabled(driver_source, cpu_count, device, bounce_buffer)
237 .instrument(tracing::info_span!("nvme_new_disabled", pci_id))
238 .await?;
239 match this
240 .enable(cpu_count as u16)
241 .instrument(tracing::info_span!("nvme_enable", pci_id))
242 .await
243 {
244 Ok(()) => Ok(this),
245 Err(err) => {
246 tracing::error!(
247 error = err.as_ref() as &dyn std::error::Error,
248 "device initialization failed, shutting down"
249 );
250 this.shutdown().await;
251 Err(err)
252 }
253 }
254 }
255
256 async fn new_disabled(
259 driver_source: &VmTaskDriverSource,
260 cpu_count: u32,
261 mut device: D,
262 bounce_buffer: bool,
263 ) -> anyhow::Result<Self> {
264 let driver = driver_source.simple();
265 let bar0 = Bar0(
266 device
267 .map_bar(0)
268 .context("failed to map device registers")?,
269 );
270
271 let cc = bar0.cc();
272 if cc.en() || bar0.csts().rdy() {
273 if let Err(e) = bar0
274 .reset(&driver)
275 .instrument(tracing::info_span!(
276 "nvme_already_enabled",
277 pci_id = device.id().to_owned()
278 ))
279 .await
280 {
281 anyhow::bail!("device is gone, csts: {:#x}", e);
282 }
283 }
284
285 let registers = Arc::new(DeviceRegisters::new(bar0));
286 let cap = registers.cap;
287
288 if cap.mpsmin() != 0 {
289 anyhow::bail!(
290 "unsupported minimum page size: {}",
291 cap.mpsmin() + NVME_PAGE_SHIFT
292 );
293 }
294
295 let (send, recv) = mesh::channel();
296 let io_issuers = Arc::new(IoIssuers {
297 per_cpu: (0..cpu_count).map(|_| OnceLock::new()).collect(),
298 send,
299 });
300
301 Ok(Self {
302 device_id: device.id().to_owned(),
303 task: Some(TaskControl::new(DriverWorkerTask {
304 device: ManuallyDrop::new(device),
305 driver: driver.clone(),
306 registers,
307 admin: None,
308 io: Vec::new(),
309 proto_io: HashMap::new(),
310 next_ioq_id: 1,
311 io_issuers: io_issuers.clone(),
312 recv,
313 bounce_buffer,
314 })),
315 admin: None,
316 identify: None,
317 driver,
318 io_issuers,
319 rescan_notifiers: Default::default(),
320 namespaces: Default::default(),
321 nvme_keepalive: false,
322 bounce_buffer,
323 })
324 }
325
326 async fn enable(&mut self, requested_io_queue_count: u16) -> anyhow::Result<()> {
328 const ADMIN_QID: u16 = 0;
329
330 let task = &mut self.task.as_mut().unwrap();
331 let worker = task.task_mut();
332
333 let admin_len = std::cmp::min(MAX_SQ_ENTRIES, MAX_CQ_ENTRIES);
338 let admin_sqes = admin_len;
339 let admin_cqes = admin_len;
340
341 let interrupt0 = worker
342 .device
343 .map_interrupt(0, 0)
344 .context("failed to map interrupt 0")?;
345
346 let admin = QueuePair::new(
348 self.driver.clone(),
349 worker.device.deref(),
350 ADMIN_QID,
351 admin_sqes,
352 admin_cqes,
353 interrupt0,
354 worker.registers.clone(),
355 self.bounce_buffer,
356 AdminAerHandler::new(),
357 )
358 .context("failed to create admin queue pair")?;
359
360 let admin_sqes = admin.sq_entries();
361 let admin_cqes = admin.cq_entries();
362
363 let admin = worker.admin.insert(admin);
364
365 worker.registers.bar0.set_aqa(
367 spec::Aqa::new()
368 .with_acqs_z(admin_cqes - 1)
369 .with_asqs_z(admin_sqes - 1),
370 );
371 worker.registers.bar0.set_asq(admin.sq_addr());
372 worker.registers.bar0.set_acq(admin.cq_addr());
373
374 let span = tracing::info_span!("nvme_ctrl_enable", pci_id = worker.device.id().to_owned());
376 let ctrl_enable_span = span.enter();
377 worker.registers.bar0.set_cc(
378 spec::Cc::new()
379 .with_iocqes(4)
380 .with_iosqes(6)
381 .with_en(true)
382 .with_mps(0),
383 );
384
385 let mut backoff = Backoff::new(&self.driver);
387 loop {
388 let csts = worker.registers.bar0.csts();
389 let csts_val: u32 = csts.into();
390 if csts_val == !0 {
391 anyhow::bail!("device is gone, csts: {:#x}", csts_val);
392 }
393 if csts.cfs() {
394 let after_reset = if let Err(e) = worker.registers.bar0.reset(&self.driver).await {
396 e
397 } else {
398 0
399 };
400 anyhow::bail!(
401 "device had fatal error, csts: {:#x}, after reset: {:#}",
402 csts_val,
403 after_reset
404 );
405 }
406 if csts.rdy() {
407 break;
408 }
409 backoff.back_off().await;
410 }
411 drop(ctrl_enable_span);
412
413 let identify = self
415 .identify
416 .insert(Arc::new(spec::IdentifyController::new_zeroed()));
417
418 admin
419 .issuer()
420 .issue_out(
421 spec::Command {
422 cdw10: spec::Cdw10Identify::new()
423 .with_cns(spec::Cns::CONTROLLER.0)
424 .into(),
425 ..admin_cmd(spec::AdminOpcode::IDENTIFY)
426 },
427 Arc::get_mut(identify).unwrap().as_mut_bytes(),
428 )
429 .await
430 .context("failed to identify controller")?;
431
432 let max_interrupt_count = worker.device.max_interrupt_count();
436 if max_interrupt_count == 0 {
437 anyhow::bail!("bad device behavior: max_interrupt_count == 0");
438 }
439
440 let requested_io_queue_count = if max_interrupt_count < requested_io_queue_count as u32 {
441 tracing::warn!(
442 max_interrupt_count,
443 requested_io_queue_count,
444 pci_id = ?worker.device.id(),
445 "queue count constrained by msi count"
446 );
447 max_interrupt_count as u16
448 } else {
449 requested_io_queue_count
450 };
451
452 let completion = admin
453 .issuer()
454 .issue_neither(spec::Command {
455 cdw10: spec::Cdw10SetFeatures::new()
456 .with_fid(spec::Feature::NUMBER_OF_QUEUES.0)
457 .into(),
458 cdw11: spec::Cdw11FeatureNumberOfQueues::new()
459 .with_ncq_z(requested_io_queue_count - 1)
460 .with_nsq_z(requested_io_queue_count - 1)
461 .into(),
462 ..admin_cmd(spec::AdminOpcode::SET_FEATURES)
463 })
464 .await
465 .context("failed to set number of queues")?;
466
467 let dw0 = spec::Cdw11FeatureNumberOfQueues::from(completion.dw0);
469 let sq_count = dw0.nsq_z() + 1;
470 let cq_count = dw0.ncq_z() + 1;
471 let allocated_io_queue_count = sq_count.min(cq_count);
472 if allocated_io_queue_count < requested_io_queue_count {
473 tracing::warn!(
474 sq_count,
475 cq_count,
476 requested_io_queue_count,
477 pci_id = ?worker.device.id(),
478 "queue count constrained by hardware queue count"
479 );
480 }
481
482 let max_io_queues = allocated_io_queue_count.min(requested_io_queue_count);
483
484 let qsize = {
485 if worker.registers.cap.mqes_z() < 1 {
486 anyhow::bail!("bad device behavior. mqes cannot be 0");
487 }
488
489 let io_cqsize = (MAX_CQ_ENTRIES - 1).min(worker.registers.cap.mqes_z()) + 1;
490 let io_sqsize = (MAX_SQ_ENTRIES - 1).min(worker.registers.cap.mqes_z()) + 1;
491
492 tracing::debug!(
493 io_cqsize,
494 io_sqsize,
495 hw_size = worker.registers.cap.mqes_z(),
496 pci_id = ?worker.device.id(),
497 "io queue sizes"
498 );
499
500 io_cqsize.min(io_sqsize)
502 };
503
504 let async_event_task = self.driver.spawn("nvme_async_event", {
506 let admin = admin.issuer().clone();
507 let rescan_notifiers = self.rescan_notifiers.clone();
508 async move {
509 if let Err(err) = handle_asynchronous_events(&admin, rescan_notifiers).await {
510 tracing::error!(
511 error = err.as_ref() as &dyn std::error::Error,
512 "asynchronous event failure, not processing any more"
513 );
514 }
515 }
516 });
517
518 let mut state = WorkerState {
519 qsize,
520 async_event_task,
521 max_io_queues,
522 };
523
524 self.admin = Some(admin.issuer().clone());
525
526 let issuer = worker
529 .create_io_queue(&mut state, 0)
530 .await
531 .context("failed to create io queue 1")?;
532
533 self.io_issuers.per_cpu[0].set(issuer).unwrap();
534 task.insert(&self.driver, "nvme_worker", state);
535 task.start();
536 Ok(())
537 }
538
539 pub async fn shutdown(mut self) {
541 tracing::debug!(pci_id = ?self.device_id, "shutting down nvme driver");
542
543 if self.nvme_keepalive {
546 return;
547 }
548 self.reset().await;
549 drop(self);
550 }
551
552 fn reset(&mut self) -> impl Send + Future<Output = ()> + use<D> {
553 let driver = self.driver.clone();
554 let id = self.device_id.clone();
555 let mut task = std::mem::take(&mut self.task).unwrap();
556 async move {
557 task.stop().await;
558 let (worker, state) = task.into_inner();
559 if let Some(state) = state {
560 state.async_event_task.cancel().await;
561 }
562 let _io_responses = join_all(worker.io.into_iter().map(|io| io.queue.shutdown())).await;
565 let _admin_responses;
566 if let Some(admin) = worker.admin {
567 _admin_responses = admin.shutdown().await;
568 }
569 if let Err(e) = worker.registers.bar0.reset(&driver).await {
570 tracing::info!(csts = e, "device reset failed");
571 }
572
573 let _vfio = ManuallyDrop::into_inner(worker.device);
574 tracing::debug!(pci_id = ?id, "dropping vfio handle to device");
575 }
576 }
577
578 pub async fn namespace(&mut self, nsid: u32) -> Result<Arc<Namespace>, NamespaceError> {
580 if let Some(handle) = self.namespaces.get_mut(&nsid) {
581 if !handle.in_use {
583 handle.in_use = true;
584 return Ok(handle.namespace.clone());
585 }
586
587 return Err(NamespaceError::DuplicateRequest { nsid });
592 }
593
594 let (send, recv) = mesh::channel::<()>();
595 let namespace = Arc::new(
596 Namespace::new(
597 &self.driver,
598 self.admin.as_ref().unwrap().clone(),
599 recv,
600 self.identify.clone().unwrap(),
601 &self.io_issuers,
602 nsid,
603 )
604 .await?,
605 );
606 self.namespaces.insert(
607 nsid,
608 NamespaceHandle {
609 namespace: namespace.clone(),
610 in_use: true,
611 },
612 );
613
614 let mut notifiers = self.rescan_notifiers.write();
616 notifiers.insert(nsid, send);
617 Ok(namespace)
618 }
619
620 pub fn fallback_cpu_count(&self) -> usize {
623 self.io_issuers
624 .per_cpu
625 .iter()
626 .enumerate()
627 .filter(|&(cpu, c)| c.get().is_some_and(|c| c.cpu != cpu as u32))
628 .count()
629 }
630
631 pub async fn save(&mut self) -> anyhow::Result<NvmeDriverSavedState> {
633 if self.identify.is_none() {
635 return Err(save_restore::Error::InvalidState.into());
636 }
637 let span = tracing::info_span!("nvme_driver_save", pci_id = self.device_id);
638 self.nvme_keepalive = true;
639 match self
640 .io_issuers
641 .send
642 .call(NvmeWorkerRequest::Save, span.clone())
643 .instrument(span.clone())
644 .await?
645 {
646 Ok(s) => {
647 let _e = span.entered();
648 tracing::info!(
649 namespaces = self
650 .namespaces
651 .keys()
652 .map(|nsid| nsid.to_string())
653 .collect::<Vec<_>>()
654 .join(", "),
655 "saving namespaces",
656 );
657 let mut saved_namespaces = vec![];
658 for (nsid, handle) in self.namespaces.iter() {
659 saved_namespaces.push(handle.namespace.save().with_context(|| {
660 format!(
661 "failed to save namespace nsid {} device {}",
662 nsid, self.device_id
663 )
664 })?);
665 }
666 Ok(NvmeDriverSavedState {
667 identify_ctrl: spec::IdentifyController::read_from_bytes(
668 self.identify.as_ref().unwrap().as_bytes(),
669 )
670 .unwrap(),
671 device_id: self.device_id.clone(),
672 namespaces: saved_namespaces,
673 worker_data: s,
674 })
675 }
676 Err(e) => Err(e),
677 }
678 }
679
680 pub async fn clear_existing_state(
685 driver_source: &VmTaskDriverSource,
686 mut device: D,
687 ) -> anyhow::Result<()> {
688 let driver = driver_source.simple();
689 let bar0_mapping = device
690 .map_bar(0)
691 .context("failed to map device registers to clear existing state")?;
692 let bar0 = Bar0(bar0_mapping);
693 bar0.reset(&driver)
694 .await
695 .map_err(|e| anyhow::anyhow!("failed to reset device during clear: {:#x}", e))?;
696 Ok(())
697 }
698
699 pub async fn restore(
701 driver_source: &VmTaskDriverSource,
702 cpu_count: u32,
703 mut device: D,
704 saved_state: &NvmeDriverSavedState,
705 bounce_buffer: bool,
706 ) -> anyhow::Result<Self> {
707 let pci_id = device.id().to_owned();
708 let driver = driver_source.simple();
709 let bar0_mapping = device
710 .map_bar(0)
711 .context("failed to map device registers")?;
712 let bar0 = Bar0(bar0_mapping);
713
714 let csts = bar0.csts();
716 if !csts.rdy() {
717 tracing::error!(
718 csts = u32::from(csts),
719 ?pci_id,
720 "device is not ready during restore"
721 );
722 anyhow::bail!(
723 "device is not ready during restore, csts: {:#x}",
724 u32::from(csts)
725 );
726 }
727
728 let registers = Arc::new(DeviceRegisters::new(bar0));
729
730 let (send, recv) = mesh::channel();
731 let io_issuers = Arc::new(IoIssuers {
732 per_cpu: (0..cpu_count).map(|_| OnceLock::new()).collect(),
733 send,
734 });
735
736 let mut this = Self {
737 device_id: device.id().to_owned(),
738 task: Some(TaskControl::new(DriverWorkerTask {
739 device: ManuallyDrop::new(device),
740 driver: driver.clone(),
741 registers: registers.clone(),
742 admin: None, io: Vec::new(),
744 proto_io: HashMap::new(),
745 next_ioq_id: 1,
746 io_issuers: io_issuers.clone(),
747 recv,
748 bounce_buffer,
749 })),
750 admin: None, identify: Some(Arc::new(
752 spec::IdentifyController::read_from_bytes(saved_state.identify_ctrl.as_bytes())
753 .map_err(|_| RestoreError::InvalidData)?,
754 )),
755 driver: driver.clone(),
756 io_issuers,
757 rescan_notifiers: Default::default(),
758 namespaces: Default::default(),
759 nvme_keepalive: true,
760 bounce_buffer,
761 };
762
763 let task = &mut this.task.as_mut().unwrap();
764 let worker = task.task_mut();
765
766 let interrupt0 = worker
768 .device
769 .map_interrupt(0, 0)
770 .with_context(|| format!("failed to map interrupt 0 for {}", pci_id))?;
771
772 let dma_client = worker.device.dma_client();
773 let restored_memory = dma_client
774 .attach_pending_buffers()
775 .with_context(|| format!("failed to restore allocations for {}", pci_id))?;
776
777 let admin = saved_state
779 .worker_data
780 .admin
781 .as_ref()
782 .map(|a| {
783 tracing::info!(
784 id = a.qid,
785 pending_commands_count = a.handler_data.pending_cmds.commands.len(),
786 ?pci_id,
787 "restoring admin queue",
788 );
789 let mem_block = restored_memory
791 .iter()
792 .find(|mem| mem.len() == a.mem_len && a.base_pfn == mem.pfns()[0])
793 .expect("unable to find restored mem block")
794 .to_owned();
795 QueuePair::restore(
796 driver.clone(),
797 interrupt0,
798 registers.clone(),
799 mem_block,
800 a,
801 bounce_buffer,
802 AdminAerHandler::new(),
803 )
804 .expect("failed to restore admin queue pair")
805 })
806 .expect("attempted to restore admin queue from empty state");
807
808 let admin = worker.admin.insert(admin);
809
810 let async_event_task = this.driver.spawn("nvme_async_event", {
812 let admin = admin.issuer().clone();
813 let rescan_notifiers = this.rescan_notifiers.clone();
814 async move {
815 if let Err(err) = handle_asynchronous_events(&admin, rescan_notifiers)
816 .instrument(tracing::info_span!("async_event_handler"))
817 .await
818 {
819 tracing::error!(
820 error = err.as_ref() as &dyn std::error::Error,
821 "asynchronous event failure, not processing any more"
822 );
823 }
824 }
825 });
826
827 let state = WorkerState {
828 qsize: saved_state.worker_data.qsize,
829 async_event_task,
830 max_io_queues: saved_state.worker_data.max_io_queues,
831 };
832
833 this.admin = Some(admin.issuer().clone());
834
835 tracing::info!(
836 state = saved_state
837 .worker_data
838 .io
839 .iter()
840 .map(|io_state| format!(
841 "{{qid={}, pending_commands_count={}}}",
842 io_state.queue_data.qid,
843 io_state.queue_data.handler_data.pending_cmds.commands.len()
844 ))
845 .collect::<Vec<_>>()
846 .join(", "),
847 ?pci_id,
848 "restoring io queues",
849 );
850
851 let mut max_seen_qid = 1;
855 worker.io = saved_state
856 .worker_data
857 .io
858 .iter()
859 .filter(|q| {
860 q.queue_data.qid == 1 || !q.queue_data.handler_data.pending_cmds.commands.is_empty()
861 })
862 .flat_map(|q| -> Result<IoQueue<D>, anyhow::Error> {
863 let qid = q.queue_data.qid;
864 let cpu = q.cpu;
865 tracing::info!(qid, cpu, ?pci_id, "restoring queue");
866 max_seen_qid = max_seen_qid.max(qid);
867 let interrupt = worker.device.map_interrupt(q.iv, q.cpu).with_context(|| {
868 format!(
869 "failed to map interrupt for {}, cpu {}, iv {}",
870 pci_id, q.cpu, q.iv
871 )
872 })?;
873 tracing::info!(qid, cpu, ?pci_id, "restoring queue: search for mem block");
874 let mem_block = restored_memory
875 .iter()
876 .find(|mem| {
877 mem.len() == q.queue_data.mem_len && q.queue_data.base_pfn == mem.pfns()[0]
878 })
879 .expect("unable to find restored mem block")
880 .to_owned();
881 tracing::info!(qid, cpu, ?pci_id, "restoring queue: restore IoQueue");
882 let q = IoQueue::restore(
883 driver.clone(),
884 interrupt,
885 registers.clone(),
886 mem_block,
887 q,
888 bounce_buffer,
889 )?;
890 tracing::info!(qid, cpu, ?pci_id, "restoring queue: create issuer");
891 let issuer = IoIssuer {
892 issuer: q.queue.issuer().clone(),
893 cpu: q.cpu,
894 };
895 this.io_issuers.per_cpu[q.cpu as usize].set(issuer).unwrap();
896 Ok(q)
897 })
898 .collect();
899
900 worker.proto_io = saved_state
903 .worker_data
904 .io
905 .iter()
906 .filter(|q| {
907 q.queue_data.qid != 1 && q.queue_data.handler_data.pending_cmds.commands.is_empty()
908 })
909 .map(|q| {
910 tracing::info!(
912 qid = q.queue_data.qid,
913 cpu = q.cpu,
914 ?pci_id,
915 "creating prototype io queue entry",
916 );
917 max_seen_qid = max_seen_qid.max(q.queue_data.qid);
918 let mem_block = restored_memory
919 .iter()
920 .find(|mem| {
921 mem.len() == q.queue_data.mem_len && q.queue_data.base_pfn == mem.pfns()[0]
922 })
923 .expect("unable to find restored mem block")
924 .to_owned();
925 (
926 q.cpu,
927 ProtoIoQueue {
928 save_state: q.clone(),
929 mem: mem_block,
930 },
931 )
932 })
933 .collect();
934
935 worker.next_ioq_id = max_seen_qid + 1;
937
938 tracing::info!(
939 namespaces = saved_state
940 .namespaces
941 .iter()
942 .map(|ns| format!("{{nsid={}, size={}}}", ns.nsid, ns.identify_ns.nsze))
943 .collect::<Vec<_>>()
944 .join(", "),
945 ?pci_id,
946 "restoring namespaces",
947 );
948
949 for ns in &saved_state.namespaces {
951 let (send, recv) = mesh::channel::<()>();
952 this.namespaces.insert(
953 ns.nsid,
954 NamespaceHandle {
955 namespace: Arc::new(Namespace::restore(
956 &driver,
957 admin.issuer().clone(),
958 recv,
959 this.identify.clone().unwrap(),
960 &this.io_issuers,
961 ns,
962 )?),
963 in_use: false,
964 },
965 );
966 this.rescan_notifiers.write().insert(ns.nsid, send);
967 }
968
969 task.insert(&this.driver, "nvme_worker", state);
970 task.start();
971
972 Ok(this)
973 }
974
975 pub fn update_servicing_flags(&mut self, nvme_keepalive: bool) {
977 tracing::debug!(nvme_keepalive, "updating nvme servicing flags");
978 self.nvme_keepalive = nvme_keepalive;
979 }
980}
981
982async fn handle_asynchronous_events(
983 admin: &Issuer,
984 rescan_notifiers: Arc<RwLock<HashMap<u32, mesh::Sender<()>>>>,
985) -> anyhow::Result<()> {
986 tracing::info!("starting asynchronous event handler task");
987 loop {
988 let dw0 = admin
989 .issue_get_aen()
990 .await
991 .context("asynchronous event request failed")?;
992
993 match spec::AsynchronousEventType(dw0.event_type()) {
994 spec::AsynchronousEventType::NOTICE => {
995 tracing::info!("received an async notice event (aen) from the controller");
996
997 let mut list = [0u32; 1024];
999 admin
1000 .issue_out(
1001 spec::Command {
1002 cdw10: spec::Cdw10GetLogPage::new()
1003 .with_lid(spec::LogPageIdentifier::CHANGED_NAMESPACE_LIST.0)
1004 .with_numdl_z(1023)
1005 .into(),
1006 ..admin_cmd(spec::AdminOpcode::GET_LOG_PAGE)
1007 },
1008 list.as_mut_bytes(),
1009 )
1010 .await
1011 .context("failed to query changed namespace list")?;
1012
1013 let notifier_guard = rescan_notifiers.read();
1020 if list[0] == 0xFFFFFFFF && list[1] == 0 {
1021 tracing::info!("more than 1024 namespaces changed, notifying all listeners");
1023 for notifiers in notifier_guard.values() {
1024 notifiers.send(());
1025 }
1026 } else {
1027 for nsid in list.iter().filter(|&&nsid| nsid != 0) {
1029 tracing::info!(nsid, "notifying listeners of changed namespace");
1030 if let Some(notifier) = notifier_guard.get(nsid) {
1031 notifier.send(());
1032 }
1033 }
1034 }
1035 }
1036 event_type => {
1037 tracing::info!(
1038 ?event_type,
1039 information = dw0.information(),
1040 log_page_identifier = dw0.log_page_identifier(),
1041 "unhandled asynchronous event"
1042 );
1043 }
1044 }
1045 }
1046}
1047
1048impl<D: DeviceBacking> Drop for NvmeDriver<D> {
1049 fn drop(&mut self) {
1050 tracing::trace!(pci_id = ?self.device_id, ka = self.nvme_keepalive, task = self.task.is_some(), "dropping nvme driver");
1051 if self.task.is_some() {
1052 tracing::debug!(nvme_keepalive = self.nvme_keepalive, pci_id = ?self.device_id, "dropping nvme driver");
1054 if !self.nvme_keepalive {
1055 let reset = self.reset();
1058 self.driver.spawn("nvme_drop", reset).detach();
1059 }
1060 }
1061 }
1062}
1063
1064impl IoIssuers {
1065 pub async fn get(&self, cpu: u32) -> Result<&Issuer, RequestError> {
1066 if let Some(v) = self.per_cpu[cpu as usize].get() {
1067 return Ok(&v.issuer);
1068 }
1069
1070 self.send
1071 .call(NvmeWorkerRequest::CreateIssuer, cpu)
1072 .await
1073 .map_err(RequestError::Gone)?;
1074
1075 Ok(self.per_cpu[cpu as usize]
1076 .get()
1077 .expect("issuer was set by rpc")
1078 .issuer
1079 .as_ref())
1080 }
1081}
1082
1083impl<D: DeviceBacking> AsyncRun<WorkerState> for DriverWorkerTask<D> {
1084 async fn run(
1085 &mut self,
1086 stop: &mut task_control::StopTask<'_>,
1087 state: &mut WorkerState,
1088 ) -> Result<(), task_control::Cancelled> {
1089 let r = stop
1090 .until_stopped(async {
1091 loop {
1092 match self.recv.next().await {
1093 Some(NvmeWorkerRequest::CreateIssuer(rpc)) => {
1094 rpc.handle(async |cpu| self.create_io_issuer(state, cpu).await)
1095 .await
1096 }
1097 Some(NvmeWorkerRequest::Save(rpc)) => {
1098 rpc.handle(async |span| {
1099 let child_span = tracing::info_span!(
1100 parent: &span,
1101 "nvme_worker_save",
1102 pci_id = %self.device.id()
1103 );
1104 self.save(state).instrument(child_span).await
1105 })
1106 .await
1107 }
1108 None => break,
1109 }
1110 }
1111 })
1112 .await;
1113 tracing::info!(pci_id = %self.device.id(), "nvme worker task exiting");
1114 r
1115 }
1116}
1117
1118impl<D: DeviceBacking> DriverWorkerTask<D> {
1119 fn restore_io_issuer(&mut self, proto: ProtoIoQueue) -> anyhow::Result<()> {
1120 let pci_id = self.device.id().to_owned();
1121 let qid = proto.save_state.queue_data.qid;
1122 let cpu = proto.save_state.cpu;
1123
1124 tracing::info!(
1125 qid,
1126 cpu,
1127 ?pci_id,
1128 "restoring queue from prototype: mapping interrupt"
1129 );
1130 let interrupt = self
1131 .device
1132 .map_interrupt(proto.save_state.iv, proto.save_state.cpu)
1133 .with_context(|| {
1134 format!(
1135 "failed to map interrupt for {}, cpu {}, iv {}",
1136 pci_id, proto.save_state.cpu, proto.save_state.iv
1137 )
1138 })?;
1139
1140 tracing::info!(
1141 qid,
1142 cpu,
1143 ?pci_id,
1144 "restoring queue from prototype: restore IoQueue"
1145 );
1146 let queue = IoQueue::restore(
1147 self.driver.clone(),
1148 interrupt,
1149 self.registers.clone(),
1150 proto.mem,
1151 &proto.save_state,
1152 self.bounce_buffer,
1153 )
1154 .with_context(|| format!("failed to restore io queue for {}, cpu {}", pci_id, cpu))?;
1155
1156 tracing::info!(
1157 qid,
1158 cpu,
1159 ?pci_id,
1160 "restoring queue from prototype: restore complete"
1161 );
1162
1163 let issuer = IoIssuer {
1164 issuer: queue.queue.issuer().clone(),
1165 cpu,
1166 };
1167
1168 self.io_issuers.per_cpu[cpu as usize]
1169 .set(issuer)
1170 .expect("issuer already set for this cpu");
1171 self.io.push(queue);
1172
1173 Ok(())
1174 }
1175
1176 async fn create_io_issuer(&mut self, state: &mut WorkerState, cpu: u32) {
1177 tracing::debug!(cpu, pci_id = ?self.device.id(), "issuer request");
1178 if self.io_issuers.per_cpu[cpu as usize].get().is_some() {
1179 return;
1180 }
1181
1182 if let Some(proto) = self.proto_io.remove(&cpu) {
1183 match self.restore_io_issuer(proto) {
1184 Ok(()) => return,
1185 Err(err) => {
1186 tracing::error!(
1194 pci_id = ?self.device.id(),
1195 cpu,
1196 error = ?err,
1197 "failed to restore io queue from prototype, creating new queue"
1198 );
1199 }
1200 }
1201 }
1202
1203 let issuer = match self
1204 .create_io_queue(state, cpu)
1205 .instrument(info_span!("create_nvme_io_queue", cpu))
1206 .await
1207 {
1208 Ok(issuer) => issuer,
1209 Err(err) => {
1210 let (fallback_cpu, fallback) = self.io_issuers.per_cpu[..cpu as usize]
1212 .iter()
1213 .enumerate()
1214 .rev()
1215 .find_map(|(i, issuer)| issuer.get().map(|issuer| (i, issuer)))
1216 .expect("unable to find an io issuer for fallback");
1217
1218 match err {
1221 DeviceError::NoMoreIoQueues(_) => {
1222 tracing::info!(
1223 pci_id = ?self.device.id(),
1224 cpu,
1225 fallback_cpu,
1226 error = &err as &dyn std::error::Error,
1227 "failed to create io queue, falling back"
1228 );
1229 }
1230 _ => {
1231 tracing::error!(
1232 pci_id = ?self.device.id(),
1233 cpu,
1234 fallback_cpu,
1235 error = &err as &dyn std::error::Error,
1236 "failed to create io queue, falling back"
1237 );
1238 }
1239 }
1240
1241 fallback.clone()
1242 }
1243 };
1244
1245 self.io_issuers.per_cpu[cpu as usize]
1246 .set(issuer)
1247 .ok()
1248 .unwrap();
1249 }
1250
1251 async fn create_io_queue(
1252 &mut self,
1253 state: &mut WorkerState,
1254 cpu: u32,
1255 ) -> Result<IoIssuer, DeviceError> {
1256 if self.io.len() >= state.max_io_queues as usize {
1257 return Err(DeviceError::NoMoreIoQueues(state.max_io_queues));
1258 }
1259
1260 let qid = self.next_ioq_id;
1263 let iv = qid - 1;
1264 self.next_ioq_id += 1;
1265
1266 tracing::debug!(cpu, qid, iv, pci_id = ?self.device.id(), "creating io queue");
1267
1268 let interrupt = self
1269 .device
1270 .map_interrupt(iv.into(), cpu)
1271 .map_err(DeviceError::InterruptMapFailure)?;
1272
1273 let queue = QueuePair::new(
1274 self.driver.clone(),
1275 self.device.deref(),
1276 qid,
1277 state.qsize,
1278 state.qsize,
1279 interrupt,
1280 self.registers.clone(),
1281 self.bounce_buffer,
1282 NoOpAerHandler,
1283 )
1284 .map_err(|err| DeviceError::IoQueuePairCreationFailure(err, qid))?;
1285
1286 assert_eq!(queue.sq_entries(), queue.cq_entries());
1287 state.qsize = queue.sq_entries();
1288
1289 let io_sq_addr = queue.sq_addr();
1290 let io_cq_addr = queue.cq_addr();
1291
1292 self.io.push(IoQueue { queue, iv, cpu });
1295 let io_queue = self.io.last_mut().unwrap();
1296
1297 let admin = self.admin.as_ref().unwrap().issuer().as_ref();
1298
1299 let mut created_completion_queue = false;
1300 let r = async {
1301 admin
1302 .issue_raw(spec::Command {
1303 cdw10: spec::Cdw10CreateIoQueue::new()
1304 .with_qid(qid)
1305 .with_qsize_z(state.qsize - 1)
1306 .into(),
1307 cdw11: spec::Cdw11CreateIoCompletionQueue::new()
1308 .with_ien(true)
1309 .with_iv(iv)
1310 .with_pc(true)
1311 .into(),
1312 dptr: [io_cq_addr, 0],
1313 ..admin_cmd(spec::AdminOpcode::CREATE_IO_COMPLETION_QUEUE)
1314 })
1315 .await
1316 .map_err(|err| DeviceError::IoCompletionQueueFailure(err.into(), qid))?;
1317
1318 created_completion_queue = true;
1319
1320 admin
1321 .issue_raw(spec::Command {
1322 cdw10: spec::Cdw10CreateIoQueue::new()
1323 .with_qid(qid)
1324 .with_qsize_z(state.qsize - 1)
1325 .into(),
1326 cdw11: spec::Cdw11CreateIoSubmissionQueue::new()
1327 .with_cqid(qid)
1328 .with_pc(true)
1329 .into(),
1330 dptr: [io_sq_addr, 0],
1331 ..admin_cmd(spec::AdminOpcode::CREATE_IO_SUBMISSION_QUEUE)
1332 })
1333 .await
1334 .map_err(|err| DeviceError::IoSubmissionQueueFailure(err.into(), qid))?;
1335
1336 Ok(())
1337 };
1338
1339 if let Err(err) = r.await {
1340 if created_completion_queue {
1341 if let Err(err) = admin
1342 .issue_raw(spec::Command {
1343 cdw10: spec::Cdw10DeleteIoQueue::new().with_qid(qid).into(),
1344 ..admin_cmd(spec::AdminOpcode::DELETE_IO_COMPLETION_QUEUE)
1345 })
1346 .await
1347 {
1348 tracing::error!(
1349 pci_id = ?self.device.id(),
1350 error = &err as &dyn std::error::Error,
1351 "failed to delete completion queue in teardown path"
1352 );
1353 }
1354 }
1355 let io = self.io.pop().unwrap();
1356 io.queue.shutdown().await;
1357 return Err(DeviceError::Other(err));
1358 }
1359
1360 Ok(IoIssuer {
1361 issuer: io_queue.queue.issuer().clone(),
1362 cpu,
1363 })
1364 }
1365
1366 pub async fn save(
1368 &mut self,
1369 worker_state: &mut WorkerState,
1370 ) -> anyhow::Result<NvmeDriverWorkerSavedState> {
1371 let admin = match self.admin.as_ref() {
1372 Some(a) => Some(a.save().await?),
1373 None => None,
1374 };
1375
1376 let io: Vec<IoQueueSavedState> = join_all(self.io.drain(..).map(async |q| q.save().await))
1377 .await
1378 .into_iter()
1379 .flatten()
1380 .chain(
1383 self.proto_io
1384 .drain()
1385 .map(|(_cpu, proto_queue)| proto_queue.save_state),
1386 )
1387 .collect();
1388
1389 match admin {
1390 None => tracing::warn!(pci_id = ?self.device.id(), "no admin queue saved"),
1391 Some(ref admin_state) => tracing::info!(
1392 pci_id = ?self.device.id(),
1393 id = admin_state.qid,
1394 pending_commands_count = admin_state.handler_data.pending_cmds.commands.len(),
1395 "saved admin queue",
1396 ),
1397 }
1398
1399 match io.is_empty() {
1400 true => tracing::warn!(pci_id = ?self.device.id(), "no io queues saved"),
1401 false => tracing::info!(
1402 pci_id = ?self.device.id(),
1403 state = io
1404 .iter()
1405 .map(|io_state| format!(
1406 "{{qid={}, pending_commands_count={}}}",
1407 io_state.queue_data.qid,
1408 io_state.queue_data.handler_data.pending_cmds.commands.len()
1409 ))
1410 .collect::<Vec<_>>()
1411 .join(", "),
1412 "saved io queues",
1413 ),
1414 }
1415
1416 Ok(NvmeDriverWorkerSavedState {
1417 admin,
1418 io,
1419 qsize: worker_state.qsize,
1420 max_io_queues: worker_state.max_io_queues,
1421 })
1422 }
1423}
1424
1425impl<D: DeviceBacking> InspectTask<WorkerState> for DriverWorkerTask<D> {
1426 fn inspect(&self, req: inspect::Request<'_>, state: Option<&WorkerState>) {
1427 req.respond().merge(self).merge(state);
1428 }
1429}
1430
1431#[expect(missing_docs)]
1433pub mod save_restore {
1434 use super::*;
1435
1436 #[derive(Debug, Error)]
1438 pub enum Error {
1439 #[error("invalid object state")]
1441 InvalidState,
1442 }
1443
1444 #[derive(Protobuf, Clone, Debug)]
1446 #[mesh(package = "nvme_driver")]
1447 pub struct NvmeDriverSavedState {
1448 #[mesh(1, encoding = "mesh::payload::encoding::ZeroCopyEncoding")]
1451 pub identify_ctrl: spec::IdentifyController,
1452 #[mesh(2)]
1454 pub device_id: String,
1455 #[mesh(3)]
1457 pub namespaces: Vec<SavedNamespaceData>,
1458 #[mesh(4)]
1460 pub worker_data: NvmeDriverWorkerSavedState,
1461 }
1462
1463 #[derive(Protobuf, Clone, Debug)]
1465 #[mesh(package = "nvme_driver")]
1466 pub struct NvmeDriverWorkerSavedState {
1467 #[mesh(1)]
1469 pub admin: Option<QueuePairSavedState>,
1470 #[mesh(2)]
1472 pub io: Vec<IoQueueSavedState>,
1473 #[mesh(3)]
1475 pub qsize: u16,
1476 #[mesh(4)]
1478 pub max_io_queues: u16,
1479 }
1480
1481 #[derive(Protobuf, Clone, Debug)]
1483 #[mesh(package = "nvme_driver")]
1484 pub struct QueuePairSavedState {
1485 #[mesh(1)]
1487 pub mem_len: usize,
1488 #[mesh(2)]
1490 pub base_pfn: u64,
1491 #[mesh(3)]
1494 pub qid: u16,
1495 #[mesh(4)]
1497 pub sq_entries: u16,
1498 #[mesh(5)]
1500 pub cq_entries: u16,
1501 #[mesh(6)]
1503 pub handler_data: QueueHandlerSavedState,
1504 }
1505
1506 #[derive(Protobuf, Clone, Debug)]
1508 #[mesh(package = "nvme_driver")]
1509 pub struct IoQueueSavedState {
1510 #[mesh(1)]
1511 pub cpu: u32,
1513 #[mesh(2)]
1514 pub iv: u32,
1516 #[mesh(3)]
1517 pub queue_data: QueuePairSavedState,
1518 }
1519
1520 #[derive(Protobuf, Clone, Debug)]
1522 #[mesh(package = "nvme_driver")]
1523 pub struct QueueHandlerSavedState {
1524 #[mesh(1)]
1525 pub sq_state: SubmissionQueueSavedState,
1526 #[mesh(2)]
1527 pub cq_state: CompletionQueueSavedState,
1528 #[mesh(3)]
1529 pub pending_cmds: PendingCommandsSavedState,
1530 #[mesh(4)]
1531 pub aer_handler: Option<AerHandlerSavedState>,
1532 }
1533
1534 #[derive(Protobuf, Clone, Debug)]
1536 #[mesh(package = "nvme_driver")]
1537 pub struct SubmissionQueueSavedState {
1538 #[mesh(1)]
1539 pub sqid: u16,
1540 #[mesh(2)]
1541 pub head: u32,
1542 #[mesh(3)]
1543 pub tail: u32,
1544 #[mesh(4)]
1545 pub committed_tail: u32,
1546 #[mesh(5)]
1547 pub len: u32,
1548 }
1549
1550 #[derive(Protobuf, Clone, Debug)]
1552 #[mesh(package = "nvme_driver")]
1553 pub struct CompletionQueueSavedState {
1554 #[mesh(1)]
1555 pub cqid: u16,
1556 #[mesh(2)]
1557 pub head: u32,
1558 #[mesh(3)]
1559 pub committed_head: u32,
1560 #[mesh(4)]
1561 pub len: u32,
1562 #[mesh(5)]
1563 pub phase: bool,
1565 }
1566
1567 #[derive(Protobuf, Clone, Debug)]
1569 #[mesh(package = "nvme_driver")]
1570 pub struct PendingCommandSavedState {
1571 #[mesh(1, encoding = "mesh::payload::encoding::ZeroCopyEncoding")]
1572 pub command: spec::Command,
1573 }
1574
1575 #[derive(Protobuf, Clone, Debug)]
1577 #[mesh(package = "nvme_driver")]
1578 pub struct PendingCommandsSavedState {
1579 #[mesh(1)]
1580 pub commands: Vec<PendingCommandSavedState>,
1581 #[mesh(2)]
1582 pub next_cid_high_bits: u16,
1583 #[mesh(3)]
1584 pub cid_key_bits: u32,
1585 }
1586
1587 #[derive(Protobuf, Clone, Debug)]
1589 #[mesh(package = "nvme_driver")]
1590 pub struct SavedNamespaceData {
1591 #[mesh(1)]
1592 pub nsid: u32,
1593 #[mesh(2, encoding = "mesh::payload::encoding::ZeroCopyEncoding")]
1594 pub identify_ns: nvme_spec::nvm::IdentifyNamespace,
1595 }
1596
1597 #[derive(Clone, Debug, Protobuf)]
1599 #[mesh(package = "nvme_driver")]
1600 pub struct AerHandlerSavedState {
1601 #[mesh(1)]
1602 pub last_aen: Option<u32>,
1603 #[mesh(2)]
1604 pub await_aen_cid: Option<u16>,
1605 }
1606}