pal_async/
multi_waker.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! A multi-waker that multiplexes multiple wakers onto a single waker.
5
6// UNSAFETY: Implementing a `RawWakerVTable`.
7#![expect(unsafe_code)]
8
9use parking_lot::Mutex;
10use std::sync::Arc;
11use std::task::Context;
12use std::task::Poll;
13use std::task::RawWaker;
14use std::task::RawWakerVTable;
15use std::task::Waker;
16
17/// Object to multiplex multiple wakers onto a single waker.
18#[derive(Debug)]
19pub struct MultiWaker<const N: usize> {
20    inner: Arc<Inner<N>>,
21}
22
23#[derive(Debug)]
24struct Inner<const N: usize> {
25    wakers: Mutex<[Option<Waker>; N]>,
26}
27
28impl<const N: usize> Inner<N> {
29    /// Sets the waker for index `i`.
30    fn set(&self, i: usize, waker: &Waker) {
31        let mut wakers = self.wakers.lock();
32        if !wakers[i].as_ref().is_some_and(|old| old.will_wake(waker)) {
33            let _old = wakers[i].replace(waker.clone());
34            drop(wakers);
35        }
36    }
37
38    /// Wakes any wakers that have been set.
39    fn wake(&self) {
40        let wakers = std::mem::replace(&mut *self.wakers.lock(), [(); N].map(|_| None));
41        for waker in wakers.into_iter().flatten() {
42            waker.wake();
43        }
44    }
45}
46
47struct Ref<'a, 'b, const N: usize> {
48    inner: &'a Arc<Inner<N>>,
49    cx_waker: &'b Waker,
50    index: usize,
51}
52
53impl<const N: usize> MultiWaker<N> {
54    /// Creates a new instance.
55    pub fn new() -> Self {
56        Self {
57            inner: Arc::new(Inner {
58                wakers: Mutex::new([(); N].map(|_| None)),
59            }),
60        }
61    }
62
63    /// Calls a poll function on behalf of entry `index`, passing a `Context` that
64    /// ensures that each index's waker is called on wake.
65    pub fn poll_wrapped<R>(
66        &self,
67        cx: &mut Context<'_>,
68        index: usize,
69        f: impl FnOnce(&mut Context<'_>) -> Poll<R>,
70    ) -> Poll<R> {
71        let waker_ref = Ref {
72            inner: &self.inner,
73            index,
74            cx_waker: cx.waker(),
75        };
76        // SAFETY:
77        // - waker_ref and its contents are valid for the duration of the call.
78        // - The waker is only used for the duration of the call.
79        // - Ref is Send + Sync, enforced by a test.
80        // - All functions passed in the vtable expect a pointer to a Ref.
81        // - All functions passed in the vtable perform only thread-safe operations.
82        let waker = unsafe {
83            Waker::from_raw(RawWaker::new(
84                std::ptr::from_ref(&waker_ref).cast(),
85                &RawWakerVTable::new(ref_clone::<N>, ref_wake::<N>, ref_wake::<N>, ref_drop),
86            ))
87        };
88        let mut cx = Context::from_waker(&waker);
89        f(&mut cx)
90    }
91}
92
93/// # Safety
94///
95/// The caller must guarantee that the pointer is valid and pointing to a Ref<N>.
96unsafe fn ref_clone<const N: usize>(ptr: *const ()) -> RawWaker {
97    // SAFETY: This function is only called through our own waker, which guarantees that the
98    // pointer is valid and pointing to a Ref.
99    let thing: &Ref<'_, '_, N> = unsafe { &*(ptr.cast()) };
100    thing.inner.set(thing.index, thing.cx_waker);
101    let waker = thing.inner.clone();
102    RawWaker::new(
103        Arc::into_raw(waker).cast(),
104        &RawWakerVTable::new(
105            val_clone::<N>,
106            val_wake::<N>,
107            val_wake_by_ref::<N>,
108            val_drop::<N>,
109        ),
110    )
111}
112
113/// # Safety
114///
115/// The caller must guarantee that the pointer is valid and pointing to a Ref<N>.
116unsafe fn ref_wake<const N: usize>(ptr: *const ()) {
117    // SAFETY: This function is only called through our own waker, which guarantees that the
118    // pointer is valid and pointing to a Ref.
119    let thing: &Ref<'_, '_, N> = unsafe { &*(ptr.cast()) };
120    thing.inner.wake();
121    thing.cx_waker.wake_by_ref();
122}
123
124fn ref_drop(_: *const ()) {}
125
126/// # Safety
127///
128/// The caller must guarantee that the pointer is valid and pointing to a Arc<Inner>.
129unsafe fn val_drop<const N: usize>(ptr: *const ()) {
130    // SAFETY: This function is only called through our own waker, which guarantees that the
131    // pointer is valid and pointing to a Arc<Inner>.
132    unsafe { Arc::decrement_strong_count(ptr.cast::<Inner<N>>()) };
133}
134
135/// # Safety
136///
137/// The caller must guarantee that the pointer is valid and pointing to a Arc<Inner>.
138unsafe fn val_wake_by_ref<const N: usize>(ptr: *const ()) {
139    // SAFETY: This function is only called through our own waker, which guarantees that the
140    // pointer is valid and pointing to a Arc<Inner>.
141    let waker = unsafe { &*ptr.cast::<Inner<N>>() };
142    waker.wake();
143}
144
145/// # Safety
146///
147/// The caller must guarantee that the pointer is valid and pointing to a Arc<Inner>.
148unsafe fn val_wake<const N: usize>(ptr: *const ()) {
149    // SAFETY: This function is only called through our own waker, which guarantees that the
150    // pointer is valid and pointing to a Arc<Inner>.
151    let waker = unsafe { Arc::from_raw(ptr.cast::<Inner<N>>()) };
152    waker.wake();
153}
154
155/// # Safety
156///
157/// The caller must guarantee that the pointer is valid and pointing to a Arc<Inner>.
158unsafe fn val_clone<const N: usize>(ptr: *const ()) -> RawWaker {
159    // SAFETY: This function is only called through our own waker, which guarantees that the
160    // pointer is valid and pointing to a Arc<Inner>.
161    unsafe {
162        Arc::increment_strong_count(ptr.cast::<Inner<N>>());
163    }
164    RawWaker::new(
165        ptr,
166        &RawWakerVTable::new(
167            val_clone::<N>,
168            val_wake::<N>,
169            val_wake_by_ref::<N>,
170            val_drop::<N>,
171        ),
172    )
173}
174
175#[cfg(test)]
176mod tests {
177    use super::MultiWaker;
178    use futures::executor::block_on;
179    use parking_lot::Mutex;
180    use std::future::poll_fn;
181    use std::sync::Arc;
182    use std::task::Context;
183    use std::task::Poll;
184    use std::task::Waker;
185    use std::time::Duration;
186
187    #[derive(Default)]
188    struct SlimEvent {
189        state: Mutex<SlimEventState>,
190    }
191
192    #[derive(Default)]
193    struct SlimEventState {
194        done: bool,
195        waker: Option<Waker>,
196    }
197
198    impl SlimEvent {
199        fn signal(&self) {
200            let mut state = self.state.lock();
201            state.done = true;
202            let waker = state.waker.take();
203            drop(state);
204            if let Some(waker) = waker {
205                waker.wake();
206            }
207        }
208
209        fn poll_wait(&self, cx: &mut Context<'_>) -> Poll<()> {
210            let mut state = self.state.lock();
211            if state.done {
212                Poll::Ready(())
213            } else {
214                let _old = state.waker.insert(cx.waker().clone());
215                drop(state);
216                Poll::Pending
217            }
218        }
219    }
220
221    #[test]
222    fn test_multiwaker() {
223        let mw = Arc::new(MultiWaker::<2>::new());
224        let event = Arc::new(SlimEvent::default());
225        let f = |index| {
226            let mw = mw.clone();
227            let event = event.clone();
228            move || {
229                block_on(async {
230                    poll_fn(|cx| mw.poll_wrapped(cx, index, |cx| event.poll_wait(cx))).await
231                })
232            }
233        };
234        let t1 = std::thread::spawn(f(0));
235        let t2 = std::thread::spawn(f(1));
236        std::thread::sleep(Duration::from_millis(100));
237        event.signal();
238
239        t1.join().unwrap();
240        t2.join().unwrap();
241    }
242
243    #[test]
244    fn ref_is_send_sync() {
245        fn assert_send<T: Send>() {}
246        fn assert_sync<T: Sync>() {}
247        assert_send::<super::Ref<'_, '_, 1>>();
248        assert_sync::<super::Ref<'_, '_, 1>>();
249    }
250}