1#[cfg(windows)]
7use bcrypt as sys;
8#[cfg(unix)]
9use ossl as sys;
10use thiserror::Error;
11
12pub struct XtsAes256(sys::XtsAes256);
14
15#[derive(Debug, Error)]
17#[error(transparent)]
18pub struct Error(sys::Error);
19
20impl XtsAes256 {
21 pub const KEY_LEN: usize = 64;
25
26 pub fn new(key: &[u8; Self::KEY_LEN], data_unit_size: u32) -> Result<Self, Error> {
28 sys::xts_aes_256(key, data_unit_size)
29 .map(Self)
30 .map_err(Error)
31 }
32
33 pub fn encrypt(&self) -> Result<XtsAes256Ctx<'_>, Error> {
35 Ok(XtsAes256Ctx(self.0.ctx(true).map_err(Error)?))
36 }
37
38 pub fn decrypt(&self) -> Result<XtsAes256Ctx<'_>, Error> {
40 Ok(XtsAes256Ctx(self.0.ctx(false).map_err(Error)?))
41 }
42}
43
44pub struct XtsAes256Ctx<'a>(sys::XtsAes256Ctx<'a>);
46
47impl XtsAes256Ctx<'_> {
48 pub fn cipher(&mut self, tweak: u128, data: &mut [u8]) -> Result<(), Error> {
50 self.0.cipher(&tweak.to_le_bytes(), data).map_err(Error)?;
51 Ok(())
52 }
53}
54
55#[cfg(unix)]
56mod ossl {
57 pub struct NonStreamingCipher {
58 enc: openssl::cipher_ctx::CipherCtx,
59 dec: openssl::cipher_ctx::CipherCtx,
60 }
61
62 pub struct NonStreamingCipherCtx<'a> {
63 ctx: openssl::cipher_ctx::CipherCtx,
64 enc: bool,
65 _dummy: &'a (),
66 }
67
68 pub type Error = openssl::error::ErrorStack;
69
70 pub type XtsAes256 = NonStreamingCipher;
71 pub type XtsAes256Ctx<'a> = NonStreamingCipherCtx<'a>;
72
73 pub fn xts_aes_256(key: &[u8], _data_unit_size: u32) -> Result<XtsAes256, Error> {
74 let mut enc = openssl::cipher_ctx::CipherCtx::new()?;
75 enc.encrypt_init(
76 Some(openssl::cipher::Cipher::aes_256_xts()),
77 Some(key),
78 None,
79 )?;
80 let mut dec = openssl::cipher_ctx::CipherCtx::new()?;
81 dec.decrypt_init(
82 Some(openssl::cipher::Cipher::aes_256_xts()),
83 Some(key),
84 None,
85 )?;
86 Ok(NonStreamingCipher { enc, dec })
87 }
88
89 impl NonStreamingCipher {
90 pub fn ctx(&self, enc: bool) -> Result<NonStreamingCipherCtx<'_>, Error> {
91 let mut ctx = openssl::cipher_ctx::CipherCtx::new()?;
92 ctx.copy(if enc { &self.enc } else { &self.dec })?;
93 Ok(NonStreamingCipherCtx {
94 ctx,
95 enc,
96 _dummy: &(),
97 })
98 }
99 }
100
101 impl NonStreamingCipherCtx<'_> {
102 pub fn cipher(&mut self, iv: &[u8], data: &mut [u8]) -> Result<(), Error> {
103 if self.enc {
104 self.ctx.encrypt_init(None, None, Some(iv))?;
105 } else {
106 self.ctx.decrypt_init(None, None, Some(iv))?;
107 }
108 self.ctx.cipher_update_inplace(data, data.len())?;
109 Ok(())
110 }
111 }
112}
113
114#[cfg(windows)]
115mod bcrypt {
116 #![expect(unsafe_code)]
118
119 use std::sync::OnceLock;
120 use thiserror::Error;
121 use windows::Win32::Foundation::NTSTATUS;
122 use windows::Win32::Foundation::RtlNtStatusToDosError;
123 use windows::Win32::Security::Cryptography::BCRYPT_ALG_HANDLE;
124 use windows::Win32::Security::Cryptography::BCRYPT_HANDLE;
125 use windows::Win32::Security::Cryptography::BCRYPT_KEY_HANDLE;
126 use windows::Win32::Security::Cryptography::BCRYPT_OPEN_ALGORITHM_PROVIDER_FLAGS;
127
128 #[derive(Debug, Error)]
129 #[error("{op} failed")]
130 pub struct Error {
131 op: &'static str,
132 #[source]
133 err: std::io::Error,
134 }
135
136 pub struct XtsAes256(Key);
137
138 pub struct XtsAes256Ctx<'a> {
139 key: &'a Key,
140 enc: bool,
141 }
142
143 impl XtsAes256 {
144 pub fn ctx(&self, enc: bool) -> Result<XtsAes256Ctx<'_>, Error> {
145 Ok(XtsAes256Ctx { key: &self.0, enc })
146 }
147 }
148
149 impl XtsAes256Ctx<'_> {
150 pub fn cipher(&self, tweak: &[u8; 16], data: &mut [u8]) -> Result<(), Error> {
151 let mut iv = u64::try_from(u128::from_le_bytes(*tweak))
155 .map_err(|_| Error {
156 op: "convert tweak",
157 err: std::io::ErrorKind::InvalidInput.into(),
158 })?
159 .to_le_bytes();
160
161 if self.enc {
162 self.key.encrypt(&mut iv, data)
163 } else {
164 self.key.decrypt(&mut iv, data)
165 }
166 }
167 }
168
169 static XTS_AES_256: OnceLock<AlgHandle> = OnceLock::new();
170
171 struct AlgHandle(BCRYPT_ALG_HANDLE);
172
173 unsafe impl Send for AlgHandle {}
175 unsafe impl Sync for AlgHandle {}
177
178 fn bcrypt_result(op: &'static str, status: NTSTATUS) -> Result<(), Error> {
179 if status.is_ok() {
180 Ok(())
181 } else {
182 let err = unsafe { RtlNtStatusToDosError(status) };
184 Err(Error {
185 op,
186 err: std::io::Error::from_raw_os_error(err as i32),
187 })
188 }
189 }
190
191 struct Key(BCRYPT_KEY_HANDLE);
192
193 unsafe impl Send for Key {}
195 unsafe impl Sync for Key {}
197
198 impl Drop for Key {
199 fn drop(&mut self) {
200 unsafe {
202 bcrypt_result(
203 "destroy key",
204 windows::Win32::Security::Cryptography::BCryptDestroyKey(self.0),
205 )
206 .unwrap();
207 }
208 }
209 }
210
211 impl Key {
212 fn encrypt(&self, iv: &mut [u8], data: &mut [u8]) -> Result<(), Error> {
213 let input = data.to_vec();
216 let mut n = 0;
217 let status = unsafe {
219 windows::Win32::Security::Cryptography::BCryptEncrypt(
220 self.0,
221 Some(&input),
222 None,
223 Some(iv),
224 Some(data),
225 &mut n,
226 windows::Win32::Security::Cryptography::BCRYPT_FLAGS(0),
227 )
228 };
229 bcrypt_result("encrypt", status)?;
230 assert_eq!(n as usize, data.len());
231 Ok(())
232 }
233
234 fn decrypt(&self, iv: &mut [u8], data: &mut [u8]) -> Result<(), Error> {
235 let input = data.to_vec();
238 let mut n = 0;
239 let status = unsafe {
241 windows::Win32::Security::Cryptography::BCryptDecrypt(
242 self.0,
243 Some(&input),
244 None,
245 Some(iv),
246 Some(data),
247 &mut n,
248 windows::Win32::Security::Cryptography::BCRYPT_FLAGS(0),
249 )
250 };
251 bcrypt_result("decrypt", status)?;
252 assert_eq!(n as usize, data.len());
253 Ok(())
254 }
255 }
256
257 pub fn xts_aes_256(key: &[u8], data_unit_size: u32) -> Result<XtsAes256, Error> {
258 let alg = if let Some(alg) = XTS_AES_256.get() {
259 alg
260 } else {
261 let mut handle = BCRYPT_ALG_HANDLE::default();
262 let status = unsafe {
264 windows::Win32::Security::Cryptography::BCryptOpenAlgorithmProvider(
265 &mut handle,
266 windows::Win32::Security::Cryptography::BCRYPT_XTS_AES_ALGORITHM,
267 windows::Win32::Security::Cryptography::MS_PRIMITIVE_PROVIDER,
268 BCRYPT_OPEN_ALGORITHM_PROVIDER_FLAGS(0),
269 )
270 };
271 bcrypt_result("open algorithm provider", status)?;
272 if let Err(AlgHandle(handle)) = XTS_AES_256.set(AlgHandle(handle)) {
273 unsafe {
275 bcrypt_result(
276 "close algorithm provider",
277 windows::Win32::Security::Cryptography::BCryptCloseAlgorithmProvider(
278 handle, 0,
279 ),
280 )
281 .unwrap();
282 }
283 }
284 XTS_AES_256.get().unwrap()
285 };
286 let key = {
287 let mut handle = BCRYPT_KEY_HANDLE::default();
288 let status = unsafe {
290 windows::Win32::Security::Cryptography::BCryptGenerateSymmetricKey(
291 alg.0,
292 &mut handle,
293 None,
294 key,
295 0,
296 )
297 };
298 bcrypt_result("generate symmetric key", status)?;
299 Key(handle)
300 };
301
302 let status = unsafe {
304 windows::Win32::Security::Cryptography::BCryptSetProperty(
305 BCRYPT_HANDLE(key.0.0),
306 windows::Win32::Security::Cryptography::BCRYPT_MESSAGE_BLOCK_LENGTH,
307 &data_unit_size.to_ne_bytes(),
308 0,
309 )
310 };
311 bcrypt_result("set message block length", status)?;
312
313 Ok(XtsAes256(key))
314 }
315}