arc_cyclic_builder/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! An extension to [`std::sync::Arc`] that adds
5//! [`Arc::new_cyclic_builder`](ArcCyclicBuilderExt::new_cyclic_builder), and
6//! [`ArcCyclicBuilder<T>`] - a generalization of [`Arc::new_cyclic`].
7//!
8//! This comes in handy when dealing with objects that have fallible / async
9//! constructors. In these cases, the fact that `Arc::new_cyclic` takes an
10//! infallible, synchronous closure precludes it from being used.
11//!
12//! # Example
13//!
14//! Constructing a self-referential `Gadget` with a fallible async constructor.
15//!
16//! ```
17//! use arc_cyclic_builder::ArcCyclicBuilderExt;
18//! use std::io;
19//! use std::sync::{Arc, Weak};
20//!
21//! struct Gadget {
22//!     me: Weak<Gadget>,
23//! }
24//!
25//! impl Gadget {
26//!     async fn new(me: Weak<Gadget>) -> io::Result<Self> {
27//!         Ok(Gadget { me })
28//!     }
29//! }
30//!
31//! async fn create_gadget() -> io::Result<Arc<Gadget>> {
32//!     let builder = Arc::new_cyclic_builder();
33//!     let gadget = Gadget::new(builder.weak()).await?;
34//!     Ok(builder.build(gadget))
35//! }
36//! ```
37//!
38//! # (Un)Safety
39//!
40//! At the time of writing (8/22/2022), the stable public APIs of `Arc` and
41//! `Weak` are not sufficient to robustly implement `ArcCyclicBuilder` outside
42//! the context of the std itself. Instead, we've had to do something quite
43//! unsafe to get this code working...
44//!
45//! Namely, we've had to make assumptions about the internal representation of
46//! `Arc` and `Weak`, and written the code assuming they will not change out
47//! from under us.
48//!
49//! This is, by all accounts, a Really Bad Idea™️, since the std makes no
50//! guarantees as to the stability of these type's _internal_ representations,
51//! and could _silently_ change them at any point.
52//!
53//! # Road to Safety
54//!
55//! ...that said, we're willing to bet that it's _highly unlikely_ that the
56//! representation of `Arc`/`Weak` is going to change in the near future, and
57//! that this code will continue to work fine (at least for a while).
58//!
59//! Of course, leaving this kind of risk in the codebase isn't a
60//! great idea, as while unit tests and MIRI tests serve as a reasonable
61//! early-warning indicator if the `Arc`/`Weak` representations have changed,
62//! ultimately, this code needs to land upstream in the std.
63//!
64//! TODO: add links to any upstream PRs we end up sending out
65
66// UNSAFETY: See crate-level doccomment.
67#![expect(unsafe_code)]
68
69use std::mem;
70use std::ptr;
71use std::ptr::NonNull;
72use std::sync::Arc;
73use std::sync::Weak;
74use std::sync::atomic;
75use std::sync::atomic::Ordering::*;
76
77// Matches the definition of `ArcInner` in the `std`
78//
79// The other important assumption: both `Arc` and `Weak` share the same repr:
80//
81// `struct Arc<T>  { ptr: NonNull<ArcInner<T>> }`
82// `struct Weak<T> { ptr: NonNull<ArcInner<T>> }`
83#[repr(C)]
84struct ArcInner<T> {
85    strong: atomic::AtomicUsize,
86    weak: atomic::AtomicUsize,
87    data: T,
88}
89
90/// Builder returned by [`Arc::new_cyclic_builder`](ArcCyclicBuilderExt::new_cyclic_builder)
91pub struct ArcCyclicBuilder<T> {
92    init_ptr: NonNull<ArcInner<T>>,
93    weak: Weak<T>,
94}
95
96// DEVNOTE: the bodies of `new` and `build` are essentially identical to the
97// implementation of `Arc::new_cyclic` in std, aside from the use of some
98// transmutes in liu of using Weak/Arc::from_inner (as ArcInner is not a
99// publicly exported type).
100impl<T> ArcCyclicBuilder<T> {
101    fn new() -> Self {
102        // Construct the inner in the "uninitialized" state with a single
103        // weak reference.
104        // NOTE: `Box::new` is replaced with the `box` keyword in std
105        let uninit_ptr: NonNull<_> = Box::leak(Box::new(ArcInner {
106            strong: atomic::AtomicUsize::new(0),
107            weak: atomic::AtomicUsize::new(1),
108            data: mem::MaybeUninit::<T>::uninit(),
109        }))
110        .into();
111        let init_ptr: NonNull<ArcInner<T>> = uninit_ptr.cast();
112
113        // SAFETY: equivalent of calling `Weak { ptr: init_ptr }`
114        let weak = unsafe { mem::transmute::<NonNull<ArcInner<T>>, Weak<T>>(init_ptr) };
115
116        Self { init_ptr, weak }
117    }
118
119    /// Obtain a `Weak<T>` to the allocation. Attempting to
120    /// [`upgrade`](Weak::upgrade) the weak reference prior to invoking
121    /// [`build`](Self::build) will fail and result in a `None` value.
122    pub fn weak(&self) -> Weak<T> {
123        self.weak.clone()
124    }
125
126    /// Finish construction of the `Arc<T>`
127    pub fn build(self, data: T) -> Arc<T> {
128        // Now we can properly initialize the inner value and turn our weak
129        // reference into a strong reference.
130        // SAFETY: self.init_ptr is guaranteed to point to our ArcInner,
131        // which has the same layout as std's.
132        let strong = unsafe {
133            let inner = self.init_ptr.as_ptr();
134            ptr::write(ptr::addr_of_mut!((*inner).data), data);
135
136            // The above write to the data field must be visible to any threads which
137            // observe a non-zero strong count. Therefore we need at least "Release" ordering
138            // in order to synchronize with the `compare_exchange_weak` in `Weak::upgrade`.
139            //
140            // "Acquire" ordering is not required. When considering the possible behaviours
141            // of `data_fn` we only need to look at what it could do with a reference to a
142            // non-upgradeable `Weak`:
143            // - It can *clone* the `Weak`, increasing the weak reference count.
144            // - It can drop those clones, decreasing the weak reference count (but never to zero).
145            //
146            // These side effects do not impact us in any way, and no other side effects are
147            // possible with safe code alone.
148            let prev_value = (*inner).strong.fetch_add(1, Release);
149            debug_assert_eq!(prev_value, 0, "No prior strong references should exist");
150
151            // SAFETY: equivalent of calling `Arc::from_inner`
152            mem::transmute::<NonNull<ArcInner<T>>, Arc<T>>(self.init_ptr)
153        };
154
155        // Strong references should collectively own a shared weak reference,
156        // so don't run the destructor for our old weak reference.
157        mem::forget(self.weak);
158        strong
159    }
160}
161
162/// An extension trait to [`Arc`] that adds
163/// [`new_cyclic_builder`](Self::new_cyclic_builder).
164pub trait ArcCyclicBuilderExt<T> {
165    /// Return a new [`ArcCyclicBuilder<T>`]
166    fn new_cyclic_builder() -> ArcCyclicBuilder<T>;
167}
168
169impl<T> ArcCyclicBuilderExt<T> for Arc<T> {
170    fn new_cyclic_builder() -> ArcCyclicBuilder<T> {
171        ArcCyclicBuilder::new()
172    }
173}
174
175#[expect(clippy::disallowed_types)] // requiring parking_lot just for a test? nah
176#[cfg(test)]
177mod test {
178    use super::*;
179    use std::sync::Mutex;
180
181    struct Gadget {
182        this: Weak<Gadget>,
183        inner: Mutex<usize>,
184
185        inc_on_drop: Arc<Mutex<usize>>,
186    }
187
188    #[derive(Debug)]
189    struct PassedZero;
190
191    impl Gadget {
192        fn new(this: Weak<Gadget>, inner: usize, inc_on_drop: Arc<Mutex<usize>>) -> Gadget {
193            Gadget {
194                this,
195                inner: Mutex::new(inner),
196                inc_on_drop,
197            }
198        }
199
200        fn try_new(
201            this: Weak<Gadget>,
202            inner: usize,
203            inc_on_drop: Arc<Mutex<usize>>,
204        ) -> Result<Gadget, PassedZero> {
205            if inner == 0 {
206                Err(PassedZero)
207            } else {
208                Ok(Gadget::new(this, inner, inc_on_drop))
209            }
210        }
211
212        async fn async_new(
213            this: Weak<Gadget>,
214            inner: usize,
215            inc_on_drop: Arc<Mutex<usize>>,
216        ) -> Gadget {
217            Gadget {
218                this,
219                inner: Mutex::new(inner),
220                inc_on_drop,
221            }
222        }
223
224        fn val(&self) -> usize {
225            *self.inner.lock().unwrap()
226        }
227
228        fn bump_self(&self) {
229            *self.this.upgrade().unwrap().inner.lock().unwrap() += 1;
230        }
231    }
232
233    impl Drop for Gadget {
234        fn drop(&mut self) {
235            *self.inc_on_drop.lock().unwrap() += 1
236        }
237    }
238
239    #[test]
240    fn smoke() {
241        let inc_on_drop = Arc::new(Mutex::new(0));
242
243        let builder = Arc::new_cyclic_builder();
244        let gadget = Gadget::new(builder.weak(), 1, inc_on_drop.clone());
245        assert!(builder.weak().upgrade().is_none());
246        let gadget = builder.build(gadget);
247
248        gadget.bump_self();
249        assert_eq!(gadget.val(), 2);
250
251        drop(gadget);
252        assert_eq!(*inc_on_drop.lock().unwrap(), 1);
253    }
254
255    // showing off how the builder can be used to
256    #[test]
257    fn smoke_fallible_ok() {
258        let inc_on_drop = Arc::new(Mutex::new(0));
259
260        let builder = Arc::new_cyclic_builder();
261        let gadget = Gadget::try_new(builder.weak(), 1, inc_on_drop.clone()).unwrap();
262        assert!(builder.weak().upgrade().is_none());
263        let gadget = builder.build(gadget);
264        gadget.bump_self();
265        assert_eq!(gadget.val(), 2);
266
267        drop(gadget);
268        assert_eq!(*inc_on_drop.lock().unwrap(), 1);
269    }
270
271    #[test]
272    fn smoke_async_construction() {
273        let inc_on_drop = Arc::new(Mutex::new(0));
274
275        let builder = Arc::new_cyclic_builder();
276
277        let gadget = futures_executor::block_on(async {
278            Gadget::async_new(builder.weak(), 1, inc_on_drop.clone()).await
279        });
280        assert!(builder.weak().upgrade().is_none());
281        let gadget = builder.build(gadget);
282        gadget.bump_self();
283        assert_eq!(gadget.val(), 2);
284
285        drop(gadget);
286        assert_eq!(*inc_on_drop.lock().unwrap(), 1);
287    }
288
289    #[test]
290    fn drop_the_builder() {
291        let builder: ArcCyclicBuilder<usize> = Arc::new_cyclic_builder();
292        let weak = builder.weak();
293        drop(builder);
294        assert!(weak.upgrade().is_none());
295        drop(weak);
296    }
297}