1#![forbid(unsafe_code)]
7
8use crate::local_only::LocalOnly;
9use mesh::MeshPayload;
10use pal_async::driver::SpawnDriver;
11use pal_async::task::Task;
12use pal_async::wait::PolledWait;
13use pal_event::Event;
14use std::fmt::Debug;
15use std::sync::Arc;
16
17#[derive(Clone, Debug, MeshPayload)]
27pub struct Interrupt {
28 inner: InterruptInner,
29}
30
31impl Default for Interrupt {
32 fn default() -> Self {
33 Self::null()
34 }
35}
36
37impl Interrupt {
38 pub fn null() -> Self {
43 Self::from_event(Event::new())
45 }
46
47 pub fn null_event() -> Self {
52 Self::from_event(Event::new())
53 }
54
55 pub fn from_event(event: Event) -> Self {
59 Self {
60 inner: InterruptInner::Event(Arc::new(event)),
61 }
62 }
63
64 pub fn from_cell(cell: mesh::Cell<Event>) -> Self {
69 Self {
70 inner: InterruptInner::Cell(Arc::new(cell)),
71 }
72 }
73
74 pub fn from_fn<F>(f: F) -> Self
79 where
80 F: 'static + Send + Sync + Fn(),
81 {
82 Self {
83 inner: InterruptInner::Fn(LocalOnly(Arc::new(f))),
84 }
85 }
86
87 pub fn deliver(&self) {
89 match &self.inner {
90 InterruptInner::Event(event) => event.signal(),
91 InterruptInner::Cell(cell) => cell.with(|event| event.signal()),
92 InterruptInner::Fn(LocalOnly(f)) => f(),
93 }
94 }
95
96 pub fn event(&self) -> Option<&Event> {
98 match &self.inner {
99 InterruptInner::Event(event) => Some(event.as_ref()),
100 _ => None,
101 }
102 }
103
104 pub fn event_or_proxy(
112 &self,
113 driver: &impl SpawnDriver,
114 ) -> std::io::Result<(Event, Option<EventProxy>)> {
115 if let Some(event) = self.event() {
116 Ok((event.clone(), None))
117 } else {
118 let (proxy, event) = EventProxy::new(driver, self.clone())?;
119 Ok((event, Some(proxy)))
120 }
121 }
122}
123
124#[derive(Clone, MeshPayload)]
125enum InterruptInner {
126 Event(Arc<Event>),
127 Cell(Arc<mesh::Cell<Event>>),
128 Fn(LocalOnly<Arc<dyn Send + Sync + Fn()>>),
129}
130
131impl Debug for InterruptInner {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 match self {
134 InterruptInner::Event(_) => f.pad("Event"),
135 InterruptInner::Cell(_) => f.pad("Cell"),
136 InterruptInner::Fn(_) => f.pad("Fn"),
137 }
138 }
139}
140
141pub struct EventProxy {
148 _task: Task<()>,
149}
150
151impl EventProxy {
152 pub fn new(driver: &impl SpawnDriver, interrupt: Interrupt) -> std::io::Result<(Self, Event)> {
155 let event = Event::new();
156 let wait = PolledWait::new(driver, event.clone())?;
157 let task = driver.spawn("interrupt-event-proxy", async move {
158 Self::run(wait, interrupt).await;
159 });
160 Ok((Self { _task: task }, event))
161 }
162
163 async fn run(mut wait: PolledWait<Event>, interrupt: Interrupt) {
164 loop {
165 wait.wait().await.expect("wait should not fail");
166 interrupt.deliver();
167 }
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::Interrupt;
174 use pal_async::DefaultDriver;
175 use pal_async::async_test;
176 use std::sync::Arc;
177 use std::sync::atomic::{AtomicUsize, Ordering};
178
179 #[test]
180 fn test_interrupt_event() {
181 let event = pal_event::Event::new();
182 let interrupt = Interrupt::from_event(event.clone());
183 interrupt.deliver();
184 assert!(event.try_wait());
185 }
186
187 #[async_test]
188 async fn test_interrupt_cell() {
189 let mut event = pal_event::Event::new();
190 let (mut updater, cell) = mesh::cell(event.clone());
191 let interrupt = Interrupt::from_cell(cell);
192 interrupt.deliver();
193 assert!(event.try_wait());
194 event = pal_event::Event::new();
195 interrupt.deliver();
196 assert!(!event.try_wait());
197 updater.set(event.clone()).await;
198 interrupt.deliver();
199 assert!(event.try_wait());
200 }
201
202 #[async_test]
203 async fn test_event_or_proxy_event_backed(driver: DefaultDriver) {
204 let orig_event = pal_event::Event::new();
205 let interrupt = Interrupt::from_event(orig_event.clone());
206 let (event, proxy) = interrupt.event_or_proxy(&driver).unwrap();
207 assert!(proxy.is_none());
209 event.signal();
210 assert!(orig_event.try_wait());
211 }
212
213 #[async_test]
214 async fn test_event_or_proxy_fn_backed(driver: DefaultDriver) {
215 let count = Arc::new(AtomicUsize::new(0));
216 let count2 = count.clone();
217 let interrupt = Interrupt::from_fn(move || {
218 count2.fetch_add(1, Ordering::SeqCst);
219 });
220 let (event, proxy) = interrupt.event_or_proxy(&driver).unwrap();
221 assert!(proxy.is_some());
223 event.signal();
225 for _ in 0..100 {
227 if count.load(Ordering::SeqCst) > 0 {
228 break;
229 }
230 pal_async::timer::PolledTimer::new(&driver)
231 .sleep(std::time::Duration::from_millis(10))
232 .await;
233 }
234 assert_eq!(count.load(Ordering::SeqCst), 1);
235 }
236}