disk_get_vmgs/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! A disk backend using the GET's VMGS block interface.
5//!
6//! This is coded as a general-purpose block device (I guess you could boot a VM
7//! off of it), but it is likely only useful for using as the VMGS backing
8//! device.
9
10#![cfg(target_os = "linux")]
11#![forbid(unsafe_code)]
12
13use disk_backend::DiskError;
14use disk_backend::DiskIo;
15use guest_emulation_transport::GuestEmulationTransportClient;
16use guestmem::MemoryRead;
17use guestmem::MemoryWrite;
18use inspect::Inspect;
19use save_restore::SavedBlockStorageMetadata;
20use scsi_buffers::RequestBuffers;
21use std::io;
22use thiserror::Error;
23
24/// An implementation of [`DiskIo`] backed by the GET.
25#[derive(Clone, Debug, Inspect)]
26pub struct GetVmgsDisk {
27    get: GuestEmulationTransportClient,
28    sector_size: u32,
29    sector_shift: u32,
30    physical_sector_size: u32,
31    sector_count: u64,
32    max_transfer_sectors: u32,
33    max_transfer_size_bytes: u32,
34}
35
36/// An error that can occur when creating a new [`GetVmgsDisk`].
37#[derive(Debug, Error)]
38pub enum NewGetVmgsDiskError {
39    /// An IO error occurred while fetching the device info.
40    #[error("GET VMGS IO error")]
41    Io(#[source] guest_emulation_transport::error::VmgsIoError),
42    /// The sector size is not a power of two.
43    #[error("invalid sector size")]
44    InvalidSectorSize,
45    /// The physical sector size is not a power of two or is smaller than the sector size.
46    #[error("invalid physical sector size")]
47    InvalidPhysicalSectorSize,
48    /// The sector count is too large.
49    #[error("invalid sector count")]
50    InvalidSectorCount,
51    /// The disk ends with a partial physical sector.
52    #[error("disk ends with a partial physical sector")]
53    IncompletePhysicalSector,
54    /// The maximum transfer size is smaller than the physical sector size.
55    #[error("transfer size is smaller than the physical sector size")]
56    InvalidMaxTransferSize,
57}
58
59impl GetVmgsDisk {
60    /// Returns a new disk instance, communicating read and write IOs over the
61    /// `get` transport.
62    pub async fn new(get: GuestEmulationTransportClient) -> Result<Self, NewGetVmgsDiskError> {
63        let response = get
64            .vmgs_get_device_info()
65            .await
66            .map_err(NewGetVmgsDiskError::Io)?;
67        Self::new_inner(
68            get,
69            response.bytes_per_logical_sector.into(),
70            response.bytes_per_physical_sector.into(),
71            response.capacity,
72            response.maximum_transfer_size_bytes,
73        )
74    }
75
76    /// Create a disk using metadata previously-fetched via [`Self::save_meta`].
77    ///
78    /// # Caution
79    ///
80    /// This method does not confirm that the provided `meta` is what would be
81    /// provided by `get`. Callers MUST ensure that the provided `meta` matches
82    /// the provided `get` instance.
83    ///
84    /// Failing to do so may result in data corruption/loss, though, notably: it
85    /// will _not_ result in any memory-unsafety (hence why the function isn't
86    /// marked `unsafe`).
87    pub fn restore_with_meta(
88        get: GuestEmulationTransportClient,
89        meta: SavedBlockStorageMetadata,
90    ) -> Result<Self, NewGetVmgsDiskError> {
91        Self::new_inner(
92            get,
93            meta.sector_size,
94            meta.physical_sector_size,
95            meta.sector_count,
96            meta.max_transfer_size_bytes,
97        )
98    }
99
100    /// Save the metadata for this disk, for use in passing to
101    /// [`Self::restore_with_meta`]
102    pub fn save_meta(&self) -> SavedBlockStorageMetadata {
103        SavedBlockStorageMetadata {
104            capacity: self.sector_count * self.sector_size as u64,
105            logical_sector_size: self.sector_size,
106            sector_count: self.sector_count,
107            sector_size: self.sector_size,
108            physical_sector_size: self.physical_sector_size,
109            max_transfer_size_bytes: self.max_transfer_size_bytes,
110        }
111    }
112
113    fn new_inner(
114        get: GuestEmulationTransportClient,
115        sector_size: u32,
116        physical_sector_size: u32,
117        sector_count: u64,
118        max_transfer_size: u32,
119    ) -> Result<Self, NewGetVmgsDiskError> {
120        if !sector_size.is_power_of_two() {
121            Err(NewGetVmgsDiskError::InvalidSectorSize)
122        } else if !physical_sector_size.is_power_of_two() || physical_sector_size < sector_size {
123            Err(NewGetVmgsDiskError::InvalidPhysicalSectorSize)
124        } else if sector_count.checked_mul(sector_size as u64).is_none() {
125            Err(NewGetVmgsDiskError::InvalidSectorCount)
126        } else if sector_count % (physical_sector_size / sector_size) as u64 != 0 {
127            Err(NewGetVmgsDiskError::IncompletePhysicalSector)
128        } else if max_transfer_size < physical_sector_size {
129            Err(NewGetVmgsDiskError::InvalidMaxTransferSize)
130        } else {
131            Ok(GetVmgsDisk {
132                get,
133                sector_size,
134                sector_shift: sector_size.trailing_zeros(),
135                physical_sector_size,
136                sector_count,
137                max_transfer_size_bytes: max_transfer_size,
138                // Always transfer in multiples of the physical sector size, if possible.
139                max_transfer_sectors: max_transfer_size / physical_sector_size
140                    * physical_sector_size
141                    / sector_size,
142            })
143        }
144    }
145}
146
147impl DiskIo for GetVmgsDisk {
148    fn disk_type(&self) -> &str {
149        "vmgs-get"
150    }
151
152    fn sector_count(&self) -> u64 {
153        self.sector_count
154    }
155
156    fn sector_size(&self) -> u32 {
157        self.sector_size
158    }
159
160    fn disk_id(&self) -> Option<[u8; 16]> {
161        None
162    }
163
164    fn physical_sector_size(&self) -> u32 {
165        self.physical_sector_size
166    }
167
168    fn is_fua_respected(&self) -> bool {
169        false
170    }
171
172    fn is_read_only(&self) -> bool {
173        false
174    }
175
176    async fn read_vectored(
177        &self,
178        buffers: &RequestBuffers<'_>,
179        mut sector: u64,
180    ) -> Result<(), DiskError> {
181        let mut writer = buffers.writer();
182        let mut remaining_sectors = buffers.len() >> self.sector_shift;
183        if sector + remaining_sectors as u64 > self.sector_count {
184            return Err(DiskError::IllegalBlock);
185        }
186        while remaining_sectors != 0 {
187            let this_sector_count = remaining_sectors.min(self.max_transfer_sectors as usize);
188            let data = self
189                .get
190                .vmgs_read(sector, this_sector_count as u32, self.sector_size)
191                .await
192                .map_err(|err| DiskError::Io(io::Error::other(err)))?;
193
194            writer.write(&data)?;
195            sector += this_sector_count as u64;
196            remaining_sectors -= this_sector_count;
197        }
198        Ok(())
199    }
200
201    async fn write_vectored(
202        &self,
203        buffers: &RequestBuffers<'_>,
204        mut sector: u64,
205        _fua: bool,
206    ) -> Result<(), DiskError> {
207        let mut reader = buffers.reader();
208        let mut remaining_sector_count = buffers.len() >> self.sector_shift;
209        if sector + remaining_sector_count as u64 > self.sector_count {
210            return Err(DiskError::IllegalBlock);
211        }
212        while remaining_sector_count != 0 {
213            let this_sector_count = remaining_sector_count.min(self.max_transfer_sectors as usize);
214            let data = reader.read_n(this_sector_count << self.sector_shift)?;
215            self.get
216                .vmgs_write(sector, data, self.sector_size)
217                .await
218                .map_err(|err| DiskError::Io(io::Error::other(err)))?;
219
220            remaining_sector_count -= this_sector_count;
221            sector += this_sector_count as u64;
222        }
223        Ok(())
224    }
225
226    /// Issues an asynchronous flush operation to the disk.
227    async fn sync_cache(&self) -> Result<(), DiskError> {
228        self.get
229            .vmgs_flush()
230            .await
231            .map_err(|err| DiskError::Io(io::Error::other(err)))
232    }
233
234    async fn unmap(
235        &self,
236        _sector: u64,
237        _count: u64,
238        _block_level_only: bool,
239    ) -> Result<(), DiskError> {
240        Ok(())
241    }
242
243    fn unmap_behavior(&self) -> disk_backend::UnmapBehavior {
244        disk_backend::UnmapBehavior::Ignored
245    }
246}
247
248/// Save/restore structure definitions.
249pub mod save_restore {
250    use mesh::payload::Protobuf;
251
252    /// Metadata for a saved block storage device.
253    #[derive(Protobuf, Clone)]
254    #[mesh(package = "vmgs")]
255    pub struct SavedBlockStorageMetadata {
256        /// The byte capacity. Redundant with sector_count * sector_size.
257        #[mesh(1)]
258        pub capacity: u64,
259        /// The logical sector size. Identical to sector_size.
260        #[mesh(2)]
261        pub logical_sector_size: u32,
262        /// The number of sectors.
263        #[mesh(3)]
264        pub sector_count: u64,
265        /// The sector size in bytes.
266        #[mesh(4)]
267        pub sector_size: u32,
268        /// The physical sector size in bytes.
269        #[mesh(5)]
270        pub physical_sector_size: u32,
271        /// The maximum transfer size in bytes.
272        #[mesh(6)]
273        pub max_transfer_size_bytes: u32,
274    }
275}
276
277// TODO: remove the VMGS specific tests and just test the `DiskIo` interfaces.
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use disk_backend::Disk;
282    use guest_emulation_transport::api::ProtocolVersion;
283    use guest_emulation_transport::test_utilities::TestGet;
284    use guest_emulation_transport::test_utilities::new_transport_pair;
285    use pal_async::DefaultDriver;
286    use pal_async::async_test;
287    use pal_async::task::Task;
288    use vmgs::FileId;
289    use vmgs::Vmgs;
290    use vmgs_broker::VmgsClient;
291    use vmgs_broker::spawn_vmgs_broker;
292
293    async fn spawn_vmgs(driver: &DefaultDriver) -> (VmgsClient, TestGet, Task<()>) {
294        let get = new_transport_pair(driver, None, ProtocolVersion::NICKEL_REV2).await;
295        let vmgs_get = GetVmgsDisk::new(get.client.clone()).await.unwrap();
296        let vmgs = Vmgs::format_new(Disk::new(vmgs_get).unwrap(), None)
297            .await
298            .unwrap();
299        let (vmgs, task) = spawn_vmgs_broker(driver, vmgs);
300        (vmgs, get, task)
301    }
302
303    #[async_test]
304    async fn basic_read_write(driver: DefaultDriver) {
305        let (vmgs, _get, _task) = spawn_vmgs(&driver).await;
306        let file_id = FileId::BIOS_NVRAM;
307
308        // write
309        let buf = b"hello world".to_vec();
310        vmgs.write_file(file_id, buf.clone()).await.unwrap();
311
312        // read
313        let info = vmgs.get_file_info(file_id).await.unwrap();
314        assert_eq!(info.valid_bytes as usize, buf.len());
315        let read_buf = vmgs.read_file(file_id).await.unwrap();
316
317        assert_eq!(buf, read_buf);
318    }
319
320    #[async_test]
321    async fn multiple_read_write(driver: DefaultDriver) {
322        let (vmgs, _get, _task) = spawn_vmgs(&driver).await;
323        let file_id_1 = FileId::BIOS_NVRAM;
324        let file_id_2 = FileId::TPM_PPI;
325        let buf_1 = b"Data data data".to_vec();
326        let buf_2 = b"password".to_vec();
327        let buf_3 = b"other data data".to_vec();
328
329        vmgs.write_file(file_id_1, buf_1.clone()).await.unwrap();
330        let info = vmgs.get_file_info(file_id_1).await.unwrap();
331        assert_eq!(info.valid_bytes as usize, buf_1.len());
332        let read_buf_1 = vmgs.read_file(file_id_1).await.unwrap();
333        assert_eq!(buf_1, read_buf_1);
334
335        vmgs.write_file(file_id_2, buf_2.clone()).await.unwrap();
336        let info = vmgs.get_file_info(file_id_2).await.unwrap();
337        assert_eq!(info.valid_bytes as usize, buf_2.len());
338        let read_buf_2 = vmgs.read_file(file_id_2).await.unwrap();
339        assert_eq!(buf_2, read_buf_2);
340
341        vmgs.write_file(file_id_1, buf_3.clone()).await.unwrap();
342        let info = vmgs.get_file_info(file_id_1).await.unwrap();
343        assert_eq!(info.valid_bytes as usize, buf_3.len());
344        let read_buf_3 = vmgs.read_file(file_id_1).await.unwrap();
345        assert_eq!(buf_3, read_buf_3);
346
347        vmgs.write_file(file_id_1, buf_1.clone()).await.unwrap();
348        let info = vmgs.get_file_info(file_id_1).await.unwrap();
349        assert_eq!(info.valid_bytes as usize, buf_1.len());
350        let read_buf_1 = vmgs.read_file(file_id_1).await.unwrap();
351        assert_eq!(buf_1, read_buf_1);
352
353        vmgs.write_file(file_id_2, buf_2.clone()).await.unwrap();
354        let info = vmgs.get_file_info(file_id_2).await.unwrap();
355        assert_eq!(info.valid_bytes as usize, buf_2.len());
356        let read_buf_2 = vmgs.read_file(file_id_2).await.unwrap();
357        assert_eq!(buf_2, read_buf_2);
358
359        vmgs.write_file(file_id_1, buf_3.clone()).await.unwrap();
360        let info = vmgs.get_file_info(file_id_1).await.unwrap();
361        assert_eq!(info.valid_bytes as usize, buf_3.len());
362        let read_buf_3 = vmgs.read_file(file_id_1).await.unwrap();
363        assert_eq!(buf_3, read_buf_3);
364    }
365
366    #[async_test]
367    async fn test_empty_write(driver: DefaultDriver) {
368        let (vmgs, _get, _task) = spawn_vmgs(&driver).await;
369        let file_id = FileId::BIOS_NVRAM;
370
371        let buf: Vec<u8> = Vec::new();
372        vmgs.write_file(file_id, buf.clone()).await.unwrap();
373
374        // read
375        let info = vmgs.get_file_info(file_id).await.unwrap();
376        assert_eq!(info.valid_bytes as usize, 0);
377        let read_buf = vmgs.read_file(file_id).await.unwrap();
378
379        assert_eq!(buf, read_buf);
380        assert_eq!(read_buf.len(), 0);
381    }
382
383    #[async_test]
384    async fn test_read_write_large(driver: DefaultDriver) {
385        let (vmgs, _get, _task) = spawn_vmgs(&driver).await;
386        let file_id = FileId::BIOS_NVRAM;
387
388        // write
389        let buf: Vec<u8> = (0..).map(|x| x as u8).take(1024 * 4 * 4 + 1).collect();
390        vmgs.write_file(file_id, buf.clone()).await.unwrap();
391
392        // read
393        let info = vmgs.get_file_info(file_id).await.unwrap();
394        assert_eq!(info.valid_bytes as usize, buf.len());
395        let read_buf = vmgs.read_file(file_id).await.unwrap();
396
397        assert_eq!(buf, read_buf);
398    }
399
400    #[async_test]
401    async fn test_read_write_encryption(driver: DefaultDriver) {
402        let get = new_transport_pair(&driver, None, ProtocolVersion::NICKEL_REV2).await;
403        let vmgs_get = GetVmgsDisk::new(get.client.clone()).await.unwrap();
404        let mut vmgs = Vmgs::format_new(Disk::new(vmgs_get).unwrap(), None)
405            .await
406            .unwrap();
407        let file_id = FileId::BIOS_NVRAM;
408        let encryption_key = vec![1; 32];
409
410        vmgs.add_new_encryption_key(&encryption_key, vmgs::EncryptionAlgorithm::AES_GCM)
411            .await
412            .unwrap();
413
414        // write
415        let buf: Vec<u8> = (0..).map(|x| x as u8).take(1024 * 4 * 4 + 1).collect();
416        vmgs.write_file_encrypted(file_id, &buf).await.unwrap();
417
418        // read
419        let info = vmgs.get_file_info(file_id).unwrap();
420        assert_eq!(info.valid_bytes as usize, buf.len());
421        let read_buf = vmgs.read_file(file_id).await.unwrap();
422
423        assert_eq!(buf, read_buf);
424
425        drop(vmgs);
426
427        let vmgs_get = GetVmgsDisk::new(get.client.clone()).await.unwrap();
428        let mut vmgs = Vmgs::open(Disk::new(vmgs_get).unwrap(), None)
429            .await
430            .unwrap();
431
432        let read_buf = vmgs.read_file(file_id).await.unwrap();
433
434        assert_ne!(buf, read_buf);
435
436        vmgs.unlock_with_encryption_key(&encryption_key)
437            .await
438            .unwrap();
439
440        let (vmgs, _task) = spawn_vmgs_broker(&driver, vmgs);
441
442        let info = vmgs.get_file_info(file_id).await.unwrap();
443        assert_eq!(info.valid_bytes as usize, buf.len());
444        let read_buf = vmgs.read_file(file_id).await.unwrap();
445
446        assert_eq!(buf, read_buf);
447    }
448}