mesh_channel/
cancel.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Cancel context support.
5
6use super::bidir::Channel;
7use super::deadline::DeadlineId;
8use super::deadline::DeadlineSet;
9use mesh_node::local_node::Port;
10use mesh_node::resource::Resource;
11use mesh_protobuf::EncodeAs;
12use mesh_protobuf::Protobuf;
13use mesh_protobuf::SerializedMessage;
14use mesh_protobuf::Timestamp;
15use mesh_protobuf::encoding::IgnoreField;
16use parking_lot::Mutex;
17use std::future::Future;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::sync::Weak;
21use std::task::Context;
22use std::task::Poll;
23use std::task::Wake;
24use std::time::Duration;
25use std::time::Instant;
26use std::time::SystemTime;
27use std::time::UNIX_EPOCH;
28use thiserror::Error;
29
30/// A cancellation context.
31///
32/// This is used to get a notification when an operation has been cancelled. It
33/// can be cloned or sent across process boundaries.
34#[derive(Debug, Protobuf)]
35#[mesh(resource = "Resource")]
36pub struct CancelContext {
37    state: CancelState,
38    deadline: Option<EncodeAs<Deadline, Timestamp>>,
39    deadline_id: Ignore<DeadlineId>,
40}
41
42#[derive(Debug, Protobuf)]
43#[mesh(resource = "Resource")]
44enum CancelState {
45    NotCancelled { ports: Vec<Channel> },
46    Cancelled(CancelReason),
47}
48
49#[derive(Debug, Default)]
50struct Ignore<T>(T);
51
52impl<T: Default> mesh_protobuf::DefaultEncoding for Ignore<T> {
53    type Encoding = IgnoreField;
54}
55
56impl Clone for CancelContext {
57    fn clone(&self) -> Self {
58        let state = match &self.state {
59            CancelState::Cancelled(reason) => CancelState::Cancelled(*reason),
60            CancelState::NotCancelled { ports, .. } => CancelState::NotCancelled {
61                ports: ports
62                    .iter()
63                    .map(|port| {
64                        // Each port has a peer port in a Cancel object; these
65                        // ports will be closed on cancellation. Send a new port
66                        // to each Cancel object; these new ports will be closed
67                        // on cancel, too.
68                        let (send, recv) = <Channel>::new_pair();
69                        port.send(SerializedMessage {
70                            data: vec![],
71                            resources: vec![Resource::Port(recv.into())],
72                        });
73                        send
74                    })
75                    .collect(),
76            },
77        };
78        Self {
79            state,
80            deadline: self.deadline,
81            deadline_id: Default::default(),
82        }
83    }
84}
85
86impl CancelContext {
87    /// Returns a new context that is never notified of cancellation.
88    pub fn new() -> Self {
89        Self {
90            state: CancelState::NotCancelled { ports: Vec::new() },
91            deadline: None,
92            deadline_id: Default::default(),
93        }
94    }
95
96    fn add_cancel(&mut self) -> Cancel {
97        let (send, recv) = Channel::new_pair();
98        match &mut self.state {
99            CancelState::Cancelled(_) => {}
100            CancelState::NotCancelled { ports, .. } => ports.push(send),
101        }
102        Cancel::new(recv)
103    }
104
105    /// Returns a new child context and a cancel function.
106    ///
107    /// The new context is notified when either this context is cancelled, or
108    /// the returned `Cancel` object's `cancel` method is called.
109    pub fn with_cancel(&self) -> (Self, Cancel) {
110        let mut ctx = self.clone();
111        let cancel = ctx.add_cancel();
112        (ctx, cancel)
113    }
114
115    /// Returns a new child context with a deadline.
116    ///
117    /// The new context is notified when either this context is cancelled, or
118    /// the deadline is exceeded.
119    pub fn with_deadline(&self, deadline: Deadline) -> Self {
120        let mut ctx = self.clone();
121        ctx.deadline = Some(
122            self.deadline
123                .map_or(deadline, |old| old.min(deadline))
124                .into(),
125        );
126        ctx
127    }
128
129    /// Returns a new child context with a timeout.
130    ///
131    /// The new context is notified when either this context is cancelled, or
132    /// the timeout has expired.
133    pub fn with_timeout(&self, timeout: Duration) -> Self {
134        match Deadline::now().checked_add(timeout) {
135            Some(deadline) => self.with_deadline(deadline),
136            None => self.clone(),
137        }
138    }
139
140    /// Returns the current deadline, if there is one.
141    pub fn deadline(&self) -> Option<Deadline> {
142        self.deadline.as_deref().copied()
143    }
144
145    /// Returns a future that completes when the context is cancelled.
146    pub fn cancelled(&mut self) -> Cancelled<'_> {
147        Cancelled(self)
148    }
149
150    /// Runs `fut` until this context is cancelled.
151    pub async fn until_cancelled<F: Future>(&mut self, fut: F) -> Result<F::Output, CancelReason> {
152        let mut fut = core::pin::pin!(fut);
153        let mut cancelled = core::pin::pin!(self.cancelled());
154        std::future::poll_fn(|cx| {
155            if let Poll::Ready(r) = fut.as_mut().poll(cx) {
156                return Poll::Ready(Ok(r));
157            }
158            if let Poll::Ready(reason) = cancelled.as_mut().poll(cx) {
159                return Poll::Ready(Err(reason));
160            }
161            Poll::Pending
162        })
163        .await
164    }
165
166    /// Runs a failable future until this context is cancelled, merging the
167    /// result with the cancellation reason.
168    pub async fn until_cancelled_failable<F: Future<Output = Result<T, E>>, T, E>(
169        &mut self,
170        fut: F,
171    ) -> Result<T, ErrorOrCancelled<E>> {
172        match self.until_cancelled(fut).await {
173            Ok(Ok(r)) => Ok(r),
174            Ok(Err(e)) => Err(ErrorOrCancelled::Error(e)),
175            Err(reason) => Err(ErrorOrCancelled::Cancelled(reason)),
176        }
177    }
178
179    /// Returns true if the context has been cancelled.
180    pub fn is_cancelled(&mut self) -> bool {
181        Pin::new(&mut self.cancelled())
182            .poll(&mut Context::from_waker(std::task::Waker::noop()))
183            .is_ready()
184    }
185}
186
187impl Default for CancelContext {
188    fn default() -> Self {
189        Self::new()
190    }
191}
192
193#[must_use]
194#[derive(Debug)]
195pub struct Cancelled<'a>(&'a mut CancelContext);
196
197#[derive(Debug, Protobuf, Copy, Clone, PartialEq, Eq)]
198pub enum CancelReason {
199    Cancelled,
200    DeadlineExceeded,
201}
202
203impl std::fmt::Display for CancelReason {
204    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205        f.pad(match *self {
206            CancelReason::Cancelled => "cancelled",
207            CancelReason::DeadlineExceeded => "deadline exceeded",
208        })
209    }
210}
211
212impl std::error::Error for CancelReason {}
213
214#[derive(Error, Debug)]
215pub enum ErrorOrCancelled<E> {
216    #[error(transparent)]
217    Error(E),
218    #[error(transparent)]
219    Cancelled(CancelReason),
220}
221
222impl Future for Cancelled<'_> {
223    type Output = CancelReason;
224
225    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
226        let this = Pin::get_mut(self);
227        match &mut this.0.state {
228            CancelState::Cancelled(reason) => return Poll::Ready(*reason),
229            CancelState::NotCancelled { ports } => {
230                for p in ports.iter_mut() {
231                    if p.poll_recv(cx).is_ready() {
232                        let reason = CancelReason::Cancelled;
233                        this.0.state = CancelState::Cancelled(reason);
234                        return Poll::Ready(reason);
235                    }
236                }
237            }
238        }
239        if let Some(deadline) = this.0.deadline {
240            if DeadlineSet::global()
241                .poll(cx, &mut this.0.deadline_id.0, *deadline)
242                .is_ready()
243            {
244                let reason = CancelReason::DeadlineExceeded;
245                this.0.state = CancelState::Cancelled(reason);
246                return Poll::Ready(reason);
247            }
248        }
249        Poll::Pending
250    }
251}
252
253impl Drop for CancelContext {
254    fn drop(&mut self) {
255        // Drop the deadline.
256        DeadlineSet::global().remove(&mut self.deadline_id.0);
257    }
258}
259
260/// A cancel notifier.
261///
262/// The associated [`CancelContext`] will be cancelled when this object is
263/// dropped or [`Cancel::cancel()`] is called.
264#[derive(Debug)]
265pub struct Cancel(Arc<CancelList>);
266
267/// A list of ports to be closed in order to cancel a chain of cancel contexts.
268///
269/// A cancel context can send additional ports to any one of these ports in
270/// order to register additional cancel contexts.
271#[derive(Debug)]
272struct CancelList {
273    ports: Mutex<Vec<Channel>>,
274}
275
276impl CancelList {
277    /// Polls each port to accumulate more ports and to garbage collect any
278    /// closed ports.
279    fn poll(&self, cx: &mut Context<'_>) {
280        let mut to_drop = Vec::new();
281        let mut ports = self.ports.lock();
282        let mut i = 0;
283        'outer: while i < ports.len() {
284            while let Poll::Ready(message) = ports[i].poll_recv(cx) {
285                match message {
286                    Ok(message) => {
287                        // Accumulate any ports onto the main list so that they
288                        // can be polled.
289                        let resources = message.resources;
290                        tracing::trace!(count = resources.len(), "adding ports");
291                        ports.extend(resources.into_iter().filter_map(|resource| {
292                            Port::try_from(resource).ok().map(|port| port.into())
293                        }));
294                    }
295                    Err(_) => {
296                        // The peer port is gone. Remove the port from the list.
297                        // Push it onto a new list to be dropped outside the lock.
298                        to_drop.push(ports.swap_remove(i));
299                        continue 'outer;
300                    }
301                }
302            }
303            i += 1;
304        }
305        if !to_drop.is_empty() {
306            tracing::trace!(count = to_drop.len(), "dropping ports");
307        }
308    }
309
310    /// Returns all the accumulated ports.
311    fn drain(&self) -> Vec<Channel> {
312        std::mem::take(&mut self.ports.lock())
313    }
314}
315
316/// Waker that processes the list inline.
317struct ListWaker {
318    list: Weak<CancelList>,
319}
320
321impl Wake for ListWaker {
322    fn wake(self: Arc<Self>) {
323        if let Some(list) = self.list.upgrade() {
324            // Ordinarily it is a bad idea to do anything like this in a waker,
325            // since this could run on an arbitrary thread, under locks, etc.
326            // However, since we're within the same crate we know that it should
327            // be safe to run CancelList::poll on the waking thread.
328            let waker = self.into();
329            let mut cx = Context::from_waker(&waker);
330            list.poll(&mut cx);
331        }
332    }
333}
334
335impl Cancel {
336    fn new(port: Channel) -> Self {
337        let inner = Arc::new(CancelList {
338            ports: Mutex::new(vec![port]),
339        });
340        // The waker is used to poll the port. This is done to detect when the
341        // port is closed so that it can be dropped, and to accumulate any
342        // incoming ports (due to CancelContext::clone) for the same.
343        let waker = Arc::new(ListWaker {
344            list: Arc::downgrade(&inner),
345        });
346        waker.wake();
347        Self(inner)
348    }
349
350    /// Cancels the associated port context and any children contexts.
351    pub fn cancel(&mut self) {
352        drop(self.0.drain());
353    }
354}
355
356/// A point in time that acts as a deadline for an operation.
357///
358/// A deadline internally tracks both wall-clock time and, optionally, OS
359/// monotonic time. When two deadlines are compared, monotonic time is
360/// preferred, but if one or more deadlines do not have monotonic time,
361/// wall-clock time is used.
362///
363/// When a deadline is serialized, only its wall-clock time is serialized. The
364/// monotonic time is not useful outside of the process that generated it, since
365/// the monotonic time is not guaranteed to be consistent across processes.
366#[derive(Debug, Copy, Clone, Eq)]
367pub struct Deadline {
368    system_time: SystemTime,
369    instant: Option<Instant>,
370}
371
372impl Deadline {
373    /// Returns a new deadline representing the current time.
374    ///
375    /// This will capture both wall-clock time and monotonic time.
376    pub fn now() -> Self {
377        Self {
378            system_time: SystemTime::now(),
379            instant: Some(Instant::now()),
380        }
381    }
382
383    /// The monotonic OS instant of the deadline, if there is one.
384    pub fn instant(&self) -> Option<Instant> {
385        self.instant
386    }
387
388    /// The wall-clock time of the deadline.
389    pub fn system_time(&self) -> SystemTime {
390        self.system_time
391    }
392
393    /// Adds a duration to the deadline, returning `None` on overflow.
394    pub fn checked_add(&self, duration: Duration) -> Option<Self> {
395        // Throw away the instant if it overflows.
396        let instant = self.instant.and_then(|i| i.checked_add(duration));
397        let system_time = self.system_time.checked_add(duration)?;
398        Some(Self {
399            system_time,
400            instant,
401        })
402    }
403}
404
405impl std::ops::Add<Duration> for Deadline {
406    type Output = Self;
407
408    fn add(self, rhs: Duration) -> Self::Output {
409        self.checked_add(rhs)
410            .expect("overflow when adding duration to deadline")
411    }
412}
413
414impl std::ops::Sub<Duration> for Deadline {
415    type Output = Deadline;
416
417    fn sub(self, rhs: Duration) -> Self::Output {
418        // Saturate to the UNIX epoch on overflow. Since `SystemTime` does
419        // generally allow times before the epoch, this might lead to the
420        // deadline "snapping back" to 1970. But for our use case the
421        // distinction between any time before now doesn't matter.
422        Self {
423            system_time: self.system_time.checked_sub(rhs).unwrap_or(UNIX_EPOCH),
424            instant: self.instant.and_then(|i| i.checked_sub(rhs)),
425        }
426    }
427}
428
429impl std::ops::Sub<Deadline> for Deadline {
430    type Output = Duration;
431
432    fn sub(self, rhs: Deadline) -> Self::Output {
433        // Saturate to zero on overflow.
434        if let Some((lhs, rhs)) = self.instant.zip(rhs.instant) {
435            lhs.checked_duration_since(rhs).unwrap_or_default()
436        } else {
437            self.system_time
438                .duration_since(rhs.system_time)
439                .unwrap_or_default()
440        }
441    }
442}
443
444impl PartialEq for Deadline {
445    fn eq(&self, other: &Self) -> bool {
446        if let Some((lhs, rhs)) = self.instant.zip(other.instant) {
447            lhs.eq(&rhs)
448        } else {
449            self.system_time.eq(&other.system_time)
450        }
451    }
452}
453
454impl PartialOrd for Deadline {
455    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
456        Some(self.cmp(other))
457    }
458}
459
460impl Ord for Deadline {
461    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
462        if let Some((lhs, rhs)) = self.instant.zip(other.instant) {
463            lhs.cmp(&rhs)
464        } else {
465            self.system_time.cmp(&other.system_time)
466        }
467    }
468}
469
470impl From<SystemTime> for Deadline {
471    fn from(system_time: SystemTime) -> Self {
472        Self {
473            system_time,
474            instant: None,
475        }
476    }
477}
478
479impl From<Deadline> for Timestamp {
480    fn from(deadline: Deadline) -> Self {
481        deadline.system_time.into()
482    }
483}
484
485impl From<Timestamp> for Deadline {
486    fn from(timestamp: Timestamp) -> Self {
487        Self {
488            system_time: timestamp.try_into().unwrap_or(UNIX_EPOCH),
489            instant: None,
490        }
491    }
492}
493
494#[cfg(test)]
495mod tests {
496    use super::CancelContext;
497    use super::CancelReason;
498    use super::Deadline;
499    use pal_async::async_test;
500    use test_with_tracing::test;
501
502    #[async_test]
503    async fn no_cancel() {
504        assert!(futures::poll!(CancelContext::new().cancelled()).is_pending());
505    }
506
507    #[async_test]
508    async fn basic_cancel() {
509        let (mut ctx, mut cancel) = CancelContext::new().with_cancel();
510        cancel.cancel();
511        assert!(futures::poll!(ctx.cancelled()).is_ready());
512    }
513
514    #[expect(clippy::redundant_clone, reason = "explicitly testing chained clones")]
515    async fn chain(use_cancel: bool) {
516        let ctx = CancelContext::new();
517        let (mut ctx, mut cancel) = ctx.with_cancel();
518        if !use_cancel {
519            ctx = ctx.with_timeout(std::time::Duration::from_millis(15));
520        }
521        let ctx = ctx.clone();
522        let ctx = ctx.clone();
523        let ctx = ctx.clone();
524        let ctx = ctx.clone();
525        let ctx = ctx.clone();
526        let ctx = ctx.clone();
527        let ctx = ctx.clone();
528        let mut ctx = ctx.clone();
529        let ctx2 = ctx.clone();
530        let ctx2 = ctx2.clone();
531        let ctx2 = ctx2.clone();
532        let ctx2 = ctx2.clone();
533        let ctx2 = ctx2.clone();
534        let mut ctx2 = ctx2.clone();
535        let _ = ctx2
536            .clone()
537            .clone()
538            .clone()
539            .clone()
540            .clone()
541            .clone()
542            .clone()
543            .clone()
544            .clone();
545        std::thread::sleep(std::time::Duration::from_millis(100));
546        if use_cancel {
547            cancel.cancel();
548        }
549        assert!(futures::poll!(ctx.cancelled()).is_ready());
550        assert!(futures::poll!(ctx2.cancelled()).is_ready());
551    }
552
553    #[async_test]
554    async fn chain_cancel() {
555        chain(true).await
556    }
557
558    #[async_test]
559    async fn chain_deadline() {
560        chain(false).await
561    }
562
563    #[async_test]
564    async fn cancel_deadline() {
565        let mut ctx = CancelContext::new().with_timeout(std::time::Duration::from_millis(0));
566        assert_eq!(ctx.cancelled().await, CancelReason::DeadlineExceeded);
567        let mut ctx = CancelContext::new().with_timeout(std::time::Duration::from_millis(100));
568        assert_eq!(ctx.cancelled().await, CancelReason::DeadlineExceeded);
569    }
570
571    #[test]
572    fn test_encode_deadline() {
573        let check = |deadline: Deadline| {
574            let timestamp: super::Timestamp = deadline.into();
575            let deadline2: Deadline = timestamp.into();
576            assert_eq!(deadline, deadline2);
577        };
578
579        check(Deadline::now());
580        check(Deadline::now() + std::time::Duration::from_secs(1));
581        check(Deadline::now() - std::time::Duration::from_secs(1));
582        check(Deadline::from(
583            std::time::SystemTime::UNIX_EPOCH - std::time::Duration::from_nanos(1_500_000_000),
584        ));
585        check(Deadline::from(
586            std::time::SystemTime::UNIX_EPOCH + std::time::Duration::from_nanos(1_500_000_000),
587        ));
588    }
589}