1#![forbid(unsafe_code)]
7
8pub mod resolver;
9
10use anyhow::Context;
11use disk_backend::Disk;
12use disk_backend::DiskError;
13use disk_backend::UnmapBehavior;
14use disk_layered::DiskLayer;
15use disk_layered::LayerAttach;
16use disk_layered::LayerConfiguration;
17use disk_layered::LayerIo;
18use disk_layered::LayeredDisk;
19use disk_layered::SectorMarker;
20use disk_layered::WriteNoOverwrite;
21use guestmem::MemoryRead;
22use guestmem::MemoryWrite;
23use inspect::Inspect;
24use parking_lot::RwLock;
25use scsi_buffers::RequestBuffers;
26use std::collections::BTreeMap;
27use std::collections::btree_map::Entry;
28use std::fmt;
29use std::fmt::Debug;
30use std::sync::atomic::AtomicU64;
31use std::sync::atomic::Ordering;
32use thiserror::Error;
33
34#[non_exhaustive]
37pub struct LazyRamDiskLayer {}
38
39impl LazyRamDiskLayer {
40 pub fn new() -> Self {
42 Self {}
43 }
44}
45
46#[derive(Inspect)]
48#[inspect(extra = "Self::inspect_extra")]
49pub struct RamDiskLayer {
50 #[inspect(flatten)]
51 state: RwLock<RamState>,
52 #[inspect(skip)]
53 sector_count: AtomicU64,
54 #[inspect(skip)]
55 resize_event: event_listener::Event,
56}
57
58#[derive(Inspect)]
59struct RamState {
60 #[inspect(skip)]
61 data: BTreeMap<u64, Sector>,
62 #[inspect(skip)] sector_count: u64,
64 zero_after: u64,
65}
66
67impl RamDiskLayer {
68 fn inspect_extra(&self, resp: &mut inspect::Response<'_>) {
69 resp.field_with("committed_size", || {
70 self.state.read().data.len() * size_of::<Sector>()
71 })
72 .field_mut_with("sector_count", |new_count| {
73 if let Some(new_count) = new_count {
74 self.resize(new_count.parse().context("invalid sector count")?)?;
75 }
76 anyhow::Ok(self.sector_count())
77 });
78 }
79}
80
81impl Debug for RamDiskLayer {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 f.debug_struct("RamDiskLayer")
84 .field("sector_count", &self.sector_count)
85 .finish()
86 }
87}
88
89#[derive(Error, Debug)]
91pub enum Error {
92 #[error("disk size {disk_size:#x} is not a multiple of the sector size {sector_size}")]
94 NotSectorMultiple {
95 disk_size: u64,
97 sector_size: u32,
99 },
100 #[error("disk has no sectors")]
102 EmptyDisk,
103}
104
105struct Sector([u8; 512]);
106
107const SECTOR_SIZE: u32 = 512;
108
109impl RamDiskLayer {
110 pub fn new(size: u64) -> Result<Self, Error> {
112 let sector_count = {
113 if size == 0 {
114 return Err(Error::EmptyDisk);
115 }
116 if size % SECTOR_SIZE as u64 != 0 {
117 return Err(Error::NotSectorMultiple {
118 disk_size: size,
119 sector_size: SECTOR_SIZE,
120 });
121 }
122 size / SECTOR_SIZE as u64
123 };
124 Ok(Self {
125 state: RwLock::new(RamState {
126 data: BTreeMap::new(),
127 sector_count,
128 zero_after: sector_count,
129 }),
130 sector_count: sector_count.into(),
131 resize_event: Default::default(),
132 })
133 }
134
135 fn resize(&self, new_sector_count: u64) -> anyhow::Result<()> {
136 if new_sector_count == 0 {
137 anyhow::bail!("invalid sector count");
138 }
139 let _removed = {
141 let mut state = self.state.write();
142 state.zero_after = new_sector_count.min(state.zero_after);
144 state.sector_count = new_sector_count;
145 self.sector_count.store(new_sector_count, Ordering::Relaxed);
150 state.data.split_off(&new_sector_count)
151 };
152 self.resize_event.notify(usize::MAX);
153 Ok(())
154 }
155
156 fn write_maybe_overwrite(
157 &self,
158 buffers: &RequestBuffers<'_>,
159 sector: u64,
160 overwrite: bool,
161 ) -> Result<(), DiskError> {
162 let count = buffers.len() / SECTOR_SIZE as usize;
163 tracing::trace!(sector, count, "write");
164 let mut state = self.state.write();
165 if sector + count as u64 > state.sector_count {
166 return Err(DiskError::IllegalBlock);
167 }
168 for i in 0..count {
169 let cur = i + sector as usize;
170 let buf = buffers.subrange(i * SECTOR_SIZE as usize, SECTOR_SIZE as usize);
171 let mut reader = buf.reader();
172 match state.data.entry(cur as u64) {
173 Entry::Vacant(entry) => {
174 entry.insert(Sector(reader.read_plain()?));
175 }
176 Entry::Occupied(mut entry) => {
177 if overwrite {
178 reader.read(&mut entry.get_mut().0)?;
179 }
180 }
181 }
182 }
183 Ok(())
184 }
185}
186
187impl LayerAttach for LazyRamDiskLayer {
188 type Error = Error;
189 type Layer = RamDiskLayer;
190
191 async fn attach(
192 self,
193 lower_layer_metadata: Option<disk_layered::DiskLayerMetadata>,
194 ) -> Result<Self::Layer, Self::Error> {
195 RamDiskLayer::new(
196 lower_layer_metadata
197 .map(|x| x.sector_count * x.sector_size as u64)
198 .ok_or(Error::EmptyDisk)?,
199 )
200 }
201}
202
203impl LayerIo for RamDiskLayer {
204 fn layer_type(&self) -> &str {
205 "ram"
206 }
207
208 fn sector_count(&self) -> u64 {
209 self.sector_count.load(Ordering::Relaxed)
210 }
211
212 fn sector_size(&self) -> u32 {
213 SECTOR_SIZE
214 }
215
216 fn is_logically_read_only(&self) -> bool {
217 false
218 }
219
220 fn disk_id(&self) -> Option<[u8; 16]> {
221 None
222 }
223
224 fn physical_sector_size(&self) -> u32 {
225 SECTOR_SIZE
226 }
227
228 fn is_fua_respected(&self) -> bool {
229 true
230 }
231
232 async fn read(
233 &self,
234 buffers: &RequestBuffers<'_>,
235 sector: u64,
236 mut marker: SectorMarker<'_>,
237 ) -> Result<(), DiskError> {
238 let count = (buffers.len() / SECTOR_SIZE as usize) as u64;
239 let end = sector + count;
240 tracing::trace!(sector, count, "read");
241 let state = self.state.read();
242 if end > state.sector_count {
243 return Err(DiskError::IllegalBlock);
244 }
245 let mut range = state.data.range(sector..end);
246 let mut last = sector;
247 while last < end {
248 let r = range.next();
249 let next = r.map(|(&s, _)| s).unwrap_or(end);
250 if next > last && next > state.zero_after {
251 let zero_start = last.max(state.zero_after);
254 let zero_count = next - zero_start;
255 let offset = (zero_start - sector) as usize * SECTOR_SIZE as usize;
256 let len = zero_count as usize * SECTOR_SIZE as usize;
257 buffers.subrange(offset, len).writer().zero(len)?;
258 marker.set_range(zero_start..next);
259 }
260 if let Some((&s, buf)) = r {
261 let offset = (s - sector) as usize * SECTOR_SIZE as usize;
262 buffers
263 .subrange(offset, SECTOR_SIZE as usize)
264 .writer()
265 .write(&buf.0)?;
266
267 marker.set(s);
268 }
269 last = next;
270 }
271 Ok(())
272 }
273
274 async fn write(
275 &self,
276 buffers: &RequestBuffers<'_>,
277 sector: u64,
278 _fua: bool,
279 ) -> Result<(), DiskError> {
280 self.write_maybe_overwrite(buffers, sector, true)
281 }
282
283 fn write_no_overwrite(&self) -> Option<impl WriteNoOverwrite> {
284 Some(self)
285 }
286
287 async fn sync_cache(&self) -> Result<(), DiskError> {
288 tracing::trace!("sync_cache");
289 Ok(())
290 }
291
292 async fn wait_resize(&self, sector_count: u64) -> u64 {
293 loop {
294 let listen = self.resize_event.listen();
295 let current = self.sector_count();
296 if current != sector_count {
297 break current;
298 }
299 listen.await;
300 }
301 }
302
303 async fn unmap(
304 &self,
305 sector_offset: u64,
306 sector_count: u64,
307 _block_level_only: bool,
308 next_is_zero: bool,
309 ) -> Result<(), DiskError> {
310 tracing::trace!(sector_offset, sector_count, "unmap");
311 let mut state = self.state.write();
312 if sector_offset + sector_count > state.sector_count {
313 return Err(DiskError::IllegalBlock);
314 }
315 if !next_is_zero {
316 if sector_offset + sector_count < state.zero_after {
319 return Ok(());
320 }
321 state.zero_after = state.zero_after.min(sector_offset);
324 }
325 let mut next_sector = sector_offset;
328 let end = sector_offset + sector_count;
329 while next_sector < end {
330 let Some((§or, _)) = state.data.range_mut(next_sector..).next() else {
331 break;
332 };
333 if sector >= end {
334 break;
335 }
336 state.data.remove(§or);
337 next_sector = sector + 1;
338 }
339 Ok(())
340 }
341
342 fn unmap_behavior(&self) -> UnmapBehavior {
343 UnmapBehavior::Unspecified
346 }
347
348 fn optimal_unmap_sectors(&self) -> u32 {
349 1
350 }
351}
352
353impl WriteNoOverwrite for RamDiskLayer {
354 async fn write_no_overwrite(
355 &self,
356 buffers: &RequestBuffers<'_>,
357 sector: u64,
358 ) -> Result<(), DiskError> {
359 self.write_maybe_overwrite(buffers, sector, false)
360 }
361}
362
363pub fn ram_disk(size: u64, read_only: bool) -> anyhow::Result<Disk> {
369 use futures::future::FutureExt;
370
371 let disk = Disk::new(
372 LayeredDisk::new(
373 read_only,
374 vec![LayerConfiguration {
375 layer: DiskLayer::new(RamDiskLayer::new(size)?),
376 write_through: false,
377 read_cache: false,
378 }],
379 )
380 .now_or_never()
381 .expect("RamDiskLayer won't block")?,
382 )?;
383
384 Ok(disk)
385}
386
387#[cfg(test)]
388mod tests {
389 use super::RamDiskLayer;
390 use super::SECTOR_SIZE;
391 use disk_backend::DiskIo;
392 use disk_layered::DiskLayer;
393 use disk_layered::LayerConfiguration;
394 use disk_layered::LayerIo;
395 use disk_layered::LayeredDisk;
396 use guestmem::GuestMemory;
397 use pal_async::async_test;
398 use scsi_buffers::OwnedRequestBuffers;
399 use test_with_tracing::test;
400 use zerocopy::IntoBytes;
401
402 const SECTOR_U64: u64 = SECTOR_SIZE as u64;
403 const SECTOR_USIZE: usize = SECTOR_SIZE as usize;
404
405 fn check(mem: &GuestMemory, sector: u64, start: usize, count: usize, high: u8) {
406 let mut buf = vec![0u32; count * SECTOR_USIZE / 4];
407 mem.read_at(start as u64 * SECTOR_U64, buf.as_mut_bytes())
408 .unwrap();
409 for (i, &b) in buf.iter().enumerate() {
410 let offset = sector * SECTOR_U64 + i as u64 * 4;
411 let expected = (offset as u32 / 4) | ((high as u32) << 24);
412 assert!(
413 b == expected,
414 "at sector {}, word {}, got {:#x}, expected {:#x}",
415 offset / SECTOR_U64,
416 (offset % SECTOR_U64) / 4,
417 b,
418 expected
419 );
420 }
421 }
422
423 async fn read(mem: &GuestMemory, disk: &mut impl DiskIo, sector: u64, count: usize) {
424 disk.read_vectored(
425 &OwnedRequestBuffers::linear(0, count * SECTOR_USIZE, true).buffer(mem),
426 sector,
427 )
428 .await
429 .unwrap();
430 }
431
432 async fn write_layer(
433 mem: &GuestMemory,
434 disk: &mut impl LayerIo,
435 sector: u64,
436 count: usize,
437 high: u8,
438 ) {
439 let buf: Vec<_> = (sector * SECTOR_U64 / 4..(sector + count as u64) * SECTOR_U64 / 4)
440 .map(|x| x as u32 | ((high as u32) << 24))
441 .collect();
442 let len = SECTOR_USIZE * count;
443 mem.write_at(0, &buf.as_bytes()[..len]).unwrap();
444
445 disk.write(
446 &OwnedRequestBuffers::linear(0, len, false).buffer(mem),
447 sector,
448 false,
449 )
450 .await
451 .unwrap();
452 }
453
454 async fn write(mem: &GuestMemory, disk: &mut impl DiskIo, sector: u64, count: usize, high: u8) {
455 let buf: Vec<_> = (sector * SECTOR_U64 / 4..(sector + count as u64) * SECTOR_U64 / 4)
456 .map(|x| x as u32 | ((high as u32) << 24))
457 .collect();
458 let len = SECTOR_USIZE * count;
459 mem.write_at(0, &buf.as_bytes()[..len]).unwrap();
460
461 disk.write_vectored(
462 &OwnedRequestBuffers::linear(0, len, false).buffer(mem),
463 sector,
464 false,
465 )
466 .await
467 .unwrap();
468 }
469
470 async fn prep_disk(size: usize) -> (GuestMemory, LayeredDisk) {
471 let guest_mem = GuestMemory::allocate(size);
472 let mut lower = RamDiskLayer::new(size as u64).unwrap();
473 write_layer(&guest_mem, &mut lower, 0, size / SECTOR_USIZE, 0).await;
474 let upper = RamDiskLayer::new(size as u64).unwrap();
475 let upper = LayeredDisk::new(
476 false,
477 Vec::from_iter([upper, lower].map(|layer| LayerConfiguration {
478 layer: DiskLayer::new(layer),
479 write_through: false,
480 read_cache: false,
481 })),
482 )
483 .await
484 .unwrap();
485 (guest_mem, upper)
486 }
487
488 #[async_test]
489 async fn diff() {
490 const SIZE: usize = 1024 * 1024;
491
492 let (guest_mem, mut upper) = prep_disk(SIZE).await;
493 read(&guest_mem, &mut upper, 10, 2).await;
494 check(&guest_mem, 10, 0, 2, 0);
495 write(&guest_mem, &mut upper, 10, 2, 1).await;
496 write(&guest_mem, &mut upper, 11, 1, 2).await;
497 read(&guest_mem, &mut upper, 9, 5).await;
498 check(&guest_mem, 9, 0, 1, 0);
499 check(&guest_mem, 10, 1, 1, 1);
500 check(&guest_mem, 11, 2, 1, 2);
501 check(&guest_mem, 12, 3, 1, 0);
502 }
503
504 async fn resize(disk: &LayeredDisk, new_size: u64) {
505 let inspect::ValueKind::Unsigned(v) =
506 inspect::update("layers/0/backing/sector_count", &new_size.to_string(), disk)
507 .await
508 .unwrap()
509 .kind
510 else {
511 panic!("bad inspect value")
512 };
513 assert_eq!(new_size, v);
514 }
515
516 #[async_test]
517 async fn test_resize() {
518 const SIZE: usize = 1024 * 1024;
519 const SECTORS: usize = SIZE / SECTOR_USIZE;
520
521 let (guest_mem, mut upper) = prep_disk(SIZE).await;
522 check(&guest_mem, 0, 0, SECTORS, 0);
523 resize(&upper, SECTORS as u64 / 2).await;
524 resize(&upper, SECTORS as u64).await;
525 read(&guest_mem, &mut upper, 0, SECTORS).await;
526 check(&guest_mem, 0, 0, SECTORS / 2, 0);
527 for s in SECTORS / 2..SECTORS {
528 let mut buf = [0u8; SECTOR_USIZE];
529 guest_mem.read_at(s as u64 * SECTOR_U64, &mut buf).unwrap();
530 assert_eq!(buf, [0u8; SECTOR_USIZE]);
531 }
532 }
533
534 #[async_test]
535 async fn test_unmap() {
536 const SIZE: usize = 1024 * 1024;
537 const SECTORS: usize = SIZE / SECTOR_USIZE;
538
539 let (guest_mem, mut upper) = prep_disk(SIZE).await;
540 upper.unmap(0, SECTORS as u64 - 1, false).await.unwrap();
541 read(&guest_mem, &mut upper, 0, SECTORS).await;
542 check(&guest_mem, 0, 0, SECTORS, 0);
543 upper
544 .unmap(SECTORS as u64 / 2, SECTORS as u64 / 2, false)
545 .await
546 .unwrap();
547 read(&guest_mem, &mut upper, 0, SECTORS).await;
548 check(&guest_mem, 0, 0, SECTORS / 2, 0);
549 for s in SECTORS / 2..SECTORS {
550 let mut buf = [0u8; SECTOR_USIZE];
551 guest_mem.read_at(s as u64 * SECTOR_U64, &mut buf).unwrap();
552 assert_eq!(buf, [0u8; SECTOR_USIZE]);
553 }
554 }
555}