xtask/tasks/fmt/
rustfmt.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use crate::Xtask;
5use crate::fs_helpers::git_diffed;
6use crate::shell::XtaskShell;
7use clap::Parser;
8use std::path::PathBuf;
9
10#[derive(Parser)]
11#[clap(about = "Check that all repo files are formatted using rustfmt")]
12pub struct Rustfmt {
13    /// Run `rustfmt` on all `.rs` files in the repo
14    #[clap(long)]
15    pub fix: bool,
16
17    /// A list of files to check
18    ///
19    /// If no files were provided, all files in-tree will be checked
20    pub files: Vec<PathBuf>,
21
22    /// Only run checks on files that are currently diffed
23    #[clap(long, conflicts_with = "files")]
24    pub only_diffed: bool,
25}
26
27impl Rustfmt {
28    pub fn new(fix: bool, only_diffed: bool) -> Self {
29        Self {
30            fix,
31            files: Vec::new(),
32            only_diffed,
33        }
34    }
35}
36
37#[derive(Debug)]
38enum Files {
39    All,
40    OnlyDiffed,
41    Specific(Vec<PathBuf>),
42}
43
44impl Xtask for Rustfmt {
45    fn run(self, ctx: crate::XtaskCtx) -> anyhow::Result<()> {
46        let files = if self.only_diffed {
47            Files::OnlyDiffed
48        } else if self.files.is_empty() {
49            Files::All
50        } else {
51            Files::Specific(self.files)
52        };
53
54        log::trace!("running rustfmt on {:?}", files);
55
56        let sh = XtaskShell::new()?;
57        let rust_toolchain = sh.var("RUST_TOOLCHAIN").map(|s| format!("+{s}")).ok();
58        let fmt_check = (!self.fix).then_some("--check");
59
60        match files {
61            Files::All => {
62                sh.cmd("cargo")
63                    .args(rust_toolchain)
64                    .args(["fmt", "--"])
65                    .args(fmt_check)
66                    .quiet()
67                    .run()?;
68            }
69            Files::OnlyDiffed => {
70                let mut files = git_diffed(ctx.in_git_hook)?;
71                files.retain(|f| f.extension().unwrap_or_default() == "rs");
72
73                if !files.is_empty() {
74                    let res = sh
75                        .cmd("rustfmt")
76                        .args(rust_toolchain)
77                        .args(fmt_check)
78                        .args(&files)
79                        .quiet()
80                        .run();
81
82                    if res.is_err() {
83                        anyhow::bail!("found formatting issues in diffed files");
84                    }
85                }
86            }
87            Files::Specific(files) => {
88                assert!(!files.is_empty());
89
90                sh.cmd("rustfmt")
91                    .args(rust_toolchain)
92                    .args(fmt_check)
93                    .args(&files)
94                    .quiet()
95                    .run()?;
96            }
97        }
98
99        log::trace!("done rustfmt");
100        Ok(())
101    }
102}