1pub mod serialize_proto;
34
35pub mod devicereport;
37
38#[cfg(test)]
39mod tests;
40
41pub mod test_helpers;
43
44use anyhow::Context;
45use parking_lot::Mutex;
46use std::sync::Arc;
47pub use tdisp_proto::GuestToHostCommand;
48pub use tdisp_proto::GuestToHostCommandExt;
49pub use tdisp_proto::GuestToHostResponse;
50pub use tdisp_proto::GuestToHostResponseExt;
51pub use tdisp_proto::TdispCommandResponseBind;
52pub use tdisp_proto::TdispCommandResponseGetDeviceInterfaceInfo;
53pub use tdisp_proto::TdispCommandResponseGetTdiReport;
54pub use tdisp_proto::TdispCommandResponseStartTdi;
55pub use tdisp_proto::TdispCommandResponseUnbind;
56pub use tdisp_proto::TdispDeviceInterfaceInfo;
57pub use tdisp_proto::TdispGuestOperationError;
58pub use tdisp_proto::TdispGuestOperationErrorCode;
59pub use tdisp_proto::TdispGuestProtocolType;
60pub use tdisp_proto::TdispGuestUnbindReason;
61pub use tdisp_proto::TdispReportType;
62pub use tdisp_proto::TdispTdiState;
63pub use tdisp_proto::guest_to_host_command::Command;
64pub use tdisp_proto::guest_to_host_response::Response;
65
66use tracing::instrument;
67
68pub type TdispCommandCallback = dyn Fn(&GuestToHostCommand) -> anyhow::Result<()> + Send + Sync;
70
71pub trait TdispHostDeviceInterface: Send + Sync {
75 fn tdisp_negotiate_protocol(
77 &mut self,
78 _requested_guest_protocol: TdispGuestProtocolType,
79 ) -> anyhow::Result<TdispDeviceInterfaceInfo>;
80
81 fn tdisp_bind_device(&mut self) -> anyhow::Result<()>;
84
85 fn tdisp_start_device(&mut self) -> anyhow::Result<()>;
88
89 fn tdisp_unbind_device(&mut self) -> anyhow::Result<()>;
91
92 fn tdisp_get_device_report(&mut self, _report_type: TdispReportType)
94 -> anyhow::Result<Vec<u8>>;
95}
96
97pub trait TdispHostDeviceTarget: Send + Sync {
99 fn tdisp_handle_guest_command(
101 &mut self,
102 _command: GuestToHostCommand,
103 ) -> anyhow::Result<GuestToHostResponse>;
104}
105
106pub struct TdispHostDeviceTargetEmulator {
108 machine: TdispHostStateMachine,
109 debug_device_id: String,
110}
111
112impl TdispHostDeviceTargetEmulator {
113 pub fn new(
115 host_interface: Arc<Mutex<dyn TdispHostDeviceInterface>>,
116 debug_device_id: &str,
117 ) -> Self {
118 Self {
119 machine: TdispHostStateMachine::new(host_interface),
120 debug_device_id: debug_device_id.to_owned(),
121 }
122 }
123
124 pub fn set_debug_device_id(&mut self, debug_device_id: &str) {
126 self.machine.set_debug_device_id(debug_device_id.to_owned());
127 self.debug_device_id = debug_device_id.to_owned();
128 }
129
130 pub fn reset(&self) {}
132}
133
134impl TdispHostDeviceTarget for TdispHostDeviceTargetEmulator {
135 #[instrument(fields(device_id = %self.debug_device_id), skip(self))]
139 fn tdisp_handle_guest_command(
140 &mut self,
141 command: GuestToHostCommand,
142 ) -> anyhow::Result<GuestToHostResponse> {
143 let mut error = TdispGuestOperationError::Success;
144 let mut response: Option<Response> = None;
145 let state_before = self.machine.state();
146 match &command.command {
147 Some(Command::GetDeviceInterfaceInfo(req)) => {
148 let protocol_type = TdispGuestProtocolType::from_i32(req.guest_protocol_type);
149
150 match protocol_type {
151 Some(protocol_type) => {
152 let interface_info = self.machine.tdisp_negotiate_protocol(protocol_type);
153 match interface_info {
154 Ok(interface_info) => {
155 response = Some(Response::GetDeviceInterfaceInfo(
156 TdispCommandResponseGetDeviceInterfaceInfo {
157 interface_info: Some(interface_info),
158 },
159 ));
160 }
161 Err(err) => {
162 error = err;
163 }
164 }
165 }
166 None => {
167 error = TdispGuestOperationError::InvalidGuestProtocolRequest;
168 }
169 }
170 }
171 Some(Command::Bind(_)) => {
172 let bind_res = self.machine.request_lock_device_resources();
173 if let Err(err) = bind_res {
174 error = err;
175 } else {
176 response = Some(Response::Bind(TdispCommandResponseBind {}));
177 }
178 }
179 Some(Command::StartTdi(_)) => {
180 let start_tdi_res = self.machine.request_start_tdi();
181 if let Err(err) = start_tdi_res {
182 error = err;
183 } else {
184 response = Some(Response::StartTdi(TdispCommandResponseStartTdi {}));
185 }
186 }
187 Some(Command::Unbind(cmd)) => {
188 let unbind_reason = TdispGuestUnbindReason::from_i32(cmd.unbind_reason);
189
190 match unbind_reason {
191 Some(reason) => {
192 let unbind_res = self.machine.request_unbind(reason);
193 if let Err(err) = unbind_res {
194 error = err;
195 }
196 response = Some(Response::Unbind(TdispCommandResponseUnbind {}));
197 }
198 None => {
199 error = TdispGuestOperationError::InvalidGuestUnbindReason;
200 }
201 }
202 }
203 Some(Command::GetTdiReport(cmd)) => {
204 let report_type = TdispReportType::from_i32(cmd.report_type);
205 match report_type {
206 Some(report_type) => {
207 let report_buffer = self.machine.request_attestation_report(report_type);
208
209 match report_buffer {
210 Ok(report_buffer) => {
211 response = Some(Response::GetTdiReport(
212 TdispCommandResponseGetTdiReport {
213 report_type: cmd.report_type,
214 report_buffer,
215 },
216 ));
217 }
218 Err(err) => {
219 error = err;
220 }
221 }
222 }
223 None => {
224 error = TdispGuestOperationError::InvalidGuestAttestationReportType;
225 }
226 }
227 }
228 _ => {
229 error = TdispGuestOperationError::InvalidGuestCommandId;
230 }
231 }
232 let state_after = self.machine.state();
233 let error_code: TdispGuestOperationErrorCode = error.into();
234 let resp = GuestToHostResponse {
235 result: error_code.into(),
236 tdi_state_before: state_before.into(),
237 tdi_state_after: state_after.into(),
238 response,
239 };
240
241 match error {
242 TdispGuestOperationError::Success => {
243 tracing::info!(?resp, "tdisp_handle_guest_command success");
244 }
245 _ => {
246 tracing::error!(?resp, "tdisp_handle_guest_command error");
247 }
248 }
249
250 Ok(resp)
251 }
252}
253
254pub trait TdispClientDevice: Send + Sync {
257 fn tdisp_command_to_host(&self, command: GuestToHostCommand) -> anyhow::Result<()>;
260}
261
262const TDISP_STATE_HISTORY_LEN: usize = 10;
264
265#[derive(Debug)]
269pub enum TdispUnbindReason {
270 Unknown(anyhow::Error),
272
273 GuestInitiated(TdispGuestUnbindReason),
275
276 ImpossibleStateTransition(anyhow::Error),
278
279 InvalidGuestTransitionToLocked,
282
283 InvalidGuestTransitionToRun,
286
287 InvalidGuestGetAttestationReportState,
290
291 InvalidGuestAcceptAttestationReportState,
294
295 InvalidGuestUnbindReason(anyhow::Error),
299}
300
301pub struct TdispHostStateMachine {
304 current_state: TdispTdiState,
306 state_history: Vec<TdispTdiState>,
308 debug_device_id: String,
310 unbind_reason_history: Vec<TdispUnbindReason>,
312 host_interface: Arc<Mutex<dyn TdispHostDeviceInterface>>,
314 guest_protocol_type: TdispGuestProtocolType,
316}
317
318impl TdispHostStateMachine {
319 pub fn new(host_interface: Arc<Mutex<dyn TdispHostDeviceInterface>>) -> Self {
321 Self {
322 current_state: TdispTdiState::Unlocked,
323 state_history: Vec::new(),
324 debug_device_id: "".to_owned(),
325 unbind_reason_history: Vec::new(),
326 host_interface,
327 guest_protocol_type: TdispGuestProtocolType::Invalid,
328 }
329 }
330
331 pub fn set_debug_device_id(&mut self, debug_device_id: String) {
333 self.debug_device_id = debug_device_id;
334 }
335
336 fn state(&self) -> TdispTdiState {
338 self.current_state
339 }
340
341 fn ensure_negotiated_protocol(&self) -> anyhow::Result<()> {
342 if self.guest_protocol_type == TdispGuestProtocolType::Invalid {
343 tracing::error!(
344 "Guest tried to perform a state transition without negotiating a protocol with the host!"
345 );
346 return Err(anyhow::anyhow!(
347 "Guest tried to perform a state transition without negotiating a protocol with the host!"
348 ));
349 }
350 Ok(())
351 }
352
353 #[instrument(fields(device_id = %self.debug_device_id), skip(self))]
357 fn is_valid_state_transition(&self, new_state: &TdispTdiState) -> bool {
358 match self.ensure_negotiated_protocol() {
360 Ok(_) => {}
361 Err(e) => {
362 tracing::error!("Failed to transition state: {e:?}");
363 return false;
364 }
365 }
366
367 match (self.current_state, *new_state) {
368 (TdispTdiState::Unlocked, TdispTdiState::Locked) => true,
370 (TdispTdiState::Locked, TdispTdiState::Run) => true,
371
372 (TdispTdiState::Run, TdispTdiState::Unlocked) => true,
374 (TdispTdiState::Locked, TdispTdiState::Unlocked) => true,
375 (TdispTdiState::Unlocked, TdispTdiState::Unlocked) => true,
376
377 _ => false,
379 }
380 }
381
382 #[instrument(fields(device_id = %self.debug_device_id), skip(self))]
385 fn transition_state_to(&mut self, new_state: TdispTdiState) -> anyhow::Result<()> {
386 tracing::info!(
387 "Request to transition from {:?} -> {:?}",
388 self.current_state,
389 new_state
390 );
391
392 if !self.is_valid_state_transition(&new_state) {
394 tracing::info!(
395 "Invalid state transition {:?} -> {:?}",
396 self.current_state,
397 new_state
398 );
399 return Err(anyhow::anyhow!(
400 "Invalid state transition {:?} -> {:?}",
401 self.current_state,
402 new_state
403 ));
404 }
405
406 if self.state_history.len() == TDISP_STATE_HISTORY_LEN {
408 self.state_history.remove(0);
409 }
410 self.state_history.push(self.current_state);
411
412 self.current_state = new_state;
414 tracing::info!("Transitioned to {:?}", self.current_state);
415
416 Ok(())
417 }
418
419 #[instrument(fields(device_id = %self.debug_device_id), skip(self))]
421 fn unbind_all(&mut self, reason: TdispUnbindReason) -> anyhow::Result<()> {
422 tracing::info!("Unbind called with reason {:?}", reason);
423
424 if let Err(reason) = self.transition_state_to(TdispTdiState::Unlocked) {
427 return Err(anyhow::anyhow!(
428 "Impossible state machine violation during TDISP Unbind: {:?}",
429 reason
430 ));
431 }
432
433 let res = self
435 .host_interface
436 .lock()
437 .tdisp_unbind_device()
438 .context("host failed to unbind TDI");
439
440 if let Err(e) = res {
441 tracing::error!("Failed to unbind TDI: {:?}", e);
442 return Err(e);
443 }
444
445 if self.unbind_reason_history.len() == TDISP_STATE_HISTORY_LEN {
447 self.unbind_reason_history.remove(0);
448 }
449 self.unbind_reason_history.push(reason);
450
451 Ok(())
452 }
453}
454
455pub trait TdispGuestRequestInterface {
459 fn tdisp_negotiate_protocol(
469 &mut self,
470 requested_guest_protocol: TdispGuestProtocolType,
471 ) -> Result<TdispDeviceInterfaceInfo, TdispGuestOperationError>;
472
473 fn request_lock_device_resources(&mut self) -> Result<(), TdispGuestOperationError>;
482
483 fn request_start_tdi(&mut self) -> Result<(), TdispGuestOperationError>;
491
492 fn request_attestation_report(
500 &mut self,
501 report_type: TdispReportType,
502 ) -> Result<Vec<u8>, TdispGuestOperationError>;
503
504 fn request_unbind(
513 &mut self,
514 reason: TdispGuestUnbindReason,
515 ) -> Result<(), TdispGuestOperationError>;
516}
517
518impl TdispGuestRequestInterface for TdispHostStateMachine {
519 #[instrument(fields(device_id = %self.debug_device_id), skip(self))]
521 fn tdisp_negotiate_protocol(
522 &mut self,
523 requested_guest_protocol: TdispGuestProtocolType,
524 ) -> Result<TdispDeviceInterfaceInfo, TdispGuestOperationError> {
525 if self.guest_protocol_type != TdispGuestProtocolType::Invalid {
526 tracing::error!(
527 "Guest tried to negotiate a protocol with the host while a protocol was already negotiated!"
528 );
529 return Err(TdispGuestOperationError::InvalidGuestProtocolRequest);
530 }
531
532 if requested_guest_protocol == TdispGuestProtocolType::Invalid {
533 tracing::error!("Guest tried to negotiate Invalid as a protocol");
534 return Err(TdispGuestOperationError::InvalidGuestProtocolRequest);
535 }
536
537 let res = self
539 .host_interface
540 .lock()
541 .tdisp_negotiate_protocol(requested_guest_protocol)
542 .context("failed to call to negotiate protocol");
543
544 match res {
545 Ok(interface_info) => {
546 match TdispGuestProtocolType::from_i32(interface_info.guest_protocol_type) {
547 Some(guest_protocol_type) => {
548 if guest_protocol_type == TdispGuestProtocolType::Invalid {
549 tracing::error!(
550 ?guest_protocol_type,
551 "Guest protocol negotiated with invalid value"
552 );
553 Err(TdispGuestOperationError::InvalidGuestProtocolRequest)
554 } else {
555 self.guest_protocol_type = guest_protocol_type;
556 tracing::info!(
557 ?interface_info,
558 "Guest protocol negotiated successfully to"
559 );
560 Ok(interface_info)
561 }
562 }
563 None => {
564 tracing::error!(
565 ?interface_info,
566 "Guest protocol negotiated with none value"
567 );
568 Err(TdispGuestOperationError::InvalidGuestProtocolRequest)
569 }
570 }
571 }
572 Err(e) => {
573 tracing::error!(?e, "Failed to negotiate protocol with host interface");
574 Err(TdispGuestOperationError::HostFailedToProcessCommand)
575 }
576 }
577 }
578
579 #[instrument(fields(device_id = %self.debug_device_id), skip(self))]
580 fn request_lock_device_resources(&mut self) -> Result<(), TdispGuestOperationError> {
581 self.ensure_negotiated_protocol()
583 .map_err(|_| TdispGuestOperationError::InvalidDeviceState)?;
584
585 if self.current_state != TdispTdiState::Unlocked {
588 tracing::error!(
589 "Unlocked to Locked state called while device was not in Unlocked state."
590 );
591
592 self.unbind_all(TdispUnbindReason::InvalidGuestTransitionToLocked)
593 .map_err(|_| TdispGuestOperationError::HostFailedToProcessCommand)?;
594 return Err(TdispGuestOperationError::InvalidDeviceState);
595 }
596
597 tracing::info!("Device bind requested, trying to transition from Unlocked to Locked state");
598
599 let res = self
601 .host_interface
602 .lock()
603 .tdisp_bind_device()
604 .context("failed to call to bind TDI");
605
606 if let Err(e) = res {
607 tracing::error!("Failed to bind TDI: {e:?}");
608 return Err(TdispGuestOperationError::HostFailedToProcessCommand);
609 }
610
611 tracing::info!("Device transition from Unlocked to Locked state");
612 match self.transition_state_to(TdispTdiState::Locked) {
613 Ok(_) => {}
614 Err(e) => {
615 tracing::error!("Failed to transition to Locked state: {e:?}");
616 return Err(TdispGuestOperationError::HostFailedToProcessCommand);
617 }
618 }
619 Ok(())
620 }
621
622 #[instrument(fields(device_id = %self.debug_device_id), skip(self))]
623 fn request_start_tdi(&mut self) -> Result<(), TdispGuestOperationError> {
624 self.ensure_negotiated_protocol()
626 .map_err(|_| TdispGuestOperationError::InvalidDeviceState)?;
627
628 if self.current_state != TdispTdiState::Locked {
629 tracing::error!("StartTDI called while device was not in Locked state.");
630 self.unbind_all(TdispUnbindReason::InvalidGuestTransitionToRun)
631 .map_err(|_| TdispGuestOperationError::HostFailedToProcessCommand)?;
632
633 return Err(TdispGuestOperationError::InvalidDeviceState);
634 }
635
636 tracing::info!("Device start requested, trying to transition from Locked to Run state");
637
638 let res = self
640 .host_interface
641 .lock()
642 .tdisp_start_device()
643 .context("failed to call to start TDI");
644
645 if let Err(e) = res {
646 tracing::error!("Failed to start TDI: {e:?}");
647 return Err(TdispGuestOperationError::HostFailedToProcessCommand);
648 }
649
650 tracing::info!("Device transition from Locked to Run state");
651 match self.transition_state_to(TdispTdiState::Run) {
652 Ok(_) => {}
653 Err(e) => {
654 tracing::error!("Failed to transition to Run state: {e:?}");
655 return Err(TdispGuestOperationError::HostFailedToProcessCommand);
656 }
657 }
658
659 Ok(())
660 }
661
662 #[instrument(fields(device_id = %self.debug_device_id), skip(self))]
663 fn request_attestation_report(
664 &mut self,
665 report_type: TdispReportType,
666 ) -> Result<Vec<u8>, TdispGuestOperationError> {
667 self.ensure_negotiated_protocol()
669 .map_err(|_| TdispGuestOperationError::InvalidDeviceState)?;
670
671 if self.current_state != TdispTdiState::Locked && self.current_state != TdispTdiState::Run {
672 tracing::error!(
673 "Request to retrieve attestation report called while device was not in Locked or Run state."
674 );
675 self.unbind_all(TdispUnbindReason::InvalidGuestGetAttestationReportState)
676 .map_err(|_| TdispGuestOperationError::HostFailedToProcessCommand)?;
677
678 return Err(TdispGuestOperationError::InvalidGuestAttestationReportState);
679 }
680
681 if report_type == TdispReportType::Invalid {
682 tracing::error!("Invalid report type TdispReportId::INVALID requested");
683 return Err(TdispGuestOperationError::InvalidGuestAttestationReportType);
684 }
685
686 let report_buffer = self
687 .host_interface
688 .lock()
689 .tdisp_get_device_report(report_type)
690 .context("failed to call to get device report from host");
691
692 match report_buffer {
693 Ok(report_buffer) => {
694 tracing::info!("Retrieve attestation report called successfully");
695 Ok(report_buffer)
696 }
697 Err(e) => {
698 tracing::error!("Failed to get device report from host: {e:?}");
699 Err(TdispGuestOperationError::HostFailedToProcessCommand)
700 }
701 }
702 }
703
704 #[instrument(fields(device_id = %self.debug_device_id), skip(self))]
705 fn request_unbind(
706 &mut self,
707 reason: TdispGuestUnbindReason,
708 ) -> Result<(), TdispGuestOperationError> {
709 self.ensure_negotiated_protocol()
711 .map_err(|_| TdispGuestOperationError::InvalidDeviceState)?;
712
713 let reason = match reason {
717 TdispGuestUnbindReason::Graceful => TdispUnbindReason::GuestInitiated(reason),
718 _ => {
719 tracing::error!(
720 "Invalid guest unbind reason {} requested",
721 reason.as_str_name()
722 );
723 TdispUnbindReason::InvalidGuestUnbindReason(anyhow::anyhow!(
724 "Invalid guest unbind reason {} requested",
725 reason.as_str_name()
726 ))
727 }
728 };
729
730 tracing::info!(
731 "Guest request to unbind succeeds while device is in {:?} (reason: {:?})",
732 self.current_state,
733 reason
734 );
735
736 self.unbind_all(reason)
737 .map_err(|_| TdispGuestOperationError::HostFailedToProcessCommand)?;
738
739 Ok(())
740 }
741}