ohcldiag_dev/
completions.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use crate::VmArg;
5use crate::VmId;
6use crate::new_client;
7use clap::Parser;
8use clap_dyn_complete::CustomCompleter;
9use clap_dyn_complete::CustomCompleterFactory;
10use pal_async::DefaultDriver;
11use std::time::Duration;
12
13/// Enable shell-completions
14#[derive(Parser)]
15pub struct Completions {
16    /// Shell to generate completions for
17    shell: clap_dyn_complete::Shell,
18}
19
20impl Completions {
21    pub fn run(self) -> anyhow::Result<()> {
22        clap_dyn_complete::emit_completion_stub(
23            self.shell,
24            "ohcldiag-dev",
25            ". complete",
26            &mut std::io::stdout(),
27        )?;
28        Ok(())
29    }
30}
31
32pub(crate) struct OhcldiagDevCompleteFactory {
33    pub driver: DefaultDriver,
34}
35
36impl CustomCompleterFactory for OhcldiagDevCompleteFactory {
37    type CustomCompleter = OhcldiagDevComplete;
38    async fn build(&self, ctx: &clap_dyn_complete::RootCtx<'_>) -> Self::CustomCompleter {
39        let vm = ctx.matches.try_get_one::<VmId>("VM").unwrap_or_default();
40        let client = if let Some(vm) = vm {
41            new_client(self.driver.clone(), &VmArg { id: vm.clone() }).ok()
42        } else {
43            None
44        };
45
46        OhcldiagDevComplete { client }
47    }
48}
49
50pub(crate) struct OhcldiagDevComplete {
51    client: Option<diag_client::DiagClient>,
52}
53
54impl CustomCompleter for OhcldiagDevComplete {
55    async fn complete(
56        &self,
57        ctx: &clap_dyn_complete::RootCtx<'_>,
58        subcommand_path: &[&str],
59        arg_id: &str,
60    ) -> Vec<String> {
61        match (subcommand_path, arg_id) {
62            (["ohcldiag-dev"], "VM") => list_vms().unwrap_or_default(),
63            (["ohcldiag-dev", "inspect"], "path") => {
64                let on_error = vec!["failed/to/connect".into()];
65
66                let (parent_path, to_complete) = (ctx.to_complete)
67                    .rsplit_once('/')
68                    .unwrap_or(("", ctx.to_complete));
69
70                let Some(client) = self.client.as_ref() else {
71                    return on_error;
72                };
73
74                let Ok(node) = client
75                    .inspect(parent_path, Some(1), Some(Duration::from_secs(1)))
76                    .await
77                else {
78                    return on_error;
79                };
80
81                let mut completions = Vec::new();
82
83                if let inspect::Node::Dir(dir) = node {
84                    for entry in dir {
85                        if entry.name.starts_with(to_complete) {
86                            if parent_path.is_empty() {
87                                completions.push(format!("{}/", entry.name))
88                            } else {
89                                completions.push(format!(
90                                    "{}/{}{}",
91                                    parent_path,
92                                    entry.name,
93                                    if matches!(entry.node, inspect::Node::Dir(..)) {
94                                        "/"
95                                    } else {
96                                        ""
97                                    }
98                                ))
99                            }
100                        }
101                    }
102                }
103                completions
104            }
105            _ => Vec::new(),
106        }
107    }
108}
109
110#[cfg(windows)]
111fn list_vms() -> anyhow::Result<Vec<String>> {
112    use anyhow::Context;
113    let output = std::process::Command::new("hvc.exe")
114        .arg("list")
115        .arg("-q")
116        .output()
117        .context("failed to invoke hvc.exe")?;
118
119    if output.status.success() {
120        let stdout = std::str::from_utf8(&output.stdout).context("stdout isn't utf8")?;
121        Ok(stdout.trim().lines().map(String::from).collect())
122    } else {
123        Ok(vec![".".into()])
124    }
125}
126
127#[cfg(not(windows))]
128fn list_vms() -> anyhow::Result<Vec<String>> {
129    Ok(vec![".".into()])
130}