1#![expect(unsafe_code)]
8
9use parking_lot::Mutex;
10use std::sync::Arc;
11use std::task::Context;
12use std::task::Poll;
13use std::task::RawWaker;
14use std::task::RawWakerVTable;
15use std::task::Waker;
16
17#[derive(Debug)]
19pub struct MultiWaker<const N: usize> {
20 inner: Arc<Inner<N>>,
21}
22
23#[derive(Debug)]
24struct Inner<const N: usize> {
25 wakers: Mutex<[Option<Waker>; N]>,
26}
27
28impl<const N: usize> Inner<N> {
29 fn set(&self, i: usize, waker: &Waker) {
31 let mut wakers = self.wakers.lock();
32 if !wakers[i].as_ref().is_some_and(|old| old.will_wake(waker)) {
33 let _old = wakers[i].replace(waker.clone());
34 drop(wakers);
35 }
36 }
37
38 fn wake(&self) {
40 let wakers = std::mem::replace(&mut *self.wakers.lock(), [(); N].map(|_| None));
41 for waker in wakers.into_iter().flatten() {
42 waker.wake();
43 }
44 }
45}
46
47struct Ref<'a, 'b, const N: usize> {
48 inner: &'a Arc<Inner<N>>,
49 cx_waker: &'b Waker,
50 index: usize,
51}
52
53impl<const N: usize> MultiWaker<N> {
54 pub fn new() -> Self {
56 Self {
57 inner: Arc::new(Inner {
58 wakers: Mutex::new([(); N].map(|_| None)),
59 }),
60 }
61 }
62
63 pub fn poll_wrapped<R>(
66 &self,
67 cx: &mut Context<'_>,
68 index: usize,
69 f: impl FnOnce(&mut Context<'_>) -> Poll<R>,
70 ) -> Poll<R> {
71 let waker_ref = Ref {
72 inner: &self.inner,
73 index,
74 cx_waker: cx.waker(),
75 };
76 let waker = unsafe {
83 Waker::from_raw(RawWaker::new(
84 std::ptr::from_ref(&waker_ref).cast(),
85 &RawWakerVTable::new(ref_clone::<N>, ref_wake::<N>, ref_wake::<N>, ref_drop),
86 ))
87 };
88 let mut cx = Context::from_waker(&waker);
89 f(&mut cx)
90 }
91}
92
93unsafe fn ref_clone<const N: usize>(ptr: *const ()) -> RawWaker {
97 let thing: &Ref<'_, '_, N> = unsafe { &*(ptr.cast()) };
100 thing.inner.set(thing.index, thing.cx_waker);
101 let waker = thing.inner.clone();
102 RawWaker::new(
103 Arc::into_raw(waker).cast(),
104 &RawWakerVTable::new(
105 val_clone::<N>,
106 val_wake::<N>,
107 val_wake_by_ref::<N>,
108 val_drop::<N>,
109 ),
110 )
111}
112
113unsafe fn ref_wake<const N: usize>(ptr: *const ()) {
117 let thing: &Ref<'_, '_, N> = unsafe { &*(ptr.cast()) };
120 thing.inner.wake();
121 thing.cx_waker.wake_by_ref();
122}
123
124fn ref_drop(_: *const ()) {}
125
126unsafe fn val_drop<const N: usize>(ptr: *const ()) {
130 unsafe { Arc::decrement_strong_count(ptr.cast::<Inner<N>>()) };
133}
134
135unsafe fn val_wake_by_ref<const N: usize>(ptr: *const ()) {
139 let waker = unsafe { &*ptr.cast::<Inner<N>>() };
142 waker.wake();
143}
144
145unsafe fn val_wake<const N: usize>(ptr: *const ()) {
149 let waker = unsafe { Arc::from_raw(ptr.cast::<Inner<N>>()) };
152 waker.wake();
153}
154
155unsafe fn val_clone<const N: usize>(ptr: *const ()) -> RawWaker {
159 unsafe {
162 Arc::increment_strong_count(ptr.cast::<Inner<N>>());
163 }
164 RawWaker::new(
165 ptr,
166 &RawWakerVTable::new(
167 val_clone::<N>,
168 val_wake::<N>,
169 val_wake_by_ref::<N>,
170 val_drop::<N>,
171 ),
172 )
173}
174
175#[cfg(test)]
176mod tests {
177 use super::MultiWaker;
178 use futures::executor::block_on;
179 use parking_lot::Mutex;
180 use std::future::poll_fn;
181 use std::sync::Arc;
182 use std::task::Context;
183 use std::task::Poll;
184 use std::task::Waker;
185 use std::time::Duration;
186
187 #[derive(Default)]
188 struct SlimEvent {
189 state: Mutex<SlimEventState>,
190 }
191
192 #[derive(Default)]
193 struct SlimEventState {
194 done: bool,
195 waker: Option<Waker>,
196 }
197
198 impl SlimEvent {
199 fn signal(&self) {
200 let mut state = self.state.lock();
201 state.done = true;
202 let waker = state.waker.take();
203 drop(state);
204 if let Some(waker) = waker {
205 waker.wake();
206 }
207 }
208
209 fn poll_wait(&self, cx: &mut Context<'_>) -> Poll<()> {
210 let mut state = self.state.lock();
211 if state.done {
212 Poll::Ready(())
213 } else {
214 let _old = state.waker.insert(cx.waker().clone());
215 drop(state);
216 Poll::Pending
217 }
218 }
219 }
220
221 #[test]
222 fn test_multiwaker() {
223 let mw = Arc::new(MultiWaker::<2>::new());
224 let event = Arc::new(SlimEvent::default());
225 let f = |index| {
226 let mw = mw.clone();
227 let event = event.clone();
228 move || {
229 block_on(async {
230 poll_fn(|cx| mw.poll_wrapped(cx, index, |cx| event.poll_wait(cx))).await
231 })
232 }
233 };
234 let t1 = std::thread::spawn(f(0));
235 let t2 = std::thread::spawn(f(1));
236 std::thread::sleep(Duration::from_millis(100));
237 event.signal();
238
239 t1.join().unwrap();
240 t2.join().unwrap();
241 }
242
243 #[test]
244 fn ref_is_send_sync() {
245 fn assert_send<T: Send>() {}
246 fn assert_sync<T: Sync>() {}
247 assert_send::<super::Ref<'_, '_, 1>>();
248 assert_sync::<super::Ref<'_, '_, 1>>();
249 }
250}