minircu/lib.rs
1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Minimal RCU (Read-Copy-Update) implementation
5//!
6//! This crate provides a minimal Read-Copy-Update (RCU) synchronization
7//! mechanism specifically designed for OpenVMM use cases. RCU is a
8//! synchronization technique that allows multiple readers to access shared data
9//! concurrently with writers by ensuring that writers create new versions of
10//! data while readers continue using old versions.
11//!
12//! This is similar to a reader-writer lock except that readers never wait:
13//! writers publish the new version of the data and then wait for all readers to
14//! finish using the old version before freeing it. This allows for very low
15//! overhead on the read side, as readers do not need to acquire locks.
16//!
17//! ## Usage
18//!
19//! Basic usage with the global domain:
20//!
21//! ```rust
22//! // Execute code in a read-side critical section
23//! let result = minircu::global().run(|| {
24//! // Access shared data safely here.
25//! 42
26//! });
27//!
28//! // Wait for all current readers to finish their critical sections.
29//! // This is typically called by writers after updating data.
30//! minircu::global().synchronize_blocking();
31//! ```
32//!
33//! ## Quiescing
34//!
35//! To optimize synchronization, threads can explicitly quiesce when it is not
36//! expected to enter a critical section for a while. The RCU domain can skip
37//! issuing a memory barrier when all threads are quiesced.
38//!
39//! ```rust
40//! use minircu::global;
41//!
42//! // Mark the current thread as quiesced.
43//! global().quiesce();
44//! ```
45//!
46//! ## Asynchronous Support
47//!
48//! The crate provides async-compatible methods for quiescing and
49//! synchronization:
50//!
51//! ```rust
52//! use minircu::global;
53//!
54//! async fn example() {
55//! // Quiesce whenever future returns Poll::Pending
56//! global().quiesce_on_pending(async {
57//! loop {
58//! // Async code here.
59//! global().run(|| {
60//! // Access shared data safely here.
61//! });
62//! }
63//! }).await;
64//!
65//! // Asynchronous synchronization
66//! global().synchronize(|duration| async move {
67//! // This should be a sleep call, e.g. using tokio::time::sleep.
68//! std::future::pending().await
69//! }).await;
70//! }
71//! ```
72//!
73//! ## Gotchas
74//!
75//! * Avoid blocking or long-running operations in critical sections as they can
76//! delay writers or cause deadlocks.
77//! * Never call [`synchronize`](RcuDomain::synchronize) or
78//! [`synchronize_blocking`](RcuDomain::synchronize_blocking) from within a critical
79//! section (will panic).
80//! * For best performance, ensure all threads in your process call `quiesce`
81//! when a thread is going to sleep or block.
82//!
83//! ## Implementation Notes
84//!
85//! On Windows and Linux, the read-side critical section avoids any processor
86//! memory barriers. It achieves this by having the write side broadcast a
87//! memory barrier to all threads in the process when needed for
88//! synchronization, via the `membarrier` syscall on Linux and
89//! `FlushProcessWriteBuffers` on Windows.
90//!
91//! On other platforms, which do not support this functionality, the read-side
92//! critical section uses a memory fence. This makes the read side more
93//! expensive on these platforms, but it is still cheaper than a mutex or
94//! reader-writer lock.
95
96// UNSAFETY: needed to access TLS from a remote thread and to call platform APIs
97// for issuing process-wide memory barriers.
98#![expect(unsafe_code)]
99
100/// Provides the environment-specific `membarrier` and `access_fence`
101/// implementations.
102#[cfg_attr(target_os = "linux", path = "linux.rs")]
103#[cfg_attr(windows, path = "windows.rs")]
104#[cfg_attr(not(any(windows, target_os = "linux")), path = "other.rs")]
105mod sys;
106
107use event_listener::Event;
108use event_listener::Listener;
109use parking_lot::Mutex;
110use std::cell::Cell;
111use std::future::Future;
112use std::future::poll_fn;
113use std::ops::Deref;
114use std::pin::pin;
115use std::sync::atomic::AtomicU64;
116use std::sync::atomic::Ordering::Acquire;
117use std::sync::atomic::Ordering::Relaxed;
118use std::sync::atomic::Ordering::Release;
119use std::sync::atomic::Ordering::SeqCst;
120use std::sync::atomic::fence;
121use std::task::Poll;
122use std::thread::LocalKey;
123use std::thread::Thread;
124use std::time::Duration;
125use std::time::Instant;
126
127/// Defines a new RCU domain, which can be synchronized with separately from
128/// other domains.
129///
130/// Usually you just want to use [`global`], the global domain.
131///
132/// Don't export this until we have a use case. We may want to make `quiesce`
133/// apply to all domains, or something like that.
134macro_rules! define_rcu_domain {
135 ($(#[$a:meta])* $vis:vis $name:ident) => {
136 $(#[$a])*
137 $vis const fn $name() -> $crate::RcuDomain {
138 static DATA: $crate::RcuData = $crate::RcuData::new();
139 thread_local! {
140 static TLS: $crate::ThreadData = const { $crate::ThreadData::new() };
141 }
142 $crate::RcuDomain::new(&TLS, &DATA)
143 }
144 };
145}
146
147define_rcu_domain! {
148 /// The global RCU domain.
149 pub global
150}
151
152/// An RCU synchronization domain.
153#[derive(Copy, Clone)]
154pub struct RcuDomain {
155 tls: &'static LocalKey<ThreadData>,
156 data: &'static RcuData,
157}
158
159impl std::fmt::Debug for RcuDomain {
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 let Self { tls: _, data } = self;
162 f.debug_struct("RcuDomain").field("data", data).finish()
163 }
164}
165
166/// Domain-global RCU state.
167#[doc(hidden)]
168#[derive(Debug)]
169pub struct RcuData {
170 /// The threads that have registered with this domain.
171 threads: Mutex<Vec<ThreadEntry>>,
172 /// The current sequence number.
173 seq: AtomicU64,
174 /// The event that is signaled when a thread exits a critical section and
175 /// there has been a sequence number update.
176 event: Event,
177 /// The number of membarriers issued.
178 membarriers: AtomicU64,
179}
180
181/// The entry in the thread list for a registered thread.
182#[derive(Debug)]
183struct ThreadEntry {
184 /// The pointer to the sequence number for this thread. The [`ThreadData`]
185 /// TLS destructor will remove this entry, so this is safe to dereference.
186 seq_ptr: TlsRef<AtomicU64>,
187 /// The last sequence number that a synchronizer can know this thread has
188 /// observed, without issuing membarriers or looking at the thread's TLS
189 /// data.
190 observed_seq: u64,
191 /// The thread that this entry is for. Used for debugging and tracing.
192 thread: Thread,
193}
194
195/// A pointer representing a valid reference to a value.
196struct TlsRef<T>(*const T);
197
198impl<T> Deref for TlsRef<T> {
199 type Target = T;
200
201 fn deref(&self) -> &Self::Target {
202 // SAFETY: This is known to point to valid TLS data for its lifetime, since the TLS
203 // drop implementation will remove this entry from the list.
204 unsafe { &*self.0 }
205 }
206}
207
208// SAFETY: Since this represents a reference to T, it is `Send` if `&T` is
209// `Send`.
210unsafe impl<T> Send for TlsRef<T> where for<'a> &'a T: Send {}
211// SAFETY: Since this represents a reference to T, it is `Sync` if `&T` is
212// `Sync`.
213unsafe impl<T> Sync for TlsRef<T> where for<'a> &'a T: Sync {}
214
215impl<T: std::fmt::Debug> std::fmt::Debug for TlsRef<T> {
216 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217 (**self).fmt(f)
218 }
219}
220
221impl RcuData {
222 /// Used by [`define_rcu_domain!`] to create a new RCU domain.
223 #[doc(hidden)]
224 pub const fn new() -> Self {
225 RcuData {
226 threads: Mutex::new(Vec::new()),
227 seq: AtomicU64::new(SEQ_FIRST),
228 event: Event::new(),
229 membarriers: AtomicU64::new(0),
230 }
231 }
232}
233
234/// The per-thread TLS data.
235#[doc(hidden)]
236pub struct ThreadData {
237 /// The current sequence number for the thread.
238 current_seq: AtomicU64,
239 /// The RCU domain this thread is registered with.
240 data: Cell<Option<&'static RcuData>>,
241}
242
243impl std::fmt::Debug for ThreadData {
244 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245 let Self {
246 current_seq: my_seq,
247 data: _,
248 } = self;
249 f.debug_struct("ThreadData")
250 .field("my_seq", my_seq)
251 .finish()
252 }
253}
254
255impl Drop for ThreadData {
256 fn drop(&mut self) {
257 if let Some(data) = self.data.get() {
258 {
259 let mut threads = data.threads.lock();
260 let i = threads
261 .iter()
262 .position(|x| x.seq_ptr.0 == &self.current_seq)
263 .unwrap();
264 threads.swap_remove(i);
265 }
266 data.event.notify(!0usize);
267 }
268 }
269}
270
271impl ThreadData {
272 /// Used by [`define_rcu_domain!`] to create a new RCU domain.
273 #[doc(hidden)]
274 pub const fn new() -> Self {
275 ThreadData {
276 current_seq: AtomicU64::new(SEQ_NONE),
277 data: Cell::new(None),
278 }
279 }
280}
281
282/// The thread has not yet registered with the RCU domain.
283const SEQ_NONE: u64 = 0;
284/// The bit set when the thread in a critical section.
285const SEQ_MASK_BUSY: u64 = 1;
286/// The value the sequence number is incremented by each synchronize call.
287const SEQ_INCREMENT: u64 = 2;
288/// The sequence value for a quiesced thread. The thread will issue a full
289/// memory barrier when leaving this state.
290const SEQ_QUIESCED: u64 = 2;
291/// The first actual sequence number.
292const SEQ_FIRST: u64 = 4;
293
294impl RcuDomain {
295 #[doc(hidden)]
296 pub const fn new(tls: &'static LocalKey<ThreadData>, data: &'static RcuData) -> Self {
297 RcuDomain { tls, data }
298 }
299
300 /// Runs `f` in a critical section. Calls to
301 /// [`synchronize`](Self::synchronize) or
302 /// [`synchronize_blocking`](Self::synchronize_blocking) for the same RCU root will
303 /// block until `f` returns.
304 ///
305 /// In general, you should avoid blocking the thread in `f`, since that can
306 /// slow calls to [`synchronize`](Self::synchronize) and can potentially
307 /// cause deadlocks.
308 pub fn run<F, R>(self, f: F) -> R
309 where
310 F: FnOnce() -> R,
311 {
312 self.tls.with(|x| x.run(self.data, f))
313 }
314
315 /// Quiesce the current thread.
316 ///
317 /// This can speed up calls to [`synchronize`](Self::synchronize) or
318 /// [`synchronize_blocking`](Self::synchronize_blocking) by allowing the RCU domain
319 /// to skip issuing a membarrier if all threads are quiesced. In return, the
320 /// first call to [`run`](Self::run) after this will be slower, as it will
321 /// need to issue a memory barrier to leave the quiesced state.
322 pub fn quiesce(self) {
323 self.tls.with(|x| {
324 x.quiesce(self.data);
325 });
326 }
327
328 /// Runs `fut`, calling [`quiesce`](Self::quiesce) on the current thread
329 /// each time `fut` returns `Poll::Pending`.
330 pub async fn quiesce_on_pending<Fut>(self, fut: Fut) -> Fut::Output
331 where
332 Fut: Future,
333 {
334 let mut fut = pin!(fut);
335 poll_fn(|cx| {
336 self.tls.with(|x| {
337 let r = fut.as_mut().poll(cx);
338 x.quiesce(self.data);
339 r
340 })
341 })
342 .await
343 }
344
345 #[track_caller]
346 fn prepare_to_wait(&self) -> Option<u64> {
347 // Quiesce this thread so we don't wait on ourselves.
348 {
349 let this_seq = self.tls.with(|x| x.quiesce(self.data));
350 assert!(
351 this_seq == SEQ_NONE || this_seq == SEQ_QUIESCED,
352 "called synchronize() inside a critical section, {this_seq:#x}",
353 );
354 }
355 // Update the domain's sequence number.
356 let seq = self.data.seq.fetch_add(SEQ_INCREMENT, SeqCst) + SEQ_INCREMENT;
357 // We need to make sure all threads are quiesced, not busy, or have
358 // observed the new sequence number. To do this, we must synchronize the
359 // global sequence number update with changes to each thread's local
360 // sequence number. To do that, we will issue a membarrier, to broadcast
361 // a memory barrier to all threads in the process.
362 //
363 // First, try to avoid the membarrier if possible--if all threads are quiesced,
364 // then there is no need to issue a membarrier, because quiesced threads will issue
365 // a memory barrier when they leave the quiesced state.
366 if self
367 .data
368 .threads
369 .lock()
370 .iter_mut()
371 .all(|t| Self::is_thread_ready(t, seq, false))
372 {
373 return None;
374 }
375 // Keep a count for diagnostics purposes.
376 self.data.membarriers.fetch_add(1, Relaxed);
377 sys::membarrier();
378 Some(seq)
379 }
380
381 /// Synchronizes the RCU domain, blocking asynchronously until all threads
382 /// have exited their critical sections and observed the new sequence
383 /// number.
384 ///
385 /// `sleep` should be a function that sleeps for the specified duration.
386 pub async fn synchronize(self, mut sleep: impl AsyncFnMut(Duration)) {
387 let Some(seq) = self.prepare_to_wait() else {
388 return;
389 };
390 let mut wait = pin!(self.wait_threads_ready(seq));
391 let mut timeout = Duration::from_millis(100);
392 loop {
393 let mut sleep = pin!(sleep(timeout));
394 let ready = poll_fn(|cx| {
395 if let Poll::Ready(()) = wait.as_mut().poll(cx) {
396 Poll::Ready(true)
397 } else if let Poll::Ready(()) = sleep.as_mut().poll(cx) {
398 Poll::Ready(false)
399 } else {
400 Poll::Pending
401 }
402 })
403 .await;
404 if ready {
405 break;
406 }
407 self.warn_stall(seq);
408 if timeout < Duration::from_secs(10) {
409 timeout *= 2;
410 }
411 }
412 }
413
414 /// Like [`synchronize`](Self::synchronize), but blocks the current thread
415 /// synchronously.
416 #[track_caller]
417 pub fn synchronize_blocking(self) {
418 let Some(seq) = self.prepare_to_wait() else {
419 return;
420 };
421 let mut timeout = Duration::from_millis(10);
422 while !self.wait_threads_ready_sync(seq, Instant::now() + timeout) {
423 self.warn_stall(seq);
424 if timeout < Duration::from_secs(10) {
425 timeout *= 2;
426 }
427 }
428 }
429
430 fn warn_stall(&self, target: u64) {
431 for thread in &mut *self.data.threads.lock() {
432 if !Self::is_thread_ready(thread, target, true) {
433 tracelimit::warn_ratelimited!(thread = thread.thread.name(), "rcu stall");
434 }
435 }
436 }
437
438 async fn wait_threads_ready(&self, target: u64) {
439 loop {
440 let event = self.data.event.listen();
441 if self.all_threads_ready(target, true) {
442 break;
443 }
444 event.await;
445 }
446 }
447
448 #[must_use]
449 fn wait_threads_ready_sync(&self, target: u64, deadline: Instant) -> bool {
450 loop {
451 let event = self.data.event.listen();
452 if self.all_threads_ready(target, true) {
453 break;
454 }
455 if event.wait_deadline(deadline).is_none() {
456 return false;
457 }
458 }
459 true
460 }
461
462 fn all_threads_ready(&self, target: u64, issued_barrier: bool) -> bool {
463 self.data
464 .threads
465 .lock()
466 .iter_mut()
467 .all(|thread| Self::is_thread_ready(thread, target, issued_barrier))
468 }
469
470 fn is_thread_ready(thread: &mut ThreadEntry, target: u64, issued_barrier: bool) -> bool {
471 if thread.observed_seq >= target {
472 return true;
473 }
474 let seq = thread.seq_ptr.load(Relaxed);
475 assert_ne!(seq, SEQ_NONE);
476 if seq & !SEQ_MASK_BUSY < target {
477 if seq & SEQ_MASK_BUSY != 0 {
478 // The thread is actively running in a critical section.
479 return false;
480 }
481 if seq != SEQ_QUIESCED {
482 // The thread is not quiesced. If a barrier was issued, then it
483 // has observed the new sequence number. It may be busy (but
484 // this CPU has not observed the write yet), but it must be busy
485 // with a newer sequence number.
486 //
487 // If a barrier was not issued, then it is possible that the
488 // thread is busy with an older sequence number. In this case,
489 // we will need to issue a membarrier to observe the value of
490 // the busy bit accurately.
491 assert!(seq >= SEQ_FIRST, "{seq}");
492 if !issued_barrier {
493 return false;
494 }
495 }
496 }
497 thread.observed_seq = target;
498 true
499 }
500}
501
502impl ThreadData {
503 fn run<F, R>(&self, data: &'static RcuData, f: F) -> R
504 where
505 F: FnOnce() -> R,
506 {
507 // Mark the thread as busy.
508 let seq = self.current_seq.load(Relaxed);
509 self.current_seq.store(seq | SEQ_MASK_BUSY, Relaxed);
510 if seq < SEQ_FIRST {
511 // The thread was quiesced or not registered. Register it now.
512 if seq == SEQ_NONE {
513 self.start(data, seq);
514 } else {
515 debug_assert!(seq == SEQ_QUIESCED || seq & SEQ_MASK_BUSY != 0, "{seq:#x}");
516 }
517 // Use a full memory barrier to ensure the write side observes that
518 // the thread is no longer quiesced before calling `f`.
519 fence(SeqCst);
520 }
521 // Ensure accesses in `f` are bounded by setting the busy bit. Note that
522 // this and other fences are just compiler fences; the write side must
523 // call `membarrier` to dynamically turn them into processor memory
524 // barriers, so to speak.
525 sys::access_fence(Acquire);
526 let r = f();
527 sys::access_fence(Release);
528 // Clear the busy bit.
529 self.current_seq.store(seq, Relaxed);
530 // Ensure the busy bit clear is visible to the write side, then read the
531 // new sequence number, to synchronize with the sequence update path.
532 sys::access_fence(SeqCst);
533 let new_seq = data.seq.load(Relaxed);
534 if new_seq != seq {
535 // The domain's current sequence number has changed. Update it and
536 // wake up any waiters.
537 self.update_seq(data, seq, new_seq);
538 }
539 r
540 }
541
542 #[inline(never)]
543 fn start(&self, data: &'static RcuData, seq: u64) {
544 if seq == SEQ_NONE {
545 // Add the thread to the list of known threads in this domain.
546 assert!(self.data.get().is_none());
547 data.threads.lock().push(ThreadEntry {
548 seq_ptr: TlsRef(&self.current_seq),
549 observed_seq: SEQ_NONE,
550 thread: std::thread::current(),
551 });
552 // Remember the domain so that we can remove the thread from the list
553 // when it exits.
554 self.data.set(Some(data));
555 }
556 }
557
558 #[inline(never)]
559 fn update_seq(&self, data: &'static RcuData, seq: u64, new_seq: u64) {
560 if seq & SEQ_MASK_BUSY != 0 {
561 // Nested call. Skip.
562 return;
563 }
564 assert!(
565 new_seq >= SEQ_FIRST && new_seq & SEQ_MASK_BUSY == 0,
566 "{new_seq}"
567 );
568 self.current_seq.store(new_seq, Relaxed);
569 // Wake up any waiters. We don't know how many threads are still in a
570 // critical section, so just wake up the writers every time and let them
571 // figure it out.
572 data.event.notify(!0usize);
573 }
574
575 fn quiesce(&self, data: &'static RcuData) -> u64 {
576 let seq = self.current_seq.load(Relaxed);
577 if seq >= SEQ_FIRST && seq & SEQ_MASK_BUSY == 0 {
578 self.current_seq.store(SEQ_QUIESCED, Relaxed);
579 data.event.notify(!0usize);
580 SEQ_QUIESCED
581 } else {
582 seq
583 }
584 }
585}
586
587#[cfg(test)]
588mod tests {
589 use crate::RcuDomain;
590 use pal_async::DefaultDriver;
591 use pal_async::DefaultPool;
592 use pal_async::async_test;
593 use pal_async::task::Spawn;
594 use pal_async::timer::PolledTimer;
595 use std::sync::atomic::Ordering;
596 use test_with_tracing::test;
597
598 async fn sync(driver: &DefaultDriver, rcu: RcuDomain) {
599 let mut timer = PolledTimer::new(driver);
600 rcu.synchronize(async |timeout| {
601 timer.sleep(timeout).await;
602 })
603 .await
604 }
605
606 #[async_test]
607 async fn test_rcu_single(driver: DefaultDriver) {
608 define_rcu_domain!(test_rcu);
609
610 test_rcu().run(|| {});
611 sync(&driver, test_rcu()).await;
612 }
613
614 #[async_test]
615 async fn test_rcu_nested(driver: DefaultDriver) {
616 define_rcu_domain!(test_rcu);
617
618 test_rcu().run(|| {
619 test_rcu().run(|| {});
620 });
621 sync(&driver, test_rcu()).await;
622 }
623
624 #[async_test]
625 async fn test_rcu_multi(driver: DefaultDriver) {
626 define_rcu_domain!(test_rcu);
627
628 let (thread, thread_driver) = DefaultPool::spawn_on_thread("test");
629 thread_driver
630 .spawn("test", async { test_rcu().run(|| {}) })
631 .await;
632
633 assert_eq!(test_rcu().data.membarriers.load(Ordering::Relaxed), 0);
634 sync(&driver, test_rcu()).await;
635 assert_eq!(test_rcu().data.membarriers.load(Ordering::Relaxed), 1);
636
637 drop(thread_driver);
638 thread.join().unwrap();
639 }
640
641 #[async_test]
642 async fn test_rcu_multi_quiesce(driver: DefaultDriver) {
643 define_rcu_domain!(test_rcu);
644
645 let (thread, thread_driver) = DefaultPool::spawn_on_thread("test");
646 thread_driver
647 .spawn(
648 "test",
649 test_rcu().quiesce_on_pending(async { test_rcu().run(|| {}) }),
650 )
651 .await;
652
653 assert_eq!(test_rcu().data.membarriers.load(Ordering::Relaxed), 0);
654 test_rcu().quiesce();
655 sync(&driver, test_rcu()).await;
656 assert_eq!(test_rcu().data.membarriers.load(Ordering::Relaxed), 0);
657
658 drop(thread_driver);
659 thread.join().unwrap();
660 }
661}