1#![cfg(target_os = "linux")]
11#![forbid(unsafe_code)]
12
13use disk_backend::DiskError;
14use disk_backend::DiskIo;
15use guest_emulation_transport::GuestEmulationTransportClient;
16use guestmem::MemoryRead;
17use guestmem::MemoryWrite;
18use inspect::Inspect;
19use save_restore::SavedBlockStorageMetadata;
20use scsi_buffers::RequestBuffers;
21use std::io;
22use thiserror::Error;
23
24#[derive(Clone, Debug, Inspect)]
26pub struct GetVmgsDisk {
27 get: GuestEmulationTransportClient,
28 sector_size: u32,
29 sector_shift: u32,
30 physical_sector_size: u32,
31 sector_count: u64,
32 max_transfer_sectors: u32,
33 max_transfer_size_bytes: u32,
34}
35
36#[derive(Debug, Error)]
38pub enum NewGetVmgsDiskError {
39 #[error("GET VMGS IO error")]
41 Io(#[source] guest_emulation_transport::error::VmgsIoError),
42 #[error("invalid sector size")]
44 InvalidSectorSize,
45 #[error("invalid physical sector size")]
47 InvalidPhysicalSectorSize,
48 #[error("invalid sector count")]
50 InvalidSectorCount,
51 #[error("disk ends with a partial physical sector")]
53 IncompletePhysicalSector,
54 #[error("transfer size is smaller than the physical sector size")]
56 InvalidMaxTransferSize,
57}
58
59impl GetVmgsDisk {
60 pub async fn new(get: GuestEmulationTransportClient) -> Result<Self, NewGetVmgsDiskError> {
63 let response = get
64 .vmgs_get_device_info()
65 .await
66 .map_err(NewGetVmgsDiskError::Io)?;
67 Self::new_inner(
68 get,
69 response.bytes_per_logical_sector.into(),
70 response.bytes_per_physical_sector.into(),
71 response.capacity,
72 response.maximum_transfer_size_bytes,
73 )
74 }
75
76 pub fn restore_with_meta(
88 get: GuestEmulationTransportClient,
89 meta: SavedBlockStorageMetadata,
90 ) -> Result<Self, NewGetVmgsDiskError> {
91 Self::new_inner(
92 get,
93 meta.sector_size,
94 meta.physical_sector_size,
95 meta.sector_count,
96 meta.max_transfer_size_bytes,
97 )
98 }
99
100 pub fn save_meta(&self) -> SavedBlockStorageMetadata {
103 SavedBlockStorageMetadata {
104 capacity: self.sector_count * self.sector_size as u64,
105 logical_sector_size: self.sector_size,
106 sector_count: self.sector_count,
107 sector_size: self.sector_size,
108 physical_sector_size: self.physical_sector_size,
109 max_transfer_size_bytes: self.max_transfer_size_bytes,
110 }
111 }
112
113 fn new_inner(
114 get: GuestEmulationTransportClient,
115 sector_size: u32,
116 physical_sector_size: u32,
117 sector_count: u64,
118 max_transfer_size: u32,
119 ) -> Result<Self, NewGetVmgsDiskError> {
120 if !sector_size.is_power_of_two() {
121 Err(NewGetVmgsDiskError::InvalidSectorSize)
122 } else if !physical_sector_size.is_power_of_two() || physical_sector_size < sector_size {
123 Err(NewGetVmgsDiskError::InvalidPhysicalSectorSize)
124 } else if sector_count.checked_mul(sector_size as u64).is_none() {
125 Err(NewGetVmgsDiskError::InvalidSectorCount)
126 } else if sector_count % (physical_sector_size / sector_size) as u64 != 0 {
127 Err(NewGetVmgsDiskError::IncompletePhysicalSector)
128 } else if max_transfer_size < physical_sector_size {
129 Err(NewGetVmgsDiskError::InvalidMaxTransferSize)
130 } else {
131 Ok(GetVmgsDisk {
132 get,
133 sector_size,
134 sector_shift: sector_size.trailing_zeros(),
135 physical_sector_size,
136 sector_count,
137 max_transfer_size_bytes: max_transfer_size,
138 max_transfer_sectors: max_transfer_size / physical_sector_size
140 * physical_sector_size
141 / sector_size,
142 })
143 }
144 }
145}
146
147impl DiskIo for GetVmgsDisk {
148 fn disk_type(&self) -> &str {
149 "vmgs-get"
150 }
151
152 fn sector_count(&self) -> u64 {
153 self.sector_count
154 }
155
156 fn sector_size(&self) -> u32 {
157 self.sector_size
158 }
159
160 fn disk_id(&self) -> Option<[u8; 16]> {
161 None
162 }
163
164 fn physical_sector_size(&self) -> u32 {
165 self.physical_sector_size
166 }
167
168 fn is_fua_respected(&self) -> bool {
169 false
170 }
171
172 fn is_read_only(&self) -> bool {
173 false
174 }
175
176 async fn read_vectored(
177 &self,
178 buffers: &RequestBuffers<'_>,
179 mut sector: u64,
180 ) -> Result<(), DiskError> {
181 let mut writer = buffers.writer();
182 let mut remaining_sectors = buffers.len() >> self.sector_shift;
183 if sector + remaining_sectors as u64 > self.sector_count {
184 return Err(DiskError::IllegalBlock);
185 }
186 while remaining_sectors != 0 {
187 let this_sector_count = remaining_sectors.min(self.max_transfer_sectors as usize);
188 let data = self
189 .get
190 .vmgs_read(sector, this_sector_count as u32, self.sector_size)
191 .await
192 .map_err(|err| DiskError::Io(io::Error::other(err)))?;
193
194 writer.write(&data)?;
195 sector += this_sector_count as u64;
196 remaining_sectors -= this_sector_count;
197 }
198 Ok(())
199 }
200
201 async fn write_vectored(
202 &self,
203 buffers: &RequestBuffers<'_>,
204 mut sector: u64,
205 _fua: bool,
206 ) -> Result<(), DiskError> {
207 let mut reader = buffers.reader();
208 let mut remaining_sector_count = buffers.len() >> self.sector_shift;
209 if sector + remaining_sector_count as u64 > self.sector_count {
210 return Err(DiskError::IllegalBlock);
211 }
212 while remaining_sector_count != 0 {
213 let this_sector_count = remaining_sector_count.min(self.max_transfer_sectors as usize);
214 let data = reader.read_n(this_sector_count << self.sector_shift)?;
215 self.get
216 .vmgs_write(sector, data, self.sector_size)
217 .await
218 .map_err(|err| DiskError::Io(io::Error::other(err)))?;
219
220 remaining_sector_count -= this_sector_count;
221 sector += this_sector_count as u64;
222 }
223 Ok(())
224 }
225
226 async fn sync_cache(&self) -> Result<(), DiskError> {
228 self.get
229 .vmgs_flush()
230 .await
231 .map_err(|err| DiskError::Io(io::Error::other(err)))
232 }
233
234 async fn unmap(
235 &self,
236 _sector: u64,
237 _count: u64,
238 _block_level_only: bool,
239 ) -> Result<(), DiskError> {
240 Ok(())
241 }
242
243 fn unmap_behavior(&self) -> disk_backend::UnmapBehavior {
244 disk_backend::UnmapBehavior::Ignored
245 }
246}
247
248pub mod save_restore {
250 use mesh::payload::Protobuf;
251
252 #[derive(Protobuf, Clone)]
254 #[mesh(package = "vmgs")]
255 pub struct SavedBlockStorageMetadata {
256 #[mesh(1)]
258 pub capacity: u64,
259 #[mesh(2)]
261 pub logical_sector_size: u32,
262 #[mesh(3)]
264 pub sector_count: u64,
265 #[mesh(4)]
267 pub sector_size: u32,
268 #[mesh(5)]
270 pub physical_sector_size: u32,
271 #[mesh(6)]
273 pub max_transfer_size_bytes: u32,
274 }
275}
276
277#[cfg(test)]
279mod tests {
280 use super::*;
281 use disk_backend::Disk;
282 use guest_emulation_transport::api::ProtocolVersion;
283 use guest_emulation_transport::test_utilities::TestGet;
284 use guest_emulation_transport::test_utilities::new_transport_pair;
285 use pal_async::DefaultDriver;
286 use pal_async::async_test;
287 use pal_async::task::Task;
288 use vmgs::FileId;
289 use vmgs::Vmgs;
290 use vmgs_broker::VmgsClient;
291 use vmgs_broker::spawn_vmgs_broker;
292
293 async fn spawn_vmgs(driver: &DefaultDriver) -> (VmgsClient, TestGet, Task<()>) {
294 let get = new_transport_pair(driver, None, ProtocolVersion::NICKEL_REV2).await;
295 let vmgs_get = GetVmgsDisk::new(get.client.clone()).await.unwrap();
296 let vmgs = Vmgs::format_new(Disk::new(vmgs_get).unwrap(), None)
297 .await
298 .unwrap();
299 let (vmgs, task) = spawn_vmgs_broker(driver, vmgs);
300 (vmgs, get, task)
301 }
302
303 #[async_test]
304 async fn basic_read_write(driver: DefaultDriver) {
305 let (vmgs, _get, _task) = spawn_vmgs(&driver).await;
306 let file_id = FileId::BIOS_NVRAM;
307
308 let buf = b"hello world".to_vec();
310 vmgs.write_file(file_id, buf.clone()).await.unwrap();
311
312 let info = vmgs.get_file_info(file_id).await.unwrap();
314 assert_eq!(info.valid_bytes as usize, buf.len());
315 let read_buf = vmgs.read_file(file_id).await.unwrap();
316
317 assert_eq!(buf, read_buf);
318 }
319
320 #[async_test]
321 async fn multiple_read_write(driver: DefaultDriver) {
322 let (vmgs, _get, _task) = spawn_vmgs(&driver).await;
323 let file_id_1 = FileId::BIOS_NVRAM;
324 let file_id_2 = FileId::TPM_PPI;
325 let buf_1 = b"Data data data".to_vec();
326 let buf_2 = b"password".to_vec();
327 let buf_3 = b"other data data".to_vec();
328
329 vmgs.write_file(file_id_1, buf_1.clone()).await.unwrap();
330 let info = vmgs.get_file_info(file_id_1).await.unwrap();
331 assert_eq!(info.valid_bytes as usize, buf_1.len());
332 let read_buf_1 = vmgs.read_file(file_id_1).await.unwrap();
333 assert_eq!(buf_1, read_buf_1);
334
335 vmgs.write_file(file_id_2, buf_2.clone()).await.unwrap();
336 let info = vmgs.get_file_info(file_id_2).await.unwrap();
337 assert_eq!(info.valid_bytes as usize, buf_2.len());
338 let read_buf_2 = vmgs.read_file(file_id_2).await.unwrap();
339 assert_eq!(buf_2, read_buf_2);
340
341 vmgs.write_file(file_id_1, buf_3.clone()).await.unwrap();
342 let info = vmgs.get_file_info(file_id_1).await.unwrap();
343 assert_eq!(info.valid_bytes as usize, buf_3.len());
344 let read_buf_3 = vmgs.read_file(file_id_1).await.unwrap();
345 assert_eq!(buf_3, read_buf_3);
346
347 vmgs.write_file(file_id_1, buf_1.clone()).await.unwrap();
348 let info = vmgs.get_file_info(file_id_1).await.unwrap();
349 assert_eq!(info.valid_bytes as usize, buf_1.len());
350 let read_buf_1 = vmgs.read_file(file_id_1).await.unwrap();
351 assert_eq!(buf_1, read_buf_1);
352
353 vmgs.write_file(file_id_2, buf_2.clone()).await.unwrap();
354 let info = vmgs.get_file_info(file_id_2).await.unwrap();
355 assert_eq!(info.valid_bytes as usize, buf_2.len());
356 let read_buf_2 = vmgs.read_file(file_id_2).await.unwrap();
357 assert_eq!(buf_2, read_buf_2);
358
359 vmgs.write_file(file_id_1, buf_3.clone()).await.unwrap();
360 let info = vmgs.get_file_info(file_id_1).await.unwrap();
361 assert_eq!(info.valid_bytes as usize, buf_3.len());
362 let read_buf_3 = vmgs.read_file(file_id_1).await.unwrap();
363 assert_eq!(buf_3, read_buf_3);
364 }
365
366 #[async_test]
367 async fn test_empty_write(driver: DefaultDriver) {
368 let (vmgs, _get, _task) = spawn_vmgs(&driver).await;
369 let file_id = FileId::BIOS_NVRAM;
370
371 let buf: Vec<u8> = Vec::new();
372 vmgs.write_file(file_id, buf.clone()).await.unwrap();
373
374 let info = vmgs.get_file_info(file_id).await.unwrap();
376 assert_eq!(info.valid_bytes as usize, 0);
377 let read_buf = vmgs.read_file(file_id).await.unwrap();
378
379 assert_eq!(buf, read_buf);
380 assert_eq!(read_buf.len(), 0);
381 }
382
383 #[async_test]
384 async fn test_read_write_large(driver: DefaultDriver) {
385 let (vmgs, _get, _task) = spawn_vmgs(&driver).await;
386 let file_id = FileId::BIOS_NVRAM;
387
388 let buf: Vec<u8> = (0..).map(|x| x as u8).take(1024 * 4 * 4 + 1).collect();
390 vmgs.write_file(file_id, buf.clone()).await.unwrap();
391
392 let info = vmgs.get_file_info(file_id).await.unwrap();
394 assert_eq!(info.valid_bytes as usize, buf.len());
395 let read_buf = vmgs.read_file(file_id).await.unwrap();
396
397 assert_eq!(buf, read_buf);
398 }
399
400 #[async_test]
401 async fn test_read_write_encryption(driver: DefaultDriver) {
402 let get = new_transport_pair(&driver, None, ProtocolVersion::NICKEL_REV2).await;
403 let vmgs_get = GetVmgsDisk::new(get.client.clone()).await.unwrap();
404 let mut vmgs = Vmgs::format_new(Disk::new(vmgs_get).unwrap(), None)
405 .await
406 .unwrap();
407 let file_id = FileId::BIOS_NVRAM;
408 let encryption_key = vec![1; 32];
409
410 vmgs.add_new_encryption_key(&encryption_key, vmgs::EncryptionAlgorithm::AES_GCM)
411 .await
412 .unwrap();
413
414 let buf: Vec<u8> = (0..).map(|x| x as u8).take(1024 * 4 * 4 + 1).collect();
416 vmgs.write_file_encrypted(file_id, &buf).await.unwrap();
417
418 let info = vmgs.get_file_info(file_id).unwrap();
420 assert_eq!(info.valid_bytes as usize, buf.len());
421 let read_buf = vmgs.read_file(file_id).await.unwrap();
422
423 assert_eq!(buf, read_buf);
424
425 drop(vmgs);
426
427 let vmgs_get = GetVmgsDisk::new(get.client.clone()).await.unwrap();
428 let mut vmgs = Vmgs::open(Disk::new(vmgs_get).unwrap(), None)
429 .await
430 .unwrap();
431
432 let read_buf = vmgs.read_file(file_id).await.unwrap();
433
434 assert_ne!(buf, read_buf);
435
436 vmgs.unlock_with_encryption_key(&encryption_key)
437 .await
438 .unwrap();
439
440 let (vmgs, _task) = spawn_vmgs_broker(&driver, vmgs);
441
442 let info = vmgs.get_file_info(file_id).await.unwrap();
443 assert_eq!(info.valid_bytes as usize, buf.len());
444 let read_buf = vmgs.read_file(file_id).await.unwrap();
445
446 assert_eq!(buf, read_buf);
447 }
448}