uevent/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Implements a listener to wait for Linux kobject uevents.
5//!
6//! These are used to wait for device hotplug events, disk capacity changes, and
7//! other asynchronous hardware state changes in Linux.
8
9#![cfg(target_os = "linux")]
10
11mod bind_kobject_uevent;
12
13use anyhow::Context;
14use fs_err::PathExt;
15use futures::AsyncReadExt;
16use futures::FutureExt;
17use futures::StreamExt;
18use futures_concurrency::future::Race;
19use mesh::rpc::Rpc;
20use mesh::rpc::RpcSend;
21use pal_async::driver::SpawnDriver;
22use pal_async::socket::PolledSocket;
23use pal_async::task::Task;
24use socket2::Socket;
25use std::future::Future;
26use std::io;
27use std::path::Path;
28use std::path::PathBuf;
29use thiserror::Error;
30
31/// A listener for Linux udev events.
32pub struct UeventListener {
33    _task: Task<()>,
34    send: mesh::Sender<TaskRequest>,
35}
36
37/// An error from [`UeventListener::new`].
38#[derive(Debug, Error)]
39#[error("failed to create uevent socket")]
40pub struct NewUeventListenerError(#[source] io::Error);
41
42impl UeventListener {
43    /// Opens a new netlink socket and starts listening on it.
44    pub fn new(driver: &impl SpawnDriver) -> Result<Self, NewUeventListenerError> {
45        let socket =
46            bind_kobject_uevent::bind_kobject_uevent_socket().map_err(NewUeventListenerError)?;
47        let socket = PolledSocket::new(driver, socket).map_err(NewUeventListenerError)?;
48        let (send, recv) = mesh::mpsc_channel();
49        let thing = ListenerTask {
50            socket,
51            callbacks: Vec::new(),
52            recv,
53            next_id: 0,
54        };
55        let task = driver.spawn("uevent", async move { thing.run().await });
56        Ok(Self { _task: task, send })
57    }
58
59    /// Adds a callback function that receives every event.
60    pub async fn add_custom_callback(
61        &self,
62        callback: impl 'static + Send + FnMut(Notification<'_>),
63    ) -> CallbackHandle {
64        self.send
65            .call(TaskRequest::NewFilter, Box::new(callback))
66            .await
67            .unwrap()
68    }
69
70    /// Adds a callback that runs when the block device with the given
71    /// major/minor numbers has been resized or a rescan event was triggered
72    /// where the caller is required to rescan for the condition
73    pub async fn add_block_resize_callback(
74        &self,
75        major: u32,
76        minor: u32,
77        mut notify: impl 'static + Send + FnMut(),
78    ) -> CallbackHandle {
79        self.add_custom_callback(move |event| match event {
80            Notification::Event(kvs) => {
81                if (kvs.get("RESCAN") == Some("true"))
82                    || (kvs.get("RESIZE") == Some("1")
83                        && kvs.get("SUBSYSTEM") == Some("block")
84                        && kvs.get("ACTION") == Some("change")
85                        && kvs.get("MAJOR").is_some_and(|x| x.parse() == Ok(major))
86                        && kvs.get("MINOR").is_some_and(|x| x.parse() == Ok(minor)))
87                {
88                    notify();
89                }
90            }
91        })
92        .await
93    }
94
95    /// Waits for a child of the provided devpath (typically something under
96    /// /sys) to exist.
97    ///
98    /// If it does not immediately exist, this will poll the path for existence
99    /// each time a new uevent arrives.
100    ///
101    /// `f` will be called with the file name of the child, and a boolean: true
102    /// if the child was found by uevent, false if it was found by sysfs. It
103    /// should return `Some(_)` if the child is the correct one.
104    ///
105    /// This is inefficient if there are lots of waiters and lots of incoming
106    /// uevents, but this is not an expected use case.
107    pub async fn wait_for_matching_child<T, F, Fut>(&self, path: &Path, f: F) -> io::Result<T>
108    where
109        F: Fn(PathBuf, bool) -> Fut,
110        Fut: Future<Output = Option<T>>,
111    {
112        let scan_for_matching_child = async || {
113            for entry in path.fs_err_read_dir()? {
114                let entry = entry?;
115                if let Some(r) = f(entry.path(), false).await {
116                    return Ok::<Option<T>, io::Error>(Some(r));
117                }
118            }
119            Ok(None)
120        };
121
122        // Fast path.
123        if path.exists() {
124            if let Some(child) = scan_for_matching_child().await? {
125                return Ok(child);
126            }
127        }
128
129        // Get the absolute devpath to make child lookups fast.
130        self.wait_for_devpath(path).await?;
131        let path = path.fs_err_canonicalize()?;
132        let path_clone = path.clone();
133        let parent_devpath = path
134            .strip_prefix("/sys")
135            .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid devpath"))?
136            .to_path_buf();
137
138        let (send, mut recv) = mesh::channel();
139        let _handle = self
140            .add_custom_callback({
141                move |notification| {
142                    match notification {
143                        Notification::Event(uevent) => {
144                            // uevent can return a rescan event in some cases where it is not sure
145                            // about the end state. In those cases, the end state needs to be checked
146                            // again for any change.
147                            if uevent.get("RESCAN") == Some("true") {
148                                if let Ok(read_dir) = path_clone.fs_err_read_dir() {
149                                    for entry in read_dir {
150                                        if let Ok(sub_entry) = entry {
151                                            send.send((sub_entry.path(), false));
152                                        }
153                                    }
154                                }
155                            } else if uevent.get("ACTION") == Some("add") {
156                                let Some(devpath) = uevent.get("DEVPATH") else {
157                                    return;
158                                };
159                                // Remove the leading /.
160                                let devpath = Path::new(&devpath[1..]);
161                                if devpath.parent() == Some(&parent_devpath) {
162                                    send.send((Path::new("/sys").join(devpath), true));
163                                }
164                            }
165                        }
166                    }
167                }
168            })
169            .await;
170
171        if let Some(child) = scan_for_matching_child().await? {
172            return Ok(child);
173        }
174
175        tracing::debug!(path = %path.display(), "waiting for child nodes");
176        while let Some((path, is_uevent)) = recv.next().await {
177            if let Some(r) = f(path, is_uevent).await {
178                return Ok(r);
179            }
180        }
181
182        Err(io::Error::new(
183            io::ErrorKind::InvalidInput,
184            "Did not find a matching path",
185        ))
186    }
187
188    /// Waits for the provided devpath (typically something under /sys) to
189    /// exist.
190    ///
191    /// If it does not immediately exist, this will poll the path for existence
192    /// each time a new uevent arrives.
193    ///
194    /// This is inefficient if there are lots of waiters and lots of incoming
195    /// uevents, but this is not an expected use case.
196    pub async fn wait_for_devpath(&self, path: &Path) -> io::Result<()> {
197        // Fast path.
198        if path.exists() {
199            return Ok(());
200        }
201
202        // Register the listener.
203        let (send, recv) = mesh::oneshot();
204        let _handle = self
205            .add_custom_callback({
206                let path = path.to_owned();
207                let mut send = Some(send);
208                move |event| {
209                    if send.is_none() {
210                        return;
211                    }
212                    match event {
213                        Notification::Event(uevent) => {
214                            if (uevent.get("ACTION") == Some("add"))
215                                || (uevent.get("RESCAN") == Some("true"))
216                            {
217                                let r = path.fs_err_symlink_metadata();
218                                if !matches!(&r, Err(err) if err.kind() == io::ErrorKind::NotFound)
219                                {
220                                    send.take().unwrap().send(r);
221                                }
222                            }
223                        }
224                    }
225                }
226            })
227            .await;
228
229        // Check for the path again in case it arrived before the listener was
230        // registered.
231        let r = match path.fs_err_symlink_metadata() {
232            Ok(m) => Ok(m),
233            Err(err) if err.kind() == io::ErrorKind::NotFound => {
234                tracing::debug!(path = %path.display(), "waiting for devpath");
235                recv.await.unwrap()
236            }
237            Err(err) => Err(err),
238        };
239        r?;
240        Ok(())
241    }
242}
243
244/// A notification for a [`UeventListener`] callback to process.
245pub enum Notification<'a> {
246    /// An event arrived.
247    Event(&'a Uevent<'a>),
248}
249
250/// A device event.
251pub struct Uevent<'a> {
252    header: &'a str,
253    properties: Vec<(&'a str, &'a str)>,
254}
255
256impl Uevent<'_> {
257    /// Gets the header.
258    pub fn header(&self) -> &str {
259        self.header
260    }
261
262    /// Gets a property by key.
263    pub fn get(&self, key: &str) -> Option<&str> {
264        let i = self
265            .properties
266            .binary_search_by_key(&key, |(k, _)| k)
267            .ok()?;
268        Some(self.properties[i].1)
269    }
270}
271
272/// A callback handle from [`UeventListener`].
273///
274/// When dropped, it will unregister the callback. This is asynchronous, so the
275/// callback may be called several more times after this.
276#[must_use]
277#[derive(Debug)]
278pub struct CallbackHandle {
279    id: u64,
280    send: mesh::Sender<TaskRequest>,
281}
282
283impl Drop for CallbackHandle {
284    fn drop(&mut self) {
285        self.send.send(TaskRequest::RemoveFilter(self.id))
286    }
287}
288
289enum TaskRequest {
290    NewFilter(Rpc<Box<dyn Send + FnMut(Notification<'_>)>, CallbackHandle>),
291    RemoveFilter(u64),
292}
293
294struct ListenerTask {
295    socket: PolledSocket<Socket>,
296    callbacks: Vec<Filter>,
297    recv: mesh::Receiver<TaskRequest>,
298    next_id: u64,
299}
300
301struct Filter {
302    id: u64,
303    func: Box<dyn Send + FnMut(Notification<'_>)>,
304}
305
306impl ListenerTask {
307    async fn run(self) {
308        if let Err(err) = self.run_inner().await {
309            tracing::error!(
310                error = err.as_ref() as &dyn std::error::Error,
311                "uevent failure"
312            );
313        }
314    }
315
316    async fn run_inner(mut self) -> anyhow::Result<()> {
317        let mut buf = [0; 4096];
318
319        enum Event {
320            Request(Option<TaskRequest>),
321            Read(io::Result<usize>),
322        }
323
324        loop {
325            let event = (
326                self.socket.read(&mut buf).map(Event::Read),
327                self.recv.next().map(Event::Request),
328            )
329                .race()
330                .await;
331
332            match event {
333                Event::Request(Some(request)) => match request {
334                    TaskRequest::NewFilter(rpc) => rpc.handle_sync(|filter_fn| {
335                        let id = self.next_id;
336                        self.next_id += 1;
337                        self.callbacks.push(Filter {
338                            func: filter_fn,
339                            id,
340                        });
341                        CallbackHandle {
342                            id,
343                            send: self.recv.sender(),
344                        }
345                    }),
346                    TaskRequest::RemoveFilter(id) => {
347                        self.callbacks
348                            .swap_remove(self.callbacks.iter().position(|f| f.id == id).unwrap());
349                    }
350                },
351                Event::Request(None) => break Ok(()),
352                Event::Read(r) => {
353                    match r {
354                        Ok(n) => {
355                            let buf = std::str::from_utf8(&buf[..n])
356                                .context("failed to parse uevent as utf-8 string")?;
357                            let uevent = parse_uevent(buf)?;
358                            for callback in &mut self.callbacks {
359                                (callback.func)(Notification::Event(&uevent));
360                            }
361                        }
362                        Err(e) => {
363                            // uevent socket is an unreliable source and in some cases (such as an
364                            // uevent flood) can overflow. Two ways to handle that. Either increase
365                            // the socket buffer size and hope that buffer doesn't overflow or wake up
366                            // the callers to have them rescan for the condition. We went with the latter
367                            // here as that has a higher degree of reliability.
368                            if let Some(libc::ENOBUFS) = e.raw_os_error() {
369                                tracing::info!("uevent socket read error: {:?}", e);
370                                let properties: Vec<(&str, &str)> = vec![("RESCAN", "true")];
371                                let uevent = Uevent {
372                                    header: "rescan",
373                                    properties,
374                                };
375                                for callback in &mut self.callbacks {
376                                    (callback.func)(Notification::Event(&uevent));
377                                }
378                            } else {
379                                Err(e).context("uevent read failure")?;
380                            }
381                        }
382                    };
383                }
384            }
385        }
386    }
387}
388
389fn parse_uevent(buf: &str) -> anyhow::Result<Uevent<'_>> {
390    let mut lines = buf.split('\0');
391    let header = lines.next().context("missing event header")?;
392    let properties = lines.filter_map(|line| line.split_once('=')).collect();
393    tracing::debug!(header, ?properties, "uevent");
394    let mut uevent = Uevent { header, properties };
395    uevent.properties.sort_by_key(|(k, _)| *k);
396    Ok(uevent)
397}