vmbus_client/
hvsock.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

use crate::OfferInfo;
use inspect::Inspect;
use mesh::rpc::Rpc;
use vmbus_core::HvsockConnectRequest;
use vmbus_core::protocol;

/// Tracks guest-to-host hvsocket requests that the host has not responded to yet.
#[derive(Inspect)]
pub(crate) struct HvsockRequestTracker {
    #[inspect(with = "|x| inspect::iter_by_index(x).map_value(|x| x.input())")]
    pending_requests: Vec<Request>,
}

pub(crate) type Request = Rpc<HvsockConnectRequest, Option<OfferInfo>>;

impl HvsockRequestTracker {
    /// Create a new request tracker.
    pub fn new() -> Self {
        Self {
            pending_requests: Vec::new(),
        }
    }

    /// Adds a new request to be tracked.
    pub fn add_request(&mut self, request: Request) {
        self.pending_requests.push(request);
    }

    /// Checks if a result from the host matches a request, and if so removes it.
    pub fn check_result(&mut self, result: &protocol::TlConnectResult) -> Option<Request> {
        if result.status >= 0 {
            tracing::warn!(
                status = result.status,
                "protocol violation: unexpected tl connect result success status"
            );
            return None;
        }
        if let Some(index) = self.pending_requests.iter().position(|request| {
            request.input().service_id == result.service_id
                && request.input().endpoint_id == result.endpoint_id
        }) {
            let rpc = self.pending_requests.swap_remove(index);
            Some(rpc)
        } else {
            tracing::warn!(?result, "Result for unknown hvsock request");
            None
        }
    }

    /// Checks if an offer from the host matches a request, and if so removes it and returns a
    /// result message to send to the vmbus server.
    pub fn check_offer(&mut self, offer: &protocol::OfferChannel) -> Option<Request> {
        if !offer.flags.tlnpi_provider() {
            return None;
        }

        let params = offer.user_defined.as_hvsock_params();
        if params.is_for_guest_accept != 0 {
            return None;
        }

        // Since silo_id isn't part of the result message, it doesn't need to be checked here
        // either.
        let Some(index) = self.pending_requests.iter().position(|request| {
            request.input().service_id == offer.interface_id
                && request.input().endpoint_id == offer.instance_id
        }) else {
            tracing::warn!(?offer, "Channel offer for unknown hvsock request");
            return None;
        };

        let rpc = self.pending_requests.swap_remove(index);
        tracing::debug!(request = ?rpc.input(), "channel offer matches hvsocket request");
        Some(rpc)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use guid::Guid;
    use vmbus_core::protocol::HvsockUserDefinedParameters;
    use vmbus_core::protocol::OfferFlags;
    use vmbus_core::protocol::UserDefinedData;
    use zerocopy::FromZeros;

    #[test]
    fn test_check_result() {
        let mut tracker = HvsockRequestTracker::new();
        let request = HvsockConnectRequest {
            service_id: Guid::new_random(),
            endpoint_id: Guid::new_random(),
            silo_id: Guid::new_random(),
            hosted_silo_unaware: false,
        };

        tracker.add_request(Rpc::detached(request));
        assert_eq!(1, tracker.pending_requests.len());

        // Endpoint ID mismatch.
        let result = protocol::TlConnectResult {
            service_id: request.service_id,
            endpoint_id: Guid::new_random(),
            status: -1,
        };

        tracker.check_result(&result);
        assert_eq!(1, tracker.pending_requests.len());

        // Service ID mismatch.
        let result = protocol::TlConnectResult {
            service_id: Guid::new_random(),
            endpoint_id: request.endpoint_id,
            status: -1,
        };

        tracker.check_result(&result);
        assert_eq!(1, tracker.pending_requests.len());

        // Match.
        let result = protocol::TlConnectResult {
            service_id: request.service_id,
            endpoint_id: request.endpoint_id,
            status: -1,
        };
        tracker.check_result(&result);
        assert_eq!(0, tracker.pending_requests.len());
    }

    #[test]
    fn test_check_offer() {
        let mut tracker = HvsockRequestTracker::new();
        let request = HvsockConnectRequest {
            service_id: Guid::new_random(),
            endpoint_id: Guid::new_random(),
            silo_id: Guid::new_random(),
            hosted_silo_unaware: false,
        };

        tracker.add_request(Rpc::detached(request));
        assert_eq!(1, tracker.pending_requests.len());

        // Endpoint ID mismatch.
        let offer = create_offer(request.service_id, Guid::new_random(), true, false);
        assert!(tracker.check_offer(&offer).is_none());

        // Endpoint ID mismatch.
        let offer = create_offer(Guid::new_random(), request.endpoint_id, true, false);
        assert!(tracker.check_offer(&offer).is_none());

        // Not a socket request.
        let offer = create_offer(request.service_id, request.endpoint_id, false, false);
        assert!(tracker.check_offer(&offer).is_none());

        // Accept request.
        let offer = create_offer(request.service_id, request.endpoint_id, true, true);
        assert!(tracker.check_offer(&offer).is_none());

        // Match.
        let offer = create_offer(request.service_id, request.endpoint_id, true, false);
        let found = tracker.check_offer(&offer).unwrap();
        assert_eq!(*found.input(), request);
        assert_eq!(0, tracker.pending_requests.len());

        // It no longer exists.
        let offer = create_offer(request.service_id, request.endpoint_id, true, false);
        assert!(tracker.check_offer(&offer).is_none());
    }

    fn create_offer(
        interface_id: Guid,
        instance_id: Guid,
        hvsock: bool,
        is_for_guest_accept: bool,
    ) -> protocol::OfferChannel {
        let mut user_defined = UserDefinedData::new_zeroed();
        *user_defined.as_hvsock_params_mut() =
            HvsockUserDefinedParameters::new(is_for_guest_accept, true, Guid::new_random());

        protocol::OfferChannel {
            interface_id,
            instance_id,
            flags: OfferFlags::new()
                .with_enumerate_device_interface(true)
                .with_named_pipe_mode(true)
                .with_tlnpi_provider(hvsock),
            user_defined,
            ..FromZeros::new_zeroed()
        }
    }
}