mesh_process/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Infrastructure to create a multi-process mesh and spawn child processes
5//! within it.
6
7// UNSAFETY: Needed to accept a raw Fd/Handle from our spawning process.
8#![expect(unsafe_code)]
9
10use anyhow::Context;
11use base64::Engine;
12use debug_ptr::DebugPtr;
13use futures::FutureExt;
14use futures::StreamExt;
15use futures::executor::block_on;
16use futures_concurrency::future::Race;
17use inspect::Inspect;
18use inspect::SensitivityLevel;
19use mesh::MeshPayload;
20use mesh::OneshotReceiver;
21use mesh::message::MeshField;
22use mesh::payload::Protobuf;
23use mesh::rpc::Rpc;
24use mesh::rpc::RpcSend;
25use mesh_remote::InvitationAddress;
26#[cfg(unix)]
27use pal::unix::process::Builder as ProcessBuilder;
28#[cfg(windows)]
29use pal::windows::process;
30#[cfg(windows)]
31use pal::windows::process::Builder as ProcessBuilder;
32use pal_async::DefaultPool;
33use pal_async::task::Spawn;
34use pal_async::task::Task;
35use slab::Slab;
36use std::borrow::Cow;
37use std::ffi::OsString;
38use std::fs::File;
39#[cfg(unix)]
40use std::os::unix::prelude::*;
41#[cfg(windows)]
42use std::os::windows::prelude::*;
43use std::path::PathBuf;
44use std::thread;
45use tracing::Instrument;
46use tracing::instrument;
47use unicycle::FuturesUnordered;
48
49#[cfg(windows)]
50type IpcNode = mesh_remote::windows::AlpcNode;
51
52#[cfg(unix)]
53type IpcNode = mesh_remote::unix::UnixNode;
54
55#[cfg(unix)]
56const IPC_FD: i32 = 3;
57
58/// The environment variable for passing the mesh IPC invitation information to
59/// a child process. This is passed through the environment instead of a command
60/// line argument so that other processes cannot steal the invitation details
61/// and use it to break into the mesh.
62const INVITATION_ENV_NAME: &str = "MESH_WORKER_INVITATION";
63
64#[derive(Protobuf)]
65struct Invitation {
66    node_name: String,
67    address: InvitationAddress,
68    #[cfg(windows)]
69    directory_handle: usize,
70    #[cfg(unix)]
71    socket_fd: i32,
72}
73
74static PROCESS_NAME: DebugPtr<String> = DebugPtr::new();
75
76/// Runs a mesh host in the current thread, then exits the process, if this
77/// process was launched by [`Mesh::launch_host`].
78///
79/// The mesh invitation is provided via environment variables. If a mesh
80/// invitation is not available this function will return immediately with `Ok`.
81/// If a mesh invitation is available, this function joins the mesh and runs the
82/// future returned by `f` until `f` returns or the parent process shuts down
83/// the mesh.
84pub fn try_run_mesh_host<U, F, T>(base_name: &str, f: F) -> anyhow::Result<()>
85where
86    U: 'static + MeshPayload + Send,
87    F: AsyncFnOnce(U) -> anyhow::Result<T>,
88{
89    block_on(async {
90        if let Some(r) = node_from_environment().await? {
91            let NodeResult {
92                node_name,
93                node,
94                initial_port,
95            } = r;
96            PROCESS_NAME.store(&node_name);
97            set_program_name(&format!("{base_name}-{node_name}"));
98            let init = OneshotReceiver::<InitialMessage<U>>::from(initial_port)
99                .await
100                .context("failed to receive initial message")?;
101            let _drop = (
102                f(init.init_message).map(Some),
103                handle_host_requests(init.requests).map(|()| None),
104            )
105                .race()
106                .await
107                .transpose()?;
108
109            tracing::debug!("waiting to shut down node");
110            node.shutdown().await;
111            drop(_drop);
112            std::process::exit(0);
113        }
114        Ok(())
115    })
116}
117
118async fn handle_host_requests(mut recv: mesh::Receiver<HostRequest>) {
119    while let Some(req) = recv.next().await {
120        match req {
121            HostRequest::Inspect(deferred) => {
122                deferred.respond(inspect_host);
123            }
124            HostRequest::Crash => panic!("explicit panic request"),
125        }
126    }
127}
128
129fn set_program_name(name: &str) {
130    let _ = name;
131    #[cfg(target_os = "linux")]
132    {
133        let _ = std::fs::write("/proc/self/comm", name);
134    }
135}
136
137struct NodeResult {
138    node_name: String,
139    node: IpcNode,
140    initial_port: mesh::local_node::Port,
141}
142
143/// Create an IPC node from an invitation provided via the process environment.
144///
145/// Returns `None` if the invitation is not present in the environment.
146async fn node_from_environment() -> anyhow::Result<Option<NodeResult>> {
147    // return early with no node if the invitation is not present in the environment.
148    let invitation_str = match std::env::var(INVITATION_ENV_NAME) {
149        Ok(str) => str,
150        Err(_) => return Ok(None),
151    };
152
153    // Clear the string to avoid leaking the invitation information into child
154    // processes.
155    //
156    // TODO: this function will become unsafe in a future Rust edition because
157    // it can cause UB if non-Rust code is concurrently accessing the
158    // environment in another thread. To be completely sound (even in the
159    // current edition), either this function and its callers need to become
160    // `unsafe`, or we need to avoid using the environment to propagate the
161    // invitation so that we can avoid this call.
162    #[expect(deprecated_safe_2024)]
163    std::env::remove_var(INVITATION_ENV_NAME);
164
165    let invitation: Invitation = mesh::payload::decode(
166        &base64::engine::general_purpose::STANDARD
167            .decode(invitation_str)
168            .context("failed to base64 decode invitation")?,
169    )
170    .context("failed to protobuf decode invitation")?;
171
172    let (left, right) = mesh::local_node::Port::new_pair();
173
174    let node;
175    #[cfg(windows)]
176    {
177        // SAFETY: trusting the initiating process to pass a valid handle. A
178        // malicious process could pass a bad handle here, but a malicious
179        // process could also just corrupt our memory arbitrarily, so...
180        let directory =
181            unsafe { OwnedHandle::from_raw_handle(invitation.directory_handle as RawHandle) };
182
183        let invitation = mesh_remote::windows::AlpcInvitation {
184            address: invitation.address,
185            directory,
186        };
187
188        // join the node w/ the provided invitation and the send port of the channel.
189        node = mesh_remote::windows::AlpcNode::join(
190            pal_async::windows::TpPool::system(),
191            invitation,
192            left,
193        )
194        .context("failed to join mesh")?;
195    }
196
197    #[cfg(unix)]
198    {
199        // SAFETY: trusting the initiating process to pass a valid fd. A
200        // malicious process could pass a bad fd here, but a malicious
201        // process could also just corrupt our memory arbitrarily, so...
202        let fd = unsafe { OwnedFd::from_raw_fd(invitation.socket_fd) };
203        let invitation = mesh_remote::unix::UnixInvitation {
204            address: invitation.address,
205            fd,
206        };
207
208        // FUTURE: use pool provided by the caller.
209        let (_, driver) = DefaultPool::spawn_on_thread("mesh-worker-pool");
210        node = mesh_remote::unix::UnixNode::join(driver, invitation, left)
211            .await
212            .context("failed to join mesh")?;
213    }
214
215    Ok(Some(NodeResult {
216        node_name: invitation.node_name,
217        node,
218        initial_port: right,
219    }))
220}
221
222/// Represents a mesh::Node with the ability to spawn new processes that can
223/// communicate with any other process belonging to the same mesh.
224///
225/// # Process creation
226/// A `Mesh` instance can spawn new processes with an initial communication
227/// channel associated with the mesh. All processes originating from the same
228/// mesh can potentially communicate and exchange channels with each other.
229///
230/// Each spawned process can be configured differently via [`ProcessConfig`].
231/// Processes are created with [`Mesh::launch_host`].
232///
233/// ```no_run
234/// # use mesh_process::{Mesh, ProcessConfig};
235/// # futures::executor::block_on(async {
236/// let mesh = Mesh::new("remote_mesh".to_string()).unwrap();
237/// let (send, recv) = mesh::channel();
238/// mesh.launch_host(ProcessConfig::new("test"), recv).await.unwrap();
239/// send.send(String::from("message for new process"));
240/// # })
241/// ```
242pub struct Mesh {
243    mesh_name: String,
244    request: mesh::Sender<MeshRequest>,
245    task: Task<()>,
246}
247
248/// Sandbox profile trait used for mesh hosts.
249pub trait SandboxProfile: Send {
250    /// Apply executes in the parent context and configures any sandbox
251    /// features that will be applied to the newly created process via
252    /// the pal builder object.
253    fn apply(&mut self, builder: &mut ProcessBuilder<'_>);
254
255    /// Finalize is intended to execute in the child process context after
256    /// application specific initialization is complete. It's optional as not
257    /// every sandbox profile will need to perform additional sandboxing.
258    /// In addition, the child will need to be aware enough to instantiate its
259    /// sandbox profile and invoke this method.
260    fn finalize(&mut self) -> anyhow::Result<()> {
261        Ok(())
262    }
263}
264
265/// Configuration for launching a new process in the mesh.
266pub struct ProcessConfig {
267    name: String,
268    process_name: Option<PathBuf>,
269    process_args: Vec<OsString>,
270    stderr: Option<File>,
271    skip_worker_arg: bool,
272    sandbox_profile: Option<Box<dyn SandboxProfile + Sync>>,
273}
274
275impl ProcessConfig {
276    /// Returns new process configuration using the current process as the
277    /// process name.
278    pub fn new(name: impl Into<String>) -> Self {
279        Self {
280            name: name.into(),
281            process_name: None,
282            process_args: Vec::new(),
283            stderr: None,
284            skip_worker_arg: false,
285            sandbox_profile: None,
286        }
287    }
288
289    /// Returns a new process configuration using the current process as the
290    /// process name.
291    pub fn new_with_sandbox(
292        name: impl Into<String>,
293        sandbox_profile: Box<dyn SandboxProfile + Sync>,
294    ) -> Self {
295        Self {
296            name: name.into(),
297            process_name: None,
298            process_args: Vec::new(),
299            stderr: None,
300            skip_worker_arg: false,
301            sandbox_profile: Some(sandbox_profile),
302        }
303    }
304
305    /// Sets the process name.
306    pub fn process_name(mut self, name: impl Into<PathBuf>) -> Self {
307        self.process_name = Some(name.into());
308        self
309    }
310
311    /// Specifies whether to  appending `<node name>` to the process's command
312    /// line.
313    ///
314    /// This is done by default to make it easier to identify the process in
315    /// task lists, but if your process parses the command line then this may
316    /// get in the way.
317    pub fn skip_worker_arg(mut self, skip: bool) -> Self {
318        self.skip_worker_arg = skip;
319        self
320    }
321
322    /// Adds arguments to the process command line.
323    pub fn args<I>(mut self, args: I) -> Self
324    where
325        I: IntoIterator,
326        I::Item: Into<OsString>,
327    {
328        self.process_args.extend(args.into_iter().map(|x| x.into()));
329        self
330    }
331
332    /// Sets the process's stderr to `file`.
333    pub fn stderr(mut self, file: Option<File>) -> Self {
334        self.stderr = file;
335        self
336    }
337}
338
339struct MeshInner {
340    requests: mesh::Receiver<MeshRequest>,
341    hosts: Slab<MeshHostInner>,
342    /// Handles for spawned host processes.
343    waiters: FuturesUnordered<OneshotReceiver<usize>>,
344    /// Mesh node for host process communication.
345    node: IpcNode,
346    /// Name for this mesh instance, used for tracing/debugging.
347    mesh_name: String,
348    /// Job object. When closed, it will terminate all the child processes. This
349    /// is used to ensure the child processes don't outlive the parent.
350    #[cfg(windows)]
351    job: pal::windows::job::Job,
352}
353
354struct MeshHostInner {
355    name: String,
356    pid: i32,
357    node_id: mesh::NodeId,
358    send: mesh::Sender<HostRequest>,
359}
360
361enum MeshRequest {
362    NewHost(Rpc<NewHostParams, anyhow::Result<()>>),
363    Inspect(inspect::Deferred),
364    Crash(i32),
365}
366
367struct NewHostParams {
368    config: ProcessConfig,
369    recv: mesh::local_node::Port,
370    request_send: mesh::Sender<HostRequest>,
371}
372
373impl Inspect for Mesh {
374    fn inspect(&self, req: inspect::Request<'_>) {
375        self.request.send(MeshRequest::Inspect(req.defer()));
376    }
377}
378
379impl Mesh {
380    /// Creates a new mesh with the given name.
381    pub fn new(mesh_name: String) -> anyhow::Result<Self> {
382        #[cfg(windows)]
383        let job = {
384            let job = pal::windows::job::Job::new().context("failed to create job object")?;
385            job.set_terminate_on_close()
386                .context("failed to set job object terminate on close")?;
387            job
388        };
389
390        #[cfg(windows)]
391        let node = mesh_remote::windows::AlpcNode::new(pal_async::windows::TpPool::system())
392            .context("AlpcNode creation failure")?;
393        #[cfg(unix)]
394        let node = {
395            // FUTURE: use pool provided by the caller.
396            let (_, driver) = DefaultPool::spawn_on_thread("mesh-worker-pool");
397            mesh_remote::unix::UnixNode::new(driver)
398        };
399
400        let (request, requests) = mesh::channel();
401        let mut inner = MeshInner {
402            requests,
403            hosts: Default::default(),
404            waiters: Default::default(),
405            node,
406            mesh_name: mesh_name.clone(),
407            #[cfg(windows)]
408            job,
409        };
410
411        // Spawn a separate thread for launching mesh processes to avoid bad
412        // interactions with any other pools.
413        let (_, driver) = DefaultPool::spawn_on_thread("mesh");
414        let task = driver.spawn(
415            format!("mesh-{}", &mesh_name),
416            async move { inner.run().await },
417        );
418
419        Ok(Self {
420            request,
421            mesh_name,
422            task,
423        })
424    }
425
426    /// Spawns a new host in the mesh with the provided configuration and
427    /// initial message.
428    ///
429    /// The initial message will be provided to the closure passed to
430    /// [`try_run_mesh_host()`].
431    pub async fn launch_host<T: 'static + MeshField + Send>(
432        &self,
433        config: ProcessConfig,
434        initial_message: T,
435    ) -> anyhow::Result<()> {
436        let (request_send, request_recv) = mesh::channel();
437
438        let (init_send, init_recv) = mesh::oneshot::<InitialMessage<T>>();
439        init_send.send(InitialMessage {
440            requests: request_recv,
441            init_message: initial_message,
442        });
443
444        self.request
445            .call(
446                MeshRequest::NewHost,
447                NewHostParams {
448                    config,
449                    recv: init_recv.into(),
450                    request_send,
451                },
452            )
453            .await
454            .context("mesh failed")?
455    }
456
457    /// Shutdown the mesh and wait for any spawned processes to exit.
458    ///
459    /// The `Mesh` instance is no longer usable after `shutdown`.
460    pub async fn shutdown(self) {
461        let span = tracing::span!(
462            tracing::Level::INFO,
463            "mesh_shutdown",
464            name = self.mesh_name.as_str(),
465        );
466
467        async {
468            drop(self.request);
469            self.task.await;
470        }
471        .instrument(span)
472        .await;
473    }
474
475    /// Crashes the child process with the given process ID.
476    pub fn crash(&self, pid: i32) {
477        self.request.send(MeshRequest::Crash(pid));
478    }
479}
480
481#[derive(MeshPayload)]
482struct InitialMessage<T> {
483    requests: mesh::Receiver<HostRequest>,
484    init_message: T,
485}
486
487#[derive(Debug, MeshPayload)]
488enum HostRequest {
489    #[mesh(transparent)]
490    Inspect(inspect::Deferred),
491    Crash,
492}
493
494fn inspect_host(resp: &mut inspect::Response<'_>) {
495    resp.field("tasks", inspect_task::inspect_task_list());
496}
497
498#[derive(Inspect)]
499struct HostInspect<'a> {
500    #[inspect(safe)]
501    name: &'a str,
502    #[inspect(debug, safe)]
503    node_id: mesh::NodeId,
504    #[cfg(target_os = "linux")]
505    #[inspect(safe)]
506    rlimit: inspect_rlimit::InspectRlimit,
507}
508
509impl MeshInner {
510    async fn run(&mut self) {
511        enum Event {
512            Request(MeshRequest),
513            Done(usize),
514        }
515
516        loop {
517            let event = futures::select! { // merge semantics
518                request = self.requests.select_next_some() => Event::Request(request),
519                n = self.waiters.select_next_some() => Event::Done(n.unwrap()),
520                complete => break,
521            };
522
523            match event {
524                Event::Request(request) => match request {
525                    MeshRequest::NewHost(rpc) => {
526                        rpc.handle(async |params| self.spawn_process(params).await)
527                            .await
528                    }
529                    MeshRequest::Inspect(deferred) => {
530                        deferred.respond(|resp| {
531                            resp.sensitivity_child("hosts", SensitivityLevel::Safe, |req| {
532                                let mut resp = req.respond();
533                                for host in self.hosts.iter().map(|(_, host)| host) {
534                                    resp.sensitivity_field_mut(
535                                        &host.pid.to_string(),
536                                        SensitivityLevel::Safe,
537                                        &mut inspect::adhoc(|req| {
538                                            req.respond()
539                                                .merge(&HostInspect {
540                                                    name: &host.name,
541                                                    node_id: host.node_id,
542                                                    #[cfg(target_os = "linux")]
543                                                    rlimit: inspect_rlimit::InspectRlimit::for_pid(
544                                                        host.pid,
545                                                    ),
546                                                })
547                                                .merge(inspect::adhoc(|req| {
548                                                    host.send
549                                                        .send(HostRequest::Inspect(req.defer()));
550                                                }));
551                                        }),
552                                    );
553                                }
554                            })
555                            .sensitivity_field_mut(
556                                &format!("hosts/{}", std::process::id()),
557                                SensitivityLevel::Safe,
558                                &mut inspect::adhoc(|req| {
559                                    let mut resp = req.respond();
560                                    resp.merge(&HostInspect {
561                                        name: &self.mesh_name,
562                                        node_id: self.node.id(),
563                                        #[cfg(target_os = "linux")]
564                                        rlimit: inspect_rlimit::InspectRlimit::new(),
565                                    });
566                                    inspect_host(&mut resp);
567                                }),
568                            );
569                        });
570                    }
571                    MeshRequest::Crash(pid) => {
572                        if pid == std::process::id() as i32 {
573                            panic!("explicit panic request");
574                        }
575
576                        let mut found = false;
577                        for (_, host) in &self.hosts {
578                            if host.pid == pid {
579                                host.send.send(HostRequest::Crash);
580                                found = true;
581                                break;
582                            }
583                        }
584
585                        if !found {
586                            tracing::error!("failed to crash process, pid {pid} not found");
587                        }
588                    }
589                },
590                Event::Done(id) => {
591                    self.hosts.remove(id);
592                }
593            }
594        }
595    }
596
597    /// Spawns a new process with a mesh channel associated with this `Mesh` instance.
598    #[instrument(name = "mesh_spawn_process", skip(self, params), fields(mesh_name = self.mesh_name.as_str(), pid = tracing::field::Empty))]
599    async fn spawn_process(&mut self, params: NewHostParams) -> anyhow::Result<()> {
600        let NewHostParams {
601            config,
602            recv,
603            request_send,
604        } = params;
605
606        let pid;
607        let node_id;
608
609        // If no process name was passed, use the current executable path to
610        // ensure we get the right file, but set arg0 to match how this process
611        // was launched.
612        let (arg0, process_name) = if let Some(n) = &config.process_name {
613            (None, Cow::Borrowed(n))
614        } else {
615            (
616                std::env::args_os().next(),
617                Cow::Owned(std::env::current_exe().context("failed to get current exe path")?),
618            )
619        };
620
621        let name = config.name.clone();
622
623        #[cfg(windows)]
624        let wait = {
625            let (invitation, handle) = self.node.invite(recv).context("mesh node invite error")?;
626            node_id = invitation.address.local_addr.node;
627
628            let invitation_env = base64::engine::general_purpose::STANDARD.encode(
629                mesh::payload::encode(Invitation {
630                    node_name: name.clone(),
631                    address: invitation.address,
632                    directory_handle: invitation.directory.as_raw_handle() as usize,
633                }),
634            );
635
636            let mut args = config.process_args;
637            if !config.skip_worker_arg {
638                args.push(name.clone().into());
639            }
640
641            let mut builder = process::Builder::from_args(
642                arg0.as_ref()
643                    .map_or_else(|| process_name.as_os_str(), |x| x.as_os_str()),
644                &args,
645            );
646            if arg0.is_some() {
647                builder.application_name(process_name.as_path());
648            }
649            builder
650                .stdin(process::Stdio::Null)
651                .stdout(process::Stdio::Null)
652                .handle(&invitation.directory)
653                .env(INVITATION_ENV_NAME, invitation_env)
654                .job(self.job.as_handle());
655
656            if let Some(log_file) = config.stderr.as_ref() {
657                builder.stderr(process::Stdio::Handle(log_file.as_handle()));
658            }
659
660            if let Some(mut sandbox_profile) = config.sandbox_profile {
661                sandbox_profile.apply(&mut builder);
662            }
663
664            let child = builder.spawn().context("failed to launch mesh process")?;
665            // Wait for the child to connect to the mesh. TODO: timeout
666            handle.await;
667            pid = child.id() as i32;
668            tracing::Span::current().record("pid", pid);
669            move || {
670                child.wait();
671                let code = child.exit_code();
672                if code == 0 {
673                    tracing::info!(pid, name = name.as_str(), "mesh child exited successfully");
674                } else {
675                    tracing::error!(pid, name = name.as_str(), code, "mesh child abnormal exit");
676                }
677            }
678        };
679        #[cfg(unix)]
680        let mut wait = {
681            use pal::unix::process;
682
683            let invitation = self
684                .node
685                .invite(recv)
686                .await
687                .context("mesh node invite error")?;
688
689            node_id = invitation.address.local_addr.node;
690
691            let invitation_env = base64::engine::general_purpose::STANDARD.encode(
692                mesh::payload::encode(Invitation {
693                    node_name: name.clone(),
694                    address: invitation.address,
695                    socket_fd: IPC_FD,
696                }),
697            );
698
699            let mut command = process::Builder::new(process_name.into_owned());
700            if let Some(arg0) = arg0 {
701                command.arg0(arg0);
702            }
703            command
704                .args(&config.process_args)
705                .stdin(process::Stdio::Null)
706                .stdout(process::Stdio::Null)
707                .dup_fd(invitation.fd.as_fd(), IPC_FD)
708                .env(INVITATION_ENV_NAME, invitation_env);
709
710            if !config.skip_worker_arg {
711                command.arg(&name);
712            }
713
714            if let Some(log_file) = config.stderr.as_ref() {
715                command.stderr(process::Stdio::Fd(log_file.as_fd()));
716            }
717
718            if let Some(mut sandbox_profile) = config.sandbox_profile {
719                sandbox_profile.apply(&mut command);
720            }
721
722            let mut child = command.spawn().context("failed to launch mesh process")?;
723            pid = child.id();
724            tracing::Span::current().record("pid", pid);
725            move || {
726                let exit_status = child.wait().expect("mesh child wait failure");
727                if let Some(0) = exit_status.code() {
728                    tracing::info!(pid, name = name.as_str(), "mesh child exited successfully");
729                } else {
730                    tracing::error!(
731                        pid,
732                        name = name.as_str(),
733                        %exit_status,
734                        "mesh child abnormal exit"
735                    );
736                }
737            }
738        };
739
740        let (wait_send, wait_recv) = mesh::oneshot();
741
742        let id = self.hosts.insert(MeshHostInner {
743            name: config.name,
744            pid,
745            node_id,
746            send: request_send,
747        });
748
749        thread::Builder::new()
750            .name(format!("wait-mesh-child-{}", pid))
751            .spawn(move || {
752                wait();
753                wait_send.send(id);
754            })
755            .unwrap();
756
757        self.waiters.push(wait_recv);
758        Ok(())
759    }
760}