diag_client/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! The client for connecting to the Underhill diagnostics server.
5
6pub mod kmsg_stream;
7
8use anyhow::Context;
9use diag_proto::ExecRequest;
10use diag_proto::WaitRequest;
11use diag_proto::WaitResponse;
12use diag_proto::network_packet_capture_request::OpData;
13use diag_proto::network_packet_capture_request::Operation;
14use futures::AsyncReadExt;
15use futures::AsyncWrite;
16use futures::AsyncWriteExt;
17use inspect::Node;
18use inspect::ValueKind;
19use kmsg_stream::KmsgStream;
20use mesh_rpc::service::Status;
21use pal_async::driver::Driver;
22use pal_async::socket::PolledSocket;
23use pal_async::task::Spawn;
24use std::io::ErrorKind;
25use std::path::Path;
26use std::path::PathBuf;
27use std::time::Duration;
28use thiserror::Error;
29
30#[cfg(windows)]
31/// Functions for Hyper-V
32pub mod hyperv {
33    use super::ConnectError;
34    use anyhow::Context;
35    use guid::Guid;
36    use pal_async::driver::Driver;
37    use pal_async::socket::PolledSocket;
38    use pal_async::timer::PolledTimer;
39    use std::fs::File;
40    use std::io::Write;
41    use std::process::Command;
42    use std::time::Duration;
43    use vmsocket::VmAddress;
44    use vmsocket::VmSocket;
45    use vmsocket::VmStream;
46
47    /// Defines how to access the serial port
48    pub enum ComPortAccessInfo<'a> {
49        /// Access by number
50        NameAndPortNumber(&'a str, u32),
51        /// Access through a named pipe
52        PortPipePath(&'a str),
53    }
54
55    /// Get ID from name
56    pub fn vm_id_from_name(name: &str) -> anyhow::Result<Guid> {
57        let output = Command::new("hvc.exe")
58            .arg("id")
59            .arg(name)
60            .output()
61            .context("failed to launch hvc")?;
62
63        if output.status.success() {
64            let stdout = std::str::from_utf8(&output.stdout)
65                .context("failed to parse hvc output")?
66                .trim();
67            Ok(stdout
68                .parse()
69                .with_context(|| format!("failed to parse VM ID '{}'", &stdout))?)
70        } else {
71            anyhow::bail!(
72                "{}",
73                std::str::from_utf8(&output.stderr).context("failed to parse hvc error output")?
74            )
75        }
76    }
77
78    /// Connect to Hyper-V socket
79    pub async fn connect_vsock(
80        driver: &(impl Driver + ?Sized),
81        vm_id: Guid,
82        port: u32,
83    ) -> Result<VmStream, ConnectError> {
84        let socket = VmSocket::new()
85            .context("failed to create AF_HYPERV socket")
86            .map_err(ConnectError::other)?;
87
88        socket
89            .set_connect_timeout(Duration::from_secs(1))
90            .context("failed to set connect timeout")
91            .map_err(ConnectError::other)?;
92
93        socket
94            .set_high_vtl(true)
95            .context("failed to set socket for VTL2")
96            .map_err(ConnectError::other)?;
97
98        let mut socket: PolledSocket<socket2::Socket> = PolledSocket::new(driver, socket.into())
99            .context("failed to create polled socket")
100            .map_err(ConnectError::other)?;
101
102        socket
103            .connect(&VmAddress::hyperv_vsock(vm_id, port).into())
104            .await
105            .map_err(ConnectError::connect)?;
106
107        Ok(socket.convert().into_inner())
108    }
109
110    /// Opens a serial port on a Hyper-V VM.
111    ///
112    /// If the VM is not running, it will periodically try to connect to the
113    /// pipe until the VM starts running. In theory, we could instead create a
114    /// named pipe server, which Hyper-V would connect to when the VM starts.
115    /// However, in this mode, once the named pipe is disconnected, Hyper-V
116    /// stops trying to reconnect until the VM is powered off and powered on
117    /// again, so don't do that.
118    pub async fn open_serial_port(
119        driver: &(impl Driver + ?Sized),
120        port: ComPortAccessInfo<'_>,
121    ) -> anyhow::Result<File> {
122        let path = match port {
123            ComPortAccessInfo::NameAndPortNumber(vm, num) => {
124                let output = Command::new("powershell.exe")
125                    .arg("-NoProfile")
126                    .arg(format!(
127                        r#"$x = Get-VMComPort "{vm}" -Number {num} -ErrorAction Stop; $x.Path"#,
128                    ))
129                    .output()
130                    .context("failed to query VM com port")?;
131
132                if !output.status.success() {
133                    let _ = std::io::stderr().write_all(&output.stderr);
134                    anyhow::bail!(
135                        "failed to query VM com port: exit status {}",
136                        output.status.code().unwrap()
137                    );
138                }
139                &String::from_utf8(output.stdout)?
140            }
141            ComPortAccessInfo::PortPipePath(path) => path,
142        };
143
144        let path = path.trim();
145        if path.is_empty() {
146            anyhow::bail!("Requested VM COM port is not configured");
147        }
148
149        let mut timer = None;
150        let pipe = loop {
151            match fs_err::OpenOptions::new().read(true).write(true).open(path) {
152                Ok(pipe) => break pipe.into(),
153                Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
154                    // The VM is not running. Wait a bit and try again.
155                    timer
156                        .get_or_insert_with(|| PolledTimer::new(driver))
157                        .sleep(Duration::from_millis(100))
158                        .await;
159                }
160                Err(err) => Err(err)?,
161            }
162        };
163
164        Ok(pipe)
165    }
166}
167
168/// Connect to a vsock with port and path
169pub async fn connect_hybrid_vsock(
170    driver: &(impl Driver + ?Sized),
171    path: &Path,
172    port: u32,
173) -> Result<PolledSocket<socket2::Socket>, ConnectError> {
174    let socket = unix_socket::UnixStream::connect(path).map_err(ConnectError::connect)?;
175    let mut socket = PolledSocket::new(driver, socket).map_err(ConnectError::other)?;
176    socket
177        .write_all(format!("CONNECT {port}\n").as_bytes())
178        .await
179        .map_err(ConnectError::other)?;
180
181    let mut ok = [0; 3];
182    socket
183        .read_exact(&mut ok)
184        .await
185        .map_err(ConnectError::other)?;
186    if &ok != b"OK " {
187        // FUTURE: consider returning an error that can be retried. This may
188        // require some changes to the hybrid vsock protocol, unclear.
189        return Err(ConnectError::other(anyhow::anyhow!(
190            "missing hybrid vsock response"
191        )));
192    }
193
194    for _ in 0.."4294967295\n".len() {
195        let mut b = [0];
196        socket
197            .read_exact(&mut b)
198            .await
199            .map_err(ConnectError::other)?;
200        if b[0] == b'\n' {
201            // Don't need to parse the host port number.
202            return Ok(socket.convert());
203        }
204    }
205    Err(ConnectError::other(anyhow::anyhow!(
206        "invalid hybrid vsock response"
207    )))
208}
209
210enum SocketType<'a> {
211    #[cfg(windows)]
212    VmId {
213        vm_id: guid::Guid,
214        port: u32,
215    },
216    HybridVsock {
217        path: &'a Path,
218        port: u32,
219    },
220}
221
222async fn new_data_connection(
223    driver: &(impl Driver + ?Sized),
224    typ: SocketType<'_>,
225) -> anyhow::Result<(u64, PolledSocket<socket2::Socket>)> {
226    let mut socket = match typ {
227        #[cfg(windows)]
228        SocketType::VmId { vm_id, port } => {
229            let socket = hyperv::connect_vsock(driver, vm_id, port).await?;
230            PolledSocket::new(driver, socket2::Socket::from(socket))?
231        }
232        SocketType::HybridVsock { path, port } => connect_hybrid_vsock(driver, path, port).await?,
233    };
234
235    // Read the 8 byte connection id which is always sent first on the connection.
236    let mut id = [0; 8];
237    socket
238        .read_exact(&mut id)
239        .await
240        .context("reading connection id")?;
241    let id = u64::from_ne_bytes(id);
242    Ok((id, socket))
243}
244
245/// Represents different VM types.
246#[derive(Clone)]
247enum VmType {
248    /// A Hyper-V VM represented by a VM ID GUID, which uses a VmSocket to connect.
249    #[cfg(windows)]
250    HyperV(guid::Guid),
251    /// A VM which uses hybrid vsock over Unix sockets.
252    HybridVsock(PathBuf),
253    /// A VM that cannot be used for data connections.
254    None,
255}
256
257/// The diagnostics client.
258pub struct DiagClient {
259    vm: VmType,
260    ttrpc: mesh_rpc::Client,
261    driver: Box<dyn Driver>,
262}
263
264/// Defines packet capture operations.
265#[derive(PartialEq)]
266pub enum PacketCaptureOperation {
267    /// Query details.
268    Query,
269    /// Start packet capture.
270    Start,
271    /// Stop packet capture.
272    Stop,
273}
274
275/// An error connecting to the diagnostics server.
276#[derive(Debug, Error)]
277#[error("failed to connect")]
278pub struct ConnectError {
279    #[source]
280    err: anyhow::Error,
281    kind: ConnectErrorKind,
282}
283
284#[derive(Debug)]
285enum ConnectErrorKind {
286    Other,
287    VmNotStarted,
288    ServerTimedOut,
289}
290
291impl ConnectError {
292    /// Returns the time to wait before retrying the connection. If `None`, the
293    /// connection should not be retried.
294    pub fn retry_timeout(&self) -> Option<Duration> {
295        match self.kind {
296            ConnectErrorKind::VmNotStarted => Some(Duration::from_secs(1)),
297            ConnectErrorKind::ServerTimedOut => {
298                // The socket infrastructure has an internal timeout.
299                Some(Duration::ZERO)
300            }
301            _ => None,
302        }
303    }
304
305    fn other(err: impl Into<anyhow::Error>) -> Self {
306        Self {
307            err: err.into(),
308            kind: ConnectErrorKind::Other,
309        }
310    }
311
312    fn connect(err: std::io::Error) -> Self {
313        let kind = match err.kind() {
314            ErrorKind::AddrNotAvailable => ConnectErrorKind::VmNotStarted,
315            ErrorKind::TimedOut => ConnectErrorKind::ServerTimedOut,
316            _ => match err.raw_os_error() {
317                #[cfg(windows)]
318                Some(windows_sys::Win32::Networking::WinSock::WSAENETUNREACH) => {
319                    ConnectErrorKind::VmNotStarted
320                }
321                _ => ConnectErrorKind::Other,
322            },
323        };
324        Self {
325            err: anyhow::Error::from(err).context("failed to connect"),
326            kind,
327        }
328    }
329}
330
331struct VmConnector {
332    vm: VmType,
333    driver: Box<dyn Driver>,
334}
335
336impl mesh_rpc::client::Dial for VmConnector {
337    type Stream = PolledSocket<socket2::Socket>;
338
339    async fn dial(&mut self) -> std::io::Result<Self::Stream> {
340        match &self.vm {
341            #[cfg(windows)]
342            VmType::HyperV(guid) => {
343                let socket = hyperv::connect_vsock(
344                    self.driver.as_ref(),
345                    *guid,
346                    diag_proto::VSOCK_CONTROL_PORT,
347                )
348                .await
349                .map_err(|err| std::io::Error::new(ErrorKind::Other, err))?;
350                Ok(PolledSocket::new(&self.driver, socket.into())?)
351            }
352            VmType::HybridVsock(path) => {
353                let socket = connect_hybrid_vsock(
354                    self.driver.as_ref(),
355                    path,
356                    diag_proto::VSOCK_CONTROL_PORT,
357                )
358                .await
359                .map_err(|err| std::io::Error::new(ErrorKind::Other, err))?;
360                Ok(socket)
361            }
362            VmType::None => unreachable!(),
363        }
364    }
365}
366
367impl DiagClient {
368    /// Creates a client from Hyper-V VM name.
369    #[cfg(windows)]
370    pub fn from_hyperv_name(
371        driver: impl Driver + Spawn + Clone,
372        name: &str,
373    ) -> anyhow::Result<Self> {
374        Ok(Self::from_hyperv_id(
375            driver,
376            hyperv::vm_id_from_name(name).map_err(ConnectError::other)?,
377        ))
378    }
379
380    /// Creates a client from a Hyper-V or HCS VM ID.
381    #[cfg(windows)]
382    pub fn from_hyperv_id(driver: impl Driver + Spawn + Clone, vm_id: guid::Guid) -> Self {
383        let vm = VmType::HyperV(vm_id);
384        Self::new(
385            driver.clone(),
386            vm.clone(),
387            VmConnector {
388                vm,
389                driver: Box::new(driver),
390            },
391        )
392    }
393
394    /// Creates a client from a hybrid vsock Unix socket path.
395    pub fn from_hybrid_vsock(driver: impl Driver + Spawn + Clone, path: &Path) -> Self {
396        let vm = VmType::HybridVsock(path.into());
397        Self::new(
398            driver.clone(),
399            vm.clone(),
400            VmConnector {
401                vm,
402                driver: Box::new(driver.clone()),
403            },
404        )
405    }
406
407    /// Creates a client from a dialer.
408    ///
409    /// This client won't be usable with operations that require additional connections.
410    pub fn from_dialer(driver: impl Driver + Spawn, conn: impl mesh_rpc::client::Dial) -> Self {
411        Self::new(driver, VmType::None, conn)
412    }
413
414    fn new(driver: impl Driver + Spawn, vm: VmType, conn: impl mesh_rpc::client::Dial) -> Self {
415        Self {
416            vm,
417            ttrpc: mesh_rpc::client::ClientBuilder::new()
418                // Use a short reconnect timeout (compared to the normal 20
419                // seconds) since the VM may start at any time.
420                .retry_timeout(Duration::from_secs(1))
421                .build(&driver, conn),
422            driver: Box::new(driver),
423        }
424    }
425
426    /// Waits for the paravisor to be ready for RPCs.
427    pub async fn wait_for_server(&self) -> anyhow::Result<()> {
428        match self
429            .ttrpc
430            .call()
431            .wait_ready(true)
432            .start(diag_proto::OpenhclDiag::Ping, ())
433            .await
434        {
435            Ok(()) => {}
436            Err(Status { code, .. }) if code == mesh_rpc::service::Code::Unimplemented as i32 => {
437                // Older versions of the diag server don't support the ping
438                // RPC, but an unimplemented failure is good enough to know
439                // the server is ready.
440            }
441            Err(status) => return Err(grpc_status(status)),
442        }
443        Ok(())
444    }
445
446    /// Creates a builder for execing a command.
447    pub fn exec(&self, command: impl AsRef<str>) -> ExecBuilder<'_> {
448        ExecBuilder {
449            client: self,
450            with_stdin: false,
451            with_stdout: false,
452            with_stderr: false,
453            request: ExecRequest {
454                command: command.as_ref().to_owned(),
455                ..Default::default()
456            },
457        }
458    }
459
460    /// Creates a new data connection socket.
461    ///
462    /// This can be used with [`DiagClient::custom_call`].
463    pub async fn connect_data(&self) -> anyhow::Result<(u64, PolledSocket<socket2::Socket>)> {
464        let socket_type = match &self.vm {
465            #[cfg(windows)]
466            VmType::HyperV(guid) => SocketType::VmId {
467                vm_id: *guid,
468                port: diag_proto::VSOCK_DATA_PORT,
469            },
470            VmType::HybridVsock(path) => SocketType::HybridVsock {
471                path,
472                port: diag_proto::VSOCK_DATA_PORT,
473            },
474            VmType::None => {
475                anyhow::bail!("cannot make additional connections with this client")
476            }
477        };
478        new_data_connection(self.driver.as_ref(), socket_type).await
479    }
480
481    /// Sends an inspection request to the server.
482    pub async fn inspect(
483        &self,
484        path: impl Into<String>,
485        depth: Option<usize>,
486        timeout: Option<Duration>,
487    ) -> anyhow::Result<Node> {
488        let response = self.ttrpc.call().timeout(timeout).start(
489            inspect_proto::InspectService::Inspect,
490            inspect_proto::InspectRequest {
491                path: path.into(),
492                // It would be better to pass an Option<u32> in the proto, but that would break backcompat.
493                depth: depth.unwrap_or(u32::MAX as usize) as u32,
494            },
495        );
496
497        let response = response.await.map_err(grpc_status)?;
498        Ok(response.result)
499    }
500
501    /// Updates an inspectable value.
502    pub async fn update(
503        &self,
504        path: impl Into<String>,
505        value: impl Into<String>,
506    ) -> anyhow::Result<inspect::Value> {
507        let response = self.ttrpc.call().start(
508            inspect_proto::InspectService::Update,
509            inspect_proto::UpdateRequest {
510                path: path.into(),
511                value: value.into(),
512            },
513        );
514
515        let response = response.await.map_err(grpc_status)?;
516
517        Ok(response.new_value)
518    }
519
520    /// Get PID of a given process
521    pub async fn get_pid(&self, name: &str) -> anyhow::Result<i32> {
522        let hosts = self.inspect("mesh/hosts", Some(1), None).await?;
523        let mut plist = Vec::new();
524
525        let Node::Dir(processes) = hosts else {
526            anyhow::bail!("Hosts node is not a dir");
527        };
528        for process in processes {
529            let Node::Dir(pnode) = process.node else {
530                anyhow::bail!("Process node is not a dir");
531            };
532            for entry in pnode {
533                if entry.name == "name" {
534                    let Node::Value(value) = entry.node else {
535                        anyhow::bail!("Name node is not a value");
536                    };
537                    let ValueKind::String(strval) = value.kind else {
538                        anyhow::bail!("Name node is not a string");
539                    };
540                    if strval == name {
541                        return Ok(process.name.parse()?);
542                    }
543                    plist.push(strval);
544                }
545            }
546        }
547
548        anyhow::bail!("PID of {name} not found. Processes: {:?}", plist)
549    }
550
551    /// Starts the VM.
552    pub async fn start(
553        &self,
554        env: impl IntoIterator<Item = (String, Option<String>)>,
555        args: impl IntoIterator<Item = String>,
556    ) -> anyhow::Result<()> {
557        let request = diag_proto::StartRequest {
558            env: env
559                .into_iter()
560                .map(|(name, value)| diag_proto::EnvPair { name, value })
561                .collect(),
562            args: args.into_iter().collect(),
563        };
564        self.ttrpc
565            .call()
566            .start(diag_proto::UnderhillDiag::Start, request)
567            .await
568            .map_err(grpc_status)?;
569
570        Ok(())
571    }
572
573    /// Gets the contents of /dev/kmsg
574    pub async fn kmsg(&self, follow: bool) -> anyhow::Result<KmsgStream> {
575        let (conn, socket) = self.connect_data().await?;
576
577        self.ttrpc
578            .call()
579            .start(
580                diag_proto::UnderhillDiag::Kmsg,
581                diag_proto::KmsgRequest { follow, conn },
582            )
583            .await
584            .map_err(grpc_status)?;
585
586        Ok(KmsgStream::new(socket))
587    }
588
589    /// Gets the contents of the file
590    pub async fn read_file(
591        &self,
592        follow: bool,
593        file_path: String,
594    ) -> anyhow::Result<PolledSocket<socket2::Socket>> {
595        let (conn, socket) = self.connect_data().await?;
596
597        self.ttrpc
598            .call()
599            .start(
600                diag_proto::UnderhillDiag::ReadFile,
601                diag_proto::FileRequest {
602                    follow,
603                    conn,
604                    file_path,
605                },
606            )
607            .await
608            .map_err(grpc_status)?;
609
610        Ok(socket)
611    }
612
613    /// Issues a call to the server using a custom RPC.
614    ///
615    /// This can be used to support extension RPCs that are not part of the main
616    /// diagnostics service.
617    pub fn custom_call(&self) -> mesh_rpc::client::CallBuilder<'_> {
618        self.ttrpc.call()
619    }
620
621    /// Crashes the VM.
622    pub async fn crash(&self, pid: i32) -> anyhow::Result<()> {
623        self.ttrpc
624            .call()
625            .start(
626                diag_proto::UnderhillDiag::Crash,
627                diag_proto::CrashRequest { pid },
628            )
629            .await
630            .map_err(grpc_status)?;
631
632        Ok(())
633    }
634
635    /// Sets up network packet capture trace.
636    pub async fn packet_capture(
637        &self,
638        op: PacketCaptureOperation,
639        num_streams: u32,
640        snaplen: u16,
641    ) -> anyhow::Result<(Vec<PolledSocket<socket2::Socket>>, u32)> {
642        let mut sockets = Vec::new();
643        let op_data = match op {
644            PacketCaptureOperation::Start => {
645                let mut conns = Vec::new();
646                for _ in 0..num_streams {
647                    let (conn, socket) = self.connect_data().await?;
648                    conns.push(conn);
649                    sockets.push(socket);
650                }
651                Some(OpData::StartData(diag_proto::StartPacketCaptureData {
652                    snaplen: snaplen.into(),
653                    conns,
654                }))
655            }
656            _ => None,
657        };
658
659        let operation = match op {
660            PacketCaptureOperation::Query => Operation::Query,
661            PacketCaptureOperation::Start => Operation::Start,
662            PacketCaptureOperation::Stop => Operation::Stop,
663        };
664
665        let response = self
666            .ttrpc
667            .call()
668            .start(
669                diag_proto::UnderhillDiag::PacketCapture,
670                diag_proto::NetworkPacketCaptureRequest {
671                    operation: operation.into(),
672                    op_data,
673                },
674            )
675            .await
676            .map_err(grpc_status)?;
677
678        Ok((sockets, response.num_streams))
679    }
680
681    /// Saves a core dump file being streamed from Underhill
682    pub async fn core_dump(
683        &self,
684        pid: i32,
685        mut writer: impl AsyncWrite + Unpin,
686        mut stderr: impl AsyncWrite + Unpin,
687        verbose: bool,
688    ) -> anyhow::Result<()> {
689        // Launch hcl-dump to dump the target process. Use raw_socket_io so that
690        // the diagnostics process does not have to be running during the core
691        // dump process; this ensures that we can dump the diagnostics process,
692        // too.
693        let mut process = self.exec("/bin/underhill-dump");
694        if verbose {
695            process.args(["-v"]);
696        }
697        let mut process = process
698            .args([pid.to_string()])
699            .stdin(false)
700            .stdout(true)
701            .stderr(true)
702            .raw_socket_io(true)
703            .spawn()
704            .await
705            .context("failed to launch underhill-dump")?;
706
707        let process_stdout = PolledSocket::new(&self.driver, process.stdout.take().unwrap())?;
708        let process_stderr = PolledSocket::new(&self.driver, process.stderr.take().unwrap())?;
709
710        let out = futures::io::copy(process_stdout, &mut writer);
711        let err = futures::io::copy(process_stderr, &mut stderr);
712
713        futures::try_join!(out, err)?;
714
715        let status = process
716            .wait()
717            .await
718            .context("failed to wait for underhill-dump")?;
719
720        if !status.success() {
721            anyhow::bail!(
722                "underhill-dump failed with exit code {}",
723                status.exit_code()
724            );
725        }
726        Ok(())
727    }
728
729    /// Restarts the Underhill worker.
730    pub async fn restart(&self) -> anyhow::Result<()> {
731        self.ttrpc
732            .call()
733            .start(diag_proto::UnderhillDiag::Restart, ())
734            .await
735            .map_err(grpc_status)?;
736
737        Ok(())
738    }
739
740    /// Pause the VM (including all devices).
741    pub async fn pause(&self) -> anyhow::Result<()> {
742        self.ttrpc
743            .call()
744            .start(diag_proto::UnderhillDiag::Pause, ())
745            .await
746            .map_err(grpc_status)?;
747
748        Ok(())
749    }
750
751    /// Resume the VM.
752    pub async fn resume(&self) -> anyhow::Result<()> {
753        self.ttrpc
754            .call()
755            .start(diag_proto::UnderhillDiag::Resume, ())
756            .await
757            .map_err(grpc_status)?;
758
759        Ok(())
760    }
761
762    /// Dumps the VM's VTL2 saved state.
763    pub async fn dump_saved_state(&self) -> anyhow::Result<Vec<u8>> {
764        let state = self
765            .ttrpc
766            .call()
767            .start(diag_proto::UnderhillDiag::DumpSavedState, ())
768            .await
769            .map_err(grpc_status)?;
770
771        Ok(state.data)
772    }
773}
774
775fn grpc_status(status: Status) -> anyhow::Error {
776    anyhow::anyhow!(status.message)
777}
778
779/// A builder for launching a command in VTL2.
780pub struct ExecBuilder<'a> {
781    client: &'a DiagClient,
782    with_stdin: bool,
783    with_stdout: bool,
784    with_stderr: bool,
785    request: ExecRequest,
786}
787
788impl ExecBuilder<'_> {
789    /// Adds `args` to the argument list.
790    pub fn args<T: AsRef<str>>(&mut self, args: impl IntoIterator<Item = T>) -> &mut Self {
791        self.request
792            .args
793            .extend(args.into_iter().map(|s| s.as_ref().to_owned()));
794        self
795    }
796
797    /// Sets whether the process is spawned with a TTY.
798    pub fn tty(&mut self, tty: bool) -> &mut Self {
799        self.request.tty = tty;
800        self
801    }
802
803    /// Specifies whether a stdin socket should be opened.
804    pub fn stdin(&mut self, stdin: bool) -> &mut Self {
805        self.with_stdin = stdin;
806        self
807    }
808
809    /// Specifies whether a stdout socket should be opened.
810    pub fn stdout(&mut self, stdout: bool) -> &mut Self {
811        self.with_stdout = stdout;
812        self
813    }
814
815    /// Specifies whether a stderr socket should be opened.
816    pub fn stderr(&mut self, stderr: bool) -> &mut Self {
817        self.with_stderr = stderr;
818        self
819    }
820
821    /// Specifies whether the processes's stdout and stderr should be combined
822    /// into a single stream (the stdout socket).
823    pub fn combine_stderr(&mut self, combine_stderr: bool) -> &mut Self {
824        self.request.combine_stderr = combine_stderr;
825        self
826    }
827
828    /// Specifies whether the vsock sockets used for stdio should be passed
829    /// directly to the launched process instead of going through relays.
830    pub fn raw_socket_io(&mut self, raw_socket_io: bool) -> &mut Self {
831        self.request.raw_socket_io = raw_socket_io;
832        self
833    }
834
835    /// Clears the default environment.
836    pub fn env_clear(&mut self) -> &mut Self {
837        self.request.clear_env = true;
838        self
839    }
840
841    /// Removes an environment variable.
842    pub fn env_remove(&mut self, name: impl AsRef<str>) -> &mut Self {
843        self.request.env.push(diag_proto::EnvPair {
844            name: name.as_ref().to_owned(),
845            value: None,
846        });
847        self
848    }
849
850    /// Sets an environment variable.
851    pub fn env(&mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> &mut Self {
852        self.request.env.push(diag_proto::EnvPair {
853            name: name.as_ref().to_owned(),
854            value: Some(value.as_ref().to_owned()),
855        });
856        self
857    }
858
859    /// Spawns the process.
860    pub async fn spawn(&self) -> anyhow::Result<Process> {
861        let mut request = self.request.clone();
862
863        let stdin = if self.with_stdin {
864            let (id, stdin) = self
865                .client
866                .connect_data()
867                .await
868                .context("failed to connect stdin")?;
869            request.stdin = id;
870
871            Some(stdin.into_inner())
872        } else {
873            None
874        };
875
876        let stdout = if self.with_stdout {
877            let (id, stdout) = self
878                .client
879                .connect_data()
880                .await
881                .context("failed to connect stdout")?;
882            request.stdout = id;
883
884            Some(stdout.into_inner())
885        } else {
886            None
887        };
888
889        let stderr = if self.with_stdout {
890            let (id, stderr) = self
891                .client
892                .connect_data()
893                .await
894                .context("failed to connect stderr")?;
895            request.stderr = id;
896
897            Some(stderr.into_inner())
898        } else {
899            None
900        };
901
902        let response = self
903            .client
904            .ttrpc
905            .call()
906            .start(diag_proto::UnderhillDiag::Exec, request)
907            .await
908            .map_err(grpc_status)?;
909
910        let wait = self.client.ttrpc.call().start(
911            diag_proto::UnderhillDiag::Wait,
912            WaitRequest { pid: response.pid },
913        );
914
915        Ok(Process {
916            stdin,
917            stdout,
918            stderr,
919            wait,
920            pid: response.pid,
921        })
922    }
923}
924
925/// A process running in VTL2.
926#[derive(Debug)]
927pub struct Process {
928    /// The standard input stream.
929    pub stdin: Option<socket2::Socket>,
930    /// The standard output stream.
931    pub stdout: Option<socket2::Socket>,
932    /// The standard error stream.
933    pub stderr: Option<socket2::Socket>,
934    pid: i32,
935    wait: mesh_rpc::client::Call<WaitResponse>,
936}
937
938impl Process {
939    /// Returns the process ID.
940    pub fn id(&self) -> i32 {
941        self.pid
942    }
943
944    /// Waits for the process to exit.
945    pub async fn wait(self) -> anyhow::Result<ExitStatus> {
946        let response = self
947            .wait
948            .await
949            .map_err(|err| anyhow::anyhow!("{}", err.message))?;
950
951        Ok(ExitStatus { response })
952    }
953}
954
955/// Process exit status.
956#[derive(Debug)]
957pub struct ExitStatus {
958    response: WaitResponse,
959}
960
961impl ExitStatus {
962    /// The exit code.
963    pub fn exit_code(&self) -> i32 {
964        self.response.exit_code
965    }
966
967    /// Whether the process successfully terminated.
968    pub fn success(&self) -> bool {
969        self.response.exit_code == 0
970    }
971}