vmotherboard/chipset/backing/arc_mutex/
state_unit.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use crate::VmmChipsetDevice;
5use async_trait::async_trait;
6use closeable_mutex::CloseableMutex;
7use futures::FutureExt;
8use futures::StreamExt;
9use futures::task::ArcWake;
10use futures::task::WakerRef;
11use futures::task::waker_ref;
12use inspect::InspectMut;
13use state_unit::StateRequest;
14use state_unit::StateUnit;
15use std::sync::Arc;
16use std::task::Context;
17use vmcore::save_restore::RestoreError;
18use vmcore::save_restore::SaveError;
19use vmcore::save_restore::SavedStateBlob;
20use vmcore::slim_event::SlimEvent;
21
22pub struct ArcMutexChipsetDeviceUnit {
23    device: Arc<CloseableMutex<dyn DynDevice>>,
24    poll_event: Arc<PollEvent>,
25    running: bool,
26    omit_saved_state: bool,
27}
28
29impl InspectMut for ArcMutexChipsetDeviceUnit {
30    fn inspect_mut(&mut self, req: inspect::Request<'_>) {
31        self.device.lock().inspect_mut(req);
32    }
33}
34
35/// Object-safe trait for the subset of [`VmmChipsetDevice`] that we use here.
36#[async_trait]
37trait DynDevice: InspectMut + Send {
38    fn start(&mut self);
39    async fn stop(&mut self);
40    async fn reset(&mut self);
41    fn poll_device(&mut self, cx: &mut Context<'_>);
42    fn save(&mut self) -> Result<SavedStateBlob, SaveError>;
43    fn restore(&mut self, state: SavedStateBlob) -> Result<(), RestoreError>;
44}
45
46#[async_trait]
47impl<T: VmmChipsetDevice> DynDevice for T {
48    fn start(&mut self) {
49        self.start()
50    }
51
52    async fn stop(&mut self) {
53        self.stop().await
54    }
55
56    async fn reset(&mut self) {
57        self.reset().await
58    }
59
60    fn poll_device(&mut self, cx: &mut Context<'_>) {
61        if let Some(poll) = self.supports_poll_device() {
62            poll.poll_device(cx);
63        }
64    }
65
66    fn save(&mut self) -> Result<SavedStateBlob, SaveError> {
67        self.save()
68    }
69
70    fn restore(&mut self, state: SavedStateBlob) -> Result<(), RestoreError> {
71        self.restore(state)
72    }
73}
74
75struct PollEvent(SlimEvent);
76
77impl ArcWake for PollEvent {
78    fn wake_by_ref(arc_self: &Arc<Self>) {
79        arc_self.0.signal();
80    }
81}
82
83impl ArcMutexChipsetDeviceUnit {
84    pub fn new(
85        device: Arc<CloseableMutex<impl 'static + VmmChipsetDevice>>,
86        omit_saved_state: bool,
87    ) -> Self {
88        Self {
89            device,
90            poll_event: Arc::new(PollEvent(SlimEvent::new())),
91            running: false,
92            omit_saved_state,
93        }
94    }
95
96    pub async fn run(mut self, mut state_change: mesh::Receiver<StateRequest>) -> Self {
97        loop {
98            enum Event<'a> {
99                StateChange(StateRequest),
100                Poll(WakerRef<'a>),
101            }
102
103            // Wait for poll requests.
104            let poll_fut = async {
105                if self.running {
106                    self.poll_event.0.wait().await;
107                    return waker_ref(&self.poll_event);
108                }
109                // The device is not running. Never
110                // complete this future.
111                std::future::pending().await
112            };
113
114            let event = futures::select! {
115                req = state_change.next() => {
116                    if let Some(req) = req {
117                        Event::StateChange(req)
118                    } else {
119                        break;
120                    }
121                }
122                waker = poll_fut.fuse() => {
123                    Event::Poll(waker)
124                }
125            };
126
127            match event {
128                Event::StateChange(req) => {
129                    req.apply(&mut self).await;
130                }
131                Event::Poll(waker) => {
132                    let mut device = self.device.lock();
133                    device.poll_device(&mut Context::from_waker(&waker));
134                }
135            }
136        }
137        self
138    }
139}
140
141impl StateUnit for ArcMutexChipsetDeviceUnit {
142    async fn start(&mut self) {
143        self.running = true;
144
145        // Poll the device at least once.
146        let mut device = self.device.lock();
147        device.start();
148        device.poll_device(&mut Context::from_waker(&waker_ref(&self.poll_event)));
149    }
150
151    async fn stop(&mut self) {
152        self.device.clone().close().stop().await;
153        self.running = false;
154        // FUTURE: consider closing the mutex while the device is stopped to
155        // find bugs. This may be difficult or not worth it since it requires
156        // that:
157        //
158        // 1. all cross-device dependencies are exactly correct.
159        // 2. no device manipulation happens externally to normal VM operation
160        //    (e.g., no calls are made to the device while the VM is stopped).
161        //
162        // These are currently not true.
163    }
164
165    async fn reset(&mut self) -> anyhow::Result<()> {
166        self.device.clone().close().reset().await;
167        Ok(())
168    }
169
170    async fn save(&mut self) -> Result<Option<SavedStateBlob>, SaveError> {
171        if self.omit_saved_state {
172            return Ok(None);
173        }
174
175        // TODO: make async
176        let state = self.device.clone().close().save()?;
177        Ok(Some(state))
178    }
179
180    async fn restore(&mut self, state: SavedStateBlob) -> Result<(), RestoreError> {
181        // TODO: make async
182        self.device.clone().close().restore(state)
183    }
184}