tdisp/
serialize_proto.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Protobuf serialization of TDISP guest-to-host commands and responses using
5//! the types defined in [`tdisp_proto`].
6
7use anyhow::Context;
8use prost::Message as _;
9use tdisp_proto::GuestToHostCommand;
10use tdisp_proto::GuestToHostResponse;
11use tdisp_proto::TdispGuestOperationErrorCode;
12use tdisp_proto::TdispGuestProtocolType;
13use tdisp_proto::TdispGuestUnbindReason;
14use tdisp_proto::TdispReportType;
15use tdisp_proto::TdispTdiState;
16use tdisp_proto::guest_to_host_command::Command;
17use tdisp_proto::guest_to_host_response::Response;
18
19/// All fields in proto3 are optional regardless of definition. This runtime requires that a protobuf field is not `None`.
20macro_rules! require_field {
21    ($val:expr) => {
22        $val.as_ref()
23            .ok_or_else(|| anyhow::anyhow!("proto validation: {} must be set", stringify!($val)))
24    };
25}
26
27/// All enums in proto3 are optional regardless of definition. This runtime requires that a protobuf enum is valid and
28/// within the range of the enum.
29macro_rules! require_enum {
30    ($field:expr, $enum_ty:ty) => {
31        <$enum_ty>::from_i32($field).ok_or_else(|| {
32            anyhow::anyhow!(
33                "proto validation: {} is not a valid {}: {}",
34                stringify!($field),
35                stringify!($enum_ty),
36                $field
37            )
38        })
39    };
40}
41
42/// Serialize a [`GuestToHostCommand`] to a protobuf-encoded byte vector.
43pub fn serialize_command(command: &GuestToHostCommand) -> Vec<u8> {
44    command.encode_to_vec()
45}
46
47/// Deserialize a [`GuestToHostCommand`] from a protobuf-encoded byte slice.
48pub fn deserialize_command(bytes: &[u8]) -> anyhow::Result<GuestToHostCommand> {
49    let res = GuestToHostCommand::decode(bytes)
50        .map_err(|e| anyhow::anyhow!("failed to deserialize GuestToHostCommand: {e}"))?;
51
52    // Then, validate the command to ensure that it matches the expected format.
53    validate_command(&res).with_context(|| "failed to validate command in deserialize_command")?;
54
55    Ok(res)
56}
57
58/// Serialize a [`GuestToHostResponse`] to a protobuf-encoded byte vector.
59pub fn serialize_response(response: &GuestToHostResponse) -> Vec<u8> {
60    response.encode_to_vec()
61}
62
63/// Deserialize a [`GuestToHostResponse`] from a protobuf-encoded byte slice.
64pub fn deserialize_response(bytes: &[u8]) -> anyhow::Result<GuestToHostResponse> {
65    let res = GuestToHostResponse::decode(bytes).map_err(|e: prost::DecodeError| {
66        anyhow::anyhow!("failed to deserialize GuestToHostResponse: {e}")
67    })?;
68
69    // Then, validate the response to ensure that it matches the expected format.
70    validate_response(&res)
71        .with_context(|| "failed to validate response in deserialize_response")?;
72
73    Ok(res)
74}
75
76/// Validate the invariants of a [`GuestToHostCommand`] to ensure that it matches the
77/// expected required protocol format.
78pub fn validate_command(command: &GuestToHostCommand) -> anyhow::Result<()> {
79    require_field!(command.command)?;
80
81    if let Some(Command::GetDeviceInterfaceInfo(req)) = &command.command {
82        require_enum!(req.guest_protocol_type, TdispGuestProtocolType)?;
83    } else if let Some(Command::GetTdiReport(req)) = &command.command {
84        require_enum!(req.report_type, TdispReportType)?;
85    } else if let Some(Command::Unbind(req)) = &command.command {
86        require_enum!(req.unbind_reason, TdispGuestUnbindReason)?;
87    }
88
89    Ok(())
90}
91
92/// Validate the invariants of a [`GuestToHostResponse`] to ensure that it matches the
93/// expected required protocol format.
94pub fn validate_response(response: &GuestToHostResponse) -> anyhow::Result<()> {
95    require_enum!(response.result, TdispGuestOperationErrorCode)?;
96    require_enum!(response.tdi_state_before, TdispTdiState)?;
97    require_enum!(response.tdi_state_after, TdispTdiState)?;
98
99    // Only require a result field if the response is a success.
100    if response.result == TdispGuestOperationErrorCode::Success as i32 {
101        require_field!(response.response)?;
102        if let Some(Response::GetTdiReport(req)) = &response.response {
103            require_enum!(req.report_type, TdispReportType)?;
104            if req.report_buffer.is_empty() {
105                return Err(anyhow::anyhow!(
106                    "proto validation: report_buffer must not be empty"
107                ));
108            }
109        } else if let Some(Response::GetDeviceInterfaceInfo(req)) = &response.response {
110            require_field!(req.interface_info)?;
111        }
112    }
113
114    Ok(())
115}