sparse_mmap/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Memory-related abstractions.
5
6// UNSAFETY: Manual pointer manipulation, dealing with mmap, and a signal handler.
7#![expect(unsafe_code)]
8#![expect(missing_docs)]
9#![expect(clippy::undocumented_unsafe_blocks, clippy::missing_safety_doc)]
10
11pub mod alloc;
12mod trycopy_windows_arm64;
13mod trycopy_windows_x64;
14pub mod unix;
15pub mod windows;
16
17pub use sys::AsMappableRef;
18pub use sys::Mappable;
19pub use sys::MappableRef;
20pub use sys::SparseMapping;
21pub use sys::alloc_shared_memory;
22pub use sys::new_mappable_from_file;
23
24use std::mem::MaybeUninit;
25use std::sync::atomic::AtomicU8;
26use thiserror::Error;
27#[cfg(unix)]
28use unix as sys;
29#[cfg(windows)]
30use windows as sys;
31use zerocopy::FromBytes;
32use zerocopy::Immutable;
33use zerocopy::IntoBytes;
34use zerocopy::KnownLayout;
35
36/// Must be called before using try_copy on Unix platforms.
37pub fn initialize_try_copy() {
38    #[cfg(unix)]
39    {
40        static INIT: std::sync::Once = std::sync::Once::new();
41        INIT.call_once(|| unsafe {
42            let err = install_signal_handlers();
43            if err != 0 {
44                panic!(
45                    "could not install signal handlers: {}",
46                    std::io::Error::from_raw_os_error(err)
47                )
48            }
49        });
50    }
51}
52
53unsafe extern "C" {
54    #[cfg(unix)]
55    fn install_signal_handlers() -> i32;
56
57    fn try_memmove(
58        dest: *mut u8,
59        src: *const u8,
60        length: usize,
61        failure: *mut AccessFailure,
62    ) -> i32;
63    fn try_memset(dest: *mut u8, c: i32, length: usize, failure: *mut AccessFailure) -> i32;
64    fn try_cmpxchg8(
65        dest: *mut u8,
66        expected: &mut u8,
67        desired: u8,
68        failure: *mut AccessFailure,
69    ) -> i32;
70    fn try_cmpxchg16(
71        dest: *mut u16,
72        expected: &mut u16,
73        desired: u16,
74        failure: *mut AccessFailure,
75    ) -> i32;
76    fn try_cmpxchg32(
77        dest: *mut u32,
78        expected: &mut u32,
79        desired: u32,
80        failure: *mut AccessFailure,
81    ) -> i32;
82    fn try_cmpxchg64(
83        dest: *mut u64,
84        expected: &mut u64,
85        desired: u64,
86        failure: *mut AccessFailure,
87    ) -> i32;
88    fn try_read8(dest: *mut u8, src: *const u8, failure: *mut AccessFailure) -> i32;
89    fn try_read16(dest: *mut u16, src: *const u16, failure: *mut AccessFailure) -> i32;
90    fn try_read32(dest: *mut u32, src: *const u32, failure: *mut AccessFailure) -> i32;
91    fn try_read64(dest: *mut u64, src: *const u64, failure: *mut AccessFailure) -> i32;
92    fn try_write8(dest: *mut u8, value: u8, failure: *mut AccessFailure) -> i32;
93    fn try_write16(dest: *mut u16, value: u16, failure: *mut AccessFailure) -> i32;
94    fn try_write32(dest: *mut u32, value: u32, failure: *mut AccessFailure) -> i32;
95    fn try_write64(dest: *mut u64, value: u64, failure: *mut AccessFailure) -> i32;
96}
97
98#[repr(C)]
99struct AccessFailure {
100    address: *mut u8,
101    #[cfg(unix)]
102    si_signo: i32,
103    #[cfg(unix)]
104    si_code: i32,
105}
106
107#[derive(Debug, Error)]
108#[error("failed to {} memory", if self.is_write { "write" } else { "read" })]
109pub struct MemoryError {
110    offset: usize,
111    is_write: bool,
112    #[source]
113    source: OsAccessError,
114}
115
116#[derive(Debug, Error)]
117enum OsAccessError {
118    #[cfg(windows)]
119    #[error("access violation")]
120    AccessViolation,
121    #[cfg(unix)]
122    #[error("SIGSEGV (si_code = {0:x})")]
123    Sigsegv(u32),
124    #[cfg(unix)]
125    #[error("SIGBUS (si_code = {0:x})")]
126    Sigbus(u32),
127}
128
129impl MemoryError {
130    fn new(src: Option<*const u8>, dest: *mut u8, len: usize, failure: &AccessFailure) -> Self {
131        let (offset, is_write) = if failure.address.is_null() {
132            // In the case of a general protection fault (#GP) the provided address is zero.
133            (0, src.is_none())
134        } else if (dest..dest.wrapping_add(len)).contains(&failure.address) {
135            (failure.address as usize - dest as usize, true)
136        } else if let Some(src) = src {
137            if (src..src.wrapping_add(len)).contains(&failure.address.cast_const()) {
138                (failure.address as usize - src as usize, false)
139            } else {
140                panic!(
141                    "invalid failure address: {:p} src: {:p} dest: {:p} len: {:#x}",
142                    failure.address, src, dest, len
143                );
144            }
145        } else {
146            panic!(
147                "invalid failure address: {:p} src: None dest: {:p} len: {:#x}",
148                failure.address, dest, len
149            );
150        };
151        #[cfg(windows)]
152        let source = OsAccessError::AccessViolation;
153        #[cfg(unix)]
154        let source = match failure.si_signo {
155            libc::SIGSEGV => OsAccessError::Sigsegv(failure.si_code as u32),
156            libc::SIGBUS => OsAccessError::Sigbus(failure.si_code as u32),
157            _ => {
158                panic!(
159                    "unexpected signal: {} src: {:?} dest: {:p} len: {:#x}",
160                    failure.si_signo, src, dest, len
161                );
162            }
163        };
164        Self {
165            offset,
166            is_write,
167            source,
168        }
169    }
170
171    /// Returns the byte offset into the buffer at which the access violation
172    /// occurred.
173    pub fn offset(&self) -> usize {
174        self.offset
175    }
176}
177
178/// Copies `count` elements from `src` to `dest`. `src` and `dest` may overlap.
179/// Fails on access violation/SIGSEGV. Note that on case of failure, some of the
180/// bytes (even partial elements) may already have been copied.
181///
182/// This also fails if initialize_try_copy has not been called.
183///
184/// # Safety
185///
186/// This routine is safe to use if the memory pointed to by `src` or `dest` is
187/// being concurrently mutated.
188///
189/// WARNING: This routine should only be used when you know that `src` and
190/// `dest` are valid, reserved addresses but you do not know if they are mapped
191/// with the appropriate protection. For example, this routine is useful if
192/// `dest` is a sparse mapping where some pages are mapped with
193/// PAGE_NOACCESS/PROT_NONE, and some are mapped with PAGE_READWRITE/PROT_WRITE.
194pub unsafe fn try_copy<T>(src: *const T, dest: *mut T, count: usize) -> Result<(), MemoryError> {
195    let mut failure = MaybeUninit::uninit();
196    let len = count * size_of::<T>();
197    // SAFETY: guaranteed by caller.
198    let ret = unsafe {
199        try_memmove(
200            dest.cast::<u8>(),
201            src.cast::<u8>(),
202            len,
203            failure.as_mut_ptr(),
204        )
205    };
206    match ret {
207        0 => Ok(()),
208        _ => Err(MemoryError::new(
209            Some(src.cast()),
210            dest.cast(),
211            len,
212            // SAFETY: failure is initialized in the failure path.
213            unsafe { failure.assume_init_ref() },
214        )),
215    }
216}
217
218/// Writes `count` bytes of the value `val` to `dest`. Fails on access
219/// violation/SIGSEGV. Note that on case of failure, some of the bytes (even
220/// partial elements) may already have been written.
221///
222/// This also fails if initialize_try_copy has not been called.
223///
224/// # Safety
225///
226/// This routine is safe to use if the memory pointed to by `dest` is being
227/// concurrently mutated.
228///
229/// WARNING: This routine should only be used when you know that `dest` is
230/// valid, reserved addresses but you do not know if they are mapped with the
231/// appropriate protection. For example, this routine is useful if `dest` is a
232/// sparse mapping where some pages are mapped with PAGE_NOACCESS/PROT_NONE, and
233/// some are mapped with PAGE_READWRITE/PROT_WRITE.
234pub unsafe fn try_write_bytes<T>(dest: *mut T, val: u8, count: usize) -> Result<(), MemoryError> {
235    let mut failure = MaybeUninit::uninit();
236    let len = count * size_of::<T>();
237    // SAFETY: guaranteed by caller.
238    let ret = unsafe { try_memset(dest.cast::<u8>(), val.into(), len, failure.as_mut_ptr()) };
239    match ret {
240        0 => Ok(()),
241        _ => Err(MemoryError::new(
242            None,
243            dest.cast(),
244            len,
245            // SAFETY: failure is initialized in the failure path.
246            unsafe { failure.assume_init_ref() },
247        )),
248    }
249}
250
251/// Atomically swaps the value at `dest` with `new` when `*dest` is `current`,
252/// using a sequentially-consistent memory ordering.
253///
254/// Returns `Ok(Ok(new))` if the swap was successful, `Ok(Err(*dest))` if the
255/// swap failed, or `Err(MemoryError::AccessViolation)` if the swap could not be
256/// attempted due to an access violation.
257///
258/// Fails at compile time if the size is not 1, 2, 4, or 8 bytes, or if the type
259/// is under-aligned.
260///
261/// # Safety
262///
263/// This routine is safe to use if the memory pointed to by `dest` is being
264/// concurrently mutated.
265///
266/// WARNING: This routine should only be used when you know that `dest` is
267/// valid, reserved addresses but you do not know if they are mapped with the
268/// appropriate protection. For example, this routine is useful if `dest` is a
269/// sparse mapping where some pages are mapped with PAGE_NOACCESS/PROT_NONE, and
270/// some are mapped with PAGE_READWRITE/PROT_WRITE.
271pub unsafe fn try_compare_exchange<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
272    dest: *mut T,
273    mut current: T,
274    new: T,
275) -> Result<Result<T, T>, MemoryError> {
276    const {
277        assert!(matches!(size_of::<T>(), 1 | 2 | 4 | 8));
278        assert!(align_of::<T>() >= size_of::<T>());
279    };
280    let mut failure = MaybeUninit::uninit();
281    // SAFETY: guaranteed by caller
282    let ret = unsafe {
283        match size_of::<T>() {
284            1 => try_cmpxchg8(
285                dest.cast(),
286                std::mem::transmute::<&mut T, &mut u8>(&mut current),
287                std::mem::transmute_copy::<T, u8>(&new),
288                failure.as_mut_ptr(),
289            ),
290            2 => try_cmpxchg16(
291                dest.cast(),
292                std::mem::transmute::<&mut T, &mut u16>(&mut current),
293                std::mem::transmute_copy::<T, u16>(&new),
294                failure.as_mut_ptr(),
295            ),
296            4 => try_cmpxchg32(
297                dest.cast(),
298                std::mem::transmute::<&mut T, &mut u32>(&mut current),
299                std::mem::transmute_copy::<T, u32>(&new),
300                failure.as_mut_ptr(),
301            ),
302            8 => try_cmpxchg64(
303                dest.cast(),
304                std::mem::transmute::<&mut T, &mut u64>(&mut current),
305                std::mem::transmute_copy::<T, u64>(&new),
306                failure.as_mut_ptr(),
307            ),
308            _ => unreachable!(),
309        }
310    };
311    match ret {
312        n if n > 0 => Ok(Ok(new)),
313        0 => Ok(Err(current)),
314        _ => Err(MemoryError::new(
315            None,
316            dest.cast(),
317            size_of::<T>(),
318            // SAFETY: failure is initialized in the failure path.
319            unsafe { failure.assume_init_ref() },
320        )),
321    }
322}
323
324/// Reads the value at `src` using one or more read instructions.
325///
326/// If `T` is 1, 2, 4, or 8 bytes in size, then exactly one read instruction is
327/// used.
328///
329/// Returns `Ok(T)` if the read was successful, or `Err(MemoryError)` if the
330/// read was unsuccessful.
331///
332/// # Safety
333///
334/// This routine is safe to use if the memory pointed to by `src` is being
335/// concurrently mutated.
336///
337/// WARNING: This routine should only be used when you know that `src` is
338/// valid, reserved addresses but you do not know if they are mapped with the
339/// appropriate protection. For example, this routine is useful if `src` is a
340/// sparse mapping where some pages are mapped with PAGE_NOACCESS/PROT_NONE, and
341/// some are mapped with PAGE_READWRITE/PROT_WRITE.
342pub unsafe fn try_read_volatile<T: FromBytes + Immutable + KnownLayout>(
343    src: *const T,
344) -> Result<T, MemoryError> {
345    let mut dest = MaybeUninit::<T>::uninit();
346    let mut failure = MaybeUninit::uninit();
347    // SAFETY: guaranteed by caller
348    let ret = unsafe {
349        match size_of::<T>() {
350            1 => try_read8(dest.as_mut_ptr().cast(), src.cast(), failure.as_mut_ptr()),
351            2 => try_read16(dest.as_mut_ptr().cast(), src.cast(), failure.as_mut_ptr()),
352            4 => try_read32(dest.as_mut_ptr().cast(), src.cast(), failure.as_mut_ptr()),
353            8 => try_read64(dest.as_mut_ptr().cast(), src.cast(), failure.as_mut_ptr()),
354            _ => try_memmove(
355                dest.as_mut_ptr().cast(),
356                src.cast::<u8>(),
357                size_of::<T>(),
358                failure.as_mut_ptr(),
359            ),
360        }
361    };
362    match ret {
363        0 => {
364            // SAFETY: dest was fully initialized by try_read.
365            Ok(unsafe { dest.assume_init() })
366        }
367        _ => Err(MemoryError::new(
368            Some(src.cast()),
369            dest.as_mut_ptr().cast(),
370            size_of::<T>(),
371            // SAFETY: failure is initialized in the failure path.
372            unsafe { failure.assume_init_ref() },
373        )),
374    }
375}
376
377/// Writes `value` at `dest` using one or more write instructions.
378///
379/// If `T` is 1, 2, 4, or 8 bytes in size, then exactly one write instruction is
380/// used.
381///
382/// Returns `Ok(())` if the write was successful, or `Err(MemoryError)` if the
383/// write was unsuccessful.
384///
385/// # Safety
386///
387/// This routine is safe to use if the memory pointed to by `dest` is being
388/// concurrently mutated.
389///
390/// WARNING: This routine should only be used when you know that `dest` is
391/// valid, reserved addresses but you do not know if they are mapped with the
392/// appropriate protection. For example, this routine is useful if `dest` is a
393/// sparse mapping where some pages are mapped with PAGE_NOACCESS/PROT_NONE, and
394/// some are mapped with PAGE_READWRITE/PROT_WRITE.
395pub unsafe fn try_write_volatile<T: IntoBytes + Immutable + KnownLayout>(
396    dest: *mut T,
397    value: &T,
398) -> Result<(), MemoryError> {
399    let mut failure = MaybeUninit::uninit();
400    // SAFETY: guaranteed by caller
401    let ret = unsafe {
402        match size_of::<T>() {
403            1 => try_write8(
404                dest.cast(),
405                std::mem::transmute_copy(value),
406                failure.as_mut_ptr(),
407            ),
408            2 => try_write16(
409                dest.cast(),
410                std::mem::transmute_copy(value),
411                failure.as_mut_ptr(),
412            ),
413            4 => try_write32(
414                dest.cast(),
415                std::mem::transmute_copy(value),
416                failure.as_mut_ptr(),
417            ),
418            8 => try_write64(
419                dest.cast(),
420                std::mem::transmute_copy(value),
421                failure.as_mut_ptr(),
422            ),
423            _ => try_memmove(
424                dest.cast(),
425                std::ptr::from_ref(value).cast(),
426                size_of::<T>(),
427                failure.as_mut_ptr(),
428            ),
429        }
430    };
431    match ret {
432        0 => Ok(()),
433        _ => Err(MemoryError::new(
434            None,
435            dest.cast(),
436            size_of::<T>(),
437            // SAFETY: failure is initialized in the failure path.
438            unsafe { failure.assume_init_ref() },
439        )),
440    }
441}
442
443#[derive(Debug, Error)]
444pub enum SparseMappingError {
445    #[error("out of bounds")]
446    OutOfBounds,
447    #[error(transparent)]
448    Memory(MemoryError),
449}
450
451impl SparseMapping {
452    /// Gets the supported page size for sparse mappings.
453    pub fn page_size() -> usize {
454        sys::page_size()
455    }
456
457    fn check(&self, offset: usize, len: usize) -> Result<(), SparseMappingError> {
458        if self.len() < offset || self.len() - offset < len {
459            return Err(SparseMappingError::OutOfBounds);
460        }
461        Ok(())
462    }
463
464    /// Reads a type `T` from `offset` in the sparse mapping using a single read instruction.
465    ///
466    /// Panics if `T` is not 1, 2, 4, or 8 bytes in size.
467    pub fn read_volatile<T: FromBytes + Immutable + KnownLayout>(
468        &self,
469        offset: usize,
470    ) -> Result<T, SparseMappingError> {
471        assert!(self.is_local(), "cannot read from remote mappings");
472
473        self.check(offset, size_of::<T>())?;
474        // SAFETY: the bounds have been checked above.
475        unsafe { try_read_volatile(self.as_ptr().byte_add(offset).cast()) }
476            .map_err(SparseMappingError::Memory)
477    }
478
479    /// Writes a type `T` at `offset` in the sparse mapping using a single write instruciton.
480    ///
481    /// Panics if `T` is not 1, 2, 4, or 8 bytes in size.
482    pub fn write_volatile<T: IntoBytes + Immutable + KnownLayout>(
483        &self,
484        offset: usize,
485        value: &T,
486    ) -> Result<(), SparseMappingError> {
487        assert!(self.is_local(), "cannot write to remote mappings");
488
489        self.check(offset, size_of::<T>())?;
490        // SAFETY: the bounds have been checked above.
491        unsafe { try_write_volatile(self.as_ptr().byte_add(offset).cast(), value) }
492            .map_err(SparseMappingError::Memory)
493    }
494
495    /// Tries to write into the sparse mapping.
496    pub fn write_at(&self, offset: usize, data: &[u8]) -> Result<(), SparseMappingError> {
497        assert!(self.is_local(), "cannot write to remote mappings");
498
499        self.check(offset, data.len())?;
500        // SAFETY: the bounds have been checked above.
501        unsafe {
502            let dest = self.as_ptr().cast::<u8>().add(offset);
503            try_copy(data.as_ptr(), dest, data.len()).map_err(SparseMappingError::Memory)
504        }
505    }
506
507    /// Tries to read from the sparse mapping.
508    pub fn read_at(&self, offset: usize, data: &mut [u8]) -> Result<(), SparseMappingError> {
509        assert!(self.is_local(), "cannot read from remote mappings");
510
511        self.check(offset, data.len())?;
512        // SAFETY: the bounds have been checked above.
513        unsafe {
514            let src = (self.as_ptr() as *const u8).add(offset);
515            try_copy(src, data.as_mut_ptr(), data.len()).map_err(SparseMappingError::Memory)
516        }
517    }
518
519    /// Tries to read a type `T` from `offset`.
520    pub fn read_plain<T: FromBytes + Immutable + KnownLayout>(
521        &self,
522        offset: usize,
523    ) -> Result<T, SparseMappingError> {
524        if matches!(size_of::<T>(), 1 | 2 | 4 | 8) {
525            self.read_volatile(offset)
526        } else {
527            let mut obj = MaybeUninit::<T>::uninit();
528            // SAFETY: `obj` is a valid target for writes.
529            unsafe {
530                self.read_at(
531                    offset,
532                    std::slice::from_raw_parts_mut(obj.as_mut_ptr().cast::<u8>(), size_of::<T>()),
533                )?;
534            }
535            // SAFETY: `obj` was fully initialized by `read_at`.
536            Ok(unsafe { obj.assume_init() })
537        }
538    }
539
540    /// Tries to fill a region of the sparse mapping with `val`.
541    pub fn fill_at(&self, offset: usize, val: u8, len: usize) -> Result<(), SparseMappingError> {
542        assert!(self.is_local(), "cannot fill remote mappings");
543
544        self.check(offset, len)?;
545        // SAFETY: the bounds have been checked above.
546        unsafe {
547            let dest = self.as_ptr().cast::<u8>().add(offset);
548            try_write_bytes(dest, val, len).map_err(SparseMappingError::Memory)
549        }
550    }
551
552    /// Gets a slice for accessing the mapped data directly.
553    ///
554    /// This is safe from a Rust memory model perspective, since the underlying
555    /// VA is either mapped and is owned in a shared state by this object (in
556    /// which case &[AtomicU8] access from multiple threads is fine), or the VA
557    /// is not mapped but is reserved and so will not be mapped by another Rust
558    /// object.
559    ///
560    /// In the latter case, actually accessing the data may cause a fault, which
561    /// will likely lead to a process crash, so care must nonetheless be taken
562    /// when using this method.
563    pub fn atomic_slice(&self, start: usize, len: usize) -> &[AtomicU8] {
564        assert!(self.len() >= start && self.len() - start >= len);
565        // SAFETY: slice is within the mapped range
566        unsafe { std::slice::from_raw_parts((self.as_ptr() as *const AtomicU8).add(start), len) }
567    }
568}
569
570#[cfg(test)]
571mod tests {
572    use super::*;
573
574    #[derive(Copy, Clone, Debug)]
575    enum Primitive {
576        Read,
577        Write,
578        CompareAndSwap,
579    }
580
581    #[repr(u32)]
582    #[derive(Copy, Clone, Debug, Eq, PartialEq)]
583    enum Size {
584        Bit8 = 8,
585        Bit16 = 16,
586        Bit32 = 32,
587        Bit64 = 64,
588    }
589
590    fn test_unsafe_primitive(primitive: Primitive, size: Size) {
591        // NOTE: this test provides a very basic validation of
592        // the compare-and-swap operation, mostly to check that
593        // the failures address in returned correctly. See other tests
594        // for more.
595        let mut dest = !0u64;
596        let dest_addr = std::ptr::from_mut(&mut dest).cast::<()>();
597        let src = 0x5555_5555_5555_5555u64;
598        let src_addr = std::ptr::from_ref(&src).cast::<()>();
599        let bad_addr_mut = 0x100 as *mut (); // Within 0..0x1000
600        let bad_addr = bad_addr_mut.cast_const();
601        let nonsense_addr = !0u64 as *mut ();
602        let expected = if size != Size::Bit64 {
603            dest.wrapping_shl(size as u32) | src.wrapping_shr(64 - (size as u32))
604        } else {
605            src
606        };
607        let mut af = AccessFailure {
608            address: nonsense_addr.cast(),
609            #[cfg(unix)]
610            si_signo: 0,
611            #[cfg(unix)]
612            si_code: 0,
613        };
614        let af_addr = &mut af as *mut _;
615
616        let res = unsafe {
617            match size {
618                Size::Bit8 => match primitive {
619                    Primitive::Read => try_read8(dest_addr.cast(), src_addr.cast(), af_addr),
620                    Primitive::Write => try_write8(dest_addr.cast(), src as u8, af_addr),
621                    Primitive::CompareAndSwap => {
622                        1 - try_cmpxchg8(dest_addr.cast(), &mut (dest as u8), src as u8, af_addr)
623                    }
624                },
625                Size::Bit16 => match primitive {
626                    Primitive::Read => try_read16(dest_addr.cast(), src_addr.cast(), af_addr),
627                    Primitive::Write => try_write16(dest_addr.cast(), src as u16, af_addr),
628                    Primitive::CompareAndSwap => {
629                        1 - try_cmpxchg16(dest_addr.cast(), &mut (dest as u16), src as u16, af_addr)
630                    }
631                },
632                Size::Bit32 => match primitive {
633                    Primitive::Read => try_read32(dest_addr.cast(), src_addr.cast(), af_addr),
634                    Primitive::Write => try_write32(dest_addr.cast(), src as u32, af_addr),
635                    Primitive::CompareAndSwap => {
636                        1 - try_cmpxchg32(dest_addr.cast(), &mut (dest as u32), src as u32, af_addr)
637                    }
638                },
639                Size::Bit64 => match primitive {
640                    Primitive::Read => try_read64(dest_addr.cast(), src_addr.cast(), af_addr),
641                    Primitive::Write => try_write64(dest_addr.cast(), src, af_addr),
642                    Primitive::CompareAndSwap => {
643                        1 - try_cmpxchg64(dest_addr.cast(), &mut { dest }, src, af_addr)
644                    }
645                },
646            }
647        };
648        assert_eq!(
649            dest, expected,
650            "Expected value must match the result for {primitive:?} and {size:?}"
651        );
652        assert_eq!(
653            res, 0,
654            "Success should be returned for {primitive:?} and {size:?}"
655        );
656        assert_eq!(
657            af.address,
658            nonsense_addr.cast(),
659            "Fault address must not be set for {primitive:?} and {size:?}"
660        );
661
662        let res = unsafe {
663            match size {
664                Size::Bit8 => match primitive {
665                    Primitive::Read => try_read8(dest_addr.cast(), bad_addr.cast(), af_addr),
666                    Primitive::Write => try_write8(bad_addr_mut.cast(), src as u8, af_addr),
667                    Primitive::CompareAndSwap => {
668                        try_cmpxchg8(bad_addr_mut.cast(), &mut (dest as u8), src as u8, af_addr)
669                    }
670                },
671                Size::Bit16 => match primitive {
672                    Primitive::Read => try_read16(dest_addr.cast(), bad_addr.cast(), af_addr),
673                    Primitive::Write => try_write16(bad_addr_mut.cast(), src as u16, af_addr),
674                    Primitive::CompareAndSwap => {
675                        try_cmpxchg16(bad_addr_mut.cast(), &mut (dest as u16), src as u16, af_addr)
676                    }
677                },
678                Size::Bit32 => match primitive {
679                    Primitive::Read => try_read32(dest_addr.cast(), bad_addr.cast(), af_addr),
680                    Primitive::Write => try_write32(bad_addr_mut.cast(), src as u32, af_addr),
681                    Primitive::CompareAndSwap => {
682                        try_cmpxchg32(bad_addr_mut.cast(), &mut (dest as u32), src as u32, af_addr)
683                    }
684                },
685                Size::Bit64 => match primitive {
686                    Primitive::Read => try_read64(dest_addr.cast(), bad_addr.cast(), af_addr),
687                    Primitive::Write => try_write64(bad_addr_mut.cast(), src, af_addr),
688                    Primitive::CompareAndSwap => {
689                        try_cmpxchg64(bad_addr_mut.cast(), &mut { dest }, src, af_addr)
690                    }
691                },
692            }
693        };
694        assert_eq!(
695            dest, expected,
696            "Fault preserved source and destination for {primitive:?} and {size:?}"
697        );
698        assert_eq!(
699            res, -1,
700            "Error code must be returned for {primitive:?} and {size:?}"
701        );
702        assert_eq!(
703            af.address,
704            bad_addr_mut.cast(),
705            "Fault address must be set for {primitive:?} and {size:?}"
706        );
707    }
708
709    #[test]
710    fn test_unsafe_primitives() {
711        initialize_try_copy();
712
713        for primitive in [Primitive::Read, Primitive::Write, Primitive::CompareAndSwap] {
714            for size in [Size::Bit8, Size::Bit16, Size::Bit32, Size::Bit64] {
715                test_unsafe_primitive(primitive, size);
716            }
717        }
718    }
719
720    static BUF: [u8; 65536] = [0xcc; 65536];
721
722    fn test_with(range_size: usize) {
723        let page_size = SparseMapping::page_size();
724
725        let mapping = SparseMapping::new(range_size).unwrap();
726        mapping.alloc(page_size, page_size).unwrap();
727        let slice = unsafe {
728            std::slice::from_raw_parts_mut(mapping.as_ptr().add(page_size).cast::<u8>(), page_size)
729        };
730        slice.copy_from_slice(&BUF[..page_size]);
731        mapping.unmap(page_size, page_size).unwrap();
732
733        mapping.alloc(range_size - page_size, page_size).unwrap();
734        let slice = unsafe {
735            std::slice::from_raw_parts_mut(
736                mapping.as_ptr().add(range_size - page_size).cast::<u8>(),
737                page_size,
738            )
739        };
740        slice.copy_from_slice(&BUF[..page_size]);
741        mapping.unmap(range_size - page_size, page_size).unwrap();
742        drop(mapping);
743    }
744
745    #[test]
746    fn test_sparse_mapping() {
747        test_with(0x100000);
748        test_with(0x200000);
749        test_with(0x200000 + SparseMapping::page_size());
750        test_with(0x40000000);
751        test_with(0x40000000 + SparseMapping::page_size());
752    }
753
754    #[test]
755    fn test_try_copy() {
756        initialize_try_copy();
757
758        let mapping = SparseMapping::new(2 * 1024 * 1024).unwrap();
759        let page_size = SparseMapping::page_size();
760        mapping.alloc(page_size, page_size).unwrap();
761        let base = mapping.as_ptr().cast::<u8>();
762        unsafe {
763            try_copy(BUF.as_ptr(), base, 100).unwrap_err();
764            try_copy(BUF.as_ptr(), base.add(page_size), 100).unwrap();
765            try_copy(BUF.as_ptr(), base.add(page_size), page_size + 1).unwrap_err();
766        }
767    }
768
769    #[test]
770    fn test_cmpxchg() {
771        initialize_try_copy();
772
773        let page_size = SparseMapping::page_size();
774        let mapping = SparseMapping::new(page_size * 2).unwrap();
775        mapping.alloc(0, page_size).unwrap();
776        let base = mapping.as_ptr().cast::<u8>();
777        unsafe {
778            assert_eq!(try_compare_exchange(base.add(8), 0, 1).unwrap().unwrap(), 1);
779            assert_eq!(
780                try_compare_exchange(base.add(8), 0, 2)
781                    .unwrap()
782                    .unwrap_err(),
783                1
784            );
785            assert_eq!(
786                try_compare_exchange(base.cast::<u64>().add(1), 1, 2)
787                    .unwrap()
788                    .unwrap(),
789                2
790            );
791            try_compare_exchange(base.add(page_size), 0, 2).unwrap_err();
792        }
793    }
794
795    #[test]
796    fn test_overlapping_mappings() {
797        #![expect(clippy::identity_op)]
798
799        let page_size = SparseMapping::page_size();
800        let mapping = SparseMapping::new(0x10 * page_size).unwrap();
801        mapping.alloc(0x1 * page_size, 0x4 * page_size).unwrap();
802        mapping.alloc(0x1 * page_size, 0x2 * page_size).unwrap();
803        mapping.alloc(0x2 * page_size, 0x3 * page_size).unwrap();
804        mapping.alloc(0, 0x10 * page_size).unwrap();
805        mapping.alloc(0x8 * page_size, 0x8 * page_size).unwrap();
806        mapping.unmap(0xc * page_size, 0x2 * page_size).unwrap();
807        mapping.alloc(0x9 * page_size, 0x4 * page_size).unwrap();
808        mapping.unmap(0x3 * page_size, 0xb * page_size).unwrap();
809
810        mapping.alloc(0x5 * page_size, 0x4 * page_size).unwrap();
811        mapping.alloc(0x6 * page_size, 0x2 * page_size).unwrap();
812        mapping.alloc(0x6 * page_size, 0x1 * page_size).unwrap();
813        mapping.alloc(0x4 * page_size, 0x3 * page_size).unwrap();
814
815        let shmem = alloc_shared_memory(0x4 * page_size).unwrap();
816        mapping
817            .map_file(0x5 * page_size, 0x4 * page_size, &shmem, 0, true)
818            .unwrap();
819        mapping
820            .map_file(0x6 * page_size, 0x2 * page_size, &shmem, 0, true)
821            .unwrap();
822        mapping
823            .map_file(0x6 * page_size, 0x1 * page_size, &shmem, 0, true)
824            .unwrap();
825        mapping
826            .map_file(0x4 * page_size, 0x3 * page_size, &shmem, 0, true)
827            .unwrap();
828
829        drop(mapping);
830    }
831}