minircu/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Minimal RCU (Read-Copy-Update) implementation
5//!
6//! This crate provides a minimal Read-Copy-Update (RCU) synchronization
7//! mechanism specifically designed for OpenVMM use cases. RCU is a
8//! synchronization technique that allows multiple readers to access shared data
9//! concurrently with writers by ensuring that writers create new versions of
10//! data while readers continue using old versions.
11//!
12//! This is similar to a reader-writer lock except that readers never wait:
13//! writers publish the new version of the data and then wait for all readers to
14//! finish using the old version before freeing it. This allows for very low
15//! overhead on the read side, as readers do not need to acquire locks.
16//!
17//! ## Usage
18//!
19//! Basic usage with the global domain:
20//!
21//! ```rust
22//! // Execute code in a read-side critical section
23//! let result = minircu::global().run(|| {
24//!     // Access shared data safely here.
25//!     42
26//! });
27//!
28//! // Wait for all current readers to finish their critical sections.
29//! // This is typically called by writers after updating data.
30//! minircu::global().synchronize_blocking();
31//! ```
32//!
33//! ## Quiescing
34//!
35//! To optimize synchronization, threads can explicitly quiesce when it is not
36//! expected to enter a critical section for a while. The RCU domain can skip
37//! issuing a memory barrier when all threads are quiesced.
38//!
39//! ```rust
40//! use minircu::global;
41//!
42//! // Mark the current thread as quiesced.
43//! global().quiesce();
44//! ```
45//!
46//! ## Asynchronous Support
47//!
48//! The crate provides async-compatible methods for quiescing and
49//! synchronization:
50//!
51//! ```rust
52//! use minircu::global;
53//!
54//! async fn example() {
55//!     // Quiesce whenever future returns Poll::Pending
56//!     global().quiesce_on_pending(async {
57//!         loop {
58//!             // Async code here.
59//!             global().run(|| {
60//!                 // Access shared data safely here.
61//!             });
62//!         }
63//!     }).await;
64//!
65//!     // Asynchronous synchronization
66//!     global().synchronize(|duration| async move {
67//!         // This should be a sleep call, e.g. using tokio::time::sleep.
68//!         std::future::pending().await
69//!     }).await;
70//! }
71//! ```
72//!
73//! ## Gotchas
74//!
75//! * Avoid blocking or long-running operations in critical sections as they can
76//!   delay writers or cause deadlocks.
77//! * Never call [`synchronize`](RcuDomain::synchronize) or
78//!   [`synchronize_blocking`](RcuDomain::synchronize_blocking) from within a critical
79//!   section (will panic).
80//! * For best performance, ensure all threads in your process call `quiesce`
81//!   when a thread is going to sleep or block.
82//!
83//! ## Implementation Notes
84//!
85//! On Windows and Linux, the read-side critical section avoids any processor
86//! memory barriers. It achieves this by having the write side broadcast a
87//! memory barrier to all threads in the process when needed for
88//! synchronization, via the `membarrier` syscall on Linux and
89//! `FlushProcessWriteBuffers` on Windows.
90//!
91//! On other platforms, which do not support this functionality, the read-side
92//! critical section uses a memory fence. This makes the read side more
93//! expensive on these platforms, but it is still cheaper than a mutex or
94//! reader-writer lock.
95
96// UNSAFETY: needed to access TLS from a remote thread and to call platform APIs
97// for issuing process-wide memory barriers.
98#![expect(unsafe_code)]
99
100/// Provides the environment-specific `membarrier` and `access_fence`
101/// implementations.
102#[cfg_attr(target_os = "linux", path = "linux.rs")]
103#[cfg_attr(windows, path = "windows.rs")]
104#[cfg_attr(not(any(windows, target_os = "linux")), path = "other.rs")]
105mod sys;
106
107use event_listener::Event;
108use event_listener::Listener;
109use parking_lot::Mutex;
110use std::cell::Cell;
111use std::future::Future;
112use std::future::poll_fn;
113use std::ops::Deref;
114use std::pin::pin;
115use std::sync::atomic::AtomicU64;
116use std::sync::atomic::Ordering::Acquire;
117use std::sync::atomic::Ordering::Relaxed;
118use std::sync::atomic::Ordering::Release;
119use std::sync::atomic::Ordering::SeqCst;
120use std::sync::atomic::fence;
121use std::task::Poll;
122use std::thread::LocalKey;
123use std::thread::Thread;
124use std::time::Duration;
125use std::time::Instant;
126
127/// Defines a new RCU domain, which can be synchronized with separately from
128/// other domains.
129///
130/// Usually you just want to use [`global`], the global domain.
131///
132/// Don't export this until we have a use case. We may want to make `quiesce`
133/// apply to all domains, or something like that.
134macro_rules! define_rcu_domain {
135    ($(#[$a:meta])* $vis:vis $name:ident) => {
136        $(#[$a])*
137        $vis const fn $name() -> $crate::RcuDomain {
138            static DATA: $crate::RcuData = $crate::RcuData::new();
139            thread_local! {
140                static TLS: $crate::ThreadData = const { $crate::ThreadData::new() };
141            }
142            $crate::RcuDomain::new(&TLS, &DATA)
143        }
144    };
145}
146
147define_rcu_domain! {
148    /// The global RCU domain.
149    pub global
150}
151
152/// An RCU synchronization domain.
153#[derive(Copy, Clone)]
154pub struct RcuDomain {
155    tls: &'static LocalKey<ThreadData>,
156    data: &'static RcuData,
157}
158
159impl std::fmt::Debug for RcuDomain {
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        let Self { tls: _, data } = self;
162        f.debug_struct("RcuDomain").field("data", data).finish()
163    }
164}
165
166/// Domain-global RCU state.
167#[doc(hidden)]
168#[derive(Debug)]
169pub struct RcuData {
170    /// The threads that have registered with this domain.
171    threads: Mutex<Vec<ThreadEntry>>,
172    /// The current sequence number.
173    seq: AtomicU64,
174    /// The event that is signaled when a thread exits a critical section and
175    /// there has been a sequence number update.
176    event: Event,
177    /// The number of membarriers issued.
178    membarriers: AtomicU64,
179}
180
181/// The entry in the thread list for a registered thread.
182#[derive(Debug)]
183struct ThreadEntry {
184    /// The pointer to the sequence number for this thread. The [`ThreadData`]
185    /// TLS destructor will remove this entry, so this is safe to dereference.
186    seq_ptr: TlsRef<AtomicU64>,
187    /// The last sequence number that a synchronizer can know this thread has
188    /// observed, without issuing membarriers or looking at the thread's TLS
189    /// data.
190    observed_seq: u64,
191    /// The thread that this entry is for. Used for debugging and tracing.
192    thread: Thread,
193}
194
195/// A pointer representing a valid reference to a value.
196struct TlsRef<T>(*const T);
197
198impl<T> Deref for TlsRef<T> {
199    type Target = T;
200
201    fn deref(&self) -> &Self::Target {
202        // SAFETY: This is known to point to valid TLS data for its lifetime, since the TLS
203        // drop implementation will remove this entry from the list.
204        unsafe { &*self.0 }
205    }
206}
207
208// SAFETY: Since this represents a reference to T, it is `Send` if `&T` is
209// `Send`.
210unsafe impl<T> Send for TlsRef<T> where for<'a> &'a T: Send {}
211// SAFETY: Since this represents a reference to T, it is `Sync` if `&T` is
212// `Sync`.
213unsafe impl<T> Sync for TlsRef<T> where for<'a> &'a T: Sync {}
214
215impl<T: std::fmt::Debug> std::fmt::Debug for TlsRef<T> {
216    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217        (**self).fmt(f)
218    }
219}
220
221impl RcuData {
222    /// Used by [`define_rcu_domain!`] to create a new RCU domain.
223    #[doc(hidden)]
224    pub const fn new() -> Self {
225        RcuData {
226            threads: Mutex::new(Vec::new()),
227            seq: AtomicU64::new(SEQ_FIRST),
228            event: Event::new(),
229            membarriers: AtomicU64::new(0),
230        }
231    }
232}
233
234/// The per-thread TLS data.
235#[doc(hidden)]
236pub struct ThreadData {
237    /// The current sequence number for the thread.
238    current_seq: AtomicU64,
239    /// The RCU domain this thread is registered with.
240    data: Cell<Option<&'static RcuData>>,
241}
242
243impl std::fmt::Debug for ThreadData {
244    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245        let Self {
246            current_seq: my_seq,
247            data: _,
248        } = self;
249        f.debug_struct("ThreadData")
250            .field("my_seq", my_seq)
251            .finish()
252    }
253}
254
255impl Drop for ThreadData {
256    fn drop(&mut self) {
257        if let Some(data) = self.data.get() {
258            {
259                let mut threads = data.threads.lock();
260                let i = threads
261                    .iter()
262                    .position(|x| x.seq_ptr.0 == &self.current_seq)
263                    .unwrap();
264                threads.swap_remove(i);
265            }
266            data.event.notify(!0usize);
267        }
268    }
269}
270
271impl ThreadData {
272    /// Used by [`define_rcu_domain!`] to create a new RCU domain.
273    #[doc(hidden)]
274    pub const fn new() -> Self {
275        ThreadData {
276            current_seq: AtomicU64::new(SEQ_NONE),
277            data: Cell::new(None),
278        }
279    }
280}
281
282/// The thread has not yet registered with the RCU domain.
283const SEQ_NONE: u64 = 0;
284/// The bit set when the thread in a critical section.
285const SEQ_MASK_BUSY: u64 = 1;
286/// The value the sequence number is incremented by each synchronize call.
287const SEQ_INCREMENT: u64 = 2;
288/// The sequence value for a quiesced thread. The thread will issue a full
289/// memory barrier when leaving this state.
290const SEQ_QUIESCED: u64 = 2;
291/// The first actual sequence number.
292const SEQ_FIRST: u64 = 4;
293
294impl RcuDomain {
295    #[doc(hidden)]
296    pub const fn new(tls: &'static LocalKey<ThreadData>, data: &'static RcuData) -> Self {
297        RcuDomain { tls, data }
298    }
299
300    /// Runs `f` in a critical section. Calls to
301    /// [`synchronize`](Self::synchronize) or
302    /// [`synchronize_blocking`](Self::synchronize_blocking) for the same RCU root will
303    /// block until `f` returns.
304    ///
305    /// In general, you should avoid blocking the thread in `f`, since that can
306    /// slow calls to [`synchronize`](Self::synchronize) and can potentially
307    /// cause deadlocks.
308    pub fn run<F, R>(self, f: F) -> R
309    where
310        F: FnOnce() -> R,
311    {
312        self.tls.with(|x| x.run(self.data, f))
313    }
314
315    /// Quiesce the current thread.
316    ///
317    /// This can speed up calls to [`synchronize`](Self::synchronize) or
318    /// [`synchronize_blocking`](Self::synchronize_blocking) by allowing the RCU domain
319    /// to skip issuing a membarrier if all threads are quiesced. In return, the
320    /// first call to [`run`](Self::run) after this will be slower, as it will
321    /// need to issue a memory barrier to leave the quiesced state.
322    pub fn quiesce(self) {
323        self.tls.with(|x| {
324            x.quiesce(self.data);
325        });
326    }
327
328    /// Runs `fut`, calling [`quiesce`](Self::quiesce) on the current thread
329    /// each time `fut` returns `Poll::Pending`.
330    pub async fn quiesce_on_pending<Fut>(self, fut: Fut) -> Fut::Output
331    where
332        Fut: Future,
333    {
334        let mut fut = pin!(fut);
335        poll_fn(|cx| {
336            self.tls.with(|x| {
337                let r = fut.as_mut().poll(cx);
338                x.quiesce(self.data);
339                r
340            })
341        })
342        .await
343    }
344
345    #[track_caller]
346    fn prepare_to_wait(&self) -> Option<u64> {
347        // Quiesce this thread so we don't wait on ourselves.
348        {
349            let this_seq = self.tls.with(|x| x.quiesce(self.data));
350            assert!(
351                this_seq == SEQ_NONE || this_seq == SEQ_QUIESCED,
352                "called synchronize() inside a critical section, {this_seq:#x}",
353            );
354        }
355        // Update the domain's sequence number.
356        let seq = self.data.seq.fetch_add(SEQ_INCREMENT, SeqCst) + SEQ_INCREMENT;
357        // We need to make sure all threads are quiesced, not busy, or have
358        // observed the new sequence number. To do this, we must synchronize the
359        // global sequence number update with changes to each thread's local
360        // sequence number. To do that, we will issue a membarrier, to broadcast
361        // a memory barrier to all threads in the process.
362        //
363        // First, try to avoid the membarrier if possible--if all threads are quiesced,
364        // then there is no need to issue a membarrier, because quiesced threads will issue
365        // a memory barrier when they leave the quiesced state.
366        if self
367            .data
368            .threads
369            .lock()
370            .iter_mut()
371            .all(|t| Self::is_thread_ready(t, seq, false))
372        {
373            return None;
374        }
375        // Keep a count for diagnostics purposes.
376        self.data.membarriers.fetch_add(1, Relaxed);
377        sys::membarrier();
378        Some(seq)
379    }
380
381    /// Synchronizes the RCU domain, blocking asynchronously until all threads
382    /// have exited their critical sections and observed the new sequence
383    /// number.
384    ///
385    /// `sleep` should be a function that sleeps for the specified duration.
386    pub async fn synchronize(self, mut sleep: impl AsyncFnMut(Duration)) {
387        let Some(seq) = self.prepare_to_wait() else {
388            return;
389        };
390        let mut wait = pin!(self.wait_threads_ready(seq));
391        let mut timeout = Duration::from_millis(100);
392        loop {
393            let mut sleep = pin!(sleep(timeout));
394            let ready = poll_fn(|cx| {
395                if let Poll::Ready(()) = wait.as_mut().poll(cx) {
396                    Poll::Ready(true)
397                } else if let Poll::Ready(()) = sleep.as_mut().poll(cx) {
398                    Poll::Ready(false)
399                } else {
400                    Poll::Pending
401                }
402            })
403            .await;
404            if ready {
405                break;
406            }
407            self.warn_stall(seq);
408            if timeout < Duration::from_secs(10) {
409                timeout *= 2;
410            }
411        }
412    }
413
414    /// Like [`synchronize`](Self::synchronize), but blocks the current thread
415    /// synchronously.
416    #[track_caller]
417    pub fn synchronize_blocking(self) {
418        let Some(seq) = self.prepare_to_wait() else {
419            return;
420        };
421        let mut timeout = Duration::from_millis(10);
422        while !self.wait_threads_ready_sync(seq, Instant::now() + timeout) {
423            self.warn_stall(seq);
424            if timeout < Duration::from_secs(10) {
425                timeout *= 2;
426            }
427        }
428    }
429
430    fn warn_stall(&self, target: u64) {
431        for thread in &mut *self.data.threads.lock() {
432            if !Self::is_thread_ready(thread, target, true) {
433                tracelimit::warn_ratelimited!(thread = thread.thread.name(), "rcu stall");
434            }
435        }
436    }
437
438    async fn wait_threads_ready(&self, target: u64) {
439        loop {
440            let event = self.data.event.listen();
441            if self.all_threads_ready(target, true) {
442                break;
443            }
444            event.await;
445        }
446    }
447
448    #[must_use]
449    fn wait_threads_ready_sync(&self, target: u64, deadline: Instant) -> bool {
450        loop {
451            let event = self.data.event.listen();
452            if self.all_threads_ready(target, true) {
453                break;
454            }
455            if event.wait_deadline(deadline).is_none() {
456                return false;
457            }
458        }
459        true
460    }
461
462    fn all_threads_ready(&self, target: u64, issued_barrier: bool) -> bool {
463        self.data
464            .threads
465            .lock()
466            .iter_mut()
467            .all(|thread| Self::is_thread_ready(thread, target, issued_barrier))
468    }
469
470    fn is_thread_ready(thread: &mut ThreadEntry, target: u64, issued_barrier: bool) -> bool {
471        if thread.observed_seq >= target {
472            return true;
473        }
474        let seq = thread.seq_ptr.load(Relaxed);
475        assert_ne!(seq, SEQ_NONE);
476        if seq & !SEQ_MASK_BUSY < target {
477            if seq & SEQ_MASK_BUSY != 0 {
478                // The thread is actively running in a critical section.
479                return false;
480            }
481            if seq != SEQ_QUIESCED {
482                // The thread is not quiesced. If a barrier was issued, then it
483                // has observed the new sequence number. It may be busy (but
484                // this CPU has not observed the write yet), but it must be busy
485                // with a newer sequence number.
486                //
487                // If a barrier was not issued, then it is possible that the
488                // thread is busy with an older sequence number. In this case,
489                // we will need to issue a membarrier to observe the value of
490                // the busy bit accurately.
491                assert!(seq >= SEQ_FIRST, "{seq}");
492                if !issued_barrier {
493                    return false;
494                }
495            }
496        }
497        thread.observed_seq = target;
498        true
499    }
500}
501
502impl ThreadData {
503    fn run<F, R>(&self, data: &'static RcuData, f: F) -> R
504    where
505        F: FnOnce() -> R,
506    {
507        // Mark the thread as busy.
508        let seq = self.current_seq.load(Relaxed);
509        self.current_seq.store(seq | SEQ_MASK_BUSY, Relaxed);
510        if seq < SEQ_FIRST {
511            // The thread was quiesced or not registered. Register it now.
512            if seq == SEQ_NONE {
513                self.start(data, seq);
514            } else {
515                debug_assert!(seq == SEQ_QUIESCED || seq & SEQ_MASK_BUSY != 0, "{seq:#x}");
516            }
517            // Use a full memory barrier to ensure the write side observes that
518            // the thread is no longer quiesced before calling `f`.
519            fence(SeqCst);
520        }
521        // Ensure accesses in `f` are bounded by setting the busy bit. Note that
522        // this and other fences are just compiler fences; the write side must
523        // call `membarrier` to dynamically turn them into processor memory
524        // barriers, so to speak.
525        sys::access_fence(Acquire);
526        let r = f();
527        sys::access_fence(Release);
528        // Clear the busy bit.
529        self.current_seq.store(seq, Relaxed);
530        // Ensure the busy bit clear is visible to the write side, then read the
531        // new sequence number, to synchronize with the sequence update path.
532        sys::access_fence(SeqCst);
533        let new_seq = data.seq.load(Relaxed);
534        if new_seq != seq {
535            // The domain's current sequence number has changed. Update it and
536            // wake up any waiters.
537            self.update_seq(data, seq, new_seq);
538        }
539        r
540    }
541
542    #[inline(never)]
543    fn start(&self, data: &'static RcuData, seq: u64) {
544        if seq == SEQ_NONE {
545            // Add the thread to the list of known threads in this domain.
546            assert!(self.data.get().is_none());
547            data.threads.lock().push(ThreadEntry {
548                seq_ptr: TlsRef(&self.current_seq),
549                observed_seq: SEQ_NONE,
550                thread: std::thread::current(),
551            });
552            // Remember the domain so that we can remove the thread from the list
553            // when it exits.
554            self.data.set(Some(data));
555        }
556    }
557
558    #[inline(never)]
559    fn update_seq(&self, data: &'static RcuData, seq: u64, new_seq: u64) {
560        if seq & SEQ_MASK_BUSY != 0 {
561            // Nested call. Skip.
562            return;
563        }
564        assert!(
565            new_seq >= SEQ_FIRST && new_seq & SEQ_MASK_BUSY == 0,
566            "{new_seq}"
567        );
568        self.current_seq.store(new_seq, Relaxed);
569        // Wake up any waiters. We don't know how many threads are still in a
570        // critical section, so just wake up the writers every time and let them
571        // figure it out.
572        data.event.notify(!0usize);
573    }
574
575    fn quiesce(&self, data: &'static RcuData) -> u64 {
576        let seq = self.current_seq.load(Relaxed);
577        if seq >= SEQ_FIRST && seq & SEQ_MASK_BUSY == 0 {
578            self.current_seq.store(SEQ_QUIESCED, Relaxed);
579            data.event.notify(!0usize);
580            SEQ_QUIESCED
581        } else {
582            seq
583        }
584    }
585}
586
587#[cfg(test)]
588mod tests {
589    use crate::RcuDomain;
590    use pal_async::DefaultDriver;
591    use pal_async::DefaultPool;
592    use pal_async::async_test;
593    use pal_async::task::Spawn;
594    use pal_async::timer::PolledTimer;
595    use std::sync::atomic::Ordering;
596    use test_with_tracing::test;
597
598    async fn sync(driver: &DefaultDriver, rcu: RcuDomain) {
599        let mut timer = PolledTimer::new(driver);
600        rcu.synchronize(async |timeout| {
601            timer.sleep(timeout).await;
602        })
603        .await
604    }
605
606    #[async_test]
607    async fn test_rcu_single(driver: DefaultDriver) {
608        define_rcu_domain!(test_rcu);
609
610        test_rcu().run(|| {});
611        sync(&driver, test_rcu()).await;
612    }
613
614    #[async_test]
615    async fn test_rcu_nested(driver: DefaultDriver) {
616        define_rcu_domain!(test_rcu);
617
618        test_rcu().run(|| {
619            test_rcu().run(|| {});
620        });
621        sync(&driver, test_rcu()).await;
622    }
623
624    #[async_test]
625    async fn test_rcu_multi(driver: DefaultDriver) {
626        define_rcu_domain!(test_rcu);
627
628        let (thread, thread_driver) = DefaultPool::spawn_on_thread("test");
629        thread_driver
630            .spawn("test", async { test_rcu().run(|| {}) })
631            .await;
632
633        assert_eq!(test_rcu().data.membarriers.load(Ordering::Relaxed), 0);
634        sync(&driver, test_rcu()).await;
635        assert_eq!(test_rcu().data.membarriers.load(Ordering::Relaxed), 1);
636
637        drop(thread_driver);
638        thread.join().unwrap();
639    }
640
641    #[async_test]
642    async fn test_rcu_multi_quiesce(driver: DefaultDriver) {
643        define_rcu_domain!(test_rcu);
644
645        let (thread, thread_driver) = DefaultPool::spawn_on_thread("test");
646        thread_driver
647            .spawn(
648                "test",
649                test_rcu().quiesce_on_pending(async { test_rcu().run(|| {}) }),
650            )
651            .await;
652
653        assert_eq!(test_rcu().data.membarriers.load(Ordering::Relaxed), 0);
654        test_rcu().quiesce();
655        sync(&driver, test_rcu()).await;
656        assert_eq!(test_rcu().data.membarriers.load(Ordering::Relaxed), 0);
657
658        drop(thread_driver);
659        thread.join().unwrap();
660    }
661}