1use super::threadpool::Io;
7use super::threadpool::IoInitiator;
8use futures::FutureExt;
9use io_uring::opcode;
10use io_uring::types::TimeoutFlags;
11use io_uring::types::Timespec;
12use pal_async::fd::FdReadyDriver;
13use pal_async::fd::PollFdReady;
14use pal_async::interest::InterestSlot;
15use pal_async::interest::PollEvents;
16use pal_async::interest::SLOT_COUNT;
17use pal_async::timer::Instant;
18use pal_async::timer::PollTimer;
19use pal_async::timer::TimerDriver;
20use pal_async::wait::MAXIMUM_WAIT_READ_SIZE;
21use pal_async::wait::PollWait;
22use pal_async::wait::WaitDriver;
23use std::fmt::Debug;
24use std::io;
25use std::os::unix::prelude::*;
26use std::sync::OnceLock;
27use std::task::Context;
28use std::task::Poll;
29use std::task::Waker;
30
31pub trait Initiate: 'static + Send + Sync + Unpin {
34 fn initiator(&self) -> &IoInitiator;
39}
40
41impl Initiate for IoInitiator {
42 fn initiator(&self) -> &IoInitiator {
43 self
44 }
45}
46
47#[derive(Debug)]
49pub struct FdReady<T: Initiate> {
50 fd: RawFd,
51 initiator: T,
52 interests: [Interest; SLOT_COUNT],
53}
54
55impl<T: Initiate> FdReady<T> {
56 pub fn new(initiator: T, fd: RawFd) -> Self {
58 FdReady {
59 fd,
60 initiator,
61 interests: Default::default(),
62 }
63 }
64}
65
66impl FdReadyDriver for IoInitiator {
67 type FdReady = FdReady<Self>;
68
69 fn new_fd_ready(&self, fd: RawFd) -> io::Result<Self::FdReady> {
70 Ok(FdReady::new(self.clone(), fd))
71 }
72}
73
74#[derive(Debug, Default)]
75struct Interest {
76 io: Option<Io<()>>,
77 cancelled: bool,
78 events: PollEvents,
79 revents: PollEvents,
80}
81
82impl<T: Initiate> PollFdReady for FdReady<T> {
83 fn poll_fd_ready(
84 &mut self,
85 cx: &mut Context<'_>,
86 slot: InterestSlot,
87 events: PollEvents,
88 ) -> Poll<PollEvents> {
89 let interest = &mut self.interests[slot as usize];
90 loop {
91 if !(interest.revents & events).is_empty() {
92 break Poll::Ready(interest.revents & events);
93 } else if let Some(io) = &mut interest.io {
94 if interest.events & events != events && !interest.cancelled {
100 io.cancel_poll();
101 interest.cancelled = true;
102 }
103 let result = std::task::ready!(io.poll_unpin(cx));
104 interest.io = None;
105 match result {
106 Ok(poll_revents) => {
107 interest.revents |= PollEvents::from_poll_events(poll_revents as i16);
108 }
109 Err(err) if err.raw_os_error() == Some(libc::ECANCELED) => {}
110 Err(err) => panic!("poll failed: {}", err),
111 }
112 } else {
113 interest.events = events;
114 let sqe = opcode::PollAdd::new(
115 io_uring::types::Fd(self.fd),
116 events.to_poll_events() as u32,
117 )
118 .build();
119 let io = unsafe { Io::new(self.initiator.initiator().clone(), sqe, ()) };
122 interest.io = Some(io);
123 interest.cancelled = false;
124 }
125 }
126 }
127
128 fn clear_fd_ready(&mut self, slot: InterestSlot) {
129 let interest = &mut self.interests[slot as usize];
130 interest.revents = PollEvents::EMPTY;
131 }
132}
133
134#[derive(Debug)]
136pub struct FdWait<T: Initiate> {
137 inner: FdWaitInner<T>,
138}
139
140#[derive(Debug)]
141enum FdWaitInner<T: Initiate> {
142 ViaPoll(pal_async::unix::FdWait<FdReady<T>>),
143 ViaRead(FdWaitViaRead<T>),
144}
145
146impl WaitDriver for IoInitiator {
147 type Wait = FdWait<Self>;
148
149 fn new_wait(&self, fd: RawFd, read_size: usize) -> io::Result<Self::Wait> {
150 Ok(FdWait::new(self.clone(), fd, read_size))
151 }
152}
153
154impl<T: Initiate> FdWait<T> {
155 pub fn new(initiator: T, fd: RawFd, read_size: usize) -> Self {
157 static SUPPORTS_NONBLOCK_READ: OnceLock<bool> = OnceLock::new();
158 const LINKAT: u8 = 39;
162 let supports_nonblock_read =
163 *SUPPORTS_NONBLOCK_READ.get_or_init(|| initiator.initiator().probe(LINKAT));
164
165 let inner = if supports_nonblock_read {
166 assert!(read_size <= MAXIMUM_WAIT_READ_SIZE);
167 FdWaitInner::ViaRead(FdWaitViaRead {
168 fd,
169 read_size,
170 initiator,
171 state: FdWaitViaReadState::Idle(Box::new(0)),
172 })
173 } else {
174 FdWaitInner::ViaPoll(pal_async::unix::FdWait::new(
175 fd,
176 FdReady::new(initiator, fd),
177 read_size,
178 ))
179 };
180 FdWait { inner }
181 }
182}
183
184impl<T: Initiate> PollWait for FdWait<T> {
185 fn poll_wait(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
186 match &mut self.inner {
187 FdWaitInner::ViaPoll(wait) => wait.poll_wait(cx),
188 FdWaitInner::ViaRead(wait) => wait.poll_wait(cx),
189 }
190 }
191
192 fn poll_cancel_wait(&mut self, cx: &mut Context<'_>) -> Poll<bool> {
193 match &mut self.inner {
194 FdWaitInner::ViaPoll(wait) => wait.poll_cancel_wait(cx),
195 FdWaitInner::ViaRead(wait) => wait.poll_cancel_wait(cx),
196 }
197 }
198}
199
200#[derive(Debug)]
201struct FdWaitViaRead<T: Initiate> {
202 fd: RawFd,
203 read_size: usize,
204 initiator: T,
205 state: FdWaitViaReadState,
206}
207
208#[derive(Debug)]
209enum FdWaitViaReadState {
210 Idle(Box<u64>),
211 ReadPending { io: Io<Box<u64>>, cancelling: bool },
212 Invalid,
213}
214
215impl<T: Initiate> PollWait for FdWaitViaRead<T> {
216 fn poll_wait(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
217 loop {
218 match std::mem::replace(&mut self.state, FdWaitViaReadState::Invalid) {
219 FdWaitViaReadState::Idle(mut buf) => {
220 assert!(self.read_size <= 8);
221 let sqe = opcode::Read::new(
222 io_uring::types::Fd(self.fd),
223 std::ptr::from_mut(&mut *buf).cast(),
224 self.read_size as u32,
225 )
226 .build();
227 let io = unsafe { Io::new(self.initiator.initiator().clone(), sqe, buf) };
230 self.state = FdWaitViaReadState::ReadPending {
231 io,
232 cancelling: false,
233 };
234 }
235 FdWaitViaReadState::ReadPending { mut io, cancelling } => match io.poll_unpin(cx) {
236 Poll::Ready(r) => {
237 self.state = FdWaitViaReadState::Idle(io.into_mem());
238 match r {
239 Ok(_) => break Poll::Ready(Ok(())),
240 Err(err) if err.raw_os_error() == Some(libc::ECANCELED) => {}
241 Err(err) => return Poll::Ready(Err(err)),
242 }
243 }
244 Poll::Pending => {
245 self.state = FdWaitViaReadState::ReadPending { io, cancelling };
246 return Poll::Pending;
247 }
248 },
249 FdWaitViaReadState::Invalid => unreachable!(),
250 }
251 }
252 }
253
254 fn poll_cancel_wait(&mut self, cx: &mut Context<'_>) -> Poll<bool> {
255 loop {
256 match std::mem::replace(&mut self.state, FdWaitViaReadState::Invalid) {
257 FdWaitViaReadState::Idle(buf) => {
258 self.state = FdWaitViaReadState::Idle(buf);
259 break Poll::Ready(false);
260 }
261 FdWaitViaReadState::ReadPending { mut io, cancelling } => {
262 if cancelling {
263 match io.poll_unpin(cx) {
264 Poll::Ready(r) => {
265 self.state = FdWaitViaReadState::Idle(io.into_mem());
266 break Poll::Ready(r.is_ok());
273 }
274 Poll::Pending => {
275 self.state = FdWaitViaReadState::ReadPending { io, cancelling };
276 break Poll::Pending;
277 }
278 }
279 } else {
280 io.cancel();
281 self.state = FdWaitViaReadState::ReadPending {
282 io,
283 cancelling: true,
284 };
285 }
286 }
287 FdWaitViaReadState::Invalid => unreachable!(),
288 }
289 }
290 }
291}
292
293impl<T: Initiate> Drop for FdWaitViaRead<T> {
294 fn drop(&mut self) {
295 let _ = self.poll_cancel_wait(&mut Context::from_waker(Waker::noop()));
296 }
297}
298
299#[derive(Debug)]
301pub struct Timer<T: Initiate> {
302 initiator: T,
303 target_deadline: Instant,
304 state: Option<TimerState>,
305}
306
307impl<T: Initiate> Timer<T> {
308 pub fn new(initiator: T) -> Self {
310 Timer {
311 initiator,
312 target_deadline: Instant::from_nanos(0),
313 state: None,
314 }
315 }
316}
317
318#[derive(Debug)]
319struct TimerState {
320 io: Io<Box<Timespec>>,
321 cancelled: bool,
322}
323
324impl TimerDriver for IoInitiator {
325 type Timer = Timer<Self>;
326
327 fn new_timer(&self) -> Self::Timer {
328 Timer::new(self.clone())
329 }
330}
331
332impl<T: Initiate> PollTimer for Timer<T> {
333 fn poll_timer(&mut self, cx: &mut Context<'_>, deadline: Option<Instant>) -> Poll<Instant> {
334 if let Some(deadline) = deadline {
335 self.set_deadline(deadline);
336 }
337 loop {
338 let now = Instant::now();
339 if self.target_deadline <= now {
340 break Poll::Ready(now);
341 } else if let Some(state) = &mut self.state {
342 let _ = std::task::ready!(state.io.poll_unpin(cx));
343 self.state = None;
344 } else {
345 let absolute_timeout = self.target_deadline - Instant::from_nanos(0);
348 let timespec = Box::new(
349 Timespec::new()
350 .sec(absolute_timeout.as_secs())
351 .nsec(absolute_timeout.subsec_nanos()),
352 );
353 let sqe = {
354 opcode::Timeout::new(&*timespec)
355 .flags(TimeoutFlags::ABS)
356 .build()
357 };
358 let io = unsafe { Io::new(self.initiator.initiator().clone(), sqe, timespec) };
361 let state = TimerState {
362 io,
363 cancelled: false,
364 };
365 self.state = Some(state);
366 }
367 }
368 }
369
370 fn set_deadline(&mut self, deadline: Instant) {
371 if let Some(state) = &mut self.state {
372 if self.target_deadline > deadline && !state.cancelled {
378 state.io.cancel_timeout();
379 state.cancelled = true;
380 }
381 }
382 self.target_deadline = deadline;
383 }
384}
385
386#[cfg(test)]
387pub(crate) mod tests {
388 use crate::IoInitiator;
389 use crate::IoUringPool;
390 use futures::executor::block_on;
391 use once_cell::sync::OnceCell;
392 use pal_async::executor_tests;
393 use pal_async::task::Spawn;
394 use std::future::Future;
395 use std::io;
396 use std::thread::JoinHandle;
397
398 pub struct SingleThreadPool {
399 _thread: JoinHandle<()>,
400 initiator: IoInitiator,
401 }
402
403 impl SingleThreadPool {
404 pub fn new() -> io::Result<Self> {
405 let pool = IoUringPool::new("test", 16)?;
406 let initiator = pool.client().initiator().clone();
407 let thread = std::thread::spawn(move || pool.run());
408 Ok(Self {
409 _thread: thread,
410 initiator,
411 })
412 }
413
414 pub fn initiator(&self) -> &IoInitiator {
415 &self.initiator
416 }
417 }
418
419 fn test_pool() -> io::Result<&'static SingleThreadPool> {
420 static POOL: OnceCell<SingleThreadPool> = OnceCell::new();
422 POOL.get_or_try_init(SingleThreadPool::new)
423 }
424
425 macro_rules! get_pool_or_skip {
426 () => {
427 match test_pool() {
428 Ok(pool) => pool,
429 Err(err) if err.raw_os_error() == Some(libc::ENOSYS) => {
430 println!("Test case skipped (no IO-Uring support)");
431 return;
432 }
433 Err(err) => panic!("{}", err),
434 }
435 };
436 }
437
438 fn run_until<F>(pool: &SingleThreadPool, fut: F) -> F::Output
439 where
440 F: 'static + Future + Send,
441 F::Output: Send,
442 {
443 block_on(pool.initiator().spawn("test", fut))
444 }
445
446 #[test]
447 fn waker_works() {
448 run_until(get_pool_or_skip!(), executor_tests::waker_tests());
449 }
450
451 #[test]
452 fn spawn_works() {
453 let pool = get_pool_or_skip!();
454 executor_tests::spawn_tests(|| (pool.initiator(), || ()))
455 }
456
457 #[test]
458 fn sleep_works() {
459 let pool = get_pool_or_skip!();
460 run_until(pool, executor_tests::sleep_tests(pool.initiator().clone()))
461 }
462
463 #[test]
464 fn wait_works() {
465 let pool = get_pool_or_skip!();
466 run_until(pool, executor_tests::wait_tests(pool.initiator().clone()))
467 }
468
469 #[test]
470 fn socket_works() {
471 let pool = get_pool_or_skip!();
472 run_until(pool, executor_tests::socket_tests(pool.initiator().clone()))
473 }
474}