1#![expect(unsafe_code)]
9
10use anyhow::Context;
11use base64::Engine;
12use debug_ptr::DebugPtr;
13use futures::FutureExt;
14use futures::StreamExt;
15use futures::executor::block_on;
16use futures_concurrency::future::Race;
17use inspect::Inspect;
18use inspect::SensitivityLevel;
19use mesh::MeshPayload;
20use mesh::OneshotReceiver;
21use mesh::message::MeshField;
22use mesh::payload::Protobuf;
23use mesh::rpc::Rpc;
24use mesh::rpc::RpcSend;
25use mesh_remote::InvitationAddress;
26#[cfg(unix)]
27use pal::unix::process::Builder as ProcessBuilder;
28#[cfg(windows)]
29use pal::windows::process;
30#[cfg(windows)]
31use pal::windows::process::Builder as ProcessBuilder;
32use pal_async::DefaultPool;
33use pal_async::task::Spawn;
34use pal_async::task::Task;
35use slab::Slab;
36use std::borrow::Cow;
37use std::ffi::OsString;
38use std::fs::File;
39#[cfg(unix)]
40use std::os::unix::prelude::*;
41#[cfg(windows)]
42use std::os::windows::prelude::*;
43use std::path::PathBuf;
44use std::thread;
45use tracing::Instrument;
46use tracing::instrument;
47use unicycle::FuturesUnordered;
48
49#[cfg(windows)]
50type IpcNode = mesh_remote::windows::AlpcNode;
51
52#[cfg(unix)]
53type IpcNode = mesh_remote::unix::UnixNode;
54
55#[cfg(unix)]
56const IPC_FD: i32 = 3;
57
58const INVITATION_ENV_NAME: &str = "MESH_WORKER_INVITATION";
63
64#[derive(Protobuf)]
65struct Invitation {
66 node_name: String,
67 address: InvitationAddress,
68 #[cfg(windows)]
69 directory_handle: usize,
70 #[cfg(unix)]
71 socket_fd: i32,
72}
73
74static PROCESS_NAME: DebugPtr<String> = DebugPtr::new();
75
76pub fn try_run_mesh_host<U, F, T>(base_name: &str, f: F) -> anyhow::Result<()>
85where
86 U: 'static + MeshPayload + Send,
87 F: AsyncFnOnce(U) -> anyhow::Result<T>,
88{
89 block_on(async {
90 if let Some(r) = node_from_environment().await? {
91 let NodeResult {
92 node_name,
93 node,
94 initial_port,
95 } = r;
96 PROCESS_NAME.store(&node_name);
97 set_program_name(&format!("{base_name}-{node_name}"));
98 let init = OneshotReceiver::<InitialMessage<U>>::from(initial_port)
99 .await
100 .context("failed to receive initial message")?;
101 let _drop = (
102 f(init.init_message).map(Some),
103 handle_host_requests(init.requests).map(|()| None),
104 )
105 .race()
106 .await
107 .transpose()?;
108
109 tracing::debug!("waiting to shut down node");
110 node.shutdown().await;
111 drop(_drop);
112 std::process::exit(0);
113 }
114 Ok(())
115 })
116}
117
118async fn handle_host_requests(mut recv: mesh::Receiver<HostRequest>) {
119 while let Some(req) = recv.next().await {
120 match req {
121 HostRequest::Inspect(deferred) => {
122 deferred.respond(inspect_host);
123 }
124 HostRequest::Crash => panic!("explicit panic request"),
125 }
126 }
127}
128
129fn set_program_name(name: &str) {
130 let _ = name;
131 #[cfg(target_os = "linux")]
132 {
133 let _ = std::fs::write("/proc/self/comm", name);
134 }
135}
136
137struct NodeResult {
138 node_name: String,
139 node: IpcNode,
140 initial_port: mesh::local_node::Port,
141}
142
143async fn node_from_environment() -> anyhow::Result<Option<NodeResult>> {
147 let invitation_str = match std::env::var(INVITATION_ENV_NAME) {
149 Ok(str) => str,
150 Err(_) => return Ok(None),
151 };
152
153 unsafe {
165 std::env::remove_var(INVITATION_ENV_NAME);
166 }
167
168 let invitation: Invitation = mesh::payload::decode(
169 &base64::engine::general_purpose::STANDARD
170 .decode(invitation_str)
171 .context("failed to base64 decode invitation")?,
172 )
173 .context("failed to protobuf decode invitation")?;
174
175 let (left, right) = mesh::local_node::Port::new_pair();
176
177 let node;
178 #[cfg(windows)]
179 {
180 let directory =
184 unsafe { OwnedHandle::from_raw_handle(invitation.directory_handle as RawHandle) };
185
186 let invitation = mesh_remote::windows::AlpcInvitation {
187 address: invitation.address,
188 directory,
189 };
190
191 node = mesh_remote::windows::AlpcNode::join(
193 pal_async::windows::TpPool::system(),
194 invitation,
195 left,
196 )
197 .context("failed to join mesh")?;
198 }
199
200 #[cfg(unix)]
201 {
202 let fd = unsafe { OwnedFd::from_raw_fd(invitation.socket_fd) };
206 let invitation = mesh_remote::unix::UnixInvitation {
207 address: invitation.address,
208 fd,
209 };
210
211 let (_, driver) = DefaultPool::spawn_on_thread("mesh-worker-pool");
213 node = mesh_remote::unix::UnixNode::join(driver, invitation, left)
214 .await
215 .context("failed to join mesh")?;
216 }
217
218 Ok(Some(NodeResult {
219 node_name: invitation.node_name,
220 node,
221 initial_port: right,
222 }))
223}
224
225pub struct Mesh {
246 mesh_name: String,
247 request: mesh::Sender<MeshRequest>,
248 task: Task<()>,
249}
250
251pub trait SandboxProfile: Send {
253 fn apply(&mut self, builder: &mut ProcessBuilder<'_>);
257
258 fn finalize(&mut self) -> anyhow::Result<()> {
264 Ok(())
265 }
266}
267
268pub struct ProcessConfig {
270 name: String,
271 process_name: Option<PathBuf>,
272 process_args: Vec<OsString>,
273 stderr: Option<File>,
274 skip_worker_arg: bool,
275 sandbox_profile: Option<Box<dyn SandboxProfile + Sync>>,
276}
277
278impl ProcessConfig {
279 pub fn new(name: impl Into<String>) -> Self {
282 Self {
283 name: name.into(),
284 process_name: None,
285 process_args: Vec::new(),
286 stderr: None,
287 skip_worker_arg: false,
288 sandbox_profile: None,
289 }
290 }
291
292 pub fn new_with_sandbox(
295 name: impl Into<String>,
296 sandbox_profile: Box<dyn SandboxProfile + Sync>,
297 ) -> Self {
298 Self {
299 name: name.into(),
300 process_name: None,
301 process_args: Vec::new(),
302 stderr: None,
303 skip_worker_arg: false,
304 sandbox_profile: Some(sandbox_profile),
305 }
306 }
307
308 pub fn process_name(mut self, name: impl Into<PathBuf>) -> Self {
310 self.process_name = Some(name.into());
311 self
312 }
313
314 pub fn skip_worker_arg(mut self, skip: bool) -> Self {
321 self.skip_worker_arg = skip;
322 self
323 }
324
325 pub fn args<I>(mut self, args: I) -> Self
327 where
328 I: IntoIterator,
329 I::Item: Into<OsString>,
330 {
331 self.process_args.extend(args.into_iter().map(|x| x.into()));
332 self
333 }
334
335 pub fn stderr(mut self, file: Option<File>) -> Self {
337 self.stderr = file;
338 self
339 }
340}
341
342struct MeshInner {
343 requests: mesh::Receiver<MeshRequest>,
344 hosts: Slab<MeshHostInner>,
345 waiters: FuturesUnordered<OneshotReceiver<usize>>,
347 node: IpcNode,
349 mesh_name: String,
351 #[cfg(windows)]
354 job: pal::windows::job::Job,
355}
356
357struct MeshHostInner {
358 name: String,
359 pid: i32,
360 node_id: mesh::NodeId,
361 send: mesh::Sender<HostRequest>,
362}
363
364enum MeshRequest {
365 NewHost(Rpc<NewHostParams, anyhow::Result<()>>),
366 Inspect(inspect::Deferred),
367 Crash(i32),
368}
369
370struct NewHostParams {
371 config: ProcessConfig,
372 recv: mesh::local_node::Port,
373 request_send: mesh::Sender<HostRequest>,
374}
375
376impl Inspect for Mesh {
377 fn inspect(&self, req: inspect::Request<'_>) {
378 self.request.send(MeshRequest::Inspect(req.defer()));
379 }
380}
381
382impl Mesh {
383 pub fn new(mesh_name: String) -> anyhow::Result<Self> {
385 #[cfg(windows)]
386 let job = {
387 let job = pal::windows::job::Job::new().context("failed to create job object")?;
388 job.set_terminate_on_close()
389 .context("failed to set job object terminate on close")?;
390 job
391 };
392
393 #[cfg(windows)]
394 let node = mesh_remote::windows::AlpcNode::new(pal_async::windows::TpPool::system())
395 .context("AlpcNode creation failure")?;
396 #[cfg(unix)]
397 let node = {
398 let (_, driver) = DefaultPool::spawn_on_thread("mesh-worker-pool");
400 mesh_remote::unix::UnixNode::new(driver)
401 };
402
403 let (request, requests) = mesh::channel();
404 let mut inner = MeshInner {
405 requests,
406 hosts: Default::default(),
407 waiters: Default::default(),
408 node,
409 mesh_name: mesh_name.clone(),
410 #[cfg(windows)]
411 job,
412 };
413
414 let (_, driver) = DefaultPool::spawn_on_thread("mesh");
417 let task = driver.spawn(
418 format!("mesh-{}", &mesh_name),
419 async move { inner.run().await },
420 );
421
422 Ok(Self {
423 request,
424 mesh_name,
425 task,
426 })
427 }
428
429 pub async fn launch_host<T: 'static + MeshField + Send>(
435 &self,
436 config: ProcessConfig,
437 initial_message: T,
438 ) -> anyhow::Result<()> {
439 let (request_send, request_recv) = mesh::channel();
440
441 let (init_send, init_recv) = mesh::oneshot::<InitialMessage<T>>();
442 init_send.send(InitialMessage {
443 requests: request_recv,
444 init_message: initial_message,
445 });
446
447 self.request
448 .call(
449 MeshRequest::NewHost,
450 NewHostParams {
451 config,
452 recv: init_recv.into(),
453 request_send,
454 },
455 )
456 .await
457 .context("mesh failed")?
458 }
459
460 pub async fn shutdown(self) {
464 let span = tracing::span!(
465 tracing::Level::INFO,
466 "mesh_shutdown",
467 name = self.mesh_name.as_str(),
468 );
469
470 async {
471 drop(self.request);
472 self.task.await;
473 }
474 .instrument(span)
475 .await;
476 }
477
478 pub fn crash(&self, pid: i32) {
480 self.request.send(MeshRequest::Crash(pid));
481 }
482}
483
484#[derive(MeshPayload)]
485struct InitialMessage<T> {
486 requests: mesh::Receiver<HostRequest>,
487 init_message: T,
488}
489
490#[derive(Debug, MeshPayload)]
491enum HostRequest {
492 #[mesh(transparent)]
493 Inspect(inspect::Deferred),
494 Crash,
495}
496
497fn inspect_host(resp: &mut inspect::Response<'_>) {
498 resp.field("tasks", inspect_task::inspect_task_list());
499}
500
501#[derive(Inspect)]
502struct HostInspect<'a> {
503 #[inspect(safe)]
504 name: &'a str,
505 #[inspect(debug, safe)]
506 node_id: mesh::NodeId,
507 #[cfg(target_os = "linux")]
508 #[inspect(safe)]
509 rlimit: inspect_rlimit::InspectRlimit,
510}
511
512impl MeshInner {
513 async fn run(&mut self) {
514 enum Event {
515 Request(MeshRequest),
516 Done(usize),
517 }
518
519 loop {
520 let event = futures::select! { request = self.requests.select_next_some() => Event::Request(request),
522 n = self.waiters.select_next_some() => Event::Done(n.unwrap()),
523 complete => break,
524 };
525
526 match event {
527 Event::Request(request) => match request {
528 MeshRequest::NewHost(rpc) => {
529 rpc.handle(async |params| self.spawn_process(params).await)
530 .await
531 }
532 MeshRequest::Inspect(deferred) => {
533 deferred.respond(|resp| {
534 resp.sensitivity_child("hosts", SensitivityLevel::Safe, |req| {
535 let mut resp = req.respond();
536 for host in self.hosts.iter().map(|(_, host)| host) {
537 resp.sensitivity_field_mut(
538 &host.pid.to_string(),
539 SensitivityLevel::Safe,
540 &mut inspect::adhoc(|req| {
541 req.respond()
542 .merge(&HostInspect {
543 name: &host.name,
544 node_id: host.node_id,
545 #[cfg(target_os = "linux")]
546 rlimit: inspect_rlimit::InspectRlimit::for_pid(
547 host.pid,
548 ),
549 })
550 .merge(inspect::adhoc(|req| {
551 host.send
552 .send(HostRequest::Inspect(req.defer()));
553 }));
554 }),
555 );
556 }
557 })
558 .sensitivity_field_mut(
559 &format!("hosts/{}", std::process::id()),
560 SensitivityLevel::Safe,
561 &mut inspect::adhoc(|req| {
562 let mut resp = req.respond();
563 resp.merge(&HostInspect {
564 name: &self.mesh_name,
565 node_id: self.node.id(),
566 #[cfg(target_os = "linux")]
567 rlimit: inspect_rlimit::InspectRlimit::new(),
568 });
569 inspect_host(&mut resp);
570 }),
571 );
572 });
573 }
574 MeshRequest::Crash(pid) => {
575 if pid == std::process::id() as i32 {
576 panic!("explicit panic request");
577 }
578
579 let mut found = false;
580 for (_, host) in &self.hosts {
581 if host.pid == pid {
582 host.send.send(HostRequest::Crash);
583 found = true;
584 break;
585 }
586 }
587
588 if !found {
589 tracing::error!("failed to crash process, pid {pid} not found");
590 }
591 }
592 },
593 Event::Done(id) => {
594 self.hosts.remove(id);
595 }
596 }
597 }
598 }
599
600 #[instrument(name = "mesh_spawn_process", skip(self, params), fields(mesh_name = self.mesh_name.as_str(), pid = tracing::field::Empty))]
602 async fn spawn_process(&mut self, params: NewHostParams) -> anyhow::Result<()> {
603 let NewHostParams {
604 config,
605 recv,
606 request_send,
607 } = params;
608
609 let pid;
610 let node_id;
611
612 let (arg0, process_name) = if let Some(n) = &config.process_name {
616 (None, Cow::Borrowed(n))
617 } else {
618 (
619 std::env::args_os().next(),
620 Cow::Owned(std::env::current_exe().context("failed to get current exe path")?),
621 )
622 };
623
624 let name = config.name.clone();
625
626 #[cfg(windows)]
627 let wait = {
628 let (invitation, handle) = self.node.invite(recv).context("mesh node invite error")?;
629 node_id = invitation.address.local_addr.node;
630
631 let invitation_env = base64::engine::general_purpose::STANDARD.encode(
632 mesh::payload::encode(Invitation {
633 node_name: name.clone(),
634 address: invitation.address,
635 directory_handle: invitation.directory.as_raw_handle() as usize,
636 }),
637 );
638
639 let mut args = config.process_args;
640 if !config.skip_worker_arg {
641 args.push(name.clone().into());
642 }
643
644 let mut builder = process::Builder::from_args(
645 arg0.as_ref()
646 .map_or_else(|| process_name.as_os_str(), |x| x.as_os_str()),
647 &args,
648 );
649 if arg0.is_some() {
650 builder.application_name(process_name.as_path());
651 }
652 builder
653 .stdin(process::Stdio::Null)
654 .stdout(process::Stdio::Null)
655 .handle(&invitation.directory)
656 .env(INVITATION_ENV_NAME, invitation_env)
657 .job(self.job.as_handle());
658
659 if let Some(log_file) = config.stderr.as_ref() {
660 builder.stderr(process::Stdio::Handle(log_file.as_handle()));
661 }
662
663 if let Some(mut sandbox_profile) = config.sandbox_profile {
664 sandbox_profile.apply(&mut builder);
665 }
666
667 let child = builder.spawn().context("failed to launch mesh process")?;
668 handle.await;
670 pid = child.id() as i32;
671 tracing::Span::current().record("pid", pid);
672 move || {
673 child.wait();
674 let code = child.exit_code();
675 if code == 0 {
676 tracing::info!(pid, name = name.as_str(), "mesh child exited successfully");
677 } else {
678 tracing::error!(pid, name = name.as_str(), code, "mesh child abnormal exit");
679 }
680 }
681 };
682 #[cfg(unix)]
683 let mut wait = {
684 use pal::unix::process;
685
686 let invitation = self
687 .node
688 .invite(recv)
689 .await
690 .context("mesh node invite error")?;
691
692 node_id = invitation.address.local_addr.node;
693
694 let invitation_env = base64::engine::general_purpose::STANDARD.encode(
695 mesh::payload::encode(Invitation {
696 node_name: name.clone(),
697 address: invitation.address,
698 socket_fd: IPC_FD,
699 }),
700 );
701
702 let mut command = process::Builder::new(process_name.into_owned());
703 if let Some(arg0) = arg0 {
704 command.arg0(arg0);
705 }
706 command
707 .args(&config.process_args)
708 .stdin(process::Stdio::Null)
709 .stdout(process::Stdio::Null)
710 .dup_fd(invitation.fd.as_fd(), IPC_FD)
711 .env(INVITATION_ENV_NAME, invitation_env);
712
713 if !config.skip_worker_arg {
714 command.arg(&name);
715 }
716
717 if let Some(log_file) = config.stderr.as_ref() {
718 command.stderr(process::Stdio::Fd(log_file.as_fd()));
719 }
720
721 if let Some(mut sandbox_profile) = config.sandbox_profile {
722 sandbox_profile.apply(&mut command);
723 }
724
725 let mut child = command.spawn().context("failed to launch mesh process")?;
726 pid = child.id();
727 tracing::Span::current().record("pid", pid);
728 move || {
729 let exit_status = child.wait().expect("mesh child wait failure");
730 if let Some(0) = exit_status.code() {
731 tracing::info!(pid, name = name.as_str(), "mesh child exited successfully");
732 } else {
733 tracing::error!(
734 pid,
735 name = name.as_str(),
736 %exit_status,
737 "mesh child abnormal exit"
738 );
739 }
740 }
741 };
742
743 let (wait_send, wait_recv) = mesh::oneshot();
744
745 let id = self.hosts.insert(MeshHostInner {
746 name: config.name,
747 pid,
748 node_id,
749 send: request_send,
750 });
751
752 thread::Builder::new()
753 .name(format!("wait-mesh-child-{}", pid))
754 .spawn(move || {
755 wait();
756 wait_send.send(id);
757 })
758 .unwrap();
759
760 self.waiters.push(wait_recv);
761 Ok(())
762 }
763}