hyperv_ic/
common.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4//! Common code for IC implementations.
5
6use anyhow::Context as _;
7use hyperv_ic_protocol::FRAMEWORK_VERSION_1;
8use hyperv_ic_protocol::FRAMEWORK_VERSION_3;
9use hyperv_ic_protocol::HeaderFlags;
10use hyperv_ic_protocol::MessageType;
11use hyperv_ic_protocol::Status;
12use hyperv_ic_protocol::Version;
13use inspect::Inspect;
14use inspect::InspectMut;
15use std::io::IoSlice;
16use vmbus_async::async_dgram::AsyncRecvExt;
17use vmbus_async::async_dgram::AsyncSendExt;
18use vmbus_async::pipe::MessagePipe;
19use vmbus_channel::RawAsyncChannel;
20use vmbus_channel::gpadl_ring::GpadlRingMem;
21use zerocopy::FromBytes;
22use zerocopy::FromZeros;
23use zerocopy::IntoBytes;
24
25/// Supported framework versions.
26const FRAMEWORK_VERSIONS: &[Version] = &[FRAMEWORK_VERSION_1, FRAMEWORK_VERSION_3];
27
28#[derive(InspectMut)]
29pub(crate) struct IcPipe {
30    #[inspect(mut)]
31    pub pipe: MessagePipe<GpadlRingMem>,
32    #[inspect(skip)]
33    buf: Vec<u8>,
34}
35
36#[derive(Inspect, Default)]
37pub(crate) enum NegotiateState {
38    #[default]
39    SendVersion,
40    WaitVersion,
41    Invalid,
42}
43
44#[derive(Copy, Clone, Debug, Inspect)]
45pub(crate) struct Versions {
46    #[inspect(display)]
47    pub framework_version: Version,
48    #[inspect(display)]
49    pub message_version: Version,
50}
51
52impl IcPipe {
53    pub fn new(raw: RawAsyncChannel<GpadlRingMem>) -> Result<Self, std::io::Error> {
54        let pipe = MessagePipe::new(raw)?;
55        let buf = vec![0; hyperv_ic_protocol::MAX_MESSAGE_SIZE];
56        Ok(Self { pipe, buf })
57    }
58
59    pub async fn negotiate(
60        &mut self,
61        state: &mut NegotiateState,
62        message_versions: &[Version],
63    ) -> anyhow::Result<Option<Versions>> {
64        match state {
65            NegotiateState::SendVersion => {
66                let message = hyperv_ic_protocol::NegotiateMessage {
67                    framework_version_count: FRAMEWORK_VERSIONS.len() as u16,
68                    message_version_count: message_versions.len() as u16,
69                    ..FromZeros::new_zeroed()
70                };
71
72                let header = hyperv_ic_protocol::Header {
73                    message_type: MessageType::VERSION_NEGOTIATION,
74                    message_size: (size_of_val(&message)
75                        + size_of_val(FRAMEWORK_VERSIONS)
76                        + size_of_val(message_versions)) as u16,
77                    status: Status::SUCCESS,
78                    transaction_id: 0,
79                    flags: HeaderFlags::new().with_transaction(true).with_request(true),
80                    ..FromZeros::new_zeroed()
81                };
82
83                self.pipe
84                    .send_vectored(&[
85                        IoSlice::new(header.as_bytes()),
86                        IoSlice::new(message.as_bytes()),
87                        IoSlice::new(FRAMEWORK_VERSIONS.as_bytes()),
88                        IoSlice::new(message_versions.as_bytes()),
89                    ])
90                    .await
91                    .context("ring buffer error")?;
92
93                *state = NegotiateState::WaitVersion;
94                Ok(None)
95            }
96            NegotiateState::WaitVersion => {
97                let (_result, buf) = self.read_response().await?;
98                let (message, rest) = hyperv_ic_protocol::NegotiateMessage::read_from_prefix(buf)
99                    .ok()
100                    .context("missing negotiate message")?;
101                if message.framework_version_count != 1 || message.message_version_count != 1 {
102                    anyhow::bail!("no supported versions");
103                }
104                let ([framework_version, message_version], _) =
105                    <[Version; 2]>::read_from_prefix(rest)
106                        .ok()
107                        .context("missing version table")?;
108
109                *state = NegotiateState::Invalid;
110                Ok(Some(Versions {
111                    framework_version,
112                    message_version,
113                }))
114            }
115            NegotiateState::Invalid => {
116                unreachable!()
117            }
118        }
119    }
120
121    pub async fn write_message(
122        &mut self,
123        versions: &Versions,
124        message_type: MessageType,
125        flags: HeaderFlags,
126        message: &[u8],
127    ) -> anyhow::Result<()> {
128        let header = hyperv_ic_protocol::Header {
129            framework_version: versions.framework_version,
130            message_type,
131            message_size: message.len() as u16,
132            message_version: versions.message_version,
133            status: Status::SUCCESS,
134            transaction_id: 0,
135            flags,
136            ..FromZeros::new_zeroed()
137        };
138
139        self.pipe
140            .send_vectored(&[IoSlice::new(header.as_bytes()), IoSlice::new(message)])
141            .await
142            .context("ring buffer error")
143    }
144
145    pub async fn read_response(&mut self) -> anyhow::Result<(Status, &[u8])> {
146        let n = self
147            .pipe
148            .recv(&mut self.buf)
149            .await
150            .context("ring buffer error")?;
151        let buf = &self.buf[..n];
152        let (header, rest) = hyperv_ic_protocol::Header::read_from_prefix(buf)
153            .ok()
154            .context("missing header")?;
155
156        if header.transaction_id != 0 || !header.flags.transaction() || !header.flags.response() {
157            anyhow::bail!("invalid transaction response");
158        }
159
160        let rest = rest
161            .get(..header.message_size as usize)
162            .context("missing message body")?;
163
164        Ok((header.status, rest))
165    }
166}