1use super::bidir::Channel;
7use super::deadline::DeadlineId;
8use super::deadline::DeadlineSet;
9use mesh_node::local_node::Port;
10use mesh_node::resource::Resource;
11use mesh_protobuf::EncodeAs;
12use mesh_protobuf::Protobuf;
13use mesh_protobuf::SerializedMessage;
14use mesh_protobuf::Timestamp;
15use mesh_protobuf::encoding::IgnoreField;
16use parking_lot::Mutex;
17use std::future::Future;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::sync::Weak;
21use std::task::Context;
22use std::task::Poll;
23use std::task::Wake;
24use std::time::Duration;
25use std::time::Instant;
26use std::time::SystemTime;
27use std::time::UNIX_EPOCH;
28use thiserror::Error;
29
30#[derive(Debug, Protobuf)]
35#[mesh(resource = "Resource")]
36pub struct CancelContext {
37 state: CancelState,
38 deadline: Option<EncodeAs<Deadline, Timestamp>>,
39 deadline_id: Ignore<DeadlineId>,
40}
41
42#[derive(Debug, Protobuf)]
43#[mesh(resource = "Resource")]
44enum CancelState {
45 NotCancelled { ports: Vec<Channel> },
46 Cancelled(CancelReason),
47}
48
49#[derive(Debug, Default)]
50struct Ignore<T>(T);
51
52impl<T: Default> mesh_protobuf::DefaultEncoding for Ignore<T> {
53 type Encoding = IgnoreField;
54}
55
56impl Clone for CancelContext {
57 fn clone(&self) -> Self {
58 let state = match &self.state {
59 CancelState::Cancelled(reason) => CancelState::Cancelled(*reason),
60 CancelState::NotCancelled { ports, .. } => CancelState::NotCancelled {
61 ports: ports
62 .iter()
63 .map(|port| {
64 let (send, recv) = <Channel>::new_pair();
69 port.send(SerializedMessage {
70 data: vec![],
71 resources: vec![Resource::Port(recv.into())],
72 });
73 send
74 })
75 .collect(),
76 },
77 };
78 Self {
79 state,
80 deadline: self.deadline,
81 deadline_id: Default::default(),
82 }
83 }
84}
85
86impl CancelContext {
87 pub fn new() -> Self {
89 Self {
90 state: CancelState::NotCancelled { ports: Vec::new() },
91 deadline: None,
92 deadline_id: Default::default(),
93 }
94 }
95
96 fn add_cancel(&mut self) -> Cancel {
97 let (send, recv) = Channel::new_pair();
98 match &mut self.state {
99 CancelState::Cancelled(_) => {}
100 CancelState::NotCancelled { ports, .. } => ports.push(send),
101 }
102 Cancel::new(recv)
103 }
104
105 pub fn with_cancel(&self) -> (Self, Cancel) {
110 let mut ctx = self.clone();
111 let cancel = ctx.add_cancel();
112 (ctx, cancel)
113 }
114
115 pub fn with_deadline(&self, deadline: Deadline) -> Self {
120 let mut ctx = self.clone();
121 ctx.deadline = Some(
122 self.deadline
123 .map_or(deadline, |old| old.min(deadline))
124 .into(),
125 );
126 ctx
127 }
128
129 pub fn with_timeout(&self, timeout: Duration) -> Self {
134 match Deadline::now().checked_add(timeout) {
135 Some(deadline) => self.with_deadline(deadline),
136 None => self.clone(),
137 }
138 }
139
140 pub fn deadline(&self) -> Option<Deadline> {
142 self.deadline.as_deref().copied()
143 }
144
145 pub fn cancelled(&mut self) -> Cancelled<'_> {
147 Cancelled(self)
148 }
149
150 pub async fn until_cancelled<F: Future>(&mut self, fut: F) -> Result<F::Output, CancelReason> {
152 let mut fut = core::pin::pin!(fut);
153 let mut cancelled = core::pin::pin!(self.cancelled());
154 std::future::poll_fn(|cx| {
155 if let Poll::Ready(r) = fut.as_mut().poll(cx) {
156 return Poll::Ready(Ok(r));
157 }
158 if let Poll::Ready(reason) = cancelled.as_mut().poll(cx) {
159 return Poll::Ready(Err(reason));
160 }
161 Poll::Pending
162 })
163 .await
164 }
165
166 pub async fn until_cancelled_failable<F: Future<Output = Result<T, E>>, T, E>(
169 &mut self,
170 fut: F,
171 ) -> Result<T, ErrorOrCancelled<E>> {
172 match self.until_cancelled(fut).await {
173 Ok(Ok(r)) => Ok(r),
174 Ok(Err(e)) => Err(ErrorOrCancelled::Error(e)),
175 Err(reason) => Err(ErrorOrCancelled::Cancelled(reason)),
176 }
177 }
178
179 pub fn is_cancelled(&mut self) -> bool {
181 Pin::new(&mut self.cancelled())
182 .poll(&mut Context::from_waker(std::task::Waker::noop()))
183 .is_ready()
184 }
185}
186
187impl Default for CancelContext {
188 fn default() -> Self {
189 Self::new()
190 }
191}
192
193#[must_use]
194#[derive(Debug)]
195pub struct Cancelled<'a>(&'a mut CancelContext);
196
197#[derive(Debug, Protobuf, Copy, Clone, PartialEq, Eq)]
198pub enum CancelReason {
199 Cancelled,
200 DeadlineExceeded,
201}
202
203impl std::fmt::Display for CancelReason {
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 f.pad(match *self {
206 CancelReason::Cancelled => "cancelled",
207 CancelReason::DeadlineExceeded => "deadline exceeded",
208 })
209 }
210}
211
212impl std::error::Error for CancelReason {}
213
214#[derive(Error, Debug)]
215pub enum ErrorOrCancelled<E> {
216 #[error(transparent)]
217 Error(E),
218 #[error(transparent)]
219 Cancelled(CancelReason),
220}
221
222impl Future for Cancelled<'_> {
223 type Output = CancelReason;
224
225 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
226 let this = Pin::get_mut(self);
227 match &mut this.0.state {
228 CancelState::Cancelled(reason) => return Poll::Ready(*reason),
229 CancelState::NotCancelled { ports } => {
230 for p in ports.iter_mut() {
231 if p.poll_recv(cx).is_ready() {
232 let reason = CancelReason::Cancelled;
233 this.0.state = CancelState::Cancelled(reason);
234 return Poll::Ready(reason);
235 }
236 }
237 }
238 }
239 if let Some(deadline) = this.0.deadline {
240 if DeadlineSet::global()
241 .poll(cx, &mut this.0.deadline_id.0, *deadline)
242 .is_ready()
243 {
244 let reason = CancelReason::DeadlineExceeded;
245 this.0.state = CancelState::Cancelled(reason);
246 return Poll::Ready(reason);
247 }
248 }
249 Poll::Pending
250 }
251}
252
253impl Drop for CancelContext {
254 fn drop(&mut self) {
255 DeadlineSet::global().remove(&mut self.deadline_id.0);
257 }
258}
259
260#[derive(Debug)]
265pub struct Cancel(Arc<CancelList>);
266
267#[derive(Debug)]
272struct CancelList {
273 ports: Mutex<Vec<Channel>>,
274}
275
276impl CancelList {
277 fn poll(&self, cx: &mut Context<'_>) {
280 let mut to_drop = Vec::new();
281 let mut ports = self.ports.lock();
282 let mut i = 0;
283 'outer: while i < ports.len() {
284 while let Poll::Ready(message) = ports[i].poll_recv(cx) {
285 match message {
286 Ok(message) => {
287 let resources = message.resources;
290 tracing::trace!(count = resources.len(), "adding ports");
291 ports.extend(resources.into_iter().filter_map(|resource| {
292 Port::try_from(resource).ok().map(|port| port.into())
293 }));
294 }
295 Err(_) => {
296 to_drop.push(ports.swap_remove(i));
299 continue 'outer;
300 }
301 }
302 }
303 i += 1;
304 }
305 if !to_drop.is_empty() {
306 tracing::trace!(count = to_drop.len(), "dropping ports");
307 }
308 }
309
310 fn drain(&self) -> Vec<Channel> {
312 std::mem::take(&mut self.ports.lock())
313 }
314}
315
316struct ListWaker {
318 list: Weak<CancelList>,
319}
320
321impl Wake for ListWaker {
322 fn wake(self: Arc<Self>) {
323 if let Some(list) = self.list.upgrade() {
324 let waker = self.into();
329 let mut cx = Context::from_waker(&waker);
330 list.poll(&mut cx);
331 }
332 }
333}
334
335impl Cancel {
336 fn new(port: Channel) -> Self {
337 let inner = Arc::new(CancelList {
338 ports: Mutex::new(vec![port]),
339 });
340 let waker = Arc::new(ListWaker {
344 list: Arc::downgrade(&inner),
345 });
346 waker.wake();
347 Self(inner)
348 }
349
350 pub fn cancel(&mut self) {
352 drop(self.0.drain());
353 }
354}
355
356#[derive(Debug, Copy, Clone, Eq)]
367pub struct Deadline {
368 system_time: SystemTime,
369 instant: Option<Instant>,
370}
371
372impl Deadline {
373 pub fn now() -> Self {
377 Self {
378 system_time: SystemTime::now(),
379 instant: Some(Instant::now()),
380 }
381 }
382
383 pub fn instant(&self) -> Option<Instant> {
385 self.instant
386 }
387
388 pub fn system_time(&self) -> SystemTime {
390 self.system_time
391 }
392
393 pub fn checked_add(&self, duration: Duration) -> Option<Self> {
395 let instant = self.instant.and_then(|i| i.checked_add(duration));
397 let system_time = self.system_time.checked_add(duration)?;
398 Some(Self {
399 system_time,
400 instant,
401 })
402 }
403}
404
405impl std::ops::Add<Duration> for Deadline {
406 type Output = Self;
407
408 fn add(self, rhs: Duration) -> Self::Output {
409 self.checked_add(rhs)
410 .expect("overflow when adding duration to deadline")
411 }
412}
413
414impl std::ops::Sub<Duration> for Deadline {
415 type Output = Deadline;
416
417 fn sub(self, rhs: Duration) -> Self::Output {
418 Self {
423 system_time: self.system_time.checked_sub(rhs).unwrap_or(UNIX_EPOCH),
424 instant: self.instant.and_then(|i| i.checked_sub(rhs)),
425 }
426 }
427}
428
429impl std::ops::Sub<Deadline> for Deadline {
430 type Output = Duration;
431
432 fn sub(self, rhs: Deadline) -> Self::Output {
433 if let Some((lhs, rhs)) = self.instant.zip(rhs.instant) {
435 lhs.checked_duration_since(rhs).unwrap_or_default()
436 } else {
437 self.system_time
438 .duration_since(rhs.system_time)
439 .unwrap_or_default()
440 }
441 }
442}
443
444impl PartialEq for Deadline {
445 fn eq(&self, other: &Self) -> bool {
446 if let Some((lhs, rhs)) = self.instant.zip(other.instant) {
447 lhs.eq(&rhs)
448 } else {
449 self.system_time.eq(&other.system_time)
450 }
451 }
452}
453
454impl PartialOrd for Deadline {
455 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
456 Some(self.cmp(other))
457 }
458}
459
460impl Ord for Deadline {
461 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
462 if let Some((lhs, rhs)) = self.instant.zip(other.instant) {
463 lhs.cmp(&rhs)
464 } else {
465 self.system_time.cmp(&other.system_time)
466 }
467 }
468}
469
470impl From<SystemTime> for Deadline {
471 fn from(system_time: SystemTime) -> Self {
472 Self {
473 system_time,
474 instant: None,
475 }
476 }
477}
478
479impl From<Deadline> for Timestamp {
480 fn from(deadline: Deadline) -> Self {
481 deadline.system_time.into()
482 }
483}
484
485impl From<Timestamp> for Deadline {
486 fn from(timestamp: Timestamp) -> Self {
487 Self {
488 system_time: timestamp.try_into().unwrap_or(UNIX_EPOCH),
489 instant: None,
490 }
491 }
492}
493
494#[cfg(test)]
495mod tests {
496 use super::CancelContext;
497 use super::CancelReason;
498 use super::Deadline;
499 use pal_async::async_test;
500 use test_with_tracing::test;
501
502 #[async_test]
503 async fn no_cancel() {
504 assert!(futures::poll!(CancelContext::new().cancelled()).is_pending());
505 }
506
507 #[async_test]
508 async fn basic_cancel() {
509 let (mut ctx, mut cancel) = CancelContext::new().with_cancel();
510 cancel.cancel();
511 assert!(futures::poll!(ctx.cancelled()).is_ready());
512 }
513
514 #[expect(clippy::redundant_clone, reason = "explicitly testing chained clones")]
515 async fn chain(use_cancel: bool) {
516 let ctx = CancelContext::new();
517 let (mut ctx, mut cancel) = ctx.with_cancel();
518 if !use_cancel {
519 ctx = ctx.with_timeout(std::time::Duration::from_millis(15));
520 }
521 let ctx = ctx.clone();
522 let ctx = ctx.clone();
523 let ctx = ctx.clone();
524 let ctx = ctx.clone();
525 let ctx = ctx.clone();
526 let ctx = ctx.clone();
527 let ctx = ctx.clone();
528 let mut ctx = ctx.clone();
529 let ctx2 = ctx.clone();
530 let ctx2 = ctx2.clone();
531 let ctx2 = ctx2.clone();
532 let ctx2 = ctx2.clone();
533 let ctx2 = ctx2.clone();
534 let mut ctx2 = ctx2.clone();
535 let _ = ctx2
536 .clone()
537 .clone()
538 .clone()
539 .clone()
540 .clone()
541 .clone()
542 .clone()
543 .clone()
544 .clone();
545 std::thread::sleep(std::time::Duration::from_millis(100));
546 if use_cancel {
547 cancel.cancel();
548 }
549 assert!(futures::poll!(ctx.cancelled()).is_ready());
550 assert!(futures::poll!(ctx2.cancelled()).is_ready());
551 }
552
553 #[async_test]
554 async fn chain_cancel() {
555 chain(true).await
556 }
557
558 #[async_test]
559 async fn chain_deadline() {
560 chain(false).await
561 }
562
563 #[async_test]
564 async fn cancel_deadline() {
565 let mut ctx = CancelContext::new().with_timeout(std::time::Duration::from_millis(0));
566 assert_eq!(ctx.cancelled().await, CancelReason::DeadlineExceeded);
567 let mut ctx = CancelContext::new().with_timeout(std::time::Duration::from_millis(100));
568 assert_eq!(ctx.cancelled().await, CancelReason::DeadlineExceeded);
569 }
570
571 #[test]
572 fn test_encode_deadline() {
573 let check = |deadline: Deadline| {
574 let timestamp: super::Timestamp = deadline.into();
575 let deadline2: Deadline = timestamp.into();
576 assert_eq!(deadline, deadline2);
577 };
578
579 check(Deadline::now());
580 check(Deadline::now() + std::time::Duration::from_secs(1));
581 check(Deadline::now() - std::time::Duration::from_secs(1));
582 check(Deadline::from(
583 std::time::SystemTime::UNIX_EPOCH - std::time::Duration::from_nanos(1_500_000_000),
584 ));
585 check(Deadline::from(
586 std::time::SystemTime::UNIX_EPOCH + std::time::Duration::from_nanos(1_500_000_000),
587 ));
588 }
589}