1#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7pub mod protocol;
8
9use futures::FutureExt;
10use futures::StreamExt;
11use guid::Guid;
12use inspect::Inspect;
13use protocol::HEADER_SIZE;
14use protocol::MAX_MESSAGE_SIZE;
15use protocol::MessageHeader;
16use protocol::VmbusMessage;
17use std::future::Future;
18use std::str::FromStr;
19use std::task::Poll;
20use thiserror::Error;
21use zerocopy::Immutable;
22use zerocopy::IntoBytes;
23use zerocopy::KnownLayout;
24
25pub const VMBUS_SINT: u8 = 2;
27
28#[derive(Debug)]
29pub struct TaggedStream<T, S>(Option<T>, S);
30
31impl<T: Clone, S: futures::Stream + Unpin> TaggedStream<T, S> {
32 pub fn new(t: T, s: S) -> Self {
33 Self(Some(t), s)
34 }
35
36 pub fn value(&self) -> Option<&T> {
37 self.0.as_ref()
38 }
39}
40
41impl<T: Clone, S: futures::Stream + Unpin> futures::Stream for TaggedStream<T, S>
42where
43 Self: Unpin,
44{
45 type Item = (T, Option<S::Item>);
46
47 fn poll_next(
48 self: std::pin::Pin<&mut Self>,
49 cx: &mut std::task::Context<'_>,
50 ) -> Poll<Option<Self::Item>> {
51 let this = self.get_mut();
52 if let Some(t) = this.0.clone() {
53 let v = std::task::ready!(this.1.poll_next_unpin(cx));
54 if v.is_none() {
55 this.0 = None;
57 }
58 Poll::Ready(Some((t, v)))
59 } else {
60 Poll::Ready(None)
61 }
62 }
63}
64
65#[derive(Debug)]
66pub struct TaggedFuture<T, F>(T, F);
67
68impl<T: Clone, F: Future + Unpin> Future for TaggedFuture<T, F>
69where
70 Self: Unpin,
71{
72 type Output = (T, F::Output);
73
74 fn poll(
75 mut self: std::pin::Pin<&mut Self>,
76 cx: &mut std::task::Context<'_>,
77 ) -> Poll<Self::Output> {
78 let r = std::task::ready!(self.1.poll_unpin(cx));
79 Poll::Ready((self.0.clone(), r))
80 }
81}
82
83#[derive(Copy, Clone, Debug, PartialEq, Eq, Inspect)]
85pub struct VersionInfo {
86 pub version: protocol::Version,
87 pub feature_flags: protocol::FeatureFlags,
88}
89
90#[derive(Copy, Clone, Debug)]
92pub struct MaxVersionInfo {
93 pub version: u32,
94 pub feature_flags: protocol::FeatureFlags,
95}
96
97impl MaxVersionInfo {
98 pub fn new(version: u32) -> Self {
99 Self {
100 version,
101 feature_flags: protocol::FeatureFlags::new(),
102 }
103 }
104}
105
106impl From<VersionInfo> for MaxVersionInfo {
107 fn from(info: VersionInfo) -> Self {
108 Self {
109 version: info.version as u32,
110 feature_flags: info.feature_flags,
111 }
112 }
113}
114
115pub fn parse_vmbus_version(value: &str) -> Result<u32, String> {
119 || -> Option<u32> {
120 let (major, minor) = value.split_once('.')?;
121 let major = u16::from_str(major).ok()?;
122 let minor = u16::from_str(minor).ok()?;
123 Some(protocol::make_version(major, minor))
124 }()
125 .ok_or_else(|| format!("invalid vmbus version '{}'", value))
126}
127
128#[derive(Clone, Debug)]
129pub struct OutgoingMessage {
130 data: [u8; MAX_MESSAGE_SIZE],
131 len: u8,
132}
133
134impl OutgoingMessage {
136 pub fn new<T: IntoBytes + Immutable + KnownLayout + VmbusMessage>(message: &T) -> Self {
138 let mut data = [0; MAX_MESSAGE_SIZE];
139 let header = MessageHeader::new(T::MESSAGE_TYPE);
140 let message_bytes = message.as_bytes();
141 let len = HEADER_SIZE + message_bytes.len();
142 data[..HEADER_SIZE].copy_from_slice(header.as_bytes());
143 data[HEADER_SIZE..len].copy_from_slice(message_bytes);
144 Self {
145 data,
146 len: len as u8,
147 }
148 }
149
150 pub fn with_data<T: IntoBytes + Immutable + KnownLayout + VmbusMessage>(
153 message: &T,
154 data: &[u8],
155 ) -> Self {
156 let mut message = OutgoingMessage::new(message);
157 let old_len = message.len as usize;
158 let len = old_len + data.len();
159 message.data[old_len..len].copy_from_slice(data);
160 message.len = len as u8;
161 message
162 }
163
164 pub fn from_message(message: &[u8]) -> Result<Self, MessageTooLarge> {
167 if message.len() > MAX_MESSAGE_SIZE {
168 return Err(MessageTooLarge);
169 }
170 let mut data = [0; MAX_MESSAGE_SIZE];
171 data[0..message.len()].copy_from_slice(message);
172 Ok(Self {
173 data,
174 len: message.len() as u8,
175 })
176 }
177
178 pub fn data(&self) -> &[u8] {
180 &self.data[..self.len as usize]
181 }
182}
183
184impl PartialEq for OutgoingMessage {
185 fn eq(&self, other: &Self) -> bool {
186 self.len == other.len && self.data[..self.len as usize] == other.data[..self.len as usize]
187 }
188}
189
190#[derive(Debug, Error)]
191#[error("a synic message exceeds the maximum length")]
192pub struct MessageTooLarge;
193
194#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Inspect)]
196pub struct HvsockConnectRequest {
197 pub service_id: Guid,
198 pub endpoint_id: Guid,
199 pub silo_id: Guid,
200 pub hosted_silo_unaware: bool,
201}
202
203impl HvsockConnectRequest {
204 pub fn from_message(value: protocol::TlConnectRequest2, hosted_silo_unaware: bool) -> Self {
205 Self {
206 service_id: value.base.service_id,
207 endpoint_id: value.base.endpoint_id,
208 silo_id: value.silo_id,
209 hosted_silo_unaware,
210 }
211 }
212}
213
214impl From<HvsockConnectRequest> for protocol::TlConnectRequest2 {
215 fn from(value: HvsockConnectRequest) -> Self {
216 Self {
217 base: protocol::TlConnectRequest {
218 endpoint_id: value.endpoint_id,
219 service_id: value.service_id,
220 },
221 silo_id: value.silo_id,
222 }
223 }
224}
225
226#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
228pub struct HvsockConnectResult {
229 pub service_id: Guid,
230 pub endpoint_id: Guid,
231 pub success: bool,
232}
233
234impl HvsockConnectResult {
235 pub fn from_request(request: &HvsockConnectRequest, success: bool) -> Self {
237 Self {
238 service_id: request.service_id,
239 endpoint_id: request.endpoint_id,
240 success,
241 }
242 }
243}
244
245impl From<protocol::TlConnectResult> for HvsockConnectResult {
246 fn from(value: protocol::TlConnectResult) -> Self {
247 Self {
248 service_id: value.service_id,
249 endpoint_id: value.endpoint_id,
250 success: value.status == protocol::STATUS_SUCCESS,
251 }
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use crate::protocol::ChannelId;
259 use crate::protocol::GpadlId;
260
261 #[test]
262 fn test_outgoing_message() {
263 let message = OutgoingMessage::new(&protocol::CloseChannel {
264 channel_id: ChannelId(5),
265 });
266
267 assert_eq!(&[0x7, 0, 0, 0, 0, 0, 0, 0, 0x5, 0, 0, 0], message.data())
268 }
269
270 #[test]
271 fn test_outgoing_message_empty() {
272 let message = OutgoingMessage::new(&protocol::Unload {});
273
274 assert_eq!(&[0x10, 0, 0, 0, 0, 0, 0, 0], message.data())
275 }
276
277 #[test]
278 fn test_outgoing_message_with_data() {
279 let message = OutgoingMessage::with_data(
280 &protocol::GpadlHeader {
281 channel_id: ChannelId(5),
282 gpadl_id: GpadlId(1),
283 len: 7,
284 count: 6,
285 },
286 &[0xa, 0xb, 0xc, 0xd],
287 );
288
289 assert_eq!(
290 &[
291 0x8, 0, 0, 0, 0, 0, 0, 0, 0x5, 0, 0, 0, 0x1, 0, 0, 0, 0x7, 0, 0x6, 0, 0xa, 0xb,
292 0xc, 0xd
293 ],
294 message.data()
295 )
296 }
297}