flowey_lib_common/
download_gh_release.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Download a github release artifact
5
6use flowey::node::prelude::*;
7use std::collections::BTreeMap;
8
9flowey_request! {
10    pub struct Request {
11        /// First component of a github repo path
12        ///
13        /// e.g: the "foo" in "github.com/foo/bar"
14        pub repo_owner: String,
15        /// Second component of a github repo path
16        ///
17        /// e.g: the "bar" in "github.com/foo/bar"
18        pub repo_name: String,
19        /// Whether this repo requires authentication.
20        ///
21        /// If true, downloads will be routed through the `gh` CLI client, which
22        /// will require auth to be set up. See
23        /// [`use_gh_cli`](crate::use_gh_cli).
24        pub needs_auth: bool,
25        /// Tag associated with the release artifact.
26        pub tag: String,
27        /// Specific filename to download.
28        pub file_name: String,
29        /// Path to downloaded artifact.
30        pub path: WriteVar<PathBuf>,
31    }
32}
33
34new_flow_node!(struct Node);
35
36impl FlowNode for Node {
37    type Request = Request;
38
39    fn imports(ctx: &mut ImportCtx<'_>) {
40        ctx.import::<crate::cache::Node>();
41        ctx.import::<crate::use_gh_cli::Node>();
42    }
43
44    fn emit(requests: Vec<Self::Request>, ctx: &mut NodeCtx<'_>) -> anyhow::Result<()> {
45        let mut download_reqs: BTreeMap<
46            (String, String, String),
47            BTreeMap<String, Vec<WriteVar<PathBuf>>>,
48        > = BTreeMap::new();
49        let mut use_gh_cli = false;
50
51        for req in requests {
52            let Request {
53                repo_owner,
54                repo_name,
55                needs_auth,
56                tag,
57                file_name,
58                path,
59            } = req;
60
61            // if any package needs auth, we might as well download every
62            // package using the GH cli.
63            use_gh_cli |= needs_auth;
64
65            download_reqs
66                .entry((repo_owner, repo_name, tag))
67                .or_default()
68                .entry(file_name)
69                .or_default()
70                .push(path)
71        }
72
73        if download_reqs.is_empty() {
74            return Ok(());
75        }
76
77        let gh_cli = use_gh_cli.then(|| ctx.reqv(crate::use_gh_cli::Request::Get));
78
79        match ctx.persistent_dir() {
80            Some(dir) => Self::with_local_cache(ctx, dir, download_reqs, gh_cli),
81            None => Self::with_ci_cache(ctx, download_reqs, gh_cli),
82        }
83
84        Ok(())
85    }
86}
87
88impl Node {
89    // Have a single folder which caches downloaded artifacts
90    fn with_local_cache(
91        ctx: &mut NodeCtx<'_>,
92        persistent_dir: ReadVar<PathBuf>,
93        download_reqs: BTreeMap<(String, String, String), BTreeMap<String, Vec<WriteVar<PathBuf>>>>,
94        gh_cli: Option<ReadVar<PathBuf>>,
95    ) {
96        ctx.emit_rust_step("download artifacts from github releases", |ctx| {
97            let gh_cli = gh_cli.claim(ctx);
98            let persistent_dir = persistent_dir.claim(ctx);
99            let download_reqs = download_reqs.claim(ctx);
100            move |rt| {
101                let persistent_dir = rt.read(persistent_dir);
102
103                // first - check what reqs are already present in the local cache
104                let mut remaining_download_reqs: BTreeMap<
105                    (String, String, String),
106                    BTreeMap<String, Vec<ClaimedWriteVar<PathBuf>>>,
107                > = BTreeMap::new();
108                for ((repo_owner, repo_name, tag), files) in download_reqs {
109                    for (file, vars) in files {
110                        let cached_file =
111                            persistent_dir.join(format!("{repo_owner}/{repo_name}/{tag}/{file}"));
112
113                        if cached_file.exists() {
114                            for var in vars {
115                                rt.write(var, &cached_file)
116                            }
117                        } else {
118                            let existing = remaining_download_reqs
119                                .entry((repo_owner.clone(), repo_name.clone(), tag.clone()))
120                                .or_default()
121                                .insert(file, vars);
122                            assert!(existing.is_none());
123                        }
124                    }
125                }
126
127                if remaining_download_reqs.is_empty() {
128                    log::info!("100% local cache hit!");
129                    return Ok(());
130                }
131
132                download_all_reqs(rt, &remaining_download_reqs, &persistent_dir, gh_cli)?;
133
134                for ((repo_owner, repo_name, tag), files) in remaining_download_reqs {
135                    for (file, vars) in files {
136                        let file =
137                            persistent_dir.join(format!("{repo_owner}/{repo_name}/{tag}/{file}"));
138                        assert!(file.exists());
139                        for var in vars {
140                            rt.write(var, &file)
141                        }
142                    }
143                }
144
145                Ok(())
146            }
147        });
148    }
149
150    // Instead of having a cache directory per-repo (and spamming the
151    // workflow with a whole bunch of cache task requests), have a single
152    // cache directory for each flow's request-set.
153    fn with_ci_cache(
154        ctx: &mut NodeCtx<'_>,
155        download_reqs: BTreeMap<(String, String, String), BTreeMap<String, Vec<WriteVar<PathBuf>>>>,
156        gh_cli: Option<ReadVar<PathBuf>>,
157    ) {
158        let cache_dir = ctx.emit_rust_stepv("create gh-release-download cache dir", |_| {
159            |_| Ok(std::env::current_dir()?.absolute()?)
160        });
161
162        let request_set_hash = {
163            let hasher = &mut rustc_hash::FxHasher::default();
164            for ((repo_owner, repo_name, tag), files) in &download_reqs {
165                std::hash::Hash::hash(repo_owner, hasher);
166                std::hash::Hash::hash(repo_name, hasher);
167                std::hash::Hash::hash(tag, hasher);
168                for file in files.keys() {
169                    std::hash::Hash::hash(&file, hasher);
170                }
171            }
172            let hash = std::hash::Hasher::finish(hasher);
173            format!("{:08x?}", hash)
174        };
175
176        let cache_key = ReadVar::from_static(format!("gh-release-download-{request_set_hash}"));
177        let hitvar = ctx.reqv(|v| {
178            crate::cache::Request {
179                label: "gh-release-download".into(),
180                dir: cache_dir.clone(),
181                key: cache_key,
182                restore_keys: None, // OK if not exact - better than nothing
183                hitvar: v,
184            }
185        });
186
187        ctx.emit_rust_step("download artifacts from github releases", |ctx| {
188            let cache_dir = cache_dir.claim(ctx);
189            let hitvar = hitvar.claim(ctx);
190            let gh_cli = gh_cli.claim(ctx);
191            let download_reqs = download_reqs.claim(ctx);
192            move |rt| {
193                let cache_dir = rt.read(cache_dir);
194                let hitvar = rt.read(hitvar);
195
196                if !matches!(hitvar, crate::cache::CacheHit::Hit) {
197                    download_all_reqs(rt, &download_reqs, &cache_dir, gh_cli)?;
198                }
199
200                for ((repo_owner, repo_name, tag), files) in download_reqs {
201                    for (file, vars) in files {
202                        let file = cache_dir.join(format!("{repo_owner}/{repo_name}/{tag}/{file}"));
203                        assert!(file.exists());
204                        for var in vars {
205                            rt.write(var, &file)
206                        }
207                    }
208                }
209
210                Ok(())
211            }
212        });
213    }
214}
215
216fn download_all_reqs(
217    rt: &mut RustRuntimeServices<'_>,
218    download_reqs: &BTreeMap<
219        (String, String, String),
220        BTreeMap<String, Vec<WriteVar<PathBuf, VarClaimed>>>,
221    >,
222    cache_dir: &Path,
223    gh_cli: Option<ReadVar<PathBuf, VarClaimed>>,
224) -> anyhow::Result<()> {
225    let sh = xshell::Shell::new()?;
226
227    let gh_cli = rt.read(gh_cli);
228
229    for ((repo_owner, repo_name, tag), files) in download_reqs {
230        let repo = format!("{repo_owner}/{repo_name}");
231
232        let out_dir = cache_dir.join(format!("{repo_owner}/{repo_name}/{tag}"));
233        fs_err::create_dir_all(&out_dir)?;
234        sh.change_dir(&out_dir);
235
236        if let Some(gh_cli) = &gh_cli {
237            // FUTURE: while the gh cli takes care of doing simultaneous downloads in
238            // the context of a single (repo, tag), we might want to have flowey spawn
239            // multiple processes to saturate the network connection in cases where
240            // multiple (repo, tag) pairs are being pulled at the same time.
241            let patterns = files.keys().flat_map(|k| ["--pattern".into(), k.clone()]);
242            xshell::cmd!(
243                sh,
244                "{gh_cli} release download -R {repo} {tag} {patterns...} --skip-existing"
245            )
246            .run()?;
247        } else {
248            // FUTURE: parallelize curl invocations across all download_reqs
249            for file in files.keys() {
250                xshell::cmd!(sh, "curl --fail -L https://github.com/{repo_owner}/{repo_name}/releases/download/{tag}/{file} -o {file}").run()?;
251            }
252        }
253    }
254
255    Ok(())
256}