1use crate::queue::QueueCoreCompleteWork;
5use crate::queue::QueueCoreGetWork;
6use crate::queue::QueueError;
7use crate::queue::QueueParams;
8use crate::queue::QueueWork;
9use crate::queue::VirtioQueuePayload;
10use crate::queue::new_queue;
11use crate::spec::VirtioDeviceFeatures;
12use async_trait::async_trait;
13use futures::FutureExt;
14use futures::Stream;
15use futures::StreamExt;
16use guestmem::DoorbellRegistration;
17use guestmem::GuestMemory;
18use guestmem::GuestMemoryError;
19use guestmem::MappedMemoryRegion;
20use inspect::Inspect;
21use pal_async::DefaultPool;
22use pal_async::driver::Driver;
23use pal_async::wait::PolledWait;
24use pal_event::Event;
25use parking_lot::Mutex;
26use std::io::Error;
27use std::pin::Pin;
28use std::sync::Arc;
29use std::task::Context;
30use std::task::Poll;
31use std::task::ready;
32use task_control::AsyncRun;
33use task_control::StopTask;
34use task_control::TaskControl;
35use thiserror::Error;
36use vmcore::interrupt::Interrupt;
37
38#[async_trait]
39pub trait VirtioQueueWorkerContext {
40 async fn process_work(&mut self, work: anyhow::Result<VirtioQueueCallbackWork>) -> bool;
41}
42
43#[derive(Debug, Inspect)]
44pub struct VirtioQueueUsedHandler {
45 #[inspect(skip)]
46 core: QueueCoreCompleteWork,
47 #[inspect(with = "|x| x.lock().0")]
48 outstanding_desc_count: Arc<Mutex<(u16, event_listener::Event)>>,
49 #[inspect(skip)]
50 notify_guest: Interrupt,
51}
52
53impl VirtioQueueUsedHandler {
54 fn new(core: QueueCoreCompleteWork, notify_guest: Interrupt) -> Self {
55 Self {
56 core,
57 outstanding_desc_count: Arc::new(Mutex::new((0, event_listener::Event::new()))),
58 notify_guest,
59 }
60 }
61
62 pub fn add_outstanding_descriptor(&self) {
63 let (count, _) = &mut *self.outstanding_desc_count.lock();
64 *count += 1;
65 }
66
67 pub fn complete_descriptor(&mut self, work: &QueueWork, bytes_written: u32) {
68 match self.core.complete_descriptor(work, bytes_written) {
69 Ok(true) => {
70 self.notify_guest.deliver();
71 }
72 Ok(false) => {}
73 Err(err) => {
74 tracelimit::error_ratelimited!(
75 error = &err as &dyn std::error::Error,
76 "failed to complete descriptor"
77 );
78 }
79 }
80 {
81 let (count, event) = &mut *self.outstanding_desc_count.lock();
82 *count -= 1;
83 if *count == 0 {
84 event.notify(usize::MAX);
85 }
86 }
87 }
88}
89
90pub struct VirtioQueueCallbackWork {
91 used_queue_handler: Arc<Mutex<VirtioQueueUsedHandler>>,
92 work: QueueWork,
93 pub payload: Vec<VirtioQueuePayload>,
94 completed: bool,
95}
96
97impl VirtioQueueCallbackWork {
98 pub fn new(
99 mut work: QueueWork,
100 used_queue_handler: &Arc<Mutex<VirtioQueueUsedHandler>>,
101 ) -> Self {
102 let used_queue_handler = used_queue_handler.clone();
103 let payload = std::mem::take(&mut work.payload);
104 used_queue_handler.lock().add_outstanding_descriptor();
105 Self {
106 work,
107 payload,
108 used_queue_handler,
109 completed: false,
110 }
111 }
112
113 pub fn complete(&mut self, bytes_written: u32) {
114 assert!(!self.completed);
115 self.used_queue_handler
116 .lock()
117 .complete_descriptor(&self.work, bytes_written);
118 self.completed = true;
119 }
120
121 pub fn descriptor_index(&self) -> u16 {
122 self.work.descriptor_index()
123 }
124
125 pub fn get_payload_length(&self, writeable: bool) -> u64 {
127 self.payload
128 .iter()
129 .filter(|x| x.writeable == writeable)
130 .fold(0, |acc, x| acc + x.length as u64)
131 }
132
133 pub fn read(&self, mem: &GuestMemory, target: &mut [u8]) -> Result<usize, GuestMemoryError> {
135 let mut remaining = target;
136 let mut read_bytes: usize = 0;
137 for payload in &self.payload {
138 if payload.writeable {
139 continue;
140 }
141
142 let size = std::cmp::min(payload.length as usize, remaining.len());
143 let (current, next) = remaining.split_at_mut(size);
144 mem.read_at(payload.address, current)?;
145 read_bytes += size;
146 if next.is_empty() {
147 break;
148 }
149
150 remaining = next;
151 }
152
153 Ok(read_bytes)
154 }
155
156 pub fn write_at_offset(
158 &self,
159 offset: u64,
160 mem: &GuestMemory,
161 source: &[u8],
162 ) -> Result<(), VirtioWriteError> {
163 let mut skip_bytes = offset;
164 let mut remaining = source;
165 for payload in &self.payload {
166 if !payload.writeable {
167 continue;
168 }
169
170 let payload_length = payload.length as u64;
171 if skip_bytes >= payload_length {
172 skip_bytes -= payload_length;
173 continue;
174 }
175
176 let size = std::cmp::min(
177 payload_length as usize - skip_bytes as usize,
178 remaining.len(),
179 );
180 let (current, next) = remaining.split_at(size);
181 mem.write_at(payload.address + skip_bytes, current)?;
182 remaining = next;
183 if remaining.is_empty() {
184 break;
185 }
186 skip_bytes = 0;
187 }
188
189 if !remaining.is_empty() {
190 return Err(VirtioWriteError::NotAllWritten(source.len()));
191 }
192
193 Ok(())
194 }
195
196 pub fn write(&self, mem: &GuestMemory, source: &[u8]) -> Result<(), VirtioWriteError> {
197 self.write_at_offset(0, mem, source)
198 }
199}
200
201#[derive(Debug, Error)]
202pub enum VirtioWriteError {
203 #[error(transparent)]
204 Memory(#[from] GuestMemoryError),
205 #[error("{0:#x} bytes not written")]
206 NotAllWritten(usize),
207}
208
209impl Drop for VirtioQueueCallbackWork {
210 fn drop(&mut self) {
211 if !self.completed {
212 self.complete(0);
213 }
214 }
215}
216
217#[derive(Debug, Inspect)]
218pub struct VirtioQueue {
219 #[inspect(flatten)]
220 core: QueueCoreGetWork,
221 used_handler: Arc<Mutex<VirtioQueueUsedHandler>>,
222 #[inspect(skip)]
223 queue_event: PolledWait<Event>,
224}
225
226impl VirtioQueue {
227 pub fn new(
228 features: VirtioDeviceFeatures,
229 params: QueueParams,
230 mem: GuestMemory,
231 notify: Interrupt,
232 queue_event: PolledWait<Event>,
233 ) -> Result<Self, QueueError> {
234 let (get_work, complete_work) = new_queue(features, mem, params)?;
235 let used_handler = Arc::new(Mutex::new(VirtioQueueUsedHandler::new(
236 complete_work,
237 notify,
238 )));
239 Ok(Self {
240 core: get_work,
241 used_handler,
242 queue_event,
243 })
244 }
245
246 pub fn poll_kick(&mut self, cx: &mut Context<'_>) -> Poll<()> {
248 ready!(self.queue_event.wait().poll_unpin(cx)).expect("waits on Event cannot fail");
249 Poll::Ready(())
250 }
251
252 pub fn try_next(&mut self) -> Result<Option<VirtioQueueCallbackWork>, Error> {
260 Ok(self
261 .core
262 .try_next_work()
263 .map_err(Error::other)?
264 .map(|work| VirtioQueueCallbackWork::new(work, &self.used_handler)))
265 }
266
267 fn poll_next_buffer(
268 &mut self,
269 cx: &mut Context<'_>,
270 ) -> Poll<Result<Option<VirtioQueueCallbackWork>, Error>> {
271 loop {
272 if let Some(work) = self.try_next()? {
273 return Ok(Some(work)).into();
274 }
275 ready!(self.poll_kick(cx));
276 }
277 }
278}
279
280impl Drop for VirtioQueue {
281 fn drop(&mut self) {
282 if Arc::get_mut(&mut self.used_handler).is_none() {
283 tracing::error!("Virtio queue dropped with outstanding work pending")
284 }
285 }
286}
287
288impl Stream for VirtioQueue {
289 type Item = Result<VirtioQueueCallbackWork, Error>;
290
291 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
292 ready!(self.get_mut().poll_next_buffer(cx))
293 .transpose()
294 .into()
295 }
296}
297
298enum VirtioQueueStateInner {
299 Initializing {
300 mem: GuestMemory,
301 features: VirtioDeviceFeatures,
302 params: QueueParams,
303 event: Event,
304 notify: Interrupt,
305 exit_event: event_listener::EventListener,
306 },
307 InitializationInProgress,
308 Running {
309 queue: VirtioQueue,
310 exit_event: event_listener::EventListener,
311 },
312}
313
314pub struct VirtioQueueState {
315 inner: VirtioQueueStateInner,
316}
317
318pub struct VirtioQueueWorker {
319 driver: Box<dyn Driver>,
320 context: Box<dyn VirtioQueueWorkerContext + Send>,
321}
322
323impl VirtioQueueWorker {
324 pub fn new(driver: impl Driver, context: Box<dyn VirtioQueueWorkerContext + Send>) -> Self {
325 Self {
326 driver: Box::new(driver),
327 context,
328 }
329 }
330
331 pub fn into_running_task(
332 self,
333 name: impl Into<String>,
334 mem: GuestMemory,
335 features: VirtioDeviceFeatures,
336 queue_resources: QueueResources,
337 exit_event: event_listener::EventListener,
338 ) -> TaskControl<VirtioQueueWorker, VirtioQueueState> {
339 let name = name.into();
340 let (_, driver) = DefaultPool::spawn_on_thread(&name);
341
342 let mut task = TaskControl::new(self);
343 task.insert(
344 driver,
345 name,
346 VirtioQueueState {
347 inner: VirtioQueueStateInner::Initializing {
348 mem,
349 features,
350 params: queue_resources.params,
351 event: queue_resources.event,
352 notify: queue_resources.notify,
353 exit_event,
354 },
355 },
356 );
357 task.start();
358 task
359 }
360
361 async fn run_queue(&mut self, state: &mut VirtioQueueState) -> bool {
362 match &mut state.inner {
363 VirtioQueueStateInner::InitializationInProgress => unreachable!(),
364 VirtioQueueStateInner::Initializing { .. } => {
365 let VirtioQueueStateInner::Initializing {
366 mem,
367 features,
368 params,
369 event,
370 notify,
371 exit_event,
372 } = std::mem::replace(
373 &mut state.inner,
374 VirtioQueueStateInner::InitializationInProgress,
375 )
376 else {
377 unreachable!()
378 };
379 let queue_event = PolledWait::new(&self.driver, event).unwrap();
380 let queue = VirtioQueue::new(features, params, mem, notify, queue_event);
381 if let Err(err) = queue {
382 tracing::error!(
383 err = &err as &dyn std::error::Error,
384 "Failed to start queue"
385 );
386 false
387 } else {
388 state.inner = VirtioQueueStateInner::Running {
389 queue: queue.unwrap(),
390 exit_event,
391 };
392 true
393 }
394 }
395 VirtioQueueStateInner::Running { queue, exit_event } => {
396 let mut exit = exit_event.fuse();
397 let mut queue_ready = queue.next().fuse();
398 let work = futures::select_biased! {
399 _ = exit => return false,
400 work = queue_ready => work.expect("queue will never complete").map_err(anyhow::Error::from),
401 };
402 self.context.process_work(work).await
403 }
404 }
405 }
406}
407
408impl AsyncRun<VirtioQueueState> for VirtioQueueWorker {
409 async fn run(
410 &mut self,
411 stop: &mut StopTask<'_>,
412 state: &mut VirtioQueueState,
413 ) -> Result<(), task_control::Cancelled> {
414 while stop.until_stopped(self.run_queue(state)).await? {}
415 Ok(())
416 }
417}
418
419pub(crate) struct VirtioDoorbells {
420 registration: Option<Arc<dyn DoorbellRegistration>>,
421 doorbells: Vec<Box<dyn Send + Sync>>,
422}
423
424impl VirtioDoorbells {
425 pub fn new(registration: Option<Arc<dyn DoorbellRegistration>>) -> Self {
426 Self {
427 registration,
428 doorbells: Vec::new(),
429 }
430 }
431
432 pub fn add(&mut self, address: u64, value: Option<u64>, length: Option<u32>, event: &Event) {
433 if let Some(registration) = &mut self.registration {
434 let doorbell = registration.register_doorbell(address, value, length, event);
435 if let Ok(doorbell) = doorbell {
436 self.doorbells.push(doorbell);
437 }
438 }
439 }
440
441 pub fn clear(&mut self) {
442 self.doorbells.clear();
443 }
444}
445
446#[derive(Copy, Clone, Debug, Default)]
447pub struct DeviceTraitsSharedMemory {
448 pub id: u8,
449 pub size: u64,
450}
451
452#[derive(Clone, Debug, Default)]
453pub struct DeviceTraits {
454 pub device_id: u16,
455 pub device_features: VirtioDeviceFeatures,
456 pub max_queues: u16,
457 pub device_register_length: u32,
458 pub shared_memory: DeviceTraitsSharedMemory,
459}
460
461pub trait VirtioDevice: inspect::InspectMut + Send {
462 fn traits(&self) -> DeviceTraits;
463 fn read_registers_u32(&self, offset: u16) -> u32;
464 fn write_registers_u32(&mut self, offset: u16, val: u32);
465 fn enable(&mut self, resources: Resources) -> anyhow::Result<()>;
473 fn poll_disable(&mut self, cx: &mut Context<'_>) -> Poll<()>;
482}
483
484pub struct QueueResources {
485 pub params: QueueParams,
486 pub notify: Interrupt,
487 pub event: Event,
488}
489
490pub struct Resources {
491 pub features: VirtioDeviceFeatures,
492 pub queues: Vec<QueueResources>,
493 pub shared_memory_region: Option<Arc<dyn MappedMemoryRegion>>,
494 pub shared_memory_size: u64,
495}