hcl/
vmsa.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Interface to `VmsaWrapper`, which combines a SEV-SNP VMSA
5//! with a bitmap to allow for register protection.
6
7use std::array;
8use std::ops::Deref;
9use std::ops::DerefMut;
10use x86defs::snp::SevEventInjectInfo;
11use x86defs::snp::SevFeatures;
12use x86defs::snp::SevSelector;
13use x86defs::snp::SevVirtualInterruptControl;
14use x86defs::snp::SevVmsa;
15use x86defs::snp::SevXmmRegister;
16use zerocopy::FromZeros;
17use zerocopy::IntoBytes;
18
19/// VMSA and register tweak bitmap.
20pub struct VmsaWrapper<'a, T> {
21    vmsa: T,
22    bitmap: &'a [u8; 64],
23}
24
25impl<'a, T> VmsaWrapper<'a, T> {
26    /// Create a VmsaWrapper
27    pub(crate) fn new(vmsa: T, bitmap: &'a [u8; 64]) -> Self {
28        VmsaWrapper { vmsa, bitmap }
29    }
30}
31
32/// Wraps a SEV VMSA structure with the register tweak bitmap to provide safe access methods.
33impl<T: Deref<Target = SevVmsa>> VmsaWrapper<'_, T> {
34    /// 64 bit register read
35    fn get_u64(&self, offset: usize) -> u64 {
36        assert!(offset % 8 == 0);
37        let vmsa_raw = &self.vmsa;
38        let v = u64::from_ne_bytes(vmsa_raw.as_bytes()[offset..offset + 8].try_into().unwrap());
39        if is_protected(self.bitmap, offset) {
40            v ^ self.vmsa.register_protection_nonce
41        } else {
42            v
43        }
44    }
45    /// 32 bit register read
46    fn get_u32(&self, offset: usize) -> u32 {
47        assert!(offset % 4 == 0);
48        (self.get_u64(offset & !7) >> ((offset & 4) * 8)) as u32
49    }
50    /// 128 bit register read
51    fn get_u128(&self, offset: usize) -> u128 {
52        self.get_u64(offset) as u128 | ((self.get_u64(offset + 8) as u128) << 64)
53    }
54
55    /// Gets an XMM VMSA register as u128
56    pub fn xmm_registers(&self, n: usize) -> u128 {
57        assert!(n < 16);
58        let off = std::mem::offset_of!(SevVmsa, xmm_registers) + (n * 16);
59        self.get_u128(off)
60    }
61
62    /// Gets a YMM VMSA register as u128
63    pub fn ymm_registers(&self, n: usize) -> u128 {
64        assert!(n < 16);
65        let off = std::mem::offset_of!(SevVmsa, ymm_registers) + (n * 16);
66        self.get_u128(off)
67    }
68
69    /// Gets the x87 VMSA registers
70    pub fn x87_registers(&self) -> [u64; 10] {
71        let base = std::mem::offset_of!(SevVmsa, x87_registers);
72        array::from_fn(|i| i * 8).map(|offset| self.get_u64(base + offset))
73    }
74}
75
76/// Wraps a mutable SEV VMSA structure with the register tweak bitmap to provide safe access methods.
77impl<T: DerefMut<Target = SevVmsa>> VmsaWrapper<'_, T> {
78    /// 64 bit value to set in register
79    fn set_u64(&self, v: u64, offset: usize) -> u64 {
80        assert!(offset % 8 == 0);
81        if is_protected(self.bitmap, offset) {
82            v ^ self.vmsa.register_protection_nonce
83        } else {
84            v
85        }
86    }
87    /// 32 bit value to set in register
88    fn set_u32(&self, v: u32, offset: usize) -> u32 {
89        assert!(offset % 4 == 0);
90        let val = (v as u64) << ((offset & 4) * 8);
91        (self.set_u64(val, offset & !7) >> ((offset & 4) * 8)) as u32
92    }
93    /// 128 bit value to set in register
94    fn set_u128(&self, v: u128, offset: usize) -> u128 {
95        self.set_u64(v as u64, offset) as u128
96            | ((self.set_u64((v >> 64) as u64, offset + 8) as u128) << 64)
97    }
98
99    /// Create a new VMSA
100    pub fn reset(&mut self, vmsa_reg_prot: bool) {
101        *self.vmsa = FromZeros::new_zeroed();
102        if vmsa_reg_prot {
103            // Initialize nonce and all protected fields.
104            getrandom::fill(self.vmsa.register_protection_nonce.as_mut_bytes())
105                .expect("rng failure");
106            let nonce = self.vmsa.register_protection_nonce;
107            let chunk_size = 8;
108            for (i, b) in self
109                .vmsa
110                .as_mut_bytes()
111                .chunks_exact_mut(chunk_size)
112                .enumerate()
113            {
114                let field_offset = i * chunk_size;
115                // Ensure direct accesses are not included in bitmap.
116                if field_offset == (std::mem::offset_of!(SevVmsa, vmpl) & !7)
117                    || field_offset == std::mem::offset_of!(SevVmsa, exit_info1)
118                    || field_offset == std::mem::offset_of!(SevVmsa, exit_info2)
119                    || field_offset == std::mem::offset_of!(SevVmsa, exit_int_info)
120                    || field_offset == std::mem::offset_of!(SevVmsa, sev_features)
121                    || field_offset == std::mem::offset_of!(SevVmsa, v_intr_cntrl)
122                    || field_offset == std::mem::offset_of!(SevVmsa, guest_error_code)
123                    || field_offset == std::mem::offset_of!(SevVmsa, virtual_tom)
124                {
125                    assert!(!is_protected(self.bitmap, field_offset));
126                }
127                if is_protected(self.bitmap, field_offset) {
128                    b.copy_from_slice(&nonce.to_ne_bytes());
129                }
130            }
131        }
132    }
133
134    /// Sets an XMM VMSA register from u128
135    pub fn set_xmm_registers(&mut self, n: usize, v: u128) {
136        assert!(n < 16);
137        let off = std::mem::offset_of!(SevVmsa, xmm_registers) + (n * 16);
138        let val: SevXmmRegister = self.set_u128(v, off).into();
139        let vmsa_raw = &mut *self.vmsa;
140        vmsa_raw.xmm_registers[n] = val;
141    }
142
143    /// Sets an XMM VMSA register from u128
144    pub fn set_ymm_registers(&mut self, n: usize, v: u128) {
145        assert!(n < 16);
146        let off = std::mem::offset_of!(SevVmsa, ymm_registers) + (n * 16);
147        let val: SevXmmRegister = self.set_u128(v, off).into();
148        let vmsa_raw = &mut *self.vmsa;
149        vmsa_raw.ymm_registers[n] = val;
150    }
151
152    /// Sets the x87 registers
153    pub fn set_x87_registers(&mut self, v: &[u64; 10]) {
154        let base = std::mem::offset_of!(SevVmsa, x87_registers);
155        for (i, new_v) in v.iter().enumerate() {
156            let val = self.set_u64(*new_v, base + (i * 8));
157            self.vmsa.x87_registers[i] = val;
158        }
159    }
160}
161
162/// Check bitmap to see if a register is included in masking.
163fn is_protected(bitmap: &[u8; 64], field_offset: usize) -> bool {
164    let byte_index = field_offset / 64;
165    let bit_index = (field_offset % 64) / 8;
166    bitmap[byte_index] & (1 << bit_index) != 0
167}
168
169macro_rules! regss {
170    ($reg:ident, $set:ident) => {
171        impl<T: Deref<Target = SevVmsa>> VmsaWrapper<'_, T> {
172            /// Gets a SevSelector VMSA register
173            pub fn $reg(&self) -> SevSelector {
174                SevSelector::from(self.get_u128(std::mem::offset_of!(SevVmsa, $reg)))
175            }
176        }
177        impl<T: DerefMut<Target = SevVmsa>> VmsaWrapper<'_, T> {
178            /// Sets a SevSelector VMSA register
179            pub fn $set(&mut self, v: SevSelector) {
180                let val = SevSelector::from(
181                    self.set_u128(v.as_u128(), std::mem::offset_of!(SevVmsa, $reg)),
182                );
183                let vmsa_raw = &mut *self.vmsa;
184                vmsa_raw.$reg = val;
185            }
186        }
187    };
188}
189macro_rules! reg64 {
190    ($reg:ident, $set:ident) => {
191        impl<T: Deref<Target = SevVmsa>> VmsaWrapper<'_, T> {
192            /// Gets a VMSA register
193            pub fn $reg(&self) -> u64 {
194                self.get_u64(std::mem::offset_of!(SevVmsa, $reg))
195            }
196        }
197        impl<T: DerefMut<Target = SevVmsa>> VmsaWrapper<'_, T> {
198            /// Sets a VMSA register
199            pub fn $set(&mut self, v: u64) {
200                let val = self.set_u64(v, std::mem::offset_of!(SevVmsa, $reg));
201                let vmsa_raw = &mut *self.vmsa;
202                vmsa_raw.$reg = val;
203            }
204        }
205    };
206}
207macro_rules! reg32 {
208    ($reg:ident, $set:ident) => {
209        impl<T: Deref<Target = SevVmsa>> VmsaWrapper<'_, T> {
210            /// Gets a VMSA register
211            pub fn $reg(&self) -> u32 {
212                self.get_u32(std::mem::offset_of!(SevVmsa, $reg))
213            }
214        }
215        impl<T: DerefMut<Target = SevVmsa>> VmsaWrapper<'_, T> {
216            /// Sets a VMSA register
217            pub fn $set(&mut self, v: u32) {
218                let val = self.set_u32(v, std::mem::offset_of!(SevVmsa, $reg));
219                let vmsa_raw = &mut *self.vmsa;
220                vmsa_raw.$reg = val;
221            }
222        }
223    };
224}
225macro_rules! get_reg_direct {
226    ($reg:ident, $ty:ty) => {
227        impl<T: Deref<Target = SevVmsa>> VmsaWrapper<'_, T> {
228            /// Gets a VMSA register directly
229            pub fn $reg(&self) -> $ty {
230                let vmsa_raw = &self.vmsa;
231                vmsa_raw.$reg
232            }
233        }
234    };
235}
236macro_rules! reg_direct {
237    ($reg:ident, $set:ident, $ty:ty) => {
238        get_reg_direct!($reg, $ty);
239        impl<T: DerefMut<Target = SevVmsa>> VmsaWrapper<'_, T> {
240            /// Sets a VMSA register directly
241            pub fn $set(&mut self, v: $ty) {
242                let vmsa_raw = &mut *self.vmsa;
243                vmsa_raw.$reg = v;
244            }
245        }
246    };
247}
248macro_rules! reg_direct_mut {
249    ($reg:ident, $set:ident, $ty:ty) => {
250        get_reg_direct!($reg, $ty);
251        impl<T: DerefMut<Target = SevVmsa>> VmsaWrapper<'_, T> {
252            /// Access VMSA field directly in order to manipulate fields.
253            pub fn $set(&mut self) -> &mut $ty {
254                &mut self.vmsa.$reg
255            }
256        }
257    };
258}
259
260reg_direct!(vmpl, set_vmpl, u8);
261get_reg_direct!(cpl, u8);
262get_reg_direct!(exit_info1, u64);
263get_reg_direct!(exit_info2, u64);
264reg_direct!(exit_int_info, set_exit_int_info, u64);
265reg_direct_mut!(sev_features, sev_features_mut, SevFeatures);
266reg_direct_mut!(v_intr_cntrl, v_intr_cntrl_mut, SevVirtualInterruptControl);
267reg_direct!(virtual_tom, set_virtual_tom, u64);
268reg_direct!(event_inject, set_event_inject, SevEventInjectInfo);
269reg_direct!(guest_error_code, set_guest_error_code, u64);
270regss!(es, set_es);
271regss!(cs, set_cs);
272regss!(ss, set_ss);
273regss!(ds, set_ds);
274regss!(fs, set_fs);
275regss!(gs, set_gs);
276regss!(gdtr, set_gdtr);
277regss!(ldtr, set_ldtr);
278regss!(idtr, set_idtr);
279regss!(tr, set_tr);
280reg64!(pl0_ssp, set_pl0_ssp);
281reg64!(pl1_ssp, set_pl1_ssp);
282reg64!(pl2_ssp, set_pl2_ssp);
283reg64!(pl3_ssp, set_pl3_ssp);
284reg64!(u_cet, set_u_cet);
285reg64!(efer, set_efer);
286reg64!(xss, set_xss);
287reg64!(cr4, set_cr4);
288reg64!(cr3, set_cr3);
289reg64!(cr0, set_cr0);
290reg64!(dr7, set_dr7);
291reg64!(dr6, set_dr6);
292reg64!(rflags, set_rflags);
293reg64!(rip, set_rip);
294reg64!(dr0, set_dr0);
295reg64!(dr1, set_dr1);
296reg64!(dr2, set_dr2);
297reg64!(dr3, set_dr3);
298reg64!(rsp, set_rsp);
299reg64!(s_cet, set_s_cet);
300reg64!(ssp, set_ssp);
301reg64!(interrupt_ssp_table_addr, set_interrupt_ssp_table_addr);
302reg64!(rax, set_rax);
303reg64!(star, set_star);
304reg64!(lstar, set_lstar);
305reg64!(cstar, set_cstar);
306reg64!(sfmask, set_sfmask);
307reg64!(kernel_gs_base, set_kernel_gs_base);
308reg64!(sysenter_cs, set_sysenter_cs);
309reg64!(sysenter_esp, set_sysenter_esp);
310reg64!(sysenter_eip, set_sysenter_eip);
311reg64!(cr2, set_cr2);
312reg64!(pat, set_pat);
313reg64!(spec_ctrl, set_spec_ctrl);
314reg32!(tsc_aux, set_tsc_aux);
315reg64!(rcx, set_rcx);
316reg64!(rdx, set_rdx);
317reg64!(rbx, set_rbx);
318reg64!(rbp, set_rbp);
319reg64!(rsi, set_rsi);
320reg64!(rdi, set_rdi);
321reg64!(r8, set_r8);
322reg64!(r9, set_r9);
323reg64!(r10, set_r10);
324reg64!(r11, set_r11);
325reg64!(r12, set_r12);
326reg64!(r13, set_r13);
327reg64!(r14, set_r14);
328reg64!(r15, set_r15);
329reg64!(next_rip, set_next_rip);
330reg64!(pcpu_id, set_pcpu_id);
331reg64!(xcr0, set_xcr0);
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    #[test]
338    fn test_reg_access() {
339        let nonce = 0xffff_ffff_ffff_ffffu64;
340        let nonce128 = ((nonce as u128) << 64) | nonce as u128;
341        let mut vmsa: SevVmsa = FromZeros::new_zeroed();
342        vmsa.register_protection_nonce = nonce;
343        let bitmap = [0xffu8; 64];
344        let mut vmsa_wrapper = VmsaWrapper {
345            vmsa: &mut vmsa,
346            bitmap: &bitmap,
347        };
348
349        let val = 0x0000_0055_0000_0055u128;
350        let val_xor = val ^ nonce128;
351        let cs = SevSelector::from(val);
352        let cs_xor = SevSelector::from(val_xor);
353        let vmpl = 2u8;
354        let rip = 0x55u64;
355        let rip_xor = rip ^ nonce;
356        let tsc = 0x55u32;
357        let tsc_xor = tsc ^ (nonce as u32);
358        let xmm_idx = 1;
359        let ymm_idx = 1;
360        let x87 = [0x55u64; 10];
361        let x87_xor = x87.map(|v| v ^ nonce);
362
363        vmsa_wrapper.set_cs(cs);
364        vmsa_wrapper.set_vmpl(vmpl);
365        vmsa_wrapper.set_rip(rip);
366        vmsa_wrapper.set_tsc_aux(tsc);
367        vmsa_wrapper.set_xmm_registers(xmm_idx, val);
368        vmsa_wrapper.set_ymm_registers(ymm_idx, val);
369        vmsa_wrapper.set_x87_registers(&x87);
370
371        assert!(vmsa_wrapper.cs() == cs);
372        assert!(vmsa_wrapper.vmpl() == vmpl);
373        assert!(vmsa_wrapper.rip() == rip);
374        assert!(vmsa_wrapper.xmm_registers(xmm_idx) == val);
375        assert!(vmsa_wrapper.ymm_registers(ymm_idx) == val);
376        assert!(vmsa_wrapper.tsc_aux() == tsc);
377        assert!(vmsa_wrapper.x87_registers() == x87);
378        assert!(vmsa.cs == cs_xor); // bitmask applied to u128
379        assert!(vmsa.vmpl == vmpl); // no bitmask applied
380        assert!(vmsa.rip == rip_xor); // bitmask applied
381        assert!(vmsa.tsc_aux == tsc_xor); // bitmask applied to u32
382        assert!(vmsa.pkru == 0); // untouched
383        assert!(vmsa.xmm_registers[xmm_idx].as_u128() == val_xor); // bitmask applied to correct XMM offset
384        assert!(vmsa.ymm_registers[ymm_idx].as_u128() == val_xor); // bitmask applied to correct YMM offset
385        assert!(vmsa.x87_registers == x87_xor);
386    }
387
388    #[test]
389    fn test_init() {
390        let mut vmsa: SevVmsa = FromZeros::new_zeroed();
391        let mut bitmap = [0x0u8; 64];
392        let xmm_idx = 1;
393        bitmap[5] = 0x80u8; // rip
394        bitmap[18] = 0x03u8; // xmm_registers[1]
395        let mut vmsa_wrapper = VmsaWrapper {
396            vmsa: &mut vmsa,
397            bitmap: &bitmap,
398        };
399        vmsa_wrapper.reset(true);
400
401        assert!(vmsa_wrapper.rip() == 0);
402        assert!(vmsa_wrapper.xmm_registers(xmm_idx) == 0);
403
404        let nonce = vmsa.register_protection_nonce;
405        let xmm_val = ((nonce as u128) << 64) | nonce as u128;
406        assert!(vmsa.rip == nonce);
407        assert!(vmsa.xmm_registers[xmm_idx].as_u128() == xmm_val);
408    }
409}