pal_uring/
uring.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Driver implementation for the `pal` crate's io-uring threadpool.
5
6use super::threadpool::Io;
7use super::threadpool::IoInitiator;
8use futures::FutureExt;
9use io_uring::opcode;
10use io_uring::types::TimeoutFlags;
11use io_uring::types::Timespec;
12use pal_async::fd::FdReadyDriver;
13use pal_async::fd::PollFdReady;
14use pal_async::interest::InterestSlot;
15use pal_async::interest::PollEvents;
16use pal_async::interest::SLOT_COUNT;
17use pal_async::timer::Instant;
18use pal_async::timer::PollTimer;
19use pal_async::timer::TimerDriver;
20use pal_async::wait::MAXIMUM_WAIT_READ_SIZE;
21use pal_async::wait::PollWait;
22use pal_async::wait::WaitDriver;
23use std::fmt::Debug;
24use std::io;
25use std::os::unix::prelude::*;
26use std::sync::OnceLock;
27use std::task::Context;
28use std::task::Poll;
29use std::task::Waker;
30
31/// An object that can be used to initiate an IO, by returning a reference to an
32/// [`IoInitiator`].
33pub trait Initiate: 'static + Send + Sync + Unpin {
34    /// Returns a reference to the initiator to use for IO operations.
35    ///
36    /// A different initiator may be returned each time this is called, allowing
37    /// an object (timer, socket, etc.) to be moved between initiators.
38    fn initiator(&self) -> &IoInitiator;
39}
40
41impl Initiate for IoInitiator {
42    fn initiator(&self) -> &IoInitiator {
43        self
44    }
45}
46
47/// A [`pal_async::fd::PollFdReady`] implementation for io_uring.
48#[derive(Debug)]
49pub struct FdReady<T: Initiate> {
50    fd: RawFd,
51    initiator: T,
52    interests: [Interest; SLOT_COUNT],
53}
54
55impl<T: Initiate> FdReady<T> {
56    /// Creates a new `FdReady` for the given file descriptor and initiator.
57    pub fn new(initiator: T, fd: RawFd) -> Self {
58        FdReady {
59            fd,
60            initiator,
61            interests: Default::default(),
62        }
63    }
64}
65
66impl FdReadyDriver for IoInitiator {
67    type FdReady = FdReady<Self>;
68
69    fn new_fd_ready(&self, fd: RawFd) -> io::Result<Self::FdReady> {
70        Ok(FdReady::new(self.clone(), fd))
71    }
72}
73
74#[derive(Debug, Default)]
75struct Interest {
76    io: Option<Io<()>>,
77    cancelled: bool,
78    events: PollEvents,
79    revents: PollEvents,
80}
81
82impl<T: Initiate> PollFdReady for FdReady<T> {
83    fn poll_fd_ready(
84        &mut self,
85        cx: &mut Context<'_>,
86        slot: InterestSlot,
87        events: PollEvents,
88    ) -> Poll<PollEvents> {
89        let interest = &mut self.interests[slot as usize];
90        loop {
91            if !(interest.revents & events).is_empty() {
92                break Poll::Ready(interest.revents & events);
93            } else if let Some(io) = &mut interest.io {
94                // Cancel the current operation if not all the requested events
95                // are included in the current IO.
96                //
97                // FUTURE: just update the current poll operation. This requires
98                // >= Linux 5.11.
99                if interest.events & events != events && !interest.cancelled {
100                    io.cancel_poll();
101                    interest.cancelled = true;
102                }
103                let result = std::task::ready!(io.poll_unpin(cx));
104                interest.io = None;
105                match result {
106                    Ok(poll_revents) => {
107                        interest.revents |= PollEvents::from_poll_events(poll_revents as i16);
108                    }
109                    Err(err) if err.raw_os_error() == Some(libc::ECANCELED) => {}
110                    Err(err) => panic!("poll failed: {}", err),
111                }
112            } else {
113                interest.events = events;
114                let sqe = opcode::PollAdd::new(
115                    io_uring::types::Fd(self.fd),
116                    events.to_poll_events() as u32,
117                )
118                .build();
119                // SAFETY: the PollAdd entry does not reference any external
120                // memory.
121                let io = unsafe { Io::new(self.initiator.initiator().clone(), sqe, ()) };
122                interest.io = Some(io);
123                interest.cancelled = false;
124            }
125        }
126    }
127
128    fn clear_fd_ready(&mut self, slot: InterestSlot) {
129        let interest = &mut self.interests[slot as usize];
130        interest.revents = PollEvents::EMPTY;
131    }
132}
133
134/// A [`pal_async::wait::PollWait`] implementation for io_uring.
135#[derive(Debug)]
136pub struct FdWait<T: Initiate> {
137    inner: FdWaitInner<T>,
138}
139
140#[derive(Debug)]
141enum FdWaitInner<T: Initiate> {
142    ViaPoll(pal_async::unix::FdWait<FdReady<T>>),
143    ViaRead(FdWaitViaRead<T>),
144}
145
146impl WaitDriver for IoInitiator {
147    type Wait = FdWait<Self>;
148
149    fn new_wait(&self, fd: RawFd, read_size: usize) -> io::Result<Self::Wait> {
150        Ok(FdWait::new(self.clone(), fd, read_size))
151    }
152}
153
154impl<T: Initiate> FdWait<T> {
155    /// Creates a new instance for the given file descriptor and initiator.
156    pub fn new(initiator: T, fd: RawFd, read_size: usize) -> Self {
157        static SUPPORTS_NONBLOCK_READ: OnceLock<bool> = OnceLock::new();
158        // There is no easy way to detect whether the ring supports nonblocking
159        // reads, but the functionality was added in the same release as linkat
160        // (5.15), so that's probably as close as we're getting.
161        const LINKAT: u8 = 39;
162        let supports_nonblock_read =
163            *SUPPORTS_NONBLOCK_READ.get_or_init(|| initiator.initiator().probe(LINKAT));
164
165        let inner = if supports_nonblock_read {
166            assert!(read_size <= MAXIMUM_WAIT_READ_SIZE);
167            FdWaitInner::ViaRead(FdWaitViaRead {
168                fd,
169                read_size,
170                initiator,
171                state: FdWaitViaReadState::Idle(Box::new(0)),
172            })
173        } else {
174            FdWaitInner::ViaPoll(pal_async::unix::FdWait::new(
175                fd,
176                FdReady::new(initiator, fd),
177                read_size,
178            ))
179        };
180        FdWait { inner }
181    }
182}
183
184impl<T: Initiate> PollWait for FdWait<T> {
185    fn poll_wait(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
186        match &mut self.inner {
187            FdWaitInner::ViaPoll(wait) => wait.poll_wait(cx),
188            FdWaitInner::ViaRead(wait) => wait.poll_wait(cx),
189        }
190    }
191
192    fn poll_cancel_wait(&mut self, cx: &mut Context<'_>) -> Poll<bool> {
193        match &mut self.inner {
194            FdWaitInner::ViaPoll(wait) => wait.poll_cancel_wait(cx),
195            FdWaitInner::ViaRead(wait) => wait.poll_cancel_wait(cx),
196        }
197    }
198}
199
200#[derive(Debug)]
201struct FdWaitViaRead<T: Initiate> {
202    fd: RawFd,
203    read_size: usize,
204    initiator: T,
205    state: FdWaitViaReadState,
206}
207
208#[derive(Debug)]
209enum FdWaitViaReadState {
210    Idle(Box<u64>),
211    ReadPending { io: Io<Box<u64>>, cancelling: bool },
212    Invalid,
213}
214
215impl<T: Initiate> PollWait for FdWaitViaRead<T> {
216    fn poll_wait(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
217        loop {
218            match std::mem::replace(&mut self.state, FdWaitViaReadState::Invalid) {
219                FdWaitViaReadState::Idle(mut buf) => {
220                    assert!(self.read_size <= 8);
221                    let sqe = opcode::Read::new(
222                        io_uring::types::Fd(self.fd),
223                        std::ptr::from_mut(&mut *buf).cast(),
224                        self.read_size as u32,
225                    )
226                    .build();
227                    // SAFETY: the sqe's buffer is kept alive in `buf` for the
228                    // lifetime of the IO.
229                    let io = unsafe { Io::new(self.initiator.initiator().clone(), sqe, buf) };
230                    self.state = FdWaitViaReadState::ReadPending {
231                        io,
232                        cancelling: false,
233                    };
234                }
235                FdWaitViaReadState::ReadPending { mut io, cancelling } => match io.poll_unpin(cx) {
236                    Poll::Ready(r) => {
237                        self.state = FdWaitViaReadState::Idle(io.into_mem());
238                        match r {
239                            Ok(_) => break Poll::Ready(Ok(())),
240                            Err(err) if err.raw_os_error() == Some(libc::ECANCELED) => {}
241                            Err(err) => return Poll::Ready(Err(err)),
242                        }
243                    }
244                    Poll::Pending => {
245                        self.state = FdWaitViaReadState::ReadPending { io, cancelling };
246                        return Poll::Pending;
247                    }
248                },
249                FdWaitViaReadState::Invalid => unreachable!(),
250            }
251        }
252    }
253
254    fn poll_cancel_wait(&mut self, cx: &mut Context<'_>) -> Poll<bool> {
255        loop {
256            match std::mem::replace(&mut self.state, FdWaitViaReadState::Invalid) {
257                FdWaitViaReadState::Idle(buf) => {
258                    self.state = FdWaitViaReadState::Idle(buf);
259                    break Poll::Ready(false);
260                }
261                FdWaitViaReadState::ReadPending { mut io, cancelling } => {
262                    if cancelling {
263                        match io.poll_unpin(cx) {
264                            Poll::Ready(r) => {
265                                self.state = FdWaitViaReadState::Idle(io.into_mem());
266                                // If `r` is an error, it was either `ECANCELED`
267                                // (so do nothing), or it was a real error. We
268                                // assume that subsequent reads will return the
269                                // same error, so we can ignore those here to
270                                // keep the cancel contract simple for the
271                                // caller.
272                                break Poll::Ready(r.is_ok());
273                            }
274                            Poll::Pending => {
275                                self.state = FdWaitViaReadState::ReadPending { io, cancelling };
276                                break Poll::Pending;
277                            }
278                        }
279                    } else {
280                        io.cancel();
281                        self.state = FdWaitViaReadState::ReadPending {
282                            io,
283                            cancelling: true,
284                        };
285                    }
286                }
287                FdWaitViaReadState::Invalid => unreachable!(),
288            }
289        }
290    }
291}
292
293impl<T: Initiate> Drop for FdWaitViaRead<T> {
294    fn drop(&mut self) {
295        let _ = self.poll_cancel_wait(&mut Context::from_waker(Waker::noop()));
296    }
297}
298
299/// A [`pal_async::timer::PollTimer`] implementation for io_uring.
300#[derive(Debug)]
301pub struct Timer<T: Initiate> {
302    initiator: T,
303    target_deadline: Instant,
304    state: Option<TimerState>,
305}
306
307impl<T: Initiate> Timer<T> {
308    /// Creates a new instance for the given initiator.
309    pub fn new(initiator: T) -> Self {
310        Timer {
311            initiator,
312            target_deadline: Instant::from_nanos(0),
313            state: None,
314        }
315    }
316}
317
318#[derive(Debug)]
319struct TimerState {
320    io: Io<Box<Timespec>>,
321    cancelled: bool,
322}
323
324impl TimerDriver for IoInitiator {
325    type Timer = Timer<Self>;
326
327    fn new_timer(&self) -> Self::Timer {
328        Timer::new(self.clone())
329    }
330}
331
332impl<T: Initiate> PollTimer for Timer<T> {
333    fn poll_timer(&mut self, cx: &mut Context<'_>, deadline: Option<Instant>) -> Poll<Instant> {
334        if let Some(deadline) = deadline {
335            self.set_deadline(deadline);
336        }
337        loop {
338            let now = Instant::now();
339            if self.target_deadline <= now {
340                break Poll::Ready(now);
341            } else if let Some(state) = &mut self.state {
342                let _ = std::task::ready!(state.io.poll_unpin(cx));
343                self.state = None;
344            } else {
345                // Compute an absolute timeout. Note that pal's Instant is
346                // CLOCK_MONOTONIC, which is exactly what io_uring supports.
347                let absolute_timeout = self.target_deadline - Instant::from_nanos(0);
348                let timespec = Box::new(
349                    Timespec::new()
350                        .sec(absolute_timeout.as_secs())
351                        .nsec(absolute_timeout.subsec_nanos()),
352                );
353                let sqe = {
354                    opcode::Timeout::new(&*timespec)
355                        .flags(TimeoutFlags::ABS)
356                        .build()
357                };
358                // SAFETY: the operation references timespec, which is boxed for
359                // the duration of the IO.
360                let io = unsafe { Io::new(self.initiator.initiator().clone(), sqe, timespec) };
361                let state = TimerState {
362                    io,
363                    cancelled: false,
364                };
365                self.state = Some(state);
366            }
367        }
368    }
369
370    fn set_deadline(&mut self, deadline: Instant) {
371        if let Some(state) = &mut self.state {
372            // Cancel the current operation if the deadline is later than
373            // the current one.
374            //
375            // FUTURE: just update the current operation. This requires >=
376            // Linux 5.11.
377            if self.target_deadline > deadline && !state.cancelled {
378                state.io.cancel_timeout();
379                state.cancelled = true;
380            }
381        }
382        self.target_deadline = deadline;
383    }
384}
385
386#[cfg(test)]
387pub(crate) mod tests {
388    use crate::IoInitiator;
389    use crate::IoUringPool;
390    use futures::executor::block_on;
391    use once_cell::sync::OnceCell;
392    use pal_async::executor_tests;
393    use pal_async::task::Spawn;
394    use std::future::Future;
395    use std::io;
396    use std::thread::JoinHandle;
397
398    pub struct SingleThreadPool {
399        _thread: JoinHandle<()>,
400        initiator: IoInitiator,
401    }
402
403    impl SingleThreadPool {
404        pub fn new() -> io::Result<Self> {
405            let pool = IoUringPool::new("test", 16)?;
406            let initiator = pool.client().initiator().clone();
407            let thread = std::thread::spawn(move || pool.run());
408            Ok(Self {
409                _thread: thread,
410                initiator,
411            })
412        }
413
414        pub fn initiator(&self) -> &IoInitiator {
415            &self.initiator
416        }
417    }
418
419    fn test_pool() -> io::Result<&'static SingleThreadPool> {
420        // TODO: switch to std::sync::OnceLock once `get_or_try_init` is stable
421        static POOL: OnceCell<SingleThreadPool> = OnceCell::new();
422        POOL.get_or_try_init(SingleThreadPool::new)
423    }
424
425    macro_rules! get_pool_or_skip {
426        () => {
427            match test_pool() {
428                Ok(pool) => pool,
429                Err(err) if err.raw_os_error() == Some(libc::ENOSYS) => {
430                    println!("Test case skipped (no IO-Uring support)");
431                    return;
432                }
433                Err(err) => panic!("{}", err),
434            }
435        };
436    }
437
438    fn run_until<F>(pool: &SingleThreadPool, fut: F) -> F::Output
439    where
440        F: 'static + Future + Send,
441        F::Output: Send,
442    {
443        block_on(pool.initiator().spawn("test", fut))
444    }
445
446    #[test]
447    fn waker_works() {
448        run_until(get_pool_or_skip!(), executor_tests::waker_tests());
449    }
450
451    #[test]
452    fn spawn_works() {
453        let pool = get_pool_or_skip!();
454        executor_tests::spawn_tests(|| (pool.initiator(), || ()))
455    }
456
457    #[test]
458    fn sleep_works() {
459        let pool = get_pool_or_skip!();
460        run_until(pool, executor_tests::sleep_tests(pool.initiator().clone()))
461    }
462
463    #[test]
464    fn wait_works() {
465        let pool = get_pool_or_skip!();
466        run_until(pool, executor_tests::wait_tests(pool.initiator().clone()))
467    }
468
469    #[test]
470    fn socket_works() {
471        let pool = get_pool_or_skip!();
472        run_until(pool, executor_tests::socket_tests(pool.initiator().clone()))
473    }
474}