use crate::OfferInfo;
use inspect::Inspect;
use mesh::rpc::Rpc;
use vmbus_core::HvsockConnectRequest;
use vmbus_core::protocol;
#[derive(Inspect)]
pub(crate) struct HvsockRequestTracker {
#[inspect(with = "|x| inspect::iter_by_index(x).map_value(|x| x.input())")]
pending_requests: Vec<Request>,
}
pub(crate) type Request = Rpc<HvsockConnectRequest, Option<OfferInfo>>;
impl HvsockRequestTracker {
pub fn new() -> Self {
Self {
pending_requests: Vec::new(),
}
}
pub fn add_request(&mut self, request: Request) {
self.pending_requests.push(request);
}
pub fn check_result(&mut self, result: &protocol::TlConnectResult) -> Option<Request> {
if result.status >= 0 {
tracing::warn!(
status = result.status,
"protocol violation: unexpected tl connect result success status"
);
return None;
}
if let Some(index) = self.pending_requests.iter().position(|request| {
request.input().service_id == result.service_id
&& request.input().endpoint_id == result.endpoint_id
}) {
let rpc = self.pending_requests.swap_remove(index);
Some(rpc)
} else {
tracing::warn!(?result, "Result for unknown hvsock request");
None
}
}
pub fn check_offer(&mut self, offer: &protocol::OfferChannel) -> Option<Request> {
if !offer.flags.tlnpi_provider() {
return None;
}
let params = offer.user_defined.as_hvsock_params();
if params.is_for_guest_accept != 0 {
return None;
}
let Some(index) = self.pending_requests.iter().position(|request| {
request.input().service_id == offer.interface_id
&& request.input().endpoint_id == offer.instance_id
}) else {
tracing::warn!(?offer, "Channel offer for unknown hvsock request");
return None;
};
let rpc = self.pending_requests.swap_remove(index);
tracing::debug!(request = ?rpc.input(), "channel offer matches hvsocket request");
Some(rpc)
}
}
#[cfg(test)]
mod tests {
use super::*;
use guid::Guid;
use vmbus_core::protocol::HvsockUserDefinedParameters;
use vmbus_core::protocol::OfferFlags;
use vmbus_core::protocol::UserDefinedData;
use zerocopy::FromZeros;
#[test]
fn test_check_result() {
let mut tracker = HvsockRequestTracker::new();
let request = HvsockConnectRequest {
service_id: Guid::new_random(),
endpoint_id: Guid::new_random(),
silo_id: Guid::new_random(),
hosted_silo_unaware: false,
};
tracker.add_request(Rpc::detached(request));
assert_eq!(1, tracker.pending_requests.len());
let result = protocol::TlConnectResult {
service_id: request.service_id,
endpoint_id: Guid::new_random(),
status: -1,
};
tracker.check_result(&result);
assert_eq!(1, tracker.pending_requests.len());
let result = protocol::TlConnectResult {
service_id: Guid::new_random(),
endpoint_id: request.endpoint_id,
status: -1,
};
tracker.check_result(&result);
assert_eq!(1, tracker.pending_requests.len());
let result = protocol::TlConnectResult {
service_id: request.service_id,
endpoint_id: request.endpoint_id,
status: -1,
};
tracker.check_result(&result);
assert_eq!(0, tracker.pending_requests.len());
}
#[test]
fn test_check_offer() {
let mut tracker = HvsockRequestTracker::new();
let request = HvsockConnectRequest {
service_id: Guid::new_random(),
endpoint_id: Guid::new_random(),
silo_id: Guid::new_random(),
hosted_silo_unaware: false,
};
tracker.add_request(Rpc::detached(request));
assert_eq!(1, tracker.pending_requests.len());
let offer = create_offer(request.service_id, Guid::new_random(), true, false);
assert!(tracker.check_offer(&offer).is_none());
let offer = create_offer(Guid::new_random(), request.endpoint_id, true, false);
assert!(tracker.check_offer(&offer).is_none());
let offer = create_offer(request.service_id, request.endpoint_id, false, false);
assert!(tracker.check_offer(&offer).is_none());
let offer = create_offer(request.service_id, request.endpoint_id, true, true);
assert!(tracker.check_offer(&offer).is_none());
let offer = create_offer(request.service_id, request.endpoint_id, true, false);
let found = tracker.check_offer(&offer).unwrap();
assert_eq!(*found.input(), request);
assert_eq!(0, tracker.pending_requests.len());
let offer = create_offer(request.service_id, request.endpoint_id, true, false);
assert!(tracker.check_offer(&offer).is_none());
}
fn create_offer(
interface_id: Guid,
instance_id: Guid,
hvsock: bool,
is_for_guest_accept: bool,
) -> protocol::OfferChannel {
let mut user_defined = UserDefinedData::new_zeroed();
*user_defined.as_hvsock_params_mut() =
HvsockUserDefinedParameters::new(is_for_guest_accept, true, Guid::new_random());
protocol::OfferChannel {
interface_id,
instance_id,
flags: OfferFlags::new()
.with_enumerate_device_interface(true)
.with_named_pipe_mode(true)
.with_tlnpi_provider(hvsock),
user_defined,
..FromZeros::new_zeroed()
}
}
}