disk_backend/
sync_wrapper.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! A wrapper around [`Disk`] that adapts the trait for use with
5//! synchronous [`std::io`] traits (such as `Read`, `Write`, `Seek`, etc...).
6//!
7//! NOTE: this is _not_ code that should see wide use across the HvLite
8//! codebase! It was written to support a very-specific use-case: leveraging
9//! existing, synchronous, Rust/C library code that reformats/repartitions
10//! drives.
11//!
12//! The fact that this adapter exists should be considered a implementation
13//! wart, and it would be great if we could swap out any dependant code with
14//! native-async implementations at some point in the future.
15
16use crate::Disk;
17use futures::executor::block_on;
18use guestmem::GuestMemory;
19use scsi_buffers::OwnedRequestBuffers;
20use std::io;
21use std::io::Read;
22use std::io::Seek;
23use std::io::SeekFrom;
24use std::io::Write;
25
26/// Wrapper around [`Disk`] that implements the synchronous [`std::io`]
27/// traits (such as `Read`, `Write`, `Seek`, etc...) using [`block_on`].
28pub struct BlockingDisk {
29    /// Inner disk instance for base operations.
30    inner: Disk,
31    /// The current position in the disk.
32    pos: u64,
33    /// Buffer for temporary data storage during read/write operations.
34    buffer: Vec<u8>,
35    /// A flag to indicate whether the buffer has been modified (true) or not (false).
36    buffer_dirty: bool,
37}
38
39impl BlockingDisk {
40    /// Create a new blocking disk wrapping `inner`.
41    pub fn new(inner: Disk) -> Self {
42        let sector_size = inner.sector_size();
43        BlockingDisk {
44            inner,
45            pos: 0,
46            buffer: vec![0; sector_size as usize],
47            buffer_dirty: false,
48        }
49    }
50
51    /// Fetches data from the disk into the buffer, flushing the buffer if it is dirty.
52    async fn fetch(&mut self) -> io::Result<()> {
53        if self.buffer_dirty {
54            block_on(self.flush())?;
55        }
56        let guest_mem = GuestMemory::allocate(self.inner.sector_size() as usize);
57        let read_buffers = OwnedRequestBuffers::linear(0, self.inner.sector_size() as usize, true);
58        let binding = read_buffers.buffer(&guest_mem);
59        let result = self
60            .inner
61            .read_vectored(&binding, self.pos / self.inner.sector_size() as u64)
62            .await;
63        guest_mem
64            .read_at(0, &mut self.buffer)
65            .map_err(|e| io::Error::other(format!("Fetch error: {}", e)))?;
66        result.map_err(|e| io::Error::other(format!("Fetch error: {}", e)))
67    }
68
69    /// Writes the buffer to the disk if it is dirty.
70    async fn flush(&mut self) -> io::Result<()> {
71        if self.buffer_dirty {
72            let guest_mem = GuestMemory::allocate(self.inner.sector_size() as usize);
73            guest_mem.write_at(0, &self.buffer).unwrap();
74            let write_buffers =
75                OwnedRequestBuffers::linear(0, self.inner.sector_size() as usize, true);
76            let binding = write_buffers.buffer(&guest_mem);
77            let future = self.inner.write_vectored(
78                &binding,
79                self.pos / self.inner.sector_size() as u64,
80                true,
81            );
82            let result = future.await;
83            self.buffer_dirty = false;
84            result.map_err(|e| io::Error::other(format!("Fetch error: {}", e)))
85        } else {
86            Ok(())
87        }
88    }
89
90    /// Reads data from the disk into the provided buffer.
91    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
92        // If the buffer size is a multiple of sector size and the buffer is not dirty
93        // use the read_full_sector method
94        if buf.len() % self.inner.sector_size() as usize == 0 && !self.buffer_dirty {
95            return self.read_full_sector(buf);
96        }
97        // Buffer size is not multiple of sector size
98        let mut total_bytes_read = 0;
99        let mut remaining = buf.len();
100        if self.buffer_dirty {
101            block_on(self.flush())?;
102        }
103        while remaining > 0 {
104            block_on(self.fetch())?;
105            let offset = (self.pos % self.inner.sector_size() as u64) as usize;
106            let bytes_to_copy =
107                std::cmp::min(remaining, self.inner.sector_size() as usize - offset);
108            buf[total_bytes_read..total_bytes_read + bytes_to_copy]
109                .copy_from_slice(&self.buffer[offset..offset + bytes_to_copy]);
110            self.pos += bytes_to_copy as u64;
111            total_bytes_read += bytes_to_copy;
112            remaining -= bytes_to_copy;
113            if remaining > 0 && offset + bytes_to_copy == self.inner.sector_size() as usize {
114                // Reached the end of a sector, fetch the next sector on the next read
115                block_on(self.fetch())?;
116            }
117        }
118        Ok(total_bytes_read)
119    }
120
121    /// Writes data from the provided buffer to the disk.
122    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
123        // If the buffer size is a multiple of sector size and the buffer is not dirty
124        // use the write_full_sector method
125        if buf.len() % self.inner.sector_size() as usize == 0 && !self.buffer_dirty {
126            return self.write_full_sector(buf);
127        }
128        // Buffer size is not multiple of sector size
129        let mut total_bytes_written = 0;
130        let mut remaining = buf.len();
131        while remaining > 0 {
132            let offset = (self.pos % self.inner.sector_size() as u64) as usize;
133            let bytes_to_copy =
134                std::cmp::min(remaining, self.inner.sector_size() as usize - offset);
135            if self.buffer_dirty {
136                block_on(self.flush())?;
137            } else if bytes_to_copy < self.inner.sector_size() as usize {
138                // Fetch the current sector if we are not writing a full sector
139                block_on(self.fetch())?;
140            }
141            self.buffer[offset..offset + bytes_to_copy]
142                .copy_from_slice(&buf[total_bytes_written..total_bytes_written + bytes_to_copy]);
143            self.buffer_dirty = true;
144            // Reached the end of a sector, flush the buffer
145            if offset + bytes_to_copy == self.inner.sector_size() as usize {
146                block_on(self.flush())?;
147            }
148            self.pos += bytes_to_copy as u64;
149            total_bytes_written += bytes_to_copy;
150            remaining -= bytes_to_copy;
151        }
152        Ok(total_bytes_written)
153    }
154
155    /// Adjusts the current position in the disk.
156    fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
157        if self.buffer_dirty {
158            block_on(self.flush())?;
159        }
160        let new_pos = match pos {
161            SeekFrom::Start(offset) => offset,
162            SeekFrom::Current(offset) => self.pos.wrapping_add(offset as u64),
163            SeekFrom::End(offset) => {
164                let end =
165                    self.inner.sector_count() as i64 * self.inner.sector_size() as i64 + offset;
166                end.try_into().unwrap()
167            }
168        };
169        self.pos = new_pos;
170        Ok(new_pos)
171    }
172
173    /// Reads a full sector from the disk into the provided buffer.
174    fn read_full_sector(&mut self, buf: &mut [u8]) -> io::Result<usize> {
175        assert_eq!(
176            buf.len() % self.inner.sector_size() as usize,
177            0,
178            "Buffer size must be a multiple of sector size"
179        );
180        let guest_mem = GuestMemory::allocate(buf.len());
181        let read_buffers = OwnedRequestBuffers::linear(0, buf.len(), true);
182        let binding = read_buffers.buffer(&guest_mem);
183        let future = self
184            .inner
185            .read_vectored(&binding, self.pos / self.inner.sector_size() as u64);
186        block_on(future).map_err(|e| io::Error::other(format!("Read error: {}", e)))?;
187
188        // Copy the data read from guest memory to the input buffer
189        guest_mem
190            .read_at(0, buf)
191            .map_err(|e| io::Error::other(format!("Fetch error: {}", e)))?;
192        // Update the position based on the bytes read
193        self.pos += buf.len() as u64;
194        Ok(buf.len())
195    }
196
197    /// Writes a full sector from the provided buffer to the disk.
198    fn write_full_sector(&mut self, buf: &[u8]) -> io::Result<usize> {
199        assert_eq!(
200            buf.len() % self.inner.sector_size() as usize,
201            0,
202            "Buffer size must be a multiple of sector size"
203        );
204        let guest_mem = GuestMemory::allocate(buf.len());
205        guest_mem.write_at(0, buf).unwrap();
206        let write_buffers = OwnedRequestBuffers::linear(0, buf.len(), true);
207        let binding = write_buffers.buffer(&guest_mem);
208        let future =
209            self.inner
210                .write_vectored(&binding, self.pos / self.inner.sector_size() as u64, true);
211        block_on(future).map_err(|e| io::Error::other(format!("Write error: {}", e)))?;
212        // Update the position based on the bytes written
213        self.pos += buf.len() as u64;
214        Ok(buf.len())
215    }
216}
217
218impl Read for BlockingDisk {
219    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
220        self.read(buf)
221    }
222}
223
224impl Write for BlockingDisk {
225    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
226        self.write(buf)
227    }
228
229    fn flush(&mut self) -> io::Result<()> {
230        block_on(self.flush())
231    }
232}
233
234impl Seek for BlockingDisk {
235    fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
236        self.seek(pos)
237    }
238}