1#![forbid(unsafe_code)]
7
8pub use hyperv_ic_protocol::shutdown::INSTANCE_ID;
9pub use hyperv_ic_protocol::shutdown::INTERFACE_ID;
10
11use guid::Guid;
12use hyperv_ic_protocol::FRAMEWORK_VERSION_1;
13use hyperv_ic_protocol::FRAMEWORK_VERSION_3;
14use hyperv_ic_protocol::Status;
15use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_1;
16use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_3;
17use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_3_1;
18use hyperv_ic_protocol::shutdown::SHUTDOWN_VERSION_3_2;
19use hyperv_ic_resources::shutdown::ShutdownParams;
20use hyperv_ic_resources::shutdown::ShutdownResult;
21use hyperv_ic_resources::shutdown::ShutdownType;
22use inspect::Inspect;
23use inspect::InspectMut;
24use mesh::rpc::Rpc;
25use mesh::rpc::RpcSend;
26use std::io::IoSlice;
27use std::mem::size_of_val;
28use task_control::Cancelled;
29use task_control::StopTask;
30use thiserror::Error;
31use vmbus_async::async_dgram::AsyncRecvExt;
32use vmbus_async::async_dgram::AsyncSendExt;
33use vmbus_async::pipe::MessagePipe;
34use vmbus_channel::RawAsyncChannel;
35use vmbus_channel::channel::ChannelOpenError;
36use vmbus_relay_intercept_device::OfferResponse;
37use vmbus_relay_intercept_device::SaveRestoreSimpleVmbusClientDevice;
38use vmbus_relay_intercept_device::SimpleVmbusClientDevice;
39use vmbus_relay_intercept_device::SimpleVmbusClientDeviceAsync;
40use vmbus_relay_intercept_device::ring_buffer::MemoryBlockRingBuffer;
41use vmbus_ring::RingMem;
42use vmcore::save_restore::NoSavedState;
43use zerocopy::FromBytes;
44use zerocopy::FromZeros;
45use zerocopy::IntoBytes;
46
47#[derive(InspectMut)]
49pub struct ShutdownGuestIc {
50 #[inspect(skip)]
51 send_shutdown_notification: mesh::Sender<Rpc<ShutdownParams, ShutdownResult>>,
52 #[inspect(skip)]
53 recv_shutdown_notification: Option<mesh::Receiver<Rpc<ShutdownParams, ShutdownResult>>>,
54}
55
56#[derive(Debug, Inspect)]
57#[inspect(tag = "channel_state")]
58enum ShutdownGuestChannelState {
59 NegotiateVersion,
60 Running {
61 #[inspect(display)]
62 framework_version: hyperv_ic_protocol::Version,
63 #[inspect(display)]
64 message_version: hyperv_ic_protocol::Version,
65 },
66}
67
68#[derive(InspectMut)]
70pub struct ShutdownGuestChannel {
71 state: ShutdownGuestChannelState,
73 #[inspect(mut)]
75 pipe: MessagePipe<MemoryBlockRingBuffer>,
76}
77
78#[derive(Debug, Error)]
79enum Error {
80 #[error("ring buffer error")]
81 Ring(#[source] std::io::Error),
82 #[error("truncated message")]
83 TruncatedMessage,
84}
85
86impl ShutdownGuestIc {
87 pub fn new() -> Self {
89 let (send_shutdown_notification, recv_shutdown_notification) = mesh::channel();
90 Self {
91 send_shutdown_notification,
92 recv_shutdown_notification: Some(recv_shutdown_notification),
93 }
94 }
95
96 pub fn get_shutdown_notifier(&mut self) -> mesh::Receiver<Rpc<ShutdownParams, ShutdownResult>> {
98 self.recv_shutdown_notification
99 .take()
100 .expect("can only be called once")
101 }
102}
103
104impl ShutdownGuestChannel {
105 fn new(pipe: MessagePipe<MemoryBlockRingBuffer>) -> Self {
106 Self {
107 state: ShutdownGuestChannelState::NegotiateVersion,
108 pipe,
109 }
110 }
111
112 async fn process(&mut self, ic: &mut ShutdownGuestIc) -> Result<(), Error> {
113 loop {
114 match read_from_pipe(&mut self.pipe).await {
115 Ok(buf) => {
116 self.handle_host_message(&buf, ic).await;
117 }
118 Err(err) => {
119 tracelimit::error_ratelimited!(
120 err = &err as &dyn std::error::Error,
121 "reading shutdown packet from host",
122 );
123 }
124 }
125 }
126 }
127
128 async fn handle_host_message(&mut self, buf: &[u8], ic: &ShutdownGuestIc) {
129 let (header, rest) = match hyperv_ic_protocol::Header::read_from_prefix(buf).ok() {
131 Some((h, r)) => (h, r),
132 None => {
133 tracelimit::error_ratelimited!("invalid shutdown packet from host",);
134 return;
135 }
136 };
137 match header.message_type {
138 hyperv_ic_protocol::MessageType::VERSION_NEGOTIATION => {
139 self.state = ShutdownGuestChannelState::NegotiateVersion;
143 if let Err(err) = self.handle_version_negotiation(&header, rest).await {
144 tracelimit::error_ratelimited!(
145 err = &err as &dyn std::error::Error,
146 "Failed version negotiation"
147 );
148 }
149 }
150 hyperv_ic_protocol::MessageType::SHUTDOWN
151 if matches!(self.state, ShutdownGuestChannelState::Running { .. }) =>
152 {
153 if let Err(err) = self.handle_shutdown_notification(&header, rest, ic).await {
154 tracelimit::error_ratelimited!(
155 err = &err as &dyn std::error::Error,
156 "Failed processing shutdown message"
157 );
158 }
159 }
160 _ => {
161 tracelimit::error_ratelimited!(r#type = ?header.message_type, state = ?self.state, "Unrecognized packet");
162 }
163 }
164 }
165
166 fn find_latest_supported_version<'a>(
167 buf: &'a [u8],
168 count: usize,
169 supported: &[hyperv_ic_protocol::Version],
170 ) -> (Option<hyperv_ic_protocol::Version>, &'a [u8]) {
171 let mut rest = buf;
172 let mut next_version;
173 let mut latest_version = None;
174 for _ in 0..count {
175 (next_version, rest) = match hyperv_ic_protocol::Version::read_from_prefix(rest).ok() {
177 Some((n, r)) => (n, r),
178 None => {
179 tracelimit::error_ratelimited!("truncated message version list");
180 return (latest_version, rest);
181 }
182 };
183 for known in supported {
184 if known.major == next_version.major && known.minor == next_version.minor {
185 if latest_version.is_some() {
186 if next_version.major >= latest_version.unwrap().major {
187 if next_version.major > latest_version.unwrap().major
188 || next_version.minor > latest_version.unwrap().minor
189 {
190 latest_version = Some(next_version);
191 }
192 }
193 } else {
194 latest_version = Some(next_version);
195 }
196 }
197 }
198 }
199 (latest_version, rest)
200 }
201
202 async fn handle_version_negotiation(
203 &mut self,
204 header: &hyperv_ic_protocol::Header,
205 msg: &[u8],
206 ) -> Result<(), Error> {
207 const FRAMEWORK_VERSIONS: &[hyperv_ic_protocol::Version] =
208 &[FRAMEWORK_VERSION_1, FRAMEWORK_VERSION_3];
209
210 const SHUTDOWN_VERSIONS: &[hyperv_ic_protocol::Version] = &[
211 SHUTDOWN_VERSION_1,
212 SHUTDOWN_VERSION_3,
213 SHUTDOWN_VERSION_3_1,
214 SHUTDOWN_VERSION_3_2,
215 ];
216
217 let (prefix, rest) = hyperv_ic_protocol::NegotiateMessage::read_from_prefix(msg)
218 .map_err(|_| Error::TruncatedMessage)?; let (latest_framework_version, rest) = Self::find_latest_supported_version(
220 rest,
221 prefix.framework_version_count as usize,
222 FRAMEWORK_VERSIONS,
223 );
224 let framework_version = if let Some(version) = latest_framework_version {
225 version
226 } else {
227 tracelimit::error_ratelimited!("Unsupported framework version");
228 FRAMEWORK_VERSIONS[FRAMEWORK_VERSIONS.len() - 1]
229 };
230 let (latest_message_version, _) = Self::find_latest_supported_version(
231 rest,
232 prefix.message_version_count as usize,
233 SHUTDOWN_VERSIONS,
234 );
235 let message_version = if let Some(version) = latest_message_version {
236 version
237 } else {
238 tracelimit::error_ratelimited!("Unsupported message version");
239 SHUTDOWN_VERSIONS[SHUTDOWN_VERSIONS.len() - 1]
240 };
241
242 let message = hyperv_ic_protocol::NegotiateMessage {
243 framework_version_count: 1,
244 message_version_count: 1,
245 ..FromZeros::new_zeroed()
246 };
247 let response = hyperv_ic_protocol::Header {
248 message_type: hyperv_ic_protocol::MessageType::VERSION_NEGOTIATION,
249 message_size: (size_of_val(&message)
250 + size_of_val(&framework_version)
251 + size_of_val(&message_version)) as u16,
252 status: Status::SUCCESS,
253 transaction_id: header.transaction_id,
254 flags: hyperv_ic_protocol::HeaderFlags::new()
255 .with_transaction(header.flags.transaction())
256 .with_response(true),
257 ..FromZeros::new_zeroed()
258 };
259 self.pipe
260 .send_vectored(&[
261 IoSlice::new(response.as_bytes()),
262 IoSlice::new(message.as_bytes()),
263 IoSlice::new(framework_version.as_bytes()),
264 IoSlice::new(message_version.as_bytes()),
265 ])
266 .await
267 .map_err(Error::Ring)?;
268
269 tracing::info!(%framework_version, %message_version, "version negotiated");
270 self.state = ShutdownGuestChannelState::Running {
271 framework_version,
272 message_version,
273 };
274 Ok(())
275 }
276
277 async fn handle_shutdown_notification(
278 &mut self,
279 header: &hyperv_ic_protocol::Header,
280 buf: &[u8],
281 ic: &ShutdownGuestIc,
282 ) -> Result<(), Error> {
283 let ShutdownGuestChannelState::Running {
284 framework_version,
285 message_version,
286 } = &self.state
287 else {
288 panic!("Shutdown message processing while in invalid state");
289 };
290
291 let message = hyperv_ic_protocol::shutdown::ShutdownMessage::read_from_prefix(buf)
292 .map_err(|_| Error::TruncatedMessage)?
293 .0; let shutdown_type = if message.flags.restart() {
295 ShutdownType::Reboot
296 } else if message.flags.hibernate() {
297 ShutdownType::Hibernate
298 } else {
299 ShutdownType::PowerOff
300 };
301 let params = ShutdownParams {
302 shutdown_type,
303 force: message.flags.force(),
304 };
305
306 let status = match ic.send_shutdown_notification.call(|x| x, params).await {
308 Ok(ShutdownResult::Ok) => Status::SUCCESS,
309 Ok(ShutdownResult::Failed(x)) => Status(x),
310 Ok(ShutdownResult::NotReady) | Ok(ShutdownResult::AlreadyInProgress) | Err(_) => {
311 Status::FAIL
312 }
313 };
314
315 let response = hyperv_ic_protocol::Header {
317 framework_version: *framework_version,
318 message_version: *message_version,
319 message_type: hyperv_ic_protocol::MessageType::SHUTDOWN,
320 message_size: 0,
321 status,
322 transaction_id: header.transaction_id,
323 flags: hyperv_ic_protocol::HeaderFlags::new()
324 .with_transaction(header.flags.transaction())
325 .with_response(true),
326 ..FromZeros::new_zeroed()
327 };
328 self.pipe
329 .send(response.as_bytes())
330 .await
331 .map_err(Error::Ring)
332 }
333}
334
335async fn read_from_pipe<T: RingMem>(pipe: &mut MessagePipe<T>) -> Result<Vec<u8>, Error> {
336 let mut buf = vec![0; hyperv_ic_protocol::MAX_MESSAGE_SIZE];
337 let n = pipe.recv(&mut buf).await.map_err(Error::Ring)?;
338 let buf = &buf[..n];
339 Ok(buf.to_vec())
340}
341
342impl SimpleVmbusClientDevice for ShutdownGuestIc {
343 type SavedState = NoSavedState;
344 type Runner = ShutdownGuestChannel;
345
346 fn instance_id(&self) -> Guid {
347 INSTANCE_ID
348 }
349
350 fn offer(&self, _offer: &vmbus_core::protocol::OfferChannel) -> OfferResponse {
351 OfferResponse::Open
352 }
353
354 fn inspect(&mut self, req: inspect::Request<'_>, runner: Option<&mut Self::Runner>) {
355 req.respond().merge(self).merge(runner);
356 }
357
358 fn open(
359 &mut self,
360 _channel_idx: u16,
361 channel: RawAsyncChannel<MemoryBlockRingBuffer>,
362 ) -> Result<Self::Runner, ChannelOpenError> {
363 let pipe = MessagePipe::new(channel)?;
364 Ok(ShutdownGuestChannel::new(pipe))
365 }
366
367 fn close(&mut self, _channel_idx: u16) {}
368
369 fn supports_save_restore(
370 &mut self,
371 ) -> Option<
372 &mut dyn SaveRestoreSimpleVmbusClientDevice<
373 SavedState = Self::SavedState,
374 Runner = Self::Runner,
375 >,
376 > {
377 None
378 }
379}
380
381impl SimpleVmbusClientDeviceAsync for ShutdownGuestIc {
382 async fn run(
383 &mut self,
384 stop: &mut StopTask<'_>,
385 runner: &mut Self::Runner,
386 ) -> Result<(), Cancelled> {
387 stop.until_stopped(async {
388 match runner.process(self).await {
389 Ok(()) => {}
390 Err(err) => {
391 tracing::error!(
392 error = &err as &dyn std::error::Error,
393 "shutdown ic relay error"
394 )
395 }
396 }
397 })
398 .await
399 }
400}