disk_backend/
sync_wrapper.rsuse crate::Disk;
use futures::executor::block_on;
use guestmem::GuestMemory;
use scsi_buffers::OwnedRequestBuffers;
use std::io;
use std::io::Read;
use std::io::Seek;
use std::io::SeekFrom;
use std::io::Write;
pub struct BlockingDisk {
inner: Disk,
pos: u64,
buffer: Vec<u8>,
buffer_dirty: bool,
}
impl BlockingDisk {
pub fn new(inner: Disk) -> Self {
let sector_size = inner.sector_size();
BlockingDisk {
inner,
pos: 0,
buffer: vec![0; sector_size as usize],
buffer_dirty: false,
}
}
async fn fetch(&mut self) -> io::Result<()> {
if self.buffer_dirty {
block_on(self.flush())?;
}
let guest_mem = GuestMemory::allocate(self.inner.sector_size() as usize);
let read_buffers = OwnedRequestBuffers::linear(0, self.inner.sector_size() as usize, true);
let binding = read_buffers.buffer(&guest_mem);
let result = self
.inner
.read_vectored(&binding, self.pos / self.inner.sector_size() as u64)
.await;
guest_mem
.read_at(0, &mut self.buffer)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("Fetch error: {}", e)))?;
result.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("Fetch error: {}", e)))
}
async fn flush(&mut self) -> io::Result<()> {
if self.buffer_dirty {
let guest_mem = GuestMemory::allocate(self.inner.sector_size() as usize);
guest_mem.write_at(0, &self.buffer).unwrap();
let write_buffers =
OwnedRequestBuffers::linear(0, self.inner.sector_size() as usize, true);
let binding = write_buffers.buffer(&guest_mem);
let future = self.inner.write_vectored(
&binding,
self.pos / self.inner.sector_size() as u64,
true,
);
let result = future.await;
self.buffer_dirty = false;
result.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("Fetch error: {}", e)))
} else {
Ok(())
}
}
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if buf.len() % self.inner.sector_size() as usize == 0 && !self.buffer_dirty {
return self.read_full_sector(buf);
}
let mut total_bytes_read = 0;
let mut remaining = buf.len();
if self.buffer_dirty {
block_on(self.flush())?;
}
while remaining > 0 {
block_on(self.fetch())?;
let offset = (self.pos % self.inner.sector_size() as u64) as usize;
let bytes_to_copy =
std::cmp::min(remaining, self.inner.sector_size() as usize - offset);
buf[total_bytes_read..total_bytes_read + bytes_to_copy]
.copy_from_slice(&self.buffer[offset..offset + bytes_to_copy]);
self.pos += bytes_to_copy as u64;
total_bytes_read += bytes_to_copy;
remaining -= bytes_to_copy;
if remaining > 0 && offset + bytes_to_copy == self.inner.sector_size() as usize {
block_on(self.fetch())?;
}
}
Ok(total_bytes_read)
}
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if buf.len() % self.inner.sector_size() as usize == 0 && !self.buffer_dirty {
return self.write_full_sector(buf);
}
let mut total_bytes_written = 0;
let mut remaining = buf.len();
while remaining > 0 {
let offset = (self.pos % self.inner.sector_size() as u64) as usize;
let bytes_to_copy =
std::cmp::min(remaining, self.inner.sector_size() as usize - offset);
if self.buffer_dirty {
block_on(self.flush())?;
} else if bytes_to_copy < self.inner.sector_size() as usize {
block_on(self.fetch())?;
}
self.buffer[offset..offset + bytes_to_copy]
.copy_from_slice(&buf[total_bytes_written..total_bytes_written + bytes_to_copy]);
self.buffer_dirty = true;
if offset + bytes_to_copy == self.inner.sector_size() as usize {
block_on(self.flush())?;
}
self.pos += bytes_to_copy as u64;
total_bytes_written += bytes_to_copy;
remaining -= bytes_to_copy;
}
Ok(total_bytes_written)
}
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
if self.buffer_dirty {
block_on(self.flush())?;
}
let new_pos = match pos {
SeekFrom::Start(offset) => offset,
SeekFrom::Current(offset) => self.pos.wrapping_add(offset as u64),
SeekFrom::End(offset) => {
let end =
self.inner.sector_count() as i64 * self.inner.sector_size() as i64 + offset;
end.try_into().unwrap()
}
};
self.pos = new_pos;
Ok(new_pos)
}
fn read_full_sector(&mut self, buf: &mut [u8]) -> io::Result<usize> {
assert_eq!(
buf.len() % self.inner.sector_size() as usize,
0,
"Buffer size must be a multiple of sector size"
);
let guest_mem = GuestMemory::allocate(buf.len());
let read_buffers = OwnedRequestBuffers::linear(0, buf.len(), true);
let binding = read_buffers.buffer(&guest_mem);
let future = self
.inner
.read_vectored(&binding, self.pos / self.inner.sector_size() as u64);
block_on(future)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("Read error: {}", e)))?;
guest_mem
.read_at(0, buf)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("Fetch error: {}", e)))?;
self.pos += buf.len() as u64;
Ok(buf.len())
}
fn write_full_sector(&mut self, buf: &[u8]) -> io::Result<usize> {
assert_eq!(
buf.len() % self.inner.sector_size() as usize,
0,
"Buffer size must be a multiple of sector size"
);
let guest_mem = GuestMemory::allocate(buf.len());
guest_mem.write_at(0, buf).unwrap();
let write_buffers = OwnedRequestBuffers::linear(0, buf.len(), true);
let binding = write_buffers.buffer(&guest_mem);
let future =
self.inner
.write_vectored(&binding, self.pos / self.inner.sector_size() as u64, true);
block_on(future)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("Write error: {}", e)))?;
self.pos += buf.len() as u64;
Ok(buf.len())
}
}
impl Read for BlockingDisk {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.read(buf)
}
}
impl Write for BlockingDisk {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
block_on(self.flush())
}
}
impl Seek for BlockingDisk {
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
self.seek(pos)
}
}