pal/
windows.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![cfg(windows)]
5// UNSAFETY: Calls to Win32 functions to handle delay loading, interacting
6// with low level primitives, and memory management.
7#![expect(unsafe_code)]
8#![expect(clippy::undocumented_unsafe_blocks, clippy::missing_safety_doc)]
9
10pub mod afd;
11pub mod affinity;
12pub mod alpc;
13pub mod fs;
14pub mod job;
15pub mod pipe;
16pub mod process;
17pub mod security;
18pub mod tp;
19
20use self::security::SecurityDescriptor;
21use windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE;
22// TODO: Revert this ntapi fallback once windows/windows-sys expose
23// NtCreateWaitCompletionPacket, NtAssociateWaitCompletionPacket, and
24// NtCancelWaitCompletionPacket.
25use ntapi::ntioapi::FILE_COMPLETION_INFORMATION;
26use ntapi::ntioapi::FileReplaceCompletionInformation;
27use ntapi::ntioapi::IO_STATUS_BLOCK;
28use ntapi::ntioapi::NtAssociateWaitCompletionPacket;
29use ntapi::ntioapi::NtCancelWaitCompletionPacket;
30use ntapi::ntioapi::NtCreateWaitCompletionPacket;
31use ntapi::ntioapi::NtSetInformationFile;
32use std::cell::UnsafeCell;
33use std::ffi::OsStr;
34use std::ffi::c_void;
35use std::fs::File;
36use std::io;
37use std::io::Error;
38use std::io::Result;
39use std::marker::PhantomData;
40use std::mem::zeroed;
41use std::os::windows::prelude::*;
42use std::path::Path;
43use std::ptr::NonNull;
44use std::ptr::addr_of;
45use std::ptr::null_mut;
46use std::sync::Once;
47use std::sync::atomic::AtomicUsize;
48use std::sync::atomic::Ordering;
49use std::time::Duration;
50use widestring::U16CString;
51use widestring::Utf16Str;
52use windows_sys::Wdk::Storage::FileSystem::NtCreateDirectoryObject;
53use windows_sys::Wdk::Storage::FileSystem::NtOpenDirectoryObject;
54use windows_sys::Wdk::Storage::FileSystem::RtlAllocateHeap;
55use windows_sys::Wdk::Storage::FileSystem::RtlDosPathNameToNtPathName_U_WithStatus;
56use windows_sys::Wdk::Storage::FileSystem::RtlFreeHeap;
57use windows_sys::Wdk::Storage::FileSystem::RtlNtStatusToDosErrorNoTeb;
58use windows_sys::Win32::Foundation::CloseHandle;
59use windows_sys::Win32::Foundation::ERROR_BAD_PATHNAME;
60use windows_sys::Win32::Foundation::NTSTATUS;
61use windows_sys::Win32::Foundation::STATUS_PENDING;
62use windows_sys::Win32::Foundation::UNICODE_STRING;
63use windows_sys::Win32::Security::SECURITY_DESCRIPTOR;
64use windows_sys::Win32::Storage::FileSystem::SetFileCompletionNotificationModes;
65use windows_sys::Win32::System::Console::STD_OUTPUT_HANDLE;
66use windows_sys::Win32::System::Console::SetStdHandle;
67use windows_sys::Win32::System::Diagnostics::Debug::GetErrorMode;
68use windows_sys::Win32::System::Diagnostics::Debug::SEM_FAILCRITICALERRORS;
69use windows_sys::Win32::System::Diagnostics::Debug::SetErrorMode;
70use windows_sys::Win32::System::IO::CreateIoCompletionPort;
71use windows_sys::Win32::System::IO::GetQueuedCompletionStatusEx;
72use windows_sys::Win32::System::IO::OVERLAPPED;
73use windows_sys::Win32::System::IO::OVERLAPPED_ENTRY;
74use windows_sys::Win32::System::IO::PostQueuedCompletionStatus;
75use windows_sys::Win32::System::Kernel::STRING;
76use windows_sys::Win32::System::Memory::GetProcessHeap;
77use windows_sys::Win32::System::Threading::GetExitCodeProcess;
78use windows_sys::Win32::System::Threading::INFINITE;
79use windows_sys::Win32::System::Threading::TerminateProcess;
80use windows_sys::Win32::System::WindowsProgramming::RtlFreeUnicodeString;
81
82#[repr(transparent)]
83#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
84pub struct SendSyncRawHandle(pub RawHandle);
85
86unsafe impl Send for SendSyncRawHandle {}
87unsafe impl Sync for SendSyncRawHandle {}
88
89pub trait BorrowedHandleExt: Sized {
90    fn duplicate(&self, inherit: bool, access: Option<u32>) -> Result<OwnedHandle>;
91}
92
93impl BorrowedHandleExt for BorrowedHandle<'_> {
94    fn duplicate(&self, inherit: bool, access: Option<u32>) -> Result<OwnedHandle> {
95        let mut handle = null_mut();
96        let options = if access.is_some() {
97            0
98        } else {
99            windows_sys::Win32::Foundation::DUPLICATE_SAME_ACCESS
100        };
101        unsafe {
102            let process = windows_sys::Win32::System::Threading::GetCurrentProcess();
103            if windows_sys::Win32::Foundation::DuplicateHandle(
104                process,
105                self.as_raw_handle(),
106                process,
107                &mut handle,
108                access.unwrap_or(0),
109                inherit.into(),
110                options,
111            ) == 0
112            {
113                return Err(Error::last_os_error());
114            }
115            Ok(OwnedHandle::from_raw_handle(handle))
116        }
117    }
118}
119
120pub trait OwnedSocketExt: Sized {
121    /// Prepares the socket for being sent to another process.
122    ///
123    /// After calling this, the socket should not be used for anything other
124    /// than duplicating to a handle (which can then be converted back to a
125    /// socket with `from_handle`).
126    fn prepare_to_send(&mut self) -> Result<BorrowedHandle<'_>>;
127
128    /// Converts a handle, originally duplicated from another socket, to a
129    /// socket. The original socket should have been prepared with
130    /// [`Self::prepare_to_send`].
131    fn from_handle(handle: OwnedHandle) -> Result<Self>;
132}
133
134const SIO_SOCKET_TRANSFER_BEGIN: u32 = windows_sys::Win32::Networking::WinSock::IOC_IN
135    | windows_sys::Win32::Networking::WinSock::IOC_VENDOR
136    | 301;
137const SIO_SOCKET_TRANSFER_END: u32 = windows_sys::Win32::Networking::WinSock::IOC_IN
138    | windows_sys::Win32::Networking::WinSock::IOC_VENDOR
139    | 302;
140
141/// Ensures WSAStartup has been called for the process.
142fn init_winsock() {
143    static INIT: Once = Once::new();
144
145    INIT.call_once(|| {
146        // Initialize a dummy socket, then throw away the result, to get the
147        // socket library to call WSAStartup for us.
148        let _ = socket2::Socket::new(socket2::Domain::IPV4, socket2::Type::DGRAM, None);
149    });
150}
151
152impl OwnedSocketExt for OwnedSocket {
153    fn prepare_to_send(&mut self) -> Result<BorrowedHandle<'_>> {
154        let mut catalog_id: u32 = 0;
155        let mut bytes = 0;
156        // SAFETY: calling the ioctl according to implementation requirements
157        unsafe {
158            if windows_sys::Win32::Networking::WinSock::WSAIoctl(
159                self.as_raw_socket() as _,
160                SIO_SOCKET_TRANSFER_BEGIN,
161                null_mut(),
162                0,
163                std::ptr::from_mut(&mut catalog_id).cast(),
164                size_of_val(&catalog_id) as u32,
165                &mut bytes,
166                null_mut(),
167                None,
168            ) != 0
169            {
170                return Err(Error::from_raw_os_error(
171                    windows_sys::Win32::Networking::WinSock::WSAGetLastError(),
172                ));
173            }
174            Ok(BorrowedHandle::borrow_raw(self.as_raw_socket() as RawHandle))
175        }
176    }
177
178    fn from_handle(handle: OwnedHandle) -> Result<Self> {
179        // This could be the first winsock interaction for the process.
180        init_winsock();
181
182        let mut catalog_id: u32 = 0;
183        let mut bytes = 0;
184        let mut socket = handle.as_raw_handle() as windows_sys::Win32::Networking::WinSock::SOCKET;
185        // SAFETY: calling the ioctl according to implementation requirements
186        unsafe {
187            if windows_sys::Win32::Networking::WinSock::WSAIoctl(
188                socket,
189                SIO_SOCKET_TRANSFER_END,
190                std::ptr::from_mut(&mut catalog_id).cast(),
191                size_of_val(&catalog_id) as u32,
192                std::ptr::from_mut(&mut socket).cast(),
193                size_of_val(&socket) as u32,
194                &mut bytes,
195                null_mut(),
196                None,
197            ) != 0
198            {
199                return Err(Error::from_raw_os_error(
200                    windows_sys::Win32::Networking::WinSock::WSAGetLastError(),
201                ));
202            }
203            // In theory SIO_SOCKET_TRANSFER_END could have changed `socket`, so
204            // forget the handle and use the socket instead.
205            let _gone = handle.into_raw_handle();
206            Ok(Self::from_raw_socket(socket as RawSocket))
207        }
208    }
209}
210
211#[repr(transparent)]
212#[derive(Debug)]
213struct WaitObject(OwnedHandle);
214
215impl WaitObject {
216    fn wait(&self) {
217        assert!(
218            unsafe {
219                windows_sys::Win32::System::Threading::WaitForSingleObject(
220                    self.0.as_raw_handle(),
221                    INFINITE,
222                )
223            } == 0
224        );
225    }
226}
227
228impl Clone for WaitObject {
229    fn clone(&self) -> Self {
230        Self(
231            self.0
232                .try_clone()
233                .expect("out of resources cloning wait object"),
234        )
235    }
236}
237
238#[derive(Debug, Clone)]
239pub struct Process(WaitObject);
240
241impl Process {
242    pub fn wait(&self) {
243        self.0.wait()
244    }
245
246    pub fn id(&self) -> u32 {
247        unsafe {
248            let pid = windows_sys::Win32::System::Threading::GetProcessId(
249                self.as_handle().as_raw_handle(),
250            );
251            assert_ne!(pid, 0);
252            pid
253        }
254    }
255
256    pub fn exit_code(&self) -> u32 {
257        let mut code = 0;
258        unsafe {
259            assert!(GetExitCodeProcess(self.as_handle().as_raw_handle(), &mut code) != 0);
260        }
261        code
262    }
263
264    /// Terminates the process immediately, setting its exit code to `exit_code`.
265    pub fn kill(&self, exit_code: u32) -> Result<()> {
266        // SAFETY: calling TerminateProcess according to API docs.
267        unsafe {
268            if TerminateProcess(self.as_handle().as_raw_handle(), exit_code) == 0 {
269                return Err(Error::last_os_error());
270            }
271        }
272        Ok(())
273    }
274}
275
276impl From<OwnedHandle> for Process {
277    fn from(handle: OwnedHandle) -> Self {
278        Self(WaitObject(handle))
279    }
280}
281
282impl AsHandle for Process {
283    fn as_handle(&self) -> BorrowedHandle<'_> {
284        (self.0).0.as_handle()
285    }
286}
287
288impl From<Process> for OwnedHandle {
289    fn from(value: Process) -> OwnedHandle {
290        (value.0).0
291    }
292}
293
294#[derive(Debug)]
295pub struct IoCompletionPort(OwnedHandle);
296
297impl IoCompletionPort {
298    pub fn new() -> Self {
299        unsafe {
300            let handle = CreateIoCompletionPort(INVALID_HANDLE_VALUE, null_mut(), 0, 0);
301            if handle.is_null() {
302                panic!("oom allocating completion port");
303            }
304            Self(OwnedHandle::from_raw_handle(handle))
305        }
306    }
307
308    pub fn get(&self, entries: &mut [OVERLAPPED_ENTRY], timeout: Option<Duration>) -> usize {
309        unsafe {
310            let mut n = 0;
311            if GetQueuedCompletionStatusEx(
312                self.0.as_raw_handle(),
313                entries.as_mut_ptr(),
314                entries.len().try_into().expect("too many entries"),
315                &mut n,
316                timeout
317                    .map(|t| t.as_millis().try_into().unwrap_or(INFINITE - 1))
318                    .unwrap_or(INFINITE),
319                false.into(),
320            ) != 0
321            {
322                n as usize
323            } else {
324                // TODO: assert timeout
325                assert!(timeout.is_some());
326                0
327            }
328        }
329    }
330
331    // Per MSDN, overlapped values are not dereferenced by PostQueuedCompletionStatus,
332    // they are passed as-is to the caller of GetQueuedCompletionStatus.
333    #[expect(clippy::not_unsafe_ptr_arg_deref)]
334    pub fn post(&self, bytes: u32, key: usize, overlapped: *mut OVERLAPPED) {
335        unsafe {
336            if PostQueuedCompletionStatus(self.0.as_raw_handle(), bytes, key, overlapped) == 0 {
337                panic!("oom posting completion port");
338            }
339        }
340    }
341
342    /// # Safety
343    ///
344    /// The caller must ensure that `handle` is valid.
345    pub unsafe fn associate(&self, handle: RawHandle, key: usize) -> Result<()> {
346        if unsafe { CreateIoCompletionPort(handle, self.0.as_raw_handle(), key, 0).is_null() } {
347            return Err(Error::last_os_error());
348        }
349        Ok(())
350    }
351}
352
353impl From<OwnedHandle> for IoCompletionPort {
354    fn from(handle: OwnedHandle) -> Self {
355        Self(handle)
356    }
357}
358
359impl AsHandle for IoCompletionPort {
360    fn as_handle(&self) -> BorrowedHandle<'_> {
361        self.0.as_handle()
362    }
363}
364
365impl From<IoCompletionPort> for OwnedHandle {
366    fn from(value: IoCompletionPort) -> OwnedHandle {
367        value.0
368    }
369}
370
371/// Sets file completion notification modes for `handle`.
372///
373/// # Safety
374/// The caller must ensure that `handle` is valid and that changing the
375/// notification modes does not cause a safety issue elsewhere (e.g. by causing
376/// unexpected IO completions in a completion port).
377pub unsafe fn set_file_completion_notification_modes(handle: RawHandle, flags: u32) -> Result<()> {
378    let flags = u8::try_from(flags).map_err(|_| {
379        Error::new(
380            io::ErrorKind::InvalidInput,
381            "file completion flags out of range",
382        )
383    })?;
384    // SAFETY: caller guarantees contract.
385    if unsafe { SetFileCompletionNotificationModes(handle, flags) } == 0 {
386        return Err(Error::last_os_error());
387    }
388    Ok(())
389}
390
391/// Disassociates `handle` from its completion port.
392///
393/// # Safety
394///
395/// The caller must ensure that `handle` is valid.
396pub unsafe fn disassociate_completion_port(handle: RawHandle) -> Result<()> {
397    let mut info = FILE_COMPLETION_INFORMATION {
398        Port: null_mut(),
399        Key: null_mut(),
400    };
401    let mut iosb = IO_STATUS_BLOCK::default();
402    // SAFETY: caller guarantees contract.
403    unsafe {
404        chk_status(NtSetInformationFile(
405            handle.cast::<c_void>(),
406            &mut iosb,
407            std::ptr::from_mut(&mut info).cast(),
408            size_of_val(&info) as u32,
409            FileReplaceCompletionInformation,
410        ))?;
411    }
412    Ok(())
413}
414
415/// Wrapper around an NT IO completion packet, used to deliver wait results to
416/// IO completion ports.
417#[derive(Debug)]
418pub struct WaitPacket(OwnedHandle);
419
420impl WaitPacket {
421    /// Creates a new wait copmletion packet.
422    pub fn new() -> Result<Self> {
423        unsafe {
424            let mut handle = null_mut();
425            chk_status(NtCreateWaitCompletionPacket(&mut handle, 1, null_mut()))?;
426            Ok(Self(OwnedHandle::from_raw_handle(handle.cast())))
427        }
428    }
429
430    /// Initiates a wait on `handle`. When `handle` becomes signaled, the packet
431    /// information will be delivered via `iocp`. Returns true if the handle was
432    /// already signaled (in which case the packet will still be delivered
433    /// through the IOCP).
434    ///
435    /// Panics if the wait could not be associated (e.g. invalid handle or wait
436    /// already in progress).
437    ///
438    /// # Safety
439    ///
440    /// The caller must ensure that `handle` is valid.
441    pub unsafe fn associate(
442        &self,
443        iocp: &IoCompletionPort,
444        handle: RawHandle,
445        key: usize,
446        apc: usize,
447        status: i32,
448        information: usize,
449    ) -> bool {
450        // SAFETY: API is being used as documented, and handle is valid
451        unsafe {
452            let mut already_signaled = 0;
453            chk_status(NtAssociateWaitCompletionPacket(
454                self.0.as_raw_handle().cast::<c_void>(),
455                iocp.as_handle().as_raw_handle().cast::<c_void>(),
456                handle.cast::<c_void>(),
457                key as *mut c_void,
458                apc as *mut c_void,
459                status,
460                information,
461                &mut already_signaled,
462            ))
463            .expect("failed to associate wait completion packet");
464            already_signaled != 0
465        }
466    }
467
468    /// Cancels a pending wait. Returns true if the wait was successfully
469    /// cancelled. If `remove_signaled_packet`, then the packet will be removed
470    /// from the IOCP (in which case it may have already consumed the signal
471    /// state of the object that was being waited upon).
472    pub fn cancel(&self, remove_signaled_packet: bool) -> bool {
473        match unsafe {
474            NtCancelWaitCompletionPacket(
475                self.0.as_raw_handle().cast::<c_void>(),
476                if remove_signaled_packet { 1 } else { 0 },
477            )
478        } {
479            windows_sys::Win32::Foundation::STATUS_SUCCESS => true,
480            STATUS_PENDING => false,
481            windows_sys::Win32::Foundation::STATUS_CANCELLED => false,
482            s => panic!(
483                "unexpected failure in NtCancelWaitCompletionPacket: {:?}",
484                chk_status(s).unwrap_err()
485            ),
486        }
487    }
488}
489
490impl AsHandle for WaitPacket {
491    fn as_handle(&self) -> BorrowedHandle<'_> {
492        self.0.as_handle()
493    }
494}
495
496// Represents a UNICODE_STRING that owns its buffer, where the buffer is
497// allocated on the Windows heap.
498#[repr(transparent)]
499pub struct UnicodeString(UNICODE_STRING);
500
501// SAFETY: UnicodeString owns its heap-allocated pointers, which can be safely
502//         aliased and sent between threads.
503unsafe impl Send for UnicodeString {}
504unsafe impl Sync for UnicodeString {}
505
506#[derive(Debug)]
507pub struct StringTooLong;
508
509impl UnicodeString {
510    pub fn new(s: &[u16]) -> std::result::Result<Self, StringTooLong> {
511        let byte_count: u16 = (s.len() * 2).try_into().map_err(|_| StringTooLong)?;
512        // FUTURE: use RtlProcessHeap instead of GetProcessHeap. This relies on
513        // unstable Rust features to get the PEB.
514        unsafe {
515            let buf = RtlAllocateHeap(GetProcessHeap().cast::<c_void>(), 0, byte_count.into())
516                .cast::<u16>();
517            assert!(!buf.is_null(), "out of memory");
518            std::ptr::copy(s.as_ptr(), buf, s.len());
519            Ok(Self(UNICODE_STRING {
520                Length: byte_count,
521                MaximumLength: byte_count,
522                Buffer: buf,
523            }))
524        }
525    }
526
527    pub fn empty() -> Self {
528        Self(unsafe { zeroed() })
529    }
530
531    pub fn is_empty(&self) -> bool {
532        self.0.Buffer.is_null()
533    }
534
535    /// The length of the string in bytes.
536    pub fn length(&self) -> usize {
537        self.0.Length as usize
538    }
539
540    pub fn as_ptr(&self) -> *const UNICODE_STRING {
541        &self.0
542    }
543
544    pub fn as_mut_ptr(&mut self) -> *mut UNICODE_STRING {
545        &mut self.0
546    }
547
548    pub fn into_raw(mut self) -> UNICODE_STRING {
549        let raw = self.0;
550        self.0.Length = 0;
551        self.0.MaximumLength = 0;
552        self.0.Buffer = null_mut();
553        raw
554    }
555
556    pub fn as_slice(&self) -> &[u16] {
557        let buffer = NonNull::new(self.0.Buffer).unwrap_or_else(NonNull::dangling);
558        unsafe { std::slice::from_raw_parts(buffer.as_ptr(), self.0.Length as usize / 2) }
559    }
560
561    pub fn as_mut_slice(&mut self) -> &mut [u16] {
562        let buffer = NonNull::new(self.0.Buffer).unwrap_or_else(NonNull::dangling);
563        unsafe { std::slice::from_raw_parts_mut(buffer.as_ptr(), self.0.Length as usize / 2) }
564    }
565}
566
567impl Drop for UnicodeString {
568    fn drop(&mut self) {
569        unsafe {
570            RtlFreeUnicodeString(&mut self.0);
571        }
572    }
573}
574
575impl<'a> TryFrom<&'a OsStr> for UnicodeString {
576    type Error = StringTooLong;
577    fn try_from(value: &'a OsStr) -> std::result::Result<Self, Self::Error> {
578        // FUTURE: figure out how to do this without a second allocation.
579        let value16: Vec<_> = value.encode_wide().collect();
580        Self::new(&value16)
581    }
582}
583
584impl<'a> TryFrom<&'a str> for UnicodeString {
585    type Error = StringTooLong;
586    fn try_from(value: &'a str) -> std::result::Result<Self, Self::Error> {
587        Self::try_from(OsStr::new(value))
588    }
589}
590
591impl TryFrom<String> for UnicodeString {
592    type Error = StringTooLong;
593    fn try_from(value: String) -> std::result::Result<Self, Self::Error> {
594        Self::try_from(OsStr::new(&value))
595    }
596}
597
598impl<'a> TryFrom<&'a Path> for UnicodeString {
599    type Error = StringTooLong;
600    fn try_from(value: &'a Path) -> std::result::Result<Self, Self::Error> {
601        Self::try_from(OsStr::new(value))
602    }
603}
604
605#[repr(transparent)]
606#[derive(Copy, Clone)]
607pub struct UnicodeStringRef<'a>(UNICODE_STRING, PhantomData<&'a [u16]>);
608
609impl<'a> UnicodeStringRef<'a> {
610    pub fn new(s: &'a [u16]) -> Option<Self> {
611        let len: u16 = (s.len() * 2).try_into().ok()?;
612        Some(Self(
613            UNICODE_STRING {
614                Length: len,
615                MaximumLength: len,
616                Buffer: s.as_ptr().cast_mut(),
617            },
618            PhantomData,
619        ))
620    }
621
622    pub fn empty() -> Self {
623        Self(unsafe { zeroed() }, PhantomData)
624    }
625
626    pub fn is_empty(&self) -> bool {
627        self.0.Buffer.is_null()
628    }
629
630    pub fn as_ptr(&self) -> *const UNICODE_STRING {
631        &self.0
632    }
633
634    pub fn as_mut_ptr(&mut self) -> *mut UNICODE_STRING {
635        &mut self.0
636    }
637
638    pub fn as_slice(&self) -> &[u16] {
639        let buffer = NonNull::new(self.0.Buffer).unwrap_or_else(NonNull::dangling);
640        unsafe { std::slice::from_raw_parts(buffer.as_ptr(), self.0.Length as usize / 2) }
641    }
642}
643
644pub trait AsUnicodeStringRef {
645    fn as_unicode_string_ref(&self) -> &UnicodeStringRef<'_>;
646}
647
648impl<T: AsUnicodeStringRef> AsUnicodeStringRef for &T {
649    fn as_unicode_string_ref(&self) -> &UnicodeStringRef<'_> {
650        (*self).as_unicode_string_ref()
651    }
652}
653
654impl AsUnicodeStringRef for UnicodeString {
655    fn as_unicode_string_ref(&self) -> &UnicodeStringRef<'_> {
656        // SAFETY: &UnicodeStringRef can be safely transmuted from
657        // &UNICODE_STRING as long as the lifetimes are correct, and they are
658        // here because the UnicodeStringRef will live no longer than self.
659        unsafe { std::mem::transmute(&self.0) }
660    }
661}
662
663impl AsUnicodeStringRef for UnicodeStringRef<'_> {
664    fn as_unicode_string_ref(&self) -> &UnicodeStringRef<'_> {
665        self
666    }
667}
668
669impl AsRef<windows::Win32::Foundation::UNICODE_STRING> for UnicodeStringRef<'_> {
670    fn as_ref(&self) -> &windows::Win32::Foundation::UNICODE_STRING {
671        // SAFETY: These are different definitions of the same type, so the memory layout is the
672        // same.
673        unsafe { std::mem::transmute(&self.0) }
674    }
675}
676
677impl<'a> TryFrom<&'a Utf16Str> for UnicodeStringRef<'a> {
678    type Error = StringTooLong;
679
680    fn try_from(value: &'a Utf16Str) -> std::result::Result<Self, Self::Error> {
681        UnicodeStringRef::new(value.as_slice()).ok_or(StringTooLong)
682    }
683}
684
685/// Associates a STRING with the lifetime of the buffer.
686#[repr(transparent)]
687pub struct AnsiStringRef<'a>(STRING, PhantomData<&'a [u8]>);
688
689impl<'a> AnsiStringRef<'a> {
690    /// Creates a new `AnsiStringRef` using the specified buffer.
691    ///
692    /// Returns `None` if the buffer is too big for a STRING's maximum length.
693    pub fn new(s: &'a [u8]) -> Option<Self> {
694        let len: u16 = s.len().try_into().ok()?;
695        Some(Self(
696            STRING {
697                Length: len,
698                MaximumLength: len,
699                Buffer: s.as_ptr().cast_mut(),
700            },
701            PhantomData,
702        ))
703    }
704
705    /// Creates an empty `AnsiStringRef` with no buffer.
706    pub fn empty() -> Self {
707        Self(unsafe { zeroed() }, PhantomData)
708    }
709
710    /// Gets a value which indicates whether this instance does not contain a buffer.
711    pub fn is_empty(&self) -> bool {
712        self.0.Buffer.is_null()
713    }
714
715    /// Returns a pointer to the contained STRING.
716    pub fn as_ptr(&self) -> *const STRING {
717        &self.0
718    }
719
720    /// Returns a mutable pointer to the contained STRING.
721    pub fn as_mut_ptr(&mut self) -> *mut STRING {
722        &mut self.0
723    }
724
725    /// Returns the valid part of a STRING's buffer as a slice.
726    pub fn as_slice(&self) -> &[u8] {
727        let buffer = NonNull::new(self.0.Buffer).unwrap_or_else(NonNull::dangling);
728        unsafe { std::slice::from_raw_parts(buffer.as_ptr(), self.0.Length as usize) }
729    }
730}
731
732impl AsRef<STRING> for AnsiStringRef<'_> {
733    fn as_ref(&self) -> &STRING {
734        &self.0
735    }
736}
737
738pub fn status_to_error(status: i32) -> Error {
739    Error::from_raw_os_error(unsafe { RtlNtStatusToDosErrorNoTeb(status) } as i32)
740}
741
742pub fn chk_status(status: i32) -> Result<i32> {
743    if status >= 0 {
744        Ok(status)
745    } else {
746        Err(status_to_error(status))
747    }
748}
749
750pub fn dos_to_nt_path<P: AsRef<Path>>(path: P) -> Result<UnicodeString> {
751    let path16 = U16CString::from_os_str(path.as_ref().as_os_str())
752        .map_err(|_| Error::from_raw_os_error(ERROR_BAD_PATHNAME as i32))?;
753    let mut pathu = UnicodeString::empty();
754    unsafe {
755        chk_status(RtlDosPathNameToNtPathName_U_WithStatus(
756            path16.as_ptr().cast_mut(),
757            pathu.as_mut_ptr(),
758            null_mut(),
759            null_mut(),
760        ))?;
761    }
762    Ok(pathu)
763}
764
765/// A wrapper around OBJECT_ATTRIBUTES.
766#[repr(transparent)]
767pub struct ObjectAttributes<'a> {
768    attributes: windows_sys::Wdk::Foundation::OBJECT_ATTRIBUTES,
769    phantom: PhantomData<&'a ()>,
770}
771
772impl Default for ObjectAttributes<'_> {
773    fn default() -> Self {
774        Self::new()
775    }
776}
777
778impl<'a> ObjectAttributes<'a> {
779    /// Constructs the default object attributes, with no name, root directory,
780    /// attributes, or security information.
781    pub fn new() -> Self {
782        Self {
783            attributes: windows_sys::Wdk::Foundation::OBJECT_ATTRIBUTES {
784                Length: size_of::<windows_sys::Wdk::Foundation::OBJECT_ATTRIBUTES>() as u32,
785                RootDirectory: null_mut(),
786                ObjectName: null_mut(),
787                Attributes: 0,
788                SecurityDescriptor: null_mut(),
789                SecurityQualityOfService: null_mut(),
790            },
791            phantom: PhantomData,
792        }
793    }
794
795    /// Sets the object name to `name`.
796    pub fn name<P>(&mut self, name: &'a P) -> &mut Self
797    where
798        P: AsUnicodeStringRef,
799    {
800        self.attributes.ObjectName = name.as_unicode_string_ref().as_ptr().cast_mut();
801        self
802    }
803
804    /// Sets the root directory to `root`.
805    pub fn root(&mut self, root: BorrowedHandle<'a>) -> &mut Self {
806        self.attributes.RootDirectory = root.as_raw_handle().cast::<c_void>();
807        self
808    }
809
810    /// Sets the attributes to `attributes`.
811    pub fn attributes(&mut self, attributes: u32) -> &mut Self {
812        self.attributes.Attributes = attributes;
813        self
814    }
815
816    /// Sets the security descriptor to `sd`.
817    pub fn security_descriptor(&mut self, sd: &'a SecurityDescriptor) -> &mut Self {
818        self.attributes.SecurityDescriptor = sd.as_ptr().cast::<SECURITY_DESCRIPTOR>();
819        self
820    }
821
822    /// Returns the OBJECT_ATTRIBUTES pointer for passing to an NT syscall.
823    pub fn as_ptr(&self) -> *mut windows_sys::Wdk::Foundation::OBJECT_ATTRIBUTES {
824        std::ptr::from_ref(&self.attributes).cast_mut()
825    }
826}
827
828impl AsRef<windows::Wdk::Foundation::OBJECT_ATTRIBUTES> for ObjectAttributes<'_> {
829    fn as_ref(&self) -> &windows::Wdk::Foundation::OBJECT_ATTRIBUTES {
830        // SAFETY: These are different definitions of the same type, so the memory layout is the
831        // same.
832        unsafe { std::mem::transmute(&self.attributes) }
833    }
834}
835
836pub fn open_object_directory(obj_attr: &ObjectAttributes<'_>, access: u32) -> Result<OwnedHandle> {
837    // SAFETY: calling the API according to the NT API
838    unsafe {
839        let mut handle = null_mut();
840        chk_status(NtOpenDirectoryObject(
841            &mut handle,
842            access,
843            obj_attr.as_ptr(),
844        ))?;
845        Ok(OwnedHandle::from_raw_handle(handle.cast()))
846    }
847}
848
849pub fn create_object_directory(
850    obj_attr: &ObjectAttributes<'_>,
851    access: u32,
852) -> Result<OwnedHandle> {
853    // SAFETY: calling the API according to the NT API
854    unsafe {
855        let mut handle = null_mut();
856        chk_status(NtCreateDirectoryObject(
857            &mut handle,
858            access,
859            obj_attr.as_ptr(),
860        ))?;
861        Ok(OwnedHandle::from_raw_handle(handle.cast()))
862    }
863}
864
865/// A wrapper around memory that was allocated with `RtlAllocateHeap` and will be freed on drop with `RtlFreeHeap`,
866/// like [`std::boxed::Box`].
867pub struct RtlHeapBox<T: ?Sized> {
868    value: NonNull<T>,
869}
870
871impl<T> RtlHeapBox<T> {
872    /// Creates a new `RtlHeapBox` from a raw pointer.
873    ///
874    /// # Safety
875    ///
876    /// The caller must guarantee that the pointer was allocated with `RtlAllocateHeap` with the default heap of the current
877    /// process as the heap handle, returned by `GetProcessHeap`.
878    ///
879    /// The caller must not allow this pointer to be aliased anywhere else. Conceptually, by calling `from_raw`, the caller
880    /// must guarantee that ownership of the pointer `value` is transferred to this `RtlHeapBox`. This is to uphold the aliasing
881    /// requirements used by `RtlHeapBox` to implement various Deref and AsRef traits.
882    ///
883    /// On drop, this memory will be freed with `RtlFreeHeap`. The caller must not manually free this pointer.
884    pub unsafe fn from_raw(value: *mut T) -> Self {
885        Self {
886            value: NonNull::new(value).unwrap(),
887        }
888    }
889
890    /// Gets the contained pointer.
891    pub fn as_ptr(&self) -> *const T {
892        self.value.as_ptr()
893    }
894}
895
896impl<T> std::ops::Deref for RtlHeapBox<T> {
897    type Target = T;
898    fn deref(&self) -> &Self::Target {
899        // SAFETY: The pointer held by the RtlHeapBox is guaranteed to be valid and conform to the rules required by NonNull::as_ref.
900        unsafe { self.value.as_ref() }
901    }
902}
903
904impl<T> std::ops::DerefMut for RtlHeapBox<T> {
905    fn deref_mut(&mut self) -> &mut T {
906        // SAFETY: The pointer held by the RtlHeapBox is guaranteed to be valid and conform to the rules required by NonNull::as_mut.
907        unsafe { self.value.as_mut() }
908    }
909}
910
911impl<T> AsRef<T> for RtlHeapBox<T> {
912    fn as_ref(&self) -> &T {
913        // SAFETY: The pointer held by the RtlHeapBox is guaranteed to be valid and conform to the rules required by NonNull::as_ref.
914        unsafe { self.value.as_ref() }
915    }
916}
917
918impl<T> AsMut<T> for RtlHeapBox<T> {
919    fn as_mut(&mut self) -> &mut T {
920        // SAFETY: The pointer held by the RtlHeapBox is guaranteed to be valid and conform to the rules required by NonNull::as_mut.
921        unsafe { self.value.as_mut() }
922    }
923}
924
925impl<T: ?Sized> Drop for RtlHeapBox<T> {
926    fn drop(&mut self) {
927        // SAFETY: The pointer held by the RtlHeapBox must be allocated via RtlAllocateHeap from the constraints in
928        //         RtlHeapBox::from_raw.
929        unsafe {
930            RtlFreeHeap(
931                GetProcessHeap().cast::<c_void>(),
932                0,
933                self.value.as_ptr().cast::<c_void>(),
934            );
935        }
936    }
937}
938
939/// A wrapper around a sized buffer that was allocated with RtlAllocateHeap. This allows extracting a slice via helper methods
940/// instead of using [`RtlHeapBox`] directly.
941pub struct RtlHeapBuffer {
942    buffer: RtlHeapBox<u8>,
943    size: usize,
944}
945
946impl RtlHeapBuffer {
947    /// Creates a new `HeapBuffer` from a raw pointer and size.
948    ///
949    /// # Safety
950    ///
951    /// The caller must guarantee that the pointer `buffer` conforms to the safety requirements imposed by [`RtlHeapBox::from_raw`].
952    ///
953    /// Additionally, the pointer described by `buffer` must describe a [`u8`] array of count `size`.
954    pub unsafe fn from_raw(buffer: *mut u8, size: usize) -> RtlHeapBuffer {
955        Self {
956            // SAFETY: The caller has guaranteed that this pointer is an RtlAllocateHeap pointer and should be managed
957            //         via a RtlHeapBox.
958            buffer: unsafe { RtlHeapBox::from_raw(buffer) },
959            size,
960        }
961    }
962}
963
964impl std::ops::Deref for RtlHeapBuffer {
965    type Target = [u8];
966    fn deref(&self) -> &Self::Target {
967        // SAFETY: The pointer described by buffer is a u8 array of self.size, as required in RtlHeapBuffer::from_raw.
968        unsafe { std::slice::from_raw_parts(self.buffer.as_ptr(), self.size) }
969    }
970}
971
972/// `Send`+`Sync` wrapper around `OVERLAPPED`.
973///
974/// Internally uses an UnsafeCell since this may be concurrently updated by the
975/// kernel.
976#[repr(transparent)]
977#[derive(Default, Debug)]
978pub struct Overlapped(UnsafeCell<OVERLAPPED>);
979
980impl Overlapped {
981    pub fn new() -> Self {
982        Default::default()
983    }
984
985    /// Sets the offset for the IO request.
986    pub fn set_offset(&mut self, offset: i64) {
987        let overlapped = self.0.get_mut();
988        overlapped.Anonymous.Anonymous.Offset = offset as u32;
989        overlapped.Anonymous.Anonymous.OffsetHigh = (offset >> 32) as u32;
990    }
991
992    pub fn set_event(&mut self, event: RawHandle) {
993        self.0.get_mut().hEvent = event;
994    }
995
996    pub fn as_ptr(&self) -> *mut OVERLAPPED {
997        self.0.get()
998    }
999
1000    /// Polls the current operation status.
1001    pub fn io_status(&self) -> Option<(NTSTATUS, usize)> {
1002        let overlapped = self.0.get();
1003        // SAFETY: The kernel might be mutating the overlapped structure right
1004        // now, so this gets a &AtomicUsize just to the Internal field that
1005        // contains the completion status.
1006        let internal = unsafe { &*addr_of!((*overlapped).Internal).cast::<AtomicUsize>() };
1007        let status = internal.load(Ordering::Acquire) as NTSTATUS;
1008        if status != STATUS_PENDING {
1009            // SAFETY: the IO is complete so it's safe to read this value directly.
1010            let information = unsafe { (*self.0.get()).InternalHigh };
1011            Some((status, information))
1012        } else {
1013            None
1014        }
1015    }
1016}
1017
1018// SAFETY: By itself, an overlapped structure can be safely sent or shared
1019// across multiple threads. Of course, while it is owned by the kernel is cannot
1020// be concurrently accessed, but this has no bearing on its Send/Sync-ness.
1021unsafe impl Send for Overlapped {}
1022// SAFETY: See above comment.
1023unsafe impl Sync for Overlapped {}
1024
1025#[macro_export]
1026macro_rules! delayload {
1027    {$dll:literal {
1028        $(
1029            $(#[$a:meta])*
1030            $visibility:vis fn $name:ident($($params:ident : $types:ty),* $(,)?) -> $result:ty;
1031        )*
1032    }} => {
1033        fn get_module() -> Result<$crate::windows_sys::Win32::Foundation::HMODULE, $crate::windows_sys::Win32::Foundation::WIN32_ERROR> {
1034            use ::std::ffi::c_void;
1035            use ::std::ptr::null_mut;
1036            use ::std::sync::atomic::{AtomicPtr, Ordering};
1037            use $crate::windows_sys::Win32::{
1038                Foundation::{FreeLibrary, GetLastError},
1039                System::LibraryLoader::{LoadLibraryA},
1040            };
1041
1042            static MODULE: AtomicPtr<c_void> = AtomicPtr::new(null_mut());
1043            let mut module = MODULE.load(Ordering::Acquire);
1044            if module.is_null() {
1045                let new_module = unsafe { LoadLibraryA(concat!($dll, "\0").as_ptr().cast::<u8>()) };
1046                if new_module.is_null() {
1047                    return Err(unsafe { GetLastError() });
1048                }
1049                match MODULE.compare_exchange(null_mut(), new_module, Ordering::Release, Ordering::Acquire) {
1050                    Ok(_) => module = new_module,
1051                    Err(old_module) => {
1052                        // Another thread won the race, use their module and free ours
1053                        unsafe { FreeLibrary(new_module) };
1054                        module = old_module;
1055                    }
1056                }
1057            }
1058            Ok(module)
1059        }
1060
1061        mod funcs {
1062            #![expect(non_snake_case)]
1063            $(
1064                $(#[$a])*
1065                pub fn $name() -> Result<usize, $crate::windows_sys::Win32::Foundation::WIN32_ERROR> {
1066                    use ::std::concat;
1067                    use ::std::sync::atomic::{AtomicUsize, Ordering};
1068                    use $crate::windows_sys::Win32::{
1069                        Foundation::ERROR_PROC_NOT_FOUND,
1070                        System::LibraryLoader::GetProcAddress,
1071                    };
1072
1073                    // A FNCELL value 0 denotes that GetProcAddress has never been
1074                    // called for the given function.
1075                    // A FNCELL value 1 denotes that GetProcAddress has been called
1076                    // but the procedure does not exist.
1077                    // Any other FNCELL value denotes the result of GetProcAddress,
1078                    // the callable adress of the function.
1079                    static FNCELL: AtomicUsize = AtomicUsize::new(0);
1080                    let mut fnval = FNCELL.load(Ordering::Relaxed);
1081                    if fnval == 0 {
1082                        let module = super::get_module()?;
1083                        fnval = unsafe { GetProcAddress(
1084                            module,
1085                            concat!(stringify!($name), "\0").as_ptr().cast::<u8>()) }
1086                            .map(|f| f as usize)
1087                            .unwrap_or(0);
1088                        if fnval == 0 {
1089                            fnval = 1;
1090                        }
1091                        FNCELL.store(fnval, Ordering::Relaxed);
1092                    }
1093                    if fnval == 1 {
1094                        return Err(ERROR_PROC_NOT_FOUND)
1095                    }
1096                    Ok(fnval)
1097                }
1098            )*
1099        }
1100
1101        pub mod is_supported {
1102            #![expect(non_snake_case)]
1103            #![allow(dead_code)]
1104            $(
1105                $(#[$a])*
1106                pub fn $name() -> bool {
1107                    super::funcs::$name().is_ok()
1108                }
1109            )*
1110        }
1111
1112        $(
1113            $(#[$a])*
1114            #[expect(non_snake_case)]
1115            $visibility unsafe fn $name($($params: $types,)*) -> $result {
1116                match funcs::$name() {
1117                    Ok(fnval) => {
1118                        type FnType = unsafe extern "system" fn($($params: $types,)*) -> $result;
1119                        // SAFETY: fnval is a valid function pointer obtained from GetProcAddress
1120                        let fnptr: FnType = unsafe { ::std::mem::transmute(fnval) };
1121                        // SAFETY: The function pointer is valid and the caller must uphold the function's safety contract
1122                        unsafe { fnptr($($params,)*) }
1123                    },
1124                    Err(win32) => {
1125                        $crate::delayload!(@result_from_win32(($result), win32))
1126                    }
1127                }
1128            }
1129        )*
1130    };
1131
1132    (@result_from_win32((i32), $val:expr)) => { $crate::windows_result::HRESULT::from_win32($val) };
1133    (@result_from_win32((u32), $val:expr)) => { $val };
1134    (@result_from_win32((DWORD), $val:expr)) => { $val };
1135    (@result_from_win32((HRESULT), $val:expr)) => { $crate::windows_result::HRESULT::from_win32($val) };
1136    (@result_from_win32(($t:tt), $val:expr)) => { panic!("could not load: {}", $val) };
1137}
1138
1139/// Closes stdout, replacing it with the null device.
1140pub fn close_stdout() -> Result<()> {
1141    let new_stdout = File::open("nul")?;
1142    let stdout = io::stdout();
1143    // Prevent concurrent accesses to stdout.
1144    let _locked = stdout.lock();
1145    let old_handle = stdout.as_raw_handle();
1146    // SAFETY: transferring ownership of the new handle.
1147    unsafe {
1148        if SetStdHandle(STD_OUTPUT_HANDLE, new_stdout.into_raw_handle()) == 0 {
1149            panic!("failed to set handle");
1150        }
1151    }
1152    drop(_locked);
1153    unsafe {
1154        // SAFETY: the old handle is no longer referenced anywhere.
1155        CloseHandle(old_handle);
1156    }
1157
1158    Ok(())
1159}
1160
1161/// Disables the hard error dialog on "critical errors".
1162pub fn disable_hard_error_dialog() {
1163    // SAFETY: This Win32 API has no safety requirements.
1164    unsafe {
1165        SetErrorMode(GetErrorMode() | SEM_FAILCRITICALERRORS);
1166    }
1167}
1168
1169#[cfg(test)]
1170mod tests {
1171    use super::*;
1172
1173    #[test]
1174    fn test_dos_to_nt_path() {
1175        let pathu = dos_to_nt_path("c:\\foo").unwrap();
1176        assert!(
1177            pathu
1178                .as_slice()
1179                .iter()
1180                .copied()
1181                .eq("\\??\\c:\\foo".encode_utf16())
1182        );
1183    }
1184
1185    #[test]
1186    fn test_alloc_unicode_string() {
1187        let s: UnicodeString = "abc".try_into().unwrap();
1188        assert!(s.as_slice().iter().copied().eq("abc".encode_utf16()));
1189    }
1190
1191    #[test]
1192    fn test_delayload() {
1193        // Test delayloading kernel32.dll functions that are guaranteed to exist
1194        mod kernel32_delayload {
1195            delayload! { "kernel32.dll" {
1196                pub fn GetCurrentProcessId() -> u32;
1197                pub fn GetCurrentThreadId() -> u32;
1198            }}
1199        }
1200
1201        // Test is_supported for functions that exist
1202        assert!(kernel32_delayload::is_supported::GetCurrentProcessId());
1203        assert!(kernel32_delayload::is_supported::GetCurrentThreadId());
1204
1205        // Test that the functions work correctly
1206        unsafe {
1207            let pid = kernel32_delayload::GetCurrentProcessId();
1208            assert_ne!(pid, 0, "Process ID should not be zero");
1209
1210            let tid = kernel32_delayload::GetCurrentThreadId();
1211            assert_ne!(tid, 0, "Thread ID should not be zero");
1212        }
1213    }
1214
1215    #[test]
1216    fn test_delayload_missing_function() {
1217        // Test delayloading a function that doesn't exist
1218        mod kernel32_missing {
1219            delayload! { "kernel32.dll" {
1220                #[allow(dead_code)]
1221                pub fn NonExistentFunction() -> u32;
1222            }}
1223        }
1224
1225        // Test that is_supported correctly identifies missing functions
1226        assert!(!kernel32_missing::is_supported::NonExistentFunction());
1227    }
1228
1229    #[test]
1230    fn test_delayload_missing_dll() {
1231        // Test delayloading from a DLL that doesn't exist
1232        mod missing_dll {
1233            delayload! { "this_dll_does_not_exist_12345.dll" {
1234                #[allow(dead_code)]
1235                pub fn SomeFunction() -> u32;
1236            }}
1237        }
1238
1239        // Test that is_supported returns false when the DLL doesn't exist
1240        assert!(!missing_dll::is_supported::SomeFunction());
1241    }
1242
1243    #[test]
1244    fn test_delayload_multiple_calls() {
1245        // Test that multiple calls to the same delayloaded function work
1246        // and that the function pointer is cached correctly
1247        mod kernel32_cached {
1248            delayload! { "kernel32.dll" {
1249                pub fn GetCurrentProcessId() -> u32;
1250            }}
1251        }
1252
1253        unsafe {
1254            let pid1 = kernel32_cached::GetCurrentProcessId();
1255            let pid2 = kernel32_cached::GetCurrentProcessId();
1256            let pid3 = kernel32_cached::GetCurrentProcessId();
1257
1258            // All calls should return the same PID
1259            assert_eq!(pid1, pid2);
1260            assert_eq!(pid2, pid3);
1261        }
1262    }
1263}