1#![cfg(target_os = "linux")]
10
11mod bind_kobject_uevent;
12
13use anyhow::Context;
14use fs_err::PathExt;
15use futures::AsyncReadExt;
16use futures::FutureExt;
17use futures::StreamExt;
18use futures_concurrency::future::Race;
19use mesh::rpc::Rpc;
20use mesh::rpc::RpcSend;
21use pal_async::driver::SpawnDriver;
22use pal_async::socket::PolledSocket;
23use pal_async::task::Task;
24use socket2::Socket;
25use std::future::Future;
26use std::io;
27use std::path::Path;
28use std::path::PathBuf;
29use thiserror::Error;
30
31pub struct UeventListener {
33 _task: Task<()>,
34 send: mesh::Sender<TaskRequest>,
35}
36
37#[derive(Debug, Error)]
39#[error("failed to create uevent socket")]
40pub struct NewUeventListenerError(#[source] io::Error);
41
42impl UeventListener {
43 pub fn new(driver: &impl SpawnDriver) -> Result<Self, NewUeventListenerError> {
45 let socket =
46 bind_kobject_uevent::bind_kobject_uevent_socket().map_err(NewUeventListenerError)?;
47 let socket = PolledSocket::new(driver, socket).map_err(NewUeventListenerError)?;
48 let (send, recv) = mesh::mpsc_channel();
49 let thing = ListenerTask {
50 socket,
51 callbacks: Vec::new(),
52 recv,
53 next_id: 0,
54 };
55 let task = driver.spawn("uevent", async move { thing.run().await });
56 Ok(Self { _task: task, send })
57 }
58
59 pub async fn add_custom_callback(
61 &self,
62 callback: impl 'static + Send + FnMut(Notification<'_>),
63 ) -> CallbackHandle {
64 self.send
65 .call(TaskRequest::NewFilter, Box::new(callback))
66 .await
67 .unwrap()
68 }
69
70 pub async fn add_block_resize_callback(
74 &self,
75 major: u32,
76 minor: u32,
77 mut notify: impl 'static + Send + FnMut(),
78 ) -> CallbackHandle {
79 self.add_custom_callback(move |event| match event {
80 Notification::Event(kvs) => {
81 if (kvs.get("RESCAN") == Some("true"))
82 || (kvs.get("RESIZE") == Some("1")
83 && kvs.get("SUBSYSTEM") == Some("block")
84 && kvs.get("ACTION") == Some("change")
85 && kvs.get("MAJOR").is_some_and(|x| x.parse() == Ok(major))
86 && kvs.get("MINOR").is_some_and(|x| x.parse() == Ok(minor)))
87 {
88 notify();
89 }
90 }
91 })
92 .await
93 }
94
95 pub async fn wait_for_matching_child<T, F, Fut>(&self, path: &Path, f: F) -> io::Result<T>
108 where
109 F: Fn(PathBuf, bool) -> Fut,
110 Fut: Future<Output = Option<T>>,
111 {
112 let scan_for_matching_child = async || {
113 for entry in path.fs_err_read_dir()? {
114 let entry = entry?;
115 if let Some(r) = f(entry.path(), false).await {
116 return Ok::<Option<T>, io::Error>(Some(r));
117 }
118 }
119 Ok(None)
120 };
121
122 if path.exists() {
124 if let Some(child) = scan_for_matching_child().await? {
125 return Ok(child);
126 }
127 }
128
129 self.wait_for_devpath(path).await?;
131 let path = path.fs_err_canonicalize()?;
132 let path_clone = path.clone();
133 let parent_devpath = path
134 .strip_prefix("/sys")
135 .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid devpath"))?
136 .to_path_buf();
137
138 let (send, mut recv) = mesh::channel();
139 let _handle = self
140 .add_custom_callback({
141 move |notification| {
142 match notification {
143 Notification::Event(uevent) => {
144 if uevent.get("RESCAN") == Some("true") {
148 if let Ok(read_dir) = path_clone.fs_err_read_dir() {
149 for entry in read_dir {
150 if let Ok(sub_entry) = entry {
151 send.send((sub_entry.path(), false));
152 }
153 }
154 }
155 } else if uevent.get("ACTION") == Some("add") {
156 let Some(devpath) = uevent.get("DEVPATH") else {
157 return;
158 };
159 let devpath = Path::new(&devpath[1..]);
161 if devpath.parent() == Some(&parent_devpath) {
162 send.send((Path::new("/sys").join(devpath), true));
163 }
164 }
165 }
166 }
167 }
168 })
169 .await;
170
171 if let Some(child) = scan_for_matching_child().await? {
172 return Ok(child);
173 }
174
175 tracing::debug!(path = %path.display(), "waiting for child nodes");
176 while let Some((path, is_uevent)) = recv.next().await {
177 if let Some(r) = f(path, is_uevent).await {
178 return Ok(r);
179 }
180 }
181
182 Err(io::Error::new(
183 io::ErrorKind::InvalidInput,
184 "Did not find a matching path",
185 ))
186 }
187
188 pub async fn wait_for_devpath(&self, path: &Path) -> io::Result<()> {
197 if path.exists() {
199 return Ok(());
200 }
201
202 let (send, recv) = mesh::oneshot();
204 let _handle = self
205 .add_custom_callback({
206 let path = path.to_owned();
207 let mut send = Some(send);
208 move |event| {
209 if send.is_none() {
210 return;
211 }
212 match event {
213 Notification::Event(uevent) => {
214 if (uevent.get("ACTION") == Some("add"))
215 || (uevent.get("RESCAN") == Some("true"))
216 {
217 let r = path.fs_err_symlink_metadata();
218 if !matches!(&r, Err(err) if err.kind() == io::ErrorKind::NotFound)
219 {
220 send.take().unwrap().send(r);
221 }
222 }
223 }
224 }
225 }
226 })
227 .await;
228
229 let r = match path.fs_err_symlink_metadata() {
232 Ok(m) => Ok(m),
233 Err(err) if err.kind() == io::ErrorKind::NotFound => {
234 tracing::debug!(path = %path.display(), "waiting for devpath");
235 recv.await.unwrap()
236 }
237 Err(err) => Err(err),
238 };
239 r?;
240 Ok(())
241 }
242}
243
244pub enum Notification<'a> {
246 Event(&'a Uevent<'a>),
248}
249
250pub struct Uevent<'a> {
252 header: &'a str,
253 properties: Vec<(&'a str, &'a str)>,
254}
255
256impl Uevent<'_> {
257 pub fn header(&self) -> &str {
259 self.header
260 }
261
262 pub fn get(&self, key: &str) -> Option<&str> {
264 let i = self
265 .properties
266 .binary_search_by_key(&key, |(k, _)| k)
267 .ok()?;
268 Some(self.properties[i].1)
269 }
270}
271
272#[must_use]
277#[derive(Debug)]
278pub struct CallbackHandle {
279 id: u64,
280 send: mesh::Sender<TaskRequest>,
281}
282
283impl Drop for CallbackHandle {
284 fn drop(&mut self) {
285 self.send.send(TaskRequest::RemoveFilter(self.id))
286 }
287}
288
289enum TaskRequest {
290 NewFilter(Rpc<Box<dyn Send + FnMut(Notification<'_>)>, CallbackHandle>),
291 RemoveFilter(u64),
292}
293
294struct ListenerTask {
295 socket: PolledSocket<Socket>,
296 callbacks: Vec<Filter>,
297 recv: mesh::Receiver<TaskRequest>,
298 next_id: u64,
299}
300
301struct Filter {
302 id: u64,
303 func: Box<dyn Send + FnMut(Notification<'_>)>,
304}
305
306impl ListenerTask {
307 async fn run(self) {
308 if let Err(err) = self.run_inner().await {
309 tracing::error!(
310 error = err.as_ref() as &dyn std::error::Error,
311 "uevent failure"
312 );
313 }
314 }
315
316 async fn run_inner(mut self) -> anyhow::Result<()> {
317 let mut buf = [0; 4096];
318
319 enum Event {
320 Request(Option<TaskRequest>),
321 Read(io::Result<usize>),
322 }
323
324 loop {
325 let event = (
326 self.socket.read(&mut buf).map(Event::Read),
327 self.recv.next().map(Event::Request),
328 )
329 .race()
330 .await;
331
332 match event {
333 Event::Request(Some(request)) => match request {
334 TaskRequest::NewFilter(rpc) => rpc.handle_sync(|filter_fn| {
335 let id = self.next_id;
336 self.next_id += 1;
337 self.callbacks.push(Filter {
338 func: filter_fn,
339 id,
340 });
341 CallbackHandle {
342 id,
343 send: self.recv.sender(),
344 }
345 }),
346 TaskRequest::RemoveFilter(id) => {
347 self.callbacks
348 .swap_remove(self.callbacks.iter().position(|f| f.id == id).unwrap());
349 }
350 },
351 Event::Request(None) => break Ok(()),
352 Event::Read(r) => {
353 match r {
354 Ok(n) => {
355 let buf = std::str::from_utf8(&buf[..n])
356 .context("failed to parse uevent as utf-8 string")?;
357 let uevent = parse_uevent(buf)?;
358 for callback in &mut self.callbacks {
359 (callback.func)(Notification::Event(&uevent));
360 }
361 }
362 Err(e) => {
363 if let Some(libc::ENOBUFS) = e.raw_os_error() {
369 tracing::info!("uevent socket read error: {:?}", e);
370 let properties: Vec<(&str, &str)> = vec![("RESCAN", "true")];
371 let uevent = Uevent {
372 header: "rescan",
373 properties,
374 };
375 for callback in &mut self.callbacks {
376 (callback.func)(Notification::Event(&uevent));
377 }
378 } else {
379 Err(e).context("uevent read failure")?;
380 }
381 }
382 };
383 }
384 }
385 }
386 }
387}
388
389fn parse_uevent(buf: &str) -> anyhow::Result<Uevent<'_>> {
390 let mut lines = buf.split('\0');
391 let header = lines.next().context("missing event header")?;
392 let properties = lines.filter_map(|line| line.split_once('=')).collect();
393 tracing::debug!(header, ?properties, "uevent");
394 let mut uevent = Uevent { header, properties };
395 uevent.properties.sort_by_key(|(k, _)| *k);
396 Ok(uevent)
397}