block_crypto/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Cryptography primitives for disk encryption.
5
6#[cfg(windows)]
7use bcrypt as sys;
8#[cfg(unix)]
9use ossl as sys;
10use thiserror::Error;
11
12/// XTS-AES-256 encryption/decryption.
13pub struct XtsAes256(sys::XtsAes256);
14
15/// An error for cryptographic operations.
16#[derive(Debug, Error)]
17#[error(transparent)]
18pub struct Error(sys::Error);
19
20impl XtsAes256 {
21    /// The required key length for the algorithm.
22    ///
23    /// Note that an XTS-AES-256 key contains two AES keys, each of which is 256 bits.
24    pub const KEY_LEN: usize = 64;
25
26    /// Creates a new XTS-AES-256 encryption/decryption context.
27    pub fn new(key: &[u8; Self::KEY_LEN], data_unit_size: u32) -> Result<Self, Error> {
28        sys::xts_aes_256(key, data_unit_size)
29            .map(Self)
30            .map_err(Error)
31    }
32
33    /// Returns a context for encrypting data.
34    pub fn encrypt(&self) -> Result<XtsAes256Ctx<'_>, Error> {
35        Ok(XtsAes256Ctx(self.0.ctx(true).map_err(Error)?))
36    }
37
38    /// Returns a context for decrypting data.
39    pub fn decrypt(&self) -> Result<XtsAes256Ctx<'_>, Error> {
40        Ok(XtsAes256Ctx(self.0.ctx(false).map_err(Error)?))
41    }
42}
43
44/// Context for XTS-AES-256 encryption/decryption.
45pub struct XtsAes256Ctx<'a>(sys::XtsAes256Ctx<'a>);
46
47impl XtsAes256Ctx<'_> {
48    /// Encrypts or decrypts `data` using the provided `tweak`.
49    pub fn cipher(&mut self, tweak: u128, data: &mut [u8]) -> Result<(), Error> {
50        self.0.cipher(&tweak.to_le_bytes(), data).map_err(Error)?;
51        Ok(())
52    }
53}
54
55#[cfg(unix)]
56mod ossl {
57    pub struct NonStreamingCipher {
58        enc: openssl::cipher_ctx::CipherCtx,
59        dec: openssl::cipher_ctx::CipherCtx,
60    }
61
62    pub struct NonStreamingCipherCtx<'a> {
63        ctx: openssl::cipher_ctx::CipherCtx,
64        enc: bool,
65        _dummy: &'a (),
66    }
67
68    pub type Error = openssl::error::ErrorStack;
69
70    pub type XtsAes256 = NonStreamingCipher;
71    pub type XtsAes256Ctx<'a> = NonStreamingCipherCtx<'a>;
72
73    pub fn xts_aes_256(key: &[u8], _data_unit_size: u32) -> Result<XtsAes256, Error> {
74        let mut enc = openssl::cipher_ctx::CipherCtx::new()?;
75        enc.encrypt_init(
76            Some(openssl::cipher::Cipher::aes_256_xts()),
77            Some(key),
78            None,
79        )?;
80        let mut dec = openssl::cipher_ctx::CipherCtx::new()?;
81        dec.decrypt_init(
82            Some(openssl::cipher::Cipher::aes_256_xts()),
83            Some(key),
84            None,
85        )?;
86        Ok(NonStreamingCipher { enc, dec })
87    }
88
89    impl NonStreamingCipher {
90        pub fn ctx(&self, enc: bool) -> Result<NonStreamingCipherCtx<'_>, Error> {
91            let mut ctx = openssl::cipher_ctx::CipherCtx::new()?;
92            ctx.copy(if enc { &self.enc } else { &self.dec })?;
93            Ok(NonStreamingCipherCtx {
94                ctx,
95                enc,
96                _dummy: &(),
97            })
98        }
99    }
100
101    impl NonStreamingCipherCtx<'_> {
102        pub fn cipher(&mut self, iv: &[u8], data: &mut [u8]) -> Result<(), Error> {
103            if self.enc {
104                self.ctx.encrypt_init(None, None, Some(iv))?;
105            } else {
106                self.ctx.decrypt_init(None, None, Some(iv))?;
107            }
108            self.ctx.cipher_update_inplace(data, data.len())?;
109            Ok(())
110        }
111    }
112}
113
114#[cfg(windows)]
115mod bcrypt {
116    // UNSAFETY: calling bcrypt APIs
117    #![expect(unsafe_code)]
118
119    use std::sync::OnceLock;
120    use thiserror::Error;
121    use windows::Win32::Foundation::NTSTATUS;
122    use windows::Win32::Foundation::RtlNtStatusToDosError;
123    use windows::Win32::Security::Cryptography::BCRYPT_ALG_HANDLE;
124    use windows::Win32::Security::Cryptography::BCRYPT_HANDLE;
125    use windows::Win32::Security::Cryptography::BCRYPT_KEY_HANDLE;
126    use windows::Win32::Security::Cryptography::BCRYPT_OPEN_ALGORITHM_PROVIDER_FLAGS;
127
128    #[derive(Debug, Error)]
129    #[error("{op} failed")]
130    pub struct Error {
131        op: &'static str,
132        #[source]
133        err: std::io::Error,
134    }
135
136    pub struct XtsAes256(Key);
137
138    pub struct XtsAes256Ctx<'a> {
139        key: &'a Key,
140        enc: bool,
141    }
142
143    impl XtsAes256 {
144        pub fn ctx(&self, enc: bool) -> Result<XtsAes256Ctx<'_>, Error> {
145            Ok(XtsAes256Ctx { key: &self.0, enc })
146        }
147    }
148
149    impl XtsAes256Ctx<'_> {
150        pub fn cipher(&self, tweak: &[u8; 16], data: &mut [u8]) -> Result<(), Error> {
151            // BCrypt only supports 64-bit tweaks, internally padding out the high 8
152            // bytes with zeroes. (Why?) This is fine for our purposes but it's a
153            // bit annoying to shuffle things around.
154            let mut iv = u64::try_from(u128::from_le_bytes(*tweak))
155                .map_err(|_| Error {
156                    op: "convert tweak",
157                    err: std::io::ErrorKind::InvalidInput.into(),
158                })?
159                .to_le_bytes();
160
161            if self.enc {
162                self.key.encrypt(&mut iv, data)
163            } else {
164                self.key.decrypt(&mut iv, data)
165            }
166        }
167    }
168
169    static XTS_AES_256: OnceLock<AlgHandle> = OnceLock::new();
170
171    struct AlgHandle(BCRYPT_ALG_HANDLE);
172
173    // SAFETY: the handle can be sent across threads.
174    unsafe impl Send for AlgHandle {}
175    // SAFETY: the handle can be shared across threads.
176    unsafe impl Sync for AlgHandle {}
177
178    fn bcrypt_result(op: &'static str, status: NTSTATUS) -> Result<(), Error> {
179        if status.is_ok() {
180            Ok(())
181        } else {
182            // SAFETY: no preconditions for this call.
183            let err = unsafe { RtlNtStatusToDosError(status) };
184            Err(Error {
185                op,
186                err: std::io::Error::from_raw_os_error(err as i32),
187            })
188        }
189    }
190
191    struct Key(BCRYPT_KEY_HANDLE);
192
193    // SAFETY: the handle can be sent across threads.
194    unsafe impl Send for Key {}
195    // SAFETY: the handle can be shared across threads.
196    unsafe impl Sync for Key {}
197
198    impl Drop for Key {
199        fn drop(&mut self) {
200            // SAFETY: handle is valid and not aliased.
201            unsafe {
202                bcrypt_result(
203                    "destroy key",
204                    windows::Win32::Security::Cryptography::BCryptDestroyKey(self.0),
205                )
206                .unwrap();
207            }
208        }
209    }
210
211    impl Key {
212        fn encrypt(&self, iv: &mut [u8], data: &mut [u8]) -> Result<(), Error> {
213            // TODO: fix windows crate to allow aliased input and output, as
214            // allowed by the API.
215            let input = data.to_vec();
216            let mut n = 0;
217            // SAFETY: key and buffers are valid for the duration of the call.
218            let status = unsafe {
219                windows::Win32::Security::Cryptography::BCryptEncrypt(
220                    self.0,
221                    Some(&input),
222                    None,
223                    Some(iv),
224                    Some(data),
225                    &mut n,
226                    windows::Win32::Security::Cryptography::BCRYPT_FLAGS(0),
227                )
228            };
229            bcrypt_result("encrypt", status)?;
230            assert_eq!(n as usize, data.len());
231            Ok(())
232        }
233
234        fn decrypt(&self, iv: &mut [u8], data: &mut [u8]) -> Result<(), Error> {
235            // TODO: fix windows crate to allow aliased input and output, as
236            // allowed by the API.
237            let input = data.to_vec();
238            let mut n = 0;
239            // SAFETY: key and buffers are valid for the duration of the call.
240            let status = unsafe {
241                windows::Win32::Security::Cryptography::BCryptDecrypt(
242                    self.0,
243                    Some(&input),
244                    None,
245                    Some(iv),
246                    Some(data),
247                    &mut n,
248                    windows::Win32::Security::Cryptography::BCRYPT_FLAGS(0),
249                )
250            };
251            bcrypt_result("decrypt", status)?;
252            assert_eq!(n as usize, data.len());
253            Ok(())
254        }
255    }
256
257    pub fn xts_aes_256(key: &[u8], data_unit_size: u32) -> Result<XtsAes256, Error> {
258        let alg = if let Some(alg) = XTS_AES_256.get() {
259            alg
260        } else {
261            let mut handle = BCRYPT_ALG_HANDLE::default();
262            // SAFETY: no safety requirements.
263            let status = unsafe {
264                windows::Win32::Security::Cryptography::BCryptOpenAlgorithmProvider(
265                    &mut handle,
266                    windows::Win32::Security::Cryptography::BCRYPT_XTS_AES_ALGORITHM,
267                    windows::Win32::Security::Cryptography::MS_PRIMITIVE_PROVIDER,
268                    BCRYPT_OPEN_ALGORITHM_PROVIDER_FLAGS(0),
269                )
270            };
271            bcrypt_result("open algorithm provider", status)?;
272            if let Err(AlgHandle(handle)) = XTS_AES_256.set(AlgHandle(handle)) {
273                // SAFETY: handle is valid and not aliased.
274                unsafe {
275                    bcrypt_result(
276                        "close algorithm provider",
277                        windows::Win32::Security::Cryptography::BCryptCloseAlgorithmProvider(
278                            handle, 0,
279                        ),
280                    )
281                    .unwrap();
282                }
283            }
284            XTS_AES_256.get().unwrap()
285        };
286        let key = {
287            let mut handle = BCRYPT_KEY_HANDLE::default();
288            // SAFETY: the algorithm handle is valid.
289            let status = unsafe {
290                windows::Win32::Security::Cryptography::BCryptGenerateSymmetricKey(
291                    alg.0,
292                    &mut handle,
293                    None,
294                    key,
295                    0,
296                )
297            };
298            bcrypt_result("generate symmetric key", status)?;
299            Key(handle)
300        };
301
302        // SAFETY: the key handle is valid.
303        let status = unsafe {
304            windows::Win32::Security::Cryptography::BCryptSetProperty(
305                BCRYPT_HANDLE(key.0.0),
306                windows::Win32::Security::Cryptography::BCRYPT_MESSAGE_BLOCK_LENGTH,
307                &data_unit_size.to_ne_bytes(),
308                0,
309            )
310        };
311        bcrypt_result("set message block length", status)?;
312
313        Ok(XtsAes256(key))
314    }
315}