1#![expect(unsafe_code)]
9
10use anyhow::Context;
11use base64::Engine;
12use debug_ptr::DebugPtr;
13use futures::FutureExt;
14use futures::StreamExt;
15use futures::executor::block_on;
16use futures_concurrency::future::Race;
17use inspect::Inspect;
18use inspect::SensitivityLevel;
19use mesh::MeshPayload;
20use mesh::OneshotReceiver;
21use mesh::message::MeshField;
22use mesh::payload::Protobuf;
23use mesh::rpc::Rpc;
24use mesh::rpc::RpcSend;
25use mesh_remote::InvitationAddress;
26#[cfg(unix)]
27use pal::unix::process::Builder as ProcessBuilder;
28#[cfg(windows)]
29use pal::windows::process;
30#[cfg(windows)]
31use pal::windows::process::Builder as ProcessBuilder;
32use pal_async::DefaultPool;
33use pal_async::task::Spawn;
34use pal_async::task::Task;
35use slab::Slab;
36use std::borrow::Cow;
37use std::ffi::OsString;
38use std::fs::File;
39#[cfg(unix)]
40use std::os::unix::prelude::*;
41#[cfg(windows)]
42use std::os::windows::prelude::*;
43use std::path::PathBuf;
44use std::thread;
45use tracing::Instrument;
46use tracing::instrument;
47use unicycle::FuturesUnordered;
48
49#[cfg(windows)]
50type IpcNode = mesh_remote::windows::AlpcNode;
51
52#[cfg(unix)]
53type IpcNode = mesh_remote::unix::UnixNode;
54
55#[cfg(unix)]
56const IPC_FD: i32 = 3;
57
58const INVITATION_ENV_NAME: &str = "MESH_WORKER_INVITATION";
63
64#[derive(Protobuf)]
65struct Invitation {
66 node_name: String,
67 address: InvitationAddress,
68 #[cfg(windows)]
69 directory_handle: usize,
70 #[cfg(unix)]
71 socket_fd: i32,
72}
73
74static PROCESS_NAME: DebugPtr<String> = DebugPtr::new();
75
76pub fn try_run_mesh_host<U, F, T>(base_name: &str, f: F) -> anyhow::Result<()>
85where
86 U: 'static + MeshPayload + Send,
87 F: AsyncFnOnce(U) -> anyhow::Result<T>,
88{
89 block_on(async {
90 if let Some(r) = node_from_environment().await? {
91 let NodeResult {
92 node_name,
93 node,
94 initial_port,
95 } = r;
96 PROCESS_NAME.store(&node_name);
97 set_program_name(&format!("{base_name}-{node_name}"));
98 let init = OneshotReceiver::<InitialMessage<U>>::from(initial_port)
99 .await
100 .context("failed to receive initial message")?;
101 let _drop = (
102 f(init.init_message).map(Some),
103 handle_host_requests(init.requests).map(|()| None),
104 )
105 .race()
106 .await
107 .transpose()?;
108
109 tracing::debug!("waiting to shut down node");
110 node.shutdown().await;
111 drop(_drop);
112 std::process::exit(0);
113 }
114 Ok(())
115 })
116}
117
118async fn handle_host_requests(mut recv: mesh::Receiver<HostRequest>) {
119 while let Some(req) = recv.next().await {
120 match req {
121 HostRequest::Inspect(deferred) => {
122 deferred.respond(inspect_host);
123 }
124 HostRequest::Crash => panic!("explicit panic request"),
125 }
126 }
127}
128
129fn set_program_name(name: &str) {
130 let _ = name;
131 #[cfg(target_os = "linux")]
132 {
133 let _ = std::fs::write("/proc/self/comm", name);
134 }
135}
136
137struct NodeResult {
138 node_name: String,
139 node: IpcNode,
140 initial_port: mesh::local_node::Port,
141}
142
143async fn node_from_environment() -> anyhow::Result<Option<NodeResult>> {
147 let invitation_str = match std::env::var(INVITATION_ENV_NAME) {
149 Ok(str) => str,
150 Err(_) => return Ok(None),
151 };
152
153 unsafe {
165 std::env::remove_var(INVITATION_ENV_NAME);
166 }
167
168 let invitation: Invitation = mesh::payload::decode(
169 &base64::engine::general_purpose::STANDARD
170 .decode(invitation_str)
171 .context("failed to base64 decode invitation")?,
172 )
173 .context("failed to protobuf decode invitation")?;
174
175 let (left, right) = mesh::local_node::Port::new_pair();
176
177 let node;
178 #[cfg(windows)]
179 {
180 let directory =
184 unsafe { OwnedHandle::from_raw_handle(invitation.directory_handle as RawHandle) };
185
186 let invitation = mesh_remote::windows::AlpcInvitation {
187 address: invitation.address,
188 directory,
189 };
190
191 node = mesh_remote::windows::AlpcNode::join(
193 pal_async::windows::TpPool::system(),
194 invitation,
195 left,
196 )
197 .context("failed to join mesh")?;
198 }
199
200 #[cfg(unix)]
201 {
202 let fd = unsafe { OwnedFd::from_raw_fd(invitation.socket_fd) };
206 let invitation = mesh_remote::unix::UnixInvitation {
207 address: invitation.address,
208 fd,
209 };
210
211 let (_, driver) = DefaultPool::spawn_on_thread("mesh-worker-pool");
213 node = mesh_remote::unix::UnixNode::join(driver, invitation, left)
214 .await
215 .context("failed to join mesh")?;
216 }
217
218 Ok(Some(NodeResult {
219 node_name: invitation.node_name,
220 node,
221 initial_port: right,
222 }))
223}
224
225#[derive(Inspect)]
246pub struct Mesh {
247 #[inspect(rename = "name")]
248 mesh_name: String,
249 #[inspect(flatten, send = "MeshRequest::Inspect")]
250 request: mesh::Sender<MeshRequest>,
251 #[inspect(skip)]
252 task: Task<()>,
253}
254
255pub trait SandboxProfile: Send {
257 fn apply(&mut self, builder: &mut ProcessBuilder<'_>);
261
262 fn finalize(&mut self) -> anyhow::Result<()> {
268 Ok(())
269 }
270}
271
272pub struct ProcessConfig {
274 name: String,
275 process_name: Option<PathBuf>,
276 process_args: Vec<OsString>,
277 stderr: Option<File>,
278 skip_worker_arg: bool,
279 sandbox_profile: Option<Box<dyn SandboxProfile + Sync>>,
280 env_vars: Vec<(OsString, OsString)>,
281}
282
283impl ProcessConfig {
284 pub fn new(name: impl Into<String>) -> Self {
287 Self {
288 name: name.into(),
289 process_name: None,
290 process_args: Vec::new(),
291 stderr: None,
292 skip_worker_arg: false,
293 sandbox_profile: None,
294 env_vars: Vec::new(),
295 }
296 }
297
298 pub fn new_with_sandbox(
301 name: impl Into<String>,
302 sandbox_profile: Box<dyn SandboxProfile + Sync>,
303 ) -> Self {
304 Self {
305 name: name.into(),
306 process_name: None,
307 process_args: Vec::new(),
308 stderr: None,
309 skip_worker_arg: false,
310 sandbox_profile: Some(sandbox_profile),
311 env_vars: Vec::new(),
312 }
313 }
314
315 pub fn process_name(mut self, name: impl Into<PathBuf>) -> Self {
317 self.process_name = Some(name.into());
318 self
319 }
320
321 pub fn skip_worker_arg(mut self, skip: bool) -> Self {
328 self.skip_worker_arg = skip;
329 self
330 }
331
332 pub fn args<I>(mut self, args: I) -> Self
334 where
335 I: IntoIterator,
336 I::Item: Into<OsString>,
337 {
338 self.process_args.extend(args.into_iter().map(|x| x.into()));
339 self
340 }
341
342 pub fn env<I>(mut self, env_vars: I) -> Self
344 where
345 I: IntoIterator,
346 I::Item: Into<(OsString, OsString)>,
347 {
348 self.env_vars.extend(env_vars.into_iter().map(|x| x.into()));
349 self
350 }
351
352 pub fn stderr(mut self, file: Option<File>) -> Self {
354 self.stderr = file;
355 self
356 }
357}
358
359struct MeshInner {
360 requests: mesh::Receiver<MeshRequest>,
361 hosts: Slab<MeshHostInner>,
362 waiters: FuturesUnordered<OneshotReceiver<usize>>,
364 node: IpcNode,
366 mesh_name: String,
368 #[cfg(windows)]
371 job: pal::windows::job::Job,
372}
373
374struct MeshHostInner {
375 name: String,
376 pid: i32,
377 node_id: mesh::NodeId,
378 send: mesh::Sender<HostRequest>,
379}
380
381enum MeshRequest {
382 NewHost(Rpc<NewHostParams, anyhow::Result<()>>),
383 Inspect(inspect::Deferred),
384 Crash(i32),
385}
386
387struct NewHostParams {
388 config: ProcessConfig,
389 recv: mesh::local_node::Port,
390 request_send: mesh::Sender<HostRequest>,
391}
392
393impl Mesh {
394 pub fn new(mesh_name: String) -> anyhow::Result<Self> {
396 #[cfg(windows)]
397 let job = {
398 let job = pal::windows::job::Job::new().context("failed to create job object")?;
399 job.set_terminate_on_close()
400 .context("failed to set job object terminate on close")?;
401 job
402 };
403
404 #[cfg(windows)]
405 let node = mesh_remote::windows::AlpcNode::new(pal_async::windows::TpPool::system())
406 .context("AlpcNode creation failure")?;
407 #[cfg(unix)]
408 let node = {
409 let (_, driver) = DefaultPool::spawn_on_thread("mesh-worker-pool");
411 mesh_remote::unix::UnixNode::new(driver)
412 };
413
414 let (request, requests) = mesh::channel();
415 let mut inner = MeshInner {
416 requests,
417 hosts: Default::default(),
418 waiters: Default::default(),
419 node,
420 mesh_name: mesh_name.clone(),
421 #[cfg(windows)]
422 job,
423 };
424
425 let (_, driver) = DefaultPool::spawn_on_thread("mesh");
428 let task = driver.spawn(
429 format!("mesh-{}", &mesh_name),
430 async move { inner.run().await },
431 );
432
433 Ok(Self {
434 request,
435 mesh_name,
436 task,
437 })
438 }
439
440 pub async fn launch_host<T: 'static + MeshField + Send>(
446 &self,
447 config: ProcessConfig,
448 initial_message: T,
449 ) -> anyhow::Result<()> {
450 let (request_send, request_recv) = mesh::channel();
451
452 let (init_send, init_recv) = mesh::oneshot::<InitialMessage<T>>();
453 init_send.send(InitialMessage {
454 requests: request_recv,
455 init_message: initial_message,
456 });
457
458 self.request
459 .call(
460 MeshRequest::NewHost,
461 NewHostParams {
462 config,
463 recv: init_recv.into(),
464 request_send,
465 },
466 )
467 .await
468 .context("mesh failed")?
469 }
470
471 pub async fn shutdown(self) {
475 let span = tracing::span!(
476 tracing::Level::INFO,
477 "mesh_shutdown",
478 name = self.mesh_name.as_str(),
479 );
480
481 async {
482 drop(self.request);
483 self.task.await;
484 }
485 .instrument(span)
486 .await;
487 }
488
489 pub fn crash(&self, pid: i32) {
491 self.request.send(MeshRequest::Crash(pid));
492 }
493}
494
495#[derive(MeshPayload)]
496struct InitialMessage<T> {
497 requests: mesh::Receiver<HostRequest>,
498 init_message: T,
499}
500
501#[derive(Debug, MeshPayload)]
502enum HostRequest {
503 #[mesh(transparent)]
504 Inspect(inspect::Deferred),
505 Crash,
506}
507
508fn inspect_host(resp: &mut inspect::Response<'_>) {
509 resp.field("tasks", inspect_task::inspect_task_list());
510}
511
512#[derive(Inspect)]
513struct HostInspect<'a> {
514 #[inspect(safe)]
515 name: &'a str,
516 #[inspect(debug, safe)]
517 node_id: mesh::NodeId,
518 #[cfg(target_os = "linux")]
519 #[inspect(safe)]
520 rlimit: inspect_rlimit::InspectRlimit,
521}
522
523impl MeshInner {
524 async fn run(&mut self) {
525 enum Event {
526 Request(MeshRequest),
527 Done(usize),
528 }
529
530 loop {
531 let event = futures::select! { request = self.requests.select_next_some() => Event::Request(request),
533 n = self.waiters.select_next_some() => Event::Done(n.unwrap()),
534 complete => break,
535 };
536
537 match event {
538 Event::Request(request) => match request {
539 MeshRequest::NewHost(rpc) => {
540 rpc.handle(async |params| self.spawn_process(params).await)
541 .await
542 }
543 MeshRequest::Inspect(deferred) => {
544 deferred.respond(|resp| {
545 resp.sensitivity_child("hosts", SensitivityLevel::Safe, |req| {
546 let mut resp = req.respond();
547 for host in self.hosts.iter().map(|(_, host)| host) {
548 resp.sensitivity_field_mut(
549 &host.pid.to_string(),
550 SensitivityLevel::Safe,
551 &mut inspect::adhoc(|req| {
552 req.respond()
553 .merge(&HostInspect {
554 name: &host.name,
555 node_id: host.node_id,
556 #[cfg(target_os = "linux")]
557 rlimit: inspect_rlimit::InspectRlimit::for_pid(
558 host.pid,
559 ),
560 })
561 .merge(inspect::send(
562 &host.send,
563 HostRequest::Inspect,
564 ));
565 }),
566 );
567 }
568 })
569 .sensitivity_field_mut(
570 &format!("hosts/{}", std::process::id()),
571 SensitivityLevel::Safe,
572 &mut inspect::adhoc(|req| {
573 let mut resp = req.respond();
574 resp.merge(&HostInspect {
575 name: &self.mesh_name,
576 node_id: self.node.id(),
577 #[cfg(target_os = "linux")]
578 rlimit: inspect_rlimit::InspectRlimit::new(),
579 });
580 inspect_host(&mut resp);
581 }),
582 );
583 });
584 }
585 MeshRequest::Crash(pid) => {
586 if pid == std::process::id() as i32 {
587 panic!("explicit panic request");
588 }
589
590 let mut found = false;
591 for (_, host) in &self.hosts {
592 if host.pid == pid {
593 host.send.send(HostRequest::Crash);
594 found = true;
595 break;
596 }
597 }
598
599 if !found {
600 tracing::error!("failed to crash process, pid {pid} not found");
601 }
602 }
603 },
604 Event::Done(id) => {
605 self.hosts.remove(id);
606 }
607 }
608 }
609 }
610
611 #[instrument(name = "mesh_spawn_process", skip(self, params), fields(mesh_name = self.mesh_name.as_str(), pid = tracing::field::Empty))]
613 async fn spawn_process(&mut self, params: NewHostParams) -> anyhow::Result<()> {
614 let NewHostParams {
615 config,
616 recv,
617 request_send,
618 } = params;
619
620 let pid;
621 let node_id;
622
623 let (arg0, process_name) = if let Some(n) = &config.process_name {
627 (None, Cow::Borrowed(n))
628 } else {
629 (
630 std::env::args_os().next(),
631 Cow::Owned(std::env::current_exe().context("failed to get current exe path")?),
632 )
633 };
634
635 let name = config.name.clone();
636
637 #[cfg(windows)]
638 let wait = {
639 let (invitation, handle) = self.node.invite(recv).context("mesh node invite error")?;
640 node_id = invitation.address.local_addr.node;
641
642 let invitation_env = base64::engine::general_purpose::STANDARD.encode(
643 mesh::payload::encode(Invitation {
644 node_name: name.clone(),
645 address: invitation.address,
646 directory_handle: invitation.directory.as_raw_handle() as usize,
647 }),
648 );
649
650 let mut args = config.process_args;
651 if !config.skip_worker_arg {
652 args.push(name.clone().into());
653 }
654
655 let mut builder = process::Builder::from_args(
656 arg0.as_ref()
657 .map_or_else(|| process_name.as_os_str(), |x| x.as_os_str()),
658 &args,
659 );
660 if arg0.is_some() {
661 builder.application_name(process_name.as_path());
662 }
663 builder
664 .stdin(process::Stdio::Null)
665 .stdout(process::Stdio::Null)
666 .handle(&invitation.directory)
667 .env(INVITATION_ENV_NAME, invitation_env)
668 .extend_env(config.env_vars)
669 .job(self.job.as_handle());
670
671 if let Some(log_file) = config.stderr.as_ref() {
672 builder.stderr(process::Stdio::Handle(log_file.as_handle()));
673 }
674
675 if let Some(mut sandbox_profile) = config.sandbox_profile {
676 sandbox_profile.apply(&mut builder);
677 }
678
679 let child = builder.spawn().context("failed to launch mesh process")?;
680 handle.await;
682 pid = child.id() as i32;
683 tracing::Span::current().record("pid", pid);
684 move || {
685 child.wait();
686 let code = child.exit_code();
687 if code == 0 {
688 tracing::info!(pid, name = name.as_str(), "mesh child exited successfully");
689 } else {
690 tracing::error!(pid, name = name.as_str(), code, "mesh child abnormal exit");
691 }
692 }
693 };
694 #[cfg(unix)]
695 let mut wait = {
696 use pal::unix::process;
697
698 let invitation = self
699 .node
700 .invite(recv)
701 .await
702 .context("mesh node invite error")?;
703
704 node_id = invitation.address.local_addr.node;
705
706 let invitation_env = base64::engine::general_purpose::STANDARD.encode(
707 mesh::payload::encode(Invitation {
708 node_name: name.clone(),
709 address: invitation.address,
710 socket_fd: IPC_FD,
711 }),
712 );
713
714 let mut command = process::Builder::new(process_name.into_owned());
715 if let Some(arg0) = arg0 {
716 command.arg0(arg0);
717 }
718 command
719 .args(&config.process_args)
720 .stdin(process::Stdio::Null)
721 .stdout(process::Stdio::Null)
722 .dup_fd(invitation.fd.as_fd(), IPC_FD)
723 .env(INVITATION_ENV_NAME, invitation_env);
724
725 if !config.skip_worker_arg {
726 command.arg(&name);
727 }
728
729 if let Some(log_file) = config.stderr.as_ref() {
730 command.stderr(process::Stdio::Fd(log_file.as_fd()));
731 }
732
733 if let Some(mut sandbox_profile) = config.sandbox_profile {
734 sandbox_profile.apply(&mut command);
735 }
736
737 let mut child = command.spawn().context("failed to launch mesh process")?;
738 pid = child.id();
739 tracing::Span::current().record("pid", pid);
740 move || {
741 let exit_status = child.wait().expect("mesh child wait failure");
742 if let Some(0) = exit_status.code() {
743 tracing::info!(pid, name = name.as_str(), "mesh child exited successfully");
744 } else {
745 tracing::error!(
746 pid,
747 name = name.as_str(),
748 %exit_status,
749 "mesh child abnormal exit"
750 );
751 }
752 }
753 };
754
755 let (wait_send, wait_recv) = mesh::oneshot();
756
757 let id = self.hosts.insert(MeshHostInner {
758 name: config.name,
759 pid,
760 node_id,
761 send: request_send,
762 });
763
764 thread::Builder::new()
765 .name(format!("wait-mesh-child-{}", pid))
766 .spawn(move || {
767 wait();
768 wait_send.send(id);
769 })
770 .unwrap();
771
772 self.waiters.push(wait_recv);
773 Ok(())
774 }
775}