mesh_channel/
cancel.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Cancel context support.
5
6use super::bidir::Channel;
7use super::deadline::DeadlineId;
8use super::deadline::DeadlineSet;
9use mesh_node::local_node::Port;
10use mesh_node::resource::Resource;
11use mesh_protobuf::EncodeAs;
12use mesh_protobuf::Protobuf;
13use mesh_protobuf::SerializedMessage;
14use mesh_protobuf::Timestamp;
15use mesh_protobuf::encoding::IgnoreField;
16use parking_lot::Mutex;
17use std::future::Future;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::sync::Weak;
21use std::task::Context;
22use std::task::Poll;
23use std::task::Wake;
24use std::time::Duration;
25use std::time::Instant;
26use std::time::SystemTime;
27use std::time::UNIX_EPOCH;
28use thiserror::Error;
29
30/// A cancellation context.
31///
32/// This is used to get a notification when an operation has been cancelled. It
33/// can be cloned or sent across process boundaries.
34#[derive(Debug, Protobuf)]
35#[mesh(resource = "Resource")]
36pub struct CancelContext {
37    state: CancelState,
38    deadline: Option<EncodeAs<Deadline, Timestamp>>,
39    deadline_id: Ignore<DeadlineId>,
40}
41
42#[derive(Debug, Protobuf)]
43#[mesh(resource = "Resource")]
44enum CancelState {
45    NotCancelled { ports: Vec<Channel> },
46    Cancelled(CancelReason),
47}
48
49#[derive(Debug, Default)]
50struct Ignore<T>(T);
51
52impl<T: Default> mesh_protobuf::DefaultEncoding for Ignore<T> {
53    type Encoding = IgnoreField;
54}
55
56impl Clone for CancelContext {
57    fn clone(&self) -> Self {
58        let state = match &self.state {
59            CancelState::Cancelled(reason) => CancelState::Cancelled(*reason),
60            CancelState::NotCancelled { ports, .. } => CancelState::NotCancelled {
61                ports: ports
62                    .iter()
63                    .map(|port| {
64                        // Each port has a peer port in a Cancel object; these
65                        // ports will be closed on cancellation. Send a new port
66                        // to each Cancel object; these new ports will be closed
67                        // on cancel, too.
68                        let (send, recv) = <Channel>::new_pair();
69                        port.send(SerializedMessage {
70                            data: vec![],
71                            resources: vec![Resource::Port(recv.into())],
72                        });
73                        send
74                    })
75                    .collect(),
76            },
77        };
78        Self {
79            state,
80            deadline: self.deadline,
81            deadline_id: Default::default(),
82        }
83    }
84}
85
86impl CancelContext {
87    /// Returns a new context that is never notified of cancellation.
88    pub fn new() -> Self {
89        Self {
90            state: CancelState::NotCancelled { ports: Vec::new() },
91            deadline: None,
92            deadline_id: Default::default(),
93        }
94    }
95
96    fn add_cancel(&mut self) -> Cancel {
97        let (send, recv) = Channel::new_pair();
98        match &mut self.state {
99            CancelState::Cancelled(_) => {}
100            CancelState::NotCancelled { ports, .. } => ports.push(send),
101        }
102        Cancel::new(recv)
103    }
104
105    /// Returns a new child context and a cancel function.
106    ///
107    /// The new context is notified when either this context is cancelled, or
108    /// the returned `Cancel` object's `cancel` method is called.
109    pub fn with_cancel(&self) -> (Self, Cancel) {
110        let mut ctx = self.clone();
111        let cancel = ctx.add_cancel();
112        (ctx, cancel)
113    }
114
115    /// Returns a new child context with a deadline.
116    ///
117    /// The new context is notified when either this context is cancelled, or
118    /// the deadline is exceeded.
119    pub fn with_deadline(&self, deadline: Deadline) -> Self {
120        let mut ctx = self.clone();
121        ctx.deadline = Some(
122            self.deadline
123                .map_or(deadline, |old| old.min(deadline))
124                .into(),
125        );
126        ctx
127    }
128
129    /// Returns a new child context with a timeout.
130    ///
131    /// The new context is notified when either this context is cancelled, or
132    /// the timeout has expired.
133    pub fn with_timeout(&self, timeout: Duration) -> Self {
134        match Deadline::now().checked_add(timeout) {
135            Some(deadline) => self.with_deadline(deadline),
136            None => self.clone(),
137        }
138    }
139
140    /// Returns the current deadline, if there is one.
141    pub fn deadline(&self) -> Option<Deadline> {
142        self.deadline.as_deref().copied()
143    }
144
145    /// Returns a future that completes when the context is cancelled.
146    pub fn cancelled(&mut self) -> Cancelled<'_> {
147        Cancelled(self)
148    }
149
150    /// Runs `fut` until this context is cancelled.
151    pub async fn until_cancelled<F: Future>(&mut self, fut: F) -> Result<F::Output, CancelReason> {
152        let mut fut = core::pin::pin!(fut);
153        let mut cancelled = core::pin::pin!(self.cancelled());
154        std::future::poll_fn(|cx| {
155            if let Poll::Ready(r) = fut.as_mut().poll(cx) {
156                return Poll::Ready(Ok(r));
157            }
158            if let Poll::Ready(reason) = cancelled.as_mut().poll(cx) {
159                return Poll::Ready(Err(reason));
160            }
161            Poll::Pending
162        })
163        .await
164    }
165
166    /// Runs a failable future until this context is cancelled, merging the
167    /// result with the cancellation reason.
168    pub async fn until_cancelled_failable<F: Future<Output = Result<T, E>>, T, E>(
169        &mut self,
170        fut: F,
171    ) -> Result<T, ErrorOrCancelled<E>> {
172        match self.until_cancelled(fut).await {
173            Ok(Ok(r)) => Ok(r),
174            Ok(Err(e)) => Err(ErrorOrCancelled::Error(e)),
175            Err(reason) => Err(ErrorOrCancelled::Cancelled(reason)),
176        }
177    }
178}
179
180impl Default for CancelContext {
181    fn default() -> Self {
182        Self::new()
183    }
184}
185
186#[must_use]
187#[derive(Debug)]
188pub struct Cancelled<'a>(&'a mut CancelContext);
189
190#[derive(Debug, Protobuf, Copy, Clone, PartialEq, Eq)]
191pub enum CancelReason {
192    Cancelled,
193    DeadlineExceeded,
194}
195
196impl std::fmt::Display for CancelReason {
197    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198        f.pad(match *self {
199            CancelReason::Cancelled => "cancelled",
200            CancelReason::DeadlineExceeded => "deadline exceeded",
201        })
202    }
203}
204
205impl std::error::Error for CancelReason {}
206
207#[derive(Error, Debug)]
208pub enum ErrorOrCancelled<E> {
209    #[error(transparent)]
210    Error(E),
211    #[error(transparent)]
212    Cancelled(CancelReason),
213}
214
215impl Future for Cancelled<'_> {
216    type Output = CancelReason;
217
218    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
219        let this = Pin::get_mut(self);
220        match &mut this.0.state {
221            CancelState::Cancelled(reason) => return Poll::Ready(*reason),
222            CancelState::NotCancelled { ports } => {
223                for p in ports.iter_mut() {
224                    if p.poll_recv(cx).is_ready() {
225                        let reason = CancelReason::Cancelled;
226                        this.0.state = CancelState::Cancelled(reason);
227                        return Poll::Ready(reason);
228                    }
229                }
230            }
231        }
232        if let Some(deadline) = this.0.deadline {
233            if DeadlineSet::global()
234                .poll(cx, &mut this.0.deadline_id.0, *deadline)
235                .is_ready()
236            {
237                let reason = CancelReason::DeadlineExceeded;
238                this.0.state = CancelState::Cancelled(reason);
239                return Poll::Ready(reason);
240            }
241        }
242        Poll::Pending
243    }
244}
245
246impl Drop for CancelContext {
247    fn drop(&mut self) {
248        // Drop the deadline.
249        DeadlineSet::global().remove(&mut self.deadline_id.0);
250    }
251}
252
253/// A cancel notifier.
254///
255/// The associated [`CancelContext`] will be cancelled when this object is
256/// dropped or [`Cancel::cancel()`] is called.
257#[derive(Debug)]
258pub struct Cancel(Arc<CancelList>);
259
260/// A list of ports to be closed in order to cancel a chain of cancel contexts.
261///
262/// A cancel context can send additional ports to any one of these ports in
263/// order to register additional cancel contexts.
264#[derive(Debug)]
265struct CancelList {
266    ports: Mutex<Vec<Channel>>,
267}
268
269impl CancelList {
270    /// Polls each port to accumulate more ports and to garbage collect any
271    /// closed ports.
272    fn poll(&self, cx: &mut Context<'_>) {
273        let mut to_drop = Vec::new();
274        let mut ports = self.ports.lock();
275        let mut i = 0;
276        'outer: while i < ports.len() {
277            while let Poll::Ready(message) = ports[i].poll_recv(cx) {
278                match message {
279                    Ok(message) => {
280                        // Accumulate any ports onto the main list so that they
281                        // can be polled.
282                        let resources = message.resources;
283                        tracing::trace!(count = resources.len(), "adding ports");
284                        ports.extend(resources.into_iter().filter_map(|resource| {
285                            Port::try_from(resource).ok().map(|port| port.into())
286                        }));
287                    }
288                    Err(_) => {
289                        // The peer port is gone. Remove the port from the list.
290                        // Push it onto a new list to be dropped outside the lock.
291                        to_drop.push(ports.swap_remove(i));
292                        continue 'outer;
293                    }
294                }
295            }
296            i += 1;
297        }
298        if !to_drop.is_empty() {
299            tracing::trace!(count = to_drop.len(), "dropping ports");
300        }
301    }
302
303    /// Returns all the accumulated ports.
304    fn drain(&self) -> Vec<Channel> {
305        std::mem::take(&mut self.ports.lock())
306    }
307}
308
309/// Waker that processes the list inline.
310struct ListWaker {
311    list: Weak<CancelList>,
312}
313
314impl Wake for ListWaker {
315    fn wake(self: Arc<Self>) {
316        if let Some(list) = self.list.upgrade() {
317            // Ordinarily it is a bad idea to do anything like this in a waker,
318            // since this could run on an arbitrary thread, under locks, etc.
319            // However, since we're within the same crate we know that it should
320            // be safe to run CancelList::poll on the waking thread.
321            let waker = self.into();
322            let mut cx = Context::from_waker(&waker);
323            list.poll(&mut cx);
324        }
325    }
326}
327
328impl Cancel {
329    fn new(port: Channel) -> Self {
330        let inner = Arc::new(CancelList {
331            ports: Mutex::new(vec![port]),
332        });
333        // The waker is used to poll the port. This is done to detect when the
334        // port is closed so that it can be dropped, and to accumulate any
335        // incoming ports (due to CancelContext::clone) for the same.
336        let waker = Arc::new(ListWaker {
337            list: Arc::downgrade(&inner),
338        });
339        waker.wake();
340        Self(inner)
341    }
342
343    /// Cancels the associated port context and any children contexts.
344    pub fn cancel(&mut self) {
345        drop(self.0.drain());
346    }
347}
348
349/// A point in time that acts as a deadline for an operation.
350///
351/// A deadline internally tracks both wall-clock time and, optionally, OS
352/// monotonic time. When two deadlines are compared, monotonic time is
353/// preferred, but if one or more deadlines do not have monotonic time,
354/// wall-clock time is used.
355///
356/// When a deadline is serialized, only its wall-clock time is serialized. The
357/// monotonic time is not useful outside of the process that generated it, since
358/// the monotonic time is not guaranteed to be consistent across processes.
359#[derive(Debug, Copy, Clone, Eq)]
360pub struct Deadline {
361    system_time: SystemTime,
362    instant: Option<Instant>,
363}
364
365impl Deadline {
366    /// Returns a new deadline representing the current time.
367    ///
368    /// This will capture both wall-clock time and monotonic time.
369    pub fn now() -> Self {
370        Self {
371            system_time: SystemTime::now(),
372            instant: Some(Instant::now()),
373        }
374    }
375
376    /// The monotonic OS instant of the deadline, if there is one.
377    pub fn instant(&self) -> Option<Instant> {
378        self.instant
379    }
380
381    /// The wall-clock time of the deadline.
382    pub fn system_time(&self) -> SystemTime {
383        self.system_time
384    }
385
386    /// Adds a duration to the deadline, returning `None` on overflow.
387    pub fn checked_add(&self, duration: Duration) -> Option<Self> {
388        // Throw away the instant if it overflows.
389        let instant = self.instant.and_then(|i| i.checked_add(duration));
390        let system_time = self.system_time.checked_add(duration)?;
391        Some(Self {
392            system_time,
393            instant,
394        })
395    }
396}
397
398impl std::ops::Add<Duration> for Deadline {
399    type Output = Self;
400
401    fn add(self, rhs: Duration) -> Self::Output {
402        self.checked_add(rhs)
403            .expect("overflow when adding duration to deadline")
404    }
405}
406
407impl std::ops::Sub<Duration> for Deadline {
408    type Output = Deadline;
409
410    fn sub(self, rhs: Duration) -> Self::Output {
411        // Saturate to the UNIX epoch on overflow. Since `SystemTime` does
412        // generally allow times before the epoch, this might lead to the
413        // deadline "snapping back" to 1970. But for our use case the
414        // distinction between any time before now doesn't matter.
415        Self {
416            system_time: self.system_time.checked_sub(rhs).unwrap_or(UNIX_EPOCH),
417            instant: self.instant.and_then(|i| i.checked_sub(rhs)),
418        }
419    }
420}
421
422impl std::ops::Sub<Deadline> for Deadline {
423    type Output = Duration;
424
425    fn sub(self, rhs: Deadline) -> Self::Output {
426        // Saturate to zero on overflow.
427        if let Some((lhs, rhs)) = self.instant.zip(rhs.instant) {
428            lhs.checked_duration_since(rhs).unwrap_or_default()
429        } else {
430            self.system_time
431                .duration_since(rhs.system_time)
432                .unwrap_or_default()
433        }
434    }
435}
436
437impl PartialEq for Deadline {
438    fn eq(&self, other: &Self) -> bool {
439        if let Some((lhs, rhs)) = self.instant.zip(other.instant) {
440            lhs.eq(&rhs)
441        } else {
442            self.system_time.eq(&other.system_time)
443        }
444    }
445}
446
447impl PartialOrd for Deadline {
448    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
449        Some(self.cmp(other))
450    }
451}
452
453impl Ord for Deadline {
454    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
455        if let Some((lhs, rhs)) = self.instant.zip(other.instant) {
456            lhs.cmp(&rhs)
457        } else {
458            self.system_time.cmp(&other.system_time)
459        }
460    }
461}
462
463impl From<SystemTime> for Deadline {
464    fn from(system_time: SystemTime) -> Self {
465        Self {
466            system_time,
467            instant: None,
468        }
469    }
470}
471
472impl From<Deadline> for Timestamp {
473    fn from(deadline: Deadline) -> Self {
474        deadline.system_time.into()
475    }
476}
477
478impl From<Timestamp> for Deadline {
479    fn from(timestamp: Timestamp) -> Self {
480        Self {
481            system_time: timestamp.try_into().unwrap_or(UNIX_EPOCH),
482            instant: None,
483        }
484    }
485}
486
487#[cfg(test)]
488mod tests {
489    use super::CancelContext;
490    use super::CancelReason;
491    use super::Deadline;
492    use pal_async::async_test;
493    use test_with_tracing::test;
494
495    #[async_test]
496    async fn no_cancel() {
497        assert!(futures::poll!(CancelContext::new().cancelled()).is_pending());
498    }
499
500    #[async_test]
501    async fn basic_cancel() {
502        let (mut ctx, mut cancel) = CancelContext::new().with_cancel();
503        cancel.cancel();
504        assert!(futures::poll!(ctx.cancelled()).is_ready());
505    }
506
507    #[expect(clippy::redundant_clone, reason = "explicitly testing chained clones")]
508    async fn chain(use_cancel: bool) {
509        let ctx = CancelContext::new();
510        let (mut ctx, mut cancel) = ctx.with_cancel();
511        if !use_cancel {
512            ctx = ctx.with_timeout(std::time::Duration::from_millis(15));
513        }
514        let ctx = ctx.clone();
515        let ctx = ctx.clone();
516        let ctx = ctx.clone();
517        let ctx = ctx.clone();
518        let ctx = ctx.clone();
519        let ctx = ctx.clone();
520        let ctx = ctx.clone();
521        let mut ctx = ctx.clone();
522        let ctx2 = ctx.clone();
523        let ctx2 = ctx2.clone();
524        let ctx2 = ctx2.clone();
525        let ctx2 = ctx2.clone();
526        let ctx2 = ctx2.clone();
527        let mut ctx2 = ctx2.clone();
528        let _ = ctx2
529            .clone()
530            .clone()
531            .clone()
532            .clone()
533            .clone()
534            .clone()
535            .clone()
536            .clone()
537            .clone();
538        std::thread::sleep(std::time::Duration::from_millis(100));
539        if use_cancel {
540            cancel.cancel();
541        }
542        assert!(futures::poll!(ctx.cancelled()).is_ready());
543        assert!(futures::poll!(ctx2.cancelled()).is_ready());
544    }
545
546    #[async_test]
547    async fn chain_cancel() {
548        chain(true).await
549    }
550
551    #[async_test]
552    async fn chain_deadline() {
553        chain(false).await
554    }
555
556    #[async_test]
557    async fn cancel_deadline() {
558        let mut ctx = CancelContext::new().with_timeout(std::time::Duration::from_millis(0));
559        assert_eq!(ctx.cancelled().await, CancelReason::DeadlineExceeded);
560        let mut ctx = CancelContext::new().with_timeout(std::time::Duration::from_millis(100));
561        assert_eq!(ctx.cancelled().await, CancelReason::DeadlineExceeded);
562    }
563
564    #[test]
565    fn test_encode_deadline() {
566        let check = |deadline: Deadline| {
567            let timestamp: super::Timestamp = deadline.into();
568            let deadline2: Deadline = timestamp.into();
569            assert_eq!(deadline, deadline2);
570        };
571
572        check(Deadline::now());
573        check(Deadline::now() + std::time::Duration::from_secs(1));
574        check(Deadline::now() - std::time::Duration::from_secs(1));
575        check(Deadline::from(
576            std::time::SystemTime::UNIX_EPOCH - std::time::Duration::from_nanos(1_500_000_000),
577        ));
578        check(Deadline::from(
579            std::time::SystemTime::UNIX_EPOCH + std::time::Duration::from_nanos(1_500_000_000),
580        ));
581    }
582}