1#![expect(dead_code)]
5
6use crate::virtio_util::VirtioPayloadReader;
7use crate::virtio_util::VirtioPayloadWriter;
8use async_trait::async_trait;
9use guestmem::GuestMemory;
10use guestmem::MappedMemoryRegion;
11use pal_async::task::Spawn;
12use std::io;
13use std::io::Write;
14use std::sync::Arc;
15use task_control::TaskControl;
16use virtio::DeviceTraits;
17use virtio::DeviceTraitsSharedMemory;
18use virtio::Resources;
19use virtio::VirtioDevice;
20use virtio::VirtioQueueCallbackWork;
21use virtio::VirtioQueueState;
22use virtio::VirtioQueueWorker;
23use virtio::VirtioQueueWorkerContext;
24use vmcore::vm_task::VmTaskDriver;
25use vmcore::vm_task::VmTaskDriverSource;
26use zerocopy::Immutable;
27use zerocopy::IntoBytes;
28use zerocopy::KnownLayout;
29
30const VIRTIO_DEVICE_TYPE_FS: u16 = 26;
31
32#[repr(C)]
34#[derive(IntoBytes, Immutable, KnownLayout)]
35struct VirtioFsDeviceConfig {
36 tag: [u8; 36],
37 num_request_queues: u32,
38}
39
40pub struct VirtioFsDevice {
42 name: Box<str>,
43
44 driver: VmTaskDriver,
45 config: VirtioFsDeviceConfig,
46 memory: GuestMemory,
47 fs: Arc<fuse::Session>,
48 workers: Vec<TaskControl<VirtioQueueWorker, VirtioQueueState>>,
49 exit_event: event_listener::Event,
50 shmem_size: u64,
51 notify_corruption: Arc<dyn Fn() + Sync + Send>,
52}
53
54impl VirtioFsDevice {
55 pub fn new<Fs>(
57 driver_source: &VmTaskDriverSource,
58 tag: &str,
59 fs: Fs,
60 memory: GuestMemory,
61 shmem_size: u64,
62 notify_corruption: Option<Arc<dyn Fn() + Sync + Send>>,
63 ) -> Self
64 where
65 Fs: 'static + fuse::Fuse + Send + Sync,
66 {
67 let mut config = VirtioFsDeviceConfig {
68 tag: [0; 36],
69 num_request_queues: 1,
70 };
71
72 let notify_corruption = if let Some(notify) = notify_corruption {
73 notify
74 } else {
75 Arc::new(|| {})
76 };
77
78 let length = std::cmp::min(tag.len(), config.tag.len());
80 config.tag[..length].copy_from_slice(&tag.as_bytes()[..length]);
81
82 Self {
83 name: format!("virtio-fs-{}", tag).into(),
84 driver: driver_source.simple(),
85 config,
86 memory,
87 fs: Arc::new(fuse::Session::new(fs)),
88 workers: Vec::new(),
89 exit_event: event_listener::Event::new(),
90 shmem_size,
91 notify_corruption,
92 }
93 }
94}
95
96impl VirtioDevice for VirtioFsDevice {
97 fn traits(&self) -> DeviceTraits {
98 DeviceTraits {
99 device_id: VIRTIO_DEVICE_TYPE_FS,
100 device_features: 0,
101 max_queues: 2,
102 device_register_length: self.config.as_bytes().len() as u32,
103 shared_memory: DeviceTraitsSharedMemory {
104 id: 0,
105 size: self.shmem_size,
106 },
107 }
108 }
109
110 fn read_registers_u32(&self, offset: u16) -> u32 {
111 let offset = offset as usize;
112 let config = self.config.as_bytes();
113 if offset < config.len() {
114 u32::from_le_bytes(
115 config[offset..offset + 4]
116 .try_into()
117 .expect("Incorrect length"),
118 )
119 } else {
120 0
121 }
122 }
123
124 fn write_registers_u32(&mut self, offset: u16, val: u32) {
125 tracing::warn!(offset, val, "[virtiofs] Unknown write",);
126 }
127
128 fn enable(&mut self, resources: Resources) {
129 self.workers = resources
130 .queues
131 .into_iter()
132 .filter_map(|queue_resources| {
133 if !queue_resources.params.enable {
134 return None;
135 }
136 let worker = VirtioFsWorker {
137 fs: self.fs.clone(),
138 mem: self.memory.clone(),
139 shared_memory_region: resources.shared_memory_region.clone(),
140 shared_memory_size: resources.shared_memory_size,
141 notify_corruption: self.notify_corruption.clone(),
142 };
143 let worker = VirtioQueueWorker::new(self.driver.clone(), Box::new(worker));
144 Some(worker.into_running_task(
145 "virtiofs-virtio-queue".to_string(),
146 self.memory.clone(),
147 resources.features,
148 queue_resources,
149 self.exit_event.listen(),
150 ))
151 })
152 .collect();
153 }
154
155 fn disable(&mut self) {
156 self.exit_event.notify(usize::MAX);
157 let mut workers = self.workers.drain(..).collect::<Vec<_>>();
158 self.driver
159 .spawn("shutdown-virtiofs-queues".to_owned(), async move {
160 futures::future::join_all(workers.iter_mut().map(async |worker| {
161 worker.stop().await;
162 }))
163 .await;
164 })
165 .detach();
166 }
167}
168
169struct VirtioFsWorker {
170 fs: Arc<fuse::Session>,
171 mem: GuestMemory,
172 shared_memory_region: Option<Arc<dyn MappedMemoryRegion>>,
173 shared_memory_size: u64,
174 notify_corruption: Arc<dyn Fn() + Sync + Send>,
175}
176
177#[async_trait]
178impl VirtioQueueWorkerContext for VirtioFsWorker {
179 async fn process_work(&mut self, work: anyhow::Result<VirtioQueueCallbackWork>) -> bool {
180 if let Err(err) = work {
181 tracing::error!(
182 error = err.as_ref() as &dyn std::error::Error,
183 "Failed processing queue"
184 );
185 return false;
186 }
187
188 let mut work = work.unwrap();
189 let reader = VirtioPayloadReader::new(&self.mem, &work);
191 let request = match fuse::Request::new(reader) {
192 Ok(request) => request,
193 Err(e) => {
194 tracing::error!(
195 error = &e as &dyn std::error::Error,
196 "[virtiofs] Invalid FUSE message, error"
197 );
198 (self.notify_corruption)();
200 work.complete(0);
203 return true;
204 }
205 };
206
207 let mut sender = VirtioReplySender {
209 work,
210 mem: &self.mem,
211 };
212 let mapper = self
213 .shared_memory_region
214 .as_ref()
215 .map(|shared_memory_region| VirtioMapper {
216 region: shared_memory_region.as_ref(),
217 size: self.shared_memory_size,
218 });
219 self.fs.dispatch(
220 request,
221 &mut sender,
222 mapper.as_ref().map(|x| x as &dyn fuse::Mapper),
223 );
224 true
225 }
226}
227struct VirtioReplySender<'a> {
229 work: VirtioQueueCallbackWork,
230 mem: &'a GuestMemory,
231}
232
233impl fuse::ReplySender for VirtioReplySender<'_> {
234 fn send(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<()> {
235 let mut writer = VirtioPayloadWriter::new(self.mem, &self.work);
236 let mut size = 0;
237
238 for buf in bufs {
241 writer.write_all(buf)?;
242 size += buf.len();
243 }
244
245 self.work.complete(size as u32);
246 Ok(())
247 }
248}
249
250struct VirtioMapper<'a> {
251 region: &'a dyn MappedMemoryRegion,
252 size: u64,
253}
254
255impl fuse::Mapper for VirtioMapper<'_> {
256 fn map(
257 &self,
258 offset: u64,
259 file: fuse::FileRef<'_>,
260 file_offset: u64,
261 len: u64,
262 writable: bool,
263 ) -> lx::Result<()> {
264 let offset = offset.try_into().map_err(|_| lx::Error::EINVAL)?;
265 let len = len.try_into().map_err(|_| lx::Error::EINVAL)?;
266 self.region.map(offset, &file, file_offset, len, writable)?;
267 Ok(())
268 }
269
270 fn unmap(&self, offset: u64, len: u64) -> lx::Result<()> {
271 let offset = offset.try_into().map_err(|_| lx::Error::EINVAL)?;
272 let len = len.try_into().map_err(|_| lx::Error::EINVAL)?;
273 self.region.unmap(offset, len)?;
274 Ok(())
275 }
276
277 fn clear(&self) {
278 let result = self.region.unmap(0, self.size as usize);
279 if let Err(result) = result {
280 tracing::error!(
281 error = &result as &dyn std::error::Error,
282 "Failed to unmap shared memory"
283 );
284 }
285 }
286}