virtio/
common.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use crate::queue::QueueCore;
5use crate::queue::QueueError;
6use crate::queue::QueueParams;
7use crate::queue::VirtioQueuePayload;
8use async_trait::async_trait;
9use futures::FutureExt;
10use futures::Stream;
11use futures::StreamExt;
12use guestmem::DoorbellRegistration;
13use guestmem::GuestMemory;
14use guestmem::GuestMemoryError;
15use guestmem::MappedMemoryRegion;
16use pal_async::DefaultPool;
17use pal_async::driver::Driver;
18use pal_async::task::Spawn;
19use pal_async::wait::PolledWait;
20use pal_event::Event;
21use parking_lot::Mutex;
22use std::io::Error;
23use std::pin::Pin;
24use std::sync::Arc;
25use std::task::Context;
26use std::task::Poll;
27use std::task::ready;
28use task_control::AsyncRun;
29use task_control::StopTask;
30use task_control::TaskControl;
31use thiserror::Error;
32use vmcore::interrupt::Interrupt;
33use vmcore::vm_task::VmTaskDriver;
34use vmcore::vm_task::VmTaskDriverSource;
35
36#[async_trait]
37pub trait VirtioQueueWorkerContext {
38    async fn process_work(&mut self, work: anyhow::Result<VirtioQueueCallbackWork>) -> bool;
39}
40
41#[derive(Debug)]
42pub struct VirtioQueueUsedHandler {
43    core: QueueCore,
44    last_used_index: u16,
45    outstanding_desc_count: Arc<Mutex<(u16, event_listener::Event)>>,
46    notify_guest: Interrupt,
47}
48
49impl VirtioQueueUsedHandler {
50    fn new(core: QueueCore, notify_guest: Interrupt) -> Self {
51        Self {
52            core,
53            last_used_index: 0,
54            outstanding_desc_count: Arc::new(Mutex::new((0, event_listener::Event::new()))),
55            notify_guest,
56        }
57    }
58
59    pub fn add_outstanding_descriptor(&self) {
60        let (count, _) = &mut *self.outstanding_desc_count.lock();
61        *count += 1;
62    }
63
64    pub fn await_outstanding_descriptors(&self) -> event_listener::EventListener {
65        let (count, event) = &*self.outstanding_desc_count.lock();
66        let listener = event.listen();
67        if *count == 0 {
68            event.notify(usize::MAX);
69        }
70        listener
71    }
72
73    pub fn complete_descriptor(&mut self, descriptor_index: u16, bytes_written: u32) {
74        match self.core.complete_descriptor(
75            &mut self.last_used_index,
76            descriptor_index,
77            bytes_written,
78        ) {
79            Ok(true) => {
80                self.notify_guest.deliver();
81            }
82            Ok(false) => {}
83            Err(err) => {
84                tracelimit::error_ratelimited!(
85                    error = &err as &dyn std::error::Error,
86                    "failed to complete descriptor"
87                );
88            }
89        }
90        {
91            let (count, event) = &mut *self.outstanding_desc_count.lock();
92            *count -= 1;
93            if *count == 0 {
94                event.notify(usize::MAX);
95            }
96        }
97    }
98}
99
100pub struct VirtioQueueCallbackWork {
101    pub payload: Vec<VirtioQueuePayload>,
102    used_queue_handler: Arc<Mutex<VirtioQueueUsedHandler>>,
103    descriptor_index: u16,
104    completed: bool,
105}
106
107impl VirtioQueueCallbackWork {
108    pub fn new(
109        payload: Vec<VirtioQueuePayload>,
110        used_queue_handler: &Arc<Mutex<VirtioQueueUsedHandler>>,
111        descriptor_index: u16,
112    ) -> Self {
113        let used_queue_handler = used_queue_handler.clone();
114        used_queue_handler.lock().add_outstanding_descriptor();
115        Self {
116            payload,
117            used_queue_handler,
118            descriptor_index,
119            completed: false,
120        }
121    }
122
123    pub fn complete(&mut self, bytes_written: u32) {
124        assert!(!self.completed);
125        self.used_queue_handler
126            .lock()
127            .complete_descriptor(self.descriptor_index, bytes_written);
128        self.completed = true;
129    }
130
131    pub fn descriptor_index(&self) -> u16 {
132        self.descriptor_index
133    }
134
135    // Determine the total size of all readable or all writeable payload buffers.
136    pub fn get_payload_length(&self, writeable: bool) -> u64 {
137        self.payload
138            .iter()
139            .filter(|x| x.writeable == writeable)
140            .fold(0, |acc, x| acc + x.length as u64)
141    }
142
143    // Read all payload into a buffer.
144    pub fn read(&self, mem: &GuestMemory, target: &mut [u8]) -> Result<usize, GuestMemoryError> {
145        let mut remaining = target;
146        let mut read_bytes: usize = 0;
147        for payload in &self.payload {
148            if payload.writeable {
149                continue;
150            }
151
152            let size = std::cmp::min(payload.length as usize, remaining.len());
153            let (current, next) = remaining.split_at_mut(size);
154            mem.read_at(payload.address, current)?;
155            read_bytes += size;
156            if next.is_empty() {
157                break;
158            }
159
160            remaining = next;
161        }
162
163        Ok(read_bytes)
164    }
165
166    // Write the specified buffer to the payload buffers.
167    pub fn write_at_offset(
168        &self,
169        offset: u64,
170        mem: &GuestMemory,
171        source: &[u8],
172    ) -> Result<(), VirtioWriteError> {
173        let mut skip_bytes = offset;
174        let mut remaining = source;
175        for payload in &self.payload {
176            if !payload.writeable {
177                continue;
178            }
179
180            let payload_length = payload.length as u64;
181            if skip_bytes >= payload_length {
182                skip_bytes -= payload_length;
183                continue;
184            }
185
186            let size = std::cmp::min(
187                payload_length as usize - skip_bytes as usize,
188                remaining.len(),
189            );
190            let (current, next) = remaining.split_at(size);
191            mem.write_at(payload.address + skip_bytes, current)?;
192            remaining = next;
193            if remaining.is_empty() {
194                break;
195            }
196            skip_bytes = 0;
197        }
198
199        if !remaining.is_empty() {
200            return Err(VirtioWriteError::NotAllWritten(source.len()));
201        }
202
203        Ok(())
204    }
205
206    pub fn write(&self, mem: &GuestMemory, source: &[u8]) -> Result<(), VirtioWriteError> {
207        self.write_at_offset(0, mem, source)
208    }
209}
210
211#[derive(Debug, Error)]
212pub enum VirtioWriteError {
213    #[error(transparent)]
214    Memory(#[from] GuestMemoryError),
215    #[error("{0:#x} bytes not written")]
216    NotAllWritten(usize),
217}
218
219impl Drop for VirtioQueueCallbackWork {
220    fn drop(&mut self) {
221        if !self.completed {
222            self.complete(0);
223        }
224    }
225}
226
227#[derive(Debug)]
228pub struct VirtioQueue {
229    core: QueueCore,
230    last_avail_index: u16,
231    used_handler: Arc<Mutex<VirtioQueueUsedHandler>>,
232    queue_event: PolledWait<Event>,
233}
234
235impl VirtioQueue {
236    pub fn new(
237        features: u64,
238        params: QueueParams,
239        mem: GuestMemory,
240        notify: Interrupt,
241        queue_event: PolledWait<Event>,
242    ) -> Result<Self, QueueError> {
243        let core = QueueCore::new(features, mem, params)?;
244        let used_handler = Arc::new(Mutex::new(VirtioQueueUsedHandler::new(
245            core.clone(),
246            notify,
247        )));
248        Ok(Self {
249            core,
250            last_avail_index: 0,
251            used_handler,
252            queue_event,
253        })
254    }
255
256    async fn wait_for_outstanding_descriptors(&self) {
257        let wait_for_descriptors = self.used_handler.lock().await_outstanding_descriptors();
258        wait_for_descriptors.await;
259    }
260
261    fn poll_next_buffer(
262        &mut self,
263        cx: &mut Context<'_>,
264    ) -> Poll<Result<Option<VirtioQueueCallbackWork>, QueueError>> {
265        let descriptor_index = loop {
266            if let Some(descriptor_index) = self.core.descriptor_index(self.last_avail_index)? {
267                break descriptor_index;
268            };
269            ready!(self.queue_event.wait().poll_unpin(cx)).expect("waits on Event cannot fail");
270        };
271        let payload = self
272            .core
273            .reader(descriptor_index)
274            .collect::<Result<Vec<_>, _>>()?;
275
276        self.last_avail_index = self.last_avail_index.wrapping_add(1);
277        Poll::Ready(Ok(Some(VirtioQueueCallbackWork::new(
278            payload,
279            &self.used_handler,
280            descriptor_index,
281        ))))
282    }
283}
284
285impl Drop for VirtioQueue {
286    fn drop(&mut self) {
287        if Arc::get_mut(&mut self.used_handler).is_none() {
288            tracing::error!("Virtio queue dropped with outstanding work pending")
289        }
290    }
291}
292
293impl Stream for VirtioQueue {
294    type Item = Result<VirtioQueueCallbackWork, Error>;
295
296    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
297        let Some(r) = ready!(self.get_mut().poll_next_buffer(cx)).transpose() else {
298            return Poll::Ready(None);
299        };
300
301        Poll::Ready(Some(r.map_err(Error::other)))
302    }
303}
304
305enum VirtioQueueStateInner {
306    Initializing {
307        mem: GuestMemory,
308        features: u64,
309        params: QueueParams,
310        event: Event,
311        notify: Interrupt,
312        exit_event: event_listener::EventListener,
313    },
314    InitializationInProgress,
315    Running {
316        queue: VirtioQueue,
317        exit_event: event_listener::EventListener,
318    },
319}
320
321pub struct VirtioQueueState {
322    inner: VirtioQueueStateInner,
323}
324
325pub struct VirtioQueueWorker {
326    driver: Box<dyn Driver>,
327    context: Box<dyn VirtioQueueWorkerContext + Send>,
328}
329
330impl VirtioQueueWorker {
331    pub fn new(driver: impl Driver, context: Box<dyn VirtioQueueWorkerContext + Send>) -> Self {
332        Self {
333            driver: Box::new(driver),
334            context,
335        }
336    }
337
338    pub fn into_running_task(
339        self,
340        name: impl Into<String>,
341        mem: GuestMemory,
342        features: u64,
343        queue_resources: QueueResources,
344        exit_event: event_listener::EventListener,
345    ) -> TaskControl<VirtioQueueWorker, VirtioQueueState> {
346        let name = name.into();
347        let (_, driver) = DefaultPool::spawn_on_thread(&name);
348
349        let mut task = TaskControl::new(self);
350        task.insert(
351            driver,
352            name,
353            VirtioQueueState {
354                inner: VirtioQueueStateInner::Initializing {
355                    mem,
356                    features,
357                    params: queue_resources.params,
358                    event: queue_resources.event,
359                    notify: queue_resources.notify,
360                    exit_event,
361                },
362            },
363        );
364        task.start();
365        task
366    }
367
368    async fn run_queue(&mut self, state: &mut VirtioQueueState) -> bool {
369        match &mut state.inner {
370            VirtioQueueStateInner::InitializationInProgress => unreachable!(),
371            VirtioQueueStateInner::Initializing { .. } => {
372                let VirtioQueueStateInner::Initializing {
373                    mem,
374                    features,
375                    params,
376                    event,
377                    notify,
378                    exit_event,
379                } = std::mem::replace(
380                    &mut state.inner,
381                    VirtioQueueStateInner::InitializationInProgress,
382                )
383                else {
384                    unreachable!()
385                };
386                let queue_event = PolledWait::new(&self.driver, event).unwrap();
387                let queue = VirtioQueue::new(features, params, mem, notify, queue_event);
388                if let Err(err) = queue {
389                    tracing::error!(
390                        err = &err as &dyn std::error::Error,
391                        "Failed to start queue"
392                    );
393                    false
394                } else {
395                    state.inner = VirtioQueueStateInner::Running {
396                        queue: queue.unwrap(),
397                        exit_event,
398                    };
399                    true
400                }
401            }
402            VirtioQueueStateInner::Running { queue, exit_event } => {
403                let mut exit = exit_event.fuse();
404                let mut queue_ready = queue.next().fuse();
405                let work = futures::select_biased! {
406                    _ = exit => return false,
407                    work = queue_ready => work.expect("queue will never complete").map_err(anyhow::Error::from),
408                };
409                self.context.process_work(work).await
410            }
411        }
412    }
413}
414
415impl AsyncRun<VirtioQueueState> for VirtioQueueWorker {
416    async fn run(
417        &mut self,
418        stop: &mut StopTask<'_>,
419        state: &mut VirtioQueueState,
420    ) -> Result<(), task_control::Cancelled> {
421        while stop.until_stopped(self.run_queue(state)).await? {}
422        Ok(())
423    }
424}
425
426pub struct VirtioRunningState {
427    pub features: u64,
428    pub enabled_queues: Vec<bool>,
429}
430
431pub enum VirtioState {
432    Unknown,
433    Running(VirtioRunningState),
434    Stopped,
435}
436
437pub(crate) struct VirtioDoorbells {
438    registration: Option<Arc<dyn DoorbellRegistration>>,
439    doorbells: Vec<Box<dyn Send + Sync>>,
440}
441
442impl VirtioDoorbells {
443    pub fn new(registration: Option<Arc<dyn DoorbellRegistration>>) -> Self {
444        Self {
445            registration,
446            doorbells: Vec::new(),
447        }
448    }
449
450    pub fn add(&mut self, address: u64, value: Option<u64>, length: Option<u32>, event: &Event) {
451        if let Some(registration) = &mut self.registration {
452            let doorbell = registration.register_doorbell(address, value, length, event);
453            if let Ok(doorbell) = doorbell {
454                self.doorbells.push(doorbell);
455            }
456        }
457    }
458
459    pub fn clear(&mut self) {
460        self.doorbells.clear();
461    }
462}
463
464#[derive(Copy, Clone, Debug, Default)]
465pub struct DeviceTraitsSharedMemory {
466    pub id: u8,
467    pub size: u64,
468}
469
470#[derive(Copy, Clone, Debug, Default)]
471pub struct DeviceTraits {
472    pub device_id: u16,
473    pub device_features: u64,
474    pub max_queues: u16,
475    pub device_register_length: u32,
476    pub shared_memory: DeviceTraitsSharedMemory,
477}
478
479pub trait LegacyVirtioDevice: Send {
480    fn traits(&self) -> DeviceTraits;
481    fn read_registers_u32(&self, offset: u16) -> u32;
482    fn write_registers_u32(&mut self, offset: u16, val: u32);
483    fn get_work_callback(&mut self, index: u16) -> Box<dyn VirtioQueueWorkerContext + Send>;
484    fn state_change(&mut self, state: &VirtioState);
485}
486
487pub trait VirtioDevice: Send {
488    fn traits(&self) -> DeviceTraits;
489    fn read_registers_u32(&self, offset: u16) -> u32;
490    fn write_registers_u32(&mut self, offset: u16, val: u32);
491    fn enable(&mut self, resources: Resources);
492    fn disable(&mut self);
493}
494
495pub struct QueueResources {
496    pub params: QueueParams,
497    pub notify: Interrupt,
498    pub event: Event,
499}
500
501pub struct Resources {
502    pub features: u64,
503    pub queues: Vec<QueueResources>,
504    pub shared_memory_region: Option<Arc<dyn MappedMemoryRegion>>,
505    pub shared_memory_size: u64,
506}
507
508/// Wraps an object implementing [`LegacyVirtioDevice`] and implements [`VirtioDevice`].
509pub struct LegacyWrapper<T: LegacyVirtioDevice> {
510    device: T,
511    driver: VmTaskDriver,
512    mem: GuestMemory,
513    workers: Vec<TaskControl<VirtioQueueWorker, VirtioQueueState>>,
514    exit_event: event_listener::Event,
515}
516
517impl<T: LegacyVirtioDevice> LegacyWrapper<T> {
518    pub fn new(driver_source: &VmTaskDriverSource, device: T, mem: &GuestMemory) -> Self {
519        Self {
520            device,
521            driver: driver_source.simple(),
522            mem: mem.clone(),
523            workers: Vec::new(),
524            exit_event: event_listener::Event::new(),
525        }
526    }
527}
528
529impl<T: LegacyVirtioDevice> VirtioDevice for LegacyWrapper<T> {
530    fn traits(&self) -> DeviceTraits {
531        self.device.traits()
532    }
533
534    fn read_registers_u32(&self, offset: u16) -> u32 {
535        self.device.read_registers_u32(offset)
536    }
537
538    fn write_registers_u32(&mut self, offset: u16, val: u32) {
539        self.device.write_registers_u32(offset, val)
540    }
541
542    fn enable(&mut self, resources: Resources) {
543        let running_state = VirtioRunningState {
544            features: resources.features,
545            enabled_queues: resources
546                .queues
547                .iter()
548                .map(|QueueResources { params, .. }| params.enable)
549                .collect(),
550        };
551
552        self.device
553            .state_change(&VirtioState::Running(running_state));
554        self.workers = resources
555            .queues
556            .into_iter()
557            .enumerate()
558            .filter_map(|(i, queue_resources)| {
559                if !queue_resources.params.enable {
560                    return None;
561                }
562                let worker = VirtioQueueWorker::new(
563                    self.driver.clone(),
564                    self.device.get_work_callback(i as u16),
565                );
566                Some(worker.into_running_task(
567                    "virtio-queue".to_string(),
568                    self.mem.clone(),
569                    resources.features,
570                    queue_resources,
571                    self.exit_event.listen(),
572                ))
573            })
574            .collect();
575    }
576
577    fn disable(&mut self) {
578        if self.workers.is_empty() {
579            return;
580        }
581        self.exit_event.notify(usize::MAX);
582        self.device.state_change(&VirtioState::Stopped);
583        let mut workers = self.workers.drain(..).collect::<Vec<_>>();
584        self.driver
585            .spawn("shutdown-legacy-virtio-queues".to_owned(), async move {
586                futures::future::join_all(workers.iter_mut().map(async |worker| {
587                    worker.stop().await;
588                    if let Some(VirtioQueueStateInner::Running { queue, .. }) =
589                        worker.state_mut().map(|s| &s.inner)
590                    {
591                        queue.wait_for_outstanding_descriptors().await;
592                    }
593                }))
594                .await;
595            })
596            .detach();
597    }
598}
599
600impl<T: LegacyVirtioDevice> Drop for LegacyWrapper<T> {
601    fn drop(&mut self) {
602        self.disable();
603    }
604}