membacking/mapping_manager/
manager.rs1use super::mappable::Mappable;
9use super::object_cache::ObjectCache;
10use super::object_cache::ObjectId;
11use super::va_mapper::VaMapper;
12use super::va_mapper::VaMapperError;
13use crate::RemoteProcess;
14use futures::StreamExt;
15use futures::future::join_all;
16use guestmem::ProvideShareableRegions;
17use guestmem::ShareableRegion;
18use inspect::Inspect;
19use inspect::InspectMut;
20use memory_range::MemoryRange;
21use mesh::MeshPayload;
22use mesh::rpc::Rpc;
23use mesh::rpc::RpcSend;
24use pal_async::task::Spawn;
25use slab::Slab;
26use std::sync::Arc;
27
28#[derive(Debug, Inspect)]
30pub struct MappingManager {
31 #[inspect(
32 flatten,
33 with = "|x| inspect::send(&x.req_send, MappingRequest::Inspect)"
34 )]
35 client: MappingManagerClient,
36}
37
38impl MappingManager {
39 pub fn new(spawn: impl Spawn, max_addr: u64, private_ram: bool) -> Self {
45 let (req_send, mut req_recv) = mesh::mpsc_channel();
46 spawn
47 .spawn("mapping_manager", {
48 let mut task = MappingManagerTask::new();
49 async move {
50 task.run(&mut req_recv).await;
51 }
52 })
53 .detach();
54 Self {
55 client: MappingManagerClient {
56 id: ObjectId::new(),
57 req_send,
58 max_addr,
59 private_ram,
60 },
61 }
62 }
63
64 pub fn client(&self) -> &MappingManagerClient {
67 &self.client
68 }
69}
70
71#[derive(Debug, MeshPayload, Clone)]
73pub struct MappingManagerClient {
74 req_send: mesh::Sender<MappingRequest>,
75 id: ObjectId,
76 max_addr: u64,
77 private_ram: bool,
78}
79
80static MAPPER_CACHE: ObjectCache<VaMapper> = ObjectCache::new();
81
82impl MappingManagerClient {
83 pub async fn new_mapper(&self) -> Result<Arc<VaMapper>, VaMapperError> {
89 let private_ram = self.private_ram;
92 MAPPER_CACHE
93 .get_or_insert_with(&self.id, async {
94 VaMapper::new(self.req_send.clone(), self.max_addr, None, private_ram).await
95 })
96 .await
97 }
98
99 pub async fn new_remote_mapper(
108 &self,
109 process: RemoteProcess,
110 ) -> Result<Arc<VaMapper>, VaMapperError> {
111 if self.private_ram {
112 return Err(VaMapperError::RemoteWithPrivateMemory);
113 }
114 Ok(Arc::new(
115 VaMapper::new(self.req_send.clone(), self.max_addr, Some(process), false).await?,
116 ))
117 }
118
119 pub async fn add_mapping(&self, params: MappingParams) {
125 self.req_send
126 .call(MappingRequest::AddMapping, params)
127 .await
128 .unwrap();
129 }
130
131 pub async fn remove_mappings(&self, range: MemoryRange) {
135 self.req_send
136 .call(MappingRequest::RemoveMappings, range)
137 .await
138 .unwrap();
139 }
140}
141
142#[derive(MeshPayload)]
144pub enum MappingRequest {
145 AddMapper(Rpc<mesh::Sender<MapperRequest>, MapperId>),
146 RemoveMapper(MapperId),
147 SendMappings(MapperId, MemoryRange),
148 AddMapping(Rpc<MappingParams, ()>),
149 RemoveMappings(Rpc<MemoryRange, ()>),
150 GetDmaTargetMappings(Rpc<(), Vec<MappingParams>>),
152 Inspect(inspect::Deferred),
153}
154
155#[derive(InspectMut)]
156struct MappingManagerTask {
157 #[inspect(with = "inspect_mappings")]
158 mappings: Vec<Mapping>,
159 #[inspect(skip)]
160 mappers: Mappers,
161}
162
163fn inspect_mappings(mappings: &Vec<Mapping>) -> impl '_ + Inspect {
164 inspect::adhoc(move |req| {
165 let mut resp = req.respond();
166 for mapping in mappings {
167 resp.field(
168 &mapping.params.range.to_string(),
169 inspect::adhoc(|req| {
170 req.respond()
171 .field("writable", mapping.params.writable)
172 .field("dma_target", mapping.params.dma_target)
173 .hex("file_offset", mapping.params.file_offset);
174 }),
175 );
176 }
177 })
178}
179
180struct Mapping {
181 params: MappingParams,
182 active_mappers: Vec<MapperId>,
183}
184
185#[derive(MeshPayload, Clone)]
187pub struct MappingParams {
188 pub range: MemoryRange,
190 pub mappable: Mappable,
192 pub file_offset: u64,
194 pub writable: bool,
196 pub dma_target: bool,
202}
203
204struct Mappers {
205 mappers: Slab<MapperComm>,
206}
207
208struct MapperComm {
209 req_send: mesh::Sender<MapperRequest>,
210}
211
212#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, MeshPayload)]
213pub struct MapperId(usize);
214
215#[derive(MeshPayload)]
217pub enum MapperRequest {
218 Map(MappingParams),
220 NoMapping(MemoryRange),
223 Unmap(Rpc<MemoryRange, ()>),
225}
226
227impl MappingManagerTask {
228 fn new() -> Self {
229 Self {
230 mappers: Mappers {
231 mappers: Slab::new(),
232 },
233 mappings: Vec::new(),
234 }
235 }
236
237 async fn run(&mut self, req_recv: &mut mesh::Receiver<MappingRequest>) {
238 while let Some(req) = req_recv.next().await {
239 match req {
240 MappingRequest::AddMapper(rpc) => rpc.handle_sync(|send| self.add_mapper(send)),
241 MappingRequest::RemoveMapper(id) => {
242 self.remove_mapper(id);
243 }
244 MappingRequest::SendMappings(id, range) => {
245 self.send_mappings(id, range);
246 }
247 MappingRequest::AddMapping(rpc) => {
248 rpc.handle_sync(|params| self.add_mapping(params))
249 }
250 MappingRequest::RemoveMappings(rpc) => {
251 rpc.handle(async |range| self.remove_mappings(range).await)
252 .await
253 }
254 MappingRequest::GetDmaTargetMappings(rpc) => {
255 rpc.handle_sync(|()| self.get_dma_target_mappings())
256 }
257 MappingRequest::Inspect(deferred) => deferred.inspect(&mut *self),
258 }
259 }
260 }
261
262 fn add_mapper(&mut self, req_send: mesh::Sender<MapperRequest>) -> MapperId {
263 let id = self.mappers.mappers.insert(MapperComm { req_send });
264 tracing::debug!(?id, "adding mapper");
265 MapperId(id)
266 }
267
268 fn remove_mapper(&mut self, id: MapperId) {
269 tracing::debug!(?id, "removing mapper");
270 self.mappers.mappers.remove(id.0);
271 for mapping in &mut self.mappings {
272 mapping.active_mappers.retain(|m| m != &id);
273 }
274 }
275
276 fn send_mappings(&mut self, id: MapperId, mut range: MemoryRange) {
277 while !range.is_empty() {
278 let (this_end, params) = if let Some(mapping) = self
280 .mappings
281 .iter_mut()
282 .filter(|mapping| mapping.params.range.overlaps(&range))
283 .min_by_key(|mapping| mapping.params.range.start())
284 {
285 if mapping.params.range.start() <= range.start() {
286 if !mapping.active_mappers.contains(&id) {
287 mapping.active_mappers.push(id);
288 }
289 (
291 mapping.params.range.end().min(range.end()),
292 Some(mapping.params.clone()),
293 )
294 } else {
295 (mapping.params.range.start(), None)
297 }
298 } else {
299 (range.end(), None)
301 };
302 let this_range = MemoryRange::new(range.start()..this_end);
303 let req = if let Some(params) = params {
304 tracing::debug!(range = %this_range, full_range = %params.range, "sending mapping for range");
305 MapperRequest::Map(params)
306 } else {
307 tracing::debug!(range = %this_range, "no mapping for range");
308 MapperRequest::NoMapping(this_range)
309 };
310 self.mappers.mappers[id.0].req_send.send(req);
311 range = MemoryRange::new(this_end..range.end());
312 }
313 }
314
315 fn add_mapping(&mut self, params: MappingParams) {
316 tracing::debug!(range = %params.range, "adding mapping");
317
318 assert!(!self.mappings.iter().any(|m| m.params.range == params.range));
319
320 self.mappings.push(Mapping {
321 params,
322 active_mappers: Vec::new(),
323 });
324 }
325
326 fn get_dma_target_mappings(&self) -> Vec<MappingParams> {
327 self.mappings
328 .iter()
329 .filter(|m| m.params.dma_target)
330 .map(|m| m.params.clone())
331 .collect()
332 }
333
334 async fn remove_mappings(&mut self, range: MemoryRange) {
335 let mut mappers = Vec::new();
336 self.mappings.retain_mut(|mapping| {
337 if !range.contains(&mapping.params.range) {
338 assert!(
339 !range.overlaps(&mapping.params.range),
340 "no partial unmappings allowed"
341 );
342 return true;
343 }
344 tracing::debug!(range = %mapping.params.range, "removing mapping");
345 mappers.append(&mut mapping.active_mappers);
346 false
347 });
348 mappers.sort();
349 mappers.dedup();
350 self.mappers.invalidate(&mappers, range).await;
351 }
352}
353
354impl Mappers {
355 async fn invalidate(&self, ids: &[MapperId], range: MemoryRange) {
356 tracing::debug!(mapper_count = ids.len(), %range, "sending invalidations");
357 join_all(ids.iter().map(async |&MapperId(i)| {
358 if let Err(err) = self.mappers[i]
359 .req_send
360 .call(MapperRequest::Unmap, range)
361 .await
362 {
363 tracing::warn!(
364 error = &err as &dyn std::error::Error,
365 "mapper dropped invalidate request"
366 );
367 }
368 }))
369 .await;
370 }
371}
372
373pub(crate) struct DmaRegionProvider {
377 pub req_send: mesh::Sender<MappingRequest>,
378}
379
380impl ProvideShareableRegions for DmaRegionProvider {
381 async fn get_regions(&self) -> Result<Vec<ShareableRegion>, guestmem::ShareableRegionError> {
382 let mappings = self
383 .req_send
384 .call(MappingRequest::GetDmaTargetMappings, ())
385 .await?;
386
387 Ok(mappings
388 .into_iter()
389 .map(|m| ShareableRegion {
390 guest_address: m.range.start(),
391 size: m.range.len(),
392 file: m.mappable.inner_arc(),
393 file_offset: m.file_offset,
394 })
395 .collect())
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use guestmem::ProvideShareableRegions;
403 use memory_range::MemoryRange;
404
405 #[pal_async::async_test]
406 async fn test_dma_target_regions_returned(spawn: impl Spawn) {
407 let mm = MappingManager::new(&spawn, 0x200000, false);
408 let client = mm.client().clone();
409
410 let ram: Mappable = sparse_mmap::alloc_shared_memory(0x100000, "test-ram")
411 .unwrap()
412 .into();
413 let device: Mappable = sparse_mmap::alloc_shared_memory(0x1000, "test-dev")
414 .unwrap()
415 .into();
416
417 client
418 .add_mapping(MappingParams {
419 range: MemoryRange::new(0..0x100000),
420 mappable: ram,
421 file_offset: 0,
422 writable: true,
423 dma_target: true,
424 })
425 .await;
426
427 client
428 .add_mapping(MappingParams {
429 range: MemoryRange::new(0x100000..0x101000),
430 mappable: device,
431 file_offset: 0,
432 writable: true,
433 dma_target: false,
434 })
435 .await;
436
437 let provider = DmaRegionProvider {
438 req_send: client.req_send.clone(),
439 };
440 let regions = provider.get_regions().await.unwrap();
441
442 assert_eq!(regions.len(), 1);
444 assert_eq!(regions[0].guest_address, 0);
445 assert_eq!(regions[0].size, 0x100000);
446 assert_eq!(regions[0].file_offset, 0);
447 }
448
449 #[pal_async::async_test]
450 async fn test_no_dma_targets_returns_empty(spawn: impl Spawn) {
451 let mm = MappingManager::new(&spawn, 0x100000, false);
452 let client = mm.client().clone();
453
454 let mappable: Mappable = sparse_mmap::alloc_shared_memory(0x1000, "test")
455 .unwrap()
456 .into();
457
458 client
459 .add_mapping(MappingParams {
460 range: MemoryRange::new(0..0x1000),
461 mappable,
462 file_offset: 0,
463 writable: true,
464 dma_target: false,
465 })
466 .await;
467
468 let provider = DmaRegionProvider {
469 req_send: client.req_send.clone(),
470 };
471 let regions = provider.get_regions().await.unwrap();
472 assert!(regions.is_empty());
473 }
474}