1#![cfg(target_os = "linux")]
11#![forbid(unsafe_code)]
12
13pub mod api;
14pub mod error;
15pub mod resolver;
16
17mod client;
18mod process_loop;
19mod worker;
20
21pub use client::GuestEmulationTransportClient;
22
23#[derive(Debug, thiserror::Error)]
25#[error("failed to initialize GET worker")]
26pub struct SpawnGetError(#[source] process_loop::FatalError);
27
28#[derive(Debug, thiserror::Error)]
32#[error("encountered fatal GET error")]
33pub struct FatalGetError(#[source] process_loop::FatalError);
34
35pub async fn spawn_get_worker(
38 driver: impl pal_async::driver::SpawnDriver,
39) -> Result<
40 (
41 GuestEmulationTransportClient,
42 pal_async::task::Task<Result<(), FatalGetError>>,
43 ),
44 SpawnGetError,
45> {
46 let (worker, task) = worker::GuestEmulationTransportWorker::new(driver)
47 .await
48 .map_err(SpawnGetError)?;
49 Ok((worker.new_client(), task))
50}
51
52#[cfg(any(feature = "test_utilities", test))]
53#[expect(missing_docs)]
54pub mod test_utilities {
55 use super::*;
56 use crate::worker::GuestEmulationTransportWorker;
57 use client::GuestEmulationTransportClient;
58 use get_protocol::ProtocolVersion;
59 use guest_emulation_device::IgvmAgentTestPlan;
60 use guest_emulation_device::test_utilities::TestGedClient;
61 use guest_emulation_device::test_utilities::TestGetResponses;
62 use mesh::Receiver;
63 use pal_async::task::Spawn;
64 use pal_async::task::Task;
65
66 pub const DEFAULT_SIZE: usize = 4194816; #[cfg_attr(not(test), expect(dead_code))]
69 pub struct TestGet {
70 pub client: GuestEmulationTransportClient,
71 pub(crate) gen_id: Receiver<[u8; 16]>,
72 pub(crate) guest_task: Task<Result<(), FatalGetError>>,
73 pub(crate) test_ged_client: TestGedClient,
74 }
75
76 pub async fn new_transport_pair(
82 spawn: impl Spawn,
83 ged_responses: Option<Vec<TestGetResponses>>,
84 version: ProtocolVersion,
85 guest_memory: Option<guestmem::GuestMemory>,
86 igvm_agent_script: Option<IgvmAgentTestPlan>,
87 ) -> TestGet {
88 let (host_vmbus, guest_vmbus) = vmbus_async::pipe::connected_message_pipes(
89 get_protocol::MAX_MESSAGE_SIZE + vmbus_ring::PAGE_SIZE,
90 );
91
92 let test_ged_client = guest_emulation_device::test_utilities::create_host_channel(
93 &spawn,
94 host_vmbus,
95 ged_responses,
96 version,
97 guest_memory,
98 igvm_agent_script,
99 );
100
101 let (guest_transport, guest_task) =
103 GuestEmulationTransportWorker::with_pipe(&spawn, guest_vmbus)
104 .await
105 .unwrap();
106
107 let client = guest_transport.new_client();
108
109 TestGet {
110 gen_id: client.take_generation_id_recv().await.unwrap(),
111 client,
112 guest_task,
113 test_ged_client,
114 }
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::test_utilities::*;
121 use super::worker::GuestEmulationTransportWorker;
122 use crate::process_loop::FatalError;
123 use get_protocol::ProtocolVersion;
124 use get_protocol::VmgsIoStatus;
125 use get_protocol::test_utilities::TEST_VMGS_SECTOR_SIZE;
126 use guest_emulation_device::test_utilities::Event;
127 use guest_emulation_device::test_utilities::TestGetResponses;
128 use pal_async::DefaultDriver;
129 use pal_async::async_test;
130 use pal_async::task::Spawn;
131 use test_with_tracing::test;
132 use vmbus_async::async_dgram::AsyncRecvExt;
133 use vmbus_async::async_dgram::AsyncSendExt;
134 use zerocopy::FromZeros;
135 use zerocopy::IntoBytes;
136
137 #[async_test]
138 async fn test_version_negotiation_failed(driver: DefaultDriver) {
139 let (mut host_vmbus, guest_vmbus) =
140 vmbus_async::pipe::connected_message_pipes(get_protocol::MAX_MESSAGE_SIZE);
141
142 let host_task = driver.spawn("host task", async move {
143 for protocol in [ProtocolVersion::NICKEL_REV2] {
144 let mut version_request = get_protocol::VersionRequest::new_zeroed();
145 let len = version_request.as_bytes().len();
146 assert_eq!(
147 len,
148 host_vmbus
149 .recv(version_request.as_mut_bytes())
150 .await
151 .unwrap()
152 );
153
154 assert_eq!(
155 version_request.message_header.message_id(),
156 get_protocol::HostRequests::VERSION
157 );
158 assert_eq!(version_request.version, protocol);
159
160 let version_response = get_protocol::VersionResponse::new(false);
162
163 host_vmbus.send(version_response.as_bytes()).await.unwrap();
164 }
165 });
166
167 let transport = GuestEmulationTransportWorker::with_pipe(driver, guest_vmbus).await;
168
169 match transport.unwrap_err() {
170 FatalError::VersionNegotiationFailed => {}
171 e => panic!("Wrong error type returned: {}", e),
172 }
173
174 host_task.await;
175 }
176
177 #[async_test]
178 async fn test_all_basic(driver: DefaultDriver) {
179 let time_zone = 5;
180 let utc = 3;
181
182 let time_response = TestGetResponses::new(Event::Response(
183 get_protocol::TimeResponse::new(0, utc, time_zone, false)
184 .as_bytes()
185 .to_vec(),
186 ));
187
188 let vmgs_device_info_response = TestGetResponses::new(Event::Response(
189 get_protocol::VmgsGetDeviceInfoResponse::new(VmgsIoStatus::SUCCESS, 1, 2, 3, 4)
190 .as_bytes()
191 .to_vec(),
192 ));
193
194 let flush_response = TestGetResponses::new(Event::Response(
195 get_protocol::VmgsFlushResponse::new(VmgsIoStatus::SUCCESS)
196 .as_bytes()
197 .to_vec(),
198 ));
199
200 let guest_state_protection = TestGetResponses::new(Event::Response(
201 get_protocol::GuestStateProtectionResponse {
202 message_header: get_protocol::HeaderGeneric::new(
203 get_protocol::HostRequests::GUEST_STATE_PROTECTION,
204 ),
205 encrypted_gsp: get_protocol::GspCiphertextContent::new_zeroed(),
206 decrypted_gsp: [get_protocol::GspCleartextContent::new_zeroed();
207 get_protocol::NUMBER_GSP as usize],
208 extended_status_flags: get_protocol::GspExtendedStatusFlags::new()
209 .with_state_refresh_request(true),
210 }
211 .as_bytes()
212 .to_vec(),
213 ));
214
215 let gsp_id = TestGetResponses::new(Event::Response(
216 get_protocol::GuestStateProtectionByIdResponse {
217 message_header: get_protocol::HeaderGeneric::new(
218 get_protocol::HostRequests::GUEST_STATE_PROTECTION_BY_ID,
219 ),
220 seed: get_protocol::GspCleartextContent::new_zeroed(),
221 extended_status_flags: get_protocol::GspExtendedStatusFlags::new()
222 .with_no_registry_file(true)
223 .with_state_refresh_request(true),
224 }
225 .as_bytes()
226 .to_vec(),
227 ));
228
229 let igvm_attest = TestGetResponses::new(Event::Response(
230 get_protocol::IgvmAttestResponse {
231 message_header: get_protocol::HeaderGeneric::new(
232 get_protocol::HostRequests::IGVM_ATTEST,
233 ),
234 length: 512,
235 }
236 .as_bytes()
237 .to_vec(),
238 ));
239
240 let ged_responses = vec![
241 time_response,
242 vmgs_device_info_response,
243 flush_response,
244 guest_state_protection,
245 gsp_id,
246 igvm_attest,
247 ];
248
249 let get = new_transport_pair(
250 driver,
251 Some(ged_responses),
252 ProtocolVersion::NICKEL_REV2,
253 None,
254 None,
255 )
256 .await;
257
258 let result = get.client.host_time().await;
259
260 assert_eq!(result.utc, utc);
261 assert_eq!(result.time_zone, time_zone);
262
263 let response = get.client.vmgs_get_device_info().await.unwrap();
264 assert_eq!(response.capacity, 1);
265 assert_eq!(response.bytes_per_logical_sector, 2);
266 assert_eq!(response.bytes_per_physical_sector, 3);
267 assert_eq!(response.maximum_transfer_size_bytes, 4);
268
269 get.client.vmgs_flush().await.unwrap();
270
271 let gsp_response = get
272 .client
273 .guest_state_protection_data(
274 [get_protocol::GspCiphertextContent::new_zeroed();
275 get_protocol::NUMBER_GSP as usize],
276 get_protocol::GspExtendedStatusFlags::new().with_state_refresh_request(true),
277 )
278 .await;
279
280 assert_eq!(
281 gsp_response.extended_status_flags,
282 get_protocol::GspExtendedStatusFlags::new().with_state_refresh_request(true)
283 );
284
285 let gsp_id_response = get
286 .client
287 .guest_state_protection_data_by_id()
288 .await
289 .unwrap();
290
291 assert_eq!(
292 gsp_id_response.extended_status_flags,
293 get_protocol::GspExtendedStatusFlags::new()
294 .with_no_registry_file(true)
295 .with_state_refresh_request(true)
296 );
297 }
298
299 #[async_test]
300 async fn test_vmgs_basic_write(driver: DefaultDriver) {
301 let vmgs_write_response = TestGetResponses::new(Event::Response(
302 get_protocol::VmgsWriteResponse::new(VmgsIoStatus::SUCCESS)
303 .as_bytes()
304 .to_vec(),
305 ));
306
307 let vmgs_read_response = TestGetResponses::new(Event::Response(
308 get_protocol::VmgsReadResponse::new(VmgsIoStatus::SUCCESS)
309 .as_bytes()
310 .to_vec(),
311 ));
312 let ged_responses = vec![vmgs_write_response, vmgs_read_response];
313
314 let get = new_transport_pair(
315 driver,
316 Some(ged_responses),
317 ProtocolVersion::NICKEL_REV2,
318 None,
319 None,
320 )
321 .await;
322 let buf = (0..512).map(|x| x as u8).collect::<Vec<u8>>();
323 get.client
324 .vmgs_write(0, buf.clone(), TEST_VMGS_SECTOR_SIZE)
325 .await
326 .unwrap();
327
328 let read_buf = get
329 .client
330 .vmgs_read(0, 1, TEST_VMGS_SECTOR_SIZE)
331 .await
332 .unwrap();
333 assert_eq!(read_buf, buf);
334 }
335
336 #[async_test]
337 async fn different_get_versions_nickel(driver: DefaultDriver) {
338 let json = get_protocol::dps_json::DevicePlatformSettingsV2Json {
340 v1: get_protocol::dps_json::HclDevicePlatformSettings {
341 com1: get_protocol::dps_json::HclUartSettings {
342 enable_port: true,
343 debugger_mode: false,
344 enable_vmbus_redirector: true,
345 },
346 com2: get_protocol::dps_json::HclUartSettings {
347 enable_port: false,
348 debugger_mode: false,
349 enable_vmbus_redirector: false,
350 },
351 enable_firmware_debugging: true,
352 ..Default::default()
353 },
354 v2: get_protocol::dps_json::HclDevicePlatformSettingsV2 {
355 r#static: get_protocol::dps_json::HclDevicePlatformSettingsV2Static {
356 legacy_memory_map: true,
357 pxe_ip_v6: true,
358 ..Default::default()
359 },
360 ..Default::default()
361 },
362 };
363 let json_data = serde_json::to_vec(&json).unwrap();
364
365 let mut dps_response = get_protocol::DevicePlatformSettingsResponseV2Rev1::new_zeroed();
366 dps_response.message_header = get_protocol::HeaderGeneric::new(
367 get_protocol::HostRequests::DEVICE_PLATFORM_SETTINGS_V2_REV1,
368 );
369 dps_response.size = json_data.len() as u32;
370 dps_response.payload_state = get_protocol::LargePayloadState::END;
371
372 let device_platform_settings = TestGetResponses::new(Event::Response(
373 [dps_response.as_bytes().to_vec(), json_data].concat(),
374 ));
375
376 let ged_responses = vec![device_platform_settings];
377
378 let get = new_transport_pair(
379 driver,
380 Some(ged_responses),
381 ProtocolVersion::NICKEL_REV2,
382 None,
383 None,
384 )
385 .await;
386
387 let dps = get.client.device_platform_settings().await.unwrap();
388 assert_eq!(dps.general.tpm_enabled, false);
389 assert_eq!(dps.general.com1_enabled, true);
390 assert_eq!(dps.general.secure_boot_enabled, false);
391
392 assert_eq!(dps.general.legacy_memory_map, true);
393 assert_eq!(dps.general.pxe_ip_v6, true);
394 assert_eq!(dps.general.nvdimm_count, 0);
395
396 assert_eq!(dps.general.generation_id, Some([0; 16]));
397 }
398
399 #[async_test]
400 async fn test_send_notification(driver: DefaultDriver) {
401 let vmgs_read_response = TestGetResponses::new(Event::Response(
404 get_protocol::VmgsReadResponse::new(VmgsIoStatus::SUCCESS)
405 .as_bytes()
406 .to_vec(),
407 ));
408 let ged_responses = vec![TestGetResponses::default(), vmgs_read_response];
409
410 let get = new_transport_pair(
411 driver,
412 Some(ged_responses),
413 ProtocolVersion::NICKEL_REV2,
414 None,
415 None,
416 )
417 .await;
418
419 get.client
420 .event_log(get_protocol::EventLogId::NO_BOOT_DEVICE);
421
422 let read_buf = get
423 .client
424 .vmgs_read(0, 1, TEST_VMGS_SECTOR_SIZE)
425 .await
426 .unwrap();
427
428 assert_eq!(read_buf[0], 5);
429 }
430
431 #[async_test]
432 async fn notification_in_between_requests(driver: DefaultDriver) {
433 let time_response = TestGetResponses::new(Event::Response(
434 get_protocol::UpdateGenerationId::new([1; 16])
435 .as_bytes()
436 .to_vec(),
437 ))
438 .add_response(Event::Response(
439 get_protocol::TimeResponse::new(0, 1, 2, false)
440 .as_bytes()
441 .to_vec(),
442 ));
443
444 let ged_responses = vec![time_response];
445
446 let mut get = new_transport_pair(
447 driver,
448 Some(ged_responses),
449 ProtocolVersion::NICKEL_REV2,
450 None,
451 None,
452 )
453 .await;
454
455 let result = get.client.host_time().await;
456
457 let gen_id = get.gen_id.recv().await.unwrap();
458
459 assert_eq!(gen_id, [1; 16]);
460
461 assert_eq!(result.utc, 1);
462 assert_eq!(result.time_zone, 2);
463 }
464
465 #[async_test]
466 async fn host_send_multiple_response(driver: DefaultDriver) {
467 let time_response = TestGetResponses::new(Event::Response(
468 get_protocol::TimeResponse::new(0, 1, 2, false)
469 .as_bytes()
470 .to_vec(),
471 ))
472 .add_response(Event::Response(
473 get_protocol::TimeResponse::new(0, 1, 2, false)
474 .as_bytes()
475 .to_vec(),
476 ));
477
478 let ged_responses = vec![time_response];
479
480 let get = new_transport_pair(
481 driver,
482 Some(ged_responses),
483 ProtocolVersion::NICKEL_REV2,
484 None,
485 None,
486 )
487 .await;
488
489 let result = get.client.host_time().await;
490
491 assert_eq!(result.utc, 1);
492 assert_eq!(result.time_zone, 2);
493
494 assert!(futures::poll!(get.guest_task).is_pending());
496 }
497
498 #[async_test]
499 async fn host_send_mismatched_multiple_response(driver: DefaultDriver) {
500 let responses = TestGetResponses::new(Event::Response(
501 get_protocol::VpciDeviceControlResponse::new(
502 get_protocol::VpciDeviceControlStatus::SUCCESS,
503 )
504 .as_bytes()
505 .to_vec(),
506 ))
507 .add_response(Event::Response(
508 get_protocol::TimeResponse::new(0, 1, 2, false)
509 .as_bytes()
510 .to_vec(),
511 ))
512 .add_response(Event::Response(
513 get_protocol::VpciDeviceControlResponse::new(
514 get_protocol::VpciDeviceControlStatus::SUCCESS,
515 )
516 .as_bytes()
517 .to_vec(),
518 ));
519
520 let ged_responses = vec![responses];
521
522 let get = new_transport_pair(
523 driver,
524 Some(ged_responses),
525 ProtocolVersion::NICKEL_REV2,
526 None,
527 None,
528 )
529 .await;
530
531 let time_req = get.client.host_time();
532
533 let result = time_req.await;
534 assert_eq!(result.utc, 1);
535 assert_eq!(result.time_zone, 2);
536 }
537
538 #[async_test]
539 async fn host_send_mismatched_multiple_request_response(driver: DefaultDriver) {
540 let responses = TestGetResponses::new(Event::Response(
541 get_protocol::TimeResponse::new(0, 1, 2, false)
542 .as_bytes()
543 .to_vec(),
544 ))
545 .add_response(Event::Response(
546 get_protocol::VpciDeviceControlResponse::new(
547 get_protocol::VpciDeviceControlStatus::SUCCESS,
548 )
549 .as_bytes()
550 .to_vec(),
551 ))
552 .add_response(Event::Response(
553 get_protocol::TimeResponse::new(0, 1, 2, false)
554 .as_bytes()
555 .to_vec(),
556 ))
557 .add_response(Event::Response(
558 get_protocol::VpciDeviceControlResponse::new(
559 get_protocol::VpciDeviceControlStatus::SUCCESS,
560 )
561 .as_bytes()
562 .to_vec(),
563 ));
564
565 let ged_responses = vec![responses];
566
567 let get = new_transport_pair(
568 driver,
569 Some(ged_responses),
570 ProtocolVersion::NICKEL_REV2,
571 None,
572 None,
573 )
574 .await;
575
576 let time_req = get.client.host_time();
577 let mut vpci_req = std::pin::pin!(get.client.offer_vpci_device(guid::Guid::new_random()));
578
579 assert!(futures::poll!(&mut vpci_req).is_pending());
581
582 let result = time_req.await;
584 assert_eq!(result.utc, 1);
585 assert_eq!(result.time_zone, 2);
586
587 vpci_req.await.unwrap();
589 }
590
591 #[async_test]
592 async fn host_send_incorrect_response(driver: DefaultDriver) {
593 let time_response = TestGetResponses::new(Event::Response(
594 get_protocol::TimeResponse::new(0, 1, 2, false)
595 .as_bytes()
596 .to_vec(),
597 ));
598
599 let ged_responses = vec![time_response];
600
601 let get = new_transport_pair(
602 driver.clone(),
603 Some(ged_responses),
604 ProtocolVersion::NICKEL_REV2,
605 None,
606 None,
607 )
608 .await;
609
610 let _never_returns = driver.spawn("badness", async move {
611 let _ = get.client.vmgs_get_device_info().await;
612 });
613
614 let internal_error = get.guest_task.await;
615
616 assert!(matches!(
617 internal_error.map_err(|x| x.0),
618 Err(FatalError::ResponseHeaderMismatchId(_, _))
619 ));
620 }
621
622 #[async_test]
623 async fn test_send_halt_reason(driver: DefaultDriver) {
624 let power_off_check =
625 TestGetResponses::new(Event::Halt(power_resources::PowerRequest::PowerOff));
626 let reset_check = TestGetResponses::new(Event::Halt(power_resources::PowerRequest::Reset));
627
628 let vmgs_device_info_response = TestGetResponses::new(Event::Response(
629 get_protocol::VmgsGetDeviceInfoResponse::new(VmgsIoStatus::SUCCESS, 1, 2, 3, 4)
630 .as_bytes()
631 .to_vec(),
632 ));
633
634 let ged_responses = vec![power_off_check, reset_check, vmgs_device_info_response];
635
636 let get = new_transport_pair(
637 driver,
638 Some(ged_responses),
639 ProtocolVersion::NICKEL_REV2,
640 None,
641 None,
642 )
643 .await;
644
645 get.client.send_power_off();
646 get.client.send_reset();
647
648 let response = get.client.vmgs_get_device_info().await.unwrap();
653 assert_eq!(response.capacity, 1);
654 assert_eq!(response.bytes_per_logical_sector, 2);
655 assert_eq!(response.bytes_per_physical_sector, 3);
656 assert_eq!(response.maximum_transfer_size_bytes, 4);
657 }
658
659 #[async_test]
660 async fn test_send_multiple_host_request(driver: DefaultDriver) {
661 let time_response = TestGetResponses::new(Event::Response(
662 get_protocol::TimeResponse::new(0, 1, 2, false)
663 .as_bytes()
664 .to_vec(),
665 ));
666
667 let ged_responses = vec![
668 time_response.clone(),
669 time_response.clone(),
670 time_response.clone(),
671 time_response.clone(),
672 time_response.clone(),
673 time_response,
674 ];
675
676 let get = new_transport_pair(
677 driver.clone(),
678 Some(ged_responses),
679 ProtocolVersion::NICKEL_REV2,
680 None,
681 None,
682 )
683 .await;
684
685 let mut tasks = Vec::new();
686
687 for i in 0..6 {
688 let client = get.client.clone();
689 tasks.push(driver.spawn(
690 format!("task {}", i),
691 async move { client.host_time().await },
692 ));
693 }
694
695 for task in tasks {
699 let time = task.await;
700 assert_eq!(time.utc, 1);
701 assert_eq!(time.time_zone, 2);
702 }
703 }
704 #[async_test]
705 async fn test_vpci_control(driver: DefaultDriver) {
706 let bus_id = guid::Guid::new_random();
707 let vpci_offer_response = TestGetResponses::new(Event::Response(
708 get_protocol::VpciDeviceControlResponse::new(
709 get_protocol::VpciDeviceControlStatus::SUCCESS,
710 )
711 .as_bytes()
712 .to_vec(),
713 ));
714
715 let vpci_revoke_response = TestGetResponses::new(Event::Response(
716 get_protocol::VpciDeviceControlResponse::new(
717 get_protocol::VpciDeviceControlStatus::SUCCESS,
718 )
719 .as_bytes()
720 .to_vec(),
721 ));
722
723 let vpci_bind_response = TestGetResponses::new(Event::Response(
724 get_protocol::VpciDeviceBindingChangeResponse::new(
725 bus_id,
726 get_protocol::VpciDeviceControlStatus::SUCCESS,
727 )
728 .as_bytes()
729 .to_vec(),
730 ));
731
732 let vpci_unbind_response = TestGetResponses::new(Event::Response(
733 get_protocol::VpciDeviceBindingChangeResponse::new(
734 bus_id,
735 get_protocol::VpciDeviceControlStatus::SUCCESS,
736 )
737 .as_bytes()
738 .to_vec(),
739 ));
740
741 let ged_responses = vec![
742 vpci_offer_response,
743 vpci_revoke_response,
744 vpci_bind_response,
745 vpci_unbind_response,
746 ];
747
748 let get = new_transport_pair(
749 driver,
750 Some(ged_responses),
751 ProtocolVersion::NICKEL_REV2,
752 None,
753 None,
754 )
755 .await;
756 get.client.offer_vpci_device(bus_id).await.unwrap();
757 get.client.revoke_vpci_device(bus_id).await.unwrap();
758 get.client
759 .report_vpci_device_binding_state(bus_id, true)
760 .await
761 .unwrap();
762 get.client
763 .report_vpci_device_binding_state(bus_id, false)
764 .await
765 .unwrap();
766
767 get.client.connect_to_vpci_event_source(bus_id).await;
768 get.client.disconnect_from_vpci_event_source(bus_id);
769 }
770
771 #[ignore]
773 #[async_test]
774 async fn test_save_guest_vtl2_state(driver: DefaultDriver) {
775 let mut get =
776 new_transport_pair(driver, None, ProtocolVersion::NICKEL_REV2, None, None).await;
777
778 get.test_ged_client.test_save_guest_vtl2_state().await;
779 }
780}