mod saved_state;
use crate::Guid;
use crate::SINT;
use crate::SynicMessage;
use crate::monitor::AssignedMonitors;
use crate::protocol::Version;
use hvdef::Vtl;
use inspect::Inspect;
pub use saved_state::RestoreError;
pub use saved_state::SavedState;
use slab::Slab;
use std::cmp::min;
use std::collections::VecDeque;
use std::collections::hash_map::Entry;
use std::collections::hash_map::HashMap;
use std::fmt::Display;
use std::ops::Index;
use std::ops::IndexMut;
use std::task::Poll;
use std::task::ready;
use thiserror::Error;
use vmbus_channel::bus::ChannelType;
use vmbus_channel::bus::GpadlRequest;
use vmbus_channel::bus::OfferKey;
use vmbus_channel::bus::OfferParams;
use vmbus_channel::bus::OpenData;
use vmbus_channel::bus::RestoredGpadl;
use vmbus_core::HvsockConnectRequest;
use vmbus_core::HvsockConnectResult;
use vmbus_core::MaxVersionInfo;
use vmbus_core::MonitorPageGpas;
use vmbus_core::OutgoingMessage;
use vmbus_core::VersionInfo;
use vmbus_core::protocol;
use vmbus_core::protocol::ChannelId;
use vmbus_core::protocol::ConnectionId;
use vmbus_core::protocol::FeatureFlags;
use vmbus_core::protocol::GpadlId;
use vmbus_core::protocol::Message;
use vmbus_core::protocol::OfferFlags;
use vmbus_core::protocol::UserDefinedData;
use vmbus_ring::gparange;
use vmcore::monitor::MonitorId;
use zerocopy::FromZeros;
use zerocopy::Immutable;
use zerocopy::IntoBytes;
use zerocopy::KnownLayout;
#[derive(Debug, Error)]
pub enum ChannelError {
#[error("unknown channel ID")]
UnknownChannelId,
#[error("unknown GPADL ID")]
UnknownGpadlId,
#[error("parse error")]
ParseError(#[from] protocol::ParseError),
#[error("invalid gpa range")]
InvalidGpaRange(#[source] gparange::Error),
#[error("duplicate GPADL ID")]
DuplicateGpadlId,
#[error("GPADL is already complete")]
GpadlAlreadyComplete,
#[error("GPADL channel ID mismatch")]
WrongGpadlChannelId,
#[error("trying to open an open channel")]
ChannelAlreadyOpen,
#[error("trying to close a closed channel")]
ChannelNotOpen,
#[error("invalid GPADL state for operation")]
InvalidGpadlState,
#[error("invalid channel state for operation")]
InvalidChannelState,
#[error("channel ID has already been released")]
ChannelReleased,
#[error("channel offers have already been sent")]
OffersAlreadySent,
#[error("invalid operation on reserved channel")]
ChannelReserved,
#[error("invalid operation on non-reserved channel")]
ChannelNotReserved,
#[error("received untrusted message for trusted connection")]
UntrustedMessage,
#[error("received a non-resuming message while paused")]
Paused,
}
#[derive(Debug, Error)]
pub enum OfferError {
#[error("the channel ID {} is not valid for this operation", (.0).0)]
InvalidChannelId(ChannelId),
#[error("the channel ID {} is already in use", (.0).0)]
ChannelIdInUse(ChannelId),
#[error("offer {0} already exists")]
AlreadyExists(OfferKey),
#[error("specified resources do not match those of the existing saved or revoked offer")]
IncompatibleResources,
#[error("too many channels have been offered")]
TooManyChannels,
#[error("mismatched monitor ID from saved state; expected {0:?}, actual {1:?}")]
MismatchedMonitorId(Option<MonitorId>, MonitorId),
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct OfferId(usize);
type IncompleteGpadlMap = HashMap<GpadlId, OfferId>;
type GpadlMap = HashMap<(GpadlId, OfferId), Gpadl>;
pub struct Server {
state: ConnectionState,
channels: ChannelList,
assigned_channels: AssignedChannels,
assigned_monitors: AssignedMonitors,
gpadls: GpadlMap,
incomplete_gpadls: IncompleteGpadlMap,
child_connection_id: u32,
max_version: Option<MaxVersionInfo>,
delayed_max_version: Option<MaxVersionInfo>,
pending_messages: PendingMessages,
}
pub struct ServerWithNotifier<'a, T> {
inner: &'a mut Server,
notifier: &'a mut T,
}
impl<T> Drop for ServerWithNotifier<'_, T> {
fn drop(&mut self) {
self.inner.validate();
}
}
impl<T: Notifier> Inspect for ServerWithNotifier<'_, T> {
fn inspect(&self, req: inspect::Request<'_>) {
let mut resp = req.respond();
let (state, info, next_action) = match &self.inner.state {
ConnectionState::Disconnected => ("disconnected", None, None),
ConnectionState::Connecting { info, .. } => ("connecting", Some(info), None),
ConnectionState::Connected(info) => (
if info.offers_sent {
"connected"
} else {
"negotiated"
},
Some(info),
None,
),
ConnectionState::Disconnecting { next_action, .. } => {
("disconnecting", None, Some(next_action))
}
};
resp.field("connection_info", info);
let next_action = next_action.map(|a| match a {
ConnectionAction::None => "disconnect",
ConnectionAction::Reset => "reset",
ConnectionAction::SendUnloadComplete => "unload",
ConnectionAction::Reconnect { .. } => "reconnect",
ConnectionAction::SendFailedVersionResponse => "send_version_response",
});
resp.field("state", state)
.field("next_action", next_action)
.field(
"assigned_monitors_bitmap",
format_args!("{:x}", self.inner.assigned_monitors.bitmap()),
)
.child("channels", |req| {
let mut resp = req.respond();
self.inner
.channels
.inspect(self.notifier, self.inner.get_version(), &mut resp);
for ((gpadl_id, offer_id), gpadl) in &self.inner.gpadls {
let channel = &self.inner.channels[*offer_id];
resp.field(
&channel_inspect_path(
&channel.offer,
format_args!("/gpadls/{}", gpadl_id.0),
),
gpadl,
);
}
});
}
}
#[derive(Debug, Copy, Clone, Inspect)]
struct ConnectionInfo {
version: VersionInfo,
trusted: bool,
offers_sent: bool,
interrupt_page: Option<u64>,
monitor_page: Option<MonitorPageGpas>,
target_message_vp: u32,
modifying: bool,
client_id: Guid,
paused: bool,
}
#[derive(Debug)]
enum ConnectionState {
Disconnected,
Disconnecting {
next_action: ConnectionAction,
modify_sent: bool,
},
Connecting {
info: ConnectionInfo,
next_action: ConnectionAction,
},
Connected(ConnectionInfo),
}
impl ConnectionState {
fn check_version(&self, min_version: Version) -> bool {
matches!(self, ConnectionState::Connected(info) if info.version.version >= min_version)
}
fn check_feature_flags(&self, flags: impl Fn(FeatureFlags) -> bool) -> bool {
matches!(self, ConnectionState::Connected(info) if flags(info.version.feature_flags))
}
fn get_version(&self) -> Option<VersionInfo> {
if let ConnectionState::Connected(info) = self {
Some(info.version)
} else {
None
}
}
fn is_trusted(&self) -> bool {
match self {
ConnectionState::Connected(info) => info.trusted,
ConnectionState::Connecting { info, .. } => info.trusted,
_ => false,
}
}
fn is_paused(&self) -> bool {
if let ConnectionState::Connected(info) = self {
info.paused
} else {
false
}
}
}
#[derive(Debug, Copy, Clone)]
enum ConnectionAction {
None,
Reset,
SendUnloadComplete,
Reconnect {
initiate_contact: InitiateContactRequest,
},
SendFailedVersionResponse,
}
#[derive(PartialEq, Eq, Debug, Copy, Clone)]
pub enum MonitorPageRequest {
None,
Some(MonitorPageGpas),
Invalid,
}
#[derive(PartialEq, Eq, Debug, Copy, Clone)]
pub struct InitiateContactRequest {
pub version_requested: u32,
pub target_message_vp: u32,
pub monitor_page: MonitorPageRequest,
pub target_sint: u8,
pub target_vtl: u8,
pub feature_flags: u32,
pub interrupt_page: Option<u64>,
pub client_id: Guid,
pub trusted: bool,
}
#[derive(Debug, Copy, Clone)]
pub struct OpenRequest {
pub open_id: u32,
pub ring_buffer_gpadl_id: GpadlId,
pub target_vp: u32,
pub downstream_ring_buffer_page_offset: u32,
pub user_data: UserDefinedData,
pub guest_specified_interrupt_info: Option<SignalInfo>,
pub flags: protocol::OpenChannelFlags,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum Update<T: std::fmt::Debug + Copy + Clone> {
Unchanged,
Reset,
Set(T),
}
impl<T: std::fmt::Debug + Copy + Clone> From<Option<T>> for Update<T> {
fn from(value: Option<T>) -> Self {
match value {
None => Self::Reset,
Some(value) => Self::Set(value),
}
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct ModifyConnectionRequest {
pub version: Option<u32>,
pub monitor_page: Update<MonitorPageGpas>,
pub interrupt_page: Update<u64>,
pub target_message_vp: Option<u32>,
pub force: bool,
pub notify_relay: bool,
}
impl Default for ModifyConnectionRequest {
fn default() -> Self {
Self {
version: None,
monitor_page: Update::Unchanged,
interrupt_page: Update::Unchanged,
target_message_vp: None,
force: false,
notify_relay: true,
}
}
}
impl From<protocol::ModifyConnection> for ModifyConnectionRequest {
fn from(value: protocol::ModifyConnection) -> Self {
let monitor_page = if value.parent_to_child_monitor_page_gpa != 0 {
Update::Set(MonitorPageGpas {
parent_to_child: value.parent_to_child_monitor_page_gpa,
child_to_parent: value.child_to_parent_monitor_page_gpa,
})
} else {
Update::Reset
};
Self {
monitor_page,
..Default::default()
}
}
}
#[derive(Debug, Copy, Clone)]
pub enum ModifyConnectionResponse {
Supported(protocol::ConnectionState, FeatureFlags),
Unsupported,
}
#[derive(Debug, Copy, Clone)]
pub enum ModifyState {
NotModifying,
Modifying { pending_target_vp: Option<u32> },
}
impl ModifyState {
pub fn is_modifying(&self) -> bool {
matches!(self, ModifyState::Modifying { .. })
}
}
#[derive(Debug, Copy, Clone)]
pub struct SignalInfo {
pub event_flag: u16,
pub connection_id: u32,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum RestoreState {
New,
Restoring,
Unmatched,
Restored,
}
#[derive(Debug, Clone)]
enum ChannelState {
ClientReleased,
Closed,
Opening {
request: OpenRequest,
reserved_state: Option<ReservedState>,
},
Open {
params: OpenRequest,
modify_state: ModifyState,
reserved_state: Option<ReservedState>,
},
Closing {
params: OpenRequest,
reserved_state: Option<ReservedState>,
},
ClosingReopen {
params: OpenRequest,
request: OpenRequest,
},
Revoked,
Reoffered,
ClosingClientRelease,
OpeningClientRelease,
}
impl ChannelState {
fn is_released(&self) -> bool {
match self {
ChannelState::Closed
| ChannelState::Opening { .. }
| ChannelState::Open { .. }
| ChannelState::Closing { .. }
| ChannelState::ClosingReopen { .. }
| ChannelState::Revoked
| ChannelState::Reoffered => false,
ChannelState::ClientReleased
| ChannelState::ClosingClientRelease
| ChannelState::OpeningClientRelease => true,
}
}
fn is_revoked(&self) -> bool {
match self {
ChannelState::Revoked | ChannelState::Reoffered => true,
ChannelState::ClientReleased
| ChannelState::Closed
| ChannelState::Opening { .. }
| ChannelState::Open { .. }
| ChannelState::Closing { .. }
| ChannelState::ClosingReopen { .. }
| ChannelState::ClosingClientRelease
| ChannelState::OpeningClientRelease => false,
}
}
fn is_reserved(&self) -> bool {
match self {
ChannelState::Open {
reserved_state: Some(_),
..
}
| ChannelState::Opening {
reserved_state: Some(_),
..
}
| ChannelState::Closing {
reserved_state: Some(_),
..
} => true,
ChannelState::Opening { .. }
| ChannelState::Open { .. }
| ChannelState::Closing { .. }
| ChannelState::ClientReleased
| ChannelState::Closed
| ChannelState::ClosingReopen { .. }
| ChannelState::Revoked
| ChannelState::Reoffered
| ChannelState::ClosingClientRelease
| ChannelState::OpeningClientRelease => false,
}
}
}
impl Display for ChannelState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let state = match self {
Self::ClientReleased => "ClientReleased",
Self::Closed => "Closed",
Self::Opening { .. } => "Opening",
Self::Open { .. } => "Open",
Self::Closing { .. } => "Closing",
Self::ClosingReopen { .. } => "ClosingReopen",
Self::Revoked => "Revoked",
Self::Reoffered => "Reoffered",
Self::ClosingClientRelease => "ClosingClientRelease",
Self::OpeningClientRelease => "OpeningClientRelease",
};
write!(f, "{}", state)
}
}
#[derive(Debug, Clone, Default, mesh::MeshPayload)]
pub struct OfferParamsInternal {
pub interface_name: String,
pub instance_id: Guid,
pub interface_id: Guid,
pub mmio_megabytes: u16,
pub mmio_megabytes_optional: u16,
pub subchannel_index: u16,
pub use_mnf: bool,
pub offer_order: Option<u32>,
pub flags: OfferFlags,
pub user_defined: UserDefinedData,
pub monitor_id: Option<u8>,
}
impl OfferParamsInternal {
pub fn key(&self) -> OfferKey {
OfferKey {
interface_id: self.interface_id,
instance_id: self.instance_id,
subchannel_index: self.subchannel_index,
}
}
}
impl From<OfferParams> for OfferParamsInternal {
fn from(value: OfferParams) -> Self {
let mut user_defined = UserDefinedData::new_zeroed();
let mut flags = OfferFlags::new()
.with_confidential_ring_buffer(true)
.with_confidential_external_memory(value.allow_confidential_external_memory);
match value.channel_type {
ChannelType::Device { pipe_packets } => {
if pipe_packets {
flags.set_named_pipe_mode(true);
user_defined.as_pipe_params_mut().pipe_type = protocol::PipeType::MESSAGE;
}
}
ChannelType::Interface {
user_defined: interface_user_defined,
} => {
flags.set_enumerate_device_interface(true);
user_defined = interface_user_defined;
}
ChannelType::Pipe { message_mode } => {
flags.set_enumerate_device_interface(true);
flags.set_named_pipe_mode(true);
user_defined.as_pipe_params_mut().pipe_type = if message_mode {
protocol::PipeType::MESSAGE
} else {
protocol::PipeType::BYTE
};
}
ChannelType::HvSocket {
is_connect,
is_for_container,
silo_id,
} => {
flags.set_enumerate_device_interface(true);
flags.set_tlnpi_provider(true);
flags.set_named_pipe_mode(true);
*user_defined.as_hvsock_params_mut() = protocol::HvsockUserDefinedParameters::new(
is_connect,
is_for_container,
silo_id,
);
}
};
Self {
interface_name: value.interface_name,
instance_id: value.instance_id,
interface_id: value.interface_id,
mmio_megabytes: value.mmio_megabytes,
mmio_megabytes_optional: value.mmio_megabytes_optional,
subchannel_index: value.subchannel_index,
use_mnf: value.use_mnf,
offer_order: value.offer_order,
user_defined,
flags,
monitor_id: None,
}
}
}
#[derive(Debug, Copy, Clone, Inspect, PartialEq, Eq)]
pub struct ConnectionTarget {
pub vp: u32,
pub sint: u8,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum MessageTarget {
Default,
ReservedChannel(OfferId, ConnectionTarget),
Custom(ConnectionTarget),
}
impl MessageTarget {
pub fn for_offer(offer_id: OfferId, reserved_state: &Option<ReservedState>) -> Self {
if let Some(state) = reserved_state {
Self::ReservedChannel(offer_id, state.target)
} else {
Self::Default
}
}
}
#[derive(Debug, Copy, Clone)]
pub struct ReservedState {
version: VersionInfo,
target: ConnectionTarget,
}
#[derive(Debug)]
struct Channel {
info: Option<OfferedInfo>,
offer: OfferParamsInternal,
state: ChannelState,
restore_state: RestoreState,
}
#[derive(Debug, Copy, Clone)]
struct OfferedInfo {
channel_id: ChannelId,
connection_id: u32,
monitor_id: Option<MonitorId>,
}
impl Channel {
fn inspect_state(&self, resp: &mut inspect::Response<'_>) {
let mut target_vp = None;
let mut event_flag = None;
let mut connection_id = None;
let mut reserved_target = None;
let state = match &self.state {
ChannelState::ClientReleased => "client_released",
ChannelState::Closed => "closed",
ChannelState::Opening { reserved_state, .. } => {
reserved_target = reserved_state.map(|state| state.target);
"opening"
}
ChannelState::Open {
params,
reserved_state,
..
} => {
target_vp = Some(params.target_vp);
if let Some(id) = params.guest_specified_interrupt_info {
event_flag = Some(id.event_flag);
connection_id = Some(id.connection_id);
}
reserved_target = reserved_state.map(|state| state.target);
"open"
}
ChannelState::Closing { reserved_state, .. } => {
reserved_target = reserved_state.map(|state| state.target);
"closing"
}
ChannelState::ClosingReopen { .. } => "closing_reopen",
ChannelState::Revoked => "revoked",
ChannelState::Reoffered => "reoffered",
ChannelState::ClosingClientRelease => "closing_client_release",
ChannelState::OpeningClientRelease => "opening_client_release",
};
let restore_state = match self.restore_state {
RestoreState::New => "new",
RestoreState::Restoring => "restoring",
RestoreState::Restored => "restored",
RestoreState::Unmatched => "unmatched",
};
if let Some(info) = &self.info {
resp.field("channel_id", info.channel_id.0)
.field("offered_connection_id", info.connection_id)
.field("monitor_id", info.monitor_id.map(|id| id.0));
}
resp.field("state", state)
.field("restore_state", restore_state)
.field("interface_name", self.offer.interface_name.clone())
.display("instance_id", &self.offer.instance_id)
.display("interface_id", &self.offer.interface_id)
.field("mmio_megabytes", self.offer.mmio_megabytes)
.field("target_vp", target_vp)
.field("guest_specified_event_flag", event_flag)
.field("guest_specified_connection_id", connection_id)
.field("reserved_connection_target", reserved_target)
.binary("offer_flags", self.offer.flags.into_bits());
}
fn handled_monitor_id(&self) -> Option<MonitorId> {
if self.offer.use_mnf && !self.state.is_reserved() {
self.info.and_then(|info| info.monitor_id)
} else {
None
}
}
fn prepare_channel(
&mut self,
offer_id: OfferId,
assigned_channels: &mut AssignedChannels,
assigned_monitors: &mut AssignedMonitors,
) {
assert!(self.info.is_none());
let entry = assigned_channels
.allocate()
.expect("there are enough channel IDs for everything in ChannelList");
let channel_id = entry.id();
entry.insert(offer_id);
let connection_id = ConnectionId::new(channel_id.0, assigned_channels.vtl, SINT);
let monitor_id = if self.offer.use_mnf {
let monitor_id = assigned_monitors.assign_monitor();
if monitor_id.is_none() {
tracelimit::warn_ratelimited!("Out of monitor IDs.");
}
monitor_id
} else {
self.offer.monitor_id.map(MonitorId)
};
self.info = Some(OfferedInfo {
channel_id,
connection_id: connection_id.0,
monitor_id,
});
}
fn release_channel(
&mut self,
offer_id: OfferId,
assigned_channels: &mut AssignedChannels,
assigned_monitors: &mut AssignedMonitors,
) {
if let Some(info) = self.info.take() {
assigned_channels.free(info.channel_id, offer_id);
if let Some(monitor_id) = info.monitor_id {
if self.offer.use_mnf {
assigned_monitors.release_monitor(monitor_id);
}
}
}
}
}
#[derive(Debug)]
struct AssignedChannels {
assignments: Vec<Option<OfferId>>,
vtl: Vtl,
reserved_offset: usize,
count_in_reserved_range: usize,
}
impl AssignedChannels {
fn new(vtl: Vtl, channel_id_offset: u16) -> Self {
Self {
assignments: vec![None; MAX_CHANNELS],
vtl,
reserved_offset: channel_id_offset as usize,
count_in_reserved_range: 0,
}
}
fn allowable_channel_count(&self) -> usize {
MAX_CHANNELS - self.reserved_offset + self.count_in_reserved_range
}
fn get(&self, channel_id: ChannelId) -> Option<OfferId> {
self.assignments
.get(Self::index(channel_id))
.copied()
.flatten()
}
fn set(&mut self, channel_id: ChannelId) -> Result<AssignmentEntry<'_>, OfferError> {
let index = Self::index(channel_id);
if self
.assignments
.get(index)
.ok_or(OfferError::InvalidChannelId(channel_id))?
.is_some()
{
return Err(OfferError::ChannelIdInUse(channel_id));
}
Ok(AssignmentEntry { list: self, index })
}
fn allocate(&mut self) -> Option<AssignmentEntry<'_>> {
let index = self.reserved_offset
+ self.assignments[self.reserved_offset..]
.iter()
.position(|x| x.is_none())?;
Some(AssignmentEntry { list: self, index })
}
fn free(&mut self, channel_id: ChannelId, offer_id: OfferId) {
let index = Self::index(channel_id);
let slot = &mut self.assignments[index];
assert_eq!(slot.take(), Some(offer_id));
if index < self.reserved_offset {
self.count_in_reserved_range -= 1;
}
}
fn index(channel_id: ChannelId) -> usize {
channel_id.0.wrapping_sub(1) as usize
}
}
struct AssignmentEntry<'a> {
list: &'a mut AssignedChannels,
index: usize,
}
impl AssignmentEntry<'_> {
pub fn id(&self) -> ChannelId {
ChannelId(self.index as u32 + 1)
}
pub fn insert(self, offer_id: OfferId) {
assert!(
self.list.assignments[self.index]
.replace(offer_id)
.is_none()
);
if self.index < self.list.reserved_offset {
self.list.count_in_reserved_range += 1;
}
}
}
struct ChannelList {
channels: Slab<Channel>,
}
fn channel_inspect_path(offer: &OfferParamsInternal, suffix: std::fmt::Arguments<'_>) -> String {
if offer.subchannel_index == 0 {
format!("{}{}", offer.instance_id, suffix)
} else {
format!(
"{}/subchannels/{}{}",
offer.instance_id, offer.subchannel_index, suffix
)
}
}
impl ChannelList {
fn inspect(
&self,
notifier: &impl Notifier,
version: Option<VersionInfo>,
resp: &mut inspect::Response<'_>,
) {
for (offer_id, channel) in self.iter() {
resp.child(
&channel_inspect_path(&channel.offer, format_args!("")),
|req| {
let mut resp = req.respond();
channel.inspect_state(&mut resp);
if !matches!(channel.state, ChannelState::Revoked) {
notifier.inspect(version, offer_id, resp.request());
}
},
);
}
}
}
pub const MAX_CHANNELS: usize = 2047;
impl ChannelList {
fn new() -> Self {
Self {
channels: Slab::new(),
}
}
fn len(&self) -> usize {
self.channels.len()
}
fn offer(&mut self, new_channel: Channel) -> OfferId {
OfferId(self.channels.insert(new_channel))
}
fn remove(&mut self, offer_id: OfferId) {
let channel = self.channels.remove(offer_id.0);
assert!(channel.info.is_none());
}
fn get_by_channel_id_mut(
&mut self,
assigned_channels: &AssignedChannels,
channel_id: ChannelId,
) -> Result<(OfferId, &mut Channel), ChannelError> {
let offer_id = assigned_channels
.get(channel_id)
.ok_or(ChannelError::UnknownChannelId)?;
let channel = &mut self[offer_id];
if channel.state.is_released() {
return Err(ChannelError::ChannelReleased);
}
assert_eq!(
channel.info.as_ref().map(|info| info.channel_id),
Some(channel_id)
);
Ok((offer_id, channel))
}
fn get_by_channel_id(
&self,
assigned_channels: &AssignedChannels,
channel_id: ChannelId,
) -> Result<(OfferId, &Channel), ChannelError> {
let offer_id = assigned_channels
.get(channel_id)
.ok_or(ChannelError::UnknownChannelId)?;
let channel = &self[offer_id];
if channel.state.is_released() {
return Err(ChannelError::ChannelReleased);
}
assert_eq!(
channel.info.as_ref().map(|info| info.channel_id),
Some(channel_id)
);
Ok((offer_id, channel))
}
fn get_by_key_mut(&mut self, key: &OfferKey) -> Option<(OfferId, &mut Channel)> {
for (offer_id, channel) in self.iter_mut() {
if channel.offer.instance_id == key.instance_id
&& channel.offer.interface_id == key.interface_id
&& channel.offer.subchannel_index == key.subchannel_index
{
return Some((offer_id, channel));
}
}
None
}
fn iter(&self) -> impl Iterator<Item = (OfferId, &Channel)> {
self.channels
.iter()
.map(|(id, channel)| (OfferId(id), channel))
}
fn iter_mut(&mut self) -> impl Iterator<Item = (OfferId, &mut Channel)> {
self.channels
.iter_mut()
.map(|(id, channel)| (OfferId(id), channel))
}
fn retain<F>(&mut self, mut f: F)
where
F: FnMut(OfferId, &mut Channel) -> bool,
{
self.channels.retain(|id, channel| {
let retain = f(OfferId(id), channel);
if !retain {
assert!(channel.info.is_none());
}
retain
})
}
}
impl Index<OfferId> for ChannelList {
type Output = Channel;
fn index(&self, offer_id: OfferId) -> &Self::Output {
&self.channels[offer_id.0]
}
}
impl IndexMut<OfferId> for ChannelList {
fn index_mut(&mut self, offer_id: OfferId) -> &mut Self::Output {
&mut self.channels[offer_id.0]
}
}
#[derive(Debug, Inspect)]
struct Gpadl {
count: u16,
#[inspect(skip)]
buf: Vec<u64>,
state: GpadlState,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Inspect)]
enum GpadlState {
InProgress,
Offered,
OfferedTearingDown,
Accepted,
TearingDown,
}
impl Gpadl {
fn new(count: u16, len: usize) -> Self {
Self {
state: GpadlState::InProgress,
count,
buf: Vec::with_capacity(len),
}
}
fn append(&mut self, data: &[u8]) -> Result<bool, ChannelError> {
if self.state == GpadlState::InProgress {
let buf = &mut self.buf;
let len = min(data.len() & !7, (buf.capacity() - buf.len()) * 8);
let data = &data[..len];
let start = buf.len();
buf.resize(buf.len() + data.len() / 8, 0);
buf[start..].as_mut_bytes().copy_from_slice(data);
Ok(if buf.len() == buf.capacity() {
gparange::MultiPagedRangeBuf::<Vec<u64>>::validate(self.count as usize, buf)
.map_err(ChannelError::InvalidGpaRange)?;
self.state = GpadlState::Offered;
true
} else {
false
})
} else {
Err(ChannelError::GpadlAlreadyComplete)
}
}
}
#[derive(Debug, Copy, Clone)]
pub struct OpenParams {
pub open_data: OpenData,
pub connection_id: u32,
pub event_flag: u16,
pub monitor_id: Option<MonitorId>,
pub flags: protocol::OpenChannelFlags,
pub reserved_target: Option<ConnectionTarget>,
}
impl OpenParams {
fn from_request(
info: &OfferedInfo,
request: &OpenRequest,
monitor_id: Option<MonitorId>,
reserved_target: Option<ConnectionTarget>,
) -> Self {
let (event_flag, connection_id) = if let Some(id) = request.guest_specified_interrupt_info {
(id.event_flag, id.connection_id)
} else {
(info.channel_id.0 as u16, info.connection_id)
};
Self {
open_data: OpenData {
target_vp: request.target_vp,
ring_offset: request.downstream_ring_buffer_page_offset,
ring_gpadl_id: request.ring_buffer_gpadl_id,
user_data: request.user_data,
event_flag,
connection_id,
},
connection_id,
event_flag,
monitor_id,
flags: request.flags.with_unused(0),
reserved_target,
}
}
}
#[derive(Debug)]
pub enum Action {
Open(OpenParams, VersionInfo),
Close,
Gpadl(GpadlId, u16, Vec<u64>),
TeardownGpadl {
gpadl_id: GpadlId,
post_restore: bool,
},
Modify {
target_vp: u32,
},
}
static SUPPORTED_VERSIONS: &[Version] = &[
Version::V1,
Version::Win7,
Version::Win8,
Version::Win8_1,
Version::Win10,
Version::Win10Rs3_0,
Version::Win10Rs3_1,
Version::Win10Rs4,
Version::Win10Rs5,
Version::Iron,
Version::Copper,
];
const SUPPORTED_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
.with_guest_specified_signal_parameters(true)
.with_channel_interrupt_redirection(true)
.with_modify_connection(true)
.with_client_id(true)
.with_pause_resume(true);
pub trait Notifier: Send {
fn notify(&mut self, offer_id: OfferId, action: Action);
fn forward_unhandled(&mut self, request: InitiateContactRequest);
fn modify_connection(&mut self, request: ModifyConnectionRequest) -> anyhow::Result<()>;
fn inspect(&self, version: Option<VersionInfo>, offer_id: OfferId, req: inspect::Request<'_>) {
let _ = (version, offer_id, req);
}
#[must_use]
fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool;
fn notify_hvsock(&mut self, request: &HvsockConnectRequest);
fn reset_complete(&mut self);
fn unload_complete(&mut self);
}
impl Server {
pub fn new(vtl: Vtl, child_connection_id: u32, channel_id_offset: u16) -> Self {
Server {
state: ConnectionState::Disconnected,
channels: ChannelList::new(),
assigned_channels: AssignedChannels::new(vtl, channel_id_offset),
assigned_monitors: AssignedMonitors::new(),
gpadls: Default::default(),
incomplete_gpadls: Default::default(),
child_connection_id,
max_version: None,
delayed_max_version: None,
pending_messages: PendingMessages(VecDeque::new()),
}
}
pub fn with_notifier<'a, T: Notifier>(
&'a mut self,
notifier: &'a mut T,
) -> ServerWithNotifier<'a, T> {
self.validate();
ServerWithNotifier {
inner: self,
notifier,
}
}
fn validate(&self) {
#[cfg(debug_assertions)]
for (_, channel) in self.channels.iter() {
let should_have_info = !channel.state.is_released();
if channel.info.is_some() != should_have_info {
panic!("channel invariant violation: {channel:?}");
}
}
}
pub fn set_compatibility_version(&mut self, version: MaxVersionInfo, delay: bool) {
if delay {
self.delayed_max_version = Some(version)
} else {
tracing::info!(?version, "Limiting VmBus connections to version");
self.max_version = Some(version);
}
}
pub fn channel_gpadls(&self, offer_id: OfferId) -> Vec<RestoredGpadl> {
self.gpadls
.iter()
.filter_map(|(&(gpadl_id, gpadl_offer_id), gpadl)| {
if offer_id != gpadl_offer_id {
return None;
}
let accepted = match gpadl.state {
GpadlState::Offered | GpadlState::OfferedTearingDown => false,
GpadlState::Accepted => true,
GpadlState::InProgress | GpadlState::TearingDown => return None,
};
Some(RestoredGpadl {
request: GpadlRequest {
id: gpadl_id,
count: gpadl.count,
buf: gpadl.buf.clone(),
},
accepted,
})
})
.collect()
}
pub fn get_version(&self) -> Option<VersionInfo> {
self.state.get_version()
}
pub fn get_restore_open_params(&self, offer_id: OfferId) -> Result<OpenParams, RestoreError> {
let channel = &self.channels[offer_id];
match channel.restore_state {
RestoreState::New => {
return Err(RestoreError::MissingChannel(channel.offer.key()));
}
RestoreState::Restoring => {}
RestoreState::Unmatched => unreachable!(),
RestoreState::Restored => {
return Err(RestoreError::AlreadyRestored(channel.offer.key()));
}
}
let info = channel
.info
.ok_or_else(|| RestoreError::MissingChannel(channel.offer.key()))?;
let (request, reserved_state) = match channel.state {
ChannelState::Closed => {
return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
}
ChannelState::Closing { params, .. } | ChannelState::ClosingReopen { params, .. } => {
(params, None)
}
ChannelState::Opening {
request,
reserved_state,
} => (request, reserved_state),
ChannelState::Open {
params,
reserved_state,
..
} => (params, reserved_state),
ChannelState::ClientReleased | ChannelState::Reoffered => {
return Err(RestoreError::MissingChannel(channel.offer.key()));
}
ChannelState::Revoked
| ChannelState::ClosingClientRelease
| ChannelState::OpeningClientRelease => unreachable!(),
};
Ok(OpenParams::from_request(
&info,
&request,
channel.handled_monitor_id(),
reserved_state.map(|state| state.target),
))
}
pub fn has_pending_messages(&self) -> bool {
!self.pending_messages.0.is_empty() && !self.state.is_paused()
}
pub fn poll_flush_pending_messages(
&mut self,
mut send: impl FnMut(&OutgoingMessage) -> Poll<()>,
) -> Poll<()> {
if !self.state.is_paused() {
while let Some(message) = self.pending_messages.0.front() {
ready!(send(message));
self.pending_messages.0.pop_front();
}
}
Poll::Ready(())
}
}
impl<'a, N: 'a + Notifier> ServerWithNotifier<'a, N> {
pub fn restore_channel(&mut self, offer_id: OfferId, open: bool) -> Result<(), RestoreError> {
let channel = &mut self.inner.channels[offer_id];
match channel.restore_state {
RestoreState::New => {
if open {
return Err(RestoreError::MissingChannel(channel.offer.key()));
} else {
return Ok(());
}
}
RestoreState::Restoring => {}
RestoreState::Unmatched => unreachable!(),
RestoreState::Restored => {
return Err(RestoreError::AlreadyRestored(channel.offer.key()));
}
}
let info = channel
.info
.ok_or_else(|| RestoreError::MissingChannel(channel.offer.key()))?;
if let Some(monitor_id) = channel.handled_monitor_id() {
if !self.inner.assigned_monitors.claim_monitor(monitor_id) {
return Err(RestoreError::DuplicateMonitorId(monitor_id.0));
}
}
if open {
match channel.state {
ChannelState::Closed => {
return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
}
ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
self.notifier.notify(offer_id, Action::Close);
}
ChannelState::Opening {
request,
reserved_state,
} => {
self.inner
.pending_messages
.sender(self.notifier, self.inner.state.is_paused())
.send_open_result(
info.channel_id,
&request,
protocol::STATUS_SUCCESS,
MessageTarget::for_offer(offer_id, &reserved_state),
);
channel.state = ChannelState::Open {
params: request,
modify_state: ModifyState::NotModifying,
reserved_state,
};
}
ChannelState::Open { .. } => {}
ChannelState::ClientReleased | ChannelState::Reoffered => {
return Err(RestoreError::MissingChannel(channel.offer.key()));
}
ChannelState::Revoked
| ChannelState::ClosingClientRelease
| ChannelState::OpeningClientRelease => unreachable!(),
};
} else {
match channel.state {
ChannelState::Closed => {}
ChannelState::Reoffered => {}
ChannelState::Closing { .. } => {
channel.state = ChannelState::Closed;
}
ChannelState::ClosingReopen { request, .. } => {
self.notifier.notify(
offer_id,
Action::Open(
OpenParams::from_request(
&info,
&request,
channel.handled_monitor_id(),
None,
),
self.inner.state.get_version().expect("must be connected"),
),
);
channel.state = ChannelState::Opening {
request,
reserved_state: None,
};
}
ChannelState::Opening {
request,
reserved_state,
} => {
self.notifier.notify(
offer_id,
Action::Open(
OpenParams::from_request(
&info,
&request,
channel.handled_monitor_id(),
reserved_state.map(|state| state.target),
),
self.inner.state.get_version().expect("must be connected"),
),
);
}
ChannelState::Open { .. } => {
return Err(RestoreError::MismatchedOpenState(channel.offer.key()));
}
ChannelState::ClientReleased => {
return Err(RestoreError::MissingChannel(channel.offer.key()));
}
ChannelState::Revoked
| ChannelState::ClosingClientRelease
| ChannelState::OpeningClientRelease => unreachable!(),
}
}
channel.restore_state = RestoreState::Restored;
Ok(())
}
pub fn post_restore(&mut self) -> Result<(), RestoreError> {
for (offer_id, channel) in self.inner.channels.iter_mut() {
match channel.restore_state {
RestoreState::Restored => {
}
RestoreState::New => {
if let ConnectionState::Connected(info) = &self.inner.state {
if matches!(channel.state, ChannelState::ClientReleased) {
channel.prepare_channel(
offer_id,
&mut self.inner.assigned_channels,
&mut self.inner.assigned_monitors,
);
channel.state = ChannelState::Closed;
self.inner
.pending_messages
.sender(self.notifier, self.inner.state.is_paused())
.send_offer(channel, info.version);
}
}
}
RestoreState::Restoring => {
let retain = revoke(
self.inner
.pending_messages
.sender(self.notifier, self.inner.state.is_paused()),
offer_id,
channel,
&mut self.inner.gpadls,
);
assert!(retain, "channel has not been released");
channel.state = ChannelState::Reoffered;
}
RestoreState::Unmatched => {
let retain = revoke(
self.inner
.pending_messages
.sender(self.notifier, self.inner.state.is_paused()),
offer_id,
channel,
&mut self.inner.gpadls,
);
assert!(retain, "channel has not been released");
}
}
}
for (&(gpadl_id, offer_id), gpadl) in self.inner.gpadls.iter_mut() {
match gpadl.state {
GpadlState::InProgress | GpadlState::Accepted => {}
GpadlState::Offered => {
self.notifier.notify(
offer_id,
Action::Gpadl(gpadl_id, gpadl.count, gpadl.buf.clone()),
);
}
GpadlState::TearingDown => {
self.notifier.notify(
offer_id,
Action::TeardownGpadl {
gpadl_id,
post_restore: true,
},
);
}
GpadlState::OfferedTearingDown => unreachable!(),
}
}
let request = match self.inner.state {
ConnectionState::Connecting {
info,
next_action: _,
} => Some(ModifyConnectionRequest {
version: Some(info.version.version as u32),
interrupt_page: info.interrupt_page.into(),
monitor_page: info.monitor_page.into(),
target_message_vp: Some(info.target_message_vp),
force: true,
notify_relay: true,
}),
ConnectionState::Connected(info) => Some(ModifyConnectionRequest {
version: None,
monitor_page: info.monitor_page.into(),
interrupt_page: info.interrupt_page.into(),
target_message_vp: Some(info.target_message_vp),
force: true,
notify_relay: info.modifying,
}),
ConnectionState::Disconnected | ConnectionState::Disconnecting { .. } => None,
};
if let Some(request) = request {
self.notifier.modify_connection(request)?;
}
self.check_disconnected();
Ok(())
}
pub fn reset(&mut self) {
assert!(!self.is_resetting());
if self.request_disconnect(ConnectionAction::Reset) {
self.complete_reset();
}
}
fn complete_reset(&mut self) {
for (_, channel) in self.inner.channels.iter_mut() {
channel.restore_state = RestoreState::New;
}
self.inner.pending_messages.0.clear();
self.notifier.reset_complete();
}
pub fn offer_channel(&mut self, offer: OfferParamsInternal) -> Result<OfferId, OfferError> {
if let Some((offer_id, channel)) = self.inner.channels.get_by_key_mut(&offer.key()) {
if channel.restore_state != RestoreState::Unmatched
&& !matches!(channel.state, ChannelState::Revoked)
{
return Err(OfferError::AlreadyExists(offer.key()));
}
let info = channel.info.expect("assigned");
if channel.restore_state == RestoreState::Unmatched {
tracing::debug!(
offer_id = offer_id.0,
key = %channel.offer.key(),
"matched channel"
);
assert!(!matches!(channel.state, ChannelState::Revoked));
channel.restore_state = RestoreState::Restoring;
if let Some(monitor_id) = offer.monitor_id {
if info.monitor_id != Some(MonitorId(monitor_id)) {
return Err(OfferError::MismatchedMonitorId(
info.monitor_id,
MonitorId(monitor_id),
));
}
}
} else {
channel.state = ChannelState::Reoffered;
tracing::info!(?offer_id, key = %channel.offer.key(), "channel marked for reoffer");
}
channel.offer = offer;
return Ok(offer_id);
}
let mut connected_version = None;
let state = match self.inner.state {
ConnectionState::Connected(ConnectionInfo {
offers_sent: true,
version,
..
}) => {
connected_version = Some(version);
ChannelState::Closed
}
ConnectionState::Connected(ConnectionInfo {
offers_sent: false, ..
})
| ConnectionState::Connecting { .. }
| ConnectionState::Disconnecting { .. }
| ConnectionState::Disconnected => ChannelState::ClientReleased,
};
if self.inner.channels.len() >= self.inner.assigned_channels.allowable_channel_count() {
return Err(OfferError::TooManyChannels);
}
let key = offer.key();
let confidential_ring_buffer = offer.flags.confidential_ring_buffer();
let confidential_external_memory = offer.flags.confidential_external_memory();
let channel = Channel {
info: None,
offer,
state,
restore_state: RestoreState::New,
};
let offer_id = self.inner.channels.offer(channel);
if let Some(version) = connected_version {
let channel = &mut self.inner.channels[offer_id];
channel.prepare_channel(
offer_id,
&mut self.inner.assigned_channels,
&mut self.inner.assigned_monitors,
);
self.inner
.pending_messages
.sender(self.notifier, self.inner.state.is_paused())
.send_offer(channel, version);
}
tracing::info!(?offer_id, %key, confidential_ring_buffer, confidential_external_memory, "new channel");
Ok(offer_id)
}
pub fn revoke_channel(&mut self, offer_id: OfferId) {
let channel = &mut self.inner.channels[offer_id];
let retain = revoke(
self.inner
.pending_messages
.sender(self.notifier, self.inner.state.is_paused()),
offer_id,
channel,
&mut self.inner.gpadls,
);
if !retain {
self.inner.channels.remove(offer_id);
}
self.check_disconnected();
}
pub fn open_complete(&mut self, offer_id: OfferId, result: i32) {
tracing::debug!(offer_id = offer_id.0, result, "open complete");
let channel = &mut self.inner.channels[offer_id];
match channel.state {
ChannelState::Opening {
request,
reserved_state,
} => {
let channel_id = channel.info.expect("assigned").channel_id;
if result >= 0 {
tracelimit::info_ratelimited!(
offer_id = offer_id.0,
channel_id = channel_id.0,
result,
"opened channel"
);
} else {
tracelimit::error_ratelimited!(
offer_id = offer_id.0,
channel_id = channel_id.0,
result,
"failed to open channel"
);
}
self.inner
.pending_messages
.sender(self.notifier, self.inner.state.is_paused())
.send_open_result(
channel_id,
&request,
result,
MessageTarget::for_offer(offer_id, &reserved_state),
);
channel.state = if result >= 0 {
ChannelState::Open {
params: request,
modify_state: ModifyState::NotModifying,
reserved_state,
}
} else {
ChannelState::Closed
};
}
ChannelState::OpeningClientRelease => {
tracing::info!(
offer_id = offer_id.0,
result,
"opened channel (client released)"
);
if result >= 0 {
channel.state = ChannelState::ClosingClientRelease;
self.notifier.notify(offer_id, Action::Close);
} else {
channel.state = ChannelState::ClientReleased;
self.check_disconnected();
}
}
ChannelState::ClientReleased
| ChannelState::Closed
| ChannelState::Open { .. }
| ChannelState::Closing { .. }
| ChannelState::ClosingReopen { .. }
| ChannelState::Revoked
| ChannelState::Reoffered
| ChannelState::ClosingClientRelease => {
tracing::error!(?offer_id, state = ?channel.state, "invalid open complete")
}
}
}
fn are_channels_reset(&self, include_reserved: bool) -> bool {
self.inner.gpadls.keys().all(|(_, offer_id)| {
!include_reserved && self.inner.channels[*offer_id].state.is_reserved()
}) && self.inner.channels.iter().all(|(_, channel)| {
matches!(channel.state, ChannelState::ClientReleased)
|| (!include_reserved && channel.state.is_reserved())
})
}
fn check_disconnected(&mut self) {
match self.inner.state {
ConnectionState::Disconnecting {
next_action,
modify_sent: false,
} => {
if self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)) {
self.inner.state = ConnectionState::Disconnecting {
next_action,
modify_sent: true,
};
self.notifier
.modify_connection(ModifyConnectionRequest {
monitor_page: Update::Reset,
interrupt_page: Update::Reset,
..Default::default()
})
.expect("resetting state should not fail");
}
}
ConnectionState::Disconnecting {
modify_sent: true, ..
}
| ConnectionState::Disconnected
| ConnectionState::Connected { .. }
| ConnectionState::Connecting { .. } => (),
}
}
fn is_resetting(&self) -> bool {
matches!(
&self.inner.state,
ConnectionState::Connecting {
next_action: ConnectionAction::Reset,
..
} | ConnectionState::Disconnecting {
next_action: ConnectionAction::Reset,
..
}
)
}
pub fn close_complete(&mut self, offer_id: OfferId) {
let channel = &mut self.inner.channels[offer_id];
tracing::info!(offer_id = offer_id.0, "closed channel");
match channel.state {
ChannelState::Closing {
reserved_state: Some(reserved_state),
..
} => {
channel.state = ChannelState::Closed;
if matches!(self.inner.state, ConnectionState::Connected { .. }) {
let channel_id = channel.info.expect("assigned").channel_id;
self.send_close_reserved_channel_response(
channel_id,
offer_id,
reserved_state.target,
);
} else {
if Self::client_release_channel(
self.inner
.pending_messages
.sender(self.notifier, self.inner.state.is_paused()),
offer_id,
channel,
&mut self.inner.gpadls,
&mut self.inner.assigned_channels,
&mut self.inner.assigned_monitors,
None,
) {
self.inner.channels.remove(offer_id);
}
}
}
ChannelState::Closing { .. } => {
channel.state = ChannelState::Closed;
}
ChannelState::ClosingClientRelease => {
channel.state = ChannelState::ClientReleased;
self.check_disconnected();
}
ChannelState::ClosingReopen { request, .. } => {
channel.state = ChannelState::Closed;
self.open_channel(offer_id, &request, None);
}
ChannelState::Closed
| ChannelState::ClientReleased
| ChannelState::Opening { .. }
| ChannelState::Open { .. }
| ChannelState::Revoked
| ChannelState::Reoffered
| ChannelState::OpeningClientRelease => {
tracing::error!(?offer_id, state = ?channel.state, "invalid close complete")
}
}
}
fn send_close_reserved_channel_response(
&mut self,
channel_id: ChannelId,
offer_id: OfferId,
target: ConnectionTarget,
) {
self.sender().send_message_with_target(
&protocol::CloseReservedChannelResponse { channel_id },
MessageTarget::ReservedChannel(offer_id, target),
);
}
fn handle_initiate_contact(
&mut self,
input: &protocol::InitiateContact2,
message: &SynicMessage,
includes_client_id: bool,
) -> Result<(), ChannelError> {
let target_info =
protocol::TargetInfo::from(input.initiate_contact.interrupt_page_or_target_info);
let target_sint = if message.multiclient
&& input.initiate_contact.version_requested >= Version::Win10Rs3_1 as u32
{
target_info.sint()
} else {
SINT
};
let target_vtl = if message.multiclient
&& input.initiate_contact.version_requested >= Version::Win10Rs4 as u32
{
target_info.vtl()
} else {
0
};
let feature_flags = if input.initiate_contact.version_requested >= Version::Copper as u32 {
target_info.feature_flags()
} else {
0
};
let target_message_vp =
if input.initiate_contact.version_requested >= Version::Win8_1 as u32 {
input.initiate_contact.target_message_vp
} else {
0
};
let interrupt_page = (input.initiate_contact.version_requested < Version::Win8 as u32
&& input.initiate_contact.interrupt_page_or_target_info != 0)
.then_some(input.initiate_contact.interrupt_page_or_target_info);
let monitor_page = if (input.initiate_contact.parent_to_child_monitor_page_gpa == 0)
!= (input.initiate_contact.child_to_parent_monitor_page_gpa == 0)
{
MonitorPageRequest::Invalid
} else if input.initiate_contact.parent_to_child_monitor_page_gpa != 0 {
MonitorPageRequest::Some(MonitorPageGpas {
parent_to_child: input.initiate_contact.parent_to_child_monitor_page_gpa,
child_to_parent: input.initiate_contact.child_to_parent_monitor_page_gpa,
})
} else {
MonitorPageRequest::None
};
let client_id = if FeatureFlags::from(feature_flags).client_id() {
if includes_client_id {
input.client_id
} else {
return Err(ChannelError::ParseError(
protocol::ParseError::MessageTooSmall(Some(
protocol::MessageType::INITIATE_CONTACT,
)),
));
}
} else {
Guid::ZERO
};
let request = InitiateContactRequest {
version_requested: input.initiate_contact.version_requested,
target_message_vp,
monitor_page,
target_sint,
target_vtl,
feature_flags,
interrupt_page,
client_id,
trusted: message.trusted,
};
self.initiate_contact(request);
Ok(())
}
pub fn initiate_contact(&mut self, request: InitiateContactRequest) {
let vtl = self.inner.assigned_channels.vtl as u8;
if request.target_vtl != vtl {
self.notifier.forward_unhandled(request);
return;
}
if request.target_sint != SINT {
tracelimit::warn_ratelimited!(
"unsupported multiclient request for VTL {} SINT {}, version {:#x}",
request.target_vtl,
request.target_sint,
request.version_requested,
);
self.send_version_response_with_target(
None,
MessageTarget::Custom(ConnectionTarget {
vp: request.target_message_vp,
sint: request.target_sint,
}),
);
return;
}
if !self.request_disconnect(ConnectionAction::Reconnect {
initiate_contact: request,
}) {
return;
}
let Some(version) = self.check_version_supported(&request) else {
tracelimit::warn_ratelimited!(
vtl,
version = request.version_requested,
client_id = ?request.client_id,
"Guest requested unsupported version"
);
self.send_version_response(None);
return;
};
tracelimit::info_ratelimited!(
vtl,
?version,
client_id = ?request.client_id,
trusted = request.trusted,
"Guest negotiated version"
);
let monitor_page = match request.monitor_page {
MonitorPageRequest::Some(mp) => Some(mp),
MonitorPageRequest::None => None,
MonitorPageRequest::Invalid => {
self.send_version_response(Some((
version,
protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
)));
return;
}
};
self.inner.state = ConnectionState::Connecting {
info: ConnectionInfo {
version,
trusted: request.trusted,
interrupt_page: request.interrupt_page,
monitor_page,
target_message_vp: request.target_message_vp,
modifying: false,
offers_sent: false,
client_id: request.client_id,
paused: false,
},
next_action: ConnectionAction::None,
};
if let Err(err) = self.notifier.modify_connection(ModifyConnectionRequest {
version: Some(request.version_requested),
monitor_page: monitor_page.into(),
interrupt_page: request.interrupt_page.into(),
target_message_vp: Some(request.target_message_vp),
force: false,
notify_relay: true,
}) {
tracelimit::error_ratelimited!(?err, "server failed to change state");
self.inner.state = ConnectionState::Disconnected;
self.send_version_response(Some((
version,
protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
)));
}
}
pub(crate) fn complete_initiate_contact(&mut self, response: ModifyConnectionResponse) {
let ConnectionState::Connecting {
mut info,
next_action,
} = self.inner.state
else {
panic!("Invalid state for completing InitiateContact.");
};
const LOCAL_FEATURE_FLAGS: FeatureFlags = FeatureFlags::new()
.with_client_id(true)
.with_confidential_channels(true);
let relay_feature_flags = match response {
ModifyConnectionResponse::Supported(
protocol::ConnectionState::SUCCESSFUL,
feature_flags,
) => feature_flags,
ModifyConnectionResponse::Supported(connection_state, feature_flags) => {
tracelimit::error_ratelimited!(
?connection_state,
"initiate contact failed because relay request failed"
);
info.version.feature_flags &= feature_flags | LOCAL_FEATURE_FLAGS;
self.send_version_response(Some((info.version, connection_state)));
self.inner.state = ConnectionState::Disconnected;
return;
}
ModifyConnectionResponse::Unsupported => {
self.send_version_response(None);
self.inner.state = ConnectionState::Disconnected;
return;
}
};
info.version.feature_flags &= relay_feature_flags | LOCAL_FEATURE_FLAGS;
self.inner.state = ConnectionState::Connected(info);
self.send_version_response(Some((info.version, protocol::ConnectionState::SUCCESSFUL)));
if !matches!(next_action, ConnectionAction::None) && self.request_disconnect(next_action) {
self.do_next_action(next_action);
}
}
fn check_version_supported(&self, request: &InitiateContactRequest) -> Option<VersionInfo> {
let version = SUPPORTED_VERSIONS
.iter()
.find(|v| request.version_requested == **v as u32)
.copied()?;
if let Some(max_version) = self.inner.max_version {
if version as u32 > max_version.version {
return None;
}
}
let supported_flags = if version >= Version::Copper {
let max_supported_flags =
SUPPORTED_FEATURE_FLAGS.with_confidential_channels(request.trusted);
if let Some(max_version) = self.inner.max_version {
max_supported_flags & max_version.feature_flags
} else {
max_supported_flags
}
} else {
FeatureFlags::new()
};
let feature_flags = supported_flags & request.feature_flags.into();
assert!(version >= Version::Copper || feature_flags == FeatureFlags::new());
if feature_flags.into_bits() != request.feature_flags {
tracelimit::warn_ratelimited!(
supported = feature_flags.into_bits(),
requested = request.feature_flags,
"Guest requested unsupported feature flags."
);
}
Some(VersionInfo {
version,
feature_flags,
})
}
fn send_version_response(&mut self, data: Option<(VersionInfo, protocol::ConnectionState)>) {
self.send_version_response_with_target(data, MessageTarget::Default);
}
fn send_version_response_with_target(
&mut self,
data: Option<(VersionInfo, protocol::ConnectionState)>,
target: MessageTarget,
) {
let mut response2 = protocol::VersionResponse2::new_zeroed();
let response = &mut response2.version_response;
let mut send_response2 = false;
if let Some((version, state)) = data {
if state == protocol::ConnectionState::SUCCESSFUL || version.version >= Version::Win8 {
response.version_supported = 1;
response.connection_state = state;
response.selected_version_or_connection_id =
if version.version >= Version::Win10Rs3_1 {
self.inner.child_connection_id
} else {
version.version as u32
};
if version.version >= Version::Copper {
response2.supported_features = version.feature_flags.into();
send_response2 = true;
}
}
}
if send_response2 {
self.sender().send_message_with_target(&response2, target);
} else {
self.sender().send_message_with_target(response, target);
}
}
fn request_disconnect(&mut self, new_action: ConnectionAction) -> bool {
assert!(!self.is_resetting());
let gpadls = &mut self.inner.gpadls;
let vm_reset = matches!(new_action, ConnectionAction::Reset);
self.inner.channels.retain(|offer_id, channel| {
(!vm_reset && channel.state.is_reserved())
|| !Self::client_release_channel(
self.inner
.pending_messages
.sender(self.notifier, self.inner.state.is_paused()),
offer_id,
channel,
gpadls,
&mut self.inner.assigned_channels,
&mut self.inner.assigned_monitors,
None,
)
});
match &mut self.inner.state {
ConnectionState::Disconnected => {
if vm_reset {
if !self.are_channels_reset(true) {
self.inner.state = ConnectionState::Disconnecting {
next_action: ConnectionAction::Reset,
modify_sent: false,
};
}
} else {
assert!(self.are_channels_reset(false));
}
}
ConnectionState::Connected { .. } => {
if self.are_channels_reset(vm_reset) {
self.inner.state = ConnectionState::Disconnected;
} else {
self.inner.state = ConnectionState::Disconnecting {
next_action: new_action,
modify_sent: false,
};
}
}
ConnectionState::Connecting { next_action, .. }
| ConnectionState::Disconnecting { next_action, .. } => {
*next_action = new_action;
}
}
matches!(self.inner.state, ConnectionState::Disconnected)
}
pub(crate) fn complete_disconnect(&mut self) {
if let ConnectionState::Disconnecting {
next_action,
modify_sent,
} = std::mem::replace(&mut self.inner.state, ConnectionState::Disconnected)
{
assert!(self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)));
if !modify_sent {
tracelimit::warn_ratelimited!("unexpected modify response");
}
self.inner.state = ConnectionState::Disconnected;
self.do_next_action(next_action);
} else {
unreachable!("not ready for disconnect");
}
}
fn do_next_action(&mut self, action: ConnectionAction) {
match action {
ConnectionAction::None => {}
ConnectionAction::Reset => {
self.complete_reset();
}
ConnectionAction::SendUnloadComplete => {
self.complete_unload();
}
ConnectionAction::Reconnect { initiate_contact } => {
self.initiate_contact(initiate_contact);
}
ConnectionAction::SendFailedVersionResponse => {
self.send_version_response(None);
}
}
}
fn handle_unload(&mut self) {
tracing::debug!(
vtl = self.inner.assigned_channels.vtl as u8,
state = ?self.inner.state,
"VmBus received unload request from guest",
);
if self.request_disconnect(ConnectionAction::SendUnloadComplete) {
self.complete_unload();
}
}
fn complete_unload(&mut self) {
self.notifier.unload_complete();
if let Some(version) = self.inner.delayed_max_version.take() {
self.inner.set_compatibility_version(version, false);
}
self.sender().send_message(&protocol::UnloadComplete {});
tracelimit::info_ratelimited!("Vmbus disconnected");
}
fn handle_request_offers(&mut self) -> Result<(), ChannelError> {
let ConnectionState::Connected(info) = &mut self.inner.state else {
unreachable!(
"in unexpected state {:?}, should be prevented by Message::parse()",
self.inner.state
);
};
if info.offers_sent {
return Err(ChannelError::OffersAlreadySent);
}
info.offers_sent = true;
let mut sorted_channels: Vec<_> = self
.inner
.channels
.iter_mut()
.filter(|(_, channel)| !channel.state.is_reserved())
.collect();
sorted_channels.sort_unstable_by_key(|(_, channel)| {
(
channel.offer.interface_id,
channel.offer.offer_order.unwrap_or(u32::MAX),
channel.offer.instance_id,
)
});
for (offer_id, channel) in sorted_channels {
assert!(matches!(channel.state, ChannelState::ClientReleased));
assert!(channel.info.is_none());
channel.prepare_channel(
offer_id,
&mut self.inner.assigned_channels,
&mut self.inner.assigned_monitors,
);
channel.state = ChannelState::Closed;
self.inner
.pending_messages
.sender(self.notifier, info.paused)
.send_offer(channel, info.version);
}
self.sender().send_message(&protocol::AllOffersDelivered {});
Ok(())
}
#[must_use]
fn gpadl_updated(
mut sender: MessageSender<'_, N>,
offer_id: OfferId,
channel: &Channel,
gpadl_id: GpadlId,
gpadl: &Gpadl,
) -> bool {
if channel.state.is_revoked() {
let channel_id = channel.info.as_ref().expect("assigned").channel_id;
sender.send_gpadl_created(channel_id, gpadl_id, protocol::STATUS_UNSUCCESSFUL);
false
} else {
sender.notifier.notify(
offer_id,
Action::Gpadl(gpadl_id, gpadl.count, gpadl.buf.clone()),
);
true
}
}
fn handle_gpadl_header(
&mut self,
input: &protocol::GpadlHeader,
range: &[u8],
) -> Result<(), ChannelError> {
let (offer_id, channel) = self
.inner
.channels
.get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
if channel.state.is_reserved() {
return Err(ChannelError::ChannelReserved);
}
let mut gpadl = Gpadl::new(input.count, input.len as usize / 8);
let done = gpadl.append(range)?;
let gpadl = match self.inner.gpadls.entry((input.gpadl_id, offer_id)) {
Entry::Vacant(entry) => entry.insert(gpadl),
Entry::Occupied(_) => return Err(ChannelError::DuplicateGpadlId),
};
if !done
&& self
.inner
.incomplete_gpadls
.insert(input.gpadl_id, offer_id)
.is_some()
{
unreachable!("gpadl ID validated above");
}
if done
&& !Self::gpadl_updated(
self.inner
.pending_messages
.sender(self.notifier, self.inner.state.is_paused()),
offer_id,
channel,
input.gpadl_id,
gpadl,
)
{
self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
}
Ok(())
}
fn handle_gpadl_body(
&mut self,
input: &protocol::GpadlBody,
range: &[u8],
) -> Result<(), ChannelError> {
let &offer_id = self
.inner
.incomplete_gpadls
.get(&input.gpadl_id)
.ok_or(ChannelError::UnknownGpadlId)?;
let gpadl = self
.inner
.gpadls
.get_mut(&(input.gpadl_id, offer_id))
.ok_or(ChannelError::UnknownGpadlId)?;
let channel = &mut self.inner.channels[offer_id];
if gpadl.append(range)? {
self.inner.incomplete_gpadls.remove(&input.gpadl_id);
if !Self::gpadl_updated(
self.inner
.pending_messages
.sender(self.notifier, self.inner.state.is_paused()),
offer_id,
channel,
input.gpadl_id,
gpadl,
) {
self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
}
}
Ok(())
}
fn handle_gpadl_teardown(
&mut self,
input: &protocol::GpadlTeardown,
) -> Result<(), ChannelError> {
tracing::debug!(
channel_id = input.channel_id.0,
gpadl_id = input.gpadl_id.0,
"Received GPADL teardown request"
);
let (offer_id, channel) = self
.inner
.channels
.get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
let gpadl = self
.inner
.gpadls
.get_mut(&(input.gpadl_id, offer_id))
.ok_or(ChannelError::UnknownGpadlId)?;
match gpadl.state {
GpadlState::InProgress
| GpadlState::Offered
| GpadlState::OfferedTearingDown
| GpadlState::TearingDown => {
return Err(ChannelError::InvalidGpadlState);
}
GpadlState::Accepted => {
if channel.info.as_ref().map(|info| info.channel_id) != Some(input.channel_id) {
return Err(ChannelError::WrongGpadlChannelId);
}
if channel.state.is_reserved() {
return Err(ChannelError::ChannelReserved);
}
if channel.state.is_revoked() {
tracing::trace!(
channel_id = input.channel_id.0,
gpadl_id = input.gpadl_id.0,
"Gpadl teardown for revoked channel"
);
self.inner.gpadls.remove(&(input.gpadl_id, offer_id));
self.sender().send_gpadl_torndown(input.gpadl_id);
} else {
gpadl.state = GpadlState::TearingDown;
self.notifier.notify(
offer_id,
Action::TeardownGpadl {
gpadl_id: input.gpadl_id,
post_restore: false,
},
);
}
}
}
Ok(())
}
fn open_channel(
&mut self,
offer_id: OfferId,
input: &OpenRequest,
reserved_state: Option<ReservedState>,
) {
let channel = &mut self.inner.channels[offer_id];
assert!(matches!(channel.state, ChannelState::Closed));
channel.state = ChannelState::Opening {
request: *input,
reserved_state,
};
let info = channel.info.as_ref().expect("assigned");
self.notifier.notify(
offer_id,
Action::Open(
OpenParams::from_request(
info,
input,
channel.handled_monitor_id(),
reserved_state.map(|state| state.target),
),
self.inner.state.get_version().expect("must be connected"),
),
);
}
fn handle_open_channel(&mut self, input: &protocol::OpenChannel2) -> Result<(), ChannelError> {
let (offer_id, channel) = self
.inner
.channels
.get_by_channel_id_mut(&self.inner.assigned_channels, input.open_channel.channel_id)?;
let guest_specified_interrupt_info = self
.inner
.state
.check_feature_flags(|ff| ff.guest_specified_signal_parameters())
.then_some(SignalInfo {
event_flag: input.event_flag,
connection_id: input.connection_id,
});
let flags = if self
.inner
.state
.check_feature_flags(|ff| ff.channel_interrupt_redirection())
{
input.flags
} else {
Default::default()
};
let request = OpenRequest {
open_id: input.open_channel.open_id,
ring_buffer_gpadl_id: input.open_channel.ring_buffer_gpadl_id,
target_vp: input.open_channel.target_vp,
downstream_ring_buffer_page_offset: input
.open_channel
.downstream_ring_buffer_page_offset,
user_data: input.open_channel.user_data,
guest_specified_interrupt_info,
flags,
};
match channel.state {
ChannelState::Closed => self.open_channel(offer_id, &request, None),
ChannelState::Closing { params, .. } => {
channel.state = ChannelState::ClosingReopen { params, request }
}
ChannelState::Revoked | ChannelState::Reoffered => {}
ChannelState::Open { .. }
| ChannelState::Opening { .. }
| ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelAlreadyOpen),
ChannelState::ClientReleased
| ChannelState::ClosingClientRelease
| ChannelState::OpeningClientRelease => unreachable!(),
}
Ok(())
}
fn handle_close_channel(&mut self, input: &protocol::CloseChannel) -> Result<(), ChannelError> {
let (offer_id, channel) = self
.inner
.channels
.get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
match channel.state {
ChannelState::Open {
params,
modify_state,
reserved_state: None,
} => {
if modify_state.is_modifying() {
tracelimit::warn_ratelimited!(
?modify_state,
"Client is closing the channel with a modify in progress"
)
}
channel.state = ChannelState::Closing {
params,
reserved_state: None,
};
self.notifier.notify(offer_id, Action::Close);
}
ChannelState::Open {
reserved_state: Some(_),
..
} => return Err(ChannelError::ChannelReserved),
ChannelState::Revoked | ChannelState::Reoffered => {}
ChannelState::Closed
| ChannelState::Opening { .. }
| ChannelState::Closing { .. }
| ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelNotOpen),
ChannelState::ClientReleased
| ChannelState::ClosingClientRelease
| ChannelState::OpeningClientRelease => unreachable!(),
}
Ok(())
}
fn handle_open_reserved_channel(
&mut self,
input: &protocol::OpenReservedChannel,
version: VersionInfo,
) -> Result<(), ChannelError> {
let (offer_id, channel) = self
.inner
.channels
.get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
let target = ConnectionTarget {
vp: input.target_vp,
sint: input.target_sint as u8,
};
let reserved_state = Some(ReservedState { version, target });
let request = OpenRequest {
ring_buffer_gpadl_id: input.ring_buffer_gpadl,
target_vp: protocol::VP_INDEX_DISABLE_INTERRUPT,
downstream_ring_buffer_page_offset: input.downstream_page_offset,
open_id: 0,
user_data: UserDefinedData::new_zeroed(),
guest_specified_interrupt_info: None,
flags: Default::default(),
};
match channel.state {
ChannelState::Closed => self.open_channel(offer_id, &request, reserved_state),
ChannelState::Revoked | ChannelState::Reoffered => {}
ChannelState::Open { .. } | ChannelState::Opening { .. } => {
return Err(ChannelError::ChannelAlreadyOpen);
}
ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
return Err(ChannelError::InvalidChannelState);
}
ChannelState::ClientReleased
| ChannelState::ClosingClientRelease
| ChannelState::OpeningClientRelease => unreachable!(),
}
Ok(())
}
fn handle_close_reserved_channel(
&mut self,
input: &protocol::CloseReservedChannel,
) -> Result<(), ChannelError> {
let (offer_id, channel) = self
.inner
.channels
.get_by_channel_id_mut(&self.inner.assigned_channels, input.channel_id)?;
match channel.state {
ChannelState::Open {
params,
reserved_state: Some(mut resvd),
..
} => {
resvd.target.vp = input.target_vp;
resvd.target.sint = input.target_sint as u8;
channel.state = ChannelState::Closing {
params,
reserved_state: Some(resvd),
};
self.notifier.notify(offer_id, Action::Close);
}
ChannelState::Open {
reserved_state: None,
..
} => return Err(ChannelError::ChannelNotReserved),
ChannelState::Revoked | ChannelState::Reoffered => {}
ChannelState::Closed
| ChannelState::Opening { .. }
| ChannelState::Closing { .. }
| ChannelState::ClosingReopen { .. } => return Err(ChannelError::ChannelNotOpen),
ChannelState::ClientReleased
| ChannelState::ClosingClientRelease
| ChannelState::OpeningClientRelease => unreachable!(),
}
Ok(())
}
#[must_use]
fn client_release_channel(
mut sender: MessageSender<'_, N>,
offer_id: OfferId,
channel: &mut Channel,
gpadls: &mut GpadlMap,
assigned_channels: &mut AssignedChannels,
assigned_monitors: &mut AssignedMonitors,
version: Option<VersionInfo>,
) -> bool {
gpadls.retain(|&(gpadl_id, gpadl_offer_id), gpadl| {
if gpadl_offer_id != offer_id {
return true;
}
match gpadl.state {
GpadlState::InProgress => false,
GpadlState::Offered => {
gpadl.state = GpadlState::OfferedTearingDown;
true
}
GpadlState::Accepted => {
if channel.state.is_revoked() {
false
} else {
gpadl.state = GpadlState::TearingDown;
sender.notifier.notify(
offer_id,
Action::TeardownGpadl {
gpadl_id,
post_restore: false,
},
);
true
}
}
GpadlState::OfferedTearingDown | GpadlState::TearingDown => true,
}
});
let remove = match &mut channel.state {
ChannelState::Closed => {
channel.state = ChannelState::ClientReleased;
false
}
ChannelState::Reoffered => {
if let Some(version) = version {
channel.state = ChannelState::Closed;
channel.restore_state = RestoreState::New;
sender.send_offer(channel, version);
return false;
}
channel.state = ChannelState::ClientReleased;
false
}
ChannelState::Revoked => {
channel.state = ChannelState::ClientReleased;
true
}
ChannelState::Opening { .. } => {
channel.state = ChannelState::OpeningClientRelease;
false
}
ChannelState::Open { .. } => {
channel.state = ChannelState::ClosingClientRelease;
sender.notifier.notify(offer_id, Action::Close);
false
}
ChannelState::Closing { .. } | ChannelState::ClosingReopen { .. } => {
channel.state = ChannelState::ClosingClientRelease;
false
}
ChannelState::ClosingClientRelease
| ChannelState::OpeningClientRelease
| ChannelState::ClientReleased => false,
};
assert!(channel.state.is_released());
channel.release_channel(offer_id, assigned_channels, assigned_monitors);
remove
}
fn handle_rel_id_released(
&mut self,
input: &protocol::RelIdReleased,
) -> Result<(), ChannelError> {
let channel_id = input.channel_id;
let (offer_id, channel) = self
.inner
.channels
.get_by_channel_id_mut(&self.inner.assigned_channels, channel_id)?;
match channel.state {
ChannelState::Closed
| ChannelState::Revoked
| ChannelState::Closing { .. }
| ChannelState::Reoffered => {
if Self::client_release_channel(
self.inner
.pending_messages
.sender(self.notifier, self.inner.state.is_paused()),
offer_id,
channel,
&mut self.inner.gpadls,
&mut self.inner.assigned_channels,
&mut self.inner.assigned_monitors,
self.inner.state.get_version(),
) {
self.inner.channels.remove(offer_id);
}
self.check_disconnected();
}
ChannelState::Opening { .. }
| ChannelState::Open { .. }
| ChannelState::ClosingReopen { .. } => return Err(ChannelError::InvalidChannelState),
ChannelState::ClientReleased
| ChannelState::OpeningClientRelease
| ChannelState::ClosingClientRelease => unreachable!(),
}
Ok(())
}
fn handle_tl_connect_request(&mut self, request: protocol::TlConnectRequest2) {
let version = self
.inner
.state
.get_version()
.expect("must be connected")
.version;
let hosted_silo_unaware = version < Version::Win10Rs5;
self.notifier
.notify_hvsock(&HvsockConnectRequest::from_message(
request,
hosted_silo_unaware,
));
}
pub fn send_tl_connect_result(&mut self, result: HvsockConnectResult) {
if !result.success && self.inner.state.check_version(Version::Win10Rs3_0) {
self.sender().send_message(&protocol::TlConnectResult {
service_id: result.service_id,
endpoint_id: result.endpoint_id,
status: protocol::STATUS_CONNECTION_REFUSED,
})
}
}
fn handle_modify_channel(
&mut self,
request: &protocol::ModifyChannel,
) -> Result<(), ChannelError> {
let result = self.modify_channel(request);
if result.is_err() {
self.send_modify_channel_response(request.channel_id, protocol::STATUS_UNSUCCESSFUL);
}
result
}
fn modify_channel(&mut self, request: &protocol::ModifyChannel) -> Result<(), ChannelError> {
let (offer_id, channel) = self
.inner
.channels
.get_by_channel_id_mut(&self.inner.assigned_channels, request.channel_id)?;
let (open_request, modify_state) = match &mut channel.state {
ChannelState::Open {
params,
modify_state,
reserved_state: None,
} => (params, modify_state),
_ => return Err(ChannelError::InvalidChannelState),
};
if let ModifyState::Modifying { pending_target_vp } = modify_state {
if self.inner.state.check_version(Version::Iron) {
tracelimit::warn_ratelimited!(
"Client sent new ModifyChannel before receiving ModifyChannelResponse."
);
} else {
*pending_target_vp = Some(request.target_vp);
}
} else {
self.notifier.notify(
offer_id,
Action::Modify {
target_vp: request.target_vp,
},
);
open_request.target_vp = request.target_vp;
*modify_state = ModifyState::Modifying {
pending_target_vp: None,
};
}
Ok(())
}
pub fn modify_channel_complete(&mut self, offer_id: OfferId, status: i32) {
let channel = &mut self.inner.channels[offer_id];
if let ChannelState::Open {
params,
modify_state: ModifyState::Modifying { pending_target_vp },
reserved_state: None,
} = channel.state
{
channel.state = ChannelState::Open {
params,
modify_state: ModifyState::NotModifying,
reserved_state: None,
};
let channel_id = channel.info.as_ref().expect("assigned").channel_id;
self.send_modify_channel_response(channel_id, status);
if let Some(target_vp) = pending_target_vp {
let request = protocol::ModifyChannel {
channel_id,
target_vp,
};
if let Err(error) = self.handle_modify_channel(&request) {
tracelimit::warn_ratelimited!(?error, "Pending ModifyChannel request failed.")
}
}
}
}
fn send_modify_channel_response(&mut self, channel_id: ChannelId, status: i32) {
if self.inner.state.check_version(Version::Iron) {
self.sender()
.send_message(&protocol::ModifyChannelResponse { channel_id, status });
}
}
fn handle_modify_connection(&mut self, request: protocol::ModifyConnection) {
if let Err(err) = self.modify_connection(request) {
tracelimit::error_ratelimited!(?err, "modifying connection failed");
self.complete_modify_connection(ModifyConnectionResponse::Supported(
protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
FeatureFlags::new(),
));
}
}
fn modify_connection(&mut self, request: protocol::ModifyConnection) -> anyhow::Result<()> {
let ConnectionState::Connected(info) = &mut self.inner.state else {
anyhow::bail!(
"Invalid state for ModifyConnection request: {:?}",
self.inner.state
);
};
if info.modifying {
anyhow::bail!(
"Duplicate ModifyConnection request, state: {:?}",
self.inner.state
);
}
if (request.child_to_parent_monitor_page_gpa == 0)
!= (request.parent_to_child_monitor_page_gpa == 0)
{
anyhow::bail!("Guest must specify either both or no monitor pages, {request:?}");
}
let monitor_page =
(request.child_to_parent_monitor_page_gpa != 0).then_some(MonitorPageGpas {
child_to_parent: request.child_to_parent_monitor_page_gpa,
parent_to_child: request.parent_to_child_monitor_page_gpa,
});
info.modifying = true;
info.monitor_page = monitor_page;
tracing::debug!("modifying connection parameters.");
self.notifier.modify_connection(request.into())?;
Ok(())
}
pub fn complete_modify_connection(&mut self, response: ModifyConnectionResponse) {
tracing::debug!(?response, "modifying connection parameters complete");
match &mut self.inner.state {
ConnectionState::Connecting { .. } => self.complete_initiate_contact(response),
ConnectionState::Disconnecting { .. } => self.complete_disconnect(),
ConnectionState::Connected(info) => {
let ModifyConnectionResponse::Supported(connection_state, ..) = response else {
panic!(
"Relay should not return {:?} for a modify request with no version.",
response
);
};
if !info.modifying {
panic!(
"ModifyConnection response while not modifying, state: {:?}",
self.inner.state
);
}
info.modifying = false;
self.sender()
.send_message(&protocol::ModifyConnectionResponse { connection_state });
}
_ => panic!(
"Invalid state for ModifyConnection response: {:?}",
self.inner.state
),
}
}
fn handle_pause(&mut self) {
tracelimit::info_ratelimited!("pausing sending messages");
self.sender().send_message(&protocol::PauseResponse {});
let ConnectionState::Connected(info) = &mut self.inner.state else {
unreachable!(
"in unexpected state {:?}, should be prevented by Message::parse()",
self.inner.state
);
};
info.paused = true;
}
pub fn handle_synic_message(&mut self, message: SynicMessage) -> Result<(), ChannelError> {
assert!(!self.is_resetting());
let version = self.inner.state.get_version();
let msg = Message::parse(&message.data, version)?;
tracing::trace!(?msg, message.trusted, "received vmbus message");
if self.inner.state.is_trusted() && !message.trusted {
tracelimit::warn_ratelimited!(?msg, "Received untrusted message");
return Err(ChannelError::UntrustedMessage);
}
match &mut self.inner.state {
ConnectionState::Connected(info) if info.paused => {
if !matches!(
msg,
Message::Resume(..)
| Message::Unload(..)
| Message::InitiateContact { .. }
| Message::InitiateContact2 { .. }
) {
tracelimit::warn_ratelimited!(?msg, "Received message while paused");
return Err(ChannelError::Paused);
}
tracelimit::info_ratelimited!("resuming sending messages");
info.paused = false;
}
_ => {}
}
match msg {
Message::InitiateContact2(input, ..) => {
self.handle_initiate_contact(&input, &message, true)?
}
Message::InitiateContact(input, ..) => {
self.handle_initiate_contact(&input.into(), &message, false)?
}
Message::Unload(..) => self.handle_unload(),
Message::RequestOffers(..) => self.handle_request_offers()?,
Message::GpadlHeader(input, range) => self.handle_gpadl_header(&input, range)?,
Message::GpadlBody(input, range) => self.handle_gpadl_body(&input, range)?,
Message::GpadlTeardown(input, ..) => self.handle_gpadl_teardown(&input)?,
Message::OpenChannel(input, ..) => self.handle_open_channel(&input.into())?,
Message::OpenChannel2(input, ..) => self.handle_open_channel(&input)?,
Message::CloseChannel(input, ..) => self.handle_close_channel(&input)?,
Message::RelIdReleased(input, ..) => self.handle_rel_id_released(&input)?,
Message::TlConnectRequest(input, ..) => self.handle_tl_connect_request(input.into()),
Message::TlConnectRequest2(input, ..) => self.handle_tl_connect_request(input),
Message::ModifyChannel(input, ..) => self.handle_modify_channel(&input)?,
Message::ModifyConnection(input, ..) => self.handle_modify_connection(input),
Message::OpenReservedChannel(input, ..) => self.handle_open_reserved_channel(
&input,
version.expect("version validated by Message::parse"),
)?,
Message::CloseReservedChannel(input, ..) => {
self.handle_close_reserved_channel(&input)?
}
Message::Pause(protocol::Pause, ..) => self.handle_pause(),
Message::Resume(protocol::Resume, ..) => {}
Message::OfferChannel(..)
| Message::RescindChannelOffer(..)
| Message::AllOffersDelivered(..)
| Message::OpenResult(..)
| Message::GpadlCreated(..)
| Message::GpadlTorndown(..)
| Message::VersionResponse(..)
| Message::VersionResponse2(..)
| Message::UnloadComplete(..)
| Message::CloseReservedChannelResponse(..)
| Message::TlConnectResult(..)
| Message::ModifyChannelResponse(..)
| Message::ModifyConnectionResponse(..)
| Message::PauseResponse(..) => {
unreachable!("Server received client message {:?}", msg);
}
}
Ok(())
}
fn get_gpadl(
gpadls: &mut GpadlMap,
offer_id: OfferId,
gpadl_id: GpadlId,
) -> Option<&mut Gpadl> {
let gpadl = gpadls.get_mut(&(gpadl_id, offer_id));
if gpadl.is_none() {
tracelimit::error_ratelimited!(?offer_id, ?gpadl_id, "invalid gpadl ID for channel");
}
gpadl
}
pub fn gpadl_create_complete(&mut self, offer_id: OfferId, gpadl_id: GpadlId, status: i32) {
let gpadl = if let Some(gpadl) = Self::get_gpadl(&mut self.inner.gpadls, offer_id, gpadl_id)
{
gpadl
} else {
return;
};
let retain = match gpadl.state {
GpadlState::InProgress | GpadlState::TearingDown | GpadlState::Accepted => {
tracelimit::error_ratelimited!(?offer_id, ?gpadl_id, ?gpadl, "invalid gpadl state");
return;
}
GpadlState::Offered => {
let channel_id = self.inner.channels[offer_id]
.info
.as_ref()
.expect("assigned")
.channel_id;
self.inner
.pending_messages
.sender(self.notifier, self.inner.state.is_paused())
.send_gpadl_created(channel_id, gpadl_id, status);
if status >= 0 {
gpadl.state = GpadlState::Accepted;
true
} else {
false
}
}
GpadlState::OfferedTearingDown => {
if status >= 0 {
self.notifier.notify(
offer_id,
Action::TeardownGpadl {
gpadl_id,
post_restore: false,
},
);
gpadl.state = GpadlState::TearingDown;
true
} else {
false
}
}
};
if !retain {
self.inner
.gpadls
.remove(&(gpadl_id, offer_id))
.expect("gpadl validated above");
self.check_disconnected();
}
}
pub fn gpadl_teardown_complete(&mut self, offer_id: OfferId, gpadl_id: GpadlId) {
tracing::debug!(
offer_id = offer_id.0,
gpadl_id = gpadl_id.0,
"Gpadl teardown complete"
);
let gpadl = if let Some(gpadl) = Self::get_gpadl(&mut self.inner.gpadls, offer_id, gpadl_id)
{
gpadl
} else {
return;
};
let channel = &mut self.inner.channels[offer_id];
match gpadl.state {
GpadlState::InProgress
| GpadlState::Offered
| GpadlState::OfferedTearingDown
| GpadlState::Accepted => {
tracelimit::error_ratelimited!(?offer_id, ?gpadl_id, ?gpadl, "invalid gpadl state");
}
GpadlState::TearingDown => {
if !channel.state.is_released() {
self.sender().send_gpadl_torndown(gpadl_id);
}
self.inner
.gpadls
.remove(&(gpadl_id, offer_id))
.expect("gpadl validated above");
self.check_disconnected();
}
}
}
fn sender(&mut self) -> MessageSender<'_, N> {
self.inner
.pending_messages
.sender(self.notifier, self.inner.state.is_paused())
}
}
fn revoke<N: Notifier>(
mut sender: MessageSender<'_, N>,
offer_id: OfferId,
channel: &mut Channel,
gpadls: &mut GpadlMap,
) -> bool {
let info = match channel.state {
ChannelState::Closed
| ChannelState::Open { .. }
| ChannelState::Opening { .. }
| ChannelState::Closing { .. }
| ChannelState::ClosingReopen { .. } => {
channel.state = ChannelState::Revoked;
Some(channel.info.as_ref().expect("assigned"))
}
ChannelState::Reoffered => {
channel.state = ChannelState::Revoked;
None
}
ChannelState::ClientReleased
| ChannelState::OpeningClientRelease
| ChannelState::ClosingClientRelease => None,
ChannelState::Revoked => return true,
};
let retain = !channel.state.is_released();
gpadls.retain(|&(gpadl_id, gpadl_offer_id), gpadl| {
if gpadl_offer_id != offer_id {
return true;
}
match gpadl.state {
GpadlState::InProgress => true,
GpadlState::Offered => {
if let Some(info) = info {
sender.send_gpadl_created(
info.channel_id,
gpadl_id,
protocol::STATUS_UNSUCCESSFUL,
);
}
false
}
GpadlState::OfferedTearingDown => false,
GpadlState::Accepted => true,
GpadlState::TearingDown => {
if info.is_some() {
sender.send_gpadl_torndown(gpadl_id);
}
false
}
}
});
if let Some(info) = info {
sender.send_rescind(info);
}
if channel.restore_state != RestoreState::New {
channel.restore_state = RestoreState::Restored;
}
retain
}
struct PendingMessages(VecDeque<OutgoingMessage>);
impl PendingMessages {
fn sender<'a, N: Notifier>(
&'a mut self,
notifier: &'a mut N,
is_paused: bool,
) -> MessageSender<'a, N> {
MessageSender {
notifier,
pending_messages: self,
is_paused,
}
}
}
struct MessageSender<'a, N> {
notifier: &'a mut N,
pending_messages: &'a mut PendingMessages,
is_paused: bool,
}
impl<N: Notifier> MessageSender<'_, N> {
fn send_message<
T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
>(
&mut self,
msg: &T,
) {
let message = OutgoingMessage::new(msg);
tracing::trace!(typ = ?T::MESSAGE_TYPE, ?msg, "sending message");
if !self.pending_messages.0.is_empty()
|| self.is_paused
|| !self.notifier.send_message(&message, MessageTarget::Default)
{
tracing::trace!("message queued");
self.pending_messages.0.push_back(message);
}
}
fn send_message_with_target<
T: IntoBytes + protocol::VmbusMessage + std::fmt::Debug + Immutable + KnownLayout,
>(
&mut self,
msg: &T,
target: MessageTarget,
) {
if target == MessageTarget::Default {
self.send_message(msg);
} else {
tracing::trace!(typ = ?T::MESSAGE_TYPE, ?msg, "sending message");
let message = OutgoingMessage::new(msg);
if !self.notifier.send_message(&message, target) {
tracelimit::warn_ratelimited!(?target, "failed to send message");
}
}
}
fn send_offer(&mut self, channel: &mut Channel, version: VersionInfo) {
let info = channel.info.as_ref().expect("assigned");
let mut flags = channel.offer.flags;
if !version.feature_flags.confidential_channels() {
flags.set_confidential_ring_buffer(false);
flags.set_confidential_external_memory(false);
}
let msg = protocol::OfferChannel {
interface_id: channel.offer.interface_id,
instance_id: channel.offer.instance_id,
rsvd: [0; 4],
flags,
mmio_megabytes: channel.offer.mmio_megabytes,
user_defined: channel.offer.user_defined,
subchannel_index: channel.offer.subchannel_index,
mmio_megabytes_optional: channel.offer.mmio_megabytes_optional,
channel_id: info.channel_id,
monitor_id: info.monitor_id.unwrap_or(MonitorId::INVALID).0,
monitor_allocated: info.monitor_id.is_some() as u8,
is_dedicated: 1,
connection_id: info.connection_id,
};
tracing::info!(
channel_id = msg.channel_id.0,
connection_id = msg.connection_id,
key = %channel.offer.key(),
"sending offer to guest"
);
self.send_message(&msg);
}
fn send_open_result(
&mut self,
channel_id: ChannelId,
open_request: &OpenRequest,
result: i32,
target: MessageTarget,
) {
self.send_message_with_target(
&protocol::OpenResult {
channel_id,
open_id: open_request.open_id,
status: result as u32,
},
target,
);
}
fn send_gpadl_created(&mut self, channel_id: ChannelId, gpadl_id: GpadlId, status: i32) {
self.send_message(&protocol::GpadlCreated {
channel_id,
gpadl_id,
status,
});
}
fn send_gpadl_torndown(&mut self, gpadl_id: GpadlId) {
self.send_message(&protocol::GpadlTorndown { gpadl_id });
}
fn send_rescind(&mut self, info: &OfferedInfo) {
tracing::info!(
channel_id = info.channel_id.0,
"rescinding channel from guest"
);
self.send_message(&protocol::RescindChannelOffer {
channel_id: info.channel_id,
});
}
}
#[cfg(test)]
mod tests {
use crate::MESSAGE_CONNECTION_ID;
use super::*;
use guid::Guid;
use protocol::VmbusMessage;
use std::collections::VecDeque;
use std::sync::mpsc;
use test_with_tracing::test;
use vmbus_core::protocol::TargetInfo;
use zerocopy::FromBytes;
fn in_msg<T: IntoBytes + Immutable + KnownLayout>(
message_type: protocol::MessageType,
t: T,
) -> SynicMessage {
in_msg_ex(message_type, t, false, false)
}
fn in_msg_ex<T: IntoBytes + Immutable + KnownLayout>(
message_type: protocol::MessageType,
t: T,
multiclient: bool,
trusted: bool,
) -> SynicMessage {
let mut data = Vec::new();
data.extend_from_slice(&message_type.0.to_ne_bytes());
data.extend_from_slice(&0u32.to_ne_bytes());
data.extend_from_slice(t.as_bytes());
SynicMessage {
data,
multiclient,
trusted,
}
}
#[test]
fn test_version_negotiation_not_supported() {
let (mut notifier, _recv) = TestNotifier::new();
let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
test_initiate_contact(&mut server, &mut notifier, 0xffffffff, 0, false, 0);
}
#[test]
fn test_version_negotiation_success() {
let (mut notifier, _recv) = TestNotifier::new();
let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
test_initiate_contact(
&mut server,
&mut notifier,
Version::Win10 as u32,
0,
true,
0,
);
}
#[test]
fn test_version_negotiation_multiclient_sint() {
let (mut notifier, _recv) = TestNotifier::new();
let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
let target_info = TargetInfo::new()
.with_sint(3)
.with_vtl(0)
.with_feature_flags(FeatureFlags::new().into());
server
.with_notifier(&mut notifier)
.handle_synic_message(in_msg_ex(
protocol::MessageType::INITIATE_CONTACT,
protocol::InitiateContact {
version_requested: Version::Win10Rs3_1 as u32,
target_message_vp: 0,
interrupt_page_or_target_info: target_info.into(),
parent_to_child_monitor_page_gpa: 0,
child_to_parent_monitor_page_gpa: 0,
},
true,
false,
))
.unwrap();
assert!(notifier.modify_requests.is_empty());
assert!(matches!(server.state, ConnectionState::Disconnected));
notifier.check_message_with_target(
OutgoingMessage::new(&protocol::VersionResponse {
version_supported: 0,
connection_state: protocol::ConnectionState::SUCCESSFUL,
padding: 0,
selected_version_or_connection_id: 0,
}),
MessageTarget::Custom(ConnectionTarget { vp: 0, sint: 3 }),
);
test_initiate_contact(
&mut server,
&mut notifier,
Version::Win10Rs3_1 as u32,
target_info.into(),
true,
0,
);
}
#[test]
fn test_version_negotiation_multiclient_vtl() {
let (mut notifier, _recv) = TestNotifier::new();
let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
let target_info = TargetInfo::new()
.with_sint(SINT)
.with_vtl(2)
.with_feature_flags(FeatureFlags::new().into());
server
.with_notifier(&mut notifier)
.handle_synic_message(in_msg_ex(
protocol::MessageType::INITIATE_CONTACT,
protocol::InitiateContact {
version_requested: Version::Win10Rs4 as u32,
target_message_vp: 0,
interrupt_page_or_target_info: target_info.into(),
parent_to_child_monitor_page_gpa: 0,
child_to_parent_monitor_page_gpa: 0,
},
true,
false,
))
.unwrap();
let action = notifier.forward_request.take().unwrap();
assert!(matches!(action, InitiateContactRequest { .. }));
assert!(notifier.messages.is_empty());
assert!(matches!(server.state, ConnectionState::Disconnected));
test_initiate_contact(
&mut server,
&mut notifier,
Version::Win10Rs4 as u32,
target_info.into(),
true,
0,
);
assert!(notifier.forward_request.is_none());
}
#[test]
fn test_version_negotiation_feature_flags() {
let (mut notifier, _recv) = TestNotifier::new();
let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
let mut target_info = TargetInfo::new()
.with_sint(SINT)
.with_vtl(0)
.with_feature_flags(FeatureFlags::new().into());
test_initiate_contact(
&mut server,
&mut notifier,
Version::Copper as u32,
target_info.into(),
true,
0,
);
target_info.set_feature_flags(
FeatureFlags::new()
.with_guest_specified_signal_parameters(true)
.into(),
);
test_initiate_contact(
&mut server,
&mut notifier,
Version::Copper as u32,
target_info.into(),
true,
FeatureFlags::new()
.with_guest_specified_signal_parameters(true)
.into(),
);
target_info.set_feature_flags(
u32::from(FeatureFlags::new().with_guest_specified_signal_parameters(true))
| 0xf0000000,
);
test_initiate_contact(
&mut server,
&mut notifier,
Version::Copper as u32,
target_info.into(),
true,
FeatureFlags::new()
.with_guest_specified_signal_parameters(true)
.into(),
);
target_info.set_feature_flags(FeatureFlags::new().with_client_id(true).into());
test_initiate_contact(
&mut server,
&mut notifier,
Version::Copper as u32,
target_info.into(),
true,
FeatureFlags::new().with_client_id(true).into(),
);
}
#[test]
fn test_version_negotiation_interrupt_page() {
let (mut notifier, _recv) = TestNotifier::new();
let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
test_initiate_contact(
&mut server,
&mut notifier,
Version::V1 as u32,
1234,
true,
0,
);
let (mut notifier, _recv) = TestNotifier::new();
let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
test_initiate_contact(
&mut server,
&mut notifier,
Version::Win7 as u32,
1234,
true,
0,
);
let (mut notifier, _recv) = TestNotifier::new();
let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
test_initiate_contact(
&mut server,
&mut notifier,
Version::Win8 as u32,
1234,
true,
0,
);
}
fn test_initiate_contact(
server: &mut Server,
notifier: &mut TestNotifier,
version: u32,
target_info: u64,
expect_supported: bool,
expected_features: u32,
) {
server
.with_notifier(notifier)
.handle_synic_message(in_msg(
protocol::MessageType::INITIATE_CONTACT,
protocol::InitiateContact2 {
initiate_contact: protocol::InitiateContact {
version_requested: version,
target_message_vp: 1,
interrupt_page_or_target_info: target_info,
parent_to_child_monitor_page_gpa: 0,
child_to_parent_monitor_page_gpa: 0,
},
client_id: guid::guid!("e6e6e6e6-e6e6-e6e6-e6e6-e6e6e6e6e6e6"),
},
))
.unwrap();
let selected_version_or_connection_id = if expect_supported {
let request = notifier.next_action();
let interrupt_page = if version < Version::Win8 as u32 {
Update::Set(target_info)
} else {
Update::Reset
};
let target_message_vp = if version < Version::Win8_1 as u32 {
Some(0)
} else {
Some(1)
};
assert_eq!(
request,
ModifyConnectionRequest {
version: Some(version),
monitor_page: Update::Reset,
interrupt_page,
target_message_vp,
..Default::default()
}
);
server.with_notifier(notifier).complete_initiate_contact(
ModifyConnectionResponse::Supported(
protocol::ConnectionState::SUCCESSFUL,
SUPPORTED_FEATURE_FLAGS,
),
);
if version >= Version::Win10Rs3_1 as u32 {
1
} else {
version
}
} else {
0
};
let version_response = protocol::VersionResponse {
version_supported: if expect_supported { 1 } else { 0 },
connection_state: protocol::ConnectionState::SUCCESSFUL,
padding: 0,
selected_version_or_connection_id,
};
if version >= Version::Copper as u32 && expect_supported {
notifier.check_message(OutgoingMessage::new(&protocol::VersionResponse2 {
version_response,
supported_features: expected_features,
}));
} else {
notifier.check_message(OutgoingMessage::new(&version_response));
assert_eq!(expected_features, 0);
}
assert!(notifier.messages.is_empty());
if expect_supported {
assert!(matches!(server.state, ConnectionState::Connected { .. }));
if version < Version::Win8_1 as u32 {
assert_eq!(Some(0), notifier.target_message_vp);
} else {
assert_eq!(Some(1), notifier.target_message_vp);
}
} else {
assert!(matches!(server.state, ConnectionState::Disconnected));
assert!(notifier.target_message_vp.is_none());
}
if version < Version::Win8 as u32 {
assert_eq!(notifier.interrupt_page, Some(target_info));
} else {
assert!(notifier.interrupt_page.is_none());
}
}
struct TestNotifier {
send: mpsc::Sender<(OfferId, Action)>,
modify_requests: VecDeque<ModifyConnectionRequest>,
messages: VecDeque<(OutgoingMessage, MessageTarget)>,
hvsock_requests: Vec<HvsockConnectRequest>,
forward_request: Option<InitiateContactRequest>,
interrupt_page: Option<u64>,
reset: bool,
monitor_page: Option<MonitorPageGpas>,
target_message_vp: Option<u32>,
pend_messages: bool,
}
impl TestNotifier {
fn new() -> (Self, mpsc::Receiver<(OfferId, Action)>) {
let (send, recv) = mpsc::channel();
(
Self {
send,
modify_requests: VecDeque::new(),
messages: VecDeque::new(),
hvsock_requests: Vec::new(),
forward_request: None,
interrupt_page: None,
reset: false,
monitor_page: None,
target_message_vp: None,
pend_messages: false,
},
recv,
)
}
fn check_message(&mut self, message: OutgoingMessage) {
self.check_message_with_target(message, MessageTarget::Default);
}
fn check_message_with_target(&mut self, message: OutgoingMessage, target: MessageTarget) {
assert_eq!(self.messages.pop_front().unwrap(), (message, target));
assert!(self.messages.is_empty());
}
fn get_message<T: VmbusMessage + FromBytes + Immutable + KnownLayout>(&mut self) -> T {
let (message, _) = self.messages.pop_front().unwrap();
let (header, data) = protocol::MessageHeader::read_from_prefix(message.data()).unwrap();
assert_eq!(header.message_type(), T::MESSAGE_TYPE);
T::read_from_prefix(data).unwrap().0 }
fn check_messages(&mut self, messages: &[OutgoingMessage]) {
let messages: Vec<_> = messages
.iter()
.map(|m| (m.clone(), MessageTarget::Default))
.collect();
assert_eq!(self.messages, messages.as_slice());
self.messages.clear();
}
fn is_reset(&mut self) -> bool {
std::mem::replace(&mut self.reset, false)
}
fn check_reset(&mut self) {
assert!(self.is_reset());
assert!(self.monitor_page.is_none());
assert!(self.target_message_vp.is_none());
}
fn next_action(&mut self) -> ModifyConnectionRequest {
self.modify_requests.pop_front().unwrap()
}
}
impl Notifier for TestNotifier {
fn notify(&mut self, offer_id: OfferId, action: Action) {
tracing::debug!(?offer_id, ?action, "notify");
self.send.send((offer_id, action)).unwrap()
}
fn forward_unhandled(&mut self, request: InitiateContactRequest) {
assert!(self.forward_request.is_none());
self.forward_request = Some(request);
}
fn modify_connection(&mut self, request: ModifyConnectionRequest) -> anyhow::Result<()> {
match request.monitor_page {
Update::Unchanged => (),
Update::Reset => self.monitor_page = None,
Update::Set(value) => self.monitor_page = Some(value),
}
if let Some(vp) = request.target_message_vp {
self.target_message_vp = Some(vp);
}
match request.interrupt_page {
Update::Unchanged => (),
Update::Reset => self.interrupt_page = None,
Update::Set(value) => self.interrupt_page = Some(value),
}
self.modify_requests.push_back(request);
Ok(())
}
fn send_message(&mut self, message: &OutgoingMessage, target: MessageTarget) -> bool {
if self.pend_messages {
return false;
}
self.messages.push_back((message.clone(), target));
true
}
fn notify_hvsock(&mut self, request: &HvsockConnectRequest) {
tracing::debug!(?request, "notify_hvsock");
self.hvsock_requests.push(*request);
}
fn reset_complete(&mut self) {
self.monitor_page = None;
self.target_message_vp = None;
self.reset = true;
}
fn unload_complete(&mut self) {}
}
#[test]
fn test_channel_lifetime() {
test_channel_lifetime_helper(Version::Win10Rs5, FeatureFlags::new());
}
#[test]
fn test_channel_lifetime_iron() {
test_channel_lifetime_helper(Version::Iron, FeatureFlags::new());
}
#[test]
fn test_channel_lifetime_copper() {
test_channel_lifetime_helper(Version::Copper, FeatureFlags::new());
}
#[test]
fn test_channel_lifetime_copper_guest_signal() {
test_channel_lifetime_helper(
Version::Copper,
FeatureFlags::new().with_guest_specified_signal_parameters(true),
);
}
#[test]
fn test_channel_lifetime_copper_open_flags() {
test_channel_lifetime_helper(
Version::Copper,
FeatureFlags::new().with_channel_interrupt_redirection(true),
);
}
fn test_channel_lifetime_helper(version: Version, feature_flags: FeatureFlags) {
let (mut notifier, recv) = TestNotifier::new();
let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
let interface_id = Guid::new_random();
let instance_id = Guid::new_random();
let offer_id = server
.with_notifier(&mut notifier)
.offer_channel(OfferParamsInternal {
interface_name: "test".to_owned(),
instance_id,
interface_id,
..Default::default()
})
.unwrap();
let mut target_info = TargetInfo::new()
.with_sint(SINT)
.with_vtl(2)
.with_feature_flags(FeatureFlags::new().into());
if version >= Version::Copper {
target_info.set_feature_flags(feature_flags.into());
}
server
.with_notifier(&mut notifier)
.handle_synic_message(in_msg(
protocol::MessageType::INITIATE_CONTACT,
protocol::InitiateContact {
version_requested: version as u32,
target_message_vp: 0,
interrupt_page_or_target_info: target_info.into(),
parent_to_child_monitor_page_gpa: 0,
child_to_parent_monitor_page_gpa: 0,
},
))
.unwrap();
let request = notifier.next_action();
assert_eq!(
request,
ModifyConnectionRequest {
version: Some(version as u32),
monitor_page: Update::Reset,
interrupt_page: Update::Reset,
target_message_vp: Some(0),
..Default::default()
}
);
server
.with_notifier(&mut notifier)
.complete_initiate_contact(ModifyConnectionResponse::Supported(
protocol::ConnectionState::SUCCESSFUL,
SUPPORTED_FEATURE_FLAGS,
));
let version_response = protocol::VersionResponse {
version_supported: 1,
selected_version_or_connection_id: 1,
..FromZeros::new_zeroed()
};
if version >= Version::Copper {
notifier.check_message(OutgoingMessage::new(&protocol::VersionResponse2 {
version_response,
supported_features: feature_flags.into(),
}));
} else {
notifier.check_message(OutgoingMessage::new(&version_response));
}
server
.with_notifier(&mut notifier)
.handle_synic_message(in_msg(protocol::MessageType::REQUEST_OFFERS, ()))
.unwrap();
let channel_id = ChannelId(1);
notifier.check_messages(&[
OutgoingMessage::new(&protocol::OfferChannel {
interface_id,
instance_id,
channel_id,
connection_id: 0x2001,
is_dedicated: 1,
monitor_id: 0xff,
..protocol::OfferChannel::new_zeroed()
}),
OutgoingMessage::new(&protocol::AllOffersDelivered {}),
]);
let open_channel = protocol::OpenChannel {
channel_id,
open_id: 1,
ring_buffer_gpadl_id: GpadlId(1),
target_vp: 3,
downstream_ring_buffer_page_offset: 2,
user_data: UserDefinedData::new_zeroed(),
};
let mut event_flag = 1;
let mut connection_id = 0x2001;
let mut expected_flags = protocol::OpenChannelFlags::new();
if version >= Version::Copper
&& (feature_flags.guest_specified_signal_parameters()
|| feature_flags.channel_interrupt_redirection())
{
if feature_flags.channel_interrupt_redirection() {
expected_flags.set_redirect_interrupt(true);
}
if feature_flags.guest_specified_signal_parameters() {
event_flag = 2;
connection_id = 0x2002;
}
server
.with_notifier(&mut notifier)
.handle_synic_message(in_msg(
protocol::MessageType::OPEN_CHANNEL,
protocol::OpenChannel2 {
open_channel,
event_flag: 2,
connection_id: 0x2002,
flags: (u16::from(
protocol::OpenChannelFlags::new().with_redirect_interrupt(true),
) | 0xabc)
.into(), },
))
.unwrap();
} else {
server
.with_notifier(&mut notifier)
.handle_synic_message(in_msg(protocol::MessageType::OPEN_CHANNEL, open_channel))
.unwrap();
}
let (id, action) = recv.recv().unwrap();
assert_eq!(id, offer_id);
let Action::Open(op, ..) = action else {
panic!("unexpected action: {:?}", action);
};
assert_eq!(op.open_data.ring_gpadl_id, GpadlId(1));
assert_eq!(op.open_data.ring_offset, 2);
assert_eq!(op.open_data.target_vp, 3);
assert_eq!(op.open_data.event_flag, event_flag);
assert_eq!(op.open_data.connection_id, connection_id);
assert_eq!(op.connection_id, connection_id);
assert_eq!(op.event_flag, event_flag);
assert_eq!(op.monitor_id, None);
assert_eq!(op.flags, expected_flags);
server
.with_notifier(&mut notifier)
.open_complete(offer_id, 0);
notifier.check_message(OutgoingMessage::new(&protocol::OpenResult {
channel_id,
open_id: 1,
status: 0,
}));
server
.with_notifier(&mut notifier)
.handle_synic_message(in_msg(
protocol::MessageType::MODIFY_CHANNEL,
protocol::ModifyChannel {
channel_id,
target_vp: 4,
},
))
.unwrap();
let (id, action) = recv.recv().unwrap();
assert_eq!(id, offer_id);
assert!(matches!(action, Action::Modify { target_vp: 4 }));
server
.with_notifier(&mut notifier)
.modify_channel_complete(id, 0);
if version >= Version::Iron {
notifier.check_message(OutgoingMessage::new(&protocol::ModifyChannelResponse {
channel_id,
status: 0,
}));
}
assert!(notifier.messages.is_empty());
server.with_notifier(&mut notifier).revoke_channel(offer_id);
server
.with_notifier(&mut notifier)
.handle_synic_message(in_msg(
protocol::MessageType::REL_ID_RELEASED,
protocol::RelIdReleased { channel_id },
))
.unwrap();
}
#[test]
fn test_hvsock() {
test_hvsock_helper(Version::Win10, false);
}
#[test]
fn test_hvsock_rs3() {
test_hvsock_helper(Version::Win10Rs3_0, false);
}
#[test]
fn test_hvsock_rs5() {
test_hvsock_helper(Version::Win10Rs5, false);
test_hvsock_helper(Version::Win10Rs5, true);
}
fn test_hvsock_helper(version: Version, force_small_message: bool) {
let (mut notifier, _recv) = TestNotifier::new();
let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
server
.with_notifier(&mut notifier)
.handle_synic_message(in_msg(
protocol::MessageType::INITIATE_CONTACT,
protocol::InitiateContact {
version_requested: version as u32,
target_message_vp: 0,
interrupt_page_or_target_info: 0,
parent_to_child_monitor_page_gpa: 0,
child_to_parent_monitor_page_gpa: 0,
},
))
.unwrap();
let request = notifier.next_action();
assert_eq!(
request,
ModifyConnectionRequest {
version: Some(version as u32),
monitor_page: Update::Reset,
interrupt_page: Update::Reset,
target_message_vp: Some(0),
..Default::default()
}
);
server
.with_notifier(&mut notifier)
.complete_initiate_contact(ModifyConnectionResponse::Supported(
protocol::ConnectionState::SUCCESSFUL,
SUPPORTED_FEATURE_FLAGS,
));
notifier.messages.pop_front();
let service_id = Guid::new_random();
let endpoint_id = Guid::new_random();
let request_msg = if version >= Version::Win10Rs5 && !force_small_message {
in_msg(
protocol::MessageType::TL_CONNECT_REQUEST,
protocol::TlConnectRequest2 {
base: protocol::TlConnectRequest {
service_id,
endpoint_id,
},
silo_id: Guid::ZERO,
},
)
} else {
in_msg(
protocol::MessageType::TL_CONNECT_REQUEST,
protocol::TlConnectRequest {
service_id,
endpoint_id,
},
)
};
server
.with_notifier(&mut notifier)
.handle_synic_message(request_msg)
.unwrap();
let request = notifier.hvsock_requests.pop().unwrap();
assert_eq!(request.service_id, service_id);
assert_eq!(request.endpoint_id, endpoint_id);
assert!(notifier.hvsock_requests.is_empty());
server
.with_notifier(&mut notifier)
.send_tl_connect_result(HvsockConnectResult::from_request(&request, false));
if version >= Version::Win10Rs3_0 {
notifier.check_message(OutgoingMessage::new(&protocol::TlConnectResult {
service_id: request.service_id,
endpoint_id: request.endpoint_id,
status: protocol::STATUS_CONNECTION_REFUSED,
}));
}
assert!(notifier.messages.is_empty());
}
struct TestEnv {
server: Server,
notifier: TestNotifier,
version: Option<VersionInfo>,
_recv: mpsc::Receiver<(OfferId, Action)>,
}
impl TestEnv {
fn new() -> Self {
let (notifier, _recv) = TestNotifier::new();
let server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0);
Self {
server,
notifier,
version: None,
_recv,
}
}
fn c(&mut self) -> ServerWithNotifier<'_, TestNotifier> {
self.server.with_notifier(&mut self.notifier)
}
fn complete_reset(&mut self) {
let _ = self.next_action();
self.c()
.complete_modify_connection(ModifyConnectionResponse::Supported(
protocol::ConnectionState::SUCCESSFUL,
SUPPORTED_FEATURE_FLAGS,
));
}
fn offer(&mut self, id: u32) -> OfferId {
self.offer_inner(id, id, false, None, None, OfferFlags::new())
}
fn offer_with_mnf(&mut self, id: u32) -> OfferId {
self.offer_inner(id, id, true, None, None, OfferFlags::new())
}
fn offer_with_preset_mnf(&mut self, id: u32, monitor_id: u8) -> OfferId {
self.offer_inner(id, id, false, None, Some(monitor_id), OfferFlags::new())
}
fn offer_with_order(
&mut self,
interface_id: u32,
instance_id: u32,
order: Option<u32>,
) -> OfferId {
self.offer_inner(
interface_id,
instance_id,
false,
order,
None,
OfferFlags::new(),
)
}
fn offer_with_flags(&mut self, id: u32, flags: OfferFlags) -> OfferId {
self.offer_inner(id, id, false, None, None, flags)
}
fn offer_inner(
&mut self,
interface_id: u32,
instance_id: u32,
use_mnf: bool,
offer_order: Option<u32>,
monitor_id: Option<u8>,
flags: OfferFlags,
) -> OfferId {
self.c()
.offer_channel(OfferParamsInternal {
instance_id: Guid {
data1: instance_id,
..Guid::ZERO
},
interface_id: Guid {
data1: interface_id,
..Guid::ZERO
},
use_mnf,
offer_order,
monitor_id,
flags,
..Default::default()
})
.unwrap()
}
fn open(&mut self, id: u32) {
self.c()
.handle_open_channel(&protocol::OpenChannel2 {
open_channel: protocol::OpenChannel {
channel_id: ChannelId(id),
..FromZeros::new_zeroed()
},
..FromZeros::new_zeroed()
})
.unwrap()
}
fn close(&mut self, id: u32) -> Result<(), ChannelError> {
self.c().handle_close_channel(&protocol::CloseChannel {
channel_id: ChannelId(id),
})
}
fn open_reserved(&mut self, id: u32, target_vp: u32, target_sint: u32) {
let version = self.server.state.get_version().expect("vmbus connected");
self.c()
.handle_open_reserved_channel(
&protocol::OpenReservedChannel {
channel_id: ChannelId(id),
target_vp,
target_sint,
ring_buffer_gpadl: GpadlId(id),
..FromZeros::new_zeroed()
},
version,
)
.unwrap()
}
fn close_reserved(&mut self, id: u32, target_vp: u32, target_sint: u32) {
self.c()
.handle_close_reserved_channel(&protocol::CloseReservedChannel {
channel_id: ChannelId(id),
target_vp,
target_sint,
})
.unwrap();
}
fn gpadl(&mut self, channel_id: u32, gpadl_id: u32) {
self.c()
.handle_gpadl_header(
&protocol::GpadlHeader {
channel_id: ChannelId(channel_id),
gpadl_id: GpadlId(gpadl_id),
count: 1,
len: 16,
},
[1u64, 0u64].as_bytes(),
)
.unwrap();
}
fn teardown_gpadl(&mut self, channel_id: u32, gpadl_id: u32) {
self.c()
.handle_gpadl_teardown(&protocol::GpadlTeardown {
channel_id: ChannelId(channel_id),
gpadl_id: GpadlId(gpadl_id),
})
.unwrap();
}
fn release(&mut self, id: u32) {
self.c()
.handle_rel_id_released(&protocol::RelIdReleased {
channel_id: ChannelId(id),
})
.unwrap();
}
fn connect(&mut self, version: Version, feature_flags: FeatureFlags) {
self.start_connect(version, feature_flags, false);
self.complete_connect();
}
fn connect_trusted(&mut self, version: Version, feature_flags: FeatureFlags) {
self.start_connect(version, feature_flags, true);
self.complete_connect();
}
fn start_connect(&mut self, version: Version, feature_flags: FeatureFlags, trusted: bool) {
self.version = Some(VersionInfo {
version,
feature_flags,
});
let result = self.c().handle_synic_message(in_msg_ex(
protocol::MessageType::INITIATE_CONTACT,
protocol::InitiateContact2 {
initiate_contact: protocol::InitiateContact {
version_requested: version as u32,
interrupt_page_or_target_info: TargetInfo::new()
.with_sint(SINT)
.with_vtl(0)
.with_feature_flags(feature_flags.into())
.into(),
child_to_parent_monitor_page_gpa: 0x123f000,
parent_to_child_monitor_page_gpa: 0x321f000,
..FromZeros::new_zeroed()
},
client_id: Guid::ZERO,
},
false,
trusted,
));
assert!(result.is_ok());
let request = self.notifier.next_action();
assert_eq!(
request,
ModifyConnectionRequest {
version: Some(version as u32),
monitor_page: Update::Set(MonitorPageGpas {
child_to_parent: 0x123f000,
parent_to_child: 0x321f000,
}),
interrupt_page: Update::Reset,
target_message_vp: Some(0),
..Default::default()
}
);
}
fn complete_connect(&mut self) {
self.c()
.complete_initiate_contact(ModifyConnectionResponse::Supported(
protocol::ConnectionState::SUCCESSFUL,
SUPPORTED_FEATURE_FLAGS,
));
let version = self.version.unwrap();
if version.version >= Version::Copper {
let response = self.notifier.get_message::<protocol::VersionResponse2>();
assert_eq!(response.version_response.version_supported, 1);
self.version = Some(VersionInfo {
version: version.version,
feature_flags: version.feature_flags & response.supported_features.into(),
})
} else {
let response = self.notifier.get_message::<protocol::VersionResponse>();
assert_eq!(response.version_supported, 1);
}
}
fn send_message(&mut self, message: SynicMessage) {
self.try_send_message(message).unwrap();
}
fn try_send_message(&mut self, message: SynicMessage) -> Result<(), ChannelError> {
self.c().handle_synic_message(message)
}
fn next_action(&mut self) -> ModifyConnectionRequest {
self.notifier.next_action()
}
}
#[test]
fn test_hot_add() {
let mut env = TestEnv::new();
let offer_id1 = env.offer(1);
let result = env.c().handle_initiate_contact(
&protocol::InitiateContact2 {
initiate_contact: protocol::InitiateContact {
version_requested: Version::Win10 as u32,
..FromZeros::new_zeroed()
},
..FromZeros::new_zeroed()
},
&SynicMessage::default(),
true,
);
assert!(result.is_ok());
let offer_id2 = env.offer(2);
env.c()
.complete_initiate_contact(ModifyConnectionResponse::Supported(
protocol::ConnectionState::SUCCESSFUL,
SUPPORTED_FEATURE_FLAGS,
));
let offer_id3 = env.offer(3);
env.c().handle_request_offers().unwrap();
let offer_id4 = env.offer(4);
env.open(1);
env.open(2);
env.open(3);
env.open(4);
env.c().open_complete(offer_id1, 0);
env.c().open_complete(offer_id2, 0);
env.c().open_complete(offer_id3, 0);
env.c().open_complete(offer_id4, 0);
env.c().reset();
env.c().close_complete(offer_id1);
env.c().close_complete(offer_id2);
env.c().close_complete(offer_id3);
env.c().close_complete(offer_id4);
env.complete_reset();
assert!(env.notifier.is_reset());
}
#[test]
fn test_save_restore_with_no_connection() {
let mut env = TestEnv::new();
let offer_id1 = env.offer(1);
let _offer_id2 = env.offer(2);
let state = env.server.save();
env.c().reset();
assert!(env.notifier.is_reset());
env.server.restore(state).unwrap();
env.c().restore_channel(offer_id1, false).unwrap();
env.c().post_restore().unwrap();
}
#[test]
fn test_save_restore_with_connection() {
let mut env = TestEnv::new();
let offer_id1 = env.offer_with_mnf(1);
let offer_id2 = env.offer(2);
let offer_id3 = env.offer_with_mnf(3);
let offer_id4 = env.offer(4);
let offer_id5 = env.offer_with_mnf(5);
let offer_id6 = env.offer(6);
let offer_id7 = env.offer(7);
let offer_id8 = env.offer(8);
let offer_id9 = env.offer(9);
let offer_id10 = env.offer(10);
let expected_monitor = MonitorPageGpas {
child_to_parent: 0x123f000,
parent_to_child: 0x321f000,
};
env.connect(Version::Win10, FeatureFlags::new());
assert_eq!(env.notifier.monitor_page, Some(expected_monitor));
env.c().handle_request_offers().unwrap();
assert_eq!(env.server.assigned_monitors.bitmap(), 7);
env.open(1);
env.open(2);
env.open(3);
env.open(5);
env.c().open_complete(offer_id1, 0);
env.c().open_complete(offer_id2, 0);
env.c().open_complete(offer_id5, 0);
env.gpadl(1, 10);
env.c().gpadl_create_complete(offer_id1, GpadlId(10), 0);
env.gpadl(1, 11);
env.gpadl(2, 20);
env.c().gpadl_create_complete(offer_id2, GpadlId(20), 0);
env.gpadl(2, 21);
env.gpadl(3, 30);
env.c().gpadl_create_complete(offer_id3, GpadlId(30), 0);
env.gpadl(3, 31);
env.open_reserved(7, 1, SINT.into());
env.open_reserved(8, 2, SINT.into());
env.open_reserved(9, 3, SINT.into());
env.c().open_complete(offer_id8, 0);
env.c().open_complete(offer_id9, 0);
env.close_reserved(9, 3, SINT.into());
env.c().revoke_channel(offer_id10);
let offer_id10 = env.offer(10);
let state = env.server.save();
env.c().reset();
env.c().close_complete(offer_id1);
env.c().close_complete(offer_id2);
env.c().open_complete(offer_id3, -1);
env.c().close_complete(offer_id5);
env.c().open_complete(offer_id7, -1);
env.c().close_complete(offer_id8);
env.c().close_complete(offer_id9);
env.c().gpadl_teardown_complete(offer_id1, GpadlId(10));
env.c().gpadl_create_complete(offer_id1, GpadlId(11), -1);
env.c().gpadl_teardown_complete(offer_id2, GpadlId(20));
env.c().gpadl_create_complete(offer_id2, GpadlId(21), -1);
env.c().gpadl_teardown_complete(offer_id3, GpadlId(30));
env.c().gpadl_create_complete(offer_id3, GpadlId(31), -1);
env.complete_reset();
env.notifier.check_reset();
env.c().revoke_channel(offer_id5);
env.c().revoke_channel(offer_id6);
env.server.restore(state.clone()).unwrap();
env.c().revoke_channel(offer_id1);
env.c().revoke_channel(offer_id4);
env.c().restore_channel(offer_id3, false).unwrap();
let offer_id5 = env.offer_with_mnf(5);
env.c().restore_channel(offer_id5, true).unwrap();
env.c().restore_channel(offer_id7, false).unwrap();
env.c().restore_channel(offer_id8, true).unwrap();
env.c().restore_channel(offer_id9, true).unwrap();
env.c().restore_channel(offer_id10, false).unwrap();
assert!(matches!(
env.server.channels[offer_id10].state,
ChannelState::Reoffered
));
env.c().post_restore().unwrap();
assert_eq!(env.notifier.monitor_page, Some(expected_monitor));
assert_eq!(env.notifier.target_message_vp, Some(0));
assert_eq!(env.server.assigned_monitors.bitmap(), 6);
env.release(1);
env.release(2);
env.release(4);
env.c().open_complete(offer_id7, 0);
env.close_reserved(8, 2, SINT.into());
env.c().close_complete(offer_id8);
env.c().close_complete(offer_id9);
env.c().reset();
env.c().open_complete(offer_id3, -1);
env.c().gpadl_teardown_complete(offer_id3, GpadlId(30));
env.c().gpadl_create_complete(offer_id3, GpadlId(31), -1);
env.c().close_complete(offer_id5);
env.c().close_complete(offer_id7);
env.complete_reset();
env.notifier.check_reset();
env.server.restore(state).unwrap();
env.c().restore_channel(offer_id3, false).unwrap();
env.c().post_restore().unwrap();
assert_eq!(env.notifier.monitor_page, Some(expected_monitor));
assert_eq!(env.notifier.target_message_vp, Some(0));
}
#[test]
fn test_save_restore_connecting() {
let mut env = TestEnv::new();
let offer_id1 = env.offer_with_mnf(1);
let _offer_id2 = env.offer(2);
env.start_connect(Version::Win10, FeatureFlags::new(), false);
assert_eq!(
env.notifier.monitor_page,
Some(MonitorPageGpas {
child_to_parent: 0x123f000,
parent_to_child: 0x321f000
})
);
let state = env.server.save();
env.c().reset();
env.complete_connect();
env.notifier.check_reset();
env.server.restore(state).unwrap();
env.c().restore_channel(offer_id1, false).unwrap();
env.c().post_restore().unwrap();
assert_eq!(
env.notifier.monitor_page,
Some(MonitorPageGpas {
child_to_parent: 0x123f000,
parent_to_child: 0x321f000
})
);
let request = env.next_action();
assert_eq!(
request,
ModifyConnectionRequest {
version: Some(Version::Win10 as u32),
monitor_page: Update::Set(MonitorPageGpas {
child_to_parent: 0x123f000,
parent_to_child: 0x321f000,
}),
interrupt_page: Update::Reset,
target_message_vp: Some(0),
force: true,
..Default::default()
}
);
assert_eq!(Some(0), env.notifier.target_message_vp);
env.complete_connect();
}
#[test]
fn test_save_restore_modifying() {
let mut env = TestEnv::new();
env.connect(
Version::Copper,
FeatureFlags::new().with_modify_connection(true),
);
let expected = MonitorPageGpas {
parent_to_child: 0x123f000,
child_to_parent: 0x321f000,
};
env.send_message(in_msg(
protocol::MessageType::MODIFY_CONNECTION,
protocol::ModifyConnection {
parent_to_child_monitor_page_gpa: expected.parent_to_child,
child_to_parent_monitor_page_gpa: expected.child_to_parent,
},
));
env.next_action();
assert_eq!(env.notifier.monitor_page, Some(expected));
let state = env.server.save();
env.c().reset();
env.notifier.check_reset();
env.server.restore(state).unwrap();
env.c().post_restore().unwrap();
let request = env.next_action();
assert_eq!(
request,
ModifyConnectionRequest {
monitor_page: Update::Set(MonitorPageGpas {
parent_to_child: 0x123f000,
child_to_parent: 0x321f000,
}),
interrupt_page: Update::Reset,
target_message_vp: Some(0),
force: true,
..Default::default()
}
);
assert_eq!(env.notifier.monitor_page, Some(expected));
env.c()
.complete_modify_connection(ModifyConnectionResponse::Supported(
protocol::ConnectionState::SUCCESSFUL,
SUPPORTED_FEATURE_FLAGS,
));
env.notifier
.check_message(OutgoingMessage::new(&protocol::ModifyConnectionResponse {
connection_state: protocol::ConnectionState::SUCCESSFUL,
}));
}
#[test]
fn test_save_restore_disconnected_reserved() {
let mut env = TestEnv::new();
let offer_id1 = env.offer(1);
let _offer_id2 = env.offer(2);
let _offer_id3 = env.offer(3);
env.connect(Version::Copper, FeatureFlags::new());
env.c().handle_request_offers().unwrap();
env.gpadl(1, 1);
env.c().gpadl_create_complete(offer_id1, GpadlId(1), 0);
env.open_reserved(1, 0, 3);
env.c().open_complete(offer_id1, protocol::STATUS_SUCCESS);
env.c().handle_unload();
let state = env.server.save();
let mut env = TestEnv::new();
let offer_id1 = env.offer(1);
let offer_id2 = env.offer(2);
let offer_id3 = env.offer(3);
env.server.restore(state).unwrap();
env.c().restore_channel(offer_id1, true).unwrap();
env.c().restore_channel(offer_id2, false).unwrap();
env.c().restore_channel(offer_id3, false).unwrap();
env.c().post_restore().unwrap();
assert!(env.server.gpadls.contains_key(&(GpadlId(1), offer_id1)));
}
#[test]
fn test_pending_messages() {
let mut env = TestEnv::new();
let offer_id1 = env.offer(1);
let offer_id2 = env.offer(2);
let offer_id3 = env.offer(3);
env.connect(Version::Copper, FeatureFlags::new());
env.c().handle_request_offers().unwrap();
env.notifier.messages.clear();
env.notifier.pend_messages = true;
env.open_reserved(2, 4, SINT.into());
env.c().open_complete(offer_id2, protocol::STATUS_SUCCESS);
assert!(env.notifier.messages.is_empty());
assert!(!env.server.has_pending_messages());
env.gpadl(1, 10);
env.c()
.gpadl_create_complete(offer_id1, GpadlId(10), protocol::STATUS_SUCCESS);
env.notifier.pend_messages = true;
env.open(3);
env.c().open_complete(offer_id3, protocol::STATUS_SUCCESS);
assert!(env.notifier.messages.is_empty());
assert!(env.server.has_pending_messages());
env.notifier.pend_messages = false;
let state = env.server.save();
let mut env = TestEnv::new();
let offer_id1 = env.offer(1);
let offer_id2 = env.offer(2);
let offer_id3 = env.offer(3);
env.server.restore(state).unwrap();
env.c().restore_channel(offer_id1, false).unwrap();
env.c().restore_channel(offer_id2, true).unwrap();
env.c().restore_channel(offer_id3, true).unwrap();
env.c().post_restore().unwrap();
assert!(env.server.has_pending_messages());
let mut pending_messages = Vec::new();
let r = env.server.poll_flush_pending_messages(|msg| {
pending_messages.push(msg.clone());
Poll::Ready(())
});
assert!(r.is_ready());
assert_eq!(pending_messages.len(), 2);
assert_eq!(
protocol::MessageHeader::read_from_prefix(pending_messages[0].data())
.unwrap()
.0
.message_type(),
protocol::MessageType::GPADL_CREATED
);
assert_eq!(
protocol::MessageHeader::read_from_prefix(pending_messages[1].data())
.unwrap()
.0
.message_type(),
protocol::MessageType::OPEN_CHANNEL_RESULT
);
assert!(!env.server.has_pending_messages());
}
#[test]
fn test_modify_connection() {
let mut env = TestEnv::new();
env.connect(
Version::Copper,
FeatureFlags::new().with_modify_connection(true),
);
env.send_message(in_msg(
protocol::MessageType::MODIFY_CONNECTION,
protocol::ModifyConnection {
parent_to_child_monitor_page_gpa: 5,
child_to_parent_monitor_page_gpa: 6,
},
));
assert_eq!(
env.notifier.monitor_page,
Some(MonitorPageGpas {
parent_to_child: 5,
child_to_parent: 6
})
);
let request = env.next_action();
assert_eq!(
request,
ModifyConnectionRequest {
monitor_page: Update::Set(MonitorPageGpas {
child_to_parent: 6,
parent_to_child: 5,
}),
..Default::default()
}
);
env.c()
.complete_modify_connection(ModifyConnectionResponse::Supported(
protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
SUPPORTED_FEATURE_FLAGS,
));
env.notifier
.check_message(OutgoingMessage::new(&protocol::ModifyConnectionResponse {
connection_state: protocol::ConnectionState::FAILED_UNKNOWN_FAILURE,
}));
}
#[test]
fn test_modify_connection_unsupported() {
let mut env = TestEnv::new();
env.connect(Version::Copper, FeatureFlags::new());
let err = env
.try_send_message(in_msg(
protocol::MessageType::MODIFY_CONNECTION,
protocol::ModifyConnection {
parent_to_child_monitor_page_gpa: 5,
child_to_parent_monitor_page_gpa: 6,
},
))
.unwrap_err();
assert!(matches!(
err,
ChannelError::ParseError(protocol::ParseError::InvalidMessageType(
protocol::MessageType::MODIFY_CONNECTION
))
));
}
#[test]
fn test_reserved_channels() {
let mut env = TestEnv::new();
let offer_id1 = env.offer(1);
let offer_id2 = env.offer(2);
let offer_id3 = env.offer(3);
env.connect(Version::Win10, FeatureFlags::new());
env.c().handle_request_offers().unwrap();
env.gpadl(1, 10);
env.c().gpadl_create_complete(offer_id1, GpadlId(10), 0);
env.notifier.messages.clear();
env.open_reserved(1, 1, SINT.into());
env.c().open_complete(offer_id1, 0);
env.notifier.check_message_with_target(
OutgoingMessage::new(&protocol::OpenResult {
channel_id: ChannelId(1),
..FromZeros::new_zeroed()
}),
MessageTarget::ReservedChannel(offer_id1, ConnectionTarget { vp: 1, sint: SINT }),
);
env.open_reserved(2, 2, SINT.into());
env.c().open_complete(offer_id2, 0);
env.open_reserved(3, 3, SINT.into());
env.c().open_complete(offer_id3, 0);
assert!(matches!(env.close(2), Err(ChannelError::ChannelReserved)));
env.c().handle_unload();
env.close_reserved(2, 2, SINT.into());
env.c().close_complete(offer_id2);
env.notifier.messages.clear();
env.connect(Version::Copper, FeatureFlags::new());
env.c().handle_request_offers().unwrap();
env.gpadl(2, 10);
env.c().gpadl_create_complete(offer_id2, GpadlId(10), 0);
env.open_reserved(2, 3, SINT.into());
env.c().open_complete(offer_id2, 0);
env.notifier.messages.clear();
env.close_reserved(1, 4, SINT.into());
env.c().close_complete(offer_id1);
env.notifier.check_message_with_target(
OutgoingMessage::new(&protocol::CloseReservedChannelResponse {
channel_id: ChannelId(1),
}),
MessageTarget::ReservedChannel(offer_id1, ConnectionTarget { vp: 4, sint: SINT }),
);
env.teardown_gpadl(1, 10);
env.c().gpadl_teardown_complete(offer_id1, GpadlId(10));
env.c().reset();
env.c().close_complete(offer_id2);
env.c().gpadl_teardown_complete(offer_id2, GpadlId(10));
env.c().close_complete(offer_id3);
env.complete_reset();
assert!(env.notifier.is_reset());
}
#[test]
fn test_disconnected_reset() {
let mut env = TestEnv::new();
let offer_id1 = env.offer(1);
env.connect(Version::Win10, FeatureFlags::new());
env.c().handle_request_offers().unwrap();
env.gpadl(1, 10);
env.c().gpadl_create_complete(offer_id1, GpadlId(10), 0);
env.open_reserved(1, 1, SINT.into());
env.c().open_complete(offer_id1, 0);
env.c().handle_unload();
env.c().reset();
env.c().close_complete(offer_id1);
env.c().gpadl_teardown_complete(offer_id1, GpadlId(10));
env.complete_reset();
assert!(env.notifier.is_reset());
let offer_id2 = env.offer(2);
env.notifier.messages.clear();
env.connect(Version::Win10, FeatureFlags::new());
env.c().handle_request_offers().unwrap();
env.gpadl(2, 20);
env.c().gpadl_create_complete(offer_id2, GpadlId(20), 0);
env.open_reserved(2, 2, SINT.into());
env.c().open_complete(offer_id2, 0);
env.c().handle_unload();
env.close_reserved(2, 2, SINT.into());
env.c().close_complete(offer_id2);
env.c().gpadl_teardown_complete(offer_id2, GpadlId(20));
env.c().reset();
assert!(env.notifier.is_reset());
}
#[test]
fn test_mnf_channel() {
let mut env = TestEnv::new();
let _offer_id1 = env.offer(1);
let _offer_id2 = env.offer_with_mnf(2);
let _offer_id3 = env.offer_with_preset_mnf(3, 5);
env.connect(Version::Copper, FeatureFlags::new());
env.c().handle_request_offers().unwrap();
assert_eq!(env.server.assigned_monitors.bitmap(), 1);
env.notifier.check_messages(&[
OutgoingMessage::new(&protocol::OfferChannel {
interface_id: Guid {
data1: 1,
..Guid::ZERO
},
instance_id: Guid {
data1: 1,
..Guid::ZERO
},
channel_id: ChannelId(1),
connection_id: 0x2001,
is_dedicated: 1,
monitor_id: 0xff,
..protocol::OfferChannel::new_zeroed()
}),
OutgoingMessage::new(&protocol::OfferChannel {
interface_id: Guid {
data1: 2,
..Guid::ZERO
},
instance_id: Guid {
data1: 2,
..Guid::ZERO
},
channel_id: ChannelId(2),
connection_id: 0x2002,
is_dedicated: 1,
monitor_id: 0,
monitor_allocated: 1,
..protocol::OfferChannel::new_zeroed()
}),
OutgoingMessage::new(&protocol::OfferChannel {
interface_id: Guid {
data1: 3,
..Guid::ZERO
},
instance_id: Guid {
data1: 3,
..Guid::ZERO
},
channel_id: ChannelId(3),
connection_id: 0x2003,
is_dedicated: 1,
monitor_id: 5,
monitor_allocated: 1,
..protocol::OfferChannel::new_zeroed()
}),
OutgoingMessage::new(&protocol::AllOffersDelivered {}),
])
}
#[test]
fn test_channel_id_order() {
let mut env = TestEnv::new();
let _offer_id1 = env.offer(3);
let _offer_id2 = env.offer(10);
let _offer_id3 = env.offer(5);
let _offer_id4 = env.offer(17);
let _offer_id5 = env.offer_with_order(5, 6, Some(2));
let _offer_id6 = env.offer_with_order(5, 8, Some(1));
let _offer_id7 = env.offer_with_order(5, 1, None);
env.connect(Version::Win10, FeatureFlags::new());
env.c().handle_request_offers().unwrap();
env.notifier.check_messages(&[
OutgoingMessage::new(&protocol::OfferChannel {
interface_id: Guid {
data1: 3,
..Guid::ZERO
},
instance_id: Guid {
data1: 3,
..Guid::ZERO
},
channel_id: ChannelId(1),
connection_id: 0x2001,
is_dedicated: 1,
monitor_id: 0xff,
..protocol::OfferChannel::new_zeroed()
}),
OutgoingMessage::new(&protocol::OfferChannel {
interface_id: Guid {
data1: 5,
..Guid::ZERO
},
instance_id: Guid {
data1: 8,
..Guid::ZERO
},
channel_id: ChannelId(2),
connection_id: 0x2002,
is_dedicated: 1,
monitor_id: 0xff,
..protocol::OfferChannel::new_zeroed()
}),
OutgoingMessage::new(&protocol::OfferChannel {
interface_id: Guid {
data1: 5,
..Guid::ZERO
},
instance_id: Guid {
data1: 6,
..Guid::ZERO
},
channel_id: ChannelId(3),
connection_id: 0x2003,
is_dedicated: 1,
monitor_id: 0xff,
..protocol::OfferChannel::new_zeroed()
}),
OutgoingMessage::new(&protocol::OfferChannel {
interface_id: Guid {
data1: 5,
..Guid::ZERO
},
instance_id: Guid {
data1: 1,
..Guid::ZERO
},
channel_id: ChannelId(4),
connection_id: 0x2004,
is_dedicated: 1,
monitor_id: 0xff,
..protocol::OfferChannel::new_zeroed()
}),
OutgoingMessage::new(&protocol::OfferChannel {
interface_id: Guid {
data1: 5,
..Guid::ZERO
},
instance_id: Guid {
data1: 5,
..Guid::ZERO
},
channel_id: ChannelId(5),
connection_id: 0x2005,
is_dedicated: 1,
monitor_id: 0xff,
..protocol::OfferChannel::new_zeroed()
}),
OutgoingMessage::new(&protocol::OfferChannel {
interface_id: Guid {
data1: 10,
..Guid::ZERO
},
instance_id: Guid {
data1: 10,
..Guid::ZERO
},
channel_id: ChannelId(6),
connection_id: 0x2006,
is_dedicated: 1,
monitor_id: 0xff,
..protocol::OfferChannel::new_zeroed()
}),
OutgoingMessage::new(&protocol::OfferChannel {
interface_id: Guid {
data1: 17,
..Guid::ZERO
},
instance_id: Guid {
data1: 17,
..Guid::ZERO
},
channel_id: ChannelId(7),
connection_id: 0x2007,
is_dedicated: 1,
monitor_id: 0xff,
..protocol::OfferChannel::new_zeroed()
}),
OutgoingMessage::new(&protocol::AllOffersDelivered {}),
])
}
#[test]
fn test_confidential_connection() {
let mut env = TestEnv::new();
env.connect_trusted(
Version::Copper,
FeatureFlags::new().with_confidential_channels(true),
);
assert_eq!(
env.version.unwrap(),
VersionInfo {
version: Version::Copper,
feature_flags: FeatureFlags::new().with_confidential_channels(true)
}
);
env.offer(1); env.offer_with_flags(2, OfferFlags::new().with_confidential_ring_buffer(true));
env.offer_with_flags(
3,
OfferFlags::new()
.with_confidential_ring_buffer(true)
.with_confidential_external_memory(true),
);
let error = env
.try_send_message(in_msg(
protocol::MessageType::REQUEST_OFFERS,
protocol::RequestOffers {},
))
.unwrap_err();
assert!(matches!(error, ChannelError::UntrustedMessage));
assert!(env.notifier.messages.is_empty());
env.send_message(in_msg_ex(
protocol::MessageType::REQUEST_OFFERS,
protocol::RequestOffers {},
false,
true,
));
let offer = env.notifier.get_message::<protocol::OfferChannel>();
assert_eq!(offer.channel_id, ChannelId(1));
assert_eq!(offer.flags, OfferFlags::new());
let offer = env.notifier.get_message::<protocol::OfferChannel>();
assert_eq!(offer.channel_id, ChannelId(2));
assert_eq!(
offer.flags,
OfferFlags::new().with_confidential_ring_buffer(true)
);
let offer = env.notifier.get_message::<protocol::OfferChannel>();
assert_eq!(offer.channel_id, ChannelId(3));
assert_eq!(
offer.flags,
OfferFlags::new()
.with_confidential_ring_buffer(true)
.with_confidential_external_memory(true)
);
env.notifier
.check_message(OutgoingMessage::new(&protocol::AllOffersDelivered {}));
}
#[test]
fn test_confidential_channels_unsupported() {
let mut env = TestEnv::new();
env.connect_trusted(Version::Copper, FeatureFlags::new());
assert_eq!(
env.version.unwrap(),
VersionInfo {
version: Version::Copper,
feature_flags: FeatureFlags::new()
}
);
env.offer_with_flags(1, OfferFlags::new().with_enumerate_device_interface(true)); env.offer_with_flags(
2,
OfferFlags::new()
.with_named_pipe_mode(true)
.with_confidential_ring_buffer(true)
.with_confidential_external_memory(true),
);
env.send_message(in_msg_ex(
protocol::MessageType::REQUEST_OFFERS,
protocol::RequestOffers {},
false,
true,
));
let offer = env.notifier.get_message::<protocol::OfferChannel>();
assert_eq!(offer.channel_id, ChannelId(1));
assert_eq!(
offer.flags,
OfferFlags::new().with_enumerate_device_interface(true)
);
let offer = env.notifier.get_message::<protocol::OfferChannel>();
assert_eq!(offer.channel_id, ChannelId(2));
assert_eq!(offer.flags, OfferFlags::new().with_named_pipe_mode(true));
env.notifier
.check_message(OutgoingMessage::new(&protocol::AllOffersDelivered {}));
}
#[test]
fn test_confidential_channels_untrusted() {
let mut env = TestEnv::new();
env.connect(
Version::Copper,
FeatureFlags::new().with_confidential_channels(true),
);
assert_eq!(
env.version.unwrap(),
VersionInfo {
version: Version::Copper,
feature_flags: FeatureFlags::new()
}
);
}
}