1use std::array;
8use std::ops::Deref;
9use std::ops::DerefMut;
10use x86defs::snp::SevEventInjectInfo;
11use x86defs::snp::SevFeatures;
12use x86defs::snp::SevSelector;
13use x86defs::snp::SevVirtualInterruptControl;
14use x86defs::snp::SevVmsa;
15use x86defs::snp::SevXmmRegister;
16use zerocopy::FromZeros;
17use zerocopy::IntoBytes;
18
19pub struct VmsaWrapper<'a, T> {
21 vmsa: T,
22 bitmap: &'a [u8; 64],
23}
24
25impl<'a, T> VmsaWrapper<'a, T> {
26 pub(crate) fn new(vmsa: T, bitmap: &'a [u8; 64]) -> Self {
28 VmsaWrapper { vmsa, bitmap }
29 }
30}
31
32impl<T: Deref<Target = SevVmsa>> VmsaWrapper<'_, T> {
34 fn get_u64(&self, offset: usize) -> u64 {
36 assert!(offset % 8 == 0);
37 let vmsa_raw = &self.vmsa;
38 let v = u64::from_ne_bytes(vmsa_raw.as_bytes()[offset..offset + 8].try_into().unwrap());
39 if is_protected(self.bitmap, offset) {
40 v ^ self.vmsa.register_protection_nonce
41 } else {
42 v
43 }
44 }
45 fn get_u32(&self, offset: usize) -> u32 {
47 assert!(offset % 4 == 0);
48 (self.get_u64(offset & !7) >> ((offset & 4) * 8)) as u32
49 }
50 fn get_u128(&self, offset: usize) -> u128 {
52 self.get_u64(offset) as u128 | ((self.get_u64(offset + 8) as u128) << 64)
53 }
54
55 pub fn xmm_registers(&self, n: usize) -> u128 {
57 assert!(n < 16);
58 let off = std::mem::offset_of!(SevVmsa, xmm_registers) + (n * 16);
59 self.get_u128(off)
60 }
61
62 pub fn ymm_registers(&self, n: usize) -> u128 {
64 assert!(n < 16);
65 let off = std::mem::offset_of!(SevVmsa, ymm_registers) + (n * 16);
66 self.get_u128(off)
67 }
68
69 pub fn x87_registers(&self) -> [u64; 10] {
71 let base = std::mem::offset_of!(SevVmsa, x87_registers);
72 array::from_fn(|i| i * 8).map(|offset| self.get_u64(base + offset))
73 }
74}
75
76impl<T: DerefMut<Target = SevVmsa>> VmsaWrapper<'_, T> {
78 fn set_u64(&self, v: u64, offset: usize) -> u64 {
80 assert!(offset % 8 == 0);
81 if is_protected(self.bitmap, offset) {
82 v ^ self.vmsa.register_protection_nonce
83 } else {
84 v
85 }
86 }
87 fn set_u32(&self, v: u32, offset: usize) -> u32 {
89 assert!(offset % 4 == 0);
90 let val = (v as u64) << ((offset & 4) * 8);
91 (self.set_u64(val, offset & !7) >> ((offset & 4) * 8)) as u32
92 }
93 fn set_u128(&self, v: u128, offset: usize) -> u128 {
95 self.set_u64(v as u64, offset) as u128
96 | ((self.set_u64((v >> 64) as u64, offset + 8) as u128) << 64)
97 }
98
99 pub fn reset(&mut self, vmsa_reg_prot: bool) {
101 *self.vmsa = FromZeros::new_zeroed();
102 if vmsa_reg_prot {
103 getrandom::fill(self.vmsa.register_protection_nonce.as_mut_bytes())
105 .expect("rng failure");
106 let nonce = self.vmsa.register_protection_nonce;
107 let chunk_size = 8;
108 for (i, b) in self
109 .vmsa
110 .as_mut_bytes()
111 .chunks_exact_mut(chunk_size)
112 .enumerate()
113 {
114 let field_offset = i * chunk_size;
115 if field_offset == (std::mem::offset_of!(SevVmsa, vmpl) & !7)
117 || field_offset == std::mem::offset_of!(SevVmsa, exit_info1)
118 || field_offset == std::mem::offset_of!(SevVmsa, exit_info2)
119 || field_offset == std::mem::offset_of!(SevVmsa, exit_int_info)
120 || field_offset == std::mem::offset_of!(SevVmsa, sev_features)
121 || field_offset == std::mem::offset_of!(SevVmsa, v_intr_cntrl)
122 || field_offset == std::mem::offset_of!(SevVmsa, guest_error_code)
123 || field_offset == std::mem::offset_of!(SevVmsa, virtual_tom)
124 {
125 assert!(!is_protected(self.bitmap, field_offset));
126 }
127 if is_protected(self.bitmap, field_offset) {
128 b.copy_from_slice(&nonce.to_ne_bytes());
129 }
130 }
131 }
132 }
133
134 pub fn set_xmm_registers(&mut self, n: usize, v: u128) {
136 assert!(n < 16);
137 let off = std::mem::offset_of!(SevVmsa, xmm_registers) + (n * 16);
138 let val: SevXmmRegister = self.set_u128(v, off).into();
139 let vmsa_raw = &mut *self.vmsa;
140 vmsa_raw.xmm_registers[n] = val;
141 }
142
143 pub fn set_ymm_registers(&mut self, n: usize, v: u128) {
145 assert!(n < 16);
146 let off = std::mem::offset_of!(SevVmsa, ymm_registers) + (n * 16);
147 let val: SevXmmRegister = self.set_u128(v, off).into();
148 let vmsa_raw = &mut *self.vmsa;
149 vmsa_raw.ymm_registers[n] = val;
150 }
151
152 pub fn set_x87_registers(&mut self, v: &[u64; 10]) {
154 let base = std::mem::offset_of!(SevVmsa, x87_registers);
155 for (i, new_v) in v.iter().enumerate() {
156 let val = self.set_u64(*new_v, base + (i * 8));
157 self.vmsa.x87_registers[i] = val;
158 }
159 }
160}
161
162fn is_protected(bitmap: &[u8; 64], field_offset: usize) -> bool {
164 let byte_index = field_offset / 64;
165 let bit_index = (field_offset % 64) / 8;
166 bitmap[byte_index] & (1 << bit_index) != 0
167}
168
169macro_rules! regss {
170 ($reg:ident, $set:ident) => {
171 impl<T: Deref<Target = SevVmsa>> VmsaWrapper<'_, T> {
172 pub fn $reg(&self) -> SevSelector {
174 SevSelector::from(self.get_u128(std::mem::offset_of!(SevVmsa, $reg)))
175 }
176 }
177 impl<T: DerefMut<Target = SevVmsa>> VmsaWrapper<'_, T> {
178 pub fn $set(&mut self, v: SevSelector) {
180 let val = SevSelector::from(
181 self.set_u128(v.as_u128(), std::mem::offset_of!(SevVmsa, $reg)),
182 );
183 let vmsa_raw = &mut *self.vmsa;
184 vmsa_raw.$reg = val;
185 }
186 }
187 };
188}
189macro_rules! reg64 {
190 ($reg:ident, $set:ident) => {
191 impl<T: Deref<Target = SevVmsa>> VmsaWrapper<'_, T> {
192 pub fn $reg(&self) -> u64 {
194 self.get_u64(std::mem::offset_of!(SevVmsa, $reg))
195 }
196 }
197 impl<T: DerefMut<Target = SevVmsa>> VmsaWrapper<'_, T> {
198 pub fn $set(&mut self, v: u64) {
200 let val = self.set_u64(v, std::mem::offset_of!(SevVmsa, $reg));
201 let vmsa_raw = &mut *self.vmsa;
202 vmsa_raw.$reg = val;
203 }
204 }
205 };
206}
207macro_rules! reg32 {
208 ($reg:ident, $set:ident) => {
209 impl<T: Deref<Target = SevVmsa>> VmsaWrapper<'_, T> {
210 pub fn $reg(&self) -> u32 {
212 self.get_u32(std::mem::offset_of!(SevVmsa, $reg))
213 }
214 }
215 impl<T: DerefMut<Target = SevVmsa>> VmsaWrapper<'_, T> {
216 pub fn $set(&mut self, v: u32) {
218 let val = self.set_u32(v, std::mem::offset_of!(SevVmsa, $reg));
219 let vmsa_raw = &mut *self.vmsa;
220 vmsa_raw.$reg = val;
221 }
222 }
223 };
224}
225macro_rules! get_reg_direct {
226 ($reg:ident, $ty:ty) => {
227 impl<T: Deref<Target = SevVmsa>> VmsaWrapper<'_, T> {
228 pub fn $reg(&self) -> $ty {
230 let vmsa_raw = &self.vmsa;
231 vmsa_raw.$reg
232 }
233 }
234 };
235}
236macro_rules! reg_direct {
237 ($reg:ident, $set:ident, $ty:ty) => {
238 get_reg_direct!($reg, $ty);
239 impl<T: DerefMut<Target = SevVmsa>> VmsaWrapper<'_, T> {
240 pub fn $set(&mut self, v: $ty) {
242 let vmsa_raw = &mut *self.vmsa;
243 vmsa_raw.$reg = v;
244 }
245 }
246 };
247}
248macro_rules! reg_direct_mut {
249 ($reg:ident, $set:ident, $ty:ty) => {
250 get_reg_direct!($reg, $ty);
251 impl<T: DerefMut<Target = SevVmsa>> VmsaWrapper<'_, T> {
252 pub fn $set(&mut self) -> &mut $ty {
254 &mut self.vmsa.$reg
255 }
256 }
257 };
258}
259
260reg_direct!(vmpl, set_vmpl, u8);
261get_reg_direct!(cpl, u8);
262get_reg_direct!(exit_info1, u64);
263get_reg_direct!(exit_info2, u64);
264reg_direct!(exit_int_info, set_exit_int_info, u64);
265reg_direct_mut!(sev_features, sev_features_mut, SevFeatures);
266reg_direct_mut!(v_intr_cntrl, v_intr_cntrl_mut, SevVirtualInterruptControl);
267reg_direct!(virtual_tom, set_virtual_tom, u64);
268reg_direct!(event_inject, set_event_inject, SevEventInjectInfo);
269reg_direct!(guest_error_code, set_guest_error_code, u64);
270regss!(es, set_es);
271regss!(cs, set_cs);
272regss!(ss, set_ss);
273regss!(ds, set_ds);
274regss!(fs, set_fs);
275regss!(gs, set_gs);
276regss!(gdtr, set_gdtr);
277regss!(ldtr, set_ldtr);
278regss!(idtr, set_idtr);
279regss!(tr, set_tr);
280reg64!(pl0_ssp, set_pl0_ssp);
281reg64!(pl1_ssp, set_pl1_ssp);
282reg64!(pl2_ssp, set_pl2_ssp);
283reg64!(pl3_ssp, set_pl3_ssp);
284reg64!(u_cet, set_u_cet);
285reg64!(efer, set_efer);
286reg64!(xss, set_xss);
287reg64!(cr4, set_cr4);
288reg64!(cr3, set_cr3);
289reg64!(cr0, set_cr0);
290reg64!(dr7, set_dr7);
291reg64!(dr6, set_dr6);
292reg64!(rflags, set_rflags);
293reg64!(rip, set_rip);
294reg64!(dr0, set_dr0);
295reg64!(dr1, set_dr1);
296reg64!(dr2, set_dr2);
297reg64!(dr3, set_dr3);
298reg64!(rsp, set_rsp);
299reg64!(s_cet, set_s_cet);
300reg64!(ssp, set_ssp);
301reg64!(interrupt_ssp_table_addr, set_interrupt_ssp_table_addr);
302reg64!(rax, set_rax);
303reg64!(star, set_star);
304reg64!(lstar, set_lstar);
305reg64!(cstar, set_cstar);
306reg64!(sfmask, set_sfmask);
307reg64!(kernel_gs_base, set_kernel_gs_base);
308reg64!(sysenter_cs, set_sysenter_cs);
309reg64!(sysenter_esp, set_sysenter_esp);
310reg64!(sysenter_eip, set_sysenter_eip);
311reg64!(cr2, set_cr2);
312reg64!(pat, set_pat);
313reg64!(spec_ctrl, set_spec_ctrl);
314reg32!(tsc_aux, set_tsc_aux);
315reg64!(rcx, set_rcx);
316reg64!(rdx, set_rdx);
317reg64!(rbx, set_rbx);
318reg64!(rbp, set_rbp);
319reg64!(rsi, set_rsi);
320reg64!(rdi, set_rdi);
321reg64!(r8, set_r8);
322reg64!(r9, set_r9);
323reg64!(r10, set_r10);
324reg64!(r11, set_r11);
325reg64!(r12, set_r12);
326reg64!(r13, set_r13);
327reg64!(r14, set_r14);
328reg64!(r15, set_r15);
329reg64!(next_rip, set_next_rip);
330reg64!(pcpu_id, set_pcpu_id);
331reg64!(xcr0, set_xcr0);
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 #[test]
338 fn test_reg_access() {
339 let nonce = 0xffff_ffff_ffff_ffffu64;
340 let nonce128 = ((nonce as u128) << 64) | nonce as u128;
341 let mut vmsa: SevVmsa = FromZeros::new_zeroed();
342 vmsa.register_protection_nonce = nonce;
343 let bitmap = [0xffu8; 64];
344 let mut vmsa_wrapper = VmsaWrapper {
345 vmsa: &mut vmsa,
346 bitmap: &bitmap,
347 };
348
349 let val = 0x0000_0055_0000_0055u128;
350 let val_xor = val ^ nonce128;
351 let cs = SevSelector::from(val);
352 let cs_xor = SevSelector::from(val_xor);
353 let vmpl = 2u8;
354 let rip = 0x55u64;
355 let rip_xor = rip ^ nonce;
356 let tsc = 0x55u32;
357 let tsc_xor = tsc ^ (nonce as u32);
358 let xmm_idx = 1;
359 let ymm_idx = 1;
360 let x87 = [0x55u64; 10];
361 let x87_xor = x87.map(|v| v ^ nonce);
362
363 vmsa_wrapper.set_cs(cs);
364 vmsa_wrapper.set_vmpl(vmpl);
365 vmsa_wrapper.set_rip(rip);
366 vmsa_wrapper.set_tsc_aux(tsc);
367 vmsa_wrapper.set_xmm_registers(xmm_idx, val);
368 vmsa_wrapper.set_ymm_registers(ymm_idx, val);
369 vmsa_wrapper.set_x87_registers(&x87);
370
371 assert!(vmsa_wrapper.cs() == cs);
372 assert!(vmsa_wrapper.vmpl() == vmpl);
373 assert!(vmsa_wrapper.rip() == rip);
374 assert!(vmsa_wrapper.xmm_registers(xmm_idx) == val);
375 assert!(vmsa_wrapper.ymm_registers(ymm_idx) == val);
376 assert!(vmsa_wrapper.tsc_aux() == tsc);
377 assert!(vmsa_wrapper.x87_registers() == x87);
378 assert!(vmsa.cs == cs_xor); assert!(vmsa.vmpl == vmpl); assert!(vmsa.rip == rip_xor); assert!(vmsa.tsc_aux == tsc_xor); assert!(vmsa.pkru == 0); assert!(vmsa.xmm_registers[xmm_idx].as_u128() == val_xor); assert!(vmsa.ymm_registers[ymm_idx].as_u128() == val_xor); assert!(vmsa.x87_registers == x87_xor);
386 }
387
388 #[test]
389 fn test_init() {
390 let mut vmsa: SevVmsa = FromZeros::new_zeroed();
391 let mut bitmap = [0x0u8; 64];
392 let xmm_idx = 1;
393 bitmap[5] = 0x80u8; bitmap[18] = 0x03u8; let mut vmsa_wrapper = VmsaWrapper {
396 vmsa: &mut vmsa,
397 bitmap: &bitmap,
398 };
399 vmsa_wrapper.reset(true);
400
401 assert!(vmsa_wrapper.rip() == 0);
402 assert!(vmsa_wrapper.xmm_registers(xmm_idx) == 0);
403
404 let nonce = vmsa.register_protection_nonce;
405 let xmm_val = ((nonce as u128) << 64) | nonce as u128;
406 assert!(vmsa.rip == nonce);
407 assert!(vmsa.xmm_registers[xmm_idx].as_u128() == xmm_val);
408 }
409}