Skip to main content

pal_async/
driver.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Driver trait.
5
6// UNSAFETY: Needed to define and implement the unsafe new_dyn_overlapped_file method.
7#![cfg_attr(any(windows, target_os = "linux"), expect(unsafe_code))]
8
9#[cfg(unix)]
10use crate::fd::FdReadyDriver;
11#[cfg(unix)]
12use crate::fd::PollFdReady;
13#[cfg(target_os = "linux")]
14use crate::io_uring::IoUringDriver;
15use crate::socket::PollSocketReady;
16use crate::socket::SocketReadyDriver;
17#[cfg(windows)]
18use crate::sys::overlapped::IoOverlapped;
19#[cfg(windows)]
20use crate::sys::overlapped::OverlappedIoDriver;
21use crate::task::Spawn;
22use crate::timer::PollTimer;
23use crate::timer::TimerDriver;
24use crate::wait::PollWait;
25use crate::wait::WaitDriver;
26use smallbox::SmallBox;
27use smallbox::space::S4;
28use std::io;
29#[cfg(unix)]
30use std::os::unix::prelude::*;
31#[cfg(windows)]
32use std::os::windows::prelude::*;
33use std::sync::Arc;
34
35/// A generic `Box`-like container of one of the polled types.
36pub type PollImpl<T> = SmallBox<T, S4>;
37
38/// A driver that supports polled IO.
39pub trait Driver: 'static + Send + Sync {
40    /// Returns a new timer.
41    fn new_dyn_timer(&self) -> PollImpl<dyn PollTimer>;
42
43    /// Returns a new object for polling file descriptor readiness.
44    #[cfg(unix)]
45    fn new_dyn_fd_ready(&self, fd: RawFd) -> io::Result<PollImpl<dyn PollFdReady>>;
46
47    /// Creates a new object for polling socket readiness.
48    #[cfg(windows)]
49    fn new_dyn_socket_ready(&self, socket: RawSocket) -> io::Result<PollImpl<dyn PollSocketReady>>;
50
51    /// Creates a new object for polling socket readiness.
52    #[cfg(unix)]
53    fn new_dyn_socket_ready(&self, socket: RawFd) -> io::Result<PollImpl<dyn PollSocketReady>>;
54
55    /// Creates a new wait.
56    #[cfg(windows)]
57    fn new_dyn_wait(&self, handle: RawHandle) -> io::Result<PollImpl<dyn PollWait>>;
58
59    /// Creates a new wait.
60    ///
61    /// Signals will be consumed using reads of `read_size` bytes, with 8-byte
62    /// buffer alignment. `read_size` must be at most
63    /// [`MAXIMUM_WAIT_READ_SIZE`](super::wait::MAXIMUM_WAIT_READ_SIZE) bytes.
64    #[cfg(unix)]
65    fn new_dyn_wait(&self, fd: RawFd, read_size: usize) -> io::Result<PollImpl<dyn PollWait>>;
66
67    /// Creates a new overlapped file handler.
68    ///
69    /// # Safety
70    /// The caller must ensure that they exclusively own `handle`, and that
71    /// `handle` stays alive until the new handler is dropped.
72    #[cfg(windows)]
73    unsafe fn new_dyn_overlapped_file(
74        &self,
75        handle: RawHandle,
76    ) -> io::Result<PollImpl<dyn IoOverlapped>>;
77
78    /// Returns whether the given opcode is supported by the ring.
79    #[cfg(target_os = "linux")]
80    fn io_uring_probe(&self, opcode: u8) -> bool;
81
82    /// Submits an io-uring SQE for asynchronous execution.
83    ///
84    /// Returns a future that completes with the IO result. The future **aborts
85    /// the process** if dropped while the IO is in flight, since there is no
86    /// way to synchronously cancel an in-flight io-uring operation.
87    ///
88    /// # Safety
89    ///
90    /// All memory referenced by the SQE must remain valid for the lifetime of
91    /// the returned future.
92    ///
93    /// This can be hard to do safely; in particular, if this future can be
94    /// leaked (via [`std::mem::forget`] or otherwise) then the caller must
95    /// ensure that any referenced memory also leaks. The easiest way to do that
96    /// is to ensure that the future is `await`ed in an async function or block
97    /// that owns the underlying memory. So, this is safe:
98    ///
99    /// ```rust,ignore
100    /// async fn write(driver: &impl Driver, file: &File, buf: Vec<u8>) -> io::Result<usize> {
101    ///     let sqe = opcode::Write::new(
102    ///         types::Fd(file.as_raw_fd()), buf.as_ptr(), buf.len() as u32,
103    ///     ).build();
104    ///     // SAFETY: `buf` is owned by this async function's state machine.
105    ///     // If the outer future is leaked, `buf` leaks with it, so the
106    ///     // memory remains valid for the io-uring operation.
107    ///     unsafe { driver.io_uring_submit(sqe).await? };
108    ///     Ok(buf.len())
109    /// }
110    /// ```
111    ///
112    /// But this is not:
113    ///
114    /// ```rust,ignore
115    /// async fn write(driver: &impl Driver, file: &File, buf: &[u8]) -> io::Result<usize> {
116    ///     let sqe = opcode::Write::new(
117    ///         types::Fd(file.as_raw_fd()), buf.as_ptr(), buf.len() as u32,
118    ///     ).build();
119    ///     // NOT SAFE: `buf` is a borrow. If the outer future is leaked,
120    ///     // the referent can be freed while the io-uring operation is
121    ///     // still in flight.
122    ///     unsafe { driver.io_uring_submit(sqe).await? };
123    ///     Ok(buf.len())
124    /// }
125    /// ```
126    #[cfg(target_os = "linux")]
127    unsafe fn io_uring_submit(
128        &self,
129        sqe: crate::io_uring::Entry,
130    ) -> std::pin::Pin<Box<dyn Future<Output = io::Result<i32>> + Send + '_>>;
131}
132
133#[cfg(all(unix, not(target_os = "linux")))]
134impl<T> Driver for T
135where
136    T: 'static + Send + Sync + FdReadyDriver + TimerDriver + SocketReadyDriver + WaitDriver,
137{
138    fn new_dyn_timer(&self) -> PollImpl<dyn PollTimer> {
139        smallbox::smallbox!(self.new_timer())
140    }
141
142    fn new_dyn_fd_ready(&self, fd: RawFd) -> io::Result<PollImpl<dyn PollFdReady>> {
143        Ok(smallbox::smallbox!(self.new_fd_ready(fd)?))
144    }
145
146    fn new_dyn_socket_ready(&self, socket: RawFd) -> io::Result<PollImpl<dyn PollSocketReady>> {
147        Ok(smallbox::smallbox!(self.new_socket_ready(socket)?))
148    }
149
150    fn new_dyn_wait(&self, fd: RawFd, read_size: usize) -> io::Result<PollImpl<dyn PollWait>> {
151        Ok(smallbox::smallbox!(self.new_wait(fd, read_size)?))
152    }
153}
154
155#[cfg(target_os = "linux")]
156impl<T> Driver for T
157where
158    T: 'static
159        + Send
160        + Sync
161        + FdReadyDriver
162        + TimerDriver
163        + SocketReadyDriver
164        + WaitDriver
165        + IoUringDriver,
166{
167    fn new_dyn_timer(&self) -> PollImpl<dyn PollTimer> {
168        smallbox::smallbox!(self.new_timer())
169    }
170
171    fn new_dyn_fd_ready(&self, fd: RawFd) -> io::Result<PollImpl<dyn PollFdReady>> {
172        Ok(smallbox::smallbox!(self.new_fd_ready(fd)?))
173    }
174
175    fn new_dyn_socket_ready(&self, socket: RawFd) -> io::Result<PollImpl<dyn PollSocketReady>> {
176        Ok(smallbox::smallbox!(self.new_socket_ready(socket)?))
177    }
178
179    fn new_dyn_wait(&self, fd: RawFd, read_size: usize) -> io::Result<PollImpl<dyn PollWait>> {
180        Ok(smallbox::smallbox!(self.new_wait(fd, read_size)?))
181    }
182
183    fn io_uring_probe(&self, opcode: u8) -> bool {
184        use crate::io_uring::IoUringSubmit as _;
185
186        self.io_uring_submitter()
187            .is_some_and(|submitter| submitter.probe(opcode))
188    }
189
190    unsafe fn io_uring_submit(
191        &self,
192        sqe: crate::io_uring::Entry,
193    ) -> std::pin::Pin<Box<dyn Future<Output = io::Result<i32>> + Send + '_>> {
194        use crate::io_uring::IoUringSubmit as _;
195
196        Box::pin(async move {
197            // SAFETY: caller guarantees contract
198            unsafe {
199                self.io_uring_submitter()
200                    .ok_or(io::ErrorKind::Unsupported)?
201                    .submit(sqe)
202            }
203            .await
204        })
205    }
206}
207
208#[cfg(windows)]
209impl<T> Driver for T
210where
211    T: 'static + Send + Sync + TimerDriver + SocketReadyDriver + WaitDriver + OverlappedIoDriver,
212{
213    fn new_dyn_timer(&self) -> PollImpl<dyn PollTimer> {
214        smallbox::smallbox!(self.new_timer())
215    }
216
217    fn new_dyn_socket_ready(&self, socket: RawSocket) -> io::Result<PollImpl<dyn PollSocketReady>> {
218        Ok(smallbox::smallbox!(self.new_socket_ready(socket)?))
219    }
220
221    fn new_dyn_wait(&self, handle: RawHandle) -> io::Result<PollImpl<dyn PollWait>> {
222        Ok(smallbox::smallbox!(self.new_wait(handle)?))
223    }
224
225    unsafe fn new_dyn_overlapped_file(
226        &self,
227        handle: RawHandle,
228    ) -> io::Result<PollImpl<dyn IoOverlapped>> {
229        // SAFETY: caller guarantees contract
230        Ok(smallbox::smallbox!(unsafe {
231            self.new_overlapped_file(handle)
232        }?))
233    }
234}
235
236#[cfg(unix)]
237impl Driver for Box<dyn Driver> {
238    fn new_dyn_timer(&self) -> PollImpl<dyn PollTimer> {
239        self.as_ref().new_dyn_timer()
240    }
241
242    fn new_dyn_fd_ready(&self, fd: RawFd) -> io::Result<PollImpl<dyn PollFdReady>> {
243        self.as_ref().new_dyn_fd_ready(fd)
244    }
245
246    fn new_dyn_socket_ready(&self, socket: RawFd) -> io::Result<PollImpl<dyn PollSocketReady>> {
247        self.as_ref().new_dyn_socket_ready(socket)
248    }
249
250    fn new_dyn_wait(&self, fd: RawFd, read_size: usize) -> io::Result<PollImpl<dyn PollWait>> {
251        self.as_ref().new_dyn_wait(fd, read_size)
252    }
253
254    #[cfg(target_os = "linux")]
255    fn io_uring_probe(&self, opcode: u8) -> bool {
256        self.as_ref().io_uring_probe(opcode)
257    }
258
259    #[cfg(target_os = "linux")]
260    unsafe fn io_uring_submit(
261        &self,
262        sqe: crate::io_uring::Entry,
263    ) -> std::pin::Pin<Box<dyn Future<Output = io::Result<i32>> + Send + '_>> {
264        // SAFETY: caller guarantees contract
265        unsafe { self.as_ref().io_uring_submit(sqe) }
266    }
267}
268
269#[cfg(windows)]
270impl Driver for Box<dyn Driver> {
271    fn new_dyn_timer(&self) -> PollImpl<dyn PollTimer> {
272        self.as_ref().new_dyn_timer()
273    }
274
275    fn new_dyn_socket_ready(&self, socket: RawSocket) -> io::Result<PollImpl<dyn PollSocketReady>> {
276        self.as_ref().new_dyn_socket_ready(socket)
277    }
278
279    fn new_dyn_wait(&self, handle: RawHandle) -> io::Result<PollImpl<dyn PollWait>> {
280        self.as_ref().new_dyn_wait(handle)
281    }
282
283    unsafe fn new_dyn_overlapped_file(
284        &self,
285        handle: RawHandle,
286    ) -> io::Result<PollImpl<dyn IoOverlapped>> {
287        // SAFETY: caller guarantees contract
288        unsafe { self.as_ref().new_dyn_overlapped_file(handle) }
289    }
290}
291
292#[cfg(unix)]
293impl Driver for Arc<dyn Driver> {
294    fn new_dyn_timer(&self) -> PollImpl<dyn PollTimer> {
295        self.as_ref().new_dyn_timer()
296    }
297
298    fn new_dyn_fd_ready(&self, fd: RawFd) -> io::Result<PollImpl<dyn PollFdReady>> {
299        self.as_ref().new_dyn_fd_ready(fd)
300    }
301
302    fn new_dyn_socket_ready(&self, socket: RawFd) -> io::Result<PollImpl<dyn PollSocketReady>> {
303        self.as_ref().new_dyn_socket_ready(socket)
304    }
305
306    fn new_dyn_wait(&self, fd: RawFd, read_size: usize) -> io::Result<PollImpl<dyn PollWait>> {
307        self.as_ref().new_dyn_wait(fd, read_size)
308    }
309
310    #[cfg(target_os = "linux")]
311    fn io_uring_probe(&self, opcode: u8) -> bool {
312        self.as_ref().io_uring_probe(opcode)
313    }
314
315    #[cfg(target_os = "linux")]
316    unsafe fn io_uring_submit(
317        &self,
318        sqe: crate::io_uring::Entry,
319    ) -> std::pin::Pin<Box<dyn Future<Output = io::Result<i32>> + Send + '_>> {
320        // SAFETY: caller guarantees contract
321        unsafe { self.as_ref().io_uring_submit(sqe) }
322    }
323}
324
325#[cfg(windows)]
326impl Driver for Arc<dyn Driver> {
327    fn new_dyn_timer(&self) -> PollImpl<dyn PollTimer> {
328        self.as_ref().new_dyn_timer()
329    }
330
331    fn new_dyn_socket_ready(&self, socket: RawSocket) -> io::Result<PollImpl<dyn PollSocketReady>> {
332        self.as_ref().new_dyn_socket_ready(socket)
333    }
334
335    fn new_dyn_wait(&self, handle: RawHandle) -> io::Result<PollImpl<dyn PollWait>> {
336        self.as_ref().new_dyn_wait(handle)
337    }
338
339    unsafe fn new_dyn_overlapped_file(
340        &self,
341        handle: RawHandle,
342    ) -> io::Result<PollImpl<dyn IoOverlapped>> {
343        // SAFETY: caller guarantees contract
344        unsafe { self.as_ref().new_dyn_overlapped_file(handle) }
345    }
346}
347
348/// Trait for [`Driver`]s that also implement [`Spawn`].
349pub trait SpawnDriver: Spawn + Driver {}
350
351impl<T: Spawn + Driver> SpawnDriver for T {}