use crate::driver::Driver;
use crate::driver::PollImpl;
use crate::sparsevec::SparseVec;
use crate::waker::WakerList;
use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use std::task::Waker;
use std::time::Duration;
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
pub struct Instant(u64);
impl Instant {
pub fn now() -> Self {
Self(crate::sys::monotonic_nanos_now())
}
pub fn as_nanos(&self) -> u64 {
self.0
}
pub fn from_nanos(nanos: u64) -> Self {
Self(nanos)
}
pub fn saturating_add(self, duration: Duration) -> Self {
Self(
self.0
.saturating_add(duration.as_nanos().try_into().unwrap_or(u64::MAX)),
)
}
}
impl std::ops::Sub for Instant {
type Output = Duration;
fn sub(self, rhs: Instant) -> Self::Output {
Duration::from_nanos(
self.0.checked_sub(rhs.0).unwrap_or_else(|| {
panic!("supplied instant {:#x} is later than {:#x}", rhs.0, self.0)
}),
)
}
}
impl std::ops::Add<Duration> for Instant {
type Output = Instant;
fn add(self, rhs: Duration) -> Self::Output {
Self(
self.0
.checked_add(rhs.as_nanos().try_into().expect("duration too large"))
.expect("supplied duration causes overflow"),
)
}
}
impl std::ops::Sub<Duration> for Instant {
type Output = Instant;
fn sub(self, rhs: Duration) -> Self::Output {
Self(
self.0
.checked_sub(rhs.as_nanos().try_into().expect("duration too large"))
.expect("supplied instant is later than self"),
)
}
}
pub trait TimerDriver: Unpin {
type Timer: 'static + PollTimer;
fn new_timer(&self) -> Self::Timer;
}
pub trait PollTimer: Unpin + Send + Sync {
fn poll_timer(&mut self, cx: &mut Context<'_>, deadline: Option<Instant>) -> Poll<Instant>;
fn set_deadline(&mut self, deadline: Instant);
}
pub struct PolledTimer(PollImpl<dyn PollTimer>);
impl PolledTimer {
pub fn new(driver: &(impl ?Sized + Driver)) -> Self {
Self(driver.new_dyn_timer())
}
pub fn sleep(&mut self, duration: Duration) -> Sleep<'_> {
self.sleep_until(Instant::now() + duration)
}
pub fn sleep_until(&mut self, deadline: Instant) -> Sleep<'_> {
self.0.set_deadline(deadline);
Sleep {
timer: self,
deadline,
}
}
}
#[must_use]
pub struct Sleep<'a> {
timer: &'a mut PolledTimer,
deadline: Instant,
}
impl Future for Sleep<'_> {
type Output = Instant;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let deadline = self.deadline;
self.timer.0.poll_timer(cx, Some(deadline))
}
}
#[derive(Debug)]
struct TimerEntry {
deadline: Instant,
waker: Option<Waker>,
}
#[derive(Debug, Default)]
pub(crate) struct TimerQueue {
timers: SparseVec<TimerEntry>,
}
#[derive(Debug, Copy, Clone)]
pub(crate) struct TimerQueueId(usize);
pub(crate) enum TimerResult {
TimedOut(Instant),
Pending(Instant),
}
impl TimerQueue {
pub fn add(&mut self) -> TimerQueueId {
TimerQueueId(self.timers.add(TimerEntry {
deadline: Instant::from_nanos(0),
waker: None,
}))
}
#[must_use]
pub fn remove(&mut self, id: TimerQueueId) -> Option<Waker> {
self.timers.remove(id.0).waker
}
pub fn poll_deadline(&mut self, cx: &mut Context<'_>, id: TimerQueueId) -> TimerResult {
let timer = &mut self.timers[id.0];
let now = Instant::now();
if timer.deadline <= now {
TimerResult::TimedOut(now)
} else {
let waker = cx.waker();
if let Some(old_waker) = &mut timer.waker {
old_waker.clone_from(waker);
} else {
timer.waker = Some(waker.clone());
}
TimerResult::Pending(timer.deadline)
}
}
pub fn set_deadline(&mut self, id: TimerQueueId, deadline: Instant) -> bool {
let timer = &mut self.timers[id.0];
let update = timer.waker.is_some() && timer.deadline > deadline;
timer.deadline = deadline;
update
}
pub fn wake_expired(&mut self, wakers: &mut WakerList) {
let mut now = None;
wakers.extend(self.timers.iter_mut().filter_map(|(_, timer)| {
if timer.waker.is_some() && timer.deadline <= *now.get_or_insert_with(Instant::now) {
let waker = timer.waker.take().unwrap();
Some(waker)
} else {
None
}
}))
}
pub fn next_deadline(&self) -> Option<Instant> {
self.timers
.iter()
.filter_map(|(_, entry)| entry.waker.is_some().then_some(entry.deadline))
.min()
}
}
#[cfg(test)]
mod tests {
use super::Instant;
use std::time::Duration;
#[test]
fn test_instant() {
let start = Instant::now();
std::thread::sleep(Duration::from_millis(100));
let end = Instant::now();
assert!(end - start >= Duration::from_millis(100));
assert!(end - start < Duration::from_millis(400));
}
}