mesh_process/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Infrastructure to create a multi-process mesh and spawn child processes
5//! within it.
6
7// UNSAFETY: Needed to accept a raw Fd/Handle from our spawning process.
8#![expect(unsafe_code)]
9
10use anyhow::Context;
11use base64::Engine;
12use debug_ptr::DebugPtr;
13use futures::FutureExt;
14use futures::StreamExt;
15use futures::executor::block_on;
16use futures_concurrency::future::Race;
17use inspect::Inspect;
18use inspect::SensitivityLevel;
19use mesh::MeshPayload;
20use mesh::OneshotReceiver;
21use mesh::message::MeshField;
22use mesh::payload::Protobuf;
23use mesh::rpc::Rpc;
24use mesh::rpc::RpcSend;
25use mesh_remote::InvitationAddress;
26#[cfg(unix)]
27use pal::unix::process::Builder as ProcessBuilder;
28#[cfg(windows)]
29use pal::windows::process;
30#[cfg(windows)]
31use pal::windows::process::Builder as ProcessBuilder;
32use pal_async::DefaultPool;
33use pal_async::task::Spawn;
34use pal_async::task::Task;
35use slab::Slab;
36use std::borrow::Cow;
37use std::ffi::OsString;
38use std::fs::File;
39#[cfg(unix)]
40use std::os::unix::prelude::*;
41#[cfg(windows)]
42use std::os::windows::prelude::*;
43use std::path::PathBuf;
44use std::thread;
45use tracing::Instrument;
46use tracing::instrument;
47use unicycle::FuturesUnordered;
48
49#[cfg(windows)]
50type IpcNode = mesh_remote::windows::AlpcNode;
51
52#[cfg(unix)]
53type IpcNode = mesh_remote::unix::UnixNode;
54
55#[cfg(unix)]
56const IPC_FD: i32 = 3;
57
58/// The environment variable for passing the mesh IPC invitation information to
59/// a child process. This is passed through the environment instead of a command
60/// line argument so that other processes cannot steal the invitation details
61/// and use it to break into the mesh.
62const INVITATION_ENV_NAME: &str = "MESH_WORKER_INVITATION";
63
64#[derive(Protobuf)]
65struct Invitation {
66    node_name: String,
67    address: InvitationAddress,
68    #[cfg(windows)]
69    directory_handle: usize,
70    #[cfg(unix)]
71    socket_fd: i32,
72}
73
74static PROCESS_NAME: DebugPtr<String> = DebugPtr::new();
75
76/// Runs a mesh host in the current thread, then exits the process, if this
77/// process was launched by [`Mesh::launch_host`].
78///
79/// The mesh invitation is provided via environment variables. If a mesh
80/// invitation is not available this function will return immediately with `Ok`.
81/// If a mesh invitation is available, this function joins the mesh and runs the
82/// future returned by `f` until `f` returns or the parent process shuts down
83/// the mesh.
84pub fn try_run_mesh_host<U, F, T>(base_name: &str, f: F) -> anyhow::Result<()>
85where
86    U: 'static + MeshPayload + Send,
87    F: AsyncFnOnce(U) -> anyhow::Result<T>,
88{
89    block_on(async {
90        if let Some(r) = node_from_environment().await? {
91            let NodeResult {
92                node_name,
93                node,
94                initial_port,
95            } = r;
96            PROCESS_NAME.store(&node_name);
97            set_program_name(&format!("{base_name}-{node_name}"));
98            let init = OneshotReceiver::<InitialMessage<U>>::from(initial_port)
99                .await
100                .context("failed to receive initial message")?;
101            let _drop = (
102                f(init.init_message).map(Some),
103                handle_host_requests(init.requests).map(|()| None),
104            )
105                .race()
106                .await
107                .transpose()?;
108
109            tracing::debug!("waiting to shut down node");
110            node.shutdown().await;
111            drop(_drop);
112            std::process::exit(0);
113        }
114        Ok(())
115    })
116}
117
118async fn handle_host_requests(mut recv: mesh::Receiver<HostRequest>) {
119    while let Some(req) = recv.next().await {
120        match req {
121            HostRequest::Inspect(deferred) => {
122                deferred.respond(inspect_host);
123            }
124            HostRequest::Crash => panic!("explicit panic request"),
125        }
126    }
127}
128
129fn set_program_name(name: &str) {
130    let _ = name;
131    #[cfg(target_os = "linux")]
132    {
133        let _ = std::fs::write("/proc/self/comm", name);
134    }
135}
136
137struct NodeResult {
138    node_name: String,
139    node: IpcNode,
140    initial_port: mesh::local_node::Port,
141}
142
143/// Create an IPC node from an invitation provided via the process environment.
144///
145/// Returns `None` if the invitation is not present in the environment.
146async fn node_from_environment() -> anyhow::Result<Option<NodeResult>> {
147    // return early with no node if the invitation is not present in the environment.
148    let invitation_str = match std::env::var(INVITATION_ENV_NAME) {
149        Ok(str) => str,
150        Err(_) => return Ok(None),
151    };
152
153    // Clear the string to avoid leaking the invitation information into child
154    // processes.
155    //
156    // TODO: this function is unsafe because
157    // it can cause UB if non-Rust code is concurrently accessing the
158    // environment in another thread. To be completely sound,
159    // either this function and its callers need to become
160    // `unsafe`, or we need to avoid using the environment to propagate the
161    // invitation so that we can avoid this call.
162    //
163    // SAFETY: Seems to work so far.
164    unsafe {
165        std::env::remove_var(INVITATION_ENV_NAME);
166    }
167
168    let invitation: Invitation = mesh::payload::decode(
169        &base64::engine::general_purpose::STANDARD
170            .decode(invitation_str)
171            .context("failed to base64 decode invitation")?,
172    )
173    .context("failed to protobuf decode invitation")?;
174
175    let (left, right) = mesh::local_node::Port::new_pair();
176
177    let node;
178    #[cfg(windows)]
179    {
180        // SAFETY: trusting the initiating process to pass a valid handle. A
181        // malicious process could pass a bad handle here, but a malicious
182        // process could also just corrupt our memory arbitrarily, so...
183        let directory =
184            unsafe { OwnedHandle::from_raw_handle(invitation.directory_handle as RawHandle) };
185
186        let invitation = mesh_remote::windows::AlpcInvitation {
187            address: invitation.address,
188            directory,
189        };
190
191        // join the node w/ the provided invitation and the send port of the channel.
192        node = mesh_remote::windows::AlpcNode::join(
193            pal_async::windows::TpPool::system(),
194            invitation,
195            left,
196        )
197        .context("failed to join mesh")?;
198    }
199
200    #[cfg(unix)]
201    {
202        // SAFETY: trusting the initiating process to pass a valid fd. A
203        // malicious process could pass a bad fd here, but a malicious
204        // process could also just corrupt our memory arbitrarily, so...
205        let fd = unsafe { OwnedFd::from_raw_fd(invitation.socket_fd) };
206        let invitation = mesh_remote::unix::UnixInvitation {
207            address: invitation.address,
208            fd,
209        };
210
211        // FUTURE: use pool provided by the caller.
212        let (_, driver) = DefaultPool::spawn_on_thread("mesh-worker-pool");
213        node = mesh_remote::unix::UnixNode::join(driver, invitation, left)
214            .await
215            .context("failed to join mesh")?;
216    }
217
218    Ok(Some(NodeResult {
219        node_name: invitation.node_name,
220        node,
221        initial_port: right,
222    }))
223}
224
225/// Represents a mesh::Node with the ability to spawn new processes that can
226/// communicate with any other process belonging to the same mesh.
227///
228/// # Process creation
229/// A `Mesh` instance can spawn new processes with an initial communication
230/// channel associated with the mesh. All processes originating from the same
231/// mesh can potentially communicate and exchange channels with each other.
232///
233/// Each spawned process can be configured differently via [`ProcessConfig`].
234/// Processes are created with [`Mesh::launch_host`].
235///
236/// ```no_run
237/// # use mesh_process::{Mesh, ProcessConfig};
238/// # futures::executor::block_on(async {
239/// let mesh = Mesh::new("remote_mesh".to_string()).unwrap();
240/// let (send, recv) = mesh::channel();
241/// mesh.launch_host(ProcessConfig::new("test"), recv).await.unwrap();
242/// send.send(String::from("message for new process"));
243/// # })
244/// ```
245pub struct Mesh {
246    mesh_name: String,
247    request: mesh::Sender<MeshRequest>,
248    task: Task<()>,
249}
250
251/// Sandbox profile trait used for mesh hosts.
252pub trait SandboxProfile: Send {
253    /// Apply executes in the parent context and configures any sandbox
254    /// features that will be applied to the newly created process via
255    /// the pal builder object.
256    fn apply(&mut self, builder: &mut ProcessBuilder<'_>);
257
258    /// Finalize is intended to execute in the child process context after
259    /// application specific initialization is complete. It's optional as not
260    /// every sandbox profile will need to perform additional sandboxing.
261    /// In addition, the child will need to be aware enough to instantiate its
262    /// sandbox profile and invoke this method.
263    fn finalize(&mut self) -> anyhow::Result<()> {
264        Ok(())
265    }
266}
267
268/// Configuration for launching a new process in the mesh.
269pub struct ProcessConfig {
270    name: String,
271    process_name: Option<PathBuf>,
272    process_args: Vec<OsString>,
273    stderr: Option<File>,
274    skip_worker_arg: bool,
275    sandbox_profile: Option<Box<dyn SandboxProfile + Sync>>,
276}
277
278impl ProcessConfig {
279    /// Returns new process configuration using the current process as the
280    /// process name.
281    pub fn new(name: impl Into<String>) -> Self {
282        Self {
283            name: name.into(),
284            process_name: None,
285            process_args: Vec::new(),
286            stderr: None,
287            skip_worker_arg: false,
288            sandbox_profile: None,
289        }
290    }
291
292    /// Returns a new process configuration using the current process as the
293    /// process name.
294    pub fn new_with_sandbox(
295        name: impl Into<String>,
296        sandbox_profile: Box<dyn SandboxProfile + Sync>,
297    ) -> Self {
298        Self {
299            name: name.into(),
300            process_name: None,
301            process_args: Vec::new(),
302            stderr: None,
303            skip_worker_arg: false,
304            sandbox_profile: Some(sandbox_profile),
305        }
306    }
307
308    /// Sets the process name.
309    pub fn process_name(mut self, name: impl Into<PathBuf>) -> Self {
310        self.process_name = Some(name.into());
311        self
312    }
313
314    /// Specifies whether to  appending `<node name>` to the process's command
315    /// line.
316    ///
317    /// This is done by default to make it easier to identify the process in
318    /// task lists, but if your process parses the command line then this may
319    /// get in the way.
320    pub fn skip_worker_arg(mut self, skip: bool) -> Self {
321        self.skip_worker_arg = skip;
322        self
323    }
324
325    /// Adds arguments to the process command line.
326    pub fn args<I>(mut self, args: I) -> Self
327    where
328        I: IntoIterator,
329        I::Item: Into<OsString>,
330    {
331        self.process_args.extend(args.into_iter().map(|x| x.into()));
332        self
333    }
334
335    /// Sets the process's stderr to `file`.
336    pub fn stderr(mut self, file: Option<File>) -> Self {
337        self.stderr = file;
338        self
339    }
340}
341
342struct MeshInner {
343    requests: mesh::Receiver<MeshRequest>,
344    hosts: Slab<MeshHostInner>,
345    /// Handles for spawned host processes.
346    waiters: FuturesUnordered<OneshotReceiver<usize>>,
347    /// Mesh node for host process communication.
348    node: IpcNode,
349    /// Name for this mesh instance, used for tracing/debugging.
350    mesh_name: String,
351    /// Job object. When closed, it will terminate all the child processes. This
352    /// is used to ensure the child processes don't outlive the parent.
353    #[cfg(windows)]
354    job: pal::windows::job::Job,
355}
356
357struct MeshHostInner {
358    name: String,
359    pid: i32,
360    node_id: mesh::NodeId,
361    send: mesh::Sender<HostRequest>,
362}
363
364enum MeshRequest {
365    NewHost(Rpc<NewHostParams, anyhow::Result<()>>),
366    Inspect(inspect::Deferred),
367    Crash(i32),
368}
369
370struct NewHostParams {
371    config: ProcessConfig,
372    recv: mesh::local_node::Port,
373    request_send: mesh::Sender<HostRequest>,
374}
375
376impl Inspect for Mesh {
377    fn inspect(&self, req: inspect::Request<'_>) {
378        self.request.send(MeshRequest::Inspect(req.defer()));
379    }
380}
381
382impl Mesh {
383    /// Creates a new mesh with the given name.
384    pub fn new(mesh_name: String) -> anyhow::Result<Self> {
385        #[cfg(windows)]
386        let job = {
387            let job = pal::windows::job::Job::new().context("failed to create job object")?;
388            job.set_terminate_on_close()
389                .context("failed to set job object terminate on close")?;
390            job
391        };
392
393        #[cfg(windows)]
394        let node = mesh_remote::windows::AlpcNode::new(pal_async::windows::TpPool::system())
395            .context("AlpcNode creation failure")?;
396        #[cfg(unix)]
397        let node = {
398            // FUTURE: use pool provided by the caller.
399            let (_, driver) = DefaultPool::spawn_on_thread("mesh-worker-pool");
400            mesh_remote::unix::UnixNode::new(driver)
401        };
402
403        let (request, requests) = mesh::channel();
404        let mut inner = MeshInner {
405            requests,
406            hosts: Default::default(),
407            waiters: Default::default(),
408            node,
409            mesh_name: mesh_name.clone(),
410            #[cfg(windows)]
411            job,
412        };
413
414        // Spawn a separate thread for launching mesh processes to avoid bad
415        // interactions with any other pools.
416        let (_, driver) = DefaultPool::spawn_on_thread("mesh");
417        let task = driver.spawn(
418            format!("mesh-{}", &mesh_name),
419            async move { inner.run().await },
420        );
421
422        Ok(Self {
423            request,
424            mesh_name,
425            task,
426        })
427    }
428
429    /// Spawns a new host in the mesh with the provided configuration and
430    /// initial message.
431    ///
432    /// The initial message will be provided to the closure passed to
433    /// [`try_run_mesh_host()`].
434    pub async fn launch_host<T: 'static + MeshField + Send>(
435        &self,
436        config: ProcessConfig,
437        initial_message: T,
438    ) -> anyhow::Result<()> {
439        let (request_send, request_recv) = mesh::channel();
440
441        let (init_send, init_recv) = mesh::oneshot::<InitialMessage<T>>();
442        init_send.send(InitialMessage {
443            requests: request_recv,
444            init_message: initial_message,
445        });
446
447        self.request
448            .call(
449                MeshRequest::NewHost,
450                NewHostParams {
451                    config,
452                    recv: init_recv.into(),
453                    request_send,
454                },
455            )
456            .await
457            .context("mesh failed")?
458    }
459
460    /// Shutdown the mesh and wait for any spawned processes to exit.
461    ///
462    /// The `Mesh` instance is no longer usable after `shutdown`.
463    pub async fn shutdown(self) {
464        let span = tracing::span!(
465            tracing::Level::INFO,
466            "mesh_shutdown",
467            name = self.mesh_name.as_str(),
468        );
469
470        async {
471            drop(self.request);
472            self.task.await;
473        }
474        .instrument(span)
475        .await;
476    }
477
478    /// Crashes the child process with the given process ID.
479    pub fn crash(&self, pid: i32) {
480        self.request.send(MeshRequest::Crash(pid));
481    }
482}
483
484#[derive(MeshPayload)]
485struct InitialMessage<T> {
486    requests: mesh::Receiver<HostRequest>,
487    init_message: T,
488}
489
490#[derive(Debug, MeshPayload)]
491enum HostRequest {
492    #[mesh(transparent)]
493    Inspect(inspect::Deferred),
494    Crash,
495}
496
497fn inspect_host(resp: &mut inspect::Response<'_>) {
498    resp.field("tasks", inspect_task::inspect_task_list());
499}
500
501#[derive(Inspect)]
502struct HostInspect<'a> {
503    #[inspect(safe)]
504    name: &'a str,
505    #[inspect(debug, safe)]
506    node_id: mesh::NodeId,
507    #[cfg(target_os = "linux")]
508    #[inspect(safe)]
509    rlimit: inspect_rlimit::InspectRlimit,
510}
511
512impl MeshInner {
513    async fn run(&mut self) {
514        enum Event {
515            Request(MeshRequest),
516            Done(usize),
517        }
518
519        loop {
520            let event = futures::select! { // merge semantics
521                request = self.requests.select_next_some() => Event::Request(request),
522                n = self.waiters.select_next_some() => Event::Done(n.unwrap()),
523                complete => break,
524            };
525
526            match event {
527                Event::Request(request) => match request {
528                    MeshRequest::NewHost(rpc) => {
529                        rpc.handle(async |params| self.spawn_process(params).await)
530                            .await
531                    }
532                    MeshRequest::Inspect(deferred) => {
533                        deferred.respond(|resp| {
534                            resp.sensitivity_child("hosts", SensitivityLevel::Safe, |req| {
535                                let mut resp = req.respond();
536                                for host in self.hosts.iter().map(|(_, host)| host) {
537                                    resp.sensitivity_field_mut(
538                                        &host.pid.to_string(),
539                                        SensitivityLevel::Safe,
540                                        &mut inspect::adhoc(|req| {
541                                            req.respond()
542                                                .merge(&HostInspect {
543                                                    name: &host.name,
544                                                    node_id: host.node_id,
545                                                    #[cfg(target_os = "linux")]
546                                                    rlimit: inspect_rlimit::InspectRlimit::for_pid(
547                                                        host.pid,
548                                                    ),
549                                                })
550                                                .merge(inspect::adhoc(|req| {
551                                                    host.send
552                                                        .send(HostRequest::Inspect(req.defer()));
553                                                }));
554                                        }),
555                                    );
556                                }
557                            })
558                            .sensitivity_field_mut(
559                                &format!("hosts/{}", std::process::id()),
560                                SensitivityLevel::Safe,
561                                &mut inspect::adhoc(|req| {
562                                    let mut resp = req.respond();
563                                    resp.merge(&HostInspect {
564                                        name: &self.mesh_name,
565                                        node_id: self.node.id(),
566                                        #[cfg(target_os = "linux")]
567                                        rlimit: inspect_rlimit::InspectRlimit::new(),
568                                    });
569                                    inspect_host(&mut resp);
570                                }),
571                            );
572                        });
573                    }
574                    MeshRequest::Crash(pid) => {
575                        if pid == std::process::id() as i32 {
576                            panic!("explicit panic request");
577                        }
578
579                        let mut found = false;
580                        for (_, host) in &self.hosts {
581                            if host.pid == pid {
582                                host.send.send(HostRequest::Crash);
583                                found = true;
584                                break;
585                            }
586                        }
587
588                        if !found {
589                            tracing::error!("failed to crash process, pid {pid} not found");
590                        }
591                    }
592                },
593                Event::Done(id) => {
594                    self.hosts.remove(id);
595                }
596            }
597        }
598    }
599
600    /// Spawns a new process with a mesh channel associated with this `Mesh` instance.
601    #[instrument(name = "mesh_spawn_process", skip(self, params), fields(mesh_name = self.mesh_name.as_str(), pid = tracing::field::Empty))]
602    async fn spawn_process(&mut self, params: NewHostParams) -> anyhow::Result<()> {
603        let NewHostParams {
604            config,
605            recv,
606            request_send,
607        } = params;
608
609        let pid;
610        let node_id;
611
612        // If no process name was passed, use the current executable path to
613        // ensure we get the right file, but set arg0 to match how this process
614        // was launched.
615        let (arg0, process_name) = if let Some(n) = &config.process_name {
616            (None, Cow::Borrowed(n))
617        } else {
618            (
619                std::env::args_os().next(),
620                Cow::Owned(std::env::current_exe().context("failed to get current exe path")?),
621            )
622        };
623
624        let name = config.name.clone();
625
626        #[cfg(windows)]
627        let wait = {
628            let (invitation, handle) = self.node.invite(recv).context("mesh node invite error")?;
629            node_id = invitation.address.local_addr.node;
630
631            let invitation_env = base64::engine::general_purpose::STANDARD.encode(
632                mesh::payload::encode(Invitation {
633                    node_name: name.clone(),
634                    address: invitation.address,
635                    directory_handle: invitation.directory.as_raw_handle() as usize,
636                }),
637            );
638
639            let mut args = config.process_args;
640            if !config.skip_worker_arg {
641                args.push(name.clone().into());
642            }
643
644            let mut builder = process::Builder::from_args(
645                arg0.as_ref()
646                    .map_or_else(|| process_name.as_os_str(), |x| x.as_os_str()),
647                &args,
648            );
649            if arg0.is_some() {
650                builder.application_name(process_name.as_path());
651            }
652            builder
653                .stdin(process::Stdio::Null)
654                .stdout(process::Stdio::Null)
655                .handle(&invitation.directory)
656                .env(INVITATION_ENV_NAME, invitation_env)
657                .job(self.job.as_handle());
658
659            if let Some(log_file) = config.stderr.as_ref() {
660                builder.stderr(process::Stdio::Handle(log_file.as_handle()));
661            }
662
663            if let Some(mut sandbox_profile) = config.sandbox_profile {
664                sandbox_profile.apply(&mut builder);
665            }
666
667            let child = builder.spawn().context("failed to launch mesh process")?;
668            // Wait for the child to connect to the mesh. TODO: timeout
669            handle.await;
670            pid = child.id() as i32;
671            tracing::Span::current().record("pid", pid);
672            move || {
673                child.wait();
674                let code = child.exit_code();
675                if code == 0 {
676                    tracing::info!(pid, name = name.as_str(), "mesh child exited successfully");
677                } else {
678                    tracing::error!(pid, name = name.as_str(), code, "mesh child abnormal exit");
679                }
680            }
681        };
682        #[cfg(unix)]
683        let mut wait = {
684            use pal::unix::process;
685
686            let invitation = self
687                .node
688                .invite(recv)
689                .await
690                .context("mesh node invite error")?;
691
692            node_id = invitation.address.local_addr.node;
693
694            let invitation_env = base64::engine::general_purpose::STANDARD.encode(
695                mesh::payload::encode(Invitation {
696                    node_name: name.clone(),
697                    address: invitation.address,
698                    socket_fd: IPC_FD,
699                }),
700            );
701
702            let mut command = process::Builder::new(process_name.into_owned());
703            if let Some(arg0) = arg0 {
704                command.arg0(arg0);
705            }
706            command
707                .args(&config.process_args)
708                .stdin(process::Stdio::Null)
709                .stdout(process::Stdio::Null)
710                .dup_fd(invitation.fd.as_fd(), IPC_FD)
711                .env(INVITATION_ENV_NAME, invitation_env);
712
713            if !config.skip_worker_arg {
714                command.arg(&name);
715            }
716
717            if let Some(log_file) = config.stderr.as_ref() {
718                command.stderr(process::Stdio::Fd(log_file.as_fd()));
719            }
720
721            if let Some(mut sandbox_profile) = config.sandbox_profile {
722                sandbox_profile.apply(&mut command);
723            }
724
725            let mut child = command.spawn().context("failed to launch mesh process")?;
726            pid = child.id();
727            tracing::Span::current().record("pid", pid);
728            move || {
729                let exit_status = child.wait().expect("mesh child wait failure");
730                if let Some(0) = exit_status.code() {
731                    tracing::info!(pid, name = name.as_str(), "mesh child exited successfully");
732                } else {
733                    tracing::error!(
734                        pid,
735                        name = name.as_str(),
736                        %exit_status,
737                        "mesh child abnormal exit"
738                    );
739                }
740            }
741        };
742
743        let (wait_send, wait_recv) = mesh::oneshot();
744
745        let id = self.hosts.insert(MeshHostInner {
746            name: config.name,
747            pid,
748            node_id,
749            send: request_send,
750        });
751
752        thread::Builder::new()
753            .name(format!("wait-mesh-child-{}", pid))
754            .spawn(move || {
755                wait();
756                wait_send.send(id);
757            })
758            .unwrap();
759
760        self.waiters.push(wait_recv);
761        Ok(())
762    }
763}