1#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7pub mod protocol;
8
9use futures::FutureExt;
10use futures::StreamExt;
11use guid::Guid;
12use inspect::Inspect;
13use protocol::HEADER_SIZE;
14use protocol::MAX_MESSAGE_SIZE;
15use protocol::MessageHeader;
16use protocol::VmbusMessage;
17use std::future::Future;
18use std::str::FromStr;
19use std::task::Poll;
20use thiserror::Error;
21use zerocopy::Immutable;
22use zerocopy::IntoBytes;
23use zerocopy::KnownLayout;
24
25#[derive(Debug)]
26pub struct TaggedStream<T, S>(Option<T>, S);
27
28impl<T: Clone, S: futures::Stream + Unpin> TaggedStream<T, S> {
29 pub fn new(t: T, s: S) -> Self {
30 Self(Some(t), s)
31 }
32
33 pub fn value(&self) -> Option<&T> {
34 self.0.as_ref()
35 }
36}
37
38impl<T: Clone, S: futures::Stream + Unpin> futures::Stream for TaggedStream<T, S>
39where
40 Self: Unpin,
41{
42 type Item = (T, Option<S::Item>);
43
44 fn poll_next(
45 self: std::pin::Pin<&mut Self>,
46 cx: &mut std::task::Context<'_>,
47 ) -> Poll<Option<Self::Item>> {
48 let this = self.get_mut();
49 if let Some(t) = this.0.clone() {
50 let v = std::task::ready!(this.1.poll_next_unpin(cx));
51 if v.is_none() {
52 this.0 = None;
54 }
55 Poll::Ready(Some((t, v)))
56 } else {
57 Poll::Ready(None)
58 }
59 }
60}
61
62#[derive(Debug)]
63pub struct TaggedFuture<T, F>(T, F);
64
65impl<T: Clone, F: Future + Unpin> Future for TaggedFuture<T, F>
66where
67 Self: Unpin,
68{
69 type Output = (T, F::Output);
70
71 fn poll(
72 mut self: std::pin::Pin<&mut Self>,
73 cx: &mut std::task::Context<'_>,
74 ) -> Poll<Self::Output> {
75 let r = std::task::ready!(self.1.poll_unpin(cx));
76 Poll::Ready((self.0.clone(), r))
77 }
78}
79
80#[derive(Copy, Clone, Debug, PartialEq, Eq, Inspect)]
82pub struct VersionInfo {
83 pub version: protocol::Version,
84 pub feature_flags: protocol::FeatureFlags,
85}
86
87#[derive(Copy, Clone, Debug)]
89pub struct MaxVersionInfo {
90 pub version: u32,
91 pub feature_flags: protocol::FeatureFlags,
92}
93
94impl MaxVersionInfo {
95 pub fn new(version: u32) -> Self {
96 Self {
97 version,
98 feature_flags: protocol::FeatureFlags::new(),
99 }
100 }
101}
102
103impl From<VersionInfo> for MaxVersionInfo {
104 fn from(info: VersionInfo) -> Self {
105 Self {
106 version: info.version as u32,
107 feature_flags: info.feature_flags,
108 }
109 }
110}
111
112pub fn parse_vmbus_version(value: &str) -> Result<u32, String> {
116 || -> Option<u32> {
117 let (major, minor) = value.split_once('.')?;
118 let major = u16::from_str(major).ok()?;
119 let minor = u16::from_str(minor).ok()?;
120 Some(protocol::make_version(major, minor))
121 }()
122 .ok_or_else(|| format!("invalid vmbus version '{}'", value))
123}
124
125#[derive(Clone, Debug)]
126pub struct OutgoingMessage {
127 data: [u8; MAX_MESSAGE_SIZE],
128 len: u8,
129}
130
131impl OutgoingMessage {
133 pub fn new<T: IntoBytes + Immutable + KnownLayout + VmbusMessage>(message: &T) -> Self {
135 let mut data = [0; MAX_MESSAGE_SIZE];
136 let header = MessageHeader::new(T::MESSAGE_TYPE);
137 let message_bytes = message.as_bytes();
138 let len = HEADER_SIZE + message_bytes.len();
139 data[..HEADER_SIZE].copy_from_slice(header.as_bytes());
140 data[HEADER_SIZE..len].copy_from_slice(message_bytes);
141 Self {
142 data,
143 len: len as u8,
144 }
145 }
146
147 pub fn with_data<T: IntoBytes + Immutable + KnownLayout + VmbusMessage>(
150 message: &T,
151 data: &[u8],
152 ) -> Self {
153 let mut message = OutgoingMessage::new(message);
154 let old_len = message.len as usize;
155 let len = old_len + data.len();
156 message.data[old_len..len].copy_from_slice(data);
157 message.len = len as u8;
158 message
159 }
160
161 pub fn from_message(message: &[u8]) -> Result<Self, MessageTooLarge> {
164 if message.len() > MAX_MESSAGE_SIZE {
165 return Err(MessageTooLarge);
166 }
167 let mut data = [0; MAX_MESSAGE_SIZE];
168 data[0..message.len()].copy_from_slice(message);
169 Ok(Self {
170 data,
171 len: message.len() as u8,
172 })
173 }
174
175 pub fn data(&self) -> &[u8] {
177 &self.data[..self.len as usize]
178 }
179}
180
181impl PartialEq for OutgoingMessage {
182 fn eq(&self, other: &Self) -> bool {
183 self.len == other.len && self.data[..self.len as usize] == other.data[..self.len as usize]
184 }
185}
186
187#[derive(Debug, Error)]
188#[error("a synic message exceeds the maximum length")]
189pub struct MessageTooLarge;
190
191#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Inspect)]
193pub struct HvsockConnectRequest {
194 pub service_id: Guid,
195 pub endpoint_id: Guid,
196 pub silo_id: Guid,
197 pub hosted_silo_unaware: bool,
198}
199
200impl HvsockConnectRequest {
201 pub fn from_message(value: protocol::TlConnectRequest2, hosted_silo_unaware: bool) -> Self {
202 Self {
203 service_id: value.base.service_id,
204 endpoint_id: value.base.endpoint_id,
205 silo_id: value.silo_id,
206 hosted_silo_unaware,
207 }
208 }
209}
210
211impl From<HvsockConnectRequest> for protocol::TlConnectRequest2 {
212 fn from(value: HvsockConnectRequest) -> Self {
213 Self {
214 base: protocol::TlConnectRequest {
215 endpoint_id: value.endpoint_id,
216 service_id: value.service_id,
217 },
218 silo_id: value.silo_id,
219 }
220 }
221}
222
223#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
225pub struct HvsockConnectResult {
226 pub service_id: Guid,
227 pub endpoint_id: Guid,
228 pub success: bool,
229}
230
231impl HvsockConnectResult {
232 pub fn from_request(request: &HvsockConnectRequest, success: bool) -> Self {
234 Self {
235 service_id: request.service_id,
236 endpoint_id: request.endpoint_id,
237 success,
238 }
239 }
240}
241
242impl From<protocol::TlConnectResult> for HvsockConnectResult {
243 fn from(value: protocol::TlConnectResult) -> Self {
244 Self {
245 service_id: value.service_id,
246 endpoint_id: value.endpoint_id,
247 success: value.status == protocol::STATUS_SUCCESS,
248 }
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255 use crate::protocol::ChannelId;
256 use crate::protocol::GpadlId;
257
258 #[test]
259 fn test_outgoing_message() {
260 let message = OutgoingMessage::new(&protocol::CloseChannel {
261 channel_id: ChannelId(5),
262 });
263
264 assert_eq!(&[0x7, 0, 0, 0, 0, 0, 0, 0, 0x5, 0, 0, 0], message.data())
265 }
266
267 #[test]
268 fn test_outgoing_message_empty() {
269 let message = OutgoingMessage::new(&protocol::Unload {});
270
271 assert_eq!(&[0x10, 0, 0, 0, 0, 0, 0, 0], message.data())
272 }
273
274 #[test]
275 fn test_outgoing_message_with_data() {
276 let message = OutgoingMessage::with_data(
277 &protocol::GpadlHeader {
278 channel_id: ChannelId(5),
279 gpadl_id: GpadlId(1),
280 len: 7,
281 count: 6,
282 },
283 &[0xa, 0xb, 0xc, 0xd],
284 );
285
286 assert_eq!(
287 &[
288 0x8, 0, 0, 0, 0, 0, 0, 0, 0x5, 0, 0, 0, 0x1, 0, 0, 0, 0x7, 0, 0x6, 0, 0xa, 0xb,
289 0xc, 0xd
290 ],
291 message.data()
292 )
293 }
294}