1#![expect(unsafe_code)]
9
10use anyhow::Context;
11use base64::Engine;
12use debug_ptr::DebugPtr;
13use futures::FutureExt;
14use futures::StreamExt;
15use futures::executor::block_on;
16use futures_concurrency::future::Race;
17use inspect::Inspect;
18use inspect::SensitivityLevel;
19use mesh::MeshPayload;
20use mesh::OneshotReceiver;
21use mesh::message::MeshField;
22use mesh::payload::Protobuf;
23use mesh::rpc::Rpc;
24use mesh::rpc::RpcSend;
25#[cfg(unix)]
26use mesh_remote::InvitationAddress;
27#[cfg(unix)]
28use pal::unix::process::Builder as ProcessBuilder;
29#[cfg(windows)]
30use pal::windows::process;
31#[cfg(windows)]
32use pal::windows::process::Builder as ProcessBuilder;
33use pal_async::DefaultPool;
34use pal_async::task::Spawn;
35use pal_async::task::Task;
36use slab::Slab;
37use std::borrow::Cow;
38use std::ffi::OsString;
39use std::fs::File;
40#[cfg(unix)]
41use std::os::unix::prelude::*;
42#[cfg(windows)]
43use std::os::windows::prelude::*;
44use std::path::PathBuf;
45use std::thread;
46use tracing::Instrument;
47use tracing::instrument;
48use unicycle::FuturesUnordered;
49
50#[cfg(windows)]
51type IpcNode = mesh_remote::windows::AlpcNode;
52
53#[cfg(unix)]
54type IpcNode = mesh_remote::unix::UnixNode;
55
56#[cfg(unix)]
57const IPC_FD: i32 = 3;
58
59const INVITATION_ENV_NAME: &str = "MESH_WORKER_INVITATION";
64
65#[derive(Protobuf)]
66struct Invitation {
67 node_name: String,
68 #[cfg(windows)]
69 credentials: mesh_remote::windows::AlpcInvitationCredentials,
70 #[cfg(unix)]
71 address: InvitationAddress,
72 #[cfg(windows)]
73 directory_handle: usize,
74 #[cfg(unix)]
75 socket_fd: i32,
76}
77
78static PROCESS_NAME: DebugPtr<String> = DebugPtr::new();
79
80pub fn try_run_mesh_host<U, F, T>(base_name: &str, f: F) -> anyhow::Result<()>
89where
90 U: 'static + MeshPayload + Send,
91 F: AsyncFnOnce(U) -> anyhow::Result<T>,
92{
93 block_on(async {
94 if let Some(r) = node_from_environment().await? {
95 let NodeResult {
96 node_name,
97 node,
98 initial_port,
99 } = r;
100 PROCESS_NAME.store(&node_name);
101 set_program_name(&format!("{base_name}-{node_name}"));
102 let init = OneshotReceiver::<InitialMessage<U>>::from(initial_port)
103 .await
104 .context("failed to receive initial message")?;
105 let _drop = (
106 f(init.init_message).map(Some),
107 handle_host_requests(init.requests).map(|()| None),
108 )
109 .race()
110 .await
111 .transpose()?;
112
113 tracing::debug!("waiting to shut down node");
114 node.shutdown().await;
115 drop(_drop);
116 std::process::exit(0);
117 }
118 Ok(())
119 })
120}
121
122async fn handle_host_requests(mut recv: mesh::Receiver<HostRequest>) {
123 while let Some(req) = recv.next().await {
124 match req {
125 HostRequest::Inspect(deferred) => {
126 deferred.respond(inspect_host);
127 }
128 HostRequest::Crash => panic!("explicit panic request"),
129 }
130 }
131}
132
133fn set_program_name(name: &str) {
134 let _ = name;
135 #[cfg(target_os = "linux")]
136 {
137 let _ = std::fs::write("/proc/self/comm", name);
138 }
139}
140
141struct NodeResult {
142 node_name: String,
143 node: IpcNode,
144 initial_port: mesh::local_node::Port,
145}
146
147async fn node_from_environment() -> anyhow::Result<Option<NodeResult>> {
151 let invitation_str = match std::env::var(INVITATION_ENV_NAME) {
153 Ok(str) => str,
154 Err(_) => return Ok(None),
155 };
156
157 unsafe {
169 std::env::remove_var(INVITATION_ENV_NAME);
170 }
171
172 let invitation: Invitation = mesh::payload::decode(
173 &base64::engine::general_purpose::STANDARD
174 .decode(invitation_str)
175 .context("failed to base64 decode invitation")?,
176 )
177 .context("failed to protobuf decode invitation")?;
178
179 let (left, right) = mesh::local_node::Port::new_pair();
180
181 let node;
182 #[cfg(windows)]
183 {
184 let directory =
188 unsafe { OwnedHandle::from_raw_handle(invitation.directory_handle as RawHandle) };
189
190 let invitation =
191 mesh_remote::windows::AlpcInvitation::new(invitation.credentials, directory);
192
193 node = mesh_remote::windows::AlpcNode::join(
195 pal_async::windows::TpPool::system(),
196 invitation,
197 left,
198 )
199 .context("failed to join mesh")?;
200 }
201
202 #[cfg(unix)]
203 {
204 let fd = unsafe { OwnedFd::from_raw_fd(invitation.socket_fd) };
208 let invitation = mesh_remote::unix::UnixInvitation {
209 address: invitation.address,
210 fd,
211 };
212
213 let (_, driver) = DefaultPool::spawn_on_thread("mesh-worker-pool");
215 node = mesh_remote::unix::UnixNode::join(driver, invitation, left)
216 .await
217 .context("failed to join mesh")?;
218 }
219
220 Ok(Some(NodeResult {
221 node_name: invitation.node_name,
222 node,
223 initial_port: right,
224 }))
225}
226
227#[derive(Inspect)]
248pub struct Mesh {
249 #[inspect(rename = "name")]
250 mesh_name: String,
251 #[inspect(flatten, send = "MeshRequest::Inspect")]
252 request: mesh::Sender<MeshRequest>,
253 #[inspect(skip)]
254 task: Task<()>,
255}
256
257pub trait SandboxProfile: Send {
259 fn apply(&mut self, builder: &mut ProcessBuilder<'_>);
263
264 fn finalize(&mut self) -> anyhow::Result<()> {
270 Ok(())
271 }
272}
273
274pub struct ProcessConfig {
276 name: String,
277 process_name: Option<PathBuf>,
278 process_args: Vec<OsString>,
279 stderr: Option<File>,
280 skip_worker_arg: bool,
281 sandbox_profile: Option<Box<dyn SandboxProfile + Sync>>,
282 env_vars: Vec<(OsString, OsString)>,
283}
284
285impl ProcessConfig {
286 pub fn new(name: impl Into<String>) -> Self {
289 Self {
290 name: name.into(),
291 process_name: None,
292 process_args: Vec::new(),
293 stderr: None,
294 skip_worker_arg: false,
295 sandbox_profile: None,
296 env_vars: Vec::new(),
297 }
298 }
299
300 pub fn new_with_sandbox(
303 name: impl Into<String>,
304 sandbox_profile: Box<dyn SandboxProfile + Sync>,
305 ) -> Self {
306 Self {
307 name: name.into(),
308 process_name: None,
309 process_args: Vec::new(),
310 stderr: None,
311 skip_worker_arg: false,
312 sandbox_profile: Some(sandbox_profile),
313 env_vars: Vec::new(),
314 }
315 }
316
317 pub fn process_name(mut self, name: impl Into<PathBuf>) -> Self {
319 self.process_name = Some(name.into());
320 self
321 }
322
323 pub fn skip_worker_arg(mut self, skip: bool) -> Self {
330 self.skip_worker_arg = skip;
331 self
332 }
333
334 pub fn args<I>(mut self, args: I) -> Self
336 where
337 I: IntoIterator,
338 I::Item: Into<OsString>,
339 {
340 self.process_args.extend(args.into_iter().map(|x| x.into()));
341 self
342 }
343
344 pub fn env<I>(mut self, env_vars: I) -> Self
346 where
347 I: IntoIterator,
348 I::Item: Into<(OsString, OsString)>,
349 {
350 self.env_vars.extend(env_vars.into_iter().map(|x| x.into()));
351 self
352 }
353
354 pub fn stderr(mut self, file: Option<File>) -> Self {
356 self.stderr = file;
357 self
358 }
359}
360
361struct MeshInner {
362 requests: mesh::Receiver<MeshRequest>,
363 hosts: Slab<MeshHostInner>,
364 waiters: FuturesUnordered<OneshotReceiver<usize>>,
366 node: IpcNode,
368 mesh_name: String,
370 #[cfg(windows)]
373 job: pal::windows::job::Job,
374}
375
376struct MeshHostInner {
377 name: String,
378 pid: i32,
379 node_id: mesh::NodeId,
380 send: mesh::Sender<HostRequest>,
381}
382
383enum MeshRequest {
384 NewHost(Rpc<NewHostParams, anyhow::Result<i32>>),
385 Inspect(inspect::Deferred),
386 Crash(i32),
387}
388
389struct NewHostParams {
390 config: ProcessConfig,
391 recv: mesh::local_node::Port,
392 request_send: mesh::Sender<HostRequest>,
393}
394
395impl Mesh {
396 pub fn new(mesh_name: String) -> anyhow::Result<Self> {
398 #[cfg(windows)]
399 let job = {
400 let job = pal::windows::job::Job::new().context("failed to create job object")?;
401 job.set_terminate_on_close()
402 .context("failed to set job object terminate on close")?;
403 job
404 };
405
406 #[cfg(windows)]
407 let node = mesh_remote::windows::AlpcNode::new(pal_async::windows::TpPool::system())
408 .context("AlpcNode creation failure")?;
409 #[cfg(unix)]
410 let node = {
411 let (_, driver) = DefaultPool::spawn_on_thread("mesh-worker-pool");
413 mesh_remote::unix::UnixNode::new(driver)
414 };
415
416 let (request, requests) = mesh::channel();
417 let mut inner = MeshInner {
418 requests,
419 hosts: Default::default(),
420 waiters: Default::default(),
421 node,
422 mesh_name: mesh_name.clone(),
423 #[cfg(windows)]
424 job,
425 };
426
427 let (_, driver) = DefaultPool::spawn_on_thread("mesh");
430 let task = driver.spawn(
431 format!("mesh-{}", &mesh_name),
432 async move { inner.run().await },
433 );
434
435 Ok(Self {
436 request,
437 mesh_name,
438 task,
439 })
440 }
441
442 pub async fn launch_host<T: 'static + MeshField + Send>(
450 &self,
451 config: ProcessConfig,
452 initial_message: T,
453 ) -> anyhow::Result<i32> {
454 let (request_send, request_recv) = mesh::channel();
455
456 let (init_send, init_recv) = mesh::oneshot::<InitialMessage<T>>();
457 init_send.send(InitialMessage {
458 requests: request_recv,
459 init_message: initial_message,
460 });
461
462 self.request
463 .call(
464 MeshRequest::NewHost,
465 NewHostParams {
466 config,
467 recv: init_recv.into(),
468 request_send,
469 },
470 )
471 .await
472 .context("mesh failed")?
473 }
474
475 pub async fn shutdown(self) {
479 let span = tracing::span!(
480 tracing::Level::INFO,
481 "mesh_shutdown",
482 name = self.mesh_name.as_str(),
483 );
484
485 async {
486 drop(self.request);
487 self.task.await;
488 }
489 .instrument(span)
490 .await;
491 }
492
493 pub fn crash(&self, pid: i32) {
495 self.request.send(MeshRequest::Crash(pid));
496 }
497}
498
499#[derive(MeshPayload)]
500struct InitialMessage<T> {
501 requests: mesh::Receiver<HostRequest>,
502 init_message: T,
503}
504
505#[derive(Debug, MeshPayload)]
506enum HostRequest {
507 #[mesh(transparent)]
508 Inspect(inspect::Deferred),
509 Crash,
510}
511
512fn inspect_host(resp: &mut inspect::Response<'_>) {
513 resp.field("tasks", inspect_task::inspect_task_list());
514}
515
516#[derive(Inspect)]
517struct HostInspect<'a> {
518 #[inspect(safe)]
519 name: &'a str,
520 #[inspect(debug, safe)]
521 node_id: mesh::NodeId,
522 #[cfg(target_os = "linux")]
523 #[inspect(safe)]
524 rlimit: inspect_rlimit::InspectRlimit,
525}
526
527impl MeshInner {
528 async fn run(&mut self) {
529 enum Event {
530 Request(MeshRequest),
531 Done(usize),
532 }
533
534 loop {
535 let event = futures::select! { request = self.requests.select_next_some() => Event::Request(request),
537 n = self.waiters.select_next_some() => Event::Done(n.unwrap()),
538 complete => break,
539 };
540
541 match event {
542 Event::Request(request) => match request {
543 MeshRequest::NewHost(rpc) => {
544 rpc.handle(async |params| self.spawn_process(params).await)
545 .await
546 }
547 MeshRequest::Inspect(deferred) => {
548 deferred.respond(|resp| {
549 resp.sensitivity_child("hosts", SensitivityLevel::Safe, |req| {
550 let mut resp = req.respond();
551 for host in self.hosts.iter().map(|(_, host)| host) {
552 resp.sensitivity_field_mut(
553 &host.pid.to_string(),
554 SensitivityLevel::Safe,
555 &mut inspect::adhoc(|req| {
556 req.respond()
557 .merge(&HostInspect {
558 name: &host.name,
559 node_id: host.node_id,
560 #[cfg(target_os = "linux")]
561 rlimit: inspect_rlimit::InspectRlimit::for_pid(
562 host.pid,
563 ),
564 })
565 .merge(inspect::send(
566 &host.send,
567 HostRequest::Inspect,
568 ));
569 }),
570 );
571 }
572 })
573 .sensitivity_field_mut(
574 &format!("hosts/{}", std::process::id()),
575 SensitivityLevel::Safe,
576 &mut inspect::adhoc(|req| {
577 let mut resp = req.respond();
578 resp.merge(&HostInspect {
579 name: &self.mesh_name,
580 node_id: self.node.id(),
581 #[cfg(target_os = "linux")]
582 rlimit: inspect_rlimit::InspectRlimit::new(),
583 });
584 inspect_host(&mut resp);
585 }),
586 );
587 });
588 }
589 MeshRequest::Crash(pid) => {
590 if pid == std::process::id() as i32 {
591 panic!("explicit panic request");
592 }
593
594 let mut found = false;
595 for (_, host) in &self.hosts {
596 if host.pid == pid {
597 host.send.send(HostRequest::Crash);
598 found = true;
599 break;
600 }
601 }
602
603 if !found {
604 tracing::error!("failed to crash process, pid {pid} not found");
605 }
606 }
607 },
608 Event::Done(id) => {
609 self.hosts.remove(id);
610 }
611 }
612 }
613 }
614
615 #[instrument(name = "mesh_spawn_process", skip(self, params), fields(mesh_name = self.mesh_name.as_str(), pid = tracing::field::Empty))]
617 async fn spawn_process(&mut self, params: NewHostParams) -> anyhow::Result<i32> {
618 let NewHostParams {
619 config,
620 recv,
621 request_send,
622 } = params;
623
624 let pid;
625 let node_id;
626
627 let (arg0, process_name) = if let Some(n) = &config.process_name {
631 (None, Cow::Borrowed(n))
632 } else {
633 (
634 std::env::args_os().next(),
635 Cow::Owned(std::env::current_exe().context("failed to get current exe path")?),
636 )
637 };
638
639 let name = config.name.clone();
640
641 #[cfg(windows)]
642 let wait = {
643 let (invitation, handle) = self.node.invite(recv).context("mesh node invite error")?;
644 node_id = invitation.node_id();
645 let (credentials, directory) = invitation.into_parts();
646
647 let invitation_env = base64::engine::general_purpose::STANDARD.encode(
648 mesh::payload::encode(Invitation {
649 node_name: name.clone(),
650 credentials,
651 directory_handle: directory.as_raw_handle() as usize,
652 }),
653 );
654
655 let mut args = config.process_args;
656 if !config.skip_worker_arg {
657 args.push(name.clone().into());
658 }
659
660 let mut builder = process::Builder::from_args(
661 arg0.as_ref()
662 .map_or_else(|| process_name.as_os_str(), |x| x.as_os_str()),
663 &args,
664 );
665 if arg0.is_some() {
666 builder.application_name(process_name.as_path());
667 }
668 builder
669 .stdin(process::Stdio::Null)
670 .stdout(process::Stdio::Null)
671 .handle(&directory)
672 .env(INVITATION_ENV_NAME, invitation_env)
673 .extend_env(config.env_vars)
674 .job(self.job.as_handle());
675
676 if let Some(log_file) = config.stderr.as_ref() {
677 builder.stderr(process::Stdio::Handle(log_file.as_handle()));
678 }
679
680 if let Some(mut sandbox_profile) = config.sandbox_profile {
681 sandbox_profile.apply(&mut builder);
682 }
683
684 let child = builder.spawn().context("failed to launch mesh process")?;
685 handle.await;
687 pid = child.id() as i32;
688 tracing::Span::current().record("pid", pid);
689 move || {
690 child.wait();
691 let code = child.exit_code();
692 if code == 0 {
693 tracing::info!(pid, name = name.as_str(), "mesh child exited successfully");
694 } else {
695 tracing::error!(pid, name = name.as_str(), code, "mesh child abnormal exit");
696 }
697 }
698 };
699 #[cfg(unix)]
700 let mut wait = {
701 use pal::unix::process;
702
703 let invitation = self
704 .node
705 .invite(recv)
706 .await
707 .context("mesh node invite error")?;
708
709 node_id = invitation.address.local_addr.node;
710
711 let invitation_env = base64::engine::general_purpose::STANDARD.encode(
712 mesh::payload::encode(Invitation {
713 node_name: name.clone(),
714 address: invitation.address,
715 socket_fd: IPC_FD,
716 }),
717 );
718
719 let mut command = process::Builder::new(process_name.into_owned());
720 if let Some(arg0) = arg0 {
721 command.arg0(arg0);
722 }
723 command
724 .args(&config.process_args)
725 .stdin(process::Stdio::Null)
726 .stdout(process::Stdio::Null)
727 .dup_fd(invitation.fd.as_fd(), IPC_FD)
728 .env(INVITATION_ENV_NAME, invitation_env);
729
730 if !config.skip_worker_arg {
731 command.arg(&name);
732 }
733
734 if let Some(log_file) = config.stderr.as_ref() {
735 command.stderr(process::Stdio::Fd(log_file.as_fd()));
736 }
737
738 if let Some(mut sandbox_profile) = config.sandbox_profile {
739 sandbox_profile.apply(&mut command);
740 }
741
742 let mut child = command.spawn().context("failed to launch mesh process")?;
743 pid = child.id();
744 tracing::Span::current().record("pid", pid);
745 move || {
746 let exit_status = child.wait().expect("mesh child wait failure");
747 if let Some(0) = exit_status.code() {
748 tracing::info!(pid, name = name.as_str(), "mesh child exited successfully");
749 } else {
750 tracing::error!(
751 pid,
752 name = name.as_str(),
753 %exit_status,
754 "mesh child abnormal exit"
755 );
756 }
757 }
758 };
759
760 let (wait_send, wait_recv) = mesh::oneshot();
761
762 let id = self.hosts.insert(MeshHostInner {
763 name: config.name,
764 pid,
765 node_id,
766 send: request_send,
767 });
768
769 thread::Builder::new()
770 .name(format!("wait-mesh-child-{}", pid))
771 .spawn(move || {
772 wait();
773 wait_send.send(id);
774 })
775 .unwrap();
776
777 self.waiters.push(wait_recv);
778 Ok(pid)
779 }
780}