guest_emulation_transport/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Guest Emulation Transport - GET
5//!
6//! The GET is the guest side of a communication channel that uses VMBUS to communicate between Guest and Host.
7//! The Guest sends messages through the GET to get information on the time, VMGS file, attestation,
8//! platform settings, bios boot settings, and guest state protection.
9
10#![cfg(target_os = "linux")]
11#![forbid(unsafe_code)]
12
13pub mod api;
14pub mod error;
15pub mod resolver;
16
17mod client;
18mod process_loop;
19mod worker;
20
21pub use client::GuestEmulationTransportClient;
22
23/// Error while initialize the GET worker
24#[derive(Debug, thiserror::Error)]
25#[error("failed to initialize GET worker")]
26pub struct SpawnGetError(#[source] process_loop::FatalError);
27
28/// Encountered fatal GET error
29// DEVNOTE: this is a distinct type from `process_loop::FatalError`, as we don't
30// want to leak the details of the internal FatalError type.
31#[derive(Debug, thiserror::Error)]
32#[error("encountered fatal GET error")]
33pub struct FatalGetError(#[source] process_loop::FatalError);
34
35/// Takes in a driver and initializes the GET, returning a client that can be
36/// used to invoke requests to the GET worker.
37pub async fn spawn_get_worker(
38    driver: impl pal_async::driver::SpawnDriver,
39) -> Result<
40    (
41        GuestEmulationTransportClient,
42        pal_async::task::Task<Result<(), FatalGetError>>,
43    ),
44    SpawnGetError,
45> {
46    let (worker, task) = worker::GuestEmulationTransportWorker::new(driver)
47        .await
48        .map_err(SpawnGetError)?;
49    Ok((worker.new_client(), task))
50}
51
52#[cfg(any(feature = "test_utilities", test))]
53#[expect(missing_docs)]
54pub mod test_utilities {
55    use super::*;
56    use crate::worker::GuestEmulationTransportWorker;
57    use client::GuestEmulationTransportClient;
58    use get_protocol::ProtocolVersion;
59    use guest_emulation_device::test_utilities::TestGedClient;
60    use guest_emulation_device::test_utilities::TestGetResponses;
61    use mesh::Receiver;
62    use pal_async::task::Spawn;
63    use pal_async::task::Task;
64
65    pub const DEFAULT_SIZE: usize = 4194816; // 4 MB
66
67    #[cfg_attr(not(test), expect(dead_code))]
68    pub struct TestGet {
69        pub client: GuestEmulationTransportClient,
70        pub(crate) gen_id: Receiver<[u8; 16]>,
71        pub(crate) guest_task: Task<Result<(), FatalGetError>>,
72        pub(crate) test_ged_client: TestGedClient,
73    }
74
75    /// Creates a new host guest transport pair ready to send data.
76    ///
77    /// If `ged_responses` is Some(), then TestGedChannel will be used to
78    /// control what responses the Host sends. Otherwise, if `ged_responses` is
79    /// None, we will use the regular GedChannel to automate responses.
80    pub async fn new_transport_pair(
81        spawn: impl Spawn,
82        ged_responses: Option<Vec<TestGetResponses>>,
83        version: ProtocolVersion,
84    ) -> TestGet {
85        let (host_vmbus, guest_vmbus) = vmbus_async::pipe::connected_message_pipes(
86            get_protocol::MAX_MESSAGE_SIZE + vmbus_ring::PAGE_SIZE,
87        );
88
89        let test_ged_client = guest_emulation_device::test_utilities::create_host_channel(
90            &spawn,
91            host_vmbus,
92            ged_responses,
93            version,
94        );
95
96        // Create the GET
97        let (guest_transport, guest_task) =
98            GuestEmulationTransportWorker::with_pipe(&spawn, guest_vmbus)
99                .await
100                .unwrap();
101
102        let client = guest_transport.new_client();
103
104        TestGet {
105            gen_id: client.take_generation_id_recv().await.unwrap(),
106            client,
107            guest_task,
108            test_ged_client,
109        }
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::test_utilities::*;
116    use super::worker::GuestEmulationTransportWorker;
117    use crate::process_loop::FatalError;
118    use get_protocol::ProtocolVersion;
119    use get_protocol::VmgsIoStatus;
120    use get_protocol::test_utilities::TEST_VMGS_SECTOR_SIZE;
121    use guest_emulation_device::test_utilities::Event;
122    use guest_emulation_device::test_utilities::TestGetResponses;
123    use pal_async::DefaultDriver;
124    use pal_async::async_test;
125    use pal_async::task::Spawn;
126    use test_with_tracing::test;
127    use vmbus_async::async_dgram::AsyncRecvExt;
128    use vmbus_async::async_dgram::AsyncSendExt;
129    use zerocopy::FromZeros;
130    use zerocopy::IntoBytes;
131
132    #[async_test]
133    async fn test_version_negotiation_failed(driver: DefaultDriver) {
134        let (mut host_vmbus, guest_vmbus) =
135            vmbus_async::pipe::connected_message_pipes(get_protocol::MAX_MESSAGE_SIZE);
136
137        let host_task = driver.spawn("host task", async move {
138            for protocol in [ProtocolVersion::NICKEL_REV2] {
139                let mut version_request = get_protocol::VersionRequest::new_zeroed();
140                let len = version_request.as_bytes().len();
141                assert_eq!(
142                    len,
143                    host_vmbus
144                        .recv(version_request.as_mut_bytes())
145                        .await
146                        .unwrap()
147                );
148
149                assert_eq!(
150                    version_request.message_header.message_id(),
151                    get_protocol::HostRequests::VERSION
152                );
153                assert_eq!(version_request.version, protocol);
154
155                // Reject the request.
156                let version_response = get_protocol::VersionResponse::new(false);
157
158                host_vmbus.send(version_response.as_bytes()).await.unwrap();
159            }
160        });
161
162        let transport = GuestEmulationTransportWorker::with_pipe(driver, guest_vmbus).await;
163
164        match transport.unwrap_err() {
165            FatalError::VersionNegotiationFailed => {}
166            e => panic!("Wrong error type returned: {}", e),
167        }
168
169        host_task.await;
170    }
171
172    #[async_test]
173    async fn test_all_basic(driver: DefaultDriver) {
174        let time_zone = 5;
175        let utc = 3;
176
177        let time_response = TestGetResponses::new(Event::Response(
178            get_protocol::TimeResponse::new(0, utc, time_zone, false)
179                .as_bytes()
180                .to_vec(),
181        ));
182
183        let vmgs_device_info_response = TestGetResponses::new(Event::Response(
184            get_protocol::VmgsGetDeviceInfoResponse::new(VmgsIoStatus::SUCCESS, 1, 2, 3, 4)
185                .as_bytes()
186                .to_vec(),
187        ));
188
189        let flush_response = TestGetResponses::new(Event::Response(
190            get_protocol::VmgsFlushResponse::new(VmgsIoStatus::SUCCESS)
191                .as_bytes()
192                .to_vec(),
193        ));
194
195        let guest_state_protection = TestGetResponses::new(Event::Response(
196            get_protocol::GuestStateProtectionResponse {
197                message_header: get_protocol::HeaderGeneric::new(
198                    get_protocol::HostRequests::GUEST_STATE_PROTECTION,
199                ),
200                encrypted_gsp: get_protocol::GspCiphertextContent::new_zeroed(),
201                decrypted_gsp: [get_protocol::GspCleartextContent::new_zeroed();
202                    get_protocol::NUMBER_GSP as usize],
203                extended_status_flags: get_protocol::GspExtendedStatusFlags::new()
204                    .with_state_refresh_request(true),
205            }
206            .as_bytes()
207            .to_vec(),
208        ));
209
210        let gsp_id = TestGetResponses::new(Event::Response(
211            get_protocol::GuestStateProtectionByIdResponse {
212                message_header: get_protocol::HeaderGeneric::new(
213                    get_protocol::HostRequests::GUEST_STATE_PROTECTION_BY_ID,
214                ),
215                seed: get_protocol::GspCleartextContent::new_zeroed(),
216                extended_status_flags: get_protocol::GspExtendedStatusFlags::new()
217                    .with_no_registry_file(true)
218                    .with_state_refresh_request(true),
219            }
220            .as_bytes()
221            .to_vec(),
222        ));
223
224        let igvm_attest = TestGetResponses::new(Event::Response(
225            get_protocol::IgvmAttestResponse {
226                message_header: get_protocol::HeaderGeneric::new(
227                    get_protocol::HostRequests::IGVM_ATTEST,
228                ),
229                length: 512,
230            }
231            .as_bytes()
232            .to_vec(),
233        ));
234
235        let ged_responses = vec![
236            time_response,
237            vmgs_device_info_response,
238            flush_response,
239            guest_state_protection,
240            gsp_id,
241            igvm_attest,
242        ];
243
244        let get =
245            new_transport_pair(driver, Some(ged_responses), ProtocolVersion::NICKEL_REV2).await;
246
247        let result = get.client.host_time().await;
248
249        assert_eq!(result.utc, utc);
250        assert_eq!(result.time_zone, time_zone);
251
252        let response = get.client.vmgs_get_device_info().await.unwrap();
253        assert_eq!(response.capacity, 1);
254        assert_eq!(response.bytes_per_logical_sector, 2);
255        assert_eq!(response.bytes_per_physical_sector, 3);
256        assert_eq!(response.maximum_transfer_size_bytes, 4);
257
258        get.client.vmgs_flush().await.unwrap();
259
260        let gsp_response = get
261            .client
262            .guest_state_protection_data(
263                [get_protocol::GspCiphertextContent::new_zeroed();
264                    get_protocol::NUMBER_GSP as usize],
265                get_protocol::GspExtendedStatusFlags::new().with_state_refresh_request(true),
266            )
267            .await;
268
269        assert_eq!(
270            gsp_response.extended_status_flags,
271            get_protocol::GspExtendedStatusFlags::new().with_state_refresh_request(true)
272        );
273
274        let gsp_id_response = get
275            .client
276            .guest_state_protection_data_by_id()
277            .await
278            .unwrap();
279
280        assert_eq!(
281            gsp_id_response.extended_status_flags,
282            get_protocol::GspExtendedStatusFlags::new()
283                .with_no_registry_file(true)
284                .with_state_refresh_request(true)
285        );
286    }
287
288    #[async_test]
289    async fn test_vmgs_basic_write(driver: DefaultDriver) {
290        let vmgs_write_response = TestGetResponses::new(Event::Response(
291            get_protocol::VmgsWriteResponse::new(VmgsIoStatus::SUCCESS)
292                .as_bytes()
293                .to_vec(),
294        ));
295
296        let vmgs_read_response = TestGetResponses::new(Event::Response(
297            get_protocol::VmgsReadResponse::new(VmgsIoStatus::SUCCESS)
298                .as_bytes()
299                .to_vec(),
300        ));
301        let ged_responses = vec![vmgs_write_response, vmgs_read_response];
302
303        let get =
304            new_transport_pair(driver, Some(ged_responses), ProtocolVersion::NICKEL_REV2).await;
305        let buf = (0..512).map(|x| x as u8).collect::<Vec<u8>>();
306        get.client
307            .vmgs_write(0, buf.clone(), TEST_VMGS_SECTOR_SIZE)
308            .await
309            .unwrap();
310
311        let read_buf = get
312            .client
313            .vmgs_read(0, 1, TEST_VMGS_SECTOR_SIZE)
314            .await
315            .unwrap();
316        assert_eq!(read_buf, buf);
317    }
318
319    #[async_test]
320    async fn different_get_versions_nickel(driver: DefaultDriver) {
321        // NICKEL, json dps no aps
322        let json = get_protocol::dps_json::DevicePlatformSettingsV2Json {
323            v1: get_protocol::dps_json::HclDevicePlatformSettings {
324                com1: get_protocol::dps_json::HclUartSettings {
325                    enable_port: true,
326                    debugger_mode: false,
327                    enable_vmbus_redirector: true,
328                },
329                com2: get_protocol::dps_json::HclUartSettings {
330                    enable_port: false,
331                    debugger_mode: false,
332                    enable_vmbus_redirector: false,
333                },
334                enable_firmware_debugging: true,
335                ..Default::default()
336            },
337            v2: get_protocol::dps_json::HclDevicePlatformSettingsV2 {
338                r#static: get_protocol::dps_json::HclDevicePlatformSettingsV2Static {
339                    legacy_memory_map: true,
340                    pxe_ip_v6: true,
341                    ..Default::default()
342                },
343                ..Default::default()
344            },
345        };
346        let json_data = serde_json::to_vec(&json).unwrap();
347
348        let mut dps_response = get_protocol::DevicePlatformSettingsResponseV2Rev1::new_zeroed();
349        dps_response.message_header = get_protocol::HeaderGeneric::new(
350            get_protocol::HostRequests::DEVICE_PLATFORM_SETTINGS_V2_REV1,
351        );
352        dps_response.size = json_data.len() as u32;
353        dps_response.payload_state = get_protocol::LargePayloadState::END;
354
355        let device_platform_settings = TestGetResponses::new(Event::Response(
356            [dps_response.as_bytes().to_vec(), json_data].concat(),
357        ));
358
359        let ged_responses = vec![device_platform_settings];
360
361        let get =
362            new_transport_pair(driver, Some(ged_responses), ProtocolVersion::NICKEL_REV2).await;
363
364        let dps = get.client.device_platform_settings().await.unwrap();
365        assert_eq!(dps.general.tpm_enabled, false);
366        assert_eq!(dps.general.com1_enabled, true);
367        assert_eq!(dps.general.secure_boot_enabled, false);
368
369        assert_eq!(dps.general.legacy_memory_map, true);
370        assert_eq!(dps.general.pxe_ip_v6, true);
371        assert_eq!(dps.general.nvdimm_count, 0);
372
373        assert_eq!(dps.general.generation_id, Some([0; 16]));
374    }
375
376    #[async_test]
377    async fn test_send_notification(driver: DefaultDriver) {
378        // HACK: host notifications are programmed to set the first byte in the
379        // vmgs to a certain value.
380        let vmgs_read_response = TestGetResponses::new(Event::Response(
381            get_protocol::VmgsReadResponse::new(VmgsIoStatus::SUCCESS)
382                .as_bytes()
383                .to_vec(),
384        ));
385        let ged_responses = vec![TestGetResponses::default(), vmgs_read_response];
386
387        let get =
388            new_transport_pair(driver, Some(ged_responses), ProtocolVersion::NICKEL_REV2).await;
389
390        get.client
391            .event_log(get_protocol::EventLogId::NO_BOOT_DEVICE);
392
393        let read_buf = get
394            .client
395            .vmgs_read(0, 1, TEST_VMGS_SECTOR_SIZE)
396            .await
397            .unwrap();
398
399        assert_eq!(read_buf[0], 5);
400    }
401
402    #[async_test]
403    async fn notification_in_between_requests(driver: DefaultDriver) {
404        let time_response = TestGetResponses::new(Event::Response(
405            get_protocol::UpdateGenerationId::new([1; 16])
406                .as_bytes()
407                .to_vec(),
408        ))
409        .add_response(Event::Response(
410            get_protocol::TimeResponse::new(0, 1, 2, false)
411                .as_bytes()
412                .to_vec(),
413        ));
414
415        let ged_responses = vec![time_response];
416
417        let mut get =
418            new_transport_pair(driver, Some(ged_responses), ProtocolVersion::NICKEL_REV2).await;
419
420        let result = get.client.host_time().await;
421
422        let gen_id = get.gen_id.recv().await.unwrap();
423
424        assert_eq!(gen_id, [1; 16]);
425
426        assert_eq!(result.utc, 1);
427        assert_eq!(result.time_zone, 2);
428    }
429
430    #[async_test]
431    async fn host_send_multiple_response(driver: DefaultDriver) {
432        let time_response = TestGetResponses::new(Event::Response(
433            get_protocol::TimeResponse::new(0, 1, 2, false)
434                .as_bytes()
435                .to_vec(),
436        ))
437        .add_response(Event::Response(
438            get_protocol::TimeResponse::new(0, 1, 2, false)
439                .as_bytes()
440                .to_vec(),
441        ));
442
443        let ged_responses = vec![time_response];
444
445        let get =
446            new_transport_pair(driver, Some(ged_responses), ProtocolVersion::NICKEL_REV2).await;
447
448        let result = get.client.host_time().await;
449
450        assert_eq!(result.utc, 1);
451        assert_eq!(result.time_zone, 2);
452
453        let _host_result = get.guest_task.await;
454
455        assert!(matches!(FatalError::NoPendingRequest, _host_result));
456    }
457
458    #[async_test]
459    async fn host_send_incorrect_response(driver: DefaultDriver) {
460        let time_response = TestGetResponses::new(Event::Response(
461            get_protocol::TimeResponse::new(0, 1, 2, false)
462                .as_bytes()
463                .to_vec(),
464        ));
465
466        let ged_responses = vec![time_response];
467
468        let get = new_transport_pair(
469            driver.clone(),
470            Some(ged_responses),
471            ProtocolVersion::NICKEL_REV2,
472        )
473        .await;
474
475        let _never_returns = driver.spawn("badness", async move {
476            let _ = get.client.vmgs_get_device_info().await;
477        });
478
479        let internal_error = get.guest_task.await;
480
481        assert!(matches!(
482            internal_error.map_err(|x| x.0),
483            Err(FatalError::ResponseHeaderMismatchId(_, _))
484        ));
485    }
486
487    #[async_test]
488    async fn test_send_halt_reason(driver: DefaultDriver) {
489        let power_off_check =
490            TestGetResponses::new(Event::Halt(power_resources::PowerRequest::PowerOff));
491        let reset_check = TestGetResponses::new(Event::Halt(power_resources::PowerRequest::Reset));
492
493        let vmgs_device_info_response = TestGetResponses::new(Event::Response(
494            get_protocol::VmgsGetDeviceInfoResponse::new(VmgsIoStatus::SUCCESS, 1, 2, 3, 4)
495                .as_bytes()
496                .to_vec(),
497        ));
498
499        let ged_responses = vec![power_off_check, reset_check, vmgs_device_info_response];
500
501        let get =
502            new_transport_pair(driver, Some(ged_responses), ProtocolVersion::NICKEL_REV2).await;
503
504        get.client.send_power_off();
505        get.client.send_reset();
506
507        // We send a vmgs_get_device_info() so we can ensure the host
508        // finishes processing the notifications we previously sent.
509        // Otherwise, the socket may close before the host can finish
510        // handling both notifications.
511        let response = get.client.vmgs_get_device_info().await.unwrap();
512        assert_eq!(response.capacity, 1);
513        assert_eq!(response.bytes_per_logical_sector, 2);
514        assert_eq!(response.bytes_per_physical_sector, 3);
515        assert_eq!(response.maximum_transfer_size_bytes, 4);
516    }
517
518    #[async_test]
519    async fn test_send_multiple_host_request(driver: DefaultDriver) {
520        let time_response = TestGetResponses::new(Event::Response(
521            get_protocol::TimeResponse::new(0, 1, 2, false)
522                .as_bytes()
523                .to_vec(),
524        ));
525
526        let ged_responses = vec![
527            time_response.clone(),
528            time_response.clone(),
529            time_response.clone(),
530            time_response.clone(),
531            time_response.clone(),
532            time_response,
533        ];
534
535        let get = new_transport_pair(
536            driver.clone(),
537            Some(ged_responses),
538            ProtocolVersion::NICKEL_REV2,
539        )
540        .await;
541
542        let mut tasks = Vec::new();
543
544        for i in 0..6 {
545            let client = get.client.clone();
546            tasks.push(driver.spawn(
547                format!("task {}", i),
548                async move { client.host_time().await },
549            ));
550        }
551
552        // Sleep 1 second to let the host process tasks
553        // std::thread::sleep(std::time::Duration::new(1, 0));
554
555        for task in tasks {
556            let time = task.await;
557            assert_eq!(time.utc, 1);
558            assert_eq!(time.time_zone, 2);
559        }
560    }
561    #[async_test]
562    async fn test_vpci_control(driver: DefaultDriver) {
563        let bus_id = guid::Guid::new_random();
564        let vpci_offer_response = TestGetResponses::new(Event::Response(
565            get_protocol::VpciDeviceControlResponse::new(
566                get_protocol::VpciDeviceControlStatus::SUCCESS,
567            )
568            .as_bytes()
569            .to_vec(),
570        ));
571
572        let vpci_revoke_response = TestGetResponses::new(Event::Response(
573            get_protocol::VpciDeviceControlResponse::new(
574                get_protocol::VpciDeviceControlStatus::SUCCESS,
575            )
576            .as_bytes()
577            .to_vec(),
578        ));
579
580        let vpci_bind_response = TestGetResponses::new(Event::Response(
581            get_protocol::VpciDeviceBindingChangeResponse::new(
582                bus_id,
583                get_protocol::VpciDeviceControlStatus::SUCCESS,
584            )
585            .as_bytes()
586            .to_vec(),
587        ));
588
589        let vpci_unbind_response = TestGetResponses::new(Event::Response(
590            get_protocol::VpciDeviceBindingChangeResponse::new(
591                bus_id,
592                get_protocol::VpciDeviceControlStatus::SUCCESS,
593            )
594            .as_bytes()
595            .to_vec(),
596        ));
597
598        let ged_responses = vec![
599            vpci_offer_response,
600            vpci_revoke_response,
601            vpci_bind_response,
602            vpci_unbind_response,
603        ];
604
605        let get =
606            new_transport_pair(driver, Some(ged_responses), ProtocolVersion::NICKEL_REV2).await;
607        get.client.offer_vpci_device(bus_id).await.unwrap();
608        get.client.revoke_vpci_device(bus_id).await.unwrap();
609        get.client
610            .report_vpci_device_binding_state(bus_id, true)
611            .await
612            .unwrap();
613        get.client
614            .report_vpci_device_binding_state(bus_id, false)
615            .await
616            .unwrap();
617
618        get.client.connect_to_vpci_event_source(bus_id).await;
619        get.client.disconnect_from_vpci_event_source(bus_id);
620    }
621
622    // Temporarily ignored until error handling is done better/hvlite as host flow is plumbed in.
623    #[ignore]
624    #[async_test]
625    async fn test_save_guest_vtl2_state(driver: DefaultDriver) {
626        let mut get = new_transport_pair(driver, None, ProtocolVersion::NICKEL_REV2).await;
627
628        get.test_ged_client.test_save_guest_vtl2_state().await;
629    }
630}