vmbus_core/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7pub mod protocol;
8
9use futures::FutureExt;
10use futures::StreamExt;
11use guid::Guid;
12use inspect::Inspect;
13use protocol::HEADER_SIZE;
14use protocol::MAX_MESSAGE_SIZE;
15use protocol::MessageHeader;
16use protocol::VmbusMessage;
17use std::future::Future;
18use std::str::FromStr;
19use std::task::Poll;
20use thiserror::Error;
21use zerocopy::Immutable;
22use zerocopy::IntoBytes;
23use zerocopy::KnownLayout;
24
25#[derive(Debug)]
26pub struct TaggedStream<T, S>(Option<T>, S);
27
28impl<T: Clone, S: futures::Stream + Unpin> TaggedStream<T, S> {
29    pub fn new(t: T, s: S) -> Self {
30        Self(Some(t), s)
31    }
32
33    pub fn value(&self) -> Option<&T> {
34        self.0.as_ref()
35    }
36}
37
38impl<T: Clone, S: futures::Stream + Unpin> futures::Stream for TaggedStream<T, S>
39where
40    Self: Unpin,
41{
42    type Item = (T, Option<S::Item>);
43
44    fn poll_next(
45        self: std::pin::Pin<&mut Self>,
46        cx: &mut std::task::Context<'_>,
47    ) -> Poll<Option<Self::Item>> {
48        let this = self.get_mut();
49        if let Some(t) = this.0.clone() {
50            let v = std::task::ready!(this.1.poll_next_unpin(cx));
51            if v.is_none() {
52                // Return `None` next time poll_next is called.
53                this.0 = None;
54            }
55            Poll::Ready(Some((t, v)))
56        } else {
57            Poll::Ready(None)
58        }
59    }
60}
61
62#[derive(Debug)]
63pub struct TaggedFuture<T, F>(T, F);
64
65impl<T: Clone, F: Future + Unpin> Future for TaggedFuture<T, F>
66where
67    Self: Unpin,
68{
69    type Output = (T, F::Output);
70
71    fn poll(
72        mut self: std::pin::Pin<&mut Self>,
73        cx: &mut std::task::Context<'_>,
74    ) -> Poll<Self::Output> {
75        let r = std::task::ready!(self.1.poll_unpin(cx));
76        Poll::Ready((self.0.clone(), r))
77    }
78}
79
80/// Represents information about a negotiated version.
81#[derive(Copy, Clone, Debug, PartialEq, Eq, Inspect)]
82pub struct VersionInfo {
83    pub version: protocol::Version,
84    pub feature_flags: protocol::FeatureFlags,
85}
86
87/// Represents a constraint on the version or features allowed.
88#[derive(Copy, Clone, Debug)]
89pub struct MaxVersionInfo {
90    pub version: u32,
91    pub feature_flags: protocol::FeatureFlags,
92}
93
94impl MaxVersionInfo {
95    pub fn new(version: u32) -> Self {
96        Self {
97            version,
98            feature_flags: protocol::FeatureFlags::new(),
99        }
100    }
101}
102
103impl From<VersionInfo> for MaxVersionInfo {
104    fn from(info: VersionInfo) -> Self {
105        Self {
106            version: info.version as u32,
107            feature_flags: info.feature_flags,
108        }
109    }
110}
111
112/// Parses a string of the form "major.minor" (e.g "5.3") into a vmbus version number.
113///
114/// N.B. This doesn't check whether the specified version actually exists.
115pub fn parse_vmbus_version(value: &str) -> Result<u32, String> {
116    || -> Option<u32> {
117        let (major, minor) = value.split_once('.')?;
118        let major = u16::from_str(major).ok()?;
119        let minor = u16::from_str(minor).ok()?;
120        Some(protocol::make_version(major, minor))
121    }()
122    .ok_or_else(|| format!("invalid vmbus version '{}'", value))
123}
124
125#[derive(Clone, Debug)]
126pub struct OutgoingMessage {
127    data: [u8; MAX_MESSAGE_SIZE],
128    len: u8,
129}
130
131/// Represents a vmbus message to be sent using the synic.
132impl OutgoingMessage {
133    /// Creates a new `OutgoingMessage` for the specified protocol message.
134    pub fn new<T: IntoBytes + Immutable + KnownLayout + VmbusMessage>(message: &T) -> Self {
135        let mut data = [0; MAX_MESSAGE_SIZE];
136        let header = MessageHeader::new(T::MESSAGE_TYPE);
137        let message_bytes = message.as_bytes();
138        let len = HEADER_SIZE + message_bytes.len();
139        data[..HEADER_SIZE].copy_from_slice(header.as_bytes());
140        data[HEADER_SIZE..len].copy_from_slice(message_bytes);
141        Self {
142            data,
143            len: len as u8,
144        }
145    }
146
147    /// Creates a new `OutgoingMessage` for the specified protocol message, including additional
148    /// data at the end of the message.
149    pub fn with_data<T: IntoBytes + Immutable + KnownLayout + VmbusMessage>(
150        message: &T,
151        data: &[u8],
152    ) -> Self {
153        let mut message = OutgoingMessage::new(message);
154        let old_len = message.len as usize;
155        let len = old_len + data.len();
156        message.data[old_len..len].copy_from_slice(data);
157        message.len = len as u8;
158        message
159    }
160
161    /// Converts an existing binary message to an `OutgoingMessage`. The slice
162    /// is assumed to contain a valid message.
163    pub fn from_message(message: &[u8]) -> Result<Self, MessageTooLarge> {
164        if message.len() > MAX_MESSAGE_SIZE {
165            return Err(MessageTooLarge);
166        }
167        let mut data = [0; MAX_MESSAGE_SIZE];
168        data[0..message.len()].copy_from_slice(message);
169        Ok(Self {
170            data,
171            len: message.len() as u8,
172        })
173    }
174
175    /// Gets the binary representation of the message.
176    pub fn data(&self) -> &[u8] {
177        &self.data[..self.len as usize]
178    }
179}
180
181impl PartialEq for OutgoingMessage {
182    fn eq(&self, other: &Self) -> bool {
183        self.len == other.len && self.data[..self.len as usize] == other.data[..self.len as usize]
184    }
185}
186
187#[derive(Debug, Error)]
188#[error("a synic message exceeds the maximum length")]
189pub struct MessageTooLarge;
190
191/// A request from the guest to connect to the specified hvsocket endpoint.
192#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Inspect)]
193pub struct HvsockConnectRequest {
194    pub service_id: Guid,
195    pub endpoint_id: Guid,
196    pub silo_id: Guid,
197    pub hosted_silo_unaware: bool,
198}
199
200impl HvsockConnectRequest {
201    pub fn from_message(value: protocol::TlConnectRequest2, hosted_silo_unaware: bool) -> Self {
202        Self {
203            service_id: value.base.service_id,
204            endpoint_id: value.base.endpoint_id,
205            silo_id: value.silo_id,
206            hosted_silo_unaware,
207        }
208    }
209}
210
211impl From<HvsockConnectRequest> for protocol::TlConnectRequest2 {
212    fn from(value: HvsockConnectRequest) -> Self {
213        Self {
214            base: protocol::TlConnectRequest {
215                endpoint_id: value.endpoint_id,
216                service_id: value.service_id,
217            },
218            silo_id: value.silo_id,
219        }
220    }
221}
222
223/// A notification from the host that a connection request has been handled.
224#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
225pub struct HvsockConnectResult {
226    pub service_id: Guid,
227    pub endpoint_id: Guid,
228    pub success: bool,
229}
230
231impl HvsockConnectResult {
232    /// Create a new result using the service and endpoint ID from the specified request.
233    pub fn from_request(request: &HvsockConnectRequest, success: bool) -> Self {
234        Self {
235            service_id: request.service_id,
236            endpoint_id: request.endpoint_id,
237            success,
238        }
239    }
240}
241
242impl From<protocol::TlConnectResult> for HvsockConnectResult {
243    fn from(value: protocol::TlConnectResult) -> Self {
244        Self {
245            service_id: value.service_id,
246            endpoint_id: value.endpoint_id,
247            success: value.status == protocol::STATUS_SUCCESS,
248        }
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255    use crate::protocol::ChannelId;
256    use crate::protocol::GpadlId;
257
258    #[test]
259    fn test_outgoing_message() {
260        let message = OutgoingMessage::new(&protocol::CloseChannel {
261            channel_id: ChannelId(5),
262        });
263
264        assert_eq!(&[0x7, 0, 0, 0, 0, 0, 0, 0, 0x5, 0, 0, 0], message.data())
265    }
266
267    #[test]
268    fn test_outgoing_message_empty() {
269        let message = OutgoingMessage::new(&protocol::Unload {});
270
271        assert_eq!(&[0x10, 0, 0, 0, 0, 0, 0, 0], message.data())
272    }
273
274    #[test]
275    fn test_outgoing_message_with_data() {
276        let message = OutgoingMessage::with_data(
277            &protocol::GpadlHeader {
278                channel_id: ChannelId(5),
279                gpadl_id: GpadlId(1),
280                len: 7,
281                count: 6,
282            },
283            &[0xa, 0xb, 0xc, 0xd],
284        );
285
286        assert_eq!(
287            &[
288                0x8, 0, 0, 0, 0, 0, 0, 0, 0x5, 0, 0, 0, 0x1, 0, 0, 0, 0x7, 0, 0x6, 0, 0xa, 0xb,
289                0xc, 0xd
290            ],
291            message.data()
292        )
293    }
294}