1#![expect(missing_docs)]
8
9mod completions;
10
11use anyhow::Context;
12use clap::ArgGroup;
13use clap::Args;
14use clap::Parser;
15use clap::Subcommand;
16use diag_client::DiagClient;
17use diag_client::PacketCaptureOperation;
18use futures::StreamExt;
19use futures::io::AllowStdIo;
20use futures_concurrency::future::Race;
21use pal_async::DefaultPool;
22use pal_async::driver::Driver;
23use pal_async::socket::PolledSocket;
24use pal_async::task::Spawn;
25use pal_async::timer::PolledTimer;
26use std::convert::Infallible;
27use std::ffi::OsStr;
28use std::io::ErrorKind;
29use std::io::IsTerminal;
30use std::io::Write;
31use std::net::TcpListener;
32use std::path::Path;
33use std::path::PathBuf;
34use std::str::FromStr;
35use std::sync::Arc;
36use std::time::Duration;
37use thiserror::Error;
38use tracing_subscriber::layer::SubscriberExt;
39use tracing_subscriber::util::SubscriberInitExt;
40use unicycle::FuturesUnordered;
41
42#[derive(Parser)]
43#[clap(about = "(dev) CLI to interact with the Underhill diagnostics server")]
44#[clap(long_about = r#"
45CLI to interact with the Underhill diagnostics server.
46
47DISCLAIMER:
48 `ohcldiag-dev` does not make ANY stability guarantees regarding the layout of
49 the CLI, the syntax that is emitted via stdout/stderr, the location of nodes
50 in the `inspect` graph, etc...
51
52 !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
53 !! ANY AUTOMATION THAT USES ohcldiag-dev WILL EVENTUALLY BREAK !!
54 !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
55"#)]
56struct Options {
57 #[clap(flatten)]
58 vm: VmArg,
59
60 #[clap(subcommand)]
61 command: Command,
62}
63
64#[derive(Subcommand)]
65enum Command {
66 #[clap(hide = true)]
67 Complete(clap_dyn_complete::Complete),
68 Completions(completions::Completions),
69 Shell {
71 #[clap(default_value = "/bin/sh")]
73 shell: String,
74 args: Vec<String>,
76 },
77 Run {
79 command: String,
81 args: Vec<String>,
83 },
84 #[clap(visible_alias = "i")]
86 Inspect {
87 #[clap(short)]
89 recursive: bool,
90 #[clap(short, long, requires("recursive"))]
92 limit: Option<usize>,
93 #[clap(short, long)]
95 json: bool,
96 #[clap(short)]
98 poll: bool,
99 #[clap(long, default_value = "1", requires("poll"))]
101 period: f64,
102 #[clap(long, requires("poll"))]
104 count: Option<usize>,
105 path: Option<String>,
107 #[clap(short, long, conflicts_with("recursive"))]
109 update: Option<String>,
110 #[clap(short, default_value = "1", conflicts_with("update"))]
112 timeout: u64,
113 },
114 #[clap(hide = true)]
116 Update {
117 path: String,
119 value: String,
121 },
122 Start {
127 #[clap(short, long)]
129 env: Vec<EnvString>,
130 #[clap(short, long)]
132 unset: Vec<String>,
133 args: Vec<String>,
135 },
136 Kmsg {
138 #[clap(short, long)]
140 follow: bool,
141 #[clap(short, long)]
143 reconnect: bool,
144 #[clap(short, long)]
146 verbose: bool,
147 #[cfg(windows)]
151 #[clap(long, conflicts_with = "reconnect")]
152 serial: bool,
153 #[cfg(windows)]
157 #[clap(long, requires = "serial")]
158 pipe_path: Option<String>,
159 },
160 File {
162 #[clap(short, long)]
164 follow: bool,
165 #[clap(short('p'), long)]
166 file_path: String,
167 },
168 Gdbserver {
178 #[clap(long)]
180 pid: Option<i32>,
181 #[clap(long, conflicts_with("pid"))]
183 multi: bool,
184 },
185 Gdbstub {
192 #[clap(short, long, default_value = "4")]
194 port: u32,
195 },
196 #[clap(group(
200 ArgGroup::new("process")
201 .required(true)
202 .args(&["pid", "name"]),
203 ))]
204 Crash {
205 crash_type: CrashType,
209 #[clap(short, long)]
211 pid: Option<i32>,
212 #[clap(short, long)]
214 name: Option<String>,
215 },
216 #[clap(group(
221 ArgGroup::new("process")
222 .required(true)
223 .args(&["pid", "name"]),
224 ))]
225 CoreDump {
226 #[clap(short, long)]
228 verbose: bool,
229 #[clap(short, long)]
231 pid: Option<i32>,
232 #[clap(short, long)]
234 name: Option<String>,
235 dst: Option<PathBuf>,
238 },
239 Restart,
241 PerfTrace {
244 #[clap(short)]
246 output: Option<PathBuf>,
247 },
248 VsockTcpRelay {
250 vsock_port: u32,
251 tcp_port: u16,
252 #[clap(long)]
253 allow_remote: bool,
254 #[clap(short, long)]
260 reconnect: bool,
261 },
262 Pause,
264 Resume,
266 DumpSavedState {
268 #[clap(short)]
270 output: Option<PathBuf>,
271 },
272 PacketCapture {
274 #[clap(short('w'), default_value = "nic")]
276 output: PathBuf,
277 #[clap(short('G'), long, default_value = "60", value_parser = |arg: &str| -> Result<Duration, std::num::ParseIntError> {Ok(Duration::from_secs(arg.parse()?))})]
279 seconds: Duration,
280 #[clap(short('s'), long, default_value = "65535", value_parser = clap::value_parser!(u16).range(1..))]
282 snaplen: u16,
283 },
284}
285
286#[derive(Debug, Clone, Args)]
287pub struct VmArg {
288 #[doc = r#"VM identifier.
289
290 This can be one of:
291
292 * vsock:PATH - A path to a hybrid vsock Unix socket for a VM, as used by HvLite
293
294 * unix:PATH - A path to a Unix socket for connecting to the control plane
295
296 "#]
297 #[cfg_attr(
298 windows,
299 doc = "* hyperv:NAME - A Hyper-V VM name
300
301 "
302 )]
303 #[cfg_attr(
304 windows,
305 doc = "* NAME_OR_PATH - Either a Hyper-V VM name, or a path as in vsock:PATH>"
306 )]
307 #[cfg_attr(not(windows), doc = "* PATH - A path as in vsock:PATH")]
308 #[clap(name = "VM")]
309 id: VmId,
310}
311
312#[derive(Debug, Clone)]
313enum VmId {
314 #[cfg(windows)]
315 HyperV(String),
316 HybridVsock(PathBuf),
317}
318
319impl FromStr for VmId {
320 type Err = Infallible;
321
322 fn from_str(s: &str) -> Result<Self, Self::Err> {
323 if let Some(s) = s.strip_prefix("vsock:") {
324 Ok(Self::HybridVsock(Path::new(s).to_owned()))
325 } else {
326 #[cfg(windows)]
327 if let Some(s) = s.strip_prefix("hyperv:") {
328 return Ok(Self::HyperV(s.to_owned()));
329 } else if !pal::windows::fs::is_unix_socket(s.as_ref()).unwrap_or(false) {
330 return Ok(Self::HyperV(s.to_owned()));
331 }
332 Ok(Self::HybridVsock(Path::new(s).to_owned()))
335 }
336 }
337}
338
339#[derive(Clone)]
340struct EnvString {
341 name: String,
342 value: String,
343}
344
345#[derive(Clone, clap::ValueEnum)]
346enum CrashType {
347 #[clap(name = "panic")]
348 UhPanic,
349}
350
351#[derive(Debug, Error)]
352#[error("bad environment variable, expected VAR=value")]
353struct BadEnvString;
354
355impl FromStr for EnvString {
356 type Err = BadEnvString;
357
358 fn from_str(s: &str) -> Result<Self, Self::Err> {
359 let (name, value) = s.split_once('=').ok_or(BadEnvString)?;
360 Ok(Self {
361 name: name.to_owned(),
362 value: value.to_owned(),
363 })
364 }
365}
366
367async fn run(
369 client: &DiagClient,
370 command: impl AsRef<str>,
371 args: impl IntoIterator<Item = impl AsRef<str>>,
372) -> anyhow::Result<()> {
373 let mut process = client
376 .exec(&command)
377 .args(args)
378 .stdin(true)
379 .stdout(true)
380 .stderr(true)
381 .spawn()
382 .await?;
383
384 let mut stdin = process.stdin.take().unwrap();
385 let mut stdout = process.stdout.take().unwrap();
386 let mut stderr = process.stderr.take().unwrap();
387
388 std::thread::spawn({
389 move || {
390 let _ = std::io::copy(&mut std::io::stdin(), &mut stdin);
391 }
392 });
393
394 let stderr_thread =
395 std::thread::spawn(move || std::io::copy(&mut stderr, &mut term::raw_stderr()));
396
397 std::io::copy(&mut stdout, &mut term::raw_stdout()).context("failed stdout copy")?;
398
399 stderr_thread
400 .join()
401 .unwrap()
402 .context("failed stdout copy")?;
403
404 let status = process.wait().await?;
405 std::process::exit(status.exit_code());
406}
407
408fn new_client(driver: impl Driver + Spawn + Clone, input: &VmArg) -> anyhow::Result<DiagClient> {
409 let client = match &input.id {
410 #[cfg(windows)]
411 VmId::HyperV(name) => DiagClient::from_hyperv_name(driver, name)?,
412 VmId::HybridVsock(path) => DiagClient::from_hybrid_vsock(driver, path),
413 };
414 Ok(client)
415}
416
417pub fn main() -> anyhow::Result<()> {
418 tracing_subscriber::registry()
419 .with(tracing_subscriber::fmt::layer())
420 .with(tracing_subscriber::EnvFilter::from_default_env())
421 .init();
422
423 term::enable_vt_and_utf8();
424 DefaultPool::run_with(async |driver| {
425 let Options { vm, command } = Options::parse();
426
427 match command {
428 Command::Complete(cmd) => {
429 cmd.println_to_stub_script::<Options>(
430 None,
431 completions::OhcldiagDevCompleteFactory {
432 driver: driver.clone(),
433 },
434 )
435 .await
436 }
437 Command::Completions(cmd) => cmd.run()?,
438 Command::Shell { shell, args } => {
439 let client = new_client(driver.clone(), &vm)?;
440
441 let term = std::env::var("TERM");
443 let term = term.as_deref().unwrap_or("xterm-256color");
444
445 let mut process = client
446 .exec(&shell)
447 .args(&args)
448 .tty(true)
449 .stdin(true)
450 .stdout(true)
451 .env("TERM", term)
452 .spawn()
453 .await?;
454
455 let mut stdin = process.stdin.take().unwrap();
456 let mut stdout = process.stdout.take().unwrap();
457
458 term::set_raw_console(true);
459 std::thread::spawn({
460 move || {
461 let _ = std::io::copy(&mut std::io::stdin(), &mut stdin);
462 }
463 });
464
465 std::io::copy(&mut stdout, &mut term::raw_stdout()).context("failed copy")?;
466
467 let status = process.wait().await?;
468
469 if !status.success() {
470 eprintln!(
471 "shell exited with non-zero exit code: {}",
472 status.exit_code()
473 );
474 }
475 }
476 Command::Run { command, args } => {
477 let client = new_client(driver.clone(), &vm)?;
478 run(&client, command, &args).await?;
479 }
480 Command::Inspect {
481 recursive,
482 limit,
483 json,
484 poll,
485 period,
486 count,
487 timeout,
488
489 path,
490 update,
491 } => {
492 let client = new_client(driver.clone(), &vm)?;
493
494 if let Some(update) = update {
495 let Some(path) = path else {
496 anyhow::bail!("must provide path for update")
497 };
498
499 let value = client.update(path, update).await?;
500 println!("{value}");
501 } else {
502 let timeout = if timeout == 0 {
503 None
504 } else {
505 Some(Duration::from_secs(timeout))
506 };
507 let query = async || {
508 client
509 .inspect(
510 path.as_deref().unwrap_or(""),
511 if recursive { limit } else { Some(0) },
512 timeout,
513 )
514 .await
515 };
516
517 if poll {
518 let mut timer = PolledTimer::new(&driver);
519 let period = Duration::from_secs_f64(period);
520 let mut last_time = pal_async::timer::Instant::now();
521 let mut last = query().await?;
522 let mut count = count;
523
524 loop {
525 match count.as_mut() {
526 Some(count) if *count == 0 => break,
527 Some(count) => *count -= 1,
528 None => {}
529 }
530 timer.sleep_until(last_time + period).await;
531 let now = pal_async::timer::Instant::now();
532 let this = query().await?;
533 let diff = this.since(&last, now - last_time);
534 if json {
535 println!("{}", diff.json());
536 } else {
537 println!("{diff:#}");
538 }
539 last = this;
540 last_time = now;
541 }
542 } else {
543 let node = query().await?;
544 if json {
545 println!("{}", node.json());
546 } else {
547 println!("{node:#}");
548 }
549 }
550 }
551 }
552 Command::Update { path, value } => {
553 eprintln!(
554 "`update` is deprecated - please use `ohcldiag-dev inspect <path> -u <new value>`"
555 );
556 let client = new_client(driver.clone(), &vm)?;
557 let value = client.update(path, value).await?;
558 println!("{value}");
559 }
560 Command::Start { env, unset, args } => {
561 let client = new_client(driver.clone(), &vm)?;
562
563 let env = env
564 .into_iter()
565 .map(|EnvString { name, value }| (name, Some(value)))
566 .chain(unset.into_iter().map(|name| (name, None)));
567
568 client.start(env, args).await?;
569 }
570 Command::Kmsg {
571 follow,
572 reconnect,
573 verbose,
574 #[cfg(windows)]
575 serial,
576 #[cfg(windows)]
577 pipe_path,
578 } => {
579 let is_terminal = std::io::stdout().is_terminal();
580
581 #[cfg(windows)]
582 if serial {
583 use diag_client::hyperv::ComPortAccessInfo;
584 use futures::AsyncBufReadExt;
585
586 let vm_name = match &vm.id {
587 VmId::HyperV(name) => name,
588 _ => anyhow::bail!("--serial is only supported for Hyper-V VMs"),
589 };
590
591 let port_access_info = if let Some(pipe_path) = pipe_path.as_ref() {
592 ComPortAccessInfo::PortPipePath(pipe_path)
593 } else {
594 ComPortAccessInfo::NameAndPortNumber(vm_name, 3)
595 };
596
597 let pipe =
598 diag_client::hyperv::open_serial_port(&driver, port_access_info).await?;
599 let pipe = pal_async::pipe::PolledPipe::new(&driver, pipe)
600 .context("failed to make a polled pipe")?;
601 let pipe = futures::io::BufReader::new(pipe);
602
603 let mut lines = pipe.lines();
604 while let Some(line) = lines.next().await {
605 let line = line?;
606 if let Some(message) = kmsg::SyslogParsedEntry::new(&line) {
607 println!("{}", message.display(is_terminal));
608 } else {
609 println!("{line}");
610 }
611 }
612
613 return Ok(());
614 }
615
616 if verbose {
617 eprintln!("Connecting to the diagnostics server.");
618 }
619
620 let client = new_client(driver.clone(), &vm)?;
621 'connect: loop {
622 if reconnect {
623 client.wait_for_server().await?;
624 }
625 let mut file_stream = client.kmsg(follow).await?;
626 if verbose {
627 eprintln!("Connected.");
628 }
629
630 while let Some(data) = file_stream.next().await {
631 match data {
632 Ok(data) => {
633 let message = kmsg::KmsgParsedEntry::new(&data)?;
634 println!("{}", message.display(is_terminal));
635 }
636 Err(err) if reconnect && err.kind() == ErrorKind::ConnectionReset => {
637 if verbose {
638 eprintln!(
639 "Connection reset to the diagnostics server. Reconnecting."
640 );
641 }
642 continue 'connect;
643 }
644 Err(err) => Err(err).context("failed to read kmsg")?,
645 }
646 }
647
648 if reconnect {
649 if verbose {
650 eprintln!("Lost connection to the diagnostics server. Reconnecting.");
651 }
652 continue 'connect;
653 }
654
655 break;
656 }
657 }
658 Command::File { follow, file_path } => {
659 let client = new_client(driver.clone(), &vm)?;
660 let stream = client.read_file(follow, file_path).await?;
661 futures::io::copy(stream, &mut AllowStdIo::new(term::raw_stdout()))
662 .await
663 .context("failed to copy trace file")?;
664 }
665 Command::Gdbserver { multi, pid } => {
666 let client = new_client(driver.clone(), &vm)?;
667 let gdbserver = "gdbserver --once";
671 let command = if multi {
672 format!("{gdbserver} --multi -")
673 } else if let Some(pid) = pid {
674 format!("{gdbserver} --attach - {pid}")
675 } else {
676 format!("{gdbserver} --attach - \"$(cat /run/underhill.pid)\"")
677 };
678
679 run(&client, "/bin/sh", &["-c", &command]).await?;
680 }
681 Command::Gdbstub { port } => {
682 let vsock = match vm.id {
683 VmId::HybridVsock(path) => {
684 diag_client::connect_hybrid_vsock(&driver, &path, port).await?
685 }
686 #[cfg(windows)]
687 VmId::HyperV(name) => {
688 let vm_id = diag_client::hyperv::vm_id_from_name(&name)?;
689 let stream =
690 diag_client::hyperv::connect_vsock(&driver, vm_id, port).await?;
691 PolledSocket::new(&driver, socket2::Socket::from(stream))?
692 }
693 };
694
695 let vsock = Arc::new(vsock.into_inner());
696 let thread = std::thread::spawn({
699 let vsock = vsock.clone();
700 move || {
701 let _ = std::io::copy(&mut std::io::stdin(), &mut vsock.as_ref());
702 }
703 });
704
705 std::io::copy(&mut vsock.as_ref(), &mut term::raw_stdout())
706 .context("failed stdout copy")?;
707 thread.join().unwrap();
708 }
709 Command::Crash {
710 crash_type,
711 pid,
712 name,
713 } => {
714 let client = new_client(driver.clone(), &vm)?;
715 let pid = if let Some(name) = name {
716 client.get_pid(&name).await?
717 } else {
718 pid.unwrap()
719 };
720 println!("Crashing PID: {pid}");
721 match crash_type {
722 CrashType::UhPanic => {
723 _ = client.crash(pid).await;
724 }
725 }
726 }
727 Command::PacketCapture {
728 output,
729 seconds,
730 snaplen,
731 } => {
732 let client = new_client(driver.clone(), &vm)?;
733 println!(
734 "Starting network packet capture. Wait for timeout or Ctrl-C to quit anytime."
735 );
736 let (_, num_streams) = client
737 .packet_capture(PacketCaptureOperation::Query, 0, 0)
738 .await?;
739 let file_stem = &output.file_stem().unwrap().to_string_lossy();
740 let extension = &output.extension().unwrap_or(OsStr::new("pcap"));
741 let mut new_output = PathBuf::from(&output);
742 let streams = client
743 .packet_capture(PacketCaptureOperation::Start, num_streams, snaplen)
744 .await?
745 .0
746 .into_iter()
747 .enumerate()
748 .map(|(i, i_stream)| {
749 new_output.set_file_name(format!("{}-{}", &file_stem, i));
750 new_output.set_extension(extension);
751 let mut out = AllowStdIo::new(fs_err::File::create(&new_output)?);
752 Ok(async move { futures::io::copy(i_stream, &mut out).await })
753 })
754 .collect::<Result<Vec<_>, std::io::Error>>()?;
755 capture_packets(client, streams, seconds).await;
756 }
757 Command::CoreDump {
758 verbose,
759 pid,
760 name,
761 dst,
762 } => {
763 ensure_not_terminal(&dst)?;
764 let client = new_client(driver.clone(), &vm)?;
765 let pid = if let Some(name) = name {
766 client.get_pid(&name).await?
767 } else {
768 pid.unwrap()
769 };
770 println!("Dumping PID: {pid}");
771 let file = create_or_stderr(&dst)?;
772 client
773 .core_dump(
774 pid,
775 AllowStdIo::new(file),
776 AllowStdIo::new(std::io::stderr()),
777 verbose,
778 )
779 .await?;
780 }
781 Command::Restart => {
782 let client = new_client(driver.clone(), &vm)?;
783 client.restart().await?;
784 }
785 Command::PerfTrace { output } => {
786 ensure_not_terminal(&output)?;
787
788 let client = new_client(driver.clone(), &vm)?;
789
790 client
792 .update("trace/perf/flush".to_owned(), "true".to_owned())
793 .await
794 .context("failed to flush perf")?;
795
796 let file = create_or_stderr(&output)?;
797 let stream = client
798 .read_file(false, "underhill.perfetto".to_owned())
799 .await
800 .context("failed to read trace file")?;
801
802 futures::io::copy(stream, &mut AllowStdIo::new(file))
803 .await
804 .context("failed to copy trace file")?;
805 }
806 Command::VsockTcpRelay {
807 vsock_port,
808 tcp_port,
809 allow_remote,
810 reconnect,
811 } => {
812 let addr = if allow_remote { "0.0.0.0" } else { "127.0.0.1" };
813 let listener = TcpListener::bind((addr, tcp_port))
814 .with_context(|| format!("binding to port {}", tcp_port))?;
815 println!("TCP listening on {}:{}", addr, tcp_port);
816 'connect: loop {
817 let (tcp_socket, tcp_addr) = listener.accept()?;
818 let tcp_socket = PolledSocket::new(&driver, tcp_socket)?;
819 println!("TCP accept on {:?}", tcp_addr);
820
821 let vsock = match vm.id {
823 VmId::HybridVsock(ref path) => {
824 diag_client::connect_hybrid_vsock(&driver, path, vsock_port).await?
828 }
829 #[cfg(windows)]
830 VmId::HyperV(ref name) => {
831 let vm_id = diag_client::hyperv::vm_id_from_name(name)?;
832 let stream =
833 diag_client::hyperv::connect_vsock(&driver, vm_id, vsock_port)
834 .await?;
835 PolledSocket::new(&driver, socket2::Socket::from(stream))?
836 }
837 };
838 println!("VSOCK connect to port {:?}", vsock_port);
839
840 let (tcp_read, mut tcp_write) = tcp_socket.split();
841 let (vsock_read, mut vsock_write) = vsock.split();
842 let tx = futures::io::copy(tcp_read, &mut vsock_write);
843 let rx = futures::io::copy(vsock_read, &mut tcp_write);
844 let result = futures::future::try_join(tx, rx).await;
845 match result {
846 Ok(_) => {}
847 Err(e) => match e.kind() {
848 ErrorKind::ConnectionReset => {}
849 _ => return Err(anyhow::Error::from(e)),
850 },
851 }
852 println!("Connection closed");
853
854 if reconnect {
855 println!("Reconnecting...");
856 continue 'connect;
857 }
858
859 break;
860 }
861 }
862 Command::Pause => {
863 let client = new_client(driver.clone(), &vm)?;
864 client.pause().await?;
865 }
866 Command::Resume => {
867 let client = new_client(driver.clone(), &vm)?;
868 client.resume().await?;
869 }
870 Command::DumpSavedState { output } => {
871 ensure_not_terminal(&output)?;
872 let client = new_client(driver.clone(), &vm)?;
873 let mut file = create_or_stderr(&output)?;
874 file.write_all(&client.dump_saved_state().await?)?;
875 }
876 }
877 Ok(())
878 })
879}
880
881fn ensure_not_terminal(path: &Option<PathBuf>) -> anyhow::Result<()> {
882 if path.is_none() && std::io::stdout().is_terminal() {
883 anyhow::bail!("cannot write to terminal");
884 }
885 Ok(())
886}
887
888fn create_or_stderr(path: &Option<PathBuf>) -> std::io::Result<fs_err::File> {
889 let file = match path {
890 Some(path) => fs_err::File::create(path)?,
891 None => fs_err::File::from_parts(term::raw_stdout(), "stdout"),
892 };
893 Ok(file)
894}
895
896async fn capture_packets(
897 client: DiagClient,
898 streams: Vec<impl std::future::Future<Output = Result<u64, std::io::Error>>>,
899 capture_duration: Duration,
900) {
901 let mut capture_streams = FuturesUnordered::from_iter(streams);
902 let (user_input_tx, mut user_input_rx) = mesh::channel();
903 ctrlc::set_handler(move || user_input_tx.send(())).expect("Error setting Ctrl-C handler");
904
905 let mut ctx = mesh::CancelContext::new().with_timeout(capture_duration);
906 let mut stop_signaled = std::pin::pin!(ctx.until_cancelled(user_input_rx.recv()));
907
908 let mut stop_streams = std::pin::pin!(async {
909 if let Err(err) = client
910 .packet_capture(PacketCaptureOperation::Stop, 0, 0)
911 .await
912 {
913 eprintln!("Failed stop: {err}");
914 }
915 });
916
917 #[derive(PartialEq)]
918 enum State {
919 Running,
920 Stopping,
921 StoppingStreamsDone,
922 Stopped,
923 }
924 let mut state = State::Running;
925 loop {
926 enum Event {
927 Continue,
928 StopSignaled,
929 StopComplete,
930 StreamsDone,
931 }
932 let stop = async {
933 match state {
934 State::Running => {
935 (&mut stop_signaled).await.ok();
936 Event::StopSignaled
937 }
938 State::Stopping | State::StoppingStreamsDone => {
939 (&mut stop_streams).await;
940 Event::StopComplete
941 }
942 State::Stopped => std::future::pending::<Event>().await,
943 }
944 };
945 let process_streams = async {
946 if state == State::StoppingStreamsDone {
947 std::future::pending::<()>().await;
948 }
949 match capture_streams.next().await {
950 Some(_) => Event::Continue,
951 None => Event::StreamsDone,
952 }
953 };
954 let event = (stop, process_streams).race();
955
956 match event.await {
959 Event::Continue => continue,
960 Event::StopSignaled => {
961 println!("Stopping packet capture...");
962 state = State::Stopping;
963 }
964 Event::StopComplete => {
965 println!("Waiting for data to be flushed...");
966 if state == State::Stopping {
967 state = State::Stopped;
968 } else {
969 break;
970 }
971 }
972 Event::StreamsDone if state == State::Stopping => {
973 state = State::StoppingStreamsDone;
974 }
975 Event::StreamsDone => {
976 if state != State::Stopped {
977 println!("Lost connection with network.");
978 }
979 break;
980 }
981 }
982 }
983 println!("All done.");
984}