nvme/
queue.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! NVMe submission and completion queue types.
5
6use crate::DOORBELL_STRIDE_BITS;
7use crate::spec;
8use guestmem::GuestMemory;
9use guestmem::GuestMemoryError;
10use inspect::Inspect;
11use parking_lot::RwLock;
12use std::sync::Arc;
13use std::sync::atomic::Ordering;
14use std::task::Context;
15use std::task::Poll;
16use std::task::Waker;
17use std::task::ready;
18use thiserror::Error;
19use vmcore::interrupt::Interrupt;
20
21pub struct DoorbellMemory {
22    mem: GuestMemory,
23    offset: u64,
24    event_idx_offset: Option<u64>,
25    wakers: Vec<Option<Waker>>,
26}
27
28pub struct InvalidDoorbell;
29
30impl DoorbellMemory {
31    pub fn new(num_qids: u16) -> Self {
32        Self {
33            mem: GuestMemory::allocate((num_qids as usize) << DOORBELL_STRIDE_BITS),
34            offset: 0,
35            event_idx_offset: None,
36            wakers: (0..num_qids).map(|_| None).collect(),
37        }
38    }
39
40    /// Update the memory used to store the doorbell values. This is used to
41    /// support shadow doorbells, where the values are directly in guest memory.
42    pub fn replace_mem(
43        &mut self,
44        mem: GuestMemory,
45        offset: u64,
46        event_idx_offset: Option<u64>,
47    ) -> Result<(), GuestMemoryError> {
48        // Copy the current doorbell values into the new memory.
49        let len = self.wakers.len() << DOORBELL_STRIDE_BITS;
50        let mut current = vec![0; len];
51        self.mem.read_at(self.offset, &mut current)?;
52        mem.write_at(offset, &current)?;
53        if let Some(event_idx_offset) = event_idx_offset {
54            // Catch eventidx up to the current doorbell value.
55            mem.write_at(event_idx_offset, &current)?;
56        }
57        self.mem = mem;
58        self.offset = offset;
59        self.event_idx_offset = event_idx_offset;
60        Ok(())
61    }
62
63    pub fn try_write(&self, db_id: u16, value: u32) -> Result<(), InvalidDoorbell> {
64        if (db_id as usize) >= self.wakers.len() {
65            return Err(InvalidDoorbell);
66        }
67        self.write(db_id, value);
68        Ok(())
69    }
70
71    fn write(&self, db_id: u16, value: u32) {
72        assert!((db_id as usize) < self.wakers.len());
73        let addr = self
74            .offset
75            .wrapping_add((db_id as u64) << DOORBELL_STRIDE_BITS);
76        if let Err(err) = self.mem.write_plain(addr, &value) {
77            tracelimit::error_ratelimited!(
78                error = &err as &dyn std::error::Error,
79                "failed to write doorbell memory"
80            );
81        }
82        if let Some(waker) = &self.wakers[db_id as usize] {
83            waker.wake_by_ref();
84        }
85    }
86
87    fn read(&self, db_id: u16) -> Option<u32> {
88        assert!((db_id as usize) < self.wakers.len());
89        self.mem
90            .read_plain(
91                self.offset
92                    .wrapping_add((db_id as u64) << DOORBELL_STRIDE_BITS),
93            )
94            .inspect_err(|err| {
95                tracelimit::error_ratelimited!(
96                    error = err as &dyn std::error::Error,
97                    "failed to read doorbell memory"
98                );
99            })
100            .ok()
101    }
102
103    fn has_event_idx(&self) -> bool {
104        self.event_idx_offset.is_some()
105    }
106
107    fn write_event_idx(&self, db_id: u16, val: u32) {
108        assert!((db_id as usize) < self.wakers.len());
109        if let Err(err) = self.mem.write_plain(
110            self.event_idx_offset
111                .unwrap()
112                .wrapping_add((db_id as u64) << DOORBELL_STRIDE_BITS),
113            &val,
114        ) {
115            tracelimit::error_ratelimited!(
116                error = &err as &dyn std::error::Error,
117                "failed to read event_idx memory"
118            )
119        }
120    }
121
122    fn read_event_idx(&self, db_id: u16) -> Option<u32> {
123        assert!((db_id as usize) < self.wakers.len());
124        self.mem
125            .read_plain(
126                self.event_idx_offset?
127                    .wrapping_add((db_id as u64) << DOORBELL_STRIDE_BITS),
128            )
129            .inspect_err(|err| {
130                tracelimit::error_ratelimited!(
131                    error = err as &dyn std::error::Error,
132                    "failed to read doorbell memory"
133                );
134            })
135            .ok()
136    }
137}
138
139#[derive(Inspect)]
140#[inspect(extra = "Self::inspect_shadow")]
141struct DoorbellState {
142    #[inspect(hex)]
143    current: u32,
144    #[inspect(hex)]
145    event_idx: u32,
146    db_id: u16,
147    db_offset: u64,
148    #[inspect(hex)]
149    len: u32,
150    #[inspect(skip)]
151    doorbells: Arc<RwLock<DoorbellMemory>>,
152    #[inspect(skip)]
153    registered_waker: Option<Waker>,
154}
155
156impl DoorbellState {
157    fn inspect_shadow(&self, resp: &mut inspect::Response<'_>) {
158        resp.field_with("doorbell", || {
159            self.doorbells.read().read(self.db_id).map(inspect::AsHex)
160        })
161        .field_with("shadow_event_idx", || {
162            self.doorbells
163                .read()
164                .read_event_idx(self.db_id)
165                .map(inspect::AsHex)
166        });
167    }
168
169    fn new(doorbells: Arc<RwLock<DoorbellMemory>>, db_id: u16, len: u32) -> Self {
170        Self {
171            current: 0,
172            event_idx: 0,
173            len,
174            doorbells,
175            registered_waker: None,
176            db_id,
177            db_offset: (db_id as u64) << DOORBELL_STRIDE_BITS,
178        }
179    }
180
181    fn probe_inner(&mut self, update_event_idx: bool) -> Option<u32> {
182        // Try to read forward.
183        let doorbell = self.doorbells.read();
184        let val = doorbell.read(self.db_id)?;
185        if val != self.current {
186            return Some(val);
187        }
188
189        if self.event_idx == val || !update_event_idx || !doorbell.has_event_idx() {
190            return None;
191        }
192
193        // Update the event index so that the guest will write the real doorbell
194        // on the next update.
195        doorbell.write_event_idx(self.db_id, val);
196        self.event_idx = val;
197
198        // Double check after a memory barrier.
199        std::sync::atomic::fence(Ordering::SeqCst);
200        let val = doorbell.read(self.db_id)?;
201        if val != self.current { Some(val) } else { None }
202    }
203
204    fn probe(&mut self, update_event_idx: bool) -> Result<bool, QueueError> {
205        // If shadow doorbells are in use, use that instead of what was written to the doorbell
206        // register, as it may be more current.
207        if let Some(val) = self.probe_inner(update_event_idx) {
208            if val >= self.len {
209                return Err(QueueError::InvalidDoorbell { val, len: self.len });
210            }
211            self.current = val;
212            Ok(true)
213        } else {
214            Ok(false)
215        }
216    }
217
218    fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), QueueError>> {
219        // Ensure we get woken up whenever the doorbell is written to.
220        if self
221            .registered_waker
222            .as_ref()
223            .is_none_or(|w| !cx.waker().will_wake(w))
224        {
225            let _old_waker =
226                self.doorbells.write().wakers[self.db_id as usize].replace(cx.waker().clone());
227            self.registered_waker = Some(cx.waker().clone());
228        }
229        if !self.probe(true)? {
230            return Poll::Pending;
231        }
232        Poll::Ready(Ok(()))
233    }
234}
235
236#[derive(Inspect)]
237pub struct SubmissionQueue {
238    tail: DoorbellState,
239    mem: GuestMemory,
240    #[inspect(hex)]
241    head: u32,
242    #[inspect(hex)]
243    gpa: u64,
244}
245
246#[derive(Debug, Error)]
247pub enum QueueError {
248    #[error("invalid doorbell value {val:#x}, len {len:#x}")]
249    InvalidDoorbell { val: u32, len: u32 },
250    #[error("queue access error")]
251    Memory(#[source] GuestMemoryError),
252}
253
254impl SubmissionQueue {
255    pub fn new(cq: &CompletionQueue, db_id: u16, gpa: u64, len: u16) -> Self {
256        let doorbells = cq.head.doorbells.clone();
257        let mem = cq.mem.clone();
258        doorbells.read().write(db_id, 0);
259        Self {
260            tail: DoorbellState::new(doorbells, db_id, len.into()),
261            head: 0,
262            gpa,
263            mem,
264        }
265    }
266
267    /// This function returns a future for the next entry in the submission queue.  It also
268    /// has a side effect of updating the tail.
269    pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<spec::Command, QueueError>> {
270        let tail = self.tail.current;
271        if tail == self.head {
272            ready!(self.tail.poll(cx))?;
273        }
274        let command: spec::Command = self
275            .mem
276            .read_plain(
277                self.gpa
278                    .wrapping_add(self.head as u64 * size_of::<spec::Command>() as u64),
279            )
280            .map_err(QueueError::Memory)?;
281
282        self.head = advance(self.head, self.tail.len);
283        Poll::Ready(Ok(command))
284    }
285
286    pub fn sqhd(&self) -> u16 {
287        self.head as u16
288    }
289}
290
291#[derive(Inspect)]
292pub struct CompletionQueue {
293    #[inspect(hex)]
294    tail: u32,
295    head: DoorbellState,
296    phase: bool,
297    #[inspect(hex)]
298    gpa: u64,
299    #[inspect(with = "Option::is_some")]
300    interrupt: Option<Interrupt>,
301    mem: GuestMemory,
302}
303
304impl CompletionQueue {
305    pub fn new(
306        doorbells: Arc<RwLock<DoorbellMemory>>,
307        db_id: u16,
308        mem: GuestMemory,
309        interrupt: Option<Interrupt>,
310        gpa: u64,
311        len: u16,
312    ) -> Self {
313        doorbells.read().write(db_id, 0);
314        Self {
315            tail: 0,
316            head: DoorbellState::new(doorbells, db_id, len.into()),
317            phase: true,
318            gpa,
319            interrupt,
320            mem,
321        }
322    }
323
324    /// Wait for free completions.
325    pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), QueueError>> {
326        let next_tail = advance(self.tail, self.head.len);
327        if self.head.current == next_tail {
328            ready!(self.head.poll(cx))?;
329        }
330        Poll::Ready(Ok(()))
331    }
332
333    pub fn write(&mut self, mut data: spec::Completion) -> Result<bool, QueueError> {
334        let next = advance(self.tail, self.head.len);
335        // Check the doorbell register instead of requiring the caller to
336        // go around the slow path and call `poll_ready`.
337        if self.head.current == next && !self.head.probe(false)? {
338            return Ok(false);
339        }
340        data.status.set_phase(self.phase);
341
342        // Atomically write the low part of the completion entry first, then the
343        // high part, using release fences to ensure ordering.
344        //
345        // This is necessary to ensure the guest can observe the full completion
346        // once it observes the phase bit change (which is in the high part).
347        let [low, high]: [u64; 2] = zerocopy::transmute!(data);
348        let gpa = self
349            .gpa
350            .wrapping_add(self.tail as u64 * size_of::<spec::Completion>() as u64);
351        self.mem
352            .write_plain(gpa, &low)
353            .map_err(QueueError::Memory)?;
354        std::sync::atomic::fence(Ordering::Release);
355        self.mem
356            .write_plain(gpa + 8, &high)
357            .map_err(QueueError::Memory)?;
358        std::sync::atomic::fence(Ordering::Release);
359
360        if let Some(interrupt) = &self.interrupt {
361            interrupt.deliver();
362        }
363        self.tail = next;
364        if self.tail == 0 {
365            self.phase = !self.phase;
366        }
367        Ok(true)
368    }
369}
370
371fn advance(n: u32, l: u32) -> u32 {
372    if n + 1 < l { n + 1 } else { 0 }
373}