1use crate::DOORBELL_STRIDE_BITS;
7use crate::spec;
8use guestmem::GuestMemory;
9use guestmem::GuestMemoryError;
10use inspect::Inspect;
11use parking_lot::RwLock;
12use std::sync::Arc;
13use std::sync::atomic::Ordering;
14use std::task::Context;
15use std::task::Poll;
16use std::task::Waker;
17use std::task::ready;
18use thiserror::Error;
19use vmcore::interrupt::Interrupt;
20
21pub struct DoorbellMemory {
22 mem: GuestMemory,
23 offset: u64,
24 event_idx_offset: Option<u64>,
25 wakers: Vec<Option<Waker>>,
26}
27
28pub struct InvalidDoorbell;
29
30impl DoorbellMemory {
31 pub fn new(num_qids: u16) -> Self {
32 Self {
33 mem: GuestMemory::allocate((num_qids as usize) << DOORBELL_STRIDE_BITS),
34 offset: 0,
35 event_idx_offset: None,
36 wakers: (0..num_qids).map(|_| None).collect(),
37 }
38 }
39
40 pub fn replace_mem(
43 &mut self,
44 mem: GuestMemory,
45 offset: u64,
46 event_idx_offset: Option<u64>,
47 ) -> Result<(), GuestMemoryError> {
48 let len = self.wakers.len() << DOORBELL_STRIDE_BITS;
50 let mut current = vec![0; len];
51 self.mem.read_at(self.offset, &mut current)?;
52 mem.write_at(offset, ¤t)?;
53 if let Some(event_idx_offset) = event_idx_offset {
54 mem.write_at(event_idx_offset, ¤t)?;
56 }
57 self.mem = mem;
58 self.offset = offset;
59 self.event_idx_offset = event_idx_offset;
60 Ok(())
61 }
62
63 pub fn try_write(&self, db_id: u16, value: u32) -> Result<(), InvalidDoorbell> {
64 if (db_id as usize) >= self.wakers.len() {
65 return Err(InvalidDoorbell);
66 }
67 self.write(db_id, value);
68 Ok(())
69 }
70
71 fn write(&self, db_id: u16, value: u32) {
72 assert!((db_id as usize) < self.wakers.len());
73 let addr = self
74 .offset
75 .wrapping_add((db_id as u64) << DOORBELL_STRIDE_BITS);
76 if let Err(err) = self.mem.write_plain(addr, &value) {
77 tracelimit::error_ratelimited!(
78 error = &err as &dyn std::error::Error,
79 "failed to write doorbell memory"
80 );
81 }
82 if let Some(waker) = &self.wakers[db_id as usize] {
83 waker.wake_by_ref();
84 }
85 }
86
87 fn read(&self, db_id: u16) -> Option<u32> {
88 assert!((db_id as usize) < self.wakers.len());
89 self.mem
90 .read_plain(
91 self.offset
92 .wrapping_add((db_id as u64) << DOORBELL_STRIDE_BITS),
93 )
94 .inspect_err(|err| {
95 tracelimit::error_ratelimited!(
96 error = err as &dyn std::error::Error,
97 "failed to read doorbell memory"
98 );
99 })
100 .ok()
101 }
102
103 fn has_event_idx(&self) -> bool {
104 self.event_idx_offset.is_some()
105 }
106
107 fn write_event_idx(&self, db_id: u16, val: u32) {
108 assert!((db_id as usize) < self.wakers.len());
109 if let Err(err) = self.mem.write_plain(
110 self.event_idx_offset
111 .unwrap()
112 .wrapping_add((db_id as u64) << DOORBELL_STRIDE_BITS),
113 &val,
114 ) {
115 tracelimit::error_ratelimited!(
116 error = &err as &dyn std::error::Error,
117 "failed to read event_idx memory"
118 )
119 }
120 }
121
122 fn read_event_idx(&self, db_id: u16) -> Option<u32> {
123 assert!((db_id as usize) < self.wakers.len());
124 self.mem
125 .read_plain(
126 self.event_idx_offset?
127 .wrapping_add((db_id as u64) << DOORBELL_STRIDE_BITS),
128 )
129 .inspect_err(|err| {
130 tracelimit::error_ratelimited!(
131 error = err as &dyn std::error::Error,
132 "failed to read doorbell memory"
133 );
134 })
135 .ok()
136 }
137}
138
139#[derive(Inspect)]
140#[inspect(extra = "Self::inspect_shadow")]
141struct DoorbellState {
142 #[inspect(hex)]
143 current: u32,
144 #[inspect(hex)]
145 event_idx: u32,
146 db_id: u16,
147 db_offset: u64,
148 #[inspect(hex)]
149 len: u32,
150 #[inspect(skip)]
151 doorbells: Arc<RwLock<DoorbellMemory>>,
152 #[inspect(skip)]
153 registered_waker: Option<Waker>,
154}
155
156impl DoorbellState {
157 fn inspect_shadow(&self, resp: &mut inspect::Response<'_>) {
158 resp.field_with("doorbell", || {
159 self.doorbells.read().read(self.db_id).map(inspect::AsHex)
160 })
161 .field_with("shadow_event_idx", || {
162 self.doorbells
163 .read()
164 .read_event_idx(self.db_id)
165 .map(inspect::AsHex)
166 });
167 }
168
169 fn new(doorbells: Arc<RwLock<DoorbellMemory>>, db_id: u16, len: u32) -> Self {
170 Self {
171 current: 0,
172 event_idx: 0,
173 len,
174 doorbells,
175 registered_waker: None,
176 db_id,
177 db_offset: (db_id as u64) << DOORBELL_STRIDE_BITS,
178 }
179 }
180
181 fn probe_inner(&mut self, update_event_idx: bool) -> Option<u32> {
182 let doorbell = self.doorbells.read();
184 let val = doorbell.read(self.db_id)?;
185 if val != self.current {
186 return Some(val);
187 }
188
189 if self.event_idx == val || !update_event_idx || !doorbell.has_event_idx() {
190 return None;
191 }
192
193 doorbell.write_event_idx(self.db_id, val);
196 self.event_idx = val;
197
198 std::sync::atomic::fence(Ordering::SeqCst);
200 let val = doorbell.read(self.db_id)?;
201 if val != self.current { Some(val) } else { None }
202 }
203
204 fn probe(&mut self, update_event_idx: bool) -> Result<bool, QueueError> {
205 if let Some(val) = self.probe_inner(update_event_idx) {
208 if val >= self.len {
209 return Err(QueueError::InvalidDoorbell { val, len: self.len });
210 }
211 self.current = val;
212 Ok(true)
213 } else {
214 Ok(false)
215 }
216 }
217
218 fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), QueueError>> {
219 if self
221 .registered_waker
222 .as_ref()
223 .is_none_or(|w| !cx.waker().will_wake(w))
224 {
225 let _old_waker =
226 self.doorbells.write().wakers[self.db_id as usize].replace(cx.waker().clone());
227 self.registered_waker = Some(cx.waker().clone());
228 }
229 if !self.probe(true)? {
230 return Poll::Pending;
231 }
232 Poll::Ready(Ok(()))
233 }
234}
235
236#[derive(Inspect)]
237pub struct SubmissionQueue {
238 tail: DoorbellState,
239 mem: GuestMemory,
240 #[inspect(hex)]
241 head: u32,
242 #[inspect(hex)]
243 gpa: u64,
244}
245
246#[derive(Debug, Error)]
247pub enum QueueError {
248 #[error("invalid doorbell value {val:#x}, len {len:#x}")]
249 InvalidDoorbell { val: u32, len: u32 },
250 #[error("queue access error")]
251 Memory(#[source] GuestMemoryError),
252}
253
254impl SubmissionQueue {
255 pub fn new(cq: &CompletionQueue, db_id: u16, gpa: u64, len: u16) -> Self {
256 let doorbells = cq.head.doorbells.clone();
257 let mem = cq.mem.clone();
258 doorbells.read().write(db_id, 0);
259 Self {
260 tail: DoorbellState::new(doorbells, db_id, len.into()),
261 head: 0,
262 gpa,
263 mem,
264 }
265 }
266
267 pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<spec::Command, QueueError>> {
270 let tail = self.tail.current;
271 if tail == self.head {
272 ready!(self.tail.poll(cx))?;
273 }
274 let command: spec::Command = self
275 .mem
276 .read_plain(
277 self.gpa
278 .wrapping_add(self.head as u64 * size_of::<spec::Command>() as u64),
279 )
280 .map_err(QueueError::Memory)?;
281
282 self.head = advance(self.head, self.tail.len);
283 Poll::Ready(Ok(command))
284 }
285
286 pub fn sqhd(&self) -> u16 {
287 self.head as u16
288 }
289}
290
291#[derive(Inspect)]
292pub struct CompletionQueue {
293 #[inspect(hex)]
294 tail: u32,
295 head: DoorbellState,
296 phase: bool,
297 #[inspect(hex)]
298 gpa: u64,
299 #[inspect(with = "Option::is_some")]
300 interrupt: Option<Interrupt>,
301 mem: GuestMemory,
302}
303
304impl CompletionQueue {
305 pub fn new(
306 doorbells: Arc<RwLock<DoorbellMemory>>,
307 db_id: u16,
308 mem: GuestMemory,
309 interrupt: Option<Interrupt>,
310 gpa: u64,
311 len: u16,
312 ) -> Self {
313 doorbells.read().write(db_id, 0);
314 Self {
315 tail: 0,
316 head: DoorbellState::new(doorbells, db_id, len.into()),
317 phase: true,
318 gpa,
319 interrupt,
320 mem,
321 }
322 }
323
324 pub fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), QueueError>> {
326 let next_tail = advance(self.tail, self.head.len);
327 if self.head.current == next_tail {
328 ready!(self.head.poll(cx))?;
329 }
330 Poll::Ready(Ok(()))
331 }
332
333 pub fn write(&mut self, mut data: spec::Completion) -> Result<bool, QueueError> {
334 let next = advance(self.tail, self.head.len);
335 if self.head.current == next && !self.head.probe(false)? {
338 return Ok(false);
339 }
340 data.status.set_phase(self.phase);
341
342 let [low, high]: [u64; 2] = zerocopy::transmute!(data);
348 let gpa = self
349 .gpa
350 .wrapping_add(self.tail as u64 * size_of::<spec::Completion>() as u64);
351 self.mem
352 .write_plain(gpa, &low)
353 .map_err(QueueError::Memory)?;
354 std::sync::atomic::fence(Ordering::Release);
355 self.mem
356 .write_plain(gpa + 8, &high)
357 .map_err(QueueError::Memory)?;
358 std::sync::atomic::fence(Ordering::Release);
359
360 if let Some(interrupt) = &self.interrupt {
361 interrupt.deliver();
362 }
363 self.tail = next;
364 if self.tail == 0 {
365 self.phase = !self.phase;
366 }
367 Ok(true)
368 }
369}
370
371fn advance(n: u32, l: u32) -> u32 {
372 if n + 1 < l { n + 1 } else { 0 }
373}