disk_backend/
sync_wrapper.rs1use crate::Disk;
17use futures::executor::block_on;
18use guestmem::GuestMemory;
19use scsi_buffers::OwnedRequestBuffers;
20use std::io;
21use std::io::Read;
22use std::io::Seek;
23use std::io::SeekFrom;
24use std::io::Write;
25
26pub struct BlockingDisk {
29 inner: Disk,
31 pos: u64,
33 buffer: Vec<u8>,
35 buffer_dirty: bool,
37}
38
39impl BlockingDisk {
40 pub fn new(inner: Disk) -> Self {
42 let sector_size = inner.sector_size();
43 BlockingDisk {
44 inner,
45 pos: 0,
46 buffer: vec![0; sector_size as usize],
47 buffer_dirty: false,
48 }
49 }
50
51 async fn fetch(&mut self) -> io::Result<()> {
53 if self.buffer_dirty {
54 block_on(self.flush())?;
55 }
56 let guest_mem = GuestMemory::allocate(self.inner.sector_size() as usize);
57 let read_buffers = OwnedRequestBuffers::linear(0, self.inner.sector_size() as usize, true);
58 let binding = read_buffers.buffer(&guest_mem);
59 let result = self
60 .inner
61 .read_vectored(&binding, self.pos / self.inner.sector_size() as u64)
62 .await;
63 guest_mem
64 .read_at(0, &mut self.buffer)
65 .map_err(|e| io::Error::other(format!("Fetch error: {}", e)))?;
66 result.map_err(|e| io::Error::other(format!("Fetch error: {}", e)))
67 }
68
69 async fn flush(&mut self) -> io::Result<()> {
71 if self.buffer_dirty {
72 let guest_mem = GuestMemory::allocate(self.inner.sector_size() as usize);
73 guest_mem.write_at(0, &self.buffer).unwrap();
74 let write_buffers =
75 OwnedRequestBuffers::linear(0, self.inner.sector_size() as usize, true);
76 let binding = write_buffers.buffer(&guest_mem);
77 let future = self.inner.write_vectored(
78 &binding,
79 self.pos / self.inner.sector_size() as u64,
80 true,
81 );
82 let result = future.await;
83 self.buffer_dirty = false;
84 result.map_err(|e| io::Error::other(format!("Fetch error: {}", e)))
85 } else {
86 Ok(())
87 }
88 }
89
90 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
92 if buf.len() % self.inner.sector_size() as usize == 0 && !self.buffer_dirty {
95 return self.read_full_sector(buf);
96 }
97 let mut total_bytes_read = 0;
99 let mut remaining = buf.len();
100 if self.buffer_dirty {
101 block_on(self.flush())?;
102 }
103 while remaining > 0 {
104 block_on(self.fetch())?;
105 let offset = (self.pos % self.inner.sector_size() as u64) as usize;
106 let bytes_to_copy =
107 std::cmp::min(remaining, self.inner.sector_size() as usize - offset);
108 buf[total_bytes_read..total_bytes_read + bytes_to_copy]
109 .copy_from_slice(&self.buffer[offset..offset + bytes_to_copy]);
110 self.pos += bytes_to_copy as u64;
111 total_bytes_read += bytes_to_copy;
112 remaining -= bytes_to_copy;
113 if remaining > 0 && offset + bytes_to_copy == self.inner.sector_size() as usize {
114 block_on(self.fetch())?;
116 }
117 }
118 Ok(total_bytes_read)
119 }
120
121 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
123 if buf.len() % self.inner.sector_size() as usize == 0 && !self.buffer_dirty {
126 return self.write_full_sector(buf);
127 }
128 let mut total_bytes_written = 0;
130 let mut remaining = buf.len();
131 while remaining > 0 {
132 let offset = (self.pos % self.inner.sector_size() as u64) as usize;
133 let bytes_to_copy =
134 std::cmp::min(remaining, self.inner.sector_size() as usize - offset);
135 if self.buffer_dirty {
136 block_on(self.flush())?;
137 } else if bytes_to_copy < self.inner.sector_size() as usize {
138 block_on(self.fetch())?;
140 }
141 self.buffer[offset..offset + bytes_to_copy]
142 .copy_from_slice(&buf[total_bytes_written..total_bytes_written + bytes_to_copy]);
143 self.buffer_dirty = true;
144 if offset + bytes_to_copy == self.inner.sector_size() as usize {
146 block_on(self.flush())?;
147 }
148 self.pos += bytes_to_copy as u64;
149 total_bytes_written += bytes_to_copy;
150 remaining -= bytes_to_copy;
151 }
152 Ok(total_bytes_written)
153 }
154
155 fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
157 if self.buffer_dirty {
158 block_on(self.flush())?;
159 }
160 let new_pos = match pos {
161 SeekFrom::Start(offset) => offset,
162 SeekFrom::Current(offset) => self.pos.wrapping_add(offset as u64),
163 SeekFrom::End(offset) => {
164 let end =
165 self.inner.sector_count() as i64 * self.inner.sector_size() as i64 + offset;
166 end.try_into().unwrap()
167 }
168 };
169 self.pos = new_pos;
170 Ok(new_pos)
171 }
172
173 fn read_full_sector(&mut self, buf: &mut [u8]) -> io::Result<usize> {
175 assert_eq!(
176 buf.len() % self.inner.sector_size() as usize,
177 0,
178 "Buffer size must be a multiple of sector size"
179 );
180 let guest_mem = GuestMemory::allocate(buf.len());
181 let read_buffers = OwnedRequestBuffers::linear(0, buf.len(), true);
182 let binding = read_buffers.buffer(&guest_mem);
183 let future = self
184 .inner
185 .read_vectored(&binding, self.pos / self.inner.sector_size() as u64);
186 block_on(future).map_err(|e| io::Error::other(format!("Read error: {}", e)))?;
187
188 guest_mem
190 .read_at(0, buf)
191 .map_err(|e| io::Error::other(format!("Fetch error: {}", e)))?;
192 self.pos += buf.len() as u64;
194 Ok(buf.len())
195 }
196
197 fn write_full_sector(&mut self, buf: &[u8]) -> io::Result<usize> {
199 assert_eq!(
200 buf.len() % self.inner.sector_size() as usize,
201 0,
202 "Buffer size must be a multiple of sector size"
203 );
204 let guest_mem = GuestMemory::allocate(buf.len());
205 guest_mem.write_at(0, buf).unwrap();
206 let write_buffers = OwnedRequestBuffers::linear(0, buf.len(), true);
207 let binding = write_buffers.buffer(&guest_mem);
208 let future =
209 self.inner
210 .write_vectored(&binding, self.pos / self.inner.sector_size() as u64, true);
211 block_on(future).map_err(|e| io::Error::other(format!("Write error: {}", e)))?;
212 self.pos += buf.len() as u64;
214 Ok(buf.len())
215 }
216}
217
218impl Read for BlockingDisk {
219 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
220 self.read(buf)
221 }
222}
223
224impl Write for BlockingDisk {
225 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
226 self.write(buf)
227 }
228
229 fn flush(&mut self) -> io::Result<()> {
230 block_on(self.flush())
231 }
232}
233
234impl Seek for BlockingDisk {
235 fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
236 self.seek(pos)
237 }
238}