1#![forbid(unsafe_code)]
7
8pub mod resolver;
9
10use anyhow::Context;
11use disk_backend::Disk;
12use disk_backend::DiskError;
13use disk_backend::UnmapBehavior;
14use disk_layered::DiskLayer;
15use disk_layered::LayerAttach;
16use disk_layered::LayerConfiguration;
17use disk_layered::LayerIo;
18use disk_layered::LayeredDisk;
19use disk_layered::SectorMarker;
20use disk_layered::WriteNoOverwrite;
21use guestmem::MemoryRead;
22use guestmem::MemoryWrite;
23use inspect::Inspect;
24use parking_lot::RwLock;
25use scsi_buffers::RequestBuffers;
26use std::collections::BTreeMap;
27use std::collections::btree_map::Entry;
28use std::fmt;
29use std::fmt::Debug;
30use std::sync::atomic::AtomicU64;
31use std::sync::atomic::Ordering;
32use thiserror::Error;
33
34pub const DEFAULT_SECTOR_SIZE: u32 = 512;
36
37pub struct LazyRamDiskLayer {
40 len: Option<u64>,
41 sector_size: Option<u32>,
42}
43
44impl LazyRamDiskLayer {
45 pub fn new() -> Self {
51 Self {
52 len: None,
53 sector_size: None,
54 }
55 }
56
57 pub fn with_len(mut self, len: u64) -> Self {
61 self.len = Some(len);
62 self
63 }
64
65 pub fn with_sector_size(mut self, sector_size: u32) -> Self {
70 self.sector_size = Some(sector_size);
71 self
72 }
73}
74
75#[derive(Inspect)]
77#[inspect(extra = "Self::inspect_extra")]
78pub struct RamDiskLayer {
79 #[inspect(flatten)]
80 state: RwLock<RamState>,
81 #[inspect(skip)]
82 sector_count: AtomicU64,
83 #[inspect(skip)]
84 resize_event: event_listener::Event,
85 sector_size: u32,
86 #[inspect(skip)]
87 sector_shift: u32,
88}
89
90#[derive(Inspect)]
91struct RamState {
92 #[inspect(skip)]
93 data: BTreeMap<u64, Box<[u8]>>,
94 #[inspect(skip)] sector_count: u64,
96 zero_after: u64,
97}
98
99impl RamDiskLayer {
100 fn inspect_extra(&self, resp: &mut inspect::Response<'_>) {
101 resp.field_with("committed_size", || {
102 self.state.read().data.len() * self.sector_size as usize
103 })
104 .field_mut_with("sector_count", |new_count| {
105 if let Some(new_count) = new_count {
106 self.resize(new_count.parse().context("invalid sector count")?)?;
107 }
108 anyhow::Ok(self.sector_count())
109 });
110 }
111}
112
113impl Debug for RamDiskLayer {
114 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115 f.debug_struct("RamDiskLayer")
116 .field("sector_count", &self.sector_count)
117 .finish()
118 }
119}
120
121#[derive(Error, Debug)]
123pub enum Error {
124 #[error("disk size {disk_size:#x} is not a multiple of the sector size {sector_size}")]
126 NotSectorMultiple {
127 disk_size: u64,
129 sector_size: u32,
131 },
132 #[error("disk has no sectors")]
134 EmptyDisk,
135 #[error("sector size {0} is not a power-of-two >= 512")]
137 InvalidSectorSize(u32),
138}
139
140impl RamDiskLayer {
141 pub fn new(size: u64) -> Result<Self, Error> {
143 Self::new_with_sector_size(size, DEFAULT_SECTOR_SIZE)
144 }
145
146 pub fn new_with_sector_size(size: u64, sector_size: u32) -> Result<Self, Error> {
150 if !sector_size.is_power_of_two() || sector_size < DEFAULT_SECTOR_SIZE {
151 return Err(Error::InvalidSectorSize(sector_size));
152 }
153 let sector_shift = sector_size.trailing_zeros();
154 let sector_count = {
155 if size == 0 {
156 return Err(Error::EmptyDisk);
157 }
158 if !size.is_multiple_of(sector_size as u64) {
159 return Err(Error::NotSectorMultiple {
160 disk_size: size,
161 sector_size,
162 });
163 }
164 size >> sector_shift
165 };
166 Ok(Self {
167 state: RwLock::new(RamState {
168 data: BTreeMap::new(),
169 sector_count,
170 zero_after: sector_count,
171 }),
172 sector_count: sector_count.into(),
173 resize_event: Default::default(),
174 sector_size,
175 sector_shift,
176 })
177 }
178
179 fn resize(&self, new_sector_count: u64) -> anyhow::Result<()> {
180 if new_sector_count == 0 {
181 anyhow::bail!("invalid sector count");
182 }
183 let _removed = {
185 let mut state = self.state.write();
186 state.zero_after = new_sector_count.min(state.zero_after);
188 state.sector_count = new_sector_count;
189 self.sector_count.store(new_sector_count, Ordering::Relaxed);
194 state.data.split_off(&new_sector_count)
195 };
196 self.resize_event.notify(usize::MAX);
197 Ok(())
198 }
199
200 fn write_maybe_overwrite(
201 &self,
202 buffers: &RequestBuffers<'_>,
203 sector: u64,
204 overwrite: bool,
205 ) -> Result<(), DiskError> {
206 let sector_bytes = self.sector_size as usize;
207 let count = buffers.len() >> self.sector_shift;
208 tracing::trace!(sector, count, "write");
209 let mut state = self.state.write();
210 if sector + count as u64 > state.sector_count {
211 return Err(DiskError::IllegalBlock);
212 }
213 for i in 0..count {
214 let cur = sector + i as u64;
215 let buf = buffers.subrange(i << self.sector_shift, sector_bytes);
216 let mut reader = buf.reader();
217 match state.data.entry(cur) {
218 Entry::Vacant(entry) => {
219 let mut data = vec![0u8; sector_bytes].into_boxed_slice();
220 reader.read(&mut data)?;
221 entry.insert(data);
222 }
223 Entry::Occupied(mut entry) => {
224 if overwrite {
225 reader.read(entry.get_mut())?;
226 }
227 }
228 }
229 }
230 Ok(())
231 }
232}
233
234impl LayerAttach for LazyRamDiskLayer {
235 type Error = Error;
236 type Layer = RamDiskLayer;
237
238 async fn attach(
239 self,
240 lower_layer_metadata: Option<disk_layered::DiskLayerMetadata>,
241 ) -> Result<Self::Layer, Self::Error> {
242 let sector_size = self
243 .sector_size
244 .or(lower_layer_metadata.as_ref().map(|meta| meta.sector_size))
245 .unwrap_or(DEFAULT_SECTOR_SIZE);
246 let total_size = match self.len {
247 Some(len) => len,
248 None => {
249 let meta = lower_layer_metadata.ok_or(Error::EmptyDisk)?;
250 meta.sector_count * meta.sector_size as u64
251 }
252 };
253 RamDiskLayer::new_with_sector_size(total_size, sector_size)
254 }
255}
256
257impl LayerIo for RamDiskLayer {
258 fn layer_type(&self) -> &str {
259 "ram"
260 }
261
262 fn sector_count(&self) -> u64 {
263 self.sector_count.load(Ordering::Relaxed)
264 }
265
266 fn sector_size(&self) -> u32 {
267 self.sector_size
268 }
269
270 fn is_logically_read_only(&self) -> bool {
271 false
272 }
273
274 fn disk_id(&self) -> Option<[u8; 16]> {
275 None
276 }
277
278 fn physical_sector_size(&self) -> u32 {
279 self.sector_size
280 }
281
282 fn is_fua_respected(&self) -> bool {
283 true
284 }
285
286 async fn read(
287 &self,
288 buffers: &RequestBuffers<'_>,
289 sector: u64,
290 mut marker: SectorMarker<'_>,
291 ) -> Result<(), DiskError> {
292 let sector_bytes = self.sector_size as usize;
293 let count = (buffers.len() >> self.sector_shift) as u64;
294 let end = sector + count;
295 tracing::trace!(sector, count, "read");
296 let state = self.state.read();
297 if end > state.sector_count {
298 return Err(DiskError::IllegalBlock);
299 }
300 let mut range = state.data.range(sector..end);
301 let mut last = sector;
302 while last < end {
303 let r = range.next();
304 let next = r.map(|(&s, _)| s).unwrap_or(end);
305 if next > last && next > state.zero_after {
306 let zero_start = last.max(state.zero_after);
309 let zero_count = next - zero_start;
310 let offset = ((zero_start - sector) as usize) << self.sector_shift;
311 let len = (zero_count as usize) << self.sector_shift;
312 buffers.subrange(offset, len).writer().zero(len)?;
313 marker.set_range(zero_start..next);
314 }
315 if let Some((&s, buf)) = r {
316 let offset = ((s - sector) as usize) << self.sector_shift;
317 buffers.subrange(offset, sector_bytes).writer().write(buf)?;
318
319 marker.set(s);
320 }
321 last = next;
322 }
323 Ok(())
324 }
325
326 async fn write(
327 &self,
328 buffers: &RequestBuffers<'_>,
329 sector: u64,
330 _fua: bool,
331 ) -> Result<(), DiskError> {
332 self.write_maybe_overwrite(buffers, sector, true)
333 }
334
335 fn write_no_overwrite(&self) -> Option<impl WriteNoOverwrite> {
336 Some(self)
337 }
338
339 async fn sync_cache(&self) -> Result<(), DiskError> {
340 tracing::trace!("sync_cache");
341 Ok(())
342 }
343
344 async fn wait_resize(&self, sector_count: u64) -> u64 {
345 loop {
346 let listen = self.resize_event.listen();
347 let current = self.sector_count();
348 if current != sector_count {
349 break current;
350 }
351 listen.await;
352 }
353 }
354
355 async fn unmap(
356 &self,
357 sector_offset: u64,
358 sector_count: u64,
359 _block_level_only: bool,
360 next_is_zero: bool,
361 ) -> Result<(), DiskError> {
362 tracing::trace!(sector_offset, sector_count, "unmap");
363 let mut state = self.state.write();
364 if sector_offset + sector_count > state.sector_count {
365 return Err(DiskError::IllegalBlock);
366 }
367 if !next_is_zero {
368 if sector_offset + sector_count < state.zero_after {
371 return Ok(());
372 }
373 state.zero_after = state.zero_after.min(sector_offset);
376 }
377 let mut next_sector = sector_offset;
380 let end = sector_offset + sector_count;
381 while next_sector < end {
382 let Some((§or, _)) = state.data.range_mut(next_sector..).next() else {
383 break;
384 };
385 if sector >= end {
386 break;
387 }
388 state.data.remove(§or);
389 next_sector = sector + 1;
390 }
391 Ok(())
392 }
393
394 fn unmap_behavior(&self) -> UnmapBehavior {
395 UnmapBehavior::Unspecified
398 }
399
400 fn optimal_unmap_sectors(&self) -> u32 {
401 1
402 }
403}
404
405impl WriteNoOverwrite for RamDiskLayer {
406 async fn write_no_overwrite(
407 &self,
408 buffers: &RequestBuffers<'_>,
409 sector: u64,
410 ) -> Result<(), DiskError> {
411 self.write_maybe_overwrite(buffers, sector, false)
412 }
413}
414
415pub fn ram_disk(size: u64, read_only: bool) -> anyhow::Result<Disk> {
421 ram_disk_with_sector_size(size, read_only, DEFAULT_SECTOR_SIZE)
422}
423
424pub fn ram_disk_with_sector_size(
429 size: u64,
430 read_only: bool,
431 sector_size: u32,
432) -> anyhow::Result<Disk> {
433 use futures::future::FutureExt;
434
435 let disk = Disk::new(
436 LayeredDisk::new(
437 read_only,
438 vec![LayerConfiguration {
439 layer: DiskLayer::new(RamDiskLayer::new_with_sector_size(size, sector_size)?),
440 write_through: false,
441 read_cache: false,
442 }],
443 )
444 .now_or_never()
445 .expect("RamDiskLayer won't block")?,
446 )?;
447
448 Ok(disk)
449}
450
451#[cfg(test)]
452mod tests {
453 use super::RamDiskLayer;
454 use disk_backend::DiskIo;
455 use disk_layered::DiskLayer;
456 use disk_layered::LayerConfiguration;
457 use disk_layered::LayerIo;
458 use disk_layered::LayeredDisk;
459 use guestmem::GuestMemory;
460 use pal_async::async_test;
461 use scsi_buffers::OwnedRequestBuffers;
462 use test_with_tracing::test;
463 use zerocopy::IntoBytes;
464
465 const SECTOR_SIZE: u32 = 512;
466 const SECTOR_U64: u64 = SECTOR_SIZE as u64;
467 const SECTOR_USIZE: usize = SECTOR_SIZE as usize;
468
469 fn check(mem: &GuestMemory, sector: u64, start: usize, count: usize, high: u8) {
470 let mut buf = vec![0u32; count * SECTOR_USIZE / 4];
471 mem.read_at(start as u64 * SECTOR_U64, buf.as_mut_bytes())
472 .unwrap();
473 for (i, &b) in buf.iter().enumerate() {
474 let offset = sector * SECTOR_U64 + i as u64 * 4;
475 let expected = (offset as u32 / 4) | ((high as u32) << 24);
476 assert!(
477 b == expected,
478 "at sector {}, word {}, got {:#x}, expected {:#x}",
479 offset / SECTOR_U64,
480 (offset % SECTOR_U64) / 4,
481 b,
482 expected
483 );
484 }
485 }
486
487 async fn read(mem: &GuestMemory, disk: &mut impl DiskIo, sector: u64, count: usize) {
488 disk.read_vectored(
489 &OwnedRequestBuffers::linear(0, count * SECTOR_USIZE, true).buffer(mem),
490 sector,
491 )
492 .await
493 .unwrap();
494 }
495
496 async fn write_layer(
497 mem: &GuestMemory,
498 disk: &mut impl LayerIo,
499 sector: u64,
500 count: usize,
501 high: u8,
502 ) {
503 let buf: Vec<_> = (sector * SECTOR_U64 / 4..(sector + count as u64) * SECTOR_U64 / 4)
504 .map(|x| x as u32 | ((high as u32) << 24))
505 .collect();
506 let len = SECTOR_USIZE * count;
507 mem.write_at(0, &buf.as_bytes()[..len]).unwrap();
508
509 disk.write(
510 &OwnedRequestBuffers::linear(0, len, false).buffer(mem),
511 sector,
512 false,
513 )
514 .await
515 .unwrap();
516 }
517
518 async fn write(mem: &GuestMemory, disk: &mut impl DiskIo, sector: u64, count: usize, high: u8) {
519 let buf: Vec<_> = (sector * SECTOR_U64 / 4..(sector + count as u64) * SECTOR_U64 / 4)
520 .map(|x| x as u32 | ((high as u32) << 24))
521 .collect();
522 let len = SECTOR_USIZE * count;
523 mem.write_at(0, &buf.as_bytes()[..len]).unwrap();
524
525 disk.write_vectored(
526 &OwnedRequestBuffers::linear(0, len, false).buffer(mem),
527 sector,
528 false,
529 )
530 .await
531 .unwrap();
532 }
533
534 async fn prep_disk(size: usize) -> (GuestMemory, LayeredDisk) {
535 let guest_mem = GuestMemory::allocate(size);
536 let mut lower = RamDiskLayer::new(size as u64).unwrap();
537 write_layer(&guest_mem, &mut lower, 0, size / SECTOR_USIZE, 0).await;
538 let upper = RamDiskLayer::new(size as u64).unwrap();
539 let upper = LayeredDisk::new(
540 false,
541 Vec::from_iter([upper, lower].map(|layer| LayerConfiguration {
542 layer: DiskLayer::new(layer),
543 write_through: false,
544 read_cache: false,
545 })),
546 )
547 .await
548 .unwrap();
549 (guest_mem, upper)
550 }
551
552 #[async_test]
553 async fn diff() {
554 const SIZE: usize = 1024 * 1024;
555
556 let (guest_mem, mut upper) = prep_disk(SIZE).await;
557 read(&guest_mem, &mut upper, 10, 2).await;
558 check(&guest_mem, 10, 0, 2, 0);
559 write(&guest_mem, &mut upper, 10, 2, 1).await;
560 write(&guest_mem, &mut upper, 11, 1, 2).await;
561 read(&guest_mem, &mut upper, 9, 5).await;
562 check(&guest_mem, 9, 0, 1, 0);
563 check(&guest_mem, 10, 1, 1, 1);
564 check(&guest_mem, 11, 2, 1, 2);
565 check(&guest_mem, 12, 3, 1, 0);
566 }
567
568 async fn resize(disk: &LayeredDisk, new_size: u64) {
569 let inspect::ValueKind::Unsigned(v) =
570 inspect::update("layers/0/backing/sector_count", &new_size.to_string(), disk)
571 .await
572 .unwrap()
573 .kind
574 else {
575 panic!("bad inspect value")
576 };
577 assert_eq!(new_size, v);
578 }
579
580 #[async_test]
581 async fn test_resize() {
582 const SIZE: usize = 1024 * 1024;
583 const SECTORS: usize = SIZE / SECTOR_USIZE;
584
585 let (guest_mem, mut upper) = prep_disk(SIZE).await;
586 check(&guest_mem, 0, 0, SECTORS, 0);
587 resize(&upper, SECTORS as u64 / 2).await;
588 resize(&upper, SECTORS as u64).await;
589 read(&guest_mem, &mut upper, 0, SECTORS).await;
590 check(&guest_mem, 0, 0, SECTORS / 2, 0);
591 for s in SECTORS / 2..SECTORS {
592 let mut buf = [0u8; SECTOR_USIZE];
593 guest_mem.read_at(s as u64 * SECTOR_U64, &mut buf).unwrap();
594 assert_eq!(buf, [0u8; SECTOR_USIZE]);
595 }
596 }
597
598 #[async_test]
599 async fn test_unmap() {
600 const SIZE: usize = 1024 * 1024;
601 const SECTORS: usize = SIZE / SECTOR_USIZE;
602
603 let (guest_mem, mut upper) = prep_disk(SIZE).await;
604 upper.unmap(0, SECTORS as u64 - 1, false).await.unwrap();
605 read(&guest_mem, &mut upper, 0, SECTORS).await;
606 check(&guest_mem, 0, 0, SECTORS, 0);
607 upper
608 .unmap(SECTORS as u64 / 2, SECTORS as u64 / 2, false)
609 .await
610 .unwrap();
611 read(&guest_mem, &mut upper, 0, SECTORS).await;
612 check(&guest_mem, 0, 0, SECTORS / 2, 0);
613 for s in SECTORS / 2..SECTORS {
614 let mut buf = [0u8; SECTOR_USIZE];
615 guest_mem.read_at(s as u64 * SECTOR_U64, &mut buf).unwrap();
616 assert_eq!(buf, [0u8; SECTOR_USIZE]);
617 }
618 }
619
620 #[async_test]
621 async fn test_4096_sector_write_read() {
622 const SECTOR_4K: usize = 4096;
623 const DISK_SIZE: u64 = 1024 * 1024; const SECTOR_COUNT: u64 = DISK_SIZE / SECTOR_4K as u64;
625
626 let layer = RamDiskLayer::new_with_sector_size(DISK_SIZE, 4096).unwrap();
627 assert_eq!(layer.sector_size(), 4096);
628 assert_eq!(layer.physical_sector_size(), 4096);
629 assert_eq!(layer.sector_count(), SECTOR_COUNT);
630
631 let guest_mem = GuestMemory::allocate(SECTOR_4K * 2);
632
633 let pattern: Vec<u8> = (0..SECTOR_4K).map(|i| (i % 251) as u8).collect();
635 guest_mem.write_at(0, &pattern).unwrap();
636 let bufs = OwnedRequestBuffers::linear(0, SECTOR_4K, false);
637 layer
638 .write(&bufs.buffer(&guest_mem), 0, false)
639 .await
640 .unwrap();
641
642 let disk = LayeredDisk::new(
644 false,
645 vec![LayerConfiguration {
646 layer: DiskLayer::new(layer),
647 write_through: false,
648 read_cache: false,
649 }],
650 )
651 .await
652 .unwrap();
653
654 guest_mem.write_at(0, &vec![0u8; SECTOR_4K]).unwrap();
656 disk.read_vectored(
657 &OwnedRequestBuffers::linear(0, SECTOR_4K, true).buffer(&guest_mem),
658 0,
659 )
660 .await
661 .unwrap();
662
663 let mut readback = vec![0u8; SECTOR_4K];
664 guest_mem.read_at(0, &mut readback).unwrap();
665 assert_eq!(readback, pattern);
666 }
667
668 #[test]
669 fn test_sector_count_4096() {
670 let layer = RamDiskLayer::new_with_sector_size(1024 * 1024, 4096).unwrap();
671 assert_eq!(layer.sector_count(), 256); }
673
674 #[test]
675 fn test_invalid_sector_size_not_power_of_two() {
676 let err = RamDiskLayer::new_with_sector_size(4096, 1000).unwrap_err();
677 assert!(matches!(err, super::Error::InvalidSectorSize(1000)));
678 }
679
680 #[test]
681 fn test_invalid_sector_size_too_small() {
682 let err = RamDiskLayer::new_with_sector_size(4096, 256).unwrap_err();
683 assert!(matches!(err, super::Error::InvalidSectorSize(256)));
684 }
685
686 #[test]
687 fn test_invalid_disk_size_not_multiple() {
688 let err = RamDiskLayer::new_with_sector_size(5000, 4096).unwrap_err();
689 assert!(matches!(
690 err,
691 super::Error::NotSectorMultiple {
692 disk_size: 5000,
693 sector_size: 4096
694 }
695 ));
696 }
697
698 #[async_test]
699 async fn test_lazy_inherits_sector_size() {
700 let lower = RamDiskLayer::new_with_sector_size(1024 * 1024, 4096).unwrap();
701 let upper = LayeredDisk::new(
702 false,
703 vec![
704 LayerConfiguration {
705 layer: DiskLayer::new(super::LazyRamDiskLayer::new()),
706 write_through: false,
707 read_cache: false,
708 },
709 LayerConfiguration {
710 layer: DiskLayer::new(lower),
711 write_through: false,
712 read_cache: false,
713 },
714 ],
715 )
716 .await
717 .unwrap();
718
719 assert_eq!(upper.sector_size(), 4096);
720 assert_eq!(upper.sector_count(), 256);
721 }
722
723 #[async_test]
724 async fn test_mismatched_sector_sizes_rejected() {
725 const DISK_SIZE: u64 = 1024 * 1024;
726
727 let result = LayeredDisk::new(
729 false,
730 vec![
731 LayerConfiguration {
732 layer: DiskLayer::new(
733 RamDiskLayer::new_with_sector_size(DISK_SIZE, 4096).unwrap(),
734 ),
735 write_through: false,
736 read_cache: false,
737 },
738 LayerConfiguration {
739 layer: DiskLayer::new(RamDiskLayer::new(DISK_SIZE).unwrap()),
740 write_through: false,
741 read_cache: false,
742 },
743 ],
744 )
745 .await;
746 assert!(result.is_err());
747
748 let result = LayeredDisk::new(
750 false,
751 vec![
752 LayerConfiguration {
753 layer: DiskLayer::new(RamDiskLayer::new(DISK_SIZE).unwrap()),
754 write_through: false,
755 read_cache: false,
756 },
757 LayerConfiguration {
758 layer: DiskLayer::new(
759 RamDiskLayer::new_with_sector_size(DISK_SIZE, 4096).unwrap(),
760 ),
761 write_through: false,
762 read_cache: false,
763 },
764 ],
765 )
766 .await;
767 assert!(result.is_err());
768 }
769}