membacking/mapping_manager/
manager.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Implements the mapping manager, which keeps track of the VA mappers and
5//! their currently active mappings. It is responsible for invalidating mappings
6//! in each VA range when they are torn down by the region manager.
7
8use super::mappable::Mappable;
9use super::object_cache::ObjectCache;
10use super::object_cache::ObjectId;
11use super::va_mapper::VaMapper;
12use super::va_mapper::VaMapperError;
13use crate::RemoteProcess;
14use futures::StreamExt;
15use futures::future::join_all;
16use guestmem::ProvideShareableRegions;
17use guestmem::ShareableRegion;
18use inspect::Inspect;
19use inspect::InspectMut;
20use memory_range::MemoryRange;
21use mesh::MeshPayload;
22use mesh::rpc::Rpc;
23use mesh::rpc::RpcSend;
24use pal_async::task::Spawn;
25use slab::Slab;
26use std::sync::Arc;
27
28/// The mapping manager.
29#[derive(Debug, Inspect)]
30pub struct MappingManager {
31    #[inspect(
32        flatten,
33        with = "|x| inspect::send(&x.req_send, MappingRequest::Inspect)"
34    )]
35    client: MappingManagerClient,
36}
37
38impl MappingManager {
39    /// Returns a new mapping manager that can map addresses up to `max_addr`.
40    ///
41    /// If `private_ram` is true, mappers created from this manager will use
42    /// anonymous private memory for guest RAM instead of shared file-backed
43    /// memory.
44    pub fn new(spawn: impl Spawn, max_addr: u64, private_ram: bool) -> Self {
45        let (req_send, mut req_recv) = mesh::mpsc_channel();
46        spawn
47            .spawn("mapping_manager", {
48                let mut task = MappingManagerTask::new();
49                async move {
50                    task.run(&mut req_recv).await;
51                }
52            })
53            .detach();
54        Self {
55            client: MappingManagerClient {
56                id: ObjectId::new(),
57                req_send,
58                max_addr,
59                private_ram,
60            },
61        }
62    }
63
64    /// Returns an object used to access the mapping manager, potentially from a
65    /// remote process.
66    pub fn client(&self) -> &MappingManagerClient {
67        &self.client
68    }
69}
70
71/// Provides access to the mapping manager.
72#[derive(Debug, MeshPayload, Clone)]
73pub struct MappingManagerClient {
74    req_send: mesh::Sender<MappingRequest>,
75    id: ObjectId,
76    max_addr: u64,
77    private_ram: bool,
78}
79
80static MAPPER_CACHE: ObjectCache<VaMapper> = ObjectCache::new();
81
82impl MappingManagerClient {
83    /// Returns a VA mapper for this guest memory.
84    ///
85    /// This will single instance the mapper, so this is safe to call multiple
86    /// times. If `private_ram` was set when creating the [`MappingManager`],
87    /// the mapper will use anonymous private memory for guest RAM.
88    pub async fn new_mapper(&self) -> Result<Arc<VaMapper>, VaMapperError> {
89        // Get the VA mapper from the mapper cache if possible to avoid keeping
90        // multiple VA ranges for this memory per process.
91        let private_ram = self.private_ram;
92        MAPPER_CACHE
93            .get_or_insert_with(&self.id, async {
94                VaMapper::new(self.req_send.clone(), self.max_addr, None, private_ram).await
95            })
96            .await
97    }
98
99    /// Returns a VA mapper for this guest memory, but map everything into the
100    /// address space of `process`.
101    ///
102    /// Each call will allocate a new unique mapper.
103    ///
104    /// Returns an error if private memory mode is enabled, since private
105    /// anonymous pages would be committed in the remote process and not
106    /// accessible locally.
107    pub async fn new_remote_mapper(
108        &self,
109        process: RemoteProcess,
110    ) -> Result<Arc<VaMapper>, VaMapperError> {
111        if self.private_ram {
112            return Err(VaMapperError::RemoteWithPrivateMemory);
113        }
114        Ok(Arc::new(
115            VaMapper::new(self.req_send.clone(), self.max_addr, Some(process), false).await?,
116        ))
117    }
118
119    /// Adds an active mapping.
120    ///
121    /// TODO: currently this will panic if the mapping overlaps an existing
122    /// mapping. This needs to be fixed to allow this to overlap existing
123    /// mappings, in which case the old ones will be split and replaced.
124    pub async fn add_mapping(&self, params: MappingParams) {
125        self.req_send
126            .call(MappingRequest::AddMapping, params)
127            .await
128            .unwrap();
129    }
130
131    /// Removes all mappings in `range`.
132    ///
133    /// TODO: allow this to split existing mappings.
134    pub async fn remove_mappings(&self, range: MemoryRange) {
135        self.req_send
136            .call(MappingRequest::RemoveMappings, range)
137            .await
138            .unwrap();
139    }
140}
141
142/// A mapping request message.
143#[derive(MeshPayload)]
144pub enum MappingRequest {
145    AddMapper(Rpc<mesh::Sender<MapperRequest>, MapperId>),
146    RemoveMapper(MapperId),
147    SendMappings(MapperId, MemoryRange),
148    AddMapping(Rpc<MappingParams, ()>),
149    RemoveMappings(Rpc<MemoryRange, ()>),
150    /// Returns all mappings that have `dma_target` set.
151    GetDmaTargetMappings(Rpc<(), Vec<MappingParams>>),
152    Inspect(inspect::Deferred),
153}
154
155#[derive(InspectMut)]
156struct MappingManagerTask {
157    #[inspect(with = "inspect_mappings")]
158    mappings: Vec<Mapping>,
159    #[inspect(skip)]
160    mappers: Mappers,
161}
162
163fn inspect_mappings(mappings: &Vec<Mapping>) -> impl '_ + Inspect {
164    inspect::adhoc(move |req| {
165        let mut resp = req.respond();
166        for mapping in mappings {
167            resp.field(
168                &mapping.params.range.to_string(),
169                inspect::adhoc(|req| {
170                    req.respond()
171                        .field("writable", mapping.params.writable)
172                        .field("dma_target", mapping.params.dma_target)
173                        .hex("file_offset", mapping.params.file_offset);
174                }),
175            );
176        }
177    })
178}
179
180struct Mapping {
181    params: MappingParams,
182    active_mappers: Vec<MapperId>,
183}
184
185/// The mapping parameters.
186#[derive(MeshPayload, Clone)]
187pub struct MappingParams {
188    /// The memory range for the mapping.
189    pub range: MemoryRange,
190    /// The OS object to map.
191    pub mappable: Mappable,
192    /// The file offset into `mappable`.
193    pub file_offset: u64,
194    /// Whether to map the memory as writable.
195    pub writable: bool,
196    /// Whether this mapping is a DMA target (guest RAM or similar).
197    ///
198    /// DMA-target mappings are exposed via [`GuestMemorySharing`](guestmem::GuestMemorySharing) so
199    /// that external consumers (vhost-user backends, etc.) can share the
200    /// backing memory.
201    pub dma_target: bool,
202}
203
204struct Mappers {
205    mappers: Slab<MapperComm>,
206}
207
208struct MapperComm {
209    req_send: mesh::Sender<MapperRequest>,
210}
211
212#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, MeshPayload)]
213pub struct MapperId(usize);
214
215/// A request to a VA mapper.
216#[derive(MeshPayload)]
217pub enum MapperRequest {
218    /// Map the specified mapping.
219    Map(MappingParams),
220    /// There is no mapping for the specified range, so release anything waiting
221    /// on such a mapping to arrive.
222    NoMapping(MemoryRange),
223    /// Unmap the specified range and send a response when it's done.
224    Unmap(Rpc<MemoryRange, ()>),
225}
226
227impl MappingManagerTask {
228    fn new() -> Self {
229        Self {
230            mappers: Mappers {
231                mappers: Slab::new(),
232            },
233            mappings: Vec::new(),
234        }
235    }
236
237    async fn run(&mut self, req_recv: &mut mesh::Receiver<MappingRequest>) {
238        while let Some(req) = req_recv.next().await {
239            match req {
240                MappingRequest::AddMapper(rpc) => rpc.handle_sync(|send| self.add_mapper(send)),
241                MappingRequest::RemoveMapper(id) => {
242                    self.remove_mapper(id);
243                }
244                MappingRequest::SendMappings(id, range) => {
245                    self.send_mappings(id, range);
246                }
247                MappingRequest::AddMapping(rpc) => {
248                    rpc.handle_sync(|params| self.add_mapping(params))
249                }
250                MappingRequest::RemoveMappings(rpc) => {
251                    rpc.handle(async |range| self.remove_mappings(range).await)
252                        .await
253                }
254                MappingRequest::GetDmaTargetMappings(rpc) => {
255                    rpc.handle_sync(|()| self.get_dma_target_mappings())
256                }
257                MappingRequest::Inspect(deferred) => deferred.inspect(&mut *self),
258            }
259        }
260    }
261
262    fn add_mapper(&mut self, req_send: mesh::Sender<MapperRequest>) -> MapperId {
263        let id = self.mappers.mappers.insert(MapperComm { req_send });
264        tracing::debug!(?id, "adding mapper");
265        MapperId(id)
266    }
267
268    fn remove_mapper(&mut self, id: MapperId) {
269        tracing::debug!(?id, "removing mapper");
270        self.mappers.mappers.remove(id.0);
271        for mapping in &mut self.mappings {
272            mapping.active_mappers.retain(|m| m != &id);
273        }
274    }
275
276    fn send_mappings(&mut self, id: MapperId, mut range: MemoryRange) {
277        while !range.is_empty() {
278            // Find the next mapping that overlaps range.
279            let (this_end, params) = if let Some(mapping) = self
280                .mappings
281                .iter_mut()
282                .filter(|mapping| mapping.params.range.overlaps(&range))
283                .min_by_key(|mapping| mapping.params.range.start())
284            {
285                if mapping.params.range.start() <= range.start() {
286                    if !mapping.active_mappers.contains(&id) {
287                        mapping.active_mappers.push(id);
288                    }
289                    // The next mapping overlaps with the start of our range.
290                    (
291                        mapping.params.range.end().min(range.end()),
292                        Some(mapping.params.clone()),
293                    )
294                } else {
295                    // There's a gap before the next mapping.
296                    (mapping.params.range.start(), None)
297                }
298            } else {
299                // No matching mappings, consume the rest of the range.
300                (range.end(), None)
301            };
302            let this_range = MemoryRange::new(range.start()..this_end);
303            let req = if let Some(params) = params {
304                tracing::debug!(range = %this_range, full_range = %params.range, "sending mapping for range");
305                MapperRequest::Map(params)
306            } else {
307                tracing::debug!(range = %this_range, "no mapping for range");
308                MapperRequest::NoMapping(this_range)
309            };
310            self.mappers.mappers[id.0].req_send.send(req);
311            range = MemoryRange::new(this_end..range.end());
312        }
313    }
314
315    fn add_mapping(&mut self, params: MappingParams) {
316        tracing::debug!(range = %params.range, "adding mapping");
317
318        assert!(!self.mappings.iter().any(|m| m.params.range == params.range));
319
320        self.mappings.push(Mapping {
321            params,
322            active_mappers: Vec::new(),
323        });
324    }
325
326    fn get_dma_target_mappings(&self) -> Vec<MappingParams> {
327        self.mappings
328            .iter()
329            .filter(|m| m.params.dma_target)
330            .map(|m| m.params.clone())
331            .collect()
332    }
333
334    async fn remove_mappings(&mut self, range: MemoryRange) {
335        let mut mappers = Vec::new();
336        self.mappings.retain_mut(|mapping| {
337            if !range.contains(&mapping.params.range) {
338                assert!(
339                    !range.overlaps(&mapping.params.range),
340                    "no partial unmappings allowed"
341                );
342                return true;
343            }
344            tracing::debug!(range = %mapping.params.range, "removing mapping");
345            mappers.append(&mut mapping.active_mappers);
346            false
347        });
348        mappers.sort();
349        mappers.dedup();
350        self.mappers.invalidate(&mappers, range).await;
351    }
352}
353
354impl Mappers {
355    async fn invalidate(&self, ids: &[MapperId], range: MemoryRange) {
356        tracing::debug!(mapper_count = ids.len(), %range, "sending invalidations");
357        join_all(ids.iter().map(async |&MapperId(i)| {
358            if let Err(err) = self.mappers[i]
359                .req_send
360                .call(MapperRequest::Unmap, range)
361                .await
362            {
363                tracing::warn!(
364                    error = &err as &dyn std::error::Error,
365                    "mapper dropped invalidate request"
366                );
367            }
368        }))
369        .await;
370    }
371}
372
373/// Implements [`ProvideShareableRegions`] by querying the
374/// [`MappingManager`] for DMA-target mappings. Used by `VaMapper`'s
375/// `sharing()` implementation.
376pub(crate) struct DmaRegionProvider {
377    pub req_send: mesh::Sender<MappingRequest>,
378}
379
380impl ProvideShareableRegions for DmaRegionProvider {
381    async fn get_regions(&self) -> Result<Vec<ShareableRegion>, guestmem::ShareableRegionError> {
382        let mappings = self
383            .req_send
384            .call(MappingRequest::GetDmaTargetMappings, ())
385            .await?;
386
387        Ok(mappings
388            .into_iter()
389            .map(|m| ShareableRegion {
390                guest_address: m.range.start(),
391                size: m.range.len(),
392                file: m.mappable.inner_arc(),
393                file_offset: m.file_offset,
394            })
395            .collect())
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use guestmem::ProvideShareableRegions;
403    use memory_range::MemoryRange;
404
405    #[pal_async::async_test]
406    async fn test_dma_target_regions_returned(spawn: impl Spawn) {
407        let mm = MappingManager::new(&spawn, 0x200000, false);
408        let client = mm.client().clone();
409
410        let ram: Mappable = sparse_mmap::alloc_shared_memory(0x100000, "test-ram")
411            .unwrap()
412            .into();
413        let device: Mappable = sparse_mmap::alloc_shared_memory(0x1000, "test-dev")
414            .unwrap()
415            .into();
416
417        client
418            .add_mapping(MappingParams {
419                range: MemoryRange::new(0..0x100000),
420                mappable: ram,
421                file_offset: 0,
422                writable: true,
423                dma_target: true,
424            })
425            .await;
426
427        client
428            .add_mapping(MappingParams {
429                range: MemoryRange::new(0x100000..0x101000),
430                mappable: device,
431                file_offset: 0,
432                writable: true,
433                dma_target: false,
434            })
435            .await;
436
437        let provider = DmaRegionProvider {
438            req_send: client.req_send.clone(),
439        };
440        let regions = provider.get_regions().await.unwrap();
441
442        // Only the DMA-target mapping should appear.
443        assert_eq!(regions.len(), 1);
444        assert_eq!(regions[0].guest_address, 0);
445        assert_eq!(regions[0].size, 0x100000);
446        assert_eq!(regions[0].file_offset, 0);
447    }
448
449    #[pal_async::async_test]
450    async fn test_no_dma_targets_returns_empty(spawn: impl Spawn) {
451        let mm = MappingManager::new(&spawn, 0x100000, false);
452        let client = mm.client().clone();
453
454        let mappable: Mappable = sparse_mmap::alloc_shared_memory(0x1000, "test")
455            .unwrap()
456            .into();
457
458        client
459            .add_mapping(MappingParams {
460                range: MemoryRange::new(0..0x1000),
461                mappable,
462                file_offset: 0,
463                writable: true,
464                dma_target: false,
465            })
466            .await;
467
468        let provider = DmaRegionProvider {
469            req_send: client.req_send.clone(),
470        };
471        let regions = provider.get_regions().await.unwrap();
472        assert!(regions.is_empty());
473    }
474}