ohcldiag_dev/
completions.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

use crate::VmArg;
use crate::VmId;
use crate::new_client;
use clap::Parser;
use clap_dyn_complete::CustomCompleter;
use clap_dyn_complete::CustomCompleterFactory;
use pal_async::DefaultDriver;
use std::time::Duration;

/// Enable shell-completions
#[derive(Parser)]
pub struct Completions {
    /// Shell to generate completions for
    shell: clap_dyn_complete::Shell,
}

impl Completions {
    pub fn run(self) -> anyhow::Result<()> {
        clap_dyn_complete::emit_completion_stub(
            self.shell,
            "ohcldiag-dev",
            ". complete",
            &mut std::io::stdout(),
        )?;
        Ok(())
    }
}

pub(crate) struct OhcldiagDevCompleteFactory {
    pub driver: DefaultDriver,
}

impl CustomCompleterFactory for OhcldiagDevCompleteFactory {
    type CustomCompleter = OhcldiagDevComplete;
    async fn build(&self, ctx: &clap_dyn_complete::RootCtx<'_>) -> Self::CustomCompleter {
        let vm = ctx.matches.try_get_one::<VmId>("VM").unwrap_or_default();
        let client = if let Some(vm) = vm {
            new_client(self.driver.clone(), &VmArg { id: vm.clone() }).ok()
        } else {
            None
        };

        OhcldiagDevComplete { client }
    }
}

pub(crate) struct OhcldiagDevComplete {
    client: Option<diag_client::DiagClient>,
}

impl CustomCompleter for OhcldiagDevComplete {
    async fn complete(
        &self,
        ctx: &clap_dyn_complete::RootCtx<'_>,
        subcommand_path: &[&str],
        arg_id: &str,
    ) -> Vec<String> {
        match (subcommand_path, arg_id) {
            (["ohcldiag-dev"], "VM") => list_vms().unwrap_or_default(),
            (["ohcldiag-dev", "inspect"], "path") => {
                let on_error = vec!["failed/to/connect".into()];

                let (parent_path, to_complete) = (ctx.to_complete)
                    .rsplit_once('/')
                    .unwrap_or(("", ctx.to_complete));

                let Some(client) = self.client.as_ref() else {
                    return on_error;
                };

                let Ok(node) = client
                    .inspect(parent_path, Some(1), Some(Duration::from_secs(1)))
                    .await
                else {
                    return on_error;
                };

                let mut completions = Vec::new();

                if let inspect::Node::Dir(dir) = node {
                    for entry in dir {
                        if entry.name.starts_with(to_complete) {
                            if parent_path.is_empty() {
                                completions.push(format!("{}/", entry.name))
                            } else {
                                completions.push(format!(
                                    "{}/{}{}",
                                    parent_path,
                                    entry.name,
                                    if matches!(entry.node, inspect::Node::Dir(..)) {
                                        "/"
                                    } else {
                                        ""
                                    }
                                ))
                            }
                        }
                    }
                }
                completions
            }
            _ => Vec::new(),
        }
    }
}

#[cfg(windows)]
fn list_vms() -> anyhow::Result<Vec<String>> {
    use anyhow::Context;
    let output = std::process::Command::new("hvc.exe")
        .arg("list")
        .arg("-q")
        .output()
        .context("failed to invoke hvc.exe")?;

    if output.status.success() {
        let stdout = std::str::from_utf8(&output.stdout).context("stdout isn't utf8")?;
        Ok(stdout.trim().lines().map(String::from).collect())
    } else {
        Ok(vec![".".into()])
    }
}

#[cfg(not(windows))]
fn list_vms() -> anyhow::Result<Vec<String>> {
    Ok(vec![".".into()])
}