virtio_p9/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6#![cfg(any(windows, target_os = "linux"))]
7
8pub mod resolver;
9
10use anyhow::Context as _;
11use futures::StreamExt;
12use guestmem::GuestMemory;
13use inspect::InspectMut;
14use pal_async::wait::PolledWait;
15use plan9::Plan9FileSystem;
16use task_control::AsyncRun;
17use task_control::Cancelled;
18use task_control::InspectTaskMut;
19use task_control::StopTask;
20use task_control::TaskControl;
21use virtio::DeviceTraits;
22use virtio::QueueResources;
23use virtio::VirtioDevice;
24use virtio::VirtioQueue;
25use virtio::VirtioQueueCallbackWork;
26use virtio::queue::QueueState;
27use virtio::spec::VirtioDeviceFeatures;
28use vmcore::vm_task::VmTaskDriver;
29use vmcore::vm_task::VmTaskDriverSource;
30
31const VIRTIO_9P_F_MOUNT_TAG: u32 = 1;
32
33#[derive(InspectMut)]
34pub struct VirtioPlan9Device {
35    tag: Vec<u8>,
36    driver: VmTaskDriver,
37    #[inspect(mut)]
38    worker: TaskControl<Plan9Worker, Plan9Queue>,
39}
40
41impl VirtioPlan9Device {
42    pub fn new(
43        driver_source: &VmTaskDriverSource,
44        tag: &str,
45        fs: Plan9FileSystem,
46    ) -> VirtioPlan9Device {
47        // The tag uses the same format as 9p protocol strings (2 byte length followed by string).
48        let length = tag.len() + size_of::<u16>();
49
50        // Round the length up to a multiple of 4 to make the read function simpler.
51        let length = (length + 3) & !3;
52        let mut tag_buffer = vec![0u8; length];
53
54        // Write a string preceded by a two byte length.
55        {
56            use std::io::Write;
57            let mut cursor = std::io::Cursor::new(&mut tag_buffer);
58            cursor.write_all(&(tag.len() as u16).to_le_bytes()).unwrap();
59            cursor.write_all(tag.as_bytes()).unwrap();
60        }
61
62        VirtioPlan9Device {
63            tag: tag_buffer,
64            driver: driver_source.simple(),
65            worker: TaskControl::new(Plan9Worker { fs }),
66        }
67    }
68}
69
70impl VirtioDevice for VirtioPlan9Device {
71    fn traits(&self) -> DeviceTraits {
72        DeviceTraits {
73            device_id: virtio::spec::VirtioDeviceType::P9,
74            device_features: VirtioDeviceFeatures::new().with_bank(0, VIRTIO_9P_F_MOUNT_TAG),
75            max_queues: 1,
76            device_register_length: self.tag.len() as u32,
77            ..Default::default()
78        }
79    }
80
81    async fn read_registers_u32(&mut self, offset: u16) -> u32 {
82        assert!(self.tag.len().is_multiple_of(4));
83        assert!(offset.is_multiple_of(4));
84
85        let offset = offset as usize;
86        if offset < self.tag.len() {
87            u32::from_le_bytes(
88                self.tag[offset..offset + 4]
89                    .try_into()
90                    .expect("Incorrect length"),
91            )
92        } else {
93            0
94        }
95    }
96
97    async fn write_registers_u32(&mut self, offset: u16, val: u32) {
98        tracing::warn!(offset, val, "[VIRTIO 9P] Unknown write",);
99    }
100
101    async fn start_queue(
102        &mut self,
103        idx: u16,
104        resources: QueueResources,
105        features: &VirtioDeviceFeatures,
106        initial_state: Option<QueueState>,
107    ) -> anyhow::Result<()> {
108        assert_eq!(idx, 0);
109
110        let queue_event = PolledWait::new(&self.driver, resources.event)
111            .context("failed to create polled wait")?;
112        let queue = VirtioQueue::new(
113            features.clone(),
114            resources.params,
115            resources.guest_memory.clone(),
116            resources.notify,
117            queue_event,
118            initial_state,
119        )
120        .context("failed to create virtio queue")?;
121
122        self.worker.insert(
123            self.driver.clone(),
124            "virtio-9p-queue",
125            Plan9Queue {
126                queue,
127                mem: resources.guest_memory,
128            },
129        );
130        self.worker.start();
131        Ok(())
132    }
133
134    async fn stop_queue(&mut self, idx: u16) -> Option<QueueState> {
135        assert_eq!(idx, 0);
136        if !self.worker.has_state() {
137            return None;
138        }
139        self.worker.stop().await;
140        let state = self.worker.remove().queue.queue_state();
141        Some(state)
142    }
143
144    async fn reset(&mut self) {
145        self.worker.task().fs.reset();
146    }
147}
148
149#[derive(InspectMut)]
150struct Plan9Worker {
151    #[inspect(skip)]
152    fs: Plan9FileSystem,
153}
154
155#[derive(InspectMut)]
156struct Plan9Queue {
157    queue: VirtioQueue,
158    mem: GuestMemory,
159}
160
161impl InspectTaskMut<Plan9Queue> for Plan9Worker {
162    fn inspect_mut(&mut self, req: inspect::Request<'_>, state: Option<&mut Plan9Queue>) {
163        req.respond().merge(self).merge(state);
164    }
165}
166
167impl AsyncRun<Plan9Queue> for Plan9Worker {
168    async fn run(
169        &mut self,
170        stop: &mut StopTask<'_>,
171        state: &mut Plan9Queue,
172    ) -> Result<(), Cancelled> {
173        loop {
174            let work = stop.until_stopped(state.queue.next()).await?;
175            let Some(work) = work else { break };
176            match work {
177                Ok(work) => {
178                    process_9p_request(&state.mem, &self.fs, work);
179                }
180                Err(err) => {
181                    tracing::error!(error = &err as &dyn std::error::Error, "queue error");
182                    break;
183                }
184            }
185        }
186        Ok(())
187    }
188}
189
190fn process_9p_request(mem: &GuestMemory, fs: &Plan9FileSystem, mut work: VirtioQueueCallbackWork) {
191    // Make a copy of the incoming message.
192    let mut message = vec![0; work.get_payload_length(false) as usize];
193    if let Err(e) = work.read(mem, &mut message) {
194        tracing::error!(
195            error = &e as &dyn std::error::Error,
196            "[VIRTIO 9P] Failed to read guest memory"
197        );
198        return;
199    }
200
201    // Allocate a temporary buffer for the response.
202    let mut response = vec![9; work.get_payload_length(true) as usize];
203    if let Ok(size) = fs.process_message(&message, &mut response) {
204        // Write out the response.
205        if let Err(e) = work.write(mem, &response[0..size]) {
206            tracing::error!(
207                error = &e as &dyn std::error::Error,
208                "[VIRTIO 9P] Failed to write guest memory"
209            );
210            return;
211        }
212
213        work.complete(size as u32);
214    }
215}