ohcldiag_dev/
main.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! A "move fast, break things" tool, that provides no long-term CLI stability
5//! guarantees.
6
7#![expect(missing_docs)]
8
9mod completions;
10
11use anyhow::Context;
12use clap::ArgGroup;
13use clap::Args;
14use clap::Parser;
15use clap::Subcommand;
16use diag_client::DiagClient;
17use diag_client::PacketCaptureOperation;
18use futures::StreamExt;
19use futures::io::AllowStdIo;
20use futures_concurrency::future::Race;
21use pal_async::DefaultPool;
22use pal_async::driver::Driver;
23use pal_async::socket::PolledSocket;
24use pal_async::task::Spawn;
25use pal_async::timer::PolledTimer;
26use std::convert::Infallible;
27use std::ffi::OsStr;
28use std::io::ErrorKind;
29use std::io::IsTerminal;
30use std::io::Write;
31use std::net::TcpListener;
32use std::path::Path;
33use std::path::PathBuf;
34use std::str::FromStr;
35use std::sync::Arc;
36use std::time::Duration;
37use thiserror::Error;
38use tracing_subscriber::layer::SubscriberExt;
39use tracing_subscriber::util::SubscriberInitExt;
40use unicycle::FuturesUnordered;
41
42#[derive(Parser)]
43#[clap(about = "(dev) CLI to interact with the Underhill diagnostics server")]
44#[clap(long_about = r#"
45CLI to interact with the Underhill diagnostics server.
46
47DISCLAIMER:
48    `ohcldiag-dev` does not make ANY stability guarantees regarding the layout of
49    the CLI, the syntax that is emitted via stdout/stderr, the location of nodes
50    in the `inspect` graph, etc...
51
52        !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
53        !! ANY AUTOMATION THAT USES ohcldiag-dev WILL EVENTUALLY BREAK !!
54        !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
55"#)]
56struct Options {
57    #[clap(flatten)]
58    vm: VmArg,
59
60    #[clap(subcommand)]
61    command: Command,
62}
63
64#[derive(Subcommand)]
65enum Command {
66    #[clap(hide = true)]
67    Complete(clap_dyn_complete::Complete),
68    Completions(completions::Completions),
69    /// Starts an interactive terminal in VTL2.
70    Shell {
71        /// The shell process to start.
72        #[clap(default_value = "/bin/sh")]
73        shell: String,
74        /// The arguments to pass to the shell process.
75        args: Vec<String>,
76    },
77    /// Runs a process in VTL2.
78    Run {
79        /// The command to run.
80        command: String,
81        /// Arguments to pass to the command.
82        args: Vec<String>,
83    },
84    /// Inspects the Underhill state.
85    #[clap(visible_alias = "i")]
86    Inspect {
87        /// Recursively enumerate child nodes.
88        #[clap(short)]
89        recursive: bool,
90        /// Limit the recursive inspection depth.
91        #[clap(short, long, requires("recursive"))]
92        limit: Option<usize>,
93        /// Output in JSON format.
94        #[clap(short, long)]
95        json: bool,
96        /// Poll periodically.
97        #[clap(short)]
98        poll: bool,
99        /// The poll period in seconds.
100        #[clap(long, default_value = "1", requires("poll"))]
101        period: f64,
102        /// The count of polls
103        #[clap(long, requires("poll"))]
104        count: Option<usize>,
105        /// The path to inspect.
106        path: Option<String>,
107        /// Update the path with a new value.
108        #[clap(short, long, conflicts_with("recursive"))]
109        update: Option<String>,
110        /// Timeout to wait for the inspection. 0 means no timeout.
111        #[clap(short, default_value = "1", conflicts_with("update"))]
112        timeout: u64,
113    },
114    /// Updates an inspectable value.
115    #[clap(hide = true)]
116    Update {
117        /// The path.
118        path: String,
119        /// The new value.
120        value: String,
121    },
122    /// Starts the VM if it's waiting for the signal to start.
123    ///
124    /// Underhill must have been started with --wait-for-start or
125    /// OPENHCL_WAIT_FOR_START set.
126    Start {
127        /// Environment variables to set, in the form X=Y
128        #[clap(short, long)]
129        env: Vec<EnvString>,
130        /// Environment variables to clear
131        #[clap(short, long)]
132        unset: Vec<String>,
133        /// Extra command line arguments to append.
134        args: Vec<String>,
135    },
136    /// Writes the contents of the kernel message buffer, /dev/kmsg.
137    Kmsg {
138        /// Keep waiting for and writing new data as its logged.
139        #[clap(short, long)]
140        follow: bool,
141        /// Reconnect (retrying indefinitely) whenever the connection is lost.
142        #[clap(short, long)]
143        reconnect: bool,
144        /// Write verbose information about the connection state.
145        #[clap(short, long)]
146        verbose: bool,
147        /// Read kmsg from the VM's serial port.
148        ///
149        /// This only works on Hyper-V.
150        #[cfg(windows)]
151        #[clap(long, conflicts_with = "reconnect")]
152        serial: bool,
153        /// Pipe to read from for the serial port (or any other pipe)
154        ///
155        /// This only works on Hyper-V.
156        #[cfg(windows)]
157        #[clap(long, requires = "serial")]
158        pipe_path: Option<String>,
159    },
160    /// Writes the contents of the file.
161    File {
162        /// Keep waiting for and writing new data as its logged.
163        #[clap(short, long)]
164        follow: bool,
165        #[clap(short('p'), long)]
166        file_path: String,
167    },
168    /// Starts GDB server on stdio.
169    ///
170    /// Use this with gdb's target command:
171    ///
172    ///     target remote |ohcldiag-dev.exe gdbserver my-vm
173    ///
174    /// Or for multi-process debugging:
175    ///
176    ///     target extended-remote |ohcldiag-dev.exe gdbserver --multi my-vm
177    Gdbserver {
178        /// The pid to attach to. Defaults to Underhill's.
179        #[clap(long)]
180        pid: Option<i32>,
181        /// Use multi-process debugging, for use with gdb's extended-remote.
182        #[clap(long, conflicts_with("pid"))]
183        multi: bool,
184    },
185    /// Starts the GDB stub for debugging the guest on stdio.
186    ///
187    /// Use this with gdb's target command:
188    ///
189    ///     target remote |ohcldiag-dev.exe gdbstub my-vm
190    ///
191    Gdbstub {
192        /// The vsock prot to connect to.
193        #[clap(short, long, default_value = "4")]
194        port: u32,
195    },
196    /// Crashes the VM.
197    ///
198    /// Must specify the VM name, as well as the crash type.
199    #[clap(group(
200        ArgGroup::new("process")
201            .required(true)
202            .args(&["pid", "name"]),
203    ))]
204    Crash {
205        /// Type of crash.
206        ///
207        /// Current crash types supported: "panic"
208        crash_type: CrashType,
209        /// PID of underhill process to crash
210        #[clap(short, long)]
211        pid: Option<i32>,
212        /// Name of underhill process to crash
213        #[clap(short, long)]
214        name: Option<String>,
215    },
216    /// Streams the ELF core dump file of a process to the host.
217    ///
218    /// Streams the core dump file of a process to the host where the file
219    /// is saved as `dst`.
220    #[clap(group(
221        ArgGroup::new("process")
222        .required(true)
223        .args(&["pid", "name"]),
224    ))]
225    CoreDump {
226        /// Enable verbose output.
227        #[clap(short, long)]
228        verbose: bool,
229        /// PID of process to dump
230        #[clap(short, long)]
231        pid: Option<i32>,
232        /// Name of underhill process to dump
233        #[clap(short, long)]
234        name: Option<String>,
235        /// Destination file path. If omitted, the data is written to the standard
236        /// output unless it is a terminal. In that case, an error is returned.
237        dst: Option<PathBuf>,
238    },
239    /// Restarts the Underhill worker process, keeping VTL0 running.
240    Restart,
241    /// Get the current contents of the performance trace buffer, for use with
242    /// <https://ui.perfetto.dev>.
243    PerfTrace {
244        /// The output file. Defaults to stdout.
245        #[clap(short)]
246        output: Option<PathBuf>,
247    },
248    /// Sets up a relay between a virtual socket and a TCP client on the host.
249    VsockTcpRelay {
250        vsock_port: u32,
251        tcp_port: u16,
252        #[clap(long)]
253        allow_remote: bool,
254        /// Reconnect (retrying indefinitely) whenever either side of the
255        /// connection is lost.
256        ///
257        /// NOTE: Today, this does not handle the case where the vsock side is
258        /// not ready to connect. That will cause the relay to terminate.
259        #[clap(short, long)]
260        reconnect: bool,
261    },
262    /// Pause the VM (including all devices)
263    Pause,
264    /// Resume the VM
265    Resume,
266    /// Dumps the VM's VTL2 state without servicing or tearing down Underhill.
267    DumpSavedState {
268        /// The output file. Defaults to stdout.
269        #[clap(short)]
270        output: Option<PathBuf>,
271    },
272    /// Starts a network packet capture trace.
273    PacketCapture {
274        /// Destination file path. nic index is appended to the file name.
275        #[clap(short('w'), default_value = "nic")]
276        output: PathBuf,
277        /// Number of seconds for which to capture packets.
278        #[clap(short('G'), long, default_value = "60", value_parser = |arg: &str| -> Result<Duration, std::num::ParseIntError> {Ok(Duration::from_secs(arg.parse()?))})]
279        seconds: Duration,
280        /// Length of the packet to capture.
281        #[clap(short('s'), long, default_value = "65535", value_parser = clap::value_parser!(u16).range(1..))]
282        snaplen: u16,
283    },
284}
285
286#[derive(Debug, Clone, Args)]
287pub struct VmArg {
288    #[doc = r#"VM identifier.
289
290    This can be one of:
291
292    * vsock:PATH - A path to a hybrid vsock Unix socket for a VM, as used by HvLite
293
294    * unix:PATH - A path to a Unix socket for connecting to the control plane
295
296    "#]
297    #[cfg_attr(
298        windows,
299        doc = "* hyperv:NAME - A Hyper-V VM name
300
301    "
302    )]
303    #[cfg_attr(
304        windows,
305        doc = "* NAME_OR_PATH - Either a Hyper-V VM name, or a path as in vsock:PATH>"
306    )]
307    #[cfg_attr(not(windows), doc = "* PATH - A path as in vsock:PATH")]
308    #[clap(name = "VM")]
309    id: VmId,
310}
311
312#[derive(Debug, Clone)]
313enum VmId {
314    #[cfg(windows)]
315    HyperV(String),
316    HybridVsock(PathBuf),
317}
318
319impl FromStr for VmId {
320    type Err = Infallible;
321
322    fn from_str(s: &str) -> Result<Self, Self::Err> {
323        if let Some(s) = s.strip_prefix("vsock:") {
324            Ok(Self::HybridVsock(Path::new(s).to_owned()))
325        } else {
326            #[cfg(windows)]
327            if let Some(s) = s.strip_prefix("hyperv:") {
328                return Ok(Self::HyperV(s.to_owned()));
329            } else if !pal::windows::fs::is_unix_socket(s.as_ref()).unwrap_or(false) {
330                return Ok(Self::HyperV(s.to_owned()));
331            }
332            // Default to hybrid vsock since this is what HvLite supports for
333            // Underhill.
334            Ok(Self::HybridVsock(Path::new(s).to_owned()))
335        }
336    }
337}
338
339#[derive(Clone)]
340struct EnvString {
341    name: String,
342    value: String,
343}
344
345#[derive(Clone, clap::ValueEnum)]
346enum CrashType {
347    #[clap(name = "panic")]
348    UhPanic,
349}
350
351#[derive(Debug, Error)]
352#[error("bad environment variable, expected VAR=value")]
353struct BadEnvString;
354
355impl FromStr for EnvString {
356    type Err = BadEnvString;
357
358    fn from_str(s: &str) -> Result<Self, Self::Err> {
359        let (name, value) = s.split_once('=').ok_or(BadEnvString)?;
360        Ok(Self {
361            name: name.to_owned(),
362            value: value.to_owned(),
363        })
364    }
365}
366
367// N.B. this exits after a successful completion.
368async fn run(
369    client: &DiagClient,
370    command: impl AsRef<str>,
371    args: impl IntoIterator<Item = impl AsRef<str>>,
372) -> anyhow::Result<()> {
373    // TODO: if stdout and stderr of this process are backed by the
374    // same thing, then pass combine_stderr instead.
375    let mut process = client
376        .exec(&command)
377        .args(args)
378        .stdin(true)
379        .stdout(true)
380        .stderr(true)
381        .spawn()
382        .await?;
383
384    let mut stdin = process.stdin.take().unwrap();
385    let mut stdout = process.stdout.take().unwrap();
386    let mut stderr = process.stderr.take().unwrap();
387
388    std::thread::spawn({
389        move || {
390            let _ = std::io::copy(&mut std::io::stdin(), &mut stdin);
391        }
392    });
393
394    let stderr_thread =
395        std::thread::spawn(move || std::io::copy(&mut stderr, &mut term::raw_stderr()));
396
397    std::io::copy(&mut stdout, &mut term::raw_stdout()).context("failed stdout copy")?;
398
399    stderr_thread
400        .join()
401        .unwrap()
402        .context("failed stdout copy")?;
403
404    let status = process.wait().await?;
405    std::process::exit(status.exit_code());
406}
407
408fn new_client(driver: impl Driver + Spawn + Clone, input: &VmArg) -> anyhow::Result<DiagClient> {
409    let client = match &input.id {
410        #[cfg(windows)]
411        VmId::HyperV(name) => DiagClient::from_hyperv_name(driver, name)?,
412        VmId::HybridVsock(path) => DiagClient::from_hybrid_vsock(driver, path),
413    };
414    Ok(client)
415}
416
417pub fn main() -> anyhow::Result<()> {
418    tracing_subscriber::registry()
419        .with(tracing_subscriber::fmt::layer())
420        .with(tracing_subscriber::EnvFilter::from_default_env())
421        .init();
422
423    term::enable_vt_and_utf8();
424    DefaultPool::run_with(async |driver| {
425        let Options { vm, command } = Options::parse();
426
427        match command {
428            Command::Complete(cmd) => {
429                cmd.println_to_stub_script::<Options>(
430                    None,
431                    completions::OhcldiagDevCompleteFactory {
432                        driver: driver.clone(),
433                    },
434                )
435                .await
436            }
437            Command::Completions(cmd) => cmd.run()?,
438            Command::Shell { shell, args } => {
439                let client = new_client(driver.clone(), &vm)?;
440
441                // Set TERM to ensure function keys and other characters work.
442                let term = std::env::var("TERM");
443                let term = term.as_deref().unwrap_or("xterm-256color");
444
445                let mut process = client
446                    .exec(&shell)
447                    .args(&args)
448                    .tty(true)
449                    .stdin(true)
450                    .stdout(true)
451                    .env("TERM", term)
452                    .spawn()
453                    .await?;
454
455                let mut stdin = process.stdin.take().unwrap();
456                let mut stdout = process.stdout.take().unwrap();
457
458                term::set_raw_console(true);
459                std::thread::spawn({
460                    move || {
461                        let _ = std::io::copy(&mut std::io::stdin(), &mut stdin);
462                    }
463                });
464
465                std::io::copy(&mut stdout, &mut term::raw_stdout()).context("failed copy")?;
466
467                let status = process.wait().await?;
468
469                if !status.success() {
470                    eprintln!(
471                        "shell exited with non-zero exit code: {}",
472                        status.exit_code()
473                    );
474                }
475            }
476            Command::Run { command, args } => {
477                let client = new_client(driver.clone(), &vm)?;
478                run(&client, command, &args).await?;
479            }
480            Command::Inspect {
481                recursive,
482                limit,
483                json,
484                poll,
485                period,
486                count,
487                timeout,
488
489                path,
490                update,
491            } => {
492                let client = new_client(driver.clone(), &vm)?;
493
494                if let Some(update) = update {
495                    let Some(path) = path else {
496                        anyhow::bail!("must provide path for update")
497                    };
498
499                    let value = client.update(path, update).await?;
500                    println!("{value}");
501                } else {
502                    let timeout = if timeout == 0 {
503                        None
504                    } else {
505                        Some(Duration::from_secs(timeout))
506                    };
507                    let query = async || {
508                        client
509                            .inspect(
510                                path.as_deref().unwrap_or(""),
511                                if recursive { limit } else { Some(0) },
512                                timeout,
513                            )
514                            .await
515                    };
516
517                    if poll {
518                        let mut timer = PolledTimer::new(&driver);
519                        let period = Duration::from_secs_f64(period);
520                        let mut last_time = pal_async::timer::Instant::now();
521                        let mut last = query().await?;
522                        let mut count = count;
523
524                        loop {
525                            match count.as_mut() {
526                                Some(count) if *count == 0 => break,
527                                Some(count) => *count -= 1,
528                                None => {}
529                            }
530                            timer.sleep_until(last_time + period).await;
531                            let now = pal_async::timer::Instant::now();
532                            let this = query().await?;
533                            let diff = this.since(&last, now - last_time);
534                            if json {
535                                println!("{}", diff.json());
536                            } else {
537                                println!("{diff:#}");
538                            }
539                            last = this;
540                            last_time = now;
541                        }
542                    } else {
543                        let node = query().await?;
544                        if json {
545                            println!("{}", node.json());
546                        } else {
547                            println!("{node:#}");
548                        }
549                    }
550                }
551            }
552            Command::Update { path, value } => {
553                eprintln!(
554                    "`update` is deprecated - please use `ohcldiag-dev inspect <path> -u <new value>`"
555                );
556                let client = new_client(driver.clone(), &vm)?;
557                let value = client.update(path, value).await?;
558                println!("{value}");
559            }
560            Command::Start { env, unset, args } => {
561                let client = new_client(driver.clone(), &vm)?;
562
563                let env = env
564                    .into_iter()
565                    .map(|EnvString { name, value }| (name, Some(value)))
566                    .chain(unset.into_iter().map(|name| (name, None)));
567
568                client.start(env, args).await?;
569            }
570            Command::Kmsg {
571                follow,
572                reconnect,
573                verbose,
574                #[cfg(windows)]
575                serial,
576                #[cfg(windows)]
577                pipe_path,
578            } => {
579                let is_terminal = std::io::stdout().is_terminal();
580
581                #[cfg(windows)]
582                if serial {
583                    use diag_client::hyperv::ComPortAccessInfo;
584                    use futures::AsyncBufReadExt;
585
586                    let vm_name = match &vm.id {
587                        VmId::HyperV(name) => name,
588                        _ => anyhow::bail!("--serial is only supported for Hyper-V VMs"),
589                    };
590
591                    let port_access_info = if let Some(pipe_path) = pipe_path.as_ref() {
592                        ComPortAccessInfo::PortPipePath(pipe_path)
593                    } else {
594                        ComPortAccessInfo::NameAndPortNumber(vm_name, 3)
595                    };
596
597                    let pipe =
598                        diag_client::hyperv::open_serial_port(&driver, port_access_info).await?;
599                    let pipe = pal_async::pipe::PolledPipe::new(&driver, pipe)
600                        .context("failed to make a polled pipe")?;
601                    let pipe = futures::io::BufReader::new(pipe);
602
603                    let mut lines = pipe.lines();
604                    while let Some(line) = lines.next().await {
605                        let line = line?;
606                        if let Some(message) = kmsg::SyslogParsedEntry::new(&line) {
607                            println!("{}", message.display(is_terminal));
608                        } else {
609                            println!("{line}");
610                        }
611                    }
612
613                    return Ok(());
614                }
615
616                if verbose {
617                    eprintln!("Connecting to the diagnostics server.");
618                }
619
620                let client = new_client(driver.clone(), &vm)?;
621                'connect: loop {
622                    if reconnect {
623                        client.wait_for_server().await?;
624                    }
625                    let mut file_stream = client.kmsg(follow).await?;
626                    if verbose {
627                        eprintln!("Connected.");
628                    }
629
630                    while let Some(data) = file_stream.next().await {
631                        match data {
632                            Ok(data) => {
633                                let message = kmsg::KmsgParsedEntry::new(&data)?;
634                                println!("{}", message.display(is_terminal));
635                            }
636                            Err(err) if reconnect && err.kind() == ErrorKind::ConnectionReset => {
637                                if verbose {
638                                    eprintln!(
639                                        "Connection reset to the diagnostics server. Reconnecting."
640                                    );
641                                }
642                                continue 'connect;
643                            }
644                            Err(err) => Err(err).context("failed to read kmsg")?,
645                        }
646                    }
647
648                    if reconnect {
649                        if verbose {
650                            eprintln!("Lost connection to the diagnostics server. Reconnecting.");
651                        }
652                        continue 'connect;
653                    }
654
655                    break;
656                }
657            }
658            Command::File { follow, file_path } => {
659                let client = new_client(driver.clone(), &vm)?;
660                let stream = client.read_file(follow, file_path).await?;
661                futures::io::copy(stream, &mut AllowStdIo::new(term::raw_stdout()))
662                    .await
663                    .context("failed to copy trace file")?;
664            }
665            Command::Gdbserver { multi, pid } => {
666                let client = new_client(driver.clone(), &vm)?;
667                // Pass the --once flag so that gdbserver exits after the stdio
668                // pipes are closed. Otherwise, gdbserver spins in a tight loop
669                // and never exits.
670                let gdbserver = "gdbserver --once";
671                let command = if multi {
672                    format!("{gdbserver} --multi -")
673                } else if let Some(pid) = pid {
674                    format!("{gdbserver} --attach - {pid}")
675                } else {
676                    format!("{gdbserver} --attach - \"$(cat /run/underhill.pid)\"")
677                };
678
679                run(&client, "/bin/sh", &["-c", &command]).await?;
680            }
681            Command::Gdbstub { port } => {
682                let vsock = match vm.id {
683                    VmId::HybridVsock(path) => {
684                        diag_client::connect_hybrid_vsock(&driver, &path, port).await?
685                    }
686                    #[cfg(windows)]
687                    VmId::HyperV(name) => {
688                        let vm_id = diag_client::hyperv::vm_id_from_name(&name)?;
689                        let stream =
690                            diag_client::hyperv::connect_vsock(&driver, vm_id, port).await?;
691                        PolledSocket::new(&driver, socket2::Socket::from(stream))?
692                    }
693                };
694
695                let vsock = Arc::new(vsock.into_inner());
696                // Spawn a thread to read stdin synchronously since pal_async
697                // does not offer a way to read it asynchronously.
698                let thread = std::thread::spawn({
699                    let vsock = vsock.clone();
700                    move || {
701                        let _ = std::io::copy(&mut std::io::stdin(), &mut vsock.as_ref());
702                    }
703                });
704
705                std::io::copy(&mut vsock.as_ref(), &mut term::raw_stdout())
706                    .context("failed stdout copy")?;
707                thread.join().unwrap();
708            }
709            Command::Crash {
710                crash_type,
711                pid,
712                name,
713            } => {
714                let client = new_client(driver.clone(), &vm)?;
715                let pid = if let Some(name) = name {
716                    client.get_pid(&name).await?
717                } else {
718                    pid.unwrap()
719                };
720                println!("Crashing PID: {pid}");
721                match crash_type {
722                    CrashType::UhPanic => {
723                        _ = client.crash(pid).await;
724                    }
725                }
726            }
727            Command::PacketCapture {
728                output,
729                seconds,
730                snaplen,
731            } => {
732                let client = new_client(driver.clone(), &vm)?;
733                println!(
734                    "Starting network packet capture. Wait for timeout or Ctrl-C to quit anytime."
735                );
736                let (_, num_streams) = client
737                    .packet_capture(PacketCaptureOperation::Query, 0, 0)
738                    .await?;
739                let file_stem = &output.file_stem().unwrap().to_string_lossy();
740                let extension = &output.extension().unwrap_or(OsStr::new("pcap"));
741                let mut new_output = PathBuf::from(&output);
742                let streams = client
743                    .packet_capture(PacketCaptureOperation::Start, num_streams, snaplen)
744                    .await?
745                    .0
746                    .into_iter()
747                    .enumerate()
748                    .map(|(i, i_stream)| {
749                        new_output.set_file_name(format!("{}-{}", &file_stem, i));
750                        new_output.set_extension(extension);
751                        let mut out = AllowStdIo::new(fs_err::File::create(&new_output)?);
752                        Ok(async move { futures::io::copy(i_stream, &mut out).await })
753                    })
754                    .collect::<Result<Vec<_>, std::io::Error>>()?;
755                capture_packets(client, streams, seconds).await;
756            }
757            Command::CoreDump {
758                verbose,
759                pid,
760                name,
761                dst,
762            } => {
763                ensure_not_terminal(&dst)?;
764                let client = new_client(driver.clone(), &vm)?;
765                let pid = if let Some(name) = name {
766                    client.get_pid(&name).await?
767                } else {
768                    pid.unwrap()
769                };
770                println!("Dumping PID: {pid}");
771                let file = create_or_stderr(&dst)?;
772                client
773                    .core_dump(
774                        pid,
775                        AllowStdIo::new(file),
776                        AllowStdIo::new(std::io::stderr()),
777                        verbose,
778                    )
779                    .await?;
780            }
781            Command::Restart => {
782                let client = new_client(driver.clone(), &vm)?;
783                client.restart().await?;
784            }
785            Command::PerfTrace { output } => {
786                ensure_not_terminal(&output)?;
787
788                let client = new_client(driver.clone(), &vm)?;
789
790                // Flush the perf trace.
791                client
792                    .update("trace/perf/flush".to_owned(), "true".to_owned())
793                    .await
794                    .context("failed to flush perf")?;
795
796                let file = create_or_stderr(&output)?;
797                let stream = client
798                    .read_file(false, "underhill.perfetto".to_owned())
799                    .await
800                    .context("failed to read trace file")?;
801
802                futures::io::copy(stream, &mut AllowStdIo::new(file))
803                    .await
804                    .context("failed to copy trace file")?;
805            }
806            Command::VsockTcpRelay {
807                vsock_port,
808                tcp_port,
809                allow_remote,
810                reconnect,
811            } => {
812                let addr = if allow_remote { "0.0.0.0" } else { "127.0.0.1" };
813                let listener = TcpListener::bind((addr, tcp_port))
814                    .with_context(|| format!("binding to port {}", tcp_port))?;
815                println!("TCP listening on {}:{}", addr, tcp_port);
816                'connect: loop {
817                    let (tcp_socket, tcp_addr) = listener.accept()?;
818                    let tcp_socket = PolledSocket::new(&driver, tcp_socket)?;
819                    println!("TCP accept on {:?}", tcp_addr);
820
821                    // TODO: support reconnect attempt for vsock like kmsg
822                    let vsock = match vm.id {
823                        VmId::HybridVsock(ref path) => {
824                            // TODO: reconnection attempt logic like kmsg is
825                            // broken for hybrid_vsock with end of file error,
826                            // if this is started before the vm is started
827                            diag_client::connect_hybrid_vsock(&driver, path, vsock_port).await?
828                        }
829                        #[cfg(windows)]
830                        VmId::HyperV(ref name) => {
831                            let vm_id = diag_client::hyperv::vm_id_from_name(name)?;
832                            let stream =
833                                diag_client::hyperv::connect_vsock(&driver, vm_id, vsock_port)
834                                    .await?;
835                            PolledSocket::new(&driver, socket2::Socket::from(stream))?
836                        }
837                    };
838                    println!("VSOCK connect to port {:?}", vsock_port);
839
840                    let (tcp_read, mut tcp_write) = tcp_socket.split();
841                    let (vsock_read, mut vsock_write) = vsock.split();
842                    let tx = futures::io::copy(tcp_read, &mut vsock_write);
843                    let rx = futures::io::copy(vsock_read, &mut tcp_write);
844                    let result = futures::future::try_join(tx, rx).await;
845                    match result {
846                        Ok(_) => {}
847                        Err(e) => match e.kind() {
848                            ErrorKind::ConnectionReset => {}
849                            _ => return Err(anyhow::Error::from(e)),
850                        },
851                    }
852                    println!("Connection closed");
853
854                    if reconnect {
855                        println!("Reconnecting...");
856                        continue 'connect;
857                    }
858
859                    break;
860                }
861            }
862            Command::Pause => {
863                let client = new_client(driver.clone(), &vm)?;
864                client.pause().await?;
865            }
866            Command::Resume => {
867                let client = new_client(driver.clone(), &vm)?;
868                client.resume().await?;
869            }
870            Command::DumpSavedState { output } => {
871                ensure_not_terminal(&output)?;
872                let client = new_client(driver.clone(), &vm)?;
873                let mut file = create_or_stderr(&output)?;
874                file.write_all(&client.dump_saved_state().await?)?;
875            }
876        }
877        Ok(())
878    })
879}
880
881fn ensure_not_terminal(path: &Option<PathBuf>) -> anyhow::Result<()> {
882    if path.is_none() && std::io::stdout().is_terminal() {
883        anyhow::bail!("cannot write to terminal");
884    }
885    Ok(())
886}
887
888fn create_or_stderr(path: &Option<PathBuf>) -> std::io::Result<fs_err::File> {
889    let file = match path {
890        Some(path) => fs_err::File::create(path)?,
891        None => fs_err::File::from_parts(term::raw_stdout(), "stdout"),
892    };
893    Ok(file)
894}
895
896async fn capture_packets(
897    client: DiagClient,
898    streams: Vec<impl std::future::Future<Output = Result<u64, std::io::Error>>>,
899    capture_duration: Duration,
900) {
901    let mut capture_streams = FuturesUnordered::from_iter(streams);
902    let (user_input_tx, mut user_input_rx) = mesh::channel();
903    ctrlc::set_handler(move || user_input_tx.send(())).expect("Error setting Ctrl-C handler");
904
905    let mut ctx = mesh::CancelContext::new().with_timeout(capture_duration);
906    let mut stop_signaled = std::pin::pin!(ctx.until_cancelled(user_input_rx.recv()));
907
908    let mut stop_streams = std::pin::pin!(async {
909        if let Err(err) = client
910            .packet_capture(PacketCaptureOperation::Stop, 0, 0)
911            .await
912        {
913            eprintln!("Failed stop: {err}");
914        }
915    });
916
917    #[derive(PartialEq)]
918    enum State {
919        Running,
920        Stopping,
921        StoppingStreamsDone,
922        Stopped,
923    }
924    let mut state = State::Running;
925    loop {
926        enum Event {
927            Continue,
928            StopSignaled,
929            StopComplete,
930            StreamsDone,
931        }
932        let stop = async {
933            match state {
934                State::Running => {
935                    (&mut stop_signaled).await.ok();
936                    Event::StopSignaled
937                }
938                State::Stopping | State::StoppingStreamsDone => {
939                    (&mut stop_streams).await;
940                    Event::StopComplete
941                }
942                State::Stopped => std::future::pending::<Event>().await,
943            }
944        };
945        let process_streams = async {
946            if state == State::StoppingStreamsDone {
947                std::future::pending::<()>().await;
948            }
949            match capture_streams.next().await {
950                Some(_) => Event::Continue,
951                None => Event::StreamsDone,
952            }
953        };
954        let event = (stop, process_streams).race();
955
956        // N.B Wait for all the copy tasks to complete to make sure the data is flushed to
957        //     ensure compatibility with the packet capture protocol.
958        match event.await {
959            Event::Continue => continue,
960            Event::StopSignaled => {
961                println!("Stopping packet capture...");
962                state = State::Stopping;
963            }
964            Event::StopComplete => {
965                println!("Waiting for data to be flushed...");
966                if state == State::Stopping {
967                    state = State::Stopped;
968                } else {
969                    break;
970                }
971            }
972            Event::StreamsDone if state == State::Stopping => {
973                state = State::StoppingStreamsDone;
974            }
975            Event::StreamsDone => {
976                if state != State::Stopped {
977                    println!("Lost connection with network.");
978                }
979                break;
980            }
981        }
982    }
983    println!("All done.");
984}