virtiofs/
virtio.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(dead_code)]
5
6use crate::virtio_util::VirtioPayloadReader;
7use crate::virtio_util::VirtioPayloadWriter;
8use async_trait::async_trait;
9use guestmem::GuestMemory;
10use guestmem::MappedMemoryRegion;
11use pal_async::task::Spawn;
12use std::io;
13use std::io::Write;
14use std::sync::Arc;
15use task_control::TaskControl;
16use virtio::DeviceTraits;
17use virtio::DeviceTraitsSharedMemory;
18use virtio::Resources;
19use virtio::VirtioDevice;
20use virtio::VirtioQueueCallbackWork;
21use virtio::VirtioQueueState;
22use virtio::VirtioQueueWorker;
23use virtio::VirtioQueueWorkerContext;
24use vmcore::vm_task::VmTaskDriver;
25use vmcore::vm_task::VmTaskDriverSource;
26use zerocopy::Immutable;
27use zerocopy::IntoBytes;
28use zerocopy::KnownLayout;
29
30const VIRTIO_DEVICE_TYPE_FS: u16 = 26;
31
32/// PCI configuration space values for virtio-fs devices.
33#[repr(C)]
34#[derive(IntoBytes, Immutable, KnownLayout)]
35struct VirtioFsDeviceConfig {
36    tag: [u8; 36],
37    num_request_queues: u32,
38}
39
40/// A virtio-fs PCI device.
41pub struct VirtioFsDevice {
42    name: Box<str>,
43
44    driver: VmTaskDriver,
45    config: VirtioFsDeviceConfig,
46    memory: GuestMemory,
47    fs: Arc<fuse::Session>,
48    workers: Vec<TaskControl<VirtioQueueWorker, VirtioQueueState>>,
49    exit_event: event_listener::Event,
50    shmem_size: u64,
51    notify_corruption: Arc<dyn Fn() + Sync + Send>,
52}
53
54impl VirtioFsDevice {
55    /// Creates a new `VirtioFsDevice` with the specified mount tag.
56    pub fn new<Fs>(
57        driver_source: &VmTaskDriverSource,
58        tag: &str,
59        fs: Fs,
60        memory: GuestMemory,
61        shmem_size: u64,
62        notify_corruption: Option<Arc<dyn Fn() + Sync + Send>>,
63    ) -> Self
64    where
65        Fs: 'static + fuse::Fuse + Send + Sync,
66    {
67        let mut config = VirtioFsDeviceConfig {
68            tag: [0; 36],
69            num_request_queues: 1,
70        };
71
72        let notify_corruption = if let Some(notify) = notify_corruption {
73            notify
74        } else {
75            Arc::new(|| {})
76        };
77
78        // Copy the tag into the config space (truncate it for now if too long).
79        let length = std::cmp::min(tag.len(), config.tag.len());
80        config.tag[..length].copy_from_slice(&tag.as_bytes()[..length]);
81
82        Self {
83            name: format!("virtio-fs-{}", tag).into(),
84            driver: driver_source.simple(),
85            config,
86            memory,
87            fs: Arc::new(fuse::Session::new(fs)),
88            workers: Vec::new(),
89            exit_event: event_listener::Event::new(),
90            shmem_size,
91            notify_corruption,
92        }
93    }
94}
95
96impl VirtioDevice for VirtioFsDevice {
97    fn traits(&self) -> DeviceTraits {
98        DeviceTraits {
99            device_id: VIRTIO_DEVICE_TYPE_FS,
100            device_features: 0,
101            max_queues: 2,
102            device_register_length: self.config.as_bytes().len() as u32,
103            shared_memory: DeviceTraitsSharedMemory {
104                id: 0,
105                size: self.shmem_size,
106            },
107        }
108    }
109
110    fn read_registers_u32(&self, offset: u16) -> u32 {
111        let offset = offset as usize;
112        let config = self.config.as_bytes();
113        if offset < config.len() {
114            u32::from_le_bytes(
115                config[offset..offset + 4]
116                    .try_into()
117                    .expect("Incorrect length"),
118            )
119        } else {
120            0
121        }
122    }
123
124    fn write_registers_u32(&mut self, offset: u16, val: u32) {
125        tracing::warn!(offset, val, "[virtiofs] Unknown write",);
126    }
127
128    fn enable(&mut self, resources: Resources) {
129        self.workers = resources
130            .queues
131            .into_iter()
132            .filter_map(|queue_resources| {
133                if !queue_resources.params.enable {
134                    return None;
135                }
136                let worker = VirtioFsWorker {
137                    fs: self.fs.clone(),
138                    mem: self.memory.clone(),
139                    shared_memory_region: resources.shared_memory_region.clone(),
140                    shared_memory_size: resources.shared_memory_size,
141                    notify_corruption: self.notify_corruption.clone(),
142                };
143                let worker = VirtioQueueWorker::new(self.driver.clone(), Box::new(worker));
144                Some(worker.into_running_task(
145                    "virtiofs-virtio-queue".to_string(),
146                    self.memory.clone(),
147                    resources.features,
148                    queue_resources,
149                    self.exit_event.listen(),
150                ))
151            })
152            .collect();
153    }
154
155    fn disable(&mut self) {
156        self.exit_event.notify(usize::MAX);
157        let mut workers = self.workers.drain(..).collect::<Vec<_>>();
158        self.driver
159            .spawn("shutdown-virtiofs-queues".to_owned(), async move {
160                futures::future::join_all(workers.iter_mut().map(async |worker| {
161                    worker.stop().await;
162                }))
163                .await;
164            })
165            .detach();
166    }
167}
168
169struct VirtioFsWorker {
170    fs: Arc<fuse::Session>,
171    mem: GuestMemory,
172    shared_memory_region: Option<Arc<dyn MappedMemoryRegion>>,
173    shared_memory_size: u64,
174    notify_corruption: Arc<dyn Fn() + Sync + Send>,
175}
176
177#[async_trait]
178impl VirtioQueueWorkerContext for VirtioFsWorker {
179    async fn process_work(&mut self, work: anyhow::Result<VirtioQueueCallbackWork>) -> bool {
180        if let Err(err) = work {
181            tracing::error!(
182                error = err.as_ref() as &dyn std::error::Error,
183                "Failed processing queue"
184            );
185            return false;
186        }
187
188        let mut work = work.unwrap();
189        // Parse the request.
190        let reader = VirtioPayloadReader::new(&self.mem, &work);
191        let request = match fuse::Request::new(reader) {
192            Ok(request) => request,
193            Err(e) => {
194                tracing::error!(
195                    error = &e as &dyn std::error::Error,
196                    "[virtiofs] Invalid FUSE message, error"
197                );
198                // Often this will result in the guest failing the device as there is no response to a request.
199                (self.notify_corruption)();
200                // This only happens if even the header couldn't be parsed, so there's no way
201                // to send an error reply since the request's unique ID isn't known.
202                work.complete(0);
203                return true;
204            }
205        };
206
207        // Dispatch to the file system.
208        let mut sender = VirtioReplySender {
209            work,
210            mem: &self.mem,
211        };
212        let mapper = self
213            .shared_memory_region
214            .as_ref()
215            .map(|shared_memory_region| VirtioMapper {
216                region: shared_memory_region.as_ref(),
217                size: self.shared_memory_size,
218            });
219        self.fs.dispatch(
220            request,
221            &mut sender,
222            mapper.as_ref().map(|x| x as &dyn fuse::Mapper),
223        );
224        true
225    }
226}
227/// An implementation of `ReplySender` for virtio payload.
228struct VirtioReplySender<'a> {
229    work: VirtioQueueCallbackWork,
230    mem: &'a GuestMemory,
231}
232
233impl fuse::ReplySender for VirtioReplySender<'_> {
234    fn send(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<()> {
235        let mut writer = VirtioPayloadWriter::new(self.mem, &self.work);
236        let mut size = 0;
237
238        // Write all the slices to the payload buffers.
239        // N.B. write_vectored isn't used because it isn't guaranteed to write all the data.
240        for buf in bufs {
241            writer.write_all(buf)?;
242            size += buf.len();
243        }
244
245        self.work.complete(size as u32);
246        Ok(())
247    }
248}
249
250struct VirtioMapper<'a> {
251    region: &'a dyn MappedMemoryRegion,
252    size: u64,
253}
254
255impl fuse::Mapper for VirtioMapper<'_> {
256    fn map(
257        &self,
258        offset: u64,
259        file: fuse::FileRef<'_>,
260        file_offset: u64,
261        len: u64,
262        writable: bool,
263    ) -> lx::Result<()> {
264        let offset = offset.try_into().map_err(|_| lx::Error::EINVAL)?;
265        let len = len.try_into().map_err(|_| lx::Error::EINVAL)?;
266        self.region.map(offset, &file, file_offset, len, writable)?;
267        Ok(())
268    }
269
270    fn unmap(&self, offset: u64, len: u64) -> lx::Result<()> {
271        let offset = offset.try_into().map_err(|_| lx::Error::EINVAL)?;
272        let len = len.try_into().map_err(|_| lx::Error::EINVAL)?;
273        self.region.unmap(offset, len)?;
274        Ok(())
275    }
276
277    fn clear(&self) {
278        let result = self.region.unmap(0, self.size as usize);
279        if let Err(result) = result {
280            tracing::error!(
281                error = &result as &dyn std::error::Error,
282                "Failed to unmap shared memory"
283            );
284        }
285    }
286}