ohcldiag_dev/
main.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! A "move fast, break things" tool, that provides no long-term CLI stability
5//! guarantees.
6
7#![expect(missing_docs)]
8#![forbid(unsafe_code)]
9
10mod completions;
11
12use anyhow::Context;
13use clap::ArgGroup;
14use clap::Args;
15use clap::Parser;
16use clap::Subcommand;
17use diag_client::DiagClient;
18use diag_client::PacketCaptureOperation;
19use futures::StreamExt;
20use futures::io::AllowStdIo;
21use futures_concurrency::future::Race;
22use pal_async::DefaultPool;
23use pal_async::driver::Driver;
24use pal_async::socket::PolledSocket;
25use pal_async::task::Spawn;
26use pal_async::timer::PolledTimer;
27use std::convert::Infallible;
28use std::ffi::OsStr;
29use std::io::ErrorKind;
30use std::io::IsTerminal;
31use std::io::Write;
32use std::net::TcpListener;
33use std::path::Path;
34use std::path::PathBuf;
35use std::str::FromStr;
36use std::sync::Arc;
37use std::time::Duration;
38use thiserror::Error;
39use tracing_subscriber::layer::SubscriberExt;
40use tracing_subscriber::util::SubscriberInitExt;
41use unicycle::FuturesUnordered;
42
43#[derive(Parser)]
44#[clap(about = "(dev) CLI to interact with the Underhill diagnostics server")]
45#[clap(long_about = r#"
46CLI to interact with the Underhill diagnostics server.
47
48DISCLAIMER:
49    `ohcldiag-dev` does not make ANY stability guarantees regarding the layout of
50    the CLI, the syntax that is emitted via stdout/stderr, the location of nodes
51    in the `inspect` graph, etc...
52
53        !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
54        !! ANY AUTOMATION THAT USES ohcldiag-dev WILL EVENTUALLY BREAK !!
55        !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
56"#)]
57struct Options {
58    #[clap(flatten)]
59    vm: VmArg,
60
61    #[clap(subcommand)]
62    command: Command,
63}
64
65#[derive(Subcommand)]
66enum Command {
67    #[clap(hide = true)]
68    Complete(clap_dyn_complete::Complete),
69    Completions(completions::Completions),
70    /// Starts an interactive terminal in VTL2.
71    Shell {
72        /// The shell process to start.
73        #[clap(default_value = "/bin/sh")]
74        shell: String,
75        /// The arguments to pass to the shell process.
76        args: Vec<String>,
77    },
78    /// Runs a process in VTL2.
79    Run {
80        /// The command to run.
81        command: String,
82        /// Arguments to pass to the command.
83        args: Vec<String>,
84    },
85    /// Inspects the Underhill state.
86    #[clap(visible_alias = "i")]
87    Inspect {
88        /// Recursively enumerate child nodes.
89        #[clap(short)]
90        recursive: bool,
91        /// Limit the recursive inspection depth.
92        #[clap(short, long, requires("recursive"))]
93        limit: Option<usize>,
94        /// Output in JSON format.
95        #[clap(short, long)]
96        json: bool,
97        /// Poll periodically.
98        #[clap(short)]
99        poll: bool,
100        /// The poll period in seconds.
101        #[clap(long, default_value = "1", requires("poll"))]
102        period: f64,
103        /// The count of polls
104        #[clap(long, requires("poll"))]
105        count: Option<usize>,
106        /// The path to inspect.
107        path: Option<String>,
108        /// Update the path with a new value.
109        #[clap(short, long, conflicts_with("recursive"))]
110        update: Option<String>,
111        /// Timeout to wait for the inspection. 0 means no timeout.
112        #[clap(short, default_value = "1", conflicts_with("update"))]
113        timeout: u64,
114    },
115    /// Updates an inspectable value.
116    #[clap(hide = true)]
117    Update {
118        /// The path.
119        path: String,
120        /// The new value.
121        value: String,
122    },
123    /// Starts the VM if it's waiting for the signal to start.
124    ///
125    /// Underhill must have been started with --wait-for-start or
126    /// OPENHCL_WAIT_FOR_START set.
127    Start {
128        /// Environment variables to set, in the form X=Y
129        #[clap(short, long)]
130        env: Vec<EnvString>,
131        /// Environment variables to clear
132        #[clap(short, long)]
133        unset: Vec<String>,
134        /// Extra command line arguments to append.
135        args: Vec<String>,
136    },
137    /// Writes the contents of the kernel message buffer, /dev/kmsg.
138    Kmsg {
139        /// Keep waiting for and writing new data as its logged.
140        #[clap(short, long)]
141        follow: bool,
142        /// Reconnect (retrying indefinitely) whenever the connection is lost.
143        #[clap(short, long)]
144        reconnect: bool,
145        /// Write verbose information about the connection state.
146        #[clap(short, long)]
147        verbose: bool,
148        /// Read kmsg from the VM's serial port.
149        ///
150        /// This only works on Hyper-V.
151        #[cfg(windows)]
152        #[clap(long, conflicts_with = "reconnect")]
153        serial: bool,
154        /// Pipe to read from for the serial port (or any other pipe)
155        ///
156        /// This only works on Hyper-V.
157        #[cfg(windows)]
158        #[clap(long, requires = "serial")]
159        pipe_path: Option<String>,
160    },
161    /// Writes the contents of the file.
162    File {
163        /// Keep waiting for and writing new data as its logged.
164        #[clap(short, long)]
165        follow: bool,
166        #[clap(short('p'), long)]
167        file_path: String,
168    },
169    /// Starts GDB server on stdio.
170    ///
171    /// Use this with gdb's target command:
172    ///
173    ///     target remote |ohcldiag-dev.exe gdbserver my-vm
174    ///
175    /// Or for multi-process debugging:
176    ///
177    ///     target extended-remote |ohcldiag-dev.exe gdbserver --multi my-vm
178    Gdbserver {
179        /// The pid to attach to. Defaults to Underhill's.
180        #[clap(long)]
181        pid: Option<i32>,
182        /// Use multi-process debugging, for use with gdb's extended-remote.
183        #[clap(long, conflicts_with("pid"))]
184        multi: bool,
185    },
186    /// Starts the GDB stub for debugging the guest on stdio.
187    ///
188    /// Use this with gdb's target command:
189    ///
190    ///     target remote |ohcldiag-dev.exe gdbstub my-vm
191    ///
192    Gdbstub {
193        /// The vsock prot to connect to.
194        #[clap(short, long, default_value = "4")]
195        port: u32,
196    },
197    /// Crashes the VM.
198    ///
199    /// Must specify the VM name, as well as the crash type.
200    #[clap(group(
201        ArgGroup::new("process")
202            .required(true)
203            .args(&["pid", "name"]),
204    ))]
205    Crash {
206        /// Type of crash.
207        ///
208        /// Current crash types supported: "panic"
209        crash_type: CrashType,
210        /// PID of underhill process to crash
211        #[clap(short, long)]
212        pid: Option<i32>,
213        /// Name of underhill process to crash
214        #[clap(short, long)]
215        name: Option<String>,
216    },
217    /// Streams the ELF core dump file of a process to the host.
218    ///
219    /// Streams the core dump file of a process to the host where the file
220    /// is saved as `dst`.
221    #[clap(group(
222        ArgGroup::new("process")
223        .required(true)
224        .args(&["pid", "name"]),
225    ))]
226    CoreDump {
227        /// Enable verbose output.
228        #[clap(short, long)]
229        verbose: bool,
230        /// PID of process to dump
231        #[clap(short, long)]
232        pid: Option<i32>,
233        /// Name of underhill process to dump
234        #[clap(short, long)]
235        name: Option<String>,
236        /// Destination file path. If omitted, the data is written to the standard
237        /// output unless it is a terminal. In that case, an error is returned.
238        dst: Option<PathBuf>,
239    },
240    /// Restarts the Underhill worker process, keeping VTL0 running.
241    Restart,
242    /// Get the current contents of the performance trace buffer, for use with
243    /// <https://ui.perfetto.dev>.
244    PerfTrace {
245        /// The output file. Defaults to stdout.
246        #[clap(short)]
247        output: Option<PathBuf>,
248    },
249    /// Sets up a relay between a virtual socket and a TCP client on the host.
250    VsockTcpRelay {
251        vsock_port: u32,
252        tcp_port: u16,
253        #[clap(long)]
254        allow_remote: bool,
255        /// Reconnect (retrying indefinitely) whenever either side of the
256        /// connection is lost.
257        ///
258        /// NOTE: Today, this does not handle the case where the vsock side is
259        /// not ready to connect. That will cause the relay to terminate.
260        #[clap(short, long)]
261        reconnect: bool,
262    },
263    /// Pause the VM (including all devices)
264    Pause,
265    /// Resume the VM
266    Resume,
267    /// Dumps the VM's VTL2 state without servicing or tearing down Underhill.
268    DumpSavedState {
269        /// The output file. Defaults to stdout.
270        #[clap(short)]
271        output: Option<PathBuf>,
272    },
273    /// Starts a network packet capture trace.
274    PacketCapture {
275        /// Destination file path. nic index is appended to the file name.
276        #[clap(short('w'), default_value = "nic")]
277        output: PathBuf,
278        /// Number of seconds for which to capture packets.
279        #[clap(short('G'), long, default_value = "60", value_parser = |arg: &str| -> Result<Duration, std::num::ParseIntError> {Ok(Duration::from_secs(arg.parse()?))})]
280        seconds: Duration,
281        /// Length of the packet to capture.
282        #[clap(short('s'), long, default_value = "65535", value_parser = clap::value_parser!(u16).range(1..))]
283        snaplen: u16,
284    },
285    /// Memory usage profile tracing.
286    MemoryProfileTrace {
287        /// PID of process to collect the trace for
288        #[clap(short, long)]
289        pid: Option<i32>,
290        /// Name of underhill process to dump
291        #[clap(short, long)]
292        name: Option<String>,
293        /// The output file. Defaults to stdout.
294        #[clap(short)]
295        output: Option<PathBuf>,
296    },
297    /// Processes EFI diagnostics from guest memory and outputs the logs.
298    ///
299    /// The log level filter controls which UEFI log entries are emitted.
300    /// The buffer already contains all log levels; this filter selects
301    /// which ones to display.
302    EfiDiagnostics {
303        /// The log level filter to apply.
304        ///
305        /// Accepted values: "default" (errors+warnings), "info" (errors+warnings+info),
306        /// "full" (all levels).
307        log_level: EfiDiagnosticsLogLevel,
308        /// The output destination.
309        ///
310        /// Accepted values: "stdout", "tracing".
311        output: EfiDiagnosticsOutput,
312    },
313}
314
315#[derive(Debug, Clone, Args)]
316pub struct VmArg {
317    #[doc = r#"VM identifier.
318
319    This can be one of:
320
321    * vsock:PATH - A path to a hybrid vsock Unix socket for a VM, as used by OpenVMM
322
323    * unix:PATH - A path to a Unix socket for connecting to the control plane
324
325    "#]
326    #[cfg_attr(
327        windows,
328        doc = "* hyperv:NAME - A Hyper-V VM name
329
330    "
331    )]
332    #[cfg_attr(
333        windows,
334        doc = "* NAME_OR_PATH - Either a Hyper-V VM name, or a path as in vsock:PATH>"
335    )]
336    #[cfg_attr(not(windows), doc = "* PATH - A path as in vsock:PATH")]
337    #[clap(name = "VM")]
338    id: VmId,
339}
340
341#[derive(Debug, Clone)]
342enum VmId {
343    #[cfg(windows)]
344    HyperV(String),
345    HybridVsock(PathBuf),
346}
347
348impl FromStr for VmId {
349    type Err = Infallible;
350
351    fn from_str(s: &str) -> Result<Self, Self::Err> {
352        if let Some(s) = s.strip_prefix("vsock:") {
353            Ok(Self::HybridVsock(Path::new(s).to_owned()))
354        } else {
355            #[cfg(windows)]
356            if let Some(s) = s.strip_prefix("hyperv:") {
357                return Ok(Self::HyperV(s.to_owned()));
358            } else if !pal::windows::fs::is_unix_socket(s.as_ref()).unwrap_or(false) {
359                return Ok(Self::HyperV(s.to_owned()));
360            }
361            // Default to hybrid vsock since this is what OpenVMM supports for
362            // Underhill.
363            Ok(Self::HybridVsock(Path::new(s).to_owned()))
364        }
365    }
366}
367
368#[derive(Clone)]
369struct EnvString {
370    name: String,
371    value: String,
372}
373
374#[derive(Clone, clap::ValueEnum)]
375enum CrashType {
376    #[clap(name = "panic")]
377    UhPanic,
378}
379
380#[derive(Clone, clap::ValueEnum)]
381enum EfiDiagnosticsLogLevel {
382    /// Errors and warnings only
383    Default,
384    /// Errors, warnings, and info
385    Info,
386    /// All log levels
387    Full,
388}
389
390impl EfiDiagnosticsLogLevel {
391    fn as_inspect_value(&self) -> &'static str {
392        match self {
393            EfiDiagnosticsLogLevel::Default => "default",
394            EfiDiagnosticsLogLevel::Info => "info",
395            EfiDiagnosticsLogLevel::Full => "full",
396        }
397    }
398}
399
400#[derive(Clone, clap::ValueEnum)]
401enum EfiDiagnosticsOutput {
402    /// Emit to stdout
403    Stdout,
404    /// Emit to tracing
405    Tracing,
406}
407
408impl EfiDiagnosticsOutput {
409    fn as_inspect_value(&self) -> &'static str {
410        match self {
411            EfiDiagnosticsOutput::Stdout => "stdout",
412            EfiDiagnosticsOutput::Tracing => "tracing",
413        }
414    }
415}
416
417#[derive(Debug, Error)]
418#[error("bad environment variable, expected VAR=value")]
419struct BadEnvString;
420
421impl FromStr for EnvString {
422    type Err = BadEnvString;
423
424    fn from_str(s: &str) -> Result<Self, Self::Err> {
425        let (name, value) = s.split_once('=').ok_or(BadEnvString)?;
426        Ok(Self {
427            name: name.to_owned(),
428            value: value.to_owned(),
429        })
430    }
431}
432
433// N.B. this exits after a successful completion.
434async fn run(
435    client: &DiagClient,
436    command: impl AsRef<str>,
437    args: impl IntoIterator<Item = impl AsRef<str>>,
438) -> anyhow::Result<()> {
439    // TODO: if stdout and stderr of this process are backed by the
440    // same thing, then pass combine_stderr instead.
441    let mut process = client
442        .exec(&command)
443        .args(args)
444        .stdin(true)
445        .stdout(true)
446        .stderr(true)
447        .spawn()
448        .await?;
449
450    let mut stdin = process.stdin.take().unwrap();
451    let mut stdout = process.stdout.take().unwrap();
452    let mut stderr = process.stderr.take().unwrap();
453
454    std::thread::spawn({
455        move || {
456            let _ = std::io::copy(&mut std::io::stdin(), &mut stdin);
457        }
458    });
459
460    let stderr_thread =
461        std::thread::spawn(move || std::io::copy(&mut stderr, &mut term::raw_stderr()));
462
463    std::io::copy(&mut stdout, &mut term::raw_stdout()).context("failed stdout copy")?;
464
465    stderr_thread
466        .join()
467        .unwrap()
468        .context("failed stderr thread")?;
469
470    let status = process.wait().await?;
471    std::process::exit(status.exit_code());
472}
473
474fn new_client(driver: impl Driver + Spawn + Clone, input: &VmArg) -> anyhow::Result<DiagClient> {
475    let client = match &input.id {
476        #[cfg(windows)]
477        VmId::HyperV(name) => DiagClient::from_hyperv_name(driver, name)?,
478        VmId::HybridVsock(path) => DiagClient::from_hybrid_vsock(driver, path),
479    };
480    Ok(client)
481}
482
483pub fn main() -> anyhow::Result<()> {
484    tracing_subscriber::registry()
485        .with(tracing_subscriber::fmt::layer())
486        .with(tracing_subscriber::EnvFilter::from_default_env())
487        .init();
488
489    term::enable_vt_and_utf8();
490    DefaultPool::run_with(async |driver| {
491        let Options { vm, command } = Options::parse();
492
493        match command {
494            Command::Complete(cmd) => {
495                cmd.println_to_stub_script::<Options>(
496                    None,
497                    completions::OhcldiagDevCompleteFactory {
498                        driver: driver.clone(),
499                    },
500                )
501                .await
502            }
503            Command::Completions(cmd) => cmd.run()?,
504            Command::Shell { shell, args } => {
505                let client = new_client(driver.clone(), &vm)?;
506
507                // Set TERM to ensure function keys and other characters work.
508                let term = std::env::var("TERM");
509                let term = term.as_deref().unwrap_or("xterm-256color");
510
511                let mut process = client
512                    .exec(&shell)
513                    .args(&args)
514                    .tty(true)
515                    .stdin(true)
516                    .stdout(true)
517                    .env("TERM", term)
518                    .spawn()
519                    .await?;
520
521                let mut stdin = process.stdin.take().unwrap();
522                let mut stdout = process.stdout.take().unwrap();
523
524                term::set_raw_console(true).expect("failed to set raw console mode");
525                std::thread::spawn({
526                    move || {
527                        let _ = std::io::copy(&mut std::io::stdin(), &mut stdin);
528                    }
529                });
530
531                std::io::copy(&mut stdout, &mut term::raw_stdout()).context("failed copy")?;
532
533                let status = process.wait().await?;
534
535                if !status.success() {
536                    eprintln!(
537                        "shell exited with non-zero exit code: {}",
538                        status.exit_code()
539                    );
540                }
541            }
542            Command::Run { command, args } => {
543                let client = new_client(driver.clone(), &vm)?;
544                run(&client, command, &args).await?;
545            }
546            Command::Inspect {
547                recursive,
548                limit,
549                json,
550                poll,
551                period,
552                count,
553                timeout,
554
555                path,
556                update,
557            } => {
558                let client = new_client(driver.clone(), &vm)?;
559
560                if let Some(update) = update {
561                    let Some(path) = path else {
562                        anyhow::bail!("must provide path for update")
563                    };
564
565                    let value = client.update(path, update).await?;
566                    match value.kind {
567                        inspect::ValueKind::String(s) => println!("{s}"),
568                        _ => println!("{value}"),
569                    }
570                } else {
571                    let timeout = if timeout == 0 {
572                        None
573                    } else {
574                        Some(Duration::from_secs(timeout))
575                    };
576                    let query = async || {
577                        client
578                            .inspect(
579                                path.as_deref().unwrap_or(""),
580                                if recursive { limit } else { Some(0) },
581                                timeout,
582                            )
583                            .await
584                    };
585
586                    if poll {
587                        let mut timer = PolledTimer::new(&driver);
588                        let period = Duration::from_secs_f64(period);
589                        let mut last_time = pal_async::timer::Instant::now();
590                        let mut last = query().await?;
591                        let mut count = count;
592
593                        loop {
594                            match count.as_mut() {
595                                Some(count) if *count == 0 => break,
596                                Some(count) => *count -= 1,
597                                None => {}
598                            }
599                            timer.sleep_until(last_time + period).await;
600                            let now = pal_async::timer::Instant::now();
601                            let this = query().await?;
602                            let diff = this.since(&last, now - last_time);
603                            if json {
604                                println!("{}", diff.json());
605                            } else {
606                                println!("{diff:#}");
607                            }
608                            last = this;
609                            last_time = now;
610                        }
611                    } else {
612                        let node = query().await?;
613                        if json {
614                            println!("{}", node.json());
615                        } else {
616                            println!("{node:#}");
617                        }
618                    }
619                }
620            }
621            Command::Update { path, value } => {
622                eprintln!(
623                    "`update` is deprecated - please use `ohcldiag-dev inspect <path> -u <new value>`"
624                );
625                let client = new_client(driver.clone(), &vm)?;
626                let value = client.update(path, value).await?;
627                match value.kind {
628                    inspect::ValueKind::String(s) => println!("{s}"),
629                    _ => println!("{value}"),
630                }
631            }
632            Command::Start { env, unset, args } => {
633                let client = new_client(driver.clone(), &vm)?;
634
635                let env = env
636                    .into_iter()
637                    .map(|EnvString { name, value }| (name, Some(value)))
638                    .chain(unset.into_iter().map(|name| (name, None)));
639
640                client.start(env, args).await?;
641            }
642            Command::Kmsg {
643                follow,
644                reconnect,
645                verbose,
646                #[cfg(windows)]
647                serial,
648                #[cfg(windows)]
649                pipe_path,
650            } => {
651                let is_terminal = std::io::stdout().is_terminal();
652
653                #[cfg(windows)]
654                if serial {
655                    use diag_client::hyperv::ComPortAccessInfo;
656                    use futures::AsyncBufReadExt;
657
658                    let vm_name = match &vm.id {
659                        VmId::HyperV(name) => name,
660                        _ => anyhow::bail!("--serial is only supported for Hyper-V VMs"),
661                    };
662
663                    let port_access_info = if let Some(pipe_path) = pipe_path.as_ref() {
664                        ComPortAccessInfo::PortPipePath(pipe_path)
665                    } else {
666                        ComPortAccessInfo::NameAndPortNumber(vm_name, 3)
667                    };
668
669                    let pipe =
670                        diag_client::hyperv::open_serial_port(&driver, port_access_info).await?;
671                    let pipe = pal_async::pipe::PolledPipe::new(&driver, pipe)
672                        .context("failed to make a polled pipe")?;
673                    let pipe = futures::io::BufReader::new(pipe);
674
675                    let mut lines = pipe.lines();
676                    while let Some(line) = lines.next().await {
677                        let line = line?;
678                        if let Some(message) = kmsg::SyslogParsedEntry::new(&line) {
679                            println!("{}", message.display(is_terminal));
680                        } else {
681                            println!("{line}");
682                        }
683                    }
684
685                    return Ok(());
686                }
687
688                if verbose {
689                    eprintln!("Connecting to the diagnostics server.");
690                }
691
692                let client = new_client(driver.clone(), &vm)?;
693                'connect: loop {
694                    if reconnect {
695                        client.wait_for_server().await?;
696                    }
697                    let mut file_stream = client.kmsg(follow).await?;
698                    if verbose {
699                        eprintln!("Connected.");
700                    }
701
702                    while let Some(data) = file_stream.next().await {
703                        match data {
704                            Ok(data) => match kmsg::KmsgParsedEntry::new(&data) {
705                                Ok(message) => println!("{}", message.display(is_terminal)),
706                                Err(e) => println!("Invalid kmsg entry: {e:?}"),
707                            },
708                            Err(err) if reconnect && err.kind() == ErrorKind::ConnectionReset => {
709                                if verbose {
710                                    eprintln!(
711                                        "Connection reset to the diagnostics server. Reconnecting."
712                                    );
713                                }
714                                continue 'connect;
715                            }
716                            Err(err) => Err(err).context("failed to read kmsg")?,
717                        }
718                    }
719
720                    if reconnect {
721                        if verbose {
722                            eprintln!("Lost connection to the diagnostics server. Reconnecting.");
723                        }
724                        continue 'connect;
725                    }
726
727                    break;
728                }
729            }
730            Command::File { follow, file_path } => {
731                let client = new_client(driver.clone(), &vm)?;
732                let stream = client.read_file(follow, file_path).await?;
733                futures::io::copy(stream, &mut AllowStdIo::new(term::raw_stdout()))
734                    .await
735                    .context("failed to copy trace file")?;
736            }
737            Command::Gdbserver { multi, pid } => {
738                let client = new_client(driver.clone(), &vm)?;
739                // Pass the --once flag so that gdbserver exits after the stdio
740                // pipes are closed. Otherwise, gdbserver spins in a tight loop
741                // and never exits.
742                let gdbserver = "gdbserver --once";
743                let command = if multi {
744                    format!("{gdbserver} --multi -")
745                } else if let Some(pid) = pid {
746                    format!("{gdbserver} --attach - {pid}")
747                } else {
748                    format!("{gdbserver} --attach - \"$(cat /run/underhill.pid)\"")
749                };
750
751                run(&client, "/bin/sh", &["-c", &command]).await?;
752            }
753            Command::Gdbstub { port } => {
754                let vsock = match vm.id {
755                    VmId::HybridVsock(path) => {
756                        diag_client::connect_hybrid_vsock(&driver, &path, port).await?
757                    }
758                    #[cfg(windows)]
759                    VmId::HyperV(name) => {
760                        let vm_id = diag_client::hyperv::vm_id_from_name(&name)?;
761                        let stream =
762                            diag_client::hyperv::connect_vsock(&driver, vm_id, port).await?;
763                        PolledSocket::new(&driver, socket2::Socket::from(stream))?
764                    }
765                };
766
767                let vsock = Arc::new(vsock.into_inner());
768                // Spawn a thread to read stdin synchronously since pal_async
769                // does not offer a way to read it asynchronously.
770                let thread = std::thread::spawn({
771                    let vsock = vsock.clone();
772                    move || {
773                        let _ = std::io::copy(&mut std::io::stdin(), &mut vsock.as_ref());
774                    }
775                });
776
777                std::io::copy(&mut vsock.as_ref(), &mut term::raw_stdout())
778                    .context("failed stdout copy")?;
779                thread.join().unwrap();
780            }
781            Command::Crash {
782                crash_type,
783                pid,
784                name,
785            } => {
786                let client = new_client(driver.clone(), &vm)?;
787                let pid = if let Some(name) = name {
788                    client.get_pid(&name).await?
789                } else {
790                    pid.unwrap()
791                };
792                println!("Crashing PID: {pid}");
793                match crash_type {
794                    CrashType::UhPanic => {
795                        _ = client.crash(pid).await;
796                    }
797                }
798            }
799            Command::PacketCapture {
800                output,
801                seconds,
802                snaplen,
803            } => {
804                let client = new_client(driver.clone(), &vm)?;
805                println!(
806                    "Starting network packet capture. Wait for timeout or Ctrl-C to quit anytime."
807                );
808                let (_, num_streams) = client
809                    .packet_capture(PacketCaptureOperation::Query, 0, 0)
810                    .await?;
811                let file_stem = &output.file_stem().unwrap().to_string_lossy();
812                let extension = &output.extension().unwrap_or(OsStr::new("pcap"));
813                let mut new_output = PathBuf::from(&output);
814                let streams = client
815                    .packet_capture(PacketCaptureOperation::Start, num_streams, snaplen)
816                    .await?
817                    .0
818                    .into_iter()
819                    .enumerate()
820                    .map(|(i, i_stream)| {
821                        new_output.set_file_name(format!("{}-{}", &file_stem, i));
822                        new_output.set_extension(extension);
823                        let mut out = AllowStdIo::new(fs_err::File::create(&new_output)?);
824                        Ok(async move { futures::io::copy(i_stream, &mut out).await })
825                    })
826                    .collect::<Result<Vec<_>, std::io::Error>>()?;
827                capture_packets(client, streams, seconds).await;
828            }
829            Command::CoreDump {
830                verbose,
831                pid,
832                name,
833                dst,
834            } => {
835                ensure_not_terminal(&dst)?;
836                let client = new_client(driver.clone(), &vm)?;
837                let pid = if let Some(name) = name {
838                    client.get_pid(&name).await?
839                } else {
840                    pid.unwrap()
841                };
842                println!("Dumping PID: {pid}");
843                let file = create_or_stderr(&dst)?;
844                client
845                    .core_dump(
846                        pid,
847                        AllowStdIo::new(file),
848                        AllowStdIo::new(std::io::stderr()),
849                        verbose,
850                    )
851                    .await?;
852            }
853            Command::Restart => {
854                let client = new_client(driver.clone(), &vm)?;
855                client.restart().await?;
856            }
857            Command::PerfTrace { output } => {
858                ensure_not_terminal(&output)?;
859
860                let client = new_client(driver.clone(), &vm)?;
861
862                // Flush the perf trace.
863                client
864                    .update("trace/perf/flush".to_owned(), "true".to_owned())
865                    .await
866                    .context("failed to flush perf")?;
867
868                let file = create_or_stderr(&output)?;
869                let stream = client
870                    .read_file(false, "underhill.perfetto".to_owned())
871                    .await
872                    .context("failed to read trace file")?;
873
874                futures::io::copy(stream, &mut AllowStdIo::new(file))
875                    .await
876                    .context("failed to copy trace file")?;
877            }
878            Command::VsockTcpRelay {
879                vsock_port,
880                tcp_port,
881                allow_remote,
882                reconnect,
883            } => {
884                let addr = if allow_remote { "0.0.0.0" } else { "127.0.0.1" };
885                let listener = TcpListener::bind((addr, tcp_port))
886                    .with_context(|| format!("binding to port {}", tcp_port))?;
887                println!("TCP listening on {}:{}", addr, tcp_port);
888                'connect: loop {
889                    let (tcp_socket, tcp_addr) = listener.accept()?;
890                    let tcp_socket = PolledSocket::new(&driver, tcp_socket)?;
891                    println!("TCP accept on {:?}", tcp_addr);
892
893                    // TODO: support reconnect attempt for vsock like kmsg
894                    let vsock = match vm.id {
895                        VmId::HybridVsock(ref path) => {
896                            // TODO: reconnection attempt logic like kmsg is
897                            // broken for hybrid_vsock with end of file error,
898                            // if this is started before the vm is started
899                            diag_client::connect_hybrid_vsock(&driver, path, vsock_port).await?
900                        }
901                        #[cfg(windows)]
902                        VmId::HyperV(ref name) => {
903                            let vm_id = diag_client::hyperv::vm_id_from_name(name)?;
904                            let stream =
905                                diag_client::hyperv::connect_vsock(&driver, vm_id, vsock_port)
906                                    .await?;
907                            PolledSocket::new(&driver, socket2::Socket::from(stream))?
908                        }
909                    };
910                    println!("VSOCK connect to port {:?}", vsock_port);
911
912                    let (tcp_read, mut tcp_write) = tcp_socket.split();
913                    let (vsock_read, mut vsock_write) = vsock.split();
914                    let tx = futures::io::copy(tcp_read, &mut vsock_write);
915                    let rx = futures::io::copy(vsock_read, &mut tcp_write);
916                    let result = futures::future::try_join(tx, rx).await;
917                    match result {
918                        Ok(_) => {}
919                        Err(e) => match e.kind() {
920                            ErrorKind::ConnectionReset => {}
921                            _ => return Err(anyhow::Error::from(e)),
922                        },
923                    }
924                    println!("Connection closed");
925
926                    if reconnect {
927                        println!("Reconnecting...");
928                        continue 'connect;
929                    }
930
931                    break;
932                }
933            }
934            Command::Pause => {
935                let client = new_client(driver.clone(), &vm)?;
936                client.pause().await?;
937            }
938            Command::Resume => {
939                let client = new_client(driver.clone(), &vm)?;
940                client.resume().await?;
941            }
942            Command::DumpSavedState { output } => {
943                ensure_not_terminal(&output)?;
944                let client = new_client(driver.clone(), &vm)?;
945                let mut file = create_or_stderr(&output)?;
946                file.write_all(&client.dump_saved_state().await?)?;
947            }
948            Command::MemoryProfileTrace { pid, name, output } => {
949                let client = new_client(driver.clone(), &vm)?;
950                let pid = if let Some(name) = name {
951                    client.get_pid(&name).await?
952                } else if let Some(pid) = pid {
953                    pid
954                } else {
955                    anyhow::bail!("either --pid or --name must be specified");
956                };
957                // Do not write anything on the stdout in case the output
958                // is set to stdout, to avoid breaking the output format
959                // of the trace.
960                let mut file = create_or_stderr(&output)?;
961                file.write_all(&client.memory_profile_trace(pid).await?)?;
962            }
963            Command::EfiDiagnostics { log_level, output } => {
964                let client = new_client(driver.clone(), &vm)?;
965                let arg = format!(
966                    "{},{}",
967                    log_level.as_inspect_value(),
968                    output.as_inspect_value()
969                );
970                let value = client
971                    .update("vm/uefi/process_diagnostics", &arg)
972                    .await
973                    .context("failed to process EFI diagnostics")?;
974                match value.kind {
975                    inspect::ValueKind::String(s) => print!("{s}"),
976                    _ => print!("{value}"),
977                }
978            }
979        }
980        Ok(())
981    })
982}
983
984fn ensure_not_terminal(path: &Option<PathBuf>) -> anyhow::Result<()> {
985    if path.is_none() && std::io::stdout().is_terminal() {
986        anyhow::bail!("cannot write to terminal");
987    }
988    Ok(())
989}
990
991fn create_or_stderr(path: &Option<PathBuf>) -> std::io::Result<fs_err::File> {
992    let file = match path {
993        Some(path) => fs_err::File::create(path)?,
994        None => fs_err::File::from_parts(term::raw_stdout(), "stdout"),
995    };
996    Ok(file)
997}
998
999async fn capture_packets(
1000    client: DiagClient,
1001    streams: Vec<impl Future<Output = Result<u64, std::io::Error>>>,
1002    capture_duration: Duration,
1003) {
1004    let mut capture_streams = FuturesUnordered::from_iter(streams);
1005    let (user_input_tx, mut user_input_rx) = mesh::channel();
1006    ctrlc::set_handler(move || user_input_tx.send(())).expect("Error setting Ctrl-C handler");
1007
1008    let mut ctx = mesh::CancelContext::new().with_timeout(capture_duration);
1009    let mut stop_signaled = std::pin::pin!(ctx.until_cancelled(user_input_rx.recv()));
1010
1011    let mut stop_streams = std::pin::pin!(async {
1012        if let Err(err) = client
1013            .packet_capture(PacketCaptureOperation::Stop, 0, 0)
1014            .await
1015        {
1016            eprintln!("Failed stop: {err}");
1017        }
1018    });
1019
1020    #[derive(PartialEq)]
1021    enum State {
1022        Running,
1023        Stopping,
1024        StoppingStreamsDone,
1025        Stopped,
1026    }
1027    let mut state = State::Running;
1028    loop {
1029        enum Event {
1030            Continue,
1031            StopSignaled,
1032            StopComplete,
1033            StreamsDone,
1034        }
1035        let stop = async {
1036            match state {
1037                State::Running => {
1038                    (&mut stop_signaled).await.ok();
1039                    Event::StopSignaled
1040                }
1041                State::Stopping | State::StoppingStreamsDone => {
1042                    (&mut stop_streams).await;
1043                    Event::StopComplete
1044                }
1045                State::Stopped => std::future::pending::<Event>().await,
1046            }
1047        };
1048        let process_streams = async {
1049            if state == State::StoppingStreamsDone {
1050                std::future::pending::<()>().await;
1051            }
1052            match capture_streams.next().await {
1053                Some(_) => Event::Continue,
1054                None => Event::StreamsDone,
1055            }
1056        };
1057        let event = (stop, process_streams).race();
1058
1059        // N.B Wait for all the copy tasks to complete to make sure the data is flushed to
1060        //     ensure compatibility with the packet capture protocol.
1061        match event.await {
1062            Event::Continue => continue,
1063            Event::StopSignaled => {
1064                println!("Stopping packet capture...");
1065                state = State::Stopping;
1066            }
1067            Event::StopComplete => {
1068                println!("Waiting for data to be flushed...");
1069                if state == State::Stopping {
1070                    state = State::Stopped;
1071                } else {
1072                    break;
1073                }
1074            }
1075            Event::StreamsDone if state == State::Stopping => {
1076                state = State::StoppingStreamsDone;
1077            }
1078            Event::StreamsDone => {
1079                if state != State::Stopped {
1080                    println!("Lost connection with network.");
1081                }
1082                break;
1083            }
1084        }
1085    }
1086    println!("All done.");
1087}