use self::msg::Msg;
use crate::api::GuestSaveRequest;
use crate::client::ModifyVtl2SettingsRequest;
use crate::error::IgvmAttestError;
use crate::error::TryIntoProtocolBool;
use chipset_resources::battery::HostBatteryUpdate;
use futures::FutureExt;
use futures::TryFutureExt;
use futures_concurrency::future::Race;
use get_protocol::HostRequests;
use get_protocol::MAX_PAYLOAD_SIZE;
use guid::Guid;
use inspect::Inspect;
use inspect::InspectMut;
use inspect_counters::Counter;
use mesh::RecvError;
use mesh::rpc::Rpc;
use mesh::rpc::RpcError;
use mesh::rpc::TryRpcSend;
use parking_lot::Mutex;
use std::cmp::min;
use std::collections::HashMap;
use std::collections::VecDeque;
use std::future::Future;
use std::future::pending;
use std::pin::Pin;
use std::sync::Arc;
use thiserror::Error;
use underhill_config::Vtl2SettingsErrorInfo;
use underhill_config::Vtl2SettingsErrorInfoVec;
use unicycle::FuturesUnordered;
use user_driver::DmaClient;
use vmbus_async::async_dgram::AsyncRecvExt;
use vmbus_async::async_dgram::AsyncSendExt;
use vmbus_async::pipe::MessagePipe;
use vmbus_ring::RingMem;
use vpci::bus_control::VpciBusEvent;
use zerocopy::FromBytes;
use zerocopy::FromZeros;
use zerocopy::Immutable;
use zerocopy::IntoBytes;
use zerocopy::KnownLayout;
#[derive(Debug, Error)]
pub(crate) enum FatalError {
#[error("open get pipe error")]
OpenPipe(#[source] vmbus_user_channel::Error),
#[error("get fd io error")]
FdIo(#[source] std::io::Error),
#[error("message size of {0} too small to read header")]
MessageSizeHeader(usize),
#[error("message size of {len} was not correct to read host response {response:?}")]
MessageSizeHostResponse { len: usize, response: HostRequests },
#[error("message size of {len} did not match dpsv2 size {expected}")]
DevicePlatformSettingsV2Payload { expected: usize, len: usize },
#[error("message size of {len} did not match vtl2 setting size {expected}")]
ModifyVtl2SettingsNotification { expected: usize, len: usize },
#[error("message size of {len} was not correct to read guest notification {notification:?}")]
MessageSizeGuestNotification {
len: usize,
notification: get_protocol::GuestNotifications,
},
#[error("response message version ({0:?}) not supported")]
InvalidResponseVersion(get_protocol::MessageVersions),
#[error("notification message version ({0:?}) not supported")]
InvalidGuestNotificationVersion(get_protocol::MessageVersions),
#[error("response message type ({0:?}) is not HOST_RESPONSE")]
InvalidResponseType(get_protocol::MessageTypes),
#[error("response header message ID {0:?} doesn't match request header message ID {1:?}")]
ResponseHeaderMismatchId(HostRequests, HostRequests),
#[error("invalid response")]
InvalidResponse,
#[error("version negotiation failed")]
VersionNegotiationFailed,
#[error("control receive failed")]
VersionNegotiationTryRecvFailed(#[source] RecvError),
#[error("received response with no pending request")]
NoPendingRequest,
#[error("failed to serialize VTL2 settings error info")]
Vtl2SettingsErrorInfoJson(#[source] serde_json::error::Error),
#[error("received too many guest notifications of kind {0:?} prior to downstream worker init")]
TooManyGuestNotifications(get_protocol::GuestNotifications),
#[error("failed to create IgvmAttest request because the gpa allocator is unavailable")]
GpaAllocatorUnavailable,
#[error("failed to allocate memory for attestation request")]
GpaMemoryAllocationError(#[source] anyhow::Error),
#[error("failed to deserialize the asynchronous `IGVM_ATTEST` response")]
DeserializeIgvmAttestResponse,
#[error(
"malformed `IGVM_ATTEST` response - reported size {response_size} was larger than maximum size {maximum_size}"
)]
InvalidIgvmAttestResponseSize {
response_size: usize,
maximum_size: usize,
},
#[error("received an `IGVM_ATTEST` response with no pending `IGVM_ATTEST` request")]
NoPendingIgvmAttestRequest,
}
fn validate_response(header: get_protocol::HeaderHostResponse) -> Result<(), FatalError> {
if header.message_version != get_protocol::MessageVersions::HEADER_VERSION_1 {
return Err(FatalError::InvalidResponseVersion(header.message_version));
}
if header.message_type != get_protocol::MessageTypes::HOST_RESPONSE {
return Err(FatalError::InvalidResponseType(header.message_type));
}
Ok(())
}
fn read_host_response_validated<T: FromBytes + Immutable + KnownLayout>(
buf: &[u8],
) -> Result<T, FatalError> {
let response = T::read_from_bytes(buf).map_err(|_| FatalError::MessageSizeHostResponse {
len: buf.len(),
response: get_protocol::HeaderHostResponse::read_from_bytes(buf)
.unwrap()
.message_id,
})?;
Ok(response)
}
fn read_guest_notification<T: FromBytes + Immutable + KnownLayout>(
notification: get_protocol::GuestNotifications,
buf: &[u8],
) -> Result<T, FatalError> {
T::read_from_bytes(buf).map_err(|_| FatalError::MessageSizeGuestNotification {
len: buf.len(),
notification,
})
}
pub(crate) mod msg {
use crate::api::GuestSaveRequest;
use crate::client::ModifyVtl2SettingsRequest;
use chipset_resources::battery::HostBatteryUpdate;
use guid::Guid;
use mesh::rpc::Rpc;
use std::sync::Arc;
use user_driver::DmaClient;
use vpci::bus_control::VpciBusEvent;
#[derive(Debug)]
pub struct VpciListenerRegistrationInput {
pub bus_instance_id: Guid,
pub sender: mesh::Sender<VpciBusEvent>,
}
#[derive(Debug)]
pub(crate) struct IgvmAttestRequestData {
pub(crate) agent_data: Vec<u8>,
pub(crate) report: Vec<u8>,
pub(crate) response_buffer_len: usize,
}
pub(crate) enum Msg {
FlushWrites(Rpc<(), ()>),
Inspect(inspect::Deferred),
SetGpaAllocator(Arc<dyn DmaClient>),
TakeGenIdReceiver(Rpc<(), Option<mesh::Receiver<[u8; 16]>>>),
TakeSaveRequestReceiver(Rpc<(), Option<mesh::Receiver<GuestSaveRequest>>>),
TakeVtl2SettingsReceiver(Rpc<(), Option<mesh::Receiver<ModifyVtl2SettingsRequest>>>),
TakeBatteryStatusReceiver(Rpc<(), Option<mesh::Receiver<HostBatteryUpdate>>>),
VpciListenerRegistration(Rpc<VpciListenerRegistrationInput, ()>),
VpciListenerDeregistration(Guid),
CompleteStartVtl0(Rpc<Option<String>, ()>),
CreateRamGpaRange(Rpc<CreateRamGpaRangeInput, get_protocol::CreateRamGpaRangeResponse>),
DevicePlatformSettingsV2(Rpc<(), Vec<u8>>),
GetVtl2SavedStateFromHost(Rpc<(), Result<Vec<u8>, ()>>),
GuestStateProtection(
Rpc<
Box<get_protocol::GuestStateProtectionRequest>,
get_protocol::GuestStateProtectionResponse,
>,
),
GuestStateProtectionById(Rpc<(), get_protocol::GuestStateProtectionByIdResponse>),
HostTime(Rpc<(), get_protocol::TimeResponse>),
IgvmAttest(Rpc<Box<IgvmAttestRequestData>, Result<Vec<u8>, crate::error::IgvmAttestError>>),
MapFramebuffer(Rpc<u64, get_protocol::MapFramebufferResponse>),
ResetRamGpaRange(Rpc<u32, get_protocol::ResetRamGpaRangeResponse>),
SendServicingState(Rpc<Result<Vec<u8>, String>, Result<(), ()>>),
UnmapFramebuffer(Rpc<(), get_protocol::UnmapFramebufferResponse>),
VgaProxyPciRead(Rpc<u16, get_protocol::VgaProxyPciReadResponse>),
VgaProxyPciWrite(Rpc<VgaProxyPciWriteInput, get_protocol::VgaProxyPciWriteResponse>),
VmgsFlush(Rpc<(), get_protocol::VmgsFlushResponse>),
VmgsGetDeviceInfo(Rpc<(), get_protocol::VmgsGetDeviceInfoResponse>),
VmgsRead(Rpc<VmgsReadInput, Result<Vec<u8>, get_protocol::VmgsReadResponse>>),
VmgsWrite(Rpc<VmgsWriteInput, Result<(), get_protocol::VmgsWriteResponse>>),
VpciDeviceBindingChange(
Rpc<VpciDeviceBindingChangeInput, get_protocol::VpciDeviceBindingChangeResponse>,
),
VpciDeviceControl(Rpc<VpciDeviceControlInput, get_protocol::VpciDeviceControlResponse>),
EventLog(get_protocol::EventLogId),
PowerState(PowerState),
ReportRestoreResultToHost(bool),
TripleFaultNotification(Vec<u8>),
VtlCrashNotification(get_protocol::VtlCrashNotification),
}
#[derive(Debug)]
pub enum PowerState {
PowerOff,
Reset,
Hibernate,
}
#[derive(Debug)]
pub struct VmgsReadInput {
pub sector_offset: u64,
pub sector_count: u32,
pub sector_size: u32,
}
#[derive(Debug)]
pub struct VmgsWriteInput {
pub sector_offset: u64,
pub buf: Vec<u8>,
pub sector_size: u32,
}
#[derive(Debug)]
pub struct VpciDeviceControlInput {
pub code: get_protocol::VpciDeviceControlCode,
pub bus_instance_id: Guid,
}
#[derive(Debug)]
pub struct VpciDeviceBindingChangeInput {
pub bus_instance_id: Guid,
pub binding_state: bool,
}
#[derive(Debug)]
pub struct VgaProxyPciWriteInput {
pub offset: u16,
pub value: u32,
}
#[derive(Debug)]
pub struct CreateRamGpaRangeInput {
pub slot: u32,
pub gpa_start: u64,
pub gpa_count: u64,
pub gpa_offset: u64,
pub flags: crate::api::CreateRamGpaRangeFlags,
}
}
#[derive(Inspect)]
#[inspect(external_tag)]
enum BufferedSender<const MAX_SIZE: usize, T> {
Buffered(#[inspect(rename = "len", with = "Vec::len")] Vec<T>),
Ready(#[inspect(skip)] mesh::Sender<T>),
}
struct BufferedSenderFull;
impl<const MAX_SIZE: usize, T: Send + 'static> BufferedSender<MAX_SIZE, T> {
fn new() -> Self {
BufferedSender::Buffered(Vec::new())
}
fn init_receiver(&mut self) -> Option<(mesh::Receiver<T>, usize)> {
match self {
BufferedSender::Buffered(buf) => {
let buf = std::mem::take(buf);
let (send, recv) = mesh::channel();
let n = buf.len();
for msg in buf {
send.send(msg)
}
*self = BufferedSender::Ready(send);
Some((recv, n))
}
BufferedSender::Ready(_) => None,
}
}
fn send(&mut self, msg: T) -> Result<(), BufferedSenderFull> {
match self {
BufferedSender::Buffered(buf) => {
if buf.len() == MAX_SIZE {
return Err(BufferedSenderFull);
}
buf.push(msg);
}
BufferedSender::Ready(sender) => {
sender.send(msg);
}
}
Ok(())
}
}
impl<const MAX_SIZE: usize, T: Send + 'static> TryRpcSend for &mut BufferedSender<MAX_SIZE, T> {
type Message = T;
type Error = BufferedSenderFull;
fn try_send_rpc(self, message: Self::Message) -> Result<(), Self::Error> {
self.send(message)
}
}
type GuestNotificationSender<T> = BufferedSender<16, T>;
fn log_buffered_guest_notifications<T>(
kind: get_protocol::GuestNotifications,
) -> impl FnOnce((mesh::Receiver<T>, usize)) -> mesh::Receiver<T> {
move |(recv, flushed)| {
if flushed > 0 {
tracing::info!(?kind, flushed, "flushing buffered guest notifications")
}
recv
}
}
#[derive(InspectMut)]
pub(crate) struct ProcessLoop<T: RingMem> {
#[inspect(mut)]
pipe: MessagePipe<T>,
#[inspect(skip)]
vtl2_settings_buf: Option<Vec<u8>>,
#[inspect(skip)]
host_requests: VecDeque<Pin<Box<dyn Future<Output = Result<(), FatalError>> + Send>>>,
#[inspect(skip)]
pipe_channels: PipeChannels,
#[inspect(skip)]
read_send: mesh::Sender<Vec<u8>>,
#[inspect(skip)]
write_recv: mesh::Receiver<WriteRequest>,
#[inspect(skip)]
igvm_attest_requests: VecDeque<Pin<Box<dyn Future<Output = Result<(), FatalError>> + Send>>>,
#[inspect(skip)]
igvm_attest_read_send: mesh::Sender<Vec<u8>>,
gpa_allocator: Option<Arc<dyn DmaClient>>,
stats: Stats,
guest_notification_listeners: GuestNotificationListeners,
#[inspect(skip)]
guest_notification_responses:
FuturesUnordered<Pin<Box<dyn Send + Future<Output = GuestNotificationResponse>>>>,
}
#[derive(Inspect)]
struct GuestNotificationListeners {
generation_id: GuestNotificationSender<[u8; 16]>,
save_request: GuestNotificationSender<GuestSaveRequest>,
vtl2_settings: GuestNotificationSender<ModifyVtl2SettingsRequest>,
#[inspect(skip)]
vpci: HashMap<Guid, mesh::Sender<VpciBusEvent>>,
battery_status: GuestNotificationSender<HostBatteryUpdate>,
}
enum GuestNotificationResponse {
ModifyVtl2Settings(Result<(), RpcError<Vec<Vtl2SettingsErrorInfo>>>),
}
#[derive(Default, Inspect)]
struct Stats {
#[inspect(with = "inspect_helpers::iter_by_debug_key")]
host_requests: HashMap<HostRequests, Counter>,
#[inspect(with = "inspect_helpers::iter_by_debug_key")]
host_responses: HashMap<HostRequests, Counter>,
#[inspect(with = "inspect_helpers::iter_by_debug_key")]
host_notifications: HashMap<get_protocol::HostNotifications, Counter>,
#[inspect(with = "inspect_helpers::iter_by_debug_key")]
guest_notifications: HashMap<get_protocol::GuestNotifications, Counter>,
}
mod inspect_helpers {
use super::*;
pub fn iter_by_debug_key<T: core::fmt::Debug>(map: &HashMap<T, Counter>) -> impl Inspect + '_ {
inspect::iter_by_key(map).map_key(|x| format!("{:?}", x))
}
}
struct HostRequestPipeAccess {
response_message_recv_mutex: Arc<Mutex<Option<mesh::Receiver<Vec<u8>>>>>,
response_message_recv: Option<mesh::Receiver<Vec<u8>>>,
request_message_send: mesh::Sender<WriteRequest>,
}
impl Drop for HostRequestPipeAccess {
fn drop(&mut self) {
*self.response_message_recv_mutex.lock() = Some(self.response_message_recv.take().unwrap());
}
}
struct PipeChannels {
response_message_recv: Arc<Mutex<Option<mesh::Receiver<Vec<u8>>>>>,
igvm_attest_response_message_recv: Arc<Mutex<Option<mesh::Receiver<Vec<u8>>>>>,
message_send: mesh::Sender<WriteRequest>,
}
enum WriteRequest {
Message(Vec<u8>),
Flush(Rpc<(), ()>),
}
impl HostRequestPipeAccess {
fn new(
response_message_recv_mutex: Arc<Mutex<Option<mesh::Receiver<Vec<u8>>>>>,
request_message_send: mesh::Sender<WriteRequest>,
) -> Self {
let response_message_recv = response_message_recv_mutex.lock().take().unwrap();
Self {
response_message_recv_mutex,
response_message_recv: Some(response_message_recv),
request_message_send,
}
}
fn send_message(&mut self, message: Vec<u8>) {
self.request_message_send
.send(WriteRequest::Message(message));
}
async fn recv_response(&mut self) -> Vec<u8> {
self.response_message_recv
.as_mut()
.unwrap()
.recv()
.await
.unwrap()
}
async fn recv_response_fixed_size<T: IntoBytes + FromBytes + Immutable + KnownLayout>(
&mut self,
id: HostRequests,
) -> Result<T, FatalError> {
let response = self.recv_response().await;
let header = get_protocol::HeaderHostRequest::read_from_prefix(response.as_bytes())
.unwrap()
.0; if id != header.message_id {
return Err(FatalError::ResponseHeaderMismatchId(header.message_id, id));
}
read_host_response_validated(&response)
}
async fn send_request_fixed_size<
T: IntoBytes + ?Sized + Immutable + KnownLayout,
U: IntoBytes + FromBytes + Immutable + KnownLayout,
>(
&mut self,
data: &T,
) -> Result<U, FatalError> {
self.send_message(data.as_bytes().to_vec());
let req_header = get_protocol::HeaderHostRequest::read_from_prefix(data.as_bytes())
.unwrap()
.0; self.recv_response_fixed_size(req_header.message_id).await
}
async fn send_failed_save_state<T: IntoBytes + ?Sized + Immutable + KnownLayout>(
&mut self,
data: &T,
) -> Result<(), FatalError> {
self.send_message(data.as_bytes().to_vec());
Ok(())
}
}
impl<T: RingMem> ProcessLoop<T> {
pub(crate) fn new(pipe: MessagePipe<T>) -> Self {
let (read_send, read_recv) = mesh::channel();
let (igvm_attest_read_send, igvm_attest_read_recv) = mesh::channel();
let (write_send, write_recv) = mesh::channel();
Self {
pipe,
stats: Default::default(),
guest_notification_responses: Default::default(),
vtl2_settings_buf: None,
host_requests: Default::default(),
igvm_attest_requests: Default::default(),
pipe_channels: PipeChannels {
response_message_recv: Arc::new(Mutex::new(Some(read_recv))),
igvm_attest_response_message_recv: Arc::new(Mutex::new(Some(
igvm_attest_read_recv,
))),
message_send: write_send,
},
read_send,
write_recv,
igvm_attest_read_send,
guest_notification_listeners: GuestNotificationListeners {
generation_id: GuestNotificationSender::new(),
vtl2_settings: GuestNotificationSender::new(),
save_request: GuestNotificationSender::new(),
vpci: HashMap::new(),
battery_status: GuestNotificationSender::new(),
},
gpa_allocator: None,
}
}
fn send_message(&mut self, buf: Vec<u8>) {
self.pipe_channels
.message_send
.send(WriteRequest::Message(buf));
}
async fn read_pipe(&mut self, buf: &mut [u8]) -> Result<usize, FatalError> {
self.pipe.recv(buf).await.map_err(FatalError::FdIo)
}
pub(crate) async fn negotiate_version(
&mut self,
) -> Result<get_protocol::ProtocolVersion, FatalError> {
for protocol in [get_protocol::ProtocolVersion::NICKEL_REV2] {
let version_request = get_protocol::VersionRequest::new(protocol);
self.pipe
.send(version_request.as_bytes())
.await
.map_err(FatalError::FdIo)?;
let mut response = get_protocol::VersionResponse::new_zeroed();
let len = self.read_pipe(response.as_mut_bytes()).await?;
validate_response(response.message_header)?;
if len != response.as_bytes().len() {
return Err(FatalError::MessageSizeHostResponse {
len,
response: HostRequests::VERSION,
});
}
if response.message_header.message_id != version_request.message_header.message_id {
return Err(FatalError::ResponseHeaderMismatchId(
response.message_header.message_id,
version_request.message_header.message_id,
));
}
let version_accepted: bool = response
.version_accepted
.into_bool()
.map_err(|_| FatalError::InvalidResponse)?;
if version_accepted {
tracing::info!("[GET] version negotiated: {:?}", protocol);
return Ok(protocol);
}
}
Err(FatalError::VersionNegotiationFailed)
}
pub(crate) async fn run(&mut self, mut recv: mesh::Receiver<Msg>) -> Result<(), FatalError> {
let mut buf: Box<[u8; get_protocol::MAX_MESSAGE_SIZE]> =
Box::new([0; get_protocol::MAX_MESSAGE_SIZE]);
let mut outgoing = Vec::new();
loop {
enum Event {
Msg(Msg),
Done,
Header(Result<usize, FatalError>),
GuestNotificationResponse(GuestNotificationResponse),
Failure(FatalError),
}
let event = {
let (mut read, mut write) = self.pipe.split();
let read_fd = read
.recv(buf.as_mut())
.map_err(FatalError::FdIo)
.map(Event::Header);
let recv_msg = recv.recv().map(|r| r.map_or(Event::Done, Event::Msg));
let send_next = async {
loop {
if !outgoing.is_empty() {
if let Ok((header, _)) =
get_protocol::HeaderRaw::read_from_prefix(outgoing.as_ref())
{
match header.message_type {
get_protocol::MessageTypes::HOST_REQUEST => {
(self.stats.host_requests)
.entry(HostRequests(header.message_id))
.or_default()
.increment();
}
get_protocol::MessageTypes::HOST_NOTIFICATION => {
(self.stats.host_notifications)
.entry(get_protocol::HostNotifications(
header.message_id,
))
.or_default()
.increment();
}
_ => {}
}
}
if let Err(err) = write.send(&outgoing).await {
return FatalError::FdIo(err);
}
outgoing.clear();
}
match self.write_recv.recv().await.unwrap() {
WriteRequest::Message(message) => outgoing = message,
WriteRequest::Flush(send) => send.complete(()),
}
}
}
.map(Event::Failure);
let run_next = async {
while let Some(request) = self.host_requests.front_mut() {
if let Err(e) = request.as_mut().await {
return e;
}
self.host_requests.pop_front();
if self
.pipe_channels
.response_message_recv
.lock()
.as_mut()
.unwrap()
.try_recv()
.is_ok()
{
return FatalError::NoPendingRequest;
}
}
pending().await
}
.map(Event::Failure);
let run_next_igvm_attest = async {
while let Some(request) = self.igvm_attest_requests.front_mut() {
if let Err(e) = request.as_mut().await {
return e;
}
self.igvm_attest_requests.pop_front();
if self
.pipe_channels
.igvm_attest_response_message_recv
.lock()
.as_mut()
.unwrap()
.try_recv()
.is_ok()
{
return FatalError::NoPendingRequest;
}
}
pending().await
}
.map(Event::Failure);
let recv_response = async {
if self.guest_notification_responses.is_empty() {
pending().await
} else {
Event::GuestNotificationResponse(
self.guest_notification_responses.next().await.unwrap(),
)
}
};
(
read_fd,
recv_msg,
send_next,
run_next,
run_next_igvm_attest,
recv_response,
)
.race()
.await
};
match event {
Event::Done => break Ok(()),
Event::Failure(err) => return Err(err),
Event::Msg(message) => {
self.process_host_request(message)?;
}
Event::Header(len) => {
let len = len?;
let buf = &buf[..len];
let header = get_protocol::HeaderRaw::read_from_prefix(buf)
.map_err(|_| FatalError::MessageSizeHeader(len))?
.0; match header.message_type {
get_protocol::MessageTypes::HOST_RESPONSE => {
(self.stats.host_responses)
.entry(HostRequests(header.message_id))
.or_default()
.increment();
self.handle_host_response(
header.try_into().expect("validated message type"),
buf,
)?;
}
get_protocol::MessageTypes::GUEST_NOTIFICATION => {
(self.stats.guest_notifications)
.entry(get_protocol::GuestNotifications(header.message_id))
.or_default()
.increment();
self.handle_guest_notification(
header.try_into().expect("validated message type"),
buf,
)?;
}
_ => panic!("Unexpected header type received: {:?}!", header),
}
}
Event::GuestNotificationResponse(response) => match response {
GuestNotificationResponse::ModifyVtl2Settings(response) => {
self.complete_modify_vtl2_settings(response)?
}
},
}
}
}
fn push_host_request_handler<F, Fut>(&mut self, f: F)
where
F: 'static + Send + FnOnce(HostRequestPipeAccess) -> Fut,
Fut: 'static + Future<Output = Result<(), FatalError>> + Send,
{
let message_recv_mutex = self.pipe_channels.response_message_recv.clone();
let message_send = self.pipe_channels.message_send.clone();
let fut = async { f(HostRequestPipeAccess::new(message_recv_mutex, message_send)).await };
self.host_requests.push_back(Box::pin(fut));
}
fn push_igvm_attest_request_handler<F, Fut>(&mut self, f: F)
where
F: 'static + Send + FnOnce(HostRequestPipeAccess) -> Fut,
Fut: 'static + Future<Output = Result<(), FatalError>> + Send,
{
let message_recv_mutex = self.pipe_channels.igvm_attest_response_message_recv.clone();
let message_send = self.pipe_channels.message_send.clone();
let fut = async { f(HostRequestPipeAccess::new(message_recv_mutex, message_send)).await };
self.igvm_attest_requests.push_back(Box::pin(fut));
}
fn push_basic_host_request_handler<Req, I, Resp>(
&mut self,
req: Rpc<I, Resp>,
f: impl 'static + Send + FnOnce(I) -> Req,
) where
Req: IntoBytes + 'static + Send + Sync + Immutable + KnownLayout,
I: 'static + Send,
Resp: 'static + IntoBytes + FromBytes + Send + Immutable + KnownLayout,
{
self.push_host_request_handler(async move |mut access| {
req.handle_must_succeed(async |input| access.send_request_fixed_size(&f(input)).await)
.await
});
}
fn process_host_request(&mut self, message: Msg) -> Result<(), FatalError> {
match message {
Msg::FlushWrites(rpc) => {
self.pipe_channels
.message_send
.send(WriteRequest::Flush(rpc));
}
Msg::Inspect(req) => {
req.inspect(self);
}
Msg::SetGpaAllocator(gpa_allocator) => {
self.gpa_allocator = Some(gpa_allocator);
}
Msg::TakeVtl2SettingsReceiver(req) => req.handle_sync(|()| {
self.guest_notification_listeners
.vtl2_settings
.init_receiver()
.map(log_buffered_guest_notifications(
get_protocol::GuestNotifications::MODIFY_VTL2_SETTINGS,
))
}),
Msg::TakeGenIdReceiver(req) => req.handle_sync(|()| {
self.guest_notification_listeners
.generation_id
.init_receiver()
.map(log_buffered_guest_notifications(
get_protocol::GuestNotifications::UPDATE_GENERATION_ID,
))
}),
Msg::TakeSaveRequestReceiver(req) => req.handle_sync(|()| {
self.guest_notification_listeners
.save_request
.init_receiver()
.map(log_buffered_guest_notifications(
get_protocol::GuestNotifications::SAVE_GUEST_VTL2_STATE,
))
}),
Msg::TakeBatteryStatusReceiver(req) => req.handle_sync(|()| {
self.guest_notification_listeners
.battery_status
.init_receiver()
.map(log_buffered_guest_notifications(
get_protocol::GuestNotifications::BATTERY_STATUS,
))
}),
Msg::VpciListenerRegistration(req) => {
req.handle_sync(|input| {
self.guest_notification_listeners
.vpci
.insert(input.bus_instance_id, input.sender);
});
}
Msg::VpciListenerDeregistration(bus_instance_id) => {
self.guest_notification_listeners
.vpci
.remove(&bus_instance_id);
}
Msg::DevicePlatformSettingsV2(req) => {
self.push_host_request_handler(|access| {
req.handle_must_succeed(async |()| {
request_device_platform_settings_v2(access).await
})
});
}
Msg::VmgsFlush(req) => {
self.push_basic_host_request_handler(req, |()| {
get_protocol::VmgsFlushRequest::new()
});
}
Msg::VmgsGetDeviceInfo(req) => {
self.push_basic_host_request_handler(req, |()| {
get_protocol::VmgsGetDeviceInfoRequest::new()
});
}
Msg::GetVtl2SavedStateFromHost(req) => self.push_host_request_handler(|access| {
req.handle_must_succeed(async |()| request_saved_state(access).await)
}),
Msg::GuestStateProtection(req) => {
self.push_basic_host_request_handler(req, |request| *request);
}
Msg::GuestStateProtectionById(req) => {
self.push_basic_host_request_handler(req, |()| {
get_protocol::GuestStateProtectionByIdRequest::new()
});
}
Msg::HostTime(req) => {
self.push_basic_host_request_handler(req, |()| get_protocol::TimeRequest::new());
}
Msg::IgvmAttest(req) => {
let shared_pool_allocator = self.gpa_allocator.clone();
self.push_igvm_attest_request_handler(|access| {
req.handle_must_succeed(async |request| {
request_igvm_attest(access, *request, shared_pool_allocator).await
})
});
}
Msg::VmgsRead(req) => {
self.push_host_request_handler(|access| {
req.handle_must_succeed(async |input| request_vmgs_read(access, input).await)
});
}
Msg::VmgsWrite(req) => {
self.push_host_request_handler(|access| {
req.handle_must_succeed(async |input| request_vmgs_write(access, input).await)
});
}
Msg::VpciDeviceControl(req) => {
self.push_basic_host_request_handler(req, |input| {
get_protocol::VpciDeviceControlRequest::new(input.code, input.bus_instance_id)
});
}
Msg::VpciDeviceBindingChange(req) => {
self.push_basic_host_request_handler(req, |input| {
get_protocol::VpciDeviceBindingChangeRequest::new(
input.bus_instance_id,
input.binding_state,
)
});
}
Msg::VgaProxyPciRead(req) => {
self.push_basic_host_request_handler(req, |input| {
get_protocol::VgaProxyPciReadRequest::new(input)
});
}
Msg::VgaProxyPciWrite(req) => {
self.push_basic_host_request_handler(req, |input| {
get_protocol::VgaProxyPciWriteRequest::new(input.offset, input.value)
});
}
Msg::MapFramebuffer(req) => {
self.push_basic_host_request_handler(req, |input| {
get_protocol::MapFramebufferRequest::new(input)
});
}
Msg::UnmapFramebuffer(req) => {
self.push_basic_host_request_handler(req, |()| {
get_protocol::UnmapFramebufferRequest::new()
});
}
Msg::CreateRamGpaRange(req) => {
self.push_basic_host_request_handler(req, |input| {
get_protocol::CreateRamGpaRangeRequest::new(
input.slot,
input.gpa_start,
input.gpa_count,
input.gpa_offset,
input.flags,
)
});
}
Msg::ResetRamGpaRange(req) => {
self.push_basic_host_request_handler(req, |input| {
get_protocol::ResetRamGpaRangeRequest::new(input)
});
}
Msg::SendServicingState(req) => self.push_host_request_handler(move |access| {
req.handle_must_succeed(async |data| {
request_send_servicing_state(access, data).await
})
}),
Msg::CompleteStartVtl0(rpc) => {
let (input, res) = rpc.split();
self.complete_start_vtl0(input)?;
res.complete(());
}
Msg::PowerState(state) => {
self.push_host_request_handler(async move |mut access| {
let message = match state {
msg::PowerState::PowerOff => get_protocol::PowerOffNotification::new(false)
.as_bytes()
.to_vec(),
msg::PowerState::Hibernate => get_protocol::PowerOffNotification::new(true)
.as_bytes()
.to_vec(),
msg::PowerState::Reset => {
get_protocol::ResetNotification::new().as_bytes().to_vec()
}
};
access.send_message(message);
Ok(())
})
}
Msg::EventLog(event_log_id) => {
self.send_message(
get_protocol::EventLogNotification::new(event_log_id)
.as_bytes()
.to_vec(),
);
}
Msg::ReportRestoreResultToHost(success) => self.report_restore_result_to_host(success),
Msg::VtlCrashNotification(crash_notification) => {
self.send_message(crash_notification.as_bytes().to_vec());
}
Msg::TripleFaultNotification(triple_fault_notification) => {
self.send_message(triple_fault_notification);
}
}
Ok(())
}
fn handle_guest_notification(
&mut self,
header: get_protocol::HeaderGuestNotification,
buf: &[u8],
) -> Result<(), FatalError> {
use get_protocol::GuestNotifications;
if header.message_version != get_protocol::MessageVersions::HEADER_VERSION_1 {
tracing::error!(
msg = ?buf,
version = ?header.message_version,
"invalid header version in guest notification",
);
return Err(FatalError::InvalidGuestNotificationVersion(
header.message_version,
));
}
let id = header.message_id;
match id {
GuestNotifications::UPDATE_GENERATION_ID => {
self.handle_update_generation_id(read_guest_notification(id, buf)?)?;
}
GuestNotifications::SAVE_GUEST_VTL2_STATE => {
self.handle_save_state_notification(read_guest_notification(id, buf)?)?;
}
GuestNotifications::MODIFY_VTL2_SETTINGS => {
self.handle_modify_vtl2_settings_notification(buf)?;
}
GuestNotifications::MODIFY_VTL2_SETTINGS_REV1 => {
self.handle_modify_vtl2_settings_rev1_notification(buf)?;
}
GuestNotifications::VPCI_DEVICE_NOTIFICATION => {
self.handle_vpci_device_notification(read_guest_notification(id, buf)?)?;
}
GuestNotifications::BATTERY_STATUS => {
self.handle_battery_status_notification(read_guest_notification(id, buf)?)?;
}
invalid_notification => {
tracing::error!(
"[HOST GET] ignoring invalid guest notification: {:?}",
invalid_notification
);
}
}
Ok(())
}
fn handle_host_response(
&mut self,
header: get_protocol::HeaderHostResponse,
buf: &[u8],
) -> Result<(), FatalError> {
if self.host_requests.is_empty() && self.igvm_attest_requests.is_empty() {
return Err(FatalError::NoPendingRequest);
}
validate_response(header)?;
if header.message_id == HostRequests::IGVM_ATTEST {
if !self.igvm_attest_requests.is_empty() {
self.igvm_attest_read_send.send(buf.to_vec());
return Ok(());
}
return Err(FatalError::NoPendingIgvmAttestRequest);
}
self.read_send.send(buf.to_vec());
Ok(())
}
fn handle_update_generation_id(
&mut self,
response: get_protocol::UpdateGenerationId,
) -> Result<(), FatalError> {
self.guest_notification_listeners
.generation_id
.send(response.generation_id)
.map_err(|_| {
FatalError::TooManyGuestNotifications(
get_protocol::GuestNotifications::UPDATE_GENERATION_ID,
)
})
}
fn handle_save_state_notification(
&mut self,
notification_header: get_protocol::SaveGuestVtl2StateNotification,
) -> Result<(), FatalError> {
self.guest_notification_listeners
.save_request
.send(GuestSaveRequest {
correlation_id: notification_header.correlation_id,
deadline: std::time::Instant::now()
+ std::time::Duration::from_secs(notification_header.timeout_hint_secs as u64),
capabilities_flags: notification_header.capabilities_flags,
})
.map_err(|_| {
FatalError::TooManyGuestNotifications(
get_protocol::GuestNotifications::SAVE_GUEST_VTL2_STATE,
)
})
}
fn handle_modify_vtl2_settings_notification(&mut self, buf: &[u8]) -> Result<(), FatalError> {
let (request, remaining) =
get_protocol::ModifyVtl2SettingsNotification::read_from_prefix(buf).map_err(|_| {
FatalError::MessageSizeGuestNotification {
len: buf.len(),
notification: get_protocol::GuestNotifications::MODIFY_VTL2_SETTINGS,
}
})?; let expected_len = request.size as usize;
if remaining.len() != expected_len {
return Err(FatalError::ModifyVtl2SettingsNotification {
expected: expected_len,
len: remaining.len(),
});
}
self.send_vtl2_settings(
remaining.into(),
get_protocol::GuestNotifications::MODIFY_VTL2_SETTINGS,
)
}
fn handle_modify_vtl2_settings_rev1_notification(
&mut self,
buf: &[u8],
) -> Result<(), FatalError> {
let (request, remaining) =
get_protocol::ModifyVtl2SettingsRev1Notification::read_from_prefix(buf).map_err(
|_| FatalError::MessageSizeGuestNotification {
len: buf.len(),
notification: get_protocol::GuestNotifications::MODIFY_VTL2_SETTINGS_REV1,
},
)?; let expected_len = request.size as usize;
if remaining.len() != expected_len {
return Err(FatalError::ModifyVtl2SettingsNotification {
expected: expected_len,
len: remaining.len(),
});
}
let vtl2_settings_buf = self.vtl2_settings_buf.take();
let mut vtl2_settings_buf = vtl2_settings_buf.unwrap_or_default();
vtl2_settings_buf.extend_from_slice(remaining);
match request.payload_state {
get_protocol::LargePayloadState::MORE => {
self.vtl2_settings_buf = Some(vtl2_settings_buf);
Ok(())
}
get_protocol::LargePayloadState::END => self.send_vtl2_settings(
vtl2_settings_buf,
get_protocol::GuestNotifications::MODIFY_VTL2_SETTINGS_REV1,
),
_ => Err(FatalError::InvalidResponse),
}
}
fn send_vtl2_settings(
&mut self,
vtl2_settings_buf: Vec<u8>,
kind: get_protocol::GuestNotifications,
) -> Result<(), FatalError> {
let res = self
.guest_notification_listeners
.vtl2_settings
.try_call_failable(ModifyVtl2SettingsRequest, vtl2_settings_buf)
.map_err(|_| FatalError::TooManyGuestNotifications(kind))?
.map(GuestNotificationResponse::ModifyVtl2Settings)
.boxed();
self.guest_notification_responses.push(res);
Ok(())
}
fn handle_vpci_device_notification(
&mut self,
notification: get_protocol::VpciDeviceNotification,
) -> Result<(), FatalError> {
tracing::debug!(
"Received VPCI device notification, bus id = {}, code = {:?}",
notification.bus_instance_id,
notification.code
);
if let Some(sender) = self
.guest_notification_listeners
.vpci
.get(¬ification.bus_instance_id)
{
let bus_event = match notification.code {
get_protocol::VpciDeviceNotificationCode::ENUMERATED => {
VpciBusEvent::DeviceEnumerated
}
get_protocol::VpciDeviceNotificationCode::PREPARE_FOR_REMOVAL => {
VpciBusEvent::PrepareForRemoval
}
_ => return Err(FatalError::InvalidResponse),
};
sender.send(bus_event);
}
Ok(())
}
fn handle_battery_status_notification(
&mut self,
response: get_protocol::BatteryStatusNotification,
) -> Result<(), FatalError> {
self.guest_notification_listeners
.battery_status
.send(HostBatteryUpdate {
battery_present: response.flags.battery_present(),
charging: response.flags.charging(),
discharging: response.flags.discharging(),
rate: response.rate,
remaining_capacity: response.remaining_capacity,
max_capacity: response.max_capacity,
ac_online: response.flags.ac_online(),
})
.map_err(|_| {
FatalError::TooManyGuestNotifications(
get_protocol::GuestNotifications::BATTERY_STATUS,
)
})
}
fn complete_modify_vtl2_settings(
&mut self,
result: Result<(), RpcError<Vec<Vtl2SettingsErrorInfo>>>,
) -> Result<(), FatalError> {
let errors = result.map_err(|err| match err {
RpcError::Call(err) => err,
RpcError::Channel(err) => vec![Vtl2SettingsErrorInfo::new(
underhill_config::Vtl2SettingsErrorCode::InternalFailure,
err.to_string(),
)],
});
let (status, errors_json) = match errors {
Ok(()) => (get_protocol::ModifyVtl2SettingsStatus::SUCCESS, None),
Err(errors) => {
let errors = Vtl2SettingsErrorInfoVec { errors };
tracing::error!(
errors = &errors as &dyn std::error::Error,
"failed to modify vtl2 settings"
);
(
get_protocol::ModifyVtl2SettingsStatus::FAILURE,
Some(
serde_json::to_string(&errors.errors)
.map_err(FatalError::Vtl2SettingsErrorInfoJson)?,
),
)
}
};
let errors_bytes = errors_json.as_ref().map(|json| json.as_bytes());
let notification = get_protocol::ModifyVtl2SettingsCompleteNotification::new(
status,
errors_bytes.map_or(0, |v| v.len()) as u32,
);
let buf = [
notification.as_bytes(),
errors_bytes.unwrap_or(&[]).as_bytes(),
]
.concat();
self.send_message(buf);
Ok(())
}
fn complete_start_vtl0(&mut self, error_msg: Option<String>) -> Result<(), FatalError> {
let status = if error_msg.is_none() {
get_protocol::StartVtl0Status::SUCCESS
} else {
get_protocol::StartVtl0Status::FAILURE
};
let error_bytes = error_msg.as_ref().map(|str| str.as_bytes());
let notification = get_protocol::StartVtl0CompleteNotification::new(
status,
error_bytes.map_or(0, |v| v.len()) as u32,
);
let buf = [
notification.as_bytes(),
error_bytes.unwrap_or(&[]).as_bytes(),
]
.concat();
self.send_message(buf);
Ok(())
}
fn report_restore_result_to_host(&mut self, success: bool) {
let result = if success {
get_protocol::GuestVtl2SaveRestoreStatus::SUCCESS
} else {
get_protocol::GuestVtl2SaveRestoreStatus::FAILURE
};
let host_notification = get_protocol::RestoreGuestVtl2StateHostNotification::new(result);
self.send_message(host_notification.as_bytes().to_vec());
}
}
async fn request_device_platform_settings_v2(
mut access: HostRequestPipeAccess,
) -> Result<Vec<u8>, FatalError> {
access.send_message(
get_protocol::DevicePlatformSettingsRequestV2::new()
.as_bytes()
.to_vec(),
);
let mut result = Vec::new();
loop {
let buf = access.recv_response().await;
let header = get_protocol::HeaderHostResponse::read_from_prefix(buf.as_slice())
.unwrap()
.0; match header.message_id {
HostRequests::DEVICE_PLATFORM_SETTINGS_V2 => {
let (response, remaining) =
get_protocol::DevicePlatformSettingsResponseV2::read_from_prefix(
buf.as_slice(),
)
.map_err(|_| FatalError::MessageSizeHostResponse {
len: buf.len(),
response: HostRequests::DEVICE_PLATFORM_SETTINGS_V2,
})?; if response.size as usize != remaining.len() {
return Err(FatalError::DevicePlatformSettingsV2Payload {
expected: response.size as usize,
len: remaining.len(),
});
}
result.extend(remaining);
break;
}
HostRequests::DEVICE_PLATFORM_SETTINGS_V2_REV1 => {
let (response, remaining) =
get_protocol::DevicePlatformSettingsResponseV2Rev1::read_from_prefix(
buf.as_slice(),
)
.map_err(|_| FatalError::MessageSizeGuestNotification {
len: buf.len(),
notification: get_protocol::GuestNotifications::MODIFY_VTL2_SETTINGS_REV1,
})?; if remaining.len() != (response.size as usize) {
return Err(FatalError::DevicePlatformSettingsV2Payload {
expected: response.size as usize,
len: remaining.len(),
});
}
result.extend(remaining);
if response.payload_state == get_protocol::LargePayloadState::END {
break;
}
}
_ => {
return Err(FatalError::ResponseHeaderMismatchId(
header.message_id,
HostRequests::DEVICE_PLATFORM_SETTINGS_V2,
));
}
}
}
Ok(result)
}
async fn request_vmgs_read(
mut access: HostRequestPipeAccess,
input: msg::VmgsReadInput,
) -> Result<Result<Vec<u8>, get_protocol::VmgsReadResponse>, FatalError> {
let msg::VmgsReadInput {
sector_offset,
sector_count,
sector_size,
} = input;
access.send_message(
get_protocol::VmgsReadRequest::new(
get_protocol::VmgsReadFlags::NONE,
sector_offset,
sector_count,
)
.as_bytes()
.to_vec(),
);
let buf = access.recv_response().await;
let vmgs_buf_len = (sector_count * sector_size) as usize;
let (response, remaining) = get_protocol::VmgsReadResponse::read_from_prefix(buf.as_slice())
.map_err(|_| FatalError::MessageSizeHostResponse {
len: buf.len(),
response: HostRequests::VMGS_READ,
})?; if response.message_header.message_id != HostRequests::VMGS_READ {
return Err(FatalError::ResponseHeaderMismatchId(
response.message_header.message_id,
HostRequests::VMGS_READ,
));
}
if response.status != get_protocol::VmgsIoStatus::SUCCESS {
return Ok(Err(response));
}
if remaining.len() != vmgs_buf_len {
return Err(FatalError::MessageSizeHostResponse {
len: buf.len(),
response: HostRequests::VMGS_READ,
});
}
Ok(Ok(remaining.to_vec()))
}
async fn request_vmgs_write(
mut access: HostRequestPipeAccess,
input: msg::VmgsWriteInput,
) -> Result<Result<(), get_protocol::VmgsWriteResponse>, FatalError> {
let request = get_protocol::VmgsWriteRequest::new(
get_protocol::VmgsWriteFlags::NONE,
input.sector_offset,
(input.buf.len() / input.sector_size as usize) as u32,
);
let message = [request.as_bytes(), &input.buf].concat();
let response: get_protocol::VmgsWriteResponse =
access.send_request_fixed_size(message.as_slice()).await?;
if response.status != get_protocol::VmgsIoStatus::SUCCESS {
return Ok(Err(response));
}
Ok(Ok(()))
}
async fn request_send_servicing_state(
mut access: HostRequestPipeAccess,
result: Result<Vec<u8>, String>,
) -> Result<Result<(), ()>, FatalError> {
let saved_state_buf = match result {
Ok(saved_state_buf) => saved_state_buf,
Err(_err) => {
return access
.send_failed_save_state(&get_protocol::SaveGuestVtl2StateRequest::new(
get_protocol::GuestVtl2SaveRestoreStatus::FAILURE,
))
.await
.map(Ok);
}
};
let mut saved_state_bytes_written = 0;
let saved_state_size = saved_state_buf.len();
const HEADER_SIZE: usize = size_of::<get_protocol::SaveGuestVtl2StateRequest>();
while saved_state_bytes_written < saved_state_size {
let status_code = if saved_state_bytes_written + MAX_PAYLOAD_SIZE >= saved_state_size {
get_protocol::GuestVtl2SaveRestoreStatus::SUCCESS
} else {
get_protocol::GuestVtl2SaveRestoreStatus::MORE_DATA
};
let host_request_header = get_protocol::SaveGuestVtl2StateRequest::new(status_code);
let payload_len = min(
saved_state_size - saved_state_bytes_written,
MAX_PAYLOAD_SIZE,
);
tracing::debug!(
"More data? {:?} saved_state_bytes_written {} saved_state_size {}, payload_len {}",
status_code,
saved_state_bytes_written,
saved_state_size,
payload_len
);
let mut message = vec![0; HEADER_SIZE + payload_len];
message[..HEADER_SIZE].copy_from_slice(host_request_header.as_bytes());
message[HEADER_SIZE..].copy_from_slice(
saved_state_buf[saved_state_bytes_written..][..payload_len].as_bytes(),
);
access.send_message(message);
saved_state_bytes_written += payload_len;
}
tracing::debug!("Done writing saved state, awaiting host response");
let response: get_protocol::SaveGuestVtl2StateResponse = access
.recv_response_fixed_size(HostRequests::SAVE_GUEST_VTL2_STATE)
.await?;
match response.save_status {
get_protocol::GuestVtl2SaveRestoreStatus::SUCCESS => Ok(Ok(())),
get_protocol::GuestVtl2SaveRestoreStatus::FAILURE => Ok(Err(())),
_ => Err(FatalError::InvalidResponse),
}
}
async fn request_saved_state(
mut access: HostRequestPipeAccess,
) -> Result<Result<Vec<u8>, ()>, FatalError> {
access.send_message(
get_protocol::RestoreGuestVtl2StateRequest::new(
get_protocol::GuestVtl2SaveRestoreStatus::REQUEST_DATA,
)
.as_bytes()
.to_vec(),
);
let mut saved_state_buf = Vec::<u8>::new();
loop {
let message_buf = access.recv_response().await;
let (response_header, remaining) =
get_protocol::RestoreGuestVtl2StateResponse::read_from_prefix(message_buf.as_slice())
.map_err(|_| FatalError::MessageSizeHostResponse {
len: message_buf.len(),
response: HostRequests::RESTORE_GUEST_VTL2_STATE,
})?; let message_id = response_header.message_header.message_id;
if message_id != HostRequests::RESTORE_GUEST_VTL2_STATE {
return Err(FatalError::ResponseHeaderMismatchId(
message_id,
HostRequests::RESTORE_GUEST_VTL2_STATE,
));
}
if response_header.data_length as usize != remaining.len() {
return Err(FatalError::InvalidResponse);
}
match response_header.restore_status {
get_protocol::GuestVtl2SaveRestoreStatus::SUCCESS => {
saved_state_buf.extend_from_slice(remaining);
break;
}
get_protocol::GuestVtl2SaveRestoreStatus::MORE_DATA => {
saved_state_buf.extend_from_slice(remaining);
}
get_protocol::GuestVtl2SaveRestoreStatus::FAILURE => {
return Ok(Err(()));
}
_ => return Err(FatalError::InvalidResponse),
}
}
Ok(Ok(saved_state_buf))
}
async fn request_igvm_attest(
mut access: HostRequestPipeAccess,
request: msg::IgvmAttestRequestData,
gpa_allocator: Option<Arc<dyn DmaClient>>,
) -> Result<Result<Vec<u8>, IgvmAttestError>, FatalError> {
let allocator = gpa_allocator.ok_or(FatalError::GpaAllocatorUnavailable)?;
let dma_size = request.response_buffer_len;
let mem = allocator
.allocate_dma_buffer(dma_size)
.map_err(FatalError::GpaMemoryAllocationError)?;
let pfn_bias = mem.pfn_bias();
let gpas = mem
.pfns()
.iter()
.map(|pfn| (pfn & !(pfn_bias)) * hvdef::HV_PAGE_SIZE)
.collect::<Vec<_>>();
let mut shared_gpa = [0u64; get_protocol::IGVM_ATTEST_MSG_MAX_SHARED_GPA];
shared_gpa[..gpas.len()].copy_from_slice(&gpas);
let request =
match prepare_igvm_attest_request(shared_gpa, &request.agent_data, &request.report) {
Ok(request) => request,
Err(e) => return Ok(Err(e)),
};
access.send_message(request.as_bytes().to_vec());
let response = access.recv_response().await;
let Ok((response, _)) = get_protocol::IgvmAttestResponse::read_from_prefix(&response) else {
Err(FatalError::DeserializeIgvmAttestResponse)?
};
let response_length = response.length as usize;
if response_length == get_protocol::IGVM_ATTEST_VMWP_GENERIC_ERROR_CODE {
return Ok(Err(IgvmAttestError::IgvmAgentGenericError));
} else if response_length > dma_size {
Err(FatalError::InvalidIgvmAttestResponseSize {
response_size: response_length,
maximum_size: dma_size,
})?
}
let mut buffer = vec![0u8; dma_size];
mem.read_at(0, &mut buffer);
buffer.truncate(response_length);
Ok(Ok(buffer))
}
fn prepare_igvm_attest_request(
shared_gpa: [u64; get_protocol::IGVM_ATTEST_MSG_MAX_SHARED_GPA],
agent_data: &[u8],
report: &[u8],
) -> Result<get_protocol::IgvmAttestRequest, IgvmAttestError> {
use get_protocol::IGVM_ATTEST_MSG_REQ_AGENT_DATA_MAX_SIZE;
use get_protocol::IGVM_ATTEST_MSG_REQ_REPORT_MAX_SIZE;
if agent_data.len() > IGVM_ATTEST_MSG_REQ_AGENT_DATA_MAX_SIZE {
Err(IgvmAttestError::InvalidAgentDataSize {
input_size: agent_data.len(),
expected_size: IGVM_ATTEST_MSG_REQ_AGENT_DATA_MAX_SIZE,
})?
}
if report.len() > IGVM_ATTEST_MSG_REQ_REPORT_MAX_SIZE {
Err(IgvmAttestError::InvalidReportSize {
input_size: report.len(),
expected_size: IGVM_ATTEST_MSG_REQ_REPORT_MAX_SIZE,
})?
}
let mut agent_data_max = [0u8; IGVM_ATTEST_MSG_REQ_AGENT_DATA_MAX_SIZE];
agent_data_max[..agent_data.len()].copy_from_slice(agent_data);
let mut report_max = [0u8; IGVM_ATTEST_MSG_REQ_REPORT_MAX_SIZE];
report_max[..report.len()].copy_from_slice(report);
Ok(get_protocol::IgvmAttestRequest::new(
shared_gpa,
shared_gpa.len() as u32,
agent_data_max,
agent_data.len() as u32,
report_max,
report.len() as u32,
))
}