1#![forbid(unsafe_code)]
8
9pub mod resolver;
10
11use block_crypto::XtsAes256;
12use disk_backend::Disk;
13use disk_backend::DiskError;
14use disk_backend::DiskIo;
15use disk_backend::UnmapBehavior;
16use guestmem::GuestMemory;
17use guestmem::MemoryRead;
18use guestmem::MemoryWrite;
19use inspect::Inspect;
20use scsi_buffers::OwnedRequestBuffers;
21use scsi_buffers::RequestBuffers;
22use thiserror::Error;
23
24#[derive(Inspect)]
26pub struct CryptDisk {
27 inner: Disk,
28 #[inspect(skip)]
29 cipher: XtsAes256,
30}
31
32#[derive(Debug, Error)]
34pub enum NewDiskError {
35 #[error("crypto error")]
37 Crypto(#[source] block_crypto::Error),
38 #[error("invalid key size for cipher")]
40 InvalidKeySize,
41}
42
43impl CryptDisk {
44 pub fn new(
47 cipher: disk_crypt_resources::Cipher,
48 key: &[u8],
49 inner: Disk,
50 ) -> Result<Self, NewDiskError> {
51 match cipher {
52 disk_crypt_resources::Cipher::XtsAes256 => {}
53 }
54 let cipher = XtsAes256::new(
55 key.try_into().map_err(|_| NewDiskError::InvalidKeySize)?,
56 inner.sector_size(),
57 )
58 .map_err(NewDiskError::Crypto)?;
59 Ok(Self { inner, cipher })
60 }
61}
62
63impl DiskIo for CryptDisk {
64 fn disk_type(&self) -> &str {
65 "crypt"
66 }
67
68 fn sector_count(&self) -> u64 {
69 self.inner.sector_count()
70 }
71
72 fn sector_size(&self) -> u32 {
73 self.inner.sector_size()
74 }
75
76 fn disk_id(&self) -> Option<[u8; 16]> {
77 self.inner.disk_id()
78 }
79
80 fn physical_sector_size(&self) -> u32 {
81 self.inner.physical_sector_size()
82 }
83
84 fn is_fua_respected(&self) -> bool {
85 self.inner.is_fua_respected()
86 }
87
88 fn is_read_only(&self) -> bool {
89 self.inner.is_read_only()
90 }
91
92 fn pr(&self) -> Option<&dyn disk_backend::pr::PersistentReservation> {
95 self.inner.pr()
96 }
97
98 async fn read_vectored(
99 &self,
100 buffers: &RequestBuffers<'_>,
101 sector: u64,
102 ) -> Result<(), DiskError> {
103 self.inner.read_vectored(buffers, sector).await?;
106
107 let mut ctx = self.cipher.decrypt().map_err(crypto_error)?;
109 let mut buf = vec![0; self.sector_size() as usize];
110 let mut reader = buffers.reader();
111 let mut writer = buffers.writer();
112 for i in 0..buffers.len() >> self.inner.sector_shift() {
113 reader.read(&mut buf)?;
114 ctx.cipher((sector + i as u64).into(), &mut buf)
115 .map_err(crypto_error)?;
116 writer.write(&buf)?;
117 }
118 Ok(())
119 }
120
121 async fn write_vectored(
122 &self,
123 buffers: &RequestBuffers<'_>,
124 sector: u64,
125 fua: bool,
126 ) -> Result<(), DiskError> {
127 let mut mem = GuestMemory::allocate(buffers.len());
134 let buf = mem.inner_buf_mut().unwrap();
135 let staged = OwnedRequestBuffers::linear(0, buffers.len(), true);
136
137 let mut ctx = self.cipher.encrypt().map_err(crypto_error)?;
139 let mut reader = buffers.reader();
140 let sector_size = self.inner.sector_size() as usize;
141 let mut offset = 0;
142 let mut tweak = sector;
143 while offset < buffers.len() {
144 let this_buf = &mut buf[offset..][..sector_size];
145 reader.read(this_buf)?;
146 ctx.cipher(tweak.into(), this_buf).map_err(crypto_error)?;
147 offset += sector_size;
148 tweak += 1;
149 }
150
151 self.inner
153 .write_vectored(&staged.buffer(&mem), sector, fua)
154 .await?;
155 Ok(())
156 }
157
158 async fn sync_cache(&self) -> Result<(), DiskError> {
159 self.inner.sync_cache().await
160 }
161
162 async fn wait_resize(&self, sector_count: u64) -> u64 {
164 self.inner.wait_resize(sector_count).await
165 }
166
167 fn unmap(
168 &self,
169 sector: u64,
170 count: u64,
171 block_level_only: bool,
172 ) -> impl std::future::Future<Output = Result<(), DiskError>> + Send {
173 self.inner.unmap(sector, count, block_level_only)
174 }
175
176 fn unmap_behavior(&self) -> UnmapBehavior {
177 match self.inner.unmap_behavior() {
178 UnmapBehavior::Unspecified | UnmapBehavior::Zeroes => UnmapBehavior::Unspecified,
181 UnmapBehavior::Ignored => UnmapBehavior::Ignored,
182 }
183 }
184
185 fn optimal_unmap_sectors(&self) -> u32 {
186 self.inner.optimal_unmap_sectors()
187 }
188}
189
190fn crypto_error(err: block_crypto::Error) -> DiskError {
191 DiskError::Io(std::io::Error::other(err))
192}
193
194#[cfg(test)]
195mod tests {
196 use crate::CryptDisk;
197 use disk_backend::Disk;
198 use guestmem::GuestMemory;
199 use pal_async::async_test;
200 use scsi_buffers::OwnedRequestBuffers;
201
202 #[async_test]
203 async fn test_basic_read_write() {
204 let key = [[0u8; 32], [1; 32]];
205 let disk = CryptDisk::new(
206 disk_crypt_resources::Cipher::XtsAes256,
207 key.as_flattened(),
208 disklayer_ram::ram_disk(0x200000, false).unwrap(),
209 )
210 .unwrap();
211 let disk = Disk::new(disk).unwrap();
212 let buffers = OwnedRequestBuffers::linear(0, 0x10000, true);
213 let mut mem = GuestMemory::allocate(0x10000);
214 let pattern = {
215 let mut acc = 3u32;
216 (0..0x10000)
217 .map(|_| {
218 acc = acc.wrapping_mul(7);
219 acc as u8
220 })
221 .collect::<Vec<_>>()
222 };
223 mem.inner_buf_mut().unwrap().copy_from_slice(&pattern);
224 disk.write_vectored(&buffers.buffer(&mem), 10, false)
225 .await
226 .unwrap();
227 mem.inner_buf_mut().unwrap().fill(0);
228 disk.read_vectored(&buffers.buffer(&mem), 10).await.unwrap();
229 assert_eq!(mem.inner_buf_mut().unwrap(), &pattern);
230 }
231}