user_driver/
interrupt.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Interrupt handling for user-mode device drivers.
5
6use parking_lot::Mutex;
7use std::future::poll_fn;
8use std::sync::Arc;
9use std::sync::atomic::AtomicBool;
10use std::sync::atomic::Ordering::Acquire;
11use std::sync::atomic::Ordering::Relaxed;
12use std::sync::atomic::Ordering::Release;
13use std::task::Context;
14use std::task::Poll;
15use std::task::Waker;
16
17/// A mapped device interrupt.
18///
19/// This interrupt can be cloned multiple times. Each clone will be separately
20/// pollable. Initially, the clone is in the not-signaled state, even if the
21/// original instance is signaled.
22pub struct DeviceInterrupt {
23    slot: Arc<DeviceInterruptSlot>,
24    inner: Arc<DeviceInterruptInner>,
25}
26
27impl Clone for DeviceInterrupt {
28    fn clone(&self) -> Self {
29        self.inner.new_interrupt()
30    }
31}
32
33impl Drop for DeviceInterrupt {
34    fn drop(&mut self) {
35        let mut slots = self.inner.slots.lock();
36        let i = slots
37            .iter()
38            .position(|s| Arc::ptr_eq(s, &self.slot))
39            .unwrap();
40        slots.swap_remove(i);
41        self.inner.slots_updated.store(true, Release);
42    }
43}
44
45impl DeviceInterrupt {
46    /// Polls the interrupt, returning `Poll::Ready` if the interrupt is
47    /// signaled.
48    pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<()> {
49        self.slot.poll(cx)
50    }
51
52    /// Waits for the interrupt to be signaled.
53    pub async fn wait(&mut self) {
54        poll_fn(|cx| self.poll(cx)).await
55    }
56}
57
58struct DeviceInterruptSlot {
59    signaled: AtomicBool,
60    waker: Mutex<Option<Waker>>,
61}
62
63impl DeviceInterruptSlot {
64    fn new() -> Self {
65        Self {
66            signaled: AtomicBool::new(false),
67            waker: Mutex::new(None),
68        }
69    }
70
71    fn poll(&self, cx: &mut Context<'_>) -> Poll<()> {
72        if self.signaled.load(Acquire) {
73            self.signaled.store(false, Release);
74            Poll::Ready(())
75        } else {
76            let _old_waker;
77            let mut waker = self.waker.lock();
78            // Check again under the lock.
79            if self.signaled.load(Acquire) {
80                self.signaled.store(false, Release);
81                return Poll::Ready(());
82            }
83            if waker.as_ref().is_none_or(|w| !w.will_wake(cx.waker())) {
84                _old_waker = waker.replace(cx.waker().clone());
85            }
86            Poll::Pending
87        }
88    }
89
90    fn signal(&self) {
91        self.signaled.store(true, Release);
92        if let Some(waker) = self.waker.lock().take() {
93            waker.wake();
94        }
95    }
96}
97
98struct DeviceInterruptInner {
99    slots: Mutex<Vec<Arc<DeviceInterruptSlot>>>,
100    slots_updated: AtomicBool,
101}
102
103impl DeviceInterruptInner {
104    fn new_interrupt(self: &Arc<Self>) -> DeviceInterrupt {
105        let slot = Arc::new(DeviceInterruptSlot::new());
106        self.slots.lock().push(slot.clone());
107        self.slots_updated.store(true, Release);
108        DeviceInterrupt {
109            slot,
110            inner: self.clone(),
111        }
112    }
113}
114
115/// A source of device interrupts.
116///
117/// This is intended to be used by the device backends to signal the
118/// [`DeviceInterrupt`] instances used by the drivers.
119pub struct DeviceInterruptSource {
120    slots: Vec<Arc<DeviceInterruptSlot>>,
121    inner: Arc<DeviceInterruptInner>,
122}
123
124impl DeviceInterruptSource {
125    /// Creates a new interrupt source.
126    pub fn new() -> Self {
127        Self {
128            inner: Arc::new(DeviceInterruptInner {
129                slots: Mutex::new(Vec::new()),
130                slots_updated: false.into(),
131            }),
132            slots: Vec::new(),
133        }
134    }
135
136    /// Creates a new interrupt target, each of which is notified when `signal`
137    /// is called.
138    pub fn new_target(&self) -> DeviceInterrupt {
139        self.inner.new_interrupt()
140    }
141
142    /// Signals all interrupt targets.
143    pub fn signal(&mut self) {
144        if self.inner.slots_updated.load(Acquire) {
145            let slots = self.inner.slots.lock();
146            self.inner.slots_updated.store(false, Relaxed);
147            self.slots.clone_from(&*slots);
148        }
149        for slot in &self.slots {
150            slot.signal();
151        }
152    }
153
154    /// Signals all interrupt targets without using the target cache. Use
155    /// `signal` instead when you have a mutable reference.
156    pub fn signal_uncached(&self) {
157        for slot in &*self.inner.slots.lock() {
158            slot.signal();
159        }
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::DeviceInterruptSource;
166    use pal_async::DefaultDriver;
167    use pal_async::async_test;
168    use pal_async::task::Spawn;
169
170    #[async_test]
171    async fn test_interrupt(driver: DefaultDriver) {
172        let mut source = DeviceInterruptSource::new();
173        let mut target = source.new_target();
174        source.signal();
175        target.wait().await;
176        let mut target_clone = target.clone();
177        let task = driver.spawn("test", async move { target_clone.wait().await });
178        source.signal();
179        task.await;
180        target.wait().await;
181    }
182}