1#![expect(missing_docs)]
8#![forbid(unsafe_code)]
9
10mod completions;
11
12use anyhow::Context;
13use clap::ArgGroup;
14use clap::Args;
15use clap::Parser;
16use clap::Subcommand;
17use diag_client::DiagClient;
18use diag_client::PacketCaptureOperation;
19use futures::StreamExt;
20use futures::io::AllowStdIo;
21use futures_concurrency::future::Race;
22use pal_async::DefaultPool;
23use pal_async::driver::Driver;
24use pal_async::socket::PolledSocket;
25use pal_async::task::Spawn;
26use pal_async::timer::PolledTimer;
27use std::convert::Infallible;
28use std::ffi::OsStr;
29use std::io::ErrorKind;
30use std::io::IsTerminal;
31use std::io::Write;
32use std::net::TcpListener;
33use std::path::Path;
34use std::path::PathBuf;
35use std::str::FromStr;
36use std::sync::Arc;
37use std::time::Duration;
38use thiserror::Error;
39use tracing_subscriber::layer::SubscriberExt;
40use tracing_subscriber::util::SubscriberInitExt;
41use unicycle::FuturesUnordered;
42
43#[derive(Parser)]
44#[clap(about = "(dev) CLI to interact with the Underhill diagnostics server")]
45#[clap(long_about = r#"
46CLI to interact with the Underhill diagnostics server.
47
48DISCLAIMER:
49 `ohcldiag-dev` does not make ANY stability guarantees regarding the layout of
50 the CLI, the syntax that is emitted via stdout/stderr, the location of nodes
51 in the `inspect` graph, etc...
52
53 !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
54 !! ANY AUTOMATION THAT USES ohcldiag-dev WILL EVENTUALLY BREAK !!
55 !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
56"#)]
57struct Options {
58 #[clap(flatten)]
59 vm: VmArg,
60
61 #[clap(subcommand)]
62 command: Command,
63}
64
65#[derive(Subcommand)]
66enum Command {
67 #[clap(hide = true)]
68 Complete(clap_dyn_complete::Complete),
69 Completions(completions::Completions),
70 Shell {
72 #[clap(default_value = "/bin/sh")]
74 shell: String,
75 args: Vec<String>,
77 },
78 Run {
80 command: String,
82 args: Vec<String>,
84 },
85 #[clap(visible_alias = "i")]
87 Inspect {
88 #[clap(short)]
90 recursive: bool,
91 #[clap(short, long, requires("recursive"))]
93 limit: Option<usize>,
94 #[clap(short, long)]
96 json: bool,
97 #[clap(short)]
99 poll: bool,
100 #[clap(long, default_value = "1", requires("poll"))]
102 period: f64,
103 #[clap(long, requires("poll"))]
105 count: Option<usize>,
106 path: Option<String>,
108 #[clap(short, long, conflicts_with("recursive"))]
110 update: Option<String>,
111 #[clap(short, default_value = "1", conflicts_with("update"))]
113 timeout: u64,
114 },
115 #[clap(hide = true)]
117 Update {
118 path: String,
120 value: String,
122 },
123 Start {
128 #[clap(short, long)]
130 env: Vec<EnvString>,
131 #[clap(short, long)]
133 unset: Vec<String>,
134 args: Vec<String>,
136 },
137 Kmsg {
139 #[clap(short, long)]
141 follow: bool,
142 #[clap(short, long)]
144 reconnect: bool,
145 #[clap(short, long)]
147 verbose: bool,
148 #[cfg(windows)]
152 #[clap(long, conflicts_with = "reconnect")]
153 serial: bool,
154 #[cfg(windows)]
158 #[clap(long, requires = "serial")]
159 pipe_path: Option<String>,
160 },
161 File {
163 #[clap(short, long)]
165 follow: bool,
166 #[clap(short('p'), long)]
167 file_path: String,
168 },
169 Gdbserver {
179 #[clap(long)]
181 pid: Option<i32>,
182 #[clap(long, conflicts_with("pid"))]
184 multi: bool,
185 },
186 Gdbstub {
193 #[clap(short, long, default_value = "4")]
195 port: u32,
196 },
197 #[clap(group(
201 ArgGroup::new("process")
202 .required(true)
203 .args(&["pid", "name"]),
204 ))]
205 Crash {
206 crash_type: CrashType,
210 #[clap(short, long)]
212 pid: Option<i32>,
213 #[clap(short, long)]
215 name: Option<String>,
216 },
217 #[clap(group(
222 ArgGroup::new("process")
223 .required(true)
224 .args(&["pid", "name"]),
225 ))]
226 CoreDump {
227 #[clap(short, long)]
229 verbose: bool,
230 #[clap(short, long)]
232 pid: Option<i32>,
233 #[clap(short, long)]
235 name: Option<String>,
236 dst: Option<PathBuf>,
239 },
240 Restart,
242 PerfTrace {
245 #[clap(short)]
247 output: Option<PathBuf>,
248 },
249 VsockTcpRelay {
251 vsock_port: u32,
252 tcp_port: u16,
253 #[clap(long)]
254 allow_remote: bool,
255 #[clap(short, long)]
261 reconnect: bool,
262 },
263 Pause,
265 Resume,
267 DumpSavedState {
269 #[clap(short)]
271 output: Option<PathBuf>,
272 },
273 PacketCapture {
275 #[clap(short('w'), default_value = "nic")]
277 output: PathBuf,
278 #[clap(short('G'), long, default_value = "60", value_parser = |arg: &str| -> Result<Duration, std::num::ParseIntError> {Ok(Duration::from_secs(arg.parse()?))})]
280 seconds: Duration,
281 #[clap(short('s'), long, default_value = "65535", value_parser = clap::value_parser!(u16).range(1..))]
283 snaplen: u16,
284 },
285}
286
287#[derive(Debug, Clone, Args)]
288pub struct VmArg {
289 #[doc = r#"VM identifier.
290
291 This can be one of:
292
293 * vsock:PATH - A path to a hybrid vsock Unix socket for a VM, as used by OpenVMM
294
295 * unix:PATH - A path to a Unix socket for connecting to the control plane
296
297 "#]
298 #[cfg_attr(
299 windows,
300 doc = "* hyperv:NAME - A Hyper-V VM name
301
302 "
303 )]
304 #[cfg_attr(
305 windows,
306 doc = "* NAME_OR_PATH - Either a Hyper-V VM name, or a path as in vsock:PATH>"
307 )]
308 #[cfg_attr(not(windows), doc = "* PATH - A path as in vsock:PATH")]
309 #[clap(name = "VM")]
310 id: VmId,
311}
312
313#[derive(Debug, Clone)]
314enum VmId {
315 #[cfg(windows)]
316 HyperV(String),
317 HybridVsock(PathBuf),
318}
319
320impl FromStr for VmId {
321 type Err = Infallible;
322
323 fn from_str(s: &str) -> Result<Self, Self::Err> {
324 if let Some(s) = s.strip_prefix("vsock:") {
325 Ok(Self::HybridVsock(Path::new(s).to_owned()))
326 } else {
327 #[cfg(windows)]
328 if let Some(s) = s.strip_prefix("hyperv:") {
329 return Ok(Self::HyperV(s.to_owned()));
330 } else if !pal::windows::fs::is_unix_socket(s.as_ref()).unwrap_or(false) {
331 return Ok(Self::HyperV(s.to_owned()));
332 }
333 Ok(Self::HybridVsock(Path::new(s).to_owned()))
336 }
337 }
338}
339
340#[derive(Clone)]
341struct EnvString {
342 name: String,
343 value: String,
344}
345
346#[derive(Clone, clap::ValueEnum)]
347enum CrashType {
348 #[clap(name = "panic")]
349 UhPanic,
350}
351
352#[derive(Debug, Error)]
353#[error("bad environment variable, expected VAR=value")]
354struct BadEnvString;
355
356impl FromStr for EnvString {
357 type Err = BadEnvString;
358
359 fn from_str(s: &str) -> Result<Self, Self::Err> {
360 let (name, value) = s.split_once('=').ok_or(BadEnvString)?;
361 Ok(Self {
362 name: name.to_owned(),
363 value: value.to_owned(),
364 })
365 }
366}
367
368async fn run(
370 client: &DiagClient,
371 command: impl AsRef<str>,
372 args: impl IntoIterator<Item = impl AsRef<str>>,
373) -> anyhow::Result<()> {
374 let mut process = client
377 .exec(&command)
378 .args(args)
379 .stdin(true)
380 .stdout(true)
381 .stderr(true)
382 .spawn()
383 .await?;
384
385 let mut stdin = process.stdin.take().unwrap();
386 let mut stdout = process.stdout.take().unwrap();
387 let mut stderr = process.stderr.take().unwrap();
388
389 std::thread::spawn({
390 move || {
391 let _ = std::io::copy(&mut std::io::stdin(), &mut stdin);
392 }
393 });
394
395 let stderr_thread =
396 std::thread::spawn(move || std::io::copy(&mut stderr, &mut term::raw_stderr()));
397
398 std::io::copy(&mut stdout, &mut term::raw_stdout()).context("failed stdout copy")?;
399
400 stderr_thread
401 .join()
402 .unwrap()
403 .context("failed stderr thread")?;
404
405 let status = process.wait().await?;
406 std::process::exit(status.exit_code());
407}
408
409fn new_client(driver: impl Driver + Spawn + Clone, input: &VmArg) -> anyhow::Result<DiagClient> {
410 let client = match &input.id {
411 #[cfg(windows)]
412 VmId::HyperV(name) => DiagClient::from_hyperv_name(driver, name)?,
413 VmId::HybridVsock(path) => DiagClient::from_hybrid_vsock(driver, path),
414 };
415 Ok(client)
416}
417
418pub fn main() -> anyhow::Result<()> {
419 tracing_subscriber::registry()
420 .with(tracing_subscriber::fmt::layer())
421 .with(tracing_subscriber::EnvFilter::from_default_env())
422 .init();
423
424 term::enable_vt_and_utf8();
425 DefaultPool::run_with(async |driver| {
426 let Options { vm, command } = Options::parse();
427
428 match command {
429 Command::Complete(cmd) => {
430 cmd.println_to_stub_script::<Options>(
431 None,
432 completions::OhcldiagDevCompleteFactory {
433 driver: driver.clone(),
434 },
435 )
436 .await
437 }
438 Command::Completions(cmd) => cmd.run()?,
439 Command::Shell { shell, args } => {
440 let client = new_client(driver.clone(), &vm)?;
441
442 let term = std::env::var("TERM");
444 let term = term.as_deref().unwrap_or("xterm-256color");
445
446 let mut process = client
447 .exec(&shell)
448 .args(&args)
449 .tty(true)
450 .stdin(true)
451 .stdout(true)
452 .env("TERM", term)
453 .spawn()
454 .await?;
455
456 let mut stdin = process.stdin.take().unwrap();
457 let mut stdout = process.stdout.take().unwrap();
458
459 term::set_raw_console(true).expect("failed to set raw console mode");
460 std::thread::spawn({
461 move || {
462 let _ = std::io::copy(&mut std::io::stdin(), &mut stdin);
463 }
464 });
465
466 std::io::copy(&mut stdout, &mut term::raw_stdout()).context("failed copy")?;
467
468 let status = process.wait().await?;
469
470 if !status.success() {
471 eprintln!(
472 "shell exited with non-zero exit code: {}",
473 status.exit_code()
474 );
475 }
476 }
477 Command::Run { command, args } => {
478 let client = new_client(driver.clone(), &vm)?;
479 run(&client, command, &args).await?;
480 }
481 Command::Inspect {
482 recursive,
483 limit,
484 json,
485 poll,
486 period,
487 count,
488 timeout,
489
490 path,
491 update,
492 } => {
493 let client = new_client(driver.clone(), &vm)?;
494
495 if let Some(update) = update {
496 let Some(path) = path else {
497 anyhow::bail!("must provide path for update")
498 };
499
500 let value = client.update(path, update).await?;
501 println!("{value}");
502 } else {
503 let timeout = if timeout == 0 {
504 None
505 } else {
506 Some(Duration::from_secs(timeout))
507 };
508 let query = async || {
509 client
510 .inspect(
511 path.as_deref().unwrap_or(""),
512 if recursive { limit } else { Some(0) },
513 timeout,
514 )
515 .await
516 };
517
518 if poll {
519 let mut timer = PolledTimer::new(&driver);
520 let period = Duration::from_secs_f64(period);
521 let mut last_time = pal_async::timer::Instant::now();
522 let mut last = query().await?;
523 let mut count = count;
524
525 loop {
526 match count.as_mut() {
527 Some(count) if *count == 0 => break,
528 Some(count) => *count -= 1,
529 None => {}
530 }
531 timer.sleep_until(last_time + period).await;
532 let now = pal_async::timer::Instant::now();
533 let this = query().await?;
534 let diff = this.since(&last, now - last_time);
535 if json {
536 println!("{}", diff.json());
537 } else {
538 println!("{diff:#}");
539 }
540 last = this;
541 last_time = now;
542 }
543 } else {
544 let node = query().await?;
545 if json {
546 println!("{}", node.json());
547 } else {
548 println!("{node:#}");
549 }
550 }
551 }
552 }
553 Command::Update { path, value } => {
554 eprintln!(
555 "`update` is deprecated - please use `ohcldiag-dev inspect <path> -u <new value>`"
556 );
557 let client = new_client(driver.clone(), &vm)?;
558 let value = client.update(path, value).await?;
559 println!("{value}");
560 }
561 Command::Start { env, unset, args } => {
562 let client = new_client(driver.clone(), &vm)?;
563
564 let env = env
565 .into_iter()
566 .map(|EnvString { name, value }| (name, Some(value)))
567 .chain(unset.into_iter().map(|name| (name, None)));
568
569 client.start(env, args).await?;
570 }
571 Command::Kmsg {
572 follow,
573 reconnect,
574 verbose,
575 #[cfg(windows)]
576 serial,
577 #[cfg(windows)]
578 pipe_path,
579 } => {
580 let is_terminal = std::io::stdout().is_terminal();
581
582 #[cfg(windows)]
583 if serial {
584 use diag_client::hyperv::ComPortAccessInfo;
585 use futures::AsyncBufReadExt;
586
587 let vm_name = match &vm.id {
588 VmId::HyperV(name) => name,
589 _ => anyhow::bail!("--serial is only supported for Hyper-V VMs"),
590 };
591
592 let port_access_info = if let Some(pipe_path) = pipe_path.as_ref() {
593 ComPortAccessInfo::PortPipePath(pipe_path)
594 } else {
595 ComPortAccessInfo::NameAndPortNumber(vm_name, 3)
596 };
597
598 let pipe =
599 diag_client::hyperv::open_serial_port(&driver, port_access_info).await?;
600 let pipe = pal_async::pipe::PolledPipe::new(&driver, pipe)
601 .context("failed to make a polled pipe")?;
602 let pipe = futures::io::BufReader::new(pipe);
603
604 let mut lines = pipe.lines();
605 while let Some(line) = lines.next().await {
606 let line = line?;
607 if let Some(message) = kmsg::SyslogParsedEntry::new(&line) {
608 println!("{}", message.display(is_terminal));
609 } else {
610 println!("{line}");
611 }
612 }
613
614 return Ok(());
615 }
616
617 if verbose {
618 eprintln!("Connecting to the diagnostics server.");
619 }
620
621 let client = new_client(driver.clone(), &vm)?;
622 'connect: loop {
623 if reconnect {
624 client.wait_for_server().await?;
625 }
626 let mut file_stream = client.kmsg(follow).await?;
627 if verbose {
628 eprintln!("Connected.");
629 }
630
631 while let Some(data) = file_stream.next().await {
632 match data {
633 Ok(data) => match kmsg::KmsgParsedEntry::new(&data) {
634 Ok(message) => println!("{}", message.display(is_terminal)),
635 Err(e) => println!("Invalid kmsg entry: {e:?}"),
636 },
637 Err(err) if reconnect && err.kind() == ErrorKind::ConnectionReset => {
638 if verbose {
639 eprintln!(
640 "Connection reset to the diagnostics server. Reconnecting."
641 );
642 }
643 continue 'connect;
644 }
645 Err(err) => Err(err).context("failed to read kmsg")?,
646 }
647 }
648
649 if reconnect {
650 if verbose {
651 eprintln!("Lost connection to the diagnostics server. Reconnecting.");
652 }
653 continue 'connect;
654 }
655
656 break;
657 }
658 }
659 Command::File { follow, file_path } => {
660 let client = new_client(driver.clone(), &vm)?;
661 let stream = client.read_file(follow, file_path).await?;
662 futures::io::copy(stream, &mut AllowStdIo::new(term::raw_stdout()))
663 .await
664 .context("failed to copy trace file")?;
665 }
666 Command::Gdbserver { multi, pid } => {
667 let client = new_client(driver.clone(), &vm)?;
668 let gdbserver = "gdbserver --once";
672 let command = if multi {
673 format!("{gdbserver} --multi -")
674 } else if let Some(pid) = pid {
675 format!("{gdbserver} --attach - {pid}")
676 } else {
677 format!("{gdbserver} --attach - \"$(cat /run/underhill.pid)\"")
678 };
679
680 run(&client, "/bin/sh", &["-c", &command]).await?;
681 }
682 Command::Gdbstub { port } => {
683 let vsock = match vm.id {
684 VmId::HybridVsock(path) => {
685 diag_client::connect_hybrid_vsock(&driver, &path, port).await?
686 }
687 #[cfg(windows)]
688 VmId::HyperV(name) => {
689 let vm_id = diag_client::hyperv::vm_id_from_name(&name)?;
690 let stream =
691 diag_client::hyperv::connect_vsock(&driver, vm_id, port).await?;
692 PolledSocket::new(&driver, socket2::Socket::from(stream))?
693 }
694 };
695
696 let vsock = Arc::new(vsock.into_inner());
697 let thread = std::thread::spawn({
700 let vsock = vsock.clone();
701 move || {
702 let _ = std::io::copy(&mut std::io::stdin(), &mut vsock.as_ref());
703 }
704 });
705
706 std::io::copy(&mut vsock.as_ref(), &mut term::raw_stdout())
707 .context("failed stdout copy")?;
708 thread.join().unwrap();
709 }
710 Command::Crash {
711 crash_type,
712 pid,
713 name,
714 } => {
715 let client = new_client(driver.clone(), &vm)?;
716 let pid = if let Some(name) = name {
717 client.get_pid(&name).await?
718 } else {
719 pid.unwrap()
720 };
721 println!("Crashing PID: {pid}");
722 match crash_type {
723 CrashType::UhPanic => {
724 _ = client.crash(pid).await;
725 }
726 }
727 }
728 Command::PacketCapture {
729 output,
730 seconds,
731 snaplen,
732 } => {
733 let client = new_client(driver.clone(), &vm)?;
734 println!(
735 "Starting network packet capture. Wait for timeout or Ctrl-C to quit anytime."
736 );
737 let (_, num_streams) = client
738 .packet_capture(PacketCaptureOperation::Query, 0, 0)
739 .await?;
740 let file_stem = &output.file_stem().unwrap().to_string_lossy();
741 let extension = &output.extension().unwrap_or(OsStr::new("pcap"));
742 let mut new_output = PathBuf::from(&output);
743 let streams = client
744 .packet_capture(PacketCaptureOperation::Start, num_streams, snaplen)
745 .await?
746 .0
747 .into_iter()
748 .enumerate()
749 .map(|(i, i_stream)| {
750 new_output.set_file_name(format!("{}-{}", &file_stem, i));
751 new_output.set_extension(extension);
752 let mut out = AllowStdIo::new(fs_err::File::create(&new_output)?);
753 Ok(async move { futures::io::copy(i_stream, &mut out).await })
754 })
755 .collect::<Result<Vec<_>, std::io::Error>>()?;
756 capture_packets(client, streams, seconds).await;
757 }
758 Command::CoreDump {
759 verbose,
760 pid,
761 name,
762 dst,
763 } => {
764 ensure_not_terminal(&dst)?;
765 let client = new_client(driver.clone(), &vm)?;
766 let pid = if let Some(name) = name {
767 client.get_pid(&name).await?
768 } else {
769 pid.unwrap()
770 };
771 println!("Dumping PID: {pid}");
772 let file = create_or_stderr(&dst)?;
773 client
774 .core_dump(
775 pid,
776 AllowStdIo::new(file),
777 AllowStdIo::new(std::io::stderr()),
778 verbose,
779 )
780 .await?;
781 }
782 Command::Restart => {
783 let client = new_client(driver.clone(), &vm)?;
784 client.restart().await?;
785 }
786 Command::PerfTrace { output } => {
787 ensure_not_terminal(&output)?;
788
789 let client = new_client(driver.clone(), &vm)?;
790
791 client
793 .update("trace/perf/flush".to_owned(), "true".to_owned())
794 .await
795 .context("failed to flush perf")?;
796
797 let file = create_or_stderr(&output)?;
798 let stream = client
799 .read_file(false, "underhill.perfetto".to_owned())
800 .await
801 .context("failed to read trace file")?;
802
803 futures::io::copy(stream, &mut AllowStdIo::new(file))
804 .await
805 .context("failed to copy trace file")?;
806 }
807 Command::VsockTcpRelay {
808 vsock_port,
809 tcp_port,
810 allow_remote,
811 reconnect,
812 } => {
813 let addr = if allow_remote { "0.0.0.0" } else { "127.0.0.1" };
814 let listener = TcpListener::bind((addr, tcp_port))
815 .with_context(|| format!("binding to port {}", tcp_port))?;
816 println!("TCP listening on {}:{}", addr, tcp_port);
817 'connect: loop {
818 let (tcp_socket, tcp_addr) = listener.accept()?;
819 let tcp_socket = PolledSocket::new(&driver, tcp_socket)?;
820 println!("TCP accept on {:?}", tcp_addr);
821
822 let vsock = match vm.id {
824 VmId::HybridVsock(ref path) => {
825 diag_client::connect_hybrid_vsock(&driver, path, vsock_port).await?
829 }
830 #[cfg(windows)]
831 VmId::HyperV(ref name) => {
832 let vm_id = diag_client::hyperv::vm_id_from_name(name)?;
833 let stream =
834 diag_client::hyperv::connect_vsock(&driver, vm_id, vsock_port)
835 .await?;
836 PolledSocket::new(&driver, socket2::Socket::from(stream))?
837 }
838 };
839 println!("VSOCK connect to port {:?}", vsock_port);
840
841 let (tcp_read, mut tcp_write) = tcp_socket.split();
842 let (vsock_read, mut vsock_write) = vsock.split();
843 let tx = futures::io::copy(tcp_read, &mut vsock_write);
844 let rx = futures::io::copy(vsock_read, &mut tcp_write);
845 let result = futures::future::try_join(tx, rx).await;
846 match result {
847 Ok(_) => {}
848 Err(e) => match e.kind() {
849 ErrorKind::ConnectionReset => {}
850 _ => return Err(anyhow::Error::from(e)),
851 },
852 }
853 println!("Connection closed");
854
855 if reconnect {
856 println!("Reconnecting...");
857 continue 'connect;
858 }
859
860 break;
861 }
862 }
863 Command::Pause => {
864 let client = new_client(driver.clone(), &vm)?;
865 client.pause().await?;
866 }
867 Command::Resume => {
868 let client = new_client(driver.clone(), &vm)?;
869 client.resume().await?;
870 }
871 Command::DumpSavedState { output } => {
872 ensure_not_terminal(&output)?;
873 let client = new_client(driver.clone(), &vm)?;
874 let mut file = create_or_stderr(&output)?;
875 file.write_all(&client.dump_saved_state().await?)?;
876 }
877 }
878 Ok(())
879 })
880}
881
882fn ensure_not_terminal(path: &Option<PathBuf>) -> anyhow::Result<()> {
883 if path.is_none() && std::io::stdout().is_terminal() {
884 anyhow::bail!("cannot write to terminal");
885 }
886 Ok(())
887}
888
889fn create_or_stderr(path: &Option<PathBuf>) -> std::io::Result<fs_err::File> {
890 let file = match path {
891 Some(path) => fs_err::File::create(path)?,
892 None => fs_err::File::from_parts(term::raw_stdout(), "stdout"),
893 };
894 Ok(file)
895}
896
897async fn capture_packets(
898 client: DiagClient,
899 streams: Vec<impl Future<Output = Result<u64, std::io::Error>>>,
900 capture_duration: Duration,
901) {
902 let mut capture_streams = FuturesUnordered::from_iter(streams);
903 let (user_input_tx, mut user_input_rx) = mesh::channel();
904 ctrlc::set_handler(move || user_input_tx.send(())).expect("Error setting Ctrl-C handler");
905
906 let mut ctx = mesh::CancelContext::new().with_timeout(capture_duration);
907 let mut stop_signaled = std::pin::pin!(ctx.until_cancelled(user_input_rx.recv()));
908
909 let mut stop_streams = std::pin::pin!(async {
910 if let Err(err) = client
911 .packet_capture(PacketCaptureOperation::Stop, 0, 0)
912 .await
913 {
914 eprintln!("Failed stop: {err}");
915 }
916 });
917
918 #[derive(PartialEq)]
919 enum State {
920 Running,
921 Stopping,
922 StoppingStreamsDone,
923 Stopped,
924 }
925 let mut state = State::Running;
926 loop {
927 enum Event {
928 Continue,
929 StopSignaled,
930 StopComplete,
931 StreamsDone,
932 }
933 let stop = async {
934 match state {
935 State::Running => {
936 (&mut stop_signaled).await.ok();
937 Event::StopSignaled
938 }
939 State::Stopping | State::StoppingStreamsDone => {
940 (&mut stop_streams).await;
941 Event::StopComplete
942 }
943 State::Stopped => std::future::pending::<Event>().await,
944 }
945 };
946 let process_streams = async {
947 if state == State::StoppingStreamsDone {
948 std::future::pending::<()>().await;
949 }
950 match capture_streams.next().await {
951 Some(_) => Event::Continue,
952 None => Event::StreamsDone,
953 }
954 };
955 let event = (stop, process_streams).race();
956
957 match event.await {
960 Event::Continue => continue,
961 Event::StopSignaled => {
962 println!("Stopping packet capture...");
963 state = State::Stopping;
964 }
965 Event::StopComplete => {
966 println!("Waiting for data to be flushed...");
967 if state == State::Stopping {
968 state = State::Stopped;
969 } else {
970 break;
971 }
972 }
973 Event::StreamsDone if state == State::Stopping => {
974 state = State::StoppingStreamsDone;
975 }
976 Event::StreamsDone => {
977 if state != State::Stopped {
978 println!("Lost connection with network.");
979 }
980 break;
981 }
982 }
983 }
984 println!("All done.");
985}