1use anyhow::Context as _;
7use hyperv_ic_protocol::FRAMEWORK_VERSION_1;
8use hyperv_ic_protocol::FRAMEWORK_VERSION_3;
9use hyperv_ic_protocol::HeaderFlags;
10use hyperv_ic_protocol::MessageType;
11use hyperv_ic_protocol::Status;
12use hyperv_ic_protocol::Version;
13use inspect::Inspect;
14use inspect::InspectMut;
15use std::io::IoSlice;
16use vmbus_async::async_dgram::AsyncRecvExt;
17use vmbus_async::async_dgram::AsyncSendExt;
18use vmbus_async::pipe::MessagePipe;
19use vmbus_channel::RawAsyncChannel;
20use vmbus_channel::gpadl_ring::GpadlRingMem;
21use zerocopy::FromBytes;
22use zerocopy::FromZeros;
23use zerocopy::IntoBytes;
24
25const FRAMEWORK_VERSIONS: &[Version] = &[FRAMEWORK_VERSION_1, FRAMEWORK_VERSION_3];
27
28#[derive(InspectMut)]
29pub(crate) struct IcPipe {
30 #[inspect(mut)]
31 pub pipe: MessagePipe<GpadlRingMem>,
32 #[inspect(skip)]
33 buf: Vec<u8>,
34}
35
36#[derive(Inspect, Default)]
37pub(crate) enum NegotiateState {
38 #[default]
39 SendVersion,
40 WaitVersion,
41 Invalid,
42}
43
44#[derive(Copy, Clone, Debug, Inspect)]
45pub(crate) struct Versions {
46 #[inspect(display)]
47 pub framework_version: Version,
48 #[inspect(display)]
49 pub message_version: Version,
50}
51
52impl IcPipe {
53 pub fn new(raw: RawAsyncChannel<GpadlRingMem>) -> Result<Self, std::io::Error> {
54 let pipe = MessagePipe::new(raw)?;
55 let buf = vec![0; hyperv_ic_protocol::MAX_MESSAGE_SIZE];
56 Ok(Self { pipe, buf })
57 }
58
59 pub async fn negotiate(
60 &mut self,
61 state: &mut NegotiateState,
62 message_versions: &[Version],
63 ) -> anyhow::Result<Option<Versions>> {
64 match state {
65 NegotiateState::SendVersion => {
66 let message = hyperv_ic_protocol::NegotiateMessage {
67 framework_version_count: FRAMEWORK_VERSIONS.len() as u16,
68 message_version_count: message_versions.len() as u16,
69 ..FromZeros::new_zeroed()
70 };
71
72 let header = hyperv_ic_protocol::Header {
73 message_type: MessageType::VERSION_NEGOTIATION,
74 message_size: (size_of_val(&message)
75 + size_of_val(FRAMEWORK_VERSIONS)
76 + size_of_val(message_versions)) as u16,
77 status: Status::SUCCESS,
78 transaction_id: 0,
79 flags: HeaderFlags::new().with_transaction(true).with_request(true),
80 ..FromZeros::new_zeroed()
81 };
82
83 self.pipe
84 .send_vectored(&[
85 IoSlice::new(header.as_bytes()),
86 IoSlice::new(message.as_bytes()),
87 IoSlice::new(FRAMEWORK_VERSIONS.as_bytes()),
88 IoSlice::new(message_versions.as_bytes()),
89 ])
90 .await
91 .context("ring buffer error")?;
92
93 *state = NegotiateState::WaitVersion;
94 Ok(None)
95 }
96 NegotiateState::WaitVersion => {
97 let (_result, buf) = self.read_response().await?;
98 let (message, rest) = hyperv_ic_protocol::NegotiateMessage::read_from_prefix(buf)
99 .ok()
100 .context("missing negotiate message")?;
101 if message.framework_version_count != 1 || message.message_version_count != 1 {
102 anyhow::bail!("no supported versions");
103 }
104 let ([framework_version, message_version], _) =
105 <[Version; 2]>::read_from_prefix(rest)
106 .ok()
107 .context("missing version table")?;
108
109 *state = NegotiateState::Invalid;
110 Ok(Some(Versions {
111 framework_version,
112 message_version,
113 }))
114 }
115 NegotiateState::Invalid => {
116 unreachable!()
117 }
118 }
119 }
120
121 pub async fn write_message(
122 &mut self,
123 versions: &Versions,
124 message_type: MessageType,
125 flags: HeaderFlags,
126 message: &[u8],
127 ) -> anyhow::Result<()> {
128 let header = hyperv_ic_protocol::Header {
129 framework_version: versions.framework_version,
130 message_type,
131 message_size: message.len() as u16,
132 message_version: versions.message_version,
133 status: Status::SUCCESS,
134 transaction_id: 0,
135 flags,
136 ..FromZeros::new_zeroed()
137 };
138
139 self.pipe
140 .send_vectored(&[IoSlice::new(header.as_bytes()), IoSlice::new(message)])
141 .await
142 .context("ring buffer error")
143 }
144
145 pub async fn read_response(&mut self) -> anyhow::Result<(Status, &[u8])> {
146 let n = self
147 .pipe
148 .recv(&mut self.buf)
149 .await
150 .context("ring buffer error")?;
151 let buf = &self.buf[..n];
152 let (header, rest) = hyperv_ic_protocol::Header::read_from_prefix(buf)
153 .ok()
154 .context("missing header")?;
155
156 if header.transaction_id != 0 || !header.flags.transaction() || !header.flags.response() {
157 anyhow::bail!("invalid transaction response");
158 }
159
160 let rest = rest
161 .get(..header.message_size as usize)
162 .context("missing message body")?;
163
164 Ok((header.status, rest))
165 }
166}