1#![expect(unsafe_code)]
9
10use guestmem::AccessError;
11use guestmem::GuestMemory;
12use guestmem::LockedRange;
13use guestmem::LockedRangeImpl;
14use guestmem::MemoryRead;
15use guestmem::MemoryWrite;
16use guestmem::ranges::PagedRange;
17use guestmem::ranges::PagedRangeWriter;
18use safeatomic::AsAtomicBytes;
19use smallvec::SmallVec;
20use std::marker::PhantomData;
21use std::ops::Deref;
22use std::sync::atomic::AtomicU8;
23use std::sync::atomic::AtomicUsize;
24use std::sync::atomic::Ordering;
25use zerocopy::FromBytes;
26use zerocopy::Immutable;
27use zerocopy::IntoBytes;
28use zerocopy::KnownLayout;
29
30#[derive(Debug, Copy, Clone)]
32#[repr(C)]
33pub struct AtomicIoVec {
34 pub address: *const AtomicU8,
36 pub len: usize,
38}
39
40impl Default for AtomicIoVec {
41 fn default() -> Self {
42 Self {
43 address: std::ptr::null(),
44 len: 0,
45 }
46 }
47}
48
49impl From<&'_ [AtomicU8]> for AtomicIoVec {
50 fn from(p: &'_ [AtomicU8]) -> Self {
51 Self {
52 address: p.as_ptr(),
53 len: p.len(),
54 }
55 }
56}
57
58impl AtomicIoVec {
59 pub unsafe fn as_slice_unchecked(&self) -> &[AtomicU8] {
65 unsafe { std::slice::from_raw_parts(self.address, self.len) }
67 }
68}
69
70unsafe impl Send for AtomicIoVec {}
73unsafe impl Sync for AtomicIoVec {}
75
76#[derive(Debug, Copy, Clone, Default)]
79#[repr(transparent)]
80pub struct IoBuffer<'a> {
81 io_vec: AtomicIoVec,
82 phantom: PhantomData<&'a AtomicU8>,
83}
84
85impl<'a> IoBuffer<'a> {
86 pub fn new(buffer: &'a [AtomicU8]) -> Self {
88 Self {
89 io_vec: AtomicIoVec {
90 address: buffer.as_ptr(),
91 len: buffer.len(),
92 },
93 phantom: PhantomData,
94 }
95 }
96
97 pub unsafe fn from_io_vec(io_vec: &AtomicIoVec) -> &Self {
102 unsafe { std::mem::transmute(io_vec) }
104 }
105
106 pub unsafe fn from_io_vecs(io_vecs: &[AtomicIoVec]) -> &[Self] {
111 unsafe { std::mem::transmute(io_vecs) }
113 }
114
115 pub fn as_ptr(&self) -> *const AtomicU8 {
117 self.io_vec.address
118 }
119
120 pub fn len(&self) -> usize {
122 self.io_vec.len
123 }
124}
125
126impl Deref for IoBuffer<'_> {
127 type Target = [AtomicU8];
128
129 fn deref(&self) -> &Self::Target {
130 unsafe { self.io_vec.as_slice_unchecked() }
133 }
134}
135
136const PAGE_SIZE: usize = 4096;
137
138#[repr(C, align(4096))]
139#[derive(Clone, IntoBytes, Immutable, KnownLayout, FromBytes)]
140struct Page([u8; PAGE_SIZE]);
141
142const ZERO_PAGE: Page = Page([0; PAGE_SIZE]);
143
144pub struct BounceBuffer {
146 pages: Vec<Page>,
147 io_vec: AtomicIoVec,
148}
149
150impl BounceBuffer {
151 pub fn new(size: usize) -> Self {
153 let mut pages = vec![ZERO_PAGE; size.div_ceil(PAGE_SIZE)];
154 let io_vec = pages.as_mut_bytes()[..size].as_atomic_bytes().into();
155 BounceBuffer { pages, io_vec }
156 }
157
158 fn len(&self) -> usize {
159 self.io_vec.len
160 }
161
162 pub fn as_mut_bytes(&mut self) -> &mut [u8] {
164 unsafe { std::slice::from_raw_parts_mut(self.pages.as_mut_ptr().cast::<u8>(), self.len()) }
168 }
169
170 pub fn io_vecs(&self) -> &[IoBuffer<'_>] {
174 std::slice::from_ref({
175 unsafe { IoBuffer::from_io_vec(&self.io_vec) }
177 })
178 }
179}
180
181pub struct LockedIoBuffers(LockedRangeImpl<LockedIoVecs>);
183
184impl LockedIoBuffers {
185 pub fn io_vecs(&self) -> &[IoBuffer<'_>] {
187 unsafe { IoBuffer::from_io_vecs(&self.0.get().0) }
191 }
192}
193
194struct LockedIoVecs(SmallVec<[AtomicIoVec; 64]>);
195
196impl LockedIoVecs {
197 fn new() -> Self {
198 Self(Default::default())
199 }
200}
201
202impl LockedRange for LockedIoVecs {
203 fn push_sub_range(&mut self, sub_range: &[AtomicU8]) {
204 self.0.push(sub_range.into());
205 }
206}
207
208struct PermissionedMemoryWriter<'a> {
212 range: PagedRange<'a>,
213 writer: PagedRangeWriter<'a>,
214 is_write: bool,
215}
216
217impl PermissionedMemoryWriter<'_> {
218 fn new<'a>(
220 range: PagedRange<'a>,
221 guest_memory: &'a GuestMemory,
222 is_write: bool,
223 ) -> PermissionedMemoryWriter<'a> {
224 let range = if is_write { range } else { PagedRange::empty() };
226 PermissionedMemoryWriter {
227 range,
228 writer: range.writer(guest_memory),
229 is_write,
230 }
231 }
232}
233
234impl MemoryWrite for PermissionedMemoryWriter<'_> {
235 fn write(&mut self, data: &[u8]) -> Result<(), AccessError> {
236 self.writer.write(data).map_err(|e| {
237 if self.is_write {
238 e
239 } else {
240 AccessError::ReadOnly
241 }
242 })
243 }
244
245 fn fill(&mut self, val: u8, len: usize) -> Result<(), AccessError> {
246 self.writer.fill(val, len).map_err(|e| {
247 if self.is_write {
248 e
249 } else {
250 AccessError::ReadOnly
251 }
252 })
253 }
254
255 fn len(&self) -> usize {
256 self.range.len()
257 }
258}
259
260#[derive(Clone, Debug)]
262pub struct RequestBuffers<'a> {
263 range: PagedRange<'a>,
264 guest_memory: &'a GuestMemory,
265 is_write: bool,
266}
267
268impl<'a> RequestBuffers<'a> {
269 pub fn new(guest_memory: &'a GuestMemory, range: PagedRange<'a>, is_write: bool) -> Self {
271 Self {
272 range,
273 guest_memory,
274 is_write,
275 }
276 }
277
278 pub fn is_empty(&self) -> bool {
280 self.range.is_empty()
281 }
282
283 pub fn len(&self) -> usize {
285 self.range.len()
286 }
287
288 pub fn guest_memory(&self) -> &GuestMemory {
290 self.guest_memory
291 }
292
293 pub fn range(&self) -> PagedRange<'_> {
295 self.range
296 }
297
298 pub fn is_aligned(&self, alignment: usize) -> bool {
303 assert!(alignment.is_power_of_two());
304 ((self.range.offset() | self.range.len() | PAGE_SIZE) & (alignment - 1)) == 0
305 }
306
307 pub fn writer(&self) -> impl MemoryWrite + '_ {
311 PermissionedMemoryWriter::new(self.range, self.guest_memory, self.is_write)
312 }
313
314 pub fn reader(&self) -> impl MemoryRead + '_ {
316 self.range.reader(self.guest_memory)
317 }
318
319 pub fn lock(&self, for_write: bool) -> Result<LockedIoBuffers, AccessError> {
323 if for_write && !self.is_write {
324 return Err(AccessError::ReadOnly);
325 }
326 Ok(LockedIoBuffers(
327 self.guest_memory
328 .lock_range(self.range, LockedIoVecs::new())?,
329 ))
330 }
331
332 pub fn subrange(&self, offset: usize, len: usize) -> Self {
336 Self {
337 range: self.range.subrange(offset, len),
338 guest_memory: self.guest_memory,
339 is_write: self.is_write,
340 }
341 }
342}
343
344#[derive(Debug, Clone)]
346pub struct OwnedRequestBuffers {
347 gpns: Vec<u64>,
348 offset: usize,
349 len: usize,
350 is_write: bool,
351}
352
353impl OwnedRequestBuffers {
354 pub fn new(gpns: &[u64]) -> Self {
356 Self::new_unaligned(gpns, 0, gpns.len() * PAGE_SIZE)
357 }
358
359 pub fn new_unaligned(gpns: &[u64], offset: usize, len: usize) -> Self {
362 Self {
363 gpns: gpns.to_vec(),
364 offset,
365 len,
366 is_write: true,
367 }
368 }
369
370 pub fn linear(offset: u64, len: usize, is_write: bool) -> Self {
373 let start_page = offset / PAGE_SIZE as u64;
374 let end_page = offset + (len as u64).div_ceil(PAGE_SIZE as u64);
375 let gpns: Vec<u64> = (start_page..end_page).collect();
376 Self {
377 gpns,
378 offset: (offset % PAGE_SIZE as u64) as usize,
379 len,
380 is_write,
381 }
382 }
383
384 pub fn buffer<'a>(&'a self, guest_memory: &'a GuestMemory) -> RequestBuffers<'a> {
386 RequestBuffers::new(
387 guest_memory,
388 PagedRange::new(self.offset, self.len, &self.gpns).unwrap(),
389 self.is_write,
390 )
391 }
392
393 pub fn len(&self) -> usize {
395 self.len
396 }
397}
398
399pub struct TrackedBounceBuffer<'a> {
402 pub buffer: BounceBuffer,
404 free_pages: &'a AtomicUsize,
406 event: &'a event_listener::Event,
408}
409
410impl Drop for TrackedBounceBuffer<'_> {
411 fn drop(&mut self) {
412 let pages = self.buffer.len().div_ceil(4096);
413 self.free_pages.fetch_add(pages, Ordering::SeqCst);
414 self.event.notify(usize::MAX);
415 }
416}
417
418#[derive(Debug)]
422pub struct BounceBufferTracker {
423 free_pages: Vec<AtomicUsize>,
425 event: Vec<event_listener::Event>,
427}
428
429impl BounceBufferTracker {
430 pub fn new(max_bounce_buffer_pages: usize, threads: usize) -> Self {
432 let mut free_pages = Vec::with_capacity(threads);
433 let mut event = Vec::with_capacity(threads);
434
435 (0..threads).for_each(|_| {
436 event.push(event_listener::Event::new());
437 free_pages.push(AtomicUsize::new(max_bounce_buffer_pages));
438 });
439
440 Self { free_pages, event }
441 }
442
443 pub async fn acquire_bounce_buffers<'a, 'b>(
447 &'b self,
448 size: usize,
449 thread: usize,
450 ) -> Box<TrackedBounceBuffer<'a>>
451 where
452 'b: 'a,
453 {
454 let pages = size.div_ceil(4096);
455 let event = self.event.get(thread).unwrap();
456 let free_pages = self.free_pages.get(thread).unwrap();
457
458 loop {
459 let listener = event.listen();
460 if free_pages
461 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |x| x.checked_sub(pages))
462 .is_ok()
463 {
464 break;
465 }
466 listener.await;
467 }
468
469 Box::new(TrackedBounceBuffer {
470 buffer: BounceBuffer::new(size),
471 free_pages,
472 event,
473 })
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480 use sparse_mmap::SparseMapping;
481 const SIZE_1MB: usize = 1048576;
482
483 #[test]
484 fn correct_read_only_behavior() {
485 let mapping = SparseMapping::new(SIZE_1MB * 4).unwrap();
486 let guest_memory = GuestMemory::new("test-scsi-buffers", mapping);
487 let range = PagedRange::new(0, 4096, &[0]).unwrap();
488 let buffers = RequestBuffers::new(&guest_memory, range, false);
489
490 let r = buffers.writer().write(&[1; 4096]);
491 assert!(
492 matches!(r, Err(AccessError::ReadOnly)),
493 "Expected read-only error, got {:?}",
494 r
495 );
496
497 let r = buffers.writer().fill(1, 4096);
498 assert!(
499 matches!(r, Err(AccessError::ReadOnly)),
500 "Expected read-only error, got {:?}",
501 r
502 );
503
504 assert!(
505 buffers.writer().len() == 0,
506 "Length should be 0 for read-only writer"
507 );
508 }
509}