chipset_device_worker/
worker.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! A worker for running ChipsetDevice implementations in a separate process.
5//!
6//! This worker provides process isolation for any device implementing the
7//! ChipsetDevice trait. It handles serialization and deserialization of
8//! device operations across process boundaries.
9
10#![forbid(unsafe_code)]
11
12mod configure;
13
14use crate::RemoteDynamicResolvers;
15use crate::guestmem::GuestMemoryRemoteBuilder;
16use crate::protocol::*;
17use anyhow::Context;
18use chipset_device::ChipsetDevice;
19use chipset_device::io::IoResult;
20use chipset_device::io::deferred::DeferredToken;
21use chipset_device_resources::ErasedChipsetDevice;
22use chipset_device_resources::ResolveChipsetDeviceHandleParams;
23use mesh::MeshPayload;
24use mesh::error::RemoteError;
25use mesh_worker::Worker;
26use mesh_worker::WorkerId;
27use mesh_worker::WorkerRpc;
28use pal_async::DefaultPool;
29use std::task::Poll;
30use vm_resource::Resource;
31use vm_resource::ResourceResolver;
32use vm_resource::kind::ChipsetDeviceHandleKind;
33use vmcore::device_state::ChangeDeviceState;
34use vmcore::save_restore::ProtobufSaveRestore;
35
36/// Worker ID for ChipsetDevice workers.
37pub(crate) const fn remote_chipset_device_worker_id<T: RemoteDynamicResolvers>()
38-> WorkerId<RemoteChipsetDeviceWorkerParameters<T>> {
39    WorkerId::new(T::WORKER_ID_STR)
40}
41
42/// Parameters for launching a remote chipset device worker.
43#[derive(MeshPayload)]
44pub struct RemoteChipsetDeviceWorkerParameters<T> {
45    pub(crate) device: Resource<ChipsetDeviceHandleKind>,
46    pub(crate) dyn_resolvers: T,
47    pub(crate) inputs: RemoteChipsetDeviceHandleParams,
48
49    pub(crate) req_recv: mesh::Receiver<DeviceRequest>,
50    pub(crate) resp_send: mesh::Sender<DeviceResponse>,
51    pub(crate) cap_send: mesh::OneshotSender<DeviceInit>,
52}
53
54#[derive(MeshPayload)]
55pub(crate) struct RemoteChipsetDeviceHandleParams {
56    pub device_name: String,
57    pub is_restoring: bool,
58    pub vmtime: vmcore::vmtime::VmTimeSourceBuilder,
59    pub guest_memory: GuestMemoryRemoteBuilder,
60    pub encrypted_guest_memory: GuestMemoryRemoteBuilder,
61}
62
63/// The chipset device worker.
64///
65/// This worker wraps any device implementing ChipsetDevice and handles
66/// device operations sent via mesh channels.
67pub struct RemoteChipsetDeviceWorker<T> {
68    device: ErasedChipsetDevice,
69    pool: Option<DefaultPool>,
70    req_recv: mesh::Receiver<DeviceRequest>,
71    resp_send: mesh::Sender<DeviceResponse>,
72    deferred_reads: Vec<DeferredRead>,
73    deferred_writes: Vec<DeferredWrite>,
74
75    _phantom_resolvers: std::marker::PhantomData<T>,
76}
77
78struct DeferredRead {
79    id: usize,
80    token: DeferredToken,
81    size: usize,
82}
83
84struct DeferredWrite {
85    id: usize,
86    token: DeferredToken,
87}
88
89impl<T: RemoteDynamicResolvers> Worker for RemoteChipsetDeviceWorker<T> {
90    type Parameters = RemoteChipsetDeviceWorkerParameters<T>;
91    type State = ();
92    const ID: WorkerId<Self::Parameters> = remote_chipset_device_worker_id();
93
94    fn new(params: Self::Parameters) -> anyhow::Result<Self> {
95        let mut pool = DefaultPool::new();
96
97        let RemoteChipsetDeviceWorkerParameters {
98            device,
99            dyn_resolvers,
100            inputs,
101
102            req_recv,
103            resp_send,
104            cap_send,
105        } = params;
106
107        let mut resolver = ResourceResolver::new();
108
109        let driver = pool.driver();
110        let mut device = pool
111            .run_until(async move {
112                dyn_resolvers
113                    .register_remote_dynamic_resolvers(&mut resolver)
114                    .await?;
115                resolver
116                    .resolve(
117                        device,
118                        ResolveChipsetDeviceHandleParams {
119                            device_name: &inputs.device_name,
120                            guest_memory: &inputs.guest_memory.build("remote_gm"),
121                            encrypted_guest_memory: &inputs
122                                .encrypted_guest_memory
123                                .build("remote_enc_gm"),
124                            vmtime: &inputs
125                                .vmtime
126                                .build(&driver)
127                                .await
128                                .context("failed to build vmtime source")?,
129                            is_restoring: inputs.is_restoring,
130                            task_driver_source: &vmcore::vm_task::VmTaskDriverSource::new(
131                                vmcore::vm_task::thread::ThreadDriverBackend::new(driver),
132                            ),
133                            // TODO: Actually wire these up
134                            configure: &mut configure::RemoteConfigureChipsetDevice {},
135                            register_mmio: &mut configure::RemoteRegisterMmio {},
136                            register_pio: &mut configure::RemoteRegisterPio {},
137                        },
138                    )
139                    .await
140                    .context("failed to resolve device")
141            })?
142            .0;
143
144        if device.supports_acknowledge_pic_interrupt().is_some()
145            || device.supports_handle_eoi().is_some()
146            || device.supports_line_interrupt_target().is_some()
147        {
148            anyhow::bail!("remote device requires unimplemented functionality");
149        }
150
151        cap_send.send(DeviceInit {
152            mmio: device.supports_mmio().map(|m| MmioInit {
153                static_regions: m
154                    .get_static_regions()
155                    .iter()
156                    .map(|(name, range)| ((*name).into(), *range.start(), *range.end()))
157                    .collect(),
158            }),
159            pio: device.supports_pio().map(|p| PioInit {
160                static_regions: p
161                    .get_static_regions()
162                    .iter()
163                    .map(|(name, range)| ((*name).into(), *range.start(), *range.end()))
164                    .collect(),
165            }),
166            pci: device.supports_pci().map(|p| PciInit {
167                suggested_bdf: p.suggested_bdf(),
168            }),
169        });
170
171        Ok(Self {
172            device,
173            pool: Some(pool),
174            req_recv,
175            resp_send,
176            deferred_reads: Vec::new(),
177            deferred_writes: Vec::new(),
178            _phantom_resolvers: std::marker::PhantomData,
179        })
180    }
181
182    fn restart(_state: Self::State) -> anyhow::Result<Self> {
183        todo!()
184    }
185
186    fn run(mut self, mut rpc_recv: mesh::Receiver<WorkerRpc<Self::State>>) -> anyhow::Result<()> {
187        self.pool.take().unwrap().run_until(async move {
188            loop {
189                enum WorkerEvent {
190                    Rpc(WorkerRpc<()>),
191                    DeviceRequest(DeviceRequest),
192                }
193
194                let event = std::future::poll_fn(|cx| {
195                    if let Some(poll_device) = self.device.supports_poll_device() {
196                        poll_device.poll_device(cx);
197                    }
198
199                    self.deferred_reads
200                        .extract_if(.., |read| {
201                            let mut data = vec![0; read.size];
202                            match read.token.poll_read(cx, &mut data) {
203                                Poll::Ready(r) => {
204                                    self.resp_send.send(DeviceResponse::Read {
205                                        id: read.id,
206                                        result: r.map(|_| data),
207                                    });
208                                    true
209                                }
210                                Poll::Pending => false,
211                            }
212                        })
213                        .for_each(|_| ());
214
215                    self.deferred_writes
216                        .extract_if(.., |write| match write.token.poll_write(cx) {
217                            Poll::Ready(r) => {
218                                self.resp_send.send(DeviceResponse::Write {
219                                    id: write.id,
220                                    result: r,
221                                });
222                                true
223                            }
224                            Poll::Pending => false,
225                        })
226                        .for_each(|_| ());
227
228                    // If either of these channels fail, we fail the worker too.
229                    if let Poll::Ready(r) = rpc_recv.poll_recv(cx) {
230                        return Poll::Ready(r.map(WorkerEvent::Rpc));
231                    }
232                    if let Poll::Ready(r) = self.req_recv.poll_recv(cx) {
233                        return Poll::Ready(r.map(WorkerEvent::DeviceRequest));
234                    }
235                    Poll::Pending
236                })
237                .await?;
238
239                match event {
240                    WorkerEvent::Rpc(rpc) => match rpc {
241                        WorkerRpc::Inspect(deferred) => {
242                            deferred.inspect(&mut self.device);
243                        }
244                        WorkerRpc::Stop => {
245                            return Ok(());
246                        }
247                        WorkerRpc::Restart(rpc) => {
248                            rpc.complete(Err(RemoteError::new(anyhow::anyhow!("not supported"))));
249                        }
250                    },
251                    WorkerEvent::DeviceRequest(req) => match req {
252                        DeviceRequest::Start => self.device.start(),
253                        DeviceRequest::Stop(rpc) => {
254                            rpc.handle(async |()| self.device.stop().await).await
255                        }
256                        DeviceRequest::Reset(rpc) => {
257                            self.deferred_reads.clear();
258                            self.deferred_writes.clear();
259                            rpc.handle(async |()| self.device.reset().await).await
260                        }
261                        DeviceRequest::MmioRead(ReadRequest { id, address, size }) => {
262                            let mut data = vec![0; size];
263                            let result = self
264                                .device
265                                .supports_mmio()
266                                .unwrap()
267                                .mmio_read(address, &mut data);
268                            self.handle_read_result(id, result, data);
269                        }
270                        DeviceRequest::MmioWrite(WriteRequest { id, address, data }) => {
271                            let result = self
272                                .device
273                                .supports_mmio()
274                                .unwrap()
275                                .mmio_write(address, &data);
276                            self.handle_write_result(id, result);
277                        }
278                        DeviceRequest::PioRead(ReadRequest { id, address, size }) => {
279                            let mut data = vec![0; size];
280                            let result = self
281                                .device
282                                .supports_pio()
283                                .unwrap()
284                                .io_read(address, &mut data);
285                            self.handle_read_result(id, result, data);
286                        }
287                        DeviceRequest::PioWrite(WriteRequest { id, address, data }) => {
288                            let result =
289                                self.device.supports_pio().unwrap().io_write(address, &data);
290                            self.handle_write_result(id, result);
291                        }
292                        DeviceRequest::PciConfigRead(ReadRequest { id, address, size }) => {
293                            assert_eq!(size, 4);
294                            let mut data = 0;
295                            let result = self
296                                .device
297                                .supports_pci()
298                                .unwrap()
299                                .pci_cfg_read(address, &mut data);
300                            self.handle_read_result(id, result, data.to_ne_bytes().to_vec());
301                        }
302                        DeviceRequest::PciConfigWrite(WriteRequest { id, address, data }) => {
303                            let result = self
304                                .device
305                                .supports_pci()
306                                .unwrap()
307                                .pci_cfg_write(address, data);
308                            self.handle_write_result(id, result);
309                        }
310                        DeviceRequest::Save(rpc) => {
311                            rpc.handle_failable_sync(|()| self.device.save())
312                        }
313                        DeviceRequest::Restore(rpc) => {
314                            rpc.handle_failable_sync(|state| self.device.restore(state))
315                        }
316                    },
317                }
318            }
319        })
320    }
321}
322
323impl<T> RemoteChipsetDeviceWorker<T> {
324    fn handle_read_result(&mut self, id: usize, result: IoResult, data: Vec<u8>) {
325        match result {
326            IoResult::Ok => self.resp_send.send(DeviceResponse::Read {
327                id,
328                result: Ok(data),
329            }),
330            IoResult::Err(io_error) => self.resp_send.send(DeviceResponse::Read {
331                id,
332                result: Err(io_error),
333            }),
334            IoResult::Defer(token) => self.deferred_reads.push(DeferredRead {
335                id,
336                token,
337                size: data.len(),
338            }),
339        }
340    }
341
342    fn handle_write_result(&mut self, id: usize, result: IoResult) {
343        match result {
344            IoResult::Ok => self
345                .resp_send
346                .send(DeviceResponse::Write { id, result: Ok(()) }),
347            IoResult::Err(io_error) => self.resp_send.send(DeviceResponse::Write {
348                id,
349                result: Err(io_error),
350            }),
351            IoResult::Defer(token) => self.deferred_writes.push(DeferredWrite { id, token }),
352        }
353    }
354}