1#![expect(unsafe_code)]
9
10use anyhow::Context;
11use base64::Engine;
12use debug_ptr::DebugPtr;
13use futures::FutureExt;
14use futures::StreamExt;
15use futures::executor::block_on;
16use futures_concurrency::future::Race;
17use inspect::Inspect;
18use inspect::SensitivityLevel;
19use mesh::MeshPayload;
20use mesh::OneshotReceiver;
21use mesh::message::MeshField;
22use mesh::payload::Protobuf;
23use mesh::rpc::Rpc;
24use mesh::rpc::RpcSend;
25use mesh_remote::InvitationAddress;
26#[cfg(unix)]
27use pal::unix::process::Builder as ProcessBuilder;
28#[cfg(windows)]
29use pal::windows::process;
30#[cfg(windows)]
31use pal::windows::process::Builder as ProcessBuilder;
32use pal_async::DefaultPool;
33use pal_async::task::Spawn;
34use pal_async::task::Task;
35use slab::Slab;
36use std::borrow::Cow;
37use std::ffi::OsString;
38use std::fs::File;
39#[cfg(unix)]
40use std::os::unix::prelude::*;
41#[cfg(windows)]
42use std::os::windows::prelude::*;
43use std::path::PathBuf;
44use std::thread;
45use tracing::Instrument;
46use tracing::instrument;
47use unicycle::FuturesUnordered;
48
49#[cfg(windows)]
50type IpcNode = mesh_remote::windows::AlpcNode;
51
52#[cfg(unix)]
53type IpcNode = mesh_remote::unix::UnixNode;
54
55#[cfg(unix)]
56const IPC_FD: i32 = 3;
57
58const INVITATION_ENV_NAME: &str = "MESH_WORKER_INVITATION";
63
64#[derive(Protobuf)]
65struct Invitation {
66 node_name: String,
67 address: InvitationAddress,
68 #[cfg(windows)]
69 directory_handle: usize,
70 #[cfg(unix)]
71 socket_fd: i32,
72}
73
74static PROCESS_NAME: DebugPtr<String> = DebugPtr::new();
75
76pub fn try_run_mesh_host<U, F, T>(base_name: &str, f: F) -> anyhow::Result<()>
85where
86 U: 'static + MeshPayload + Send,
87 F: AsyncFnOnce(U) -> anyhow::Result<T>,
88{
89 block_on(async {
90 if let Some(r) = node_from_environment().await? {
91 let NodeResult {
92 node_name,
93 node,
94 initial_port,
95 } = r;
96 PROCESS_NAME.store(&node_name);
97 set_program_name(&format!("{base_name}-{node_name}"));
98 let init = OneshotReceiver::<InitialMessage<U>>::from(initial_port)
99 .await
100 .context("failed to receive initial message")?;
101 let _drop = (
102 f(init.init_message).map(Some),
103 handle_host_requests(init.requests).map(|()| None),
104 )
105 .race()
106 .await
107 .transpose()?;
108
109 tracing::debug!("waiting to shut down node");
110 node.shutdown().await;
111 drop(_drop);
112 std::process::exit(0);
113 }
114 Ok(())
115 })
116}
117
118async fn handle_host_requests(mut recv: mesh::Receiver<HostRequest>) {
119 while let Some(req) = recv.next().await {
120 match req {
121 HostRequest::Inspect(deferred) => {
122 deferred.respond(inspect_host);
123 }
124 HostRequest::Crash => panic!("explicit panic request"),
125 }
126 }
127}
128
129fn set_program_name(name: &str) {
130 let _ = name;
131 #[cfg(target_os = "linux")]
132 {
133 let _ = std::fs::write("/proc/self/comm", name);
134 }
135}
136
137struct NodeResult {
138 node_name: String,
139 node: IpcNode,
140 initial_port: mesh::local_node::Port,
141}
142
143async fn node_from_environment() -> anyhow::Result<Option<NodeResult>> {
147 let invitation_str = match std::env::var(INVITATION_ENV_NAME) {
149 Ok(str) => str,
150 Err(_) => return Ok(None),
151 };
152
153 #[expect(deprecated_safe_2024)]
163 std::env::remove_var(INVITATION_ENV_NAME);
164
165 let invitation: Invitation = mesh::payload::decode(
166 &base64::engine::general_purpose::STANDARD
167 .decode(invitation_str)
168 .context("failed to base64 decode invitation")?,
169 )
170 .context("failed to protobuf decode invitation")?;
171
172 let (left, right) = mesh::local_node::Port::new_pair();
173
174 let node;
175 #[cfg(windows)]
176 {
177 let directory =
181 unsafe { OwnedHandle::from_raw_handle(invitation.directory_handle as RawHandle) };
182
183 let invitation = mesh_remote::windows::AlpcInvitation {
184 address: invitation.address,
185 directory,
186 };
187
188 node = mesh_remote::windows::AlpcNode::join(
190 pal_async::windows::TpPool::system(),
191 invitation,
192 left,
193 )
194 .context("failed to join mesh")?;
195 }
196
197 #[cfg(unix)]
198 {
199 let fd = unsafe { OwnedFd::from_raw_fd(invitation.socket_fd) };
203 let invitation = mesh_remote::unix::UnixInvitation {
204 address: invitation.address,
205 fd,
206 };
207
208 let (_, driver) = DefaultPool::spawn_on_thread("mesh-worker-pool");
210 node = mesh_remote::unix::UnixNode::join(driver, invitation, left)
211 .await
212 .context("failed to join mesh")?;
213 }
214
215 Ok(Some(NodeResult {
216 node_name: invitation.node_name,
217 node,
218 initial_port: right,
219 }))
220}
221
222pub struct Mesh {
243 mesh_name: String,
244 request: mesh::Sender<MeshRequest>,
245 task: Task<()>,
246}
247
248pub trait SandboxProfile: Send {
250 fn apply(&mut self, builder: &mut ProcessBuilder<'_>);
254
255 fn finalize(&mut self) -> anyhow::Result<()> {
261 Ok(())
262 }
263}
264
265pub struct ProcessConfig {
267 name: String,
268 process_name: Option<PathBuf>,
269 process_args: Vec<OsString>,
270 stderr: Option<File>,
271 skip_worker_arg: bool,
272 sandbox_profile: Option<Box<dyn SandboxProfile + Sync>>,
273}
274
275impl ProcessConfig {
276 pub fn new(name: impl Into<String>) -> Self {
279 Self {
280 name: name.into(),
281 process_name: None,
282 process_args: Vec::new(),
283 stderr: None,
284 skip_worker_arg: false,
285 sandbox_profile: None,
286 }
287 }
288
289 pub fn new_with_sandbox(
292 name: impl Into<String>,
293 sandbox_profile: Box<dyn SandboxProfile + Sync>,
294 ) -> Self {
295 Self {
296 name: name.into(),
297 process_name: None,
298 process_args: Vec::new(),
299 stderr: None,
300 skip_worker_arg: false,
301 sandbox_profile: Some(sandbox_profile),
302 }
303 }
304
305 pub fn process_name(mut self, name: impl Into<PathBuf>) -> Self {
307 self.process_name = Some(name.into());
308 self
309 }
310
311 pub fn skip_worker_arg(mut self, skip: bool) -> Self {
318 self.skip_worker_arg = skip;
319 self
320 }
321
322 pub fn args<I>(mut self, args: I) -> Self
324 where
325 I: IntoIterator,
326 I::Item: Into<OsString>,
327 {
328 self.process_args.extend(args.into_iter().map(|x| x.into()));
329 self
330 }
331
332 pub fn stderr(mut self, file: Option<File>) -> Self {
334 self.stderr = file;
335 self
336 }
337}
338
339struct MeshInner {
340 requests: mesh::Receiver<MeshRequest>,
341 hosts: Slab<MeshHostInner>,
342 waiters: FuturesUnordered<OneshotReceiver<usize>>,
344 node: IpcNode,
346 mesh_name: String,
348 #[cfg(windows)]
351 job: pal::windows::job::Job,
352}
353
354struct MeshHostInner {
355 name: String,
356 pid: i32,
357 node_id: mesh::NodeId,
358 send: mesh::Sender<HostRequest>,
359}
360
361enum MeshRequest {
362 NewHost(Rpc<NewHostParams, anyhow::Result<()>>),
363 Inspect(inspect::Deferred),
364 Crash(i32),
365}
366
367struct NewHostParams {
368 config: ProcessConfig,
369 recv: mesh::local_node::Port,
370 request_send: mesh::Sender<HostRequest>,
371}
372
373impl Inspect for Mesh {
374 fn inspect(&self, req: inspect::Request<'_>) {
375 self.request.send(MeshRequest::Inspect(req.defer()));
376 }
377}
378
379impl Mesh {
380 pub fn new(mesh_name: String) -> anyhow::Result<Self> {
382 #[cfg(windows)]
383 let job = {
384 let job = pal::windows::job::Job::new().context("failed to create job object")?;
385 job.set_terminate_on_close()
386 .context("failed to set job object terminate on close")?;
387 job
388 };
389
390 #[cfg(windows)]
391 let node = mesh_remote::windows::AlpcNode::new(pal_async::windows::TpPool::system())
392 .context("AlpcNode creation failure")?;
393 #[cfg(unix)]
394 let node = {
395 let (_, driver) = DefaultPool::spawn_on_thread("mesh-worker-pool");
397 mesh_remote::unix::UnixNode::new(driver)
398 };
399
400 let (request, requests) = mesh::channel();
401 let mut inner = MeshInner {
402 requests,
403 hosts: Default::default(),
404 waiters: Default::default(),
405 node,
406 mesh_name: mesh_name.clone(),
407 #[cfg(windows)]
408 job,
409 };
410
411 let (_, driver) = DefaultPool::spawn_on_thread("mesh");
414 let task = driver.spawn(
415 format!("mesh-{}", &mesh_name),
416 async move { inner.run().await },
417 );
418
419 Ok(Self {
420 request,
421 mesh_name,
422 task,
423 })
424 }
425
426 pub async fn launch_host<T: 'static + MeshField + Send>(
432 &self,
433 config: ProcessConfig,
434 initial_message: T,
435 ) -> anyhow::Result<()> {
436 let (request_send, request_recv) = mesh::channel();
437
438 let (init_send, init_recv) = mesh::oneshot::<InitialMessage<T>>();
439 init_send.send(InitialMessage {
440 requests: request_recv,
441 init_message: initial_message,
442 });
443
444 self.request
445 .call(
446 MeshRequest::NewHost,
447 NewHostParams {
448 config,
449 recv: init_recv.into(),
450 request_send,
451 },
452 )
453 .await
454 .context("mesh failed")?
455 }
456
457 pub async fn shutdown(self) {
461 let span = tracing::span!(
462 tracing::Level::INFO,
463 "mesh_shutdown",
464 name = self.mesh_name.as_str(),
465 );
466
467 async {
468 drop(self.request);
469 self.task.await;
470 }
471 .instrument(span)
472 .await;
473 }
474
475 pub fn crash(&self, pid: i32) {
477 self.request.send(MeshRequest::Crash(pid));
478 }
479}
480
481#[derive(MeshPayload)]
482struct InitialMessage<T> {
483 requests: mesh::Receiver<HostRequest>,
484 init_message: T,
485}
486
487#[derive(Debug, MeshPayload)]
488enum HostRequest {
489 #[mesh(transparent)]
490 Inspect(inspect::Deferred),
491 Crash,
492}
493
494fn inspect_host(resp: &mut inspect::Response<'_>) {
495 resp.field("tasks", inspect_task::inspect_task_list());
496}
497
498#[derive(Inspect)]
499struct HostInspect<'a> {
500 #[inspect(safe)]
501 name: &'a str,
502 #[inspect(debug, safe)]
503 node_id: mesh::NodeId,
504 #[cfg(target_os = "linux")]
505 #[inspect(safe)]
506 rlimit: inspect_rlimit::InspectRlimit,
507}
508
509impl MeshInner {
510 async fn run(&mut self) {
511 enum Event {
512 Request(MeshRequest),
513 Done(usize),
514 }
515
516 loop {
517 let event = futures::select! { request = self.requests.select_next_some() => Event::Request(request),
519 n = self.waiters.select_next_some() => Event::Done(n.unwrap()),
520 complete => break,
521 };
522
523 match event {
524 Event::Request(request) => match request {
525 MeshRequest::NewHost(rpc) => {
526 rpc.handle(async |params| self.spawn_process(params).await)
527 .await
528 }
529 MeshRequest::Inspect(deferred) => {
530 deferred.respond(|resp| {
531 resp.sensitivity_child("hosts", SensitivityLevel::Safe, |req| {
532 let mut resp = req.respond();
533 for host in self.hosts.iter().map(|(_, host)| host) {
534 resp.sensitivity_field_mut(
535 &host.pid.to_string(),
536 SensitivityLevel::Safe,
537 &mut inspect::adhoc(|req| {
538 req.respond()
539 .merge(&HostInspect {
540 name: &host.name,
541 node_id: host.node_id,
542 #[cfg(target_os = "linux")]
543 rlimit: inspect_rlimit::InspectRlimit::for_pid(
544 host.pid,
545 ),
546 })
547 .merge(inspect::adhoc(|req| {
548 host.send
549 .send(HostRequest::Inspect(req.defer()));
550 }));
551 }),
552 );
553 }
554 })
555 .sensitivity_field_mut(
556 &format!("hosts/{}", std::process::id()),
557 SensitivityLevel::Safe,
558 &mut inspect::adhoc(|req| {
559 let mut resp = req.respond();
560 resp.merge(&HostInspect {
561 name: &self.mesh_name,
562 node_id: self.node.id(),
563 #[cfg(target_os = "linux")]
564 rlimit: inspect_rlimit::InspectRlimit::new(),
565 });
566 inspect_host(&mut resp);
567 }),
568 );
569 });
570 }
571 MeshRequest::Crash(pid) => {
572 if pid == std::process::id() as i32 {
573 panic!("explicit panic request");
574 }
575
576 let mut found = false;
577 for (_, host) in &self.hosts {
578 if host.pid == pid {
579 host.send.send(HostRequest::Crash);
580 found = true;
581 break;
582 }
583 }
584
585 if !found {
586 tracing::error!("failed to crash process, pid {pid} not found");
587 }
588 }
589 },
590 Event::Done(id) => {
591 self.hosts.remove(id);
592 }
593 }
594 }
595 }
596
597 #[instrument(name = "mesh_spawn_process", skip(self, params), fields(mesh_name = self.mesh_name.as_str(), pid = tracing::field::Empty))]
599 async fn spawn_process(&mut self, params: NewHostParams) -> anyhow::Result<()> {
600 let NewHostParams {
601 config,
602 recv,
603 request_send,
604 } = params;
605
606 let pid;
607 let node_id;
608
609 let (arg0, process_name) = if let Some(n) = &config.process_name {
613 (None, Cow::Borrowed(n))
614 } else {
615 (
616 std::env::args_os().next(),
617 Cow::Owned(std::env::current_exe().context("failed to get current exe path")?),
618 )
619 };
620
621 let name = config.name.clone();
622
623 #[cfg(windows)]
624 let wait = {
625 let (invitation, handle) = self.node.invite(recv).context("mesh node invite error")?;
626 node_id = invitation.address.local_addr.node;
627
628 let invitation_env = base64::engine::general_purpose::STANDARD.encode(
629 mesh::payload::encode(Invitation {
630 node_name: name.clone(),
631 address: invitation.address,
632 directory_handle: invitation.directory.as_raw_handle() as usize,
633 }),
634 );
635
636 let mut args = config.process_args;
637 if !config.skip_worker_arg {
638 args.push(name.clone().into());
639 }
640
641 let mut builder = process::Builder::from_args(
642 arg0.as_ref()
643 .map_or_else(|| process_name.as_os_str(), |x| x.as_os_str()),
644 &args,
645 );
646 if arg0.is_some() {
647 builder.application_name(process_name.as_path());
648 }
649 builder
650 .stdin(process::Stdio::Null)
651 .stdout(process::Stdio::Null)
652 .handle(&invitation.directory)
653 .env(INVITATION_ENV_NAME, invitation_env)
654 .job(self.job.as_handle());
655
656 if let Some(log_file) = config.stderr.as_ref() {
657 builder.stderr(process::Stdio::Handle(log_file.as_handle()));
658 }
659
660 if let Some(mut sandbox_profile) = config.sandbox_profile {
661 sandbox_profile.apply(&mut builder);
662 }
663
664 let child = builder.spawn().context("failed to launch mesh process")?;
665 handle.await;
667 pid = child.id() as i32;
668 tracing::Span::current().record("pid", pid);
669 move || {
670 child.wait();
671 let code = child.exit_code();
672 if code == 0 {
673 tracing::info!(pid, name = name.as_str(), "mesh child exited successfully");
674 } else {
675 tracing::error!(pid, name = name.as_str(), code, "mesh child abnormal exit");
676 }
677 }
678 };
679 #[cfg(unix)]
680 let mut wait = {
681 use pal::unix::process;
682
683 let invitation = self
684 .node
685 .invite(recv)
686 .await
687 .context("mesh node invite error")?;
688
689 node_id = invitation.address.local_addr.node;
690
691 let invitation_env = base64::engine::general_purpose::STANDARD.encode(
692 mesh::payload::encode(Invitation {
693 node_name: name.clone(),
694 address: invitation.address,
695 socket_fd: IPC_FD,
696 }),
697 );
698
699 let mut command = process::Builder::new(process_name.into_owned());
700 if let Some(arg0) = arg0 {
701 command.arg0(arg0);
702 }
703 command
704 .args(&config.process_args)
705 .stdin(process::Stdio::Null)
706 .stdout(process::Stdio::Null)
707 .dup_fd(invitation.fd.as_fd(), IPC_FD)
708 .env(INVITATION_ENV_NAME, invitation_env);
709
710 if !config.skip_worker_arg {
711 command.arg(&name);
712 }
713
714 if let Some(log_file) = config.stderr.as_ref() {
715 command.stderr(process::Stdio::Fd(log_file.as_fd()));
716 }
717
718 if let Some(mut sandbox_profile) = config.sandbox_profile {
719 sandbox_profile.apply(&mut command);
720 }
721
722 let mut child = command.spawn().context("failed to launch mesh process")?;
723 pid = child.id();
724 tracing::Span::current().record("pid", pid);
725 move || {
726 let exit_status = child.wait().expect("mesh child wait failure");
727 if let Some(0) = exit_status.code() {
728 tracing::info!(pid, name = name.as_str(), "mesh child exited successfully");
729 } else {
730 tracing::error!(
731 pid,
732 name = name.as_str(),
733 %exit_status,
734 "mesh child abnormal exit"
735 );
736 }
737 }
738 };
739
740 let (wait_send, wait_recv) = mesh::oneshot();
741
742 let id = self.hosts.insert(MeshHostInner {
743 name: config.name,
744 pid,
745 node_id,
746 send: request_send,
747 });
748
749 thread::Builder::new()
750 .name(format!("wait-mesh-child-{}", pid))
751 .spawn(move || {
752 wait();
753 wait_send.send(id);
754 })
755 .unwrap();
756
757 self.waiters.push(wait_recv);
758 Ok(())
759 }
760}