mesh_process/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Infrastructure to create a multi-process mesh and spawn child processes
5//! within it.
6
7// UNSAFETY: Needed to accept a raw Fd/Handle from our spawning process.
8#![expect(unsafe_code)]
9
10use anyhow::Context;
11use base64::Engine;
12use debug_ptr::DebugPtr;
13use futures::FutureExt;
14use futures::StreamExt;
15use futures::executor::block_on;
16use futures_concurrency::future::Race;
17use inspect::Inspect;
18use inspect::SensitivityLevel;
19use mesh::MeshPayload;
20use mesh::OneshotReceiver;
21use mesh::message::MeshField;
22use mesh::payload::Protobuf;
23use mesh::rpc::Rpc;
24use mesh::rpc::RpcSend;
25use mesh_remote::InvitationAddress;
26#[cfg(unix)]
27use pal::unix::process::Builder as ProcessBuilder;
28#[cfg(windows)]
29use pal::windows::process;
30#[cfg(windows)]
31use pal::windows::process::Builder as ProcessBuilder;
32use pal_async::DefaultPool;
33use pal_async::task::Spawn;
34use pal_async::task::Task;
35use slab::Slab;
36use std::borrow::Cow;
37use std::ffi::OsString;
38use std::fs::File;
39#[cfg(unix)]
40use std::os::unix::prelude::*;
41#[cfg(windows)]
42use std::os::windows::prelude::*;
43use std::path::PathBuf;
44use std::thread;
45use tracing::Instrument;
46use tracing::instrument;
47use unicycle::FuturesUnordered;
48
49#[cfg(windows)]
50type IpcNode = mesh_remote::windows::AlpcNode;
51
52#[cfg(unix)]
53type IpcNode = mesh_remote::unix::UnixNode;
54
55#[cfg(unix)]
56const IPC_FD: i32 = 3;
57
58/// The environment variable for passing the mesh IPC invitation information to
59/// a child process. This is passed through the environment instead of a command
60/// line argument so that other processes cannot steal the invitation details
61/// and use it to break into the mesh.
62const INVITATION_ENV_NAME: &str = "MESH_WORKER_INVITATION";
63
64#[derive(Protobuf)]
65struct Invitation {
66    node_name: String,
67    address: InvitationAddress,
68    #[cfg(windows)]
69    directory_handle: usize,
70    #[cfg(unix)]
71    socket_fd: i32,
72}
73
74static PROCESS_NAME: DebugPtr<String> = DebugPtr::new();
75
76/// Runs a mesh host in the current thread, then exits the process, if this
77/// process was launched by [`Mesh::launch_host`].
78///
79/// The mesh invitation is provided via environment variables. If a mesh
80/// invitation is not available this function will return immediately with `Ok`.
81/// If a mesh invitation is available, this function joins the mesh and runs the
82/// future returned by `f` until `f` returns or the parent process shuts down
83/// the mesh.
84pub fn try_run_mesh_host<U, F, T>(base_name: &str, f: F) -> anyhow::Result<()>
85where
86    U: 'static + MeshPayload + Send,
87    F: AsyncFnOnce(U) -> anyhow::Result<T>,
88{
89    block_on(async {
90        if let Some(r) = node_from_environment().await? {
91            let NodeResult {
92                node_name,
93                node,
94                initial_port,
95            } = r;
96            PROCESS_NAME.store(&node_name);
97            set_program_name(&format!("{base_name}-{node_name}"));
98            let init = OneshotReceiver::<InitialMessage<U>>::from(initial_port)
99                .await
100                .context("failed to receive initial message")?;
101            let _drop = (
102                f(init.init_message).map(Some),
103                handle_host_requests(init.requests).map(|()| None),
104            )
105                .race()
106                .await
107                .transpose()?;
108
109            tracing::debug!("waiting to shut down node");
110            node.shutdown().await;
111            drop(_drop);
112            std::process::exit(0);
113        }
114        Ok(())
115    })
116}
117
118async fn handle_host_requests(mut recv: mesh::Receiver<HostRequest>) {
119    while let Some(req) = recv.next().await {
120        match req {
121            HostRequest::Inspect(deferred) => {
122                deferred.respond(inspect_host);
123            }
124            HostRequest::Crash => panic!("explicit panic request"),
125        }
126    }
127}
128
129fn set_program_name(name: &str) {
130    let _ = name;
131    #[cfg(target_os = "linux")]
132    {
133        let _ = std::fs::write("/proc/self/comm", name);
134    }
135}
136
137struct NodeResult {
138    node_name: String,
139    node: IpcNode,
140    initial_port: mesh::local_node::Port,
141}
142
143/// Create an IPC node from an invitation provided via the process environment.
144///
145/// Returns `None` if the invitation is not present in the environment.
146async fn node_from_environment() -> anyhow::Result<Option<NodeResult>> {
147    // return early with no node if the invitation is not present in the environment.
148    let invitation_str = match std::env::var(INVITATION_ENV_NAME) {
149        Ok(str) => str,
150        Err(_) => return Ok(None),
151    };
152
153    // Clear the string to avoid leaking the invitation information into child
154    // processes.
155    //
156    // TODO: this function is unsafe because
157    // it can cause UB if non-Rust code is concurrently accessing the
158    // environment in another thread. To be completely sound,
159    // either this function and its callers need to become
160    // `unsafe`, or we need to avoid using the environment to propagate the
161    // invitation so that we can avoid this call.
162    //
163    // SAFETY: Seems to work so far.
164    unsafe {
165        std::env::remove_var(INVITATION_ENV_NAME);
166    }
167
168    let invitation: Invitation = mesh::payload::decode(
169        &base64::engine::general_purpose::STANDARD
170            .decode(invitation_str)
171            .context("failed to base64 decode invitation")?,
172    )
173    .context("failed to protobuf decode invitation")?;
174
175    let (left, right) = mesh::local_node::Port::new_pair();
176
177    let node;
178    #[cfg(windows)]
179    {
180        // SAFETY: trusting the initiating process to pass a valid handle. A
181        // malicious process could pass a bad handle here, but a malicious
182        // process could also just corrupt our memory arbitrarily, so...
183        let directory =
184            unsafe { OwnedHandle::from_raw_handle(invitation.directory_handle as RawHandle) };
185
186        let invitation = mesh_remote::windows::AlpcInvitation {
187            address: invitation.address,
188            directory,
189        };
190
191        // join the node w/ the provided invitation and the send port of the channel.
192        node = mesh_remote::windows::AlpcNode::join(
193            pal_async::windows::TpPool::system(),
194            invitation,
195            left,
196        )
197        .context("failed to join mesh")?;
198    }
199
200    #[cfg(unix)]
201    {
202        // SAFETY: trusting the initiating process to pass a valid fd. A
203        // malicious process could pass a bad fd here, but a malicious
204        // process could also just corrupt our memory arbitrarily, so...
205        let fd = unsafe { OwnedFd::from_raw_fd(invitation.socket_fd) };
206        let invitation = mesh_remote::unix::UnixInvitation {
207            address: invitation.address,
208            fd,
209        };
210
211        // FUTURE: use pool provided by the caller.
212        let (_, driver) = DefaultPool::spawn_on_thread("mesh-worker-pool");
213        node = mesh_remote::unix::UnixNode::join(driver, invitation, left)
214            .await
215            .context("failed to join mesh")?;
216    }
217
218    Ok(Some(NodeResult {
219        node_name: invitation.node_name,
220        node,
221        initial_port: right,
222    }))
223}
224
225/// Represents a mesh::Node with the ability to spawn new processes that can
226/// communicate with any other process belonging to the same mesh.
227///
228/// # Process creation
229/// A `Mesh` instance can spawn new processes with an initial communication
230/// channel associated with the mesh. All processes originating from the same
231/// mesh can potentially communicate and exchange channels with each other.
232///
233/// Each spawned process can be configured differently via [`ProcessConfig`].
234/// Processes are created with [`Mesh::launch_host`].
235///
236/// ```no_run
237/// # use mesh_process::{Mesh, ProcessConfig};
238/// # futures::executor::block_on(async {
239/// let mesh = Mesh::new("remote_mesh".to_string()).unwrap();
240/// let (send, recv) = mesh::channel();
241/// mesh.launch_host(ProcessConfig::new("test"), recv).await.unwrap();
242/// send.send(String::from("message for new process"));
243/// # })
244/// ```
245#[derive(Inspect)]
246pub struct Mesh {
247    #[inspect(rename = "name")]
248    mesh_name: String,
249    #[inspect(flatten, send = "MeshRequest::Inspect")]
250    request: mesh::Sender<MeshRequest>,
251    #[inspect(skip)]
252    task: Task<()>,
253}
254
255/// Sandbox profile trait used for mesh hosts.
256pub trait SandboxProfile: Send {
257    /// Apply executes in the parent context and configures any sandbox
258    /// features that will be applied to the newly created process via
259    /// the pal builder object.
260    fn apply(&mut self, builder: &mut ProcessBuilder<'_>);
261
262    /// Finalize is intended to execute in the child process context after
263    /// application specific initialization is complete. It's optional as not
264    /// every sandbox profile will need to perform additional sandboxing.
265    /// In addition, the child will need to be aware enough to instantiate its
266    /// sandbox profile and invoke this method.
267    fn finalize(&mut self) -> anyhow::Result<()> {
268        Ok(())
269    }
270}
271
272/// Configuration for launching a new process in the mesh.
273pub struct ProcessConfig {
274    name: String,
275    process_name: Option<PathBuf>,
276    process_args: Vec<OsString>,
277    stderr: Option<File>,
278    skip_worker_arg: bool,
279    sandbox_profile: Option<Box<dyn SandboxProfile + Sync>>,
280    env_vars: Vec<(OsString, OsString)>,
281}
282
283impl ProcessConfig {
284    /// Returns new process configuration using the current process as the
285    /// process name.
286    pub fn new(name: impl Into<String>) -> Self {
287        Self {
288            name: name.into(),
289            process_name: None,
290            process_args: Vec::new(),
291            stderr: None,
292            skip_worker_arg: false,
293            sandbox_profile: None,
294            env_vars: Vec::new(),
295        }
296    }
297
298    /// Returns a new process configuration using the current process as the
299    /// process name.
300    pub fn new_with_sandbox(
301        name: impl Into<String>,
302        sandbox_profile: Box<dyn SandboxProfile + Sync>,
303    ) -> Self {
304        Self {
305            name: name.into(),
306            process_name: None,
307            process_args: Vec::new(),
308            stderr: None,
309            skip_worker_arg: false,
310            sandbox_profile: Some(sandbox_profile),
311            env_vars: Vec::new(),
312        }
313    }
314
315    /// Sets the process name.
316    pub fn process_name(mut self, name: impl Into<PathBuf>) -> Self {
317        self.process_name = Some(name.into());
318        self
319    }
320
321    /// Specifies whether to  appending `<node name>` to the process's command
322    /// line.
323    ///
324    /// This is done by default to make it easier to identify the process in
325    /// task lists, but if your process parses the command line then this may
326    /// get in the way.
327    pub fn skip_worker_arg(mut self, skip: bool) -> Self {
328        self.skip_worker_arg = skip;
329        self
330    }
331
332    /// Adds arguments to the process command line.
333    pub fn args<I>(mut self, args: I) -> Self
334    where
335        I: IntoIterator,
336        I::Item: Into<OsString>,
337    {
338        self.process_args.extend(args.into_iter().map(|x| x.into()));
339        self
340    }
341
342    /// Adds environment variables when launching the process.
343    pub fn env<I>(mut self, env_vars: I) -> Self
344    where
345        I: IntoIterator,
346        I::Item: Into<(OsString, OsString)>,
347    {
348        self.env_vars.extend(env_vars.into_iter().map(|x| x.into()));
349        self
350    }
351
352    /// Sets the process's stderr to `file`.
353    pub fn stderr(mut self, file: Option<File>) -> Self {
354        self.stderr = file;
355        self
356    }
357}
358
359struct MeshInner {
360    requests: mesh::Receiver<MeshRequest>,
361    hosts: Slab<MeshHostInner>,
362    /// Handles for spawned host processes.
363    waiters: FuturesUnordered<OneshotReceiver<usize>>,
364    /// Mesh node for host process communication.
365    node: IpcNode,
366    /// Name for this mesh instance, used for tracing/debugging.
367    mesh_name: String,
368    /// Job object. When closed, it will terminate all the child processes. This
369    /// is used to ensure the child processes don't outlive the parent.
370    #[cfg(windows)]
371    job: pal::windows::job::Job,
372}
373
374struct MeshHostInner {
375    name: String,
376    pid: i32,
377    node_id: mesh::NodeId,
378    send: mesh::Sender<HostRequest>,
379}
380
381enum MeshRequest {
382    NewHost(Rpc<NewHostParams, anyhow::Result<()>>),
383    Inspect(inspect::Deferred),
384    Crash(i32),
385}
386
387struct NewHostParams {
388    config: ProcessConfig,
389    recv: mesh::local_node::Port,
390    request_send: mesh::Sender<HostRequest>,
391}
392
393impl Mesh {
394    /// Creates a new mesh with the given name.
395    pub fn new(mesh_name: String) -> anyhow::Result<Self> {
396        #[cfg(windows)]
397        let job = {
398            let job = pal::windows::job::Job::new().context("failed to create job object")?;
399            job.set_terminate_on_close()
400                .context("failed to set job object terminate on close")?;
401            job
402        };
403
404        #[cfg(windows)]
405        let node = mesh_remote::windows::AlpcNode::new(pal_async::windows::TpPool::system())
406            .context("AlpcNode creation failure")?;
407        #[cfg(unix)]
408        let node = {
409            // FUTURE: use pool provided by the caller.
410            let (_, driver) = DefaultPool::spawn_on_thread("mesh-worker-pool");
411            mesh_remote::unix::UnixNode::new(driver)
412        };
413
414        let (request, requests) = mesh::channel();
415        let mut inner = MeshInner {
416            requests,
417            hosts: Default::default(),
418            waiters: Default::default(),
419            node,
420            mesh_name: mesh_name.clone(),
421            #[cfg(windows)]
422            job,
423        };
424
425        // Spawn a separate thread for launching mesh processes to avoid bad
426        // interactions with any other pools.
427        let (_, driver) = DefaultPool::spawn_on_thread("mesh");
428        let task = driver.spawn(
429            format!("mesh-{}", &mesh_name),
430            async move { inner.run().await },
431        );
432
433        Ok(Self {
434            request,
435            mesh_name,
436            task,
437        })
438    }
439
440    /// Spawns a new host in the mesh with the provided configuration and
441    /// initial message.
442    ///
443    /// The initial message will be provided to the closure passed to
444    /// [`try_run_mesh_host()`].
445    pub async fn launch_host<T: 'static + MeshField + Send>(
446        &self,
447        config: ProcessConfig,
448        initial_message: T,
449    ) -> anyhow::Result<()> {
450        let (request_send, request_recv) = mesh::channel();
451
452        let (init_send, init_recv) = mesh::oneshot::<InitialMessage<T>>();
453        init_send.send(InitialMessage {
454            requests: request_recv,
455            init_message: initial_message,
456        });
457
458        self.request
459            .call(
460                MeshRequest::NewHost,
461                NewHostParams {
462                    config,
463                    recv: init_recv.into(),
464                    request_send,
465                },
466            )
467            .await
468            .context("mesh failed")?
469    }
470
471    /// Shutdown the mesh and wait for any spawned processes to exit.
472    ///
473    /// The `Mesh` instance is no longer usable after `shutdown`.
474    pub async fn shutdown(self) {
475        let span = tracing::span!(
476            tracing::Level::INFO,
477            "mesh_shutdown",
478            name = self.mesh_name.as_str(),
479        );
480
481        async {
482            drop(self.request);
483            self.task.await;
484        }
485        .instrument(span)
486        .await;
487    }
488
489    /// Crashes the child process with the given process ID.
490    pub fn crash(&self, pid: i32) {
491        self.request.send(MeshRequest::Crash(pid));
492    }
493}
494
495#[derive(MeshPayload)]
496struct InitialMessage<T> {
497    requests: mesh::Receiver<HostRequest>,
498    init_message: T,
499}
500
501#[derive(Debug, MeshPayload)]
502enum HostRequest {
503    #[mesh(transparent)]
504    Inspect(inspect::Deferred),
505    Crash,
506}
507
508fn inspect_host(resp: &mut inspect::Response<'_>) {
509    resp.field("tasks", inspect_task::inspect_task_list());
510}
511
512#[derive(Inspect)]
513struct HostInspect<'a> {
514    #[inspect(safe)]
515    name: &'a str,
516    #[inspect(debug, safe)]
517    node_id: mesh::NodeId,
518    #[cfg(target_os = "linux")]
519    #[inspect(safe)]
520    rlimit: inspect_rlimit::InspectRlimit,
521}
522
523impl MeshInner {
524    async fn run(&mut self) {
525        enum Event {
526            Request(MeshRequest),
527            Done(usize),
528        }
529
530        loop {
531            let event = futures::select! { // merge semantics
532                request = self.requests.select_next_some() => Event::Request(request),
533                n = self.waiters.select_next_some() => Event::Done(n.unwrap()),
534                complete => break,
535            };
536
537            match event {
538                Event::Request(request) => match request {
539                    MeshRequest::NewHost(rpc) => {
540                        rpc.handle(async |params| self.spawn_process(params).await)
541                            .await
542                    }
543                    MeshRequest::Inspect(deferred) => {
544                        deferred.respond(|resp| {
545                            resp.sensitivity_child("hosts", SensitivityLevel::Safe, |req| {
546                                let mut resp = req.respond();
547                                for host in self.hosts.iter().map(|(_, host)| host) {
548                                    resp.sensitivity_field_mut(
549                                        &host.pid.to_string(),
550                                        SensitivityLevel::Safe,
551                                        &mut inspect::adhoc(|req| {
552                                            req.respond()
553                                                .merge(&HostInspect {
554                                                    name: &host.name,
555                                                    node_id: host.node_id,
556                                                    #[cfg(target_os = "linux")]
557                                                    rlimit: inspect_rlimit::InspectRlimit::for_pid(
558                                                        host.pid,
559                                                    ),
560                                                })
561                                                .merge(inspect::send(
562                                                    &host.send,
563                                                    HostRequest::Inspect,
564                                                ));
565                                        }),
566                                    );
567                                }
568                            })
569                            .sensitivity_field_mut(
570                                &format!("hosts/{}", std::process::id()),
571                                SensitivityLevel::Safe,
572                                &mut inspect::adhoc(|req| {
573                                    let mut resp = req.respond();
574                                    resp.merge(&HostInspect {
575                                        name: &self.mesh_name,
576                                        node_id: self.node.id(),
577                                        #[cfg(target_os = "linux")]
578                                        rlimit: inspect_rlimit::InspectRlimit::new(),
579                                    });
580                                    inspect_host(&mut resp);
581                                }),
582                            );
583                        });
584                    }
585                    MeshRequest::Crash(pid) => {
586                        if pid == std::process::id() as i32 {
587                            panic!("explicit panic request");
588                        }
589
590                        let mut found = false;
591                        for (_, host) in &self.hosts {
592                            if host.pid == pid {
593                                host.send.send(HostRequest::Crash);
594                                found = true;
595                                break;
596                            }
597                        }
598
599                        if !found {
600                            tracing::error!("failed to crash process, pid {pid} not found");
601                        }
602                    }
603                },
604                Event::Done(id) => {
605                    self.hosts.remove(id);
606                }
607            }
608        }
609    }
610
611    /// Spawns a new process with a mesh channel associated with this `Mesh` instance.
612    #[instrument(name = "mesh_spawn_process", skip(self, params), fields(mesh_name = self.mesh_name.as_str(), pid = tracing::field::Empty))]
613    async fn spawn_process(&mut self, params: NewHostParams) -> anyhow::Result<()> {
614        let NewHostParams {
615            config,
616            recv,
617            request_send,
618        } = params;
619
620        let pid;
621        let node_id;
622
623        // If no process name was passed, use the current executable path to
624        // ensure we get the right file, but set arg0 to match how this process
625        // was launched.
626        let (arg0, process_name) = if let Some(n) = &config.process_name {
627            (None, Cow::Borrowed(n))
628        } else {
629            (
630                std::env::args_os().next(),
631                Cow::Owned(std::env::current_exe().context("failed to get current exe path")?),
632            )
633        };
634
635        let name = config.name.clone();
636
637        #[cfg(windows)]
638        let wait = {
639            let (invitation, handle) = self.node.invite(recv).context("mesh node invite error")?;
640            node_id = invitation.address.local_addr.node;
641
642            let invitation_env = base64::engine::general_purpose::STANDARD.encode(
643                mesh::payload::encode(Invitation {
644                    node_name: name.clone(),
645                    address: invitation.address,
646                    directory_handle: invitation.directory.as_raw_handle() as usize,
647                }),
648            );
649
650            let mut args = config.process_args;
651            if !config.skip_worker_arg {
652                args.push(name.clone().into());
653            }
654
655            let mut builder = process::Builder::from_args(
656                arg0.as_ref()
657                    .map_or_else(|| process_name.as_os_str(), |x| x.as_os_str()),
658                &args,
659            );
660            if arg0.is_some() {
661                builder.application_name(process_name.as_path());
662            }
663            builder
664                .stdin(process::Stdio::Null)
665                .stdout(process::Stdio::Null)
666                .handle(&invitation.directory)
667                .env(INVITATION_ENV_NAME, invitation_env)
668                .extend_env(config.env_vars)
669                .job(self.job.as_handle());
670
671            if let Some(log_file) = config.stderr.as_ref() {
672                builder.stderr(process::Stdio::Handle(log_file.as_handle()));
673            }
674
675            if let Some(mut sandbox_profile) = config.sandbox_profile {
676                sandbox_profile.apply(&mut builder);
677            }
678
679            let child = builder.spawn().context("failed to launch mesh process")?;
680            // Wait for the child to connect to the mesh. TODO: timeout
681            handle.await;
682            pid = child.id() as i32;
683            tracing::Span::current().record("pid", pid);
684            move || {
685                child.wait();
686                let code = child.exit_code();
687                if code == 0 {
688                    tracing::info!(pid, name = name.as_str(), "mesh child exited successfully");
689                } else {
690                    tracing::error!(pid, name = name.as_str(), code, "mesh child abnormal exit");
691                }
692            }
693        };
694        #[cfg(unix)]
695        let mut wait = {
696            use pal::unix::process;
697
698            let invitation = self
699                .node
700                .invite(recv)
701                .await
702                .context("mesh node invite error")?;
703
704            node_id = invitation.address.local_addr.node;
705
706            let invitation_env = base64::engine::general_purpose::STANDARD.encode(
707                mesh::payload::encode(Invitation {
708                    node_name: name.clone(),
709                    address: invitation.address,
710                    socket_fd: IPC_FD,
711                }),
712            );
713
714            let mut command = process::Builder::new(process_name.into_owned());
715            if let Some(arg0) = arg0 {
716                command.arg0(arg0);
717            }
718            command
719                .args(&config.process_args)
720                .stdin(process::Stdio::Null)
721                .stdout(process::Stdio::Null)
722                .dup_fd(invitation.fd.as_fd(), IPC_FD)
723                .env(INVITATION_ENV_NAME, invitation_env);
724
725            if !config.skip_worker_arg {
726                command.arg(&name);
727            }
728
729            if let Some(log_file) = config.stderr.as_ref() {
730                command.stderr(process::Stdio::Fd(log_file.as_fd()));
731            }
732
733            if let Some(mut sandbox_profile) = config.sandbox_profile {
734                sandbox_profile.apply(&mut command);
735            }
736
737            let mut child = command.spawn().context("failed to launch mesh process")?;
738            pid = child.id();
739            tracing::Span::current().record("pid", pid);
740            move || {
741                let exit_status = child.wait().expect("mesh child wait failure");
742                if let Some(0) = exit_status.code() {
743                    tracing::info!(pid, name = name.as_str(), "mesh child exited successfully");
744                } else {
745                    tracing::error!(
746                        pid,
747                        name = name.as_str(),
748                        %exit_status,
749                        "mesh child abnormal exit"
750                    );
751                }
752            }
753        };
754
755        let (wait_send, wait_recv) = mesh::oneshot();
756
757        let id = self.hosts.insert(MeshHostInner {
758            name: config.name,
759            pid,
760            node_id,
761            send: request_send,
762        });
763
764        thread::Builder::new()
765            .name(format!("wait-mesh-child-{}", pid))
766            .spawn(move || {
767                wait();
768                wait_send.send(id);
769            })
770            .unwrap();
771
772        self.waiters.push(wait_recv);
773        Ok(())
774    }
775}