virtiofs/
virtio.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use crate::virtio_util::VirtioPayloadReader;
5use crate::virtio_util::VirtioPayloadWriter;
6use anyhow::Context as _;
7use futures::StreamExt;
8use guestmem::GuestMemory;
9use guestmem::MappedMemoryRegion;
10use inspect::InspectMut;
11use pal_async::wait::PolledWait;
12use std::io;
13use std::io::Write;
14use std::sync::Arc;
15use task_control::AsyncRun;
16use task_control::Cancelled;
17use task_control::StopTask;
18use task_control::TaskControl;
19use virtio::DeviceTraits;
20use virtio::DeviceTraitsSharedMemory;
21use virtio::QueueResources;
22use virtio::VirtioDevice;
23use virtio::VirtioQueue;
24use virtio::VirtioQueueCallbackWork;
25use virtio::queue::QueueState;
26use virtio::spec::VirtioDeviceFeatures;
27use vmcore::vm_task::VmTaskDriver;
28use vmcore::vm_task::VmTaskDriverSource;
29use zerocopy::Immutable;
30use zerocopy::IntoBytes;
31use zerocopy::KnownLayout;
32
33/// PCI configuration space values for virtio-fs devices.
34#[repr(C)]
35#[derive(IntoBytes, Immutable, KnownLayout)]
36struct VirtioFsDeviceConfig {
37    tag: [u8; 36],
38    num_request_queues: u32,
39}
40
41/// A virtio-fs PCI device.
42#[derive(InspectMut)]
43pub struct VirtioFsDevice {
44    task_name: Box<str>,
45    driver: VmTaskDriver,
46    #[inspect(skip)]
47    config: VirtioFsDeviceConfig,
48    #[inspect(skip)]
49    fs: Arc<fuse::Session>,
50    #[inspect(skip)]
51    workers: Vec<TaskControl<VirtioFsWorker, VirtioFsQueue>>,
52    shmem_size: u64,
53    #[inspect(skip)]
54    shared_memory_region: Option<Arc<dyn MappedMemoryRegion>>,
55    #[inspect(skip)]
56    notify_corruption: Arc<dyn Fn() + Sync + Send>,
57}
58
59impl VirtioFsDevice {
60    /// Creates a new `VirtioFsDevice` with the specified mount tag.
61    pub fn new<Fs>(
62        driver_source: &VmTaskDriverSource,
63        tag: &str,
64        fs: Fs,
65        shmem_size: u64,
66        notify_corruption: Option<Arc<dyn Fn() + Sync + Send>>,
67    ) -> Self
68    where
69        Fs: 'static + fuse::Fuse + Send + Sync,
70    {
71        let mut config = VirtioFsDeviceConfig {
72            tag: [0; 36],
73            num_request_queues: 1,
74        };
75
76        let notify_corruption = if let Some(notify) = notify_corruption {
77            notify
78        } else {
79            Arc::new(|| {})
80        };
81
82        // Copy the tag into the config space (truncate it for now if too long).
83        let length = std::cmp::min(tag.len(), config.tag.len());
84        config.tag[..length].copy_from_slice(&tag.as_bytes()[..length]);
85
86        Self {
87            task_name: format!("virtiofs-{}", tag).into(),
88            driver: driver_source.simple(),
89            config,
90            fs: Arc::new(fuse::Session::new(fs)),
91            workers: Vec::new(),
92            shmem_size,
93            shared_memory_region: None,
94            notify_corruption,
95        }
96    }
97}
98
99impl VirtioDevice for VirtioFsDevice {
100    fn traits(&self) -> DeviceTraits {
101        DeviceTraits {
102            device_id: virtio::spec::VirtioDeviceType::FS,
103            device_features: VirtioDeviceFeatures::new()
104                .with_bank0(
105                    virtio::spec::VirtioDeviceFeaturesBank0::new()
106                        .with_ring_event_idx(true)
107                        .with_ring_indirect_desc(true),
108                )
109                .with_bank1(virtio::spec::VirtioDeviceFeaturesBank1::new().with_ring_packed(true)),
110            max_queues: 2,
111            device_register_length: self.config.as_bytes().len() as u32,
112            shared_memory: DeviceTraitsSharedMemory {
113                id: 0,
114                size: self.shmem_size,
115            },
116        }
117    }
118
119    async fn read_registers_u32(&mut self, offset: u16) -> u32 {
120        let offset = offset as usize;
121        let config = self.config.as_bytes();
122        if offset < config.len() {
123            u32::from_le_bytes(
124                config[offset..offset + 4]
125                    .try_into()
126                    .expect("Incorrect length"),
127            )
128        } else {
129            0
130        }
131    }
132
133    async fn write_registers_u32(&mut self, offset: u16, val: u32) {
134        tracing::warn!(offset, val, "[virtiofs] Unknown write",);
135    }
136
137    fn set_shared_memory_region(
138        &mut self,
139        region: &Arc<dyn MappedMemoryRegion>,
140    ) -> anyhow::Result<()> {
141        self.shared_memory_region = Some(region.clone());
142        Ok(())
143    }
144
145    async fn start_queue(
146        &mut self,
147        idx: u16,
148        resources: QueueResources,
149        features: &VirtioDeviceFeatures,
150        initial_state: Option<QueueState>,
151    ) -> anyhow::Result<()> {
152        let mut tc = TaskControl::new(VirtioFsWorker {
153            fs: self.fs.clone(),
154            shared_memory_region: self.shared_memory_region.clone(),
155            shared_memory_size: self.shmem_size,
156            notify_corruption: self.notify_corruption.clone(),
157        });
158
159        let queue_event = PolledWait::new(&self.driver, resources.event)
160            .context("failed to create polled wait")?;
161        let queue = VirtioQueue::new(
162            features.clone(),
163            resources.params,
164            resources.guest_memory.clone(),
165            resources.notify,
166            queue_event,
167            initial_state,
168        )
169        .context("failed to create virtio queue")?;
170
171        tc.insert(
172            self.driver.clone(),
173            &*self.task_name,
174            VirtioFsQueue {
175                queue,
176                mem: resources.guest_memory,
177            },
178        );
179        tc.start();
180
181        let idx = idx as usize;
182        if idx >= self.workers.len() {
183            self.workers.resize_with(idx + 1, || {
184                TaskControl::new(VirtioFsWorker {
185                    fs: self.fs.clone(),
186                    shared_memory_region: None,
187                    shared_memory_size: 0,
188                    notify_corruption: self.notify_corruption.clone(),
189                })
190            });
191        }
192        self.workers[idx] = tc;
193        Ok(())
194    }
195
196    async fn stop_queue(&mut self, idx: u16) -> Option<QueueState> {
197        let idx = idx as usize;
198        if idx >= self.workers.len() || !self.workers[idx].has_state() {
199            return None;
200        }
201        self.workers[idx].stop().await;
202        let state = self.workers[idx].remove().queue.queue_state();
203        Some(state)
204    }
205
206    async fn reset(&mut self) {
207        self.workers.clear();
208        if let Some(region) = &self.shared_memory_region {
209            if let Err(e) = region.unmap(0, self.shmem_size as usize) {
210                tracing::error!(
211                    error = &e as &dyn std::error::Error,
212                    "failed to unmap DAX region on reset"
213                );
214            }
215        }
216        self.shared_memory_region = None;
217        self.fs.destroy();
218    }
219}
220
221struct VirtioFsWorker {
222    fs: Arc<fuse::Session>,
223    shared_memory_region: Option<Arc<dyn MappedMemoryRegion>>,
224    shared_memory_size: u64,
225    notify_corruption: Arc<dyn Fn() + Sync + Send>,
226}
227
228struct VirtioFsQueue {
229    queue: VirtioQueue,
230    mem: GuestMemory,
231}
232
233impl AsyncRun<VirtioFsQueue> for VirtioFsWorker {
234    async fn run(
235        &mut self,
236        stop: &mut StopTask<'_>,
237        state: &mut VirtioFsQueue,
238    ) -> Result<(), Cancelled> {
239        loop {
240            let work = stop.until_stopped(state.queue.next()).await?;
241            let Some(work) = work else { break };
242            match work {
243                Ok(work) => {
244                    let bytes = process_virtiofs_request(self, &state.mem, &work);
245                    state.queue.complete(work, bytes);
246                }
247                Err(err) => {
248                    tracing::error!(
249                        error = &err as &dyn std::error::Error,
250                        "Failed processing queue"
251                    );
252                    break;
253                }
254            }
255        }
256        Ok(())
257    }
258}
259
260fn process_virtiofs_request(
261    worker: &VirtioFsWorker,
262    mem: &GuestMemory,
263    work: &VirtioQueueCallbackWork,
264) -> u32 {
265    // Parse the request.
266    let reader = VirtioPayloadReader::new(mem, work);
267    let request = match fuse::Request::new(reader) {
268        Ok(request) => request,
269        Err(e) => {
270            tracing::error!(
271                error = &e as &dyn std::error::Error,
272                "[virtiofs] Invalid FUSE message, error"
273            );
274            // Often this will result in the guest failing the device as there is no response to a request.
275            (worker.notify_corruption)();
276            // This only happens if even the header couldn't be parsed, so there's no way
277            // to send an error reply since the request's unique ID isn't known.
278            return 0;
279        }
280    };
281
282    // Dispatch to the file system. The sender writes the reply into guest
283    // memory but does not complete the descriptor—completion happens once,
284    // after dispatch returns. For FUSE no-reply operations (Forget,
285    // BatchForget, Destroy), send() is never called and bytes_written
286    // stays 0.
287    let mut sender = VirtioReplySender {
288        work,
289        mem,
290        bytes_written: 0,
291    };
292    let mapper = worker
293        .shared_memory_region
294        .as_ref()
295        .map(|shared_memory_region| VirtioMapper {
296            region: shared_memory_region.as_ref(),
297            size: worker.shared_memory_size,
298        });
299    worker.fs.dispatch(
300        request,
301        &mut sender,
302        mapper.as_ref().map(|x| x as &dyn fuse::Mapper),
303    );
304    sender.bytes_written
305}
306/// An implementation of `ReplySender` for virtio payload.
307///
308/// Writes the FUSE reply into guest memory and records the byte count.
309/// Does not complete the descriptor—the caller is responsible for that.
310struct VirtioReplySender<'a> {
311    work: &'a VirtioQueueCallbackWork,
312    mem: &'a GuestMemory,
313    bytes_written: u32,
314}
315
316impl fuse::ReplySender for VirtioReplySender<'_> {
317    fn send(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<()> {
318        let mut writer = VirtioPayloadWriter::new(self.mem, self.work);
319        let mut size = 0;
320
321        // Write all the slices to the payload buffers.
322        // N.B. write_vectored isn't used because it isn't guaranteed to write all the data.
323        for buf in bufs {
324            writer.write_all(buf)?;
325            size += buf.len();
326        }
327
328        self.bytes_written = size as u32;
329        Ok(())
330    }
331}
332
333struct VirtioMapper<'a> {
334    region: &'a dyn MappedMemoryRegion,
335    size: u64,
336}
337
338impl fuse::Mapper for VirtioMapper<'_> {
339    fn map(
340        &self,
341        offset: u64,
342        file: fuse::FileRef<'_>,
343        file_offset: u64,
344        len: u64,
345        writable: bool,
346    ) -> lx::Result<()> {
347        let offset = offset.try_into().map_err(|_| lx::Error::EINVAL)?;
348        let len = len.try_into().map_err(|_| lx::Error::EINVAL)?;
349        self.region.map(offset, &file, file_offset, len, writable)?;
350        Ok(())
351    }
352
353    fn unmap(&self, offset: u64, len: u64) -> lx::Result<()> {
354        let offset = offset.try_into().map_err(|_| lx::Error::EINVAL)?;
355        let len = len.try_into().map_err(|_| lx::Error::EINVAL)?;
356        self.region.unmap(offset, len)?;
357        Ok(())
358    }
359
360    fn clear(&self) {
361        let result = self.region.unmap(0, self.size as usize);
362        if let Err(result) = result {
363            tracing::error!(
364                error = &result as &dyn std::error::Error,
365                "Failed to unmap shared memory"
366            );
367        }
368    }
369}