1#![cfg(target_os = "linux")]
11#![forbid(unsafe_code)]
12
13pub mod api;
14pub mod error;
15pub mod resolver;
16
17mod client;
18mod process_loop;
19mod worker;
20
21pub use client::GuestEmulationTransportClient;
22
23#[derive(Debug, thiserror::Error)]
25#[error("failed to initialize GET worker")]
26pub struct SpawnGetError(#[source] process_loop::FatalError);
27
28#[derive(Debug, thiserror::Error)]
32#[error("encountered fatal GET error")]
33pub struct FatalGetError(#[source] process_loop::FatalError);
34
35pub async fn spawn_get_worker(
38 driver: impl pal_async::driver::SpawnDriver,
39) -> Result<
40 (
41 GuestEmulationTransportClient,
42 pal_async::task::Task<Result<(), FatalGetError>>,
43 ),
44 SpawnGetError,
45> {
46 let (worker, task) = worker::GuestEmulationTransportWorker::new(driver)
47 .await
48 .map_err(SpawnGetError)?;
49 Ok((worker.new_client(), task))
50}
51
52#[cfg(any(feature = "test_utilities", test))]
53#[expect(missing_docs)]
54pub mod test_utilities {
55 use super::*;
56 use crate::worker::GuestEmulationTransportWorker;
57 use client::GuestEmulationTransportClient;
58 use get_protocol::ProtocolVersion;
59 use guest_emulation_device::test_utilities::TestGedClient;
60 use guest_emulation_device::test_utilities::TestGetResponses;
61 use mesh::Receiver;
62 use pal_async::task::Spawn;
63 use pal_async::task::Task;
64
65 pub const DEFAULT_SIZE: usize = 4194816; #[cfg_attr(not(test), expect(dead_code))]
68 pub struct TestGet {
69 pub client: GuestEmulationTransportClient,
70 pub(crate) gen_id: Receiver<[u8; 16]>,
71 pub(crate) guest_task: Task<Result<(), FatalGetError>>,
72 pub(crate) test_ged_client: TestGedClient,
73 }
74
75 pub async fn new_transport_pair(
81 spawn: impl Spawn,
82 ged_responses: Option<Vec<TestGetResponses>>,
83 version: ProtocolVersion,
84 ) -> TestGet {
85 let (host_vmbus, guest_vmbus) = vmbus_async::pipe::connected_message_pipes(
86 get_protocol::MAX_MESSAGE_SIZE + vmbus_ring::PAGE_SIZE,
87 );
88
89 let test_ged_client = guest_emulation_device::test_utilities::create_host_channel(
90 &spawn,
91 host_vmbus,
92 ged_responses,
93 version,
94 );
95
96 let (guest_transport, guest_task) =
98 GuestEmulationTransportWorker::with_pipe(&spawn, guest_vmbus)
99 .await
100 .unwrap();
101
102 let client = guest_transport.new_client();
103
104 TestGet {
105 gen_id: client.take_generation_id_recv().await.unwrap(),
106 client,
107 guest_task,
108 test_ged_client,
109 }
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::test_utilities::*;
116 use super::worker::GuestEmulationTransportWorker;
117 use crate::process_loop::FatalError;
118 use get_protocol::ProtocolVersion;
119 use get_protocol::VmgsIoStatus;
120 use get_protocol::test_utilities::TEST_VMGS_SECTOR_SIZE;
121 use guest_emulation_device::test_utilities::Event;
122 use guest_emulation_device::test_utilities::TestGetResponses;
123 use pal_async::DefaultDriver;
124 use pal_async::async_test;
125 use pal_async::task::Spawn;
126 use test_with_tracing::test;
127 use vmbus_async::async_dgram::AsyncRecvExt;
128 use vmbus_async::async_dgram::AsyncSendExt;
129 use zerocopy::FromZeros;
130 use zerocopy::IntoBytes;
131
132 #[async_test]
133 async fn test_version_negotiation_failed(driver: DefaultDriver) {
134 let (mut host_vmbus, guest_vmbus) =
135 vmbus_async::pipe::connected_message_pipes(get_protocol::MAX_MESSAGE_SIZE);
136
137 let host_task = driver.spawn("host task", async move {
138 for protocol in [ProtocolVersion::NICKEL_REV2] {
139 let mut version_request = get_protocol::VersionRequest::new_zeroed();
140 let len = version_request.as_bytes().len();
141 assert_eq!(
142 len,
143 host_vmbus
144 .recv(version_request.as_mut_bytes())
145 .await
146 .unwrap()
147 );
148
149 assert_eq!(
150 version_request.message_header.message_id(),
151 get_protocol::HostRequests::VERSION
152 );
153 assert_eq!(version_request.version, protocol);
154
155 let version_response = get_protocol::VersionResponse::new(false);
157
158 host_vmbus.send(version_response.as_bytes()).await.unwrap();
159 }
160 });
161
162 let transport = GuestEmulationTransportWorker::with_pipe(driver, guest_vmbus).await;
163
164 match transport.unwrap_err() {
165 FatalError::VersionNegotiationFailed => {}
166 e => panic!("Wrong error type returned: {}", e),
167 }
168
169 host_task.await;
170 }
171
172 #[async_test]
173 async fn test_all_basic(driver: DefaultDriver) {
174 let time_zone = 5;
175 let utc = 3;
176
177 let time_response = TestGetResponses::new(Event::Response(
178 get_protocol::TimeResponse::new(0, utc, time_zone, false)
179 .as_bytes()
180 .to_vec(),
181 ));
182
183 let vmgs_device_info_response = TestGetResponses::new(Event::Response(
184 get_protocol::VmgsGetDeviceInfoResponse::new(VmgsIoStatus::SUCCESS, 1, 2, 3, 4)
185 .as_bytes()
186 .to_vec(),
187 ));
188
189 let flush_response = TestGetResponses::new(Event::Response(
190 get_protocol::VmgsFlushResponse::new(VmgsIoStatus::SUCCESS)
191 .as_bytes()
192 .to_vec(),
193 ));
194
195 let guest_state_protection = TestGetResponses::new(Event::Response(
196 get_protocol::GuestStateProtectionResponse {
197 message_header: get_protocol::HeaderGeneric::new(
198 get_protocol::HostRequests::GUEST_STATE_PROTECTION,
199 ),
200 encrypted_gsp: get_protocol::GspCiphertextContent::new_zeroed(),
201 decrypted_gsp: [get_protocol::GspCleartextContent::new_zeroed();
202 get_protocol::NUMBER_GSP as usize],
203 extended_status_flags: get_protocol::GspExtendedStatusFlags::new()
204 .with_state_refresh_request(true),
205 }
206 .as_bytes()
207 .to_vec(),
208 ));
209
210 let gsp_id = TestGetResponses::new(Event::Response(
211 get_protocol::GuestStateProtectionByIdResponse {
212 message_header: get_protocol::HeaderGeneric::new(
213 get_protocol::HostRequests::GUEST_STATE_PROTECTION_BY_ID,
214 ),
215 seed: get_protocol::GspCleartextContent::new_zeroed(),
216 extended_status_flags: get_protocol::GspExtendedStatusFlags::new()
217 .with_no_registry_file(true)
218 .with_state_refresh_request(true),
219 }
220 .as_bytes()
221 .to_vec(),
222 ));
223
224 let igvm_attest = TestGetResponses::new(Event::Response(
225 get_protocol::IgvmAttestResponse {
226 message_header: get_protocol::HeaderGeneric::new(
227 get_protocol::HostRequests::IGVM_ATTEST,
228 ),
229 length: 512,
230 }
231 .as_bytes()
232 .to_vec(),
233 ));
234
235 let ged_responses = vec![
236 time_response,
237 vmgs_device_info_response,
238 flush_response,
239 guest_state_protection,
240 gsp_id,
241 igvm_attest,
242 ];
243
244 let get =
245 new_transport_pair(driver, Some(ged_responses), ProtocolVersion::NICKEL_REV2).await;
246
247 let result = get.client.host_time().await;
248
249 assert_eq!(result.utc, utc);
250 assert_eq!(result.time_zone, time_zone);
251
252 let response = get.client.vmgs_get_device_info().await.unwrap();
253 assert_eq!(response.capacity, 1);
254 assert_eq!(response.bytes_per_logical_sector, 2);
255 assert_eq!(response.bytes_per_physical_sector, 3);
256 assert_eq!(response.maximum_transfer_size_bytes, 4);
257
258 get.client.vmgs_flush().await.unwrap();
259
260 let gsp_response = get
261 .client
262 .guest_state_protection_data(
263 [get_protocol::GspCiphertextContent::new_zeroed();
264 get_protocol::NUMBER_GSP as usize],
265 get_protocol::GspExtendedStatusFlags::new().with_state_refresh_request(true),
266 )
267 .await;
268
269 assert_eq!(
270 gsp_response.extended_status_flags,
271 get_protocol::GspExtendedStatusFlags::new().with_state_refresh_request(true)
272 );
273
274 let gsp_id_response = get
275 .client
276 .guest_state_protection_data_by_id()
277 .await
278 .unwrap();
279
280 assert_eq!(
281 gsp_id_response.extended_status_flags,
282 get_protocol::GspExtendedStatusFlags::new()
283 .with_no_registry_file(true)
284 .with_state_refresh_request(true)
285 );
286 }
287
288 #[async_test]
289 async fn test_vmgs_basic_write(driver: DefaultDriver) {
290 let vmgs_write_response = TestGetResponses::new(Event::Response(
291 get_protocol::VmgsWriteResponse::new(VmgsIoStatus::SUCCESS)
292 .as_bytes()
293 .to_vec(),
294 ));
295
296 let vmgs_read_response = TestGetResponses::new(Event::Response(
297 get_protocol::VmgsReadResponse::new(VmgsIoStatus::SUCCESS)
298 .as_bytes()
299 .to_vec(),
300 ));
301 let ged_responses = vec![vmgs_write_response, vmgs_read_response];
302
303 let get =
304 new_transport_pair(driver, Some(ged_responses), ProtocolVersion::NICKEL_REV2).await;
305 let buf = (0..512).map(|x| x as u8).collect::<Vec<u8>>();
306 get.client
307 .vmgs_write(0, buf.clone(), TEST_VMGS_SECTOR_SIZE)
308 .await
309 .unwrap();
310
311 let read_buf = get
312 .client
313 .vmgs_read(0, 1, TEST_VMGS_SECTOR_SIZE)
314 .await
315 .unwrap();
316 assert_eq!(read_buf, buf);
317 }
318
319 #[async_test]
320 async fn different_get_versions_nickel(driver: DefaultDriver) {
321 let json = get_protocol::dps_json::DevicePlatformSettingsV2Json {
323 v1: get_protocol::dps_json::HclDevicePlatformSettings {
324 com1: get_protocol::dps_json::HclUartSettings {
325 enable_port: true,
326 debugger_mode: false,
327 enable_vmbus_redirector: true,
328 },
329 com2: get_protocol::dps_json::HclUartSettings {
330 enable_port: false,
331 debugger_mode: false,
332 enable_vmbus_redirector: false,
333 },
334 enable_firmware_debugging: true,
335 ..Default::default()
336 },
337 v2: get_protocol::dps_json::HclDevicePlatformSettingsV2 {
338 r#static: get_protocol::dps_json::HclDevicePlatformSettingsV2Static {
339 legacy_memory_map: true,
340 pxe_ip_v6: true,
341 ..Default::default()
342 },
343 ..Default::default()
344 },
345 };
346 let json_data = serde_json::to_vec(&json).unwrap();
347
348 let mut dps_response = get_protocol::DevicePlatformSettingsResponseV2Rev1::new_zeroed();
349 dps_response.message_header = get_protocol::HeaderGeneric::new(
350 get_protocol::HostRequests::DEVICE_PLATFORM_SETTINGS_V2_REV1,
351 );
352 dps_response.size = json_data.len() as u32;
353 dps_response.payload_state = get_protocol::LargePayloadState::END;
354
355 let device_platform_settings = TestGetResponses::new(Event::Response(
356 [dps_response.as_bytes().to_vec(), json_data].concat(),
357 ));
358
359 let ged_responses = vec![device_platform_settings];
360
361 let get =
362 new_transport_pair(driver, Some(ged_responses), ProtocolVersion::NICKEL_REV2).await;
363
364 let dps = get.client.device_platform_settings().await.unwrap();
365 assert_eq!(dps.general.tpm_enabled, false);
366 assert_eq!(dps.general.com1_enabled, true);
367 assert_eq!(dps.general.secure_boot_enabled, false);
368
369 assert_eq!(dps.general.legacy_memory_map, true);
370 assert_eq!(dps.general.pxe_ip_v6, true);
371 assert_eq!(dps.general.nvdimm_count, 0);
372
373 assert_eq!(dps.general.generation_id, Some([0; 16]));
374 }
375
376 #[async_test]
377 async fn test_send_notification(driver: DefaultDriver) {
378 let vmgs_read_response = TestGetResponses::new(Event::Response(
381 get_protocol::VmgsReadResponse::new(VmgsIoStatus::SUCCESS)
382 .as_bytes()
383 .to_vec(),
384 ));
385 let ged_responses = vec![TestGetResponses::default(), vmgs_read_response];
386
387 let get =
388 new_transport_pair(driver, Some(ged_responses), ProtocolVersion::NICKEL_REV2).await;
389
390 get.client
391 .event_log(get_protocol::EventLogId::NO_BOOT_DEVICE);
392
393 let read_buf = get
394 .client
395 .vmgs_read(0, 1, TEST_VMGS_SECTOR_SIZE)
396 .await
397 .unwrap();
398
399 assert_eq!(read_buf[0], 5);
400 }
401
402 #[async_test]
403 async fn notification_in_between_requests(driver: DefaultDriver) {
404 let time_response = TestGetResponses::new(Event::Response(
405 get_protocol::UpdateGenerationId::new([1; 16])
406 .as_bytes()
407 .to_vec(),
408 ))
409 .add_response(Event::Response(
410 get_protocol::TimeResponse::new(0, 1, 2, false)
411 .as_bytes()
412 .to_vec(),
413 ));
414
415 let ged_responses = vec![time_response];
416
417 let mut get =
418 new_transport_pair(driver, Some(ged_responses), ProtocolVersion::NICKEL_REV2).await;
419
420 let result = get.client.host_time().await;
421
422 let gen_id = get.gen_id.recv().await.unwrap();
423
424 assert_eq!(gen_id, [1; 16]);
425
426 assert_eq!(result.utc, 1);
427 assert_eq!(result.time_zone, 2);
428 }
429
430 #[async_test]
431 async fn host_send_multiple_response(driver: DefaultDriver) {
432 let time_response = TestGetResponses::new(Event::Response(
433 get_protocol::TimeResponse::new(0, 1, 2, false)
434 .as_bytes()
435 .to_vec(),
436 ))
437 .add_response(Event::Response(
438 get_protocol::TimeResponse::new(0, 1, 2, false)
439 .as_bytes()
440 .to_vec(),
441 ));
442
443 let ged_responses = vec![time_response];
444
445 let get =
446 new_transport_pair(driver, Some(ged_responses), ProtocolVersion::NICKEL_REV2).await;
447
448 let result = get.client.host_time().await;
449
450 assert_eq!(result.utc, 1);
451 assert_eq!(result.time_zone, 2);
452
453 let _host_result = get.guest_task.await;
454
455 assert!(matches!(FatalError::NoPendingRequest, _host_result));
456 }
457
458 #[async_test]
459 async fn host_send_incorrect_response(driver: DefaultDriver) {
460 let time_response = TestGetResponses::new(Event::Response(
461 get_protocol::TimeResponse::new(0, 1, 2, false)
462 .as_bytes()
463 .to_vec(),
464 ));
465
466 let ged_responses = vec![time_response];
467
468 let get = new_transport_pair(
469 driver.clone(),
470 Some(ged_responses),
471 ProtocolVersion::NICKEL_REV2,
472 )
473 .await;
474
475 let _never_returns = driver.spawn("badness", async move {
476 let _ = get.client.vmgs_get_device_info().await;
477 });
478
479 let internal_error = get.guest_task.await;
480
481 assert!(matches!(
482 internal_error.map_err(|x| x.0),
483 Err(FatalError::ResponseHeaderMismatchId(_, _))
484 ));
485 }
486
487 #[async_test]
488 async fn test_send_halt_reason(driver: DefaultDriver) {
489 let power_off_check =
490 TestGetResponses::new(Event::Halt(power_resources::PowerRequest::PowerOff));
491 let reset_check = TestGetResponses::new(Event::Halt(power_resources::PowerRequest::Reset));
492
493 let vmgs_device_info_response = TestGetResponses::new(Event::Response(
494 get_protocol::VmgsGetDeviceInfoResponse::new(VmgsIoStatus::SUCCESS, 1, 2, 3, 4)
495 .as_bytes()
496 .to_vec(),
497 ));
498
499 let ged_responses = vec![power_off_check, reset_check, vmgs_device_info_response];
500
501 let get =
502 new_transport_pair(driver, Some(ged_responses), ProtocolVersion::NICKEL_REV2).await;
503
504 get.client.send_power_off();
505 get.client.send_reset();
506
507 let response = get.client.vmgs_get_device_info().await.unwrap();
512 assert_eq!(response.capacity, 1);
513 assert_eq!(response.bytes_per_logical_sector, 2);
514 assert_eq!(response.bytes_per_physical_sector, 3);
515 assert_eq!(response.maximum_transfer_size_bytes, 4);
516 }
517
518 #[async_test]
519 async fn test_send_multiple_host_request(driver: DefaultDriver) {
520 let time_response = TestGetResponses::new(Event::Response(
521 get_protocol::TimeResponse::new(0, 1, 2, false)
522 .as_bytes()
523 .to_vec(),
524 ));
525
526 let ged_responses = vec![
527 time_response.clone(),
528 time_response.clone(),
529 time_response.clone(),
530 time_response.clone(),
531 time_response.clone(),
532 time_response,
533 ];
534
535 let get = new_transport_pair(
536 driver.clone(),
537 Some(ged_responses),
538 ProtocolVersion::NICKEL_REV2,
539 )
540 .await;
541
542 let mut tasks = Vec::new();
543
544 for i in 0..6 {
545 let client = get.client.clone();
546 tasks.push(driver.spawn(
547 format!("task {}", i),
548 async move { client.host_time().await },
549 ));
550 }
551
552 for task in tasks {
556 let time = task.await;
557 assert_eq!(time.utc, 1);
558 assert_eq!(time.time_zone, 2);
559 }
560 }
561 #[async_test]
562 async fn test_vpci_control(driver: DefaultDriver) {
563 let bus_id = guid::Guid::new_random();
564 let vpci_offer_response = TestGetResponses::new(Event::Response(
565 get_protocol::VpciDeviceControlResponse::new(
566 get_protocol::VpciDeviceControlStatus::SUCCESS,
567 )
568 .as_bytes()
569 .to_vec(),
570 ));
571
572 let vpci_revoke_response = TestGetResponses::new(Event::Response(
573 get_protocol::VpciDeviceControlResponse::new(
574 get_protocol::VpciDeviceControlStatus::SUCCESS,
575 )
576 .as_bytes()
577 .to_vec(),
578 ));
579
580 let vpci_bind_response = TestGetResponses::new(Event::Response(
581 get_protocol::VpciDeviceBindingChangeResponse::new(
582 bus_id,
583 get_protocol::VpciDeviceControlStatus::SUCCESS,
584 )
585 .as_bytes()
586 .to_vec(),
587 ));
588
589 let vpci_unbind_response = TestGetResponses::new(Event::Response(
590 get_protocol::VpciDeviceBindingChangeResponse::new(
591 bus_id,
592 get_protocol::VpciDeviceControlStatus::SUCCESS,
593 )
594 .as_bytes()
595 .to_vec(),
596 ));
597
598 let ged_responses = vec![
599 vpci_offer_response,
600 vpci_revoke_response,
601 vpci_bind_response,
602 vpci_unbind_response,
603 ];
604
605 let get =
606 new_transport_pair(driver, Some(ged_responses), ProtocolVersion::NICKEL_REV2).await;
607 get.client.offer_vpci_device(bus_id).await.unwrap();
608 get.client.revoke_vpci_device(bus_id).await.unwrap();
609 get.client
610 .report_vpci_device_binding_state(bus_id, true)
611 .await
612 .unwrap();
613 get.client
614 .report_vpci_device_binding_state(bus_id, false)
615 .await
616 .unwrap();
617
618 get.client.connect_to_vpci_event_source(bus_id).await;
619 get.client.disconnect_from_vpci_event_source(bus_id);
620 }
621
622 #[ignore]
624 #[async_test]
625 async fn test_save_guest_vtl2_state(driver: DefaultDriver) {
626 let mut get = new_transport_pair(driver, None, ProtocolVersion::NICKEL_REV2).await;
627
628 get.test_ged_client.test_save_guest_vtl2_state().await;
629 }
630}