mesh_process/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Infrastructure to create a multi-process mesh and spawn child processes
5//! within it.
6
7// UNSAFETY: Needed to accept a raw Fd/Handle from our spawning process.
8#![expect(unsafe_code)]
9
10use anyhow::Context;
11use base64::Engine;
12use debug_ptr::DebugPtr;
13use futures::FutureExt;
14use futures::StreamExt;
15use futures::executor::block_on;
16use futures_concurrency::future::Race;
17use inspect::Inspect;
18use inspect::SensitivityLevel;
19use mesh::MeshPayload;
20use mesh::OneshotReceiver;
21use mesh::message::MeshField;
22use mesh::payload::Protobuf;
23use mesh::rpc::Rpc;
24use mesh::rpc::RpcSend;
25use mesh_remote::InvitationAddress;
26#[cfg(unix)]
27use pal::unix::process::Builder as ProcessBuilder;
28#[cfg(windows)]
29use pal::windows::process;
30#[cfg(windows)]
31use pal::windows::process::Builder as ProcessBuilder;
32use pal_async::DefaultPool;
33use pal_async::task::Spawn;
34use pal_async::task::Task;
35use slab::Slab;
36use std::borrow::Cow;
37use std::ffi::OsString;
38use std::fs::File;
39#[cfg(unix)]
40use std::os::unix::prelude::*;
41#[cfg(windows)]
42use std::os::windows::prelude::*;
43use std::path::PathBuf;
44use std::thread;
45use tracing::Instrument;
46use tracing::instrument;
47use unicycle::FuturesUnordered;
48
49#[cfg(windows)]
50type IpcNode = mesh_remote::windows::AlpcNode;
51
52#[cfg(unix)]
53type IpcNode = mesh_remote::unix::UnixNode;
54
55#[cfg(unix)]
56const IPC_FD: i32 = 3;
57
58/// The environment variable for passing the mesh IPC invitation information to
59/// a child process. This is passed through the environment instead of a command
60/// line argument so that other processes cannot steal the invitation details
61/// and use it to break into the mesh.
62const INVITATION_ENV_NAME: &str = "MESH_WORKER_INVITATION";
63
64#[derive(Protobuf)]
65struct Invitation {
66    node_name: String,
67    address: InvitationAddress,
68    #[cfg(windows)]
69    directory_handle: usize,
70    #[cfg(unix)]
71    socket_fd: i32,
72}
73
74static PROCESS_NAME: DebugPtr<String> = DebugPtr::new();
75
76/// Runs a mesh host in the current thread, then exits the process, if this
77/// process was launched by [`Mesh::launch_host`].
78///
79/// The mesh invitation is provided via environment variables. If a mesh
80/// invitation is not available this function will return immediately with `Ok`.
81/// If a mesh invitation is available, this function joins the mesh and runs the
82/// future returned by `f` until `f` returns or the parent process shuts down
83/// the mesh.
84pub fn try_run_mesh_host<U, F, T>(base_name: &str, f: F) -> anyhow::Result<()>
85where
86    U: 'static + MeshPayload + Send,
87    F: AsyncFnOnce(U) -> anyhow::Result<T>,
88{
89    block_on(async {
90        if let Some(r) = node_from_environment().await? {
91            let NodeResult {
92                node_name,
93                node,
94                initial_port,
95            } = r;
96            PROCESS_NAME.store(&node_name);
97            set_program_name(&format!("{base_name}-{node_name}"));
98            let init = OneshotReceiver::<InitialMessage<U>>::from(initial_port)
99                .await
100                .context("failed to receive initial message")?;
101            let _drop = (
102                f(init.init_message).map(Some),
103                handle_host_requests(init.requests).map(|()| None),
104            )
105                .race()
106                .await
107                .transpose()?;
108
109            tracing::debug!("waiting to shut down node");
110            node.shutdown().await;
111            drop(_drop);
112            std::process::exit(0);
113        }
114        Ok(())
115    })
116}
117
118async fn handle_host_requests(mut recv: mesh::Receiver<HostRequest>) {
119    while let Some(req) = recv.next().await {
120        match req {
121            HostRequest::Inspect(deferred) => {
122                deferred.respond(inspect_host);
123            }
124            HostRequest::Crash => panic!("explicit panic request"),
125        }
126    }
127}
128
129fn set_program_name(name: &str) {
130    let _ = name;
131    #[cfg(target_os = "linux")]
132    {
133        let _ = std::fs::write("/proc/self/comm", name);
134    }
135}
136
137struct NodeResult {
138    node_name: String,
139    node: IpcNode,
140    initial_port: mesh::local_node::Port,
141}
142
143/// Create an IPC node from an invitation provided via the process environment.
144///
145/// Returns `None` if the invitation is not present in the environment.
146async fn node_from_environment() -> anyhow::Result<Option<NodeResult>> {
147    // return early with no node if the invitation is not present in the environment.
148    let invitation_str = match std::env::var(INVITATION_ENV_NAME) {
149        Ok(str) => str,
150        Err(_) => return Ok(None),
151    };
152
153    // Clear the string to avoid leaking the invitation information into child
154    // processes.
155    //
156    // TODO: this function is unsafe because
157    // it can cause UB if non-Rust code is concurrently accessing the
158    // environment in another thread. To be completely sound,
159    // either this function and its callers need to become
160    // `unsafe`, or we need to avoid using the environment to propagate the
161    // invitation so that we can avoid this call.
162    //
163    // SAFETY: Seems to work so far.
164    unsafe {
165        std::env::remove_var(INVITATION_ENV_NAME);
166    }
167
168    let invitation: Invitation = mesh::payload::decode(
169        &base64::engine::general_purpose::STANDARD
170            .decode(invitation_str)
171            .context("failed to base64 decode invitation")?,
172    )
173    .context("failed to protobuf decode invitation")?;
174
175    let (left, right) = mesh::local_node::Port::new_pair();
176
177    let node;
178    #[cfg(windows)]
179    {
180        // SAFETY: trusting the initiating process to pass a valid handle. A
181        // malicious process could pass a bad handle here, but a malicious
182        // process could also just corrupt our memory arbitrarily, so...
183        let directory =
184            unsafe { OwnedHandle::from_raw_handle(invitation.directory_handle as RawHandle) };
185
186        let invitation = mesh_remote::windows::AlpcInvitation {
187            address: invitation.address,
188            directory,
189        };
190
191        // join the node w/ the provided invitation and the send port of the channel.
192        node = mesh_remote::windows::AlpcNode::join(
193            pal_async::windows::TpPool::system(),
194            invitation,
195            left,
196        )
197        .context("failed to join mesh")?;
198    }
199
200    #[cfg(unix)]
201    {
202        // SAFETY: trusting the initiating process to pass a valid fd. A
203        // malicious process could pass a bad fd here, but a malicious
204        // process could also just corrupt our memory arbitrarily, so...
205        let fd = unsafe { OwnedFd::from_raw_fd(invitation.socket_fd) };
206        let invitation = mesh_remote::unix::UnixInvitation {
207            address: invitation.address,
208            fd,
209        };
210
211        // FUTURE: use pool provided by the caller.
212        let (_, driver) = DefaultPool::spawn_on_thread("mesh-worker-pool");
213        node = mesh_remote::unix::UnixNode::join(driver, invitation, left)
214            .await
215            .context("failed to join mesh")?;
216    }
217
218    Ok(Some(NodeResult {
219        node_name: invitation.node_name,
220        node,
221        initial_port: right,
222    }))
223}
224
225/// Represents a mesh::Node with the ability to spawn new processes that can
226/// communicate with any other process belonging to the same mesh.
227///
228/// # Process creation
229/// A `Mesh` instance can spawn new processes with an initial communication
230/// channel associated with the mesh. All processes originating from the same
231/// mesh can potentially communicate and exchange channels with each other.
232///
233/// Each spawned process can be configured differently via [`ProcessConfig`].
234/// Processes are created with [`Mesh::launch_host`].
235///
236/// ```no_run
237/// # use mesh_process::{Mesh, ProcessConfig};
238/// # futures::executor::block_on(async {
239/// let mesh = Mesh::new("remote_mesh".to_string()).unwrap();
240/// let (send, recv) = mesh::channel();
241/// mesh.launch_host(ProcessConfig::new("test"), recv).await.unwrap();
242/// send.send(String::from("message for new process"));
243/// # })
244/// ```
245#[derive(Inspect)]
246pub struct Mesh {
247    #[inspect(rename = "name")]
248    mesh_name: String,
249    #[inspect(flatten, send = "MeshRequest::Inspect")]
250    request: mesh::Sender<MeshRequest>,
251    #[inspect(skip)]
252    task: Task<()>,
253}
254
255/// Sandbox profile trait used for mesh hosts.
256pub trait SandboxProfile: Send {
257    /// Apply executes in the parent context and configures any sandbox
258    /// features that will be applied to the newly created process via
259    /// the pal builder object.
260    fn apply(&mut self, builder: &mut ProcessBuilder<'_>);
261
262    /// Finalize is intended to execute in the child process context after
263    /// application specific initialization is complete. It's optional as not
264    /// every sandbox profile will need to perform additional sandboxing.
265    /// In addition, the child will need to be aware enough to instantiate its
266    /// sandbox profile and invoke this method.
267    fn finalize(&mut self) -> anyhow::Result<()> {
268        Ok(())
269    }
270}
271
272/// Configuration for launching a new process in the mesh.
273pub struct ProcessConfig {
274    name: String,
275    process_name: Option<PathBuf>,
276    process_args: Vec<OsString>,
277    stderr: Option<File>,
278    skip_worker_arg: bool,
279    sandbox_profile: Option<Box<dyn SandboxProfile + Sync>>,
280}
281
282impl ProcessConfig {
283    /// Returns new process configuration using the current process as the
284    /// process name.
285    pub fn new(name: impl Into<String>) -> Self {
286        Self {
287            name: name.into(),
288            process_name: None,
289            process_args: Vec::new(),
290            stderr: None,
291            skip_worker_arg: false,
292            sandbox_profile: None,
293        }
294    }
295
296    /// Returns a new process configuration using the current process as the
297    /// process name.
298    pub fn new_with_sandbox(
299        name: impl Into<String>,
300        sandbox_profile: Box<dyn SandboxProfile + Sync>,
301    ) -> Self {
302        Self {
303            name: name.into(),
304            process_name: None,
305            process_args: Vec::new(),
306            stderr: None,
307            skip_worker_arg: false,
308            sandbox_profile: Some(sandbox_profile),
309        }
310    }
311
312    /// Sets the process name.
313    pub fn process_name(mut self, name: impl Into<PathBuf>) -> Self {
314        self.process_name = Some(name.into());
315        self
316    }
317
318    /// Specifies whether to  appending `<node name>` to the process's command
319    /// line.
320    ///
321    /// This is done by default to make it easier to identify the process in
322    /// task lists, but if your process parses the command line then this may
323    /// get in the way.
324    pub fn skip_worker_arg(mut self, skip: bool) -> Self {
325        self.skip_worker_arg = skip;
326        self
327    }
328
329    /// Adds arguments to the process command line.
330    pub fn args<I>(mut self, args: I) -> Self
331    where
332        I: IntoIterator,
333        I::Item: Into<OsString>,
334    {
335        self.process_args.extend(args.into_iter().map(|x| x.into()));
336        self
337    }
338
339    /// Sets the process's stderr to `file`.
340    pub fn stderr(mut self, file: Option<File>) -> Self {
341        self.stderr = file;
342        self
343    }
344}
345
346struct MeshInner {
347    requests: mesh::Receiver<MeshRequest>,
348    hosts: Slab<MeshHostInner>,
349    /// Handles for spawned host processes.
350    waiters: FuturesUnordered<OneshotReceiver<usize>>,
351    /// Mesh node for host process communication.
352    node: IpcNode,
353    /// Name for this mesh instance, used for tracing/debugging.
354    mesh_name: String,
355    /// Job object. When closed, it will terminate all the child processes. This
356    /// is used to ensure the child processes don't outlive the parent.
357    #[cfg(windows)]
358    job: pal::windows::job::Job,
359}
360
361struct MeshHostInner {
362    name: String,
363    pid: i32,
364    node_id: mesh::NodeId,
365    send: mesh::Sender<HostRequest>,
366}
367
368enum MeshRequest {
369    NewHost(Rpc<NewHostParams, anyhow::Result<()>>),
370    Inspect(inspect::Deferred),
371    Crash(i32),
372}
373
374struct NewHostParams {
375    config: ProcessConfig,
376    recv: mesh::local_node::Port,
377    request_send: mesh::Sender<HostRequest>,
378}
379
380impl Mesh {
381    /// Creates a new mesh with the given name.
382    pub fn new(mesh_name: String) -> anyhow::Result<Self> {
383        #[cfg(windows)]
384        let job = {
385            let job = pal::windows::job::Job::new().context("failed to create job object")?;
386            job.set_terminate_on_close()
387                .context("failed to set job object terminate on close")?;
388            job
389        };
390
391        #[cfg(windows)]
392        let node = mesh_remote::windows::AlpcNode::new(pal_async::windows::TpPool::system())
393            .context("AlpcNode creation failure")?;
394        #[cfg(unix)]
395        let node = {
396            // FUTURE: use pool provided by the caller.
397            let (_, driver) = DefaultPool::spawn_on_thread("mesh-worker-pool");
398            mesh_remote::unix::UnixNode::new(driver)
399        };
400
401        let (request, requests) = mesh::channel();
402        let mut inner = MeshInner {
403            requests,
404            hosts: Default::default(),
405            waiters: Default::default(),
406            node,
407            mesh_name: mesh_name.clone(),
408            #[cfg(windows)]
409            job,
410        };
411
412        // Spawn a separate thread for launching mesh processes to avoid bad
413        // interactions with any other pools.
414        let (_, driver) = DefaultPool::spawn_on_thread("mesh");
415        let task = driver.spawn(
416            format!("mesh-{}", &mesh_name),
417            async move { inner.run().await },
418        );
419
420        Ok(Self {
421            request,
422            mesh_name,
423            task,
424        })
425    }
426
427    /// Spawns a new host in the mesh with the provided configuration and
428    /// initial message.
429    ///
430    /// The initial message will be provided to the closure passed to
431    /// [`try_run_mesh_host()`].
432    pub async fn launch_host<T: 'static + MeshField + Send>(
433        &self,
434        config: ProcessConfig,
435        initial_message: T,
436    ) -> anyhow::Result<()> {
437        let (request_send, request_recv) = mesh::channel();
438
439        let (init_send, init_recv) = mesh::oneshot::<InitialMessage<T>>();
440        init_send.send(InitialMessage {
441            requests: request_recv,
442            init_message: initial_message,
443        });
444
445        self.request
446            .call(
447                MeshRequest::NewHost,
448                NewHostParams {
449                    config,
450                    recv: init_recv.into(),
451                    request_send,
452                },
453            )
454            .await
455            .context("mesh failed")?
456    }
457
458    /// Shutdown the mesh and wait for any spawned processes to exit.
459    ///
460    /// The `Mesh` instance is no longer usable after `shutdown`.
461    pub async fn shutdown(self) {
462        let span = tracing::span!(
463            tracing::Level::INFO,
464            "mesh_shutdown",
465            name = self.mesh_name.as_str(),
466        );
467
468        async {
469            drop(self.request);
470            self.task.await;
471        }
472        .instrument(span)
473        .await;
474    }
475
476    /// Crashes the child process with the given process ID.
477    pub fn crash(&self, pid: i32) {
478        self.request.send(MeshRequest::Crash(pid));
479    }
480}
481
482#[derive(MeshPayload)]
483struct InitialMessage<T> {
484    requests: mesh::Receiver<HostRequest>,
485    init_message: T,
486}
487
488#[derive(Debug, MeshPayload)]
489enum HostRequest {
490    #[mesh(transparent)]
491    Inspect(inspect::Deferred),
492    Crash,
493}
494
495fn inspect_host(resp: &mut inspect::Response<'_>) {
496    resp.field("tasks", inspect_task::inspect_task_list());
497}
498
499#[derive(Inspect)]
500struct HostInspect<'a> {
501    #[inspect(safe)]
502    name: &'a str,
503    #[inspect(debug, safe)]
504    node_id: mesh::NodeId,
505    #[cfg(target_os = "linux")]
506    #[inspect(safe)]
507    rlimit: inspect_rlimit::InspectRlimit,
508}
509
510impl MeshInner {
511    async fn run(&mut self) {
512        enum Event {
513            Request(MeshRequest),
514            Done(usize),
515        }
516
517        loop {
518            let event = futures::select! { // merge semantics
519                request = self.requests.select_next_some() => Event::Request(request),
520                n = self.waiters.select_next_some() => Event::Done(n.unwrap()),
521                complete => break,
522            };
523
524            match event {
525                Event::Request(request) => match request {
526                    MeshRequest::NewHost(rpc) => {
527                        rpc.handle(async |params| self.spawn_process(params).await)
528                            .await
529                    }
530                    MeshRequest::Inspect(deferred) => {
531                        deferred.respond(|resp| {
532                            resp.sensitivity_child("hosts", SensitivityLevel::Safe, |req| {
533                                let mut resp = req.respond();
534                                for host in self.hosts.iter().map(|(_, host)| host) {
535                                    resp.sensitivity_field_mut(
536                                        &host.pid.to_string(),
537                                        SensitivityLevel::Safe,
538                                        &mut inspect::adhoc(|req| {
539                                            req.respond()
540                                                .merge(&HostInspect {
541                                                    name: &host.name,
542                                                    node_id: host.node_id,
543                                                    #[cfg(target_os = "linux")]
544                                                    rlimit: inspect_rlimit::InspectRlimit::for_pid(
545                                                        host.pid,
546                                                    ),
547                                                })
548                                                .merge(inspect::send(
549                                                    &host.send,
550                                                    HostRequest::Inspect,
551                                                ));
552                                        }),
553                                    );
554                                }
555                            })
556                            .sensitivity_field_mut(
557                                &format!("hosts/{}", std::process::id()),
558                                SensitivityLevel::Safe,
559                                &mut inspect::adhoc(|req| {
560                                    let mut resp = req.respond();
561                                    resp.merge(&HostInspect {
562                                        name: &self.mesh_name,
563                                        node_id: self.node.id(),
564                                        #[cfg(target_os = "linux")]
565                                        rlimit: inspect_rlimit::InspectRlimit::new(),
566                                    });
567                                    inspect_host(&mut resp);
568                                }),
569                            );
570                        });
571                    }
572                    MeshRequest::Crash(pid) => {
573                        if pid == std::process::id() as i32 {
574                            panic!("explicit panic request");
575                        }
576
577                        let mut found = false;
578                        for (_, host) in &self.hosts {
579                            if host.pid == pid {
580                                host.send.send(HostRequest::Crash);
581                                found = true;
582                                break;
583                            }
584                        }
585
586                        if !found {
587                            tracing::error!("failed to crash process, pid {pid} not found");
588                        }
589                    }
590                },
591                Event::Done(id) => {
592                    self.hosts.remove(id);
593                }
594            }
595        }
596    }
597
598    /// Spawns a new process with a mesh channel associated with this `Mesh` instance.
599    #[instrument(name = "mesh_spawn_process", skip(self, params), fields(mesh_name = self.mesh_name.as_str(), pid = tracing::field::Empty))]
600    async fn spawn_process(&mut self, params: NewHostParams) -> anyhow::Result<()> {
601        let NewHostParams {
602            config,
603            recv,
604            request_send,
605        } = params;
606
607        let pid;
608        let node_id;
609
610        // If no process name was passed, use the current executable path to
611        // ensure we get the right file, but set arg0 to match how this process
612        // was launched.
613        let (arg0, process_name) = if let Some(n) = &config.process_name {
614            (None, Cow::Borrowed(n))
615        } else {
616            (
617                std::env::args_os().next(),
618                Cow::Owned(std::env::current_exe().context("failed to get current exe path")?),
619            )
620        };
621
622        let name = config.name.clone();
623
624        #[cfg(windows)]
625        let wait = {
626            let (invitation, handle) = self.node.invite(recv).context("mesh node invite error")?;
627            node_id = invitation.address.local_addr.node;
628
629            let invitation_env = base64::engine::general_purpose::STANDARD.encode(
630                mesh::payload::encode(Invitation {
631                    node_name: name.clone(),
632                    address: invitation.address,
633                    directory_handle: invitation.directory.as_raw_handle() as usize,
634                }),
635            );
636
637            let mut args = config.process_args;
638            if !config.skip_worker_arg {
639                args.push(name.clone().into());
640            }
641
642            let mut builder = process::Builder::from_args(
643                arg0.as_ref()
644                    .map_or_else(|| process_name.as_os_str(), |x| x.as_os_str()),
645                &args,
646            );
647            if arg0.is_some() {
648                builder.application_name(process_name.as_path());
649            }
650            builder
651                .stdin(process::Stdio::Null)
652                .stdout(process::Stdio::Null)
653                .handle(&invitation.directory)
654                .env(INVITATION_ENV_NAME, invitation_env)
655                .job(self.job.as_handle());
656
657            if let Some(log_file) = config.stderr.as_ref() {
658                builder.stderr(process::Stdio::Handle(log_file.as_handle()));
659            }
660
661            if let Some(mut sandbox_profile) = config.sandbox_profile {
662                sandbox_profile.apply(&mut builder);
663            }
664
665            let child = builder.spawn().context("failed to launch mesh process")?;
666            // Wait for the child to connect to the mesh. TODO: timeout
667            handle.await;
668            pid = child.id() as i32;
669            tracing::Span::current().record("pid", pid);
670            move || {
671                child.wait();
672                let code = child.exit_code();
673                if code == 0 {
674                    tracing::info!(pid, name = name.as_str(), "mesh child exited successfully");
675                } else {
676                    tracing::error!(pid, name = name.as_str(), code, "mesh child abnormal exit");
677                }
678            }
679        };
680        #[cfg(unix)]
681        let mut wait = {
682            use pal::unix::process;
683
684            let invitation = self
685                .node
686                .invite(recv)
687                .await
688                .context("mesh node invite error")?;
689
690            node_id = invitation.address.local_addr.node;
691
692            let invitation_env = base64::engine::general_purpose::STANDARD.encode(
693                mesh::payload::encode(Invitation {
694                    node_name: name.clone(),
695                    address: invitation.address,
696                    socket_fd: IPC_FD,
697                }),
698            );
699
700            let mut command = process::Builder::new(process_name.into_owned());
701            if let Some(arg0) = arg0 {
702                command.arg0(arg0);
703            }
704            command
705                .args(&config.process_args)
706                .stdin(process::Stdio::Null)
707                .stdout(process::Stdio::Null)
708                .dup_fd(invitation.fd.as_fd(), IPC_FD)
709                .env(INVITATION_ENV_NAME, invitation_env);
710
711            if !config.skip_worker_arg {
712                command.arg(&name);
713            }
714
715            if let Some(log_file) = config.stderr.as_ref() {
716                command.stderr(process::Stdio::Fd(log_file.as_fd()));
717            }
718
719            if let Some(mut sandbox_profile) = config.sandbox_profile {
720                sandbox_profile.apply(&mut command);
721            }
722
723            let mut child = command.spawn().context("failed to launch mesh process")?;
724            pid = child.id();
725            tracing::Span::current().record("pid", pid);
726            move || {
727                let exit_status = child.wait().expect("mesh child wait failure");
728                if let Some(0) = exit_status.code() {
729                    tracing::info!(pid, name = name.as_str(), "mesh child exited successfully");
730                } else {
731                    tracing::error!(
732                        pid,
733                        name = name.as_str(),
734                        %exit_status,
735                        "mesh child abnormal exit"
736                    );
737                }
738            }
739        };
740
741        let (wait_send, wait_recv) = mesh::oneshot();
742
743        let id = self.hosts.insert(MeshHostInner {
744            name: config.name,
745            pid,
746            node_id,
747            send: request_send,
748        });
749
750        thread::Builder::new()
751            .name(format!("wait-mesh-child-{}", pid))
752            .spawn(move || {
753                wait();
754                wait_send.send(id);
755            })
756            .unwrap();
757
758        self.waiters.push(wait_recv);
759        Ok(())
760    }
761}