virtio/
common.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use crate::queue::QueueCoreCompleteWork;
5use crate::queue::QueueCoreGetWork;
6use crate::queue::QueueError;
7use crate::queue::QueueParams;
8use crate::queue::QueueWork;
9use crate::queue::VirtioQueuePayload;
10use crate::queue::new_queue;
11use crate::spec::VirtioDeviceFeatures;
12use async_trait::async_trait;
13use futures::FutureExt;
14use futures::Stream;
15use futures::StreamExt;
16use guestmem::DoorbellRegistration;
17use guestmem::GuestMemory;
18use guestmem::GuestMemoryError;
19use guestmem::MappedMemoryRegion;
20use inspect::Inspect;
21use pal_async::DefaultPool;
22use pal_async::driver::Driver;
23use pal_async::wait::PolledWait;
24use pal_event::Event;
25use parking_lot::Mutex;
26use std::io::Error;
27use std::pin::Pin;
28use std::sync::Arc;
29use std::task::Context;
30use std::task::Poll;
31use std::task::ready;
32use task_control::AsyncRun;
33use task_control::StopTask;
34use task_control::TaskControl;
35use thiserror::Error;
36use vmcore::interrupt::Interrupt;
37
38#[async_trait]
39pub trait VirtioQueueWorkerContext {
40    async fn process_work(&mut self, work: anyhow::Result<VirtioQueueCallbackWork>) -> bool;
41}
42
43#[derive(Debug, Inspect)]
44pub struct VirtioQueueUsedHandler {
45    #[inspect(skip)]
46    core: QueueCoreCompleteWork,
47    #[inspect(with = "|x| x.lock().0")]
48    outstanding_desc_count: Arc<Mutex<(u16, event_listener::Event)>>,
49    #[inspect(skip)]
50    notify_guest: Interrupt,
51}
52
53impl VirtioQueueUsedHandler {
54    fn new(core: QueueCoreCompleteWork, notify_guest: Interrupt) -> Self {
55        Self {
56            core,
57            outstanding_desc_count: Arc::new(Mutex::new((0, event_listener::Event::new()))),
58            notify_guest,
59        }
60    }
61
62    pub fn add_outstanding_descriptor(&self) {
63        let (count, _) = &mut *self.outstanding_desc_count.lock();
64        *count += 1;
65    }
66
67    pub fn complete_descriptor(&mut self, work: &QueueWork, bytes_written: u32) {
68        match self.core.complete_descriptor(work, bytes_written) {
69            Ok(true) => {
70                self.notify_guest.deliver();
71            }
72            Ok(false) => {}
73            Err(err) => {
74                tracelimit::error_ratelimited!(
75                    error = &err as &dyn std::error::Error,
76                    "failed to complete descriptor"
77                );
78            }
79        }
80        {
81            let (count, event) = &mut *self.outstanding_desc_count.lock();
82            *count -= 1;
83            if *count == 0 {
84                event.notify(usize::MAX);
85            }
86        }
87    }
88}
89
90pub struct VirtioQueueCallbackWork {
91    used_queue_handler: Arc<Mutex<VirtioQueueUsedHandler>>,
92    work: QueueWork,
93    pub payload: Vec<VirtioQueuePayload>,
94    completed: bool,
95}
96
97impl VirtioQueueCallbackWork {
98    pub fn new(
99        mut work: QueueWork,
100        used_queue_handler: &Arc<Mutex<VirtioQueueUsedHandler>>,
101    ) -> Self {
102        let used_queue_handler = used_queue_handler.clone();
103        let payload = std::mem::take(&mut work.payload);
104        used_queue_handler.lock().add_outstanding_descriptor();
105        Self {
106            work,
107            payload,
108            used_queue_handler,
109            completed: false,
110        }
111    }
112
113    pub fn complete(&mut self, bytes_written: u32) {
114        assert!(!self.completed);
115        self.used_queue_handler
116            .lock()
117            .complete_descriptor(&self.work, bytes_written);
118        self.completed = true;
119    }
120
121    pub fn descriptor_index(&self) -> u16 {
122        self.work.descriptor_index()
123    }
124
125    // Determine the total size of all readable or all writeable payload buffers.
126    pub fn get_payload_length(&self, writeable: bool) -> u64 {
127        self.payload
128            .iter()
129            .filter(|x| x.writeable == writeable)
130            .fold(0, |acc, x| acc + x.length as u64)
131    }
132
133    // Read all payload into a buffer.
134    pub fn read(&self, mem: &GuestMemory, target: &mut [u8]) -> Result<usize, GuestMemoryError> {
135        let mut remaining = target;
136        let mut read_bytes: usize = 0;
137        for payload in &self.payload {
138            if payload.writeable {
139                continue;
140            }
141
142            let size = std::cmp::min(payload.length as usize, remaining.len());
143            let (current, next) = remaining.split_at_mut(size);
144            mem.read_at(payload.address, current)?;
145            read_bytes += size;
146            if next.is_empty() {
147                break;
148            }
149
150            remaining = next;
151        }
152
153        Ok(read_bytes)
154    }
155
156    // Write the specified buffer to the payload buffers.
157    pub fn write_at_offset(
158        &self,
159        offset: u64,
160        mem: &GuestMemory,
161        source: &[u8],
162    ) -> Result<(), VirtioWriteError> {
163        let mut skip_bytes = offset;
164        let mut remaining = source;
165        for payload in &self.payload {
166            if !payload.writeable {
167                continue;
168            }
169
170            let payload_length = payload.length as u64;
171            if skip_bytes >= payload_length {
172                skip_bytes -= payload_length;
173                continue;
174            }
175
176            let size = std::cmp::min(
177                payload_length as usize - skip_bytes as usize,
178                remaining.len(),
179            );
180            let (current, next) = remaining.split_at(size);
181            mem.write_at(payload.address + skip_bytes, current)?;
182            remaining = next;
183            if remaining.is_empty() {
184                break;
185            }
186            skip_bytes = 0;
187        }
188
189        if !remaining.is_empty() {
190            return Err(VirtioWriteError::NotAllWritten(source.len()));
191        }
192
193        Ok(())
194    }
195
196    pub fn write(&self, mem: &GuestMemory, source: &[u8]) -> Result<(), VirtioWriteError> {
197        self.write_at_offset(0, mem, source)
198    }
199}
200
201#[derive(Debug, Error)]
202pub enum VirtioWriteError {
203    #[error(transparent)]
204    Memory(#[from] GuestMemoryError),
205    #[error("{0:#x} bytes not written")]
206    NotAllWritten(usize),
207}
208
209impl Drop for VirtioQueueCallbackWork {
210    fn drop(&mut self) {
211        if !self.completed {
212            self.complete(0);
213        }
214    }
215}
216
217#[derive(Debug, Inspect)]
218pub struct VirtioQueue {
219    #[inspect(flatten)]
220    core: QueueCoreGetWork,
221    used_handler: Arc<Mutex<VirtioQueueUsedHandler>>,
222    #[inspect(skip)]
223    queue_event: PolledWait<Event>,
224}
225
226impl VirtioQueue {
227    pub fn new(
228        features: VirtioDeviceFeatures,
229        params: QueueParams,
230        mem: GuestMemory,
231        notify: Interrupt,
232        queue_event: PolledWait<Event>,
233    ) -> Result<Self, QueueError> {
234        let (get_work, complete_work) = new_queue(features, mem, params)?;
235        let used_handler = Arc::new(Mutex::new(VirtioQueueUsedHandler::new(
236            complete_work,
237            notify,
238        )));
239        Ok(Self {
240            core: get_work,
241            used_handler,
242            queue_event,
243        })
244    }
245
246    /// Polls until the queue is kicked by the guest, indicating new work may be available.
247    pub fn poll_kick(&mut self, cx: &mut Context<'_>) -> Poll<()> {
248        ready!(self.queue_event.wait().poll_unpin(cx)).expect("waits on Event cannot fail");
249        Poll::Ready(())
250    }
251
252    /// Try to get the next work item from the queue. Returns `Ok(None)` if no
253    /// work is currently available, or an error if there was an issue accessing
254    /// the queue.
255    ///
256    /// If `None` is returned, then the queue will be armed so that the guest
257    /// will kick it when new work is available; the caller can use
258    /// [`poll_kick`](Self::poll_kick) to wait for this.
259    pub fn try_next(&mut self) -> Result<Option<VirtioQueueCallbackWork>, Error> {
260        Ok(self
261            .core
262            .try_next_work()
263            .map_err(Error::other)?
264            .map(|work| VirtioQueueCallbackWork::new(work, &self.used_handler)))
265    }
266
267    fn poll_next_buffer(
268        &mut self,
269        cx: &mut Context<'_>,
270    ) -> Poll<Result<Option<VirtioQueueCallbackWork>, Error>> {
271        loop {
272            if let Some(work) = self.try_next()? {
273                return Ok(Some(work)).into();
274            }
275            ready!(self.poll_kick(cx));
276        }
277    }
278}
279
280impl Drop for VirtioQueue {
281    fn drop(&mut self) {
282        if Arc::get_mut(&mut self.used_handler).is_none() {
283            tracing::error!("Virtio queue dropped with outstanding work pending")
284        }
285    }
286}
287
288impl Stream for VirtioQueue {
289    type Item = Result<VirtioQueueCallbackWork, Error>;
290
291    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
292        ready!(self.get_mut().poll_next_buffer(cx))
293            .transpose()
294            .into()
295    }
296}
297
298enum VirtioQueueStateInner {
299    Initializing {
300        mem: GuestMemory,
301        features: VirtioDeviceFeatures,
302        params: QueueParams,
303        event: Event,
304        notify: Interrupt,
305        exit_event: event_listener::EventListener,
306    },
307    InitializationInProgress,
308    Running {
309        queue: VirtioQueue,
310        exit_event: event_listener::EventListener,
311    },
312}
313
314pub struct VirtioQueueState {
315    inner: VirtioQueueStateInner,
316}
317
318pub struct VirtioQueueWorker {
319    driver: Box<dyn Driver>,
320    context: Box<dyn VirtioQueueWorkerContext + Send>,
321}
322
323impl VirtioQueueWorker {
324    pub fn new(driver: impl Driver, context: Box<dyn VirtioQueueWorkerContext + Send>) -> Self {
325        Self {
326            driver: Box::new(driver),
327            context,
328        }
329    }
330
331    pub fn into_running_task(
332        self,
333        name: impl Into<String>,
334        mem: GuestMemory,
335        features: VirtioDeviceFeatures,
336        queue_resources: QueueResources,
337        exit_event: event_listener::EventListener,
338    ) -> TaskControl<VirtioQueueWorker, VirtioQueueState> {
339        let name = name.into();
340        let (_, driver) = DefaultPool::spawn_on_thread(&name);
341
342        let mut task = TaskControl::new(self);
343        task.insert(
344            driver,
345            name,
346            VirtioQueueState {
347                inner: VirtioQueueStateInner::Initializing {
348                    mem,
349                    features,
350                    params: queue_resources.params,
351                    event: queue_resources.event,
352                    notify: queue_resources.notify,
353                    exit_event,
354                },
355            },
356        );
357        task.start();
358        task
359    }
360
361    async fn run_queue(&mut self, state: &mut VirtioQueueState) -> bool {
362        match &mut state.inner {
363            VirtioQueueStateInner::InitializationInProgress => unreachable!(),
364            VirtioQueueStateInner::Initializing { .. } => {
365                let VirtioQueueStateInner::Initializing {
366                    mem,
367                    features,
368                    params,
369                    event,
370                    notify,
371                    exit_event,
372                } = std::mem::replace(
373                    &mut state.inner,
374                    VirtioQueueStateInner::InitializationInProgress,
375                )
376                else {
377                    unreachable!()
378                };
379                let queue_event = PolledWait::new(&self.driver, event).unwrap();
380                let queue = VirtioQueue::new(features, params, mem, notify, queue_event);
381                if let Err(err) = queue {
382                    tracing::error!(
383                        err = &err as &dyn std::error::Error,
384                        "Failed to start queue"
385                    );
386                    false
387                } else {
388                    state.inner = VirtioQueueStateInner::Running {
389                        queue: queue.unwrap(),
390                        exit_event,
391                    };
392                    true
393                }
394            }
395            VirtioQueueStateInner::Running { queue, exit_event } => {
396                let mut exit = exit_event.fuse();
397                let mut queue_ready = queue.next().fuse();
398                let work = futures::select_biased! {
399                    _ = exit => return false,
400                    work = queue_ready => work.expect("queue will never complete").map_err(anyhow::Error::from),
401                };
402                self.context.process_work(work).await
403            }
404        }
405    }
406}
407
408impl AsyncRun<VirtioQueueState> for VirtioQueueWorker {
409    async fn run(
410        &mut self,
411        stop: &mut StopTask<'_>,
412        state: &mut VirtioQueueState,
413    ) -> Result<(), task_control::Cancelled> {
414        while stop.until_stopped(self.run_queue(state)).await? {}
415        Ok(())
416    }
417}
418
419pub(crate) struct VirtioDoorbells {
420    registration: Option<Arc<dyn DoorbellRegistration>>,
421    doorbells: Vec<Box<dyn Send + Sync>>,
422}
423
424impl VirtioDoorbells {
425    pub fn new(registration: Option<Arc<dyn DoorbellRegistration>>) -> Self {
426        Self {
427            registration,
428            doorbells: Vec::new(),
429        }
430    }
431
432    pub fn add(&mut self, address: u64, value: Option<u64>, length: Option<u32>, event: &Event) {
433        if let Some(registration) = &mut self.registration {
434            let doorbell = registration.register_doorbell(address, value, length, event);
435            if let Ok(doorbell) = doorbell {
436                self.doorbells.push(doorbell);
437            }
438        }
439    }
440
441    pub fn clear(&mut self) {
442        self.doorbells.clear();
443    }
444}
445
446#[derive(Copy, Clone, Debug, Default)]
447pub struct DeviceTraitsSharedMemory {
448    pub id: u8,
449    pub size: u64,
450}
451
452#[derive(Clone, Debug, Default)]
453pub struct DeviceTraits {
454    pub device_id: u16,
455    pub device_features: VirtioDeviceFeatures,
456    pub max_queues: u16,
457    pub device_register_length: u32,
458    pub shared_memory: DeviceTraitsSharedMemory,
459}
460
461pub trait VirtioDevice: inspect::InspectMut + Send {
462    fn traits(&self) -> DeviceTraits;
463    fn read_registers_u32(&self, offset: u16) -> u32;
464    fn write_registers_u32(&mut self, offset: u16, val: u32);
465    /// Enable the device with the given resources.
466    ///
467    /// Called when the guest sets `DRIVER_OK`. On success, the device should
468    /// start processing queues and the transport will reflect `DRIVER_OK` in
469    /// the device status. On failure, the transport will log the error and
470    /// leave `DRIVER_OK` unset, so the device remains inert and the guest
471    /// will observe failures through IO timeouts.
472    fn enable(&mut self, resources: Resources) -> anyhow::Result<()>;
473    /// Poll the device to complete a disable/reset operation.
474    ///
475    /// This is called when the guest writes status=0 (device reset). The device
476    /// should stop workers and drain any in-flight IO. Returns `Poll::Ready(())`
477    /// when the disable is complete, or `Poll::Pending` if more work is needed.
478    ///
479    /// Devices that don't need async cleanup can return `Poll::Ready(())`
480    /// immediately.
481    fn poll_disable(&mut self, cx: &mut Context<'_>) -> Poll<()>;
482}
483
484pub struct QueueResources {
485    pub params: QueueParams,
486    pub notify: Interrupt,
487    pub event: Event,
488}
489
490pub struct Resources {
491    pub features: VirtioDeviceFeatures,
492    pub queues: Vec<QueueResources>,
493    pub shared_memory_region: Option<Arc<dyn MappedMemoryRegion>>,
494    pub shared_memory_size: u64,
495}