membacking/mapping_manager/
manager.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Implements the mapping manager, which keeps track of the VA mappers and
5//! their currently active mappings. It is responsible for invalidating mappings
6//! in each VA range when they are torn down by the region manager.
7
8use super::mappable::Mappable;
9use super::object_cache::ObjectCache;
10use super::object_cache::ObjectId;
11use super::va_mapper::VaMapper;
12use super::va_mapper::VaMapperError;
13use crate::RemoteProcess;
14use futures::StreamExt;
15use futures::future::join_all;
16use inspect::Inspect;
17use inspect::InspectMut;
18use memory_range::MemoryRange;
19use mesh::MeshPayload;
20use mesh::rpc::Rpc;
21use mesh::rpc::RpcSend;
22use pal_async::task::Spawn;
23use slab::Slab;
24use std::sync::Arc;
25
26/// The mapping manager.
27#[derive(Debug, Inspect)]
28pub struct MappingManager {
29    #[inspect(
30        flatten,
31        with = "|x| inspect::send(&x.req_send, MappingRequest::Inspect)"
32    )]
33    client: MappingManagerClient,
34}
35
36impl MappingManager {
37    /// Returns a new mapping manager that can map addresses up to `max_addr`.
38    pub fn new(spawn: impl Spawn, max_addr: u64) -> Self {
39        let (req_send, mut req_recv) = mesh::mpsc_channel();
40        spawn
41            .spawn("mapping_manager", {
42                let mut task = MappingManagerTask::new();
43                async move {
44                    task.run(&mut req_recv).await;
45                }
46            })
47            .detach();
48        Self {
49            client: MappingManagerClient {
50                id: ObjectId::new(),
51                req_send,
52                max_addr,
53            },
54        }
55    }
56
57    /// Returns an object used to access the mapping manager, potentially from a
58    /// remote process.
59    pub fn client(&self) -> &MappingManagerClient {
60        &self.client
61    }
62}
63
64/// Provides access to the mapping manager.
65#[derive(Debug, MeshPayload, Clone)]
66pub struct MappingManagerClient {
67    req_send: mesh::Sender<MappingRequest>,
68    id: ObjectId,
69    max_addr: u64,
70}
71
72static MAPPER_CACHE: ObjectCache<VaMapper> = ObjectCache::new();
73
74impl MappingManagerClient {
75    /// Returns a VA mapper for this guest memory.
76    ///
77    /// This will single instance the mapper, so this is safe to call multiple times.
78    pub async fn new_mapper(&self) -> Result<Arc<VaMapper>, VaMapperError> {
79        // Get the VA mapper from the mapper cache if possible to avoid keeping
80        // multiple VA ranges for this memory per process.
81        MAPPER_CACHE
82            .get_or_insert_with(&self.id, async {
83                VaMapper::new(self.req_send.clone(), self.max_addr, None).await
84            })
85            .await
86    }
87
88    /// Returns a VA mapper for this guest memory, but map everything into the
89    /// address space of `process`.
90    ///
91    /// Each call will allocate a new unique mapper.
92    pub async fn new_remote_mapper(
93        &self,
94        process: RemoteProcess,
95    ) -> Result<Arc<VaMapper>, VaMapperError> {
96        Ok(Arc::new(
97            VaMapper::new(self.req_send.clone(), self.max_addr, Some(process)).await?,
98        ))
99    }
100
101    /// Adds an active mapping.
102    ///
103    /// TODO: currently this will panic if the mapping overlaps an existing
104    /// mapping. This needs to be fixed to allow this to overlap existing
105    /// mappings, in which case the old ones will be split and replaced.
106    pub async fn add_mapping(
107        &self,
108        range: MemoryRange,
109        mappable: Mappable,
110        file_offset: u64,
111        writable: bool,
112    ) {
113        let params = MappingParams {
114            range,
115            mappable,
116            file_offset,
117            writable,
118        };
119
120        self.req_send
121            .call(MappingRequest::AddMapping, params)
122            .await
123            .unwrap();
124    }
125
126    /// Removes all mappings in `range`.
127    ///
128    /// TODO: allow this to split existing mappings.
129    pub async fn remove_mappings(&self, range: MemoryRange) {
130        self.req_send
131            .call(MappingRequest::RemoveMappings, range)
132            .await
133            .unwrap();
134    }
135}
136
137/// A mapping request message.
138#[derive(MeshPayload)]
139pub enum MappingRequest {
140    AddMapper(Rpc<mesh::Sender<MapperRequest>, MapperId>),
141    RemoveMapper(MapperId),
142    SendMappings(MapperId, MemoryRange),
143    AddMapping(Rpc<MappingParams, ()>),
144    RemoveMappings(Rpc<MemoryRange, ()>),
145    Inspect(inspect::Deferred),
146}
147
148#[derive(InspectMut)]
149struct MappingManagerTask {
150    #[inspect(with = "inspect_mappings")]
151    mappings: Vec<Mapping>,
152    #[inspect(skip)]
153    mappers: Mappers,
154}
155
156fn inspect_mappings(mappings: &Vec<Mapping>) -> impl '_ + Inspect {
157    inspect::adhoc(move |req| {
158        let mut resp = req.respond();
159        for mapping in mappings {
160            resp.field(
161                &mapping.params.range.to_string(),
162                inspect::adhoc(|req| {
163                    req.respond()
164                        .field("writable", mapping.params.writable)
165                        .hex("file_offset", mapping.params.file_offset);
166                }),
167            );
168        }
169    })
170}
171
172struct Mapping {
173    params: MappingParams,
174    active_mappers: Vec<MapperId>,
175}
176
177/// The mapping parameters.
178#[derive(MeshPayload, Clone)]
179pub struct MappingParams {
180    /// The memory range for the mapping.
181    pub range: MemoryRange,
182    /// The OS object to map.
183    pub mappable: Mappable,
184    /// The file offset into `mappable`.
185    pub file_offset: u64,
186    /// Whether to map the memory as writable.
187    pub writable: bool,
188}
189
190struct Mappers {
191    mappers: Slab<MapperComm>,
192}
193
194struct MapperComm {
195    req_send: mesh::Sender<MapperRequest>,
196}
197
198#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, MeshPayload)]
199pub struct MapperId(usize);
200
201/// A request to a VA mapper.
202#[derive(MeshPayload)]
203pub enum MapperRequest {
204    /// Map the specified mapping.
205    Map(MappingParams),
206    /// There is no mapping for the specified range, so release anything waiting
207    /// on such a mapping to arrive.
208    NoMapping(MemoryRange),
209    /// Unmap the specified range and send a response when it's done.
210    Unmap(Rpc<MemoryRange, ()>),
211}
212
213impl MappingManagerTask {
214    fn new() -> Self {
215        Self {
216            mappers: Mappers {
217                mappers: Slab::new(),
218            },
219            mappings: Vec::new(),
220        }
221    }
222
223    async fn run(&mut self, req_recv: &mut mesh::Receiver<MappingRequest>) {
224        while let Some(req) = req_recv.next().await {
225            match req {
226                MappingRequest::AddMapper(rpc) => rpc.handle_sync(|send| self.add_mapper(send)),
227                MappingRequest::RemoveMapper(id) => {
228                    self.remove_mapper(id);
229                }
230                MappingRequest::SendMappings(id, range) => {
231                    self.send_mappings(id, range);
232                }
233                MappingRequest::AddMapping(rpc) => {
234                    rpc.handle_sync(|params| self.add_mapping(params))
235                }
236                MappingRequest::RemoveMappings(rpc) => {
237                    rpc.handle(async |range| self.remove_mappings(range).await)
238                        .await
239                }
240                MappingRequest::Inspect(deferred) => deferred.inspect(&mut *self),
241            }
242        }
243    }
244
245    fn add_mapper(&mut self, req_send: mesh::Sender<MapperRequest>) -> MapperId {
246        let id = self.mappers.mappers.insert(MapperComm { req_send });
247        tracing::debug!(?id, "adding mapper");
248        MapperId(id)
249    }
250
251    fn remove_mapper(&mut self, id: MapperId) {
252        tracing::debug!(?id, "removing mapper");
253        self.mappers.mappers.remove(id.0);
254        for mapping in &mut self.mappings {
255            mapping.active_mappers.retain(|m| m != &id);
256        }
257    }
258
259    fn send_mappings(&mut self, id: MapperId, mut range: MemoryRange) {
260        while !range.is_empty() {
261            // Find the next mapping that overlaps range.
262            let (this_end, params) = if let Some(mapping) = self
263                .mappings
264                .iter_mut()
265                .filter(|mapping| mapping.params.range.overlaps(&range))
266                .min_by_key(|mapping| mapping.params.range.start())
267            {
268                if mapping.params.range.start() <= range.start() {
269                    if !mapping.active_mappers.contains(&id) {
270                        mapping.active_mappers.push(id);
271                    }
272                    // The next mapping overlaps with the start of our range.
273                    (
274                        mapping.params.range.end().min(range.end()),
275                        Some(mapping.params.clone()),
276                    )
277                } else {
278                    // There's a gap before the next mapping.
279                    (mapping.params.range.start(), None)
280                }
281            } else {
282                // No matching mappings, consume the rest of the range.
283                (range.end(), None)
284            };
285            let this_range = MemoryRange::new(range.start()..this_end);
286            let req = if let Some(params) = params {
287                tracing::debug!(range = %this_range, full_range = %params.range, "sending mapping for range");
288                MapperRequest::Map(params)
289            } else {
290                tracing::debug!(range = %this_range, "no mapping for range");
291                MapperRequest::NoMapping(this_range)
292            };
293            self.mappers.mappers[id.0].req_send.send(req);
294            range = MemoryRange::new(this_end..range.end());
295        }
296    }
297
298    fn add_mapping(&mut self, params: MappingParams) {
299        tracing::debug!(range = %params.range, "adding mapping");
300
301        assert!(!self.mappings.iter().any(|m| m.params.range == params.range));
302
303        self.mappings.push(Mapping {
304            params,
305            active_mappers: Vec::new(),
306        });
307    }
308
309    async fn remove_mappings(&mut self, range: MemoryRange) {
310        let mut mappers = Vec::new();
311        self.mappings.retain_mut(|mapping| {
312            if !range.contains(&mapping.params.range) {
313                assert!(
314                    !range.overlaps(&mapping.params.range),
315                    "no partial unmappings allowed"
316                );
317                return true;
318            }
319            tracing::debug!(range = %mapping.params.range, "removing mapping");
320            mappers.append(&mut mapping.active_mappers);
321            false
322        });
323        mappers.sort();
324        mappers.dedup();
325        self.mappers.invalidate(&mappers, range).await;
326    }
327}
328
329impl Mappers {
330    async fn invalidate(&self, ids: &[MapperId], range: MemoryRange) {
331        tracing::debug!(mapper_count = ids.len(), %range, "sending invalidations");
332        join_all(ids.iter().map(async |&MapperId(i)| {
333            if let Err(err) = self.mappers[i]
334                .req_send
335                .call(MapperRequest::Unmap, range)
336                .await
337            {
338                tracing::warn!(
339                    error = &err as &dyn std::error::Error,
340                    "mapper dropped invalidate request"
341                );
342            }
343        }))
344        .await;
345    }
346}