ohcldiag_dev/
main.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! A "move fast, break things" tool, that provides no long-term CLI stability
5//! guarantees.
6
7#![expect(missing_docs)]
8#![forbid(unsafe_code)]
9
10mod completions;
11
12use anyhow::Context;
13use clap::ArgGroup;
14use clap::Args;
15use clap::Parser;
16use clap::Subcommand;
17use diag_client::DiagClient;
18use diag_client::PacketCaptureOperation;
19use futures::StreamExt;
20use futures::io::AllowStdIo;
21use futures_concurrency::future::Race;
22use pal_async::DefaultPool;
23use pal_async::driver::Driver;
24use pal_async::socket::PolledSocket;
25use pal_async::task::Spawn;
26use pal_async::timer::PolledTimer;
27use std::convert::Infallible;
28use std::ffi::OsStr;
29use std::io::ErrorKind;
30use std::io::IsTerminal;
31use std::io::Write;
32use std::net::TcpListener;
33use std::path::Path;
34use std::path::PathBuf;
35use std::str::FromStr;
36use std::sync::Arc;
37use std::time::Duration;
38use thiserror::Error;
39use tracing_subscriber::layer::SubscriberExt;
40use tracing_subscriber::util::SubscriberInitExt;
41use unicycle::FuturesUnordered;
42
43#[derive(Parser)]
44#[clap(about = "(dev) CLI to interact with the Underhill diagnostics server")]
45#[clap(long_about = r#"
46CLI to interact with the Underhill diagnostics server.
47
48DISCLAIMER:
49    `ohcldiag-dev` does not make ANY stability guarantees regarding the layout of
50    the CLI, the syntax that is emitted via stdout/stderr, the location of nodes
51    in the `inspect` graph, etc...
52
53        !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
54        !! ANY AUTOMATION THAT USES ohcldiag-dev WILL EVENTUALLY BREAK !!
55        !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
56"#)]
57struct Options {
58    #[clap(flatten)]
59    vm: VmArg,
60
61    #[clap(subcommand)]
62    command: Command,
63}
64
65#[derive(Subcommand)]
66enum Command {
67    #[clap(hide = true)]
68    Complete(clap_dyn_complete::Complete),
69    Completions(completions::Completions),
70    /// Starts an interactive terminal in VTL2.
71    Shell {
72        /// The shell process to start.
73        #[clap(default_value = "/bin/sh")]
74        shell: String,
75        /// The arguments to pass to the shell process.
76        args: Vec<String>,
77    },
78    /// Runs a process in VTL2.
79    Run {
80        /// The command to run.
81        command: String,
82        /// Arguments to pass to the command.
83        args: Vec<String>,
84    },
85    /// Inspects the Underhill state.
86    #[clap(visible_alias = "i")]
87    Inspect {
88        /// Recursively enumerate child nodes.
89        #[clap(short)]
90        recursive: bool,
91        /// Limit the recursive inspection depth.
92        #[clap(short, long, requires("recursive"))]
93        limit: Option<usize>,
94        /// Output in JSON format.
95        #[clap(short, long)]
96        json: bool,
97        /// Poll periodically.
98        #[clap(short)]
99        poll: bool,
100        /// The poll period in seconds.
101        #[clap(long, default_value = "1", requires("poll"))]
102        period: f64,
103        /// The count of polls
104        #[clap(long, requires("poll"))]
105        count: Option<usize>,
106        /// The path to inspect.
107        path: Option<String>,
108        /// Update the path with a new value.
109        #[clap(short, long, conflicts_with("recursive"))]
110        update: Option<String>,
111        /// Timeout to wait for the inspection. 0 means no timeout.
112        #[clap(short, default_value = "1", conflicts_with("update"))]
113        timeout: u64,
114    },
115    /// Updates an inspectable value.
116    #[clap(hide = true)]
117    Update {
118        /// The path.
119        path: String,
120        /// The new value.
121        value: String,
122    },
123    /// Starts the VM if it's waiting for the signal to start.
124    ///
125    /// Underhill must have been started with --wait-for-start or
126    /// OPENHCL_WAIT_FOR_START set.
127    Start {
128        /// Environment variables to set, in the form X=Y
129        #[clap(short, long)]
130        env: Vec<EnvString>,
131        /// Environment variables to clear
132        #[clap(short, long)]
133        unset: Vec<String>,
134        /// Extra command line arguments to append.
135        args: Vec<String>,
136    },
137    /// Writes the contents of the kernel message buffer, /dev/kmsg.
138    Kmsg {
139        /// Keep waiting for and writing new data as its logged.
140        #[clap(short, long)]
141        follow: bool,
142        /// Reconnect (retrying indefinitely) whenever the connection is lost.
143        #[clap(short, long)]
144        reconnect: bool,
145        /// Write verbose information about the connection state.
146        #[clap(short, long)]
147        verbose: bool,
148        /// Read kmsg from the VM's serial port.
149        ///
150        /// This only works on Hyper-V.
151        #[cfg(windows)]
152        #[clap(long, conflicts_with = "reconnect")]
153        serial: bool,
154        /// Pipe to read from for the serial port (or any other pipe)
155        ///
156        /// This only works on Hyper-V.
157        #[cfg(windows)]
158        #[clap(long, requires = "serial")]
159        pipe_path: Option<String>,
160    },
161    /// Writes the contents of the file.
162    File {
163        /// Keep waiting for and writing new data as its logged.
164        #[clap(short, long)]
165        follow: bool,
166        #[clap(short('p'), long)]
167        file_path: String,
168    },
169    /// Starts GDB server on stdio.
170    ///
171    /// Use this with gdb's target command:
172    ///
173    ///     target remote |ohcldiag-dev.exe gdbserver my-vm
174    ///
175    /// Or for multi-process debugging:
176    ///
177    ///     target extended-remote |ohcldiag-dev.exe gdbserver --multi my-vm
178    Gdbserver {
179        /// The pid to attach to. Defaults to Underhill's.
180        #[clap(long)]
181        pid: Option<i32>,
182        /// Use multi-process debugging, for use with gdb's extended-remote.
183        #[clap(long, conflicts_with("pid"))]
184        multi: bool,
185    },
186    /// Starts the GDB stub for debugging the guest on stdio.
187    ///
188    /// Use this with gdb's target command:
189    ///
190    ///     target remote |ohcldiag-dev.exe gdbstub my-vm
191    ///
192    Gdbstub {
193        /// The vsock prot to connect to.
194        #[clap(short, long, default_value = "4")]
195        port: u32,
196    },
197    /// Crashes the VM.
198    ///
199    /// Must specify the VM name, as well as the crash type.
200    #[clap(group(
201        ArgGroup::new("process")
202            .required(true)
203            .args(&["pid", "name"]),
204    ))]
205    Crash {
206        /// Type of crash.
207        ///
208        /// Current crash types supported: "panic"
209        crash_type: CrashType,
210        /// PID of underhill process to crash
211        #[clap(short, long)]
212        pid: Option<i32>,
213        /// Name of underhill process to crash
214        #[clap(short, long)]
215        name: Option<String>,
216    },
217    /// Streams the ELF core dump file of a process to the host.
218    ///
219    /// Streams the core dump file of a process to the host where the file
220    /// is saved as `dst`.
221    #[clap(group(
222        ArgGroup::new("process")
223        .required(true)
224        .args(&["pid", "name"]),
225    ))]
226    CoreDump {
227        /// Enable verbose output.
228        #[clap(short, long)]
229        verbose: bool,
230        /// PID of process to dump
231        #[clap(short, long)]
232        pid: Option<i32>,
233        /// Name of underhill process to dump
234        #[clap(short, long)]
235        name: Option<String>,
236        /// Destination file path. If omitted, the data is written to the standard
237        /// output unless it is a terminal. In that case, an error is returned.
238        dst: Option<PathBuf>,
239    },
240    /// Restarts the Underhill worker process, keeping VTL0 running.
241    Restart,
242    /// Get the current contents of the performance trace buffer, for use with
243    /// <https://ui.perfetto.dev>.
244    PerfTrace {
245        /// The output file. Defaults to stdout.
246        #[clap(short)]
247        output: Option<PathBuf>,
248    },
249    /// Sets up a relay between a virtual socket and a TCP client on the host.
250    VsockTcpRelay {
251        vsock_port: u32,
252        tcp_port: u16,
253        #[clap(long)]
254        allow_remote: bool,
255        /// Reconnect (retrying indefinitely) whenever either side of the
256        /// connection is lost.
257        ///
258        /// NOTE: Today, this does not handle the case where the vsock side is
259        /// not ready to connect. That will cause the relay to terminate.
260        #[clap(short, long)]
261        reconnect: bool,
262    },
263    /// Pause the VM (including all devices)
264    Pause,
265    /// Resume the VM
266    Resume,
267    /// Dumps the VM's VTL2 state without servicing or tearing down Underhill.
268    DumpSavedState {
269        /// The output file. Defaults to stdout.
270        #[clap(short)]
271        output: Option<PathBuf>,
272    },
273    /// Starts a network packet capture trace.
274    PacketCapture {
275        /// Destination file path. nic index is appended to the file name.
276        #[clap(short('w'), default_value = "nic")]
277        output: PathBuf,
278        /// Number of seconds for which to capture packets.
279        #[clap(short('G'), long, default_value = "60", value_parser = |arg: &str| -> Result<Duration, std::num::ParseIntError> {Ok(Duration::from_secs(arg.parse()?))})]
280        seconds: Duration,
281        /// Length of the packet to capture.
282        #[clap(short('s'), long, default_value = "65535", value_parser = clap::value_parser!(u16).range(1..))]
283        snaplen: u16,
284    },
285}
286
287#[derive(Debug, Clone, Args)]
288pub struct VmArg {
289    #[doc = r#"VM identifier.
290
291    This can be one of:
292
293    * vsock:PATH - A path to a hybrid vsock Unix socket for a VM, as used by OpenVMM
294
295    * unix:PATH - A path to a Unix socket for connecting to the control plane
296
297    "#]
298    #[cfg_attr(
299        windows,
300        doc = "* hyperv:NAME - A Hyper-V VM name
301
302    "
303    )]
304    #[cfg_attr(
305        windows,
306        doc = "* NAME_OR_PATH - Either a Hyper-V VM name, or a path as in vsock:PATH>"
307    )]
308    #[cfg_attr(not(windows), doc = "* PATH - A path as in vsock:PATH")]
309    #[clap(name = "VM")]
310    id: VmId,
311}
312
313#[derive(Debug, Clone)]
314enum VmId {
315    #[cfg(windows)]
316    HyperV(String),
317    HybridVsock(PathBuf),
318}
319
320impl FromStr for VmId {
321    type Err = Infallible;
322
323    fn from_str(s: &str) -> Result<Self, Self::Err> {
324        if let Some(s) = s.strip_prefix("vsock:") {
325            Ok(Self::HybridVsock(Path::new(s).to_owned()))
326        } else {
327            #[cfg(windows)]
328            if let Some(s) = s.strip_prefix("hyperv:") {
329                return Ok(Self::HyperV(s.to_owned()));
330            } else if !pal::windows::fs::is_unix_socket(s.as_ref()).unwrap_or(false) {
331                return Ok(Self::HyperV(s.to_owned()));
332            }
333            // Default to hybrid vsock since this is what OpenVMM supports for
334            // Underhill.
335            Ok(Self::HybridVsock(Path::new(s).to_owned()))
336        }
337    }
338}
339
340#[derive(Clone)]
341struct EnvString {
342    name: String,
343    value: String,
344}
345
346#[derive(Clone, clap::ValueEnum)]
347enum CrashType {
348    #[clap(name = "panic")]
349    UhPanic,
350}
351
352#[derive(Debug, Error)]
353#[error("bad environment variable, expected VAR=value")]
354struct BadEnvString;
355
356impl FromStr for EnvString {
357    type Err = BadEnvString;
358
359    fn from_str(s: &str) -> Result<Self, Self::Err> {
360        let (name, value) = s.split_once('=').ok_or(BadEnvString)?;
361        Ok(Self {
362            name: name.to_owned(),
363            value: value.to_owned(),
364        })
365    }
366}
367
368// N.B. this exits after a successful completion.
369async fn run(
370    client: &DiagClient,
371    command: impl AsRef<str>,
372    args: impl IntoIterator<Item = impl AsRef<str>>,
373) -> anyhow::Result<()> {
374    // TODO: if stdout and stderr of this process are backed by the
375    // same thing, then pass combine_stderr instead.
376    let mut process = client
377        .exec(&command)
378        .args(args)
379        .stdin(true)
380        .stdout(true)
381        .stderr(true)
382        .spawn()
383        .await?;
384
385    let mut stdin = process.stdin.take().unwrap();
386    let mut stdout = process.stdout.take().unwrap();
387    let mut stderr = process.stderr.take().unwrap();
388
389    std::thread::spawn({
390        move || {
391            let _ = std::io::copy(&mut std::io::stdin(), &mut stdin);
392        }
393    });
394
395    let stderr_thread =
396        std::thread::spawn(move || std::io::copy(&mut stderr, &mut term::raw_stderr()));
397
398    std::io::copy(&mut stdout, &mut term::raw_stdout()).context("failed stdout copy")?;
399
400    stderr_thread
401        .join()
402        .unwrap()
403        .context("failed stderr thread")?;
404
405    let status = process.wait().await?;
406    std::process::exit(status.exit_code());
407}
408
409fn new_client(driver: impl Driver + Spawn + Clone, input: &VmArg) -> anyhow::Result<DiagClient> {
410    let client = match &input.id {
411        #[cfg(windows)]
412        VmId::HyperV(name) => DiagClient::from_hyperv_name(driver, name)?,
413        VmId::HybridVsock(path) => DiagClient::from_hybrid_vsock(driver, path),
414    };
415    Ok(client)
416}
417
418pub fn main() -> anyhow::Result<()> {
419    tracing_subscriber::registry()
420        .with(tracing_subscriber::fmt::layer())
421        .with(tracing_subscriber::EnvFilter::from_default_env())
422        .init();
423
424    term::enable_vt_and_utf8();
425    DefaultPool::run_with(async |driver| {
426        let Options { vm, command } = Options::parse();
427
428        match command {
429            Command::Complete(cmd) => {
430                cmd.println_to_stub_script::<Options>(
431                    None,
432                    completions::OhcldiagDevCompleteFactory {
433                        driver: driver.clone(),
434                    },
435                )
436                .await
437            }
438            Command::Completions(cmd) => cmd.run()?,
439            Command::Shell { shell, args } => {
440                let client = new_client(driver.clone(), &vm)?;
441
442                // Set TERM to ensure function keys and other characters work.
443                let term = std::env::var("TERM");
444                let term = term.as_deref().unwrap_or("xterm-256color");
445
446                let mut process = client
447                    .exec(&shell)
448                    .args(&args)
449                    .tty(true)
450                    .stdin(true)
451                    .stdout(true)
452                    .env("TERM", term)
453                    .spawn()
454                    .await?;
455
456                let mut stdin = process.stdin.take().unwrap();
457                let mut stdout = process.stdout.take().unwrap();
458
459                term::set_raw_console(true).expect("failed to set raw console mode");
460                std::thread::spawn({
461                    move || {
462                        let _ = std::io::copy(&mut std::io::stdin(), &mut stdin);
463                    }
464                });
465
466                std::io::copy(&mut stdout, &mut term::raw_stdout()).context("failed copy")?;
467
468                let status = process.wait().await?;
469
470                if !status.success() {
471                    eprintln!(
472                        "shell exited with non-zero exit code: {}",
473                        status.exit_code()
474                    );
475                }
476            }
477            Command::Run { command, args } => {
478                let client = new_client(driver.clone(), &vm)?;
479                run(&client, command, &args).await?;
480            }
481            Command::Inspect {
482                recursive,
483                limit,
484                json,
485                poll,
486                period,
487                count,
488                timeout,
489
490                path,
491                update,
492            } => {
493                let client = new_client(driver.clone(), &vm)?;
494
495                if let Some(update) = update {
496                    let Some(path) = path else {
497                        anyhow::bail!("must provide path for update")
498                    };
499
500                    let value = client.update(path, update).await?;
501                    println!("{value}");
502                } else {
503                    let timeout = if timeout == 0 {
504                        None
505                    } else {
506                        Some(Duration::from_secs(timeout))
507                    };
508                    let query = async || {
509                        client
510                            .inspect(
511                                path.as_deref().unwrap_or(""),
512                                if recursive { limit } else { Some(0) },
513                                timeout,
514                            )
515                            .await
516                    };
517
518                    if poll {
519                        let mut timer = PolledTimer::new(&driver);
520                        let period = Duration::from_secs_f64(period);
521                        let mut last_time = pal_async::timer::Instant::now();
522                        let mut last = query().await?;
523                        let mut count = count;
524
525                        loop {
526                            match count.as_mut() {
527                                Some(count) if *count == 0 => break,
528                                Some(count) => *count -= 1,
529                                None => {}
530                            }
531                            timer.sleep_until(last_time + period).await;
532                            let now = pal_async::timer::Instant::now();
533                            let this = query().await?;
534                            let diff = this.since(&last, now - last_time);
535                            if json {
536                                println!("{}", diff.json());
537                            } else {
538                                println!("{diff:#}");
539                            }
540                            last = this;
541                            last_time = now;
542                        }
543                    } else {
544                        let node = query().await?;
545                        if json {
546                            println!("{}", node.json());
547                        } else {
548                            println!("{node:#}");
549                        }
550                    }
551                }
552            }
553            Command::Update { path, value } => {
554                eprintln!(
555                    "`update` is deprecated - please use `ohcldiag-dev inspect <path> -u <new value>`"
556                );
557                let client = new_client(driver.clone(), &vm)?;
558                let value = client.update(path, value).await?;
559                println!("{value}");
560            }
561            Command::Start { env, unset, args } => {
562                let client = new_client(driver.clone(), &vm)?;
563
564                let env = env
565                    .into_iter()
566                    .map(|EnvString { name, value }| (name, Some(value)))
567                    .chain(unset.into_iter().map(|name| (name, None)));
568
569                client.start(env, args).await?;
570            }
571            Command::Kmsg {
572                follow,
573                reconnect,
574                verbose,
575                #[cfg(windows)]
576                serial,
577                #[cfg(windows)]
578                pipe_path,
579            } => {
580                let is_terminal = std::io::stdout().is_terminal();
581
582                #[cfg(windows)]
583                if serial {
584                    use diag_client::hyperv::ComPortAccessInfo;
585                    use futures::AsyncBufReadExt;
586
587                    let vm_name = match &vm.id {
588                        VmId::HyperV(name) => name,
589                        _ => anyhow::bail!("--serial is only supported for Hyper-V VMs"),
590                    };
591
592                    let port_access_info = if let Some(pipe_path) = pipe_path.as_ref() {
593                        ComPortAccessInfo::PortPipePath(pipe_path)
594                    } else {
595                        ComPortAccessInfo::NameAndPortNumber(vm_name, 3)
596                    };
597
598                    let pipe =
599                        diag_client::hyperv::open_serial_port(&driver, port_access_info).await?;
600                    let pipe = pal_async::pipe::PolledPipe::new(&driver, pipe)
601                        .context("failed to make a polled pipe")?;
602                    let pipe = futures::io::BufReader::new(pipe);
603
604                    let mut lines = pipe.lines();
605                    while let Some(line) = lines.next().await {
606                        let line = line?;
607                        if let Some(message) = kmsg::SyslogParsedEntry::new(&line) {
608                            println!("{}", message.display(is_terminal));
609                        } else {
610                            println!("{line}");
611                        }
612                    }
613
614                    return Ok(());
615                }
616
617                if verbose {
618                    eprintln!("Connecting to the diagnostics server.");
619                }
620
621                let client = new_client(driver.clone(), &vm)?;
622                'connect: loop {
623                    if reconnect {
624                        client.wait_for_server().await?;
625                    }
626                    let mut file_stream = client.kmsg(follow).await?;
627                    if verbose {
628                        eprintln!("Connected.");
629                    }
630
631                    while let Some(data) = file_stream.next().await {
632                        match data {
633                            Ok(data) => match kmsg::KmsgParsedEntry::new(&data) {
634                                Ok(message) => println!("{}", message.display(is_terminal)),
635                                Err(e) => println!("Invalid kmsg entry: {e:?}"),
636                            },
637                            Err(err) if reconnect && err.kind() == ErrorKind::ConnectionReset => {
638                                if verbose {
639                                    eprintln!(
640                                        "Connection reset to the diagnostics server. Reconnecting."
641                                    );
642                                }
643                                continue 'connect;
644                            }
645                            Err(err) => Err(err).context("failed to read kmsg")?,
646                        }
647                    }
648
649                    if reconnect {
650                        if verbose {
651                            eprintln!("Lost connection to the diagnostics server. Reconnecting.");
652                        }
653                        continue 'connect;
654                    }
655
656                    break;
657                }
658            }
659            Command::File { follow, file_path } => {
660                let client = new_client(driver.clone(), &vm)?;
661                let stream = client.read_file(follow, file_path).await?;
662                futures::io::copy(stream, &mut AllowStdIo::new(term::raw_stdout()))
663                    .await
664                    .context("failed to copy trace file")?;
665            }
666            Command::Gdbserver { multi, pid } => {
667                let client = new_client(driver.clone(), &vm)?;
668                // Pass the --once flag so that gdbserver exits after the stdio
669                // pipes are closed. Otherwise, gdbserver spins in a tight loop
670                // and never exits.
671                let gdbserver = "gdbserver --once";
672                let command = if multi {
673                    format!("{gdbserver} --multi -")
674                } else if let Some(pid) = pid {
675                    format!("{gdbserver} --attach - {pid}")
676                } else {
677                    format!("{gdbserver} --attach - \"$(cat /run/underhill.pid)\"")
678                };
679
680                run(&client, "/bin/sh", &["-c", &command]).await?;
681            }
682            Command::Gdbstub { port } => {
683                let vsock = match vm.id {
684                    VmId::HybridVsock(path) => {
685                        diag_client::connect_hybrid_vsock(&driver, &path, port).await?
686                    }
687                    #[cfg(windows)]
688                    VmId::HyperV(name) => {
689                        let vm_id = diag_client::hyperv::vm_id_from_name(&name)?;
690                        let stream =
691                            diag_client::hyperv::connect_vsock(&driver, vm_id, port).await?;
692                        PolledSocket::new(&driver, socket2::Socket::from(stream))?
693                    }
694                };
695
696                let vsock = Arc::new(vsock.into_inner());
697                // Spawn a thread to read stdin synchronously since pal_async
698                // does not offer a way to read it asynchronously.
699                let thread = std::thread::spawn({
700                    let vsock = vsock.clone();
701                    move || {
702                        let _ = std::io::copy(&mut std::io::stdin(), &mut vsock.as_ref());
703                    }
704                });
705
706                std::io::copy(&mut vsock.as_ref(), &mut term::raw_stdout())
707                    .context("failed stdout copy")?;
708                thread.join().unwrap();
709            }
710            Command::Crash {
711                crash_type,
712                pid,
713                name,
714            } => {
715                let client = new_client(driver.clone(), &vm)?;
716                let pid = if let Some(name) = name {
717                    client.get_pid(&name).await?
718                } else {
719                    pid.unwrap()
720                };
721                println!("Crashing PID: {pid}");
722                match crash_type {
723                    CrashType::UhPanic => {
724                        _ = client.crash(pid).await;
725                    }
726                }
727            }
728            Command::PacketCapture {
729                output,
730                seconds,
731                snaplen,
732            } => {
733                let client = new_client(driver.clone(), &vm)?;
734                println!(
735                    "Starting network packet capture. Wait for timeout or Ctrl-C to quit anytime."
736                );
737                let (_, num_streams) = client
738                    .packet_capture(PacketCaptureOperation::Query, 0, 0)
739                    .await?;
740                let file_stem = &output.file_stem().unwrap().to_string_lossy();
741                let extension = &output.extension().unwrap_or(OsStr::new("pcap"));
742                let mut new_output = PathBuf::from(&output);
743                let streams = client
744                    .packet_capture(PacketCaptureOperation::Start, num_streams, snaplen)
745                    .await?
746                    .0
747                    .into_iter()
748                    .enumerate()
749                    .map(|(i, i_stream)| {
750                        new_output.set_file_name(format!("{}-{}", &file_stem, i));
751                        new_output.set_extension(extension);
752                        let mut out = AllowStdIo::new(fs_err::File::create(&new_output)?);
753                        Ok(async move { futures::io::copy(i_stream, &mut out).await })
754                    })
755                    .collect::<Result<Vec<_>, std::io::Error>>()?;
756                capture_packets(client, streams, seconds).await;
757            }
758            Command::CoreDump {
759                verbose,
760                pid,
761                name,
762                dst,
763            } => {
764                ensure_not_terminal(&dst)?;
765                let client = new_client(driver.clone(), &vm)?;
766                let pid = if let Some(name) = name {
767                    client.get_pid(&name).await?
768                } else {
769                    pid.unwrap()
770                };
771                println!("Dumping PID: {pid}");
772                let file = create_or_stderr(&dst)?;
773                client
774                    .core_dump(
775                        pid,
776                        AllowStdIo::new(file),
777                        AllowStdIo::new(std::io::stderr()),
778                        verbose,
779                    )
780                    .await?;
781            }
782            Command::Restart => {
783                let client = new_client(driver.clone(), &vm)?;
784                client.restart().await?;
785            }
786            Command::PerfTrace { output } => {
787                ensure_not_terminal(&output)?;
788
789                let client = new_client(driver.clone(), &vm)?;
790
791                // Flush the perf trace.
792                client
793                    .update("trace/perf/flush".to_owned(), "true".to_owned())
794                    .await
795                    .context("failed to flush perf")?;
796
797                let file = create_or_stderr(&output)?;
798                let stream = client
799                    .read_file(false, "underhill.perfetto".to_owned())
800                    .await
801                    .context("failed to read trace file")?;
802
803                futures::io::copy(stream, &mut AllowStdIo::new(file))
804                    .await
805                    .context("failed to copy trace file")?;
806            }
807            Command::VsockTcpRelay {
808                vsock_port,
809                tcp_port,
810                allow_remote,
811                reconnect,
812            } => {
813                let addr = if allow_remote { "0.0.0.0" } else { "127.0.0.1" };
814                let listener = TcpListener::bind((addr, tcp_port))
815                    .with_context(|| format!("binding to port {}", tcp_port))?;
816                println!("TCP listening on {}:{}", addr, tcp_port);
817                'connect: loop {
818                    let (tcp_socket, tcp_addr) = listener.accept()?;
819                    let tcp_socket = PolledSocket::new(&driver, tcp_socket)?;
820                    println!("TCP accept on {:?}", tcp_addr);
821
822                    // TODO: support reconnect attempt for vsock like kmsg
823                    let vsock = match vm.id {
824                        VmId::HybridVsock(ref path) => {
825                            // TODO: reconnection attempt logic like kmsg is
826                            // broken for hybrid_vsock with end of file error,
827                            // if this is started before the vm is started
828                            diag_client::connect_hybrid_vsock(&driver, path, vsock_port).await?
829                        }
830                        #[cfg(windows)]
831                        VmId::HyperV(ref name) => {
832                            let vm_id = diag_client::hyperv::vm_id_from_name(name)?;
833                            let stream =
834                                diag_client::hyperv::connect_vsock(&driver, vm_id, vsock_port)
835                                    .await?;
836                            PolledSocket::new(&driver, socket2::Socket::from(stream))?
837                        }
838                    };
839                    println!("VSOCK connect to port {:?}", vsock_port);
840
841                    let (tcp_read, mut tcp_write) = tcp_socket.split();
842                    let (vsock_read, mut vsock_write) = vsock.split();
843                    let tx = futures::io::copy(tcp_read, &mut vsock_write);
844                    let rx = futures::io::copy(vsock_read, &mut tcp_write);
845                    let result = futures::future::try_join(tx, rx).await;
846                    match result {
847                        Ok(_) => {}
848                        Err(e) => match e.kind() {
849                            ErrorKind::ConnectionReset => {}
850                            _ => return Err(anyhow::Error::from(e)),
851                        },
852                    }
853                    println!("Connection closed");
854
855                    if reconnect {
856                        println!("Reconnecting...");
857                        continue 'connect;
858                    }
859
860                    break;
861                }
862            }
863            Command::Pause => {
864                let client = new_client(driver.clone(), &vm)?;
865                client.pause().await?;
866            }
867            Command::Resume => {
868                let client = new_client(driver.clone(), &vm)?;
869                client.resume().await?;
870            }
871            Command::DumpSavedState { output } => {
872                ensure_not_terminal(&output)?;
873                let client = new_client(driver.clone(), &vm)?;
874                let mut file = create_or_stderr(&output)?;
875                file.write_all(&client.dump_saved_state().await?)?;
876            }
877        }
878        Ok(())
879    })
880}
881
882fn ensure_not_terminal(path: &Option<PathBuf>) -> anyhow::Result<()> {
883    if path.is_none() && std::io::stdout().is_terminal() {
884        anyhow::bail!("cannot write to terminal");
885    }
886    Ok(())
887}
888
889fn create_or_stderr(path: &Option<PathBuf>) -> std::io::Result<fs_err::File> {
890    let file = match path {
891        Some(path) => fs_err::File::create(path)?,
892        None => fs_err::File::from_parts(term::raw_stdout(), "stdout"),
893    };
894    Ok(file)
895}
896
897async fn capture_packets(
898    client: DiagClient,
899    streams: Vec<impl Future<Output = Result<u64, std::io::Error>>>,
900    capture_duration: Duration,
901) {
902    let mut capture_streams = FuturesUnordered::from_iter(streams);
903    let (user_input_tx, mut user_input_rx) = mesh::channel();
904    ctrlc::set_handler(move || user_input_tx.send(())).expect("Error setting Ctrl-C handler");
905
906    let mut ctx = mesh::CancelContext::new().with_timeout(capture_duration);
907    let mut stop_signaled = std::pin::pin!(ctx.until_cancelled(user_input_rx.recv()));
908
909    let mut stop_streams = std::pin::pin!(async {
910        if let Err(err) = client
911            .packet_capture(PacketCaptureOperation::Stop, 0, 0)
912            .await
913        {
914            eprintln!("Failed stop: {err}");
915        }
916    });
917
918    #[derive(PartialEq)]
919    enum State {
920        Running,
921        Stopping,
922        StoppingStreamsDone,
923        Stopped,
924    }
925    let mut state = State::Running;
926    loop {
927        enum Event {
928            Continue,
929            StopSignaled,
930            StopComplete,
931            StreamsDone,
932        }
933        let stop = async {
934            match state {
935                State::Running => {
936                    (&mut stop_signaled).await.ok();
937                    Event::StopSignaled
938                }
939                State::Stopping | State::StoppingStreamsDone => {
940                    (&mut stop_streams).await;
941                    Event::StopComplete
942                }
943                State::Stopped => std::future::pending::<Event>().await,
944            }
945        };
946        let process_streams = async {
947            if state == State::StoppingStreamsDone {
948                std::future::pending::<()>().await;
949            }
950            match capture_streams.next().await {
951                Some(_) => Event::Continue,
952                None => Event::StreamsDone,
953            }
954        };
955        let event = (stop, process_streams).race();
956
957        // N.B Wait for all the copy tasks to complete to make sure the data is flushed to
958        //     ensure compatibility with the packet capture protocol.
959        match event.await {
960            Event::Continue => continue,
961            Event::StopSignaled => {
962                println!("Stopping packet capture...");
963                state = State::Stopping;
964            }
965            Event::StopComplete => {
966                println!("Waiting for data to be flushed...");
967                if state == State::Stopping {
968                    state = State::Stopped;
969                } else {
970                    break;
971                }
972            }
973            Event::StreamsDone if state == State::Stopping => {
974                state = State::StoppingStreamsDone;
975            }
976            Event::StreamsDone => {
977                if state != State::Stopped {
978                    println!("Lost connection with network.");
979                }
980                break;
981            }
982        }
983    }
984    println!("All done.");
985}