1#![expect(unsafe_code)]
9
10use anyhow::Context;
11use base64::Engine;
12use debug_ptr::DebugPtr;
13use futures::FutureExt;
14use futures::StreamExt;
15use futures::executor::block_on;
16use futures_concurrency::future::Race;
17use inspect::Inspect;
18use inspect::SensitivityLevel;
19use mesh::MeshPayload;
20use mesh::OneshotReceiver;
21use mesh::message::MeshField;
22use mesh::payload::Protobuf;
23use mesh::rpc::Rpc;
24use mesh::rpc::RpcSend;
25use mesh_remote::InvitationAddress;
26#[cfg(unix)]
27use pal::unix::process::Builder as ProcessBuilder;
28#[cfg(windows)]
29use pal::windows::process;
30#[cfg(windows)]
31use pal::windows::process::Builder as ProcessBuilder;
32use pal_async::DefaultPool;
33use pal_async::task::Spawn;
34use pal_async::task::Task;
35use slab::Slab;
36use std::borrow::Cow;
37use std::ffi::OsString;
38use std::fs::File;
39#[cfg(unix)]
40use std::os::unix::prelude::*;
41#[cfg(windows)]
42use std::os::windows::prelude::*;
43use std::path::PathBuf;
44use std::thread;
45use tracing::Instrument;
46use tracing::instrument;
47use unicycle::FuturesUnordered;
48
49#[cfg(windows)]
50type IpcNode = mesh_remote::windows::AlpcNode;
51
52#[cfg(unix)]
53type IpcNode = mesh_remote::unix::UnixNode;
54
55#[cfg(unix)]
56const IPC_FD: i32 = 3;
57
58const INVITATION_ENV_NAME: &str = "MESH_WORKER_INVITATION";
63
64#[derive(Protobuf)]
65struct Invitation {
66 node_name: String,
67 address: InvitationAddress,
68 #[cfg(windows)]
69 directory_handle: usize,
70 #[cfg(unix)]
71 socket_fd: i32,
72}
73
74static PROCESS_NAME: DebugPtr<String> = DebugPtr::new();
75
76pub fn try_run_mesh_host<U, F, T>(base_name: &str, f: F) -> anyhow::Result<()>
85where
86 U: 'static + MeshPayload + Send,
87 F: AsyncFnOnce(U) -> anyhow::Result<T>,
88{
89 block_on(async {
90 if let Some(r) = node_from_environment().await? {
91 let NodeResult {
92 node_name,
93 node,
94 initial_port,
95 } = r;
96 PROCESS_NAME.store(&node_name);
97 set_program_name(&format!("{base_name}-{node_name}"));
98 let init = OneshotReceiver::<InitialMessage<U>>::from(initial_port)
99 .await
100 .context("failed to receive initial message")?;
101 let _drop = (
102 f(init.init_message).map(Some),
103 handle_host_requests(init.requests).map(|()| None),
104 )
105 .race()
106 .await
107 .transpose()?;
108
109 tracing::debug!("waiting to shut down node");
110 node.shutdown().await;
111 drop(_drop);
112 std::process::exit(0);
113 }
114 Ok(())
115 })
116}
117
118async fn handle_host_requests(mut recv: mesh::Receiver<HostRequest>) {
119 while let Some(req) = recv.next().await {
120 match req {
121 HostRequest::Inspect(deferred) => {
122 deferred.respond(inspect_host);
123 }
124 HostRequest::Crash => panic!("explicit panic request"),
125 }
126 }
127}
128
129fn set_program_name(name: &str) {
130 let _ = name;
131 #[cfg(target_os = "linux")]
132 {
133 let _ = std::fs::write("/proc/self/comm", name);
134 }
135}
136
137struct NodeResult {
138 node_name: String,
139 node: IpcNode,
140 initial_port: mesh::local_node::Port,
141}
142
143async fn node_from_environment() -> anyhow::Result<Option<NodeResult>> {
147 let invitation_str = match std::env::var(INVITATION_ENV_NAME) {
149 Ok(str) => str,
150 Err(_) => return Ok(None),
151 };
152
153 unsafe {
165 std::env::remove_var(INVITATION_ENV_NAME);
166 }
167
168 let invitation: Invitation = mesh::payload::decode(
169 &base64::engine::general_purpose::STANDARD
170 .decode(invitation_str)
171 .context("failed to base64 decode invitation")?,
172 )
173 .context("failed to protobuf decode invitation")?;
174
175 let (left, right) = mesh::local_node::Port::new_pair();
176
177 let node;
178 #[cfg(windows)]
179 {
180 let directory =
184 unsafe { OwnedHandle::from_raw_handle(invitation.directory_handle as RawHandle) };
185
186 let invitation = mesh_remote::windows::AlpcInvitation {
187 address: invitation.address,
188 directory,
189 };
190
191 node = mesh_remote::windows::AlpcNode::join(
193 pal_async::windows::TpPool::system(),
194 invitation,
195 left,
196 )
197 .context("failed to join mesh")?;
198 }
199
200 #[cfg(unix)]
201 {
202 let fd = unsafe { OwnedFd::from_raw_fd(invitation.socket_fd) };
206 let invitation = mesh_remote::unix::UnixInvitation {
207 address: invitation.address,
208 fd,
209 };
210
211 let (_, driver) = DefaultPool::spawn_on_thread("mesh-worker-pool");
213 node = mesh_remote::unix::UnixNode::join(driver, invitation, left)
214 .await
215 .context("failed to join mesh")?;
216 }
217
218 Ok(Some(NodeResult {
219 node_name: invitation.node_name,
220 node,
221 initial_port: right,
222 }))
223}
224
225#[derive(Inspect)]
246pub struct Mesh {
247 #[inspect(rename = "name")]
248 mesh_name: String,
249 #[inspect(flatten, send = "MeshRequest::Inspect")]
250 request: mesh::Sender<MeshRequest>,
251 #[inspect(skip)]
252 task: Task<()>,
253}
254
255pub trait SandboxProfile: Send {
257 fn apply(&mut self, builder: &mut ProcessBuilder<'_>);
261
262 fn finalize(&mut self) -> anyhow::Result<()> {
268 Ok(())
269 }
270}
271
272pub struct ProcessConfig {
274 name: String,
275 process_name: Option<PathBuf>,
276 process_args: Vec<OsString>,
277 stderr: Option<File>,
278 skip_worker_arg: bool,
279 sandbox_profile: Option<Box<dyn SandboxProfile + Sync>>,
280}
281
282impl ProcessConfig {
283 pub fn new(name: impl Into<String>) -> Self {
286 Self {
287 name: name.into(),
288 process_name: None,
289 process_args: Vec::new(),
290 stderr: None,
291 skip_worker_arg: false,
292 sandbox_profile: None,
293 }
294 }
295
296 pub fn new_with_sandbox(
299 name: impl Into<String>,
300 sandbox_profile: Box<dyn SandboxProfile + Sync>,
301 ) -> Self {
302 Self {
303 name: name.into(),
304 process_name: None,
305 process_args: Vec::new(),
306 stderr: None,
307 skip_worker_arg: false,
308 sandbox_profile: Some(sandbox_profile),
309 }
310 }
311
312 pub fn process_name(mut self, name: impl Into<PathBuf>) -> Self {
314 self.process_name = Some(name.into());
315 self
316 }
317
318 pub fn skip_worker_arg(mut self, skip: bool) -> Self {
325 self.skip_worker_arg = skip;
326 self
327 }
328
329 pub fn args<I>(mut self, args: I) -> Self
331 where
332 I: IntoIterator,
333 I::Item: Into<OsString>,
334 {
335 self.process_args.extend(args.into_iter().map(|x| x.into()));
336 self
337 }
338
339 pub fn stderr(mut self, file: Option<File>) -> Self {
341 self.stderr = file;
342 self
343 }
344}
345
346struct MeshInner {
347 requests: mesh::Receiver<MeshRequest>,
348 hosts: Slab<MeshHostInner>,
349 waiters: FuturesUnordered<OneshotReceiver<usize>>,
351 node: IpcNode,
353 mesh_name: String,
355 #[cfg(windows)]
358 job: pal::windows::job::Job,
359}
360
361struct MeshHostInner {
362 name: String,
363 pid: i32,
364 node_id: mesh::NodeId,
365 send: mesh::Sender<HostRequest>,
366}
367
368enum MeshRequest {
369 NewHost(Rpc<NewHostParams, anyhow::Result<()>>),
370 Inspect(inspect::Deferred),
371 Crash(i32),
372}
373
374struct NewHostParams {
375 config: ProcessConfig,
376 recv: mesh::local_node::Port,
377 request_send: mesh::Sender<HostRequest>,
378}
379
380impl Mesh {
381 pub fn new(mesh_name: String) -> anyhow::Result<Self> {
383 #[cfg(windows)]
384 let job = {
385 let job = pal::windows::job::Job::new().context("failed to create job object")?;
386 job.set_terminate_on_close()
387 .context("failed to set job object terminate on close")?;
388 job
389 };
390
391 #[cfg(windows)]
392 let node = mesh_remote::windows::AlpcNode::new(pal_async::windows::TpPool::system())
393 .context("AlpcNode creation failure")?;
394 #[cfg(unix)]
395 let node = {
396 let (_, driver) = DefaultPool::spawn_on_thread("mesh-worker-pool");
398 mesh_remote::unix::UnixNode::new(driver)
399 };
400
401 let (request, requests) = mesh::channel();
402 let mut inner = MeshInner {
403 requests,
404 hosts: Default::default(),
405 waiters: Default::default(),
406 node,
407 mesh_name: mesh_name.clone(),
408 #[cfg(windows)]
409 job,
410 };
411
412 let (_, driver) = DefaultPool::spawn_on_thread("mesh");
415 let task = driver.spawn(
416 format!("mesh-{}", &mesh_name),
417 async move { inner.run().await },
418 );
419
420 Ok(Self {
421 request,
422 mesh_name,
423 task,
424 })
425 }
426
427 pub async fn launch_host<T: 'static + MeshField + Send>(
433 &self,
434 config: ProcessConfig,
435 initial_message: T,
436 ) -> anyhow::Result<()> {
437 let (request_send, request_recv) = mesh::channel();
438
439 let (init_send, init_recv) = mesh::oneshot::<InitialMessage<T>>();
440 init_send.send(InitialMessage {
441 requests: request_recv,
442 init_message: initial_message,
443 });
444
445 self.request
446 .call(
447 MeshRequest::NewHost,
448 NewHostParams {
449 config,
450 recv: init_recv.into(),
451 request_send,
452 },
453 )
454 .await
455 .context("mesh failed")?
456 }
457
458 pub async fn shutdown(self) {
462 let span = tracing::span!(
463 tracing::Level::INFO,
464 "mesh_shutdown",
465 name = self.mesh_name.as_str(),
466 );
467
468 async {
469 drop(self.request);
470 self.task.await;
471 }
472 .instrument(span)
473 .await;
474 }
475
476 pub fn crash(&self, pid: i32) {
478 self.request.send(MeshRequest::Crash(pid));
479 }
480}
481
482#[derive(MeshPayload)]
483struct InitialMessage<T> {
484 requests: mesh::Receiver<HostRequest>,
485 init_message: T,
486}
487
488#[derive(Debug, MeshPayload)]
489enum HostRequest {
490 #[mesh(transparent)]
491 Inspect(inspect::Deferred),
492 Crash,
493}
494
495fn inspect_host(resp: &mut inspect::Response<'_>) {
496 resp.field("tasks", inspect_task::inspect_task_list());
497}
498
499#[derive(Inspect)]
500struct HostInspect<'a> {
501 #[inspect(safe)]
502 name: &'a str,
503 #[inspect(debug, safe)]
504 node_id: mesh::NodeId,
505 #[cfg(target_os = "linux")]
506 #[inspect(safe)]
507 rlimit: inspect_rlimit::InspectRlimit,
508}
509
510impl MeshInner {
511 async fn run(&mut self) {
512 enum Event {
513 Request(MeshRequest),
514 Done(usize),
515 }
516
517 loop {
518 let event = futures::select! { request = self.requests.select_next_some() => Event::Request(request),
520 n = self.waiters.select_next_some() => Event::Done(n.unwrap()),
521 complete => break,
522 };
523
524 match event {
525 Event::Request(request) => match request {
526 MeshRequest::NewHost(rpc) => {
527 rpc.handle(async |params| self.spawn_process(params).await)
528 .await
529 }
530 MeshRequest::Inspect(deferred) => {
531 deferred.respond(|resp| {
532 resp.sensitivity_child("hosts", SensitivityLevel::Safe, |req| {
533 let mut resp = req.respond();
534 for host in self.hosts.iter().map(|(_, host)| host) {
535 resp.sensitivity_field_mut(
536 &host.pid.to_string(),
537 SensitivityLevel::Safe,
538 &mut inspect::adhoc(|req| {
539 req.respond()
540 .merge(&HostInspect {
541 name: &host.name,
542 node_id: host.node_id,
543 #[cfg(target_os = "linux")]
544 rlimit: inspect_rlimit::InspectRlimit::for_pid(
545 host.pid,
546 ),
547 })
548 .merge(inspect::send(
549 &host.send,
550 HostRequest::Inspect,
551 ));
552 }),
553 );
554 }
555 })
556 .sensitivity_field_mut(
557 &format!("hosts/{}", std::process::id()),
558 SensitivityLevel::Safe,
559 &mut inspect::adhoc(|req| {
560 let mut resp = req.respond();
561 resp.merge(&HostInspect {
562 name: &self.mesh_name,
563 node_id: self.node.id(),
564 #[cfg(target_os = "linux")]
565 rlimit: inspect_rlimit::InspectRlimit::new(),
566 });
567 inspect_host(&mut resp);
568 }),
569 );
570 });
571 }
572 MeshRequest::Crash(pid) => {
573 if pid == std::process::id() as i32 {
574 panic!("explicit panic request");
575 }
576
577 let mut found = false;
578 for (_, host) in &self.hosts {
579 if host.pid == pid {
580 host.send.send(HostRequest::Crash);
581 found = true;
582 break;
583 }
584 }
585
586 if !found {
587 tracing::error!("failed to crash process, pid {pid} not found");
588 }
589 }
590 },
591 Event::Done(id) => {
592 self.hosts.remove(id);
593 }
594 }
595 }
596 }
597
598 #[instrument(name = "mesh_spawn_process", skip(self, params), fields(mesh_name = self.mesh_name.as_str(), pid = tracing::field::Empty))]
600 async fn spawn_process(&mut self, params: NewHostParams) -> anyhow::Result<()> {
601 let NewHostParams {
602 config,
603 recv,
604 request_send,
605 } = params;
606
607 let pid;
608 let node_id;
609
610 let (arg0, process_name) = if let Some(n) = &config.process_name {
614 (None, Cow::Borrowed(n))
615 } else {
616 (
617 std::env::args_os().next(),
618 Cow::Owned(std::env::current_exe().context("failed to get current exe path")?),
619 )
620 };
621
622 let name = config.name.clone();
623
624 #[cfg(windows)]
625 let wait = {
626 let (invitation, handle) = self.node.invite(recv).context("mesh node invite error")?;
627 node_id = invitation.address.local_addr.node;
628
629 let invitation_env = base64::engine::general_purpose::STANDARD.encode(
630 mesh::payload::encode(Invitation {
631 node_name: name.clone(),
632 address: invitation.address,
633 directory_handle: invitation.directory.as_raw_handle() as usize,
634 }),
635 );
636
637 let mut args = config.process_args;
638 if !config.skip_worker_arg {
639 args.push(name.clone().into());
640 }
641
642 let mut builder = process::Builder::from_args(
643 arg0.as_ref()
644 .map_or_else(|| process_name.as_os_str(), |x| x.as_os_str()),
645 &args,
646 );
647 if arg0.is_some() {
648 builder.application_name(process_name.as_path());
649 }
650 builder
651 .stdin(process::Stdio::Null)
652 .stdout(process::Stdio::Null)
653 .handle(&invitation.directory)
654 .env(INVITATION_ENV_NAME, invitation_env)
655 .job(self.job.as_handle());
656
657 if let Some(log_file) = config.stderr.as_ref() {
658 builder.stderr(process::Stdio::Handle(log_file.as_handle()));
659 }
660
661 if let Some(mut sandbox_profile) = config.sandbox_profile {
662 sandbox_profile.apply(&mut builder);
663 }
664
665 let child = builder.spawn().context("failed to launch mesh process")?;
666 handle.await;
668 pid = child.id() as i32;
669 tracing::Span::current().record("pid", pid);
670 move || {
671 child.wait();
672 let code = child.exit_code();
673 if code == 0 {
674 tracing::info!(pid, name = name.as_str(), "mesh child exited successfully");
675 } else {
676 tracing::error!(pid, name = name.as_str(), code, "mesh child abnormal exit");
677 }
678 }
679 };
680 #[cfg(unix)]
681 let mut wait = {
682 use pal::unix::process;
683
684 let invitation = self
685 .node
686 .invite(recv)
687 .await
688 .context("mesh node invite error")?;
689
690 node_id = invitation.address.local_addr.node;
691
692 let invitation_env = base64::engine::general_purpose::STANDARD.encode(
693 mesh::payload::encode(Invitation {
694 node_name: name.clone(),
695 address: invitation.address,
696 socket_fd: IPC_FD,
697 }),
698 );
699
700 let mut command = process::Builder::new(process_name.into_owned());
701 if let Some(arg0) = arg0 {
702 command.arg0(arg0);
703 }
704 command
705 .args(&config.process_args)
706 .stdin(process::Stdio::Null)
707 .stdout(process::Stdio::Null)
708 .dup_fd(invitation.fd.as_fd(), IPC_FD)
709 .env(INVITATION_ENV_NAME, invitation_env);
710
711 if !config.skip_worker_arg {
712 command.arg(&name);
713 }
714
715 if let Some(log_file) = config.stderr.as_ref() {
716 command.stderr(process::Stdio::Fd(log_file.as_fd()));
717 }
718
719 if let Some(mut sandbox_profile) = config.sandbox_profile {
720 sandbox_profile.apply(&mut command);
721 }
722
723 let mut child = command.spawn().context("failed to launch mesh process")?;
724 pid = child.id();
725 tracing::Span::current().record("pid", pid);
726 move || {
727 let exit_status = child.wait().expect("mesh child wait failure");
728 if let Some(0) = exit_status.code() {
729 tracing::info!(pid, name = name.as_str(), "mesh child exited successfully");
730 } else {
731 tracing::error!(
732 pid,
733 name = name.as_str(),
734 %exit_status,
735 "mesh child abnormal exit"
736 );
737 }
738 }
739 };
740
741 let (wait_send, wait_recv) = mesh::oneshot();
742
743 let id = self.hosts.insert(MeshHostInner {
744 name: config.name,
745 pid,
746 node_id,
747 send: request_send,
748 });
749
750 thread::Builder::new()
751 .name(format!("wait-mesh-child-{}", pid))
752 .spawn(move || {
753 wait();
754 wait_send.send(id);
755 })
756 .unwrap();
757
758 self.waiters.push(wait_recv);
759 Ok(())
760 }
761}