1use super::Hcl;
7use super::HclVp;
8use super::MshvVtl;
9use super::NoRunner;
10use super::ProcessorRunner;
11use super::hcl_tdcall;
12use super::mshv_tdcall;
13use crate::GuestVtl;
14use crate::protocol::tdx_tdg_vp_enter_exit_info;
15use crate::protocol::tdx_vp_context;
16use crate::protocol::tdx_vp_state;
17use crate::protocol::tdx_vp_state_flags;
18use hv1_structs::VtlArray;
19use hvdef::HvRegisterName;
20use hvdef::HvRegisterValue;
21use memory_range::MemoryRange;
22use sidecar_client::SidecarVp;
23use std::cell::UnsafeCell;
24use std::os::fd::AsRawFd;
25use tdcall::Tdcall;
26use tdcall::tdcall_vp_invgla;
27use tdcall::tdcall_vp_rd;
28use tdcall::tdcall_vp_wr;
29use x86defs::tdx::TdCallResult;
30use x86defs::tdx::TdCallResultCode;
31use x86defs::tdx::TdGlaVmAndFlags;
32use x86defs::tdx::TdVpsClassCode;
33use x86defs::tdx::TdgMemPageAttrWriteR8;
34use x86defs::tdx::TdgMemPageGpaAttr;
35use x86defs::tdx::TdxContextCode;
36use x86defs::tdx::TdxExtendedFieldCode;
37use x86defs::tdx::TdxGlaListInfo;
38use x86defs::tdx::TdxGp;
39use x86defs::tdx::TdxL2Ctls;
40use x86defs::tdx::TdxL2EnterGuestState;
41use x86defs::tdx::TdxVmFlags;
42use x86defs::vmx::ApicPage;
43use x86defs::vmx::VmcsField;
44
45pub struct Tdx<'a> {
47 apic_pages: VtlArray<&'a UnsafeCell<ApicPage>, 2>,
48}
49
50impl MshvVtl {
51 pub fn tdx_set_page_attributes(
53 &self,
54 range: MemoryRange,
55 attributes: TdgMemPageGpaAttr,
56 mask: TdgMemPageAttrWriteR8,
57 ) -> Result<(), TdCallResultCode> {
58 tdcall::set_page_attributes(&mut MshvVtlTdcall(self), range, attributes, mask)
59 }
60
61 pub fn tdx_accept_pages(
68 &self,
69 range: MemoryRange,
70 attributes: Option<(TdgMemPageGpaAttr, TdgMemPageAttrWriteR8)>,
71 ) -> Result<(), tdcall::AcceptPagesError> {
72 let attributes = attributes
73 .map_or(tdcall::AcceptPagesAttributes::None, |(attributes, mask)| {
74 tdcall::AcceptPagesAttributes::Set { attributes, mask }
75 });
76
77 tdcall::accept_pages(&mut MshvVtlTdcall(self), range, attributes)
78 }
79}
80
81impl<'a> ProcessorRunner<'a, Tdx<'a>> {
82 fn tdx_vp_context(&self) -> &tdx_vp_context {
85 unsafe { &*(&raw mut (*self.run.get()).context).cast() }
89 }
90
91 fn tdx_vp_context_mut(&mut self) -> &mut tdx_vp_context {
94 unsafe { &mut *(&raw mut (*self.run.get()).context).cast() }
98 }
99
100 fn tdx_enter_guest_state(&self) -> &TdxL2EnterGuestState {
102 &self.tdx_vp_context().gpr_list
103 }
104
105 fn tdx_enter_guest_state_mut(&mut self) -> &mut TdxL2EnterGuestState {
107 &mut self.tdx_vp_context_mut().gpr_list
108 }
109
110 pub fn tdx_enter_guest_gps(&self) -> &[u64; 16] {
112 &self.tdx_enter_guest_state().gps
113 }
114
115 pub fn tdx_enter_guest_gps_mut(&mut self) -> &mut [u64; 16] {
117 &mut self.tdx_enter_guest_state_mut().gps
118 }
119
120 pub fn tdx_vp_enter_exit_info(&self) -> &tdx_tdg_vp_enter_exit_info {
122 &self.tdx_vp_context().exit_info
123 }
124
125 pub fn tdx_apic_page(&self, vtl: GuestVtl) -> &ApicPage {
127 unsafe { &*self.state.apic_pages[vtl].get() }
130 }
131
132 pub fn tdx_apic_page_mut(&mut self, vtl: GuestVtl) -> &mut ApicPage {
134 unsafe { &mut *self.state.apic_pages[vtl].get() }
137 }
138
139 fn tdx_vp_state(&self) -> &tdx_vp_state {
141 &self.tdx_vp_context().vp_state
142 }
143
144 fn tdx_vp_state_mut(&mut self) -> &mut tdx_vp_state {
146 &mut self.tdx_vp_context_mut().vp_state
147 }
148
149 pub fn cr2(&self) -> u64 {
151 self.tdx_vp_state().cr2
152 }
153
154 pub fn set_cr2(&mut self, value: u64) {
156 self.tdx_vp_state_mut().cr2 = value;
157 }
158
159 pub fn tdx_vp_state_flags_mut(&mut self) -> &mut tdx_vp_state_flags {
161 &mut self.tdx_vp_state_mut().flags
162 }
163
164 fn tdx_vp_entry_flags(&self) -> &TdxVmFlags {
166 &self.tdx_vp_context().entry_rcx
167 }
168
169 fn tdx_vp_entry_flags_mut(&mut self) -> &mut TdxVmFlags {
171 &mut self.tdx_vp_context_mut().entry_rcx
172 }
173
174 pub fn read_private_regs(&self, regs: &mut TdxPrivateRegs) {
177 let TdxL2EnterGuestState {
178 gps, rflags,
180 rip,
181 ssp,
182 rvi,
183 svi,
184 reserved: _reserved,
185 } = self.tdx_enter_guest_state();
186 regs.rflags = *rflags;
187 regs.rip = *rip;
188 regs.rsp = gps[TdxGp::RSP];
189 regs.ssp = *ssp;
190 regs.rvi = *rvi;
191 regs.svi = *svi;
192
193 let tdx_vp_state {
194 msr_kernel_gs_base,
195 msr_star,
196 msr_lstar,
197 msr_sfmask,
198 msr_xss,
199 cr2: _cr2, msr_tsc_aux,
201 flags: _flags, } = self.tdx_vp_state();
203 regs.msr_kernel_gs_base = *msr_kernel_gs_base;
204 regs.msr_star = *msr_star;
205 regs.msr_lstar = *msr_lstar;
206 regs.msr_sfmask = *msr_sfmask;
207 regs.msr_xss = *msr_xss;
208 regs.msr_tsc_aux = *msr_tsc_aux;
209
210 regs.vp_entry_flags = *self.tdx_vp_entry_flags();
211 }
212
213 pub fn write_private_regs(&mut self, regs: &TdxPrivateRegs) {
216 let TdxPrivateRegs {
217 rflags,
218 rip,
219 rsp,
220 ssp,
221 rvi,
222 svi,
223 msr_kernel_gs_base,
224 msr_star,
225 msr_lstar,
226 msr_sfmask,
227 msr_xss,
228 msr_tsc_aux,
229 vp_entry_flags,
230 } = regs;
231
232 let enter_guest_state = self.tdx_enter_guest_state_mut();
233 enter_guest_state.rflags = *rflags;
234 enter_guest_state.rip = *rip;
235 enter_guest_state.ssp = *ssp;
236 enter_guest_state.rvi = *rvi;
237 enter_guest_state.svi = *svi;
238 enter_guest_state.gps[TdxGp::RSP] = *rsp;
239
240 let vp_state = self.tdx_vp_state_mut();
241 vp_state.msr_kernel_gs_base = *msr_kernel_gs_base;
242 vp_state.msr_star = *msr_star;
243 vp_state.msr_lstar = *msr_lstar;
244 vp_state.msr_sfmask = *msr_sfmask;
245 vp_state.msr_xss = *msr_xss;
246 vp_state.msr_tsc_aux = *msr_tsc_aux;
247
248 *self.tdx_vp_entry_flags_mut() = *vp_entry_flags;
249 }
250
251 fn write_vmcs(&mut self, vtl: GuestVtl, field: VmcsField, mask: u64, value: u64) -> u64 {
252 tdcall_vp_wr(
253 &mut MshvVtlTdcall(&self.hcl.mshv_vtl),
254 vmcs_field_code(field, vtl),
255 value,
256 mask,
257 )
258 .expect("fatal vmcs access failure")
259 }
260
261 fn read_vmcs(&self, vtl: GuestVtl, field: VmcsField) -> u64 {
262 tdcall_vp_rd(
263 &mut MshvVtlTdcall(&self.hcl.mshv_vtl),
264 vmcs_field_code(field, vtl),
265 )
266 .expect("fatal vmcs access failure")
267 }
268
269 pub fn write_vmcs64(&mut self, vtl: GuestVtl, field: VmcsField, mask: u64, value: u64) -> u64 {
277 assert!(matches!(
278 field.field_width(),
279 x86defs::vmx::FieldWidth::WidthNatural | x86defs::vmx::FieldWidth::Width64
280 ));
281 self.write_vmcs(vtl, field, mask, value)
282 }
283
284 pub fn read_vmcs64(&self, vtl: GuestVtl, field: VmcsField) -> u64 {
289 assert!(matches!(
290 field.field_width(),
291 x86defs::vmx::FieldWidth::WidthNatural | x86defs::vmx::FieldWidth::Width64
292 ));
293 self.read_vmcs(vtl, field)
294 }
295
296 pub fn write_vmcs32(&mut self, vtl: GuestVtl, field: VmcsField, mask: u32, value: u32) -> u32 {
304 assert_eq!(field.field_width(), x86defs::vmx::FieldWidth::Width32);
305 self.write_vmcs(vtl, field, mask.into(), value.into()) as u32
306 }
307
308 pub fn read_vmcs32(&self, vtl: GuestVtl, field: VmcsField) -> u32 {
313 assert_eq!(field.field_width(), x86defs::vmx::FieldWidth::Width32);
314 self.read_vmcs(vtl, field) as u32
315 }
316
317 pub fn write_vmcs16(&mut self, vtl: GuestVtl, field: VmcsField, mask: u16, value: u16) -> u16 {
325 assert_eq!(field.field_width(), x86defs::vmx::FieldWidth::Width16);
326 self.write_vmcs(vtl, field, mask.into(), value.into()) as u16
327 }
328
329 pub fn read_vmcs16(&self, vtl: GuestVtl, field: VmcsField) -> u16 {
334 assert_eq!(field.field_width(), x86defs::vmx::FieldWidth::Width16);
335 self.read_vmcs(vtl, field) as u16
336 }
337
338 pub fn set_msr_bit(&self, vtl: GuestVtl, msr_index: u32, write: bool, intercept: bool) {
342 let mut word_index = (msr_index & 0xFFFF) / 64;
343
344 if msr_index & 0x80000000 == 0x80000000 {
345 assert!((0xC0000000..=0xC0001FFF).contains(&msr_index));
346 word_index += 0x80;
347 } else {
348 assert!(msr_index <= 0x00001FFF);
349 }
350
351 if write {
352 word_index += 0x100;
353 }
354
355 self.write_msr_bitmap(
356 vtl,
357 word_index,
358 1 << (msr_index as u64 & 0x3F),
359 if intercept { !0 } else { 0 },
360 );
361 }
362
363 pub fn write_msr_bitmap(&self, vtl: GuestVtl, i: u32, mask: u64, word: u64) -> u64 {
370 let class_code = match vtl {
371 GuestVtl::Vtl0 => TdVpsClassCode::MSR_BITMAPS_1,
372 GuestVtl::Vtl1 => TdVpsClassCode::MSR_BITMAPS_2,
373 };
374 let field_code = TdxExtendedFieldCode::new()
375 .with_context_code(TdxContextCode::TD_VCPU)
376 .with_field_size(x86defs::tdx::FieldSize::Size64Bit)
377 .with_field_code(i)
378 .with_class_code(class_code.0);
379
380 tdcall_vp_wr(
381 &mut MshvVtlTdcall(&self.hcl.mshv_vtl),
382 field_code,
383 word,
384 mask,
385 )
386 .unwrap()
387 }
388
389 pub fn set_l2_ctls(&self, vtl: GuestVtl, value: TdxL2Ctls) -> Result<TdxL2Ctls, TdCallResult> {
393 let field_code = match vtl {
394 GuestVtl::Vtl0 => x86defs::tdx::TDX_FIELD_CODE_L2_CTLS_VM1,
395 GuestVtl::Vtl1 => x86defs::tdx::TDX_FIELD_CODE_L2_CTLS_VM2,
396 };
397 tdcall_vp_wr(
398 &mut MshvVtlTdcall(&self.hcl.mshv_vtl),
399 field_code,
400 value.into(),
401 !0,
402 )
403 .map(Into::into)
404 }
405
406 pub fn invgla(
408 &self,
409 gla_flags: TdGlaVmAndFlags,
410 gla_info: TdxGlaListInfo,
411 ) -> Result<(), TdCallResult> {
412 tdcall_vp_invgla(&mut MshvVtlTdcall(&self.hcl.mshv_vtl), gla_flags, gla_info)
413 }
414
415 pub fn fx_state(&self) -> &x86defs::xsave::Fxsave {
417 &self.tdx_vp_context().fx_state
418 }
419
420 pub fn fx_state_mut(&mut self) -> &mut x86defs::xsave::Fxsave {
422 &mut self.tdx_vp_context_mut().fx_state
423 }
424}
425
426fn vmcs_field_code(field: VmcsField, vtl: GuestVtl) -> TdxExtendedFieldCode {
427 let class_code = match vtl {
428 GuestVtl::Vtl0 => TdVpsClassCode::VMCS_1,
429 GuestVtl::Vtl1 => TdVpsClassCode::VMCS_2,
430 };
431 let field_size = match field.field_width() {
432 x86defs::vmx::FieldWidth::Width16 => x86defs::tdx::FieldSize::Size16Bit,
433 x86defs::vmx::FieldWidth::Width32 => x86defs::tdx::FieldSize::Size32Bit,
434 x86defs::vmx::FieldWidth::Width64 => x86defs::tdx::FieldSize::Size64Bit,
435 x86defs::vmx::FieldWidth::WidthNatural => x86defs::tdx::FieldSize::Size64Bit,
436 };
437 TdxExtendedFieldCode::new()
438 .with_context_code(TdxContextCode::TD_VCPU)
439 .with_class_code(class_code.0)
440 .with_field_code(field.into())
441 .with_field_size(field_size)
442}
443
444impl<'a> super::private::BackingPrivate<'a> for Tdx<'a> {
445 fn new(vp: &'a HclVp, sidecar: Option<&SidecarVp<'_>>, hcl: &Hcl) -> Result<Self, NoRunner> {
446 assert!(sidecar.is_none());
447 let super::BackingState::Tdx {
448 vtl0_apic_page,
449 vtl1_apic_page,
450 } = &vp.backing
451 else {
452 return Err(NoRunner::MismatchedIsolation);
453 };
454
455 let vtl1_apic_page_addr = vtl1_apic_page.pfns()[0] * user_driver::memory::PAGE_SIZE64;
458 tdcall_vp_wr(
459 &mut MshvVtlTdcall(&hcl.mshv_vtl),
460 vmcs_field_code(VmcsField::VMX_VMCS_VIRTUAL_APIC_PAGE, GuestVtl::Vtl1),
461 vtl1_apic_page_addr,
462 !0,
463 )
464 .expect("failed registering VTL1 APIC page");
465
466 let vtl1_apic_page = unsafe { &*vtl1_apic_page.base().cast() };
469
470 Ok(Self {
471 apic_pages: [vtl0_apic_page.as_ref(), vtl1_apic_page].into(),
472 })
473 }
474
475 fn try_set_reg(
476 _runner: &mut ProcessorRunner<'a, Self>,
477 _vtl: GuestVtl,
478 _name: HvRegisterName,
479 _value: HvRegisterValue,
480 ) -> Result<bool, super::Error> {
481 Ok(false)
482 }
483
484 fn must_flush_regs_on(_runner: &ProcessorRunner<'a, Self>, _name: HvRegisterName) -> bool {
485 false
486 }
487
488 fn try_get_reg(
489 _runner: &ProcessorRunner<'a, Self>,
490 _vtl: GuestVtl,
491 _name: HvRegisterName,
492 ) -> Result<Option<HvRegisterValue>, super::Error> {
493 Ok(None)
494 }
495
496 fn flush_register_page(_runner: &mut ProcessorRunner<'a, Self>) {}
497}
498
499#[derive(inspect::InspectMut)]
501#[expect(missing_docs, reason = "Self-describing field names")]
502pub struct TdxPrivateRegs {
503 pub rflags: u64,
505 pub rip: u64,
506 pub rsp: u64,
507 pub ssp: u64,
508 pub rvi: u8,
509 pub svi: u8,
510 pub msr_kernel_gs_base: u64,
512 pub msr_star: u64,
513 pub msr_lstar: u64,
514 pub msr_sfmask: u64,
515 pub msr_xss: u64,
516 pub msr_tsc_aux: u64,
517 #[inspect(hex, with = "|x| x.into_bits()")]
519 pub vp_entry_flags: TdxVmFlags,
520}
521
522impl TdxPrivateRegs {
523 pub fn new(vtl: GuestVtl) -> Self {
526 Self {
527 rflags: x86defs::RFlags::at_reset().into(),
528 rip: 0,
529 rsp: 0,
530 ssp: 0,
531 rvi: 0,
532 svi: 0,
533 msr_kernel_gs_base: 0,
534 msr_star: 0,
535 msr_lstar: 0,
536 msr_sfmask: 0,
537 msr_xss: 0,
538 msr_tsc_aux: 0,
539 vp_entry_flags: TdxVmFlags::new()
544 .with_vm_index(vtl as u8 + 1)
545 .with_invd_translations(x86defs::tdx::TDX_VP_ENTER_INVD_INVEPT),
546 }
547 }
548}
549
550struct MshvVtlTdcall<'a>(&'a MshvVtl);
551
552impl Tdcall for MshvVtlTdcall<'_> {
553 fn tdcall(&mut self, input: tdcall::TdcallInput) -> tdcall::TdcallOutput {
554 let mut mshv_tdcall_args = {
555 let tdcall::TdcallInput {
556 leaf,
557 rcx,
558 rdx,
559 r8,
560 r9,
561 r10,
562 r11,
563 r12,
564 r13,
565 r14,
566 r15,
567 } = input;
568
569 assert_ne!(leaf, x86defs::tdx::TdCallLeaf::VP_VMCALL);
573 assert_eq!(r10, 0);
574 assert_eq!(r11, 0);
575 assert_eq!(r12, 0);
576 assert_eq!(r13, 0);
577 assert_eq!(r14, 0);
578 assert_eq!(r15, 0);
579
580 mshv_tdcall {
581 rax: leaf.0,
582 rcx,
583 rdx,
584 r8,
585 r9,
586 r10_out: 0,
587 r11_out: 0,
588 }
589 };
590
591 unsafe {
593 hcl_tdcall(self.0.file.as_raw_fd(), &mut mshv_tdcall_args)
596 .expect("todo handle tdcall ioctl error");
597 }
598
599 tdcall::TdcallOutput {
600 rax: TdCallResult::from(mshv_tdcall_args.rax),
601 rcx: mshv_tdcall_args.rcx,
602 rdx: mshv_tdcall_args.rdx,
603 r8: mshv_tdcall_args.r8,
604 r10: mshv_tdcall_args.r10_out,
605 r11: mshv_tdcall_args.r11_out,
606 }
607 }
608}