1use super::bidir::Channel;
7use super::deadline::DeadlineId;
8use super::deadline::DeadlineSet;
9use mesh_node::local_node::Port;
10use mesh_node::resource::Resource;
11use mesh_protobuf::EncodeAs;
12use mesh_protobuf::Protobuf;
13use mesh_protobuf::SerializedMessage;
14use mesh_protobuf::Timestamp;
15use mesh_protobuf::encoding::IgnoreField;
16use parking_lot::Mutex;
17use std::future::Future;
18use std::pin::Pin;
19use std::sync::Arc;
20use std::sync::Weak;
21use std::task::Context;
22use std::task::Poll;
23use std::task::Wake;
24use std::time::Duration;
25use std::time::Instant;
26use std::time::SystemTime;
27use std::time::UNIX_EPOCH;
28use thiserror::Error;
29
30#[derive(Debug, Protobuf)]
35#[mesh(resource = "Resource")]
36pub struct CancelContext {
37 state: CancelState,
38 deadline: Option<EncodeAs<Deadline, Timestamp>>,
39 deadline_id: Ignore<DeadlineId>,
40}
41
42#[derive(Debug, Protobuf)]
43#[mesh(resource = "Resource")]
44enum CancelState {
45 NotCancelled { ports: Vec<Channel> },
46 Cancelled(CancelReason),
47}
48
49#[derive(Debug, Default)]
50struct Ignore<T>(T);
51
52impl<T: Default> mesh_protobuf::DefaultEncoding for Ignore<T> {
53 type Encoding = IgnoreField;
54}
55
56impl Clone for CancelContext {
57 fn clone(&self) -> Self {
58 let state = match &self.state {
59 CancelState::Cancelled(reason) => CancelState::Cancelled(*reason),
60 CancelState::NotCancelled { ports, .. } => CancelState::NotCancelled {
61 ports: ports
62 .iter()
63 .map(|port| {
64 let (send, recv) = <Channel>::new_pair();
69 port.send(SerializedMessage {
70 data: vec![],
71 resources: vec![Resource::Port(recv.into())],
72 });
73 send
74 })
75 .collect(),
76 },
77 };
78 Self {
79 state,
80 deadline: self.deadline,
81 deadline_id: Default::default(),
82 }
83 }
84}
85
86impl CancelContext {
87 pub fn new() -> Self {
89 Self {
90 state: CancelState::NotCancelled { ports: Vec::new() },
91 deadline: None,
92 deadline_id: Default::default(),
93 }
94 }
95
96 fn add_cancel(&mut self) -> Cancel {
97 let (send, recv) = Channel::new_pair();
98 match &mut self.state {
99 CancelState::Cancelled(_) => {}
100 CancelState::NotCancelled { ports, .. } => ports.push(send),
101 }
102 Cancel::new(recv)
103 }
104
105 pub fn with_cancel(&self) -> (Self, Cancel) {
110 let mut ctx = self.clone();
111 let cancel = ctx.add_cancel();
112 (ctx, cancel)
113 }
114
115 pub fn with_deadline(&self, deadline: Deadline) -> Self {
120 let mut ctx = self.clone();
121 ctx.deadline = Some(
122 self.deadline
123 .map_or(deadline, |old| old.min(deadline))
124 .into(),
125 );
126 ctx
127 }
128
129 pub fn with_timeout(&self, timeout: Duration) -> Self {
134 match Deadline::now().checked_add(timeout) {
135 Some(deadline) => self.with_deadline(deadline),
136 None => self.clone(),
137 }
138 }
139
140 pub fn deadline(&self) -> Option<Deadline> {
142 self.deadline.as_deref().copied()
143 }
144
145 pub fn cancelled(&mut self) -> Cancelled<'_> {
147 Cancelled(self)
148 }
149
150 pub async fn until_cancelled<F: Future>(&mut self, fut: F) -> Result<F::Output, CancelReason> {
152 let mut fut = core::pin::pin!(fut);
153 let mut cancelled = core::pin::pin!(self.cancelled());
154 std::future::poll_fn(|cx| {
155 if let Poll::Ready(r) = fut.as_mut().poll(cx) {
156 return Poll::Ready(Ok(r));
157 }
158 if let Poll::Ready(reason) = cancelled.as_mut().poll(cx) {
159 return Poll::Ready(Err(reason));
160 }
161 Poll::Pending
162 })
163 .await
164 }
165
166 pub async fn until_cancelled_failable<F: Future<Output = Result<T, E>>, T, E>(
169 &mut self,
170 fut: F,
171 ) -> Result<T, ErrorOrCancelled<E>> {
172 match self.until_cancelled(fut).await {
173 Ok(Ok(r)) => Ok(r),
174 Ok(Err(e)) => Err(ErrorOrCancelled::Error(e)),
175 Err(reason) => Err(ErrorOrCancelled::Cancelled(reason)),
176 }
177 }
178}
179
180impl Default for CancelContext {
181 fn default() -> Self {
182 Self::new()
183 }
184}
185
186#[must_use]
187#[derive(Debug)]
188pub struct Cancelled<'a>(&'a mut CancelContext);
189
190#[derive(Debug, Protobuf, Copy, Clone, PartialEq, Eq)]
191pub enum CancelReason {
192 Cancelled,
193 DeadlineExceeded,
194}
195
196impl std::fmt::Display for CancelReason {
197 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198 f.pad(match *self {
199 CancelReason::Cancelled => "cancelled",
200 CancelReason::DeadlineExceeded => "deadline exceeded",
201 })
202 }
203}
204
205impl std::error::Error for CancelReason {}
206
207#[derive(Error, Debug)]
208pub enum ErrorOrCancelled<E> {
209 #[error(transparent)]
210 Error(E),
211 #[error(transparent)]
212 Cancelled(CancelReason),
213}
214
215impl Future for Cancelled<'_> {
216 type Output = CancelReason;
217
218 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
219 let this = Pin::get_mut(self);
220 match &mut this.0.state {
221 CancelState::Cancelled(reason) => return Poll::Ready(*reason),
222 CancelState::NotCancelled { ports } => {
223 for p in ports.iter_mut() {
224 if p.poll_recv(cx).is_ready() {
225 let reason = CancelReason::Cancelled;
226 this.0.state = CancelState::Cancelled(reason);
227 return Poll::Ready(reason);
228 }
229 }
230 }
231 }
232 if let Some(deadline) = this.0.deadline {
233 if DeadlineSet::global()
234 .poll(cx, &mut this.0.deadline_id.0, *deadline)
235 .is_ready()
236 {
237 let reason = CancelReason::DeadlineExceeded;
238 this.0.state = CancelState::Cancelled(reason);
239 return Poll::Ready(reason);
240 }
241 }
242 Poll::Pending
243 }
244}
245
246impl Drop for CancelContext {
247 fn drop(&mut self) {
248 DeadlineSet::global().remove(&mut self.deadline_id.0);
250 }
251}
252
253#[derive(Debug)]
258pub struct Cancel(Arc<CancelList>);
259
260#[derive(Debug)]
265struct CancelList {
266 ports: Mutex<Vec<Channel>>,
267}
268
269impl CancelList {
270 fn poll(&self, cx: &mut Context<'_>) {
273 let mut to_drop = Vec::new();
274 let mut ports = self.ports.lock();
275 let mut i = 0;
276 'outer: while i < ports.len() {
277 while let Poll::Ready(message) = ports[i].poll_recv(cx) {
278 match message {
279 Ok(message) => {
280 let resources = message.resources;
283 tracing::trace!(count = resources.len(), "adding ports");
284 ports.extend(resources.into_iter().filter_map(|resource| {
285 Port::try_from(resource).ok().map(|port| port.into())
286 }));
287 }
288 Err(_) => {
289 to_drop.push(ports.swap_remove(i));
292 continue 'outer;
293 }
294 }
295 }
296 i += 1;
297 }
298 if !to_drop.is_empty() {
299 tracing::trace!(count = to_drop.len(), "dropping ports");
300 }
301 }
302
303 fn drain(&self) -> Vec<Channel> {
305 std::mem::take(&mut self.ports.lock())
306 }
307}
308
309struct ListWaker {
311 list: Weak<CancelList>,
312}
313
314impl Wake for ListWaker {
315 fn wake(self: Arc<Self>) {
316 if let Some(list) = self.list.upgrade() {
317 let waker = self.into();
322 let mut cx = Context::from_waker(&waker);
323 list.poll(&mut cx);
324 }
325 }
326}
327
328impl Cancel {
329 fn new(port: Channel) -> Self {
330 let inner = Arc::new(CancelList {
331 ports: Mutex::new(vec![port]),
332 });
333 let waker = Arc::new(ListWaker {
337 list: Arc::downgrade(&inner),
338 });
339 waker.wake();
340 Self(inner)
341 }
342
343 pub fn cancel(&mut self) {
345 drop(self.0.drain());
346 }
347}
348
349#[derive(Debug, Copy, Clone, Eq)]
360pub struct Deadline {
361 system_time: SystemTime,
362 instant: Option<Instant>,
363}
364
365impl Deadline {
366 pub fn now() -> Self {
370 Self {
371 system_time: SystemTime::now(),
372 instant: Some(Instant::now()),
373 }
374 }
375
376 pub fn instant(&self) -> Option<Instant> {
378 self.instant
379 }
380
381 pub fn system_time(&self) -> SystemTime {
383 self.system_time
384 }
385
386 pub fn checked_add(&self, duration: Duration) -> Option<Self> {
388 let instant = self.instant.and_then(|i| i.checked_add(duration));
390 let system_time = self.system_time.checked_add(duration)?;
391 Some(Self {
392 system_time,
393 instant,
394 })
395 }
396}
397
398impl std::ops::Add<Duration> for Deadline {
399 type Output = Self;
400
401 fn add(self, rhs: Duration) -> Self::Output {
402 self.checked_add(rhs)
403 .expect("overflow when adding duration to deadline")
404 }
405}
406
407impl std::ops::Sub<Duration> for Deadline {
408 type Output = Deadline;
409
410 fn sub(self, rhs: Duration) -> Self::Output {
411 Self {
416 system_time: self.system_time.checked_sub(rhs).unwrap_or(UNIX_EPOCH),
417 instant: self.instant.and_then(|i| i.checked_sub(rhs)),
418 }
419 }
420}
421
422impl std::ops::Sub<Deadline> for Deadline {
423 type Output = Duration;
424
425 fn sub(self, rhs: Deadline) -> Self::Output {
426 if let Some((lhs, rhs)) = self.instant.zip(rhs.instant) {
428 lhs.checked_duration_since(rhs).unwrap_or_default()
429 } else {
430 self.system_time
431 .duration_since(rhs.system_time)
432 .unwrap_or_default()
433 }
434 }
435}
436
437impl PartialEq for Deadline {
438 fn eq(&self, other: &Self) -> bool {
439 if let Some((lhs, rhs)) = self.instant.zip(other.instant) {
440 lhs.eq(&rhs)
441 } else {
442 self.system_time.eq(&other.system_time)
443 }
444 }
445}
446
447impl PartialOrd for Deadline {
448 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
449 Some(self.cmp(other))
450 }
451}
452
453impl Ord for Deadline {
454 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
455 if let Some((lhs, rhs)) = self.instant.zip(other.instant) {
456 lhs.cmp(&rhs)
457 } else {
458 self.system_time.cmp(&other.system_time)
459 }
460 }
461}
462
463impl From<SystemTime> for Deadline {
464 fn from(system_time: SystemTime) -> Self {
465 Self {
466 system_time,
467 instant: None,
468 }
469 }
470}
471
472impl From<Deadline> for Timestamp {
473 fn from(deadline: Deadline) -> Self {
474 deadline.system_time.into()
475 }
476}
477
478impl From<Timestamp> for Deadline {
479 fn from(timestamp: Timestamp) -> Self {
480 Self {
481 system_time: timestamp.try_into().unwrap_or(UNIX_EPOCH),
482 instant: None,
483 }
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::CancelContext;
490 use super::CancelReason;
491 use super::Deadline;
492 use pal_async::async_test;
493 use test_with_tracing::test;
494
495 #[async_test]
496 async fn no_cancel() {
497 assert!(futures::poll!(CancelContext::new().cancelled()).is_pending());
498 }
499
500 #[async_test]
501 async fn basic_cancel() {
502 let (mut ctx, mut cancel) = CancelContext::new().with_cancel();
503 cancel.cancel();
504 assert!(futures::poll!(ctx.cancelled()).is_ready());
505 }
506
507 #[expect(clippy::redundant_clone, reason = "explicitly testing chained clones")]
508 async fn chain(use_cancel: bool) {
509 let ctx = CancelContext::new();
510 let (mut ctx, mut cancel) = ctx.with_cancel();
511 if !use_cancel {
512 ctx = ctx.with_timeout(std::time::Duration::from_millis(15));
513 }
514 let ctx = ctx.clone();
515 let ctx = ctx.clone();
516 let ctx = ctx.clone();
517 let ctx = ctx.clone();
518 let ctx = ctx.clone();
519 let ctx = ctx.clone();
520 let ctx = ctx.clone();
521 let mut ctx = ctx.clone();
522 let ctx2 = ctx.clone();
523 let ctx2 = ctx2.clone();
524 let ctx2 = ctx2.clone();
525 let ctx2 = ctx2.clone();
526 let ctx2 = ctx2.clone();
527 let mut ctx2 = ctx2.clone();
528 let _ = ctx2
529 .clone()
530 .clone()
531 .clone()
532 .clone()
533 .clone()
534 .clone()
535 .clone()
536 .clone()
537 .clone();
538 std::thread::sleep(std::time::Duration::from_millis(100));
539 if use_cancel {
540 cancel.cancel();
541 }
542 assert!(futures::poll!(ctx.cancelled()).is_ready());
543 assert!(futures::poll!(ctx2.cancelled()).is_ready());
544 }
545
546 #[async_test]
547 async fn chain_cancel() {
548 chain(true).await
549 }
550
551 #[async_test]
552 async fn chain_deadline() {
553 chain(false).await
554 }
555
556 #[async_test]
557 async fn cancel_deadline() {
558 let mut ctx = CancelContext::new().with_timeout(std::time::Duration::from_millis(0));
559 assert_eq!(ctx.cancelled().await, CancelReason::DeadlineExceeded);
560 let mut ctx = CancelContext::new().with_timeout(std::time::Duration::from_millis(100));
561 assert_eq!(ctx.cancelled().await, CancelReason::DeadlineExceeded);
562 }
563
564 #[test]
565 fn test_encode_deadline() {
566 let check = |deadline: Deadline| {
567 let timestamp: super::Timestamp = deadline.into();
568 let deadline2: Deadline = timestamp.into();
569 assert_eq!(deadline, deadline2);
570 };
571
572 check(Deadline::now());
573 check(Deadline::now() + std::time::Duration::from_secs(1));
574 check(Deadline::now() - std::time::Duration::from_secs(1));
575 check(Deadline::from(
576 std::time::SystemTime::UNIX_EPOCH - std::time::Duration::from_nanos(1_500_000_000),
577 ));
578 check(Deadline::from(
579 std::time::SystemTime::UNIX_EPOCH + std::time::Duration::from_nanos(1_500_000_000),
580 ));
581 }
582}