use super::bidir::Channel;
use super::deadline::DeadlineId;
use super::deadline::DeadlineSet;
use mesh_node::local_node::Port;
use mesh_node::resource::Resource;
use mesh_protobuf::encoding::IgnoreField;
use mesh_protobuf::EncodeAs;
use mesh_protobuf::Protobuf;
use mesh_protobuf::SerializedMessage;
use mesh_protobuf::Timestamp;
use parking_lot::Mutex;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::Weak;
use std::task::Context;
use std::task::Poll;
use std::task::Wake;
use std::time::Duration;
use std::time::Instant;
use std::time::SystemTime;
use std::time::UNIX_EPOCH;
#[derive(Debug, Protobuf)]
#[mesh(resource = "Resource")]
pub struct CancelContext {
state: CancelState,
deadline: Option<EncodeAs<Deadline, Timestamp>>,
deadline_id: Ignore<DeadlineId>,
}
#[derive(Debug, Protobuf)]
#[mesh(resource = "Resource")]
enum CancelState {
NotCancelled { ports: Vec<Channel> },
Cancelled(CancelReason),
}
#[derive(Debug, Default)]
struct Ignore<T>(T);
impl<T: Default> mesh_protobuf::DefaultEncoding for Ignore<T> {
type Encoding = IgnoreField;
}
impl Clone for CancelContext {
fn clone(&self) -> Self {
let state = match &self.state {
CancelState::Cancelled(reason) => CancelState::Cancelled(*reason),
CancelState::NotCancelled { ports, .. } => CancelState::NotCancelled {
ports: ports
.iter()
.map(|port| {
let (send, recv) = <Channel>::new_pair();
port.send(SerializedMessage {
data: vec![],
resources: vec![Resource::Port(recv.into())],
});
send
})
.collect(),
},
};
Self {
state,
deadline: self.deadline,
deadline_id: Default::default(),
}
}
}
impl CancelContext {
pub fn new() -> Self {
Self {
state: CancelState::NotCancelled { ports: Vec::new() },
deadline: None,
deadline_id: Default::default(),
}
}
fn add_cancel(&mut self) -> Cancel {
let (send, recv) = Channel::new_pair();
match &mut self.state {
CancelState::Cancelled(_) => {}
CancelState::NotCancelled { ports, .. } => ports.push(send),
}
Cancel::new(recv)
}
pub fn with_cancel(&self) -> (Self, Cancel) {
let mut ctx = self.clone();
let cancel = ctx.add_cancel();
(ctx, cancel)
}
pub fn with_deadline(&self, deadline: Deadline) -> Self {
let mut ctx = self.clone();
ctx.deadline = Some(
self.deadline
.map_or(deadline, |old| old.min(deadline))
.into(),
);
ctx
}
pub fn with_timeout(&self, timeout: Duration) -> Self {
match Deadline::now().checked_add(timeout) {
Some(deadline) => self.with_deadline(deadline),
None => self.clone(),
}
}
pub fn deadline(&self) -> Option<Deadline> {
self.deadline.as_deref().copied()
}
pub fn cancelled(&mut self) -> Cancelled<'_> {
Cancelled(self)
}
pub async fn until_cancelled<F: Future>(&mut self, fut: F) -> Result<F::Output, CancelReason> {
let mut fut = core::pin::pin!(fut);
let mut cancelled = core::pin::pin!(self.cancelled());
std::future::poll_fn(|cx| {
if let Poll::Ready(r) = fut.as_mut().poll(cx) {
return Poll::Ready(Ok(r));
}
if let Poll::Ready(reason) = cancelled.as_mut().poll(cx) {
return Poll::Ready(Err(reason));
}
Poll::Pending
})
.await
}
}
impl Default for CancelContext {
fn default() -> Self {
Self::new()
}
}
#[must_use]
#[derive(Debug)]
pub struct Cancelled<'a>(&'a mut CancelContext);
#[derive(Debug, Protobuf, Copy, Clone, PartialEq, Eq)]
pub enum CancelReason {
Cancelled,
DeadlineExceeded,
}
impl std::fmt::Display for CancelReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.pad(match *self {
CancelReason::Cancelled => "cancelled",
CancelReason::DeadlineExceeded => "deadline exceeded",
})
}
}
impl std::error::Error for CancelReason {}
impl Future for Cancelled<'_> {
type Output = CancelReason;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = Pin::get_mut(self);
match &mut this.0.state {
CancelState::Cancelled(reason) => return Poll::Ready(*reason),
CancelState::NotCancelled { ports } => {
for p in ports.iter_mut() {
if p.poll_recv(cx).is_ready() {
let reason = CancelReason::Cancelled;
this.0.state = CancelState::Cancelled(reason);
return Poll::Ready(reason);
}
}
}
}
if let Some(deadline) = this.0.deadline {
if DeadlineSet::global()
.poll(cx, &mut this.0.deadline_id.0, *deadline)
.is_ready()
{
let reason = CancelReason::DeadlineExceeded;
this.0.state = CancelState::Cancelled(reason);
return Poll::Ready(reason);
}
}
Poll::Pending
}
}
impl Drop for CancelContext {
fn drop(&mut self) {
DeadlineSet::global().remove(&mut self.deadline_id.0);
}
}
#[derive(Debug)]
pub struct Cancel(Arc<CancelList>);
#[derive(Debug)]
struct CancelList {
ports: Mutex<Vec<Channel>>,
}
impl CancelList {
fn poll(&self, cx: &mut Context<'_>) {
let mut to_drop = Vec::new();
let mut ports = self.ports.lock();
let mut i = 0;
'outer: while i < ports.len() {
while let Poll::Ready(message) = ports[i].poll_recv(cx) {
match message {
Ok(message) => {
let resources = message.resources;
tracing::trace!(count = resources.len(), "adding ports");
ports.extend(resources.into_iter().filter_map(|resource| {
Port::try_from(resource).ok().map(|port| port.into())
}));
}
Err(_) => {
to_drop.push(ports.swap_remove(i));
continue 'outer;
}
}
}
i += 1;
}
if !to_drop.is_empty() {
tracing::trace!(count = to_drop.len(), "dropping ports");
}
}
fn drain(&self) -> Vec<Channel> {
std::mem::take(&mut self.ports.lock())
}
}
struct ListWaker {
list: Weak<CancelList>,
}
impl Wake for ListWaker {
fn wake(self: Arc<Self>) {
if let Some(list) = self.list.upgrade() {
let waker = self.into();
let mut cx = Context::from_waker(&waker);
list.poll(&mut cx);
}
}
}
impl Cancel {
fn new(port: Channel) -> Self {
let inner = Arc::new(CancelList {
ports: Mutex::new(vec![port]),
});
let waker = Arc::new(ListWaker {
list: Arc::downgrade(&inner),
});
waker.wake();
Self(inner)
}
pub fn cancel(&mut self) {
drop(self.0.drain());
}
}
#[derive(Debug, Copy, Clone, Eq)]
pub struct Deadline {
system_time: SystemTime,
instant: Option<Instant>,
}
impl Deadline {
pub fn now() -> Self {
Self {
system_time: SystemTime::now(),
instant: Some(Instant::now()),
}
}
pub fn instant(&self) -> Option<Instant> {
self.instant
}
pub fn system_time(&self) -> SystemTime {
self.system_time
}
pub fn checked_add(&self, duration: Duration) -> Option<Self> {
let instant = self.instant.and_then(|i| i.checked_add(duration));
let system_time = self.system_time.checked_add(duration)?;
Some(Self {
system_time,
instant,
})
}
}
impl std::ops::Add<Duration> for Deadline {
type Output = Self;
fn add(self, rhs: Duration) -> Self::Output {
self.checked_add(rhs)
.expect("overflow when adding duration to deadline")
}
}
impl std::ops::Sub<Duration> for Deadline {
type Output = Deadline;
fn sub(self, rhs: Duration) -> Self::Output {
Self {
system_time: self.system_time.checked_sub(rhs).unwrap_or(UNIX_EPOCH),
instant: self.instant.and_then(|i| i.checked_sub(rhs)),
}
}
}
impl std::ops::Sub<Deadline> for Deadline {
type Output = Duration;
fn sub(self, rhs: Deadline) -> Self::Output {
if let Some((lhs, rhs)) = self.instant.zip(rhs.instant) {
lhs.checked_duration_since(rhs).unwrap_or_default()
} else {
self.system_time
.duration_since(rhs.system_time)
.unwrap_or_default()
}
}
}
impl PartialEq for Deadline {
fn eq(&self, other: &Self) -> bool {
if let Some((lhs, rhs)) = self.instant.zip(other.instant) {
lhs.eq(&rhs)
} else {
self.system_time.eq(&other.system_time)
}
}
}
impl PartialOrd for Deadline {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Deadline {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
if let Some((lhs, rhs)) = self.instant.zip(other.instant) {
lhs.cmp(&rhs)
} else {
self.system_time.cmp(&other.system_time)
}
}
}
impl From<SystemTime> for Deadline {
fn from(system_time: SystemTime) -> Self {
Self {
system_time,
instant: None,
}
}
}
impl From<Deadline> for Timestamp {
fn from(deadline: Deadline) -> Self {
deadline.system_time.into()
}
}
impl From<Timestamp> for Deadline {
fn from(timestamp: Timestamp) -> Self {
Self {
system_time: timestamp.try_into().unwrap_or(UNIX_EPOCH),
instant: None,
}
}
}
#[cfg(test)]
mod tests {
use super::CancelContext;
use super::CancelReason;
use super::Deadline;
use pal_async::async_test;
use test_with_tracing::test;
#[async_test]
async fn no_cancel() {
assert!(futures::poll!(CancelContext::new().cancelled()).is_pending());
}
#[async_test]
async fn basic_cancel() {
let (mut ctx, mut cancel) = CancelContext::new().with_cancel();
cancel.cancel();
assert!(futures::poll!(ctx.cancelled()).is_ready());
}
#[allow(clippy::redundant_clone)] async fn chain(use_cancel: bool) {
let ctx = CancelContext::new();
let (mut ctx, mut cancel) = ctx.with_cancel();
if !use_cancel {
ctx = ctx.with_timeout(std::time::Duration::from_millis(15));
}
let ctx = ctx.clone();
let ctx = ctx.clone();
let ctx = ctx.clone();
let ctx = ctx.clone();
let ctx = ctx.clone();
let ctx = ctx.clone();
let ctx = ctx.clone();
let mut ctx = ctx.clone();
let ctx2 = ctx.clone();
let ctx2 = ctx2.clone();
let ctx2 = ctx2.clone();
let ctx2 = ctx2.clone();
let ctx2 = ctx2.clone();
let mut ctx2 = ctx2.clone();
let _ = ctx2
.clone()
.clone()
.clone()
.clone()
.clone()
.clone()
.clone()
.clone()
.clone();
std::thread::sleep(std::time::Duration::from_millis(100));
if use_cancel {
cancel.cancel();
}
assert!(futures::poll!(ctx.cancelled()).is_ready());
assert!(futures::poll!(ctx2.cancelled()).is_ready());
}
#[async_test]
async fn chain_cancel() {
chain(true).await
}
#[async_test]
async fn chain_deadline() {
chain(false).await
}
#[async_test]
async fn cancel_deadline() {
let mut ctx = CancelContext::new().with_timeout(std::time::Duration::from_millis(0));
assert_eq!(ctx.cancelled().await, CancelReason::DeadlineExceeded);
let mut ctx = CancelContext::new().with_timeout(std::time::Duration::from_millis(100));
assert_eq!(ctx.cancelled().await, CancelReason::DeadlineExceeded);
}
#[test]
fn test_encode_deadline() {
let check = |deadline: Deadline| {
let timestamp: super::Timestamp = deadline.into();
let deadline2: Deadline = timestamp.into();
assert_eq!(deadline, deadline2);
};
check(Deadline::now());
check(Deadline::now() + std::time::Duration::from_secs(1));
check(Deadline::now() - std::time::Duration::from_secs(1));
check(Deadline::from(
std::time::SystemTime::UNIX_EPOCH - std::time::Duration::from_nanos(1_500_000_000),
));
check(Deadline::from(
std::time::SystemTime::UNIX_EPOCH + std::time::Duration::from_nanos(1_500_000_000),
));
}
}