1#![expect(missing_docs)]
8#![forbid(unsafe_code)]
9
10mod completions;
11
12use anyhow::Context;
13use clap::ArgGroup;
14use clap::Args;
15use clap::Parser;
16use clap::Subcommand;
17use diag_client::DiagClient;
18use diag_client::PacketCaptureOperation;
19use futures::StreamExt;
20use futures::io::AllowStdIo;
21use futures_concurrency::future::Race;
22use pal_async::DefaultPool;
23use pal_async::driver::Driver;
24use pal_async::socket::PolledSocket;
25use pal_async::task::Spawn;
26use pal_async::timer::PolledTimer;
27use std::convert::Infallible;
28use std::ffi::OsStr;
29use std::io::ErrorKind;
30use std::io::IsTerminal;
31use std::io::Write;
32use std::net::TcpListener;
33use std::path::Path;
34use std::path::PathBuf;
35use std::str::FromStr;
36use std::sync::Arc;
37use std::time::Duration;
38use thiserror::Error;
39use tracing_subscriber::layer::SubscriberExt;
40use tracing_subscriber::util::SubscriberInitExt;
41use unicycle::FuturesUnordered;
42
43#[derive(Parser)]
44#[clap(about = "(dev) CLI to interact with the Underhill diagnostics server")]
45#[clap(long_about = r#"
46CLI to interact with the Underhill diagnostics server.
47
48DISCLAIMER:
49 `ohcldiag-dev` does not make ANY stability guarantees regarding the layout of
50 the CLI, the syntax that is emitted via stdout/stderr, the location of nodes
51 in the `inspect` graph, etc...
52
53 !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
54 !! ANY AUTOMATION THAT USES ohcldiag-dev WILL EVENTUALLY BREAK !!
55 !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
56"#)]
57struct Options {
58 #[clap(flatten)]
59 vm: VmArg,
60
61 #[clap(subcommand)]
62 command: Command,
63}
64
65#[derive(Subcommand)]
66enum Command {
67 #[clap(hide = true)]
68 Complete(clap_dyn_complete::Complete),
69 Completions(completions::Completions),
70 Shell {
72 #[clap(default_value = "/bin/sh")]
74 shell: String,
75 args: Vec<String>,
77 },
78 Run {
80 command: String,
82 args: Vec<String>,
84 },
85 #[clap(visible_alias = "i")]
87 Inspect {
88 #[clap(short)]
90 recursive: bool,
91 #[clap(short, long, requires("recursive"))]
93 limit: Option<usize>,
94 #[clap(short, long)]
96 json: bool,
97 #[clap(short)]
99 poll: bool,
100 #[clap(long, default_value = "1", requires("poll"))]
102 period: f64,
103 #[clap(long, requires("poll"))]
105 count: Option<usize>,
106 path: Option<String>,
108 #[clap(short, long, conflicts_with("recursive"))]
110 update: Option<String>,
111 #[clap(short, default_value = "1", conflicts_with("update"))]
113 timeout: u64,
114 },
115 #[clap(hide = true)]
117 Update {
118 path: String,
120 value: String,
122 },
123 Start {
128 #[clap(short, long)]
130 env: Vec<EnvString>,
131 #[clap(short, long)]
133 unset: Vec<String>,
134 args: Vec<String>,
136 },
137 Kmsg {
139 #[clap(short, long)]
141 follow: bool,
142 #[clap(short, long)]
144 reconnect: bool,
145 #[clap(short, long)]
147 verbose: bool,
148 #[cfg(windows)]
152 #[clap(long, conflicts_with = "reconnect")]
153 serial: bool,
154 #[cfg(windows)]
158 #[clap(long, requires = "serial")]
159 pipe_path: Option<String>,
160 },
161 File {
163 #[clap(short, long)]
165 follow: bool,
166 #[clap(short('p'), long)]
167 file_path: String,
168 },
169 Gdbserver {
179 #[clap(long)]
181 pid: Option<i32>,
182 #[clap(long, conflicts_with("pid"))]
184 multi: bool,
185 },
186 Gdbstub {
193 #[clap(short, long, default_value = "4")]
195 port: u32,
196 },
197 #[clap(group(
201 ArgGroup::new("process")
202 .required(true)
203 .args(&["pid", "name"]),
204 ))]
205 Crash {
206 crash_type: CrashType,
210 #[clap(short, long)]
212 pid: Option<i32>,
213 #[clap(short, long)]
215 name: Option<String>,
216 },
217 #[clap(group(
222 ArgGroup::new("process")
223 .required(true)
224 .args(&["pid", "name"]),
225 ))]
226 CoreDump {
227 #[clap(short, long)]
229 verbose: bool,
230 #[clap(short, long)]
232 pid: Option<i32>,
233 #[clap(short, long)]
235 name: Option<String>,
236 dst: Option<PathBuf>,
239 },
240 Restart,
242 PerfTrace {
245 #[clap(short)]
247 output: Option<PathBuf>,
248 },
249 VsockTcpRelay {
251 vsock_port: u32,
252 tcp_port: u16,
253 #[clap(long)]
254 allow_remote: bool,
255 #[clap(short, long)]
261 reconnect: bool,
262 },
263 Pause,
265 Resume,
267 DumpSavedState {
269 #[clap(short)]
271 output: Option<PathBuf>,
272 },
273 PacketCapture {
275 #[clap(short('w'), default_value = "nic")]
277 output: PathBuf,
278 #[clap(short('G'), long, default_value = "60", value_parser = |arg: &str| -> Result<Duration, std::num::ParseIntError> {Ok(Duration::from_secs(arg.parse()?))})]
280 seconds: Duration,
281 #[clap(short('s'), long, default_value = "65535", value_parser = clap::value_parser!(u16).range(1..))]
283 snaplen: u16,
284 },
285 MemoryProfileTrace {
287 #[clap(short, long)]
289 pid: Option<i32>,
290 #[clap(short, long)]
292 name: Option<String>,
293 #[clap(short)]
295 output: Option<PathBuf>,
296 },
297 EfiDiagnostics {
303 log_level: EfiDiagnosticsLogLevel,
308 output: EfiDiagnosticsOutput,
312 },
313}
314
315#[derive(Debug, Clone, Args)]
316pub struct VmArg {
317 #[doc = r#"VM identifier.
318
319 This can be one of:
320
321 * vsock:PATH - A path to a hybrid vsock Unix socket for a VM, as used by OpenVMM
322
323 * unix:PATH - A path to a Unix socket for connecting to the control plane
324
325 "#]
326 #[cfg_attr(
327 windows,
328 doc = "* hyperv:NAME - A Hyper-V VM name
329
330 "
331 )]
332 #[cfg_attr(
333 windows,
334 doc = "* NAME_OR_PATH - Either a Hyper-V VM name, or a path as in vsock:PATH>"
335 )]
336 #[cfg_attr(not(windows), doc = "* PATH - A path as in vsock:PATH")]
337 #[clap(name = "VM")]
338 id: VmId,
339}
340
341#[derive(Debug, Clone)]
342enum VmId {
343 #[cfg(windows)]
344 HyperV(String),
345 HybridVsock(PathBuf),
346}
347
348impl FromStr for VmId {
349 type Err = Infallible;
350
351 fn from_str(s: &str) -> Result<Self, Self::Err> {
352 if let Some(s) = s.strip_prefix("vsock:") {
353 Ok(Self::HybridVsock(Path::new(s).to_owned()))
354 } else {
355 #[cfg(windows)]
356 if let Some(s) = s.strip_prefix("hyperv:") {
357 return Ok(Self::HyperV(s.to_owned()));
358 } else if !pal::windows::fs::is_unix_socket(s.as_ref()).unwrap_or(false) {
359 return Ok(Self::HyperV(s.to_owned()));
360 }
361 Ok(Self::HybridVsock(Path::new(s).to_owned()))
364 }
365 }
366}
367
368#[derive(Clone)]
369struct EnvString {
370 name: String,
371 value: String,
372}
373
374#[derive(Clone, clap::ValueEnum)]
375enum CrashType {
376 #[clap(name = "panic")]
377 UhPanic,
378}
379
380#[derive(Clone, clap::ValueEnum)]
381enum EfiDiagnosticsLogLevel {
382 Default,
384 Info,
386 Full,
388}
389
390impl EfiDiagnosticsLogLevel {
391 fn as_inspect_value(&self) -> &'static str {
392 match self {
393 EfiDiagnosticsLogLevel::Default => "default",
394 EfiDiagnosticsLogLevel::Info => "info",
395 EfiDiagnosticsLogLevel::Full => "full",
396 }
397 }
398}
399
400#[derive(Clone, clap::ValueEnum)]
401enum EfiDiagnosticsOutput {
402 Stdout,
404 Tracing,
406}
407
408impl EfiDiagnosticsOutput {
409 fn as_inspect_value(&self) -> &'static str {
410 match self {
411 EfiDiagnosticsOutput::Stdout => "stdout",
412 EfiDiagnosticsOutput::Tracing => "tracing",
413 }
414 }
415}
416
417#[derive(Debug, Error)]
418#[error("bad environment variable, expected VAR=value")]
419struct BadEnvString;
420
421impl FromStr for EnvString {
422 type Err = BadEnvString;
423
424 fn from_str(s: &str) -> Result<Self, Self::Err> {
425 let (name, value) = s.split_once('=').ok_or(BadEnvString)?;
426 Ok(Self {
427 name: name.to_owned(),
428 value: value.to_owned(),
429 })
430 }
431}
432
433async fn run(
435 client: &DiagClient,
436 command: impl AsRef<str>,
437 args: impl IntoIterator<Item = impl AsRef<str>>,
438) -> anyhow::Result<()> {
439 let mut process = client
442 .exec(&command)
443 .args(args)
444 .stdin(true)
445 .stdout(true)
446 .stderr(true)
447 .spawn()
448 .await?;
449
450 let mut stdin = process.stdin.take().unwrap();
451 let mut stdout = process.stdout.take().unwrap();
452 let mut stderr = process.stderr.take().unwrap();
453
454 std::thread::spawn({
455 move || {
456 let _ = std::io::copy(&mut std::io::stdin(), &mut stdin);
457 }
458 });
459
460 let stderr_thread =
461 std::thread::spawn(move || std::io::copy(&mut stderr, &mut term::raw_stderr()));
462
463 std::io::copy(&mut stdout, &mut term::raw_stdout()).context("failed stdout copy")?;
464
465 stderr_thread
466 .join()
467 .unwrap()
468 .context("failed stderr thread")?;
469
470 let status = process.wait().await?;
471 std::process::exit(status.exit_code());
472}
473
474fn new_client(driver: impl Driver + Spawn + Clone, input: &VmArg) -> anyhow::Result<DiagClient> {
475 let client = match &input.id {
476 #[cfg(windows)]
477 VmId::HyperV(name) => DiagClient::from_hyperv_name(driver, name)?,
478 VmId::HybridVsock(path) => DiagClient::from_hybrid_vsock(driver, path),
479 };
480 Ok(client)
481}
482
483pub fn main() -> anyhow::Result<()> {
484 tracing_subscriber::registry()
485 .with(tracing_subscriber::fmt::layer())
486 .with(tracing_subscriber::EnvFilter::from_default_env())
487 .init();
488
489 term::enable_vt_and_utf8();
490 DefaultPool::run_with(async |driver| {
491 let Options { vm, command } = Options::parse();
492
493 match command {
494 Command::Complete(cmd) => {
495 cmd.println_to_stub_script::<Options>(
496 None,
497 completions::OhcldiagDevCompleteFactory {
498 driver: driver.clone(),
499 },
500 )
501 .await
502 }
503 Command::Completions(cmd) => cmd.run()?,
504 Command::Shell { shell, args } => {
505 let client = new_client(driver.clone(), &vm)?;
506
507 let term = std::env::var("TERM");
509 let term = term.as_deref().unwrap_or("xterm-256color");
510
511 let mut process = client
512 .exec(&shell)
513 .args(&args)
514 .tty(true)
515 .stdin(true)
516 .stdout(true)
517 .env("TERM", term)
518 .spawn()
519 .await?;
520
521 let mut stdin = process.stdin.take().unwrap();
522 let mut stdout = process.stdout.take().unwrap();
523
524 term::set_raw_console(true).expect("failed to set raw console mode");
525 std::thread::spawn({
526 move || {
527 let _ = std::io::copy(&mut std::io::stdin(), &mut stdin);
528 }
529 });
530
531 std::io::copy(&mut stdout, &mut term::raw_stdout()).context("failed copy")?;
532
533 let status = process.wait().await?;
534
535 if !status.success() {
536 eprintln!(
537 "shell exited with non-zero exit code: {}",
538 status.exit_code()
539 );
540 }
541 }
542 Command::Run { command, args } => {
543 let client = new_client(driver.clone(), &vm)?;
544 run(&client, command, &args).await?;
545 }
546 Command::Inspect {
547 recursive,
548 limit,
549 json,
550 poll,
551 period,
552 count,
553 timeout,
554
555 path,
556 update,
557 } => {
558 let client = new_client(driver.clone(), &vm)?;
559
560 if let Some(update) = update {
561 let Some(path) = path else {
562 anyhow::bail!("must provide path for update")
563 };
564
565 let value = client.update(path, update).await?;
566 match value.kind {
567 inspect::ValueKind::String(s) => println!("{s}"),
568 _ => println!("{value}"),
569 }
570 } else {
571 let timeout = if timeout == 0 {
572 None
573 } else {
574 Some(Duration::from_secs(timeout))
575 };
576 let query = async || {
577 client
578 .inspect(
579 path.as_deref().unwrap_or(""),
580 if recursive { limit } else { Some(0) },
581 timeout,
582 )
583 .await
584 };
585
586 if poll {
587 let mut timer = PolledTimer::new(&driver);
588 let period = Duration::from_secs_f64(period);
589 let mut last_time = pal_async::timer::Instant::now();
590 let mut last = query().await?;
591 let mut count = count;
592
593 loop {
594 match count.as_mut() {
595 Some(count) if *count == 0 => break,
596 Some(count) => *count -= 1,
597 None => {}
598 }
599 timer.sleep_until(last_time + period).await;
600 let now = pal_async::timer::Instant::now();
601 let this = query().await?;
602 let diff = this.since(&last, now - last_time);
603 if json {
604 println!("{}", diff.json());
605 } else {
606 println!("{diff:#}");
607 }
608 last = this;
609 last_time = now;
610 }
611 } else {
612 let node = query().await?;
613 if json {
614 println!("{}", node.json());
615 } else {
616 println!("{node:#}");
617 }
618 }
619 }
620 }
621 Command::Update { path, value } => {
622 eprintln!(
623 "`update` is deprecated - please use `ohcldiag-dev inspect <path> -u <new value>`"
624 );
625 let client = new_client(driver.clone(), &vm)?;
626 let value = client.update(path, value).await?;
627 match value.kind {
628 inspect::ValueKind::String(s) => println!("{s}"),
629 _ => println!("{value}"),
630 }
631 }
632 Command::Start { env, unset, args } => {
633 let client = new_client(driver.clone(), &vm)?;
634
635 let env = env
636 .into_iter()
637 .map(|EnvString { name, value }| (name, Some(value)))
638 .chain(unset.into_iter().map(|name| (name, None)));
639
640 client.start(env, args).await?;
641 }
642 Command::Kmsg {
643 follow,
644 reconnect,
645 verbose,
646 #[cfg(windows)]
647 serial,
648 #[cfg(windows)]
649 pipe_path,
650 } => {
651 let is_terminal = std::io::stdout().is_terminal();
652
653 #[cfg(windows)]
654 if serial {
655 use diag_client::hyperv::ComPortAccessInfo;
656 use futures::AsyncBufReadExt;
657
658 let vm_name = match &vm.id {
659 VmId::HyperV(name) => name,
660 _ => anyhow::bail!("--serial is only supported for Hyper-V VMs"),
661 };
662
663 let port_access_info = if let Some(pipe_path) = pipe_path.as_ref() {
664 ComPortAccessInfo::PortPipePath(pipe_path)
665 } else {
666 ComPortAccessInfo::NameAndPortNumber(vm_name, 3)
667 };
668
669 let pipe =
670 diag_client::hyperv::open_serial_port(&driver, port_access_info).await?;
671 let pipe = pal_async::pipe::PolledPipe::new(&driver, pipe)
672 .context("failed to make a polled pipe")?;
673 let pipe = futures::io::BufReader::new(pipe);
674
675 let mut lines = pipe.lines();
676 while let Some(line) = lines.next().await {
677 let line = line?;
678 if let Some(message) = kmsg::SyslogParsedEntry::new(&line) {
679 println!("{}", message.display(is_terminal));
680 } else {
681 println!("{line}");
682 }
683 }
684
685 return Ok(());
686 }
687
688 if verbose {
689 eprintln!("Connecting to the diagnostics server.");
690 }
691
692 let client = new_client(driver.clone(), &vm)?;
693 'connect: loop {
694 if reconnect {
695 client.wait_for_server().await?;
696 }
697 let mut file_stream = client.kmsg(follow).await?;
698 if verbose {
699 eprintln!("Connected.");
700 }
701
702 while let Some(data) = file_stream.next().await {
703 match data {
704 Ok(data) => match kmsg::KmsgParsedEntry::new(&data) {
705 Ok(message) => println!("{}", message.display(is_terminal)),
706 Err(e) => println!("Invalid kmsg entry: {e:?}"),
707 },
708 Err(err) if reconnect && err.kind() == ErrorKind::ConnectionReset => {
709 if verbose {
710 eprintln!(
711 "Connection reset to the diagnostics server. Reconnecting."
712 );
713 }
714 continue 'connect;
715 }
716 Err(err) => Err(err).context("failed to read kmsg")?,
717 }
718 }
719
720 if reconnect {
721 if verbose {
722 eprintln!("Lost connection to the diagnostics server. Reconnecting.");
723 }
724 continue 'connect;
725 }
726
727 break;
728 }
729 }
730 Command::File { follow, file_path } => {
731 let client = new_client(driver.clone(), &vm)?;
732 let stream = client.read_file(follow, file_path).await?;
733 futures::io::copy(stream, &mut AllowStdIo::new(term::raw_stdout()))
734 .await
735 .context("failed to copy trace file")?;
736 }
737 Command::Gdbserver { multi, pid } => {
738 let client = new_client(driver.clone(), &vm)?;
739 let gdbserver = "gdbserver --once";
743 let command = if multi {
744 format!("{gdbserver} --multi -")
745 } else if let Some(pid) = pid {
746 format!("{gdbserver} --attach - {pid}")
747 } else {
748 format!("{gdbserver} --attach - \"$(cat /run/underhill.pid)\"")
749 };
750
751 run(&client, "/bin/sh", &["-c", &command]).await?;
752 }
753 Command::Gdbstub { port } => {
754 let vsock = match vm.id {
755 VmId::HybridVsock(path) => {
756 diag_client::connect_hybrid_vsock(&driver, &path, port).await?
757 }
758 #[cfg(windows)]
759 VmId::HyperV(name) => {
760 let vm_id = diag_client::hyperv::vm_id_from_name(&name)?;
761 let stream =
762 diag_client::hyperv::connect_vsock(&driver, vm_id, port).await?;
763 PolledSocket::new(&driver, socket2::Socket::from(stream))?
764 }
765 };
766
767 let vsock = Arc::new(vsock.into_inner());
768 let thread = std::thread::spawn({
771 let vsock = vsock.clone();
772 move || {
773 let _ = std::io::copy(&mut std::io::stdin(), &mut vsock.as_ref());
774 }
775 });
776
777 std::io::copy(&mut vsock.as_ref(), &mut term::raw_stdout())
778 .context("failed stdout copy")?;
779 thread.join().unwrap();
780 }
781 Command::Crash {
782 crash_type,
783 pid,
784 name,
785 } => {
786 let client = new_client(driver.clone(), &vm)?;
787 let pid = if let Some(name) = name {
788 client.get_pid(&name).await?
789 } else {
790 pid.unwrap()
791 };
792 println!("Crashing PID: {pid}");
793 match crash_type {
794 CrashType::UhPanic => {
795 _ = client.crash(pid).await;
796 }
797 }
798 }
799 Command::PacketCapture {
800 output,
801 seconds,
802 snaplen,
803 } => {
804 let client = new_client(driver.clone(), &vm)?;
805 println!(
806 "Starting network packet capture. Wait for timeout or Ctrl-C to quit anytime."
807 );
808 let (_, num_streams) = client
809 .packet_capture(PacketCaptureOperation::Query, 0, 0)
810 .await?;
811 let file_stem = &output.file_stem().unwrap().to_string_lossy();
812 let extension = &output.extension().unwrap_or(OsStr::new("pcap"));
813 let mut new_output = PathBuf::from(&output);
814 let streams = client
815 .packet_capture(PacketCaptureOperation::Start, num_streams, snaplen)
816 .await?
817 .0
818 .into_iter()
819 .enumerate()
820 .map(|(i, i_stream)| {
821 new_output.set_file_name(format!("{}-{}", &file_stem, i));
822 new_output.set_extension(extension);
823 let mut out = AllowStdIo::new(fs_err::File::create(&new_output)?);
824 Ok(async move { futures::io::copy(i_stream, &mut out).await })
825 })
826 .collect::<Result<Vec<_>, std::io::Error>>()?;
827 capture_packets(client, streams, seconds).await;
828 }
829 Command::CoreDump {
830 verbose,
831 pid,
832 name,
833 dst,
834 } => {
835 ensure_not_terminal(&dst)?;
836 let client = new_client(driver.clone(), &vm)?;
837 let pid = if let Some(name) = name {
838 client.get_pid(&name).await?
839 } else {
840 pid.unwrap()
841 };
842 println!("Dumping PID: {pid}");
843 let file = create_or_stderr(&dst)?;
844 client
845 .core_dump(
846 pid,
847 AllowStdIo::new(file),
848 AllowStdIo::new(std::io::stderr()),
849 verbose,
850 )
851 .await?;
852 }
853 Command::Restart => {
854 let client = new_client(driver.clone(), &vm)?;
855 client.restart().await?;
856 }
857 Command::PerfTrace { output } => {
858 ensure_not_terminal(&output)?;
859
860 let client = new_client(driver.clone(), &vm)?;
861
862 client
864 .update("trace/perf/flush".to_owned(), "true".to_owned())
865 .await
866 .context("failed to flush perf")?;
867
868 let file = create_or_stderr(&output)?;
869 let stream = client
870 .read_file(false, "underhill.perfetto".to_owned())
871 .await
872 .context("failed to read trace file")?;
873
874 futures::io::copy(stream, &mut AllowStdIo::new(file))
875 .await
876 .context("failed to copy trace file")?;
877 }
878 Command::VsockTcpRelay {
879 vsock_port,
880 tcp_port,
881 allow_remote,
882 reconnect,
883 } => {
884 let addr = if allow_remote { "0.0.0.0" } else { "127.0.0.1" };
885 let listener = TcpListener::bind((addr, tcp_port))
886 .with_context(|| format!("binding to port {}", tcp_port))?;
887 println!("TCP listening on {}:{}", addr, tcp_port);
888 'connect: loop {
889 let (tcp_socket, tcp_addr) = listener.accept()?;
890 let tcp_socket = PolledSocket::new(&driver, tcp_socket)?;
891 println!("TCP accept on {:?}", tcp_addr);
892
893 let vsock = match vm.id {
895 VmId::HybridVsock(ref path) => {
896 diag_client::connect_hybrid_vsock(&driver, path, vsock_port).await?
900 }
901 #[cfg(windows)]
902 VmId::HyperV(ref name) => {
903 let vm_id = diag_client::hyperv::vm_id_from_name(name)?;
904 let stream =
905 diag_client::hyperv::connect_vsock(&driver, vm_id, vsock_port)
906 .await?;
907 PolledSocket::new(&driver, socket2::Socket::from(stream))?
908 }
909 };
910 println!("VSOCK connect to port {:?}", vsock_port);
911
912 let (tcp_read, mut tcp_write) = tcp_socket.split();
913 let (vsock_read, mut vsock_write) = vsock.split();
914 let tx = futures::io::copy(tcp_read, &mut vsock_write);
915 let rx = futures::io::copy(vsock_read, &mut tcp_write);
916 let result = futures::future::try_join(tx, rx).await;
917 match result {
918 Ok(_) => {}
919 Err(e) => match e.kind() {
920 ErrorKind::ConnectionReset => {}
921 _ => return Err(anyhow::Error::from(e)),
922 },
923 }
924 println!("Connection closed");
925
926 if reconnect {
927 println!("Reconnecting...");
928 continue 'connect;
929 }
930
931 break;
932 }
933 }
934 Command::Pause => {
935 let client = new_client(driver.clone(), &vm)?;
936 client.pause().await?;
937 }
938 Command::Resume => {
939 let client = new_client(driver.clone(), &vm)?;
940 client.resume().await?;
941 }
942 Command::DumpSavedState { output } => {
943 ensure_not_terminal(&output)?;
944 let client = new_client(driver.clone(), &vm)?;
945 let mut file = create_or_stderr(&output)?;
946 file.write_all(&client.dump_saved_state().await?)?;
947 }
948 Command::MemoryProfileTrace { pid, name, output } => {
949 let client = new_client(driver.clone(), &vm)?;
950 let pid = if let Some(name) = name {
951 client.get_pid(&name).await?
952 } else if let Some(pid) = pid {
953 pid
954 } else {
955 anyhow::bail!("either --pid or --name must be specified");
956 };
957 let mut file = create_or_stderr(&output)?;
961 file.write_all(&client.memory_profile_trace(pid).await?)?;
962 }
963 Command::EfiDiagnostics { log_level, output } => {
964 let client = new_client(driver.clone(), &vm)?;
965 let arg = format!(
966 "{},{}",
967 log_level.as_inspect_value(),
968 output.as_inspect_value()
969 );
970 let value = client
971 .update("vm/uefi/process_diagnostics", &arg)
972 .await
973 .context("failed to process EFI diagnostics")?;
974 match value.kind {
975 inspect::ValueKind::String(s) => print!("{s}"),
976 _ => print!("{value}"),
977 }
978 }
979 }
980 Ok(())
981 })
982}
983
984fn ensure_not_terminal(path: &Option<PathBuf>) -> anyhow::Result<()> {
985 if path.is_none() && std::io::stdout().is_terminal() {
986 anyhow::bail!("cannot write to terminal");
987 }
988 Ok(())
989}
990
991fn create_or_stderr(path: &Option<PathBuf>) -> std::io::Result<fs_err::File> {
992 let file = match path {
993 Some(path) => fs_err::File::create(path)?,
994 None => fs_err::File::from_parts(term::raw_stdout(), "stdout"),
995 };
996 Ok(file)
997}
998
999async fn capture_packets(
1000 client: DiagClient,
1001 streams: Vec<impl Future<Output = Result<u64, std::io::Error>>>,
1002 capture_duration: Duration,
1003) {
1004 let mut capture_streams = FuturesUnordered::from_iter(streams);
1005 let (user_input_tx, mut user_input_rx) = mesh::channel();
1006 ctrlc::set_handler(move || user_input_tx.send(())).expect("Error setting Ctrl-C handler");
1007
1008 let mut ctx = mesh::CancelContext::new().with_timeout(capture_duration);
1009 let mut stop_signaled = std::pin::pin!(ctx.until_cancelled(user_input_rx.recv()));
1010
1011 let mut stop_streams = std::pin::pin!(async {
1012 if let Err(err) = client
1013 .packet_capture(PacketCaptureOperation::Stop, 0, 0)
1014 .await
1015 {
1016 eprintln!("Failed stop: {err}");
1017 }
1018 });
1019
1020 #[derive(PartialEq)]
1021 enum State {
1022 Running,
1023 Stopping,
1024 StoppingStreamsDone,
1025 Stopped,
1026 }
1027 let mut state = State::Running;
1028 loop {
1029 enum Event {
1030 Continue,
1031 StopSignaled,
1032 StopComplete,
1033 StreamsDone,
1034 }
1035 let stop = async {
1036 match state {
1037 State::Running => {
1038 (&mut stop_signaled).await.ok();
1039 Event::StopSignaled
1040 }
1041 State::Stopping | State::StoppingStreamsDone => {
1042 (&mut stop_streams).await;
1043 Event::StopComplete
1044 }
1045 State::Stopped => std::future::pending::<Event>().await,
1046 }
1047 };
1048 let process_streams = async {
1049 if state == State::StoppingStreamsDone {
1050 std::future::pending::<()>().await;
1051 }
1052 match capture_streams.next().await {
1053 Some(_) => Event::Continue,
1054 None => Event::StreamsDone,
1055 }
1056 };
1057 let event = (stop, process_streams).race();
1058
1059 match event.await {
1062 Event::Continue => continue,
1063 Event::StopSignaled => {
1064 println!("Stopping packet capture...");
1065 state = State::Stopping;
1066 }
1067 Event::StopComplete => {
1068 println!("Waiting for data to be flushed...");
1069 if state == State::Stopping {
1070 state = State::Stopped;
1071 } else {
1072 break;
1073 }
1074 }
1075 Event::StreamsDone if state == State::Stopping => {
1076 state = State::StoppingStreamsDone;
1077 }
1078 Event::StreamsDone => {
1079 if state != State::Stopped {
1080 println!("Lost connection with network.");
1081 }
1082 break;
1083 }
1084 }
1085 }
1086 println!("All done.");
1087}