membacking\memory_manager/
device_memory.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! This implements the [`MemoryMapper`] trait and related functionality for
5//! [`GuestMemoryManager`](super::GuestMemoryManager).
6
7use super::DEVICE_PRIORITY;
8use crate::mapping_manager::Mappable;
9use crate::region_manager::MapParams;
10use crate::region_manager::RegionHandle;
11use crate::region_manager::RegionManagerClient;
12use futures::executor::block_on;
13use guestmem::MappableGuestMemory;
14use guestmem::MappedMemoryRegion;
15use guestmem::MemoryMapper;
16use memory_range::MemoryRange;
17use parking_lot::Mutex;
18use std::io;
19use std::sync::Arc;
20
21/// A [`MemoryMapper`] implementation for
22/// [`GuestMemoryManager`](super::GuestMemoryManager).
23#[derive(Clone, Debug)]
24pub struct DeviceMemoryMapper {
25    region_manager: RegionManagerClient,
26}
27
28impl DeviceMemoryMapper {
29    pub(super) fn new(region_manager: RegionManagerClient) -> Self {
30        Self { region_manager }
31    }
32}
33
34impl MemoryMapper for DeviceMemoryMapper {
35    fn new_region(
36        &self,
37        len: usize,
38        debug_name: String,
39    ) -> io::Result<(Box<dyn MappableGuestMemory>, Arc<dyn MappedMemoryRegion>)> {
40        let region = Arc::new(DeviceMemoryRegion {
41            len,
42            debug_name,
43            region_manager: self.region_manager.clone(),
44            state: Mutex::new(DeviceRegionState {
45                handle: None,
46                mappings: Vec::new(),
47            }),
48        });
49
50        Ok((Box::new(DeviceMemoryControl(region.clone())), region))
51    }
52}
53
54#[derive(Debug)]
55struct DeviceMemoryRegion {
56    debug_name: String,
57    len: usize,
58    region_manager: RegionManagerClient,
59    state: Mutex<DeviceRegionState>,
60}
61
62#[derive(Debug)]
63struct DeviceRegionState {
64    handle: Option<RegionHandle>,
65    mappings: Vec<DeviceMapping>,
66}
67
68#[derive(Debug)]
69struct DeviceMapping {
70    range: MemoryRange,
71    file_offset: u64,
72    mappable: Mappable,
73    writable: bool,
74}
75
76impl DeviceMemoryRegion {
77    fn validated_memory_range(&self, offset: usize, len: usize) -> io::Result<MemoryRange> {
78        (offset..offset.wrapping_add(len))
79            .try_into()
80            .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))
81    }
82}
83
84impl MappedMemoryRegion for DeviceMemoryRegion {
85    fn map(
86        &self,
87        offset: usize,
88        section: &dyn sparse_mmap::AsMappableRef,
89        file_offset: u64,
90        len: usize,
91        writable: bool,
92    ) -> io::Result<()> {
93        #[cfg(unix)]
94        let mappable = section.as_fd().try_clone_to_owned()?;
95        #[cfg(windows)]
96        let mappable = section.as_handle().try_clone_to_owned()?;
97
98        let range = self.validated_memory_range(offset, len)?;
99        let new_mapping = DeviceMapping {
100            range,
101            file_offset,
102            mappable: mappable.into(),
103            writable,
104        };
105
106        let mut state = self.state.lock();
107        for mapping in &state.mappings {
108            if mapping.range.overlaps(&new_mapping.range) {
109                todo!("support overlapping mappings");
110            }
111        }
112
113        if let Some(handle) = &state.handle {
114            block_on(handle.add_mapping(
115                new_mapping.range,
116                new_mapping.mappable.clone(),
117                new_mapping.file_offset,
118                new_mapping.writable,
119            ));
120        }
121        state.mappings.push(new_mapping);
122        Ok(())
123    }
124
125    fn unmap(&self, offset: usize, len: usize) -> io::Result<()> {
126        let range = self.validated_memory_range(offset, len)?;
127        let mut state = self.state.lock();
128        state.mappings.retain(|mapping| {
129            if !range.contains(&mapping.range) && range.overlaps(&mapping.range) {
130                todo!("support overlapping mappings");
131            }
132            range.contains(&mapping.range)
133        });
134
135        if let Some(handle) = &state.handle {
136            block_on(handle.remove_mappings(range));
137        }
138        Ok(())
139    }
140}
141
142#[derive(Debug)]
143struct DeviceMemoryControl(Arc<DeviceMemoryRegion>);
144
145impl MappableGuestMemory for DeviceMemoryControl {
146    fn map_to_guest(&mut self, gpa: u64, writable: bool) -> io::Result<()> {
147        #[expect(clippy::await_holding_lock)] // Treat all this as sync for now.
148        block_on(async {
149            let mut state = self.0.state.lock();
150            if let Some(handle) = state.handle.take() {
151                handle.teardown().await;
152            }
153            let handle = self
154                .0
155                .region_manager
156                .new_region(
157                    self.0.debug_name.clone(),
158                    MemoryRange::try_from(gpa..gpa.wrapping_add(self.0.len as u64))
159                        .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?,
160                    DEVICE_PRIORITY,
161                )
162                .await
163                .map_err(io::Error::other)?;
164
165            for mapping in &state.mappings {
166                handle
167                    .add_mapping(
168                        mapping.range,
169                        mapping.mappable.clone(),
170                        mapping.file_offset,
171                        mapping.writable,
172                    )
173                    .await;
174            }
175
176            handle
177                .map(MapParams {
178                    writable,
179                    executable: true,
180                    prefetch: false,
181                })
182                .await;
183
184            state.handle = Some(handle);
185            Ok(())
186        })
187    }
188
189    fn unmap_from_guest(&mut self) {
190        let mut state = self.0.state.lock();
191        if let Some(handle) = state.handle.take() {
192            block_on(handle.teardown());
193        }
194    }
195}