vmbus_core/
lib.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4#![expect(missing_docs)]
5#![forbid(unsafe_code)]
6
7pub mod protocol;
8
9use futures::FutureExt;
10use futures::StreamExt;
11use guid::Guid;
12use inspect::Inspect;
13use protocol::HEADER_SIZE;
14use protocol::MAX_MESSAGE_SIZE;
15use protocol::MessageHeader;
16use protocol::VmbusMessage;
17use std::future::Future;
18use std::str::FromStr;
19use std::task::Poll;
20use thiserror::Error;
21use zerocopy::Immutable;
22use zerocopy::IntoBytes;
23use zerocopy::KnownLayout;
24
25/// The standard, non-redirected synthetic interrupt used by VMBus.
26pub const VMBUS_SINT: u8 = 2;
27
28#[derive(Debug)]
29pub struct TaggedStream<T, S>(Option<T>, S);
30
31impl<T: Clone, S: futures::Stream + Unpin> TaggedStream<T, S> {
32    pub fn new(t: T, s: S) -> Self {
33        Self(Some(t), s)
34    }
35
36    pub fn value(&self) -> Option<&T> {
37        self.0.as_ref()
38    }
39}
40
41impl<T: Clone, S: futures::Stream + Unpin> futures::Stream for TaggedStream<T, S>
42where
43    Self: Unpin,
44{
45    type Item = (T, Option<S::Item>);
46
47    fn poll_next(
48        self: std::pin::Pin<&mut Self>,
49        cx: &mut std::task::Context<'_>,
50    ) -> Poll<Option<Self::Item>> {
51        let this = self.get_mut();
52        if let Some(t) = this.0.clone() {
53            let v = std::task::ready!(this.1.poll_next_unpin(cx));
54            if v.is_none() {
55                // Return `None` next time poll_next is called.
56                this.0 = None;
57            }
58            Poll::Ready(Some((t, v)))
59        } else {
60            Poll::Ready(None)
61        }
62    }
63}
64
65#[derive(Debug)]
66pub struct TaggedFuture<T, F>(T, F);
67
68impl<T: Clone, F: Future + Unpin> Future for TaggedFuture<T, F>
69where
70    Self: Unpin,
71{
72    type Output = (T, F::Output);
73
74    fn poll(
75        mut self: std::pin::Pin<&mut Self>,
76        cx: &mut std::task::Context<'_>,
77    ) -> Poll<Self::Output> {
78        let r = std::task::ready!(self.1.poll_unpin(cx));
79        Poll::Ready((self.0.clone(), r))
80    }
81}
82
83/// Represents information about a negotiated version.
84#[derive(Copy, Clone, Debug, PartialEq, Eq, Inspect)]
85pub struct VersionInfo {
86    pub version: protocol::Version,
87    pub feature_flags: protocol::FeatureFlags,
88}
89
90/// Represents a constraint on the version or features allowed.
91#[derive(Copy, Clone, Debug)]
92pub struct MaxVersionInfo {
93    pub version: u32,
94    pub feature_flags: protocol::FeatureFlags,
95}
96
97impl MaxVersionInfo {
98    pub fn new(version: u32) -> Self {
99        Self {
100            version,
101            feature_flags: protocol::FeatureFlags::new(),
102        }
103    }
104}
105
106impl From<VersionInfo> for MaxVersionInfo {
107    fn from(info: VersionInfo) -> Self {
108        Self {
109            version: info.version as u32,
110            feature_flags: info.feature_flags,
111        }
112    }
113}
114
115/// Parses a string of the form "major.minor" (e.g "5.3") into a vmbus version number.
116///
117/// N.B. This doesn't check whether the specified version actually exists.
118pub fn parse_vmbus_version(value: &str) -> Result<u32, String> {
119    || -> Option<u32> {
120        let (major, minor) = value.split_once('.')?;
121        let major = u16::from_str(major).ok()?;
122        let minor = u16::from_str(minor).ok()?;
123        Some(protocol::make_version(major, minor))
124    }()
125    .ok_or_else(|| format!("invalid vmbus version '{}'", value))
126}
127
128#[derive(Clone, Debug)]
129pub struct OutgoingMessage {
130    data: [u8; MAX_MESSAGE_SIZE],
131    len: u8,
132}
133
134/// Represents a vmbus message to be sent using the synic.
135impl OutgoingMessage {
136    /// Creates a new `OutgoingMessage` for the specified protocol message.
137    pub fn new<T: IntoBytes + Immutable + KnownLayout + VmbusMessage>(message: &T) -> Self {
138        let mut data = [0; MAX_MESSAGE_SIZE];
139        let header = MessageHeader::new(T::MESSAGE_TYPE);
140        let message_bytes = message.as_bytes();
141        let len = HEADER_SIZE + message_bytes.len();
142        data[..HEADER_SIZE].copy_from_slice(header.as_bytes());
143        data[HEADER_SIZE..len].copy_from_slice(message_bytes);
144        Self {
145            data,
146            len: len as u8,
147        }
148    }
149
150    /// Creates a new `OutgoingMessage` for the specified protocol message, including additional
151    /// data at the end of the message.
152    pub fn with_data<T: IntoBytes + Immutable + KnownLayout + VmbusMessage>(
153        message: &T,
154        data: &[u8],
155    ) -> Self {
156        let mut message = OutgoingMessage::new(message);
157        let old_len = message.len as usize;
158        let len = old_len + data.len();
159        message.data[old_len..len].copy_from_slice(data);
160        message.len = len as u8;
161        message
162    }
163
164    /// Converts an existing binary message to an `OutgoingMessage`. The slice
165    /// is assumed to contain a valid message.
166    pub fn from_message(message: &[u8]) -> Result<Self, MessageTooLarge> {
167        if message.len() > MAX_MESSAGE_SIZE {
168            return Err(MessageTooLarge);
169        }
170        let mut data = [0; MAX_MESSAGE_SIZE];
171        data[0..message.len()].copy_from_slice(message);
172        Ok(Self {
173            data,
174            len: message.len() as u8,
175        })
176    }
177
178    /// Gets the binary representation of the message.
179    pub fn data(&self) -> &[u8] {
180        &self.data[..self.len as usize]
181    }
182}
183
184impl PartialEq for OutgoingMessage {
185    fn eq(&self, other: &Self) -> bool {
186        self.len == other.len && self.data[..self.len as usize] == other.data[..self.len as usize]
187    }
188}
189
190#[derive(Debug, Error)]
191#[error("a synic message exceeds the maximum length")]
192pub struct MessageTooLarge;
193
194/// A request from the guest to connect to the specified hvsocket endpoint.
195#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Inspect)]
196pub struct HvsockConnectRequest {
197    pub service_id: Guid,
198    pub endpoint_id: Guid,
199    pub silo_id: Guid,
200    pub hosted_silo_unaware: bool,
201}
202
203impl HvsockConnectRequest {
204    pub fn from_message(value: protocol::TlConnectRequest2, hosted_silo_unaware: bool) -> Self {
205        Self {
206            service_id: value.base.service_id,
207            endpoint_id: value.base.endpoint_id,
208            silo_id: value.silo_id,
209            hosted_silo_unaware,
210        }
211    }
212}
213
214impl From<HvsockConnectRequest> for protocol::TlConnectRequest2 {
215    fn from(value: HvsockConnectRequest) -> Self {
216        Self {
217            base: protocol::TlConnectRequest {
218                endpoint_id: value.endpoint_id,
219                service_id: value.service_id,
220            },
221            silo_id: value.silo_id,
222        }
223    }
224}
225
226/// A notification from the host that a connection request has been handled.
227#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
228pub struct HvsockConnectResult {
229    pub service_id: Guid,
230    pub endpoint_id: Guid,
231    pub success: bool,
232}
233
234impl HvsockConnectResult {
235    /// Create a new result using the service and endpoint ID from the specified request.
236    pub fn from_request(request: &HvsockConnectRequest, success: bool) -> Self {
237        Self {
238            service_id: request.service_id,
239            endpoint_id: request.endpoint_id,
240            success,
241        }
242    }
243}
244
245impl From<protocol::TlConnectResult> for HvsockConnectResult {
246    fn from(value: protocol::TlConnectResult) -> Self {
247        Self {
248            service_id: value.service_id,
249            endpoint_id: value.endpoint_id,
250            success: value.status == protocol::STATUS_SUCCESS,
251        }
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use crate::protocol::ChannelId;
259    use crate::protocol::GpadlId;
260
261    #[test]
262    fn test_outgoing_message() {
263        let message = OutgoingMessage::new(&protocol::CloseChannel {
264            channel_id: ChannelId(5),
265        });
266
267        assert_eq!(&[0x7, 0, 0, 0, 0, 0, 0, 0, 0x5, 0, 0, 0], message.data())
268    }
269
270    #[test]
271    fn test_outgoing_message_empty() {
272        let message = OutgoingMessage::new(&protocol::Unload {});
273
274        assert_eq!(&[0x10, 0, 0, 0, 0, 0, 0, 0], message.data())
275    }
276
277    #[test]
278    fn test_outgoing_message_with_data() {
279        let message = OutgoingMessage::with_data(
280            &protocol::GpadlHeader {
281                channel_id: ChannelId(5),
282                gpadl_id: GpadlId(1),
283                len: 7,
284                count: 6,
285            },
286            &[0xa, 0xb, 0xc, 0xd],
287        );
288
289        assert_eq!(
290            &[
291                0x8, 0, 0, 0, 0, 0, 0, 0, 0x5, 0, 0, 0, 0x1, 0, 0, 0, 0x7, 0, 0x6, 0, 0xa, 0xb,
292                0xc, 0xd
293            ],
294            message.data()
295        )
296    }
297}