membacking/mapping_manager/
manager.rsuse super::mappable::Mappable;
use super::object_cache::ObjectCache;
use super::object_cache::ObjectId;
use super::va_mapper::VaMapper;
use super::va_mapper::VaMapperError;
use crate::RemoteProcess;
use futures::StreamExt;
use futures::future::join_all;
use inspect::Inspect;
use inspect::InspectMut;
use memory_range::MemoryRange;
use mesh::MeshPayload;
use mesh::rpc::Rpc;
use mesh::rpc::RpcSend;
use pal_async::task::Spawn;
use slab::Slab;
use std::sync::Arc;
#[derive(Debug)]
pub struct MappingManager {
client: MappingManagerClient,
}
impl Inspect for MappingManager {
fn inspect(&self, req: inspect::Request<'_>) {
self.client
.req_send
.send(MappingRequest::Inspect(req.defer()));
}
}
impl MappingManager {
pub fn new(spawn: impl Spawn, max_addr: u64) -> Self {
let (req_send, mut req_recv) = mesh::mpsc_channel();
spawn
.spawn("mapping_manager", {
let mut task = MappingManagerTask::new();
async move {
task.run(&mut req_recv).await;
}
})
.detach();
Self {
client: MappingManagerClient {
id: ObjectId::new(),
req_send,
max_addr,
},
}
}
pub fn client(&self) -> &MappingManagerClient {
&self.client
}
}
#[derive(Debug, MeshPayload, Clone)]
pub struct MappingManagerClient {
req_send: mesh::Sender<MappingRequest>,
id: ObjectId,
max_addr: u64,
}
static MAPPER_CACHE: ObjectCache<VaMapper> = ObjectCache::new();
impl MappingManagerClient {
pub async fn new_mapper(&self) -> Result<Arc<VaMapper>, VaMapperError> {
MAPPER_CACHE
.get_or_insert_with(&self.id, async {
VaMapper::new(self.req_send.clone(), self.max_addr, None).await
})
.await
}
pub async fn new_remote_mapper(
&self,
process: RemoteProcess,
) -> Result<Arc<VaMapper>, VaMapperError> {
Ok(Arc::new(
VaMapper::new(self.req_send.clone(), self.max_addr, Some(process)).await?,
))
}
pub async fn add_mapping(
&self,
range: MemoryRange,
mappable: Mappable,
file_offset: u64,
writable: bool,
) {
let params = MappingParams {
range,
mappable,
file_offset,
writable,
};
self.req_send
.call(MappingRequest::AddMapping, params)
.await
.unwrap();
}
pub async fn remove_mappings(&self, range: MemoryRange) {
self.req_send
.call(MappingRequest::RemoveMappings, range)
.await
.unwrap();
}
}
#[derive(MeshPayload)]
pub enum MappingRequest {
AddMapper(Rpc<mesh::Sender<MapperRequest>, MapperId>),
RemoveMapper(MapperId),
SendMappings(MapperId, MemoryRange),
AddMapping(Rpc<MappingParams, ()>),
RemoveMappings(Rpc<MemoryRange, ()>),
Inspect(inspect::Deferred),
}
#[derive(InspectMut)]
struct MappingManagerTask {
#[inspect(with = "inspect_mappings")]
mappings: Vec<Mapping>,
#[inspect(skip)]
mappers: Mappers,
}
fn inspect_mappings(mappings: &Vec<Mapping>) -> impl '_ + Inspect {
inspect::adhoc(move |req| {
let mut resp = req.respond();
for mapping in mappings {
resp.field(
&mapping.params.range.to_string(),
inspect::adhoc(|req| {
req.respond()
.field("writable", mapping.params.writable)
.hex("file_offset", mapping.params.file_offset);
}),
);
}
})
}
struct Mapping {
params: MappingParams,
active_mappers: Vec<MapperId>,
}
#[derive(MeshPayload, Clone)]
pub struct MappingParams {
pub range: MemoryRange,
pub mappable: Mappable,
pub file_offset: u64,
pub writable: bool,
}
struct Mappers {
mappers: Slab<MapperComm>,
}
struct MapperComm {
req_send: mesh::Sender<MapperRequest>,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, MeshPayload)]
pub struct MapperId(usize);
#[derive(MeshPayload)]
pub enum MapperRequest {
Map(MappingParams),
NoMapping(MemoryRange),
Unmap(Rpc<MemoryRange, ()>),
}
impl MappingManagerTask {
fn new() -> Self {
Self {
mappers: Mappers {
mappers: Slab::new(),
},
mappings: Vec::new(),
}
}
async fn run(&mut self, req_recv: &mut mesh::Receiver<MappingRequest>) {
while let Some(req) = req_recv.next().await {
match req {
MappingRequest::AddMapper(rpc) => rpc.handle_sync(|send| self.add_mapper(send)),
MappingRequest::RemoveMapper(id) => {
self.remove_mapper(id);
}
MappingRequest::SendMappings(id, range) => {
self.send_mappings(id, range);
}
MappingRequest::AddMapping(rpc) => {
rpc.handle_sync(|params| self.add_mapping(params))
}
MappingRequest::RemoveMappings(rpc) => {
rpc.handle(async |range| self.remove_mappings(range).await)
.await
}
MappingRequest::Inspect(deferred) => deferred.inspect(&mut *self),
}
}
}
fn add_mapper(&mut self, req_send: mesh::Sender<MapperRequest>) -> MapperId {
let id = self.mappers.mappers.insert(MapperComm { req_send });
tracing::debug!(?id, "adding mapper");
MapperId(id)
}
fn remove_mapper(&mut self, id: MapperId) {
tracing::debug!(?id, "removing mapper");
self.mappers.mappers.remove(id.0);
for mapping in &mut self.mappings {
mapping.active_mappers.retain(|m| m != &id);
}
}
fn send_mappings(&mut self, id: MapperId, mut range: MemoryRange) {
while !range.is_empty() {
let (this_end, params) = if let Some(mapping) = self
.mappings
.iter_mut()
.filter(|mapping| mapping.params.range.overlaps(&range))
.min_by_key(|mapping| mapping.params.range.start())
{
if mapping.params.range.start() <= range.start() {
if !mapping.active_mappers.contains(&id) {
mapping.active_mappers.push(id);
}
(
mapping.params.range.end().min(range.end()),
Some(mapping.params.clone()),
)
} else {
(mapping.params.range.start(), None)
}
} else {
(range.end(), None)
};
let this_range = MemoryRange::new(range.start()..this_end);
let req = if let Some(params) = params {
tracing::debug!(range = %this_range, full_range = %params.range, "sending mapping for range");
MapperRequest::Map(params)
} else {
tracing::debug!(range = %this_range, "no mapping for range");
MapperRequest::NoMapping(this_range)
};
self.mappers.mappers[id.0].req_send.send(req);
range = MemoryRange::new(this_end..range.end());
}
}
fn add_mapping(&mut self, params: MappingParams) {
tracing::debug!(range = %params.range, "adding mapping");
assert!(!self.mappings.iter().any(|m| m.params.range == params.range));
self.mappings.push(Mapping {
params,
active_mappers: Vec::new(),
});
}
async fn remove_mappings(&mut self, range: MemoryRange) {
let mut mappers = Vec::new();
self.mappings.retain_mut(|mapping| {
if !range.contains(&mapping.params.range) {
assert!(
!range.overlaps(&mapping.params.range),
"no partial unmappings allowed"
);
return true;
}
tracing::debug!(range = %mapping.params.range, "removing mapping");
mappers.append(&mut mapping.active_mappers);
false
});
mappers.sort();
mappers.dedup();
self.mappers.invalidate(&mappers, range).await;
}
}
impl Mappers {
async fn invalidate(&self, ids: &[MapperId], range: MemoryRange) {
tracing::debug!(mapper_count = ids.len(), %range, "sending invalidations");
join_all(ids.iter().map(async |&MapperId(i)| {
if let Err(err) = self.mappers[i]
.req_send
.call(MapperRequest::Unmap, range)
.await
{
tracing::warn!(
error = &err as &dyn std::error::Error,
"mapper dropped invalidate request"
);
}
}))
.await;
}
}