1use crate::virtio_util::VirtioPayloadReader;
5use crate::virtio_util::VirtioPayloadWriter;
6use anyhow::Context as _;
7use futures::StreamExt;
8use guestmem::GuestMemory;
9use guestmem::MappedMemoryRegion;
10use inspect::InspectMut;
11use pal_async::wait::PolledWait;
12use std::io;
13use std::io::Write;
14use std::sync::Arc;
15use task_control::AsyncRun;
16use task_control::Cancelled;
17use task_control::StopTask;
18use task_control::TaskControl;
19use virtio::DeviceTraits;
20use virtio::DeviceTraitsSharedMemory;
21use virtio::QueueResources;
22use virtio::VirtioDevice;
23use virtio::VirtioQueue;
24use virtio::VirtioQueueCallbackWork;
25use virtio::queue::QueueState;
26use virtio::spec::VirtioDeviceFeatures;
27use vmcore::vm_task::VmTaskDriver;
28use vmcore::vm_task::VmTaskDriverSource;
29use zerocopy::Immutable;
30use zerocopy::IntoBytes;
31use zerocopy::KnownLayout;
32
33#[repr(C)]
35#[derive(IntoBytes, Immutable, KnownLayout)]
36struct VirtioFsDeviceConfig {
37 tag: [u8; 36],
38 num_request_queues: u32,
39}
40
41#[derive(InspectMut)]
43pub struct VirtioFsDevice {
44 task_name: Box<str>,
45 driver: VmTaskDriver,
46 #[inspect(skip)]
47 config: VirtioFsDeviceConfig,
48 #[inspect(skip)]
49 fs: Arc<fuse::Session>,
50 #[inspect(skip)]
51 workers: Vec<TaskControl<VirtioFsWorker, VirtioFsQueue>>,
52 shmem_size: u64,
53 #[inspect(skip)]
54 shared_memory_region: Option<Arc<dyn MappedMemoryRegion>>,
55 #[inspect(skip)]
56 notify_corruption: Arc<dyn Fn() + Sync + Send>,
57}
58
59impl VirtioFsDevice {
60 pub fn new<Fs>(
62 driver_source: &VmTaskDriverSource,
63 tag: &str,
64 fs: Fs,
65 shmem_size: u64,
66 notify_corruption: Option<Arc<dyn Fn() + Sync + Send>>,
67 ) -> Self
68 where
69 Fs: 'static + fuse::Fuse + Send + Sync,
70 {
71 let mut config = VirtioFsDeviceConfig {
72 tag: [0; 36],
73 num_request_queues: 1,
74 };
75
76 let notify_corruption = if let Some(notify) = notify_corruption {
77 notify
78 } else {
79 Arc::new(|| {})
80 };
81
82 let length = std::cmp::min(tag.len(), config.tag.len());
84 config.tag[..length].copy_from_slice(&tag.as_bytes()[..length]);
85
86 Self {
87 task_name: format!("virtiofs-{}", tag).into(),
88 driver: driver_source.simple(),
89 config,
90 fs: Arc::new(fuse::Session::new(fs)),
91 workers: Vec::new(),
92 shmem_size,
93 shared_memory_region: None,
94 notify_corruption,
95 }
96 }
97}
98
99impl VirtioDevice for VirtioFsDevice {
100 fn traits(&self) -> DeviceTraits {
101 DeviceTraits {
102 device_id: virtio::spec::VirtioDeviceType::FS,
103 device_features: VirtioDeviceFeatures::new()
104 .with_bank0(
105 virtio::spec::VirtioDeviceFeaturesBank0::new()
106 .with_ring_event_idx(true)
107 .with_ring_indirect_desc(true),
108 )
109 .with_bank1(virtio::spec::VirtioDeviceFeaturesBank1::new().with_ring_packed(true)),
110 max_queues: 2,
111 device_register_length: self.config.as_bytes().len() as u32,
112 shared_memory: DeviceTraitsSharedMemory {
113 id: 0,
114 size: self.shmem_size,
115 },
116 }
117 }
118
119 async fn read_registers_u32(&mut self, offset: u16) -> u32 {
120 let offset = offset as usize;
121 let config = self.config.as_bytes();
122 if offset < config.len() {
123 u32::from_le_bytes(
124 config[offset..offset + 4]
125 .try_into()
126 .expect("Incorrect length"),
127 )
128 } else {
129 0
130 }
131 }
132
133 async fn write_registers_u32(&mut self, offset: u16, val: u32) {
134 tracing::warn!(offset, val, "[virtiofs] Unknown write",);
135 }
136
137 fn set_shared_memory_region(
138 &mut self,
139 region: &Arc<dyn MappedMemoryRegion>,
140 ) -> anyhow::Result<()> {
141 self.shared_memory_region = Some(region.clone());
142 Ok(())
143 }
144
145 async fn start_queue(
146 &mut self,
147 idx: u16,
148 resources: QueueResources,
149 features: &VirtioDeviceFeatures,
150 initial_state: Option<QueueState>,
151 ) -> anyhow::Result<()> {
152 let mut tc = TaskControl::new(VirtioFsWorker {
153 fs: self.fs.clone(),
154 shared_memory_region: self.shared_memory_region.clone(),
155 shared_memory_size: self.shmem_size,
156 notify_corruption: self.notify_corruption.clone(),
157 });
158
159 let queue_event = PolledWait::new(&self.driver, resources.event)
160 .context("failed to create polled wait")?;
161 let queue = VirtioQueue::new(
162 features.clone(),
163 resources.params,
164 resources.guest_memory.clone(),
165 resources.notify,
166 queue_event,
167 initial_state,
168 )
169 .context("failed to create virtio queue")?;
170
171 tc.insert(
172 self.driver.clone(),
173 &*self.task_name,
174 VirtioFsQueue {
175 queue,
176 mem: resources.guest_memory,
177 },
178 );
179 tc.start();
180
181 let idx = idx as usize;
182 if idx >= self.workers.len() {
183 self.workers.resize_with(idx + 1, || {
184 TaskControl::new(VirtioFsWorker {
185 fs: self.fs.clone(),
186 shared_memory_region: None,
187 shared_memory_size: 0,
188 notify_corruption: self.notify_corruption.clone(),
189 })
190 });
191 }
192 self.workers[idx] = tc;
193 Ok(())
194 }
195
196 async fn stop_queue(&mut self, idx: u16) -> Option<QueueState> {
197 let idx = idx as usize;
198 if idx >= self.workers.len() || !self.workers[idx].has_state() {
199 return None;
200 }
201 self.workers[idx].stop().await;
202 let state = self.workers[idx].remove().queue.queue_state();
203 Some(state)
204 }
205
206 async fn reset(&mut self) {
207 self.workers.clear();
208 if let Some(region) = &self.shared_memory_region {
209 if let Err(e) = region.unmap(0, self.shmem_size as usize) {
210 tracing::error!(
211 error = &e as &dyn std::error::Error,
212 "failed to unmap DAX region on reset"
213 );
214 }
215 }
216 self.shared_memory_region = None;
217 self.fs.destroy();
218 }
219}
220
221struct VirtioFsWorker {
222 fs: Arc<fuse::Session>,
223 shared_memory_region: Option<Arc<dyn MappedMemoryRegion>>,
224 shared_memory_size: u64,
225 notify_corruption: Arc<dyn Fn() + Sync + Send>,
226}
227
228struct VirtioFsQueue {
229 queue: VirtioQueue,
230 mem: GuestMemory,
231}
232
233impl AsyncRun<VirtioFsQueue> for VirtioFsWorker {
234 async fn run(
235 &mut self,
236 stop: &mut StopTask<'_>,
237 state: &mut VirtioFsQueue,
238 ) -> Result<(), Cancelled> {
239 loop {
240 let work = stop.until_stopped(state.queue.next()).await?;
241 let Some(work) = work else { break };
242 match work {
243 Ok(work) => {
244 let bytes = process_virtiofs_request(self, &state.mem, &work);
245 state.queue.complete(work, bytes);
246 }
247 Err(err) => {
248 tracing::error!(
249 error = &err as &dyn std::error::Error,
250 "Failed processing queue"
251 );
252 break;
253 }
254 }
255 }
256 Ok(())
257 }
258}
259
260fn process_virtiofs_request(
261 worker: &VirtioFsWorker,
262 mem: &GuestMemory,
263 work: &VirtioQueueCallbackWork,
264) -> u32 {
265 let reader = VirtioPayloadReader::new(mem, work);
267 let request = match fuse::Request::new(reader) {
268 Ok(request) => request,
269 Err(e) => {
270 tracing::error!(
271 error = &e as &dyn std::error::Error,
272 "[virtiofs] Invalid FUSE message, error"
273 );
274 (worker.notify_corruption)();
276 return 0;
279 }
280 };
281
282 let mut sender = VirtioReplySender {
288 work,
289 mem,
290 bytes_written: 0,
291 };
292 let mapper = worker
293 .shared_memory_region
294 .as_ref()
295 .map(|shared_memory_region| VirtioMapper {
296 region: shared_memory_region.as_ref(),
297 size: worker.shared_memory_size,
298 });
299 worker.fs.dispatch(
300 request,
301 &mut sender,
302 mapper.as_ref().map(|x| x as &dyn fuse::Mapper),
303 );
304 sender.bytes_written
305}
306struct VirtioReplySender<'a> {
311 work: &'a VirtioQueueCallbackWork,
312 mem: &'a GuestMemory,
313 bytes_written: u32,
314}
315
316impl fuse::ReplySender for VirtioReplySender<'_> {
317 fn send(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<()> {
318 let mut writer = VirtioPayloadWriter::new(self.mem, self.work);
319 let mut size = 0;
320
321 for buf in bufs {
324 writer.write_all(buf)?;
325 size += buf.len();
326 }
327
328 self.bytes_written = size as u32;
329 Ok(())
330 }
331}
332
333struct VirtioMapper<'a> {
334 region: &'a dyn MappedMemoryRegion,
335 size: u64,
336}
337
338impl fuse::Mapper for VirtioMapper<'_> {
339 fn map(
340 &self,
341 offset: u64,
342 file: fuse::FileRef<'_>,
343 file_offset: u64,
344 len: u64,
345 writable: bool,
346 ) -> lx::Result<()> {
347 let offset = offset.try_into().map_err(|_| lx::Error::EINVAL)?;
348 let len = len.try_into().map_err(|_| lx::Error::EINVAL)?;
349 self.region.map(offset, &file, file_offset, len, writable)?;
350 Ok(())
351 }
352
353 fn unmap(&self, offset: u64, len: u64) -> lx::Result<()> {
354 let offset = offset.try_into().map_err(|_| lx::Error::EINVAL)?;
355 let len = len.try_into().map_err(|_| lx::Error::EINVAL)?;
356 self.region.unmap(offset, len)?;
357 Ok(())
358 }
359
360 fn clear(&self) {
361 let result = self.region.unmap(0, self.size as usize);
362 if let Err(result) = result {
363 tracing::error!(
364 error = &result as &dyn std::error::Error,
365 "Failed to unmap shared memory"
366 );
367 }
368 }
369}