disklayer_ram/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! RAM-backed disk layer implementation.
5
6#![forbid(unsafe_code)]
7
8pub mod resolver;
9
10use anyhow::Context;
11use disk_backend::Disk;
12use disk_backend::DiskError;
13use disk_backend::UnmapBehavior;
14use disk_layered::DiskLayer;
15use disk_layered::LayerAttach;
16use disk_layered::LayerConfiguration;
17use disk_layered::LayerIo;
18use disk_layered::LayeredDisk;
19use disk_layered::SectorMarker;
20use disk_layered::WriteNoOverwrite;
21use guestmem::MemoryRead;
22use guestmem::MemoryWrite;
23use inspect::Inspect;
24use parking_lot::RwLock;
25use scsi_buffers::RequestBuffers;
26use std::collections::BTreeMap;
27use std::collections::btree_map::Entry;
28use std::fmt;
29use std::fmt::Debug;
30use std::sync::atomic::AtomicU64;
31use std::sync::atomic::Ordering;
32use thiserror::Error;
33
34/// A disk layer backed by RAM, which lazily infers its topology from the layer
35/// it is being stacked on-top of
36#[non_exhaustive]
37pub struct LazyRamDiskLayer {}
38
39impl LazyRamDiskLayer {
40    /// Create a new lazy RAM-backed disk layer
41    pub fn new() -> Self {
42        Self {}
43    }
44}
45
46/// A disk layer backed entirely by RAM.
47#[derive(Inspect)]
48#[inspect(extra = "Self::inspect_extra")]
49pub struct RamDiskLayer {
50    #[inspect(flatten)]
51    state: RwLock<RamState>,
52    #[inspect(skip)]
53    sector_count: AtomicU64,
54    #[inspect(skip)]
55    resize_event: event_listener::Event,
56}
57
58#[derive(Inspect)]
59struct RamState {
60    #[inspect(skip)]
61    data: BTreeMap<u64, Sector>,
62    #[inspect(skip)] // handled in inspect_extra()
63    sector_count: u64,
64    zero_after: u64,
65}
66
67impl RamDiskLayer {
68    fn inspect_extra(&self, resp: &mut inspect::Response<'_>) {
69        resp.field_with("committed_size", || {
70            self.state.read().data.len() * size_of::<Sector>()
71        })
72        .field_mut_with("sector_count", |new_count| {
73            if let Some(new_count) = new_count {
74                self.resize(new_count.parse().context("invalid sector count")?)?;
75            }
76            anyhow::Ok(self.sector_count())
77        });
78    }
79}
80
81impl Debug for RamDiskLayer {
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        f.debug_struct("RamDiskLayer")
84            .field("sector_count", &self.sector_count)
85            .finish()
86    }
87}
88
89/// An error creating a RAM disk.
90#[derive(Error, Debug)]
91pub enum Error {
92    /// The disk size is not a multiple of the sector size.
93    #[error("disk size {disk_size:#x} is not a multiple of the sector size {sector_size}")]
94    NotSectorMultiple {
95        /// The disk size.
96        disk_size: u64,
97        /// The sector size.
98        sector_size: u32,
99    },
100    /// The disk has no sectors.
101    #[error("disk has no sectors")]
102    EmptyDisk,
103}
104
105struct Sector([u8; 512]);
106
107const SECTOR_SIZE: u32 = 512;
108
109impl RamDiskLayer {
110    /// Makes a new RAM disk layer of `size` bytes.
111    pub fn new(size: u64) -> Result<Self, Error> {
112        let sector_count = {
113            if size == 0 {
114                return Err(Error::EmptyDisk);
115            }
116            if size % SECTOR_SIZE as u64 != 0 {
117                return Err(Error::NotSectorMultiple {
118                    disk_size: size,
119                    sector_size: SECTOR_SIZE,
120                });
121            }
122            size / SECTOR_SIZE as u64
123        };
124        Ok(Self {
125            state: RwLock::new(RamState {
126                data: BTreeMap::new(),
127                sector_count,
128                zero_after: sector_count,
129            }),
130            sector_count: sector_count.into(),
131            resize_event: Default::default(),
132        })
133    }
134
135    fn resize(&self, new_sector_count: u64) -> anyhow::Result<()> {
136        if new_sector_count == 0 {
137            anyhow::bail!("invalid sector count");
138        }
139        // Remove any truncated data and update the sector count under the lock.
140        let _removed = {
141            let mut state = self.state.write();
142            // Remember that any non-present sectors after this point need to be zeroed.
143            state.zero_after = new_sector_count.min(state.zero_after);
144            state.sector_count = new_sector_count;
145            // Cache the sector count in an atomic for the fast path.
146            //
147            // FUTURE: remove uses of .sector_count() in the IO path,
148            // eliminating the need for this.
149            self.sector_count.store(new_sector_count, Ordering::Relaxed);
150            state.data.split_off(&new_sector_count)
151        };
152        self.resize_event.notify(usize::MAX);
153        Ok(())
154    }
155
156    fn write_maybe_overwrite(
157        &self,
158        buffers: &RequestBuffers<'_>,
159        sector: u64,
160        overwrite: bool,
161    ) -> Result<(), DiskError> {
162        let count = buffers.len() / SECTOR_SIZE as usize;
163        tracing::trace!(sector, count, "write");
164        let mut state = self.state.write();
165        if sector + count as u64 > state.sector_count {
166            return Err(DiskError::IllegalBlock);
167        }
168        for i in 0..count {
169            let cur = i + sector as usize;
170            let buf = buffers.subrange(i * SECTOR_SIZE as usize, SECTOR_SIZE as usize);
171            let mut reader = buf.reader();
172            match state.data.entry(cur as u64) {
173                Entry::Vacant(entry) => {
174                    entry.insert(Sector(reader.read_plain()?));
175                }
176                Entry::Occupied(mut entry) => {
177                    if overwrite {
178                        reader.read(&mut entry.get_mut().0)?;
179                    }
180                }
181            }
182        }
183        Ok(())
184    }
185}
186
187impl LayerAttach for LazyRamDiskLayer {
188    type Error = Error;
189    type Layer = RamDiskLayer;
190
191    async fn attach(
192        self,
193        lower_layer_metadata: Option<disk_layered::DiskLayerMetadata>,
194    ) -> Result<Self::Layer, Self::Error> {
195        RamDiskLayer::new(
196            lower_layer_metadata
197                .map(|x| x.sector_count * x.sector_size as u64)
198                .ok_or(Error::EmptyDisk)?,
199        )
200    }
201}
202
203impl LayerIo for RamDiskLayer {
204    fn layer_type(&self) -> &str {
205        "ram"
206    }
207
208    fn sector_count(&self) -> u64 {
209        self.sector_count.load(Ordering::Relaxed)
210    }
211
212    fn sector_size(&self) -> u32 {
213        SECTOR_SIZE
214    }
215
216    fn is_logically_read_only(&self) -> bool {
217        false
218    }
219
220    fn disk_id(&self) -> Option<[u8; 16]> {
221        None
222    }
223
224    fn physical_sector_size(&self) -> u32 {
225        SECTOR_SIZE
226    }
227
228    fn is_fua_respected(&self) -> bool {
229        true
230    }
231
232    async fn read(
233        &self,
234        buffers: &RequestBuffers<'_>,
235        sector: u64,
236        mut marker: SectorMarker<'_>,
237    ) -> Result<(), DiskError> {
238        let count = (buffers.len() / SECTOR_SIZE as usize) as u64;
239        let end = sector + count;
240        tracing::trace!(sector, count, "read");
241        let state = self.state.read();
242        if end > state.sector_count {
243            return Err(DiskError::IllegalBlock);
244        }
245        let mut range = state.data.range(sector..end);
246        let mut last = sector;
247        while last < end {
248            let r = range.next();
249            let next = r.map(|(&s, _)| s).unwrap_or(end);
250            if next > last && next > state.zero_after {
251                // Some non-present sectors need to be zeroed, since they are
252                // after the zero-after point (due to a resize).
253                let zero_start = last.max(state.zero_after);
254                let zero_count = next - zero_start;
255                let offset = (zero_start - sector) as usize * SECTOR_SIZE as usize;
256                let len = zero_count as usize * SECTOR_SIZE as usize;
257                buffers.subrange(offset, len).writer().zero(len)?;
258                marker.set_range(zero_start..next);
259            }
260            if let Some((&s, buf)) = r {
261                let offset = (s - sector) as usize * SECTOR_SIZE as usize;
262                buffers
263                    .subrange(offset, SECTOR_SIZE as usize)
264                    .writer()
265                    .write(&buf.0)?;
266
267                marker.set(s);
268            }
269            last = next;
270        }
271        Ok(())
272    }
273
274    async fn write(
275        &self,
276        buffers: &RequestBuffers<'_>,
277        sector: u64,
278        _fua: bool,
279    ) -> Result<(), DiskError> {
280        self.write_maybe_overwrite(buffers, sector, true)
281    }
282
283    fn write_no_overwrite(&self) -> Option<impl WriteNoOverwrite> {
284        Some(self)
285    }
286
287    async fn sync_cache(&self) -> Result<(), DiskError> {
288        tracing::trace!("sync_cache");
289        Ok(())
290    }
291
292    async fn wait_resize(&self, sector_count: u64) -> u64 {
293        loop {
294            let listen = self.resize_event.listen();
295            let current = self.sector_count();
296            if current != sector_count {
297                break current;
298            }
299            listen.await;
300        }
301    }
302
303    async fn unmap(
304        &self,
305        sector_offset: u64,
306        sector_count: u64,
307        _block_level_only: bool,
308        next_is_zero: bool,
309    ) -> Result<(), DiskError> {
310        tracing::trace!(sector_offset, sector_count, "unmap");
311        let mut state = self.state.write();
312        if sector_offset + sector_count > state.sector_count {
313            return Err(DiskError::IllegalBlock);
314        }
315        if !next_is_zero {
316            // This would create a hole of zeroes, which we cannot represent in
317            // the tree. Ignore the unmap.
318            if sector_offset + sector_count < state.zero_after {
319                return Ok(());
320            }
321            // The unmap is within or will extend the not-present-is-zero
322            // region, so allow it.
323            state.zero_after = state.zero_after.min(sector_offset);
324        }
325        // Sadly, there appears to be no way to remove a range of entries
326        // from a btree map.
327        let mut next_sector = sector_offset;
328        let end = sector_offset + sector_count;
329        while next_sector < end {
330            let Some((&sector, _)) = state.data.range_mut(next_sector..).next() else {
331                break;
332            };
333            if sector >= end {
334                break;
335            }
336            state.data.remove(&sector);
337            next_sector = sector + 1;
338        }
339        Ok(())
340    }
341
342    fn unmap_behavior(&self) -> UnmapBehavior {
343        // This layer zeroes if the lower layer is zero, but otherwise does
344        // nothing, so we must report unspecified.
345        UnmapBehavior::Unspecified
346    }
347
348    fn optimal_unmap_sectors(&self) -> u32 {
349        1
350    }
351}
352
353impl WriteNoOverwrite for RamDiskLayer {
354    async fn write_no_overwrite(
355        &self,
356        buffers: &RequestBuffers<'_>,
357        sector: u64,
358    ) -> Result<(), DiskError> {
359        self.write_maybe_overwrite(buffers, sector, false)
360    }
361}
362
363/// Create a RAM disk of `size` bytes.
364///
365/// This is a convenience function for creating a layered disk with a single RAM
366/// layer. It is useful since non-layered RAM disks are used all over the place,
367/// especially in tests.
368pub fn ram_disk(size: u64, read_only: bool) -> anyhow::Result<Disk> {
369    use futures::future::FutureExt;
370
371    let disk = Disk::new(
372        LayeredDisk::new(
373            read_only,
374            vec![LayerConfiguration {
375                layer: DiskLayer::new(RamDiskLayer::new(size)?),
376                write_through: false,
377                read_cache: false,
378            }],
379        )
380        .now_or_never()
381        .expect("RamDiskLayer won't block")?,
382    )?;
383
384    Ok(disk)
385}
386
387#[cfg(test)]
388mod tests {
389    use super::RamDiskLayer;
390    use super::SECTOR_SIZE;
391    use disk_backend::DiskIo;
392    use disk_layered::DiskLayer;
393    use disk_layered::LayerConfiguration;
394    use disk_layered::LayerIo;
395    use disk_layered::LayeredDisk;
396    use guestmem::GuestMemory;
397    use pal_async::async_test;
398    use scsi_buffers::OwnedRequestBuffers;
399    use test_with_tracing::test;
400    use zerocopy::IntoBytes;
401
402    const SECTOR_U64: u64 = SECTOR_SIZE as u64;
403    const SECTOR_USIZE: usize = SECTOR_SIZE as usize;
404
405    fn check(mem: &GuestMemory, sector: u64, start: usize, count: usize, high: u8) {
406        let mut buf = vec![0u32; count * SECTOR_USIZE / 4];
407        mem.read_at(start as u64 * SECTOR_U64, buf.as_mut_bytes())
408            .unwrap();
409        for (i, &b) in buf.iter().enumerate() {
410            let offset = sector * SECTOR_U64 + i as u64 * 4;
411            let expected = (offset as u32 / 4) | ((high as u32) << 24);
412            assert!(
413                b == expected,
414                "at sector {}, word {}, got {:#x}, expected {:#x}",
415                offset / SECTOR_U64,
416                (offset % SECTOR_U64) / 4,
417                b,
418                expected
419            );
420        }
421    }
422
423    async fn read(mem: &GuestMemory, disk: &mut impl DiskIo, sector: u64, count: usize) {
424        disk.read_vectored(
425            &OwnedRequestBuffers::linear(0, count * SECTOR_USIZE, true).buffer(mem),
426            sector,
427        )
428        .await
429        .unwrap();
430    }
431
432    async fn write_layer(
433        mem: &GuestMemory,
434        disk: &mut impl LayerIo,
435        sector: u64,
436        count: usize,
437        high: u8,
438    ) {
439        let buf: Vec<_> = (sector * SECTOR_U64 / 4..(sector + count as u64) * SECTOR_U64 / 4)
440            .map(|x| x as u32 | ((high as u32) << 24))
441            .collect();
442        let len = SECTOR_USIZE * count;
443        mem.write_at(0, &buf.as_bytes()[..len]).unwrap();
444
445        disk.write(
446            &OwnedRequestBuffers::linear(0, len, false).buffer(mem),
447            sector,
448            false,
449        )
450        .await
451        .unwrap();
452    }
453
454    async fn write(mem: &GuestMemory, disk: &mut impl DiskIo, sector: u64, count: usize, high: u8) {
455        let buf: Vec<_> = (sector * SECTOR_U64 / 4..(sector + count as u64) * SECTOR_U64 / 4)
456            .map(|x| x as u32 | ((high as u32) << 24))
457            .collect();
458        let len = SECTOR_USIZE * count;
459        mem.write_at(0, &buf.as_bytes()[..len]).unwrap();
460
461        disk.write_vectored(
462            &OwnedRequestBuffers::linear(0, len, false).buffer(mem),
463            sector,
464            false,
465        )
466        .await
467        .unwrap();
468    }
469
470    async fn prep_disk(size: usize) -> (GuestMemory, LayeredDisk) {
471        let guest_mem = GuestMemory::allocate(size);
472        let mut lower = RamDiskLayer::new(size as u64).unwrap();
473        write_layer(&guest_mem, &mut lower, 0, size / SECTOR_USIZE, 0).await;
474        let upper = RamDiskLayer::new(size as u64).unwrap();
475        let upper = LayeredDisk::new(
476            false,
477            Vec::from_iter([upper, lower].map(|layer| LayerConfiguration {
478                layer: DiskLayer::new(layer),
479                write_through: false,
480                read_cache: false,
481            })),
482        )
483        .await
484        .unwrap();
485        (guest_mem, upper)
486    }
487
488    #[async_test]
489    async fn diff() {
490        const SIZE: usize = 1024 * 1024;
491
492        let (guest_mem, mut upper) = prep_disk(SIZE).await;
493        read(&guest_mem, &mut upper, 10, 2).await;
494        check(&guest_mem, 10, 0, 2, 0);
495        write(&guest_mem, &mut upper, 10, 2, 1).await;
496        write(&guest_mem, &mut upper, 11, 1, 2).await;
497        read(&guest_mem, &mut upper, 9, 5).await;
498        check(&guest_mem, 9, 0, 1, 0);
499        check(&guest_mem, 10, 1, 1, 1);
500        check(&guest_mem, 11, 2, 1, 2);
501        check(&guest_mem, 12, 3, 1, 0);
502    }
503
504    async fn resize(disk: &LayeredDisk, new_size: u64) {
505        let inspect::ValueKind::Unsigned(v) =
506            inspect::update("layers/0/backing/sector_count", &new_size.to_string(), disk)
507                .await
508                .unwrap()
509                .kind
510        else {
511            panic!("bad inspect value")
512        };
513        assert_eq!(new_size, v);
514    }
515
516    #[async_test]
517    async fn test_resize() {
518        const SIZE: usize = 1024 * 1024;
519        const SECTORS: usize = SIZE / SECTOR_USIZE;
520
521        let (guest_mem, mut upper) = prep_disk(SIZE).await;
522        check(&guest_mem, 0, 0, SECTORS, 0);
523        resize(&upper, SECTORS as u64 / 2).await;
524        resize(&upper, SECTORS as u64).await;
525        read(&guest_mem, &mut upper, 0, SECTORS).await;
526        check(&guest_mem, 0, 0, SECTORS / 2, 0);
527        for s in SECTORS / 2..SECTORS {
528            let mut buf = [0u8; SECTOR_USIZE];
529            guest_mem.read_at(s as u64 * SECTOR_U64, &mut buf).unwrap();
530            assert_eq!(buf, [0u8; SECTOR_USIZE]);
531        }
532    }
533
534    #[async_test]
535    async fn test_unmap() {
536        const SIZE: usize = 1024 * 1024;
537        const SECTORS: usize = SIZE / SECTOR_USIZE;
538
539        let (guest_mem, mut upper) = prep_disk(SIZE).await;
540        upper.unmap(0, SECTORS as u64 - 1, false).await.unwrap();
541        read(&guest_mem, &mut upper, 0, SECTORS).await;
542        check(&guest_mem, 0, 0, SECTORS, 0);
543        upper
544            .unmap(SECTORS as u64 / 2, SECTORS as u64 / 2, false)
545            .await
546            .unwrap();
547        read(&guest_mem, &mut upper, 0, SECTORS).await;
548        check(&guest_mem, 0, 0, SECTORS / 2, 0);
549        for s in SECTORS / 2..SECTORS {
550            let mut buf = [0u8; SECTOR_USIZE];
551            guest_mem.read_at(s as u64 * SECTOR_U64, &mut buf).unwrap();
552            assert_eq!(buf, [0u8; SECTOR_USIZE]);
553        }
554    }
555}