nvme_driver/
driver.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Implementation of the device driver core.
5
6use super::spec;
7use crate::NVME_PAGE_SHIFT;
8use crate::NamespaceError;
9use crate::NamespaceHandle;
10use crate::RequestError;
11use crate::driver::save_restore::IoQueueSavedState;
12use crate::namespace::Namespace;
13use crate::queue_pair::AdminAerHandler;
14use crate::queue_pair::Issuer;
15use crate::queue_pair::MAX_CQ_ENTRIES;
16use crate::queue_pair::MAX_SQ_ENTRIES;
17use crate::queue_pair::NoOpAerHandler;
18use crate::queue_pair::QueuePair;
19use crate::queue_pair::admin_cmd;
20use crate::registers::Bar0;
21use crate::registers::DeviceRegisters;
22use crate::save_restore::NvmeDriverSavedState;
23use anyhow::Context as _;
24use futures::StreamExt;
25use futures::future::join_all;
26use inspect::Inspect;
27use mesh::payload::Protobuf;
28use mesh::rpc::Rpc;
29use mesh::rpc::RpcSend;
30use pal_async::task::Spawn;
31use pal_async::task::Task;
32use parking_lot::RwLock;
33use save_restore::NvmeDriverWorkerSavedState;
34use std::collections::HashMap;
35use std::mem::ManuallyDrop;
36use std::ops::Deref;
37use std::sync::Arc;
38use std::sync::OnceLock;
39use std::sync::Weak;
40use task_control::AsyncRun;
41use task_control::InspectTask;
42use task_control::TaskControl;
43use thiserror::Error;
44use tracing::Instrument;
45use tracing::Span;
46use tracing::info_span;
47use user_driver::DeviceBacking;
48use user_driver::backoff::Backoff;
49use user_driver::interrupt::DeviceInterrupt;
50use user_driver::memory::MemoryBlock;
51use vmcore::vm_task::VmTaskDriver;
52use vmcore::vm_task::VmTaskDriverSource;
53use zerocopy::FromBytes;
54use zerocopy::FromZeros;
55use zerocopy::IntoBytes;
56
57/// An NVMe driver.
58///
59/// Note that if this is dropped, the process will abort. Call
60/// [`NvmeDriver::shutdown`] to drop this.
61///
62/// Further, note that this is an internal interface to be used
63/// only by `NvmeDisk`! Remove any sanitization in `fuzz_nvm_driver.rs`
64/// if this struct is used anywhere else.
65#[derive(Inspect)]
66pub struct NvmeDriver<D: DeviceBacking> {
67    #[inspect(flatten)]
68    task: Option<TaskControl<DriverWorkerTask<D>, WorkerState>>,
69    device_id: String,
70    identify: Option<Arc<spec::IdentifyController>>,
71    #[inspect(skip)]
72    driver: VmTaskDriver,
73    #[inspect(skip)]
74    admin: Option<Arc<Issuer>>,
75    #[inspect(skip)]
76    io_issuers: Arc<IoIssuers>,
77    #[inspect(skip)]
78    rescan_notifiers: Arc<RwLock<HashMap<u32, mesh::Sender<()>>>>,
79    /// NVMe namespaces associated with this driver. Mapping nsid to NamespaceHandle.
80    #[inspect(skip)]
81    namespaces: HashMap<u32, WeakOrStrong<Namespace>>,
82    /// Keeps the controller connected (CC.EN==1) while servicing.
83    nvme_keepalive: bool,
84    bounce_buffer: bool,
85}
86
87/// A container that can hold either a weak or strong reference to a value.
88///
89/// During normal operation, the driver ONLY stores weak references. After restore
90/// strong references are temporarily held until the StorageController retrieves them.
91/// Once retrieved, the strong reference is downgraded to a weak one, resuming
92/// normal behavior.
93enum WeakOrStrong<T> {
94    Weak(Weak<T>),
95    Strong(Arc<T>),
96}
97
98impl<T> WeakOrStrong<T> {
99    /// Returns a strong reference to the underlying value when possible.
100    /// Implicitly downgrades Strong to Weak when this function is invoked.
101    pub fn get_arc(&mut self) -> Option<Arc<T>> {
102        match self {
103            WeakOrStrong::Strong(arc) => {
104                let strong = arc.clone();
105                *self = WeakOrStrong::Weak(Arc::downgrade(arc));
106                Some(strong)
107            }
108            WeakOrStrong::Weak(weak) => weak.upgrade(),
109        }
110    }
111
112    pub fn is_weak(&self) -> bool {
113        matches!(self, WeakOrStrong::Weak(_))
114    }
115}
116
117#[derive(Inspect)]
118struct DriverWorkerTask<D: DeviceBacking> {
119    /// The VFIO device backing this driver. For KeepAlive cases, the VFIO handle
120    /// is never dropped, otherwise there is a chance that VFIO will reset the
121    /// device. We don't want that.
122    ///
123    /// Dropped in `NvmeDriver::reset`.
124    device: ManuallyDrop<D>,
125    #[inspect(skip)]
126    driver: VmTaskDriver,
127    registers: Arc<DeviceRegisters<D>>,
128    admin: Option<QueuePair<AdminAerHandler, D>>,
129    #[inspect(iter_by_index)]
130    io: Vec<IoQueue<D>>,
131    /// Prototype IO queues for restoring from saved state. These are queues
132    /// that were created on the device at some point, but had no pending
133    /// IOs at save/restore time. These will be promoted to full IO queues
134    /// on demand.
135    ///
136    /// cpu => queue info
137    #[inspect(skip)]
138    proto_io: HashMap<u32, ProtoIoQueue>,
139    /// The next qid to use when creating an IO queue for a new issuer.
140    next_ioq_id: u16,
141    io_issuers: Arc<IoIssuers>,
142    #[inspect(skip)]
143    recv: mesh::Receiver<NvmeWorkerRequest>,
144    bounce_buffer: bool,
145}
146
147#[derive(Inspect)]
148struct WorkerState {
149    max_io_queues: u16,
150    qsize: u16,
151    #[inspect(skip)]
152    async_event_task: Task<()>,
153}
154
155/// An error restoring from saved state.
156#[derive(Debug, Error)]
157pub enum RestoreError {
158    #[error("invalid data")]
159    InvalidData,
160}
161
162#[derive(Debug, Error)]
163pub enum DeviceError {
164    #[error("no more io queues available, reached maximum {0}")]
165    NoMoreIoQueues(u16),
166    #[error("failed to map interrupt")]
167    InterruptMapFailure(#[source] anyhow::Error),
168    #[error("failed to create io queue pair {1}")]
169    IoQueuePairCreationFailure(#[source] anyhow::Error, u16),
170    #[error("failed to create io completion queue {1}")]
171    IoCompletionQueueFailure(#[source] anyhow::Error, u16),
172    #[error("failed to create io submission queue {1}")]
173    IoSubmissionQueueFailure(#[source] anyhow::Error, u16),
174    // Other device related errors
175    #[error(transparent)]
176    Other(anyhow::Error),
177}
178
179#[derive(Debug, Clone)]
180struct ProtoIoQueue {
181    save_state: IoQueueSavedState,
182    mem: MemoryBlock,
183}
184
185#[derive(Inspect)]
186struct IoQueue<D: DeviceBacking> {
187    queue: QueuePair<NoOpAerHandler, D>,
188    iv: u16,
189    cpu: u32,
190}
191
192impl<D: DeviceBacking> IoQueue<D> {
193    pub async fn save(&self) -> anyhow::Result<IoQueueSavedState> {
194        Ok(IoQueueSavedState {
195            cpu: self.cpu,
196            iv: self.iv as u32,
197            queue_data: self.queue.save().await?,
198        })
199    }
200
201    pub fn restore(
202        spawner: VmTaskDriver,
203        interrupt: DeviceInterrupt,
204        registers: Arc<DeviceRegisters<D>>,
205        mem_block: MemoryBlock,
206        device_id: &str,
207        saved_state: &IoQueueSavedState,
208        bounce_buffer: bool,
209    ) -> anyhow::Result<Self> {
210        let IoQueueSavedState {
211            cpu,
212            iv,
213            queue_data,
214        } = saved_state;
215        let queue = QueuePair::restore(
216            spawner,
217            interrupt,
218            registers.clone(),
219            mem_block,
220            device_id,
221            queue_data,
222            bounce_buffer,
223            NoOpAerHandler,
224        )?;
225
226        Ok(Self {
227            queue,
228            iv: *iv as u16,
229            cpu: *cpu,
230        })
231    }
232}
233
234#[derive(Debug, Inspect)]
235pub(crate) struct IoIssuers {
236    #[inspect(iter_by_index)]
237    per_cpu: Vec<OnceLock<IoIssuer>>,
238    #[inspect(skip)]
239    send: mesh::Sender<NvmeWorkerRequest>,
240}
241
242#[derive(Debug, Clone, Inspect)]
243struct IoIssuer {
244    #[inspect(flatten)]
245    issuer: Arc<Issuer>,
246    cpu: u32,
247}
248
249#[derive(Debug)]
250enum NvmeWorkerRequest {
251    CreateIssuer(Rpc<u32, ()>),
252    /// Save worker state.
253    Save(Rpc<Span, anyhow::Result<NvmeDriverWorkerSavedState>>),
254}
255
256impl<D: DeviceBacking> NvmeDriver<D> {
257    /// Initializes the driver.
258    pub async fn new(
259        driver_source: &VmTaskDriverSource,
260        cpu_count: u32,
261        device: D,
262        bounce_buffer: bool,
263    ) -> anyhow::Result<Self> {
264        let pci_id = device.id().to_owned();
265        let mut this = Self::new_disabled(driver_source, cpu_count, device, bounce_buffer)
266            .instrument(tracing::info_span!("nvme_new_disabled", pci_id))
267            .await?;
268        match this
269            .enable(cpu_count as u16)
270            .instrument(tracing::info_span!("nvme_enable", pci_id))
271            .await
272        {
273            Ok(()) => Ok(this),
274            Err(err) => {
275                tracing::error!(
276                    error = err.as_ref() as &dyn std::error::Error,
277                    "device initialization failed, shutting down"
278                );
279                this.shutdown().await;
280                Err(err)
281            }
282        }
283    }
284
285    /// Initializes but does not enable the device. DMA memory
286    /// is preallocated from backing device.
287    async fn new_disabled(
288        driver_source: &VmTaskDriverSource,
289        cpu_count: u32,
290        mut device: D,
291        bounce_buffer: bool,
292    ) -> anyhow::Result<Self> {
293        let driver = driver_source.simple();
294        let bar0 = Bar0(
295            device
296                .map_bar(0)
297                .context("failed to map device registers")?,
298        );
299
300        let cc = bar0.cc();
301        if cc.en() || bar0.csts().rdy() {
302            if let Err(e) = bar0
303                .reset(&driver)
304                .instrument(tracing::info_span!(
305                    "nvme_already_enabled",
306                    pci_id = device.id().to_owned()
307                ))
308                .await
309            {
310                anyhow::bail!("device is gone, csts: {:#x}", e);
311            }
312        }
313
314        let registers = Arc::new(DeviceRegisters::new(bar0));
315        let cap = registers.cap;
316
317        if cap.mpsmin() != 0 {
318            anyhow::bail!(
319                "unsupported minimum page size: {}",
320                cap.mpsmin() + NVME_PAGE_SHIFT
321            );
322        }
323
324        let (send, recv) = mesh::channel();
325        let io_issuers = Arc::new(IoIssuers {
326            per_cpu: (0..cpu_count).map(|_| OnceLock::new()).collect(),
327            send,
328        });
329
330        Ok(Self {
331            device_id: device.id().to_owned(),
332            task: Some(TaskControl::new(DriverWorkerTask {
333                device: ManuallyDrop::new(device),
334                driver: driver.clone(),
335                registers,
336                admin: None,
337                io: Vec::new(),
338                proto_io: HashMap::new(),
339                next_ioq_id: 1,
340                io_issuers: io_issuers.clone(),
341                recv,
342                bounce_buffer,
343            })),
344            admin: None,
345            identify: None,
346            driver,
347            io_issuers,
348            rescan_notifiers: Default::default(),
349            namespaces: Default::default(),
350            nvme_keepalive: false,
351            bounce_buffer,
352        })
353    }
354
355    /// Enables the device, aliasing the admin queue memory and adding IO queues.
356    async fn enable(&mut self, requested_io_queue_count: u16) -> anyhow::Result<()> {
357        const ADMIN_QID: u16 = 0;
358
359        let task = &mut self.task.as_mut().unwrap();
360        let worker = task.task_mut();
361
362        // Request the admin queue pair be the same size to avoid potential
363        // device bugs where differing sizes might be a less common scenario
364        //
365        // Namely: using differing sizes revealed a bug in the initial NvmeDirectV2 implementation
366        let admin_len = std::cmp::min(MAX_SQ_ENTRIES, MAX_CQ_ENTRIES);
367        let admin_sqes = admin_len;
368        let admin_cqes = admin_len;
369
370        let interrupt0 = worker
371            .device
372            .map_interrupt(0, 0)
373            .context("failed to map interrupt 0")?;
374
375        // Start the admin queue pair.
376        let admin = QueuePair::new(
377            self.driver.clone(),
378            worker.device.deref(),
379            ADMIN_QID,
380            admin_sqes,
381            admin_cqes,
382            interrupt0,
383            worker.registers.clone(),
384            self.bounce_buffer,
385            AdminAerHandler::new(),
386        )
387        .context("failed to create admin queue pair")?;
388
389        let admin_sqes = admin.sq_entries();
390        let admin_cqes = admin.cq_entries();
391
392        let admin = worker.admin.insert(admin);
393
394        // Register the admin queue with the controller.
395        worker.registers.bar0.set_aqa(
396            spec::Aqa::new()
397                .with_acqs_z(admin_cqes - 1)
398                .with_asqs_z(admin_sqes - 1),
399        );
400        worker.registers.bar0.set_asq(admin.sq_addr());
401        worker.registers.bar0.set_acq(admin.cq_addr());
402
403        // Enable the controller.
404        let span = tracing::info_span!("nvme_ctrl_enable", pci_id = worker.device.id().to_owned());
405        let ctrl_enable_span = span.enter();
406        worker.registers.bar0.set_cc(
407            spec::Cc::new()
408                .with_iocqes(4)
409                .with_iosqes(6)
410                .with_en(true)
411                .with_mps(0),
412        );
413
414        // Wait for the controller to be ready.
415        let mut backoff = Backoff::new(&self.driver);
416        loop {
417            let csts = worker.registers.bar0.csts();
418            let csts_val: u32 = csts.into();
419            if csts_val == !0 {
420                anyhow::bail!("device is gone, csts: {:#x}", csts_val);
421            }
422            if csts.cfs() {
423                // Attempt to leave the device in reset state CC.EN 1 -> 0.
424                let after_reset = if let Err(e) = worker.registers.bar0.reset(&self.driver).await {
425                    e
426                } else {
427                    0
428                };
429                anyhow::bail!(
430                    "device had fatal error, csts: {:#x}, after reset: {:#}",
431                    csts_val,
432                    after_reset
433                );
434            }
435            if csts.rdy() {
436                break;
437            }
438            backoff.back_off().await;
439        }
440        drop(ctrl_enable_span);
441
442        // Get the controller identify structure.
443        let identify = self
444            .identify
445            .insert(Arc::new(spec::IdentifyController::new_zeroed()));
446
447        admin
448            .issuer()
449            .issue_out(
450                spec::Command {
451                    cdw10: spec::Cdw10Identify::new()
452                        .with_cns(spec::Cns::CONTROLLER.0)
453                        .into(),
454                    ..admin_cmd(spec::AdminOpcode::IDENTIFY)
455                },
456                Arc::get_mut(identify).unwrap().as_mut_bytes(),
457            )
458            .await
459            .context("failed to identify controller")?;
460
461        // Configure the number of IO queues.
462        //
463        // Note that interrupt zero is shared between IO queue 1 and the admin queue.
464        let max_interrupt_count = worker.device.max_interrupt_count();
465        if max_interrupt_count == 0 {
466            anyhow::bail!("bad device behavior: max_interrupt_count == 0");
467        }
468
469        let requested_io_queue_count = if max_interrupt_count < requested_io_queue_count as u32 {
470            tracing::warn!(
471                max_interrupt_count,
472                requested_io_queue_count,
473                pci_id = ?worker.device.id(),
474                "queue count constrained by msi count"
475            );
476            max_interrupt_count as u16
477        } else {
478            requested_io_queue_count
479        };
480
481        let completion = admin
482            .issuer()
483            .issue_neither(spec::Command {
484                cdw10: spec::Cdw10SetFeatures::new()
485                    .with_fid(spec::Feature::NUMBER_OF_QUEUES.0)
486                    .into(),
487                cdw11: spec::Cdw11FeatureNumberOfQueues::new()
488                    .with_ncq_z(requested_io_queue_count - 1)
489                    .with_nsq_z(requested_io_queue_count - 1)
490                    .into(),
491                ..admin_cmd(spec::AdminOpcode::SET_FEATURES)
492            })
493            .await
494            .context("failed to set number of queues")?;
495
496        // See how many queues are actually available.
497        let dw0 = spec::Cdw11FeatureNumberOfQueues::from(completion.dw0);
498        let sq_count = dw0.nsq_z() + 1;
499        let cq_count = dw0.ncq_z() + 1;
500        let allocated_io_queue_count = sq_count.min(cq_count);
501        if allocated_io_queue_count < requested_io_queue_count {
502            tracing::warn!(
503                sq_count,
504                cq_count,
505                requested_io_queue_count,
506                pci_id = ?worker.device.id(),
507                "queue count constrained by hardware queue count"
508            );
509        }
510
511        let max_io_queues = allocated_io_queue_count.min(requested_io_queue_count);
512
513        let qsize = {
514            if worker.registers.cap.mqes_z() < 1 {
515                anyhow::bail!("bad device behavior. mqes cannot be 0");
516            }
517
518            let io_cqsize = (MAX_CQ_ENTRIES - 1).min(worker.registers.cap.mqes_z()) + 1;
519            let io_sqsize = (MAX_SQ_ENTRIES - 1).min(worker.registers.cap.mqes_z()) + 1;
520
521            tracing::debug!(
522                io_cqsize,
523                io_sqsize,
524                hw_size = worker.registers.cap.mqes_z(),
525                pci_id = ?worker.device.id(),
526                "io queue sizes"
527            );
528
529            // Some hardware (such as ASAP) require that the sq and cq have the same size.
530            io_cqsize.min(io_sqsize)
531        };
532
533        // Spawn a task to handle asynchronous events.
534        let async_event_task = self.driver.spawn("nvme_async_event", {
535            let admin = admin.issuer().clone();
536            let rescan_notifiers = self.rescan_notifiers.clone();
537            async move {
538                if let Err(err) = handle_asynchronous_events(&admin, rescan_notifiers).await {
539                    tracing::error!(
540                        error = err.as_ref() as &dyn std::error::Error,
541                        "asynchronous event failure, not processing any more"
542                    );
543                }
544            }
545        });
546
547        let mut state = WorkerState {
548            qsize,
549            async_event_task,
550            max_io_queues,
551        };
552
553        self.admin = Some(admin.issuer().clone());
554
555        // Pre-create the IO queue 1 for CPU 0. The other queues will be created
556        // lazily. Numbering for I/O queues starts with 1 (0 is Admin).
557        let issuer = worker
558            .create_io_queue(&mut state, 0)
559            .await
560            .context("failed to create io queue 1")?;
561
562        self.io_issuers.per_cpu[0].set(issuer).unwrap();
563        task.insert(&self.driver, "nvme_worker", state);
564        task.start();
565        Ok(())
566    }
567
568    /// Shuts the device down.
569    pub async fn shutdown(mut self) {
570        tracing::debug!(pci_id = ?self.device_id, "shutting down nvme driver");
571
572        // If nvme_keepalive was requested, return early.
573        // The memory is still aliased as we don't flush pending IOs.
574        if self.nvme_keepalive {
575            return;
576        }
577        self.reset().await;
578        drop(self);
579    }
580
581    fn reset(&mut self) -> impl Send + Future<Output = ()> + use<D> {
582        let driver = self.driver.clone();
583        let id = self.device_id.clone();
584        let mut task = std::mem::take(&mut self.task).unwrap();
585        async move {
586            task.stop().await;
587            let (worker, state) = task.into_inner();
588            if let Some(state) = state {
589                state.async_event_task.cancel().await;
590            }
591            // Hold onto responses until the reset completes so that waiting IOs do
592            // not think the memory is unaliased by the device.
593            let _io_responses = join_all(worker.io.into_iter().map(|io| io.queue.shutdown())).await;
594            let _admin_responses;
595            if let Some(admin) = worker.admin {
596                _admin_responses = admin.shutdown().await;
597            }
598            if let Err(e) = worker.registers.bar0.reset(&driver).await {
599                tracing::info!(csts = e, "device reset failed");
600            }
601
602            let _vfio = ManuallyDrop::into_inner(worker.device);
603            tracing::debug!(pci_id = ?id, "dropping vfio handle to device");
604        }
605    }
606
607    /// Gets the namespace with namespace ID `nsid`.
608    pub async fn namespace(&mut self, nsid: u32) -> Result<NamespaceHandle, NamespaceError> {
609        if let Some(namespace) = self.namespaces.get_mut(&nsid) {
610            // After restore we will have a strong ref -> downgrade and return.
611            // If we have a weak ref, make sure it is not upgradeable (that means we have a duplicate somewhere).
612            let is_weak = namespace.is_weak(); // This value will change after invoking get_arc().
613            let namespace = namespace.get_arc();
614            if let Some(namespace) = namespace {
615                if is_weak && namespace.check_active().is_ok() {
616                    return Err(NamespaceError::Duplicate(nsid));
617                }
618
619                tracing::debug!(
620                    "reusing existing namespace nsid={}. This should only happen after restore.",
621                    nsid
622                );
623                return Ok(NamespaceHandle::new(namespace));
624            }
625        }
626
627        let (send, recv) = mesh::channel::<()>();
628        let namespace = Arc::new(
629            Namespace::new(
630                &self.driver,
631                self.admin.as_ref().unwrap().clone(),
632                recv,
633                self.identify.clone().unwrap(),
634                &self.io_issuers,
635                nsid,
636            )
637            .await?,
638        );
639        self.namespaces
640            .insert(nsid, WeakOrStrong::Weak(Arc::downgrade(&namespace)));
641
642        // Append the sender to the list of notifiers for this nsid.
643        let mut notifiers = self.rescan_notifiers.write();
644        notifiers.insert(nsid, send);
645        Ok(NamespaceHandle::new(namespace))
646    }
647
648    /// Returns the number of CPUs that are in fallback mode (that are using a
649    /// remote CPU's queue due to a failure or resource limitation).
650    pub fn fallback_cpu_count(&self) -> usize {
651        self.io_issuers
652            .per_cpu
653            .iter()
654            .enumerate()
655            .filter(|&(cpu, c)| c.get().is_some_and(|c| c.cpu != cpu as u32))
656            .count()
657    }
658
659    /// Saves the NVMe driver state during servicing.
660    pub async fn save(&mut self) -> anyhow::Result<NvmeDriverSavedState> {
661        // Nothing to save if Identify Controller was never queried.
662        if self.identify.is_none() {
663            return Err(save_restore::Error::InvalidState.into());
664        }
665        let span = tracing::info_span!("nvme_driver_save", pci_id = self.device_id);
666        self.nvme_keepalive = true;
667        match self
668            .io_issuers
669            .send
670            .call(NvmeWorkerRequest::Save, span.clone())
671            .instrument(span.clone())
672            .await?
673        {
674            Ok(s) => {
675                let _e = span.entered();
676                tracing::info!(
677                    namespaces = self
678                        .namespaces
679                        .keys()
680                        .map(|nsid| nsid.to_string())
681                        .collect::<Vec<_>>()
682                        .join(", "),
683                    "saving namespaces",
684                );
685                let mut saved_namespaces = vec![];
686                for (nsid, namespace) in self.namespaces.iter_mut() {
687                    let is_weak = namespace.is_weak(); // This value will change after invoking get_arc().
688                    if let Some(ns) = namespace.get_arc()
689                        && ns.check_active().is_ok()
690                        && is_weak
691                    {
692                        saved_namespaces.push(ns.save().with_context(|| {
693                            format!(
694                                "failed to save namespace nsid {} device {}",
695                                nsid, self.device_id
696                            )
697                        })?);
698                    }
699                }
700                Ok(NvmeDriverSavedState {
701                    identify_ctrl: spec::IdentifyController::read_from_bytes(
702                        self.identify.as_ref().unwrap().as_bytes(),
703                    )
704                    .unwrap(),
705                    device_id: self.device_id.clone(),
706                    namespaces: saved_namespaces,
707                    worker_data: s,
708                })
709            }
710            Err(e) => Err(e),
711        }
712    }
713
714    /// This should only be called during restore if keepalive is no longer
715    /// supported and the previously enabled device needs to be reset. It
716    /// performs a controller reset by setting cc.en to 0. It will then also
717    /// drop the given device instance.
718    pub async fn clear_existing_state(
719        driver_source: &VmTaskDriverSource,
720        mut device: D,
721    ) -> anyhow::Result<()> {
722        let driver = driver_source.simple();
723        let bar0_mapping = device
724            .map_bar(0)
725            .context("failed to map device registers to clear existing state")?;
726        let bar0 = Bar0(bar0_mapping);
727        bar0.reset(&driver)
728            .await
729            .map_err(|e| anyhow::anyhow!("failed to reset device during clear: {:#x}", e))?;
730        Ok(())
731    }
732
733    /// Restores NVMe driver state after servicing.
734    pub async fn restore(
735        driver_source: &VmTaskDriverSource,
736        cpu_count: u32,
737        mut device: D,
738        saved_state: &NvmeDriverSavedState,
739        bounce_buffer: bool,
740    ) -> anyhow::Result<Self> {
741        let pci_id = device.id().to_owned();
742        let driver = driver_source.simple();
743        let bar0_mapping = device
744            .map_bar(0)
745            .context("failed to map device registers")?;
746        let bar0 = Bar0(bar0_mapping);
747
748        // It is expected for the device to be alive when restoring.
749        let csts = bar0.csts();
750        if !csts.rdy() {
751            tracing::error!(
752                csts = u32::from(csts),
753                ?pci_id,
754                "device is not ready during restore"
755            );
756            anyhow::bail!(
757                "device is not ready during restore, csts: {:#x}",
758                u32::from(csts)
759            );
760        }
761
762        let registers = Arc::new(DeviceRegisters::new(bar0));
763
764        let (send, recv) = mesh::channel();
765        let io_issuers = Arc::new(IoIssuers {
766            per_cpu: (0..cpu_count).map(|_| OnceLock::new()).collect(),
767            send,
768        });
769
770        let mut this = Self {
771            device_id: device.id().to_owned(),
772            task: Some(TaskControl::new(DriverWorkerTask {
773                device: ManuallyDrop::new(device),
774                driver: driver.clone(),
775                registers: registers.clone(),
776                admin: None, // Updated below.
777                io: Vec::new(),
778                proto_io: HashMap::new(),
779                next_ioq_id: 1,
780                io_issuers: io_issuers.clone(),
781                recv,
782                bounce_buffer,
783            })),
784            admin: None, // Updated below.
785            identify: Some(Arc::new(
786                spec::IdentifyController::read_from_bytes(saved_state.identify_ctrl.as_bytes())
787                    .map_err(|_| RestoreError::InvalidData)?,
788            )),
789            driver: driver.clone(),
790            io_issuers,
791            rescan_notifiers: Default::default(),
792            namespaces: Default::default(),
793            nvme_keepalive: true,
794            bounce_buffer,
795        };
796
797        let task = &mut this.task.as_mut().unwrap();
798        let worker = task.task_mut();
799
800        // Interrupt 0 is shared between admin queue and I/O queue 1.
801        let interrupt0 = worker
802            .device
803            .map_interrupt(0, 0)
804            .with_context(|| format!("failed to map interrupt 0 for {}", pci_id))?;
805
806        let dma_client = worker.device.dma_client();
807        let restored_memory = dma_client
808            .attach_pending_buffers()
809            .with_context(|| format!("failed to restore allocations for {}", pci_id))?;
810
811        // Restore the admin queue pair.
812        let admin = saved_state
813            .worker_data
814            .admin
815            .as_ref()
816            .map(|a| {
817                tracing::info!(
818                    id = a.qid,
819                    pending_commands_count = a.handler_data.pending_cmds.commands.len(),
820                    ?pci_id,
821                    "restoring admin queue",
822                );
823                // Restore memory block for admin queue pair.
824                let mem_block = restored_memory
825                    .iter()
826                    .find(|mem| mem.len() == a.mem_len && a.base_pfn == mem.pfns()[0])
827                    .expect("unable to find restored mem block")
828                    .to_owned();
829                QueuePair::restore(
830                    driver.clone(),
831                    interrupt0,
832                    registers.clone(),
833                    mem_block,
834                    &pci_id,
835                    a,
836                    bounce_buffer,
837                    AdminAerHandler::new(),
838                )
839                .expect("failed to restore admin queue pair")
840            })
841            .expect("attempted to restore admin queue from empty state");
842
843        let admin = worker.admin.insert(admin);
844
845        // Spawn a task to handle asynchronous events.
846        let async_event_task = this.driver.spawn("nvme_async_event", {
847            let admin = admin.issuer().clone();
848            let rescan_notifiers = this.rescan_notifiers.clone();
849            async move {
850                if let Err(err) = handle_asynchronous_events(&admin, rescan_notifiers)
851                    .instrument(tracing::info_span!("async_event_handler"))
852                    .await
853                {
854                    tracing::error!(
855                        error = err.as_ref() as &dyn std::error::Error,
856                        "asynchronous event failure, not processing any more"
857                    );
858                }
859            }
860        });
861
862        let state = WorkerState {
863            qsize: saved_state.worker_data.qsize,
864            async_event_task,
865            max_io_queues: saved_state.worker_data.max_io_queues,
866        };
867
868        this.admin = Some(admin.issuer().clone());
869
870        tracing::info!(
871            state = saved_state
872                .worker_data
873                .io
874                .iter()
875                .map(|io_state| format!(
876                    "{{qid={}, pending_commands_count={}}}",
877                    io_state.queue_data.qid,
878                    io_state.queue_data.handler_data.pending_cmds.commands.len()
879                ))
880                .collect::<Vec<_>>()
881                .join(", "),
882            ?pci_id,
883            "restoring io queues",
884        );
885
886        // Restore I/O queues.
887        // (1) Restore qid1 and any queues that have pending commands.
888        // Interrupt vector 0 is shared between Admin queue and I/O queue #1.
889        let mut max_seen_qid = 1;
890        worker.io = saved_state
891            .worker_data
892            .io
893            .iter()
894            .filter(|q| {
895                q.queue_data.qid == 1 || !q.queue_data.handler_data.pending_cmds.commands.is_empty()
896            })
897            .flat_map(|q| -> Result<IoQueue<D>, anyhow::Error> {
898                let qid = q.queue_data.qid;
899                let cpu = q.cpu;
900                tracing::info!(qid, cpu, ?pci_id, "restoring queue");
901                max_seen_qid = max_seen_qid.max(qid);
902                let interrupt = worker.device.map_interrupt(q.iv, q.cpu).with_context(|| {
903                    format!(
904                        "failed to map interrupt for {}, cpu {}, iv {}",
905                        pci_id, q.cpu, q.iv
906                    )
907                })?;
908                tracing::info!(qid, cpu, ?pci_id, "restoring queue: search for mem block");
909                let mem_block = restored_memory
910                    .iter()
911                    .find(|mem| {
912                        mem.len() == q.queue_data.mem_len && q.queue_data.base_pfn == mem.pfns()[0]
913                    })
914                    .expect("unable to find restored mem block")
915                    .to_owned();
916                tracing::info!(qid, cpu, ?pci_id, "restoring queue: restore IoQueue");
917                let q = IoQueue::restore(
918                    driver.clone(),
919                    interrupt,
920                    registers.clone(),
921                    mem_block,
922                    &pci_id,
923                    q,
924                    bounce_buffer,
925                )?;
926                tracing::info!(qid, cpu, ?pci_id, "restoring queue: create issuer");
927                let issuer = IoIssuer {
928                    issuer: q.queue.issuer().clone(),
929                    cpu: q.cpu,
930                };
931                this.io_issuers.per_cpu[q.cpu as usize].set(issuer).unwrap();
932                Ok(q)
933            })
934            .collect();
935
936        // (2) Create prototype entries for any queues that don't currently have outstanding commands.
937        // They will be restored on demand later.
938        worker.proto_io = saved_state
939            .worker_data
940            .io
941            .iter()
942            .filter(|q| {
943                q.queue_data.qid != 1 && q.queue_data.handler_data.pending_cmds.commands.is_empty()
944            })
945            .map(|q| {
946                // Create a prototype IO queue entry.
947                tracing::info!(
948                    qid = q.queue_data.qid,
949                    cpu = q.cpu,
950                    ?pci_id,
951                    "creating prototype io queue entry",
952                );
953                max_seen_qid = max_seen_qid.max(q.queue_data.qid);
954                let mem_block = restored_memory
955                    .iter()
956                    .find(|mem| {
957                        mem.len() == q.queue_data.mem_len && q.queue_data.base_pfn == mem.pfns()[0]
958                    })
959                    .expect("unable to find restored mem block")
960                    .to_owned();
961                (
962                    q.cpu,
963                    ProtoIoQueue {
964                        save_state: q.clone(),
965                        mem: mem_block,
966                    },
967                )
968            })
969            .collect();
970
971        // Update next_ioq_id to avoid reusing qids.
972        worker.next_ioq_id = max_seen_qid + 1;
973
974        tracing::info!(
975            namespaces = saved_state
976                .namespaces
977                .iter()
978                .map(|ns| format!("{{nsid={}, size={}}}", ns.nsid, ns.identify_ns.nsze))
979                .collect::<Vec<_>>()
980                .join(", "),
981            ?pci_id,
982            "restoring namespaces",
983        );
984
985        // Restore namespace(s).
986        for ns in &saved_state.namespaces {
987            let (send, recv) = mesh::channel::<()>();
988            this.namespaces.insert(
989                ns.nsid,
990                WeakOrStrong::Strong(Arc::new(Namespace::restore(
991                    &driver,
992                    admin.issuer().clone(),
993                    recv,
994                    this.identify.clone().unwrap(),
995                    &this.io_issuers,
996                    ns,
997                )?)),
998            );
999            this.rescan_notifiers.write().insert(ns.nsid, send);
1000        }
1001
1002        task.insert(&this.driver, "nvme_worker", state);
1003        task.start();
1004
1005        Ok(this)
1006    }
1007
1008    /// Change device's behavior when servicing.
1009    pub fn update_servicing_flags(&mut self, nvme_keepalive: bool) {
1010        tracing::debug!(nvme_keepalive, "updating nvme servicing flags");
1011        self.nvme_keepalive = nvme_keepalive;
1012    }
1013}
1014
1015async fn handle_asynchronous_events(
1016    admin: &Issuer,
1017    rescan_notifiers: Arc<RwLock<HashMap<u32, mesh::Sender<()>>>>,
1018) -> anyhow::Result<()> {
1019    tracing::info!("starting asynchronous event handler task");
1020    loop {
1021        let dw0 = admin
1022            .issue_get_aen()
1023            .await
1024            .context("asynchronous event request failed")?;
1025
1026        match spec::AsynchronousEventType(dw0.event_type()) {
1027            spec::AsynchronousEventType::NOTICE => {
1028                tracing::info!("received an async notice event (aen) from the controller");
1029
1030                // Clear the namespace list.
1031                let mut list = [0u32; 1024];
1032                admin
1033                    .issue_out(
1034                        spec::Command {
1035                            cdw10: spec::Cdw10GetLogPage::new()
1036                                .with_lid(spec::LogPageIdentifier::CHANGED_NAMESPACE_LIST.0)
1037                                .with_numdl_z(1023)
1038                                .into(),
1039                            ..admin_cmd(spec::AdminOpcode::GET_LOG_PAGE)
1040                        },
1041                        list.as_mut_bytes(),
1042                    )
1043                    .await
1044                    .context("failed to query changed namespace list")?;
1045
1046                // Notify only the namespaces that have changed.
1047
1048                // NOTE: The nvme spec states - If more than 1,024 namespaces have
1049                // changed attributes since the last time the log page was read,
1050                // the first entry in the log page shall be set to
1051                // FFFFFFFFh and the remainder of the list shall be zero filled.
1052                let notifier_guard = rescan_notifiers.read();
1053                if list[0] == 0xFFFFFFFF && list[1] == 0 {
1054                    // More than 1024 namespaces changed - notify all registered namespaces
1055                    tracing::info!("more than 1024 namespaces changed, notifying all listeners");
1056                    for notifiers in notifier_guard.values() {
1057                        notifiers.send(());
1058                    }
1059                } else {
1060                    // Notify specific namespaces that have changed
1061                    for nsid in list.iter().filter(|&&nsid| nsid != 0) {
1062                        tracing::info!(nsid, "notifying listeners of changed namespace");
1063                        if let Some(notifier) = notifier_guard.get(nsid) {
1064                            notifier.send(());
1065                        }
1066                    }
1067                }
1068            }
1069            event_type => {
1070                tracing::info!(
1071                    ?event_type,
1072                    information = dw0.information(),
1073                    log_page_identifier = dw0.log_page_identifier(),
1074                    "unhandled asynchronous event"
1075                );
1076            }
1077        }
1078    }
1079}
1080
1081impl<D: DeviceBacking> Drop for NvmeDriver<D> {
1082    fn drop(&mut self) {
1083        tracing::trace!(pci_id = ?self.device_id, ka = self.nvme_keepalive, task = self.task.is_some(), "dropping nvme driver");
1084        if self.task.is_some() {
1085            // Do not reset NVMe device when nvme_keepalive is requested.
1086            tracing::debug!(nvme_keepalive = self.nvme_keepalive, pci_id = ?self.device_id, "dropping nvme driver");
1087            if !self.nvme_keepalive {
1088                // Reset the device asynchronously so that pending IOs are not
1089                // dropped while their memory is aliased.
1090                let reset = self.reset();
1091                self.driver.spawn("nvme_drop", reset).detach();
1092            }
1093        }
1094    }
1095}
1096
1097impl IoIssuers {
1098    pub async fn get(&self, cpu: u32) -> Result<&Issuer, RequestError> {
1099        if let Some(v) = self.per_cpu[cpu as usize].get() {
1100            return Ok(&v.issuer);
1101        }
1102
1103        self.send
1104            .call(NvmeWorkerRequest::CreateIssuer, cpu)
1105            .await
1106            .map_err(RequestError::Gone)?;
1107
1108        Ok(self.per_cpu[cpu as usize]
1109            .get()
1110            .expect("issuer was set by rpc")
1111            .issuer
1112            .as_ref())
1113    }
1114}
1115
1116impl<D: DeviceBacking> AsyncRun<WorkerState> for DriverWorkerTask<D> {
1117    async fn run(
1118        &mut self,
1119        stop: &mut task_control::StopTask<'_>,
1120        state: &mut WorkerState,
1121    ) -> Result<(), task_control::Cancelled> {
1122        let r = stop
1123            .until_stopped(async {
1124                loop {
1125                    match self.recv.next().await {
1126                        Some(NvmeWorkerRequest::CreateIssuer(rpc)) => {
1127                            rpc.handle(async |cpu| self.create_io_issuer(state, cpu).await)
1128                                .await
1129                        }
1130                        Some(NvmeWorkerRequest::Save(rpc)) => {
1131                            rpc.handle(async |span| {
1132                                let child_span = tracing::info_span!(
1133                                    parent: &span,
1134                                    "nvme_worker_save",
1135                                    pci_id = %self.device.id()
1136                                );
1137                                self.save(state).instrument(child_span).await
1138                            })
1139                            .await
1140                        }
1141                        None => break,
1142                    }
1143                }
1144            })
1145            .await;
1146        tracing::info!(pci_id = %self.device.id(), "nvme worker task exiting");
1147        r
1148    }
1149}
1150
1151impl<D: DeviceBacking> DriverWorkerTask<D> {
1152    fn restore_io_issuer(&mut self, proto: ProtoIoQueue) -> anyhow::Result<()> {
1153        let pci_id = self.device.id().to_owned();
1154        let qid = proto.save_state.queue_data.qid;
1155        let cpu = proto.save_state.cpu;
1156
1157        tracing::info!(
1158            qid,
1159            cpu,
1160            ?pci_id,
1161            "restoring queue from prototype: mapping interrupt"
1162        );
1163        let interrupt = self
1164            .device
1165            .map_interrupt(proto.save_state.iv, proto.save_state.cpu)
1166            .with_context(|| {
1167                format!(
1168                    "failed to map interrupt for {}, cpu {}, iv {}",
1169                    pci_id, proto.save_state.cpu, proto.save_state.iv
1170                )
1171            })?;
1172
1173        tracing::info!(
1174            qid,
1175            cpu,
1176            ?pci_id,
1177            "restoring queue from prototype: restore IoQueue"
1178        );
1179        let queue = IoQueue::restore(
1180            self.driver.clone(),
1181            interrupt,
1182            self.registers.clone(),
1183            proto.mem,
1184            &pci_id,
1185            &proto.save_state,
1186            self.bounce_buffer,
1187        )
1188        .with_context(|| format!("failed to restore io queue for {}, cpu {}", pci_id, cpu))?;
1189
1190        tracing::info!(
1191            qid,
1192            cpu,
1193            ?pci_id,
1194            "restoring queue from prototype: restore complete"
1195        );
1196
1197        let issuer = IoIssuer {
1198            issuer: queue.queue.issuer().clone(),
1199            cpu,
1200        };
1201
1202        self.io_issuers.per_cpu[cpu as usize]
1203            .set(issuer)
1204            .expect("issuer already set for this cpu");
1205        self.io.push(queue);
1206
1207        Ok(())
1208    }
1209
1210    async fn create_io_issuer(&mut self, state: &mut WorkerState, cpu: u32) {
1211        tracing::debug!(cpu, pci_id = ?self.device.id(), "issuer request");
1212        if self.io_issuers.per_cpu[cpu as usize].get().is_some() {
1213            return;
1214        }
1215
1216        if let Some(proto) = self.proto_io.remove(&cpu) {
1217            match self.restore_io_issuer(proto) {
1218                Ok(()) => return,
1219                Err(err) => {
1220                    // The memory block will be dropped as `proto` goes out of scope.
1221                    //
1222                    // TODO: in future work, consider trying to issue the NVMe command to delete
1223                    // the prior IO queue pair. Given that restore failed, and crucially, why
1224                    // restore failed, that may or may not be the right thing to do. It is probably
1225                    // the "right" protocol thing to do, though.
1226
1227                    tracing::error!(
1228                        pci_id = ?self.device.id(),
1229                        cpu,
1230                        error = ?err,
1231                        "failed to restore io queue from prototype, creating new queue"
1232                    );
1233                }
1234            }
1235        }
1236
1237        let issuer = match self
1238            .create_io_queue(state, cpu)
1239            .instrument(info_span!("create_nvme_io_queue", cpu))
1240            .await
1241        {
1242            Ok(issuer) => issuer,
1243            Err(err) => {
1244                // Find a fallback queue close in index to the failed queue.
1245                let (fallback_cpu, fallback) = self.io_issuers.per_cpu[..cpu as usize]
1246                    .iter()
1247                    .enumerate()
1248                    .rev()
1249                    .find_map(|(i, issuer)| issuer.get().map(|issuer| (i, issuer)))
1250                    .expect("unable to find an io issuer for fallback");
1251
1252                // Log the error as informational only when there is a lack of
1253                // hardware resources from the device.
1254                match err {
1255                    DeviceError::NoMoreIoQueues(_) => {
1256                        tracing::info!(
1257                            pci_id = ?self.device.id(),
1258                            cpu,
1259                            fallback_cpu,
1260                            error = &err as &dyn std::error::Error,
1261                            "failed to create io queue, falling back"
1262                        );
1263                    }
1264                    _ => {
1265                        tracing::error!(
1266                            pci_id = ?self.device.id(),
1267                            cpu,
1268                            fallback_cpu,
1269                            error = &err as &dyn std::error::Error,
1270                            "failed to create io queue, falling back"
1271                        );
1272                    }
1273                }
1274
1275                fallback.clone()
1276            }
1277        };
1278
1279        self.io_issuers.per_cpu[cpu as usize]
1280            .set(issuer)
1281            .ok()
1282            .unwrap();
1283    }
1284
1285    async fn create_io_queue(
1286        &mut self,
1287        state: &mut WorkerState,
1288        cpu: u32,
1289    ) -> Result<IoIssuer, DeviceError> {
1290        if self.io.len() >= state.max_io_queues as usize {
1291            return Err(DeviceError::NoMoreIoQueues(state.max_io_queues));
1292        }
1293
1294        // qid is 1-based, iv is 0-based.
1295        // And, IO queue 1 shares interrupt vector 0 with the admin queue.
1296        let qid = self.next_ioq_id;
1297        let iv = qid - 1;
1298        self.next_ioq_id += 1;
1299
1300        tracing::debug!(cpu, qid, iv, pci_id = ?self.device.id(), "creating io queue");
1301
1302        let interrupt = self
1303            .device
1304            .map_interrupt(iv.into(), cpu)
1305            .map_err(DeviceError::InterruptMapFailure)?;
1306
1307        let queue = QueuePair::new(
1308            self.driver.clone(),
1309            self.device.deref(),
1310            qid,
1311            state.qsize,
1312            state.qsize,
1313            interrupt,
1314            self.registers.clone(),
1315            self.bounce_buffer,
1316            NoOpAerHandler,
1317        )
1318        .map_err(|err| DeviceError::IoQueuePairCreationFailure(err, qid))?;
1319
1320        assert_eq!(queue.sq_entries(), queue.cq_entries());
1321        state.qsize = queue.sq_entries();
1322
1323        let io_sq_addr = queue.sq_addr();
1324        let io_cq_addr = queue.cq_addr();
1325
1326        // Add the queue pair before aliasing its memory with the device so
1327        // that it can be torn down correctly on failure.
1328        self.io.push(IoQueue { queue, iv, cpu });
1329        let io_queue = self.io.last_mut().unwrap();
1330
1331        let admin = self.admin.as_ref().unwrap().issuer().as_ref();
1332
1333        let mut created_completion_queue = false;
1334        let r = async {
1335            admin
1336                .issue_raw(spec::Command {
1337                    cdw10: spec::Cdw10CreateIoQueue::new()
1338                        .with_qid(qid)
1339                        .with_qsize_z(state.qsize - 1)
1340                        .into(),
1341                    cdw11: spec::Cdw11CreateIoCompletionQueue::new()
1342                        .with_ien(true)
1343                        .with_iv(iv)
1344                        .with_pc(true)
1345                        .into(),
1346                    dptr: [io_cq_addr, 0],
1347                    ..admin_cmd(spec::AdminOpcode::CREATE_IO_COMPLETION_QUEUE)
1348                })
1349                .await
1350                .map_err(|err| DeviceError::IoCompletionQueueFailure(err.into(), qid))?;
1351
1352            created_completion_queue = true;
1353
1354            admin
1355                .issue_raw(spec::Command {
1356                    cdw10: spec::Cdw10CreateIoQueue::new()
1357                        .with_qid(qid)
1358                        .with_qsize_z(state.qsize - 1)
1359                        .into(),
1360                    cdw11: spec::Cdw11CreateIoSubmissionQueue::new()
1361                        .with_cqid(qid)
1362                        .with_pc(true)
1363                        .into(),
1364                    dptr: [io_sq_addr, 0],
1365                    ..admin_cmd(spec::AdminOpcode::CREATE_IO_SUBMISSION_QUEUE)
1366                })
1367                .await
1368                .map_err(|err| DeviceError::IoSubmissionQueueFailure(err.into(), qid))?;
1369
1370            Ok(())
1371        };
1372
1373        if let Err(err) = r.await {
1374            if created_completion_queue {
1375                if let Err(err) = admin
1376                    .issue_raw(spec::Command {
1377                        cdw10: spec::Cdw10DeleteIoQueue::new().with_qid(qid).into(),
1378                        ..admin_cmd(spec::AdminOpcode::DELETE_IO_COMPLETION_QUEUE)
1379                    })
1380                    .await
1381                {
1382                    tracing::error!(
1383                        pci_id = ?self.device.id(),
1384                        error = &err as &dyn std::error::Error,
1385                        "failed to delete completion queue in teardown path"
1386                    );
1387                }
1388            }
1389            let io = self.io.pop().unwrap();
1390            io.queue.shutdown().await;
1391            return Err(DeviceError::Other(err));
1392        }
1393
1394        Ok(IoIssuer {
1395            issuer: io_queue.queue.issuer().clone(),
1396            cpu,
1397        })
1398    }
1399
1400    /// Save NVMe driver state for servicing.
1401    pub async fn save(
1402        &mut self,
1403        worker_state: &mut WorkerState,
1404    ) -> anyhow::Result<NvmeDriverWorkerSavedState> {
1405        tracing::info!(pci_id = ?self.device.id(), "saving nvme driver worker state: admin queue");
1406        let admin = match self.admin.as_ref() {
1407            Some(a) => match a.save().await {
1408                Ok(admin_state) => {
1409                    tracing::info!(
1410                        pci_id = ?self.device.id(),
1411                        id = admin_state.qid,
1412                        pending_commands_count = admin_state.handler_data.pending_cmds.commands.len(),
1413                        "saved admin queue",
1414                    );
1415                    Some(admin_state)
1416                }
1417                Err(e) => {
1418                    tracing::error!(
1419                            pci_id = ?self.device.id(),
1420                            error = e.as_ref() as &dyn std::error::Error,
1421                            "failed to save admin queue",
1422                    );
1423                    return Err(e);
1424                }
1425            },
1426            None => {
1427                tracing::warn!(pci_id = ?self.device.id(), "no admin queue saved");
1428                None
1429            }
1430        };
1431
1432        tracing::info!(pci_id = ?self.device.id(), "saving nvme driver worker state: io queues");
1433        let (ok, errs): (Vec<_>, Vec<_>) =
1434            join_all(self.io.drain(..).map(async |q| q.save().await))
1435                .await
1436                .into_iter()
1437                .partition(Result::is_ok);
1438        if !errs.is_empty() {
1439            for e in errs.into_iter().map(Result::unwrap_err) {
1440                tracing::error!(
1441                    pci_id = ?self.device.id(),
1442                    error = e.as_ref() as &dyn std::error::Error,
1443                    "failed to save io queue",
1444                );
1445            }
1446            return Err(anyhow::anyhow!("failed to save one or more io queues"));
1447        }
1448
1449        let io: Vec<IoQueueSavedState> = ok
1450            .into_iter()
1451            .map(Result::unwrap)
1452            // Don't forget to include any queues that were saved from a _previous_ save, but were never restored
1453            // because they didn't see any IO.
1454            .chain(
1455                self.proto_io
1456                    .drain()
1457                    .map(|(_cpu, proto_queue)| proto_queue.save_state),
1458            )
1459            .collect();
1460
1461        match io.is_empty() {
1462            true => tracing::warn!(pci_id = ?self.device.id(), "no io queues saved"),
1463            false => tracing::info!(
1464                pci_id = ?self.device.id(),
1465                state = io
1466                    .iter()
1467                    .map(|io_state| format!(
1468                        "{{qid={}, pending_commands_count={}}}",
1469                        io_state.queue_data.qid,
1470                        io_state.queue_data.handler_data.pending_cmds.commands.len()
1471                    ))
1472                    .collect::<Vec<_>>()
1473                    .join(", "),
1474                "saved io queues",
1475            ),
1476        }
1477
1478        Ok(NvmeDriverWorkerSavedState {
1479            admin,
1480            io,
1481            qsize: worker_state.qsize,
1482            max_io_queues: worker_state.max_io_queues,
1483        })
1484    }
1485}
1486
1487impl<D: DeviceBacking> InspectTask<WorkerState> for DriverWorkerTask<D> {
1488    fn inspect(&self, req: inspect::Request<'_>, state: Option<&WorkerState>) {
1489        req.respond().merge(self).merge(state);
1490    }
1491}
1492
1493/// Save/restore data structures exposed by the NVMe driver.
1494#[expect(missing_docs)]
1495pub mod save_restore {
1496    use super::*;
1497
1498    /// Save and Restore errors for this module.
1499    #[derive(Debug, Error)]
1500    pub enum Error {
1501        /// No data to save.
1502        #[error("invalid object state")]
1503        InvalidState,
1504    }
1505
1506    /// Save/restore state for NVMe driver.
1507    #[derive(Protobuf, Clone, Debug)]
1508    #[mesh(package = "nvme_driver")]
1509    pub struct NvmeDriverSavedState {
1510        /// Copy of the controller's IDENTIFY structure.
1511        /// It is defined as Option<> in original structure.
1512        #[mesh(1, encoding = "mesh::payload::encoding::ZeroCopyEncoding")]
1513        pub identify_ctrl: spec::IdentifyController,
1514        /// Device ID string.
1515        #[mesh(2)]
1516        pub device_id: String,
1517        /// Namespace data.
1518        #[mesh(3)]
1519        pub namespaces: Vec<SavedNamespaceData>,
1520        /// NVMe driver worker task data.
1521        #[mesh(4)]
1522        pub worker_data: NvmeDriverWorkerSavedState,
1523    }
1524
1525    /// Save/restore state for NVMe driver worker task.
1526    #[derive(Protobuf, Clone, Debug)]
1527    #[mesh(package = "nvme_driver")]
1528    pub struct NvmeDriverWorkerSavedState {
1529        /// Admin queue state.
1530        #[mesh(1)]
1531        pub admin: Option<QueuePairSavedState>,
1532        /// IO queue states.
1533        #[mesh(2)]
1534        pub io: Vec<IoQueueSavedState>,
1535        /// Queue size as determined by CAP.MQES.
1536        #[mesh(3)]
1537        pub qsize: u16,
1538        /// Max number of IO queue pairs.
1539        #[mesh(4)]
1540        pub max_io_queues: u16,
1541    }
1542
1543    /// Save/restore state for QueuePair.
1544    #[derive(Protobuf, Clone, Debug)]
1545    #[mesh(package = "nvme_driver")]
1546    pub struct QueuePairSavedState {
1547        /// Allocated memory size in bytes.
1548        #[mesh(1)]
1549        pub mem_len: usize,
1550        /// First PFN of the physically contiguous block.
1551        #[mesh(2)]
1552        pub base_pfn: u64,
1553        /// Queue ID used when creating the pair
1554        /// (SQ and CQ IDs are using same number).
1555        #[mesh(3)]
1556        pub qid: u16,
1557        /// Submission queue entries.
1558        #[mesh(4)]
1559        pub sq_entries: u16,
1560        /// Completion queue entries.
1561        #[mesh(5)]
1562        pub cq_entries: u16,
1563        /// QueueHandler task data.
1564        #[mesh(6)]
1565        pub handler_data: QueueHandlerSavedState,
1566    }
1567
1568    /// Save/restore state for IoQueue.
1569    #[derive(Protobuf, Clone, Debug)]
1570    #[mesh(package = "nvme_driver")]
1571    pub struct IoQueueSavedState {
1572        #[mesh(1)]
1573        /// Which CPU handles requests.
1574        pub cpu: u32,
1575        #[mesh(2)]
1576        /// Interrupt vector (MSI-X)
1577        pub iv: u32,
1578        #[mesh(3)]
1579        pub queue_data: QueuePairSavedState,
1580    }
1581
1582    /// Save/restore state for QueueHandler task.
1583    #[derive(Protobuf, Clone, Debug)]
1584    #[mesh(package = "nvme_driver")]
1585    pub struct QueueHandlerSavedState {
1586        #[mesh(1)]
1587        pub sq_state: SubmissionQueueSavedState,
1588        #[mesh(2)]
1589        pub cq_state: CompletionQueueSavedState,
1590        #[mesh(3)]
1591        pub pending_cmds: PendingCommandsSavedState,
1592        #[mesh(4)]
1593        pub aer_handler: Option<AerHandlerSavedState>,
1594    }
1595
1596    /// Snapshot of submission queue metadata captured during save.
1597    #[derive(Protobuf, Clone, Debug)]
1598    #[mesh(package = "nvme_driver")]
1599    pub struct SubmissionQueueSavedState {
1600        #[mesh(1)]
1601        pub sqid: u16,
1602        #[mesh(2)]
1603        pub head: u32,
1604        #[mesh(3)]
1605        pub tail: u32,
1606        #[mesh(4)]
1607        pub committed_tail: u32,
1608        #[mesh(5)]
1609        pub len: u32,
1610    }
1611
1612    /// Snapshot of completion queue metadata captured during save.
1613    #[derive(Protobuf, Clone, Debug)]
1614    #[mesh(package = "nvme_driver")]
1615    pub struct CompletionQueueSavedState {
1616        #[mesh(1)]
1617        pub cqid: u16,
1618        #[mesh(2)]
1619        pub head: u32,
1620        #[mesh(3)]
1621        pub committed_head: u32,
1622        #[mesh(4)]
1623        pub len: u32,
1624        #[mesh(5)]
1625        /// NVMe completion tag.
1626        pub phase: bool,
1627    }
1628
1629    /// Pending command entry captured from a queue handler.
1630    #[derive(Protobuf, Clone, Debug)]
1631    #[mesh(package = "nvme_driver")]
1632    pub struct PendingCommandSavedState {
1633        #[mesh(1, encoding = "mesh::payload::encoding::ZeroCopyEncoding")]
1634        pub command: spec::Command,
1635    }
1636
1637    /// Collection of pending commands indexed by CID.
1638    #[derive(Protobuf, Clone, Debug)]
1639    #[mesh(package = "nvme_driver")]
1640    pub struct PendingCommandsSavedState {
1641        #[mesh(1)]
1642        pub commands: Vec<PendingCommandSavedState>,
1643        #[mesh(2)]
1644        pub next_cid_high_bits: u16,
1645        #[mesh(3)]
1646        pub cid_key_bits: u32,
1647    }
1648
1649    /// NVMe namespace data.
1650    #[derive(Protobuf, Clone, Debug)]
1651    #[mesh(package = "nvme_driver")]
1652    pub struct SavedNamespaceData {
1653        #[mesh(1)]
1654        pub nsid: u32,
1655        #[mesh(2, encoding = "mesh::payload::encoding::ZeroCopyEncoding")]
1656        pub identify_ns: nvme_spec::nvm::IdentifyNamespace,
1657    }
1658
1659    /// Saved Async Event Request handler metadata.
1660    #[derive(Clone, Debug, Protobuf)]
1661    #[mesh(package = "nvme_driver")]
1662    pub struct AerHandlerSavedState {
1663        #[mesh(1)]
1664        pub last_aen: Option<u32>,
1665        #[mesh(2)]
1666        pub await_aen_cid: Option<u16>,
1667    }
1668}